Skip to content

Commit

Permalink
Update loader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jun 13, 2024
1 parent 82f10cb commit 08424f0
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,18 @@ def _get_model_name(model_name, load_in_4bit = True):
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
trust_remote_code = False,
use_gradient_checkpointing = True,
resize_model_vocab = None,
revision = None,
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
trust_remote_code = False,
use_gradient_checkpointing = "unsloth",
resize_model_vocab = None,
revision = None,
*args, **kwargs,
):
if token is None and "HF_TOKEN" in os.environ:
Expand Down Expand Up @@ -141,23 +141,24 @@ def from_pretrained(
pass

model, tokenizer = dispatch_model.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = dispatch_model,
tokenizer_name = tokenizer_name,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = dispatch_model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
revision = revision if not is_peft else None,
revision = revision if not is_peft else None,
*args, **kwargs,
)

if resize_model_vocab is not None:
model.resize_token_embeddings(resize_model_vocab)
pass

# In case the model supports tagging, add the unsloth tag.
if hasattr(model, "add_model_tags"):
Expand Down Expand Up @@ -187,8 +188,16 @@ def from_pretrained(
pass

if is_peft:
# From https://github.com/huggingface/peft/issues/184
# Now add PEFT adapters
model = PeftModel.from_pretrained(model, old_model_name, token = token, revision = revision)
model.enable_input_require_grads()
model = PeftModel.from_pretrained(
model,
old_model_name,
token = token,
revision = revision,
is_trainable = True,
)
# Patch it as well!
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
pass
Expand Down

0 comments on commit 08424f0

Please sign in to comment.