Fix crash with CollnetChain on some node topologies
Fix hang when interleaving the capture of different graphs
Fix hang during init in multi-threaded mode
Fix potential data corruption with LL128 protocol on unaligned buffers.
Fix CPU usage during preconnect
Fixes double-free in the error path for ncclCommInitAll
Workaround hang on H100 with Ring/LL128 on 2 GPUs.
This commit is contained in:
Sylvain Jeaugey 2022-10-25 00:55:55 -07:00
parent da8152e57a
commit cb111f764a
11 changed files with 281 additions and 147 deletions

View File

@ -1,6 +1,6 @@
##### version ##### version
NCCL_MAJOR := 2 NCCL_MAJOR := 2
NCCL_MINOR := 15 NCCL_MINOR := 15
NCCL_PATCH := 1 NCCL_PATCH := 5
NCCL_SUFFIX := NCCL_SUFFIX :=
PKG_REVISION := 1 PKG_REVISION := 1

View File

@ -20,11 +20,11 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) {
// The extra on nRanks+1 is for collnet root (i.e. network) // The extra on nRanks+1 is for collnet root (i.e. network)
channel->peers = ncclMemoryStackAlloc<struct ncclChannelPeer>(&comm->memPermanent, nRanks+1); channel->peers = ncclMemoryStackAlloc<struct ncclChannelPeer>(&comm->memPermanent, nRanks+1);
NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nRanks+1, comm->deviceStream.stream)); NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nRanks+1, comm->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devPeers); ncclCommPushCudaFree(comm, channel->devPeers);
channel->ring.userRanks = ncclMemoryStackAlloc<int>(&comm->memPermanent, nRanks); channel->ring.userRanks = ncclMemoryStackAlloc<int>(&comm->memPermanent, nRanks);
NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, comm->deviceStream.stream)); NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, comm->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devRingUserRanks); ncclCommPushCudaFree(comm, channel->devRingUserRanks);
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream));

View File

