Skip to content

Commit

Permalink
compact code following @limiteinductive suggestion https://github.com…
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Feb 15, 2024
1 parent 2a03655 commit 298fb35
Showing 1 changed file with 4 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,11 @@ def compute_clip_text_embedding(self, text: str | list[str], negative_text: str
negative_text: The negative prompt to compute the CLIP text embedding of.
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
"""
text = [text] if isinstance(text, str) else text
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
conditional_embedding = self.clip_text_encoder(text)
if text == negative_text:
negative_embedding = conditional_embedding
else:
if isinstance(text, list) and isinstance(negative_text, list):
assert len(text) == len(
negative_text
), "The length of the text list and negative_text should be the same"

if isinstance(negative_text, str) and isinstance(text, list):
negative_text = [negative_text] * len(text)

negative_embedding = self.clip_text_encoder(negative_text)

negative_embedding = self.clip_text_encoder(negative_text) if negative_text else conditional_embedding
return cat((negative_embedding, conditional_embedding))

def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
Expand Down

0 comments on commit 298fb35

Please sign in to comment.