Skip to content

Commit

Permalink
train_utils media token fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Lo committed Dec 2, 2023
1 parent c5feb97 commit dbb1ad8
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion open_flamingo/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from data_utils import DataInfo
import random
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn as nn


def train_one_epoch(
Expand Down Expand Up @@ -77,8 +79,11 @@ def train_one_epoch(
batch_metadata_to_log[
f"{datasets[dataset_ix].name}_num_tokens"
] = attention_mask.sum().item()
model = unwrap_model(model)
model.media_token_id = 400
model = DDP(model)
batch_metadata_to_log[f"{datasets[dataset_ix].name}_num_images"] = (
(input_ids == model.media_token_id).sum().item()
(input_ids == model.module.media_token_id).sum().item()
)

# forward pass
Expand Down Expand Up @@ -188,6 +193,16 @@ def random_seed(seed=42, rank=0):
random.seed(seed + rank)


def unwrap_model(model):
"""
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
"""
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return model.module
else:
return model


################################
# Helper functions for logging #
################################
Expand Down

0 comments on commit dbb1ad8

Please sign in to comment.