nccl/src/collectives/device/reduce_scatter.h

91 lines
3.5 KiB
C++

/*************************************************************************
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "devcomm.h"
#include "collectives.h"
#include "primitives.h"
namespace {
template<typename T, typename RedOp, typename Proto>
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
const int nChannels = args->nChannels;
ncclRing *ring = &ncclShmem.channel.ring;
int const *ringRanks = ring->userRanks;
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice*args->slicesPerChunk);
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2);
const int nranks = ncclShmem.comm.nRanks;
const ssize_t loopSize = nChannels*chunkSize;
const ssize_t size = args->count;
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg,
args->stepsPerSlice, args->slicesPerChunk);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t realChunkSize;
if (Proto::Id == NCCL_PROTO_SIMPLE) {
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels));
realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T));
}
else if (Proto::Id == NCCL_PROTO_LL)
realChunkSize = size-gridOffset < loopSize ? args->lastChunkSize : chunkSize;
else if (Proto::Id == NCCL_PROTO_LL128)
realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize);
realChunkSize = int(realChunkSize);
ssize_t chunkOffset = gridOffset + bid*int(realChunkSize);
/////////////// begin ReduceScatter steps ///////////////
ssize_t offset;
int nelem = min(realChunkSize, size-chunkOffset);
int rankDest;
// step 0: push data to next GPU
rankDest = ringRanks[nranks-1];
offset = chunkOffset + rankDest * size;
prims.send(offset, nelem);
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
rankDest = ringRanks[nranks-j];
offset = chunkOffset + rankDest * size;
prims.recvReduceSend(offset, nelem);
}
// step k-1: reduce this buffer and data, which will produce the final result
rankDest = ringRanks[0];
offset = chunkOffset + rankDest * size;
prims.recvReduceCopy(offset, chunkOffset, nelem, /*postOp=*/true);
}
}
}
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(ncclWorkElem *args) {
using Proto = ProtoSimple<>;
runRing<T, RedOp, Proto>(args);
}
};
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
__device__ __forceinline__ void run(ncclWorkElem *args) {
runRing<T, RedOp, ProtoLL>(args);
}
};
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
__device__ __forceinline__ void run(ncclWorkElem *args) {
runRing<T, RedOp, ProtoLL128>(args);
}
};