diff --git a/hfppl/llms.py b/hfppl/llms.py index 1e554ec..1dcb334 100644 --- a/hfppl/llms.py +++ b/hfppl/llms.py @@ -1,7 +1,7 @@ """Utilities for working with HuggingFace language models, including caching and auto-batching.""" import torch -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import asyncio import string @@ -266,18 +266,22 @@ def from_pretrained(cls, model_id, auth_token=False, load_in_8bit=True): Returns: model (hfppl.llms.CachedCausalLM): the LLaMPPL-compatible interface to the HuggingFace model. """ + bnb_config = BitsAndBytesConfig(load_in_8bit=load_in_8bit) + if not auth_token: tok = AutoTokenizer.from_pretrained(model_id) mod = AutoModelForCausalLM.from_pretrained( - model_id, device_map="auto", load_in_8bit=load_in_8bit + model_id, + device_map="auto", + quantization_config=bnb_config, ) else: - tok = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token) + tok = AutoTokenizer.from_pretrained(model_id, token=auth_token) mod = AutoModelForCausalLM.from_pretrained( model_id, - use_auth_token=auth_token, + token=auth_token, device_map="auto", - load_in_8bit=load_in_8bit, + quantization_config=bnb_config, ) return CachedCausalLM(mod, tok)