From e88e0ecb6e3aaf3020fc9bb99ba48d84c8e7fc74 Mon Sep 17 00:00:00 2001 From: Corey Derochie Date: Thu, 26 Sep 2024 17:59:41 -0500 Subject: [PATCH] Added documentation to explain the flow better. --- src/misc/msccl/msccl_status.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/misc/msccl/msccl_status.cc b/src/misc/msccl/msccl_status.cc index c01563b16..d8ed2d698 100644 --- a/src/misc/msccl/msccl_status.cc +++ b/src/misc/msccl/msccl_status.cc @@ -29,6 +29,9 @@ static vector> rankStates; static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { static thread_local shared_ptr threadRankState; + // Calling code can allocate states for the number of ranks at an appropriate time. + // It is assumed that all threads will call this function simultaneously with the + // same rankCount, which would avoid race conditions later in the function. if (rankCount > 0) { lock_guard lock(rankStatesMutex); if (rankStates.size() < rankCount) { @@ -37,6 +40,7 @@ static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { } if (rank < 0 || rank >= rankStates.size()) { + // threadRankState is used when no rank state can be returned (rank<0 or rank not in rankStates) if (!threadRankState) { threadRankState.reset(new mscclRankState()); } @@ -44,6 +48,7 @@ static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { } if (!rankStates[rank]) { + // When no state is yet assigned to a rank, use the current thread's threadRankState. if (!threadRankState) { threadRankState.reset(new mscclRankState()); } @@ -51,6 +56,8 @@ static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { } if (!threadRankState) { + // Cache this rank's state in threadRankState in case this thread calls with rank<0 later. + // NOTE: When multiple ranks share a thread, only the first rank in will be used for rank<0. threadRankState = rankStates[rank]; }