Skip to content

Commit

Permalink
fix: Refactor split_cfgs to correctly group configs for parallel tr…
Browse files Browse the repository at this point in the history
…aining
  • Loading branch information
samuelstevens committed Dec 5, 2024
1 parent 2355738 commit 2023fba
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions saev/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,22 +421,24 @@ def split_cfgs(cfgs: list[config.Train]) -> list[list[config.Train]]:
Returns:
A list of lists, where the configs in each sublist do not differ in any keys that are in `CANNOT_PARALLELIZE`. This means that each sublist is a valid "parallel" set of configs for `train`.
"""

seen = collections.defaultdict(list)
# Group configs by their values for CANNOT_PARALLELIZE keys
groups = {}
for cfg in cfgs:
dct = dataclasses.asdict(cfg)
dct = helpers.flattened(dct)
for key, value in dct.items():
seen[key].append(value)

bad_keys = {}
for key, values in seen.items():
if key in CANNOT_PARALLELIZE and len(set(values)) != 1:
bad_keys[key] = values

if bad_keys:
msg = ", ".join(f"'{key}': {values}" for key, values in bad_keys.items())
raise ValueError(f"Cannot parallelize training over: {msg}")

# Create a key tuple from the values of CANNOT_PARALLELIZE keys
key_values = []
for key in sorted(CANNOT_PARALLELIZE):
key_values.append((key, dct[key]))
group_key = tuple(key_values)

if group_key not in groups:
groups[group_key] = []
groups[group_key].append(cfg)

# Convert groups dict to list of lists
return list(groups.values())


##############
Expand Down

0 comments on commit 2023fba

Please sign in to comment.