diff --git a/mlx_vlm/convert.py b/mlx_vlm/convert.py index 11d2959..5952a88 100644 --- a/mlx_vlm/convert.py +++ b/mlx_vlm/convert.py @@ -51,13 +51,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--skip-vision", - help="Skip vision quantization.", - action="store_true", - default=False, - ) - parser.add_argument( - "--skip-vision-non-divisible", - help="Skip layers that are not divisible by 64 in vision encoder.", + help="Skip vision module quantization.", action="store_true", default=False, ) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index b91d60a..549424b 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -7,7 +7,7 @@ DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit" DEFAULT_IMAGE = [] DEFAULT_PROMPT = "What are these?" -DEFAULT_MAX_TOKENS = 100 +DEFAULT_MAX_TOKENS = 256 DEFAULT_TEMP = 0.5 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 @@ -64,12 +64,11 @@ def parse_arguments(): def get_model_and_processors(model_path, adapter_path): model_path = get_model_path(model_path) - config = load_config(model_path) + config = load_config(model_path, trust_remote_code=True) model, processor = load( - model_path, {"trust_remote_code": True}, adapter_path=adapter_path + model_path, adapter_path=adapter_path, lazy=False, trust_remote_code=True ) - image_processor = load_image_processor(model_path) - return model, processor, image_processor, config + return model, processor, config def main(): @@ -77,9 +76,7 @@ def main(): if isinstance(args.image, str): args.image = [args.image] - model, processor, image_processor, config = get_model_and_processors( - args.model, args.adapter_path - ) + model, processor, config = get_model_and_processors(args.model, args.adapter_path) prompt = codecs.decode(args.prompt, "unicode_escape") @@ -95,12 +92,11 @@ def main(): output = generate( model, processor, - args.image, prompt, - image_processor, - args.temp, - args.max_tokens, - args.verbose, + image=args.image, + temp=args.temp, + max_tokens=args.max_tokens, + verbose=args.verbose, **kwargs, ) if not args.verbose: diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index f7440a3..ee47eb6 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -4,6 +4,7 @@ import mlx.core as mx from PIL import Image +from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor from transformers.image_processing_utils import get_size_dict from transformers.image_utils import ChannelDimension, PILImageResampling @@ -22,7 +23,7 @@ def expand2square(pil_img, background_color): return result -class BaseImageProcessor(ABC): +class BaseImageProcessor(ImageProcessor): def __init__( self, image_mean=(0.5, 0.5, 0.5), @@ -72,6 +73,9 @@ def update_and_fetch(self, keys, values): self.update(keys, values) return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + def fetch(self): + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + def update(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: @@ -127,6 +131,9 @@ def update_and_fetch(self, keys, values): self.cache_length += keys.shape[2] return self.keys, self.values + def fetch(self): + return self.keys, self.values + def update(self, keys, values): """Update cache with new key/value tensors without returning. @@ -167,6 +174,9 @@ def _trim(self, trim_size, v, append=None): to_cat.append(append) return mx.concatenate(to_cat, axis=2) + def fetch(self): + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + def update_and_fetch(self, keys, values): prev = self.offset B, _, S = keys.shape[:3] diff --git a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py index 075fb0b..1796083 100644 --- a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +++ b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py @@ -62,6 +62,10 @@ class ModelConfig: @classmethod def from_dict(cls, params): + if "language_config" in params: + params["text_config"] = params["language_config"] + del params["language_config"] + return cls( **{ k: v @@ -224,7 +228,12 @@ def __init__(self, config: ModelConfig): f"len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}" ) tile_variants_num = len(candidate_resolutions) - # self.tile_indicators = mx.array(mx.random(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std) + self.tile_indicators = mx.array( + mx.random.normal( + (tile_variants_num + 1, config.projector_config.n_embed) + ) + * embed_std + ) else: raise ValueError( f"tile tag should be either 1D or 2D, but got {self.tile_tag}" diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 52085dd..6de5642 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -212,7 +212,7 @@ def get_input_embeddings( pixel_attention_mask: Optional[mx.array] = None, ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.embed_tokens(input_ids) inputs_embeds = self.language_model.embed_tokens(input_ids) diff --git a/mlx_vlm/models/idefics2/vision.py b/mlx_vlm/models/idefics2/vision.py index 2c7fb91..fb9bfd3 100644 --- a/mlx_vlm/models/idefics2/vision.py +++ b/mlx_vlm/models/idefics2/vision.py @@ -16,7 +16,7 @@ class VisionConfig: num_attention_heads: int image_size: int patch_size: int - layer_norm_eps: float + layer_norm_eps: float = 1e-6 num_channels: int = 3 @classmethod @@ -217,7 +217,7 @@ def __init__(self, config: VisionConfig): super().__init__() self.config = config self.model_type = config.model_type - if self.model_type != "idefics2": + if self.model_type not in ["idefics2", "idefics2_vision"]: raise ValueError(f"Unsupported model type: {self.model_type}") self.embeddings = VisionEmbeddings(config) self.encoder = Encoder(config) diff --git a/mlx_vlm/models/idefics3/idefics3.py b/mlx_vlm/models/idefics3/idefics3.py index 936a344..025909a 100644 --- a/mlx_vlm/models/idefics3/idefics3.py +++ b/mlx_vlm/models/idefics3/idefics3.py @@ -98,7 +98,7 @@ def get_input_embeddings( pixel_attention_mask: Optional[mx.array] = None, ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.embed_tokens(input_ids) inputs_embeds = self.language_model.embed_tokens(input_ids) diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 39298a9..696d9db 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -70,7 +70,7 @@ def get_input_embeddings( pixel_values: Optional[mx.array] = None, ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.model.embed_tokens(input_ids) # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index bbbb5cd..a9c6ad0 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -35,7 +35,6 @@ class ModelConfig: auto_map: dict hidden_size: int mm_hidden_size: int - mm_vision_tower: str mm_projector_type: str = "mlp2x_gelu" ignore_index: int = -100 image_token_index: int = -200 @@ -132,7 +131,7 @@ def get_input_embeddings( pixel_values: Optional[mx.array] = None, ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.model.embed_tokens(input_ids) inputs_embeds = self.language_model.model.embed_tokens(input_ids) diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index 29abea1..f10649f 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -75,7 +75,7 @@ def get_input_embeddings( pixel_values: Optional[mx.array] = None, ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.model.embed_tokens(input_ids) # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) diff --git a/mlx_vlm/models/mllama/language.py b/mlx_vlm/models/mllama/language.py index 0b53b39..de89b51 100644 --- a/mlx_vlm/models/mllama/language.py +++ b/mlx_vlm/models/mllama/language.py @@ -80,12 +80,12 @@ def __call__( ) -> mx.array: bsz, q_len, _ = hidden_states.shape - query_states = ( + query = ( self.q_proj(hidden_states) .reshape(bsz, q_len, self.num_heads, self.head_dim) .transpose(0, 2, 1, 3) ) - query_states = self.q_norm(query_states) + query_states = self.q_norm(query) if cross_attention_states is not None: key_states = ( @@ -99,14 +99,11 @@ def __call__( .transpose(0, 2, 1, 3) ) key_states = self.k_norm(key_states) - if cache is not None: - key_states, value_states = cache.update_and_fetch( - key_states, value_states - ) + elif cache is not None and cache.offset > 0: + key_states, value_states = cache.fetch() else: - raise ValueError( - "Cross attention states must be provided for cross attention layer." - ) + key_states, value_states = mx.split(query, 2, axis=1) + key_states = self.k_norm(key_states) attn_output = mx.fast.scaled_dot_product_attention( query_states, @@ -338,10 +335,6 @@ def __call__( for idx, (decoder_layer, c) in enumerate(zip(self.layers, cache)): if idx in self.config.cross_attention_layers: - if cross_attention_states is None: - raise ValueError( - f"Cross attention states must be provided for layer {idx}" - ) layer_outputs = decoder_layer( hidden_states, cross_attention_states=cross_attention_states, diff --git a/mlx_vlm/models/multi_modality/__init__.py b/mlx_vlm/models/multi_modality/__init__.py index daf0a68..ee1c212 100644 --- a/mlx_vlm/models/multi_modality/__init__.py +++ b/mlx_vlm/models/multi_modality/__init__.py @@ -1,9 +1,9 @@ from .multi_modality import ( - AlignerConfig, ImageProcessor, LanguageModel, Model, ModelConfig, + ProjectorConfig, TextConfig, VisionConfig, VisionModel, diff --git a/mlx_vlm/models/multi_modality/multi_modality.py b/mlx_vlm/models/multi_modality/multi_modality.py index c1d9df9..130edf6 100644 --- a/mlx_vlm/models/multi_modality/multi_modality.py +++ b/mlx_vlm/models/multi_modality/multi_modality.py @@ -10,16 +10,16 @@ import numpy as np from huggingface_hub import snapshot_download from PIL import Image -from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_processing_utils import BatchFeature from transformers.image_utils import to_numpy_array -from ..base import expand2square +from ..base import BaseImageProcessor, expand2square from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @dataclass -class AlignerConfig: +class ProjectorConfig: cls: str model_type: str params: dict @@ -39,7 +39,7 @@ def from_dict(cls, params): class ModelConfig: text_config: TextConfig vision_config: VisionConfig - aligner_config: AlignerConfig + projector_config: ProjectorConfig model_type: str ignore_index: int = -100 image_token_index: int = 100015 @@ -51,6 +51,10 @@ class ModelConfig: @classmethod def from_dict(cls, params): + if "aligner_config" in params: + params["projector_config"] = params["aligner_config"] + del params["aligner_config"] + return cls( **{ k: v @@ -174,7 +178,7 @@ class MlpProjector(nn.Module): def __init__(self, config: ModelConfig): super().__init__() - if config.aligner_config.params["projector_type"] == "mlp_gelu": + if config.projector_config.params["projector_type"] == "mlp_gelu": self.layers = [ nn.Linear( config.vision_config.hidden_size, @@ -182,7 +186,7 @@ def __init__(self, config: ModelConfig): bias=True, ) ] - mlp_depth = config.aligner_config.params["depth"] + mlp_depth = config.projector_config.params["depth"] for _ in range(1, mlp_depth): self.layers.append(nn.GELU()) self.layers.append( @@ -193,10 +197,10 @@ def __init__(self, config: ModelConfig): ) ) elif ( - config.aligner_config.params["projector_type"] + config.projector_config.params["projector_type"] == "low_high_hybrid_split_mlp_gelu" ): - mlp_depth = config.aligner_config.params["depth"] + mlp_depth = config.projector_config.params["depth"] self.high_up_proj = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size // 2 ) @@ -214,7 +218,7 @@ def __init__(self, config: ModelConfig): ) else: - projector_type = config.aligner_config.params["projector_type"] + projector_type = config.projector_config.params["projector_type"] raise ValueError(f"Unknown projector type: {projector_type}") def __call__(self, x: Union[mx.array, Tuple]) -> mx.array: @@ -398,8 +402,8 @@ def from_pretrained(path_or_hf_repo: str): model_config = ModelConfig.from_dict(model_config) model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.aligner_config = AlignerConfig.from_dict( - model_config.aligner_config + model_config.projector_config = ProjectorConfig.from_dict( + model_config.projector_config ) model_config.text_config = TextConfig.from_dict(model_config.text_config) diff --git a/mlx_vlm/models/qwen2_vl/language.py b/mlx_vlm/models/qwen2_vl/language.py index 96b970f..1295de2 100644 --- a/mlx_vlm/models/qwen2_vl/language.py +++ b/mlx_vlm/models/qwen2_vl/language.py @@ -34,8 +34,8 @@ def __post_init__(self): if not all(key in self.rope_scaling for key in required_keys): raise ValueError(f"rope_scaling must contain keys {required_keys}") - if not self.rope_scaling["type"] == "mrope": - raise ValueError(f"rope_scaling type must be 'mrope'") + if not self.rope_scaling["type"] in ["mrope", "default"]: + raise ValueError(f"rope_scaling type must be 'mrope' or 'default'") @classmethod def from_dict(cls, params): diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index 7ad3ea4..bc90762 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -27,6 +27,12 @@ class ModelConfig: @classmethod def from_dict(cls, params): + # Copy text config parameters from root level + excluded_keys = {"vision_config"} + params["text_config"] = dict( + filter(lambda x: x[0] not in excluded_keys, params.items()) + ) + return cls( **{ k: v @@ -51,7 +57,10 @@ def get_input_embeddings( ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.model.embed_tokens(input_ids) + + dtype = self.vision_tower.patch_embed.proj.weight.dtype + pixel_values = pixel_values.astype(dtype) # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) @@ -91,10 +100,8 @@ def __call__( **kwargs, ): image_grid_thw = kwargs.pop("image_grid_thw", None) - image_grid_thw = mx.array(image_grid_thw) - - dtype = self.vision_tower.patch_embed.proj.weight.dtype - pixel_values = pixel_values.astype(dtype) + if image_grid_thw is not None: + image_grid_thw = mx.array(image_grid_thw) input_embddings = self.get_input_embeddings( input_ids, pixel_values, image_grid_thw diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 7f30d05..dc56dcb 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -132,7 +132,6 @@ def test_llava_bunny(self): }, hidden_size=1024, mm_hidden_size=1152, - mm_vision_tower="google/siglip-so400m-patch14-384", mm_projector_type="mlp2x_gelu", ignore_index=-100, image_token_index=-200, @@ -499,9 +498,9 @@ def test_multi_modality(self): params={}, ) - aligner_config = multi_modality.AlignerConfig( + projector_config = multi_modality.ProjectorConfig( cls="MlpProjector", - model_type="aligner", + model_type="projector", params={ "depth": 2, "input_dim": 1024, @@ -513,7 +512,7 @@ def test_multi_modality(self): config = multi_modality.ModelConfig( text_config=text_config, vision_config=vision_config, - aligner_config=aligner_config, + projector_config=projector_config, model_type="multi_modality", ignore_index=-100, image_token_index=100015, diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 9e57956..218ffef 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -5,11 +5,11 @@ import logging import shutil import time -from dataclasses import asdict +from dataclasses import dataclass from io import BytesIO from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -17,7 +17,7 @@ import requests from huggingface_hub import snapshot_download from mlx.utils import tree_flatten, tree_unflatten -from PIL import Image +from PIL import Image, ImageOps from transformers import ( AutoConfig, AutoProcessor, @@ -35,11 +35,17 @@ MAX_FILE_SIZE_GB = 5 -linear_class_predicate = ( - lambda m: isinstance(m, nn.Linear) - and m.weight.shape[0] - != 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models -) + +@dataclass +class GenerationResult: + text: str + token: Optional[int] + logprobs: Optional[List[float]] + prompt_tokens: int + generation_tokens: int + prompt_tps: float + generation_tps: float + peak_memory: float def get_model_and_args(config: dict): @@ -96,7 +102,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def load_model(model_path: Path, lazy: bool = False) -> nn.Module: +def load_model(model_path: Path, lazy: bool = False, **kwargs) -> nn.Module: """ Load and initialize the model from a given path. @@ -113,8 +119,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. """ - - config = load_config(model_path) + config = load_config(model_path, **kwargs) quantization = config.get("quantization", None) weight_files = glob.glob(str(model_path / "*.safetensors")) @@ -144,109 +149,35 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: for wf in weight_files: weights.update(mx.load(wf)) - if "language_config" in config: - config["text_config"] = config["language_config"] - del config["language_config"] - model_class, model_type = get_model_and_args(config=config) - if "vision_config" in config: - skip_vision = config["vision_config"].get("skip_vision", False) - skip_vision_non_divisible = config["vision_config"].get( - "skip_vision_non_divisible", False - ) - else: - skip_vision = False - skip_vision_non_divisible = False - - if model_type == "llava_bunny": - vision_config = AutoConfig.from_pretrained(config["mm_vision_tower"]) - text_config = AutoConfig.from_pretrained(config["language_model"]) - vision_config = vision_config.to_dict() - text_config = text_config.to_dict() - config["vision_config"] = { - **vision_config["vision_config"], - **config.get("vision_config", {}), - } - config["text_config"] = text_config - if model_type == "idefics2": - config = AutoConfig.from_pretrained(model_path).to_dict() - if model_type == "phi3_v": - config["vision_config"] = config["img_processor"] - config["text_config"] = {} - if model_type == "qwen2_vl": - config["text_config"] = { - k: v for k, v in config.items() if k != "vision_config" - } - - if model_type == "molmo": - intermediate_size = None - if "vision_config" in config and "intermediate_size" in config["vision_config"]: - intermediate_size = config["vision_config"]["intermediate_size"] + # Initialize text and vision configs if not present + config.setdefault("text_config", {}) + config.setdefault("vision_config", {}) - config["text_config"] = asdict(model_class.TextConfig()) - config["vision_config"] = asdict(model_class.VisionConfig()) - config["vision_config"]["intermediate_size"] = intermediate_size - - config["vision_config"]["skip_vision"] = skip_vision - config["vision_config"]["skip_vision_non_divisible"] = skip_vision_non_divisible + # Get vision config settings with defaults + vision_config = config.get("vision_config", {}) + skip_vision = vision_config.get("skip_vision", False) + # Initialize model config and update it with module configs model_config = model_class.ModelConfig.from_dict(config) - - model_config.vision_config = model_class.VisionConfig.from_dict( - config["vision_config"] - ) - - model_config.text_config = model_class.TextConfig.from_dict(config["text_config"]) - - if hasattr(model_config, "perceiver_config"): - model_config.perceiver_config = model_class.PerceiverConfig.from_dict( - config["perceiver_config"] - ) - if hasattr(model_config, "aligner_config"): - model_config.aligner_config = model_class.AlignerConfig.from_dict( - config["aligner_config"] - ) - - if hasattr(model_config, "projector_config"): - model_config.projector_config = model_class.ProjectorConfig.from_dict( - config["projector_config"] - ) + modules = ["text", "vision", "perceiver", "projector"] + model_config = update_module_configs(model_config, model_class, config, modules) model = model_class.Model(model_config) - if hasattr(model, "sanitize"): - weights = model.sanitize(weights) - - if hasattr(model_class.VisionModel, "sanitize"): - weights = model_class.VisionModel(model_config.vision_config).sanitize( - weights=weights - ) - - if hasattr(model_class.LanguageModel, "sanitize"): - weights = model_class.LanguageModel(model_config.text_config).sanitize( - weights=weights - ) + # Sanitize weights + weights = sanitize_weights(model, weights) + weights = sanitize_weights( + model_class.VisionModel, weights, model_config.vision_config + ) + weights = sanitize_weights( + model_class.LanguageModel, weights, model_config.text_config + ) if (quantization := config.get("quantization", None)) is not None: - # Handle legacy models which may not have everything quantized - skip_vision = config.get("vision_config", {}).get("skip_vision", False) - skip_vision_non_divisible = config.get("vision_config", {}).get( - "skip_vision_non_divisible", False - ) - if skip_vision: - class_predicate = lambda _, m: not ( - "vision_model" in m.name or "vision_tower" in m.name - ) - elif skip_vision_non_divisible: - class_predicate = ( - lambda _, m: hasattr(m, "to_quantized") and m.weight.shape[-1] % 64 == 0 - ) - else: - class_predicate = ( - lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) - and f"{p}.scales" in weights - ) + # Handle legacy models which may not have everything quantized` + class_predicate = get_class_predicate(skip_vision, weights) nn.quantize( model, @@ -262,11 +193,58 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: return model +def sanitize_weights(model_obj, weights, config=None): + """Helper function to sanitize weights if the model has a sanitize method""" + if hasattr(model_obj, "sanitize"): + if config is not None: + model_obj = model_obj(config) + weights = model_obj.sanitize(weights) + return weights + + +def update_module_configs(model_config, model_class, config, modules): + """Updates configuration for model modules like text and vision modules. + + Args: + model_config: The model configuration object that will be updated + model_class: The model class containing component config classes + config: Dictionary containing configuration parameters + modules: List of module names to update configs for (e.g. ["text", "vision"]) + + Returns: + The updated model_config object + """ + for config_name in modules: + config_attr = f"{config_name}_config" + if hasattr(model_config, config_attr): + config_class = getattr(model_class, f"{config_name.title()}Config") + setattr( + model_config, config_attr, config_class.from_dict(config[config_attr]) + ) + return model_config + + +def get_class_predicate(skip_vision, weights=None): + if skip_vision: + return lambda _, m: not ("vision_model" in m.name or "vision_tower" in m.name) + else: + if weights: + return lambda p, m: ( + hasattr(m, "to_quantized") + and m.weight.shape[-1] % 64 == 0 + and f"{p}.scales" in weights + ) + else: + return ( + lambda _, m: hasattr(m, "to_quantized") and m.weight.shape[-1] % 64 == 0 + ) + + def load( path_or_hf_repo: str, - processor_config={}, adapter_path: Optional[str] = None, lazy: bool = False, + **kwargs, ) -> Tuple[nn.Module, Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]: """ Load the model and tokenizer from a given path or a huggingface repository. @@ -289,36 +267,56 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model = load_model(model_path, lazy) + model = load_model(model_path, lazy, **kwargs) if adapter_path is not None: # TODO: Support more modules than just language_model model = apply_lora_layers(model, adapter_path) model.eval() - processor = load_processor(model_path, processor_config=processor_config) + image_processor = load_image_processor(model_path, **kwargs) + processor = load_processor(model_path, True, **kwargs) + + if image_processor is not None: + processor.image_processor = image_processor return model, processor -def load_config(model_path: Union[str, Path]) -> dict: +def load_config(model_path: Union[str, Path], **kwargs) -> dict: + """Load model configuration from a path or Hugging Face repo. + + Args: + model_path: Local path or Hugging Face repo ID to load config from + **kwargs: Additional keyword arguments to pass to the config loader + + Returns: + dict: Model configuration + Raises: + FileNotFoundError: If config.json is not found at the path + """ if isinstance(model_path, str): model_path = get_model_path(model_path) try: - with open(model_path / "config.json", "r") as f: - config = json.load(f) - except FileNotFoundError: - logging.error(f"Config file not found in {model_path}") - raise - return config + return AutoConfig.from_pretrained(model_path, **kwargs).to_dict() + except ValueError: + try: + with open(model_path / "config.json", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError as exc: + raise FileNotFoundError(f"Config not found at {model_path}") from exc -def load_image_processor(model_path: Union[str, Path]) -> BaseImageProcessor: +def load_image_processor(model_path: Union[str, Path], **kwargs) -> BaseImageProcessor: if isinstance(model_path, str): model_path = get_model_path(model_path) - config = load_config(model_path) + if not kwargs: + config = load_config(model_path, trust_remote_code=True) + else: + config = load_config(model_path, **kwargs) + model_class, _ = get_model_and_args(config) image_processor = None @@ -336,9 +334,10 @@ def load_image_processor(model_path: Union[str, Path]) -> BaseImageProcessor: def load_processor( - model_path, processor_config={"trust_remote_code": True}, add_detokenizer=True + model_path, add_detokenizer=True, **kwargs ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - processor = AutoProcessor.from_pretrained(model_path, **processor_config) + + processor = AutoProcessor.from_pretrained(model_path, **kwargs) if add_detokenizer: detokenizer_class = load_tokenizer(model_path, return_tokenizer=False) if "tokenizer" in processor.__dict__.keys(): @@ -349,11 +348,11 @@ def load_processor( def fetch_from_hub( - model_path: Path, lazy: bool = False + model_path: Path, lazy: bool = False, **kwargs ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - model = load_model(model_path, lazy) - config = load_config(model_path) - processor = load_processor(model_path, add_detokenizer=False) + model = load_model(model_path, lazy, **kwargs) + config = load_config(model_path, **kwargs) + processor = load_processor(model_path, add_detokenizer=False, **kwargs) return model, config, processor @@ -541,7 +540,6 @@ def quantize_model( q_group_size: int, q_bits: int, skip_vision: bool = False, - skip_vision_non_divisible: bool = False, ) -> Tuple[dict, dict]: """ Applies quantization to the model weights. @@ -559,9 +557,6 @@ def quantize_model( quantized_config = copy.deepcopy(config) quantized_config.setdefault("vision_config", {}) - # Get vision model size - vision_size = _get_vision_size(model) - # Apply quantization if skip_vision: # Quantize only non-vision modules @@ -569,38 +564,18 @@ def quantize_model( model, q_group_size, q_bits, - class_predicate=lambda x: not ( - "vision_model" in x.name or "vision_tower" in x.name - ), + class_predicate=get_class_predicate(skip_vision), ) quantized_config["vision_config"]["skip_vision"] = skip_vision - elif skip_vision_non_divisible: + + else: # Quantize only layers with to_quantized method and divisible by 64 nn.quantize( model, q_group_size, q_bits, - class_predicate=lambda _, m: hasattr(m, "to_quantized") - and m.weight.shape[-1] % 64 == 0, + class_predicate=get_class_predicate(skip_vision), ) - quantized_config["vision_config"][ - "skip_vision_non_divisible" - ] = skip_vision_non_divisible - else: - # Pad vision model if needed - _pad_vision_model(model, vision_size) - if hasattr(model.config.vision_config, "intermediate_size"): - _update_vision_config( - quantized_config, vision_size, key="intermediate_size" - ) - elif hasattr(model.config.vision_config, "hidden_size"): - _update_vision_config(quantized_config, vision_size, key="hidden_size") - else: - raise ValueError( - "No intermediate_size or hidden_size found in vision_config" - ) - - nn.quantize(model, q_group_size, q_bits) # Update config and get weights quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} @@ -609,45 +584,6 @@ def quantize_model( return quantized_weights, quantized_config -def _get_vision_size(model: nn.Module) -> int: - """Get vision model intermediate/hidden size.""" - if hasattr(model.config.vision_config, "intermediate_size"): - return model.config.vision_config.intermediate_size - return model.config.vision_config.hidden_size - - -def _pad_vision_model(model: nn.Module, vision_size: int, divisor: int = 64) -> None: - """Pad vision model layers to be divisible by divisor.""" - if any(vision_size % size != 0 for size in [64, 128]): - for name, module in model.named_modules(): - if isinstance(module, nn.Linear) and ( - "vision_model" in name or "vision_tower" in name - ): - out_features, in_features = module.weight.shape - - new_out = ( - ((out_features // divisor) + 1) * divisor - if out_features % divisor - else out_features - ) - new_in = ( - ((in_features // divisor) + 1) * divisor - if in_features % divisor - else in_features - ) - - if new_out != out_features or new_in != in_features: - new_weight = mx.zeros((new_out, new_in), dtype=module.weight.dtype) - new_bias = mx.zeros((new_out), dtype=module.bias.dtype) - - new_weight[:out_features, :in_features] = module.weight - module.weight = new_weight - - if hasattr(module, "bias"): - new_bias[:out_features] = module.bias - module.bias = new_bias - - def _update_vision_config( config: dict, value: Union[int, bool], @@ -731,10 +667,13 @@ def convert( dequantize: bool = False, skip_vision: bool = False, skip_vision_non_divisible: bool = False, + trust_remote_code: bool = True, ): print("[INFO] Loading") model_path = get_model_path(hf_path, revision=revision) - model, config, processor = fetch_from_hub(model_path, lazy=True) + model, config, processor = fetch_from_hub( + model_path, lazy=True, trust_remote_code=trust_remote_code + ) weights = dict(tree_flatten(model.parameters())) dtype = mx.float16 if quantize else getattr(mx, dtype) @@ -773,35 +712,36 @@ def convert( upload_to_hub(mlx_path, upload_repo, hf_path) -def load_image(image_source: Union[str, Path, BytesIO]): +def load_image(image_source: Union[str, Path, BytesIO], timeout: int = 10): """ Helper function to load an image from either a URL or file. """ - if isinstance(image_source, BytesIO): + if isinstance(image_source, BytesIO) or Path(image_source).is_file(): # for base64 encoded images try: - return Image.open(image_source) + image = Image.open(image_source) except IOError as e: - raise ValueError(f"Failed to load image from BytesIO with error: {e}") + raise ValueError( + f"Failed to load image from {image_source} with error: {e}" + ) from e elif image_source.startswith(("http://", "https://")): try: - response = requests.get(image_source, stream=True) + response = requests.get(image_source, stream=True, timeout=timeout) response.raise_for_status() - return Image.open(response.raw) + image = Image.open(response.raw) except Exception as e: raise ValueError( f"Failed to load image from URL: {image_source} with error {e}" - ) - elif Path(image_source).is_file(): - try: - return Image.open(image_source) - except IOError as e: - raise ValueError(f"Failed to load image {image_source} with error: {e}") + ) from e else: raise ValueError( f"The image {image_source} must be a valid URL or existing file." ) + image = ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + def resize_image(img, max_size): ratio = min(max_size[0] / img.width, max_size[1] / img.height) @@ -812,34 +752,27 @@ def resize_image(img, max_size): def process_image(img, resize_shape, image_processor): if isinstance(img, str): img = load_image(img) - if resize_shape is not None and image_processor is None: + if resize_shape is not None and not isinstance(image_processor, BaseImageProcessor): img = resize_image(img, resize_shape) return img -def prepare_inputs( - image_processor, processor, images, prompts, image_token_index, resize_shape=None -): - from transformers.image_utils import load_image +def prepare_inputs(processor, images, prompts, image_token_index, resize_shape=None): - mask = None if not isinstance(images, list): images = [images] # Process images + image_processor = ( + processor.image_processor if hasattr(processor, "image_processor") else None + ) images = [process_image(img, resize_shape, image_processor) for img in images] - image_grid_thw = None - image_sizes = None - aspect_ratio_ids = None - aspect_ratio_mask = None - cross_attention_mask = None - image_input_idx = None - image_masks = None - images_spatial_crop = None - images_seq_mask = None + model_inputs = {} - if image_processor is not None: + if hasattr(processor, "image_processor") and isinstance( + processor.image_processor, BaseImageProcessor + ): if not isinstance(prompts, list): prompts = [prompts] @@ -861,14 +794,13 @@ def prepare_inputs( padding = [processor.pad_token_id] * (max_length - len(ids)) input_ids.append(mx.array(ids + padding)) - input_ids = mx.array(input_ids) - - pixel_values = image_processor.preprocess(images=images) - pixel_values = mx.array(np.stack(pixel_values)) + model_inputs["input_ids"] = mx.array(input_ids) + pixel_values = processor.image_processor.preprocess(images=images) + model_inputs["pixel_values"] = mx.array(np.stack(pixel_values)) + model_inputs["attention_mask"] = mx.array( + [(ids != processor.pad_token_id) for ids in input_ids] + ).astype(mx.int32) - mask = mx.array([(ids != processor.pad_token_id) for ids in input_ids]).astype( - mx.int32 - ) else: processor.tokenizer.pad_token = processor.tokenizer.eos_token if hasattr(processor, "process"): @@ -888,61 +820,17 @@ def prepare_inputs( pixel_values = inputs["pixel_values"] else: pixel_values = mx.array(inputs["pixel_values"]) - input_ids = mx.array(inputs["input_ids"]) - mask = ( + + model_inputs["pixel_values"] = pixel_values + model_inputs["attention_mask"] = ( mx.array(inputs["attention_mask"]) if "attention_mask" in inputs else None ) + # Convert inputs to model_inputs with mx.array if present + for key, value in inputs.items(): + if key not in model_inputs and not isinstance(value, (str, list)): + model_inputs[key] = mx.array(value) - image_input_idx = inputs.get("image_input_idx", None) - if image_input_idx is not None: - image_input_idx = mx.array(image_input_idx) - - image_masks = inputs.get("image_masks", None) - if image_masks is not None: - image_masks = mx.array(image_masks) - - image_sizes = inputs.get("image_sizes", None) - if image_sizes is not None: - image_sizes = mx.array(image_sizes) - - image_grid_thw = inputs.get("image_grid_thw", None) - if image_grid_thw is not None: - image_grid_thw = mx.array(image_grid_thw) - - aspect_ratio_ids = inputs.get("aspect_ratio_ids", None) - if aspect_ratio_ids is not None: - aspect_ratio_ids = mx.array(aspect_ratio_ids) - - aspect_ratio_mask = inputs.get("aspect_ratio_mask", None) - if aspect_ratio_mask is not None: - aspect_ratio_mask = mx.array(aspect_ratio_mask) - - cross_attention_mask = inputs.get("cross_attention_mask", None) - if cross_attention_mask is not None: - cross_attention_mask = mx.array(cross_attention_mask) - - images_spatial_crop = inputs.get("images_spatial_crop", None) - if images_spatial_crop is not None: - images_spatial_crop = mx.array(images_spatial_crop) - - images_seq_mask = inputs.get("images_seq_mask", None) - if images_seq_mask is not None: - images_seq_mask = mx.array(images_seq_mask) - - return ( - input_ids, - pixel_values, - mask, - image_grid_thw, - image_sizes, - aspect_ratio_ids, - aspect_ratio_mask, - cross_attention_mask, - image_input_idx, - image_masks, - images_spatial_crop, - images_seq_mask, - ) + return model_inputs def generate_step( @@ -950,6 +838,7 @@ def generate_step( model: nn.Module, pixel_values, mask, + max_tokens: int = 256, temp: float = 0.0, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, @@ -1076,22 +965,26 @@ def _step(y, **kwargs): else: kwargs = {} + n = 0 while True: - next_y, next_logprobs = _step(y, **kwargs) - mx.async_eval(next_y) - if "decoder_input_ids" in kwargs: - kwargs["decoder_input_ids"] = next_y[None] - yield y.item(), logprobs - y, logprobs = next_y, next_logprobs + if n != max_tokens: + next_y, next_logprobs = _step(y, **kwargs) + mx.async_eval(next_y) + if "decoder_input_ids" in kwargs: + kwargs["decoder_input_ids"] = next_y[None] + yield y.item(), logprobs + y, logprobs = next_y, next_logprobs + if n == max_tokens: + break + + n += 1 def stream_generate( model: nn.Module, processor: PreTrainedTokenizer, - image: str, prompt: str, - image_processor=None, - max_tokens: int = 100, + image: Union[str, List[str]] = None, **kwargs, ) -> Union[str, Generator[str, None, None]]: """ @@ -1107,70 +1000,76 @@ def stream_generate( Yields: Generator[Tuple[mx.array, mx.array]]: A generator producing text. """ - - if image_processor is not None: - tokenizer = processor - else: - tokenizer = processor.tokenizer + tokenizer = processor if hasattr(processor, "encode") else processor.tokenizer + prompt_tokens = mx.array(tokenizer.encode(prompt)) resize_shape = kwargs.pop("resize_shape", None) - if hasattr(model.config, "image_token_index"): - image_token_index = model.config.image_token_index - else: - image_token_index = None + image_token_index = getattr(model.config, "image_token_index", None) - # Prepare inputs - inputs = prepare_inputs( - image_processor, processor, image, prompt, image_token_index, resize_shape - ) - input_ids, pixel_values, mask = inputs[:3] - kwargs = { - k: v - for k, v in zip( - [ - "image_grid_thw", - "image_sizes", - "aspect_ratio_ids", - "aspect_ratio_mask", - "cross_attention_mask", - "images_spatial_crop", - "images_seq_mask", - ], - inputs[3:], + if not image: + input_ids = prompt_tokens[None, :] + pixel_values = mask = None + kwargs = {} + else: + inputs = prepare_inputs( + processor, image, prompt, image_token_index, resize_shape ) - } + input_ids = inputs["input_ids"] + pixel_values = inputs["pixel_values"] + mask = inputs["attention_mask"] + kwargs = { + k: v + for k, v in inputs.items() + if k not in ["input_ids", "pixel_values", "attention_mask"] + } detokenizer = processor.detokenizer - detokenizer.reset() - for (token, _), n in zip( - generate_step(input_ids, model, pixel_values, mask, **kwargs), - range(max_tokens), + tic = time.perf_counter() + for n, (token, logprobs) in enumerate( + generate_step(input_ids, model, pixel_values, mask, **kwargs) ): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = input_ids.size / prompt_time + tic = time.perf_counter() + if token == tokenizer.eos_token_id: break + detokenizer.add_token(token) # Yield the last segment if streaming - yield detokenizer.last_segment + yield GenerationResult( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=input_ids.size, + generation_tokens=n + 1, + prompt_tps=prompt_tps, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) detokenizer.finalize() - yield detokenizer.last_segment + yield GenerationResult( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=input_ids.size, + generation_tokens=n + 1, + prompt_tps=prompt_tps, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) def generate( model: nn.Module, processor: PreTrainedTokenizer, - image: str, prompt: str, - image_processor=None, - temp: float = 0.0, - max_tokens: int = 100, + image: Union[str, List[str]] = None, verbose: bool = False, - formatter: Optional[Callable] = None, - repetition_penalty: Optional[float] = 1.1, - repetition_context_size: Optional[int] = 20, - top_p: float = 1.0, **kwargs, ) -> str: """ @@ -1195,102 +1094,27 @@ def generate( print("Image:", image, "\n") print("Prompt:", prompt) - if image_processor is not None: - prompt_tokens = mx.array(processor.encode(prompt)) - tokenizer = processor - else: - prompt_tokens = mx.array(processor.tokenizer.encode(prompt)) - tokenizer = processor.tokenizer - - resize_shape = kwargs.pop("resize_shape", None) - if hasattr(model.config, "image_token_index"): - image_token_index = model.config.image_token_index - else: - image_token_index = None - - if image == []: - input_ids = prompt_tokens[None, :] - pixel_values = None - mask = None - kwargs = {} - else: - # Prepare inputs - inputs = prepare_inputs( - image_processor, processor, image, prompt, image_token_index, resize_shape - ) - input_ids, pixel_values, mask = inputs[:3] - kwargs = { - k: v - for k, v in zip( - [ - "image_grid_thw", - "image_sizes", - "aspect_ratio_ids", - "aspect_ratio_mask", - "cross_attention_mask", - "image_input_idx", - "image_masks", - "images_spatial_crop", - "images_seq_mask", - ], - inputs[3:], - ) - } - - # Initialize timing and detokenizer - tic = time.perf_counter() - detokenizer = processor.detokenizer - detokenizer.reset() - - # Generate tokens - generator = generate_step( - input_ids, - model, - pixel_values, - mask, - temp, - repetition_penalty, - repetition_context_size, - top_p, - **kwargs, - ) - - for (token, prob), n in zip(generator, range(max_tokens)): - - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - # TODO: Fix as first token - # Handle special case for DeepSeek-vl-7b-chat and PaliGemma models - # These models may generate EOS token as the first token (n == 0) - # For all other cases, break the loop when EOS is encountered after the first token - if token == tokenizer.eos_token_id and n > 0: - break - - detokenizer.add_token(token) - + text = "" + last_response = None + for response in stream_generate(model, processor, prompt, image, **kwargs): if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - formatter(detokenizer.last_segment, prob.item()) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() + print(response.text, end="", flush=True) + text += response.text + last_response = response if verbose: - print(detokenizer.last_segment, flush=True) - gen_time = time.perf_counter() - tic - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") + print("\n" + "=" * 10) + if len(text) == 0: + print("No text generated for this prompt") return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - - print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {gen_tps:.3f} tokens-per-sec") + print( + f"Prompt: {last_response.prompt_tokens} tokens, " + f"{last_response.prompt_tps:.3f} tokens-per-sec" + ) + print( + f"Generation: {last_response.generation_tokens} tokens, " + f"{last_response.generation_tps:.3f} tokens-per-sec" + ) + print(f"Peak memory: {last_response.peak_memory:.3f} GB") - return detokenizer.text + return text diff --git a/mlx_vlm/version.py b/mlx_vlm/version.py index 0a8da88..f1380ee 100644 --- a/mlx_vlm/version.py +++ b/mlx_vlm/version.py @@ -1 +1 @@ -__version__ = "0.1.6" +__version__ = "0.1.7" diff --git a/requirements.txt b/requirements.txt index c6634bb..1b773cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ mlx>=0.18.1 datasets>=2.19.1 tqdm>=4.66.2 numpy>=1.23.4 -transformers>=4.45.1 +transformers>=4.47.1 scipy==1.13.1 gradio>=4.44.0 Pillow>=10.3.0