Skip to content

Commit

Permalink
remove deepspeed, some fixes, and llava
Browse files Browse the repository at this point in the history
  • Loading branch information
Anas Awadalla committed Feb 22, 2024
1 parent eb6b8aa commit 1e75320
Show file tree
Hide file tree
Showing 22 changed files with 203 additions and 392 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ To instantiate an OpenFlamingo model with one of our released weights, initializ
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-4B-vitl-rpj3b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)
```

Expand Down
1 change: 1 addition & 0 deletions open_flamingo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .src.flamingo import Flamingo
from .src.kosmos import Kosmos
from .src.blip import BLIP
from .src.llava import Llava
from .src.factory import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES
2 changes: 0 additions & 2 deletions open_flamingo/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ To help standardize VLM evaluations, we have implemented EvalModel wrappers for
## Distributed evaluation
Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun` (see sample scripts section below).

We have also implemented distributed evaluation using Deepspeed, which additionally shards model parameters across GPUs for memory savings. To use Deepspeed instead of DDP, use the `--deepspeed` flag.

We also support evaluating at a lower precision using the `--precision` flag. We find minimal difference between evaluating at full precision vs. amp_bf16.

## Sample scripts
Expand Down
33 changes: 4 additions & 29 deletions open_flamingo/eval/eval_models/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_eval_model(name, *args, **kwargs):
class BaseEvalModel(abc.ABC):
"""Base class encapsulating functionality needed to evaluate a model."""

