diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 4b66dd32e521..a94f8fdc4394 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -83,7 +83,7 @@ def get_qwen_prompt_embeds( split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = tokenizer_max_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -103,6 +103,7 @@ def get_qwen_prompt_embeds_edit( image: Optional[torch.Tensor] = None, prompt_template_encode: str = QWENIMAGE_EDIT_PROMPT_TEMPLATE, prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, + tokenizer_max_length: int = 1024, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -130,7 +131,7 @@ def get_qwen_prompt_embeds_edit( split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = tokenizer_max_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -151,6 +152,7 @@ def get_qwen_prompt_embeds_edit_plus( prompt_template_encode: str = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, img_template_encode: str = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, + tokenizer_max_length: int = 1024, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -186,7 +188,7 @@ def get_qwen_prompt_embeds_edit_plus( split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = tokenizer_max_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -767,6 +769,7 @@ def expected_configs(self) -> List[ConfigSpec]: return [ ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE), ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX), + ConfigSpec(name="tokenizer_max_length", default=1024), ] @property @@ -838,6 +841,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): image=block_state.resized_image, prompt_template_encode=components.config.prompt_template_encode, prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + tokenizer_max_length=components.config.tokenizer_max_length, device=device, ) @@ -852,6 +856,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): image=block_state.resized_image, prompt_template_encode=components.config.prompt_template_encode, prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + tokenizer_max_length=components.config.tokenizer_max_length, device=device, ) @@ -890,6 +895,7 @@ def expected_configs(self) -> List[ConfigSpec]: ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE), ConfigSpec(name="img_template_encode", default=QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE), ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX), + ConfigSpec(name="tokenizer_max_length", default=1024), ] @property @@ -962,6 +968,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): prompt_template_encode=components.config.prompt_template_encode, img_template_encode=components.config.img_template_encode, prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + tokenizer_max_length=components.config.tokenizer_max_length, device=device, ) @@ -978,6 +985,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): prompt_template_encode=components.config.prompt_template_encode, img_template_encode=components.config.img_template_encode, prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + tokenizer_max_length=components.config.tokenizer_max_length, device=device, ) ) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bc3ce84e1019..dbb50a2607dd 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -190,6 +190,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -200,7 +201,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -211,7 +212,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -251,7 +252,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index ce6fc974a56e..2bcacdce571d 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -254,6 +254,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -264,7 +265,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -275,7 +276,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -316,7 +317,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 77d78a5ca7a1..8e6b7c45c665 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -236,6 +236,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -246,7 +247,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(self.device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -257,7 +258,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -297,7 +298,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index dd723460a59e..89b53364849a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -229,6 +229,7 @@ def _get_qwen_prompt_embeds( image: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -257,8 +258,10 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + # Truncate if longer than max_sequence_length + split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -301,7 +304,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, image, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index cf467203a9d2..33e10c44d7de 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -240,6 +240,7 @@ def _get_qwen_prompt_embeds( image: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -268,8 +269,10 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + # Truncate if longer than max_sequence_length + split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -313,7 +316,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, image, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 257e2d846c7c..9fd02676474c 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -232,6 +232,7 @@ def _get_qwen_prompt_embeds( image: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -270,8 +271,10 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + # Truncate if longer than max_sequence_length + split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -315,7 +318,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, image, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index e0b41b8b8799..27651d2effd5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -197,6 +197,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -207,7 +208,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -218,7 +219,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -294,7 +295,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 83f02539b1ba..99cbfe4a46cd 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -207,6 +207,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -217,7 +218,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -228,7 +229,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -305,7 +306,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] diff --git a/tests/pipelines/qwenimage/test_qwenimage.py b/tests/pipelines/qwenimage/test_qwenimage.py index 8ebfe7d08bc1..bbc95072f5be 100644 --- a/tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/pipelines/qwenimage/test_qwenimage.py @@ -160,7 +160,7 @@ def test_inference(self): self.assertEqual(generated_image.shape, (3, 32, 32)) # fmt: off - expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208]) + expected_slice = torch.tensor([0.5633, 0.6416, 0.6035, 0.5617, 0.5813, 0.5502, 0.5718, 0.6345, 0.4164, 0.3563, 0.5630, 0.4849, 0.4979, 0.5269, 0.4096, 0.5020]) # fmt: on generated_slice = generated_image.flatten() @@ -234,3 +234,61 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): expected_diff_max, "VAE tiling should not affect the inference results", ) + + def test_prompt_embeds_padding(self): + """Test that prompt embeddings are padded to tokenizer_max_length (1024) instead of batch max.""" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + + # Test 1: Short prompt should be padded to 1024, not to its actual length + short_prompt = "test" + prompt_embeds, prompt_embeds_mask = pipe.encode_prompt( + prompt=short_prompt, + device=torch_device, + num_images_per_prompt=1, + max_sequence_length=1024, + ) + + # Should be padded to 1024 (tokenizer_max_length), not to the actual token count + self.assertEqual( + prompt_embeds.shape[1], + 1024, + f"Short prompt should be padded to 1024, got {prompt_embeds.shape[1]}", + ) + self.assertEqual( + prompt_embeds_mask.shape[1], + 1024, + f"Mask should be 1024 length, got {prompt_embeds_mask.shape[1]}", + ) + + # Test 2: Batch with different lengths should all be padded to same length (1024) + batch_prompts = ["short", "a much longer prompt here"] + prompt_embeds_batch, mask_batch = pipe.encode_prompt( + prompt=batch_prompts, + device=torch_device, + num_images_per_prompt=1, + max_sequence_length=1024, + ) + + self.assertEqual(prompt_embeds_batch.shape[0], 2, "Batch size should be 2") + self.assertEqual( + prompt_embeds_batch.shape[1], + 1024, + f"All prompts in batch should be padded to 1024, got {prompt_embeds_batch.shape[1]}", + ) + + # Test 3: With default max_sequence_length (512), should still pad to 1024 internally + # then truncate to 512 + prompt_embeds_512, mask_512 = pipe.encode_prompt( + prompt=short_prompt, + device=torch_device, + num_images_per_prompt=1, + max_sequence_length=512, + ) + + self.assertEqual( + prompt_embeds_512.shape[1], + 512, + f"With max_sequence_length=512, should truncate to 512, got {prompt_embeds_512.shape[1]}", + )