Skip to content

Commit

Permalink
Local inference (#27)
Browse files Browse the repository at this point in the history
* First pass at load and run

* Return first decoded entry, load docstring

* Initial implementation for adapter config patcher

* Add adapter config patcher to util init

* First pass at script for loading and running a model with CLI

* Add base model override to cli for inference script

* Return immediately if no overrides are given

* Add adapter config overrides to inference script

* CLI support for processing one or more texts

* Docstring updates for load / run

* Refactor train into tuned model classmethod

* Move inference CLI to a separate script

* Infer device for inference

* adapter config docstrings

Signed-off-by: Alex-Brooks <[email protected]>

* Add inference instructions

Signed-off-by: Alex-Brooks <[email protected]>

* Add max new tokens as an arg to run inference

Signed-off-by: Alex-Brooks <[email protected]>

* Split inference and tuning back apart

Signed-off-by: Alex-Brooks <[email protected]>

* Consolidate inference cli and tuned model class

Signed-off-by: Alex-Brooks <[email protected]>

* Consolidate adapter config patcher into inference script

Signed-off-by: Alex-Brooks <[email protected]>

* Move inference script outside of tuning package

Signed-off-by: Alex-Brooks <[email protected]>

* Update readme inference instructions

Signed-off-by: Alex-Brooks <[email protected]>

---------

Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks authored Feb 6, 2024
1 parent 304b179 commit fc07060
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 25 deletions.
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,56 @@ tuning/sft_trainer.py \

For `GPTBigCode` models, Hugging Face has enabled Flash v2 and one can simply replace the `'LlamaDecoderLayer'` with `'GPTBigCodeBlock'` in `tuning/config/fsdp_config.json` for proper sharding of the model.

## Inference
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.

### Running a single example
If you want to run a single example through a model, you can pass it with the `--text` flag.

```bash
python scripts/run_inference.py \
--model my_checkpoint \
--text "This is a text the model will run inference on" \
--max_new_tokens 50 \
--out_file result.json
```

### Running multiple examples
To run multiple examples, pass a path to a file containing each source text as its own line. Example:

Contents of `source_texts.txt`
```
This is the first text to be processed.
And this is the second text to be processed.
```

```bash
python scripts/run_inference.py \
--model my_checkpoint \
--text_file source_texts.txt \
--max_new_tokens 50 \
--out_file result.json
```

### Inference Results Format
After running the inference script, the specified `--out_file` will be a JSON file, where each text has the original input string and the predicted output string, as follows. Note that due to the implementation of `.generate()` in Transformers, in general, the input string will be contained in the output string as well.
```
[
{
"input": "{{Your input string goes here}}",
"output": "{{Generate result of processing your input string goes here}}"
},
...
]
```

### Changing the Base Model for Inference
If you tuned a model using a *local* base model, then a machine-specific path will be saved into your checkpoint by Peft, specifically the `adapter_config.json`. This can be problematic if you are running inference on a different machine than you used for tuning.

As a workaround, the CLI for inference provides an arg for `--base_model_name_or_path`, where a new base model may be passed to run inference with. This will patch the `base_model_name_or_path` in your checkpoint's `adapter_config.json` while loading the model, and restore it to its original value after completion. Alternatively, if you like, you can change the config's value yourself.

NOTE: This can also be an issue for tokenizers (with the `tokenizer_name_or_path` config entry). We currently do not allow tokenizer patching since the tokenizer can also be explicitly configured within the base model and checkpoint model, but may choose to expose an override for the `tokenizer_name_or_path` in the future.

## Validation

We can use [`lm-evaluation-harness`](https://github.com/EleutherAI/lm-evaluation-harness) from EleutherAI for evaluating the generated model. For example, for the Llama-13B model, using the above command and the model at the end of Epoch 5, we evaluated MMLU score to be `53.9` compared to base model to be `52.8`.
Expand Down
234 changes: 234 additions & 0 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""CLI for running loading a tuned model and running one or more inference calls on it.
NOTE: For the moment, this script is intentionally written to contain all dependencies for two
reasons:
- to keep it portable and not deal with managing multiple local packages.
- because we don't currently plan on supporting inference as a library; i.e., this is only for
testing.
If these things change in the future, we should consider breaking it up.
"""
import argparse
import json
import os
from peft import AutoPeftModelForCausalLM
import torch
from tqdm import tqdm
from transformers import AutoTokenizer


### Utilities
class AdapterConfigPatcher:
"""Adapter config patcher is a context manager for patching overrides into a config;
it will locate the adapter_config.json in a file and patch a dict of provided overrides
when inside the dict block, and restore them when it leaves. This DOES actually write to
the file, so it's NOT safe to use in parallel inference runs.
Example:
overrides = {"base_model_name_or_path": "foo"}
with AdapterConfigPatcher(checkpoint_path, overrides):
# When loaded in this block, the config's base_model_name_or_path is "foo"
peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
"""
def __init__(self, checkpoint_path: str, overrides: dict):
self.checkpoint_path = checkpoint_path
self.overrides = overrides
self.config_path = AdapterConfigPatcher._locate_adapter_config(self.checkpoint_path)
# Values that we will patch later on
self.patched_values = {}

@staticmethod
def _locate_adapter_config(checkpoint_path: str) -> str:
"""Given a path to a tuned checkpoint, tries to find the adapter_config
that will be loaded through the Peft auto model API.
Args:
checkpoint_path: str
Checkpoint model, which presumably holds an adapter config.
Returns:
str
Path to the located adapter_config file.
"""
config_path = os.path.join(checkpoint_path, "adapter_config.json")
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Could not locate adapter config: {config_path}")
return config_path

def _apply_config_changes(self, overrides: dict) -> dict:
"""Applies a patch to a config with some override dict, returning the values
that we patched over so that they may be restored later.
Args:
overrides: dict
Overrides to write into the adapter_config.json. Currently, we
require all override keys to be defined in the config being patched.
Returns:
dict
Dict containing the values that we patched over.
"""
# If we have no overrides, this context manager is a noop; no need to do anything
if not overrides:
return {}
with open(self.config_path, "r") as config_file:
adapter_config = json.load(config_file)
overridden_values = self._get_old_config_values(adapter_config, overrides)
adapter_config = {**adapter_config, **overrides}
with open(self.config_path, "w") as config_file:
json.dump(adapter_config, config_file, indent=4)
return overridden_values

@staticmethod
def _get_old_config_values(adapter_config: dict, overrides: dict) -> dict:
"""Grabs the existing config subdict that we are going to clobber from the
loaded adapter_config.
Args:
adapter_config: dict
Adapter config whose values we are interested in patching.
overrides: dict
Dict of overrides, containing keys definined in the adapter_config with
new values.
Returns:
dict
The subdictionary of adapter_config, containing the keys in overrides,
with the values that we are going to replace.
"""
# For now, we only expect to patch the base model; this may change in the future,
# but ensure that anything we are patching is defined in the original config
if not set(overrides.keys()).issubset(set(adapter_config.keys())):
raise KeyError("Adapter config overrides must be set in the config being patched")
return {key: adapter_config[key] for key in overrides}

def __enter__(self):
"""Apply the config overrides and saved the patched values."""
self.patched_values = self._apply_config_changes(self.overrides)

def __exit__(self, exc_type, exc_value, exc_tb):
"""Apply the patched values over our exported overrides."""
self._apply_config_changes(self.patched_values)


### Funcs for loading and running models
class TunedCausalLM:
def __init__(self, model, tokenizer, device):
self.peft_model = model
self.tokenizer = tokenizer
self.device = device

@classmethod
def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "TunedCausalLM":
"""Loads an instance of this model.
Args:
checkpoint_path: str
Checkpoint model to be loaded, which is a directory containing an
adapter_config.json.
base_model_name_or_path: str [Default: None]
Override for the base model to be used.
By default, the paths for the base model and tokenizer are contained within the adapter
config of the tuned model. Note that in this context, a path may refer to a model to be
downloaded from HF hub, or a local path on disk, the latter of which we must be careful
with when using a model that was written on a different device.
Returns:
TunedCausalLM
An instance of this class on which we can run inference.
"""
overrides = {"base_model_name_or_path": base_model_name_or_path} if base_model_name_or_path is not None else {}
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
# Apply the configs to the adapter config of this model; if no overrides
# are provided, then the context manager doesn't have any effect.
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
except OSError as e:
print("Failed to initialize checkpoint model!")
raise e
device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
peft_model.to(device)
return cls(peft_model, tokenizer, device)


def run(self, text: str, *, max_new_tokens: int) -> str:
"""Runs inference on an instance of this model.
Args:
text: str
Text on which we want to run inference.
max_new_tokens: int
Max new tokens to use for inference.
Returns:
str
Text generation result.
"""
tok_res = self.tokenizer(text, return_tensors="pt")
input_ids = tok_res.input_ids.to(self.device)

peft_outputs = self.peft_model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens)
decoded_result = self.tokenizer.batch_decode(peft_outputs, skip_special_tokens=False)[0]
return decoded_result


### Main & arg parsing
def main():
parser = argparse.ArgumentParser(
description="Loads a tuned model and runs an inference call(s) through it"
)
parser.add_argument("--model", help="Path to tuned model to be loaded", required=True)
parser.add_argument(
"--out_file",
help="JSON file to write results to",
default="inference_result.json",
)
parser.add_argument(
"--base_model_name_or_path",
help="Override for base model to be used [default: value in model adapter_config.json]",
default=None
)
parser.add_argument(
"--max_new_tokens",
help="max new tokens to use for inference",
type=int,
default=20,
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--text", help="Text to run inference on")
group.add_argument("--text_file", help="File to be processed where each line is a text to run inference on")
args = parser.parse_args()
# If we passed a file, check if it exists before doing anything else
if args.text_file and not os.path.isfile(args.text_file):
raise FileNotFoundError(f"Text file: {args.text_file} does not exist!")

# Load the model
loaded_model = TunedCausalLM.load(
checkpoint_path=args.model,
base_model_name_or_path=args.base_model_name_or_path,
)

# Run inference on the text; if multiple were provided, process them all
if args.text:
texts = [args.text]
else:
with open(args.text_file, "r") as text_file:
texts = [line.strip() for line in text_file.readlines()]

# TODO: we should add batch inference support
results = [
{"input": text, "output": loaded_model.run(text, max_new_tokens=args.max_new_tokens)}
for text in tqdm(texts)
]

# Export the results to a file
with open(args.out_file, "w") as out_file:
json.dump(results, out_file, sort_keys=True, indent=4)

print(f"Exported results to: {args.out_file}")

if __name__ == "__main__":
main()
48 changes: 23 additions & 25 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaTokenizerFast, GPTNeoXTokenizerFast, GPT2Tokenizer
import os
from typing import Optional, Union

import datasets
import fire
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import transformers
from peft.utils.other import fsdp_auto_wrap_policy
import torch
import datasets

from tuning.data import tokenizer_data_utils
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaTokenizerFast, GPTNeoXTokenizerFast, GPT2Tokenizer
from transformers.utils import logging
from transformers import TrainerCallback
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from tuning.aim_loader import get_aimstack_callback
from tuning.config import configs, peft_config
from tuning.data import tokenizer_data_utils
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype

from tuning.aim_loader import get_aimstack_callback
from transformers.utils import logging
from dataclasses import asdict
from typing import Optional, Union

from peft import LoraConfig
import os
from transformers import TrainerCallback
from peft.utils.other import fsdp_auto_wrap_policy

class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
Expand All @@ -29,21 +25,22 @@ def on_save(self, args, state, control, **kwargs):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))



def train(
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
train_args: configs.TrainingArguments,
peft_config: Optional[Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]] = None,
):
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
train_args: configs.TrainingArguments,
peft_config: Optional[Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]] = None,
):
"""Call the SFTTrainer
Args:
model_args: tuning.config.configs.ModelArguments
data_args: tuning.config.configs.DataArguments
train_args: tuning.config.configs.TrainingArguments
peft_config: peft_config.LoraConfig for Lora tuning | \
peft_config.PromptTuningConfig for prompt tuning | \
None for fine tuning
peft_config.PromptTuningConfig for prompt tuning | \
None for fine tuning
The peft configuration to pass to trainer
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1
Expand All @@ -62,7 +59,7 @@ def train(
train_args.fsdp_config = {'xla':False}

task_type = "CAUSAL_LM"
model = transformers.AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=train_args.cache_dir,
torch_dtype=get_torch_dtype(model_args.torch_dtype),
Expand All @@ -74,7 +71,7 @@ def train(
model.gradient_checkpointing_enable()

# TODO: Move these to a config as well
tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=train_args.cache_dir,
use_fast = True
Expand Down Expand Up @@ -170,6 +167,7 @@ def train(
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
trainer.train()


def main(**kwargs):
parser = transformers.HfArgumentParser(dataclass_types=(configs.ModelArguments,
configs.DataArguments,
Expand Down

0 comments on commit fc07060

Please sign in to comment.