diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index 3f8b4f162..39051a8f8 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -40,7 +40,7 @@ def __init__( def init_context(self) -> Contexts: return {"text_encoder_pooling": {"end_of_text_index": []}} - def __call__(self, text: str) -> tuple[Float[Tensor, "batch 77 1280"], Float[Tensor, "batch 1280"]]: + def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 1280"], Float[Tensor, "batch 1280"]]: return super().__call__(text) @property diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index 0b3c7feee..4563c1531 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -103,7 +103,7 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: @no_grad() -def test_double_text_encoder_batch2(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None: +def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None: manual_seed(seed=0) prompt1 = "A photo of a pizza." prompt2 = "A giant duck."