Skip to content

Commit

Permalink
Merge pull request #15 from probcomp/alexlew-hf-updates
Browse files Browse the repository at this point in the history
Use BitsAndBytesConfig instead of load_in_8bit and token instead of u…
  • Loading branch information
alex-lew authored Jul 17, 2024
2 parents 4921bfe + a8d3c6d commit d72096e
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions hfppl/llms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utilities for working with HuggingFace language models, including caching and auto-batching."""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import asyncio
import string

Expand Down Expand Up @@ -266,18 +266,22 @@ def from_pretrained(cls, model_id, auth_token=False, load_in_8bit=True):
Returns:
model (hfppl.llms.CachedCausalLM): the LLaMPPL-compatible interface to the HuggingFace model.
"""
bnb_config = BitsAndBytesConfig(load_in_8bit=load_in_8bit)

if not auth_token:
tok = AutoTokenizer.from_pretrained(model_id)
mod = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", load_in_8bit=load_in_8bit
model_id,
device_map="auto",
quantization_config=bnb_config,
)
else:
tok = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
tok = AutoTokenizer.from_pretrained(model_id, token=auth_token)
mod = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=auth_token,
token=auth_token,
device_map="auto",
load_in_8bit=load_in_8bit,
quantization_config=bnb_config,
)

return CachedCausalLM(mod, tok)
Expand Down

0 comments on commit d72096e

Please sign in to comment.