-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KeyError: 'image_adapter.pos_embed' #48
Comments
@xxxxyliu Same situation for me - let me know if you figured out the solution. |
@xxxxyliu Here we go: def resize_abs_pos_embed(self, checkpoint):
# Check for the correct key in the checkpoint
pos_embed_key = 'backbone.image_adapter.pos_embed' if 'backbone.image_adapter.pos_embed' in checkpoint else 'image_adapter.pos_embed'
pos_embed_checkpoint = checkpoint[pos_embed_key]
embedding_size = pos_embed_checkpoint.shape[-1]
bucket_size = self.image_adapter.bucket_size
num_patches = bucket_size ** 2
num_extra_tokens = self.image_adapter.pos_embed.shape[-2] - num_patches
# Calculate original and new sizes for position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
new_size = int(num_patches ** 0.5)
# Keep class_token and dist_token unchanged
rank, _ = get_dist_info()
if orig_size != new_size:
if rank == 0:
print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size}")
extra_tokens = pos_embed_checkpoint[:num_extra_tokens]
# Only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
checkpoint[pos_embed_key] = new_pos_embed
return checkpoint |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I haven't changed the model structure, but I'm encountering an error when using pre-trained weights.
The text was updated successfully, but these errors were encountered: