⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
14 changes: 11 additions & 3 deletions src/diffusers/modular_pipelines/qwenimage/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_qwen_prompt_embeds(
split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = tokenizer_max_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand All @@ -103,6 +103,7 @@ def get_qwen_prompt_embeds_edit(
image: Optional[torch.Tensor] = None,
prompt_template_encode: str = QWENIMAGE_EDIT_PROMPT_TEMPLATE,
prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX,
tokenizer_max_length: int = 1024,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
Expand Down Expand Up @@ -130,7 +131,7 @@ def get_qwen_prompt_embeds_edit(
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = tokenizer_max_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand All @@ -151,6 +152,7 @@ def get_qwen_prompt_embeds_edit_plus(
prompt_template_encode: str = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE,
img_template_encode: str = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE,
prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX,
tokenizer_max_length: int = 1024,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
Expand Down Expand Up @@ -186,7 +188,7 @@ def get_qwen_prompt_embeds_edit_plus(
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = tokenizer_max_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -767,6 +769,7 @@ def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE),
ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX),
ConfigSpec(name="tokenizer_max_length", default=1024),
]

@property
Expand Down Expand Up @@ -838,6 +841,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
image=block_state.resized_image,
prompt_template_encode=components.config.prompt_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
tokenizer_max_length=components.config.tokenizer_max_length,
device=device,
)

Expand All @@ -852,6 +856,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
image=block_state.resized_image,
prompt_template_encode=components.config.prompt_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
tokenizer_max_length=components.config.tokenizer_max_length,
device=device,
)

Expand Down Expand Up @@ -890,6 +895,7 @@ def expected_configs(self) -> List[ConfigSpec]:
ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE),
ConfigSpec(name="img_template_encode", default=QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE),
ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX),
ConfigSpec(name="tokenizer_max_length", default=1024),
]

@property
Expand Down Expand Up @@ -962,6 +968,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
prompt_template_encode=components.config.prompt_template_encode,
img_template_encode=components.config.img_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
tokenizer_max_length=components.config.tokenizer_max_length,
device=device,
)

Expand All @@ -978,6 +985,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
prompt_template_encode=components.config.prompt_template_encode,
img_template_encode=components.config.img_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
tokenizer_max_length=components.config.tokenizer_max_length,
device=device,
)
)
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def _get_qwen_prompt_embeds(
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand All @@ -200,7 +201,7 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
Expand All @@ -211,7 +212,7 @@ def _get_qwen_prompt_embeds(
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -251,7 +252,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def _get_qwen_prompt_embeds(
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand All @@ -264,7 +265,7 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
Expand All @@ -275,7 +276,7 @@ def _get_qwen_prompt_embeds(
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -316,7 +317,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def _get_qwen_prompt_embeds(
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand All @@ -246,7 +247,7 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
Expand All @@ -257,7 +258,7 @@ def _get_qwen_prompt_embeds(
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -297,7 +298,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def _get_qwen_prompt_embeds(
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand Down Expand Up @@ -257,8 +258,10 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
# Truncate if longer than max_sequence_length
split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -301,7 +304,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, image, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def _get_qwen_prompt_embeds(
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand Down Expand Up @@ -268,8 +269,10 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
# Truncate if longer than max_sequence_length
split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -313,7 +316,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, image, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def _get_qwen_prompt_embeds(
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand Down Expand Up @@ -270,8 +271,10 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
# Truncate if longer than max_sequence_length
split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -315,7 +318,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, image, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def _get_qwen_prompt_embeds(
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand All @@ -207,7 +208,7 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
Expand All @@ -218,7 +219,7 @@ def _get_qwen_prompt_embeds(
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -294,7 +295,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
Expand Down
Loading