⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@Feng0w0
Copy link
Contributor

@Feng0w0 Feng0w0 commented Jan 12, 2026

In Wan's Rope, NPU will issue complex128 operations to AICPU, resulting in poor performance. Therefore, it is necessary to convert complex128 data to complex64

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Feng0w0, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the performance of the Wan model's Rope implementation when running on NPU devices. It achieves this by strategically reducing the precision of complex number operations from complex128 to complex64, which avoids inefficient fallback to AICPU for higher-precision computations. Additionally, the PR includes minor documentation updates to align with recent changes in the NPU training script directory structure.

Highlights

  • NPU Performance Optimization: The core change involves converting torch.complex128 data to torch.complex64 within the rope_apply function for NPU (Neural Processing Unit) operations. This is a conditional change, applied only when IS_NPU_AVAILABLE is true, to prevent performance degradation caused by complex128 operations being offloaded to AICPU.
  • NPU Availability Check: Introduced an import for IS_NPU_AVAILABLE from diffsynth.core.device.npu_compatible_device and diffsynth.core.device to enable conditional logic for NPU-specific optimizations.
  • Documentation Update: Updated the documentation (both English and Chinese versions) to reflect a change in the directory name for NPU training script samples from npu_scripts to npu_training.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to improve performance on NPU by using torch.complex64 for RoPE operations. While the intent is correct, the implementation is incomplete in both diffsynth/models/wan_video_dit.py and diffsynth/utils/xfuser/xdit_context_parallel.py. In both files, only one of the operands for a complex multiplication is converted to complex64, while the other remains complex128. Due to PyTorch's type promotion rules, the operation will still be performed in complex128, negating the intended performance benefit. I've provided comments and a suggestion to fix this issue by ensuring both operands are consistently complex64 on NPU devices. The documentation updates appear correct.

Comment on lines 95 to 98
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation converts freqs to complex64 on NPU, but x_out remains complex128 (as it's created from x.to(torch.float64)). During the multiplication x_out * freqs, PyTorch will promote freqs back to complex128, which means the operation will still use complex128 and not achieve the intended performance improvement on NPU. To fix this, x_out should also be complex64 when running on NPU. The suggested change ensures both tensors are of the correct data type for the operation.

Suggested change
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
x_out = torch.view_as_complex(x.to(torch.float32 if IS_NPU_AVAILABLE else torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(x_out.dtype)
x_out = torch.view_as_real(x_out * freqs).flatten(2)

freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]

freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change converts freqs_rank to complex64 on NPU, but the x_out tensor, defined on line 46, is still created as complex128 (from x.to(torch.float64)). During the multiplication x_out * freqs_rank on line 54, PyTorch will promote freqs_rank back to complex128. This means the operation will still use complex128, which negates the performance benefit this change is trying to achieve on NPU. To fix this, x_out should also be complex64. You should modify line 46 to use x.to(torch.float32) when IS_NPU_AVAILABLE is true.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant