From cb111f764a6d46370f24f75101d6b219bb2dda54 Mon Sep 17 00:00:00 2001 From: Sylvain Jeaugey Date: Tue, 25 Oct 2022 00:55:55 -0700 Subject: [PATCH] 2.15.5-1 Fix crash with CollnetChain on some node topologies Fix hang when interleaving the capture of different graphs Fix hang during init in multi-threaded mode Fix potential data corruption with LL128 protocol on unaligned buffers. Fix CPU usage during preconnect Fixes double-free in the error path for ncclCommInitAll Workaround hang on H100 with Ring/LL128 on 2 GPUs. --- makefiles/version.mk | 2 +- src/channel.cc | 4 +- src/collectives/device/prims_ll128.h | 17 +- src/enqueue.cc | 1 + src/graph/connect.cc | 9 +- src/graph/tuning.cc | 12 +- src/group.cc | 2 + src/include/graph.h | 2 +- src/include/strongstream.h | 52 +++-- src/init.cc | 20 +- src/misc/strongstream.cc | 307 +++++++++++++++++++-------- 11 files changed, 281 insertions(+), 147 deletions(-) diff --git a/makefiles/version.mk b/makefiles/version.mk index 977d763..be64e9a 100644 --- a/makefiles/version.mk +++ b/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 NCCL_MINOR := 15 -NCCL_PATCH := 1 +NCCL_PATCH := 5 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/src/channel.cc b/src/channel.cc index c1254f1..0514076 100644 --- a/src/channel.cc +++ b/src/channel.cc @@ -20,11 +20,11 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) { // The extra on nRanks+1 is for collnet root (i.e. network) channel->peers = ncclMemoryStackAlloc(&comm->memPermanent, nRanks+1); - NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nRanks+1, comm->deviceStream.stream)); + NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nRanks+1, comm->deviceStream.cudaStream)); ncclCommPushCudaFree(comm, channel->devPeers); channel->ring.userRanks = ncclMemoryStackAlloc(&comm->memPermanent, nRanks); - NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, comm->deviceStream.stream)); + NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, comm->deviceStream.cudaStream)); ncclCommPushCudaFree(comm, channel->devRingUserRanks); NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index 3136940..773a921 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -15,11 +15,12 @@ class Primitives: static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; RedOp redOp; - const int tid; - const int nthreads; - const int wid; + const int tid; // thread index in primitives group + const int nthreads; // thread count in primitives group + const int wid; // lane index in warp const int stepSize; - const int warp; + const int warp; // warp index in primitives group + const int warpInBlock; // warp index in thread block const bool flagThread; const int group; Fan fan; @@ -108,7 +109,7 @@ class Primitives: // buffer into shmem. int misalignment = reinterpret_cast(src) % 16; uint64_t *src8 = reinterpret_cast(reinterpret_cast(src) & -uintptr_t(16)); - uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]); + uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warpInBlock]); #pragma unroll for(int g=0; g < WordPerThread/2; g++) if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T)) @@ -152,7 +153,7 @@ class Primitives: } // Write to dst if 16-byte aligned, shmem otherwise. int misalignment = reinterpret_cast(dst)%16; - uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]); + uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warpInBlock]); #pragma unroll for(int g=0; g < WordPerThread/2; g++) { int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); @@ -166,7 +167,7 @@ class Primitives: __syncwarp(); // Write rest from shmem to dst. No need to coalesce stores to 16-bytes, // the hardware keeps up fine. - T *shm = (T*)ncclShmem.ll128warp[warp]; + T *shm = (T*)ncclShmem.ll128warp[warpInBlock]; int skip = misalignment == 0 ? eltN & -EltPer16B : 0; for(int i=skip+wid; i < eltN; i += WARP_SIZE) dst[i] = shm[i]; @@ -215,7 +216,6 @@ class Primitives: /************************ Recv rest *********************/ if (RECV) { { // Consume data from first recv - uint64_t* ptr = recvPtr(0)+ll128Offset; #pragma unroll for (int u=0; u()(redOp, vr[u], v[u]) : vr[u]; @@ -360,6 +360,7 @@ public: ): redOp(redOpArg), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), + warpInBlock(threadIdx.x/WARP_SIZE), flagThread((tid%8)==7), group(group&(uint16_t)0xFFFF), stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) { int connIndex = group >> 16; diff --git a/src/enqueue.cc b/src/enqueue.cc index 0db55bf..8bac73f 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -853,6 +853,7 @@ static ncclResult_t hostStreamPlanTask(struct ncclComm* comm, struct ncclKernelP } static void CUDART_CB hostStreamPlanCallback(void *plan_) { + NVTX3_FUNC_RANGE_IN(nccl_domain); struct ncclKernelPlan* plan = (struct ncclKernelPlan*)plan_; ncclResult_t result = hostStreamPlanTask(plan->comm, plan); if (result != ncclSuccess) { diff --git a/src/graph/connect.cc b/src/graph/connect.cc index 01ff282..ccf1e04 100644 --- a/src/graph/connect.cc +++ b/src/graph/connect.cc @@ -15,7 +15,7 @@ /******************************************************************/ ncclResult_t ncclTopoPreset(struct ncclComm* comm, - struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, + struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph, struct ncclTopoRanks* topoRanks) { int rank = comm->rank; int localRanks = comm->topo->nodes[GPU].count; @@ -37,6 +37,7 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, int* ringIntra = ringGraph->intra+c*localRanks; int* treeIntra = treeGraph->intra+c*localRanks; + int* collNetIntra = collNetGraph->intra+c*localRanks; for (int i=0; itreeToChild1[c] = treeIntra[child1Index]; channel->tree.up = i == 0 ? -1 : treeIntra[i-1]; channel->tree.down[0] = i == localRanks-1 ? -1 : treeIntra[i+1]; - channel->collnetChain.up = i == 0 ? comm->nRanks : treeIntra[i-1]; - channel->collnetChain.down[0] = i == localRanks-1 ? -1 : treeIntra[i+1]; + } + if (collNetIntra[i] == rank) { + channel->collnetChain.up = i == 0 ? comm->nRanks : collNetIntra[i-1]; + channel->collnetChain.down[0] = i == localRanks-1 ? -1 : collNetIntra[i+1]; } } topoRanks->ringPrev[c] = channel->ring.prev; diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index 07a2104..18afc03 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -227,8 +227,16 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom int pEnable = protoEnable[p]; if (pEnable == 2 && p == NCCL_PROTO_LL128) { // Enable LL128 by default only on Volta/Ampere/Hopper+NVLink. Other cases are not tested and may cause silent data corruption. - pEnable = (graphs[a]->typeInter <= PATH_PXB) && graphs[a]->typeIntra <= PATH_NVL && - ((minCompCap == 70 && maxCompCap == 70) || (minCompCap == 80 && maxCompCap == 80) || (minCompCap == 90 && maxCompCap == 90)) ? 1 : 0; + pEnable = 1; + pEnable &= (graphs[a]->typeInter <= PATH_PXB); + pEnable &= (graphs[a]->typeIntra <= PATH_NVL); + pEnable &= (minCompCap == maxCompCap); + switch (minCompCap) { + case 70: pEnable &= 1; break; + case 80: pEnable &= 1; break; + case 90: pEnable &= !(CUDART_VERSION == 11080 && c == ncclFuncAllReduce && a == NCCL_ALGO_RING && comm->nRanks == 2); break; + default: pEnable &= 0; break; + } } if (pEnable == 0) comm->bandwidths[c][a][p] = 0; // Only disable algo for Allreduce since others only have one diff --git a/src/group.cc b/src/group.cc index d246f28..ff416e3 100644 --- a/src/group.cc +++ b/src/group.cc @@ -331,6 +331,8 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) { job = job->next; } while (job != nullptr); + // Let preconnect threads progress. + if (jobsDone == false) usleep(1); } while (jobsDone == false); if (ret != ncclSuccess) goto fail; diff --git a/src/include/graph.h b/src/include/graph.h index 26c1e76..91e85e7 100644 --- a/src/include/graph.h +++ b/src/include/graph.h @@ -101,7 +101,7 @@ struct ncclTopoRanks { }; ncclResult_t ncclTopoPreset(struct ncclComm* comm, - struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, + struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph, struct ncclTopoRanks* topoRanks); ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns, diff --git a/src/include/strongstream.h b/src/include/strongstream.h index 74df610..16b6e07 100644 --- a/src/include/strongstream.h +++ b/src/include/strongstream.h @@ -18,7 +18,7 @@ struct ncclCudaGraph { #if CUDART_VERSION >= 11030 cudaGraph_t graph; - uint64_t graphId; + unsigned long long graphId; #endif }; @@ -57,36 +57,29 @@ ncclResult_t ncclCudaGraphAddDestructor(struct ncclCudaGraph graph, cudaHostFn_t * streams unfit for the use of serializing access to a persistent resource. * Strong streams have been introduced to address this need. * - * Constraints of using strong streams: + * - All updates to a strong stream must be enclosed by a Acquire/Release pair. * - * - Operations that enqueue work to the strong stream need to be enclosed by - * ncclStrongStream[Acquire/Release] pairs. Acquire/release act like fences, - * the strong stream is not stateful so there is no harm in redundant acquire - * or releases. + * - The Acquire, Release, and all updates take a ncclCudaGraph parameter + * indicating the currently capturing graph (or none). This parameter must be + * the same for the entire sequence of {Acquire; ...; Release}. * * - An {Acquire; ...; Release} sequence must not be concurrent with any * other operations against the strong stream including graph launches which * reference this stream. - * - * - All strong stream functions take a "graph" parameter which must reference - * the currently capturing graph, or null if none. */ struct ncclStrongStream; ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss); ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss); -// Has this strong stream ever been captured in a graph. -bool ncclStrongStreamEverCaptured(struct ncclStrongStream* ss); - // Acquire-fence the strong stream. ncclResult_t ncclStrongStreamAcquire( struct ncclCudaGraph graph, struct ncclStrongStream* ss ); // Acquire-fence the strong stream assuming no graph is capturing. This permits -// the caller to enqueue directly to the `ss->stream` member using native CUDA -// calls. Strong stream must be released via: +// the caller to enqueue directly to the `ss->cudaStream` member using native CUDA +// calls. Strong stream still must be released via: // ncclStrongStreamRelease(ncclCudaGraphNone(), ss); ncclResult_t ncclStrongStreamAcquireUncaptured(struct ncclStrongStream* ss); @@ -103,6 +96,7 @@ ncclResult_t ncclStrongStreamLaunchKernel( struct ncclCudaGraph graph, struct ncclStrongStream* ss, void* fn, dim3 grid, dim3 block, void** args, size_t sharedMemBytes ); + // Cause `a` to wait for the current state `b`. Both `a` and `b` must be acquired. ncclResult_t ncclStrongStreamWaitStream( struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b @@ -121,21 +115,23 @@ ncclResult_t ncclStrongStreamSynchronize(struct ncclStrongStream* ss); //////////////////////////////////////////////////////////////////////////////// +struct ncclStrongStreamGraph; // internal to ncclStrongStream + struct ncclStrongStream { - cudaStream_t stream; - cudaEvent_t event; - #if CUDART_VERSION >= 11030 - cudaGraphNode_t node; // null if never captured, otherwise never null again - uint64_t graphId:63, eventIsLagging:1; - #endif + // Used when not graph capturing. + cudaStream_t cudaStream; +#if CUDART_VERSION >= 11030 + // The event used to establish order between graphs and streams. During acquire + // this event is waited on, during release it is recorded to. + cudaEvent_t serialEvent; + // This stream ever appeared in a graph capture. + bool everCaptured; + // Tracks whether serialEvent needs to be recorded to upon Release(). + bool serialEventNeedsRecord; + struct ncclStrongStreamGraph* graphHead; +#else + cudaEvent_t scratchEvent; +#endif }; -inline bool ncclStrongStreamEverCaptured(struct ncclStrongStream* ss) { - #if CUDART_VERSION >= 11030 - return ss->node != nullptr; - #else - return false; - #endif -} - #endif diff --git a/src/init.cc b/src/init.cc index 86fc9df..ab0a064 100644 --- a/src/init.cc +++ b/src/init.cc @@ -364,7 +364,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { int nRanks = comm->nRanks; struct ncclDevCommAndChannels *devCommAndChans, tmpCommAndChans; - NCCLCHECK(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.stream)); + NCCLCHECK(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.cudaStream)); ncclCommPushCudaFree(comm, devCommAndChans); comm->devComm = &devCommAndChans->comm; tmpCommAndChans.comm.rank = comm->rank; @@ -410,12 +410,12 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c]; if (comm->channels[c].ring.userRanks != nullptr) { - NCCLCHECK(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.stream)); + NCCLCHECK(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.cudaStream)); } } - NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.stream)); - CUDACHECK(cudaStreamSynchronize(comm->deviceStream.stream)); + NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.cudaStream)); + CUDACHECK(cudaStreamSynchronize(comm->deviceStream.cudaStream)); NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); return ncclSuccess; } @@ -649,7 +649,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm allGather3Data[rank].collNetSupport = comm->collNetSupport; comm->nChannels = std::min(treeGraph.nChannels, ringGraph.nChannels); - NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &allGather3Data[rank].topoRanks)); + NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &collNetGraph, &allGather3Data[rank].topoRanks)); NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data))); @@ -1037,6 +1037,8 @@ collnet_cleanup: } } + NCCLCHECKGOTO(devCommSetup(comm), ret, affinity_restore); + /* Local intra-node barrier */ NCCLCHECK(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0])); @@ -1087,7 +1089,6 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { } NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup); NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup); - NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup); // update communicator state comm->initState = ncclSuccess; @@ -1214,6 +1215,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { gpuFlags[devlist[i]] = 1; } free(gpuFlags); + gpuFlags = nullptr; } ncclUniqueId uniqueId; @@ -1225,11 +1227,9 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { } NCCLCHECKGOTO(ncclGroupEnd(), ret, fail); -exit: - return ret; fail: - if (gpuFlags) free(gpuFlags); - goto exit; + free(gpuFlags); + return ret; } ncclResult_t ncclCommSetAsyncError(ncclComm_t comm, ncclResult_t nextState) { diff --git a/src/misc/strongstream.cc b/src/misc/strongstream.cc index 0524223..d07698b 100644 --- a/src/misc/strongstream.cc +++ b/src/misc/strongstream.cc @@ -9,32 +9,61 @@ #include "checks.h" #include "param.h" +// Tracks the chain of graph nodes for a given graph captured identified by +// its graph id. This state has to live for as long as captured work is being +// submitted. CUDA doesn't have mechanism to inform us when the user ends capture +// so the best we can do is get notified when the graph is destroyed. +struct ncclStrongStreamGraph { + struct ncclStrongStreamGraph* next; + // Atomically exchanged to false by both the main thread or the graph destructor + // callback. The last to arrive deletes the node. + bool alive; + unsigned long long graphId; + // For each graph we track the "tip" of the chain of graph nodes. A linear + // chain would always have just one node at its tip, but since we have to merge + // in chains from other streams (via ncclStrongStreamWaitStream) some spots + // in the chain can be wider than a single node and thus need a list, so we + // maintain a dynamically sized array of tip nodes. + int tipCount, tipCapacity; + cudaGraphNode_t* tipNodes; +}; + +static void ncclStrongStreamGraphDelete(struct ncclStrongStreamGraph* g) { + free(g->tipNodes); + free(g); +} + //////////////////////////////////////////////////////////////////////////////// ncclResult_t ncclCudaGetCapturingGraph( struct ncclCudaGraph* graph, cudaStream_t stream ) { - #if CUDART_VERSION >= 11030 + #if CUDART_VERSION >= 10000 // cudaStreamGetCaptureInfo int driver; NCCLCHECK(ncclCudaDriverVersion(&driver)); - if (driver < 11030) { + if (CUDART_VERSION < 11030 || driver < 11030) { cudaStreamCaptureStatus status; unsigned long long gid; - graph->graph = nullptr; CUDACHECK(cudaStreamGetCaptureInfo(stream, &status, &gid)); + #if CUDART_VERSION >= 11030 + graph->graph = nullptr; + graph->graphId = ULLONG_MAX; + #endif if (status != cudaStreamCaptureStatusNone) { - WARN("The installed CUDA driver is older than the minimum version (R465) required for NCCL's CUDA Graphs support"); + WARN("NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465."); return ncclInvalidUsage; } } else { - cudaStreamCaptureStatus status; - unsigned long long gid; - CUDACHECK(cudaStreamGetCaptureInfo_v2(stream, &status, &gid, &graph->graph, nullptr, nullptr)); - if (status != cudaStreamCaptureStatusActive) { - graph->graph = nullptr; - gid = ULLONG_MAX; - } - graph->graphId = gid; + #if CUDART_VERSION >= 11030 + cudaStreamCaptureStatus status; + unsigned long long gid; + CUDACHECK(cudaStreamGetCaptureInfo_v2(stream, &status, &gid, &graph->graph, nullptr, nullptr)); + if (status != cudaStreamCaptureStatusActive) { + graph->graph = nullptr; + gid = ULLONG_MAX; + } + graph->graphId = gid; + #endif } #endif return ncclSuccess; @@ -57,52 +86,114 @@ ncclResult_t ncclCudaGraphAddDestructor(struct ncclCudaGraph graph, cudaHostFn_t //////////////////////////////////////////////////////////////////////////////// ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss) { - CUDACHECK(cudaStreamCreateWithFlags(&ss->stream, cudaStreamNonBlocking)); - CUDACHECK(cudaEventCreateWithFlags(&ss->event, cudaEventDisableTiming)); + CUDACHECK(cudaStreamCreateWithFlags(&ss->cudaStream, cudaStreamNonBlocking)); #if CUDART_VERSION >= 11030 - ss->node = nullptr; - ss->graphId = (1ull<<(8*sizeof(long long)-1))-1; - ss->eventIsLagging = 0; + CUDACHECK(cudaEventCreateWithFlags(&ss->serialEvent, cudaEventDisableTiming)); + ss->everCaptured = false; + ss->serialEventNeedsRecord = false; + ss->graphHead = nullptr; + #else + CUDACHECK(cudaEventCreateWithFlags(&ss->scratchEvent, cudaEventDisableTiming)); #endif return ncclSuccess; } +static void graphDestructor(void* arg) { + struct ncclStrongStreamGraph* g = (struct ncclStrongStreamGraph*)arg; + if (false == __atomic_exchange_n(&g->alive, false, __ATOMIC_ACQ_REL)) { + // Last to arrive deletes list node. + ncclStrongStreamGraphDelete(g); + } +} + ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss) { + CUDACHECK(cudaStreamDestroy(ss->cudaStream)); #if CUDART_VERSION >= 11030 - CUDACHECK(cudaEventDestroy(ss->event)); + CUDACHECK(cudaEventDestroy(ss->serialEvent)); + // Delete list of per-graph chains. + struct ncclStrongStreamGraph* g = ss->graphHead; + while (g != nullptr) { + struct ncclStrongStreamGraph* next = g->next; + if (false == __atomic_exchange_n(&g->alive, false, __ATOMIC_ACQ_REL)) { + // Last to arrive deletes list node. + ncclStrongStreamGraphDelete(g); + } + g = next; + } + #else + CUDACHECK(cudaEventDestroy(ss->scratchEvent)); #endif - CUDACHECK(cudaStreamDestroy(ss->stream)); return ncclSuccess; } NCCL_PARAM(GraphMixingSupport, "GRAPH_MIXING_SUPPORT", 1) +static void ensureTips(struct ncclStrongStreamGraph* g, int n) { + if (g->tipCapacity < n) { + g->tipNodes = (cudaGraphNode_t*)realloc(g->tipNodes, n*sizeof(cudaGraphNode_t)); + g->tipCapacity = n; + } +} + ncclResult_t ncclStrongStreamAcquire( struct ncclCudaGraph graph, struct ncclStrongStream* ss ) { #if CUDART_VERSION >= 11030 bool mixing = ncclParamGraphMixingSupport(); if (graph.graph == nullptr) { - if (mixing && ncclStrongStreamEverCaptured(ss)) { - CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0)); - ss->eventIsLagging = 0; + if (mixing && ss->everCaptured) { + CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0)); + ss->serialEventNeedsRecord = false; } } else { - if (ss->graphId != graph.graphId) { - if (mixing && ss->eventIsLagging) { - // Can only be here if previous release was for uncaptured work that - // elided updating the event because no capture had yet occurred. - CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0)); - CUDACHECK(cudaEventRecord(ss->event, ss->stream)); - } - ss->graphId = graph.graphId; - ss->eventIsLagging = 0; - if (mixing) { - CUDACHECK(cudaGraphAddEventWaitNode(&ss->node, graph.graph, nullptr, 0, ss->event)); + ss->everCaptured = true; + // Find the current graph in our list of graphs if it exists. + struct ncclStrongStreamGraph** pg = &ss->graphHead; + struct ncclStrongStreamGraph* g; + while (*pg != nullptr) { + g = *pg; + if (g->graphId == graph.graphId) { + // Move to front of list so that operations after acquire don't have to search the list. + *pg = g->next; + g->next = ss->graphHead; + ss->graphHead = g; + return ncclSuccess; + } else if (false == __atomic_load_n(&g->alive, __ATOMIC_ACQUIRE)) { + // Unrelated graph that has been destroyed. Remove and delete. + *pg = g->next; + ncclStrongStreamGraphDelete(g); } else { - CUDACHECK(cudaGraphAddEmptyNode(&ss->node, graph.graph, nullptr, 0)); + pg = &g->next; } } + + // This is a new graph so add to the list. + g = (struct ncclStrongStreamGraph*)malloc(sizeof(struct ncclStrongStreamGraph)); + g->graphId = graph.graphId; + g->tipNodes = nullptr; + g->tipCapacity = 0; + g->tipCount = 0; + g->next = ss->graphHead; + ss->graphHead = g; + g->alive = true; + NCCLCHECK(ncclCudaGraphAddDestructor(graph, graphDestructor, (void*)g)); + + if (mixing && ss->serialEventNeedsRecord) { + // Can only be here if previous release was for uncaptured work that + // elided updating the event because no capture had yet occurred. + CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0)); + CUDACHECK(cudaEventRecord(ss->serialEvent, ss->cudaStream)); + } + ss->serialEventNeedsRecord = false; + + // First node in the chain must be a wait on the serialEvent. + if (mixing) { + ensureTips(g, 1); + CUDACHECK(cudaGraphAddEventWaitNode(&g->tipNodes[0], graph.graph, nullptr, 0, ss->serialEvent)); + g->tipCount = 1; + } else { + g->tipCount = 0; + } } #endif return ncclSuccess; @@ -111,26 +202,38 @@ ncclResult_t ncclStrongStreamAcquire( ncclResult_t ncclStrongStreamAcquireUncaptured(struct ncclStrongStream* ss) { #if CUDART_VERSION >= 11030 bool mixing = ncclParamGraphMixingSupport(); - if (mixing && ncclStrongStreamEverCaptured(ss)) { - CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0)); + if (mixing && ss->everCaptured) { + CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0)); } - ss->eventIsLagging = 1; // Assume the caller is going to add work to stream. + ss->serialEventNeedsRecord = true; // Assume the caller is going to add work to stream. #endif return ncclSuccess; } +static ncclResult_t checkGraphId(struct ncclStrongStreamGraph* g, unsigned long long id) { + if (g == nullptr || g->graphId != id) { + WARN("Expected graph id=%llu was not at head of strong stream's internal list.", id); + return ncclInternalError; + } + return ncclSuccess; +} + ncclResult_t ncclStrongStreamRelease(struct ncclCudaGraph graph, struct ncclStrongStream* ss) { #if CUDART_VERSION >= 11030 bool mixing = ncclParamGraphMixingSupport(); - if (mixing && ss->eventIsLagging) { + if (mixing && ss->serialEventNeedsRecord) { if (graph.graph == nullptr) { - if (ncclStrongStreamEverCaptured(ss)) { - CUDACHECK(cudaEventRecord(ss->event, ss->stream)); - ss->eventIsLagging = 0; + if (ss->everCaptured) { + CUDACHECK(cudaEventRecord(ss->serialEvent, ss->cudaStream)); + ss->serialEventNeedsRecord = false; } } else { - CUDACHECK(cudaGraphAddEventRecordNode(&ss->node, graph.graph, &ss->node, 1, ss->event)); - ss->eventIsLagging = 0; + struct ncclStrongStreamGraph* g = ss->graphHead; + NCCLCHECK(checkGraphId(g, graph.graphId)); + ensureTips(g, 1); + CUDACHECK(cudaGraphAddEventRecordNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, ss->serialEvent)); + g->tipCount = 1; + ss->serialEventNeedsRecord = false; } } #endif @@ -142,16 +245,20 @@ ncclResult_t ncclStrongStreamLaunchHost( ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { - CUDACHECK(cudaLaunchHostFunc(ss->stream, fn, arg)); + CUDACHECK(cudaLaunchHostFunc(ss->cudaStream, fn, arg)); } else { cudaHostNodeParams p; p.fn = fn; p.userData = arg; - CUDACHECK(cudaGraphAddHostNode(&ss->node, graph.graph, &ss->node, 1, &p)); + struct ncclStrongStreamGraph* g = ss->graphHead; + NCCLCHECK(checkGraphId(g, graph.graphId)); + ensureTips(g, 1); + CUDACHECK(cudaGraphAddHostNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, &p)); + g->tipCount = 1; } - ss->eventIsLagging = 1; + ss->serialEventNeedsRecord = true; #else - CUDACHECK(cudaLaunchHostFunc(ss->stream, fn, arg)); + CUDACHECK(cudaLaunchHostFunc(ss->cudaStream, fn, arg)); #endif return ncclSuccess; } @@ -162,9 +269,8 @@ ncclResult_t ncclStrongStreamLaunchKernel( ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { - CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->stream)); + CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->cudaStream)); } else { - cudaGraphNode_t tip = ss->node; cudaKernelNodeParams p; p.func = fn; p.gridDim = grid; @@ -172,33 +278,53 @@ ncclResult_t ncclStrongStreamLaunchKernel( p.kernelParams = args; p.sharedMemBytes = sharedMemBytes; p.extra = nullptr; - CUDACHECK(cudaGraphAddKernelNode(&ss->node, graph.graph, &tip, 1, &p)); + struct ncclStrongStreamGraph* g = ss->graphHead; + NCCLCHECK(checkGraphId(g, graph.graphId)); + ensureTips(g, 1); + CUDACHECK(cudaGraphAddKernelNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, &p)); + g->tipCount = 1; } - ss->eventIsLagging = 1; + ss->serialEventNeedsRecord = true; #else - CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->stream)); + CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->cudaStream)); #endif return ncclSuccess; } +// Merge node list `b` into list `a` but don't add duplicates. +static void mergeTips(struct ncclStrongStreamGraph* a, cudaGraphNode_t const* bNodes, int bn) { + int an = a->tipCount; + ensureTips(a, an + bn); + for (int bi=0; bi < bn; bi++) { + for (int ai=0; ai < an; ai++) { + if (a->tipNodes[ai] == bNodes[bi]) goto next_b; + } + a->tipNodes[a->tipCount++] = bNodes[bi]; + next_b:; + } +} + ncclResult_t ncclStrongStreamWaitStream( struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { - if (b->eventIsLagging) { - b->eventIsLagging = 0; - CUDACHECK(cudaEventRecord(b->event, b->stream)); + if (b->serialEventNeedsRecord) { + b->serialEventNeedsRecord = false; + CUDACHECK(cudaEventRecord(b->serialEvent, b->cudaStream)); } - CUDACHECK(cudaStreamWaitEvent(a->stream, b->event, 0)); + CUDACHECK(cudaStreamWaitEvent(a->cudaStream, b->serialEvent, 0)); } else { - cudaGraphNode_t pair[2] = {a->node, b->node}; - CUDACHECK(cudaGraphAddEmptyNode(&a->node, graph.graph, pair, 2)); + struct ncclStrongStreamGraph* ag = a->graphHead; + NCCLCHECK(checkGraphId(ag, graph.graphId)); + struct ncclStrongStreamGraph* bg = b->graphHead; + NCCLCHECK(checkGraphId(bg, graph.graphId)); + mergeTips(ag, bg->tipNodes, bg->tipCount); } - a->eventIsLagging = 1; + a->serialEventNeedsRecord = true; #else - CUDACHECK(cudaEventRecord(b->event, b->stream)); - CUDACHECK(cudaStreamWaitEvent(a->stream, b->event, 0)); + CUDACHECK(cudaEventRecord(b->scratchEvent, b->cudaStream)); + CUDACHECK(cudaStreamWaitEvent(a->cudaStream, b->scratchEvent, 0)); #endif return ncclSuccess; } @@ -208,35 +334,29 @@ ncclResult_t ncclStrongStreamWaitStream( ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { - CUDACHECK(cudaEventRecord(a->event, b)); - CUDACHECK(cudaStreamWaitEvent(a->stream, a->event, 0)); - // We used a->event to record b so it no longer reflects anything about a. - a->eventIsLagging = 1; + // It is ok to use a->serialEvent to record b since we'll be setting + // a->serialEventNeedsRecord so the event won't be considered accurate + // until re-recorded. + CUDACHECK(cudaEventRecord(a->serialEvent, b)); + CUDACHECK(cudaStreamWaitEvent(a->cudaStream, a->serialEvent, 0)); } else { cudaStreamCaptureStatus status; - unsigned long long gid1; - cudaGraphNode_t const* deps; - size_t depN = 0; - CUDACHECK(cudaStreamGetCaptureInfo_v2(b, &status, &gid1, nullptr, &deps, &depN)); - if (status != cudaStreamCaptureStatusActive || graph.graphId != gid1) { + unsigned long long bGraphId; + cudaGraphNode_t const* bNodes; + size_t bCount = 0; + CUDACHECK(cudaStreamGetCaptureInfo_v2(b, &status, &bGraphId, nullptr, &bNodes, &bCount)); + if (status != cudaStreamCaptureStatusActive || graph.graphId != bGraphId) { WARN("Stream is not being captured by the expected graph."); return ncclInvalidUsage; } - if (depN > 0 && (depN > 1 || deps[0] != a->node)) { - cudaGraphNode_t tie; - if (depN == 1) { - tie = deps[0]; - } else { - CUDACHECK(cudaGraphAddEmptyNode(&tie, graph.graph, deps, depN)); - } - cudaGraphNode_t pair[2] = {a->node, tie}; - CUDACHECK(cudaGraphAddEmptyNode(&a->node, graph.graph, pair, 2)); - a->eventIsLagging = 1; - } + struct ncclStrongStreamGraph* ag = a->graphHead; + NCCLCHECK(checkGraphId(ag, graph.graphId)); + mergeTips(ag, bNodes, bCount); } + a->serialEventNeedsRecord = true; #else - CUDACHECK(cudaEventRecord(a->event, b)); - CUDACHECK(cudaStreamWaitEvent(a->stream, a->event, 0)); + CUDACHECK(cudaEventRecord(a->scratchEvent, b)); + CUDACHECK(cudaStreamWaitEvent(a->cudaStream, a->scratchEvent, 0)); #endif return ncclSuccess; } @@ -246,25 +366,28 @@ ncclResult_t ncclStrongStreamWaitStream( ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { - if (b->eventIsLagging) { - b->eventIsLagging = 0; - CUDACHECK(cudaEventRecord(b->event, b->stream)); + if (b->serialEventNeedsRecord) { + b->serialEventNeedsRecord = false; + CUDACHECK(cudaEventRecord(b->serialEvent, b->cudaStream)); } - CUDACHECK(cudaStreamWaitEvent(a, b->event, 0)); + CUDACHECK(cudaStreamWaitEvent(a, b->serialEvent, 0)); } else { - CUDACHECK(cudaStreamUpdateCaptureDependencies(a, &b->node, 1, cudaStreamAddCaptureDependencies)); + struct ncclStrongStreamGraph* bg = b->graphHead; + NCCLCHECK(checkGraphId(bg, graph.graphId)); + CUDACHECK(cudaStreamUpdateCaptureDependencies(a, bg->tipNodes, bg->tipCount, cudaStreamAddCaptureDependencies)); } #else - CUDACHECK(cudaEventRecord(b->event, b->stream)); - CUDACHECK(cudaStreamWaitEvent(a, b->event, 0)); + CUDACHECK(cudaEventRecord(b->scratchEvent, b->cudaStream)); + CUDACHECK(cudaStreamWaitEvent(a, b->scratchEvent, 0)); #endif return ncclSuccess; } ncclResult_t ncclStrongStreamSynchronize(struct ncclStrongStream* ss) { #if CUDART_VERSION >= 11030 - CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0)); + CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0)); + ss->serialEventNeedsRecord = false; #endif - CUDACHECK(cudaStreamSynchronize(ss->stream)); + CUDACHECK(cudaStreamSynchronize(ss->cudaStream)); return ncclSuccess; }