JAX

JAX 是用纯 Python 编写的,但它依赖于 XLA,需要作为 jaxlib 包安装。使用以下指示信息来安装 pip 或 conda 的二进制包。


alt text

  • pip安装
    pip install --upgrade pip

CUDA 12 installation Note: wheels only available on linux. pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


CUDA 11 installation Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

  • 如果 JAX 检测到错误版本的 CUDA 库,则有几件事要检查:
  • 确保未设置,因为可以 覆盖 CUDA 库。LD_LIBRARY_PATH``LD_LIBRARY_PATH 确保安装的 CUDA 库是 JAX 请求的库。 重新运行上面的安装命令应该可以工作。
  • Conda 安装
    conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia