Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: 'image_adapter.pos_embed' #48

Open
Xy-unu opened this issue Jan 29, 2024 · 2 comments
Open

KeyError: 'image_adapter.pos_embed' #48

Xy-unu opened this issue Jan 29, 2024 · 2 comments
Assignees

Comments

@Xy-unu
Copy link

Xy-unu commented Jan 29, 2024

2024-01-28 15:51:36,396 - mmseg - INFO - load checkpoint from local path: /data/onepeace_seg_cocostuff2ade20k.pth
Traceback (most recent call last):
  File "/root/ONE-PEACE/one_peace_vision/seg/train.py", line 243, in <module>
    main()
  File "/root/ONE-PEACE/one_peace_vision/seg/train.py", line 203, in main
    model.init_weights()
  File "/opt/conda/envs/onepeace/lib/python3.8/site-packages/mmcv/runner/base_module.py", line 116, in init_weights
    m.init_weights()
  File "/root/ONE-PEACE/one_peace_vision/seg/mmseg_custom/models/backbones/onepeace.py", line 571, in init_weights
    state_dict = self.resize_abs_pos_embed(model)
  File "/root/ONE-PEACE/one_peace_vision/seg/mmseg_custom/models/backbones/onepeace.py", line 467, in resize_abs_pos_embed
    pos_embed_checkpoint = checkpoint['image_adapter.pos_embed']
KeyError: 'image_adapter.pos_embed'

I haven't changed the model structure, but I'm encountering an error when using pre-trained weights.

@logicwong logicwong self-assigned this Feb 5, 2024
@AndrewTKent
Copy link

@xxxxyliu Same situation for me - let me know if you figured out the solution.

@logicwong logicwong assigned simonJJJ and unassigned logicwong Apr 8, 2024
@AndrewTKent
Copy link

AndrewTKent commented Sep 20, 2024

@xxxxyliu Here we go:

def resize_abs_pos_embed(self, checkpoint):
    # Check for the correct key in the checkpoint
    pos_embed_key = 'backbone.image_adapter.pos_embed' if 'backbone.image_adapter.pos_embed' in checkpoint else 'image_adapter.pos_embed'
    
    pos_embed_checkpoint = checkpoint[pos_embed_key]
    embedding_size = pos_embed_checkpoint.shape[-1]
    bucket_size = self.image_adapter.bucket_size
    num_patches = bucket_size ** 2
    num_extra_tokens = self.image_adapter.pos_embed.shape[-2] - num_patches
    
    # Calculate original and new sizes for position embedding
    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
    new_size = int(num_patches ** 0.5)
    
    # Keep class_token and dist_token unchanged
    rank, _ = get_dist_info()
    
    if orig_size != new_size:
        if rank == 0:
            print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size}")
        
        extra_tokens = pos_embed_checkpoint[:num_extra_tokens]
        
        # Only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        
        pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)
        
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
        checkpoint[pos_embed_key] = new_pos_embed
    
    return checkpoint

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants