Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove embed_* variants of model architectures in speculator training #115

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class train_config:

# speculator training
tp_size: int = 8
model_arch: str = "embedllama"
model_arch: str = "llama"
model_path: str = "/path/to/model/"
n_speculator_heads: int = 3
speculator_width: int = 4096
Expand Down
2 changes: 1 addition & 1 deletion scripts/README_SPECULATOR.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
### Following parameters are relevant for speculator training:

- *model_arch*: architecture of the base model (one of: embedllama, embedmixtral, embedgpt_bigcode-- FMS implementations extending the base arch to also emit embedding vector together with the model output. See 'EmbedLLaMA' in train_spculator_utils.py)
- *model_arch*: architecture of the base model (one of: llama, mixtral, gpt_bigcode)

- *model_variant*: identifier with which a specific variant (e.g., 7b) is registered for the model architecture. See 'example model registrations' in train_spculator_utils.py.

Expand Down
142 changes: 0 additions & 142 deletions speculator/train_speculator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from fms.models import register_model
from fms.models.gpt_bigcode import GPTBigCode
from fms.models.gpt_bigcode import _20b_config as _gpt_bigcode_20b_config
from fms.models.gpt_bigcode import _hf_sd_to_fms_sd as _gptbigcode_hf_sd_to_fms_sd
from fms.models.llama import LLaMA
from fms.models.llama import _hf_sd_to_fms_sd as _llama_hf_sd_to_fms_sd
from fms.models.mixtral import Mixtral, MixtralConfig
from fms.models.mixtral import _hf_sd_to_fms_sd as _mixtral_hf_sd_to_fms_sd
from fms.utils import serialization, tokenizers
from fms.utils.generation import _make_cache_contiguous
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -431,137 +423,3 @@ def train_speculator(
tokens_seen=elapsed_tokens + n_tok,
is_compiled=cfg.use_torch_compile,
)


class EmbedGPTBigCode(GPTBigCode):
# Overrides the forward function of GPTBigCode to allow returning embedding vectors
def forward(
self,
x: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None,
use_cache: bool = False,
attn_algorithm: Optional[str] = None,
include_embeds: bool = False,
):
output, cache = self.base_model(
x,
mask,
position_ids=position_ids,
past_key_value_states=past_key_value_states,
use_cache=use_cache,
attn_algorithm=attn_algorithm,
)

preds = self.head(output)

out = [preds]
if use_cache:
out.append(cache)
if include_embeds:
out.append(output)
if len(out) == 1:
return out[0]
return out


class EmbedLLaMA(LLaMA):
# Overrides the forward function of LLaMA to allow returning embedding vectors
def forward(
self,
x,
mask=None,
position_ids=None,
past_key_value_states=None,
use_cache=False,
only_last_token=False,
attn_algorithm=None,
include_embeds=False,
):
output, cache = self._helper(
x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm
)

if only_last_token:
output = output[:, -1, :]
preds = self.shared(output, reverse=True)

out = [preds]
if use_cache:
out.append(cache)
if include_embeds:
out.append(output)
if len(out) == 1:
return out[0]
return out


class EmbedMixtral(Mixtral): # FMS impl of Mixtral
# Overrides the forward function of Mixtral to allow returning embedding vectors
def forward(
self,
x,
mask=None,
position_ids=None,
past_key_value_states=None,
use_cache=False,
only_last_token=False,
attn_algorithm=None,
include_embeds=False,
):
output, cache = self.base_model(
x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm
)

if only_last_token:
output = output[:, -1, :]
preds = self.head(output)

out = [preds]
if use_cache:
out.append(cache)
if include_embeds:
out.append(output)
if len(out) == 1:
return out[0]
return out


def _gpt_bigcode_factory_factory(config):
def factory(**kwargs):
return EmbedGPTBigCode(config, **kwargs)

return factory


def _llama_factory_factory(config):
def factory(**kwargs):
return EmbedLLaMA(config, **kwargs)

return factory


def _mixtral_factory_factory(config):
def factory(**kwargs):
return EmbedMixtral(config, **kwargs)

return factory


# example model registrations
register_model(
"embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config)
)
serialization.register_adapter("embedgpt_bigcode", "hf", _gptbigcode_hf_sd_to_fms_sd)

register_model(
"embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b"))
)
register_model(
"embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b"))
)
serialization.register_adapter("embedllama", "hf", _llama_hf_sd_to_fms_sd)

register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig()))
serialization.register_adapter("embedmixtral", "hf", _mixtral_hf_sd_to_fms_sd)
Loading