-
Notifications
You must be signed in to change notification settings - Fork 611
[Pytorch] Add get_backward_dw_params api for TE module #2614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Pytorch] Add get_backward_dw_params api for TE module #2614
Conversation
Greptile SummaryThis PR adds a The new method:
Key Change:
Potential Issue:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant MegatronLM as Megatron-LM
participant TEModule as TE Module
participant Hooks as Parameter Hooks
participant BackwardDW as backward_dw()
Note over MegatronLM,BackwardDW: Setup Phase
MegatronLM->>TEModule: register_wgrad_accumulation_and_reduce_hooks(hook_fn)
TEModule->>TEModule: Store hook in wgrad_accumulation_and_reduce_hooks[]
Note over MegatronLM,BackwardDW: Before CUDA Graph Execution
MegatronLM->>TEModule: get_backward_dw_params()
TEModule->>TEModule: noop_cat(self._get_weight_tensors())
alt use_bias == True
TEModule->>TEModule: noop_cat([getattr(self, name) for name in self.bias_names])
end
TEModule-->>MegatronLM: Return [weight_params, bias_params?]
MegatronLM->>Hooks: Register backward_post_hook on returned params
Note over MegatronLM,BackwardDW: CUDA Graph Execution
MegatronLM->>BackwardDW: Execute backward_dw() in CUDA graph
BackwardDW->>BackwardDW: Pop wgrad from store
alt fuse_wgrad_accumulation == False
BackwardDW->>BackwardDW: weight_tensor.grad = wgrad
end
alt use_bias == True
BackwardDW->>BackwardDW: bias_tensor.grad = bgrad
end
BackwardDW->>Hooks: Call wgrad_accumulation_and_reduce_hooks
Hooks->>MegatronLM: Trigger parameter grad reduce
|
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
| Get the parameters for the backward weight gradient computation. | ||
| """ | ||
| params = [] | ||
| params.append(noop_cat(self._get_weight_tensors())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True
Description
This PR adds
get_backward_dw_paramsfor TE modules, which helps manage the hooks of parameters.For Megatron-LM,
get_backward_dw_paramswill be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: