⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@Wohox
Copy link
Contributor

@Wohox Wohox commented Jan 22, 2026

Description

This PR adds get_backward_dw_params for TE modules, which helps manage the hooks of parameters.

For Megatron-LM, get_backward_dw_params will be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Summary

This PR adds a get_backward_dw_params() method to TransformerEngineBaseModule that returns the parameters involved in delayed weight gradient computation. This API is needed for Megatron-LM's CUDA graph execution with wgrad computation, where parameter hooks need to be registered to ensure proper gradient reduction.

The new method:

  • Returns a list containing weight parameters via noop_cat(self._get_weight_tensors())
  • Conditionally appends bias parameters if self.use_bias is true
  • Mirrors the parameter access pattern in the existing backward_dw() method

Key Change:

  • Enables external frameworks like Megatron-LM to query which parameters will be modified during backward_dw() execution, allowing them to register appropriate hooks before CUDA graph capture

Potential Issue:

  • The method unconditionally returns weight parameters, but backward_dw() only accesses weights when not self.fuse_wgrad_accumulation. This discrepancy may cause unnecessary hooks to be registered when weight gradient accumulation is fused.

Confidence Score: 3/5

  • This PR introduces a helper API with a potential logic issue that may cause incorrect hook registration in certain configurations
  • Score reflects that while the implementation follows the pattern from backward_dw(), there's a critical discrepancy: weight parameters are unconditionally returned regardless of fuse_wgrad_accumulation setting, which could lead to hooks being registered on parameters that won't actually receive gradients. The PR description lacks detail about testing with different configuration combinations.
  • Check transformer_engine/pytorch/module/base.py line 1537 to verify the logic handles fuse_wgrad_accumulation correctly

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Added get_backward_dw_params method that returns weight and bias parameters for delayed weight gradient computation. The method mirrors logic from backward_dw but unconditionally returns weight parameters regardless of fuse_wgrad_accumulation setting.

Sequence Diagram

sequenceDiagram
    participant MegatronLM as Megatron-LM
    participant TEModule as TE Module
    participant Hooks as Parameter Hooks
    participant BackwardDW as backward_dw()
    
    Note over MegatronLM,BackwardDW: Setup Phase
    MegatronLM->>TEModule: register_wgrad_accumulation_and_reduce_hooks(hook_fn)
    TEModule->>TEModule: Store hook in wgrad_accumulation_and_reduce_hooks[]
    
    Note over MegatronLM,BackwardDW: Before CUDA Graph Execution
    MegatronLM->>TEModule: get_backward_dw_params()
    TEModule->>TEModule: noop_cat(self._get_weight_tensors())
    alt use_bias == True
        TEModule->>TEModule: noop_cat([getattr(self, name) for name in self.bias_names])
    end
    TEModule-->>MegatronLM: Return [weight_params, bias_params?]
    MegatronLM->>Hooks: Register backward_post_hook on returned params
    
    Note over MegatronLM,BackwardDW: CUDA Graph Execution
    MegatronLM->>BackwardDW: Execute backward_dw() in CUDA graph
    BackwardDW->>BackwardDW: Pop wgrad from store
    alt fuse_wgrad_accumulation == False
        BackwardDW->>BackwardDW: weight_tensor.grad = wgrad
    end
    alt use_bias == True
        BackwardDW->>BackwardDW: bias_tensor.grad = bgrad
    end
    BackwardDW->>Hooks: Call wgrad_accumulation_and_reduce_hooks
    Hooks->>MegatronLM: Trigger parameter grad reduce
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@Wohox
Copy link
Contributor Author

Wohox commented Jan 22, 2026

@buptzyb @lhb8125 Please help review this PR, thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Get the parameters for the backward weight gradient computation.
"""
params = []
params.append(noop_cat(self._get_weight_tensors()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant