diff --git a/src/collectives/all_gather.cc b/src/collectives/all_gather.cc index 97ec981..738be95 100644 --- a/src/collectives/all_gather.cc +++ b/src/collectives/all_gather.cc @@ -20,6 +20,6 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun struct ncclInfo info = { ncclFuncAllGather, "AllGather", sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */ - ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS }; + }; return ncclEnqueueCheck(&info); } diff --git a/src/collectives/all_reduce.cc b/src/collectives/all_reduce.cc index 8ac61a2..c854e65 100644 --- a/src/collectives/all_reduce.cc +++ b/src/collectives/all_reduce.cc @@ -26,6 +26,6 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, struct ncclInfo info = { ncclFuncAllReduce, "AllReduce", sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */ - ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS }; + }; return ncclEnqueueCheck(&info); } diff --git a/src/collectives/broadcast.cc b/src/collectives/broadcast.cc index c73502e..c4222d5 100644 --- a/src/collectives/broadcast.cc +++ b/src/collectives/broadcast.cc @@ -24,7 +24,7 @@ ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, n struct ncclInfo info = { ncclFuncBroadcast, "Broadcast", sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */ - BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS }; + }; return ncclEnqueueCheck(&info); } /* Deprecated original "in place" function, similar to MPI */ diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index 4e82dd6..ddab804 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -17,7 +17,7 @@ namespace { const int nChannels = args->nChannels; ncclRing *ring = &ncclShmem.channel.ring; const int *ringRanks = ring->userRanks; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLGATHER_CHUNKSTEPS : 1)); + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice*args->slicesPerChunk); // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2); const int nranks = ncclShmem.comm.nRanks; @@ -27,7 +27,8 @@ namespace { T *inputBuf = (T*)args->sendbuff; T *outputBuf = (T*)args->recvbuff; Primitives, 1, Proto, 0> prims - (tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg); + (tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; @@ -79,7 +80,7 @@ namespace { template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { - using Proto = ProtoSimple; + using Proto = ProtoSimple<>; runRing(args); } }; diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index 3f12e5e..65573b2 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -17,7 +17,7 @@ namespace { const int nChannels = args->nChannels; ncclRing *ring = &ncclShmem.channel.ring; int ringIx = ring->index; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLREDUCE_CHUNKSTEPS : 1)); + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice*args->slicesPerChunk); const int nranks = ncclShmem.comm.nRanks; const ssize_t loopSize = nChannels*nranks*chunkSize; const ssize_t size = args->count; @@ -31,7 +31,8 @@ namespace { } Primitives, 1, Proto, 0> prims - (tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg); + (tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; @@ -103,7 +104,7 @@ namespace { ncclTree *tree = &ncclShmem.channel.tree; ssize_t chunkSize = int( Proto::Id == NCCL_PROTO_SIMPLE ? args->lastChunkSize - /* LL & LL128 */ : Proto::calcBytePerStep()/sizeof(T)); + /* LL & LL128 */ : Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice); const ssize_t minChunkSize = int( Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads-2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T)) /* LL & LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T))); @@ -115,7 +116,9 @@ namespace { { // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) Primitives, /*Direct=*/0, Proto, 0> prims - (tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg); + (tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk); + if (tree->up == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); @@ -141,7 +144,8 @@ namespace { { // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) Primitives, /*Direct=*/1, Proto, 0> prims - (tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg); + (tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk); if (tree->up == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); @@ -175,7 +179,7 @@ namespace { ncclTree *tree = &ncclShmem.channel.tree; ssize_t chunkSize = int( Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize - : Proto::calcBytePerStep()/sizeof(T)); + : Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice); const ssize_t minChunkSize = int( Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads - 2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T)) : Proto::Id == NCCL_PROTO_LL ? nthreads*(Proto::calcBytePerGrain()/sizeof(T)) @@ -199,7 +203,9 @@ namespace { if (tree->up == -1) { // Reduce and broadcast. Max number of recv is 3, max number of send is 3 Primitives, /*Direct=*/1, Proto, 0> - prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff, args->redOpArg); + prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); int nelem = min(chunkSize, size-offset); @@ -216,7 +222,8 @@ namespace { * but the ctor above for tree roots would be DirectRecv=0 DirectSend=1. */ Primitives, /*Direct=*/1, Proto, 0> - prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg, 0*Proto::MaxGroupWidth); + prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, 0*Proto::MaxGroupWidth); if (tree->down[0] == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); @@ -235,7 +242,8 @@ namespace { else { // Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local) Primitives, /*Direct=*/1, Proto, 0> - prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg, 1*Proto::MaxGroupWidth); + prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, 1*Proto::MaxGroupWidth); if (tree->down[0] == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); @@ -257,8 +265,7 @@ namespace { template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { - using Proto = ProtoSimple; - runRing(args); + runRing>(args); } }; @@ -266,9 +273,9 @@ template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { #if CUDART_VERSION >= 11020 && CUDART_VERSION < 11040 && __CUDA_ARCH__ >= 800 - runTreeUpDown>(args); + runTreeUpDown>(args); #else - runTreeSplit>(args); + runTreeSplit>(args); #endif } }; @@ -295,13 +302,14 @@ struct RunWorkElement; + using Proto = ProtoSimple<>; if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) { // Scatter int group = (2*Proto::MaxGroupWidth) | (1<<16); Primitives, /*Direct=*/1, Proto, 0> - prims(tid-tidStartScatter, nThreadsScatter, NULL, direct->up, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + prims(tid-tidStartScatter, nThreadsScatter, NULL, direct->up, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, group, args); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize; int nelem = min(direct->nHeads*chunkSize, size-offset); @@ -316,7 +324,8 @@ struct RunWorkElement, /*Direct=*/1, Proto, 0> - prims(tid-tidStartReduce, nThreadsReduce, direct->down, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + prims(tid-tidStartReduce, nThreadsReduce, direct->down, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, group, args); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); @@ -329,7 +338,8 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> - prims(tid-tidStartReduce, nThreadsReduce, nullptr, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg, group); + prims(tid-tidStartReduce, nThreadsReduce, nullptr, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, group); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); @@ -340,7 +350,8 @@ struct RunWorkElement, /*Direct=*/1, Proto, 0> - prims(tid, nThreadsGather, direct->up, NULL, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + prims(tid, nThreadsGather, direct->up, NULL, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, group, args); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize; int nelem = min(direct->nHeads*chunkSize, size-offset); @@ -351,7 +362,8 @@ struct RunWorkElement, /*Direct=*/1, Proto, 0> - prims(tid-tidStartBcast, nThreadsBcast, &direct->out, direct->down, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + prims(tid-tidStartBcast, nThreadsBcast, &direct->out, direct->down, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, group, args); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); @@ -360,7 +372,8 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> - prims(tid-tidStartBcast, nThreadsBcast, &direct->out, nullptr, args->sendbuff, args->recvbuff, args->redOpArg, group); + prims(tid-tidStartBcast, nThreadsBcast, &direct->out, nullptr, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, group); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); @@ -387,7 +400,7 @@ struct RunWorkElement= 256) nthreadsSplit += 64; int group, send, recv, groupTid, groupNthreads; - using Proto = ProtoSimple<1, 1>; + using Proto = ProtoSimple<>; if (tid < nthreadsSplit) { group = (0*Proto::MaxGroupWidth) | (1<<16); recv = tree->down[0]; @@ -403,7 +416,8 @@ struct RunWorkElement, /*Direct=*/1, Proto, 0> - prims(groupTid, groupNthreads, &recv, &send, args->sendbuff, args->recvbuff, args->redOpArg, group); + prims(groupTid, groupNthreads, &recv, &send, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk, group); if (tid < nthreadsSplit) { if (recv == -1) { diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h index ebe4381..0a59617 100644 --- a/src/collectives/device/broadcast.h +++ b/src/collectives/device/broadcast.h @@ -16,7 +16,7 @@ namespace { const int bid = args->bid; const int nChannels = args->nChannels; ncclRing *ring = &ncclShmem.channel.ring; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? BROADCAST_CHUNKSTEPS : 1)); + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice); const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))); const ssize_t loopSize = nChannels*chunkSize; const ssize_t size = args->count; @@ -27,7 +27,8 @@ namespace { T *inputBuf = (T*)args->sendbuff; T *outputBuf = (T*)args->recvbuff; Primitives, 0, Proto, 0> - prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg); + prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; @@ -62,7 +63,7 @@ namespace { template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { - using Proto = ProtoSimple; + using Proto = ProtoSimple<>; runRing(args); } }; diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index ccc0d22..d4d79ff 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -20,11 +20,9 @@ * to how that protocol operates with a consistent interface so that our * algorithm code can operate protocol parametrically. */ -template +template struct ProtoSimple { static constexpr int Id = NCCL_PROTO_SIMPLE; - static constexpr int SlicePerChunk = SlicePerChunk_1; - static constexpr int StepPerSlice = StepPerSlice_1; static constexpr int Unroll = Unroll_1; // Data bytes (no flags etc) in one step of the fifo queue. diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index 60f64ff..0bcef5c 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -19,6 +19,7 @@ class Primitives: const int wid; const int group; const int stepLines; + const int stepsPerChunk; Fan fan; T *userBufs[2]; struct ncclConnInfo* recvConn = NULL; @@ -40,8 +41,8 @@ class Primitives: inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; } inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); } inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); } - inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); } - inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); } + inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+stepsPerChunk); } + inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+stepsPerChunk); } inline __device__ void barrier() { if (nthreads == WARP_SIZE) @@ -64,34 +65,34 @@ class Primitives: inline __device__ void waitSend(int nbytes) { if (sendConnHeadPtr) { int spins = 0; - while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { + while (sendConnHeadCache + NCCL_STEPS < sendConnHead + stepsPerChunk) { sendConnHeadCache = *sendConnHeadPtr; if (checkAbort(spins, 1)) break; } if (sendConnFifoPtr) { - int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; + int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepsPerChunk*stepLines*sizeof(union ncclLLFifoLine) : nbytes; sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size; } - sendConnHead += 1; + sendConnHead += stepsPerChunk; } barrier(); } inline __device__ void incRecv(int i) { - recvStep[i] += 1; + recvStep[i] += stepsPerChunk; } inline __device__ void postRecv() { barrier(); - if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1; + if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += stepsPerChunk; } inline __device__ void incSend(int i, int offset) { // LL Cleanup : write all flags in the slice to make sure we don't have // data corruption when flag loops over. if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) { - for (int o = offset; o: __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) { recvBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL]; recvStep[i] = conn->step; + recvStep[i] = roundUp(recvStep[i], stepsPerChunk); if (wid == i) recvConn = conn; } __device__ __forceinline__ void loadRecvSync() { @@ -308,6 +310,7 @@ class Primitives: __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { sendBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL]; sendStep[i] = conn->step; + sendStep[i] = roundUp(sendStep[i], stepsPerChunk); if (wid == i) sendConn = conn; } __device__ __forceinline__ void loadSendSync() { @@ -322,10 +325,12 @@ class Primitives: public: __device__ Primitives( const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, - void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0 + void const *inputBuf, void *outputBuf, uint64_t redOpArg, int stepsPerChunk, + int ignored, int group=0 ): redOp(redOpArg), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group&(uint16_t)0xFFFF), + stepsPerChunk(stepsPerChunk), stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) { int connIndex = group >> 16; auto *channel = &ncclShmem.channel; diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index 773a921..4ea6b33 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -23,6 +23,7 @@ class Primitives: const int warpInBlock; // warp index in thread block const bool flagThread; const int group; + const int stepsPerChunk; Fan fan; T *userBufs[2]; struct ncclConnInfo* recvConn = NULL; @@ -46,8 +47,8 @@ class Primitives: inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; } inline __device__ uint64_t* recvPtr(int i) { return recvBuff[i]+recvOffset(i); } inline __device__ uint64_t* sendPtr(int i) { return sendBuff[i]+sendOffset(i); } - inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+1; } - inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; } + inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+stepsPerChunk; } + inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+stepsPerChunk; } inline __device__ void barrier() { asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(15-group)); @@ -67,22 +68,22 @@ class Primitives: inline __device__ void waitSend(int nbytes) { if (sendConnHeadPtr) { int spins = 0; - while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { + while (sendConnHeadCache + NCCL_STEPS < sendConnHead + stepsPerChunk) { sendConnHeadCache = *sendConnHeadPtr; if (checkAbort(spins, wid, 1)) break; } if (sendConnFifoPtr) { sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes; } - sendConnHead += 1; + sendConnHead += stepsPerChunk; } } inline __device__ void postRecv() { - if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1; + if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += stepsPerChunk; } inline __device__ void postSend() { - if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; } + if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += stepsPerChunk; } } template @@ -315,15 +316,16 @@ class Primitives: } barrier(); - if (SEND) for (int i=0; i < MaxSend; i++) sendStep[i] += 1; + if (SEND) for (int i=0; i < MaxSend; i++) sendStep[i] += stepsPerChunk; if (SEND) postSend(); - if (RECV) for (int i=0; i < MaxRecv; i++) recvStep[i] += 1; + if (RECV) for (int i=0; i < MaxRecv; i++) recvStep[i] += stepsPerChunk; if (RECV) postRecv(); } __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) { recvBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128]; recvStep[i] = conn->step; + recvStep[i] = roundUp(recvStep[i], stepsPerChunk); if (wid == i) recvConn = conn; } __device__ __forceinline__ void loadRecvSync() { @@ -336,6 +338,7 @@ class Primitives: __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { sendBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128]; sendStep[i] = conn->step; + sendStep[i] = roundUp(sendStep[i], stepsPerChunk); if (wid == i) sendConn = conn; } __device__ __forceinline__ void loadSendSync() { @@ -356,12 +359,14 @@ class Primitives: public: __device__ Primitives( const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, - void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0 + void const *inputBuf, void *outputBuf, uint64_t redOpArg, int stepsPerChunk, + int ignored, int group=0 ): redOp(redOpArg), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), warpInBlock(threadIdx.x/WARP_SIZE), flagThread((tid%8)==7), group(group&(uint16_t)0xFFFF), + stepsPerChunk(stepsPerChunk), stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) { int connIndex = group >> 16; auto *channel = &ncclShmem.channel; diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index 9d2d19a..44ab91c 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -4,10 +4,9 @@ * See LICENSE.txt for license information ************************************************************************/ -template +template class Primitives< - T, RedOp, Fan, Direct, ProtoSimple, P2p + T, RedOp, Fan, Direct, ProtoSimple, P2p > { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; @@ -27,6 +26,8 @@ class Primitives< int nthreads; int nworkers; const int stepSize; + const int stepsPerSlice; + const int slicesPerChunk; Fan fan; int index; // Peer index I'm responsible for int flags; @@ -79,10 +80,10 @@ class Primitives< if (((flags & (Recv*RoleWaitRecv)) && !noRecvWait) || ((flags & (Send*RoleWaitSend)) && !noSendWait)) { int spins = 0; - while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { + while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + stepsPerSlice) { connStepCache = *connStepPtr; if (checkAbort(spins)) break; - //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); + //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+stepsPerSlice)); } } @@ -114,14 +115,14 @@ class Primitives< else { ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; } - step += StepPerSlice; + step += stepsPerSlice; } } template inline __device__ void postPeer() { if (flags & (Recv*RolePostRecv | Send*RolePostSend)) { - step += StepPerSlice; + step += stepsPerSlice; *connStepPtr = step; } } @@ -136,8 +137,8 @@ class Primitives< constexpr int Dst = DstBuf != -1; nelem = nelem < 0 ? 0 : nelem; - int sliceSize = stepSize*StepPerSlice; - sliceSize = max(divUp(nelem, 16*SlicePerChunk)*16, sliceSize/32); + int sliceSize = stepSize*stepsPerSlice; + sliceSize = max(divUp(nelem, 16*slicesPerChunk)*16, sliceSize/32); int slice = 0; int offset = 0; @@ -165,12 +166,7 @@ class Primitives< // barrier(); // post(); // } // Since we no longer unroll, new branch added here - #if __CUDA_ARCH__ < 700 - // Yeah, so all that above don't matter a lick on older hardware. - #pragma unroll SlicePerChunk - #else - #pragma unroll 1 - #endif + #pragma unroll 1 do { sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; if (Src && (flags & (SrcBuf==Input ? RoleInput : RoleOutput))) @@ -214,7 +210,7 @@ class Primitives< postPeer(); offset += sliceSize; slice += 1; - } while (slice < SlicePerChunk && offset < nelem); + } while (slice < slicesPerChunk && offset < nelem); } // Non-workers come straight here. Workers too but only once the remaining @@ -222,7 +218,7 @@ class Primitives< // worker perf is the limiter, perf-wise this loop is effectively unentered, // hence just a single branch insn. #pragma unroll 1 - while (slice < SlicePerChunk) { + while (slice < slicesPerChunk) { sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; { // Only workers could have Wait roles so we know the slice must be empty // since we've exited the loop above. @@ -246,11 +242,11 @@ class Primitives< constexpr int DirectRecv = 1 && Direct && DirectRecv1; constexpr int DirectSend = 1 && Direct && DirectSend1; int offset = 0; // slice offset - int sliceSize = stepSize*StepPerSlice; - int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size + int sliceSize = stepSize*stepsPerSlice; + int dataSize = max(DIVUP(peerElem, 16*slicesPerChunk)*16, sliceSize/32); // per-peer slice size #pragma unroll - for (int slice=0; slicerecv[connIndex]; step = conn->step; - step = roundUp(step, SlicePerChunk*StepPerSlice); + step = roundUp(step, slicesPerChunk*stepsPerSlice); if (flags & RolePostRecv) { connStepPtr = conn->head; *connStepPtr = step; // Return credits in case we rounded up. @@ -353,7 +349,7 @@ class Primitives< if (flags & (RoleWaitSend|RolePostSend)) { auto *conn = &peer->send[connIndex]; step = conn->step; - step = roundUp(step, SlicePerChunk*StepPerSlice); + step = roundUp(step, slicesPerChunk*stepsPerSlice); if (flags & RolePostSend) { connStepPtr = conn->tail; } @@ -395,9 +391,11 @@ class Primitives< public: __device__ Primitives( int tid, int nthreads, int const *recvPeers, int const *sendPeers, - void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr + void const *inputBuf, void *outputBuf, uint64_t redOpArg, + int stepsPerSlice, int slicesPerChunk, + uint32_t group=0, struct ncclWorkElem* e = nullptr ): - tid(tid), + tid(tid), stepsPerSlice(stepsPerSlice), slicesPerChunk(slicesPerChunk), stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) { // For send operations, we need an extra warp to overlap the threadfence and the copy diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h index 0927037..d745cdf 100644 --- a/src/collectives/device/reduce.h +++ b/src/collectives/device/reduce.h @@ -16,7 +16,7 @@ namespace { const int bid = args->bid; const int nChannels = args->nChannels; ncclRing *ring = &ncclShmem.channel.ring; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCE_CHUNKSTEPS : 1)); + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice); const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))); const int nranks = ncclShmem.comm.nRanks; const ssize_t loopSize = nChannels*chunkSize; @@ -26,7 +26,8 @@ namespace { const int root = args->root; Primitives, 0, Proto, 0> - prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg); + prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk); auto calcChunkSize = [&]__device__(ssize_t gridOffset)->int { int realChunkSize; @@ -71,7 +72,7 @@ namespace { template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { - using Proto = ProtoSimple; + using Proto = ProtoSimple<>; runRing(args); } }; diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h index 754889a..afd65eb 100644 --- a/src/collectives/device/reduce_scatter.h +++ b/src/collectives/device/reduce_scatter.h @@ -17,7 +17,7 @@ namespace { const int nChannels = args->nChannels; ncclRing *ring = &ncclShmem.channel.ring; int const *ringRanks = ring->userRanks; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCESCATTER_CHUNKSTEPS : 1)); + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice*args->slicesPerChunk); // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2); const int nranks = ncclShmem.comm.nRanks; @@ -25,7 +25,8 @@ namespace { const ssize_t size = args->count; Primitives, 0, Proto, 0> - prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg); + prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg, + args->stepsPerSlice, args->slicesPerChunk); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; @@ -69,7 +70,7 @@ namespace { template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { - using Proto = ProtoSimple; + using Proto = ProtoSimple<>; runRing(args); } }; diff --git a/src/collectives/device/sendrecv.h b/src/collectives/device/sendrecv.h index ec1e20c..7120994 100644 --- a/src/collectives/device/sendrecv.h +++ b/src/collectives/device/sendrecv.h @@ -25,7 +25,8 @@ struct RunWork { if (args->proto == NCCL_PROTO_LL) chunkSize /= 2; int const peer = args->peer; Primitives, 1, Proto, 1> prims - (tid, nthreads, nullptr, &peer, buff, nullptr, /*redOpArg(ignored)=*/0, group); + (tid, nthreads, nullptr, &peer, buff, nullptr, /*redOpArg(ignored)=*/0, + 1<stepsPerChunkPow2, 1, group); size_t offset = 0; do { int nelem = min(size_t(chunkSize), count-offset); @@ -44,7 +45,8 @@ struct RunWork { if (args->proto == NCCL_PROTO_LL) chunkSize /= 2; // This is to account for chunkEffectiveSize int const peer = args->peer; Primitives, 1, Proto, 1> prims - (tid, nthreads, &peer, nullptr, nullptr, buff, /*redOpArg(ignored)=*/0, group); + (tid, nthreads, &peer, nullptr, nullptr, buff, /*redOpArg(ignored)=*/0, + 1<stepsPerChunkPow2, 1, group); size_t offset = 0; do { int nelem = min(size_t(chunkSize), count-offset); @@ -79,13 +81,13 @@ struct RunWork { if (args->proto == NCCL_PROTO_LL) { runRecv(tid, nthreads, group, args); } else { - runRecv>(tid, nthreads, group, args); + runRecv>(tid, nthreads, group, args); } } else { if (args->proto == NCCL_PROTO_LL) { runSend(tid, nthreads, group, args); } else { - runSend>(tid, nthreads, group, args); + runSend>(tid, nthreads, group, args); } } } diff --git a/src/collectives/reduce.cc b/src/collectives/reduce.cc index 6335516..ba027b6 100644 --- a/src/collectives/reduce.cc +++ b/src/collectives/reduce.cc @@ -28,6 +28,6 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, struct ncclInfo info = { ncclFuncReduce, "Reduce", sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */ - REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS }; + }; return ncclEnqueueCheck(&info); } diff --git a/src/collectives/reduce_scatter.cc b/src/collectives/reduce_scatter.cc index 5242545..4a13499 100644 --- a/src/collectives/reduce_scatter.cc +++ b/src/collectives/reduce_scatter.cc @@ -26,6 +26,6 @@ ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recv struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter", sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */ - REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS }; + }; return ncclEnqueueCheck(&info); } diff --git a/src/collectives/sendrecv.cc b/src/collectives/sendrecv.cc index 9a81b0a..b5828ae 100644 --- a/src/collectives/sendrecv.cc +++ b/src/collectives/sendrecv.cc @@ -25,8 +25,8 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload) struct ncclInfo info = { ncclFuncSend, "Send", - NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ - 1, 1 }; + NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream /* Args */ + }; ncclResult_t ret; NCCLCHECK(ncclGroupStart()); ret = ncclEnqueueCheck(&info); @@ -43,7 +43,7 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int struct ncclInfo info = { ncclFuncRecv, "Recv", NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ - 1, 1 }; + }; ncclResult_t ret; NCCLCHECK(ncclGroupStart()); ret = ncclEnqueueCheck(&info); diff --git a/src/enqueue.cc b/src/enqueue.cc index 0744e09..25d520e 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -332,6 +332,7 @@ static ncclResult_t addCollToPlan( } NCCL_PARAM(P2pLLThreshold, "P2P_LL_THRESHOLD", 16384); +NCCL_PARAM(P2pChunkSteps, "P2P_CHUNKSTEPS", (NCCL_STEPS/8)); // Put p2p op in plan assuming there is space in nWorkBudget, so you must // ensure *nWorkBudget >= 1 upon entry. @@ -343,8 +344,10 @@ static ncclResult_t addP2pToPlan( isSendNotRecv ? ncclFuncSend : ncclFuncRecv, isSendNotRecv ? "Send" : "Recv", nullptr, addr, bytes, ncclInt8, ncclSum, peer, comm, (cudaStream_t)0, - /*Args*/1, 1 + (int)ncclParamP2pChunkSteps(), (int)ncclParamP2pChunkSteps() }; + // Shared buffers do not support *steps>1. + if (comm->nNodes > 1) info.sliceSteps = info.chunkSteps = 1; int channelId; NCCLCHECK(ncclChannelCompute(comm, peer, chunk%comm->p2pnChannelsPerPeer, info.coll, &channelId)); @@ -363,6 +366,7 @@ static ncclResult_t addP2pToPlan( elem.peer = peer; elem.nWarps = NCCL_MAX_NTHREADS/WARP_SIZE; elem.p2pType = isSendNotRecv ? ncclWorkP2pTypeSend : ncclWorkP2pTypeRecv; + elem.stepsPerChunkPow2 = log2i(info.chunkSteps); elem.buffLo32 = uint32_t(reinterpret_cast(addr)); elem.buffHi32 = reinterpret_cast(addr)>>32; elem.countLo32 = uint32_t(bytes); @@ -609,7 +613,7 @@ static ncclResult_t scheduleP2pTasksToPlan( // Compute how much to split operations // Natural step size matching buffer steps. - ssize_t stepSize = comm->p2pChunkSize; + ssize_t chunkSize = comm->p2pChunkSize; // Try to use all channels int nChannelsMax = comm->p2pnChannelsPerPeer; int nChannelsMin = nChannelsMax; @@ -643,8 +647,8 @@ static ncclResult_t scheduleP2pTasksToPlan( char* sendPtr = send ? (char*)send->buff : nullptr; ssize_t recvBytes = recv ? recv->bytes : 0; ssize_t sendBytes = send ? send->bytes : 0; - ssize_t minSize = stepSize/8; - ssize_t maxSize = comm->nNodes > 1 ? stepSize : stepSize*32; + ssize_t minSize = chunkSize/8; + ssize_t maxSize = comm->nNodes > 1 ? chunkSize : chunkSize*32; ssize_t recvChunkBytesMax = calcP2pChunkSize(recvBytes, nChannelsMin, nChannelsMax, minSize, maxSize); ssize_t sendChunkBytesMax = calcP2pChunkSize(sendBytes, nChannelsMin, nChannelsMax, minSize, maxSize); // Zero size send/recv are syncs, encode here with -1. @@ -1258,6 +1262,40 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) { return ncclSuccess; } +NCCL_PARAM(LLChunkSteps, "LL_CHUNKSTEPS", (NCCL_STEPS/8)); +NCCL_PARAM(LL128ChunkSteps, "LL128_CHUNKSTEPS", (NCCL_STEPS/8)); +NCCL_PARAM(PipelineChunkSteps, "PIPELINE_CHUNKSTEPS", (NCCL_STEPS/8)); +NCCL_PARAM(RingChunkSteps, "RING_CHUNKSTEPS", (NCCL_STEPS/2)); +NCCL_PARAM(RingSliceSteps, "RING_SLICESTEPS", (NCCL_STEPS/4)); + +static ncclResult_t getStepInfo(struct ncclInfo* info) { + if (info->protocol == NCCL_PROTO_LL) { + info->chunkSteps = info->sliceSteps = ncclParamLLChunkSteps(); + } else if (info->protocol == NCCL_PROTO_LL128) { + info->chunkSteps = info->sliceSteps = ncclParamLL128ChunkSteps(); + } else { /* SIMPLE */ + if (info->algorithm == NCCL_ALGO_TREE || info->coll == ncclFuncBroadcast || info->coll == ncclFuncReduce) { + info->chunkSteps = info->sliceSteps = ncclParamPipelineChunkSteps(); + } else { + info->chunkSteps = ncclParamRingChunkSteps(); + info->sliceSteps = ncclParamRingSliceSteps(); + } + } + if (info->chunkSteps > NCCL_STEPS/2 || info->sliceSteps > NCCL_STEPS/2) { + WARN("Invalid chunkSteps=%d/sliceSteps=%d, must be at most NCCL_STEPS/2=%d\n", info->chunkSteps, info->sliceSteps, NCCL_STEPS/2); + return ncclInvalidUsage; + } + if (info->chunkSteps % info->sliceSteps) { + WARN("Invalid chunkSteps=%d, must be a multiple of sliceSteps=%d\n", info->chunkSteps, info->sliceSteps); + return ncclInvalidUsage; + } + if (info->chunkSteps / info->sliceSteps > NCCL_MAX_SLICE_PER_CHUNK) { + WARN("Invalid chunkSteps=%d, must be at most sliceSteps*%d=%d\n", info->chunkSteps, NCCL_MAX_SLICE_PER_CHUNK, info->sliceSteps*NCCL_MAX_SLICE_PER_CHUNK); + return ncclInvalidUsage; + } + return ncclSuccess; +} + static ncclResult_t computeColl(struct ncclInfo* info /* input */, int* workFuncIndex, struct ncclWorkElem* work, struct ncclProxyOp* proxyOp /* output */) { int collNetTypeSupport = 0; // Check whether algo and proto have been preset (as in aggregation case) @@ -1270,6 +1308,7 @@ comp_next: // Set nstepsPerLoop and nchunksPerLoop NCCLCHECK(getPatternInfo(info)); NCCLCHECK(getLoopInfo(info)); + NCCLCHECK(getStepInfo(info)); work->sendbuff = info->sendbuff; work->recvbuff = info->recvbuff; @@ -1279,6 +1318,8 @@ comp_next: work->nWarps = info->nThreads / WARP_SIZE; work->redOpArg = info->opFull.scalarArg; work->redOpArgIsPtr = info->opFull.scalarArgIsPtr; + work->stepsPerSlice = info->sliceSteps; + work->slicesPerChunk = info->chunkSteps/info->sliceSteps; if (info->comm->nRanks == 1) { // one-rank reduce index @@ -1289,8 +1330,8 @@ comp_next: *workFuncIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol); int stepSize = info->comm->buffSizes[info->protocol]/NCCL_STEPS; - int chunkSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->chunkSteps : 1; - int sliceSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->sliceSteps : 1; + int chunkSteps = info->chunkSteps; + int sliceSteps = info->sliceSteps; int chunkSize = stepSize*chunkSteps; // Compute lastChunkSize @@ -1320,7 +1361,7 @@ comp_next: while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2; work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); } else if (info->protocol == NCCL_PROTO_LL) { - const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine); + const ssize_t sliceSize = chunkSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine); const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize; work->lastChunkSize = DIVUP((info->nBytes-(info->nBytes/loopSize)*loopSize), info->nChannels*info->nchunksPerLoop); ALIGN_SIZE(work->lastChunkSize, info->nThreads*sizeof(uint64_t)); diff --git a/src/include/collectives.h b/src/include/collectives.h index f50a379..25c14f9 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -109,16 +109,5 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, float)(); extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)(); // CHUNKSIZE must be a multiple of SLICESIZE -#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4) -#define ALLREDUCE_CHUNKSTEPS (NCCL_STEPS/2) -#define ALLGATHER_SLICESTEPS (NCCL_STEPS/4) -#define ALLGATHER_CHUNKSTEPS (NCCL_STEPS/2) -#define REDUCESCATTER_SLICESTEPS (NCCL_STEPS/4) -#define REDUCESCATTER_CHUNKSTEPS (NCCL_STEPS/2) -#define BROADCAST_SLICESTEPS 1 -#define BROADCAST_CHUNKSTEPS 1 -#define REDUCE_SLICESTEPS 1 -#define REDUCE_CHUNKSTEPS 1 -#define NCCL_MAX_SLICE_PER_CHUNK 2 // max value for CHUNKSTEPS/SLICESTEPS, must accord with above - +#define NCCL_MAX_SLICE_PER_CHUNK 8 // max value for CHUNKSTEPS/SLICESTEPS, must accord with below #endif diff --git a/src/include/devcomm.h b/src/include/devcomm.h index 53d6838..5c74244 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -29,7 +29,7 @@ extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS]; extern const char* ncclProtoStr[NCCL_NUM_PROTOCOLS]; #define NCCL_MAX_OPS 2048 -#define NCCL_STEPS 8 +#define NCCL_STEPS 32 union ncclLLFifoLine { /* Flags have to be *after* data, because otherwise, an incomplete receive @@ -51,13 +51,13 @@ union ncclLLFifoLine { #define NCCL_MAX_NTHREADS 640 #define NCCL_SIMPLE_MAX_NTHREADS 512 #define NCCL_LL_MAX_NTHREADS 512 -#define NCCL_LL_LINES_PER_THREAD 8 +#define NCCL_LL_LINES_PER_THREAD 64 // Must be a multiple of NCCL_STEPS #ifdef TEST_LL_CLEANUP #define NCCL_LL_CLEAN_MASK 0x078 // Set to 0x100 to disable cleanup #define NCCL_LL_FLAG_MAX 0x100 #define NCCL_LL_FLAG(a) ((uint32_t)((a) % NCCL_LL_FLAG_MAX)) #else -#define NCCL_LL_CLEAN_MASK 0x7ffffff8 +#define NCCL_LL_CLEAN_MASK 0x7fffffe0 #define NCCL_LL_FLAG(a) ((uint32_t)(a)) #endif // Make sure the clean mask will last for at least NCCL_NSTEPS @@ -68,7 +68,7 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK #define NCCL_LL128_DATAELEMS (NCCL_LL128_LINEELEMS-1) #define NCCL_LL128_MAX_NTHREADS 640 -#define NCCL_LL128_ELEMS_PER_THREAD 120 +#define NCCL_LL128_ELEMS_PER_THREAD 960 // Must be a multiple of NCCL_STEPS #define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 8 #define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS) @@ -199,6 +199,8 @@ struct ncclWorkElem { uint32_t root; uint8_t bid; uint8_t nChannels; + uint8_t stepsPerSlice; + uint8_t slicesPerChunk; uint64_t redOpArg; }; @@ -212,7 +214,8 @@ struct ncclWorkElemP2p { enum ncclWorkP2PType p2pType; uint8_t nWarps; uint8_t warpStart; - uint8_t ngroups; + uint8_t ngroups:5; + uint8_t stepsPerChunkPow2:3; // Important not to use any fields with greater than 4-byte alignment since // we need sizeof(ncclWorkElemP2p)==28, but that would be padded up to 32 if // there were 8-byte fields. diff --git a/src/include/nccl_net.h b/src/include/nccl_net.h index 255a44e..b2725c0 100644 --- a/src/include/nccl_net.h +++ b/src/include/nccl_net.h @@ -17,7 +17,7 @@ #define NCCL_PTR_DMABUF 0x4 // Maximum number of requests per comm object -#define NCCL_NET_MAX_REQUESTS 8 +#define NCCL_NET_MAX_REQUESTS 32 typedef enum {NCCL_LOG_NONE=0, NCCL_LOG_VERSION=1, NCCL_LOG_WARN=2, NCCL_LOG_INFO=3, NCCL_LOG_ABORT=4, NCCL_LOG_TRACE=5} ncclDebugLogLevel; typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_CALL=512, NCCL_ALL=~0} ncclDebugLogSubSys; diff --git a/src/init.cc b/src/init.cc index 91a8793..2b39afa 100644 --- a/src/init.cc +++ b/src/init.cc @@ -463,8 +463,8 @@ static ncclResult_t setupChannel(struct ncclComm* comm, int channelId, int rank, return ncclSuccess; } -#define DEFAULT_LL_BUFFSIZE (NCCL_LL_LINES_PER_THREAD*NCCL_LL_MAX_NTHREADS*NCCL_STEPS*sizeof(union ncclLLFifoLine)) -#define DEFAULT_LL128_BUFFSIZE (NCCL_LL128_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS*NCCL_STEPS*sizeof(uint64_t)) +#define DEFAULT_LL_BUFFSIZE (NCCL_LL_LINES_PER_THREAD*NCCL_LL_MAX_NTHREADS*sizeof(union ncclLLFifoLine)) +#define DEFAULT_LL128_BUFFSIZE (NCCL_LL128_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS*sizeof(uint64_t)) #define DEFAULT_BUFFSIZE (1 << 22) /* 4MiB */ #define DEFAULT_BUFFSIZE_ARM (1 << 20) /* 1MiB */ NCCL_PARAM(BuffSize, "BUFFSIZE", -2); diff --git a/src/proxy.cc b/src/proxy.cc index 2103b7a..13b98a3 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -430,15 +430,15 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op) int channelId = info->channelId; struct ncclChannel* channel = info->comm->channels+channelId; op->channelId = channelId; - op->sliceSteps = 1; - op->chunkSteps = 1; + op->sliceSteps = info->sliceSteps; + op->chunkSteps = info->chunkSteps; op->dtype = info->datatype; op->protocol = info->protocol; int stepSize = info->comm->buffSizes[op->protocol]/NCCL_STEPS; + int chunkSize = stepSize * info->chunkSteps; - if (op->protocol == NCCL_PROTO_SIMPLE) stepSize = info->comm->p2pChunkSize; - info->chunkSize = stepSize; + if (op->protocol == NCCL_PROTO_SIMPLE) chunkSize = info->comm->p2pChunkSize; op->root = info->root; struct ncclChannelPeer* peer = channel->peers + op->root; @@ -446,24 +446,27 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op) op->pattern = ncclPatternSend; if (op->root != info->comm->rank && peer->send[1].transportComm == &netTransport.send) { // Tune chunk size for the network - if (info->count < stepSize) info->chunkSize /= 4; - else if (info->count < 8*stepSize) info->chunkSize /= 2; + if (info->count < chunkSize) chunkSize /= 4; + else if (info->count < 8*chunkSize) chunkSize /= 2; } } else if (info->coll == ncclFuncRecv) { op->pattern = ncclPatternRecv; if (op->root != info->comm->rank && peer->recv[1].transportComm == &netTransport.recv) { // Tune chunk size for the network - if (info->count < stepSize) info->chunkSize /= 4; - else if (info->count < 8*stepSize) info->chunkSize /= 2; + if (info->count < chunkSize) chunkSize /= 4; + else if (info->count < 8*chunkSize) chunkSize /= 2; } } else { WARN("P2p operation is neither send or recv"); return ncclInternalError; } if (ncclParamChunkSize() != 0) { - info->chunkSize = ncclParamChunkSize(); + chunkSize = ncclParamChunkSize(); } - op->chunkSize = info->chunkSize; + if (chunkSize > stepSize*info->chunkSteps) { + chunkSize = stepSize*info->chunkSteps; + } + op->chunkSize = info->chunkSize = chunkSize; // Compute nSteps for proxies int chunkEffectiveSize = op->chunkSize; @@ -471,7 +474,7 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op) chunkEffectiveSize /= 2; } - op->nbytes = stepSize; + op->nbytes = chunkSize; op->nsteps = DIVUP(info->count, chunkEffectiveSize); if (op->nsteps == 0) op->nsteps = 1; diff --git a/src/transport/net.cc b/src/transport/net.cc index bdb2e2d..66fcfcf 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -837,7 +837,6 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg args->idle = 1; if (args->state == ncclProxyOpProgress) { int p = args->protocol; - int maxDepth = std::min(NCCL_STEPS, NCCL_SHARED_STEPS/args->nsubs); for (int s=0; snsubs; s++) { struct ncclProxySubArgs* sub = args->subs+s; if (sub->done == sub->nsteps) continue; @@ -848,6 +847,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg int buffSize = stepSize*args->sliceSteps; if (sub->nbytes < buffSize) buffSize = sub->nbytes; // Post buffers to the GPU + int maxDepth = resources->shared ? NCCL_SHARED_STEPS/args->nsubs : NCCL_STEPS; if (sub->posted < sub->nsteps && sub->posted < sub->done + maxDepth) { int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; if (resources->shared) { @@ -883,7 +883,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg if (!ready) { // When data is in sysmem, we need to wait until all flags are correct since the GPU only // called threadfence() - uint64_t flag = sub->base+sub->transmitted+1; + uint64_t flag = sub->base+sub->transmitted+args->sliceSteps; int nFifoLines = DIVUP(sizesFifo[buffSlot], sizeof(uint64_t)*NCCL_LL128_LINEELEMS); volatile uint64_t* lines = (volatile uint64_t*)buff; ready = 1; @@ -892,7 +892,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg } } } else if (p == NCCL_PROTO_LL) { - uint32_t flag = NCCL_LL_FLAG(sub->base+sub->transmitted+1); + uint32_t flag = NCCL_LL_FLAG(sub->base+sub->transmitted+args->sliceSteps); int nFifoLines = DIVUP(size, sizeof(union ncclLLFifoLine)); union ncclLLFifoLine* lines = (union ncclLLFifoLine*)buff; for (int i=0; iidle = 1; if (args->state == ncclProxyOpProgress) { int p = args->protocol; - int maxDepth = std::min(NCCL_STEPS, NCCL_SHARED_STEPS/args->nsubs); for (int s=0; snsubs; s+=args->subs[s].groupSize) { struct ncclProxySubArgs* subGroup = args->subs+s; int subCount = 0; @@ -999,8 +998,9 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; if (sub->posted < sub->nsteps) { - if (sub->posted >= sub->done + maxDepth) { subCount = 0; break; } struct recvResources* resources = (struct recvResources*) (sub->connection->transportResources); + int maxDepth = resources->shared ? NCCL_SHARED_STEPS/args->nsubs : NCCL_STEPS; + if (sub->posted >= sub->done + maxDepth) { subCount = 0; break; } int stepSize = resources->buffSizes[p] / NCCL_STEPS; char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;