diff --git a/saev/training.py b/saev/training.py index 8c12cd6..1b6f163 100644 --- a/saev/training.py +++ b/saev/training.py @@ -421,22 +421,24 @@ def split_cfgs(cfgs: list[config.Train]) -> list[list[config.Train]]: Returns: A list of lists, where the configs in each sublist do not differ in any keys that are in `CANNOT_PARALLELIZE`. This means that each sublist is a valid "parallel" set of configs for `train`. """ - - seen = collections.defaultdict(list) + # Group configs by their values for CANNOT_PARALLELIZE keys + groups = {} for cfg in cfgs: dct = dataclasses.asdict(cfg) dct = helpers.flattened(dct) - for key, value in dct.items(): - seen[key].append(value) - - bad_keys = {} - for key, values in seen.items(): - if key in CANNOT_PARALLELIZE and len(set(values)) != 1: - bad_keys[key] = values - - if bad_keys: - msg = ", ".join(f"'{key}': {values}" for key, values in bad_keys.items()) - raise ValueError(f"Cannot parallelize training over: {msg}") + + # Create a key tuple from the values of CANNOT_PARALLELIZE keys + key_values = [] + for key in sorted(CANNOT_PARALLELIZE): + key_values.append((key, dct[key])) + group_key = tuple(key_values) + + if group_key not in groups: + groups[group_key] = [] + groups[group_key].append(cfg) + + # Convert groups dict to list of lists + return list(groups.values()) ##############