-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[model][NPU]:Wan model rope use torch.complex64 in NPU #1191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @Feng0w0, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the performance of the Wan model's Rope implementation when running on NPU devices. It achieves this by strategically reducing the precision of complex number operations from Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to improve performance on NPU by using torch.complex64 for RoPE operations. While the intent is correct, the implementation is incomplete in both diffsynth/models/wan_video_dit.py and diffsynth/utils/xfuser/xdit_context_parallel.py. In both files, only one of the operands for a complex multiplication is converted to complex64, while the other remains complex128. Due to PyTorch's type promotion rules, the operation will still be performed in complex128, negating the intended performance benefit. I've provided comments and a suggestion to fix this issue by ensuring both operands are consistently complex64 on NPU devices. The documentation updates appear correct.
| x_out = torch.view_as_complex(x.to(torch.float64).reshape( | ||
| x.shape[0], x.shape[1], x.shape[2], -1, 2)) | ||
| freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs | ||
| x_out = torch.view_as_real(x_out * freqs).flatten(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation converts freqs to complex64 on NPU, but x_out remains complex128 (as it's created from x.to(torch.float64)). During the multiplication x_out * freqs, PyTorch will promote freqs back to complex128, which means the operation will still use complex128 and not achieve the intended performance improvement on NPU. To fix this, x_out should also be complex64 when running on NPU. The suggested change ensures both tensors are of the correct data type for the operation.
| x_out = torch.view_as_complex(x.to(torch.float64).reshape( | |
| x.shape[0], x.shape[1], x.shape[2], -1, 2)) | |
| freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs | |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) | |
| x_out = torch.view_as_complex(x.to(torch.float32 if IS_NPU_AVAILABLE else torch.float64).reshape( | |
| x.shape[0], x.shape[1], x.shape[2], -1, 2)) | |
| freqs = freqs.to(x_out.dtype) | |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) |
| freqs = pad_freqs(freqs, s_per_rank * sp_size) | ||
| freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] | ||
|
|
||
| freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change converts freqs_rank to complex64 on NPU, but the x_out tensor, defined on line 46, is still created as complex128 (from x.to(torch.float64)). During the multiplication x_out * freqs_rank on line 54, PyTorch will promote freqs_rank back to complex128. This means the operation will still use complex128, which negates the performance benefit this change is trying to achieve on NPU. To fix this, x_out should also be complex64. You should modify line 46 to use x.to(torch.float32) when IS_NPU_AVAILABLE is true.
In Wan's Rope, NPU will issue complex128 operations to AICPU, resulting in poor performance. Therefore, it is necessary to convert complex128 data to complex64