def __init__(self, model_args: List[str], init_on_device=False):
def __init__(self, model_args: List[str]):
"""Initialize model.
Args:
Expand Down Expand Up @@ -59,17 +59,6 @@ def __init__(self, model_args: List[str], init_on_device=False):
self.autocast = get_autocast(self.precision)
self.cast_dtype = get_cast_dtype(self.precision)

# initialization context
if init_on_device:
# for deepspeed, must init on device, or likely CPU OOM
import deepspeed

self.init_ctx = deepspeed.OnDevice(
dtype=self.cast_dtype, device=self.device
)
else:
self.init_ctx = suppress()

@property
def required_args(self):
"""Return list of required arguments to initialize model."""
Expand All @@ -83,23 +72,9 @@ def _check_init(self):
assert hasattr(self, "tokenizer"), "Tokenizer has not been initialized"
self.tokenizer.padding_side = "left"

def init_distributed(self, world_size=None, use_deepspeed=False):
"""Wrap model as DDP or deepspeed."""
if use_deepspeed:
assert "amp" not in self.precision, "Deepspeed does not support amp"
import deepspeed

self.ds_engine = deepspeed.init_inference(
self.model,
mp_size=world_size,
dtype=self.cast_dtype,
checkpoint=None,
replace_with_kernel_inject=True,
)
self.model = self.ds_engine.module
self.autocast = get_autocast(None)
else:
self.model = DDP(self.model, device_ids=[self.device])
def init_distributed(self):
"""Wrap model as DDP."""
self.model = DDP(self.model, device_ids=[self.device])

def __call__(
self,
Expand Down
10 changes: 1 addition & 9 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,6 @@
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)
parser.add_argument(
"--deepspeed",
default=False,
action="store_true",
help="Whether to use deepspeed for distributed inference.",
)


def main():
Expand All @@ -414,11 +408,9 @@ def main():
model_args["device"] = device_id

# initialize model
eval_model = get_eval_model(args.model, model_args, init_on_device=args.deepspeed)
eval_model = get_eval_model(args.model, model_args, init_on_device=False)
eval_model.init_distributed(
local_rank=args.local_rank,
world_size=args.world_size,
use_deepspeed=args.deepspeed,
)

# Validate args
Expand Down
77 changes: 0 additions & 77 deletions open_flamingo/scripts/run_eval_deepspeed.sh

This file was deleted.

41 changes: 0 additions & 41 deletions open_flamingo/scripts/run_train_deepspeed.sh

This file was deleted.

1 change: 1 addition & 0 deletions open_flamingo/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .helpers import VLMOutputWithPast
21 changes: 17 additions & 4 deletions open_flamingo/src/cross_attn_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,16 @@ def init_cross_attention_layers(
"""
Add gated cross attn layers to the decoder.
"""
self.old_decoder_blocks = self._get_decoder_layers()
old_decoder_blocks = self._get_decoder_layers()
self.decoder_block_class = old_decoder_blocks[0].__class__
self.gated_cross_attn_layers = nn.ModuleList(
[
GatedCrossAttentionBlock(
dim=lang_hidden_size, dim_visual=vis_hidden_size
)
if (layer_idx + 1) % cross_attn_every_n_layers == 0
else None
for layer_idx, _ in enumerate(self._get_decoder_layers())
for layer_idx, _ in enumerate(old_decoder_blocks)
]
)
self._set_decoder_layers(
Expand All @@ -106,7 +107,7 @@ def init_cross_attention_layers(
gated_cross_attn_layer, decoder_layer, gradient_checkpointing
)
for gated_cross_attn_layer, decoder_layer in zip(
self.gated_cross_attn_layers, self.old_decoder_blocks
self.gated_cross_attn_layers, old_decoder_blocks
)
]
)
Expand All @@ -119,11 +120,14 @@ def _condition_media_before_forward(
vision_tokens: torch.Tensor = None,
past_media_locations: torch.Tensor = None,
past_vision_tokens: torch.Tensor = None,
num_beams: int = 1,
):
"""Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence"""
assert (
self.initialized_cross_attention
), "Cross attention layers have not been initialized. "

# concat with past
if past_media_locations is not None and past_vision_tokens is not None:
if vision_tokens is not None:
updated_vision_tokens = torch.cat(
Expand All @@ -146,6 +150,15 @@ def _condition_media_before_forward(
updated_vision_tokens = vision_tokens
updated_media_locations = input_ids == self.media_token_id

# repeat the vision tokens and media locations for each beam
updated_vision_tokens = updated_vision_tokens.repeat_interleave(
num_beams, dim=0
)
updated_media_locations = updated_media_locations.repeat_interleave(
num_beams, dim=0
)

# condition
for layer in self._get_decoder_layers():
layer.condition_vis_x(updated_vision_tokens)
layer.condition_media_locations(updated_media_locations)
Expand All @@ -157,4 +170,4 @@ def is_conditioned(self) -> bool:
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
layer.condition_media_locations(None)
63 changes: 21 additions & 42 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Optional
import torch.nn as nn

from transformers import AutoModelForCausalLM, AutoTokenizer
import open_clip

from .flamingo import Flamingo
from .kosmos import Kosmos
from .blip import BLIP
from .llava import Llava
from .utils import hasattr_recursive, setattr_recursive

SUPPORTED_MODEL_FAMILIES = ("flamingo", "kosmos", "blip")

SUPPORTED_MODEL_FAMILIES = ("flamingo", "kosmos", "blip", "llava")
MODEL_FAMILY_TO_CLASS = {
"flamingo": Flamingo,
"kosmos": Kosmos,
"blip": BLIP,
"llava": Llava,
}

def create_model_and_transforms(
clip_vision_encoder_path: str,
Expand Down Expand Up @@ -83,41 +88,16 @@ def create_model_and_transforms(
if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_model)

if model_family == "flamingo":
model = Flamingo(
vision_encoder=vision_encoder,
lang_model=lang_model,
vis_feature_dim=vis_hidden_dim,
initial_tokenizer_len=len(text_tokenizer),
gradient_checkpointing=gradient_checkpointing,
decoder_layers_attr_name=decoder_layers_attr_name,
pad_token_id=text_tokenizer.pad_token_id,
**model_kwargs,
)

elif model_family == "kosmos":
model = Kosmos(
vision_encoder=vision_encoder,
lang_model=lang_model,
vis_feature_dim=vis_hidden_dim,
initial_tokenizer_len=len(text_tokenizer),
gradient_checkpointing=gradient_checkpointing,
pad_token_id=text_tokenizer.pad_token_id,
decoder_layers_attr_name=decoder_layers_attr_name,
**model_kwargs,
)

elif model_family == "blip":
model = BLIP(
vision_encoder=vision_encoder,
lang_model=lang_model,
vis_feature_dim=vis_hidden_dim,
initial_tokenizer_len=len(text_tokenizer),
gradient_checkpointing=gradient_checkpointing,
pad_token_id=text_tokenizer.pad_token_id,
decoder_layers_attr_name=decoder_layers_attr_name,
**model_kwargs,
)
model = MODEL_FAMILY_TO_CLASS[model_family](
vision_encoder=vision_encoder,
lang_model=lang_model,
vis_feature_dim=vis_hidden_dim,
initial_tokenizer_len=len(text_tokenizer),
gradient_checkpointing=gradient_checkpointing,
decoder_layers_attr_name=decoder_layers_attr_name,
pad_token_id=text_tokenizer.pad_token_id,
**model_kwargs,
)

# add special tokens to the tokenizer and language models
text_tokenizer.add_special_tokens(
Expand All @@ -130,7 +110,6 @@ def create_model_and_transforms(
for v in model.special_tokens.values()
}
)

# freeze appropriate parameters
model.set_trainable()

Expand All @@ -139,8 +118,8 @@ def create_model_and_transforms(
print(
f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters"
)
print(f"========== Trainable Parameters\n{model.num_trainable_params_per_module}")
print(f"========== Total Parameters\n{model.num_params_per_module}\n==========")
print(f"==========Trainable Parameters\n{model.num_trainable_params_per_module}")
print(f"==========Total Parameters\n{model.num_params_per_module}\n==========")
return model, image_processor, text_tokenizer


Expand Down Expand Up @@ -220,4 +199,4 @@ def has_fn(model, fn_name):
getattr(model, fn_name)()
return True
except:
return False
return False
Loading

0 comments on commit 1e75320

Please sign in to comment.