Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Value dimension in ImageCrossAttention #188

Merged
merged 1 commit into from
Jan 17, 2024

Conversation

hugojarkoff
Copy link
Contributor

This is a minor issue, but while reading the codebase I noticed that the ImageCrossAttention uses the wrong in_features dimension in the second fl.Linear layer:

class ImageCrossAttention(fl.Chain):
    def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
            ...      
                    # This first Linear corresponds to the keys K' 
                    fl.Linear(
                      ...
                    ),
                    ...
                    # This second Linear corresponds to the values V'
                    fl.Linear(
                        in_features=text_cross_attention.key_embedding_dim,
                        out_features=text_cross_attention.inner_dim,
                        bias=text_cross_attention.use_bias,
                        device=text_cross_attention.device,
                        dtype=text_cross_attention.dtype,
                    ),
                ),
            ...

Should (IIUC) be changed to :

class ImageCrossAttention(fl.Chain):
    def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
            ...      
                    # This first Linear corresponds to the keys K' 
                    fl.Linear(
                      ...
                    ),
                    ...
                    # This second Linear corresponds to the values V'
                    fl.Linear(
                        in_features=text_cross_attention.value_embedding_dim,
                        out_features=text_cross_attention.inner_dim,
                        bias=text_cross_attention.use_bias,
                        device=text_cross_attention.device,
                        dtype=text_cross_attention.dtype,
                    ),
                ),
            ...

In practice, and IINM, this shouldn't change anything in the context of Image Cross-Attention (since both key and query dim are the same).

@hugojarkoff hugojarkoff requested a review from deltheil January 17, 2024 14:57
@limiteinductive limiteinductive self-requested a review January 17, 2024 15:40
@limiteinductive limiteinductive merged commit a6a9c8b into main Jan 17, 2024
1 check passed
@limiteinductive limiteinductive deleted the pr/fix-mistake-image-cross-attention branch January 17, 2024 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants