JAX
JAX 是用纯 Python 编写的,但它依赖于 XLA,需要作为 jaxlib 包安装。使用以下指示信息来安装 pip 或 conda 的二进制包。
- pip安装
pip install --upgrade pip
CUDA 12 installation Note: wheels only available on linux. pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
CUDA 11 installation Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- 如果 JAX 检测到错误版本的 CUDA 库,则有几件事要检查:
- 确保未设置,因为可以 覆盖 CUDA 库。LD_LIBRARY_PATH``LD_LIBRARY_PATH 确保安装的 CUDA 库是 JAX 请求的库。 重新运行上面的安装命令应该可以工作。
- Conda 安装
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia