Skip to content

Commit

Permalink
Updated hyperparameters for finetuning
Browse files Browse the repository at this point in the history
* Updated CLI hyperparameters for fine tuning

* minor fix

* Comment out parameters that have not yet been implemented

* minor fix: default value for n_epochs
  • Loading branch information
orangetin authored Jun 29, 2023
1 parent 71277be commit 1faa861
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 70 deletions.
113 changes: 59 additions & 54 deletions src/together/commands/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def _add_create(parser: argparse._SubParsersAction[argparse.ArgumentParser]) ->
required=True,
type=str,
)
create_finetune_parser.add_argument(
"--validation-file",
"-v",
default=None,
help="The ID of an uploaded file that contains validation data.",
type=str,
)
# create_finetune_parser.add_argument(
# "--validation-file",
# "-v",
# default=None,
# help="The ID of an uploaded file that contains validation data.",
# type=str,
# )
create_finetune_parser.add_argument(
"--model",
"-m",
Expand All @@ -57,53 +57,57 @@ def _add_create(parser: argparse._SubParsersAction[argparse.ArgumentParser]) ->
create_finetune_parser.add_argument(
"--batch-size",
"-b",
default=None,
default=32,
help="The batch size to use for training.",
type=int,
)
create_finetune_parser.add_argument(
"--learning-rate-multiplier",
"-lrm",
default=None,
"--learning-rate",
"-lr",
default=0.00001,
help="The learning rate multiplier to use for training.",
type=float,
)
create_finetune_parser.add_argument(
"--prompt-loss-weight",
"-plw",
default=0.01,
help="The weight to use for loss on the prompt tokens.",
type=float,
)
create_finetune_parser.add_argument(
"--compute-classification-metrics",
"-ccm",
default=False,
action="store_true",
help="Calculate classification-specific metrics using the validation set.",
)
create_finetune_parser.add_argument(
"--classification-n-classes",
"-cnc",
default=None,
help="The number of classes in a classification task.",
type=int,
)
create_finetune_parser.add_argument(
"--classification-positive-class",
"-cpc",
default=None,
help="The positive class in binary classification.",
type=str,
)
create_finetune_parser.add_argument(
"--classification-betas",
"-cb",
default=None,
help="Calculate F-beta scores at the specified beta values.",
nargs="+",
type=float,
)
# create_finetune_parser.add_argument(
# "--warmup-steps",
# "-ws",
# default=0,
# help="Warmup steps",
# type=int,
# )
# create_finetune_parser.add_argument(
# "--train-warmup-steps",
# "-tws",
# default=0,
# help="Train warmup steps",
# type=int,
# )
# create_finetune_parser.add_argument(
# "--sequence-length",
# "-sl",
# default=2048,
# help="Max sequence length",
# type=int,
# )
# create_finetune_parser.add_argument(
# "--seed",
# default=42,
# help="Training seed",
# type=int,
# )
# create_finetune_parser.add_argument(
# "--fp32",
# help="Enable FP32 training. Defaults to false (FP16 training).",
# default=False,
# action="store_true",
# )
# create_finetune_parser.add_argument(
# "--checkpoint-steps",
# "-b",
# default=0,
# help="Number of steps between each checkpoint. Defaults to 0 = checkpoints per epoch.",
# type=int,
# )
create_finetune_parser.add_argument(
"--suffix",
"-s",
Expand Down Expand Up @@ -244,16 +248,17 @@ def _run_create(args: argparse.Namespace) -> None:

response = finetune.create_finetune(
training_file=args.training_file, # training file_id
validation_file=args.validation_file, # validation file_id
# validation_file=args.validation_file, # validation file_id
model=args.model,
n_epochs=args.n_epochs,
batch_size=args.batch_size,
learning_rate_multiplier=args.learning_rate_multiplier,
prompt_loss_weight=args.prompt_loss_weight,
compute_classification_metrics=args.compute_classification_metrics,
classification_n_classes=args.classification_n_classes,
classification_positive_class=args.classification_positive_class,
classification_betas=args.classification_betas,
learning_rate=args.learning_rate,
# warmup_steps=args.warmup_steps,
# train_warmup_steps=args.train_warmup_steps,
# seq_length=args.sequence_length,
# seed=args.seed,
# fp16=not args.fp32,
# checkpoint_steps=args.checkpoint_steps,
suffix=args.suffix,
)

Expand Down
33 changes: 17 additions & 16 deletions src/together/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,31 @@ def __init__(
def create_finetune(
self,
training_file: str, # training file_id
validation_file: Optional[str] = None, # validation file_id
# validation_file: Optional[str] = None, # validation file_id
model: Optional[str] = None,
n_epochs: Optional[int] = 4,
batch_size: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
prompt_loss_weight: Optional[float] = 0.01,
compute_classification_metrics: Optional[bool] = False,
classification_n_classes: Optional[int] = None,
classification_positive_class: Optional[str] = None,
classification_betas: Optional[List[Any]] = None,
n_epochs: Optional[int] = 1,
batch_size: Optional[int] = 32,
learning_rate: Optional[float] = 0.00001,
# warmup_steps: Optional[int] = 0,
# train_warmup_steps: Optional[int] = 0,
# seq_length: Optional[int] = 2048,
# seed: Optional[int] = 42,
# fp16: Optional[bool] = True,
# checkpoint_steps: Optional[int] = None,
suffix: Optional[str] = None,
) -> Dict[Any, Any]:
parameter_payload = {
"training_file": training_file,
"validation_file": validation_file,
# "validation_file": validation_file,
"model": model,
"n_epochs": n_epochs,
"batch_size": batch_size,
"learning_rate_multiplier": learning_rate_multiplier,
"prompt_loss_weight": prompt_loss_weight,
"compute_classification_metrics": compute_classification_metrics,
"classification_n_classes": classification_n_classes,
"classification_positive_class": classification_positive_class,
"classification_betas": classification_betas,
"learning_rate": learning_rate,
# "warmup_steps": warmup_steps,
# "train_warmup_steps": train_warmup_steps,
# "seq_length": seq_length,
# "seed": seed,
# "fp16": fp16,
"suffix": suffix,
}

Expand Down

0 comments on commit 1faa861

Please sign in to comment.