Increase depth of FIFO buffers and make chunksteps/slicesteps adjustable.
This commit is contained in:
parent
28189e2df8
commit
c324f771db
@ -20,6 +20,6 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun
|
||||
|
||||
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
|
||||
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
|
||||
ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS };
|
||||
};
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
|
@ -26,6 +26,6 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
|
||||
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
|
||||
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
|
||||
ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
|
||||
};
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, n
|
||||
|
||||
struct ncclInfo info = { ncclFuncBroadcast, "Broadcast",
|
||||
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
|
||||
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS };
|
||||
};
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
/* Deprecated original "in place" function, similar to MPI */
|
||||
|
@ -17,7 +17,7 @@ namespace {
|
||||
const int nChannels = args->nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
const int *ringRanks = ring->userRanks;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLGATHER_CHUNKSTEPS : 1));
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice*args->slicesPerChunk);
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2);
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
@ -27,7 +27,8 @@ namespace {
|
||||
T *inputBuf = (T*)args->sendbuff;
|
||||
T *outputBuf = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 1, Proto, 0> prims
|
||||
(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg);
|
||||
(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
@ -79,7 +80,7 @@ namespace {
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>;
|
||||
using Proto = ProtoSimple<>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
@ -17,7 +17,7 @@ namespace {
|
||||
const int nChannels = args->nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
int ringIx = ring->index;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLREDUCE_CHUNKSTEPS : 1));
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice*args->slicesPerChunk);
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*nranks*chunkSize;
|
||||
const ssize_t size = args->count;
|
||||
@ -31,7 +31,8 @@ namespace {
|
||||
}
|
||||
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 1, Proto, 0> prims
|
||||
(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
@ -103,7 +104,7 @@ namespace {
|
||||
ncclTree *tree = &ncclShmem.channel.tree;
|
||||
ssize_t chunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? args->lastChunkSize
|
||||
/* LL & LL128 */ : Proto::calcBytePerStep()/sizeof(T));
|
||||
/* LL & LL128 */ : Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice);
|
||||
const ssize_t minChunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads-2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T))
|
||||
/* LL & LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
@ -115,7 +116,9 @@ namespace {
|
||||
|
||||
{ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0> prims
|
||||
(tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
(tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk);
|
||||
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@ -141,7 +144,8 @@ namespace {
|
||||
|
||||
{ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto, 0> prims
|
||||
(tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
(tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@ -175,7 +179,7 @@ namespace {
|
||||
ncclTree *tree = &ncclShmem.channel.tree;
|
||||
ssize_t chunkSize = int(
|
||||
Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize
|
||||
: Proto::calcBytePerStep()/sizeof(T));
|
||||
: Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice);
|
||||
const ssize_t minChunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads - 2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T)) :
|
||||
Proto::Id == NCCL_PROTO_LL ? nthreads*(Proto::calcBytePerGrain()/sizeof(T))
|
||||
@ -199,7 +203,9 @@ namespace {
|
||||
if (tree->up == -1) {
|
||||
// Reduce and broadcast. Max number of recv is 3, max number of send is 3
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@ -216,7 +222,8 @@ namespace {
|
||||
* but the ctor above for tree roots would be DirectRecv=0 DirectSend=1.
|
||||
*/
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg, 0*Proto::MaxGroupWidth);
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, 0*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@ -235,7 +242,8 @@ namespace {
|
||||
else {
|
||||
// Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg, 1*Proto::MaxGroupWidth);
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, 1*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@ -257,8 +265,7 @@ namespace {
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
runRing<T, RedOp, ProtoSimple<>>(args);
|
||||
}
|
||||
};
|
||||
|
||||
@ -266,9 +273,9 @@ template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
#if CUDART_VERSION >= 11020 && CUDART_VERSION < 11040 && __CUDA_ARCH__ >= 800
|
||||
runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(args);
|
||||
runTreeUpDown<T, RedOp, ProtoSimple<>>(args);
|
||||
#else
|
||||
runTreeSplit<T, RedOp, ProtoSimple<1, 1>>(args);
|
||||
runTreeSplit<T, RedOp, ProtoSimple<>>(args);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@ -295,13 +302,14 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
|
||||
const int tidStartScatter = tidStartBcast + nThreadsBcast;
|
||||
const int tidStartReduce = tidStartScatter + nThreadsScatter;
|
||||
|
||||
using Proto = ProtoSimple<1, 1>;
|
||||
using Proto = ProtoSimple<>;
|
||||
|
||||
if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) {
|
||||
// Scatter
|
||||
int group = (2*Proto::MaxGroupWidth) | (1<<16);
|
||||
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid-tidStartScatter, nThreadsScatter, NULL, direct->up, args->sendbuff, args->recvbuff, args->redOpArg, group, args);
|
||||
prims(tid-tidStartScatter, nThreadsScatter, NULL, direct->up, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, group, args);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize;
|
||||
int nelem = min(direct->nHeads*chunkSize, size-offset);
|
||||
@ -316,7 +324,8 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
|
||||
if (hasDn) {
|
||||
// Reduce, send to network
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 1>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, direct->down, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg, group, args);
|
||||
prims(tid-tidStartReduce, nThreadsReduce, direct->down, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, group, args);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@ -329,7 +338,8 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
|
||||
} else {
|
||||
// Directly send to network
|
||||
Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, nullptr, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg, group);
|
||||
prims(tid-tidStartReduce, nThreadsReduce, nullptr, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, group);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@ -340,7 +350,8 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
|
||||
// Gather
|
||||
int group = (0*Proto::MaxGroupWidth) | (0<<16);
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 0>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid, nThreadsGather, direct->up, NULL, args->sendbuff, args->recvbuff, args->redOpArg, group, args);
|
||||
prims(tid, nThreadsGather, direct->up, NULL, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, group, args);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize;
|
||||
int nelem = min(direct->nHeads*chunkSize, size-offset);
|
||||
@ -351,7 +362,8 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
|
||||
if (hasDn) {
|
||||
// Recv from network, broadcast
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &direct->out, direct->down, args->sendbuff, args->recvbuff, args->redOpArg, group, args);
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &direct->out, direct->down, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, group, args);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@ -360,7 +372,8 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
|
||||
} else {
|
||||
// Recv from network (no post thread needed)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &direct->out, nullptr, args->sendbuff, args->recvbuff, args->redOpArg, group);
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &direct->out, nullptr, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, group);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@ -387,7 +400,7 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_CHAIN, NCCL
|
||||
if (nthreadsSplit >= 256) nthreadsSplit += 64;
|
||||
|
||||
int group, send, recv, groupTid, groupNthreads;
|
||||
using Proto = ProtoSimple<1, 1>;
|
||||
using Proto = ProtoSimple<>;
|
||||
if (tid < nthreadsSplit) {
|
||||
group = (0*Proto::MaxGroupWidth) | (1<<16);
|
||||
recv = tree->down[0];
|
||||
@ -403,7 +416,8 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_CHAIN, NCCL
|
||||
}
|
||||
|
||||
Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/1, Proto, 0>
|
||||
prims(groupTid, groupNthreads, &recv, &send, args->sendbuff, args->recvbuff, args->redOpArg, group);
|
||||
prims(groupTid, groupNthreads, &recv, &send, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk, group);
|
||||
|
||||
if (tid < nthreadsSplit) {
|
||||
if (recv == -1) {
|
||||
|
@ -16,7 +16,7 @@ namespace {
|
||||
const int bid = args->bid;
|
||||
const int nChannels = args->nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? BROADCAST_CHUNKSTEPS : 1));
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice);
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->count;
|
||||
@ -27,7 +27,8 @@ namespace {
|
||||
T *inputBuf = (T*)args->sendbuff;
|
||||
T *outputBuf = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
@ -62,7 +63,7 @@ namespace {
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS>;
|
||||
using Proto = ProtoSimple<>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
@ -20,11 +20,9 @@
|
||||
* to how that protocol operates with a consistent interface so that our
|
||||
* algorithm code can operate protocol parametrically.
|
||||
*/
|
||||
template<int SlicePerChunk_1, int StepPerSlice_1, int Unroll_1 = COLL_UNROLL>
|
||||
template<int Unroll_1 = COLL_UNROLL>
|
||||
struct ProtoSimple {
|
||||
static constexpr int Id = NCCL_PROTO_SIMPLE;
|
||||
static constexpr int SlicePerChunk = SlicePerChunk_1;
|
||||
static constexpr int StepPerSlice = StepPerSlice_1;
|
||||
static constexpr int Unroll = Unroll_1;
|
||||
|
||||
// Data bytes (no flags etc) in one step of the fifo queue.
|
||||
|
@ -19,6 +19,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
|
||||
const int wid;
|
||||
const int group;
|
||||
const int stepLines;
|
||||
const int stepsPerChunk;
|
||||
Fan fan;
|
||||
T *userBufs[2];
|
||||
struct ncclConnInfo* recvConn = NULL;
|
||||
@ -40,8 +41,8 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
|
||||
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; }
|
||||
inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
|
||||
inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
|
||||
inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); }
|
||||
inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); }
|
||||
inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+stepsPerChunk); }
|
||||
inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+stepsPerChunk); }
|
||||
|
||||
inline __device__ void barrier() {
|
||||
if (nthreads == WARP_SIZE)
|
||||
@ -64,34 +65,34 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
if (sendConnHeadPtr) {
|
||||
int spins = 0;
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + stepsPerChunk) {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
if (checkAbort(spins, 1)) break;
|
||||
}
|
||||
if (sendConnFifoPtr) {
|
||||
int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes;
|
||||
int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepsPerChunk*stepLines*sizeof(union ncclLLFifoLine) : nbytes;
|
||||
sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size;
|
||||
}
|
||||
sendConnHead += 1;
|
||||
sendConnHead += stepsPerChunk;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
inline __device__ void incRecv(int i) {
|
||||
recvStep[i] += 1;
|
||||
recvStep[i] += stepsPerChunk;
|
||||
}
|
||||
inline __device__ void postRecv() {
|
||||
barrier();
|
||||
if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1;
|
||||
if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += stepsPerChunk;
|
||||
}
|
||||
|
||||
inline __device__ void incSend(int i, int offset) {
|
||||
// LL Cleanup : write all flags in the slice to make sure we don't have
|
||||
// data corruption when flag loops over.
|
||||
if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) {
|
||||
for (int o = offset; o<stepLines; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
|
||||
for (int o = offset; o<stepLines*stepsPerChunk; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
|
||||
}
|
||||
sendStep[i]++;
|
||||
sendStep[i] += stepsPerChunk;
|
||||
}
|
||||
|
||||
__device__ uint64_t readLL(int offset, int i) {
|
||||
@ -296,6 +297,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
||||
recvBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
|
||||
recvStep[i] = conn->step;
|
||||
recvStep[i] = roundUp(recvStep[i], stepsPerChunk);
|
||||
if (wid == i) recvConn = conn;
|
||||
}
|
||||
__device__ __forceinline__ void loadRecvSync() {
|
||||
@ -308,6 +310,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
|
||||
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
|
||||
sendBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
|
||||
sendStep[i] = conn->step;
|
||||
sendStep[i] = roundUp(sendStep[i], stepsPerChunk);
|
||||
if (wid == i) sendConn = conn;
|
||||
}
|
||||
__device__ __forceinline__ void loadSendSync() {
|
||||
@ -322,10 +325,12 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
|
||||
public:
|
||||
__device__ Primitives(
|
||||
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int stepsPerChunk,
|
||||
int ignored, int group=0
|
||||
):
|
||||
redOp(redOpArg),
|
||||
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group&(uint16_t)0xFFFF),
|
||||
stepsPerChunk(stepsPerChunk),
|
||||
stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) {
|
||||
int connIndex = group >> 16;
|
||||
auto *channel = &ncclShmem.channel;
|
||||
|
@ -23,6 +23,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
|
||||
const int warpInBlock; // warp index in thread block
|
||||
const bool flagThread;
|
||||
const int group;
|
||||
const int stepsPerChunk;
|
||||
Fan fan;
|
||||
T *userBufs[2];
|
||||
struct ncclConnInfo* recvConn = NULL;
|
||||
@ -46,8 +47,8 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
|
||||
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; }
|
||||
inline __device__ uint64_t* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
|
||||
inline __device__ uint64_t* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
|
||||
inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+1; }
|
||||
inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; }
|
||||
inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+stepsPerChunk; }
|
||||
inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+stepsPerChunk; }
|
||||
|
||||
inline __device__ void barrier() {
|
||||
asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(15-group));
|
||||
@ -67,22 +68,22 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
if (sendConnHeadPtr) {
|
||||
int spins = 0;
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + stepsPerChunk) {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
if (checkAbort(spins, wid, 1)) break;
|
||||
}
|
||||
if (sendConnFifoPtr) {
|
||||
sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes;
|
||||
}
|
||||
sendConnHead += 1;
|
||||
sendConnHead += stepsPerChunk;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void postRecv() {
|
||||
if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1;
|
||||
if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += stepsPerChunk;
|
||||
}
|
||||
inline __device__ void postSend() {
|
||||
if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; }
|
||||
if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += stepsPerChunk; }
|
||||
}
|
||||
|
||||
template<int WordPerThread>
|
||||
@ -315,15 +316,16 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
|
||||
}
|
||||
|
||||
barrier();
|
||||
if (SEND) for (int i=0; i < MaxSend; i++) sendStep[i] += 1;
|
||||
if (SEND) for (int i=0; i < MaxSend; i++) sendStep[i] += stepsPerChunk;
|
||||
if (SEND) postSend();
|
||||
if (RECV) for (int i=0; i < MaxRecv; i++) recvStep[i] += 1;
|
||||
if (RECV) for (int i=0; i < MaxRecv; i++) recvStep[i] += stepsPerChunk;
|
||||
if (RECV) postRecv();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
||||
recvBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
|
||||
recvStep[i] = conn->step;
|
||||
recvStep[i] = roundUp(recvStep[i], stepsPerChunk);
|
||||
if (wid == i) recvConn = conn;
|
||||
}
|
||||
__device__ __forceinline__ void loadRecvSync() {
|
||||
@ -336,6 +338,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
|
||||
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
|
||||
sendBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
|
||||
sendStep[i] = conn->step;
|
||||
sendStep[i] = roundUp(sendStep[i], stepsPerChunk);
|
||||
if (wid == i) sendConn = conn;
|
||||
}
|
||||
__device__ __forceinline__ void loadSendSync() {
|
||||
@ -356,12 +359,14 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
|
||||
public:
|
||||
__device__ Primitives(
|
||||
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int stepsPerChunk,
|
||||
int ignored, int group=0
|
||||
):
|
||||
redOp(redOpArg),
|
||||
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE),
|
||||
warpInBlock(threadIdx.x/WARP_SIZE),
|
||||
flagThread((tid%8)==7), group(group&(uint16_t)0xFFFF),
|
||||
stepsPerChunk(stepsPerChunk),
|
||||
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) {
|
||||
int connIndex = group >> 16;
|
||||
auto *channel = &ncclShmem.channel;
|
||||
|
@ -4,10 +4,9 @@
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct,
|
||||
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p>
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int Unroll, int P2p>
|
||||
class Primitives<
|
||||
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, Unroll>, P2p
|
||||
T, RedOp, Fan, Direct, ProtoSimple<Unroll>, P2p
|
||||
> {
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
@ -27,6 +26,8 @@ class Primitives<
|
||||
int nthreads;
|
||||
int nworkers;
|
||||
const int stepSize;
|
||||
const int stepsPerSlice;
|
||||
const int slicesPerChunk;
|
||||
Fan fan;
|
||||
int index; // Peer index I'm responsible for
|
||||
int flags;
|
||||
@ -79,10 +80,10 @@ class Primitives<
|
||||
if (((flags & (Recv*RoleWaitRecv)) && !noRecvWait) ||
|
||||
((flags & (Send*RoleWaitSend)) && !noSendWait)) {
|
||||
int spins = 0;
|
||||
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
|
||||
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + stepsPerSlice) {
|
||||
connStepCache = *connStepPtr;
|
||||
if (checkAbort(spins)) break;
|
||||
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
|
||||
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+stepsPerSlice));
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,14 +115,14 @@ class Primitives<
|
||||
else {
|
||||
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
|
||||
}
|
||||
step += StepPerSlice;
|
||||
step += stepsPerSlice;
|
||||
}
|
||||
}
|
||||
|
||||
template<int Recv, int Send>
|
||||
inline __device__ void postPeer() {
|
||||
if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
|
||||
step += StepPerSlice;
|
||||
step += stepsPerSlice;
|
||||
*connStepPtr = step;
|
||||
}
|
||||
}
|
||||
@ -136,8 +137,8 @@ class Primitives<
|
||||
constexpr int Dst = DstBuf != -1;
|
||||
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
int sliceSize = stepSize*StepPerSlice;
|
||||
sliceSize = max(divUp(nelem, 16*SlicePerChunk)*16, sliceSize/32);
|
||||
int sliceSize = stepSize*stepsPerSlice;
|
||||
sliceSize = max(divUp(nelem, 16*slicesPerChunk)*16, sliceSize/32);
|
||||
int slice = 0;
|
||||
int offset = 0;
|
||||
|
||||
@ -165,12 +166,7 @@ class Primitives<
|
||||
// barrier();
|
||||
// post();
|
||||
// } // Since we no longer unroll, new branch added here
|
||||
#if __CUDA_ARCH__ < 700
|
||||
// Yeah, so all that above don't matter a lick on older hardware.
|
||||
#pragma unroll SlicePerChunk
|
||||
#else
|
||||
#pragma unroll 1
|
||||
#endif
|
||||
do {
|
||||
sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset;
|
||||
if (Src && (flags & (SrcBuf==Input ? RoleInput : RoleOutput)))
|
||||
@ -214,7 +210,7 @@ class Primitives<
|
||||
postPeer<Recv, Send>();
|
||||
offset += sliceSize;
|
||||
slice += 1;
|
||||
} while (slice < SlicePerChunk && offset < nelem);
|
||||
} while (slice < slicesPerChunk && offset < nelem);
|
||||
}
|
||||
|
||||
// Non-workers come straight here. Workers too but only once the remaining
|
||||
@ -222,7 +218,7 @@ class Primitives<
|
||||
// worker perf is the limiter, perf-wise this loop is effectively unentered,
|
||||
// hence just a single branch insn.
|
||||
#pragma unroll 1
|
||||
while (slice < SlicePerChunk) {
|
||||
while (slice < slicesPerChunk) {
|
||||
sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset;
|
||||
{ // Only workers could have Wait roles so we know the slice must be empty
|
||||
// since we've exited the loop above.
|
||||
@ -246,11 +242,11 @@ class Primitives<
|
||||
constexpr int DirectRecv = 1 && Direct && DirectRecv1;
|
||||
constexpr int DirectSend = 1 && Direct && DirectSend1;
|
||||
int offset = 0; // slice offset
|
||||
int sliceSize = stepSize*StepPerSlice;
|
||||
int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size
|
||||
int sliceSize = stepSize*stepsPerSlice;
|
||||
int dataSize = max(DIVUP(peerElem, 16*slicesPerChunk)*16, sliceSize/32); // per-peer slice size
|
||||
|
||||
#pragma unroll
|
||||
for (int slice=0; slice<SlicePerChunk; ++slice) {
|
||||
for (int slice=0; slice<slicesPerChunk; ++slice) {
|
||||
int realSize = max(0, min(dataSize, peerElem-offset));
|
||||
if (tid < nworkers) {
|
||||
if (Send) {
|
||||
@ -313,7 +309,7 @@ class Primitives<
|
||||
if (flags & (RoleWaitRecv|RolePostRecv)) {
|
||||
auto *conn = &peer->recv[connIndex];
|
||||
step = conn->step;
|
||||
step = roundUp(step, SlicePerChunk*StepPerSlice);
|
||||
step = roundUp(step, slicesPerChunk*stepsPerSlice);
|
||||
if (flags & RolePostRecv) {
|
||||
connStepPtr = conn->head;
|
||||
*connStepPtr = step; // Return credits in case we rounded up.
|
||||
@ -353,7 +349,7 @@ class Primitives<
|
||||
if (flags & (RoleWaitSend|RolePostSend)) {
|
||||
auto *conn = &peer->send[connIndex];
|
||||
step = conn->step;
|
||||
step = roundUp(step, SlicePerChunk*StepPerSlice);
|
||||
step = roundUp(step, slicesPerChunk*stepsPerSlice);
|
||||
if (flags & RolePostSend) {
|
||||
connStepPtr = conn->tail;
|
||||
}
|
||||
@ -395,9 +391,11 @@ class Primitives<
|
||||
public:
|
||||
__device__ Primitives(
|
||||
int tid, int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg,
|
||||
int stepsPerSlice, int slicesPerChunk,
|
||||
uint32_t group=0, struct ncclWorkElem* e = nullptr
|
||||
):
|
||||
tid(tid),
|
||||
tid(tid), stepsPerSlice(stepsPerSlice), slicesPerChunk(slicesPerChunk),
|
||||
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) {
|
||||
|
||||
// For send operations, we need an extra warp to overlap the threadfence and the copy
|
||||
|
@ -16,7 +16,7 @@ namespace {
|
||||
const int bid = args->bid;
|
||||
const int nChannels = args->nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCE_CHUNKSTEPS : 1));
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice);
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
@ -26,7 +26,8 @@ namespace {
|
||||
const int root = args->root;
|
||||
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk);
|
||||
|
||||
auto calcChunkSize = [&]__device__(ssize_t gridOffset)->int {
|
||||
int realChunkSize;
|
||||
@ -71,7 +72,7 @@ namespace {
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS>;
|
||||
using Proto = ProtoSimple<>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
@ -17,7 +17,7 @@ namespace {
|
||||
const int nChannels = args->nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
int const *ringRanks = ring->userRanks;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCESCATTER_CHUNKSTEPS : 1));
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * args->stepsPerSlice*args->slicesPerChunk);
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2);
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
@ -25,7 +25,8 @@ namespace {
|
||||
const ssize_t size = args->count;
|
||||
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg,
|
||||
args->stepsPerSlice, args->slicesPerChunk);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
@ -69,7 +70,7 @@ namespace {
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>;
|
||||
using Proto = ProtoSimple<>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
@ -25,7 +25,8 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
if (args->proto == NCCL_PROTO_LL) chunkSize /= 2;
|
||||
int const peer = args->peer;
|
||||
Primitives<T, RedOp, FanAsymmetric<0, 1>, 1, Proto, 1> prims
|
||||
(tid, nthreads, nullptr, &peer, buff, nullptr, /*redOpArg(ignored)=*/0, group);
|
||||
(tid, nthreads, nullptr, &peer, buff, nullptr, /*redOpArg(ignored)=*/0,
|
||||
1<<args->stepsPerChunkPow2, 1, group);
|
||||
size_t offset = 0;
|
||||
do {
|
||||
int nelem = min(size_t(chunkSize), count-offset);
|
||||
@ -44,7 +45,8 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
if (args->proto == NCCL_PROTO_LL) chunkSize /= 2; // This is to account for chunkEffectiveSize
|
||||
int const peer = args->peer;
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, 1, Proto, 1> prims
|
||||
(tid, nthreads, &peer, nullptr, nullptr, buff, /*redOpArg(ignored)=*/0, group);
|
||||
(tid, nthreads, &peer, nullptr, nullptr, buff, /*redOpArg(ignored)=*/0,
|
||||
1<<args->stepsPerChunkPow2, 1, group);
|
||||
size_t offset = 0;
|
||||
do {
|
||||
int nelem = min(size_t(chunkSize), count-offset);
|
||||
@ -79,13 +81,13 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
if (args->proto == NCCL_PROTO_LL) {
|
||||
runRecv<ProtoLL>(tid, nthreads, group, args);
|
||||
} else {
|
||||
runRecv<ProtoSimple<1,1>>(tid, nthreads, group, args);
|
||||
runRecv<ProtoSimple<>>(tid, nthreads, group, args);
|
||||
}
|
||||
} else {
|
||||
if (args->proto == NCCL_PROTO_LL) {
|
||||
runSend<ProtoLL>(tid, nthreads, group, args);
|
||||
} else {
|
||||
runSend<ProtoSimple<1,1>>(tid, nthreads, group, args);
|
||||
runSend<ProtoSimple<>>(tid, nthreads, group, args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -28,6 +28,6 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
|
||||
struct ncclInfo info = { ncclFuncReduce, "Reduce",
|
||||
sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */
|
||||
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS };
|
||||
};
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
|
@ -26,6 +26,6 @@ ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recv
|
||||
|
||||
struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
|
||||
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
|
||||
REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS };
|
||||
};
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
|
@ -25,8 +25,8 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp
|
||||
NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload)
|
||||
|
||||
struct ncclInfo info = { ncclFuncSend, "Send",
|
||||
NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
|
||||
1, 1 };
|
||||
NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream /* Args */
|
||||
};
|
||||
ncclResult_t ret;
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
ret = ncclEnqueueCheck(&info);
|
||||
@ -43,7 +43,7 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int
|
||||
|
||||
struct ncclInfo info = { ncclFuncRecv, "Recv",
|
||||
NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
|
||||
1, 1 };
|
||||
};
|
||||
ncclResult_t ret;
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
ret = ncclEnqueueCheck(&info);
|
||||
|
@ -332,6 +332,7 @@ static ncclResult_t addCollToPlan(
|
||||
}
|
||||
|
||||
NCCL_PARAM(P2pLLThreshold, "P2P_LL_THRESHOLD", 16384);
|
||||
NCCL_PARAM(P2pChunkSteps, "P2P_CHUNKSTEPS", (NCCL_STEPS/8));
|
||||
|
||||
// Put p2p op in plan assuming there is space in nWorkBudget, so you must
|
||||
// ensure *nWorkBudget >= 1 upon entry.
|
||||
@ -343,8 +344,10 @@ static ncclResult_t addP2pToPlan(
|
||||
isSendNotRecv ? ncclFuncSend : ncclFuncRecv,
|
||||
isSendNotRecv ? "Send" : "Recv",
|
||||
nullptr, addr, bytes, ncclInt8, ncclSum, peer, comm, (cudaStream_t)0,
|
||||
/*Args*/1, 1
|
||||
(int)ncclParamP2pChunkSteps(), (int)ncclParamP2pChunkSteps()
|
||||
};
|
||||
// Shared buffers do not support *steps>1.
|
||||
if (comm->nNodes > 1) info.sliceSteps = info.chunkSteps = 1;
|
||||
|
||||
int channelId;
|
||||
NCCLCHECK(ncclChannelCompute(comm, peer, chunk%comm->p2pnChannelsPerPeer, info.coll, &channelId));
|
||||
@ -363,6 +366,7 @@ static ncclResult_t addP2pToPlan(
|
||||
elem.peer = peer;
|
||||
elem.nWarps = NCCL_MAX_NTHREADS/WARP_SIZE;
|
||||
elem.p2pType = isSendNotRecv ? ncclWorkP2pTypeSend : ncclWorkP2pTypeRecv;
|
||||
elem.stepsPerChunkPow2 = log2i(info.chunkSteps);
|
||||
elem.buffLo32 = uint32_t(reinterpret_cast<uintptr_t>(addr));
|
||||
elem.buffHi32 = reinterpret_cast<uintptr_t>(addr)>>32;
|
||||
elem.countLo32 = uint32_t(bytes);
|
||||
@ -609,7 +613,7 @@ static ncclResult_t scheduleP2pTasksToPlan(
|
||||
|
||||
// Compute how much to split operations
|
||||
// Natural step size matching buffer steps.
|
||||
ssize_t stepSize = comm->p2pChunkSize;
|
||||
ssize_t chunkSize = comm->p2pChunkSize;
|
||||
// Try to use all channels
|
||||
int nChannelsMax = comm->p2pnChannelsPerPeer;
|
||||
int nChannelsMin = nChannelsMax;
|
||||
@ -643,8 +647,8 @@ static ncclResult_t scheduleP2pTasksToPlan(
|
||||
char* sendPtr = send ? (char*)send->buff : nullptr;
|
||||
ssize_t recvBytes = recv ? recv->bytes : 0;
|
||||
ssize_t sendBytes = send ? send->bytes : 0;
|
||||
ssize_t minSize = stepSize/8;
|
||||
ssize_t maxSize = comm->nNodes > 1 ? stepSize : stepSize*32;
|
||||
ssize_t minSize = chunkSize/8;
|
||||
ssize_t maxSize = comm->nNodes > 1 ? chunkSize : chunkSize*32;
|
||||
ssize_t recvChunkBytesMax = calcP2pChunkSize(recvBytes, nChannelsMin, nChannelsMax, minSize, maxSize);
|
||||
ssize_t sendChunkBytesMax = calcP2pChunkSize(sendBytes, nChannelsMin, nChannelsMax, minSize, maxSize);
|
||||
// Zero size send/recv are syncs, encode here with -1.
|
||||
@ -1258,6 +1262,40 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_PARAM(LLChunkSteps, "LL_CHUNKSTEPS", (NCCL_STEPS/8));
|
||||
NCCL_PARAM(LL128ChunkSteps, "LL128_CHUNKSTEPS", (NCCL_STEPS/8));
|
||||
NCCL_PARAM(PipelineChunkSteps, "PIPELINE_CHUNKSTEPS", (NCCL_STEPS/8));
|
||||
NCCL_PARAM(RingChunkSteps, "RING_CHUNKSTEPS", (NCCL_STEPS/2));
|
||||
NCCL_PARAM(RingSliceSteps, "RING_SLICESTEPS", (NCCL_STEPS/4));
|
||||
|
||||
static ncclResult_t getStepInfo(struct ncclInfo* info) {
|
||||
if (info->protocol == NCCL_PROTO_LL) {
|
||||
info->chunkSteps = info->sliceSteps = ncclParamLLChunkSteps();
|
||||
} else if (info->protocol == NCCL_PROTO_LL128) {
|
||||
info->chunkSteps = info->sliceSteps = ncclParamLL128ChunkSteps();
|
||||
} else { /* SIMPLE */
|
||||
if (info->algorithm == NCCL_ALGO_TREE || info->coll == ncclFuncBroadcast || info->coll == ncclFuncReduce) {
|
||||
info->chunkSteps = info->sliceSteps = ncclParamPipelineChunkSteps();
|
||||
} else {
|
||||
info->chunkSteps = ncclParamRingChunkSteps();
|
||||
info->sliceSteps = ncclParamRingSliceSteps();
|
||||
}
|
||||
}
|
||||
if (info->chunkSteps > NCCL_STEPS/2 || info->sliceSteps > NCCL_STEPS/2) {
|
||||
WARN("Invalid chunkSteps=%d/sliceSteps=%d, must be at most NCCL_STEPS/2=%d\n", info->chunkSteps, info->sliceSteps, NCCL_STEPS/2);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
if (info->chunkSteps % info->sliceSteps) {
|
||||
WARN("Invalid chunkSteps=%d, must be a multiple of sliceSteps=%d\n", info->chunkSteps, info->sliceSteps);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
if (info->chunkSteps / info->sliceSteps > NCCL_MAX_SLICE_PER_CHUNK) {
|
||||
WARN("Invalid chunkSteps=%d, must be at most sliceSteps*%d=%d\n", info->chunkSteps, NCCL_MAX_SLICE_PER_CHUNK, info->sliceSteps*NCCL_MAX_SLICE_PER_CHUNK);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t computeColl(struct ncclInfo* info /* input */, int* workFuncIndex, struct ncclWorkElem* work, struct ncclProxyOp* proxyOp /* output */) {
|
||||
int collNetTypeSupport = 0;
|
||||
// Check whether algo and proto have been preset (as in aggregation case)
|
||||
@ -1270,6 +1308,7 @@ comp_next:
|
||||
// Set nstepsPerLoop and nchunksPerLoop
|
||||
NCCLCHECK(getPatternInfo(info));
|
||||
NCCLCHECK(getLoopInfo(info));
|
||||
NCCLCHECK(getStepInfo(info));
|
||||
|
||||
work->sendbuff = info->sendbuff;
|
||||
work->recvbuff = info->recvbuff;
|
||||
@ -1279,6 +1318,8 @@ comp_next:
|
||||
work->nWarps = info->nThreads / WARP_SIZE;
|
||||
work->redOpArg = info->opFull.scalarArg;
|
||||
work->redOpArgIsPtr = info->opFull.scalarArgIsPtr;
|
||||
work->stepsPerSlice = info->sliceSteps;
|
||||
work->slicesPerChunk = info->chunkSteps/info->sliceSteps;
|
||||
|
||||
if (info->comm->nRanks == 1) {
|
||||
// one-rank reduce index
|
||||
@ -1289,8 +1330,8 @@ comp_next:
|
||||
*workFuncIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);
|
||||
|
||||
int stepSize = info->comm->buffSizes[info->protocol]/NCCL_STEPS;
|
||||
int chunkSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->chunkSteps : 1;
|
||||
int sliceSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->sliceSteps : 1;
|
||||
int chunkSteps = info->chunkSteps;
|
||||
int sliceSteps = info->sliceSteps;
|
||||
int chunkSize = stepSize*chunkSteps;
|
||||
|
||||
// Compute lastChunkSize
|
||||
@ -1320,7 +1361,7 @@ comp_next:
|
||||
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2;
|
||||
work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
|
||||
} else if (info->protocol == NCCL_PROTO_LL) {
|
||||
const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine);
|
||||
const ssize_t sliceSize = chunkSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine);
|
||||
const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize;
|
||||
work->lastChunkSize = DIVUP((info->nBytes-(info->nBytes/loopSize)*loopSize), info->nChannels*info->nchunksPerLoop);
|
||||
ALIGN_SIZE(work->lastChunkSize, info->nThreads*sizeof(uint64_t));
|
||||
|
@ -109,16 +109,5 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, float)();
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)();
|
||||
|
||||
// CHUNKSIZE must be a multiple of SLICESIZE
|
||||
#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4)
|
||||
#define ALLREDUCE_CHUNKSTEPS (NCCL_STEPS/2)
|
||||
#define ALLGATHER_SLICESTEPS (NCCL_STEPS/4)
|
||||
#define ALLGATHER_CHUNKSTEPS (NCCL_STEPS/2)
|
||||
#define REDUCESCATTER_SLICESTEPS (NCCL_STEPS/4)
|
||||
#define REDUCESCATTER_CHUNKSTEPS (NCCL_STEPS/2)
|
||||
#define BROADCAST_SLICESTEPS 1
|
||||
#define BROADCAST_CHUNKSTEPS 1
|
||||
#define REDUCE_SLICESTEPS 1
|
||||
#define REDUCE_CHUNKSTEPS 1
|
||||
#define NCCL_MAX_SLICE_PER_CHUNK 2 // max value for CHUNKSTEPS/SLICESTEPS, must accord with above
|
||||
|
||||
#define NCCL_MAX_SLICE_PER_CHUNK 8 // max value for CHUNKSTEPS/SLICESTEPS, must accord with below
|
||||
#endif
|
||||
|
@ -29,7 +29,7 @@ extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS];
|
||||
extern const char* ncclProtoStr[NCCL_NUM_PROTOCOLS];
|
||||
|
||||
#define NCCL_MAX_OPS 2048
|
||||
#define NCCL_STEPS 8
|
||||
#define NCCL_STEPS 32
|
||||
|
||||
union ncclLLFifoLine {
|
||||
/* Flags have to be *after* data, because otherwise, an incomplete receive
|
||||
@ -51,13 +51,13 @@ union ncclLLFifoLine {
|
||||
#define NCCL_MAX_NTHREADS 640
|
||||
#define NCCL_SIMPLE_MAX_NTHREADS 512
|
||||
#define NCCL_LL_MAX_NTHREADS 512
|
||||
#define NCCL_LL_LINES_PER_THREAD 8
|
||||
#define NCCL_LL_LINES_PER_THREAD 64 // Must be a multiple of NCCL_STEPS
|
||||
#ifdef TEST_LL_CLEANUP
|
||||
#define NCCL_LL_CLEAN_MASK 0x078 // Set to 0x100 to disable cleanup
|
||||
#define NCCL_LL_FLAG_MAX 0x100
|
||||
#define NCCL_LL_FLAG(a) ((uint32_t)((a) % NCCL_LL_FLAG_MAX))
|
||||
#else
|
||||
#define NCCL_LL_CLEAN_MASK 0x7ffffff8
|
||||
#define NCCL_LL_CLEAN_MASK 0x7fffffe0
|
||||
#define NCCL_LL_FLAG(a) ((uint32_t)(a))
|
||||
#endif
|
||||
// Make sure the clean mask will last for at least NCCL_NSTEPS
|
||||
@ -68,7 +68,7 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK
|
||||
#define NCCL_LL128_DATAELEMS (NCCL_LL128_LINEELEMS-1)
|
||||
|
||||
#define NCCL_LL128_MAX_NTHREADS 640
|
||||
#define NCCL_LL128_ELEMS_PER_THREAD 120
|
||||
#define NCCL_LL128_ELEMS_PER_THREAD 960 // Must be a multiple of NCCL_STEPS
|
||||
|
||||
#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 8
|
||||
#define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS)
|
||||
@ -199,6 +199,8 @@ struct ncclWorkElem {
|
||||
uint32_t root;
|
||||
uint8_t bid;
|
||||
uint8_t nChannels;
|
||||
uint8_t stepsPerSlice;
|
||||
uint8_t slicesPerChunk;
|
||||
uint64_t redOpArg;
|
||||
};
|
||||
|
||||
@ -212,7 +214,8 @@ struct ncclWorkElemP2p {
|
||||
enum ncclWorkP2PType p2pType;
|
||||
uint8_t nWarps;
|
||||
uint8_t warpStart;
|
||||
uint8_t ngroups;
|
||||
uint8_t ngroups:5;
|
||||
uint8_t stepsPerChunkPow2:3;
|
||||
// Important not to use any fields with greater than 4-byte alignment since
|
||||
// we need sizeof(ncclWorkElemP2p)==28, but that would be padded up to 32 if
|
||||
// there were 8-byte fields.
|
||||
|
@ -17,7 +17,7 @@
|
||||
#define NCCL_PTR_DMABUF 0x4
|
||||
|
||||
// Maximum number of requests per comm object
|
||||
#define NCCL_NET_MAX_REQUESTS 8
|
||||
#define NCCL_NET_MAX_REQUESTS 32
|
||||
|
||||
typedef enum {NCCL_LOG_NONE=0, NCCL_LOG_VERSION=1, NCCL_LOG_WARN=2, NCCL_LOG_INFO=3, NCCL_LOG_ABORT=4, NCCL_LOG_TRACE=5} ncclDebugLogLevel;
|
||||
typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_CALL=512, NCCL_ALL=~0} ncclDebugLogSubSys;
|
||||
|
@ -463,8 +463,8 @@ static ncclResult_t setupChannel(struct ncclComm* comm, int channelId, int rank,
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
#define DEFAULT_LL_BUFFSIZE (NCCL_LL_LINES_PER_THREAD*NCCL_LL_MAX_NTHREADS*NCCL_STEPS*sizeof(union ncclLLFifoLine))
|
||||
#define DEFAULT_LL128_BUFFSIZE (NCCL_LL128_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS*NCCL_STEPS*sizeof(uint64_t))
|
||||
#define DEFAULT_LL_BUFFSIZE (NCCL_LL_LINES_PER_THREAD*NCCL_LL_MAX_NTHREADS*sizeof(union ncclLLFifoLine))
|
||||
#define DEFAULT_LL128_BUFFSIZE (NCCL_LL128_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS*sizeof(uint64_t))
|
||||
#define DEFAULT_BUFFSIZE (1 << 22) /* 4MiB */
|
||||
#define DEFAULT_BUFFSIZE_ARM (1 << 20) /* 1MiB */
|
||||
NCCL_PARAM(BuffSize, "BUFFSIZE", -2);
|
||||
|
25
src/proxy.cc
25
src/proxy.cc
@ -430,15 +430,15 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op)
|
||||
int channelId = info->channelId;
|
||||
struct ncclChannel* channel = info->comm->channels+channelId;
|
||||
op->channelId = channelId;
|
||||
op->sliceSteps = 1;
|
||||
op->chunkSteps = 1;
|
||||
op->sliceSteps = info->sliceSteps;
|
||||
op->chunkSteps = info->chunkSteps;
|
||||
op->dtype = info->datatype;
|
||||
op->protocol = info->protocol;
|
||||
|
||||
int stepSize = info->comm->buffSizes[op->protocol]/NCCL_STEPS;
|
||||
int chunkSize = stepSize * info->chunkSteps;
|
||||
|
||||
if (op->protocol == NCCL_PROTO_SIMPLE) stepSize = info->comm->p2pChunkSize;
|
||||
info->chunkSize = stepSize;
|
||||
if (op->protocol == NCCL_PROTO_SIMPLE) chunkSize = info->comm->p2pChunkSize;
|
||||
op->root = info->root;
|
||||
|
||||
struct ncclChannelPeer* peer = channel->peers + op->root;
|
||||
@ -446,24 +446,27 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op)
|
||||
op->pattern = ncclPatternSend;
|
||||
if (op->root != info->comm->rank && peer->send[1].transportComm == &netTransport.send) {
|
||||
// Tune chunk size for the network
|
||||
if (info->count < stepSize) info->chunkSize /= 4;
|
||||
else if (info->count < 8*stepSize) info->chunkSize /= 2;
|
||||
if (info->count < chunkSize) chunkSize /= 4;
|
||||
else if (info->count < 8*chunkSize) chunkSize /= 2;
|
||||
}
|
||||
} else if (info->coll == ncclFuncRecv) {
|
||||
op->pattern = ncclPatternRecv;
|
||||
if (op->root != info->comm->rank && peer->recv[1].transportComm == &netTransport.recv) {
|
||||
// Tune chunk size for the network
|
||||
if (info->count < stepSize) info->chunkSize /= 4;
|
||||
else if (info->count < 8*stepSize) info->chunkSize /= 2;
|
||||
if (info->count < chunkSize) chunkSize /= 4;
|
||||
else if (info->count < 8*chunkSize) chunkSize /= 2;
|
||||
}
|
||||
} else {
|
||||
WARN("P2p operation is neither send or recv");
|
||||
return ncclInternalError;
|
||||
}
|
||||
if (ncclParamChunkSize() != 0) {
|
||||
info->chunkSize = ncclParamChunkSize();
|
||||
chunkSize = ncclParamChunkSize();
|
||||
}
|
||||
op->chunkSize = info->chunkSize;
|
||||
if (chunkSize > stepSize*info->chunkSteps) {
|
||||
chunkSize = stepSize*info->chunkSteps;
|
||||
}
|
||||
op->chunkSize = info->chunkSize = chunkSize;
|
||||
|
||||
// Compute nSteps for proxies
|
||||
int chunkEffectiveSize = op->chunkSize;
|
||||
@ -471,7 +474,7 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op)
|
||||
chunkEffectiveSize /= 2;
|
||||
}
|
||||
|
||||
op->nbytes = stepSize;
|
||||
op->nbytes = chunkSize;
|
||||
op->nsteps = DIVUP(info->count, chunkEffectiveSize);
|
||||
if (op->nsteps == 0) op->nsteps = 1;
|
||||
|
||||
|
@ -837,7 +837,6 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
||||
args->idle = 1;
|
||||
if (args->state == ncclProxyOpProgress) {
|
||||
int p = args->protocol;
|
||||
int maxDepth = std::min(NCCL_STEPS, NCCL_SHARED_STEPS/args->nsubs);
|
||||
for (int s=0; s<args->nsubs; s++) {
|
||||
struct ncclProxySubArgs* sub = args->subs+s;
|
||||
if (sub->done == sub->nsteps) continue;
|
||||
@ -848,6 +847,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
||||
int buffSize = stepSize*args->sliceSteps;
|
||||
if (sub->nbytes < buffSize) buffSize = sub->nbytes;
|
||||
// Post buffers to the GPU
|
||||
int maxDepth = resources->shared ? NCCL_SHARED_STEPS/args->nsubs : NCCL_STEPS;
|
||||
if (sub->posted < sub->nsteps && sub->posted < sub->done + maxDepth) {
|
||||
int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;
|
||||
if (resources->shared) {
|
||||
@ -883,7 +883,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
||||
if (!ready) {
|
||||
// When data is in sysmem, we need to wait until all flags are correct since the GPU only
|
||||
// called threadfence()
|
||||
uint64_t flag = sub->base+sub->transmitted+1;
|
||||
uint64_t flag = sub->base+sub->transmitted+args->sliceSteps;
|
||||
int nFifoLines = DIVUP(sizesFifo[buffSlot], sizeof(uint64_t)*NCCL_LL128_LINEELEMS);
|
||||
volatile uint64_t* lines = (volatile uint64_t*)buff;
|
||||
ready = 1;
|
||||
@ -892,7 +892,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
||||
}
|
||||
}
|
||||
} else if (p == NCCL_PROTO_LL) {
|
||||
uint32_t flag = NCCL_LL_FLAG(sub->base+sub->transmitted+1);
|
||||
uint32_t flag = NCCL_LL_FLAG(sub->base+sub->transmitted+args->sliceSteps);
|
||||
int nFifoLines = DIVUP(size, sizeof(union ncclLLFifoLine));
|
||||
union ncclLLFifoLine* lines = (union ncclLLFifoLine*)buff;
|
||||
for (int i=0; i<nFifoLines; i++) {
|
||||
@ -987,7 +987,6 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
||||
args->idle = 1;
|
||||
if (args->state == ncclProxyOpProgress) {
|
||||
int p = args->protocol;
|
||||
int maxDepth = std::min(NCCL_STEPS, NCCL_SHARED_STEPS/args->nsubs);
|
||||
for (int s=0; s<args->nsubs; s+=args->subs[s].groupSize) {
|
||||
struct ncclProxySubArgs* subGroup = args->subs+s;
|
||||
int subCount = 0;
|
||||
@ -999,8 +998,9 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
||||
for (int i=0; i<subGroup->groupSize; i++) {
|
||||
struct ncclProxySubArgs* sub = subGroup + i;
|
||||
if (sub->posted < sub->nsteps) {
|
||||
if (sub->posted >= sub->done + maxDepth) { subCount = 0; break; }
|
||||
struct recvResources* resources = (struct recvResources*) (sub->connection->transportResources);
|
||||
int maxDepth = resources->shared ? NCCL_SHARED_STEPS/args->nsubs : NCCL_STEPS;
|
||||
if (sub->posted >= sub->done + maxDepth) { subCount = 0; break; }
|
||||
int stepSize = resources->buffSizes[p] / NCCL_STEPS;
|
||||
char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]);
|
||||
int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;
|
||||
|
Loading…
x
Reference in New Issue
Block a user