Skip to content

Commit

Permalink
Refactor utils #1 (#161)
Browse files Browse the repository at this point in the history
* remove unused

* add default layer_norm

* remove unused

* remove llava_bunny and idefics2 custom configs

* refactor molmo and qwen2 config

* add deprecation warning

* refactor update model configs

* refactor sanitize weights

* refactor class_predicate

* move custom config logic to from_dict

* uncomment

* fix config name

* rename aligner to projector

* fix tests

* remove module from update list

* add trusted remote as kwargs

* update baseImageProcessor

* refactor image processor

* pin latest transformers

* bump version

* refactor prepare inputs

* simplifiy image loading

* fix load_image and refactor load_config

* make skip_non_divisible a default

* skip non divisible default and rename model inputs

* refactor condition

* fix language input only

* add fetch KV

* Increase default max tokens to 256

* refactor generate, generate step and stream

* fix high usage and add language only support (#163)
  • Loading branch information
Blaizzy authored Dec 30, 2024
1 parent 050a6d7 commit 78920b0
Show file tree
Hide file tree
Showing 19 changed files with 325 additions and 490 deletions.
8 changes: 1 addition & 7 deletions mlx_vlm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@ def configure_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--skip-vision",
help="Skip vision quantization.",
action="store_true",
default=False,
)
parser.add_argument(
"--skip-vision-non-divisible",
help="Skip layers that are not divisible by 64 in vision encoder.",
help="Skip vision module quantization.",
action="store_true",
default=False,
)
Expand Down
22 changes: 9 additions & 13 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit"
DEFAULT_IMAGE = []
DEFAULT_PROMPT = "What are these?"
DEFAULT_MAX_TOKENS = 100
DEFAULT_MAX_TOKENS = 256
DEFAULT_TEMP = 0.5
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
Expand Down Expand Up @@ -64,22 +64,19 @@ def parse_arguments():

def get_model_and_processors(model_path, adapter_path):
model_path = get_model_path(model_path)
config = load_config(model_path)
config = load_config(model_path, trust_remote_code=True)
model, processor = load(
model_path, {"trust_remote_code": True}, adapter_path=adapter_path
model_path, adapter_path=adapter_path, lazy=False, trust_remote_code=True
)
image_processor = load_image_processor(model_path)
return model, processor, image_processor, config
return model, processor, config


def main():
args = parse_arguments()
if isinstance(args.image, str):
args.image = [args.image]

model, processor, image_processor, config = get_model_and_processors(
args.model, args.adapter_path
)
model, processor, config = get_model_and_processors(args.model, args.adapter_path)

prompt = codecs.decode(args.prompt, "unicode_escape")

Expand All @@ -95,12 +92,11 @@ def main():
output = generate(
model,
processor,
args.image,
prompt,
image_processor,
args.temp,
args.max_tokens,
args.verbose,
image=args.image,
temp=args.temp,
max_tokens=args.max_tokens,
verbose=args.verbose,
**kwargs,
)
if not args.verbose:
Expand Down
12 changes: 11 additions & 1 deletion mlx_vlm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mlx.core as mx
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor
from transformers.image_processing_utils import get_size_dict
from transformers.image_utils import ChannelDimension, PILImageResampling

Expand All @@ -22,7 +23,7 @@ def expand2square(pil_img, background_color):
return result


class BaseImageProcessor(ABC):
class BaseImageProcessor(ImageProcessor):
def __init__(
self,
image_mean=(0.5, 0.5, 0.5),
Expand Down Expand Up @@ -72,6 +73,9 @@ def update_and_fetch(self, keys, values):
self.update(keys, values)
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]

def fetch(self):
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]

