ncclGroup's containing operations of mixed datatype, element, or collective
would induce crash.
This commit is contained in:
John Bachan 2021-08-31 14:33:48 -07:00
parent 7e51592129
commit 5f2f2f670f
2 changed files with 37 additions and 12 deletions

View File

@ -71,9 +71,30 @@ template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
struct RunWork {
__device__ void run(ncclWork *w) {
int tid = threadIdx.x;
#pragma unroll 1
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) {
if (tid < w->elems[e].nThreads)
/* Some invariants that must hold:
* 1. All elems[] have same funcIndex.
* 2. All elems[] have same nThreads.
* 3. The thread-to-group relation (as in prims group numbers) is the same
* for all elems[].
*
* If (1) isn't true then we might be in the wrong function since dispatch
* on ncclFuncs[w->elems[0].funcIndex] is how we got here.
*
* If (2) or (3) aren't true, then threads from different work elements
* could race for barrier resources (barrier numbers 0...15) which is fatal.
*
* Important, to ensure (3), implementations of
* `RunWorkElement<Fn,T,RedOp,Algo,Proto>::run()` may only use values which
* are the same for all elems[] when deciding how to map threads to groups,
* such as the following:
* Fn, T, RedOp, Algo, Proto, nThreads
*
* This last one is difficult to enforce and diagnosing it is a headeache.
* Device-side developers, consider yourselves warned.
*/
if (tid < w->elems[0].nThreads) {
#pragma unroll 1
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++)
RunWorkElement<Fn, T, RedOp, Algo, Proto>().run(&w->elems[e]);
}
}

View File

@ -681,29 +681,29 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
// Reduce the per-channel size if we cannot fully utilize the channels
while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2;
int channelUsed = 0;
ncclFunc_t commonColl = ncclNumFuncs;
int fastPath = 1;
int homogeneous = 1;
int allCollNetSupport = comm->collNetSupport;
for (int c = 0; c < comm->asyncOpCount; c++) {
struct ncclInfo* info = comm->asyncOps+c;
info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels
channelUsed += info->nChannels;
// We can use fast path if all collectives are the same
if (commonColl == ncclNumFuncs) commonColl = info->coll;
else if (commonColl != info->coll) fastPath = 0;
else if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
homogeneous &= info->coll == comm->asyncOps[0].coll &&
info->op == comm->asyncOps[0].op &&
info->datatype == comm->asyncOps[0].datatype;
if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
}
// Compute algo, proto, nthreads for the entire kernel
struct ncclInfo total;
total.comm = comm;
total.coll = commonColl;
total.coll = comm->asyncOps[0].coll;
total.nBytes = comm->asyncTotalSize;
total.nChannels = std::min(channelUsed, comm->nChannels);
int perChannelOps = DIVUP(channelUsed, total.nChannels);
if (fastPath) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps));
if (homogeneous) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps));
for (int c = 0; c < comm->asyncOpCount; c++) {
struct ncclInfo* info = comm->asyncOps+c;
if (fastPath) {
if (homogeneous) {
info->algorithm = total.algorithm;
info->protocol = total.protocol;
info->nThreads = total.nThreads;
@ -883,7 +883,11 @@ ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem*
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
struct ncclWork* w = channel->workFifo+opIndex;
int segment = -1;
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) {
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0 &&
// All elems in work must have same (funcIndex,nThreads),
// see "src/collectives/device/common.h"
w->elems[0].funcIndex == work->funcIndex &&
w->elems[0].nThreads == work->nThreads) {
// Try to pack more segments into a single operation
segment = getSegment(COLL_SEGMENT, 0, w);
}