Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A two-stage copy design with scratch buffer #445

Open
wants to merge 9 commits into
base: anantharamus/broadcast-amd
Choose a base branch
from
168 changes: 117 additions & 51 deletions apps/nccl/src/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ __global__ void __launch_bounds__(1024, 1)
const size_t nWarp = nThread / WARP_SIZE;
const size_t nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
const size_t peerIdx = blockIdx.x; // Stores the peerIdx.

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> smChans[NRANKS_PER_NODE - 1];
if (threadIdx.x < nPeer) {
Expand All @@ -35,12 +36,15 @@ __global__ void __launch_bounds__(1024, 1)
__syncthreads();

const size_t peerRootIdx = (root == rank) ? nPeer : ((root < rank) ? root : (root - 1));
const size_t rootsmaller = (root < rank) ? 1 : 0;

const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
const size_t bytes = bytesPerGPU;
size_t unitBytesPerThread;
if (bytes * nPeer >= nThread * 64) {
if (bytes >= nThread * 64) {
unitBytesPerThread = 64;
// unitBytesPerThread = 16;
// unitBytesPerThread = 32;
} else {
unitBytesPerThread = 16;
}
Expand All @@ -53,93 +57,155 @@ __global__ void __launch_bounds__(1024, 1)

size_t scratchSub = 0;

// printf("nLoop = %ld, bytes = %ld, unitBytes = %ld, bytes mod unitBytes = %ld \n", nLoop, bytes, unitBytes,
// bytes % unitBytes);

// First loop will always fit the scratch size.
if (nLoop > 0) {
// First loop unrolling
const size_t offset = blockIdx.x * unitBytesPerBlock;
if (rank == root) {
const size_t offset = blockIdx.x * unitBytesPerBlock;
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[threadIdx.x].signal();
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}

} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
const size_t offset = (rank - rootsmaller) * unitBytesPerBlock;
if (blockIdx.x == (rank - rootsmaller) && threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.

// Step 2.
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
if (peerIdx != peerRootIdx) {
smChans[peerIdx].copy<16, false>(dst + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
__syncthreads();
if (threadIdx.x != peerRootIdx && threadIdx.x < nPeer) {
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.
//__syncthreads();
{
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this bracket needed?

const size_t offset = blockIdx.x * unitBytesPerBlock;
smChans[peerIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
}
}

for (size_t i = 1; i < nLoop; ++i) {
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
if (i % nLoopToSync == 0) { // Sync to reuse scratch buff
scratchSub = -i * unitBytes;
deviceSyncer.sync(gridDim.x);
if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
}
if (rank == root) {
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.

smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[threadIdx.x].signal();
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
const size_t offset = (rank - rootsmaller) * unitBytesPerBlock + i * unitBytes;
if (blockIdx.x == (rank - rootsmaller) && threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.

// Step 2.
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock,
threadIdx.x, blockDim.x);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
if (peerIdx != peerRootIdx) {
smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, scratch_ + offset + scratchSub, unitBytesPerBlock,
threadIdx.x, blockDim.x);
}
__syncthreads();
if (threadIdx.x != peerRootIdx && threadIdx.x < nPeer) {
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.
{
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
smChans[peerIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock, threadIdx.x,
blockDim.x);
}
}
}

// Remainder loop will also fit the scratch buff since we subtract unitBytes from SCRATCH_SIZE.
if (bytes % unitBytes > 0) { // remainder.
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes = (offset < bytes) ? (bytes - offset) : 0;
if (remainBytes > 0) {
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
// const size_t remainTotalBytes = bytes - nLoop * unitBytes;
// const size_t nblocks_to_use_base = remainTotalBytes / unitBytesPerBlock;
// const size_t nblocks_to_use =
// (remainTotalBytes % unitBytesPerBlock) ? nblocks_to_use_base + 1 : nblocks_to_use_base;

// printf("nLoop = %ld, bytes = %ld, nblocks_to_use = %ld\n", nLoop, bytes, nblocks_to_use);

// if (blockIdx.x < nblocks_to_use) {
if (rank == root) {
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes =
offset < bytes ? ((bytes - offset) > unitBytesPerBlock ? unitBytesPerBlock : (bytes - offset)) : 0;
char* send_ = reinterpret_cast<char*>(sendbuff);
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.

smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[threadIdx.x].signal();
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
blockDim.x);
smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
}
} // remainBytes > 0.

} else { // rank != root.
const size_t offset = (rank - rootsmaller) * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes =
(offset < bytes) ? ((bytes - offset) > unitBytesPerBlock ? unitBytesPerBlock : (bytes - offset)) : 0;

if (blockIdx.x == (rank - rootsmaller) && threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.
__syncthreads();

// Step 2.
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
if (peerIdx != peerRootIdx) {
smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, scratch_ + offset + scratchSub, remainBytes,
threadIdx.x, blockDim.x);
}
__syncthreads();
if (threadIdx.x != peerRootIdx && threadIdx.x < nPeer) {
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.
{
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes =
(offset < bytes) ? ((bytes - offset) > unitBytesPerBlock ? unitBytesPerBlock : (bytes - offset)) : 0;
smChans[peerIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
blockDim.x);
}
}
//} // remainBytes > 0.
}

deviceSyncer.sync(gridDim.x);
Expand Down
Loading