ncclGroup's containing operations of mixed datatype, element, or collective would induce crash.
This commit is contained in:
parent
7e51592129
commit
5f2f2f670f
@ -71,9 +71,30 @@ template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
|
||||
struct RunWork {
|
||||
__device__ void run(ncclWork *w) {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll 1
|
||||
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) {
|
||||
if (tid < w->elems[e].nThreads)
|
||||
/* Some invariants that must hold:
|
||||
* 1. All elems[] have same funcIndex.
|
||||
* 2. All elems[] have same nThreads.
|
||||
* 3. The thread-to-group relation (as in prims group numbers) is the same
|
||||
* for all elems[].
|
||||
*
|
||||
* If (1) isn't true then we might be in the wrong function since dispatch
|
||||
* on ncclFuncs[w->elems[0].funcIndex] is how we got here.
|
||||
*
|
||||
* If (2) or (3) aren't true, then threads from different work elements
|
||||
* could race for barrier resources (barrier numbers 0...15) which is fatal.
|
||||
*
|
||||
* Important, to ensure (3), implementations of
|
||||
* `RunWorkElement<Fn,T,RedOp,Algo,Proto>::run()` may only use values which
|
||||
* are the same for all elems[] when deciding how to map threads to groups,
|
||||
* such as the following:
|
||||
* Fn, T, RedOp, Algo, Proto, nThreads
|
||||
*
|
||||
* This last one is difficult to enforce and diagnosing it is a headeache.
|
||||
* Device-side developers, consider yourselves warned.
|
||||
*/
|
||||
if (tid < w->elems[0].nThreads) {
|
||||
#pragma unroll 1
|
||||
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++)
|
||||
RunWorkElement<Fn, T, RedOp, Algo, Proto>().run(&w->elems[e]);
|
||||
}
|
||||
}
|
||||
|
@ -681,29 +681,29 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
|
||||
// Reduce the per-channel size if we cannot fully utilize the channels
|
||||
while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2;
|
||||
int channelUsed = 0;
|
||||
ncclFunc_t commonColl = ncclNumFuncs;
|
||||
int fastPath = 1;
|
||||
int homogeneous = 1;
|
||||
int allCollNetSupport = comm->collNetSupport;
|
||||
for (int c = 0; c < comm->asyncOpCount; c++) {
|
||||
struct ncclInfo* info = comm->asyncOps+c;
|
||||
info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels
|
||||
channelUsed += info->nChannels;
|
||||
// We can use fast path if all collectives are the same
|
||||
if (commonColl == ncclNumFuncs) commonColl = info->coll;
|
||||
else if (commonColl != info->coll) fastPath = 0;
|
||||
else if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
|
||||
homogeneous &= info->coll == comm->asyncOps[0].coll &&
|
||||
info->op == comm->asyncOps[0].op &&
|
||||
info->datatype == comm->asyncOps[0].datatype;
|
||||
if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
|
||||
}
|
||||
// Compute algo, proto, nthreads for the entire kernel
|
||||
struct ncclInfo total;
|
||||
total.comm = comm;
|
||||
total.coll = commonColl;
|
||||
total.coll = comm->asyncOps[0].coll;
|
||||
total.nBytes = comm->asyncTotalSize;
|
||||
total.nChannels = std::min(channelUsed, comm->nChannels);
|
||||
int perChannelOps = DIVUP(channelUsed, total.nChannels);
|
||||
if (fastPath) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps));
|
||||
if (homogeneous) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps));
|
||||
for (int c = 0; c < comm->asyncOpCount; c++) {
|
||||
struct ncclInfo* info = comm->asyncOps+c;
|
||||
if (fastPath) {
|
||||
if (homogeneous) {
|
||||
info->algorithm = total.algorithm;
|
||||
info->protocol = total.protocol;
|
||||
info->nThreads = total.nThreads;
|
||||
@ -883,7 +883,11 @@ ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem*
|
||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
int segment = -1;
|
||||
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) {
|
||||
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0 &&
|
||||
// All elems in work must have same (funcIndex,nThreads),
|
||||
// see "src/collectives/device/common.h"
|
||||
w->elems[0].funcIndex == work->funcIndex &&
|
||||
w->elems[0].nThreads == work->nThreads) {
|
||||
// Try to pack more segments into a single operation
|
||||
segment = getSegment(COLL_SEGMENT, 0, w);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user