@ -15,11 +15,12 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
static constexpr int Input=0, Output=1; static constexpr int Input=0, Output=1;
RedOp redOp; RedOp redOp;
const int tid; const int tid; // thread index in primitives group
const int nthreads; const int nthreads; // thread count in primitives group
const int wid; const int wid; // lane index in warp
const int stepSize; const int stepSize;
const int warp; const int warp; // warp index in primitives group
const int warpInBlock; // warp index in thread block
const bool flagThread; const bool flagThread;
const int group; const int group;
Fan fan; Fan fan;
@ -108,7 +109,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
// buffer into shmem. // buffer into shmem.
int misalignment = reinterpret_cast<uintptr_t>(src) % 16; int misalignment = reinterpret_cast<uintptr_t>(src) % 16;
uint64_t *src8 = reinterpret_cast<uint64_t*>(reinterpret_cast<uintptr_t>(src) & -uintptr_t(16)); uint64_t *src8 = reinterpret_cast<uint64_t*>(reinterpret_cast<uintptr_t>(src) & -uintptr_t(16));
uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]); uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warpInBlock]);
#pragma unroll #pragma unroll
for(int g=0; g < WordPerThread/2; g++) for(int g=0; g < WordPerThread/2; g++)
if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T)) if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T))
@ -152,7 +153,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
} }
// Write to dst if 16-byte aligned, shmem otherwise. // Write to dst if 16-byte aligned, shmem otherwise.
int misalignment = reinterpret_cast<uintptr_t>(dst)%16; int misalignment = reinterpret_cast<uintptr_t>(dst)%16;
uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]); uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warpInBlock]);
#pragma unroll #pragma unroll
for(int g=0; g < WordPerThread/2; g++) { for(int g=0; g < WordPerThread/2; g++) {
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
@ -166,7 +167,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
__syncwarp(); __syncwarp();
// Write rest from shmem to dst. No need to coalesce stores to 16-bytes, // Write rest from shmem to dst. No need to coalesce stores to 16-bytes,
// the hardware keeps up fine. // the hardware keeps up fine.
T *shm = (T*)ncclShmem.ll128warp[warp]; T *shm = (T*)ncclShmem.ll128warp[warpInBlock];
int skip = misalignment == 0 ? eltN & -EltPer16B : 0; int skip = misalignment == 0 ? eltN & -EltPer16B : 0;
for(int i=skip+wid; i < eltN; i += WARP_SIZE) for(int i=skip+wid; i < eltN; i += WARP_SIZE)
dst[i] = shm[i]; dst[i] = shm[i];
@ -215,7 +216,6 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
/************************ Recv rest *********************/ /************************ Recv rest *********************/
if (RECV) { if (RECV) {
{ // Consume data from first recv { // Consume data from first recv
uint64_t* ptr = recvPtr(0)+ll128Offset;
#pragma unroll #pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) { for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
v[u] = SRC ? MULTI<RedOp, T>()(redOp, vr[u], v[u]) : vr[u]; v[u] = SRC ? MULTI<RedOp, T>()(redOp, vr[u], v[u]) : vr[u];
@ -360,6 +360,7 @@ public:
): ):
redOp(redOpArg), redOp(redOpArg),
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE),
warpInBlock(threadIdx.x/WARP_SIZE),
flagThread((tid%8)==7), group(group&(uint16_t)0xFFFF), flagThread((tid%8)==7), group(group&(uint16_t)0xFFFF),
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) { stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) {
int connIndex = group >> 16; int connIndex = group >> 16;

View File

@ -853,6 +853,7 @@ static ncclResult_t hostStreamPlanTask(struct ncclComm* comm, struct ncclKernelP
} }
static void CUDART_CB hostStreamPlanCallback(void *plan_) { static void CUDART_CB hostStreamPlanCallback(void *plan_) {
NVTX3_FUNC_RANGE_IN(nccl_domain);
struct ncclKernelPlan* plan = (struct ncclKernelPlan*)plan_; struct ncclKernelPlan* plan = (struct ncclKernelPlan*)plan_;
ncclResult_t result = hostStreamPlanTask(plan->comm, plan); ncclResult_t result = hostStreamPlanTask(plan->comm, plan);
if (result != ncclSuccess) { if (result != ncclSuccess) {

View File

@ -15,7 +15,7 @@
/******************************************************************/ /******************************************************************/
ncclResult_t ncclTopoPreset(struct ncclComm* comm, ncclResult_t ncclTopoPreset(struct ncclComm* comm,
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph,
struct ncclTopoRanks* topoRanks) { struct ncclTopoRanks* topoRanks) {
int rank = comm->rank; int rank = comm->rank;
int localRanks = comm->topo->nodes[GPU].count; int localRanks = comm->topo->nodes[GPU].count;
@ -37,6 +37,7 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
int* ringIntra = ringGraph->intra+c*localRanks; int* ringIntra = ringGraph->intra+c*localRanks;
int* treeIntra = treeGraph->intra+c*localRanks; int* treeIntra = treeGraph->intra+c*localRanks;
int* collNetIntra = collNetGraph->intra+c*localRanks;
for (int i=0; i<localRanks; i++) { for (int i=0; i<localRanks; i++) {
if (ringIntra[i] == rank) { if (ringIntra[i] == rank) {
@ -55,8 +56,10 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
topoRanks->treeToChild1[c] = treeIntra[child1Index]; topoRanks->treeToChild1[c] = treeIntra[child1Index];
channel->tree.up = i == 0 ? -1 : treeIntra[i-1]; channel->tree.up = i == 0 ? -1 : treeIntra[i-1];
channel->tree.down[0] = i == localRanks-1 ? -1 : treeIntra[i+1]; channel->tree.down[0] = i == localRanks-1 ? -1 : treeIntra[i+1];
channel->collnetChain.up = i == 0 ? comm->nRanks : treeIntra[i-1]; }
channel->collnetChain.down[0] = i == localRanks-1 ? -1 : treeIntra[i+1]; if (collNetIntra[i] == rank) {
channel->collnetChain.up = i == 0 ? comm->nRanks : collNetIntra[i-1];
channel->collnetChain.down[0] = i == localRanks-1 ? -1 : collNetIntra[i+1];
} }
} }
topoRanks->ringPrev[c] = channel->ring.prev; topoRanks->ringPrev[c] = channel->ring.prev;

View File

@ -227,8 +227,16 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
int pEnable = protoEnable[p]; int pEnable = protoEnable[p];
if (pEnable == 2 && p == NCCL_PROTO_LL128) { if (pEnable == 2 && p == NCCL_PROTO_LL128) {
// Enable LL128 by default only on Volta/Ampere/Hopper+NVLink. Other cases are not tested and may cause silent data corruption. // Enable LL128 by default only on Volta/Ampere/Hopper+NVLink. Other cases are not tested and may cause silent data corruption.
pEnable = (graphs[a]->typeInter <= PATH_PXB) && graphs[a]->typeIntra <= PATH_NVL && pEnable = 1;
((minCompCap == 70 && maxCompCap == 70) || (minCompCap == 80 && maxCompCap == 80) || (minCompCap == 90 && maxCompCap == 90)) ? 1 : 0; pEnable &= (graphs[a]->typeInter <= PATH_PXB);
pEnable &= (graphs[a]->typeIntra <= PATH_NVL);
pEnable &= (minCompCap == maxCompCap);
switch (minCompCap) {
case 70: pEnable &= 1; break;
case 80: pEnable &= 1; break;
case 90: pEnable &= !(CUDART_VERSION == 11080 && c == ncclFuncAllReduce && a == NCCL_ALGO_RING && comm->nRanks == 2); break;
default: pEnable &= 0; break;
}
} }
if (pEnable == 0) comm->bandwidths[c][a][p] = 0; if (pEnable == 0) comm->bandwidths[c][a][p] = 0;
// Only disable algo for Allreduce since others only have one // Only disable algo for Allreduce since others only have one

View File

@ -331,6 +331,8 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) {
job = job->next; job = job->next;
} while (job != nullptr); } while (job != nullptr);
// Let preconnect threads progress.
if (jobsDone == false) usleep(1);
} while (jobsDone == false); } while (jobsDone == false);
if (ret != ncclSuccess) goto fail; if (ret != ncclSuccess) goto fail;

View File

@ -101,7 +101,7 @@ struct ncclTopoRanks {
}; };
ncclResult_t ncclTopoPreset(struct ncclComm* comm, ncclResult_t ncclTopoPreset(struct ncclComm* comm,
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph,
struct ncclTopoRanks* topoRanks); struct ncclTopoRanks* topoRanks);
ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns, ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns,

View File

@ -18,7 +18,7 @@
struct ncclCudaGraph { struct ncclCudaGraph {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
cudaGraph_t graph; cudaGraph_t graph;
uint64_t graphId; unsigned long long graphId;
#endif #endif
}; };
@ -57,36 +57,29 @@ ncclResult_t ncclCudaGraphAddDestructor(struct ncclCudaGraph graph, cudaHostFn_t
* streams unfit for the use of serializing access to a persistent resource. * streams unfit for the use of serializing access to a persistent resource.
* Strong streams have been introduced to address this need. * Strong streams have been introduced to address this need.
* *
* Constraints of using strong streams: * - All updates to a strong stream must be enclosed by a Acquire/Release pair.
* *
* - Operations that enqueue work to the strong stream need to be enclosed by * - The Acquire, Release, and all updates take a ncclCudaGraph parameter
* ncclStrongStream[Acquire/Release] pairs. Acquire/release act like fences, * indicating the currently capturing graph (or none). This parameter must be
* the strong stream is not stateful so there is no harm in redundant acquire * the same for the entire sequence of {Acquire; ...; Release}.
* or releases.
* *
* - An {Acquire; ...; Release} sequence must not be concurrent with any * - An {Acquire; ...; Release} sequence must not be concurrent with any
* other operations against the strong stream including graph launches which * other operations against the strong stream including graph launches which
* reference this stream. * reference this stream.
*
* - All strong stream functions take a "graph" parameter which must reference
* the currently capturing graph, or null if none.
*/ */
struct ncclStrongStream; struct ncclStrongStream;
ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss); ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss);
ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss); ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss);
// Has this strong stream ever been captured in a graph.
bool ncclStrongStreamEverCaptured(struct ncclStrongStream* ss);
// Acquire-fence the strong stream. // Acquire-fence the strong stream.
ncclResult_t ncclStrongStreamAcquire( ncclResult_t ncclStrongStreamAcquire(
struct ncclCudaGraph graph, struct ncclStrongStream* ss struct ncclCudaGraph graph, struct ncclStrongStream* ss
); );
// Acquire-fence the strong stream assuming no graph is capturing. This permits // Acquire-fence the strong stream assuming no graph is capturing. This permits
// the caller to enqueue directly to the `ss->stream` member using native CUDA // the caller to enqueue directly to the `ss->cudaStream` member using native CUDA
// calls. Strong stream must be released via: // calls. Strong stream still must be released via:
// ncclStrongStreamRelease(ncclCudaGraphNone(), ss); // ncclStrongStreamRelease(ncclCudaGraphNone(), ss);
ncclResult_t ncclStrongStreamAcquireUncaptured(struct ncclStrongStream* ss); ncclResult_t ncclStrongStreamAcquireUncaptured(struct ncclStrongStream* ss);
@ -103,6 +96,7 @@ ncclResult_t ncclStrongStreamLaunchKernel(
struct ncclCudaGraph graph, struct ncclStrongStream* ss, struct ncclCudaGraph graph, struct ncclStrongStream* ss,
void* fn, dim3 grid, dim3 block, void** args, size_t sharedMemBytes void* fn, dim3 grid, dim3 block, void** args, size_t sharedMemBytes
); );
// Cause `a` to wait for the current state `b`. Both `a` and `b` must be acquired. // Cause `a` to wait for the current state `b`. Both `a` and `b` must be acquired.
ncclResult_t ncclStrongStreamWaitStream( ncclResult_t ncclStrongStreamWaitStream(
struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b
@ -121,21 +115,23 @@ ncclResult_t ncclStrongStreamSynchronize(struct ncclStrongStream* ss);
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
struct ncclStrongStreamGraph; // internal to ncclStrongStream
struct ncclStrongStream { struct ncclStrongStream {
cudaStream_t stream; // Used when not graph capturing.
cudaEvent_t event; cudaStream_t cudaStream;
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
cudaGraphNode_t node; // null if never captured, otherwise never null again // The event used to establish order between graphs and streams. During acquire
uint64_t graphId:63, eventIsLagging:1; // this event is waited on, during release it is recorded to.
#endif cudaEvent_t serialEvent;
// This stream ever appeared in a graph capture.
bool everCaptured;
// Tracks whether serialEvent needs to be recorded to upon Release().
bool serialEventNeedsRecord;
struct ncclStrongStreamGraph* graphHead;
#else
cudaEvent_t scratchEvent;
#endif
}; };
inline bool ncclStrongStreamEverCaptured(struct ncclStrongStream* ss) {
#if CUDART_VERSION >= 11030
return ss->node != nullptr;
#else
return false;
#endif
}
#endif #endif

View File

@ -364,7 +364,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
int nRanks = comm->nRanks; int nRanks = comm->nRanks;
struct ncclDevCommAndChannels *devCommAndChans, tmpCommAndChans; struct ncclDevCommAndChannels *devCommAndChans, tmpCommAndChans;
NCCLCHECK(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.stream)); NCCLCHECK(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, devCommAndChans); ncclCommPushCudaFree(comm, devCommAndChans);
comm->devComm = &devCommAndChans->comm; comm->devComm = &devCommAndChans->comm;
tmpCommAndChans.comm.rank = comm->rank; tmpCommAndChans.comm.rank = comm->rank;
@ -410,12 +410,12 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c]; tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c];
if (comm->channels[c].ring.userRanks != nullptr) { if (comm->channels[c].ring.userRanks != nullptr) {
NCCLCHECK(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.stream)); NCCLCHECK(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.cudaStream));
} }
} }
NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.stream)); NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.cudaStream));
CUDACHECK(cudaStreamSynchronize(comm->deviceStream.stream)); CUDACHECK(cudaStreamSynchronize(comm->deviceStream.cudaStream));
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream));
return ncclSuccess; return ncclSuccess;
} }
@ -649,7 +649,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
allGather3Data[rank].collNetSupport = comm->collNetSupport; allGather3Data[rank].collNetSupport = comm->collNetSupport;
comm->nChannels = std::min(treeGraph.nChannels, ringGraph.nChannels); comm->nChannels = std::min(treeGraph.nChannels, ringGraph.nChannels);
NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &allGather3Data[rank].topoRanks)); NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &collNetGraph, &allGather3Data[rank].topoRanks));
NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data))); NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data)));
@ -1037,6 +1037,8 @@ collnet_cleanup:
} }
} }
NCCLCHECKGOTO(devCommSetup(comm), ret, affinity_restore);
/* Local intra-node barrier */ /* Local intra-node barrier */
NCCLCHECK(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0])); NCCLCHECK(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]));
@ -1087,7 +1089,6 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
} }
NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup); NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup);
NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup); NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup);
NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup);
// update communicator state // update communicator state
comm->initState = ncclSuccess; comm->initState = ncclSuccess;
@ -1214,6 +1215,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
gpuFlags[devlist[i]] = 1; gpuFlags[devlist[i]] = 1;
} }
free(gpuFlags); free(gpuFlags);
gpuFlags = nullptr;
} }
ncclUniqueId uniqueId; ncclUniqueId uniqueId;
@ -1225,11 +1227,9 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
} }
NCCLCHECKGOTO(ncclGroupEnd(), ret, fail); NCCLCHECKGOTO(ncclGroupEnd(), ret, fail);
exit:
return ret;
fail: fail:
if (gpuFlags) free(gpuFlags); free(gpuFlags);
goto exit; return ret;
} }
ncclResult_t ncclCommSetAsyncError(ncclComm_t comm, ncclResult_t nextState) { ncclResult_t ncclCommSetAsyncError(ncclComm_t comm, ncclResult_t nextState) {

View File

@ -9,24 +9,52 @@
#include "checks.h" #include "checks.h"
#include "param.h" #include "param.h"
// Tracks the chain of graph nodes for a given graph captured identified by
// its graph id. This state has to live for as long as captured work is being
// submitted. CUDA doesn't have mechanism to inform us when the user ends capture
// so the best we can do is get notified when the graph is destroyed.
struct ncclStrongStreamGraph {
struct ncclStrongStreamGraph* next;
// Atomically exchanged to false by both the main thread or the graph destructor
// callback. The last to arrive deletes the node.
bool alive;
unsigned long long graphId;
// For each graph we track the "tip" of the chain of graph nodes. A linear
// chain would always have just one node at its tip, but since we have to merge
// in chains from other streams (via ncclStrongStreamWaitStream) some spots
// in the chain can be wider than a single node and thus need a list, so we
// maintain a dynamically sized array of tip nodes.
int tipCount, tipCapacity;
cudaGraphNode_t* tipNodes;
};
static void ncclStrongStreamGraphDelete(struct ncclStrongStreamGraph* g) {
free(g->tipNodes);
free(g);
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
ncclResult_t ncclCudaGetCapturingGraph( ncclResult_t ncclCudaGetCapturingGraph(
struct ncclCudaGraph* graph, cudaStream_t stream struct ncclCudaGraph* graph, cudaStream_t stream
) { ) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 10000 // cudaStreamGetCaptureInfo
int driver; int driver;
NCCLCHECK(ncclCudaDriverVersion(&driver)); NCCLCHECK(ncclCudaDriverVersion(&driver));
if (driver < 11030) { if (CUDART_VERSION < 11030 || driver < 11030) {
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
unsigned long long gid; unsigned long long gid;
graph->graph = nullptr;
CUDACHECK(cudaStreamGetCaptureInfo(stream, &status, &gid)); CUDACHECK(cudaStreamGetCaptureInfo(stream, &status, &gid));
#if CUDART_VERSION >= 11030
graph->graph = nullptr;
graph->graphId = ULLONG_MAX;
#endif
if (status != cudaStreamCaptureStatusNone) { if (status != cudaStreamCaptureStatusNone) {
WARN("The installed CUDA driver is older than the minimum version (R465) required for NCCL's CUDA Graphs support"); WARN("NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465.");
return ncclInvalidUsage; return ncclInvalidUsage;
} }
} else { } else {
#if CUDART_VERSION >= 11030
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
unsigned long long gid; unsigned long long gid;
CUDACHECK(cudaStreamGetCaptureInfo_v2(stream, &status, &gid, &graph->graph, nullptr, nullptr)); CUDACHECK(cudaStreamGetCaptureInfo_v2(stream, &status, &gid, &graph->graph, nullptr, nullptr));
@ -35,6 +63,7 @@ ncclResult_t ncclCudaGetCapturingGraph(
gid = ULLONG_MAX; gid = ULLONG_MAX;
} }
graph->graphId = gid; graph->graphId = gid;
#endif
} }
#endif #endif
return ncclSuccess; return ncclSuccess;
@ -57,51 +86,113 @@ ncclResult_t ncclCudaGraphAddDestructor(struct ncclCudaGraph graph, cudaHostFn_t
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss) { ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss) {
CUDACHECK(cudaStreamCreateWithFlags(&ss->stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&ss->cudaStream, cudaStreamNonBlocking));
CUDACHECK(cudaEventCreateWithFlags(&ss->event, cudaEventDisableTiming));
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
ss->node = nullptr; CUDACHECK(cudaEventCreateWithFlags(&ss->serialEvent, cudaEventDisableTiming));
ss->graphId = (1ull<<(8*sizeof(long long)-1))-1; ss->everCaptured = false;
ss->eventIsLagging = 0; ss->serialEventNeedsRecord = false;
ss->graphHead = nullptr;
#else
CUDACHECK(cudaEventCreateWithFlags(&ss->scratchEvent, cudaEventDisableTiming));
#endif #endif
return ncclSuccess; return ncclSuccess;
} }
static void graphDestructor(void* arg) {
struct ncclStrongStreamGraph* g = (struct ncclStrongStreamGraph*)arg;
if (false == __atomic_exchange_n(&g->alive, false, __ATOMIC_ACQ_REL)) {
// Last to arrive deletes list node.
ncclStrongStreamGraphDelete(g);
}
}
ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss) { ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss) {
CUDACHECK(cudaStreamDestroy(ss->cudaStream));
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
CUDACHECK(cudaEventDestroy(ss->event)); CUDACHECK(cudaEventDestroy(ss->serialEvent));
// Delete list of per-graph chains.
struct ncclStrongStreamGraph* g = ss->graphHead;
while (g != nullptr) {
struct ncclStrongStreamGraph* next = g->next;
if (false == __atomic_exchange_n(&g->alive, false, __ATOMIC_ACQ_REL)) {
// Last to arrive deletes list node.
ncclStrongStreamGraphDelete(g);
}
g = next;
}
#else
CUDACHECK(cudaEventDestroy(ss->scratchEvent));
#endif #endif
CUDACHECK(cudaStreamDestroy(ss->stream));
return ncclSuccess; return ncclSuccess;
} }
NCCL_PARAM(GraphMixingSupport, "GRAPH_MIXING_SUPPORT", 1) NCCL_PARAM(GraphMixingSupport, "GRAPH_MIXING_SUPPORT", 1)
static void ensureTips(struct ncclStrongStreamGraph* g, int n) {
if (g->tipCapacity < n) {
g->tipNodes = (cudaGraphNode_t*)realloc(g->tipNodes, n*sizeof(cudaGraphNode_t));
g->tipCapacity = n;
}
}
ncclResult_t ncclStrongStreamAcquire( ncclResult_t ncclStrongStreamAcquire(
struct ncclCudaGraph graph, struct ncclStrongStream* ss struct ncclCudaGraph graph, struct ncclStrongStream* ss
) { ) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
bool mixing = ncclParamGraphMixingSupport(); bool mixing = ncclParamGraphMixingSupport();
if (graph.graph == nullptr) { if (graph.graph == nullptr) {
if (mixing && ncclStrongStreamEverCaptured(ss)) { if (mixing && ss->everCaptured) {
CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0)); CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0));
ss->eventIsLagging = 0; ss->serialEventNeedsRecord = false;
} }
} else { } else {
if (ss->graphId != graph.graphId) { ss->everCaptured = true;
if (mixing && ss->eventIsLagging) { // Find the current graph in our list of graphs if it exists.
struct ncclStrongStreamGraph** pg = &ss->graphHead;
struct ncclStrongStreamGraph* g;
while (*pg != nullptr) {
g = *pg;
if (g->graphId == graph.graphId) {
// Move to front of list so that operations after acquire don't have to search the list.
*pg = g->next;
g->next = ss->graphHead;
ss->graphHead = g;
return ncclSuccess;
} else if (false == __atomic_load_n(&g->alive, __ATOMIC_ACQUIRE)) {
// Unrelated graph that has been destroyed. Remove and delete.
*pg = g->next;
ncclStrongStreamGraphDelete(g);
} else {
pg = &g->next;
}
}
// This is a new graph so add to the list.
g = (struct ncclStrongStreamGraph*)malloc(sizeof(struct ncclStrongStreamGraph));
g->graphId = graph.graphId;
g->tipNodes = nullptr;
g->tipCapacity = 0;
g->tipCount = 0;
g->next = ss->graphHead;
ss->graphHead = g;
g->alive = true;
NCCLCHECK(ncclCudaGraphAddDestructor(graph, graphDestructor, (void*)g));
if (mixing && ss->serialEventNeedsRecord) {
// Can only be here if previous release was for uncaptured work that // Can only be here if previous release was for uncaptured work that
// elided updating the event because no capture had yet occurred. // elided updating the event because no capture had yet occurred.
CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0)); CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0));
CUDACHECK(cudaEventRecord(ss->event, ss->stream)); CUDACHECK(cudaEventRecord(ss->serialEvent, ss->cudaStream));
} }
ss->graphId = graph.graphId; ss->serialEventNeedsRecord = false;
ss->eventIsLagging = 0;
// First node in the chain must be a wait on the serialEvent.
if (mixing) { if (mixing) {
CUDACHECK(cudaGraphAddEventWaitNode(&ss->node, graph.graph, nullptr, 0, ss->event)); ensureTips(g, 1);
CUDACHECK(cudaGraphAddEventWaitNode(&g->tipNodes[0], graph.graph, nullptr, 0, ss->serialEvent));
g->tipCount = 1;
} else { } else {
CUDACHECK(cudaGraphAddEmptyNode(&ss->node, graph.graph, nullptr, 0)); g->tipCount = 0;
}
} }
} }
#endif #endif
@ -111,26 +202,38 @@ ncclResult_t ncclStrongStreamAcquire(
ncclResult_t ncclStrongStreamAcquireUncaptured(struct ncclStrongStream* ss) { ncclResult_t ncclStrongStreamAcquireUncaptured(struct ncclStrongStream* ss) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
bool mixing = ncclParamGraphMixingSupport(); bool mixing = ncclParamGraphMixingSupport();
if (mixing && ncclStrongStreamEverCaptured(ss)) { if (mixing && ss->everCaptured) {
CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0)); CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0));
} }
ss->eventIsLagging = 1; // Assume the caller is going to add work to stream. ss->serialEventNeedsRecord = true; // Assume the caller is going to add work to stream.
#endif #endif
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t checkGraphId(struct ncclStrongStreamGraph* g, unsigned long long id) {
if (g == nullptr || g->graphId != id) {
WARN("Expected graph id=%llu was not at head of strong stream's internal list.", id);
return ncclInternalError;
}
return ncclSuccess;
}
ncclResult_t ncclStrongStreamRelease(struct ncclCudaGraph graph, struct ncclStrongStream* ss) { ncclResult_t ncclStrongStreamRelease(struct ncclCudaGraph graph, struct ncclStrongStream* ss) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
bool mixing = ncclParamGraphMixingSupport(); bool mixing = ncclParamGraphMixingSupport();
if (mixing && ss->eventIsLagging) { if (mixing && ss->serialEventNeedsRecord) {
if (graph.graph == nullptr) { if (graph.graph == nullptr) {
if (ncclStrongStreamEverCaptured(ss)) { if (ss->everCaptured) {
CUDACHECK(cudaEventRecord(ss->event, ss->stream)); CUDACHECK(cudaEventRecord(ss->serialEvent, ss->cudaStream));
ss->eventIsLagging = 0; ss->serialEventNeedsRecord = false;
} }
} else { } else {
CUDACHECK(cudaGraphAddEventRecordNode(&ss->node, graph.graph, &ss->node, 1, ss->event)); struct ncclStrongStreamGraph* g = ss->graphHead;
ss->eventIsLagging = 0; NCCLCHECK(checkGraphId(g, graph.graphId));
ensureTips(g, 1);
CUDACHECK(cudaGraphAddEventRecordNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, ss->serialEvent));
g->tipCount = 1;
ss->serialEventNeedsRecord = false;
} }
} }
#endif #endif
@ -142,16 +245,20 @@ ncclResult_t ncclStrongStreamLaunchHost(
) { ) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
if (graph.graph == nullptr) { if (graph.graph == nullptr) {
CUDACHECK(cudaLaunchHostFunc(ss->stream, fn, arg)); CUDACHECK(cudaLaunchHostFunc(ss->cudaStream, fn, arg));
} else { } else {
cudaHostNodeParams p; cudaHostNodeParams p;
p.fn = fn; p.fn = fn;
p.userData = arg; p.userData = arg;
CUDACHECK(cudaGraphAddHostNode(&ss->node, graph.graph, &ss->node, 1, &p)); struct ncclStrongStreamGraph* g = ss->graphHead;
NCCLCHECK(checkGraphId(g, graph.graphId));
ensureTips(g, 1);
CUDACHECK(cudaGraphAddHostNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, &p));
g->tipCount = 1;
} }
ss->eventIsLagging = 1; ss->serialEventNeedsRecord = true;
#else #else
CUDACHECK(cudaLaunchHostFunc(ss->stream, fn, arg)); CUDACHECK(cudaLaunchHostFunc(ss->cudaStream, fn, arg));
#endif #endif
return ncclSuccess; return ncclSuccess;
} }
@ -162,9 +269,8 @@ ncclResult_t ncclStrongStreamLaunchKernel(
) { ) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
if (graph.graph == nullptr) { if (graph.graph == nullptr) {
CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->stream)); CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->cudaStream));
} else { } else {
cudaGraphNode_t tip = ss->node;
cudaKernelNodeParams p; cudaKernelNodeParams p;
p.func = fn; p.func = fn;
p.gridDim = grid; p.gridDim = grid;
@ -172,33 +278,53 @@ ncclResult_t ncclStrongStreamLaunchKernel(
p.kernelParams = args; p.kernelParams = args;
p.sharedMemBytes = sharedMemBytes; p.sharedMemBytes = sharedMemBytes;
p.extra = nullptr; p.extra = nullptr;
CUDACHECK(cudaGraphAddKernelNode(&ss->node, graph.graph, &tip, 1, &p)); struct ncclStrongStreamGraph* g = ss->graphHead;
NCCLCHECK(checkGraphId(g, graph.graphId));
ensureTips(g, 1);
CUDACHECK(cudaGraphAddKernelNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, &p));
g->tipCount = 1;
} }
ss->eventIsLagging = 1; ss->serialEventNeedsRecord = true;
#else #else
CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->stream)); CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->cudaStream));
#endif #endif
return ncclSuccess; return ncclSuccess;
} }
// Merge node list `b` into list `a` but don't add duplicates.
static void mergeTips(struct ncclStrongStreamGraph* a, cudaGraphNode_t const* bNodes, int bn) {
int an = a->tipCount;
ensureTips(a, an + bn);
for (int bi=0; bi < bn; bi++) {
for (int ai=0; ai < an; ai++) {
if (a->tipNodes[ai] == bNodes[bi]) goto next_b;
}
a->tipNodes[a->tipCount++] = bNodes[bi];
next_b:;
}
}
ncclResult_t ncclStrongStreamWaitStream( ncclResult_t ncclStrongStreamWaitStream(
struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b
) { ) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
if (graph.graph == nullptr) { if (graph.graph == nullptr) {
if (b->eventIsLagging) { if (b->serialEventNeedsRecord) {
b->eventIsLagging = 0; b->serialEventNeedsRecord = false;
CUDACHECK(cudaEventRecord(b->event, b->stream)); CUDACHECK(cudaEventRecord(b->serialEvent, b->cudaStream));
} }
CUDACHECK(cudaStreamWaitEvent(a->stream, b->event, 0)); CUDACHECK(cudaStreamWaitEvent(a->cudaStream, b->serialEvent, 0));
} else { } else {
cudaGraphNode_t pair[2] = {a->node, b->node}; struct ncclStrongStreamGraph* ag = a->graphHead;
CUDACHECK(cudaGraphAddEmptyNode(&a->node, graph.graph, pair, 2)); NCCLCHECK(checkGraphId(ag, graph.graphId));
struct ncclStrongStreamGraph* bg = b->graphHead;
NCCLCHECK(checkGraphId(bg, graph.graphId));
mergeTips(ag, bg->tipNodes, bg->tipCount);
} }
a->eventIsLagging = 1; a->serialEventNeedsRecord = true;
#else #else
CUDACHECK(cudaEventRecord(b->event, b->stream)); CUDACHECK(cudaEventRecord(b->scratchEvent, b->cudaStream));
CUDACHECK(cudaStreamWaitEvent(a->stream, b->event, 0)); CUDACHECK(cudaStreamWaitEvent(a->cudaStream, b->scratchEvent, 0));
#endif #endif
return ncclSuccess; return ncclSuccess;
} }
@ -208,35 +334,29 @@ ncclResult_t ncclStrongStreamWaitStream(
) { ) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
if (graph.graph == nullptr) { if (graph.graph == nullptr) {
CUDACHECK(cudaEventRecord(a->event, b)); // It is ok to use a->serialEvent to record b since we'll be setting
CUDACHECK(cudaStreamWaitEvent(a->stream, a->event, 0)); // a->serialEventNeedsRecord so the event won't be considered accurate
// We used a->event to record b so it no longer reflects anything about a. // until re-recorded.
a->eventIsLagging = 1; CUDACHECK(cudaEventRecord(a->serialEvent, b));
CUDACHECK(cudaStreamWaitEvent(a->cudaStream, a->serialEvent, 0));
} else { } else {
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
unsigned long long gid1; unsigned long long bGraphId;
cudaGraphNode_t const* deps; cudaGraphNode_t const* bNodes;
size_t depN = 0; size_t bCount = 0;
CUDACHECK(cudaStreamGetCaptureInfo_v2(b, &status, &gid1, nullptr, &deps, &depN)); CUDACHECK(cudaStreamGetCaptureInfo_v2(b, &status, &bGraphId, nullptr, &bNodes, &bCount));
if (status != cudaStreamCaptureStatusActive || graph.graphId != gid1) { if (status != cudaStreamCaptureStatusActive || graph.graphId != bGraphId) {
WARN("Stream is not being captured by the expected graph."); WARN("Stream is not being captured by the expected graph.");
return ncclInvalidUsage; return ncclInvalidUsage;
} }
if (depN > 0 && (depN > 1 || deps[0] != a->node)) { struct ncclStrongStreamGraph* ag = a->graphHead;
cudaGraphNode_t tie; NCCLCHECK(checkGraphId(ag, graph.graphId));
if (depN == 1) { mergeTips(ag, bNodes, bCount);
tie = deps[0];
} else {
CUDACHECK(cudaGraphAddEmptyNode(&tie, graph.graph, deps, depN));
}
cudaGraphNode_t pair[2] = {a->node, tie};
CUDACHECK(cudaGraphAddEmptyNode(&a->node, graph.graph, pair, 2));
a->eventIsLagging = 1;
}
} }
a->serialEventNeedsRecord = true;
#else #else
CUDACHECK(cudaEventRecord(a->event, b)); CUDACHECK(cudaEventRecord(a->scratchEvent, b));
CUDACHECK(cudaStreamWaitEvent(a->stream, a->event, 0)); CUDACHECK(cudaStreamWaitEvent(a->cudaStream, a->scratchEvent, 0));
#endif #endif
return ncclSuccess; return ncclSuccess;
} }
@ -246,25 +366,28 @@ ncclResult_t ncclStrongStreamWaitStream(
) { ) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
if (graph.graph == nullptr) { if (graph.graph == nullptr) {
if (b->eventIsLagging) { if (b->serialEventNeedsRecord) {
b->eventIsLagging = 0; b->serialEventNeedsRecord = false;
CUDACHECK(cudaEventRecord(b->event, b->stream)); CUDACHECK(cudaEventRecord(b->serialEvent, b->cudaStream));
} }
CUDACHECK(cudaStreamWaitEvent(a, b->event, 0)); CUDACHECK(cudaStreamWaitEvent(a, b->serialEvent, 0));
} else { } else {
CUDACHECK(cudaStreamUpdateCaptureDependencies(a, &b->node, 1, cudaStreamAddCaptureDependencies)); struct ncclStrongStreamGraph* bg = b->graphHead;
NCCLCHECK(checkGraphId(bg, graph.graphId));
CUDACHECK(cudaStreamUpdateCaptureDependencies(a, bg->tipNodes, bg->tipCount, cudaStreamAddCaptureDependencies));
} }
#else #else
CUDACHECK(cudaEventRecord(b->event, b->stream)); CUDACHECK(cudaEventRecord(b->scratchEvent, b->cudaStream));
CUDACHECK(cudaStreamWaitEvent(a, b->event, 0)); CUDACHECK(cudaStreamWaitEvent(a, b->scratchEvent, 0));
#endif #endif
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t ncclStrongStreamSynchronize(struct ncclStrongStream* ss) { ncclResult_t ncclStrongStreamSynchronize(struct ncclStrongStream* ss) {
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0)); CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0));
ss->serialEventNeedsRecord = false;
#endif #endif
CUDACHECK(cudaStreamSynchronize(ss->stream)); CUDACHECK(cudaStreamSynchronize(ss->cudaStream));
return ncclSuccess; return ncclSuccess;
} }