Fix/wan2.1 flash attention#153
Conversation
There was a problem hiding this comment.
Pull request overview
Updates AMD inference Dockerfiles to adjust FlashAttention build/install behavior (notably for Wan2.1) and expands the supported ROCm arch list for Mochi.
Changes:
- Replaces the pinned/parameterized FlashAttention wheel build in the Wan2.1 Dockerfile with a direct
setup.py installfrom an unpinned ROCm/flash-attention clone. - Adds
gfx950to thePYTORCH_ROCM_ARCHlist in the Mochi inference Dockerfile.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| docker/pyt_wan2.1_inference.ubuntu.amd.Dockerfile | Changes FlashAttention installation steps for Wan2.1 image builds. |
| docker/pyt_mochi_inference.ubuntu.amd.Dockerfile | Updates the ROCm architecture list used when building FlashAttention. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #ARG BUILD_FA="1" | ||
| #ARG FA_BRANCH="v3.0.0.r1-cktile" | ||
| #ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||
| #RUN if [ "$BUILD_FA" = "1" ]; then \ | ||
| # cd ${WORKSPACE_DIR} \ | ||
| # && pip uninstall -y flash-attention \ | ||
| # && rm -rf flash-attention \ | ||
| # && git clone ${FA_REPO} \ | ||
| # && cd flash-attention \ | ||
| # && git checkout ${FA_BRANCH} \ | ||
| # && git submodule update --init \ | ||
| # && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ | ||
| # && pip install dist/*.whl \ | ||
| # && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ | ||
| # fi | ||
| # install flash attention | ||
| ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" | ||
|
|
||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ | ||
| cd flash-attention &&\ | ||
| python setup.py install |
| #ARG BUILD_FA="1" | ||
| #ARG FA_BRANCH="v3.0.0.r1-cktile" | ||
| #ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||
| #RUN if [ "$BUILD_FA" = "1" ]; then \ | ||
| # cd ${WORKSPACE_DIR} \ | ||
| # && pip uninstall -y flash-attention \ | ||
| # && rm -rf flash-attention \ | ||
| # && git clone ${FA_REPO} \ | ||
| # && cd flash-attention \ | ||
| # && git checkout ${FA_BRANCH} \ | ||
| # && git submodule update --init \ | ||
| # && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ | ||
| # && pip install dist/*.whl \ | ||
| # && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ | ||
| # fi | ||
| # install flash attention | ||
| ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" | ||
|
|
||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ | ||
| cd flash-attention &&\ | ||
| python setup.py install |
| ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" | ||
| ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 | ||
| ARG PYTORCH_ROCM_ARCH=gfx950;gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 |
| # install flash attention | ||
| ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" | ||
|
|
||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ |
There was a problem hiding this comment.
Please use FA_BRANCH & FA_REPO arguments. These are meant to build using build arguments with whatever branch is needed. Already existing branch is the latest tag from Flash-attention, any specific reason to remove it?
There was a problem hiding this comment.
@lcskrishna as per the steps mentioned on SWDEV-564747, thus why the args. has removed, and also please refer the steps mentioned in this repo - https://github.com/Dao-AILab/flash-attention.
1. Replace deprecated apt-key usage by securely adding the ROCm GPG key to /etc/apt/keyrings, configuring the ROCm APT repository with proper signing, and updating package lists. 2. Comment out the fixed Flash Attention branch argument to allow flexibility in selecting or defaulting the repository version during build.
vadseshu
left a comment
There was a problem hiding this comment.
“Disables the fixed Flash Attention branch argument, enabling more flexible version selection during builds and reducing tight coupling to a specific release.”
Motivation
Updated wan2.1 dockerfile with FA steps taken from ROCM FA repo.
Technical Details
Test Plan
Test Result
Submission Checklist