Skip to content

Commit

Permalink
refactor: Improve config parallelization validation with detailed err…
Browse files Browse the repository at this point in the history
…or reporting
  • Loading branch information
samuelstevens committed Dec 5, 2024
1 parent 50cb4e3 commit 2355738
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions saev/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,14 +414,29 @@ def __iter__(self):
def split_cfgs(cfgs: list[config.Train]) -> list[list[config.Train]]:
"""
Splits configs into groups that can be parallelized.
Arguments:
A list of configs from a sweep file.
Returns:
A list of lists, where the configs in each sublist do not differ in any keys that are in `CANNOT_PARALLELIZE`. This means that each sublist is a valid "parallel" set of configs for `train`.
"""
# Group configs by n_workers value
groups = collections.defaultdict(list)

seen = collections.defaultdict(list)
for cfg in cfgs:
groups[cfg.n_workers].append(cfg)

# Return list of groups
return list(groups.values())
dct = dataclasses.asdict(cfg)
dct = helpers.flattened(dct)
for key, value in dct.items():
seen[key].append(value)

bad_keys = {}
for key, values in seen.items():
if key in CANNOT_PARALLELIZE and len(set(values)) != 1:
bad_keys[key] = values

if bad_keys:
msg = ", ".join(f"'{key}': {values}" for key, values in bad_keys.items())
raise ValueError(f"Cannot parallelize training over: {msg}")


##############
Expand Down

0 comments on commit 2355738

Please sign in to comment.