-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[Metal] Add Flash Attention VJP for training #2995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Notes: I'm working on https://github.com/mlx-node/mlx-node and trying to port some features in trl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h
Outdated
Show resolved
Hide resolved
568ff36 to
26b5857
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
d9089ef to
dd8daf1
Compare
Implement fused backward pass (VJP) for scaled_dot_product_attention on Metal GPU, enabling efficient training without falling back to unfused attention. - **dQ Kernel** (steel_attention_vjp_dq.h): Computes query gradients - Outer loop over KV blocks, inner accumulation for dQ - Uses log2 domain for numerical stability - **dK/dV Kernel** (steel_attention_vjp_dkv.h): Computes key/value gradients - K-row ownership model eliminates atomic operations - Each simdgroup owns exclusive K rows to prevent races - Optimized path for short sequences (L ≤ 8) - Uses shared memory for efficient reduction - Float32 accumulators for half/bfloat16 precision - Logsumexp caching from forward pass - Proper GQA (grouped query attention) support - Causal mask support - Comprehensive test coverage for all code paths - No gradient support for mask or attention sinks (falls back to unfused) - Requires logsumexp from forward pass (training mode only) - Head dimension D=256 not supported in vector VJP (threadgroup memory) Co-Authored-By: Claude <[email protected]>
dd8daf1 to
5c78507
Compare
Summary
Implements fused backward pass (VJP) for
scaled_dot_product_attentionon Metal GPU. This enables efficient gradient computation during training without falling back to unfused (decomposed) attention operations.Changes
New Files
mlx/backend/metal/kernels/sdpa_vector_vjp.h- Vector VJP kernel for short sequences (L ≤ 8)mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.h- STEEL dQ gradient kernelmlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h- STEEL dK/dV gradient kernelModified Files
mlx/backend/metal/scaled_dot_product_attention.cpp- VJP dispatch logic (+840 lines)mlx/fast.cpp/mlx/fast_primitives.h- Logsumexp caching, VJP routingpython/tests/test_fast_sdpa.py- Comprehensive VJP tests (+220 lines)Implementation Notes
Uses a two-kernel approach to avoid atomic operations:
dQ kernel (
steel_attention_vjp_dq.h):dK/dV kernel (
steel_attention_vjp_dkv.h):Vector VJP (
sdpa_vector_vjp.h):Key Features
Limitations
Test Plan
test_sdpa_gradpassestest_sdpa_grad_vector_path- short sequences (L=1,4,7,8)test_sdpa_grad_steel_path- longer sequences (L=16,32,128,256)test_sdpa_grad_head_dims- head dimensions (D=32,64,96,128)test_sdpa_grad_gqa- GQA configurations (4:1, 8:1, 16:1, MHA)test_sdpa_grad_dtypes- float16, bfloat16, float32test_sdpa_grad_edge_cases- L=1, non-power-of-2, large batch, qL≠kvLAll 21 SDPA tests pass (1 skipped for unrelated disabled feature).