Skip to content

Commit

Permalink
Tweak ArmorRM implementation, add args to CLI (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert authored Oct 4, 2024
1 parent e42d40f commit c8f3fd1
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 10 deletions.
4 changes: 2 additions & 2 deletions rewardbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@
"model_builder": AutoModelForSequenceClassification.from_pretrained,
"pipeline_builder": ArmoRMPipeline,
"quantized": False,
"custom_dialogue": True,
"model_type": "Custom Classifier",
"custom_dialogue": False,
"model_type": "Sequence Classifier",
"torch_dtype": torch.bfloat16,
},
"Ray2333/GRM-Gemma-2B-sftreg": {
Expand Down
29 changes: 29 additions & 0 deletions rewardbench/models/armorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,35 @@


class ArmoRMPipeline:
def __init__(self, task, model, tokenizer):
self.task = task
self.model = model.eval()
self.tokenizer = tokenizer

def __call__(self, samples, return_inputs=False, **kwargs):
_ = kwargs.get("batch_size", 1)
truncation = kwargs.get("truncation", True)
padding = kwargs.get("padding", True)
max_length = kwargs.get("max_length", 2048)
inputs = self.tokenizer(
samples,
truncation=truncation,
max_length=max_length,
padding=padding,
# return_special_tokens_mask=True,
return_tensors="pt",
).to("cuda")

with torch.no_grad():
outputs = self.model(**inputs)
if return_inputs:
return outputs.logits, inputs
else:
return outputs.logits


# Moved to newer implementation that doesn't require "Custom Dialogue" tag
class LegacyArmoRMPipeline:
def __init__(self, task, model, tokenizer):
self.task = task
self.model = model
Expand Down
38 changes: 31 additions & 7 deletions rewardbench/rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
from dataclasses import dataclass
from pprint import pformat
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, Optional, Union

import numpy as np
import pkg_resources
Expand All @@ -41,6 +41,7 @@
REWARD_MODEL_CONFIG,
check_tokenizer_chat_template,
load_and_process_dataset,
torch_dtype_mapping,
)


Expand Down Expand Up @@ -85,6 +86,10 @@ class Args:
"""The batch size to use."""
max_length: int = 512
"""The max length to use."""
torch_dtype: Literal["float16", "bfloat16", "float32", "float64"] = "float16"
"""PyTorch dtype (default: float16)"""
attn_implementation: Optional[Literal["eager", "sdpa", "flash_attention_2"]] = None
"""Attention implementation to use (default: None)"""

# system args
load_json: bool = False
Expand Down Expand Up @@ -274,11 +279,22 @@ def rewardbench(args: Args):
custom_dialogue = config["custom_dialogue"]
pipeline_builder = config["pipeline_builder"]
_ = config["model_type"]
torch_dtype = config.get("torch_dtype", None)
if custom_dialogue:
raise NotImplementedError("Custom dialogue not implemented yet for simpler data formatting.")

model_builder = config["model_builder"]

# Handle datatype
args.torch_dtype = torch_dtype_mapping(args.torch_dtype)
# if not datatype in config (default), check args
if torch_dtype is None:
# if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
if args.torch_dtype == torch.bfloat16:
quantized = False
logger.info("Disabling quantization for bfloat16 datatype")
torch_dtype = args.torch_dtype

#########################
# load dataset
#########################
Expand Down Expand Up @@ -344,7 +360,7 @@ def rewardbench(args: Args):

model_kwargs = {
"load_in_8bit": True,
"device_map": "auto",
"device_map": "auto" if torch.cuda.is_available() else "cpu",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
model = model_builder(
Expand Down Expand Up @@ -408,11 +424,19 @@ def rewardbench(args: Args):
model_kwargs = {
"load_in_8bit": True,
"device_map": {"": current_device},
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
else:
# note, device map auto does not work for quantized models
model_kwargs = {"device_map": "auto"}
# note, device map auto does not work for bitsandbytes quantized models
model_kwargs = {
"device_map": "auto",
"torch_dtype": torch_dtype,
}

# if attn_implementation is not specified, this falls back to Hugging Face's default
# strategy (which chooses between sdpa and eager depending on pytorch version)
if args.attn_implementation:
model_kwargs["attn_implementation"] = args.attn_implementation

model = model_builder(
args.model, **model_kwargs, revision=args.revision, trust_remote_code=args.trust_remote_code
Expand Down Expand Up @@ -472,8 +496,8 @@ def rewardbench(args: Args):
score_rejected_batch = [result["score"] for result in rewards_rejected]
# for classes that directly output scores (custom code)
else:
score_chosen_batch = rewards_chosen.cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.cpu().numpy().tolist()
score_chosen_batch = rewards_chosen.float().cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.float().cpu().numpy().tolist()

# log results
[
Expand Down
3 changes: 2 additions & 1 deletion scripts/run_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def main():
model_builder = config["model_builder"]
pipeline_builder = config["pipeline_builder"]
torch_dtype = config.get("torch_dtype", None)

# if not datatype in config (default), check args
if torch_dtype is None:
# if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
Expand Down Expand Up @@ -211,7 +212,7 @@ def main():
}
else:
model_kwargs = {
"device_map": "auto",
"device_map": "auto" if torch.cuda.is_available() else "cpu",
"torch_dtype": torch_dtype,
}

Expand Down

0 comments on commit c8f3fd1

Please sign in to comment.