Skip to content

Commit

Permalink
SFT improvements (labeling fixes, different packing implementations) (#…
Browse files Browse the repository at this point in the history
…1240)

* - add different packing impl (Unpacked, packing until overflow)
- fix labels to also have valid/test implementations
- fix label masking in _get_batch to also include anything from get_ltor_masks_and_position_ids

* Update arguments.py to use train_label_data_paths instead of label_data_paths

* - fix precommit
  • Loading branch information
dmahan93 authored Aug 27, 2024
1 parent 591563d commit c786367
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 54 deletions.
37 changes: 34 additions & 3 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def build_the_dataset(
data_prefix,
name,
data_impl,
pack_impl,
allow_chopped,
num_samples,
seq_length,
seed,
Expand Down Expand Up @@ -83,6 +85,8 @@ def build_the_dataset(
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
)
Expand All @@ -93,6 +97,8 @@ def build_train_valid_test_datasets(
data_prefix,
use_shared_fs,
data_impl,
pack_impl,
allow_chopped,
splits_string,
train_valid_test_num_samples,
seq_length,
Expand Down Expand Up @@ -138,6 +144,8 @@ def build_dataset(index, name):
train_valid_test_num_samples[index],
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
use_shared_fs=use_shared_fs,
)
return dataset
Expand Down Expand Up @@ -204,12 +212,25 @@ def build_weighted_datasets(
):
# build individual datasets
train_datasets, valid_datasets, test_datasets = [], [], []
for i, (train_path, label_path, valid_path, test_path) in enumerate(
for i, (
train_path,
train_label_path,
valid_path,
valid_label_path,
test_path,
test_label_path,
) in enumerate(
zip_longest(
neox_args.train_data_paths,
neox_args.label_data_paths if neox_args.label_data_paths else [],
neox_args.train_label_data_paths
if neox_args.train_label_data_paths
else [],
neox_args.valid_data_paths,
neox_args.valid_label_data_paths
if neox_args.valid_label_data_paths
else [],
neox_args.test_data_paths,
neox_args.test_label_data_paths if neox_args.test_label_data_paths else [],
)
):
if train_path:
Expand All @@ -218,12 +239,14 @@ def build_weighted_datasets(
data_prefix=train_path,
name=f"train_{i}",
data_impl=neox_args.data_impl,
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
label_prefix=train_label_path,
)
)

Expand All @@ -233,11 +256,14 @@ def build_weighted_datasets(
data_prefix=valid_path,
name=f"valid_{i}",
data_impl=neox_args.data_impl,
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=valid_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=valid_label_path,
)
)

Expand All @@ -247,11 +273,14 @@ def build_weighted_datasets(
data_prefix=test_path,
name=f"test_{i}",
data_impl=neox_args.data_impl,
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=test_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=test_label_path,
)
)
return train_datasets, valid_datasets, test_datasets
Expand Down Expand Up @@ -414,6 +443,8 @@ def build_train_valid_test_data_iterators(neox_args):
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
)

# Build dataloders.
Expand Down
188 changes: 152 additions & 36 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,19 @@ def __init__(
num_samples,
seq_length,
seed,
pack_impl="packed",
allow_chopped=True,
build_index_mappings=True,
use_shared_fs=True,
label_dataset=None,
):

self.name = name
self.pack_impl = pack_impl
self.allow_chopped = allow_chopped
self.indexed_dataset = indexed_dataset
self.label_dataset = label_dataset
self.seq_length = seq_length

# Checks
assert np.min(documents) >= 0
Expand All @@ -56,10 +61,13 @@ def __init__(
data_prefix,
documents,
self.indexed_dataset.sizes,
self.label_dataset,
num_samples,
seq_length,
seed,
self.pack_impl,
use_shared_fs=use_shared_fs,
allow_chopped=self.allow_chopped,
)
self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1
self.sample_idx_len = self.sample_idx.shape[0] - 1
Expand Down Expand Up @@ -113,8 +121,38 @@ def __getitem__(self, idx):
samples.append(np.concatenate(sample_list))

