Fix crash with CollnetChain on some node topologies
Fix hang when interleaving the capture of different graphs
Fix hang during init in multi-threaded mode
Fix potential data corruption with LL128 protocol on unaligned buffers.
Fix CPU usage during preconnect
Fixes double-free in the error path for ncclCommInitAll
Workaround hang on H100 with Ring/LL128 on 2 GPUs.
This commit is contained in:
Sylvain Jeaugey 2022-10-25 00:55:55 -07:00
parent da8152e57a
commit cb111f764a
11 changed files with 281 additions and 147 deletions

View File

@ -1,6 +1,6 @@
##### version
NCCL_MAJOR := 2
NCCL_MINOR := 15
NCCL_PATCH := 1
NCCL_PATCH := 5
NCCL_SUFFIX :=
PKG_REVISION := 1

View File

@ -20,11 +20,11 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) {
// The extra on nRanks+1 is for collnet root (i.e. network)
channel->peers = ncclMemoryStackAlloc<struct ncclChannelPeer>(&comm->memPermanent, nRanks+1);
NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nRanks+1, comm->deviceStream.stream));
NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nRanks+1, comm->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devPeers);
channel->ring.userRanks = ncclMemoryStackAlloc<int>(&comm->memPermanent, nRanks);
NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, comm->deviceStream.stream));
NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, comm->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devRingUserRanks);
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream));

View File

