From 68dfaf6d5ab29f4d2bb7ebd3c47b4c081ccb64ac Mon Sep 17 00:00:00 2001 From: sahil suneja Date: Tue, 10 Sep 2024 22:07:20 +0000 Subject: [PATCH] remove embed_* variantes of model architectures; assuming base model implementation supports returning embedding vectors Signed-off-by: sahil suneja --- fms_fsdp/config/training.py | 2 +- scripts/README_SPECULATOR.md | 2 +- speculator/train_speculator_utils.py | 142 --------------------------- 3 files changed, 2 insertions(+), 144 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 1d072958..ff738c1f 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -62,7 +62,7 @@ class train_config: # speculator training tp_size: int = 8 - model_arch: str = "embedllama" + model_arch: str = "llama" model_path: str = "/path/to/model/" n_speculator_heads: int = 3 speculator_width: int = 4096 diff --git a/scripts/README_SPECULATOR.md b/scripts/README_SPECULATOR.md index c1e850d2..552140b6 100644 --- a/scripts/README_SPECULATOR.md +++ b/scripts/README_SPECULATOR.md @@ -1,6 +1,6 @@ ### Following parameters are relevant for speculator training: -- *model_arch*: architecture of the base model (one of: embedllama, embedmixtral, embedgpt_bigcode-- FMS implementations extending the base arch to also emit embedding vector together with the model output. See 'EmbedLLaMA' in train_spculator_utils.py) +- *model_arch*: architecture of the base model (one of: llama, mixtral, gpt_bigcode) - *model_variant*: identifier with which a specific variant (e.g., 7b) is registered for the model architecture. See 'example model registrations' in train_spculator_utils.py. diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 87b4e7b2..cc0fd5de 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -7,14 +7,6 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from fms.models import register_model -from fms.models.gpt_bigcode import GPTBigCode -from fms.models.gpt_bigcode import _20b_config as _gpt_bigcode_20b_config -from fms.models.gpt_bigcode import _hf_sd_to_fms_sd as _gptbigcode_hf_sd_to_fms_sd -from fms.models.llama import LLaMA -from fms.models.llama import _hf_sd_to_fms_sd as _llama_hf_sd_to_fms_sd -from fms.models.mixtral import Mixtral, MixtralConfig -from fms.models.mixtral import _hf_sd_to_fms_sd as _mixtral_hf_sd_to_fms_sd from fms.utils import serialization, tokenizers from fms.utils.generation import _make_cache_contiguous from torch.nn import CrossEntropyLoss @@ -431,137 +423,3 @@ def train_speculator( tokens_seen=elapsed_tokens + n_tok, is_compiled=cfg.use_torch_compile, ) - - -class EmbedGPTBigCode(GPTBigCode): - # Overrides the forward function of GPTBigCode to allow returning embedding vectors - def forward( - self, - x: torch.LongTensor, - mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, - use_cache: bool = False, - attn_algorithm: Optional[str] = None, - include_embeds: bool = False, - ): - output, cache = self.base_model( - x, - mask, - position_ids=position_ids, - past_key_value_states=past_key_value_states, - use_cache=use_cache, - attn_algorithm=attn_algorithm, - ) - - preds = self.head(output) - - out = [preds] - if use_cache: - out.append(cache) - if include_embeds: - out.append(output) - if len(out) == 1: - return out[0] - return out - - -class EmbedLLaMA(LLaMA): - # Overrides the forward function of LLaMA to allow returning embedding vectors - def forward( - self, - x, - mask=None, - position_ids=None, - past_key_value_states=None, - use_cache=False, - only_last_token=False, - attn_algorithm=None, - include_embeds=False, - ): - output, cache = self._helper( - x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm - ) - - if only_last_token: - output = output[:, -1, :] - preds = self.shared(output, reverse=True) - - out = [preds] - if use_cache: - out.append(cache) - if include_embeds: - out.append(output) - if len(out) == 1: - return out[0] - return out - - -class EmbedMixtral(Mixtral): # FMS impl of Mixtral - # Overrides the forward function of Mixtral to allow returning embedding vectors - def forward( - self, - x, - mask=None, - position_ids=None, - past_key_value_states=None, - use_cache=False, - only_last_token=False, - attn_algorithm=None, - include_embeds=False, - ): - output, cache = self.base_model( - x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm - ) - - if only_last_token: - output = output[:, -1, :] - preds = self.head(output) - - out = [preds] - if use_cache: - out.append(cache) - if include_embeds: - out.append(output) - if len(out) == 1: - return out[0] - return out - - -def _gpt_bigcode_factory_factory(config): - def factory(**kwargs): - return EmbedGPTBigCode(config, **kwargs) - - return factory - - -def _llama_factory_factory(config): - def factory(**kwargs): - return EmbedLLaMA(config, **kwargs) - - return factory - - -def _mixtral_factory_factory(config): - def factory(**kwargs): - return EmbedMixtral(config, **kwargs) - - return factory - - -# example model registrations -register_model( - "embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config) -) -serialization.register_adapter("embedgpt_bigcode", "hf", _gptbigcode_hf_sd_to_fms_sd) - -register_model( - "embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b")) -) -register_model( - "embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b")) -) -serialization.register_adapter("embedllama", "hf", _llama_hf_sd_to_fms_sd) - -register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig())) -serialization.register_adapter("embedmixtral", "hf", _mixtral_hf_sd_to_fms_sd)