if len(datasets) == 1:
if len(samples[0]) < (self.seq_length + 1):
# Pad with -100s so the masking function can ignore these.
samples[0] = np.pad(
samples[0],
(0, (self.seq_length + 1) - len(samples[0])),
mode="constant",
constant_values=-100,
)
elif len(samples[0]) > (self.seq_length + 1):
# Check for overflow and truncate.
samples[0] = samples[0][: (self.seq_length + 1)]
return {"text": np.array(samples[0], dtype=np.int64)}
else:
if len(samples[0]) < (self.seq_length + 1):
# Pad with 0s, can use any number since it's masked.
samples[0] = np.pad(
samples[0],
(0, (self.seq_length + 1) - len(samples[0])),
mode="constant",
constant_values=0,
)
# pad with -100s so we can mask it out
samples[1] = np.pad(
samples[1],
(0, (self.seq_length + 1) - len(samples[1])),
mode="constant",
constant_values=-100,
)
elif len(samples[0]) > (self.seq_length + 1):
# Check for overflow and truncate.
samples[0] = samples[0][: (self.seq_length + 1)]
samples[1] = samples[1][: (self.seq_length + 1)]
return {
"text": np.array(samples[0], dtype=np.int64),
"label": np.array(samples[1], dtype=np.int64),
Expand All @@ -132,10 +170,13 @@ def _build_index_mappings(
data_prefix,
documents,
sizes,
label_dataset,
num_samples,
seq_length,
seed,
packing_impl,
use_shared_fs=True,
allow_chopped=True,
):
"""Build doc-idx, sample-idx, and shuffle-idx.
doc-idx: is an array (ordered) of documents to be used in training.
Expand All @@ -155,6 +196,9 @@ def _build_index_mappings(
_filename += "_{}ns".format(num_samples)
_filename += "_{}sl".format(seq_length)
_filename += "_{}s".format(seed)
_filename += "_{}pi".format(packing_impl)
if allow_chopped:
_filename += "_ac"
doc_idx_filename = _filename + "_doc_idx.npy"
sample_idx_filename = _filename + "_sample_idx.npy"
shuffle_idx_filename = _filename + "_shuffle_idx.npy"
Expand All @@ -177,44 +221,116 @@ def _build_index_mappings(
)
# doc-idx.
start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(
" > elapsed time to build and save doc-idx mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
# sample-idx.
start_time = time.time()
# Use C++ implementation for speed.
from megatron.data import helpers

assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32

num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length
if 2 * (num_samples + 1) < np.iinfo(np.int32).max:
sample_idx = helpers.build_sample_idx_int32(
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
if packing_impl == "packed":
doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(
" > elapsed time to build and save doc-idx mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
else:
sample_idx = helpers.build_sample_idx_int64(
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
# sample-idx.
start_time = time.time()
# Use C++ implementation for speed.
from megatron.data import helpers

assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32

num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length
if 2 * (num_samples + 1) < np.iinfo(np.int32).max:
sample_idx = helpers.build_sample_idx_int32(
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
)
else:
sample_idx = helpers.build_sample_idx_int64(
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(
" > elapsed time to build and save sample-idx mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(
" > elapsed time to build and save sample-idx mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
# shuffle-idx.
start_time = time.time()
# -1 is due to data structure used to retrieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(
" > elapsed time to build and save shuffle-idx mapping"
" (seconds): {:4f}".format(time.time() - start_time)
)
# shuffle-idx.
start_time = time.time()
# -1 is due to data structure used to retrieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(
" > elapsed time to build and save shuffle-idx mapping"
" (seconds): {:4f}".format(time.time() - start_time)
)
elif packing_impl == "pack_until_overflow":
# Naively pack data until it overflows, then roll it over to a new one instead.
shuffle_idx = np.arange(num_samples) # Shuffle index around epochs
np_rng.shuffle(shuffle_idx)
sample_idx = []
doc_idx = []
# Iterate over files until we have enough samples.
temp_shuffle_idx = np.arange(len(documents))
np_rng.shuffle(temp_shuffle_idx)
running_length = 0
curr_shuffle_idx = 0
while len(sample_idx) < num_samples:
if not allow_chopped:
# +1 since we shift left/right by 1
if sizes[temp_shuffle_idx[curr_shuffle_idx]] > seq_length + 1:
curr_shuffle_idx += 1
continue
# First, check if we need to skip this item...
if label_dataset is not None:
if np.all(
label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[
: seq_length + 1
]
== -100
):
curr_shuffle_idx += 1
continue
doc_length = sizes[temp_shuffle_idx[curr_shuffle_idx]]
if running_length == 0:
sample_idx.append(np.array([len(doc_idx), 0]))
doc_idx.append(temp_shuffle_idx[curr_shuffle_idx])
running_length += doc_length
else:
if running_length + doc_length > (seq_length + 1):
running_length = doc_length
sample_idx.append(np.array([len(doc_idx), 0]))
else:
running_length += doc_length
doc_idx.append(temp_shuffle_idx[curr_shuffle_idx])
curr_shuffle_idx += 1
if curr_shuffle_idx == len(documents):
curr_shuffle_idx = 0
np_rng.shuffle(temp_shuffle_idx)
sample_idx.append(np.array([len(doc_idx), 0]))
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
elif packing_impl == "unpacked":
# Unpacked data, one sample per document.
shuffle_idx = np.arange(num_samples) # Shuffle index around epochs
np_rng.shuffle(shuffle_idx)
sample_idx = np.zeros((num_samples + 1, 2), dtype=np.int64)
sample_idx[:, 0] = np.array([i for i in range(num_samples + 1)])
sample_idx[:, 1] = 0
doc_idx = list()
doc_i = 0
while len(doc_idx) <= num_samples:
if not allow_chopped:
# +1 since we shift left/right by 1
if sizes[doc_i] > seq_length + 1:
doc_i = (doc_i + 1) % len(documents)
continue
# Just in case we have bad data in the loop...
if np.all(label_dataset.get(doc_i)[:seq_length] == -100):
doc_i = (doc_i + 1) % len(documents)
continue
doc_idx.append(doc_i)
doc_i = (doc_i + 1) % len(documents)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
Expand Down
6 changes: 2 additions & 4 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,10 +1121,8 @@ def calculate_derived(self):
if self.test_data_paths and (self.test_data_weights is None):
self.test_data_weights = [1.0] * len(self.test_data_paths)

if self.label_data_paths:
err_str = (
"Must use `label_data_paths` with `train_data_paths`, not `data_path`"
)
if self.train_label_data_paths:
err_str = "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`"
assert self.train_data_paths and not self.data_path, err_str

# if a sample input file is provided, default text_gen_type type to input-file
Expand Down
Loading

0 comments on commit c786367

Please sign in to comment.