def update(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
Expand Down Expand Up @@ -127,6 +131,9 @@ def update_and_fetch(self, keys, values):
self.cache_length += keys.shape[2]
return self.keys, self.values

def fetch(self):
return self.keys, self.values

def update(self, keys, values):
"""Update cache with new key/value tensors without returning.
Expand Down Expand Up @@ -167,6 +174,9 @@ def _trim(self, trim_size, v, append=None):
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)

def fetch(self):
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]

def update_and_fetch(self, keys, values):
prev = self.offset
B, _, S = keys.shape[:3]
Expand Down
11 changes: 10 additions & 1 deletion mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class ModelConfig:

@classmethod
def from_dict(cls, params):
if "language_config" in params:
params["text_config"] = params["language_config"]
del params["language_config"]

return cls(
**{
k: v
Expand Down Expand Up @@ -224,7 +228,12 @@ def __init__(self, config: ModelConfig):
f"len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}"
)
tile_variants_num = len(candidate_resolutions)
# self.tile_indicators = mx.array(mx.random(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std)
self.tile_indicators = mx.array(
mx.random.normal(
(tile_variants_num + 1, config.projector_config.n_embed)
)
* embed_std
)
else:
raise ValueError(
f"tile tag should be either 1D or 2D, but got {self.tile_tag}"
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_input_embeddings(
pixel_attention_mask: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.embed_tokens(input_ids)

inputs_embeds = self.language_model.embed_tokens(input_ids)

Expand Down
4 changes: 2 additions & 2 deletions mlx_vlm/models/idefics2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class VisionConfig:
num_attention_heads: int
image_size: int
patch_size: int
layer_norm_eps: float
layer_norm_eps: float = 1e-6
num_channels: int = 3

@classmethod
Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
self.model_type = config.model_type
if self.model_type != "idefics2":
if self.model_type not in ["idefics2", "idefics2_vision"]:
raise ValueError(f"Unsupported model type: {self.model_type}")
self.embeddings = VisionEmbeddings(config)
self.encoder = Encoder(config)
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/idefics3/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_input_embeddings(
pixel_attention_mask: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.embed_tokens(input_ids)

inputs_embeds = self.language_model.embed_tokens(input_ids)

Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_input_embeddings(
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
Expand Down
3 changes: 1 addition & 2 deletions mlx_vlm/models/llava_bunny/llava_bunny.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class ModelConfig:
auto_map: dict
hidden_size: int
mm_hidden_size: int
mm_vision_tower: str
mm_projector_type: str = "mlp2x_gelu"
ignore_index: int = -100
image_token_index: int = -200
Expand Down Expand Up @@ -132,7 +131,7 @@ def get_input_embeddings(
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

inputs_embeds = self.language_model.model.embed_tokens(input_ids)

Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/llava_next/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_input_embeddings(
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
Expand Down
19 changes: 6 additions & 13 deletions mlx_vlm/models/mllama/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def __call__(
) -> mx.array:

bsz, q_len, _ = hidden_states.shape
query_states = (
query = (
self.q_proj(hidden_states)
.reshape(bsz, q_len, self.num_heads, self.head_dim)
.transpose(0, 2, 1, 3)
)
query_states = self.q_norm(query_states)
query_states = self.q_norm(query)

if cross_attention_states is not None:
key_states = (
Expand All @@ -99,14 +99,11 @@ def __call__(
.transpose(0, 2, 1, 3)
)
key_states = self.k_norm(key_states)
if cache is not None:
key_states, value_states = cache.update_and_fetch(
key_states, value_states
)
elif cache is not None and cache.offset > 0:
key_states, value_states = cache.fetch()
else:
raise ValueError(
"Cross attention states must be provided for cross attention layer."
)
key_states, value_states = mx.split(query, 2, axis=1)
key_states = self.k_norm(key_states)

attn_output = mx.fast.scaled_dot_product_attention(
query_states,
Expand Down Expand Up @@ -338,10 +335,6 @@ def __call__(

for idx, (decoder_layer, c) in enumerate(zip(self.layers, cache)):
if idx in self.config.cross_attention_layers:
if cross_attention_states is None:
raise ValueError(
f"Cross attention states must be provided for layer {idx}"
)
layer_outputs = decoder_layer(
hidden_states,
cross_attention_states=cross_attention_states,
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/multi_modality/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .multi_modality import (
AlignerConfig,
ImageProcessor,
LanguageModel,
Model,
ModelConfig,
ProjectorConfig,
TextConfig,
VisionConfig,
VisionModel,
Expand Down
26 changes: 15 additions & 11 deletions mlx_vlm/models/multi_modality/multi_modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
import numpy as np
from huggingface_hub import snapshot_download
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import to_numpy_array

from ..base import expand2square
from ..base import BaseImageProcessor, expand2square
from .language import LanguageModel, TextConfig
from .vision import VisionConfig, VisionModel


@dataclass
class AlignerConfig:
class ProjectorConfig:
cls: str
model_type: str
params: dict
Expand All @@ -39,7 +39,7 @@ def from_dict(cls, params):
class ModelConfig:
text_config: TextConfig
vision_config: VisionConfig
aligner_config: AlignerConfig
projector_config: ProjectorConfig
model_type: str
ignore_index: int = -100
image_token_index: int = 100015
Expand All @@ -51,6 +51,10 @@ class ModelConfig:

@classmethod
def from_dict(cls, params):
if "aligner_config" in params:
params["projector_config"] = params["aligner_config"]
del params["aligner_config"]

return cls(
**{
k: v
Expand Down Expand Up @@ -174,15 +178,15 @@ class MlpProjector(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()

if config.aligner_config.params["projector_type"] == "mlp_gelu":
if config.projector_config.params["projector_type"] == "mlp_gelu":
self.layers = [
nn.Linear(
config.vision_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
]
mlp_depth = config.aligner_config.params["depth"]
mlp_depth = config.projector_config.params["depth"]
for _ in range(1, mlp_depth):
self.layers.append(nn.GELU())
self.layers.append(
Expand All @@ -193,10 +197,10 @@ def __init__(self, config: ModelConfig):
)
)
elif (
config.aligner_config.params["projector_type"]
config.projector_config.params["projector_type"]
== "low_high_hybrid_split_mlp_gelu"
):
mlp_depth = config.aligner_config.params["depth"]
mlp_depth = config.projector_config.params["depth"]
self.high_up_proj = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size // 2
)
Expand All @@ -214,7 +218,7 @@ def __init__(self, config: ModelConfig):
)

else:
projector_type = config.aligner_config.params["projector_type"]
projector_type = config.projector_config.params["projector_type"]
raise ValueError(f"Unknown projector type: {projector_type}")

def __call__(self, x: Union[mx.array, Tuple]) -> mx.array:
Expand Down Expand Up @@ -398,8 +402,8 @@ def from_pretrained(path_or_hf_repo: str):
model_config = ModelConfig.from_dict(model_config)

model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
model_config.aligner_config = AlignerConfig.from_dict(
model_config.aligner_config
model_config.projector_config = ProjectorConfig.from_dict(
model_config.projector_config
)
model_config.text_config = TextConfig.from_dict(model_config.text_config)

Expand Down
4 changes: 2 additions & 2 deletions mlx_vlm/models/qwen2_vl/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __post_init__(self):
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")

if not self.rope_scaling["type"] == "mrope":
raise ValueError(f"rope_scaling type must be 'mrope'")
if not self.rope_scaling["type"] in ["mrope", "default"]:
raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")

@classmethod
def from_dict(cls, params):
Expand Down
Loading

0 comments on commit 78920b0

Please sign in to comment.