Skip to content

Commit

Permalink
fix a bug in StableDiffusionUpscalePipeline when prompt is None (#…
Browse files Browse the repository at this point in the history
…4278)

* fix batch_size

* add test

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
2 people authored and patrickvonplaten committed Jul 27, 2023
1 parent 49c9517 commit a982916
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,13 @@ def check_inputs(

# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if isinstance(prompt, str):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
else:
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]

if isinstance(image, list):
image_batch_size = len(image)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,68 @@ def test_stable_diffusion_upscale_batch(self):
image = output.images
assert image.shape[0] == 2

def test_stable_diffusion_upscale_prompt_embeds(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet_upscale
low_res_scheduler = DDPMScheduler()
scheduler = DDIMScheduler(prediction_type="v_prediction")
vae = self.dummy_vae
text_encoder = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionUpscalePipeline(
unet=unet,
low_res_scheduler=low_res_scheduler,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
max_noise_level=350,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
image=low_res_image,
generator=generator,
guidance_scale=6.0,
noise_level=20,
num_inference_steps=2,
output_type="np",
)

image = output.images

generator = torch.Generator(device=device).manual_seed(0)
prompt_embeds = sd_pipe._encode_prompt(prompt, device, 1, False)
image_from_prompt_embeds = sd_pipe(
prompt_embeds=prompt_embeds,
image=[low_res_image],
generator=generator,
guidance_scale=6.0,
noise_level=20,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]

image_slice = image[0, -3:, -3:, -1]
image_from_prompt_embeds_slice = image_from_prompt_embeds[0, -3:, -3:, -1]

expected_height_width = low_res_image.size[0] * 4
assert image.shape == (1, expected_height_width, expected_height_width, 3)
expected_slice = np.array([0.3113, 0.3910, 0.4272, 0.4859, 0.5061, 0.4652, 0.5362, 0.5715, 0.5661])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_prompt_embeds_slice.flatten() - expected_slice).max() < 1e-2

@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_upscale_fp16(self):
"""Test that stable diffusion upscale works with fp16"""
Expand Down

0 comments on commit a982916

Please sign in to comment.