ncclGroup's containing operations of mixed datatype, element, or collective would induce crash.
This commit is contained in:
parent
7e51592129
commit
5f2f2f670f
@ -71,9 +71,30 @@ template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
|
|||||||
struct RunWork {
|
struct RunWork {
|
||||||
__device__ void run(ncclWork *w) {
|
__device__ void run(ncclWork *w) {
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
#pragma unroll 1
|
/* Some invariants that must hold:
|
||||||
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) {
|
* 1. All elems[] have same funcIndex.
|
||||||
if (tid < w->elems[e].nThreads)
|
* 2. All elems[] have same nThreads.
|
||||||
|
* 3. The thread-to-group relation (as in prims group numbers) is the same
|
||||||
|
* for all elems[].
|
||||||
|
*
|
||||||
|
* If (1) isn't true then we might be in the wrong function since dispatch
|
||||||
|
* on ncclFuncs[w->elems[0].funcIndex] is how we got here.
|
||||||
|
*
|
||||||
|
* If (2) or (3) aren't true, then threads from different work elements
|
||||||
|
* could race for barrier resources (barrier numbers 0...15) which is fatal.
|
||||||
|
*
|
||||||
|
* Important, to ensure (3), implementations of
|
||||||
|
* `RunWorkElement<Fn,T,RedOp,Algo,Proto>::run()` may only use values which
|
||||||
|
* are the same for all elems[] when deciding how to map threads to groups,
|
||||||
|
* such as the following:
|
||||||
|
* Fn, T, RedOp, Algo, Proto, nThreads
|
||||||
|
*
|
||||||
|
* This last one is difficult to enforce and diagnosing it is a headeache.
|
||||||
|
* Device-side developers, consider yourselves warned.
|
||||||
|
*/
|
||||||
|
if (tid < w->elems[0].nThreads) {
|
||||||
|
#pragma unroll 1
|
||||||
|
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++)
|
||||||
RunWorkElement<Fn, T, RedOp, Algo, Proto>().run(&w->elems[e]);
|
RunWorkElement<Fn, T, RedOp, Algo, Proto>().run(&w->elems[e]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -681,29 +681,29 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
|
|||||||
// Reduce the per-channel size if we cannot fully utilize the channels
|
// Reduce the per-channel size if we cannot fully utilize the channels
|
||||||
while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2;
|
while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2;
|
||||||
int channelUsed = 0;
|
int channelUsed = 0;
|
||||||
ncclFunc_t commonColl = ncclNumFuncs;
|
int homogeneous = 1;
|
||||||
int fastPath = 1;
|
|
||||||
int allCollNetSupport = comm->collNetSupport;
|
int allCollNetSupport = comm->collNetSupport;
|
||||||
for (int c = 0; c < comm->asyncOpCount; c++) {
|
for (int c = 0; c < comm->asyncOpCount; c++) {
|
||||||
struct ncclInfo* info = comm->asyncOps+c;
|
struct ncclInfo* info = comm->asyncOps+c;
|
||||||
info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels
|
info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels
|
||||||
channelUsed += info->nChannels;
|
channelUsed += info->nChannels;
|
||||||
// We can use fast path if all collectives are the same
|
// We can use fast path if all collectives are the same
|
||||||
if (commonColl == ncclNumFuncs) commonColl = info->coll;
|
homogeneous &= info->coll == comm->asyncOps[0].coll &&
|
||||||
else if (commonColl != info->coll) fastPath = 0;
|
info->op == comm->asyncOps[0].op &&
|
||||||
else if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
|
info->datatype == comm->asyncOps[0].datatype;
|
||||||
|
if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
|
||||||
}
|
}
|
||||||
// Compute algo, proto, nthreads for the entire kernel
|
// Compute algo, proto, nthreads for the entire kernel
|
||||||
struct ncclInfo total;
|
struct ncclInfo total;
|
||||||
total.comm = comm;
|
total.comm = comm;
|
||||||
total.coll = commonColl;
|
total.coll = comm->asyncOps[0].coll;
|
||||||
total.nBytes = comm->asyncTotalSize;
|
total.nBytes = comm->asyncTotalSize;
|
||||||
total.nChannels = std::min(channelUsed, comm->nChannels);
|
total.nChannels = std::min(channelUsed, comm->nChannels);
|
||||||
int perChannelOps = DIVUP(channelUsed, total.nChannels);
|
int perChannelOps = DIVUP(channelUsed, total.nChannels);
|
||||||
if (fastPath) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps));
|
if (homogeneous) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps));
|
||||||
for (int c = 0; c < comm->asyncOpCount; c++) {
|
for (int c = 0; c < comm->asyncOpCount; c++) {
|
||||||
struct ncclInfo* info = comm->asyncOps+c;
|
struct ncclInfo* info = comm->asyncOps+c;
|
||||||
if (fastPath) {
|
if (homogeneous) {
|
||||||
info->algorithm = total.algorithm;
|
info->algorithm = total.algorithm;
|
||||||
info->protocol = total.protocol;
|
info->protocol = total.protocol;
|
||||||
info->nThreads = total.nThreads;
|
info->nThreads = total.nThreads;
|
||||||
@ -883,7 +883,11 @@ ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem*
|
|||||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||||
struct ncclWork* w = channel->workFifo+opIndex;
|
struct ncclWork* w = channel->workFifo+opIndex;
|
||||||
int segment = -1;
|
int segment = -1;
|
||||||
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) {
|
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0 &&
|
||||||
|
// All elems in work must have same (funcIndex,nThreads),
|
||||||
|
// see "src/collectives/device/common.h"
|
||||||
|
w->elems[0].funcIndex == work->funcIndex &&
|
||||||
|
w->elems[0].nThreads == work->nThreads) {
|
||||||
// Try to pack more segments into a single operation
|
// Try to pack more segments into a single operation
|
||||||
segment = getSegment(COLL_SEGMENT, 0, w);
|
segment = getSegment(COLL_SEGMENT, 0, w);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user