diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 5c3d6ce611cc..86a8c5fbb732 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -466,12 +466,15 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, height: int = 720, width: int = 1280, num_frames: int = 129, num_inference_steps: int = 50, sigmas: List[float] = None, guidance_scale: float = 6.0, + true_cfg_scale: float = 1.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -590,6 +593,7 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Encode input prompt + do_true_cfg = true_cfg_scale > 1.0 and negative_prompt is not None prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, @@ -601,12 +605,29 @@ def __call__( device=device, max_sequence_length=max_sequence_length, ) + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=None, + pooled_prompt_embeds=None, + prompt_attention_mask=None, + device=device, + max_sequence_length=max_sequence_length, + ) transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) if pooled_prompt_embeds is not None: pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + if do_true_cfg: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + if negative_pooled_prompt_embeds is not None: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas @@ -658,6 +679,18 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]