@ -15,11 +15,12 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
static constexpr int Input=0, Output=1;
RedOp redOp;
const int tid;
const int nthreads;
const int wid;
const int tid; // thread index in primitives group
const int nthreads; // thread count in primitives group
const int wid; // lane index in warp
const int stepSize;
const int warp;
const int warp; // warp index in primitives group
const int warpInBlock; // warp index in thread block
const bool flagThread;
const int group;
Fan fan;
@ -108,7 +109,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
// buffer into shmem.
int misalignment = reinterpret_cast<uintptr_t>(src) % 16;
uint64_t *src8 = reinterpret_cast<uint64_t*>(reinterpret_cast<uintptr_t>(src) & -uintptr_t(16));
uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]);
uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warpInBlock]);
#pragma unroll
for(int g=0; g < WordPerThread/2; g++)
if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T))
@ -152,7 +153,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
}
// Write to dst if 16-byte aligned, shmem otherwise.
int misalignment = reinterpret_cast<uintptr_t>(dst)%16;
uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]);
uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warpInBlock]);
#pragma unroll
for(int g=0; g < WordPerThread/2; g++) {
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
@ -166,7 +167,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
__syncwarp();
// Write rest from shmem to dst. No need to coalesce stores to 16-bytes,
// the hardware keeps up fine.
T *shm = (T*)ncclShmem.ll128warp[warp];
T *shm = (T*)ncclShmem.ll128warp[warpInBlock];
int skip = misalignment == 0 ? eltN & -EltPer16B : 0;
for(int i=skip+wid; i < eltN; i += WARP_SIZE)
dst[i] = shm[i];
@ -215,7 +216,6 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
/************************ Recv rest *********************/
if (RECV) {
{ // Consume data from first recv
uint64_t* ptr = recvPtr(0)+ll128Offset;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
v[u] = SRC ? MULTI<RedOp, T>()(redOp, vr[u], v[u]) : vr[u];
@ -360,6 +360,7 @@ public:
):
redOp(redOpArg),
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE),
warpInBlock(threadIdx.x/WARP_SIZE),
flagThread((tid%8)==7), group(group&(uint16_t)0xFFFF),
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) {
int connIndex = group >> 16;

View File

@ -853,6 +853,7 @@ static ncclResult_t hostStreamPlanTask(struct ncclComm* comm, struct ncclKernelP
}
static void CUDART_CB hostStreamPlanCallback(void *plan_) {
NVTX3_FUNC_RANGE_IN(nccl_domain);
struct ncclKernelPlan* plan = (struct ncclKernelPlan*)plan_;
ncclResult_t result = hostStreamPlanTask(plan->comm, plan);
if (result != ncclSuccess) {

View File

@ -15,7 +15,7 @@
/******************************************************************/
ncclResult_t ncclTopoPreset(struct ncclComm* comm,
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph,
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph,
struct ncclTopoRanks* topoRanks) {
int rank = comm->rank;
int localRanks = comm->topo->nodes[GPU].count;
@ -37,6 +37,7 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
int* ringIntra = ringGraph->intra+c*localRanks;
int* treeIntra = treeGraph->intra+c*localRanks;
int* collNetIntra = collNetGraph->intra+c*localRanks;
for (int i=0; i<localRanks; i++) {
if (ringIntra[i] == rank) {
@ -55,8 +56,10 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
topoRanks->treeToChild1[c] = treeIntra[child1Index];
channel->tree.up = i == 0 ? -1 : treeIntra[i-1];
channel->tree.down[0] = i == localRanks-1 ? -1 : treeIntra[i+1];
channel->collnetChain.up = i == 0 ? comm->nRanks : treeIntra[i-1];
channel->collnetChain.down[0] = i == localRanks-1 ? -1 : treeIntra[i+1];
}
if (collNetIntra[i] == rank) {
channel->collnetChain.up = i == 0 ? comm->nRanks : collNetIntra[i-1];
channel->collnetChain.down[0] = i == localRanks-1 ? -1 : collNetIntra[i+1];
}
}
topoRanks->ringPrev[c] = channel->ring.prev;

View File

@ -227,8 +227,16 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
int pEnable = protoEnable[p];
if (pEnable == 2 && p == NCCL_PROTO_LL128) {
// Enable LL128 by default only on Volta/Ampere/Hopper+NVLink. Other cases are not tested and may cause silent data corruption.
pEnable = (graphs[a]->typeInter <= PATH_PXB) && graphs[a]->typeIntra <= PATH_NVL &&
((minCompCap == 70 && maxCompCap == 70) || (minCompCap == 80 && maxCompCap == 80) || (minCompCap == 90 && maxCompCap == 90)) ? 1 : 0;
pEnable = 1;
pEnable &= (graphs[a]->typeInter <= PATH_PXB);
pEnable &= (graphs[a]->typeIntra <= PATH_NVL);
pEnable &= (minCompCap == maxCompCap);
switch (minCompCap) {
case 70: pEnable &= 1; break;
case 80: pEnable &= 1; break;
case 90: pEnable &= !(CUDART_VERSION == 11080 && c == ncclFuncAllReduce && a == NCCL_ALGO_RING && comm->nRanks == 2); break;
default: pEnable &= 0; break;
}
}
if (pEnable == 0) comm->bandwidths[c][a][p] = 0;
// Only disable algo for Allreduce since others only have one

View File

@ -331,6 +331,8 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) {
job = job->next;
} while (job != nullptr);
// Let preconnect threads progress.
if (jobsDone == false) usleep(1);
} while (jobsDone == false);
if (ret != ncclSuccess) goto fail;

View File

@ -101,7 +101,7 @@ struct ncclTopoRanks {
};
ncclResult_t ncclTopoPreset(struct ncclComm* comm,
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph,
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph,
struct ncclTopoRanks* topoRanks);
ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns,

View File

@ -18,7 +18,7 @@
struct ncclCudaGraph {
#if CUDART_VERSION >= 11030
cudaGraph_t graph;
uint64_t graphId;
unsigned long long graphId;
#endif
};
@ -57,36 +57,29 @@ ncclResult_t ncclCudaGraphAddDestructor(struct ncclCudaGraph graph, cudaHostFn_t
* streams unfit for the use of serializing access to a persistent resource.
* Strong streams have been introduced to address this need.
*
* Constraints of using strong streams:
* - All updates to a strong stream must be enclosed by a Acquire/Release pair.
*
* - Operations that enqueue work to the strong stream need to be enclosed by
* ncclStrongStream[Acquire/Release] pairs. Acquire/release act like fences,
* the strong stream is not stateful so there is no harm in redundant acquire
* or releases.
* - The Acquire, Release, and all updates take a ncclCudaGraph parameter
* indicating the currently capturing graph (or none). This parameter must be
* the same for the entire sequence of {Acquire; ...; Release}.
*
* - An {Acquire; ...; Release} sequence must not be concurrent with any
* other operations against the strong stream including graph launches which
* reference this stream.
*
* - All strong stream functions take a "graph" parameter which must reference
* the currently capturing graph, or null if none.
*/
struct ncclStrongStream;
ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss);
ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss);
// Has this strong stream ever been captured in a graph.
bool ncclStrongStreamEverCaptured(struct ncclStrongStream* ss);
// Acquire-fence the strong stream.
ncclResult_t ncclStrongStreamAcquire(
struct ncclCudaGraph graph, struct ncclStrongStream* ss
);
// Acquire-fence the strong stream assuming no graph is capturing. This permits
// the caller to enqueue directly to the `ss->stream` member using native CUDA
// calls. Strong stream must be released via:
// the caller to enqueue directly to the `ss->cudaStream` member using native CUDA
// calls. Strong stream still must be released via:
// ncclStrongStreamRelease(ncclCudaGraphNone(), ss);
ncclResult_t ncclStrongStreamAcquireUncaptured(struct ncclStrongStream* ss);
@ -103,6 +96,7 @@ ncclResult_t ncclStrongStreamLaunchKernel(
struct ncclCudaGraph graph, struct ncclStrongStream* ss,
void* fn, dim3 grid, dim3 block, void** args, size_t sharedMemBytes
);
// Cause `a` to wait for the current state `b`. Both `a` and `b` must be acquired.
ncclResult_t ncclStrongStreamWaitStream(
struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b
@ -121,21 +115,23 @@ ncclResult_t ncclStrongStreamSynchronize(struct ncclStrongStream* ss);
////////////////////////////////////////////////////////////////////////////////
struct ncclStrongStreamGraph; // internal to ncclStrongStream
struct ncclStrongStream {
cudaStream_t stream;
cudaEvent_t event;
#if CUDART_VERSION >= 11030
cudaGraphNode_t node; // null if never captured, otherwise never null again
uint64_t graphId:63, eventIsLagging:1;
#endif
// Used when not graph capturing.
cudaStream_t cudaStream;
#if CUDART_VERSION >= 11030
// The event used to establish order between graphs and streams. During acquire
// this event is waited on, during release it is recorded to.
cudaEvent_t serialEvent;
// This stream ever appeared in a graph capture.
bool everCaptured;
// Tracks whether serialEvent needs to be recorded to upon Release().
bool serialEventNeedsRecord;
struct ncclStrongStreamGraph* graphHead;
#else
cudaEvent_t scratchEvent;
#endif
};
inline bool ncclStrongStreamEverCaptured(struct ncclStrongStream* ss) {
#if CUDART_VERSION >= 11030
return ss->node != nullptr;
#else
return false;
#endif
}
#endif

View File

@ -364,7 +364,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
int nRanks = comm->nRanks;
struct ncclDevCommAndChannels *devCommAndChans, tmpCommAndChans;
NCCLCHECK(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.stream));
NCCLCHECK(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, devCommAndChans);
comm->devComm = &devCommAndChans->comm;
tmpCommAndChans.comm.rank = comm->rank;
@ -410,12 +410,12 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c];
if (comm->channels[c].ring.userRanks != nullptr) {
NCCLCHECK(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.stream));
NCCLCHECK(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.cudaStream));
}
}
NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.stream));
CUDACHECK(cudaStreamSynchronize(comm->deviceStream.stream));
NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.cudaStream));
CUDACHECK(cudaStreamSynchronize(comm->deviceStream.cudaStream));
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream));
return ncclSuccess;
}
@ -649,7 +649,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
allGather3Data[rank].collNetSupport = comm->collNetSupport;
comm->nChannels = std::min(treeGraph.nChannels, ringGraph.nChannels);
NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &allGather3Data[rank].topoRanks));
NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &collNetGraph, &allGather3Data[rank].topoRanks));
NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data)));
@ -1037,6 +1037,8 @@ collnet_cleanup:
}
}
NCCLCHECKGOTO(devCommSetup(comm), ret, affinity_restore);
/* Local intra-node barrier */
NCCLCHECK(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]));
@ -1087,7 +1089,6 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
}
NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup);
NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup);
NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup);
// update communicator state
comm->initState = ncclSuccess;
@ -1214,6 +1215,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
gpuFlags[devlist[i]] = 1;
}
free(gpuFlags);
gpuFlags = nullptr;
}
ncclUniqueId uniqueId;
@ -1225,11 +1227,9 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
}
NCCLCHECKGOTO(ncclGroupEnd(), ret, fail);
exit:
return ret;
fail:
if (gpuFlags) free(gpuFlags);
goto exit;
free(gpuFlags);
return ret;
}
ncclResult_t ncclCommSetAsyncError(ncclComm_t comm, ncclResult_t nextState) {

View File

@ -9,24 +9,52 @@
#include "checks.h"
#include "param.h"
// Tracks the chain of graph nodes for a given graph captured identified by
// its graph id. This state has to live for as long as captured work is being
// submitted. CUDA doesn't have mechanism to inform us when the user ends capture
// so the best we can do is get notified when the graph is destroyed.
struct ncclStrongStreamGraph {
struct ncclStrongStreamGraph* next;
// Atomically exchanged to false by both the main thread or the graph destructor
// callback. The last to arrive deletes the node.
bool alive;
unsigned long long graphId;
// For each graph we track the "tip" of the chain of graph nodes. A linear
// chain would always have just one node at its tip, but since we have to merge
// in chains from other streams (via ncclStrongStreamWaitStream) some spots
// in the chain can be wider than a single node and thus need a list, so we
// maintain a dynamically sized array of tip nodes.
int tipCount, tipCapacity;
cudaGraphNode_t* tipNodes;
};
static void ncclStrongStreamGraphDelete(struct ncclStrongStreamGraph* g) {
free(g->tipNodes);
free(g);
}
////////////////////////////////////////////////////////////////////////////////
ncclResult_t ncclCudaGetCapturingGraph(
struct ncclCudaGraph* graph, cudaStream_t stream
) {
#if CUDART_VERSION >= 11030
#if CUDART_VERSION >= 10000 // cudaStreamGetCaptureInfo
int driver;
NCCLCHECK(ncclCudaDriverVersion(&driver));
if (driver < 11030) {
if (CUDART_VERSION < 11030 || driver < 11030) {
cudaStreamCaptureStatus status;
unsigned long long gid;
graph->graph = nullptr;
CUDACHECK(cudaStreamGetCaptureInfo(stream, &status, &gid));
#if CUDART_VERSION >= 11030
graph->graph = nullptr;
graph->graphId = ULLONG_MAX;
#endif
if (status != cudaStreamCaptureStatusNone) {
WARN("The installed CUDA driver is older than the minimum version (R465) required for NCCL's CUDA Graphs support");
WARN("NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465.");
return ncclInvalidUsage;
}
} else {
#if CUDART_VERSION >= 11030
cudaStreamCaptureStatus status;
unsigned long long gid;
CUDACHECK(cudaStreamGetCaptureInfo_v2(stream, &status, &gid, &graph->graph, nullptr, nullptr));
@ -35,6 +63,7 @@ ncclResult_t ncclCudaGetCapturingGraph(
gid = ULLONG_MAX;
}
graph->graphId = gid;
#endif
}
#endif
return ncclSuccess;
@ -57,51 +86,113 @@ ncclResult_t ncclCudaGraphAddDestructor(struct ncclCudaGraph graph, cudaHostFn_t
////////////////////////////////////////////////////////////////////////////////
ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss) {
CUDACHECK(cudaStreamCreateWithFlags(&ss->stream, cudaStreamNonBlocking));
CUDACHECK(cudaEventCreateWithFlags(&ss->event, cudaEventDisableTiming));
CUDACHECK(cudaStreamCreateWithFlags(&ss->cudaStream, cudaStreamNonBlocking));
#if CUDART_VERSION >= 11030
ss->node = nullptr;
ss->graphId = (1ull<<(8*sizeof(long long)-1))-1;
ss->eventIsLagging = 0;
CUDACHECK(cudaEventCreateWithFlags(&ss->serialEvent, cudaEventDisableTiming));
ss->everCaptured = false;
ss->serialEventNeedsRecord = false;
ss->graphHead = nullptr;
#else
CUDACHECK(cudaEventCreateWithFlags(&ss->scratchEvent, cudaEventDisableTiming));
#endif
return ncclSuccess;
}
static void graphDestructor(void* arg) {
struct ncclStrongStreamGraph* g = (struct ncclStrongStreamGraph*)arg;
if (false == __atomic_exchange_n(&g->alive, false, __ATOMIC_ACQ_REL)) {
// Last to arrive deletes list node.
ncclStrongStreamGraphDelete(g);
}
}
ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss) {
CUDACHECK(cudaStreamDestroy(ss->cudaStream));
#if CUDART_VERSION >= 11030
CUDACHECK(cudaEventDestroy(ss->event));
CUDACHECK(cudaEventDestroy(ss->serialEvent));
// Delete list of per-graph chains.
struct ncclStrongStreamGraph* g = ss->graphHead;
while (g != nullptr) {
struct ncclStrongStreamGraph* next = g->next;
if (false == __atomic_exchange_n(&g->alive, false, __ATOMIC_ACQ_REL)) {
// Last to arrive deletes list node.
ncclStrongStreamGraphDelete(g);
}
g = next;
}
#else
CUDACHECK(cudaEventDestroy(ss->scratchEvent));
#endif
CUDACHECK(cudaStreamDestroy(ss->stream));
return ncclSuccess;
}
NCCL_PARAM(GraphMixingSupport, "GRAPH_MIXING_SUPPORT", 1)
static void ensureTips(struct ncclStrongStreamGraph* g, int n) {
if (g->tipCapacity < n) {
g->tipNodes = (cudaGraphNode_t*)realloc(g->tipNodes, n*sizeof(cudaGraphNode_t));
g->tipCapacity = n;
}
}
ncclResult_t ncclStrongStreamAcquire(
struct ncclCudaGraph graph, struct ncclStrongStream* ss
) {
#if CUDART_VERSION >= 11030
bool mixing = ncclParamGraphMixingSupport();
if (graph.graph == nullptr) {
if (mixing && ncclStrongStreamEverCaptured(ss)) {
CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0));
ss->eventIsLagging = 0;
if (mixing && ss->everCaptured) {
CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0));
ss->serialEventNeedsRecord = false;
}
} else {
if (ss->graphId != graph.graphId) {
if (mixing && ss->eventIsLagging) {
ss->everCaptured = true;
// Find the current graph in our list of graphs if it exists.
struct ncclStrongStreamGraph** pg = &ss->graphHead;
struct ncclStrongStreamGraph* g;
while (*pg != nullptr) {
g = *pg;
if (g->graphId == graph.graphId) {
// Move to front of list so that operations after acquire don't have to search the list.
*pg = g->next;
g->next = ss->graphHead;
ss->graphHead = g;
return ncclSuccess;
} else if (false == __atomic_load_n(&g->alive, __ATOMIC_ACQUIRE)) {
// Unrelated graph that has been destroyed. Remove and delete.
*pg = g->next;
ncclStrongStreamGraphDelete(g);
} else {
pg = &g->next;
}
}
// This is a new graph so add to the list.
g = (struct ncclStrongStreamGraph*)malloc(sizeof(struct ncclStrongStreamGraph));
g->graphId = graph.graphId;
g->tipNodes = nullptr;
g->tipCapacity = 0;
g->tipCount = 0;
g->next = ss->graphHead;
ss->graphHead = g;
g->alive = true;
NCCLCHECK(ncclCudaGraphAddDestructor(graph, graphDestructor, (void*)g));
if (mixing && ss->serialEventNeedsRecord) {
// Can only be here if previous release was for uncaptured work that
// elided updating the event because no capture had yet occurred.
CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0));
CUDACHECK(cudaEventRecord(ss->event, ss->stream));
CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0));
CUDACHECK(cudaEventRecord(ss->serialEvent, ss->cudaStream));
}
ss->graphId = graph.graphId;
ss->eventIsLagging = 0;
ss->serialEventNeedsRecord = false;
// First node in the chain must be a wait on the serialEvent.
if (mixing) {
CUDACHECK(cudaGraphAddEventWaitNode(&ss->node, graph.graph, nullptr, 0, ss->event));
ensureTips(g, 1);
CUDACHECK(cudaGraphAddEventWaitNode(&g->tipNodes[0], graph.graph, nullptr, 0, ss->serialEvent));
g->tipCount = 1;
} else {
CUDACHECK(cudaGraphAddEmptyNode(&ss->node, graph.graph, nullptr, 0));
}
g->tipCount = 0;
}
}
#endif
@ -111,26 +202,38 @@ ncclResult_t ncclStrongStreamAcquire(
ncclResult_t ncclStrongStreamAcquireUncaptured(struct ncclStrongStream* ss) {
#if CUDART_VERSION >= 11030
bool mixing = ncclParamGraphMixingSupport();
if (mixing && ncclStrongStreamEverCaptured(ss)) {
CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0));
if (mixing && ss->everCaptured) {
CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0));
}
ss->eventIsLagging = 1; // Assume the caller is going to add work to stream.
ss->serialEventNeedsRecord = true; // Assume the caller is going to add work to stream.
#endif
return ncclSuccess;
}
static ncclResult_t checkGraphId(struct ncclStrongStreamGraph* g, unsigned long long id) {
if (g == nullptr || g->graphId != id) {
WARN("Expected graph id=%llu was not at head of strong stream's internal list.", id);
return ncclInternalError;
}
return ncclSuccess;
}
ncclResult_t ncclStrongStreamRelease(struct ncclCudaGraph graph, struct ncclStrongStream* ss) {
#if CUDART_VERSION >= 11030
bool mixing = ncclParamGraphMixingSupport();
if (mixing && ss->eventIsLagging) {
if (mixing && ss->serialEventNeedsRecord) {
if (graph.graph == nullptr) {
if (ncclStrongStreamEverCaptured(ss)) {
CUDACHECK(cudaEventRecord(ss->event, ss->stream));
ss->eventIsLagging = 0;
if (ss->everCaptured) {
CUDACHECK(cudaEventRecord(ss->serialEvent, ss->cudaStream));
ss->serialEventNeedsRecord = false;
}
} else {
CUDACHECK(cudaGraphAddEventRecordNode(&ss->node, graph.graph, &ss->node, 1, ss->event));
ss->eventIsLagging = 0;
struct ncclStrongStreamGraph* g = ss->graphHead;
NCCLCHECK(checkGraphId(g, graph.graphId));
ensureTips(g, 1);
CUDACHECK(cudaGraphAddEventRecordNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, ss->serialEvent));
g->tipCount = 1;
ss->serialEventNeedsRecord = false;
}
}
#endif
@ -142,16 +245,20 @@ ncclResult_t ncclStrongStreamLaunchHost(
) {
#if CUDART_VERSION >= 11030
if (graph.graph == nullptr) {
CUDACHECK(cudaLaunchHostFunc(ss->stream, fn, arg));
CUDACHECK(cudaLaunchHostFunc(ss->cudaStream, fn, arg));
} else {
cudaHostNodeParams p;
p.fn = fn;
p.userData = arg;
CUDACHECK(cudaGraphAddHostNode(&ss->node, graph.graph, &ss->node, 1, &p));
struct ncclStrongStreamGraph* g = ss->graphHead;
NCCLCHECK(checkGraphId(g, graph.graphId));
ensureTips(g, 1);
CUDACHECK(cudaGraphAddHostNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, &p));
g->tipCount = 1;
}
ss->eventIsLagging = 1;
ss->serialEventNeedsRecord = true;
#else
CUDACHECK(cudaLaunchHostFunc(ss->stream, fn, arg));
CUDACHECK(cudaLaunchHostFunc(ss->cudaStream, fn, arg));
#endif
return ncclSuccess;
}
@ -162,9 +269,8 @@ ncclResult_t ncclStrongStreamLaunchKernel(
) {
#if CUDART_VERSION >= 11030
if (graph.graph == nullptr) {
CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->stream));
CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->cudaStream));
} else {
cudaGraphNode_t tip = ss->node;
cudaKernelNodeParams p;
p.func = fn;
p.gridDim = grid;
@ -172,33 +278,53 @@ ncclResult_t ncclStrongStreamLaunchKernel(
p.kernelParams = args;
p.sharedMemBytes = sharedMemBytes;
p.extra = nullptr;
CUDACHECK(cudaGraphAddKernelNode(&ss->node, graph.graph, &tip, 1, &p));
struct ncclStrongStreamGraph* g = ss->graphHead;
NCCLCHECK(checkGraphId(g, graph.graphId));
ensureTips(g, 1);
CUDACHECK(cudaGraphAddKernelNode(&g->tipNodes[0], graph.graph, g->tipNodes, g->tipCount, &p));
g->tipCount = 1;
}
ss->eventIsLagging = 1;
ss->serialEventNeedsRecord = true;
#else
CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->stream));
CUDACHECK(cudaLaunchKernel(fn, grid, block, args, sharedMemBytes, ss->cudaStream));
#endif
return ncclSuccess;
}
// Merge node list `b` into list `a` but don't add duplicates.
static void mergeTips(struct ncclStrongStreamGraph* a, cudaGraphNode_t const* bNodes, int bn) {
int an = a->tipCount;
ensureTips(a, an + bn);
for (int bi=0; bi < bn; bi++) {
for (int ai=0; ai < an; ai++) {
if (a->tipNodes[ai] == bNodes[bi]) goto next_b;
}
a->tipNodes[a->tipCount++] = bNodes[bi];
next_b:;
}
}
ncclResult_t ncclStrongStreamWaitStream(
struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b
) {
#if CUDART_VERSION >= 11030
if (graph.graph == nullptr) {
if (b->eventIsLagging) {
b->eventIsLagging = 0;
CUDACHECK(cudaEventRecord(b->event, b->stream));
if (b->serialEventNeedsRecord) {
b->serialEventNeedsRecord = false;
CUDACHECK(cudaEventRecord(b->serialEvent, b->cudaStream));
}
CUDACHECK(cudaStreamWaitEvent(a->stream, b->event, 0));
CUDACHECK(cudaStreamWaitEvent(a->cudaStream, b->serialEvent, 0));
} else {
cudaGraphNode_t pair[2] = {a->node, b->node};
CUDACHECK(cudaGraphAddEmptyNode(&a->node, graph.graph, pair, 2));
struct ncclStrongStreamGraph* ag = a->graphHead;
NCCLCHECK(checkGraphId(ag, graph.graphId));
struct ncclStrongStreamGraph* bg = b->graphHead;
NCCLCHECK(checkGraphId(bg, graph.graphId));
mergeTips(ag, bg->tipNodes, bg->tipCount);
}
a->eventIsLagging = 1;
a->serialEventNeedsRecord = true;
#else
CUDACHECK(cudaEventRecord(b->event, b->stream));
CUDACHECK(cudaStreamWaitEvent(a->stream, b->event, 0));
CUDACHECK(cudaEventRecord(b->scratchEvent, b->cudaStream));
CUDACHECK(cudaStreamWaitEvent(a->cudaStream, b->scratchEvent, 0));
#endif
return ncclSuccess;
}
@ -208,35 +334,29 @@ ncclResult_t ncclStrongStreamWaitStream(
) {
#if CUDART_VERSION >= 11030
if (graph.graph == nullptr) {
CUDACHECK(cudaEventRecord(a->event, b));
CUDACHECK(cudaStreamWaitEvent(a->stream, a->event, 0));
// We used a->event to record b so it no longer reflects anything about a.
a->eventIsLagging = 1;
// It is ok to use a->serialEvent to record b since we'll be setting
// a->serialEventNeedsRecord so the event won't be considered accurate
// until re-recorded.
CUDACHECK(cudaEventRecord(a->serialEvent, b));
CUDACHECK(cudaStreamWaitEvent(a->cudaStream, a->serialEvent, 0));
} else {
cudaStreamCaptureStatus status;
unsigned long long gid1;
cudaGraphNode_t const* deps;
size_t depN = 0;
CUDACHECK(cudaStreamGetCaptureInfo_v2(b, &status, &gid1, nullptr, &deps, &depN));
if (status != cudaStreamCaptureStatusActive || graph.graphId != gid1) {
unsigned long long bGraphId;
cudaGraphNode_t const* bNodes;
size_t bCount = 0;
CUDACHECK(cudaStreamGetCaptureInfo_v2(b, &status, &bGraphId, nullptr, &bNodes, &bCount));
if (status != cudaStreamCaptureStatusActive || graph.graphId != bGraphId) {
WARN("Stream is not being captured by the expected graph.");
return ncclInvalidUsage;
}
if (depN > 0 && (depN > 1 || deps[0] != a->node)) {
cudaGraphNode_t tie;
if (depN == 1) {
tie = deps[0];
} else {
CUDACHECK(cudaGraphAddEmptyNode(&tie, graph.graph, deps, depN));
}
cudaGraphNode_t pair[2] = {a->node, tie};
CUDACHECK(cudaGraphAddEmptyNode(&a->node, graph.graph, pair, 2));
a->eventIsLagging = 1;
}
struct ncclStrongStreamGraph* ag = a->graphHead;
NCCLCHECK(checkGraphId(ag, graph.graphId));
mergeTips(ag, bNodes, bCount);
}
a->serialEventNeedsRecord = true;
#else
CUDACHECK(cudaEventRecord(a->event, b));
CUDACHECK(cudaStreamWaitEvent(a->stream, a->event, 0));
CUDACHECK(cudaEventRecord(a->scratchEvent, b));
CUDACHECK(cudaStreamWaitEvent(a->cudaStream, a->scratchEvent, 0));
#endif
return ncclSuccess;
}
@ -246,25 +366,28 @@ ncclResult_t ncclStrongStreamWaitStream(
) {
#if CUDART_VERSION >= 11030
if (graph.graph == nullptr) {
if (b->eventIsLagging) {
b->eventIsLagging = 0;
CUDACHECK(cudaEventRecord(b->event, b->stream));
if (b->serialEventNeedsRecord) {
b->serialEventNeedsRecord = false;
CUDACHECK(cudaEventRecord(b->serialEvent, b->cudaStream));
}
CUDACHECK(cudaStreamWaitEvent(a, b->event, 0));
CUDACHECK(cudaStreamWaitEvent(a, b->serialEvent, 0));
} else {
CUDACHECK(cudaStreamUpdateCaptureDependencies(a, &b->node, 1, cudaStreamAddCaptureDependencies));
struct ncclStrongStreamGraph* bg = b->graphHead;
NCCLCHECK(checkGraphId(bg, graph.graphId));
CUDACHECK(cudaStreamUpdateCaptureDependencies(a, bg->tipNodes, bg->tipCount, cudaStreamAddCaptureDependencies));
}
#else
CUDACHECK(cudaEventRecord(b->event, b->stream));
CUDACHECK(cudaStreamWaitEvent(a, b->event, 0));
CUDACHECK(cudaEventRecord(b->scratchEvent, b->cudaStream));
CUDACHECK(cudaStreamWaitEvent(a, b->scratchEvent, 0));
#endif
return ncclSuccess;
}
ncclResult_t ncclStrongStreamSynchronize(struct ncclStrongStream* ss) {
#if CUDART_VERSION >= 11030
CUDACHECK(cudaStreamWaitEvent(ss->stream, ss->event, 0));
CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->serialEvent, 0));
ss->serialEventNeedsRecord = false;
#endif
CUDACHECK(cudaStreamSynchronize(ss->stream));
CUDACHECK(cudaStreamSynchronize(ss->cudaStream));
return ncclSuccess;
}