Flash Attention 3 手动编译安装

系统信息

  • 系统镜像:docker pull nvidia/cuda:12.4.1-cudnn-devel-rockylinux8
  • python 版本:3.10
  • pytorch 版本:2.6.0+cu124

源码编译安装

注意:目前 Flash Attention 3 还是 beta 版本。

# 下载最新源码
git clone https://github.com/Dao-AILab/flash-attention.git
cd ./flash-attention
git submodule update --init --recursive

# 安装依赖
yum install -y gcc-toolset-11
pip install ninja
source /opt/rh/gcc-toolset-11/enable

# 编译安装
# 由于 Flash Attention 3 还处于 beta 阶段,每次编译的版本号均一样
# 建议添加编译分支的 commit id 用于区分具体的版本
cd ./hopper
MAX_JOBS=4 \
  NVCC_THREADS=2 \
  FLASH_ATTENTION_FORCE_BUILD=TRUE \
  FLASH_ATTN_LOCAL_VERSION=$(git rev-parse --short HEAD) \
  pip wheel \
  -v \
  --no-build-isolation \
  --no-cache-dir \
  --no-deps \
  .
pip install ./flash_attn_3-*.whl

单测(可选)

注意:所有单测执行下来一般会有几百个测例失败,属于正常情况,只要不影响所使用模型的训练、推理精度,可以忽略未通过的测例。

pip install pytest
export PYTHONPATH=$PWD
pytest -q -s test_flash_attn.py

单卡 H800 测试结果如下(大部份是显存不够 torch.OutOfMemoryError):

537 failed, 110502 passed, 89088 skipped in 2544.67s (0:42:24)

注意事项

  1. Hugging Face 的 transformers 已支持使用 flash_attention_3 进行训练推理,但实测相比 flash_attention_2flash_attention_3 在训练过程中 loss 收敛还不太稳定,建议仅用于尝鲜测试。
Comment