Skip to content

Commit

Permalink
fix(parameter_sync): enable multi-thread all-gather and all-to-all (#200
Browse files Browse the repository at this point in the history
)

* add log for first sync_parameters

* enable mutli-thread `all-gather` and `all-to-all`
  • Loading branch information
haolin-nju authored Jan 6, 2025
1 parent 6e56d78 commit b9cc669
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
7 changes: 6 additions & 1 deletion chatlearn/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,13 @@ def learn(self):
self.runtime_args.max_relay_episode,
self.runtime_args.relay_episode_offset)
logger.info(f"{LOG_START} " + get_full_proc_memory_info('Before first param sync'))
self.timers("sync_parameters").start()
self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync)
logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync'))
self.timers("sync_parameters").stop()
logger.info(
f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " \
+ get_full_proc_memory_info('After first param sync')
)
self._data_loader = data_loader
for episode_id in range(self._start_episode, self.runtime_args.num_episode):
if self.runtime_args.nsys:
Expand Down
11 changes: 8 additions & 3 deletions chatlearn/synchronizer/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,14 +994,18 @@ def validate_sync_results_parallel(self, actor_mappings_list:List, requires_grad
else:
execute_in_parallel(self.validate_sync_results, args)

def _calculate_max_workers(self, sorted_send_actors, actor_mapping):
def _calculate_max_workers(self, sorted_send_actors, actor_mappings=None):
max_workers = get_args().runtime_args.param_sync_max_workers
if max_workers is None:
max_workers = max(self.src_model.total_gpu // 8, 1)
if max_workers == -1:
if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
max_workers = len(sorted_send_actors)
else:
assert actor_mappings is not None, (
"actor_mappings should not be None when max_workers is -1 and "
"communication type for parameter synchronization is not broadcast."
)
max_workers = len(sorted_send_actors) * len(actor_mappings[sorted_send_actors[0]])
return max_workers

Expand Down Expand Up @@ -1387,19 +1391,20 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False):
if self.concurrent_comm:
assert self.dst_model.use_vllm_backend

max_workers = self._calculate_max_workers(self.send_actors_to_regroup_routed_experts)
if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER:
# allgather routed experts only
self.sync_allgather_multi_threads(
[self.send_actors_to_regroup_routed_experts],
max_workers=1,
max_workers=max_workers,
requires_grad=requires_grad,
group_name=self.group_name + "_allgather",
filter_fn=self.routed_experts_filter)
elif self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL:
# alltoall routed experts only
self.sync_alltoall_multi_threads(
[self.send_actors_to_regroup_routed_experts],
max_workers=1,
max_workers=max_workers,
requires_grad=requires_grad,
filter_fn=self.routed_experts_filter)

Expand Down

0 comments on commit b9cc669

Please sign in to comment.