Skip to content

Commit

Permalink
feat: add flash attn to inference and eval scripts (#132)
Browse files Browse the repository at this point in the history
* add flash attn to inference and eval scripts

Signed-off-by: Anh-Uong <[email protected]>

* load model with torch_dtype bfloat16

Signed-off-by: Anh-Uong <[email protected]>

---------

Signed-off-by: Anh-Uong <[email protected]>
  • Loading branch information
anhuong authored May 2, 2024
1 parent c2f2f8c commit dd29d49
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
7 changes: 6 additions & 1 deletion scripts/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def parse_and_validate_args():
action="store_true",
)
parser.add_argument("--purge_results", action=argparse.BooleanOptionalAction)
parser.add_argument(
"--use_flash_attn",
help="Whether to load the model using Flash Attention 2",
action="store_true",
)
parsed_args = parser.parse_args()

print(f"Multiclass / multioutput delimiter: {parsed_args.delimiter}")
Expand Down Expand Up @@ -441,7 +446,7 @@ def export_experiment_info(

if __name__ == "__main__":
args = parse_and_validate_args()
tuned_model = TunedCausalLM.load(args.model)
tuned_model = TunedCausalLM.load(args.model, use_flash_attn=args.use_flash_attn)
eval_data = datasets.load_dataset(
"json", data_files=args.data_path, split=args.split
)
Expand Down
27 changes: 24 additions & 3 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def __init__(self, model, tokenizer, device):

@classmethod
def load(
cls, checkpoint_path: str, base_model_name_or_path: str = None
cls,
checkpoint_path: str,
base_model_name_or_path: str = None,
use_flash_attn: bool = False,
) -> "TunedCausalLM":
"""Loads an instance of this model.
Expand All @@ -152,6 +155,8 @@ def load(
adapter_config.json.
base_model_name_or_path: str [Default: None]
Override for the base model to be used.
use_flash_attn: bool [Default: False]
Whether to load the model using flash attention.
By default, the paths for the base model and tokenizer are contained within the adapter
config of the tuned model. Note that in this context, a path may refer to a model to be
Expand All @@ -173,14 +178,24 @@ def load(
try:
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
model = AutoPeftModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
except OSError as e:
print("Failed to initialize checkpoint model!")
raise e
except FileNotFoundError:
print("No adapter config found! Loading as a merged model...")
# Unable to find the adapter config; fall back to loading as a merged model
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
Expand Down Expand Up @@ -246,6 +261,11 @@ def main():
type=int,
default=20,
)
parser.add_argument(
"--use_flash_attn",
help="Whether to load the model using Flash Attention 2",
action="store_true",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--text", help="Text to run inference on")
group.add_argument(
Expand All @@ -261,6 +281,7 @@ def main():
loaded_model = TunedCausalLM.load(
checkpoint_path=args.model,
base_model_name_or_path=args.base_model_name_or_path,
use_flash_attn=args.use_flash_attn,
)

# Run inference on the text; if multiple were provided, process them all
Expand Down

0 comments on commit dd29d49

Please sign in to comment.