Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support streaming datasets #233

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_parse_arguments_defaults(job_config):
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
assert training_args.save_strategy.value == "epoch"
assert training_args.save_strategy == "epoch"


def test_parse_arguments_peft_method(job_config):
Expand Down
17 changes: 16 additions & 1 deletion tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class DataArguments:


@dataclass
class TrainingArguments(transformers.TrainingArguments):
class TrainingArguments(
transformers.TrainingArguments
): # pylint: disable=too-many-instance-attributes
cache_dir: Optional[str] = field(default=None)
# optim: str = field(default=DEFAULT_OPTIMIZER)
max_seq_length: int = field(
Expand Down Expand Up @@ -122,6 +124,19 @@ class TrainingArguments(transformers.TrainingArguments):
+ "Requires additional configs, see tuning.configs/tracker_configs.py"
},
)
streaming: bool = field(
default=False,
metadata={"help": "set to True to stream data during training"},
)

def __post_init__(self):
# when using iterable datasets it is needed to provide data dispatch strategy
# if split_batches is True, the data is fetched only by rank 0 process
# and is distributed across worker processes
# related - https://github.com/huggingface/accelerate/issues/2023
if self.streaming:
self.split_batches
self.accelerator_config = {"split_batches": True}


@dataclass
Expand Down
15 changes: 14 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@
USER_ERROR_EXIT_CODE,
write_termination_log,
)
from tuning.utils.preprocessing_utils import get_data_collator, validate_data_args
from tuning.utils.preprocessing_utils import (
get_data_collator,
validate_data_args,
validate_train_args,
)


def train(
Expand Down Expand Up @@ -242,6 +246,8 @@ def train(

# Validate if data args are set properly
validate_data_args(data_args, packing)
validate_train_args(train_args=train_args)

data_collator = get_data_collator(packing, data_args.response_template, tokenizer)

# load the data by parsing JSON
Expand Down Expand Up @@ -311,6 +317,13 @@ def train(
}
training_args = SFTConfig(**transformer_kwargs)

if train_args.streaming:
formatted_train_dataset = formatted_train_dataset.to_iterable_dataset()
if formatted_validation_dataset:
formatted_validation_dataset = (
formatted_validation_dataset.to_iterable_dataset()
)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand Down
7 changes: 7 additions & 0 deletions tuning/utils/preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
from tuning.config import configs


def validate_train_args(train_args: configs.TrainingArguments):
if train_args.streaming:
# IterableDatasets do not yet support training in terms of epochs yet
if train_args.max_steps == -1:
raise ValueError("IterableDatasets only support max_steps for training")


def validate_data_args(data_args: configs.DataArguments, packing: bool):

assert isinstance(
Expand Down
Loading