Add support for IB SHARP to NVLS (NVLink SHARP algorithm).
Add NVLS+Tree algorithm.
Add support for memory management using cuMem* functions.
Use all NICs for Send/Receive operations on systems with more than
one NIC per GPU (#804).
Add ncclCommSplit primitive, with resource sharing option in config.
Fix alltoallv hang (#788)
Increase number of channels on H100 when we're not limited by NVLink.
Improve error reporting in case of IB failure, printing local and
remote ID (#779).
Add build option to allow compilation against RDMA includes instead
of dynamically loading IB verbs symbols (#802).
Fix context creation for progress thread (#803).
NET/IB: add option to use multiple QPs in round-robin mode.
Fix tree performance issue when NVB is disabled on HCM topologies.
This commit is contained in:
Sylvain Jeaugey 2023-04-03 05:32:07 -07:00
parent 9b7d5edbfc
commit d97a32fac8
64 changed files with 4758 additions and 3131 deletions

View File

@ -12,6 +12,7 @@ DEBUG ?= 0
TRACE ?= 0 TRACE ?= 0
PROFAPI ?= 1 PROFAPI ?= 1
NVTX ?= 1 NVTX ?= 1
RDMA_CORE ?= 0
NVCC = $(CUDA_HOME)/bin/nvcc NVCC = $(CUDA_HOME)/bin/nvcc
@ -106,3 +107,7 @@ endif
ifneq ($(PROFAPI), 0) ifneq ($(PROFAPI), 0)
CXXFLAGS += -DPROFAPI CXXFLAGS += -DPROFAPI
endif endif
ifneq ($(RDMA_CORE), 0)
CXXFLAGS += -DNCCL_BUILD_RDMA_CORE=1
endif

View File

@ -1,6 +1,6 @@
##### version ##### version
NCCL_MAJOR := 2 NCCL_MAJOR := 2
NCCL_MINOR := 17 NCCL_MINOR := 18
NCCL_PATCH := 1 NCCL_PATCH := 1
NCCL_SUFFIX := NCCL_SUFFIX :=
PKG_REVISION := 1 PKG_REVISION := 1

View File

@ -10,7 +10,7 @@ include ../makefiles/version.mk
##### src files ##### src files
INCEXPORTS := nccl.h nccl_net.h INCEXPORTS := nccl.h nccl_net.h
LIBSRCFILES := init.cc init_nvtx.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \ LIBSRCFILES := init.cc init_nvtx.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \
misc/cudawrap.cc misc/nvmlwrap.cc misc/ibvwrap.cc misc/gdrwrap.cc \ misc/cudawrap.cc misc/nvmlwrap.cc misc/ibvsymbols.cc misc/ibvwrap.cc misc/gdrwrap.cc \
misc/utils.cc misc/argcheck.cc misc/socket.cc misc/shmutils.cc misc/profiler.cc misc/param.cc misc/strongstream.cc \ misc/utils.cc misc/argcheck.cc misc/socket.cc misc/shmutils.cc misc/profiler.cc misc/param.cc misc/strongstream.cc \
misc/ipcsocket.cc \ misc/ipcsocket.cc \
transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc transport/nvls.cc \ transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc transport/nvls.cc \

View File

@ -305,6 +305,74 @@ ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm*
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t bootstrapSplit(struct ncclBootstrapHandle* handle, struct ncclComm* comm, struct ncclComm* parent, int color, int key, int* parentRanks) {
ncclResult_t ret = ncclSuccess;
int rank = comm->rank;
int nranks = comm->nRanks;
int prev, next;
ncclSocketAddress listenAddr, tmpAddr;
struct ncclSocket* proxySocket;
struct bootstrapState* state;
NCCLCHECKGOTO(ncclCalloc(&state, 1), ret, fail);
state->rank = rank;
state->nranks = nranks;
state->abortFlag = comm->abortFlag;
comm->bootstrap = state;
comm->magic = state->magic = handle->magic;
prev = parentRanks[(rank-1+nranks)%nranks];
next = parentRanks[(rank+1)%nranks];
// Setup my sockets for the allgather ring and other p2p connections
NCCLCHECKGOTO(ncclSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag, 0), ret, fail);
NCCLCHECKGOTO(ncclSocketInit(&state->ringRecvSocket, NULL, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag, 0), ret, fail);
// Create socket for other ranks to contact me
NCCLCHECKGOTO(ncclSocketListen(&state->listenSock), ret, fail);
// Get addr from next rank
NCCLCHECKGOTO(ncclSocketGetAddr(&state->listenSock, &listenAddr), ret, fail);
NCCLCHECKGOTO(bootstrapSend(parent->bootstrap, prev, -2, &listenAddr, sizeof(union ncclSocketAddress)), ret, fail);
NCCLCHECKGOTO(bootstrapRecv(parent->bootstrap, next, -2, &tmpAddr, sizeof(union ncclSocketAddress)), ret, fail);
NCCLCHECKGOTO(ncclSocketInit(&state->ringSendSocket, &tmpAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag, 0), ret, fail);
NCCLCHECKGOTO(ncclSocketConnect(&state->ringSendSocket), ret, fail);
// Accept the connect request from the previous rank in the AllGather ring
NCCLCHECKGOTO(ncclSocketAccept(&state->ringRecvSocket, &state->listenSock), ret, fail);
// AllGather all listen handlers
NCCLCHECKGOTO(ncclCalloc(&state->peerCommAddresses, nranks), ret, fail);
memcpy(state->peerCommAddresses+rank, &listenAddr, sizeof(union ncclSocketAddress));
NCCLCHECKGOTO(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union ncclSocketAddress)), ret, fail);
if (parent->config.splitShare) {
/* map local rank to top parent local rank. */
for (int i = 0; i < nranks; ++i) {
comm->topParentRanks[i] = parent->topParentRanks[parentRanks[i]];
}
comm->proxyState = parent->sharedRes->proxyState;
ncclAtomicRefCountIncrement(&parent->sharedRes->proxyState->refCount);
} else {
// Create the service proxy
NCCLCHECKGOTO(ncclCalloc(&state->peerProxyAddresses, nranks), ret, fail);
NCCLCHECKGOTO(ncclCalloc(&proxySocket, 1), ret, fail);
NCCLCHECKGOTO(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeProxy, comm->abortFlag, 0), ret, fail);
NCCLCHECKGOTO(ncclSocketListen(proxySocket), ret, fail);
NCCLCHECKGOTO(ncclSocketGetAddr(proxySocket, &tmpAddr), ret, fail);
memcpy(state->peerProxyAddresses + rank, &tmpAddr, sizeof(union ncclSocketAddress));
NCCLCHECKGOTO(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress)), ret, fail);
NCCLCHECKGOTO(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses), ret, fail);
}
INFO(NCCL_INIT, "bootstrapSplit: rank %d nranks %d color %d key %d prev %d next %d - DONE", rank, nranks, color, key, prev, next);
exit:
return ret;
fail:
goto exit;
}
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) { ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
struct bootstrapState* state = (struct bootstrapState*)commState; struct bootstrapState* state = (struct bootstrapState*)commState;
char* data = (char*)allData; char* data = (char*)allData;
@ -336,7 +404,7 @@ ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int s
struct bootstrapState* state = (struct bootstrapState*)commState; struct bootstrapState* state = (struct bootstrapState*)commState;
struct ncclSocket sock; struct ncclSocket sock;
NCCLCHECKGOTO(ncclSocketInit(&sock, state->peerCommAddresses+peer, state->magic, ncclSocketTypeBootstrap, state->abortFlag), ret, fail); NCCLCHECKGOTO(ncclSocketInit(&sock, state->peerCommAddresses+peer, state->magic, ncclSocketTypeBootstrap), ret, fail);
NCCLCHECKGOTO(ncclSocketConnect(&sock), ret, fail); NCCLCHECKGOTO(ncclSocketConnect(&sock), ret, fail);
NCCLCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail); NCCLCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail);
NCCLCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail); NCCLCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail);
@ -397,7 +465,7 @@ ncclResult_t bootstrapIntraNodeBroadcast(void* commState, int *ranks, int rank,
} }
} }
else { else {
NCCLCHECK(bootstrapRecv(commState, ranks[root], /*tag=*/rank, bcastData, size)); NCCLCHECK(bootstrapRecv(commState, ranks[root], /*tag=*/ranks[rank], bcastData, size));
} }
TRACE(NCCL_INIT, "rank %d nranks %d root %d size %d - DONE", rank, nranks, root, size); TRACE(NCCL_INIT, "rank %d nranks %d root %d size %d - DONE", rank, nranks, root, size);

View File

@ -17,30 +17,120 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) {
channel->id = channelId; channel->id = channelId;
channel->workFifoSent = 0; channel->workFifoSent = 0;
NCCLCHECK(ncclStrongStreamAcquireUncaptured(&comm->deviceStream)); struct ncclSharedResources* sharedRes = comm->sharedRes;
// The extra on nRanks+1 is for collnet root (i.e. network) NCCLCHECK(ncclStrongStreamAcquireUncaptured(&sharedRes->deviceStream));
channel->peers = ncclMemoryStackAlloc<struct ncclChannelPeer>(&comm->memPermanent, nPeers);
NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nPeers, comm->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devPeers);
channel->ring.userRanks = ncclMemoryStackAlloc<int>(&comm->memPermanent, nRanks); if (channel->peers == NULL) {
NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, comm->deviceStream.cudaStream)); // The extra on nRanks+1 is for collnet root (i.e. network)
ncclCommPushCudaFree(comm, channel->devRingUserRanks); // Allocate everything related to sharedRes with ncclCalloc as this can be
// shared between communicators hence should not be tied to comm.
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); if (sharedRes->peers[channelId] == NULL) {
NCCLCHECK(ncclCalloc(sharedRes->peers + channelId, sharedRes->tpNRanks));
for (int r=0; r < nPeers; ++r) { }
for (int b=0; b < NCCL_MAX_CONNS; b++) { channel->peers = ncclMemoryStackAlloc<struct ncclChannelPeer*>(&comm->memPermanent, nPeers);
channel->peers[r].send[b].comm = comm; for (int r = 0; r < nRanks; r++) {
channel->peers[r].recv[b].comm = comm; channel->peers[r] = comm->sharedRes->peers[channelId] + comm->topParentRanks[r];
ncclAtomicRefCountIncrement(&channel->peers[r]->refCount);
} }
} }
if (channel->devPeers == NULL) {
if (sharedRes->devPeers[channelId] == NULL) {
NCCLCHECK(ncclCudaCallocAsync(sharedRes->devPeers + channelId, sharedRes->tpNRanks, sharedRes->deviceStream.cudaStream));
}
/* channel->devPeers is not shared, so just free it when calling commFree() */
NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nPeers, sharedRes->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devPeers);
for (int r = 0; r < nRanks; r++) {
uintptr_t addr = (uintptr_t)(comm->sharedRes->devPeers[channelId] + comm->topParentRanks[r]);
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
}
}
channel->ring.userRanks = ncclMemoryStackAlloc<int>(&comm->memPermanent, nRanks);
NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, sharedRes->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devRingUserRanks);
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks) { ncclResult_t initNvlsChannel(struct ncclComm* comm, int channelId, struct ncclComm* parent, bool share) {
struct ncclChannel* channel = &comm->channels[channelId];
struct ncclSharedResources* sharedRes = comm->sharedRes;
if (channel->nvlsPeers != NULL)
return ncclSuccess;
if (channel->id == -1)
NCCLCHECK(initChannel(comm, channelId));
NCCLCHECK(ncclStrongStreamAcquireUncaptured(&sharedRes->deviceStream));
if (share) {
channel->nvlsPeers = parent->channels[channelId].nvlsPeers;
channel->nvlsDevPeers = parent->channels[channelId].nvlsDevPeers;
for (int r = 0; r < comm->localRanks; ++r) {
int tr = comm->topParentLocalRanks[r];
uintptr_t addr = (uintptr_t)(parent->channels[channelId].nvlsDevPeers + tr);
channel->peers[comm->nRanks + 1 + r] = parent->channels[channelId].nvlsPeers + tr;
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks + 1 + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
ncclAtomicRefCountIncrement(&parent->channels[channelId].nvlsPeers[tr].refCount);
}
} else {
NCCLCHECK(ncclCalloc(&channel->nvlsPeers, comm->localRanks));
NCCLCHECK(ncclCudaCallocAsync(&channel->nvlsDevPeers, comm->localRanks, sharedRes->deviceStream.cudaStream));
for (int r = 0; r < comm->localRanks; ++r) {
uintptr_t addr = (uintptr_t)(channel->nvlsDevPeers + r);
channel->peers[comm->nRanks + 1 + r] = channel->nvlsPeers + r;
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks + 1 + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
ncclAtomicRefCountIncrement(&channel->nvlsPeers[r].refCount);
}
}
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));
return ncclSuccess;
}
ncclResult_t initCollnetChannel(struct ncclComm* comm, int channelId, struct ncclComm* parent, bool share) {
struct ncclChannel* channel = &comm->channels[channelId];
struct ncclSharedResources* sharedRes = comm->sharedRes;
uintptr_t addr;
if (channel->collnetPeers != NULL)
return ncclSuccess;
if (channel->id == -1)
NCCLCHECK(initChannel(comm, channelId));
NCCLCHECK(ncclStrongStreamAcquireUncaptured(&sharedRes->deviceStream));
if (share) {
channel->collnetPeers = parent->channels[channelId].collnetPeers;
channel->collnetDevPeers = parent->channels[channelId].collnetDevPeers;
addr = (uintptr_t)parent->channels[channelId].collnetDevPeers;
channel->peers[comm->nRanks] = parent->channels[channelId].collnetPeers;
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
ncclAtomicRefCountIncrement(&parent->channels[channelId].collnetPeers->refCount);
} else {
NCCLCHECK(ncclCalloc(&channel->collnetPeers, 1));
NCCLCHECK(ncclCudaCallocAsync(&channel->collnetDevPeers, 1, sharedRes->deviceStream.cudaStream));
addr = (uintptr_t)channel->collnetDevPeers;
channel->peers[comm->nRanks] = channel->collnetPeers;
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
ncclAtomicRefCountIncrement(&channel->collnetPeers->refCount);
}
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));
return ncclSuccess;
}
ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks, int collnetNRanks, int nvlsNRanks) {
int nPeers = nRanks + collnetNRanks + nvlsNRanks;
/* channel peers are only valid when async init thread completes commAlloc() and /* channel peers are only valid when async init thread completes commAlloc() and
* the channel is intialized with initChannel(); if either is not done, this channel * the channel is intialized with initChannel(); if either is not done, this channel
* should never be free. */ * should never be free. */
@ -48,18 +138,23 @@ ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks) {
// Free transport proxy resources // Free transport proxy resources
// Note: free all send resources first due to CollNet arrangement // Note: free all send resources first due to CollNet arrangement
for (int r=0; r<nRanks+1; r++) { for (int r = 0; r < nPeers; r++) {
struct ncclChannelPeer* peer = channel->peers+r; struct ncclChannelPeer* peer = channel->peers[r];
for (int b=0; b<NCCL_MAX_CONNS; b++) { if (peer) {
if (peer->send[b].transportComm) NCCLCHECK(peer->send[b].transportComm->free(peer->send+b)); if (ncclAtomicRefCountDecrement(&peer->refCount) == 0) {
for (int b=0; b<NCCL_MAX_CONNS; b++) {
if (peer->send[b].transportComm) NCCLCHECK(peer->send[b].transportComm->free(peer->send+b));
if (peer->recv[b].transportComm) NCCLCHECK(peer->recv[b].transportComm->free(peer->recv+b));
}
if (r == nRanks) {
free(channel->collnetPeers);
ncclCudaFree(channel->collnetDevPeers);
} else if (r == nPeers - 1) {
free(channel->nvlsPeers);
ncclCudaFree(channel->nvlsDevPeers);
}
}
} }
} }
for (int r=0; r<nRanks+1; r++) {
struct ncclChannelPeer* peer = channel->peers+r;
for (int b=0; b<NCCL_MAX_CONNS; b++) {
if (peer->recv[b].transportComm) NCCLCHECK(peer->recv[b].transportComm->free(peer->recv+b));
}
}
return ncclSuccess; return ncclSuccess;
} }

View File

@ -55,7 +55,7 @@ namespace {
if (inputBuf + chunkOffset == outputBuf + offset) { // In place if (inputBuf + chunkOffset == outputBuf + offset) { // In place
prims.directSend(chunkOffset, offset, nelem); prims.directSend(chunkOffset, offset, nelem);
} else { } else {
prims.directCopySend(chunkOffset, offset, offset, nelem); prims.directCopySend(chunkOffset, offset, nelem);
} }
// k-2 steps: copy to next GPU // k-2 steps: copy to next GPU
@ -63,7 +63,7 @@ namespace {
rankDest = ringRanks[nranks-j]; rankDest = ringRanks[nranks-j];
offset = chunkOffset + rankDest * size; offset = chunkOffset + rankDest * size;
prims.directRecvCopySend(offset, offset, nelem); prims.directRecvCopySend(offset, nelem);
} }
// Make final copy from buffer to dest. // Make final copy from buffer to dest.
@ -118,19 +118,19 @@ struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SI
if (tid < tidEndGather) { if (tid < tidEndGather) {
// Gather // Gather
int group = (0*Proto::MaxGroupWidth) | (0<<16);
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, args->redOpArg, group, args); prims(tid, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff,
args->redOpArg, 0*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*chunkSize; ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
prims.gather(offset, nvls->nHeads*size, nelem, size, -1, 0); prims.gather(offset, nvls->nHeads*size, nelem, size, -1, 0);
} }
} else if (tid < tidEndBcast) { } else if (tid < tidEndBcast) {
int group = (3*Proto::MaxGroupWidth) | (1<<16); // Bcast through NVLS
// Bcast through MC
Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndGather, nThreadsBcast, NULL, &nvls->down, args->sendbuff, NULL, args->redOpArg, group, args); prims(tid-tidEndGather, nThreadsBcast, NULL, &nvls->down, args->sendbuff, NULL,
args->redOpArg, 3*Proto::MaxGroupWidth, 1, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*chunkSize; ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);

View File

@ -76,14 +76,14 @@ namespace {
chunk = ringIx + 0; chunk = ringIx + 0;
offset = calcOffset(chunk); offset = calcOffset(chunk);
nelem = min(realChunkSize, size-offset); nelem = min(realChunkSize, size-offset);
prims.directRecvReduceCopySend(offset, offset, offset, nelem, /*postOp=*/true); prims.directRecvReduceCopySend(offset, offset, nelem, /*postOp=*/true);
// k-2 steps: copy to next GPU // k-2 steps: copy to next GPU
for (int j=1; j<nranks-1; ++j) { for (int j=1; j<nranks-1; ++j) {
chunk = modRanks(ringIx + nranks-j); chunk = modRanks(ringIx + nranks-j);
offset = calcOffset(chunk); offset = calcOffset(chunk);
nelem = min(realChunkSize, size-offset); nelem = min(realChunkSize, size-offset);
prims.directRecvCopySend(offset, offset, nelem); prims.directRecvCopySend(offset, nelem);
} }
// Make final copy from buffer to dest. // Make final copy from buffer to dest.
@ -146,7 +146,7 @@ namespace {
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*int(chunkSize); ssize_t offset = gridOffset + bid*int(chunkSize);
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
prims.directSendFromOutput(offset, offset, nelem); prims.directSendFromOutput(offset, nelem);
} }
} }
else if (tree->down[0] == -1) { else if (tree->down[0] == -1) {
@ -160,7 +160,7 @@ namespace {
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*int(chunkSize); ssize_t offset = gridOffset + bid*int(chunkSize);
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
prims.directRecvCopySend(offset, offset, nelem); prims.directRecvCopySend(offset, nelem);
} }
} }
} }
@ -203,7 +203,7 @@ namespace {
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*int(chunkSize); ssize_t offset = gridOffset + bid*int(chunkSize);
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
prims.directRecvReduceCopySend(offset, offset, offset, nelem, /*doPost=*/true); prims.directRecvReduceCopySend(offset, offset, nelem, /*doPost=*/true);
} }
} }
else if (tid < nthreadsSplit) { else if (tid < nthreadsSplit) {
@ -235,7 +235,8 @@ namespace {
else { else {
// Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local) // Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local)
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto, 0> Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto, 0>
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg, 1*Proto::MaxGroupWidth); prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff,
args->redOpArg, 1*Proto::MaxGroupWidth);
if (tree->down[0] == -1) { if (tree->down[0] == -1) {
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*int(chunkSize); ssize_t offset = gridOffset + bid*int(chunkSize);
@ -247,7 +248,7 @@ namespace {
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*int(chunkSize); ssize_t offset = gridOffset + bid*int(chunkSize);
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
prims.directRecvCopySend(offset, offset, nelem); prims.directRecvCopySend(offset, nelem);
} }
} }
} }
@ -299,9 +300,9 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) { if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) {
// Scatter // Scatter
int group = (2*Proto::MaxGroupWidth) | (1<<16);
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/1, Proto, 0> Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/1, Proto, 0>
prims(tid-tidStartScatter, nThreadsScatter, NULL, direct->up, args->sendbuff, args->recvbuff, args->redOpArg, group, args); prims(tid-tidStartScatter, nThreadsScatter, NULL, direct->up, args->sendbuff, args->recvbuff,
args->redOpArg, 2*Proto::MaxGroupWidth, 1, 1, args);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize; ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize;
int nelem = min(direct->nHeads*chunkSize, size-offset); int nelem = min(direct->nHeads*chunkSize, size-offset);
@ -312,16 +313,16 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
} }
} }
} else if (tid >= tidStartReduce && direct->out != -1) { } else if (tid >= tidStartReduce && direct->out != -1) {
int group = (3*Proto::MaxGroupWidth) | (1<<16);
if (hasDn) { if (hasDn) {
// Reduce, send to network // Reduce, send to network
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 1>, /*Direct=*/1, Proto, 0> Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 1>, /*Direct=*/1, Proto, 0>
prims(tid-tidStartReduce, nThreadsReduce, direct->down, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg, group, args); prims(tid-tidStartReduce, nThreadsReduce, direct->down, &direct->out, args->sendbuff, args->recvbuff,
args->redOpArg, 3*Proto::MaxGroupWidth, 1, 1, args);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
if (args->regUsed) { if (args->regUsed) {
prims.directRecvReduceSend(offset, offset, nelem); prims.directRecvReduceSend(offset, nelem);
} else { } else {
prims.recvReduceSend(offset, nelem); prims.recvReduceSend(offset, nelem);
} }
@ -329,7 +330,8 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
} else { } else {
// Directly send to network // Directly send to network
Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto, 0>
prims(tid-tidStartReduce, nThreadsReduce, nullptr, &direct->out, args->sendbuff, args->recvbuff, args->redOpArg, group); prims(tid-tidStartReduce, nThreadsReduce, nullptr, &direct->out, args->sendbuff, args->recvbuff,
args->redOpArg, 3*Proto::MaxGroupWidth, 1, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
@ -338,29 +340,30 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
} }
} else if (tid < tidStartBcast && hasUp) { } else if (tid < tidStartBcast && hasUp) {
// Gather // Gather
int group = (0*Proto::MaxGroupWidth) | (0<<16);
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 0>, /*Direct=*/1, Proto, 0> Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 0>, /*Direct=*/1, Proto, 0>
prims(tid, nThreadsGather, direct->up, NULL, args->sendbuff, args->recvbuff, args->redOpArg, group, args); prims(tid, nThreadsGather, direct->up, NULL, args->sendbuff, args->recvbuff,
args->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, args);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize; ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize;
int nelem = min(direct->nHeads*chunkSize, size-offset); int nelem = min(direct->nHeads*chunkSize, size-offset);
prims.directGather(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); prims.directGather(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift);
} }
} else if (tid >= tidStartBcast && tid < tidStartScatter && direct->out != -1) { } else if (tid >= tidStartBcast && tid < tidStartScatter && direct->out != -1) {
int group = (1*Proto::MaxGroupWidth) | (0<<16);
if (hasDn) { if (hasDn) {
// Recv from network, broadcast // Recv from network, broadcast
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/1, Proto, 0> Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/1, Proto, 0>
prims(tid-tidStartBcast, nThreadsBcast, &direct->out, direct->down, args->sendbuff, args->recvbuff, args->redOpArg, group, args); prims(tid-tidStartBcast, nThreadsBcast, &direct->out, direct->down, args->sendbuff, args->recvbuff,
args->redOpArg, 1*Proto::MaxGroupWidth, 0, 0, args);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
prims.recvCopyDirectSend(offset, offset, nelem, /*postOp=*/true); prims.recvCopyDirectSend(offset, nelem, /*postOp=*/true);
} }
} else { } else {
// Recv from network (no post thread needed) // Recv from network (no post thread needed)
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0>
prims(tid-tidStartBcast, nThreadsBcast, &direct->out, nullptr, args->sendbuff, args->recvbuff, args->redOpArg, group); prims(tid-tidStartBcast, nThreadsBcast, &direct->out, nullptr, args->sendbuff, args->recvbuff,
args->redOpArg, 1*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
@ -383,23 +386,27 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SI
const ssize_t size = args->count; const ssize_t size = args->count;
const ssize_t loopSize = nChannels*nvls->nHeads*chunkSize; const ssize_t loopSize = nChannels*nvls->nHeads*chunkSize;
const int nranks = ncclShmem.comm.nRanks; const int nranks = ncclShmem.comm.nRanks;
const int reduceWarps = nranks <= 6 ? 6 : 4; const bool hasOut = nvls->out != -1;
const int copyWarps = ((NCCL_MAX_NTHREADS/WARP_SIZE) - reduceWarps)/2; const int reduceWarps = hasOut ? 3 : nranks <= 6 ? 7 : 5;
const int bcastWarps = hasOut ? 2 : 0;
const int scatterWarps = ((NCCL_MAX_NTHREADS/WARP_SIZE) - reduceWarps - bcastWarps + 1)/2;
const int gatherWarps = ((NCCL_MAX_NTHREADS/WARP_SIZE) - reduceWarps - bcastWarps)/2;
const int nThreadsScatter = copyWarps*WARP_SIZE; const int nThreadsScatter = scatterWarps*WARP_SIZE;
const int nThreadsGather = (copyWarps-1)*WARP_SIZE; const int nThreadsGather = gatherWarps*WARP_SIZE;
const int nThreadsReduce = (reduceWarps+1)*WARP_SIZE; const int nThreadsReduce = reduceWarps*WARP_SIZE;
const int nThreadsBcast = (bcastWarps)*WARP_SIZE;
const int tidEndScatter = nThreadsScatter; const int tidEndScatter = nThreadsScatter;
const int tidEndGather = tidEndScatter + nThreadsGather; const int tidEndGather = tidEndScatter + nThreadsGather;
const int tidEndReduce = tidEndGather + nThreadsReduce; const int tidEndReduce = tidEndGather + nThreadsReduce;
const int tidEndBcast = tidEndReduce + nThreadsBcast;
using Proto = ProtoSimple<1, 1, COLL_UNROLL, /*NVLS=*/true>;
if (tid < tidEndScatter) { if (tid < tidEndScatter) {
// Scatter // Scatter
int group = (0*Proto::MaxGroupWidth) | (0<<16); using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, args->recvbuff, args->redOpArg, group, args); prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL,
args->redOpArg, 0*Proto::MaxGroupWidth, 1, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize; ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize;
int nelem = min(nvls->nHeads*chunkSize, size-offset); int nelem = min(nvls->nHeads*chunkSize, size-offset);
@ -407,19 +414,136 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SI
} }
} else if (tid < tidEndGather) { } else if (tid < tidEndGather) {
// Gather // Gather
int group = (2*Proto::MaxGroupWidth) | (0<<16); using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndScatter, nThreadsGather, nvls->up, NULL, args->sendbuff, args->recvbuff, args->redOpArg, group, args); prims(tid-tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff,
args->redOpArg, 1*Proto::MaxGroupWidth, 1, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize; ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize;
int nelem = min(nvls->nHeads*chunkSize, size-offset); int nelem = min(nvls->nHeads*chunkSize, size-offset);
prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0);
} }
} else if (tid < tidEndReduce) { } else if (tid < tidEndReduce && nvls->headRank != -1) {
int group = (3*Proto::MaxGroupWidth) | (1<<16); if (!hasOut) {
// Reduce, broadcast through NVLS // Reduce, broadcast through NVLS
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 1>;
Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndGather, nThreadsReduce, &nvls->down, &nvls->down, NULL, NULL,
args->redOpArg, 2*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset);
prims.recvSend(nelem);
}
} else {
// Reduce, send to network
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndGather, nThreadsReduce, &nvls->down, &nvls->out, NULL, NULL,
args->redOpArg, 2*Proto::MaxGroupWidth, 0, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset);
prims.recvSend(nelem);
}
}
} else if (tid < tidEndBcast && nvls->headRank != -1) {
// Recv from network, broadcast
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>;
Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndGather, nThreadsReduce, &nvls->down, &nvls->down, args->sendbuff, args->recvbuff, args->redOpArg, group, args); prims(tid-tidEndReduce, nThreadsBcast, &nvls->out, &nvls->down, NULL, NULL,
args->redOpArg, 3*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset);
prims.recvSend(nelem);
}
}
#endif // NCCL_NVLS_ENABLED
}
};
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(ncclWorkElem *args) {
#if NCCL_NVLS_ENABLED
const int tid = threadIdx.x;
const int bid = args->bid;
const int nChannels = args->nChannels;
struct ncclNvls* nvls = &ncclShmem.channel.nvls;
const int treeUp = nvls->treeUp;
const int* treeDown = nvls->treeDown;
const ssize_t chunkSize = int(args->lastChunkSize);
const ssize_t size = args->count;
const ssize_t loopSize = nChannels*nvls->nHeads*chunkSize;
const int nranks = ncclShmem.comm.nRanks;
const bool hasUp = treeUp != -1;
const int reduceWarps = hasUp ? 5 : nranks <= 6 ? 7 : 5;
const int bcastWarps = hasUp ? 4 : 0;
const int scatterWarps = ((NCCL_MAX_NTHREADS/WARP_SIZE) - reduceWarps - bcastWarps + 1)/2;
const int gatherWarps = ((NCCL_MAX_NTHREADS/WARP_SIZE) - reduceWarps - bcastWarps)/2;
const int nThreadsScatter = scatterWarps*WARP_SIZE;
const int nThreadsGather = gatherWarps*WARP_SIZE;
const int nThreadsReduce = reduceWarps*WARP_SIZE;
const int nThreadsBcast = (bcastWarps)*WARP_SIZE;
const int tidEndScatter = nThreadsScatter;
const int tidEndGather = tidEndScatter + nThreadsGather;
const int tidEndReduce = tidEndGather + nThreadsReduce;
const int tidEndBcast = tidEndReduce + nThreadsBcast;
if (tid < tidEndScatter) {
// Scatter
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL,
args->redOpArg, 0*Proto::MaxGroupWidth, 1, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize;
int nelem = min(nvls->nHeads*chunkSize, size-offset);
prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0);
}
} else if (tid < tidEndGather) {
// Gather
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff,
args->redOpArg, 1*Proto::MaxGroupWidth, 1, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize;
int nelem = min(nvls->nHeads*chunkSize, size-offset);
prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0);
}
} else if (tid < tidEndReduce && nvls->headRank != -1) {
if (!hasUp) {
// Reduce and Broadcast
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 1>;
Primitives<T, RedOp, FanSymmetric<3>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndGather, nThreadsReduce, treeDown, treeDown, NULL, NULL,
args->redOpArg, 2*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset);
prims.recvSend(nelem);
}
} else {
// Reduce, send to network
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
Primitives<T, RedOp, FanAsymmetric<3, 1>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndGather, nThreadsReduce, treeDown, &treeUp, NULL, NULL,
args->redOpArg, 2*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset);
prims.recvSend(nelem);
}
}
} else if (tid < tidEndBcast && nvls->headRank != -1) {
// Recv from network, broadcast
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>;
Primitives<T, RedOp, FanAsymmetric<1, 3>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndReduce, nThreadsBcast, &treeUp, treeDown, NULL, NULL,
args->redOpArg, 3*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize; ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
@ -445,16 +569,20 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_CHAIN, NCCL
int nthreadsSplit = nthreads/2; int nthreadsSplit = nthreads/2;
if (nthreadsSplit >= 256) nthreadsSplit += 64; if (nthreadsSplit >= 256) nthreadsSplit += 64;
int group, send, recv, groupTid, groupNthreads; int group, connIndex, send, recv, groupTid, groupNthreads;
using Proto = ProtoSimple<1, 1>; using Proto = ProtoSimple<1, 1>;
if (tid < nthreadsSplit) { if (tid < nthreadsSplit) {
group = (0*Proto::MaxGroupWidth) | (1<<16); // Reduce up the chain
group = 0;
connIndex = 1;
recv = tree->down[0]; recv = tree->down[0];
send = tree->up; send = tree->up;
groupTid = tid; groupTid = tid;
groupNthreads = nthreadsSplit; groupNthreads = nthreadsSplit;
} else { } else {
group = (1*Proto::MaxGroupWidth); // Broadcast down the chain
group = 1;
connIndex = 0;
recv = tree->up; recv = tree->up;
send = tree->down[0]; send = tree->down[0];
groupTid = tid - nthreadsSplit; groupTid = tid - nthreadsSplit;
@ -462,7 +590,8 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_CHAIN, NCCL
} }
Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/1, Proto, 0> Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/1, Proto, 0>
prims(groupTid, groupNthreads, &recv, &send, args->sendbuff, args->recvbuff, args->redOpArg, group); prims(groupTid, groupNthreads, &recv, &send, args->sendbuff, args->recvbuff,
args->redOpArg, group*Proto::MaxGroupWidth, connIndex, connIndex);
if (tid < nthreadsSplit) { if (tid < nthreadsSplit) {
if (recv == -1) { if (recv == -1) {
@ -490,7 +619,7 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_CHAIN, NCCL
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*int(chunkSize); ssize_t offset = gridOffset + bid*int(chunkSize);
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
prims.directRecvCopySend(offset, offset, nelem); prims.directRecvCopySend(offset, nelem);
} }
} }
} }

View File

@ -22,7 +22,6 @@ struct ncclShmemGroup {
ncclConnInfo *sendConns[NCCL_MAX_NVLS_ARITY]; ncclConnInfo *sendConns[NCCL_MAX_NVLS_ARITY];
void* srcs[NCCL_MAX_NVLS_ARITY+1]; void* srcs[NCCL_MAX_NVLS_ARITY+1];
void* dsts[NCCL_MAX_NVLS_ARITY+1]; void* dsts[NCCL_MAX_NVLS_ARITY+1];
int nvlsRecv;
}; };
struct ncclShmemData { struct ncclShmemData {
@ -237,7 +236,8 @@ __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)() { \
IMPL_COLL4(func, RING, devredop, type, ncclType) \ IMPL_COLL4(func, RING, devredop, type, ncclType) \
IMPL_COLL4(func, COLLNET_DIRECT, devredop, type, ncclType) \ IMPL_COLL4(func, COLLNET_DIRECT, devredop, type, ncclType) \
IMPL_COLL4(func, COLLNET_CHAIN, devredop, type, ncclType) \ IMPL_COLL4(func, COLLNET_CHAIN, devredop, type, ncclType) \
IMPL_COLL4(func, NVLS, devredop, type, ncclType) IMPL_COLL4(func, NVLS, devredop, type, ncclType) \
IMPL_COLL4(func, NVLS_TREE, devredop, type, ncclType)
#if NCCL_TYPE == 0 #if NCCL_TYPE == 0
#define IMPL_COLL2(func, devredop) IMPL_COLL3(func, devredop, int8_t, ncclInt8) #define IMPL_COLL2(func, devredop) IMPL_COLL3(func, devredop, int8_t, ncclInt8)

View File

@ -26,7 +26,8 @@ inline __device__ int loadInt(int* ptr) {
} }
template<typename RedFn, typename T, int Unroll, int BytePerPack, template<typename RedFn, typename T, int Unroll, int BytePerPack,
int MinSrcs, int MaxSrcs, int MinDsts, int MaxDsts, int PreOpSrcs, int MultimemSrcs, int MinSrcs, int MaxSrcs,
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
typename IntBytes> typename IntBytes>
__device__ __forceinline__ void reduceCopyPacks( __device__ __forceinline__ void reduceCopyPacks(
int nThreads, int &thread, int nThreads, int &thread,
@ -35,6 +36,7 @@ __device__ __forceinline__ void reduceCopyPacks(
IntBytes &nBytesBehind, IntBytes &nBytesAhead IntBytes &nBytesBehind, IntBytes &nBytesAhead
) { ) {
static_assert(std::is_signed<IntBytes>::value, "IntBytes must be a signed integral type."); static_assert(std::is_signed<IntBytes>::value, "IntBytes must be a signed integral type.");
if (BytePerPack == 0) __trap();
// A hunk is the amount of contiguous data a warp consumes per loop iteration // A hunk is the amount of contiguous data a warp consumes per loop iteration
// assuming all threads partake. // assuming all threads partake.
@ -47,15 +49,15 @@ __device__ __forceinline__ void reduceCopyPacks(
IntBytes threadBytesBehind = nBytesBehind + (warp*BytePerHunk + lane*BytePerPack); IntBytes threadBytesBehind = nBytesBehind + (warp*BytePerHunk + lane*BytePerPack);
IntBytes threadBytesAhead = nBytesAhead - (warp*BytePerHunk + lane*BytePerPack); IntBytes threadBytesAhead = nBytesAhead - (warp*BytePerHunk + lane*BytePerPack);
// Number of hunks to be consumed over all warps. // Number of hunks to be consumed over all warps.
IntBytes nHunksAhead = nBytesAhead/BytePerHunk; IntBytes nHunksAhead = nBytesAhead/(BytePerHunk + !BytePerHunk);
// Advance collective position. // Advance collective position.
nBytesBehind += nHunksAhead*BytePerHunk; nBytesBehind += nHunksAhead*BytePerHunk;
nBytesAhead -= nHunksAhead*BytePerHunk; nBytesAhead -= nHunksAhead*BytePerHunk;
if (Unroll==1 && BytePerPack <= nBytesAhead) { if (Unroll==1 && BytePerPack <= nBytesAhead) {
// Only Unroll=1 can do partial hunks (where not all threads partake). // Only Unroll=1 can do partial hunks (where not all threads partake).
nHunksAhead += 1; nHunksAhead += 1;
nBytesBehind += nBytesAhead - (nBytesAhead%BytePerPack); nBytesBehind += nBytesAhead - (nBytesAhead%(BytePerPack + !BytePerPack));
nBytesAhead = nBytesAhead%BytePerPack; nBytesAhead = nBytesAhead%(BytePerPack + !BytePerPack);
} }
nHunksAhead -= warp; nHunksAhead -= warp;
@ -77,8 +79,13 @@ __device__ __forceinline__ void reduceCopyPacks(
{ RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0); { RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0);
#pragma unroll Unroll #pragma unroll Unroll
for (int u=0; u < Unroll; u++) { for (int u=0; u < Unroll; u++) {
// Use volatile loads in case credits are polled for with volatile (instead of acquire). if (0 < MultimemSrcs) {
acc[u] = ld_volatile_global<BytePerPack>(minSrcs[0]); // applyLoadMultimem uses relaxed semantics for same reason we use volatile below.
acc[u] = applyLoadMultimem<RedFn, BytePerPack>(preFn, minSrcs[0]);
} else {
// Use volatile loads in case credits are polled for with volatile (instead of acquire).
acc[u] = ld_volatile_global<BytePerPack>(minSrcs[0]);
}
minSrcs[0] += WARP_SIZE*BytePerPack; minSrcs[0] += WARP_SIZE*BytePerPack;
if (0 < PreOpSrcs) acc[u] = applyPreOp(preFn, acc[u]); if (0 < PreOpSrcs) acc[u] = applyPreOp(preFn, acc[u]);
} }
@ -90,8 +97,13 @@ __device__ __forceinline__ void reduceCopyPacks(
RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0);
#pragma unroll Unroll #pragma unroll Unroll
for (int u=0; u < Unroll; u++) { for (int u=0; u < Unroll; u++) {
// Use volatile loads in case credits are polled for with volatile (instead of acquire). if (s < MultimemSrcs) {
tmp[u] = ld_volatile_global<BytePerPack>(minSrcs[s]); // applyLoadMultimem uses relaxed semantics for same reason we use volatile below.
acc[u] = applyLoadMultimem<RedFn, BytePerPack>(preFn, minSrcs[s]);
} else {
// Use volatile loads in case credits are polled for with volatile (instead of acquire).
tmp[u] = ld_volatile_global<BytePerPack>(minSrcs[s]);
}
minSrcs[s] += WARP_SIZE*BytePerPack; minSrcs[s] += WARP_SIZE*BytePerPack;
} }
#pragma unroll Unroll #pragma unroll Unroll
@ -128,7 +140,11 @@ __device__ __forceinline__ void reduceCopyPacks(
for (int d=0; d < MinDsts; d++) { for (int d=0; d < MinDsts; d++) {
#pragma unroll Unroll #pragma unroll Unroll
for (int u=0; u < Unroll; u++) { for (int u=0; u < Unroll; u++) {
st_global<BytePerPack>(minDsts[d], acc[u]); if (d < MultimemDsts) {
multimem_st_global(minDsts[d], acc[u]);
} else {
st_global<BytePerPack>(minDsts[d], acc[u]);
}
minDsts[d] += WARP_SIZE*BytePerPack; minDsts[d] += WARP_SIZE*BytePerPack;
} }
} }
@ -165,213 +181,61 @@ __device__ __forceinline__ void reduceCopyPacks(
} }
template<int Unroll, typename RedFn, typename T, template<int Unroll, typename RedFn, typename T,
int MinSrcs, int MaxSrcs, int MinDsts, int MaxDsts, int PreOpSrcs, int MultimemSrcs, int MinSrcs, int MaxSrcs,
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
typename IntBytes> typename IntBytes>
__device__ __forceinline__ void ReduceOrCopyMulti( __device__ __forceinline__ void reduceCopy(
int thread, int nThreads, int thread, int nThreads,
uint64_t redArg, uint64_t *preOpArgs, bool postOp, uint64_t redArg, uint64_t *preOpArgs, bool postOp,
int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs,
IntBytes nElts IntBytes nElts
) { ) {
static_assert(MultimemSrcs <= MinSrcs && MultimemDsts <= MinDsts, "Multimem pointers cannot exceed respective Min values.");
//int nWarps = nThreads/WARP_SIZE; //int nWarps = nThreads/WARP_SIZE;
//int warp = thread/WARP_SIZE; //int warp = thread/WARP_SIZE;
int lane = thread%WARP_SIZE; int lane = thread%WARP_SIZE;
// If a multimem src is present then our biggest pack size is limited to what
// Check that all is 16B aligned. If not don't use 16B load/stores. // is supported for this redfn/type.
int aligned = 1; constexpr int BigPackSize = (MultimemSrcs == 0) ? 16 : LoadMultimem_BigPackSize<RedFn>::BigPackSize;
if (lane < nSrcs) aligned &= 0 == cvta_to_global(srcPtrs[lane])%16;
if (lane < nDsts) aligned &= 0 == cvta_to_global(dstPtrs[lane])%16;
aligned = __all_sync(~0u, aligned);
IntBytes nBytesBehind = 0; IntBytes nBytesBehind = 0;
IntBytes nBytesAhead = nElts*sizeof(T); IntBytes nBytesAhead = nElts*sizeof(T);
if (aligned) {
reduceCopyPacks<RedFn, T, Unroll, /*BytePerPack=*/16,
MinSrcs, MaxSrcs, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead);
if (nBytesAhead == 0) return;
reduceCopyPacks<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/16, #if __cpp_if_constexpr
MinSrcs, MaxSrcs, MinDsts, MaxDsts, PreOpSrcs> if constexpr (BigPackSize > sizeof(T)) {
(nThreads, /*&*/thread, redArg, preOpArgs, postOp, #else
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); if (BigPackSize > sizeof(T)) {
if (nBytesAhead == 0) return; #endif
// Check that all pointers are BigPackSize aligned.
bool aligned = true;
if (lane < nSrcs) aligned &= 0 == cvta_to_global(srcPtrs[lane]) % (BigPackSize + !BigPackSize);
if (lane < nDsts) aligned &= 0 == cvta_to_global(dstPtrs[lane]) % (BigPackSize + !BigPackSize);
aligned = __all_sync(~0u, aligned);
if (aligned) {
reduceCopyPacks<RedFn, T, Unroll, BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead);
if (nBytesAhead == 0) return;
reduceCopyPacks<RedFn, T, /*Unroll=*/1, BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead);
if (nBytesAhead == 0) return;
}
} }
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T), reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
MinSrcs, MaxSrcs, MinDsts, MaxDsts, PreOpSrcs> MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp, (nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead);
if (nBytesAhead == 0) return; if (nBytesAhead == 0) return;
reduceCopyPacks<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T), reduceCopyPacks<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
MinSrcs, MaxSrcs, MinDsts, MaxDsts, PreOpSrcs> MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp, (nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead);
} }
// Copies from srcAddr to dstAddr using multimem load/store. The amount copied
// will be at most Unroll*BytePerPack*WARP_SIZE. If Partial=1, then the amount
// will be the min() of that and nBytesAhead. If srcAddr is not BytePerPack
// aligned then the amount copied will be less by (srcAddr%BytePerPack) since
// we begin loads at the first pack containing the first element.
template<typename RedFn, typename T, int Unroll, int BytePerPack,
bool SrcAligned, // is srcAddr aligned to BytePerPack
bool DstAligned, // are dstAddr and nBytesAhead both aligned to BytePerPack
bool Partial, // is this a possibly partial hunk
typename IntBytes>
__device__ __forceinline__ void copyMultimemMultimem_WarpUnrolled(
int lane, RedFn redFn, bool postOp, uintptr_t srcAddr, uintptr_t dstAddr,
IntBytes nBytesAhead, uint32_t scratchAddr
) {
int srcMisalign = SrcAligned ? 0 : srcAddr%BytePerPack;
srcAddr -= srcMisalign;
BytePack<BytePerPack> reg[Unroll];
int offset = lane*BytePerPack;
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
if (!Partial || (offset < srcMisalign + nBytesAhead)) {
reg[u] = applyLoadMultimem(redFn, srcAddr+offset);
if (postOp) reg[u] = applyPostOp(redFn, reg[u]);
}
offset += WARP_SIZE*BytePerPack;
}
if (SrcAligned && DstAligned) {
offset = lane*BytePerPack;
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
if (!Partial || offset < nBytesAhead) {
multimem_st_global<BytePerPack>(dstAddr+offset, reg[u]);
}
offset += WARP_SIZE*BytePerPack;
}
} else {
__syncwarp();
offset = lane*BytePerPack;
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
if (!Partial || (offset < srcMisalign + nBytesAhead)) {
st_shared<BytePerPack>(scratchAddr+offset, reg[u]);
}
offset += WARP_SIZE*BytePerPack;
}
__syncwarp();
if (!SrcAligned) {
// Ignore the beginning of the first pack corresponding to bytes overread
// due to misalignment.
nBytesAhead = min(nBytesAhead, Unroll*WARP_SIZE*BytePerPack - srcMisalign);
}
copyGlobalShared_WarpUnrolled
<sizeof(T), /*MaxBytes=*/Unroll*WARP_SIZE*BytePerPack, /*Multimem=*/1>
(lane, dstAddr, scratchAddr+srcMisalign, nBytesAhead);
}
}
// copyMultimemMultimem_IfEnabled has two overloads: the enabled case whose first arg
// has type `std::true_type` and the disabled case with first arg `std::false_type`.
// This is to guard the template instantiations of Apply_LoadMultimem on types/ops where
// they aren't supported. A nicer approach is to use C++17's "if constexpr".
template<typename RedFn, typename IntBytes>
__device__ __forceinline__ void copyMultimemMultimem_IfEnabled(
std::false_type enabled/*=false*/,
int thread, int nThreads, uint64_t redArg, bool postOp,
void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr
) {
// nop
}
template<typename RedFn, typename IntBytes>
__device__ __forceinline__ void copyMultimemMultimem_IfEnabled(
std::true_type enabled/*=true*/,
int thread, int nThreads, uint64_t redArg, bool postOp,
void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr
) {
static_assert(std::is_signed<IntBytes>::value, "IntBytes must be a signed integral type.");
constexpr int BytePerPack = Apply_LoadMultimem<RedFn>::PackSize;
using T = typename RedFn::EltType;
constexpr int Unroll = ncclNvlsUnroll(BytePerPack);
constexpr int BytePerHunk = Unroll*WARP_SIZE*BytePerPack;
int nWarps = nThreads/WARP_SIZE;
int warp = thread/WARP_SIZE;
int lane = thread%WARP_SIZE;
RedFn redFn(redArg);
uintptr_t srcAddr = cvta_to_global(srcPtr);
uintptr_t dstAddr = cvta_to_global(dstPtr);
IntBytes warpBytesAhead = nElts*sizeof(T);
bool partialHunkIsFront;
// First handle misalignment of srcAddr.
if ((BytePerPack != sizeof(T)) && (srcAddr%BytePerPack != 0)) {
// If srcAddr isn't pack aligned then the first hunk processed will be short
// the same number of bytes as srcAddr's misalignment.
if (warp == 0) {
partialHunkIsFront = true;
goto PartialHunk; // "call" PartialHunk()
PartialHunkFrontReturn:
warp = nWarps;
}
warp -= 1; // Rotate warp numbers for load balancing
int advanced = BytePerHunk-(srcAddr%BytePerPack); // since copyMultimemMultimem_WarpUnrolled shorts by the misalignment
srcAddr += advanced; // srcAddr is now pack aligned
dstAddr += advanced;
warpBytesAhead -= advanced;
}
warpBytesAhead -= warp*BytePerHunk;
srcAddr += warp*BytePerHunk;
dstAddr += warp*BytePerHunk;
// Now that srcAddr is pack aligned detect if dstAddr is pack aligned.
if ((BytePerPack == sizeof(T)) || (dstAddr%BytePerPack == 0)) {
while (BytePerHunk <= warpBytesAhead) {
copyMultimemMultimem_WarpUnrolled
<RedFn, T, Unroll, BytePerPack, /*SrcAligned=*/true, /*DstAligned=*/true, /*Partial=*/false>
(lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr);
srcAddr += nWarps*BytePerHunk;
dstAddr += nWarps*BytePerHunk;
warpBytesAhead -= nWarps*BytePerHunk;
}
} else {
while (BytePerHunk <= warpBytesAhead) {
copyMultimemMultimem_WarpUnrolled
<RedFn, T, Unroll, BytePerPack, /*SrcAligned=*/true, /*DstAligned=*/false, /*Partial=*/false>
(lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr);
srcAddr += nWarps*BytePerHunk;
dstAddr += nWarps*BytePerHunk;
warpBytesAhead -= nWarps*BytePerHunk;
}
}
if (0 < warpBytesAhead) {
partialHunkIsFront = false;
goto PartialHunk; // "call" PartialHunk()
PartialHunkBackReturn:;
}
return;
PartialHunk:
// We have to handle a partial hunk possibly at the front and back of the
// buffer. We generate the code once here since its a lot of instructions,
// and then simulate function calls with gotos.
copyMultimemMultimem_WarpUnrolled
<RedFn, T, Unroll, BytePerPack, /*SrcAligned=*/false, /*DstAligned=*/false, /*Partial=*/true>
(lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr);
if (partialHunkIsFront) goto PartialHunkFrontReturn;
goto PartialHunkBackReturn;
}
template<typename RedFn, typename IntBytes>
__device__ __forceinline__ void copyMultimemMultimem(
int thread, int nThreads, uint64_t redArg, bool postOp,
void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr
) {
constexpr bool Enabled = Apply_LoadMultimem<RedFn>::PackSize != 0;
copyMultimemMultimem_IfEnabled<RedFn>(
/*enabled=*/std::integral_constant<bool, Enabled>(),
thread, nThreads, redArg, postOp, srcPtr, dstPtr, nElts, warpScratchAddr);
}
#endif // COMMON_KERNEL_H_ #endif // COMMON_KERNEL_H_

View File

@ -23,7 +23,8 @@ __shared__ ncclShmemData ncclShmem;
NCCL_FUNC5(func, RING, devredop, type, nullify), \ NCCL_FUNC5(func, RING, devredop, type, nullify), \
NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \ NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \
NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \ NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \
NCCL_FUNC5(func, NVLS, devredop, type, nullify) NCCL_FUNC5(func, NVLS, devredop, type, nullify), \
NCCL_FUNC5(func, NVLS_TREE, devredop, type, nullify)
#if defined(__CUDA_BF16_TYPES_EXIST__) #if defined(__CUDA_BF16_TYPES_EXIST__)
// Must be consistent with ncclDataType_t // Must be consistent with ncclDataType_t

View File

@ -37,7 +37,7 @@ namespace {
dst += i0; dst += i0;
void *vsrc = (void*)src; void *vsrc = (void*)src;
void *vdst = (void*)dst; void *vdst = (void*)dst;
ReduceOrCopyMulti<COLL_UNROLL, RedOp, T, 1, 1, 1, 1, /*PreOpSrcs=*/1> reduceCopy<COLL_UNROLL, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/1>
(tid, tn, we->redOpArg, &(we->redOpArg), true, 1, &vsrc, 1, &vdst, i1-i0); (tid, tn, we->redOpArg, &(we->redOpArg), true, 1, &vsrc, 1, &vdst, i1-i0);
} }
} }

View File

@ -7,6 +7,8 @@
#ifndef OP128_H_ #ifndef OP128_H_
#define OP128_H_ #define OP128_H_
#include <type_traits>
inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) { inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];"
: "=l"(v0), "=l"(v1) : "l"(ptr)); : "=l"(v0), "=l"(v1) : "l"(ptr));
@ -94,6 +96,8 @@ __device__ __forceinline__ T* cvta_from_global(uintptr_t gptr) {
template<int Size> template<int Size>
union BytePack; union BytePack;
template<> template<>
union BytePack<0> {};
template<>
union BytePack<1> { union BytePack<1> {
uint8_t u8, native; uint8_t u8, native;
}; };
@ -129,14 +133,26 @@ union alignas(16) BytePack<16> {
}; };
template<typename T> template<typename T>
__device__ __forceinline__ BytePack<sizeof(T)> toPack(T value) { struct BytePackOf {
union { BytePack<sizeof(T)> p; T v; }; static constexpr int Size = sizeof(T);
using Pack = BytePack<Size>;
};
template<>
struct BytePackOf<BytePack<0>> {
static constexpr int Size = 0;
using Pack = BytePack<0>;
};
template<typename T>
__device__ __forceinline__ typename BytePackOf<T>::Pack toPack(T value) {
union { typename BytePackOf<T>::Pack p; T v; };
v = value; v = value;
return p; return p;
} }
template<typename T> template<typename T>
__device__ __forceinline__ T fromPack(BytePack<sizeof(T)> pack) { __device__ __forceinline__ T fromPack(typename BytePackOf<T>::Pack pack) {
union { BytePack<sizeof(T)> p; T v; }; union { typename BytePackOf<T>::Pack p; T v; };
p = pack; p = pack;
return v; return v;
} }
@ -151,6 +167,13 @@ template<int Size> __device__ BytePack<Size> ld_volatile_shared(uint32_t addr);
template<int Size> __device__ void st_global(uintptr_t addr, BytePack<Size> value); template<int Size> __device__ void st_global(uintptr_t addr, BytePack<Size> value);
template<int Size> __device__ void st_shared(uint32_t addr, BytePack<Size> value); template<int Size> __device__ void st_shared(uint32_t addr, BytePack<Size> value);
template<> __device__ __forceinline__ BytePack<0> ld_global<0>(uintptr_t addr) { return {}; }
template<> __device__ __forceinline__ BytePack<0> ld_volatile_global<0>(uintptr_t addr) { return {}; }
template<> __device__ __forceinline__ BytePack<0> ld_shared<0>(uint32_t addr) { return {}; }
template<> __device__ __forceinline__ BytePack<0> ld_volatile_shared<0>(uint32_t addr) { return {}; }
template<> __device__ __forceinline__ void st_global<0>(uintptr_t addr, BytePack<0> value) {}
template<> __device__ __forceinline__ void st_shared<0>(uint32_t addr, BytePack<0> value) {}
// Used to define implementations for above prototypes. // Used to define implementations for above prototypes.
#define DEFINE_ld_st(bytes, data_cxx_ty, data_ptx_ty, data_reg_ty, space, addr_cxx_ty, addr_reg_ty) \ #define DEFINE_ld_st(bytes, data_cxx_ty, data_ptx_ty, data_reg_ty, space, addr_cxx_ty, addr_reg_ty) \
template<> \ template<> \
@ -275,6 +298,18 @@ __device__ __forceinline__ void multimem_st_global(uintptr_t addr, BytePack<Size
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
template<> template<>
__device__ __forceinline__ void multimem_st_global<0>(uintptr_t addr, BytePack<0> val) {
// nop
}
template<>
__device__ __forceinline__ void multimem_st_global<1>(uintptr_t addr, BytePack<1> val) {
asm volatile("st.global.b8 [%0], %1;" :: "l"(addr), "r"((uint32_t)val.u8) : "memory");
}
template<>
__device__ __forceinline__ void multimem_st_global<2>(uintptr_t addr, BytePack<2> val) {
asm volatile("st.global.b16 [%0], %1;" :: "l"(addr), "h"(val.u16) : "memory");
}
template<>
__device__ __forceinline__ void multimem_st_global<4>(uintptr_t addr, BytePack<4> val) { __device__ __forceinline__ void multimem_st_global<4>(uintptr_t addr, BytePack<4> val) {
asm volatile("multimem.st.global.b32 [%0], %1;" :: "l"(addr), "r"(val.u32) : "memory"); asm volatile("multimem.st.global.b32 [%0], %1;" :: "l"(addr), "r"(val.u32) : "memory");
} }

View File

@ -21,13 +21,14 @@
* to how that protocol operates with a consistent interface so that our * to how that protocol operates with a consistent interface so that our
* algorithm code can operate protocol parametrically. * algorithm code can operate protocol parametrically.
*/ */
template<int SlicePerChunk_1, int StepPerSlice_1, int Unroll_1 = COLL_UNROLL, bool NVLS_1 = false> template<int SlicePerChunk_1, int StepPerSlice_1, int Unroll_1 = COLL_UNROLL, int MultimemSrcs_1 = 0, int MultimemDsts_1 = 0>
struct ProtoSimple { struct ProtoSimple {
static constexpr int Id = NCCL_PROTO_SIMPLE; static constexpr int Id = NCCL_PROTO_SIMPLE;
static constexpr int SlicePerChunk = SlicePerChunk_1; static constexpr int SlicePerChunk = SlicePerChunk_1;
static constexpr int StepPerSlice = StepPerSlice_1; static constexpr int StepPerSlice = StepPerSlice_1;
static constexpr int Unroll = Unroll_1; static constexpr int Unroll = Unroll_1;
static constexpr bool NVLS = NVLS_1; static constexpr int MultimemSrcs = MultimemSrcs_1;
static constexpr int MultimemDsts = MultimemDsts_1;
// Data bytes (no flags etc) in one step of the fifo queue. // Data bytes (no flags etc) in one step of the fifo queue.
__device__ static int calcBytePerStep() { __device__ static int calcBytePerStep() {
@ -39,9 +40,6 @@ struct ProtoSimple {
} }
// Group width is how many consecutive group values a subchannel occupies. // Group width is how many consecutive group values a subchannel occupies.
static constexpr int MaxGroupWidth = 2; static constexpr int MaxGroupWidth = 2;
__device__ static int calcGroupWidth(bool send, int nthreads) {
return send && nthreads-WARP_SIZE >= 64 ? 2 : 1;
}
}; };
struct ProtoLL { struct ProtoLL {
@ -57,9 +55,6 @@ struct ProtoLL {
} }
// Group width is how many consecutive group values a subchannel occupies. // Group width is how many consecutive group values a subchannel occupies.
static constexpr int MaxGroupWidth = 1; static constexpr int MaxGroupWidth = 1;
__device__ static int calcGroupWidth(bool send, int nthreads) {
return 1;
}
}; };
struct ProtoLL128 { struct ProtoLL128 {
@ -75,9 +70,6 @@ struct ProtoLL128 {
} }
// Group width is how many consecutive group values a subchannel occupies. // Group width is how many consecutive group values a subchannel occupies.
static constexpr int MaxGroupWidth = 1; static constexpr int MaxGroupWidth = 1;
__device__ static int calcGroupWidth(bool send, int nthreads) {
return 1;
}
}; };
/* Fan (as in fan-in & fan-out) classes hold recv and send counts. The template /* Fan (as in fan-in & fan-out) classes hold recv and send counts. The template
@ -117,22 +109,22 @@ class Primitives;
// Used by LL & LL128 to implement direct members in the naive way. // Used by LL & LL128 to implement direct members in the naive way.
template<typename RealPrimitives> template<typename RealPrimitives>
struct PrimitivesWithoutDirect { struct PrimitivesWithoutDirect {
__device__ void directSend(intptr_t inpIx, intptr_t remoteOutIx, int eltN) { __device__ void directSend(intptr_t inpIx, intptr_t outIx, int eltN) {
static_cast<RealPrimitives*>(this)->send(inpIx, eltN); static_cast<RealPrimitives*>(this)->send(inpIx, eltN);
} }
__device__ void directSendFromOutput(intptr_t outIx, intptr_t remoteOutIx, int eltN) { __device__ void directSendFromOutput(intptr_t outIx, int eltN) {
static_cast<RealPrimitives*>(this)->sendFromOutput(outIx, eltN); static_cast<RealPrimitives*>(this)->sendFromOutput(outIx, eltN);
} }
__device__ void directRecv(intptr_t outIx, int eltN) { __device__ void directRecv(intptr_t outIx, int eltN) {
static_cast<RealPrimitives*>(this)->recv(outIx, eltN, /*postOp=*/false); static_cast<RealPrimitives*>(this)->recv(outIx, eltN, /*postOp=*/false);
} }
__device__ void directCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { __device__ void directCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
static_cast<RealPrimitives*>(this)->copySend(inpIx, outIx, eltN, postOp); static_cast<RealPrimitives*>(this)->copySend(inpIx, outIx, eltN, postOp);
} }
__device__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) { __device__ void directRecvCopySend(intptr_t outIx, int eltN) {
static_cast<RealPrimitives*>(this)->recvCopySend(outIx, eltN, /*postOp=*/false); static_cast<RealPrimitives*>(this)->recvCopySend(outIx, eltN, /*postOp=*/false);
} }
__device__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { __device__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
// Direct is only for the send part // Direct is only for the send part
static_cast<RealPrimitives*>(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp); static_cast<RealPrimitives*>(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp);
} }

View File

@ -322,22 +322,22 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p>:
public: public:
__device__ Primitives( __device__ Primitives(
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0 void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group=0,
uint8_t connIndexRecv=0, uint8_t connIndexSend=0
): ):
redOp(redOpArg), redOp(redOpArg),
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group&(uint16_t)0xFFFF), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group),
stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) { stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) {
int connIndex = group >> 16;
auto *channel = &ncclShmem.channel; auto *channel = &ncclShmem.channel;
// If we are going to support oneshot collNet + LL, then we would need to add connector index here // If we are going to support oneshot collNet + LL, then we would need to add connector index here
int nrecv=0, nsend=0; int nrecv=0, nsend=0;
// We compare with Fan::MaxRecv here because this->MaxRecv is always at least 1 // We compare with Fan::MaxRecv here because this->MaxRecv is always at least 1
while (nrecv < Fan::MaxRecv && recvPeers[nrecv] >= 0) { while (nrecv < Fan::MaxRecv && recvPeers[nrecv] >= 0) {
loadRecvConn(&channel->peers[recvPeers[nrecv]].recv[connIndex], nrecv); loadRecvConn(&channel->peers[recvPeers[nrecv]]->recv[connIndexRecv], nrecv);
nrecv++; nrecv++;
} }
while (nsend < MaxSend && sendPeers[nsend] >= 0) { while (nsend < MaxSend && sendPeers[nsend] >= 0) {
loadSendConn(&channel->peers[sendPeers[nsend]].send[connIndex], nsend); loadSendConn(&channel->peers[sendPeers[nsend]]->send[connIndexSend], nsend);
nsend++; nsend++;
} }
this->fan = Fan(nrecv, nsend); this->fan = Fan(nrecv, nsend);

View File

@ -363,22 +363,22 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
public: public:
__device__ Primitives( __device__ Primitives(
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0 void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group=0,
uint8_t connIndexRecv=0, uint8_t connIndexSend=0
): ):
redOp(redOpArg), redOp(redOpArg),
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE),
warpInBlock(threadIdx.x/WARP_SIZE), warpInBlock(threadIdx.x/WARP_SIZE),
flagThread((tid%8)==7), group(group&(uint16_t)0xFFFF), flagThread((tid%8)==7), group(group),
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) { stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) {
int connIndex = group >> 16;
auto *channel = &ncclShmem.channel; auto *channel = &ncclShmem.channel;
int nrecv=0, nsend=0; int nrecv=0, nsend=0;
while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) { while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) {
loadRecvConn(&channel->peers[recvPeers[nrecv]].recv[connIndex], nrecv); loadRecvConn(&channel->peers[recvPeers[nrecv]]->recv[connIndexRecv], nrecv);
nrecv++; nrecv++;
} }
while (nsend < MaxSend && sendPeers[nsend] >= 0) { while (nsend < MaxSend && sendPeers[nsend] >= 0) {
loadSendConn(&channel->peers[sendPeers[nsend]].send[connIndex], nsend); loadSendConn(&channel->peers[sendPeers[nsend]]->send[connIndexSend], nsend);
nsend++; nsend++;
} }
this->fan = Fan(nrecv, nsend); this->fan = Fan(nrecv, nsend);

View File

@ -5,9 +5,9 @@
************************************************************************/ ************************************************************************/
template<typename T, typename RedOp, typename Fan, int Direct, template<typename T, typename RedOp, typename Fan, int Direct,
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, bool NVLS> int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts>
class Primitives< class Primitives<
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, Unroll, NVLS>, P2p T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, Unroll, MultimemSrcs, MultimemDsts>, P2p
> { > {
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
static constexpr int Input=0, Output=1; static constexpr int Input=0, Output=1;
@ -23,10 +23,9 @@ class Primitives<
DirectWrite = 0x200, DirectWrite = 0x200,
DirectRead = 0x400, DirectRead = 0x400,
ThreadsSynced = 0x800, ThreadsSynced = 0x800,
NvlsMinPolling = 0x1000, NvlsMinPolling = 0x1000;
NvlsRecv = 0x2000;
const int tid, tidInBlock; const int tid, tidInBlock;
int nthreads; const int nthreads;
int nworkers; int nworkers;
const int stepSize; const int stepSize;
Fan fan; Fan fan;
@ -107,19 +106,19 @@ class Primitives<
inline __device__ uint64_t loadStepValue(uint64_t* ptr) { inline __device__ uint64_t loadStepValue(uint64_t* ptr) {
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
if (NVLS && (flags & NvlsMinPolling)) { if (flags & NvlsMinPolling) {
uint64_t ans; uint64_t ans;
asm("multimem.ld_reduce.acquire.sys.global.min.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr))); asm("multimem.ld_reduce.acquire.sys.global.min.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr)));
return ans; return ans;
} }
#endif #endif
// volatile is faster than acquire but not as correct. Make sure ReduceOrCopyMulti // volatile is faster than acquire but not as correct. Make sure reduceCopy
// loads data using volatile so it doesn't see stale data in L1. // loads data using volatile so it doesn't see stale data in L1.
return ld_volatile_global(ptr); return ld_volatile_global(ptr);
} }
template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst> template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst>
__device__ __forceinline__ void waitPeer(intptr_t dstIx, intptr_t remoteIx, int offset, int nelts) { __device__ __forceinline__ void waitPeer(intptr_t srcIx, intptr_t dstIx, int offset, int nelts) {
const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
const bool noRecvWait = DirectRecv && Src && (flags & DirectRead); // no wait when directly reading from remote input const bool noRecvWait = DirectRecv && Src && (flags & DirectRead); // no wait when directly reading from remote input
const bool noSendWait = DirectSend && (flags & (DirectRead|DirectWrite)); // no wait in empty send (e.g. directScatter) or direct remote write const bool noSendWait = DirectSend && (flags & (DirectRead|DirectWrite)); // no wait in empty send (e.g. directScatter) or direct remote write
@ -143,7 +142,7 @@ class Primitives<
ptrs[index] = connEltsFifo + loadInt(connOffsFifoPtr + (step%NCCL_STEPS))/sizeof(T); ptrs[index] = connEltsFifo + loadInt(connOffsFifoPtr + (step%NCCL_STEPS))/sizeof(T);
else if (isSendNotRecv && DirectSend) { else if (isSendNotRecv && DirectSend) {
if (flags & DirectWrite) { if (flags & DirectWrite) {
ptrs[index] = directBuff + remoteIx + offset; ptrs[index] = directBuff + dstIx + offset;
} else if (flags & DirectRead) { // empty send } else if (flags & DirectRead) { // empty send
ptrs[index] = nullptr; ptrs[index] = nullptr;
} else { } else {
@ -151,7 +150,7 @@ class Primitives<
} }
} else if (!isSendNotRecv && DirectRecv) { } else if (!isSendNotRecv && DirectRecv) {
if (flags & DirectRead) { if (flags & DirectRead) {
ptrs[index] = directBuff + remoteIx + offset; ptrs[index] = directBuff + srcIx + offset;
} else if (flags & DirectWrite) { } else if (flags & DirectWrite) {
ptrs[index] = directBuff + dstIx + offset; // send to next from my output buffer ptrs[index] = directBuff + dstIx + offset; // send to next from my output buffer
} else { } else {
@ -176,7 +175,7 @@ class Primitives<
template <int DirectRecv1, int DirectSend1, int Recv, int Send, int SrcBuf, int DstBuf> template <int DirectRecv1, int DirectSend1, int Recv, int Send, int SrcBuf, int DstBuf>
__device__ __forceinline__ void genericOp( __device__ __forceinline__ void genericOp(
intptr_t srcIx, intptr_t dstIx, intptr_t remoteIx, int nelem, bool postOp intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp
) { ) {
constexpr int DirectRecv = 1 && Direct && DirectRecv1; constexpr int DirectRecv = 1 && Direct && DirectRecv1;
constexpr int DirectSend = 1 && Direct && DirectSend1; constexpr int DirectSend = 1 && Direct && DirectSend1;
@ -225,20 +224,15 @@ class Primitives<
ncclShmem.groups[group].srcs[0] = userBuff + srcIx + offset; ncclShmem.groups[group].srcs[0] = userBuff + srcIx + offset;
if (Dst && (flags & (DstBuf==Input ? RoleInput : RoleOutput))) if (Dst && (flags & (DstBuf==Input ? RoleInput : RoleOutput)))
ncclShmem.groups[group].dsts[0] = userBuff + dstIx + offset; ncclShmem.groups[group].dsts[0] = userBuff + dstIx + offset;
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(dstIx, remoteIx, offset, sliceSize); waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(srcIx, dstIx, offset, sliceSize);
subBarrier(); subBarrier();
/* if user abort the kernel, we don't need to actually perform copy/reduce; just set size /* if user abort the kernel, we don't need to actually perform copy/reduce; just set size
* to 0 to avoid unnecessary workload. */ * to 0 to avoid unnecessary workload. */
int workSize = ncclShmem.aborted ? 0 : sliceSize; int workSize = ncclShmem.aborted ? 0 : sliceSize;
if (NVLS && ncclShmem.groups[group].nvlsRecv) { if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) {
void* src = ncclShmem.groups[group].srcs[0];
void* dst = ncclShmem.groups[group].dsts[0];
copyMultimemMultimem<RedOp>(tid, nworkers, ncclShmem.redOpArgs[0], postOp, src, dst, workSize,
cvta_to_shared(ncclScratchForWarp(tidInBlock/WARP_SIZE)));
} else if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) {
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
if (Send) { if (Send) {
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, MaxSend, /*PreOpSrcs*/0> reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, MaxSend, /*PreOpSrcs*/0>
(tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false, (tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false,
1, ncclShmem.groups[group].srcs, 1, ncclShmem.groups[group].srcs,
fan.nsend(), ncclShmem.groups[group].dsts+1, fan.nsend(), ncclShmem.groups[group].dsts+1,
@ -246,7 +240,7 @@ class Primitives<
} }
} else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem.groups[group].dsts[Dst] == nullptr) { } else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem.groups[group].dsts[Dst] == nullptr) {
// For broadcast in CollNet to do empty send // For broadcast in CollNet to do empty send
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1, /*PreOpSrcs*/0> reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs*/0>
(tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp, (tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp,
Recv, ncclShmem.groups[group].srcs, Recv, ncclShmem.groups[group].srcs,
Dst, ncclShmem.groups[group].dsts, Dst, ncclShmem.groups[group].dsts,
@ -254,7 +248,9 @@ class Primitives<
} else { } else {
constexpr int PreOpSrcs = SrcBuf != Input ? 0 : constexpr int PreOpSrcs = SrcBuf != Input ? 0 :
DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1;
ReduceOrCopyMulti<Unroll, RedOp, T, Recv+Src, Recv*MaxRecv+Src, Send+Dst, Send*MaxSend+Dst, PreOpSrcs> reduceCopy<Unroll, RedOp, T,
MultimemSrcs, Recv+Src, Recv*MaxRecv+Src,
MultimemDsts, Send+Dst, Send*MaxSend+Dst, PreOpSrcs>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs, Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs,
Send*fan.nsend()+Dst, ncclShmem.groups[group].dsts, Send*fan.nsend()+Dst, ncclShmem.groups[group].dsts,
@ -319,7 +315,7 @@ class Primitives<
void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset; void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset;
int realPeerSize = min(realSize, totalElem-pOffset); int realPeerSize = min(realSize, totalElem-pOffset);
if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) { if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) {
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1, PreOpSrcs>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize); reduceCopy<Unroll, RedOp, T, 0,1,1, 0,1,1, PreOpSrcs>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize);
// Mark for threadfence at the end // Mark for threadfence at the end
fenceNeeded |= true; fenceNeeded |= true;
} }
@ -342,7 +338,7 @@ class Primitives<
if (skip >= 0 && i >= skip) pOffset += peerElem; if (skip >= 0 && i >= skip) pOffset += peerElem;
void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset; void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset;
int realPeerSize = min(realSize, totalElem-pOffset); int realPeerSize = min(realSize, totalElem-pOffset);
if (realPeerSize > 0) ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1, /*PreOpSrcs=*/0>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); if (realPeerSize > 0) reduceCopy<Unroll, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize);
} }
} }
} }
@ -364,14 +360,7 @@ class Primitives<
} }
if (flags & RoleWaitRecv) { if (flags & RoleWaitRecv) {
ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs() ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs()
if ((index == 0) && (flags & RoleWaitRecv)) { flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0;
if (conn->flags & NCCL_NVLS_MIN_POLL) {
flags |= NvlsMinPolling;
ncclShmem.groups[group].nvlsRecv = 1;
} else {
ncclShmem.groups[group].nvlsRecv = 0;
}
}
connStepPtr = conn->tail; connStepPtr = conn->tail;
connStepCache = loadStepValue(connStepPtr); connStepCache = loadStepValue(connStepPtr);
flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0;
@ -448,16 +437,14 @@ class Primitives<
public: public:
__device__ Primitives( __device__ Primitives(
int tid, int nthreads, int const *recvPeers, int const *sendPeers, int tid, int nthreads, int const *recvPeers, int const *sendPeers,
void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group=0,
uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclWorkElem* e = nullptr
): ):
tid(tid), tidInBlock(threadIdx.x), tid(tid), nthreads(nthreads), tidInBlock(threadIdx.x), group(group),
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) { stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) {
// For send operations, we need an extra warp to overlap the threadfence and the copy // For send operations, we need an extra warp to overlap the threadfence and the copy
this->nthreads = nthreads;
this->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0); this->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0);
this->group = group & (uint16_t)0xFFFF;
int connIndex = group >> 16;
int nrecv=0, nsend=0; int nrecv=0, nsend=0;
while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++; while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++;
@ -487,8 +474,8 @@ class Primitives<
if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index]; if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index];
if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index]; if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index];
loadRecvConn(&ncclShmem.channel.peers[peer], connIndex, e); loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, e);
loadSendConn(&ncclShmem.channel.peers[peer], connIndex, e); loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, e);
setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclWorkElemReg*)e); setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclWorkElemReg*)e);
} }
@ -593,62 +580,62 @@ class Primitives<
} }
__device__ __forceinline__ void send(intptr_t inpIx, int eltN) { __device__ __forceinline__ void send(intptr_t inpIx, int eltN) {
genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, -1, eltN, false); genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, eltN, false);
} }
__device__ __forceinline__ void sendFromOutput(intptr_t outIx, int eltN) { __device__ __forceinline__ void sendFromOutput(intptr_t outIx, int eltN) {
genericOp<0, 0, 0, 1, Output, -1>(outIx, -1, -1, eltN, false); genericOp<0, 0, 0, 1, Output, -1>(outIx, -1, eltN, false);
} }
__device__ __forceinline__ void directSend(intptr_t inpIx, intptr_t remoteOutIx, int eltN) { __device__ __forceinline__ void directSend(intptr_t inpIx, intptr_t outIx, int eltN) {
genericOp<0, 1, 0, 1, Input, -1>(inpIx, -1, remoteOutIx, eltN, false); genericOp<0, 1, 0, 1, Input, -1>(inpIx, outIx, eltN, false);
} }
__device__ __forceinline__ void directSendFromOutput(intptr_t outIx, intptr_t remoteOutIx, int eltN) { __device__ __forceinline__ void directSendFromOutput(intptr_t outIx, int eltN) {
genericOp<0, 1, 0, 1, Output, -1>(outIx, -1, remoteOutIx, eltN, false); genericOp<0, 1, 0, 1, Output, -1>(outIx, outIx, eltN, false);
} }
__device__ __forceinline__ void recv(intptr_t outIx, int eltN, bool postOp=false) { __device__ __forceinline__ void recv(intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 0, -1, Output>(-1, outIx, -1, eltN, postOp); genericOp<0, 0, 1, 0, -1, Output>(-1, outIx, eltN, postOp);
} }
__device__ __forceinline__ void directRecv(intptr_t outIx, int eltN) { __device__ __forceinline__ void directRecv(intptr_t outIx, int eltN) {
genericOp<1, 0, 1, 0, -1, Output>(-1, outIx, -1, eltN, /*postOp=*/false); genericOp<1, 0, 1, 0, -1, Output>(-1, outIx, eltN, /*postOp=*/false);
} }
__device__ __forceinline__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { __device__ __forceinline__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 0, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp); genericOp<0, 0, 0, 1, Input, Output>(inpIx, outIx, eltN, postOp);
} }
__device__ __forceinline__ void directCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { __device__ __forceinline__ void directCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp); genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, eltN, postOp);
} }
__device__ __forceinline__ void recvSend(int eltN, bool postOp=false) { __device__ __forceinline__ void recvSend(int eltN, bool postOp=false) {
genericOp<0, 0, 1, 1, -1, -1>(-1, -1, -1, eltN, postOp); genericOp<0, 0, 1, 1, -1, -1>(-1, -1, eltN, postOp);
} }
__device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) { __device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, -1, eltN, postOp); genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, eltN, postOp);
} }
__device__ __forceinline__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) { __device__ __forceinline__ void directRecvCopySend(intptr_t outIx, int eltN) {
genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, false); genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, eltN, false);
} }
__device__ __forceinline__ void recvCopyDirectSend(intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { __device__ __forceinline__ void recvCopyDirectSend(intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, postOp); genericOp<0, 1, 1, 1, -1, Output>(-1, outIx, eltN, postOp);
} }
__device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { __device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, -1, eltN, postOp); genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, eltN, postOp);
} }
__device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) { __device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, -1, eltN, postOp); genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, eltN, postOp);
} }
__device__ __forceinline__ void directRecvReduceSend(intptr_t inpIx, intptr_t remoteInpIx, int eltN, bool postOp=false) { __device__ __forceinline__ void directRecvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) {
genericOp<1, 0, 1, 1, Input, -1>(inpIx, -1, remoteInpIx, eltN, postOp); genericOp<1, 0, 1, 1, Input, -1>(inpIx, -1, eltN, postOp);
} }
__device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { __device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp); genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
} }
__device__ __forceinline__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { __device__ __forceinline__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
// Direct is only for the send part // Direct is only for the send part
genericOp<0, 1, 1, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp); genericOp<0, 1, 1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
} }
__device__ __forceinline__ void __device__ __forceinline__ void

View File

@ -55,9 +55,14 @@ struct Apply_PostOp/*{
static BytePack<EltPerPack*sizeof(T)> postOp(Fn fn, BytePack<EltPerPack*sizeof(T)> a); static BytePack<EltPerPack*sizeof(T)> postOp(Fn fn, BytePack<EltPerPack*sizeof(T)> a);
}*/; }*/;
template<typename Fn> template<typename Fn>
struct LoadMultimem_BigPackSize/*{
// If non-zero, then this and sizeof(T) are valid pack sizes for LoadMultimem,
// otherwise there are no valid pack sizes for LoadMultimem.
static constexpr int BigPackSize = 0;
}*/;
template<typename Fn, int BytePerPack>
struct Apply_LoadMultimem/*{ struct Apply_LoadMultimem/*{
static constexpr int PackSize; // 0 if not implemented static BytePack<BytePerPack> load(Fn fn, uintptr_t addr);
static BytePack<PackSize> load(Fn fn, uintptr_t addr);
}*/; }*/;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -69,7 +74,7 @@ struct Apply_LoadMultimem/*{
template<typename Fn, typename Pack> template<typename Fn, typename Pack>
__device__ __forceinline__ Pack applyReduce(Fn fn, Pack a, Pack b) { __device__ __forceinline__ Pack applyReduce(Fn fn, Pack a, Pack b) {
return fromPack<Pack>( return fromPack<Pack>(
Apply_Reduce<Fn, sizeof(Pack)/sizeof(typename Fn::EltType)> Apply_Reduce<Fn, BytePackOf<Pack>::Size/sizeof(typename Fn::EltType)>
::reduce(fn, toPack(a), toPack(b)) ::reduce(fn, toPack(a), toPack(b))
); );
} }
@ -77,7 +82,7 @@ __device__ __forceinline__ Pack applyReduce(Fn fn, Pack a, Pack b) {
template<typename Fn, typename Pack> template<typename Fn, typename Pack>
__device__ __forceinline__ Pack applyPreOp(Fn fn, Pack a) { __device__ __forceinline__ Pack applyPreOp(Fn fn, Pack a) {
return fromPack<Pack>( return fromPack<Pack>(
Apply_PreOp<Fn, sizeof(Pack)/sizeof(typename Fn::EltType)> Apply_PreOp<Fn, BytePackOf<Pack>::Size/sizeof(typename Fn::EltType)>
::preOp(fn, toPack(a)) ::preOp(fn, toPack(a))
); );
} }
@ -85,19 +90,27 @@ __device__ __forceinline__ Pack applyPreOp(Fn fn, Pack a) {
template<typename Fn, typename Pack> template<typename Fn, typename Pack>
__device__ __forceinline__ Pack applyPostOp(Fn fn, Pack a) { __device__ __forceinline__ Pack applyPostOp(Fn fn, Pack a) {
return fromPack<Pack>( return fromPack<Pack>(
Apply_PostOp<Fn, sizeof(Pack)/sizeof(typename Fn::EltType)> Apply_PostOp<Fn, BytePackOf<Pack>::Size/sizeof(typename Fn::EltType)>
::postOp(fn, toPack(a)) ::postOp(fn, toPack(a))
); );
} }
template<typename Fn> template<typename Fn, int BytePerPack>
__device__ __forceinline__ BytePack<Apply_LoadMultimem<Fn>::PackSize> applyLoadMultimem(Fn fn, uintptr_t addr) { __device__ __forceinline__ BytePack<BytePerPack> applyLoadMultimem(Fn fn, uintptr_t addr) {
return Apply_LoadMultimem<Fn>::load(fn, addr); return Apply_LoadMultimem<Fn, BytePerPack>::load(fn, addr);
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Apply_Reduce // Apply_Reduce
// Nonsensical base case
template<typename Fn>
struct Apply_Reduce<Fn, /*EltPerPack=*/0> {
__device__ static BytePack<0> reduce(Fn fn, BytePack<0> a, BytePack<0> b) {
return {};
}
};
// General recursive definition (EltPerPack > 1). This is how we iterate over // General recursive definition (EltPerPack > 1). This is how we iterate over
// all elements in a pack of any size, by breaking it into halves. Eventually // all elements in a pack of any size, by breaking it into halves. Eventually
// we'll hit a base case (a more specific template specialization which takes // we'll hit a base case (a more specific template specialization which takes
@ -282,6 +295,14 @@ struct Apply_PreOp<Fn, /*EltPerPack=*/1> {
return a; return a;
} }
}; };
// Base case definition (EltPerPack == 0), is nonsense!
template<typename Fn>
struct Apply_PreOp<Fn, /*EltPerPack=*/0> {
static constexpr bool IsIdentity = true;
__device__ static BytePack<0> preOp(Fn fn, BytePack<0> a) {
return {};
}
};
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Apply_PostOp // Apply_PostOp
@ -315,6 +336,14 @@ struct Apply_PostOp<Fn, /*EltPerPack=*/1> {
return a; return a;
} }
}; };
// Base case definition (EltPerPack == 0), is nonsense!
template<typename Fn>
struct Apply_PostOp<Fn, /*EltPerPack=*/0> {
static constexpr bool IsIdentity = true;
__device__ static BytePack<0> postOp(Fn fn, BytePack<0> a) {
return {};
}
};
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -505,11 +534,6 @@ struct Apply_PostOp<FuncSumPostDiv<T>, /*EltPerPack=*/1> {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Apply_LoadMultimem // Apply_LoadMultimem
template<typename Fn>
struct Apply_LoadMultimem {
static constexpr int PackSize = 0; // Indicates not implemented
};
#define SIZEOF_BytePack_field_u16 2 #define SIZEOF_BytePack_field_u16 2
#define PTX_REG_BytePack_field_u16 "h" #define PTX_REG_BytePack_field_u16 "h"
@ -521,11 +545,11 @@ struct Apply_LoadMultimem {
#define DEFINE_Apply_LoadMultimem(Fn, T, op, ptx_ty, pack_field) \ #define DEFINE_Apply_LoadMultimem(Fn, T, op, ptx_ty, pack_field) \
template<> \ template<> \
struct Apply_LoadMultimem<Fn<T>> { \ struct Apply_LoadMultimem<Fn<T>, SIZEOF_BytePack_field_##pack_field> { \
static constexpr int PackSize = 1*(SIZEOF_BytePack_field_##pack_field); \ static constexpr int PackSize = SIZEOF_BytePack_field_##pack_field; \
__device__ static BytePack<PackSize> load(Fn<T> fn, uintptr_t addr) { \ __device__ static BytePack<PackSize> load(Fn<T> fn, uintptr_t addr) { \
BytePack<PackSize> ans; \ BytePack<PackSize> ans; \
asm("multimem.ld_reduce.global." #op "." #ptx_ty " %0, [%1];" \ asm("multimem.ld_reduce.relaxed.sys.global." #op "." #ptx_ty " %0, [%1];" \
: "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field) \ : "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field) \
: "l"(addr)); \ : "l"(addr)); \
return ans; \ return ans; \
@ -533,11 +557,11 @@ struct Apply_LoadMultimem {
}; };
#define DEFINE_Apply_LoadMultimem_v4(Fn, T, op, ptx_ty, pack_field) \ #define DEFINE_Apply_LoadMultimem_v4(Fn, T, op, ptx_ty, pack_field) \
template<> \ template<> \
struct Apply_LoadMultimem<Fn<T>> { \ struct Apply_LoadMultimem<Fn<T>, 4*(SIZEOF_BytePack_field_##pack_field)> { \
static constexpr int PackSize = 4*(SIZEOF_BytePack_field_##pack_field); \ static constexpr int PackSize = 4*(SIZEOF_BytePack_field_##pack_field); \
__device__ static BytePack<PackSize> load(Fn<T> fn, uintptr_t addr) { \ __device__ static BytePack<PackSize> load(Fn<T> fn, uintptr_t addr) { \
BytePack<PackSize> ans; \ BytePack<PackSize> ans; \
asm("multimem.ld_reduce.global." #op ".v4." #ptx_ty " {%0,%1,%2,%3}, [%4];" \ asm("multimem.ld_reduce.relaxed.sys.global." #op ".v4." #ptx_ty " {%0,%1,%2,%3}, [%4];" \
: "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[0]), \ : "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[0]), \
"=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[1]), \ "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[1]), \
"=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[2]), \ "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[2]), \
@ -546,8 +570,45 @@ struct Apply_LoadMultimem {
return ans; \ return ans; \
} \ } \
}; };
#define DEFINE_Apply_LoadMultimem_v4x2_and_subhalf(Fn, T, op, ptx_ty, pack_field) \
DEFINE_Apply_LoadMultimem_v4(Fn, T, op, ptx_ty, pack_field) \
template<> \
struct Apply_LoadMultimem<Fn<T>, sizeof(T)> { \
__device__ static BytePack<sizeof(T)> load(Fn<T> fn, uintptr_t addr) { \
BytePack<2*sizeof(T)> tmp; \
asm("multimem.ld_reduce.relaxed.sys.global." #op "." #ptx_ty " %0, [%1];" \
: "=" PTX_REG_BytePack_field_##pack_field(tmp.pack_field) \
: "l"(addr & -uintptr_t(sizeof(T)))); \
return tmp.half[(addr/sizeof(T))%2]; \
} \
};
template<typename Fn, int BytePerPack>
struct Apply_LoadMultimem {
__device__ static BytePack<BytePerPack> load(Fn fn, uintptr_t addr) {
__trap();
return {};
}
};
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
template<typename Fn>
struct LoadMultimem_BigPackSize {
using T = typename Fn::EltType;
static constexpr bool IsSum = std::is_same<Fn, FuncSum<T>>::value ||
std::is_same<Fn, FuncPreMulSum<T>>::value ||
std::is_same<Fn, FuncSumPostDiv<T>>::value;
static constexpr bool IsMinOrMax = std::is_same<Fn, FuncMin<T>>::value ||
std::is_same<Fn, FuncMax<T>>::value;
static constexpr bool IsFloat = IsFloatingPoint<T>::value;
static constexpr int BigPackSize =
IsFloat && IsSum && sizeof(T) < 8 ? 16 :
IsFloat && IsSum ? 8 :
IsFloat && IsMinOrMax && sizeof(T)==2 ? 16 :
!IsFloat && (IsSum||IsMinOrMax) && sizeof(T)>=4 ? sizeof(T) :
/*multimem.ld_reduce not supported:*/ 0;
};
DEFINE_Apply_LoadMultimem(FuncSum, uint32_t, add, u32, u32) DEFINE_Apply_LoadMultimem(FuncSum, uint32_t, add, u32, u32)
DEFINE_Apply_LoadMultimem(FuncMin, uint32_t, min, u32, u32) DEFINE_Apply_LoadMultimem(FuncMin, uint32_t, min, u32, u32)
DEFINE_Apply_LoadMultimem(FuncMax, uint32_t, max, u32, u32) DEFINE_Apply_LoadMultimem(FuncMax, uint32_t, max, u32, u32)
@ -564,23 +625,30 @@ struct Apply_LoadMultimem {
DEFINE_Apply_LoadMultimem(FuncMin, int64_t, min, s64, u64) DEFINE_Apply_LoadMultimem(FuncMin, int64_t, min, s64, u64)
DEFINE_Apply_LoadMultimem(FuncMax, int64_t, max, s64, u64) DEFINE_Apply_LoadMultimem(FuncMax, int64_t, max, s64, u64)
DEFINE_Apply_LoadMultimem(FuncSum, float, add, f32, u32)
DEFINE_Apply_LoadMultimem_v4(FuncSum, float, add, f32, u32) DEFINE_Apply_LoadMultimem_v4(FuncSum, float, add, f32, u32)
DEFINE_Apply_LoadMultimem(FuncSum, double, add, f64, u64) DEFINE_Apply_LoadMultimem(FuncSum, double, add, f64, u64)
DEFINE_Apply_LoadMultimem_v4(FuncSum, half, add, f16x2, u32) DEFINE_Apply_LoadMultimem_v4x2_and_subhalf(FuncSum, half, add, f16x2, u32)
DEFINE_Apply_LoadMultimem_v4(FuncMin, half, min, f16x2, u32) DEFINE_Apply_LoadMultimem_v4x2_and_subhalf(FuncMin, half, min, f16x2, u32)
DEFINE_Apply_LoadMultimem_v4(FuncMax, half, max, f16x2, u32) DEFINE_Apply_LoadMultimem_v4x2_and_subhalf(FuncMax, half, max, f16x2, u32)
#if defined(__CUDA_BF16_TYPES_EXIST__) #if defined(__CUDA_BF16_TYPES_EXIST__)
DEFINE_Apply_LoadMultimem_v4(FuncSum, __nv_bfloat16, add, bf16x2, u32) DEFINE_Apply_LoadMultimem_v4x2_and_subhalf(FuncSum, __nv_bfloat16, add, bf16x2, u32)
DEFINE_Apply_LoadMultimem_v4(FuncMin, __nv_bfloat16, min, bf16x2, u32) DEFINE_Apply_LoadMultimem_v4x2_and_subhalf(FuncMin, __nv_bfloat16, min, bf16x2, u32)
DEFINE_Apply_LoadMultimem_v4(FuncMax, __nv_bfloat16, max, bf16x2, u32) DEFINE_Apply_LoadMultimem_v4x2_and_subhalf(FuncMax, __nv_bfloat16, max, bf16x2, u32)
#endif #endif
#else
template<typename Fn>
struct LoadMultimem_BigPackSize {
static constexpr int BigPackSize = 0;
};
#endif #endif
#undef DEFINE_Apply_LoadMultimem #undef DEFINE_Apply_LoadMultimem
#undef DEFINE_Apply_LoadMultimem_v4 #undef DEFINE_Apply_LoadMultimem_v4
#undef DEFINE_Apply_LoadMultimem_v4x2_and_subhalf
#undef SIZEOF_BytePack_field_u64 #undef SIZEOF_BytePack_field_u64
#undef PTX_REG_BytePack_field_u64 #undef PTX_REG_BytePack_field_u64
#undef SIZEOF_BytePack_field_u32 #undef SIZEOF_BytePack_field_u32

View File

@ -108,19 +108,19 @@ struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROT
if (tid < tidEndScatter) { if (tid < tidEndScatter) {
// Scatter // Scatter
int group = (0*Proto::MaxGroupWidth) | (0<<16);
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL, args->redOpArg, group, args); prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL,
args->redOpArg, 0*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*chunkSize; ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);
prims.scatter(offset, nvls->nHeads*size, nelem, size, -1, 0); prims.scatter(offset, nvls->nHeads*size, nelem, size, -1, 0);
} }
} else if (tid < tidEndReduce) { } else if (tid < tidEndReduce) {
int group = (3*Proto::MaxGroupWidth) | (1<<16); // Reduce through NVLS
// Reduce through MC
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0> Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, args->recvbuff, args->redOpArg, group, args); prims(tid-tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, args->recvbuff,
args->redOpArg, 3*Proto::MaxGroupWidth, 1, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*chunkSize; ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset); int nelem = min(chunkSize, size-offset);

View File

@ -11,14 +11,14 @@
template<typename T, typename RedOp> template<typename T, typename RedOp>
struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> { struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
template<typename Proto> template<typename Proto>
__device__ void runSend(const int tid, const int nthreads, const int group, struct ncclWorkElemP2p* args) { __device__ void runSend(const int tid, const int nthreads, const uint8_t group, struct ncclWorkElemP2p* args) {
void* buff = reinterpret_cast<void*>(uintptr_t(args->buffHi32)<<32 | args->buffLo32); void* buff = reinterpret_cast<void*>(uintptr_t(args->buffHi32)<<32 | args->buffLo32);
ssize_t count = reinterpret_cast<size_t>(size_t(args->countHi32)<<32 | args->countLo32); ssize_t count = reinterpret_cast<size_t>(size_t(args->countHi32)<<32 | args->countLo32);
if (args->peer == ncclShmem.comm.rank) { if (args->peer == ncclShmem.comm.rank) {
struct ncclWorkElemP2p* recvArgs = args-1; struct ncclWorkElemP2p* recvArgs = args-1;
void* recvBuff = reinterpret_cast<void*>(uintptr_t(recvArgs->buffHi32)<<32 | recvArgs->buffLo32); void* recvBuff = reinterpret_cast<void*>(uintptr_t(recvArgs->buffHi32)<<32 | recvArgs->buffLo32);
if (buff != recvBuff) { if (buff != recvBuff) {
ReduceOrCopyMulti<COLL_UNROLL, RedOp, T, 1, 1, 1, 1, /*PreOpSrcs=*/0> reduceCopy<COLL_UNROLL, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
(tid, nthreads, 0, nullptr, false, 1, &buff, 1, &recvBuff, count); (tid, nthreads, 0, nullptr, false, 1, &buff, 1, &recvBuff, count);
} }
} else { } else {
@ -26,7 +26,7 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
if (args->proto == NCCL_PROTO_LL) chunkSize /= 2; if (args->proto == NCCL_PROTO_LL) chunkSize /= 2;
int const peer = args->peer; int const peer = args->peer;
Primitives<T, RedOp, FanAsymmetric<0, 1>, 1, Proto, 1> prims Primitives<T, RedOp, FanAsymmetric<0, 1>, 1, Proto, 1> prims
(tid, nthreads, nullptr, &peer, buff, nullptr, /*redOpArg(ignored)=*/0, group); (tid, nthreads, nullptr, &peer, buff, nullptr, /*redOpArg(ignored)=*/0, group, 1, 1);
size_t offset = 0; size_t offset = 0;
do { do {
int nelem = min(size_t(chunkSize), count-offset); int nelem = min(size_t(chunkSize), count-offset);
@ -37,7 +37,7 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
} }
template<typename Proto> template<typename Proto>
__device__ void runRecv(const int tid, const int nthreads, const int group, struct ncclWorkElemP2p* args) { __device__ void runRecv(const int tid, const int nthreads, const uint8_t group, struct ncclWorkElemP2p* args) {
if (args->peer != ncclShmem.comm.rank) { if (args->peer != ncclShmem.comm.rank) {
void* buff = reinterpret_cast<void*>(uintptr_t(args->buffHi32)<<32 | args->buffLo32); void* buff = reinterpret_cast<void*>(uintptr_t(args->buffHi32)<<32 | args->buffLo32);
ssize_t count = reinterpret_cast<size_t>(size_t(args->countHi32)<<32 | args->countLo32); ssize_t count = reinterpret_cast<size_t>(size_t(args->countHi32)<<32 | args->countLo32);
@ -45,7 +45,7 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
if (args->proto == NCCL_PROTO_LL) chunkSize /= 2; // This is to account for chunkEffectiveSize if (args->proto == NCCL_PROTO_LL) chunkSize /= 2; // This is to account for chunkEffectiveSize
int const peer = args->peer; int const peer = args->peer;
Primitives<T, RedOp, FanAsymmetric<1, 0>, 1, Proto, 1> prims Primitives<T, RedOp, FanAsymmetric<1, 0>, 1, Proto, 1> prims
(tid, nthreads, &peer, nullptr, nullptr, buff, /*redOpArg(ignored)=*/0, group); (tid, nthreads, &peer, nullptr, nullptr, buff, /*redOpArg(ignored)=*/0, group, 1, 1);
size_t offset = 0; size_t offset = 0;
do { do {
int nelem = min(size_t(chunkSize), count-offset); int nelem = min(size_t(chunkSize), count-offset);
@ -65,11 +65,10 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
// warpStarts were rounded thanks to int division, but for group number we need to round the other way around // warpStarts were rounded thanks to int division, but for group number we need to round the other way around
// So we mirror wid then mirror again the group. // So we mirror wid then mirror again the group.
#define NWARPS (NCCL_MAX_NTHREADS/WARP_SIZE) #define NWARPS (NCCL_MAX_NTHREADS/WARP_SIZE)
int group = ngroups-1- (NWARPS-1-wid) * ngroups / NWARPS; uint8_t group = ngroups-1- (NWARPS-1-wid) * ngroups / NWARPS;
args += group; args += group;
tid -= args->warpStart * WARP_SIZE; tid -= args->warpStart * WARP_SIZE;
int nthreads = args->nWarps * WARP_SIZE; int nthreads = args->nWarps * WARP_SIZE;
group |= 1<<16; // Used to select connIndex 1
if (args->p2pType == ncclWorkP2pTypeUnused) return; if (args->p2pType == ncclWorkP2pTypeUnused) return;
if (tid >= nthreads || args->peer == -1) return; if (tid >= nthreads || args->peer == -1) return;

View File

@ -74,6 +74,8 @@ void ncclDebugInit() {
mask = NCCL_ALLOC; mask = NCCL_ALLOC;
} else if (strcasecmp(subsys, "CALL") == 0) { } else if (strcasecmp(subsys, "CALL") == 0) {
mask = NCCL_CALL; mask = NCCL_CALL;
} else if (strcasecmp(subsys, "PROXY") == 0) {
mask = NCCL_PROXY;
} else if (strcasecmp(subsys, "NVLS") == 0) { } else if (strcasecmp(subsys, "NVLS") == 0) {
mask = NCCL_NVLS; mask = NCCL_NVLS;
} else if (strcasecmp(subsys, "ALL") == 0) { } else if (strcasecmp(subsys, "ALL") == 0) {

View File

@ -33,7 +33,8 @@ struct ncclKernelMatch {
NCCL_FUNC5(func, RING, devredop, type, specialized), \ NCCL_FUNC5(func, RING, devredop, type, specialized), \
NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, specialized), \ NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, specialized), \
NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, specialized), \ NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, specialized), \
NCCL_FUNC5(func, NVLS, devredop, type, specialized) NCCL_FUNC5(func, NVLS, devredop, type, specialized), \
NCCL_FUNC5(func, NVLS_TREE, devredop, type, specialized)
#ifdef __CUDA_BF16_TYPES_EXIST__ #ifdef __CUDA_BF16_TYPES_EXIST__
#define HAVE_BFLOAT16 1 #define HAVE_BFLOAT16 1
@ -215,12 +216,13 @@ static void finishWork(struct ncclWork* work) {
static void appendWorkElemP2p( static void appendWorkElemP2p(
struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId, struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId,
struct ncclWorkElemP2p const *elem struct ncclWorkElemP2p const *elem, bool fuseOk
) { ) {
constexpr int funcIndex = FUNC_INDEX_P2P; constexpr int funcIndex = FUNC_INDEX_P2P;
struct ncclKernelPlan::Channel* chan = &plan->channels[channelId]; struct ncclKernelPlan::Channel* chan = &plan->channels[channelId];
struct ncclWorkList* q = ncclIntruQueueTail(&chan->workQueue); struct ncclWorkList* q = ncclIntruQueueTail(&chan->workQueue);
if (q && funcIndex == q->work.header.funcIndex) { if (q && funcIndex == q->work.header.funcIndex) {
if (!fuseOk) goto NewWork;
if (chan->p2pTailElem[elem->p2pType-1] < NCCL_MAX_WORK_ELEMENTS_P2P) { if (chan->p2pTailElem[elem->p2pType-1] < NCCL_MAX_WORK_ELEMENTS_P2P) {
for (int e = -2 + chan->p2pTailElem[elem->p2pType-1]; e >= 0; e -= 2) { for (int e = -2 + chan->p2pTailElem[elem->p2pType-1]; e >= 0; e -= 2) {
// Can't have multiple elements of the same ncclWork communicate with the // Can't have multiple elements of the same ncclWork communicate with the
@ -349,7 +351,7 @@ NCCL_PARAM(P2pLLThreshold, "P2P_LL_THRESHOLD", 16384);
// ensure *nWorkBudget >= 1 upon entry. // ensure *nWorkBudget >= 1 upon entry.
static ncclResult_t addP2pToPlan( static ncclResult_t addP2pToPlan(
struct ncclComm* comm, struct ncclKernelPlan* plan, int* nWorkBudget, struct ncclComm* comm, struct ncclKernelPlan* plan, int* nWorkBudget,
bool isSendNotRecv, int peer, int chunk, void *addr, size_t bytes bool isSendNotRecv, int peer, int chunk, void *addr, size_t bytes, bool fuseOk
) { ) {
struct ncclInfo info = { struct ncclInfo info = {
isSendNotRecv ? ncclFuncSend : ncclFuncRecv, isSendNotRecv ? ncclFuncSend : ncclFuncRecv,
@ -364,7 +366,7 @@ static ncclResult_t addP2pToPlan(
// 1 is connIndex // 1 is connIndex
struct ncclConnInfo* conn = isSendNotRecv ? struct ncclConnInfo* conn = isSendNotRecv ?
&comm->channels[channelId].peers[peer].send[1].conn : &comm->channels[channelId].peers[peer].recv[1].conn; &comm->channels[channelId].peers[peer]->send[1].conn : &comm->channels[channelId].peers[peer]->recv[1].conn;
info.protocol = ((conn->buffs[NCCL_PROTO_LL] != nullptr) && bytes <= ncclParamP2pLLThreshold()) ? NCCL_PROTO_LL : NCCL_PROTO_SIMPLE; info.protocol = ((conn->buffs[NCCL_PROTO_LL] != nullptr) && bytes <= ncclParamP2pLLThreshold()) ? NCCL_PROTO_LL : NCCL_PROTO_SIMPLE;
struct ncclProxyOp proxyOp = {}; struct ncclProxyOp proxyOp = {};
@ -382,7 +384,7 @@ static ncclResult_t addP2pToPlan(
elem.chunkSize = info.chunkSize; // computed by ncclProxyComputeP2p elem.chunkSize = info.chunkSize; // computed by ncclProxyComputeP2p
*nWorkBudget += plan->channels[channelId].nWork; *nWorkBudget += plan->channels[channelId].nWork;
appendWorkElemP2p(comm, plan, channelId, &elem); appendWorkElemP2p(comm, plan, channelId, &elem, fuseOk);
*nWorkBudget -= plan->channels[channelId].nWork; *nWorkBudget -= plan->channels[channelId].nWork;
// Calculate the opCount after appendWorkElemP2p since it will always return // Calculate the opCount after appendWorkElemP2p since it will always return
@ -553,7 +555,7 @@ static ncclResult_t scheduleCollTasksToPlan(
info.sliceSteps = head->sliceSteps; info.sliceSteps = head->sliceSteps;
NCCLCHECK(ncclInfoSetDerived(&info, comm->nRanks)); NCCLCHECK(ncclInfoSetDerived(&info, comm->nRanks));
if (nAggOps > 1) { if (nAggOps > 1) {
int maxChannels = aggInfo.algorithm == NCCL_ALGO_NVLS ? comm->nvlsChannels : comm->nChannels; int maxChannels = aggInfo.algorithm == NCCL_ALGO_NVLS || aggInfo.algorithm == NCCL_ALGO_NVLS_TREE ? comm->nvlsChannels : comm->nChannels;
info.nChannels = DIVUP(info.nBytes, bytePerChannel[collNetSupport]); info.nChannels = DIVUP(info.nBytes, bytePerChannel[collNetSupport]);
info.nChannels = std::max(1, std::min(info.nChannels, maxChannels)); info.nChannels = std::max(1, std::min(info.nChannels, maxChannels));
info.algorithm = aggInfo.algorithm; info.algorithm = aggInfo.algorithm;
@ -578,7 +580,7 @@ static ncclResult_t scheduleCollTasksToPlan(
NCCLCHECK(registerIntraNodeBuffers(comm, plan, &info, &regBufUsed, regBufSend, regBufRecv)); NCCLCHECK(registerIntraNodeBuffers(comm, plan, &info, &regBufUsed, regBufSend, regBufRecv));
} }
int maxChannels = info.algorithm == NCCL_ALGO_NVLS ? comm->nvlsChannels : comm->nChannels; int maxChannels = info.algorithm == NCCL_ALGO_NVLS || aggInfo.algorithm == NCCL_ALGO_NVLS_TREE ? comm->nvlsChannels : comm->nChannels;
NCCLCHECK(addCollToPlan(comm, plan, nWorkBudget, workFuncIndex, &workElem, &proxyOp, NCCLCHECK(addCollToPlan(comm, plan, nWorkBudget, workFuncIndex, &workElem, &proxyOp,
maxChannels, info.nChannels, info.nBytes, regBufUsed, regBufSend, regBufRecv)); maxChannels, info.nChannels, info.nBytes, regBufUsed, regBufSend, regBufRecv));
tasks->nTasksColl -= 1; tasks->nTasksColl -= 1;
@ -632,12 +634,15 @@ static ncclResult_t scheduleP2pTasksToPlan(
// Avoid overloading channels with 8+ operations as we loose the sync warp, hence a bit of bandwidth. // Avoid overloading channels with 8+ operations as we loose the sync warp, hence a bit of bandwidth.
while (nChannelsMax*nRanks > comm->p2pnChannels*4 && nChannelsMax > 1) nChannelsMax /= 2; while (nChannelsMax*nRanks > comm->p2pnChannels*4 && nChannelsMax > 1) nChannelsMax /= 2;
bool fuseOk;
// We can perform 8 send/recv per round per CTA. Make sure we jump between fused blocks at node boundaries.
while (tasks->nTasksP2p != 0) { while (tasks->nTasksP2p != 0) {
for (int i=0; i < nRanks; i++) { for (int i=0; i < tasks->p2pOrderSteps; i++) {
int sendPeer = sendOrder[i]; int sendPeer = sendOrder[i];
int recvPeer = recvOrder[i]; int recvPeer = recvOrder[i];
struct ncclTaskP2p* send = ncclIntruQueueHead(&peers[sendPeer].sendQueue); if ((i % (NCCL_MAX_WORK_ELEMENTS_P2P/2)) == 0) fuseOk = false;
struct ncclTaskP2p* recv = ncclIntruQueueHead(&peers[recvPeer].recvQueue); struct ncclTaskP2p* send = sendPeer != -1 ? ncclIntruQueueHead(&peers[sendPeer].sendQueue) : NULL;
struct ncclTaskP2p* recv = recvPeer != -1 ? ncclIntruQueueHead(&peers[recvPeer].recvQueue) : NULL;
if (sendPeer == comm->rank) { if (sendPeer == comm->rank) {
if (recvPeer != comm->rank) { if (recvPeer != comm->rank) {
WARN("Sendrecv plan not aligned for self"); WARN("Sendrecv plan not aligned for self");
@ -676,7 +681,8 @@ static ncclResult_t scheduleP2pTasksToPlan(
if (recvChunkBytes != 0) { if (recvChunkBytes != 0) {
if (recvChunkBytes == -1) recvChunkBytes = 0; if (recvChunkBytes == -1) recvChunkBytes = 0;
if (*nWorkBudget < 1) return ncclSuccess; // ensure room in budget if (*nWorkBudget < 1) return ncclSuccess; // ensure room in budget
NCCLCHECK(addP2pToPlan(comm, plan, nWorkBudget, /*isSendNotRecv=*/false, recvPeer, recv->chunk, recvPtr, recvChunkBytes)); NCCLCHECK(addP2pToPlan(comm, plan, nWorkBudget, /*isSendNotRecv=*/false, recvPeer, recv->chunk, recvPtr, recvChunkBytes, fuseOk));
fuseOk = true;
recvPtr += recvChunkBytes; recvPtr += recvChunkBytes;
recvBytes -= recvChunkBytes; recvBytes -= recvChunkBytes;
recv->chunk += 1; recv->chunk += 1;
@ -689,7 +695,8 @@ static ncclResult_t scheduleP2pTasksToPlan(
if (sendChunkBytes != 0) { if (sendChunkBytes != 0) {
if (sendChunkBytes == -1) sendChunkBytes = 0; if (sendChunkBytes == -1) sendChunkBytes = 0;
if (*nWorkBudget < 1) return ncclSuccess; // ensure room in budget if (*nWorkBudget < 1) return ncclSuccess; // ensure room in budget
NCCLCHECK(addP2pToPlan(comm, plan, nWorkBudget, /*isSendNotRecv=*/true, sendPeer, send->chunk, sendPtr, sendChunkBytes)); NCCLCHECK(addP2pToPlan(comm, plan, nWorkBudget, /*isSendNotRecv=*/true, sendPeer, send->chunk, sendPtr, sendChunkBytes, fuseOk));
fuseOk = true;
sendPtr += sendChunkBytes; sendPtr += sendChunkBytes;
sendBytes -= sendChunkBytes; sendBytes -= sendChunkBytes;
send->chunk += 1; send->chunk += 1;
@ -822,12 +829,12 @@ static ncclResult_t uploadWork(struct ncclComm* comm, struct ncclKernelPlan* pla
} }
static ncclResult_t uploadProxyOps(struct ncclComm* comm, struct ncclKernelPlan* plan) { static ncclResult_t uploadProxyOps(struct ncclComm* comm, struct ncclKernelPlan* plan) {
uint64_t collOpCount = comm->collOpCount; uint64_t collOpCount = comm->sharedRes->collOpCount;
// Advance comm's collOpCount by number of colls in this plan. // Advance comm's collOpCount by number of colls in this plan.
comm->collOpCount = collOpCount + plan->collOpCount; comm->sharedRes->collOpCount += plan->collOpCount;
for (int c=0; c < plan->channelUbound; c++) { for (int c=0; c < plan->channelUbound; c++) {
struct ncclProxyOp* q = ncclIntruQueueHead(&plan->channels[c].proxyOpQueue); struct ncclProxyOp* q = ncclIntruQueueHead(&plan->channels[c].proxyOpQueue);
uint64_t p2pOpCount = comm->channels[c].p2pOpCount; uint64_t p2pOpCount = comm->sharedRes->p2pOpCount[c];
uint64_t nextP2pOpCount = p2pOpCount; uint64_t nextP2pOpCount = p2pOpCount;
while (q != nullptr) { while (q != nullptr) {
struct ncclProxyOp* qNext = q->enqNext; struct ncclProxyOp* qNext = q->enqNext;
@ -850,7 +857,7 @@ static ncclResult_t uploadProxyOps(struct ncclComm* comm, struct ncclKernelPlan*
q = qNext; q = qNext;
} }
// Advance channel's p2pOpCount by number of p2p's in this plan channel. // Advance channel's p2pOpCount by number of p2p's in this plan channel.
comm->channels[c].p2pOpCount = nextP2pOpCount; comm->sharedRes->p2pOpCount[c] = nextP2pOpCount;
} }
return ncclSuccess; return ncclSuccess;
} }
@ -969,14 +976,14 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) {
// The two-level fan-in fan-out is because ncclStrongStreamWaitStream() requires // The two-level fan-in fan-out is because ncclStrongStreamWaitStream() requires
// at least one of the two streams to be strong-stream. // at least one of the two streams to be strong-stream.
cudaStream_t launchStream = tasks->streams->stream; cudaStream_t launchStream = tasks->streams->stream;
NCCLCHECKGOTO(ncclStrongStreamAcquire(tasks->capturingGraph, &comm->deviceStream), result, failure); NCCLCHECKGOTO(ncclStrongStreamAcquire(tasks->capturingGraph, &comm->sharedRes->deviceStream), result, failure);
// Create dependency for device stream on user streams. First from extra user // Create dependency for device stream on user streams. First from extra user
// streams to deviceStream. Then deviceStream to first user stream. // streams to deviceStream. Then deviceStream to first user stream.
for (struct ncclCudaStreamList* l=tasks->streams->next; l != nullptr; l = l->next) { for (struct ncclCudaStreamList* l=tasks->streams->next; l != nullptr; l = l->next) {
NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, &comm->deviceStream, l->stream), result, failure); NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, &comm->sharedRes->deviceStream, l->stream), result, failure);
} }
NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, launchStream, &comm->deviceStream), result, failure); NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, launchStream, &comm->sharedRes->deviceStream), result, failure);
if (persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking) { if (persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking) {
// We have to launch host tasks to push proxy args. We are careful to only // We have to launch host tasks to push proxy args. We are careful to only
@ -986,15 +993,15 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) {
if (plan->hasProxyOps) { if (plan->hasProxyOps) {
if (!acquired) { if (!acquired) {
acquired = true; acquired = true;
NCCLCHECKGOTO(ncclStrongStreamAcquire(tasks->capturingGraph, &comm->hostStream), result, failure); NCCLCHECKGOTO(ncclStrongStreamAcquire(tasks->capturingGraph, &comm->sharedRes->hostStream), result, failure);
} }
NCCLCHECKGOTO(ncclStrongStreamLaunchHost(tasks->capturingGraph, &comm->hostStream, hostStreamPlanCallback, plan), result, failure); NCCLCHECKGOTO(ncclStrongStreamLaunchHost(tasks->capturingGraph, &comm->sharedRes->hostStream, hostStreamPlanCallback, plan), result, failure);
} }
} }
if (acquired) { if (acquired) {
// Make to-be-launched kernels dependent on just-launched host stream tasks. // Make to-be-launched kernels dependent on just-launched host stream tasks.
NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, launchStream, &comm->hostStream), result, failure); NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, launchStream, &comm->sharedRes->hostStream), result, failure);
NCCLCHECKGOTO(ncclStrongStreamRelease(tasks->capturingGraph, &comm->hostStream), result, failure); NCCLCHECKGOTO(ncclStrongStreamRelease(tasks->capturingGraph, &comm->sharedRes->hostStream), result, failure);
} }
} }
@ -1038,7 +1045,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan
NCCLCHECK(ncclCudaDriverVersion(&driverVersion)); NCCLCHECK(ncclCudaDriverVersion(&driverVersion));
if (driverVersion >= 11080) { if (driverVersion >= 11080) {
int compCap = comm->compCap; int compCap = comm->compCap;
unsigned int clusterSize = (compCap == 90) ? comm->cgaClusterSize : 0; unsigned int clusterSize = (compCap == 90) ? comm->config.cgaClusterSize : 0;
cudaLaunchConfig_t launchConfig = {0}; cudaLaunchConfig_t launchConfig = {0};
cudaLaunchAttribute launchAttrs[3]; cudaLaunchAttribute launchAttrs[3];
@ -1110,7 +1117,7 @@ ncclResult_t ncclLaunchFinish(struct ncclComm* comm) {
// Create dependency for deviceStream on launchStream. We know that deviceStream // Create dependency for deviceStream on launchStream. We know that deviceStream
// hasn't been modified since launchStream waited on it (in ncclLaunchPrepare), // hasn't been modified since launchStream waited on it (in ncclLaunchPrepare),
// so we can say that launchStream subsumes it. // so we can say that launchStream subsumes it.
NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, &comm->deviceStream, launchStream, /*b_subsumes_a=*/true), result, resume1); NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, &comm->sharedRes->deviceStream, launchStream, /*b_subsumes_a=*/true), result, resume1);
resume1: resume1:
// Create dependency for other user streams (skip launch stream) on deviceStream. // Create dependency for other user streams (skip launch stream) on deviceStream.
// Again, the user streams haven't been touched since deviceStream waited on them // Again, the user streams haven't been touched since deviceStream waited on them
@ -1118,12 +1125,12 @@ ncclResult_t ncclLaunchFinish(struct ncclComm* comm) {
struct ncclCudaStreamList* sl = tasks->streams->next; struct ncclCudaStreamList* sl = tasks->streams->next;
tasks->streams = nullptr; // Reset comm->tasks.streams to empty. tasks->streams = nullptr; // Reset comm->tasks.streams to empty.
while (sl != nullptr) { while (sl != nullptr) {
NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, sl->stream, &comm->deviceStream, /*b_subsumes_a=*/true), result, resume2); NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, sl->stream, &comm->sharedRes->deviceStream, /*b_subsumes_a=*/true), result, resume2);
resume2: resume2:
sl = sl->next; sl = sl->next;
} }
// Release device stream as acquired in ncclLaunchPrepare() // Release device stream as acquired in ncclLaunchPrepare()
NCCLCHECKGOTO(ncclStrongStreamRelease(tasks->capturingGraph, &comm->deviceStream), result, resume3); NCCLCHECKGOTO(ncclStrongStreamRelease(tasks->capturingGraph, &comm->sharedRes->deviceStream), result, resume3);
resume3:; resume3:;
} }
return result; return result;
@ -1160,6 +1167,8 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i
for (int a=0; a<nAlgos; a++) { for (int a=0; a<nAlgos; a++) {
if ((a == NCCL_ALGO_COLLNET_DIRECT || a == NCCL_ALGO_COLLNET_CHAIN) && collNetTypeSupport != 1) continue; if ((a == NCCL_ALGO_COLLNET_DIRECT || a == NCCL_ALGO_COLLNET_CHAIN) && collNetTypeSupport != 1) continue;
if (a == NCCL_ALGO_NVLS && !NCCL_NVLS_SUPPORTS(info->datatype, info->opFull.op)) continue; if (a == NCCL_ALGO_NVLS && !NCCL_NVLS_SUPPORTS(info->datatype, info->opFull.op)) continue;
if (a == NCCL_ALGO_NVLS && collNetTypeSupport != 1 && comm->nNodes > 1) continue;
if (a == NCCL_ALGO_NVLS_TREE && !NCCL_NVLS_SUPPORTS(info->datatype, info->opFull.op)) continue;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
float time; float time;
@ -1193,7 +1202,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i
} }
ncSwitch /= 2; ncSwitch /= 2;
} }
} else if (info->algorithm == NCCL_ALGO_NVLS) { } else if (info->algorithm == NCCL_ALGO_NVLS || info->algorithm == NCCL_ALGO_NVLS_TREE) {
// NVLS should not need more than 16 channels to get peak BW. // NVLS should not need more than 16 channels to get peak BW.
nc = comm->nvlsChannels; nc = comm->nvlsChannels;
} else { } else {
@ -1205,12 +1214,9 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i
} }
} }
if (info->protocol == NCCL_PROTO_SIMPLE) { if (info->protocol == NCCL_PROTO_SIMPLE) {
nt += WARP_SIZE; // Extra warp for sync if (info->algorithm == NCCL_ALGO_RING) nt += WARP_SIZE; // Extra warp for sync
// More threads or sync warps needed due to split thread model // More threads or sync warps needed due to split thread model
if (info->algorithm == NCCL_ALGO_TREE) nt += 3*WARP_SIZE; if (info->algorithm == NCCL_ALGO_TREE) nt += 4*WARP_SIZE;
if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) nt += 3*WARP_SIZE;
if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN) nt += 3*WARP_SIZE;
if (info->algorithm == NCCL_ALGO_NVLS) nt = NCCL_MAX_NTHREADS;
} }
nt = nt/WARP_SIZE < 3 ? 3*WARP_SIZE : nt; nt = nt/WARP_SIZE < 3 ? 3*WARP_SIZE : nt;
info->nChannels = nc; info->nChannels = nc;
@ -1226,10 +1232,13 @@ static ncclResult_t getPatternInfo(struct ncclInfo* info) {
info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo; break; info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo; break;
case ncclFuncReduceScatter: case ncclFuncReduceScatter:
case ncclFuncAllGather: case ncclFuncAllGather:
info->pattern = ncclPatternRing; break; info->pattern =
info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls :
ncclPatternRing; break;
case ncclFuncAllReduce: case ncclFuncAllReduce:
info->pattern = info->pattern =
info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls : info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls :
info->algorithm == NCCL_ALGO_NVLS_TREE ? ncclPatternNvlsTree :
info->algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect : info->algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect :
info->algorithm == NCCL_ALGO_COLLNET_CHAIN ? ncclPatternCollnetChain : info->algorithm == NCCL_ALGO_COLLNET_CHAIN ? ncclPatternCollnetChain :
info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown :
@ -1249,14 +1258,17 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) {
case ncclPatternPipelineFrom: case ncclPatternPipelineFrom:
case ncclPatternPipelineTo: case ncclPatternPipelineTo:
case ncclPatternCollnetChain: case ncclPatternCollnetChain:
info->nstepsPerLoop = info->nchunksPerLoop = 1; break;
case ncclPatternNvls: case ncclPatternNvls:
info->nstepsPerLoop = info-> nchunksPerLoop = 1; break; info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].nvls.nHeads; break;
case ncclPatternCollnetDirect: case ncclPatternCollnetDirect:
info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].collnetDirect.nHeads; break; info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].collnetDirect.nHeads; break;
case ncclPatternRing: case ncclPatternRing:
info->nstepsPerLoop = info->comm->nRanks-1; info->nchunksPerLoop = info->comm->nRanks; break; info->nstepsPerLoop = info->comm->nRanks-1; info->nchunksPerLoop = info->comm->nRanks; break;
case ncclPatternRingTwice: case ncclPatternRingTwice:
info->nstepsPerLoop = 2*(info->comm->nRanks-1); info->nchunksPerLoop = info->comm->nRanks; break; info->nstepsPerLoop = 2*(info->comm->nRanks-1); info->nchunksPerLoop = info->comm->nRanks; break;
case ncclPatternNvlsTree:
info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].nvls.nHeads; break;
default: default:
WARN("Unknown pattern %d", info->pattern); WARN("Unknown pattern %d", info->pattern);
return ncclInternalError; return ncclInternalError;
@ -1326,13 +1338,22 @@ comp_next:
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2; while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2;
work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
} else if (info->algorithm == NCCL_ALGO_NVLS) { } else if (info->algorithm == NCCL_ALGO_NVLS) {
if (chunkSize > 131072) chunkSize = 131072; int maxChunkSize = 131072;
if (chunkSize > maxChunkSize) chunkSize = maxChunkSize;
// Use uint64_t so that concurrentOps*chunkSize*X does not overflow // Use uint64_t so that concurrentOps*chunkSize*X does not overflow
uint64_t concurrentOps = info->nChannels*info->comm->channels[0].nvls.nHeads; uint64_t concurrentOps = info->nChannels*info->comm->channels[0].nvls.nHeads;
if ((info->nBytes < (32 * (concurrentOps*chunkSize))) && (chunkSize > 65536)) chunkSize = 65536; if ((info->nBytes < (64 * (concurrentOps*chunkSize))) && (chunkSize > 65536)) chunkSize = 65536;
if ((info->nBytes < (8 * (concurrentOps*chunkSize))) && (chunkSize > 32768)) chunkSize = 32768; if ((info->nBytes < (8 * (concurrentOps*chunkSize))) && (chunkSize > 32768)) chunkSize = 32768;
if ((info->nBytes < (2 * (concurrentOps*chunkSize))) && (chunkSize > 16384)) chunkSize = 16384; if ((info->nBytes < (2 * (concurrentOps*chunkSize))) && (chunkSize > 16384)) chunkSize = 16384;
work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
} else if (info->algorithm == NCCL_ALGO_NVLS_TREE) {
// Use uint64_t so that concurrentOps*chunkSize*X does not overflow
uint64_t concurrentOps = info->nChannels*info->comm->channels[0].nvls.nHeads;
if ((info->nBytes < (32 * (concurrentOps*chunkSize))) && (chunkSize > 262144)) chunkSize = 262144;
if ((info->nBytes < (16 * (concurrentOps*chunkSize))) && (chunkSize > 131072)) chunkSize = 131072;
if ((info->nBytes < (4 * (concurrentOps*chunkSize))) && (chunkSize > 65536)) chunkSize = 65536;
if ((info->nBytes < (1 * (concurrentOps*chunkSize))) && (chunkSize > 32768)) chunkSize = 32768;
work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
} else if (info->protocol == NCCL_PROTO_LL) { } else if (info->protocol == NCCL_PROTO_LL) {
const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine); const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine);
const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize; const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize;
@ -1361,8 +1382,7 @@ comp_next:
proxyOp->chunkSize = chunkSize; proxyOp->chunkSize = chunkSize;
proxyOp->protocol = info->protocol; proxyOp->protocol = info->protocol;
proxyOp->dtype = info->datatype; proxyOp->dtype = info->datatype;
proxyOp->redOp = (info->algorithm != NCCL_ALGO_COLLNET_DIRECT && info->algorithm != NCCL_ALGO_COLLNET_CHAIN) ? ncclNumOps : // Only set redOp when using CollNet proxyOp->redOp = info->opFull.op==ncclDevPreMulSum || info->opFull.op==ncclDevSumPostDiv ? ncclSum : // Network sees avg as sum
info->opFull.op==ncclDevPreMulSum || info->opFull.op==ncclDevSumPostDiv ? ncclSum : // Network sees avg as sum
info->op; info->op;
proxyOp->pattern = info->pattern; proxyOp->pattern = info->pattern;
proxyOp->root = info->root; proxyOp->root = info->root;
@ -1476,12 +1496,12 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo const* inf
int channelId; int channelId;
NCCLCHECK(ncclChannelComputeFromBase(comm, channelBaseId, c, &channelId)); NCCLCHECK(ncclChannelComputeFromBase(comm, channelBaseId, c, &channelId));
if (isSendNotRecv) { if (isSendNotRecv) {
if (comm->channels[channelId].peers[peer].send[1].connected == 0) { // P2P uses only 1 connector if (comm->channels[channelId].peers[peer]->send[1].connected == 0) { // P2P uses only 1 connector
comm->connectSend[peer] |= (1UL<<channelId); comm->connectSend[peer] |= (1UL<<channelId);
ncclGroupCommPreconnect(comm); ncclGroupCommPreconnect(comm);
} }
} else { } else {
if (comm->channels[channelId].peers[peer].recv[1].connected == 0) { // P2P uses only 1 connector if (comm->channels[channelId].peers[peer]->recv[1].connected == 0) { // P2P uses only 1 connector
comm->connectRecv[peer] |= (1UL<<channelId); comm->connectRecv[peer] |= (1UL<<channelId);
ncclGroupCommPreconnect(comm); ncclGroupCommPreconnect(comm);
} }
@ -1576,10 +1596,10 @@ exit:
NCCLCHECK(ncclGroupEndInternal()); NCCLCHECK(ncclGroupEndInternal());
/* if depth is 1, ncclGroupEndInternal() will trigger group ops. The state can change /* if depth is 1, ncclGroupEndInternal() will trigger group ops. The state can change
* so we have to check state here. */ * so we have to check state here. */
if (info->comm && !info->comm->blocking) { NCCLCHECK(ncclCommGetAsyncError(info->comm, &ret)) }; if (info->comm && !info->comm->config.blocking) { NCCLCHECK(ncclCommGetAsyncError(info->comm, &ret)) };
return ret; return ret;
fail: fail:
if (info->comm && !info->comm->blocking) (void) ncclCommSetAsyncError(info->comm, ret); if (info->comm && !info->comm->config.blocking) (void) ncclCommSetAsyncError(info->comm, ret);
goto exit; goto exit;
} }

View File

@ -14,9 +14,7 @@
/********************* Internode connection ***********************/ /********************* Internode connection ***********************/
/******************************************************************/ /******************************************************************/
ncclResult_t ncclTopoPreset(struct ncclComm* comm, ncclResult_t ncclTopoPreset(struct ncclComm* comm, struct ncclTopoGraph** graphs, struct ncclTopoRanks* topoRanks) {
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph,
struct ncclTopoRanks* topoRanks) {
int rank = comm->rank; int rank = comm->rank;
int localRanks = comm->topo->nodes[GPU].count; int localRanks = comm->topo->nodes[GPU].count;
int nChannels = comm->nChannels; int nChannels = comm->nChannels;
@ -35,9 +33,10 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
for (int i=0; i<NCCL_MAX_DIRECT_ARITY; i++) channel->collnetDirect.up[i] = -1; for (int i=0; i<NCCL_MAX_DIRECT_ARITY; i++) channel->collnetDirect.up[i] = -1;
for (int i=0; i<NCCL_MAX_DIRECT_ARITY; i++) channel->collnetDirect.down[i] = -1; for (int i=0; i<NCCL_MAX_DIRECT_ARITY; i++) channel->collnetDirect.down[i] = -1;
int* ringIntra = ringGraph->intra+c*localRanks; int* ringIntra = graphs[NCCL_ALGO_RING]->intra+c*localRanks;
int* treeIntra = treeGraph->intra+c*localRanks; int* treeIntra = graphs[NCCL_ALGO_TREE]->intra+c*localRanks;
int* collNetIntra = collNetGraph->intra+c*localRanks; int* collNetIntra = graphs[NCCL_ALGO_COLLNET_CHAIN]->intra+c*localRanks;
int* nvlsIntra = graphs[NCCL_ALGO_NVLS]->intra+c*localRanks;
for (int i=0; i<localRanks; i++) { for (int i=0; i<localRanks; i++) {
if (ringIntra[i] == rank) { if (ringIntra[i] == rank) {
@ -48,8 +47,8 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
} }
if (treeIntra[i] == rank) { if (treeIntra[i] == rank) {
int parentIndex = 0; int parentIndex = 0;
int child0Index = treeGraph->pattern == NCCL_TOPO_PATTERN_TREE ? 0 : 1; int child0Index = graphs[NCCL_ALGO_TREE]->pattern == NCCL_TOPO_PATTERN_TREE ? 0 : 1;
int child1Index = treeGraph->pattern == NCCL_TOPO_PATTERN_SPLIT_TREE ? 1 : 0; int child1Index = graphs[NCCL_ALGO_TREE]->pattern == NCCL_TOPO_PATTERN_SPLIT_TREE ? 1 : 0;
topoRanks->treeToParent[c] = treeIntra[parentIndex]; topoRanks->treeToParent[c] = treeIntra[parentIndex];
topoRanks->treeToChild0[c] = treeIntra[child0Index]; topoRanks->treeToChild0[c] = treeIntra[child0Index];
@ -64,6 +63,7 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
} }
topoRanks->ringPrev[c] = channel->ring.prev; topoRanks->ringPrev[c] = channel->ring.prev;
topoRanks->ringNext[c] = channel->ring.next; topoRanks->ringNext[c] = channel->ring.next;
topoRanks->nvlsHeads[c] = nvlsIntra[0];
} }
// Duplicate channels rings/trees // Duplicate channels rings/trees
struct ncclChannel* channel0 = comm->channels; struct ncclChannel* channel0 = comm->channels;
@ -72,26 +72,26 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t connectRings(struct ncclComm* comm, int* ringRecv, int* ringSend, int* ringPrev, int* ringNext, int* firstRanks) { static ncclResult_t connectRings(struct ncclComm* comm, int* ringRecv, int* ringSend, int* ringPrev, int* ringNext) {
int nChannels = comm->nChannels; int nChannels = comm->nChannels;
int nNodes = comm->nNodes; int nNodes = comm->nNodes;
for (int c=0; c<nChannels; c++) { for (int c=0; c<nChannels; c++) {
int* recv = ringRecv+c*comm->nRanks; int* recv = ringRecv+c*comm->nNodes;
int* send = ringSend+c*comm->nRanks; int* send = ringSend+c*comm->nNodes;
int* prev = ringPrev+c*comm->nRanks; int* prev = ringPrev+c*comm->nRanks;
int* next = ringNext+c*comm->nRanks; int* next = ringNext+c*comm->nRanks;
struct ncclChannel* channel0 = comm->channels+c; struct ncclChannel* channel0 = comm->channels+c;
struct ncclChannel* channel1 = channel0+nChannels; struct ncclChannel* channel1 = channel0+nChannels;
for (int n=0; n<nNodes; n++) { for (int n=0; n<nNodes; n++) {
int recvRank = recv[firstRanks[n]]; int recvRank = recv[n];
int prevSendRank = send[firstRanks[(n-1+nNodes)%nNodes]]; int prevSendRank = send[(n-1+nNodes)%nNodes];
prev[recvRank] = prevSendRank; prev[recvRank] = prevSendRank;
if (comm->rank == recvRank) { if (comm->rank == recvRank) {
channel0->ring.prev = prevSendRank; channel0->ring.prev = prevSendRank;
channel1->ring.prev = prevSendRank; channel1->ring.prev = prevSendRank;
} }
int sendRank = send[firstRanks[n]]; int sendRank = send[n];
int nextRecvRank = recv[firstRanks[(n+1)%nNodes]]; int nextRecvRank = recv[(n+1)%nNodes];
next[sendRank] = nextRecvRank; next[sendRank] = nextRecvRank;
if (comm->rank == sendRank) { if (comm->rank == sendRank) {
channel0->ring.next = nextRecvRank; channel0->ring.next = nextRecvRank;
@ -104,8 +104,8 @@ static ncclResult_t connectRings(struct ncclComm* comm, int* ringRecv, int* ring
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t getIndexes(int* ranks, int* indexes, int nNodes, int* firstRanks) { static ncclResult_t getIndexes(int* ranks, int* indexes, int nNodes) {
for (int n=0; n<nNodes; n++) indexes[n] = ranks[firstRanks[n]]; for (int n=0; n<nNodes; n++) indexes[n] = ranks[n];
return ncclSuccess; return ncclSuccess;
} }
@ -127,48 +127,42 @@ static ncclResult_t setTreeDown(struct ncclTree* tree, int* indexes, int d) {
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t connectTrees(struct ncclComm* comm, int* treeToParent, int* treeToChild0, int* treeToChild1, int* firstRanks, int* treePatterns) { static ncclResult_t connectTrees(struct ncclComm* comm, int* treeToParent, int* treeToChild0, int* treeToChild1, int* treePatterns) {
const int nChannels = comm->nChannels, nNodes = comm->nNodes, node = comm->node; const int nChannels = comm->nChannels, nNodes = comm->nNodes, node = comm->node;
int* ranksToParent, *ranksToChild0, *ranksToChild1;
NCCLCHECK(ncclCalloc(&ranksToParent, nNodes));
NCCLCHECK(ncclCalloc(&ranksToChild0, nNodes));
NCCLCHECK(ncclCalloc(&ranksToChild1, nNodes));
// Compute tree depth. Not an exact value but a good approximation in most // Compute tree depth. Not an exact value but a good approximation in most
// cases // cases
int depth = comm->nRanks/nNodes - 1 + log2i(nNodes); int depth = comm->nRanks/nNodes - 1 + log2i(nNodes);
int t0u, t0d0, t0d1, t0ChildType, t1u, t1d0, t1d1, t1ChildType; int t0u, t0d0, t0d1, t0ChildType, t1u, t1d0, t1d1, t1ChildType;
int* ttp, *ttc0, *ttc1;
NCCLCHECK(ncclGetDtree(nNodes, node, &t0u, &t0d0, &t0d1, &t0ChildType, &t1u, &t1d0, &t1d1, &t1ChildType)); NCCLCHECK(ncclGetDtree(nNodes, node, &t0u, &t0d0, &t0d1, &t0ChildType, &t1u, &t1d0, &t1d1, &t1ChildType));
for (int c=0; c<nChannels; c++) { for (int c=0; c<nChannels; c++) {
struct ncclChannel* channel0 = comm->channels+c; struct ncclChannel* channel0 = comm->channels+c;
struct ncclChannel* channel1 = channel0+nChannels; struct ncclChannel* channel1 = channel0+nChannels;
NCCLCHECK(getIndexes(treeToParent+c*comm->nRanks, ranksToParent, nNodes, firstRanks)); ttp = treeToParent+c*comm->nNodes;
NCCLCHECK(getIndexes(treeToChild0+c*comm->nRanks, ranksToChild0, nNodes, firstRanks)); ttc0 = treeToChild0+c*comm->nNodes;
NCCLCHECK(getIndexes(treeToChild1+c*comm->nRanks, ranksToChild1, nNodes, firstRanks)); ttc1 = treeToChild1+c*comm->nNodes;
if (comm->rank == ranksToParent[node]) { if (comm->rank == ttp[node]) {
NCCLCHECK(setTreeUp(&channel0->tree, t0ChildType == 0 ? ranksToChild0 : ranksToChild1, t0u)); NCCLCHECK(setTreeUp(&channel0->tree, t0ChildType == 0 ? ttc0 : ttc1, t0u));
NCCLCHECK(setTreeUp(&channel1->tree, t1ChildType == 0 ? ranksToChild0 : ranksToChild1, t1u)); NCCLCHECK(setTreeUp(&channel1->tree, t1ChildType == 0 ? ttc0 : ttc1, t1u));
} }
if (comm->rank == ranksToChild0[node]) { if (comm->rank == ttc0[node]) {
NCCLCHECK(setTreeDown(&channel0->tree, ranksToParent, t0d0)); NCCLCHECK(setTreeDown(&channel0->tree, ttp, t0d0));
NCCLCHECK(setTreeDown(&channel1->tree, ranksToParent, t1d0)); NCCLCHECK(setTreeDown(&channel1->tree, ttp, t1d0));
} }
if (comm->rank == ranksToChild1[node]) { if (comm->rank == ttc1[node]) {
NCCLCHECK(setTreeDown(&channel0->tree, ranksToParent, t0d1)); NCCLCHECK(setTreeDown(&channel0->tree, ttp, t0d1));
NCCLCHECK(setTreeDown(&channel1->tree, ranksToParent, t1d1)); NCCLCHECK(setTreeDown(&channel1->tree, ttp, t1d1));
} }
if (comm->rank == ranksToParent[node] || if (comm->rank == ttp[node] ||
comm->rank == ranksToChild0[node] || comm->rank == ttc0[node] ||
comm->rank == ranksToChild1[node]) { comm->rank == ttc1[node]) {
INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c, channel0->tree.up, comm->rank, channel0->tree.down[0], channel0->tree.down[1], channel0->tree.down[2]); INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c, channel0->tree.up, comm->rank, channel0->tree.down[0], channel0->tree.down[1], channel0->tree.down[2]);
INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c+nChannels, channel1->tree.up, comm->rank, channel1->tree.down[0], channel1->tree.down[1], channel1->tree.down[2]); INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c+nChannels, channel1->tree.up, comm->rank, channel1->tree.down[0], channel1->tree.down[1], channel1->tree.down[2]);
} }
channel0->tree.depth = channel1->tree.depth = depth; channel0->tree.depth = channel1->tree.depth = depth;
} }
free(ranksToParent);
free(ranksToChild0);
free(ranksToChild1);
return ncclSuccess; return ncclSuccess;
} }
@ -221,10 +215,96 @@ static ncclResult_t connectCollNet(struct ncclComm* comm, struct ncclTopoGraph*
INFO(NCCL_GRAPH, "%s", line); INFO(NCCL_GRAPH, "%s", line);
channel->collnetChain.depth = comm->nRanks/comm->nNodes; channel->collnetChain.depth = comm->nRanks/comm->nNodes;
} }
for (int c=0; c<comm->nvlsChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
if (channel->nvls.headRank != -1) channel->nvls.out = comm->nRanks;
}
free(heads); free(heads);
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t connectNvls(struct ncclComm* comm, int* nvlsHeads, struct ncclTopoGraph* nvlsGraph) {
int nHeads = nvlsGraph->nChannels;
int headRank = -1;
for (int h=0; h<nHeads; h++) {
if (nvlsGraph->intra[h*comm->localRanks] == comm->rank) headRank = h;
}
if (nHeads == 0) {
comm->nvlsChannels = 0;
return ncclSuccess;
}
for (int c=0; c<comm->nvlsChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
channel->nvls.nHeads = nHeads;
for (int h=0; h<nHeads; h++) channel->nvls.up[h] = comm->nRanks+1+h;
for (int h=nHeads; h<NCCL_MAX_NVLS_ARITY; h++) channel->nvls.up[h] = -1;
channel->nvls.down = comm->nRanks+1+headRank;
channel->nvls.out = -1; // NVLS+SHARP not yet implemented.
channel->nvls.headRank = headRank;
channel->nvls.treeUp = channel->nvls.treeDown[0] = channel->nvls.treeDown[1] = channel->nvls.treeDown[2] = -1;
channel->nvls.node = comm->node;
channel->nvls.nNodes = comm->nNodes;
}
if (comm->nNodes == 1) return ncclSuccess;
// Connect Trees
int tree0Parent, tree0Child0, tree0Child1, tree1Parent, tree1Child0, tree1Child1;
int pc0, pc1; // ignored
NCCLCHECK(ncclGetDtree(comm->nNodes, comm->node,
&tree0Parent, &tree0Child0, &tree0Child1, &pc0,
&tree1Parent, &tree1Child0, &tree1Child1, &pc1));
int* heads = NULL;
int treeUp[2] = { -1, -1 };
int treeDown0[2] = { -1, -1 };
int treeDown1[2] = { -1, -1 };
if (comm->node == 0) {
for (int h=0; h<nHeads; h++) {
char line[1024];
sprintf(line, "NVLS Head %2d:", h);
heads = nvlsHeads+h*comm->nNodes;
for (int n=0; n<comm->nNodes && n<20; n++) {
sprintf(line+strlen(line), " %2d", heads[n]);
}
INFO(NCCL_INIT, "%s", line);
}
}
// Find the heads where I'm the head rank and retain tree up/down
for (int h=0; h<nHeads; h++) {
heads = nvlsHeads+h*comm->nNodes;
if (heads[comm->node] == comm->rank) {
treeUp[0] = tree0Parent == -1 ? -1: heads[tree0Parent];
treeDown0[0] = tree0Child0 == -1 ? -1 : heads[tree0Child0];
treeDown1[0] = tree0Child1 == -1 ? -1 : heads[tree0Child1];
treeUp[1] = tree1Parent == -1 ? -1 : heads[tree1Parent];
treeDown0[1] = tree1Child0 == -1 ? -1 : heads[tree1Child0];
treeDown1[1] = tree1Child1 == -1 ? -1 : heads[tree1Child1];
break;
}
}
// Set prev/next in all channels (NVLS compute channels work
// orthogonally to NVLS search channels).
for (int c=0; c<comm->nvlsChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
channel->nvls.treeUp = treeUp[c%2];
channel->nvls.treeDown[0] = channel->nvls.down;
int ix = 1;
if (treeDown0[c%2] != -1) channel->nvls.treeDown[ix++] = treeDown0[c%2];
if (treeDown1[c%2] != -1) channel->nvls.treeDown[ix] = treeDown1[c%2];
}
struct ncclNvls* nvls0 = &comm->channels[0].nvls;
struct ncclNvls* nvls1 = &comm->channels[1].nvls;
INFO(NCCL_GRAPH, "NVLS Trees : %d/%d->%d->%d %d/%d->%d->%d",
nvls0->treeDown[0], nvls0->treeDown[1], comm->rank, nvls0->treeUp,
nvls1->treeDown[0], nvls1->treeDown[1], comm->rank, nvls1->treeUp);
return ncclSuccess;
}
// Legacy naming // Legacy naming
NCCL_PARAM(MinNrings, "MIN_NRINGS", -2); NCCL_PARAM(MinNrings, "MIN_NRINGS", -2);
NCCL_PARAM(MaxNrings, "MAX_NRINGS", -2); NCCL_PARAM(MaxNrings, "MAX_NRINGS", -2);
@ -266,33 +346,40 @@ static int copyChannels(struct ncclComm* comm, int start, int end, int* ringPrev
return c; return c;
} }
ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns, struct ncclTopoRanks** allTopoRanks, int* rings, struct ncclTopoGraph* collNetGraph) { ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns, struct ncclTopoRanks** allTopoRanks, int* rings, struct ncclTopoGraph** graphs) {
// Gather data from all ranks // Gather data from all ranks
int *ringRecv, *ringSend, *ringPrev, *ringNext, *treeToParent, *treeToChild0, *treeToChild1; int *ringRecv, *ringSend, *ringPrev, *ringNext, *treeToParent, *treeToChild0, *treeToChild1, *nvlsHeads;
int nranks = comm->nRanks; int nranks = comm->nRanks;
int nNodes = comm->nNodes;
int nChannels = comm->nChannels; int nChannels = comm->nChannels;
NCCLCHECK(ncclCalloc(&ringRecv, nranks*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&ringRecv, nNodes*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&ringSend, nranks*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&ringSend, nNodes*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&ringPrev, nranks*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&ringPrev, nranks*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&ringNext, nranks*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&ringNext, nranks*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&treeToParent, nranks*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&treeToParent, nNodes*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&treeToChild0, nranks*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&treeToChild0, nNodes*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&treeToChild1, nranks*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&treeToChild1, nNodes*MAXCHANNELS));
for (int i=0; i<nranks; i++) { NCCLCHECK(ncclCalloc(&nvlsHeads, nNodes*MAXCHANNELS));
for (int c=0; c<nChannels;c++) { for (int c=0; c<nChannels;c++) {
ringRecv[c*nranks+i] = allTopoRanks[i]->ringRecv[c]; for (int n=0; n<nNodes; n++) {
ringSend[c*nranks+i] = allTopoRanks[i]->ringSend[c]; int r = firstRanks[n];
ringPrev[c*nranks+i] = allTopoRanks[i]->ringPrev[c]; ringRecv[c*nNodes+n] = allTopoRanks[r]->ringRecv[c];
ringNext[c*nranks+i] = allTopoRanks[i]->ringNext[c]; ringSend[c*nNodes+n] = allTopoRanks[r]->ringSend[c];
treeToParent[c*nranks+i] = allTopoRanks[i]->treeToParent[c]; treeToParent[c*nNodes+n] = allTopoRanks[r]->treeToParent[c];
treeToChild0[c*nranks+i] = allTopoRanks[i]->treeToChild0[c]; treeToChild0[c*nNodes+n] = allTopoRanks[r]->treeToChild0[c];
treeToChild1[c*nranks+i] = allTopoRanks[i]->treeToChild1[c]; treeToChild1[c*nNodes+n] = allTopoRanks[r]->treeToChild1[c];
nvlsHeads[c*nNodes+n] = allTopoRanks[r]->nvlsHeads[c];
}
for (int r=0; r<nranks; r++) {
ringPrev[c*nranks+r] = allTopoRanks[r]->ringPrev[c];
ringNext[c*nranks+r] = allTopoRanks[r]->ringNext[c];
} }
} }
// Connect rings and trees. This should also duplicate the channels. // Connect rings and trees. This should also duplicate the channels.
NCCLCHECK(connectRings(comm, ringRecv, ringSend, ringPrev, ringNext, firstRanks)); NCCLCHECK(connectRings(comm, ringRecv, ringSend, ringPrev, ringNext));
NCCLCHECK(connectTrees(comm, treeToParent, treeToChild0, treeToChild1, firstRanks, treePatterns)); NCCLCHECK(connectTrees(comm, treeToParent, treeToChild0, treeToChild1, treePatterns));
NCCLCHECK(connectNvls(comm, nvlsHeads, graphs[NCCL_ALGO_NVLS]));
// Duplicate ringPrev/ringNext for ncclBuildRing // Duplicate ringPrev/ringNext for ncclBuildRing
memcpy(ringPrev+nChannels*nranks, ringPrev, nChannels*nranks*sizeof(int)); memcpy(ringPrev+nChannels*nranks, ringPrev, nChannels*nranks*sizeof(int));
@ -303,6 +390,7 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa
// Setup CollNet // Setup CollNet
if (comm->collNetSupport == 1) { if (comm->collNetSupport == 1) {
struct ncclTopoGraph* collNetGraph = graphs[NCCL_ALGO_COLLNET_DIRECT];
// Add more channels to saturate intra-node bandwidth, except the 1 PPN case // Add more channels to saturate intra-node bandwidth, except the 1 PPN case
if (collNetGraph->bwIntra > collNetGraph->bwInter && comm->nRanks > comm->nNodes) { if (collNetGraph->bwIntra > collNetGraph->bwInter && comm->nRanks > comm->nNodes) {
int collNetNchannels = std::min(MAXCHANNELS, nChannels+nChannels/2); int collNetNchannels = std::min(MAXCHANNELS, nChannels+nChannels/2);
@ -311,10 +399,21 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa
NCCLCHECK(connectCollNet(comm, collNetGraph)); NCCLCHECK(connectCollNet(comm, collNetGraph));
} }
// Use 4 compute channels per search channel to reach peak BW on <8 PPN
if (comm->minCompCap == 90 && comm->nNodes > 1 && graphs[NCCL_ALGO_RING]->bwIntra > 45.0 && 2*nChannels <= MAXCHANNELS) {
nChannels = comm->nChannels = copyChannels(comm, nChannels, 2*nChannels, ringPrev, ringNext);
}
// Honor NCCL_MIN_NRINGS/NCCL_MAX_NRINGS. // Honor NCCL_MIN_NRINGS/NCCL_MAX_NRINGS.
// We permit combining max, then min, to only use the first channels, then duplicate them. // We permit combining max, then min, to only use the first channels, then duplicate them.
nChannels = comm->nChannels = std::min(std::min(ncclMaxNchannels(), nChannels), comm->maxCTAs); if (comm->sharedRes->owner != comm) {
nChannels = comm->nChannels = copyChannels(comm, nChannels, std::max(ncclMinNchannels(), comm->minCTAs), ringPrev, ringNext); /* child comm #channels cannot exceed top parent #channels. */
nChannels = comm->nChannels = std::min(std::min(std::min(ncclMaxNchannels(), nChannels), comm->config.maxCTAs), comm->sharedRes->tpNChannels);
nChannels = comm->nChannels = copyChannels(comm, nChannels, std::min(std::max(ncclMinNchannels(), comm->config.minCTAs), comm->sharedRes->tpNChannels), ringPrev, ringNext);
} else {
nChannels = comm->nChannels = std::min(std::min(ncclMaxNchannels(), nChannels), comm->config.maxCTAs);
nChannels = comm->nChannels = copyChannels(comm, nChannels, std::max(ncclMinNchannels(), comm->config.minCTAs), ringPrev, ringNext);
}
// Create rings array and check all is fine // Create rings array and check all is fine
NCCLCHECK(ncclBuildRings(nChannels, rings, comm->rank, comm->nRanks, ringPrev, ringNext)); NCCLCHECK(ncclBuildRings(nChannels, rings, comm->rank, comm->nRanks, ringPrev, ringNext));
@ -326,6 +425,7 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa
free(treeToParent); free(treeToParent);
free(treeToChild0); free(treeToChild0);
free(treeToChild1); free(treeToChild1);
free(nvlsHeads);
return ncclSuccess; return ncclSuccess;
} }

View File

@ -538,6 +538,11 @@ ncclResult_t ncclTopoComputePaths(struct ncclTopoSystem* system, struct ncclComm
NCCLCHECK(ncclTopoSetPaths(system->nodes[NET].nodes+n, system)); NCCLCHECK(ncclTopoSetPaths(system->nodes[NET].nodes+n, system));
} }
// Set direct paths to NVSwitches.
for (int n=0; n<system->nodes[NVS].count; n++) {
NCCLCHECK(ncclTopoSetPaths(system->nodes[NVS].nodes+n, system));
}
// Update path for GPUs when we don't want to / can't use GPU Direct P2P // Update path for GPUs when we don't want to / can't use GPU Direct P2P
for (int g=0; g<system->nodes[GPU].count; g++) { for (int g=0; g<system->nodes[GPU].count; g++) {
for (int p=0; p<system->nodes[GPU].count; p++) { for (int p=0; p<system->nodes[GPU].count; p++) {
@ -564,7 +569,7 @@ ncclResult_t ncclTopoComputePaths(struct ncclTopoSystem* system, struct ncclComm
NCCLCHECK(ncclTransports[TRANSPORT_SHM]->canConnect(&shm, system, NULL, srcInfo, dstInfo)); NCCLCHECK(ncclTransports[TRANSPORT_SHM]->canConnect(&shm, system, NULL, srcInfo, dstInfo));
if (shm == 0) { if (shm == 0) {
// Mark this peer as inaccessible. We'll trim it later. // Mark this peer as inaccessible. We'll trim it later.
system->nodes[GPU].nodes[p].paths[GPU][g].count = 0; system->nodes[GPU].nodes[p].paths[GPU][g].type = PATH_NET;
} }
} }
} }
@ -578,32 +583,20 @@ ncclResult_t ncclTopoComputePaths(struct ncclTopoSystem* system, struct ncclComm
// Check whether we can access the NIC through another NVLink-connected GPU (PXN) // Check whether we can access the NIC through another NVLink-connected GPU (PXN)
struct ncclTopoNode* gpu = system->nodes[GPU].nodes+g; struct ncclTopoNode* gpu = system->nodes[GPU].nodes+g;
if (ncclPxnDisable(comm) != 1) { if (ncclPxnDisable(comm) != 1) {
int pxnGpu = -1; int localGpuIndex;
NCCLCHECK(ncclTopoGetLocalGpu(system, system->nodes[NET].nodes[n].id, &localGpuIndex));
for (int p=0; p<system->nodes[GPU].count; p++) { if (localGpuIndex != g && localGpuIndex != -1) {
if (p == g) continue;
// PXN = PCI + NVLink. // PXN = PCI + NVLink.
struct ncclTopoNode* peerNode = system->nodes[GPU].nodes+p; struct ncclTopoNode* peerNode = system->nodes[GPU].nodes+localGpuIndex;
// Only use PXN for NIC n if remote GPU p ... // Only use PXN for NIC n if remote GPU p ...
if (peerNode->paths[NET][n].type > PATH_PXB || // Is connected to the NIC through PCI if (peerNode->paths[NET][n].type <= PATH_PXB && // Is connected to the NIC through PCI
peerNode->paths[GPU][g].type > PATH_NVL || // Is connected to us through NVLink peerNode->paths[GPU][g].type <= PATH_NVL && // Is connected to us through NVLink
(peerNode->paths[NET][n].bw <= gpu->paths[NET][n].bw && // Has either higher BW to that NIC (peerNode->paths[NET][n].bw > gpu->paths[NET][n].bw || // Has either higher BW to that NIC
gpu->paths[NET][n].type <= PATH_PXB)) // or avoids going through a CPU gpu->paths[NET][n].type > PATH_PXB)) // or avoids going through a CPU
continue;
pxnGpu = p;
int netDev;
NCCLCHECK(ncclTopoGetLocalNet(system, peerNode->gpu.rank, &netDev));
// To ensure proper balancing, use preferably a local GPU which advertised that NIC as its preferred one.
if (netDev == netNode->id) break;
}
if (pxnGpu != -1) {
// We can use that GPU as relay to communicate with that NIC. // We can use that GPU as relay to communicate with that NIC.
// Only enabling it in the GPU->NIC direction for now to favor // Only enabling it in the GPU->NIC direction for now to favor
// receiving locally and sending remotely (consistent with net.cc) // receiving locally and sending remotely (consistent with net.cc)
NCCLCHECK(addInterStep(system, GPU, pxnGpu, GPU, g, NET, n)); NCCLCHECK(addInterStep(system, GPU, localGpuIndex, GPU, g, NET, n));
} }
} }
// Update path when we dont want to / can't use GPU Direct RDMA. // Update path when we dont want to / can't use GPU Direct RDMA.
@ -632,7 +625,7 @@ ncclResult_t ncclTopoTrimSystem(struct ncclTopoSystem* system, struct ncclComm*
domains[g] = g; domains[g] = g;
ids[g] = gpu->id; ids[g] = gpu->id;
for (int p=0; p<g; p++) { for (int p=0; p<g; p++) {
if (gpu->paths[GPU][p].count > 0) { if (gpu->paths[GPU][p].type < PATH_NET) {
domains[g] = std::min(domains[g], domains[p]); domains[g] = std::min(domains[g], domains[p]);
} }
} }
@ -708,8 +701,14 @@ static int nextPow2(int v) {
ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm) { ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm) {
/* here we already honor comm->max/minCTAs for p2pnChannels. */ /* here we already honor comm->max/minCTAs for p2pnChannels. */
comm->p2pnChannels = std::min(comm->nChannels, (int)ncclParamMaxP2pNChannels()); if (comm->sharedRes->owner != comm) {
comm->p2pnChannels = std::max(comm->p2pnChannels, (int)ncclParamMinP2pNChannels()); comm->p2pnChannels = std::min(comm->nChannels, (int)ncclParamMaxP2pNChannels());
comm->p2pnChannels = std::min(std::max(comm->p2pnChannels, (int)ncclParamMinP2pNChannels()), comm->sharedRes->tpP2pNChannels);
} else {
comm->p2pnChannels = std::min(comm->nChannels, (int)ncclParamMaxP2pNChannels());
comm->p2pnChannels = std::max(comm->p2pnChannels, (int)ncclParamMinP2pNChannels());
}
int minChannels = comm->p2pnChannels; int minChannels = comm->p2pnChannels;
// We need to loop through all local GPUs to have a global picture // We need to loop through all local GPUs to have a global picture
for (int g=0; g<comm->topo->nodes[GPU].count; g++) { for (int g=0; g<comm->topo->nodes[GPU].count; g++) {

View File

@ -10,6 +10,8 @@
#include "xml.h" #include "xml.h"
#include <math.h> #include <math.h>
NCCL_PARAM(CrossNic, "CROSS_NIC", 2);
// Initialize system->maxBw. This is the per-channel (i.e. per-SM) // Initialize system->maxBw. This is the per-channel (i.e. per-SM)
// max bw. // max bw.
static float getMaxBw(struct ncclTopoSystem* system, struct ncclTopoNode* gpu, int type) { static float getMaxBw(struct ncclTopoSystem* system, struct ncclTopoNode* gpu, int type) {
@ -106,11 +108,15 @@ static ncclResult_t ncclTopoFollowPath(struct ncclTopoSystem* system, struct ncc
if (type1 == -1) return ncclSuccess; if (type1 == -1) return ncclSuccess;
struct ncclTopoNode* node1 = system->nodes[type1].nodes+index1; struct ncclTopoNode* node1 = system->nodes[type1].nodes+index1;
struct ncclTopoLinkList* path = node1->paths[type2]+index2; struct ncclTopoLinkList* path = node1->paths[type2]+index2;
if (path == NULL) {
WARN("No path computed to go from %s/%d to %s/%d", topoNodeTypeStr[type1], index1, topoNodeTypeStr[type2], index2);
return ncclInternalError;
}
if (path->count == 0 ) return ncclSuccess; if (path->count == 0 ) return ncclSuccess;
// Now check link type // Now check link type
*node = NULL; *node = NULL;
int intra = type1 == GPU && type2 == GPU; int intra = (type1 == GPU || type1 == NVS) && (type2 == GPU || type2 == NVS);
float bw = intra ? graph->bwIntra : graph->bwInter; float bw = intra ? graph->bwIntra : graph->bwInter;
int type = intra ? graph->typeIntra : graph->typeInter; int type = intra ? graph->typeIntra : graph->typeInter;
@ -290,17 +296,53 @@ ncclResult_t ncclTopoSearchTryGpu(struct ncclTopoSystem* system, struct ncclTopo
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t ncclTopoCompareGraphs(struct ncclTopoGraph* graph, struct ncclTopoGraph* refGraph, int* copy) { ncclResult_t ncclTopoSearchTryNvls(struct ncclTopoSystem* system, struct ncclTopoGraph* graph, struct ncclTopoGraph* saveGraph, int g, int ngpus, int *time) {
// 1. Constraint to get the same nChannels between Rings and Trees struct ncclTopoNode* nvs;
struct ncclTopoNode* gpu;
int d0=0; // See if there is enough bandwidth for NVS->GPU traffic
do {
NCCLCHECK(ncclTopoFollowPath(system, graph, NVS, 0, GPU, d0, d0 == g ? 2 : 1, &gpu));
d0++;
} while (gpu && d0 < system->nodes[GPU].count);
if (gpu == NULL) {
d0--;
} else {
int d1=0; // See if there is enough bandwidth for GPU->NVS traffic
do {
NCCLCHECK(ncclTopoFollowPath(system, graph, GPU, d1, NVS, 0, d1 == g ? 2 : 1, &nvs));
d1++;
} while (nvs && d1 < system->nodes[GPU].count);
if (nvs == NULL) {
d1--;
} else { // Both directions worked. Move on to the next path.
NCCLCHECK(ncclTopoSearchRecGpu(system, graph, saveGraph, NULL, ngpus, -1, -1, 0, time));
}
while (d1) {
d1--;
NCCLCHECK(ncclTopoFollowPath(system, graph, GPU, d1, NVS, 0, d1 == g ? -2 : -1, &nvs));
}
}
while (d0) {
d0--;
NCCLCHECK(ncclTopoFollowPath(system, graph, NVS, 0, GPU, d0, d0 == g ? -2 : -1, &gpu));
}
return ncclSuccess;
}
ncclResult_t ncclTopoCompareGraphs(struct ncclTopoSystem* system, struct ncclTopoGraph* graph, struct ncclTopoGraph* refGraph, int* copy) {
// 1. Try to get the same nChannels between Rings and Trees
if (graph->nChannels < graph->minChannels) return ncclSuccess; if (graph->nChannels < graph->minChannels) return ncclSuccess;
// 2. Try to get better bandwidth // 2. Try to get better bandwidth
if (graph->nChannels*graph->bwIntra < refGraph->nChannels*refGraph->bwIntra) return ncclSuccess; // Give a 15% perf bonus to paths not crossing nics
if (graph->nChannels*graph->bwIntra > refGraph->nChannels*refGraph->bwIntra) { float target = 1.0 - (refGraph->crossNic - graph->crossNic) * .15;
if (graph->nChannels*graph->bwIntra > refGraph->nChannels*refGraph->bwIntra*target) {
*copy = 1; *copy = 1;
return ncclSuccess; return ncclSuccess;
} }
// 3. Less hops (but not at the price of going cross NICs) if (graph->nChannels*graph->bwIntra < refGraph->nChannels*refGraph->bwIntra*target) return ncclSuccess;
// 3. Less hops
if (graph->pattern == refGraph->pattern && graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops) *copy = 1; if (graph->pattern == refGraph->pattern && graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops) *copy = 1;
return ncclSuccess; return ncclSuccess;
} }
@ -365,7 +407,7 @@ ncclResult_t ncclTopoSearchRecGpu(struct ncclTopoSystem* system, struct ncclTopo
// Determine whether we found a better solution or not // Determine whether we found a better solution or not
int copy = 0; int copy = 0;
graph->nChannels++; graph->nChannels++;
NCCLCHECK(ncclTopoCompareGraphs(graph, saveGraph, &copy)); NCCLCHECK(ncclTopoCompareGraphs(system, graph, saveGraph, &copy));
if (copy) { if (copy) {
memcpy(saveGraph, graph, sizeof(struct ncclTopoGraph)); memcpy(saveGraph, graph, sizeof(struct ncclTopoGraph));
if (graph->nChannels == graph->maxChannels) *time = -1; if (graph->nChannels == graph->maxChannels) *time = -1;
@ -417,6 +459,8 @@ ncclResult_t ncclTopoSearchRecGpu(struct ncclTopoSystem* system, struct ncclTopo
} }
free(nets); free(nets);
} }
} else if (graph->pattern == NCCL_TOPO_PATTERN_NVLS) {
NCCLCHECK(ncclTopoSearchTryNvls(system, graph, saveGraph, g, ngpus, time));
} else if (step < system->nodes[GPU].count-1) { } else if (step < system->nodes[GPU].count-1) {
// Go to next GPU // Go to next GPU
int next[NCCL_TOPO_MAX_NODES]; int next[NCCL_TOPO_MAX_NODES];
@ -570,7 +614,10 @@ ncclResult_t ncclTopoSearchRec(struct ncclTopoSystem* system, struct ncclTopoGra
ncclTopoSearchRecNet(system, graph, saveGraph, backToNet, backToFirstRank, time); ncclTopoSearchRecNet(system, graph, saveGraph, backToNet, backToFirstRank, time);
} else { } else {
// Intra-node only. // Intra-node only.
if (graph->nChannels == 0) { if (graph->pattern == NCCL_TOPO_PATTERN_NVLS) {
NCCLCHECK(ncclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, 0, time, -1, -1, graph->nChannels));
return ncclSuccess;
} else if (graph->nChannels == 0) {
// Try PCI order first // Try PCI order first
NCCLCHECK(ncclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, FORCED_ORDER_PCI, time, -1, -1, 0)); NCCLCHECK(ncclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, FORCED_ORDER_PCI, time, -1, -1, 0));
} else { } else {
@ -637,7 +684,7 @@ ncclResult_t ncclTopoGetGraphFromXmlSub(struct ncclXmlNode *xmlGraph, struct ncc
int crossNic; int crossNic;
NCCLCHECK(xmlGetAttrInt(xmlGraph, "crossnic", &crossNic)); NCCLCHECK(xmlGetAttrInt(xmlGraph, "crossnic", &crossNic));
if (graph->crossNic == 0 && crossNic == 1) return ncclSuccess; if (ncclParamCrossNic() == 0 && crossNic == 1) return ncclSuccess;
graph->crossNic = crossNic; graph->crossNic = crossNic;
NCCLCHECK(xmlGetAttrInt(xmlGraph, "pattern", &graph->pattern)); NCCLCHECK(xmlGetAttrInt(xmlGraph, "pattern", &graph->pattern));
@ -726,29 +773,31 @@ ncclResult_t ncclTopoGetXmlFromGraphs(int ngraphs, struct ncclTopoGraph** graphs
return ncclSuccess; return ncclSuccess;
} }
float speedArrayIntra[] = { 44.0, 30.0, 22.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0 }; float speedArrayIntra[] = { 40.0, 30.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0 };
float speedArrayInter[] = { 48.0, 30.0, 28.0, 24.0, 22.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 }; float speedArrayInter[] = { 48.0, 30.0, 28.0, 24.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
#define NSPEEDSINTRA (sizeof(speedArrayIntra)/sizeof(float)) #define NSPEEDSINTRA (sizeof(speedArrayIntra)/sizeof(float))
#define NSPEEDSINTER (sizeof(speedArrayInter)/sizeof(float)) #define NSPEEDSINTER (sizeof(speedArrayInter)/sizeof(float))
float sm90SpeedArrayIntra[] = { 66.0, 33.0, 24.0, 20.0, 15.0, 12.0, 6.0, 3.0 }; float sm90SpeedArrayIntra[] = { 60.0, 40.0, 30.0, 24.0, 20.0, 15.0, 12.0, 6.0, 3.0 };
float sm90SpeedArrayInter[] = { 48.0, 45.0, 30.0, 24.0, 15.0, 12.0, 6.0, 3.0, 2.4, 1.2, 0.24, 0.12 }; float sm90SpeedArrayInter[] = { 48.0, 45.0, 42.0, 40.0, 30.0, 24.0, 15.0, 12.0, 6.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
#define NSPEEDSINTRA_SM90 (sizeof(sm90SpeedArrayIntra)/sizeof(float)) #define NSPEEDSINTRA_SM90 (sizeof(sm90SpeedArrayIntra)/sizeof(float))
#define NSPEEDSINTER_SM90 (sizeof(sm90SpeedArrayInter)/sizeof(float)) #define NSPEEDSINTER_SM90 (sizeof(sm90SpeedArrayInter)/sizeof(float))
NCCL_PARAM(CrossNic, "CROSS_NIC", 2);
ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph) { ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph) {
int ngpus = system->nodes[GPU].count; int ngpus = system->nodes[GPU].count;
graph->crossNic = ncclParamCrossNic(); graph->crossNic = ncclParamCrossNic();
int crossNic = (system->nodes[NET].count > 1) && graph->crossNic ? 1 : 0; int crossNic = (system->nodes[NET].count > 1) && graph->crossNic &&
(graph->pattern == NCCL_TOPO_PATTERN_RING ||
graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE ||
graph->pattern == NCCL_TOPO_PATTERN_SPLIT_TREE) ? 1 : 0;
graph->bwIntra = graph->bwInter = 0; graph->bwIntra = graph->bwInter = 0;
graph->latencyInter = 0; graph->latencyInter = 0;
if (graph->crossNic == 2) graph->crossNic = 0; if (graph->crossNic == 2) graph->crossNic = 0;
graph->typeIntra = ngpus == 1 ? PATH_LOC : PATH_NVL; graph->typeIntra = ngpus == 1 ? PATH_LOC : PATH_NVL;
graph->typeInter = PATH_PIX; graph->typeInter = PATH_PIX;
graph->nChannels = 0; graph->nChannels = 0;
graph->sameChannels = 1; int trySameChannels = graph->pattern == NCCL_TOPO_PATTERN_NVLS ? 0 : 1;
graph->sameChannels = trySameChannels;
char* str = getenv("NCCL_GRAPH_FILE"); char* str = getenv("NCCL_GRAPH_FILE");
if (str) { if (str) {
@ -763,10 +812,16 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph
if (graph->nChannels > 0) return ncclSuccess; if (graph->nChannels > 0) return ncclSuccess;
} }
if (ngpus == 1) if (graph->pattern != NCCL_TOPO_PATTERN_RING) graph->pattern = NCCL_TOPO_PATTERN_TREE;
int ccMin; int ccMin;
NCCLCHECK(ncclTopoGetCompCap(system, &ccMin, NULL)); NCCLCHECK(ncclTopoGetCompCap(system, &ccMin, NULL));
if (graph->pattern == NCCL_TOPO_PATTERN_NVLS && (system->nodes[NVS].count == 0 || ccMin < 90)) return ncclSuccess;
if (ngpus == 1) if (graph->pattern != NCCL_TOPO_PATTERN_RING) graph->pattern = NCCL_TOPO_PATTERN_TREE;
if (system->nodes[NET].count == 0 && graph->pattern == NCCL_TOPO_PATTERN_NVLS) {
// Force intra-node NVLS algorithm to pull evenly from all GPUs.
graph->minChannels = graph->maxChannels = system->nodes[GPU].count;
}
struct ncclTopoGraph tmpGraph; struct ncclTopoGraph tmpGraph;
memcpy(&tmpGraph, graph, sizeof(struct ncclTopoGraph)); memcpy(&tmpGraph, graph, sizeof(struct ncclTopoGraph));
@ -783,7 +838,9 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph
} }
int pass = 1; int pass = 1;
int speedIndex = 0; int speedIndex = 0;
while (speedArray[speedIndex] > system->maxBw && speedIndex < nspeeds-1) speedIndex++; float maxBw = system->maxBw;
if (system->nodes[NET].count == 0 && graph->pattern == NCCL_TOPO_PATTERN_NVLS) maxBw /= ngpus; // We want all GPUs to pull the same BW
while (speedArray[speedIndex] > maxBw && speedIndex < nspeeds-1) speedIndex++;
tmpGraph.bwIntra = tmpGraph.bwInter = speedArray[speedIndex]; tmpGraph.bwIntra = tmpGraph.bwInter = speedArray[speedIndex];
int64_t globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT; int64_t globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT;
@ -817,7 +874,7 @@ search:
tmpGraph.sameChannels = 0; tmpGraph.sameChannels = 0;
goto search; goto search;
} }
tmpGraph.sameChannels = 1; tmpGraph.sameChannels = trySameChannels;
if (time != -1) globalTimeout += time; if (time != -1) globalTimeout += time;
else globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT; else globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT;
@ -856,7 +913,7 @@ search:
goto search; goto search;
} }
speedIndex = 0; speedIndex = 0;
while (speedArray[speedIndex] > system->maxBw && speedIndex < nspeeds-1) speedIndex++; while (speedArray[speedIndex] > maxBw && speedIndex < nspeeds-1) speedIndex++;
tmpGraph.bwIntra = tmpGraph.bwInter = speedArray[speedIndex]; tmpGraph.bwIntra = tmpGraph.bwInter = speedArray[speedIndex];
} }
@ -885,7 +942,7 @@ done:
memcpy(&tmpGraph, graph, sizeof(tmpGraph)); memcpy(&tmpGraph, graph, sizeof(tmpGraph));
} }
if (graph->nChannels == 0 && graph->collNet == 0) { if (graph->nChannels == 0 && graph->collNet == 0 && graph->pattern != NCCL_TOPO_PATTERN_NVLS) {
WARN("Could not find a path for pattern %d, falling back to simple order", graph->pattern); WARN("Could not find a path for pattern %d, falling back to simple order", graph->pattern);
for (int i=0; i<ngpus; i++) graph->intra[i] = system->nodes[GPU].nodes[i].gpu.rank; for (int i=0; i<ngpus; i++) graph->intra[i] = system->nodes[GPU].nodes[i].gpu.rank;
graph->inter[0] = graph->inter[1] = 0; graph->inter[0] = graph->inter[1] = 0;
@ -894,7 +951,7 @@ done:
graph->nChannels = 1; graph->nChannels = 1;
} }
if ((ccMin <= 80 && graph->bwIntra >= 25.0) || (ccMin <= 90 && graph->bwIntra >= 50.0)) { if (graph->pattern != NCCL_TOPO_PATTERN_NVLS && ((ccMin <= 80 && graph->bwIntra >= 25.0) || (ccMin <= 90 && graph->bwIntra >= 50.0))) {
int dupChannels = std::min(graph->nChannels*2, graph->maxChannels); int dupChannels = std::min(graph->nChannels*2, graph->maxChannels);
memcpy(graph->intra+graph->nChannels*ngpus, graph->intra, (dupChannels-graph->nChannels)*ngpus*sizeof(int)); memcpy(graph->intra+graph->nChannels*ngpus, graph->intra, (dupChannels-graph->nChannels)*ngpus*sizeof(int));
memcpy(graph->inter+graph->nChannels*2,graph->inter, (dupChannels-graph->nChannels)*2*sizeof(int)); memcpy(graph->inter+graph->nChannels*2,graph->inter, (dupChannels-graph->nChannels)*2*sizeof(int));
@ -943,23 +1000,40 @@ ncclResult_t ncclTopoDumpGraphs(struct ncclTopoSystem* system, int ngraphs, stru
return ncclSuccess; return ncclSuccess;
} }
#include "comm.h"
// NVLS channels aren't compute channels. Find which NIC corresponds to our rank being the head
ncclResult_t getNvlsNetDev(struct ncclComm* comm, struct ncclTopoGraph* graph, int* dev) {
int localRanks = comm->topo->nodes[GPU].count;
for (int c=0; c<graph->nChannels; c++) {
if (graph->intra[c*localRanks] == comm->rank) {
*dev = graph->inter[c*2];
return ncclSuccess;
}
}
WARN("Could not find NIC for rank %d in NVLS graph\n", comm->rank);
return ncclInternalError;
}
// 0: don't use PXN for P2P, 1: use PXN if needed, 2: use PXN as much as possible to maximize aggregation // 0: don't use PXN for P2P, 1: use PXN if needed, 2: use PXN as much as possible to maximize aggregation
NCCL_PARAM(P2pPxnLevel, "P2P_PXN_LEVEL", 2); NCCL_PARAM(P2pPxnLevel, "P2P_PXN_LEVEL", 2);
#include "comm.h"
ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoGraph* graph, int channelId, int peerRank, int* dev, int* proxyRank) { ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoGraph* graph, int channelId, int peerRank, int* dev, int* proxyRank) {
if (graph) { if (graph) {
// Honor the net device in the graph // Honor the net device in the graph
int channel = channelId%graph->nChannels; int channel = channelId%graph->nChannels;
int ngpus = comm->topo->nodes[GPU].count; int ngpus = comm->topo->nodes[GPU].count;
int index = graph->intra[channel*ngpus] == rank ? 0 : 1; int index = graph->intra[channel*ngpus] == rank ? 0 : 1;
*dev = graph->inter[channel*2+index]; if (graph->pattern != NCCL_TOPO_PATTERN_NVLS) {
*dev = graph->inter[channel*2+index];
} else {
NCCLCHECK(getNvlsNetDev(comm, graph, dev));
}
NCCLCHECK(ncclTopoGetIntermediateRank(comm->topo, rank, *dev, proxyRank)); NCCLCHECK(ncclTopoGetIntermediateRank(comm->topo, rank, *dev, proxyRank));
} else if (peerRank == -1) { } else if (peerRank == -1) {
return ncclInternalError; return ncclInternalError;
} else { } else {
// Start with our local NIC and local Rank // Start with our local NIC and local Rank
NCCLCHECK(ncclTopoGetLocalNet(comm->topo, rank, dev)); NCCLCHECK(ncclTopoGetLocalNet(comm->topo, rank, channelId, dev));
*proxyRank = rank; *proxyRank = rank;
int pxnLevel = ncclPxnDisable(comm) == 1 ? 0 : ncclParamP2pPxnLevel(); int pxnLevel = ncclPxnDisable(comm) == 1 ? 0 : ncclParamP2pPxnLevel();
@ -969,7 +1043,9 @@ ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoG
int cudaDev = comm->peerInfo[peerRank].cudaDev; int cudaDev = comm->peerInfo[peerRank].cudaDev;
int localRank; int localRank;
if (ncclTopoDevToRank(comm->topo, cudaDev, &localRank) != ncclSuccess) return ncclSuccess; if (ncclTopoDevToRank(comm->topo, cudaDev, &localRank) != ncclSuccess) return ncclSuccess;
int netDev = comm->peerInfo[localRank].netDev; int netDev;
NCCLCHECK(ncclTopoGetLocalNet(comm->topo, localRank, channelId, &netDev));
int n; int n;
// Check that device exists on our node // Check that device exists on our node
if (ncclParamCrossNic() == 0) { if (ncclParamCrossNic() == 0) {
@ -989,20 +1065,17 @@ ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoG
NCCLCHECK(ncclTopoGetIntermediateRank(comm->topo, rank, *dev, proxyRank)); NCCLCHECK(ncclTopoGetIntermediateRank(comm->topo, rank, *dev, proxyRank));
} }
} else if (pxnLevel == 2) { } else if (pxnLevel == 2) {
// Check whether we can access it through our node-local GPU for that NIC. // Check which local GPU corresponds to that NIC and see if we can use PXN.
for (int r=0; r<comm->localRanks; r++) { int n, g1, g2;
int peerRank = comm->localRankToRank[r]; NCCLCHECK(ncclTopoIdToIndex(comm->topo, NET, netDev, &n));
if (comm->peerInfo[peerRank].netDev == netDev) { NCCLCHECK(ncclTopoRankToIndex(comm->topo, rank, &g1));
int g1, g2, n; NCCLCHECK(ncclTopoGetLocalGpu(comm->topo, netDev, &g2));
NCCLCHECK(ncclTopoRankToIndex(comm->topo, rank, &g1)); if (g2 != -1) {
NCCLCHECK(ncclTopoRankToIndex(comm->topo, peerRank, &g2)); struct ncclTopoNode* peerGpu = comm->topo->nodes[GPU].nodes+g2;
NCCLCHECK(ncclTopoIdToIndex(comm->topo, NET, netDev, &n)); if (peerGpu->paths[GPU][g1].type <= PATH_NVL && peerGpu->paths[NET][n].type <= PATH_PXB) {
struct ncclTopoNode* peerGpu = comm->topo->nodes[GPU].nodes+g2; *proxyRank = peerGpu->gpu.rank;
if (peerGpu->paths[GPU][g1].type <= PATH_NVL && peerGpu->paths[NET][n].type <= PATH_PXB) { *dev = netDev;
*proxyRank = peerRank; return ncclSuccess;
*dev = netDev;
return ncclSuccess;
}
} }
} }
} }

View File

@ -646,11 +646,11 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy
} }
} }
if (netDevCount == 0) { if (netDevCount == 0) {
NCCLCHECK(ncclNetDevices(comm, &netDevCount)); NCCLCHECK(comm->ncclNet->devices(&netDevCount));
} }
for (int n=0; n<netDevCount; n++) { for (int n=0; n<netDevCount; n++) {
ncclNetProperties_t props; ncclNetProperties_t props;
NCCLCHECK(ncclNetGetProperties(comm, n, &props)); NCCLCHECK(comm->ncclNet->getProperties(n, &props));
struct ncclXmlNode* netNode; struct ncclXmlNode* netNode;
NCCLCHECK(ncclTopoFillNet(xml, props.pciPath, props.name, &netNode)); NCCLCHECK(ncclTopoFillNet(xml, props.pciPath, props.name, &netNode));
NCCLCHECK(xmlSetAttrInt(netNode, "keep", 1)); NCCLCHECK(xmlSetAttrInt(netNode, "keep", 1));
@ -679,10 +679,8 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t ncclTopoGetLocalNet(struct ncclTopoSystem* system, int rank, int* id) { static ncclResult_t getLocalNetMask(struct ncclTopoSystem* system, int g, uint64_t* localNetMask, int* type) {
int g; int minType = PATH_DIS;
NCCLCHECK(ncclTopoRankToIndex(system, rank, &g));
int minType = PATH_SYS;
float maxBw = 0; float maxBw = 0;
int count = 0; int count = 0;
int* nets; int* nets;
@ -692,20 +690,115 @@ ncclResult_t ncclTopoGetLocalNet(struct ncclTopoSystem* system, int rank, int* i
if (path->bw > maxBw || (path->bw == maxBw && path->type < minType)) { if (path->bw > maxBw || (path->bw == maxBw && path->type < minType)) {
maxBw = path->bw; maxBw = path->bw;
minType = path->type; minType = path->type;
if (type) *type = minType;
count = 0; count = 0;
} }
if (path->bw == maxBw && path->type == minType) nets[count++] = system->nodes[NET].nodes[n].id; if (path->bw == maxBw && path->type == minType) nets[count++] = system->nodes[NET].nodes[n].id;
} }
if (count == 0) {
*id = -1; *localNetMask = 0ULL;
free(nets); for (int n=0; n<count; n++) {
if (nets[n] >= 64) return ncclInternalError;
*localNetMask |= 1ULL<<nets[n];
}
free(nets);
return ncclSuccess;
}
ncclResult_t ncclTopoGetLocalNet(struct ncclTopoSystem* system, int rank, int channelId, int* id) {
uint64_t* localNetMasks;
int ngpus = system->nodes[GPU].count;
NCCLCHECK(ncclCalloc(&localNetMasks, ngpus));
// Fill localNetMasks for all GPUs.
for (int g=0; g<ngpus; g++) {
NCCLCHECK(getLocalNetMask(system, g, localNetMasks+g, NULL));
}
// Find GPUs which have the same mask as rank, i.e. share the same local Nets.
int gpu;
NCCLCHECK(ncclTopoRankToIndex(system, rank, &gpu));
int netLocalGpus = 0, netLocalGpu = 0;
for (int g=0; g<ngpus; g++) {
if (localNetMasks[g] == localNetMasks[gpu]) {
if (g == gpu) netLocalGpu = netLocalGpus;
netLocalGpus++;
}
}
uint64_t localNetMask = localNetMasks[gpu];
free(localNetMasks);
if (localNetMask == 0) return ncclInternalError;
// Round robin on GPUs and channels
int gIndex = 0, cId = 0, n = 0;
while (1) {
if (1ULL << n & localNetMask) {
if (gIndex == netLocalGpu && cId == channelId) {
*id = n;
return ncclSuccess;
}
gIndex++;
if (gIndex == netLocalGpus) {
gIndex = 0;
cId++;
}
}
n = (n+1) % 64;
}
}
ncclResult_t ncclTopoGetLocalGpu(struct ncclTopoSystem* system, int net, int* gpuIndex) {
int ngpus = system->nodes[GPU].count;
int* gpus;
NCCLCHECK(ncclCalloc(&gpus, ngpus));
// Find localNetMask which includes net with the most local GPUs.
int netLocalGpus = 0, minType = PATH_DIS;
uint64_t localNetMask = 0ULL;
for (int g=0; g<ngpus; g++) {
int type = PATH_DIS;
uint64_t mask;
NCCLCHECK(getLocalNetMask(system, g, &mask, &type));
if ((1ULL<<net) & mask) {
if (type < minType) {
localNetMask = mask;
netLocalGpus = 0;
minType = type;
}
if (type == minType) {
if (localNetMask && mask != localNetMask) {
WARN("Gpus %d and %d both have a type of %d with net %d yet have different netMasks of %lx and %lx\n", g, gpus[netLocalGpus-1], minType, net, mask, localNetMask);
free(gpus);
return ncclInternalError;
}
gpus[netLocalGpus] = g;
netLocalGpus++;
}
}
}
if (localNetMask == 0ULL) {
*gpuIndex = -1;
free(gpus);
return ncclSuccess; return ncclSuccess;
} }
int rr = system->nodes[GPU].nodes[g].gpu.dev; // Round robin on GPUs and channels
*id = nets[rr%count]; int gIndex = 0, cId = 0, n = 0;
free(nets); while (1) {
return ncclSuccess; if (1ULL << n & localNetMask) {
if (n == net) {
*gpuIndex = gpus[gIndex];
free(gpus);
return ncclSuccess;
}
gIndex++;
if (gIndex == netLocalGpus) {
gIndex = 0;
cId++;
}
}
n = (n+1) % 64;
}
} }
/****************************/ /****************************/
@ -785,6 +878,11 @@ ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t ncclTopoGetGpuCount(struct ncclTopoSystem* system, int* count) {
*count = system->nodes[GPU].count;
return ncclSuccess;
}
ncclResult_t ncclTopoGetNetCount(struct ncclTopoSystem* system, int* count) { ncclResult_t ncclTopoGetNetCount(struct ncclTopoSystem* system, int* count) {
*count = system->nodes[NET].count; *count = system->nodes[NET].count;
return ncclSuccess; return ncclSuccess;

View File

@ -12,12 +12,13 @@
#define LOC_BW 5000.0 #define LOC_BW 5000.0
#define SM60_NVLINK_BW 18.0 #define SM60_NVLINK_BW 18.0
#define SM70_NVLINK_BW 22.0 #define SM70_NVLINK_BW 20.0
#define SM80_NVLINK_BW 22.0 #define SM80_NVLINK_BW 20.0
#define SM90_NVLINK_BW 20.0
#define SM86_NVLINK_BW 12.0 #define SM86_NVLINK_BW 12.0
#define PCI_BW 12.0 // PCI Gen3 x16 #define PCI_BW 12.0 // PCI Gen3 x16
#define QPI_BW 6.0 #define QPI_BW 6.0
#define SKL_QPI_BW 9.0 #define SKL_QPI_BW 10.0
#define ZPI_BW 6.0 #define ZPI_BW 6.0
#define YONGFENG_ZPI_BW 9.0 #define YONGFENG_ZPI_BW 9.0
#define P9_BW 32.0 #define P9_BW 32.0
@ -72,7 +73,12 @@ extern const char* topoLinkTypeStr[];
// Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) // Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
#define PATH_SYS 7 #define PATH_SYS 7
#define PATH_DIS 7
// Connection through the network
#define PATH_NET 8
// Disconnected
#define PATH_DIS 9
extern const char* topoPathTypeStr[]; extern const char* topoPathTypeStr[];
struct ncclTopoNode; struct ncclTopoNode;
@ -195,6 +201,7 @@ static ncclResult_t ncclTopoDevToRank(struct ncclTopoSystem* system, int dev, in
// Returns NVLink bw in GB/s // Returns NVLink bw in GB/s
static float ncclTopoNVLinkBw(int cudaCompCap) { static float ncclTopoNVLinkBw(int cudaCompCap) {
return return
cudaCompCap >= 90 ? SM90_NVLINK_BW :
cudaCompCap == 86 ? SM86_NVLINK_BW : cudaCompCap == 86 ? SM86_NVLINK_BW :
cudaCompCap >= 80 ? SM80_NVLINK_BW : cudaCompCap >= 80 ? SM80_NVLINK_BW :
cudaCompCap >= 70 ? SM70_NVLINK_BW : cudaCompCap >= 70 ? SM70_NVLINK_BW :

View File

@ -53,26 +53,30 @@ ncclResult_t parseList(const char* str, const char* elems[], int nelems, int* li
// Latencies in us, Bandwidths in GB/s // Latencies in us, Bandwidths in GB/s
// Tree { LL, LL128, Simple } , Ring { LL, LL128, Simple } // Tree { LL, LL128, Simple } , Ring { LL, LL128, Simple }
static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 4.4, 4.4, 0 }, { 3.6, 10.0, 8.4 }, { 4.4, 4.4, 0 }, { 4.4, 4.4, 0 }, { 0, 0, 40.0 }}; static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = {
{ 6.8, 14.0, 0 }, { 6.6, 14.0, 8.4 }, // Tree, Ring
{ 6.8, 14.0, 0 }, { 6.8, 14.0, 0 }, // Collnet Direct, Chain
{ 0, 0, 23.0 }, { 0, 0, 23.0 }}; // NVLS, NVLS Tree
// NVLink, PCI, Network // NVLink, PCI, Network
#define NCCL_HW_NVLINK 0 #define NCCL_HW_NVLINK 0
#define NCCL_HW_PCI 1 #define NCCL_HW_PCI 1
#define NCCL_HW_NET 2 #define NCCL_HW_NET 2
// Tree/Simple is the latency a 256kB chunk, which is ~ base lat + 256k/12GB/s (+ 256k/12GB/s for the network). // Tree/Simple is the latency a 256kB chunk, which is ~ base lat + 256k/12GB/s (+ 256k/12GB/s for the network).
// Ring/LL128 reflects the latency for the second plateau, not the base latency.
static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
{ /* NVLINK */ { /* NVLINK */
{ /* Tree (LL/LL128/Simple)*/ { .52, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .47, 1.9, 3.4 }, { /* Tree (LL/LL128/Simple)*/ { .6, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .6, 1.9, 3.4 },
/* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 8.0 }, /* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 4.75 },
/* NVLS */ { 0, 0, 0 } }, /* NVLS */ { 0, 0, 0 }, /* NVLSTree */ { 0, 0, 0 } },
/* PCI */ /* PCI */
{ /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 }, { /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 },
/* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 8.0 }, /* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 8.0 },
/* NVLS */ { 0, 0, 0 } }, /* NVLS */ { 0, 0, 0 }, /* NVLSTree */ { 0, 0, 0 } },
/* NET */ /* NET */
{ /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 28 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 9.6 }, { /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 28 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 14.0 },
/* CollNetDirect (Simple)*/ { 0, 0, 10.7 }, /* CollNetChain (Simple)*/ { 0, 0, 10.7 }, /* CollNetDirect (Simple)*/ { 0, 0, 10.7 }, /* CollNetChain (Simple)*/ { 0, 0, 14 },
/* NVLS */ { 0, 0, 0 } } /* NVLS */ { 0, 0, 18 }, /* NVLSTree */ { 0, 0, 19 } }
}; };
/* Array indexes used below */ /* Array indexes used below */
@ -94,15 +98,28 @@ static const double perChMaxTreeBws[3][3] = {
/* Hopper (N1/N2/N4) */ {38.7, 41.4, 33.0}, /* Hopper (N1/N2/N4) */ {38.7, 41.4, 33.0},
}; };
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph) { // Network post overhead in ns (1000 = 1 us)
int simpleDefaultThreads = (ringGraph->bwIntra*ringGraph->nChannels <= PCI_BW) ? 256 : NCCL_SIMPLE_MAX_NTHREADS; NCCL_PARAM(NetOverhead, "NET_OVERHEAD", -2);
static float getNetOverhead(struct ncclComm* comm) {
if (ncclParamNetOverhead() != -2) return ncclParamNetOverhead() * .001;
int cpuArch, cpuVendor, cpuModel;
NCCLCHECK(ncclTopoCpuType(comm->topo, &cpuArch, &cpuVendor, &cpuModel));
if (cpuArch == NCCL_TOPO_CPU_ARCH_X86 && cpuVendor == NCCL_TOPO_CPU_VENDOR_INTEL) return 1.0;
if (cpuArch == NCCL_TOPO_CPU_ARCH_X86 && cpuVendor == NCCL_TOPO_CPU_VENDOR_AMD) return 2.0;
else return 1.0;
}
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph** graphs) {
int simpleDefaultThreads = (graphs[NCCL_ALGO_RING]->bwIntra*graphs[NCCL_ALGO_RING]->nChannels <= PCI_BW) ? 256 : NCCL_SIMPLE_MAX_NTHREADS;
comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_SIMPLE] = comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_SIMPLE] =
getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_SIMPLE_MAX_NTHREADS, simpleDefaultThreads); getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_SIMPLE_MAX_NTHREADS, simpleDefaultThreads);
comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] =
getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_SIMPLE_MAX_NTHREADS, NCCL_SIMPLE_MAX_NTHREADS); getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_SIMPLE_MAX_NTHREADS, NCCL_SIMPLE_MAX_NTHREADS);
comm->maxThreads[NCCL_ALGO_COLLNET_DIRECT][NCCL_PROTO_SIMPLE] = comm->maxThreads[NCCL_ALGO_COLLNET_DIRECT][NCCL_PROTO_SIMPLE] =
comm->maxThreads[NCCL_ALGO_COLLNET_CHAIN][NCCL_PROTO_SIMPLE] = comm->maxThreads[NCCL_ALGO_COLLNET_CHAIN][NCCL_PROTO_SIMPLE] =
comm->maxThreads[NCCL_ALGO_NVLS][NCCL_PROTO_SIMPLE] = NCCL_SIMPLE_MAX_NTHREADS; comm->maxThreads[NCCL_ALGO_NVLS][NCCL_PROTO_SIMPLE] =
comm->maxThreads[NCCL_ALGO_NVLS_TREE][NCCL_PROTO_SIMPLE] = NCCL_MAX_NTHREADS;
comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL] = comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL] =
getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_LL_MAX_NTHREADS, NCCL_LL_MAX_NTHREADS); getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_LL_MAX_NTHREADS, NCCL_LL_MAX_NTHREADS);
comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL128] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL128] = comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL128] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL128] =
@ -124,7 +141,6 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE]; if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE];
float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount
struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph, collNetGraph, ringGraph/* we only need the NVSwitch speed for NVLS*/ };
int intraHw[NCCL_NUM_ALGORITHMS], hw[NCCL_NUM_ALGORITHMS]; int intraHw[NCCL_NUM_ALGORITHMS], hw[NCCL_NUM_ALGORITHMS];
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) intraHw[a] = graphs[a]->typeIntra == LINK_NVL ? NCCL_HW_NVLINK : NCCL_HW_PCI; for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) intraHw[a] = graphs[a]->typeIntra == LINK_NVL ? NCCL_HW_NVLINK : NCCL_HW_PCI;
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) hw[a] = nNodes == 1 ? intraHw[a] : NCCL_HW_NET; for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) hw[a] = nNodes == 1 ? intraHw[a] : NCCL_HW_NET;
@ -140,18 +156,16 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) { for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) {
if (coll == ncclFuncBroadcast && a != NCCL_ALGO_RING) continue; if (coll == ncclFuncBroadcast && a != NCCL_ALGO_RING) continue;
if (coll == ncclFuncReduce && a != NCCL_ALGO_RING) continue; if (coll == ncclFuncReduce && a != NCCL_ALGO_RING) continue;
if (coll == ncclFuncReduceScatter && a != NCCL_ALGO_RING && a != NCCL_ALGO_NVLS) continue; if (coll == ncclFuncReduceScatter && a != NCCL_ALGO_RING) continue;
if (coll == ncclFuncAllGather && a != NCCL_ALGO_RING && a != NCCL_ALGO_NVLS) continue; if (coll == ncclFuncAllGather && a != NCCL_ALGO_RING) continue;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
if (a == NCCL_ALGO_NVLS && p != NCCL_PROTO_SIMPLE) continue; if ((a == NCCL_ALGO_NVLS || a == NCCL_ALGO_NVLS_TREE) && p != NCCL_PROTO_SIMPLE) continue;
int collnet = (a == NCCL_ALGO_COLLNET_DIRECT || a == NCCL_ALGO_COLLNET_CHAIN) ? 1 : 0; int collnet = (a == NCCL_ALGO_COLLNET_DIRECT || a == NCCL_ALGO_COLLNET_CHAIN) ? 1 : 0;
float bw = nNodes <= 2 || collnet ? graphs[a]->bwIntra : graphs[a]->bwInter; float bw = nNodes <= 2 || collnet ? graphs[a]->bwIntra : graphs[a]->bwInter;
float busBw = graphs[a]->nChannels * bw; float busBw = graphs[a]->nChannels * bw;
// Various model refinements // Various model refinements
if (compCapIndex == AMPERE_COMPCAP_IDX) busBw = std::min(busBw, 235.0f);
if (compCapIndex == HOPPER_COMPCAP_IDX) busBw = std::min(busBw, 370.0f);
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); } if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); }
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[compCapIndex]*graphs[a]->nChannels); if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[compCapIndex]*graphs[a]->nChannels);
if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw); if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw);
@ -165,30 +179,39 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
factor -= (factor-1)/2; factor -= (factor-1)/2;
busBw /= factor; busBw /= factor;
} }
if (a == NCCL_ALGO_COLLNET_CHAIN && p == NCCL_PROTO_SIMPLE) busBw *= .75; if (a == NCCL_ALGO_COLLNET_DIRECT && p == NCCL_PROTO_SIMPLE && minCompCap >= 90) busBw *= .85;
// Convert bus BW to algorithm BW // Convert bus BW to algorithm BW
float ratio; float ratio;
if (a == NCCL_ALGO_RING) ratio = (1.0 * nRanks) / nsteps; if (a == NCCL_ALGO_RING) ratio = (1.0 * nRanks) / nsteps;
else if (a == NCCL_ALGO_NVLS) ratio = .75; else if (a == NCCL_ALGO_NVLS) ratio = .75;
else if (a == NCCL_ALGO_NVLS_TREE) ratio = .70 * nNodes / (2*(nNodes-1));
else ratio = .5; else ratio = .5;
comm->bandwidths[coll][a][p] = busBw * ratio; comm->bandwidths[coll][a][p] = busBw * ratio;
comm->latencies[coll][a][p] = baseLat[a][p]; comm->latencies[coll][a][p] = baseLat[a][p];
float intraLat = hwLat[intraHw[a]][a][p]; float intraLat = hwLat[intraHw[a]][a][p];
float interLat = graphs[a]->latencyInter ? graphs[a]->latencyInter : hwLat[NCCL_HW_NET][a][p]; float interLat = hwLat[NCCL_HW_NET][a][p] + graphs[a]->latencyInter;
// Also add the flush extra latency
if (p == NCCL_PROTO_SIMPLE) interLat += graphs[a]->latencyInter;
if (nNodes > 1 && p == NCCL_PROTO_LL) intraLat *= 1.8;
if (a == NCCL_ALGO_RING) { if (a == NCCL_ALGO_RING) {
float lat = hwLat[hw[a]][a][p]; float lat = hwLat[hw[a]][a][p];
if ((coll == ncclFuncReduce || coll == ncclFuncBroadcast)) { if ((coll == ncclFuncReduce || coll == ncclFuncBroadcast)) {
if (ringGraph->sameChannels) { if (graphs[a]->sameChannels) {
comm->latencies[coll][a][p] += lat; comm->latencies[coll][a][p] += lat;
} else { } else {
if (p == NCCL_PROTO_SIMPLE) lat = hwLat[hw[a]][NCCL_ALGO_TREE][p]; // Add some chunk latency, waiting for proper chunk modeling if (p == NCCL_PROTO_SIMPLE) lat = hwLat[hw[a]][NCCL_ALGO_TREE][p]; // Add some chunk latency, waiting for proper chunk modeling
comm->latencies[coll][a][p] += nsteps*lat; comm->latencies[coll][a][p] += nsteps*lat;
} }
} else { } else {
// Inter-node rings still have to launch nsteps * net overhead.
float netOverhead = 0.0;
if (nNodes > 1) {
netOverhead = getNetOverhead(comm);
if (p == NCCL_PROTO_SIMPLE) netOverhead *= 3;
}
intraLat = std::max(intraLat, netOverhead);
comm->latencies[coll][a][p] += (nsteps-nInterSteps)*intraLat + nInterSteps*interLat; comm->latencies[coll][a][p] += (nsteps-nInterSteps)*intraLat + nInterSteps*interLat;
} }
} else if (a == NCCL_ALGO_TREE) { } else if (a == NCCL_ALGO_TREE) {
@ -198,7 +221,11 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
comm->latencies[coll][a][p] += comm->latencies[coll][a][p] +=
2 * (std::min(1, (nRanks/nNodes-1)) * intraLat + (nRanks/nNodes-1) * 0.5) + interLat; // Add 0.5 arity serialization latency 2 * (std::min(1, (nRanks/nNodes-1)) * intraLat + (nRanks/nNodes-1) * 0.5) + interLat; // Add 0.5 arity serialization latency
} else if (a == NCCL_ALGO_COLLNET_CHAIN) { } else if (a == NCCL_ALGO_COLLNET_CHAIN) {
comm->latencies[coll][a][p] += 2 * (nRanks/nNodes-1) * intraLat; comm->latencies[coll][a][p] += 2 * (nRanks/nNodes-1) * intraLat + interLat;
} else if (a == NCCL_ALGO_NVLS) {
if (nNodes > 1) comm->latencies[coll][a][p] += hwLat[NCCL_HW_NET][a][p];
} else if (a == NCCL_ALGO_NVLS_TREE) {
comm->latencies[coll][a][p] += 2*(nNodes-1)*hwLat[NCCL_HW_NET][a][p];
} }
} }
} }
@ -207,7 +234,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
// Protocols/Algorithms enable/disable, and user overrides. // Protocols/Algorithms enable/disable, and user overrides.
// All are enabled except ll128 which is enabled by default only in certain cases. // All are enabled except ll128 which is enabled by default only in certain cases.
int protoEnable[NCCL_NUM_PROTOCOLS] = { 1, 2, 1 }; int protoEnable[NCCL_NUM_PROTOCOLS] = { 1, 2, 1 };
int algoEnable[NCCL_NUM_ALGORITHMS] = { 1, 1, 1, 1, 1 }; int algoEnable[NCCL_NUM_ALGORITHMS] = { 1, 1, 1, 1, 1, 1 };
const char *protoStr = getenv("NCCL_PROTO"); const char *protoStr = getenv("NCCL_PROTO");
if (protoStr) { if (protoStr) {
@ -220,15 +247,16 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
NCCLCHECK(parseList(algoStr, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoEnable)); NCCLCHECK(parseList(algoStr, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoEnable));
} }
// Disable NVLink SHARP if not supported if (comm->nNodes == 1) algoEnable[NCCL_ALGO_NVLS_TREE] = 0;
if (comm->nvlsSupport == 0 /* || comm->localRanks <= 2*/) algoEnable[NCCL_ALGO_NVLS] = 0;
// Disable CollNet if it is not supported // Disable CollNet if it is not supported
if (comm->collNetSupport == 0) { if (comm->collNetSupport == 0) {
algoEnable[NCCL_ALGO_COLLNET_DIRECT] = 0; algoEnable[NCCL_ALGO_COLLNET_DIRECT] = 0;
algoEnable[NCCL_ALGO_COLLNET_CHAIN] = 0; algoEnable[NCCL_ALGO_COLLNET_CHAIN] = 0;
if (comm->nNodes > 1) algoEnable[NCCL_ALGO_NVLS] = 0;
// If user has hard set NCCL_ALGO=COLLNET, ignore it // If user has hard set NCCL_ALGO=COLLNET, ignore it
if (algoEnable[NCCL_ALGO_RING] == 0 && algoEnable[NCCL_ALGO_TREE] == 0) { if (algoEnable[NCCL_ALGO_RING] == 0 && algoEnable[NCCL_ALGO_TREE] == 0 &&
algoEnable[NCCL_ALGO_NVLS] == 0 && algoEnable[NCCL_ALGO_NVLS_TREE] == 0) {
algoEnable[NCCL_ALGO_RING] = algoEnable[NCCL_ALGO_TREE] = 1; algoEnable[NCCL_ALGO_RING] = algoEnable[NCCL_ALGO_TREE] = 1;
if (comm->rank == 0) WARN("CollNet is not supported or fails to initialize, ignoring NCCL_ALGO=COLLNET"); if (comm->rank == 0) WARN("CollNet is not supported or fails to initialize, ignoring NCCL_ALGO=COLLNET");
} }
@ -262,28 +290,38 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
if (comm->rank == 0) { if (comm->rank == 0) {
char line[1024]; char line[1024];
sprintf(line, "Latency/AlgBw |"); for (int block=0; block<2; block++) {
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) { sprintf(line, " Algorithm |");
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int ba=0; ba<NCCL_NUM_ALGORITHMS/2; ba++) {
sprintf(line+strlen(line), " %7s/%6s |", ncclAlgoStr[a], ncclProtoStr[p]); int a = block*NCCL_NUM_ALGORITHMS/2+ba;
sprintf(line+strlen(line), " %14s %14s %14s |", "", ncclAlgoStr[a], "");
} }
} INFO(NCCL_TUNING, "%s", line);
INFO(NCCL_TUNING, "%s", line); sprintf(line, " Protocol |");
sprintf(line, " Max NThreads |"); for (int ba=0; ba<NCCL_NUM_ALGORITHMS/2; ba++) {
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) {
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
sprintf(line+strlen(line), " %14d |", comm->maxThreads[a][p]);
}
}
INFO(NCCL_TUNING, "%s", line);
for (int c=0; c<NCCL_NUM_FUNCTIONS; c++) {
sprintf(line, "%13s |", ncclFuncStr[c]);
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) {
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
sprintf(line+strlen(line), "%8.1f/%6.1f |", comm->latencies[c][a][p], comm->bandwidths[c][a][p]); sprintf(line+strlen(line), " %14s |", ncclProtoStr[p]);
} }
} }
INFO(NCCL_TUNING, "%s", line); INFO(NCCL_TUNING, "%s", line);
sprintf(line, " Max NThreads |");
for (int ba=0; ba<NCCL_NUM_ALGORITHMS/2; ba++) {
int a = block*NCCL_NUM_ALGORITHMS/2+ba;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
sprintf(line+strlen(line), " %14d |", comm->maxThreads[a][p]);
}
}
INFO(NCCL_TUNING, "%s", line);
for (int c=0; c<NCCL_NUM_FUNCTIONS; c++) {
sprintf(line, "%13s |", ncclFuncStr[c]);
for (int ba=0; ba<NCCL_NUM_ALGORITHMS/2; ba++) {
int a = block*NCCL_NUM_ALGORITHMS/2+ba;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
sprintf(line+strlen(line), "%8.1f/%6.1f |", comm->latencies[c][a][p], comm->bandwidths[c][a][p]);
}
}
INFO(NCCL_TUNING, "%s", line);
}
} }
} }
@ -340,8 +378,8 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int proto
if (algorithm == NCCL_ALGO_TREE && logSize < 23) bw *= treeCorrectionFactor[protocol][logSize]; if (algorithm == NCCL_ALGO_TREE && logSize < 23) bw *= treeCorrectionFactor[protocol][logSize];
if (info->nChannels != 0) bw = bw / info->comm->nChannels * info->nChannels; if (info->nChannels != 0) bw = bw / info->comm->nChannels * info->nChannels;
if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && info->comm->nNodes > 1 if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && info->comm->nNodes > 1
&& info->coll == ncclFuncAllReduce && info->nBytes >= info->comm->nRanks/16.0*65536) { && info->coll == ncclFuncAllReduce && info->nBytes/(info->comm->nChannels*info->comm->nRanks) >= 64) {
lat *= info->comm->minCompCap < 90 ? 1.9 : 1.5; // Plateau effect of ring lat *= info->comm->minCompCap < 80 ? 1.9 : 1.4; // Plateau effect of ring
} }
// Tree pipelining saves latency in aggregation cases // Tree pipelining saves latency in aggregation cases
int latCount = algorithm == NCCL_ALGO_RING ? numPipeOps : DIVUP(numPipeOps, NCCL_MAX_WORK_ELEMENTS); int latCount = algorithm == NCCL_ALGO_RING ? numPipeOps : DIVUP(numPipeOps, NCCL_MAX_WORK_ELEMENTS);

View File

@ -46,8 +46,8 @@ ncclResult_t ncclAsyncLaunch(
/* check if there are blocking and nonblocking comms at the same time in group. */ /* check if there are blocking and nonblocking comms at the same time in group. */
if (ncclGroupBlocking == -1) { if (ncclGroupBlocking == -1) {
/* first met communicator */ /* first met communicator */
ncclGroupBlocking = comm->blocking; ncclGroupBlocking = comm->config.blocking;
} else if (ncclGroupBlocking != comm->blocking) { } else if (ncclGroupBlocking != comm->config.blocking) {
WARN("Blocking and nonblocking communicators are not allowed in the same group."); WARN("Blocking and nonblocking communicators are not allowed in the same group.");
ret = ncclInvalidArgument; ret = ncclInvalidArgument;
} }
@ -242,7 +242,7 @@ static void groupCleanup(struct ncclComm** groupCommHeadPtr, struct ncclComm** g
ncclIntruQueueConstruct(&comm->tasks.peers[i].recvQueue); ncclIntruQueueConstruct(&comm->tasks.peers[i].recvQueue);
} }
if (!comm->blocking) if (!comm->config.blocking)
(void) ncclCommSetAsyncError(comm, error); (void) ncclCommSetAsyncError(comm, error);
comm = next; comm = next;
} }
@ -251,7 +251,7 @@ static void groupCleanup(struct ncclComm** groupCommHeadPtr, struct ncclComm** g
while (!ncclIntruQueueEmpty(asyncJobsPtr)) { while (!ncclIntruQueueEmpty(asyncJobsPtr)) {
struct ncclAsyncJob* job = ncclIntruQueueDequeue(asyncJobsPtr); struct ncclAsyncJob* job = ncclIntruQueueDequeue(asyncJobsPtr);
*job->abortFlag = 1; *job->abortFlag = 1;
if (job->comm && !job->comm->blocking) if (job->comm && !job->comm->config.blocking)
(void) ncclCommSetAsyncError(job->comm, error); (void) ncclCommSetAsyncError(job->comm, error);
if (job->undo) job->undo(job); if (job->undo) job->undo(job);
if (job->destructor) job->destructor((void*)job); if (job->destructor) job->destructor((void*)job);
@ -346,7 +346,7 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) {
while (!ncclIntruQueueEmpty(asyncJobsMain)) { while (!ncclIntruQueueEmpty(asyncJobsMain)) {
struct ncclAsyncJob* job = ncclIntruQueueDequeue(asyncJobsMain); struct ncclAsyncJob* job = ncclIntruQueueDequeue(asyncJobsMain);
if (job->comm && !job->comm->blocking) if (job->comm && !job->comm->config.blocking)
(void) ncclCommSetAsyncError(job->comm, ret); (void) ncclCommSetAsyncError(job->comm, ret);
if (job->destructor) job->destructor((void*)job); if (job->destructor) job->destructor((void*)job);
} }
@ -355,7 +355,7 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) {
struct ncclComm* comm = groupCommHeadMain; struct ncclComm* comm = groupCommHeadMain;
struct ncclComm* next = comm->groupNext; struct ncclComm* next = comm->groupNext;
(void) ncclGroupCommLeave(comm); (void) ncclGroupCommLeave(comm);
if (!comm->blocking) { if (!comm->config.blocking) {
(void) ncclCommSetAsyncError(comm, ret); (void) ncclCommSetAsyncError(comm, ret);
} }
groupCommHeadMain = next; groupCommHeadMain = next;

View File

@ -13,6 +13,9 @@
#define ROUNDUP(x, y) \ #define ROUNDUP(x, y) \
(DIVUP((x), (y))*(y)) (DIVUP((x), (y))*(y))
#define ALIGN_POWER(x, y) \
((x) > (y) ? ROUNDUP(x, y) : ((y)/((y)/(x))))
#define ALIGN_SIZE(size, align) \ #define ALIGN_SIZE(size, align) \
size = ((size + (align) - 1) / (align)) * (align); size = ((size + (align) - 1) / (align)) * (align);

View File

@ -11,6 +11,7 @@
#include "checks.h" #include "checks.h"
#include "align.h" #include "align.h"
#include "utils.h" #include "utils.h"
#include "p2p.h"
#include <sys/mman.h> #include <sys/mman.h>
#include <unistd.h> #include <unistd.h>
#include <stdlib.h> #include <stdlib.h>
@ -72,13 +73,88 @@ ncclResult_t ncclRealloc(T** ptr, size_t oldNelem, size_t nelem) {
return ncclSuccess; return ncclSuccess;
} }
#if CUDART_VERSION >= 11030
#include <cuda.h>
#include "cudawrap.h"
static inline ncclResult_t ncclCuMemAlloc(void **ptr, CUmemGenericAllocationHandle *handlep, size_t size) {
ncclResult_t result = ncclSuccess;
size_t granularity = 0;
CUdevice currentDev;
CUmemAllocationProp prop = {};
CUmemAccessDesc accessDesc = {};
CUmemGenericAllocationHandle handle;
int cudaDev;
int flag = 0;
CUDACHECK(cudaGetDevice(&cudaDev));
CUCHECK(cuDeviceGet(&currentDev, cudaDev));
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.requestedHandleTypes = NCCL_P2P_HANDLE_TYPE; // So it can be exported
prop.location.id = currentDev;
// Query device to see if RDMA support is available
CUCHECK(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED, currentDev));
if (flag) prop.allocFlags.gpuDirectRDMACapable = 1;
CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
ALIGN_SIZE(size, granularity);
/* Allocate the physical memory on the device */
CUCHECK(cuMemCreate(&handle, size, &prop, 0));
/* Reserve a virtual address range */
CUCHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, 0, 0, 0));
/* Map the virtual address range to the physical allocation */
CUCHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
/* Now allow RW access to the newly mapped memory */
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = currentDev;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CUCHECK(cuMemSetAccess((CUdeviceptr)*ptr, size, &accessDesc, 1));
if (handlep) *handlep = handle;
TRACE(NCCL_ALLOC, "CuMem Alloc Size %zi pointer %p handle %llx", size, *ptr, handle);
return result;
}
static inline ncclResult_t ncclCuMemFree(void *ptr) {
if (ptr == NULL) return ncclSuccess;
ncclResult_t result = ncclSuccess;
CUmemGenericAllocationHandle handle;
size_t size = 0;
CUCHECK(cuMemRetainAllocationHandle(&handle, ptr));
CUCHECK(cuMemRelease(handle));
CUCHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));
TRACE(NCCL_ALLOC, "CuMem Free Size %zi pointer %p handle 0x%llx", size, ptr, handle);
CUCHECK(cuMemUnmap((CUdeviceptr)ptr, size));
CUCHECK(cuMemRelease(handle));
CUCHECK(cuMemAddressFree((CUdeviceptr)ptr, size));
return result;
}
#else
extern int ncclCuMemEnable();
static inline ncclResult_t ncclCuMemAlloc(void **ptr, void *handlep, size_t size) {
WARN("CUMEM not supported prior to CUDA 11.3");
return ncclInternalError;
}
static inline ncclResult_t ncclCuMemFree(void *ptr) {
WARN("CUMEM not supported prior to CUDA 11.3");
return ncclInternalError;
}
#endif
template <typename T> template <typename T>
ncclResult_t ncclCudaMallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) { ncclResult_t ncclCudaMallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
ncclResult_t result = ncclSuccess; ncclResult_t result = ncclSuccess;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
*ptr = nullptr; *ptr = nullptr;
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
CUDACHECKGOTO(cudaMalloc(ptr, nelem*sizeof(T)), result, finish); if (ncclCuMemEnable()) {
NCCLCHECKGOTO(ncclCuMemAlloc((void **)ptr, NULL, nelem*sizeof(T)), result, finish);
} else {
CUDACHECKGOTO(cudaMalloc(ptr, nelem*sizeof(T)), result, finish);
}
finish: finish:
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
if (*ptr == nullptr) WARN("Failed to CUDA malloc %ld bytes", nelem*sizeof(T)); if (*ptr == nullptr) WARN("Failed to CUDA malloc %ld bytes", nelem*sizeof(T));
@ -96,7 +172,11 @@ ncclResult_t ncclCudaCallocDebug(T** ptr, size_t nelem, const char *filefunc, in
// Need a side stream so as not to interfere with graph capture. // Need a side stream so as not to interfere with graph capture.
cudaStream_t stream; cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECKGOTO(cudaMalloc(ptr, nelem*sizeof(T)), result, finish); if (ncclCuMemEnable()) {
NCCLCHECKGOTO(ncclCuMemAlloc((void **)ptr, NULL, nelem*sizeof(T)), result, finish);
} else {
CUDACHECKGOTO(cudaMalloc(ptr, nelem*sizeof(T)), result, finish);
}
CUDACHECKGOTO(cudaMemsetAsync(*ptr, 0, nelem*sizeof(T), stream), result, finish); CUDACHECKGOTO(cudaMemsetAsync(*ptr, 0, nelem*sizeof(T), stream), result, finish);
CUDACHECKGOTO(cudaStreamSynchronize(stream), result, finish); CUDACHECKGOTO(cudaStreamSynchronize(stream), result, finish);
CUDACHECKGOTO(cudaStreamDestroy(stream), result, finish); CUDACHECKGOTO(cudaStreamDestroy(stream), result, finish);
@ -114,7 +194,11 @@ ncclResult_t ncclCudaCallocAsyncDebug(T** ptr, size_t nelem, cudaStream_t stream
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
*ptr = nullptr; *ptr = nullptr;
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
CUDACHECKGOTO(cudaMalloc(ptr, nelem*sizeof(T)), result, finish); if (ncclCuMemEnable()) {
NCCLCHECKGOTO(ncclCuMemAlloc((void **)ptr, NULL, nelem*sizeof(T)), result, finish);
} else {
CUDACHECKGOTO(cudaMalloc(ptr, nelem*sizeof(T)), result, finish);
}
CUDACHECKGOTO(cudaMemsetAsync(*ptr, 0, nelem*sizeof(T), stream), result, finish); CUDACHECKGOTO(cudaMemsetAsync(*ptr, 0, nelem*sizeof(T), stream), result, finish);
finish: finish:
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
@ -155,8 +239,13 @@ template <typename T>
ncclResult_t ncclCudaFree(T* ptr) { ncclResult_t ncclCudaFree(T* ptr) {
ncclResult_t result = ncclSuccess; ncclResult_t result = ncclSuccess;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
TRACE(NCCL_ALLOC, "Cuda Free pointer %p", ptr);
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
CUDACHECKGOTO(cudaFree(ptr), result, finish); if (ncclCuMemEnable()) {
NCCLCHECKGOTO(ncclCuMemFree((void *)ptr), result, finish);
} else {
CUDACHECKGOTO(cudaFree(ptr), result, finish);
}
finish: finish:
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
return result; return result;

View File

@ -20,6 +20,7 @@ ncclResult_t bootstrapNetInit();
ncclResult_t bootstrapCreateRoot(struct ncclBootstrapHandle* handle, bool idFromEnv); ncclResult_t bootstrapCreateRoot(struct ncclBootstrapHandle* handle, bool idFromEnv);
ncclResult_t bootstrapGetUniqueId(struct ncclBootstrapHandle* handle); ncclResult_t bootstrapGetUniqueId(struct ncclBootstrapHandle* handle);
ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* comm); ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* comm);
ncclResult_t bootstrapSplit(struct ncclBootstrapHandle* handle, struct ncclComm* comm, struct ncclComm* parent, int color, int key, int* parentRanks);
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size); ncclResult_t bootstrapAllGather(void* commState, void* allData, int size);
ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size); ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size);
ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size);

View File

@ -9,7 +9,9 @@
#include "comm.h" #include "comm.h"
ncclResult_t initChannel(struct ncclComm* comm, int channelid); ncclResult_t initChannel(struct ncclComm* comm, int channelid);
ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks); ncclResult_t initNvlsChannel(struct ncclComm* comm, int channelId, struct ncclComm* parent, bool share);
ncclResult_t initCollnetChannel(struct ncclComm* comm, int channelId, struct ncclComm* parent, bool share);
ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks, int collnetNRanks, int nvlsNRanks);
static ncclResult_t ncclChannelComputeBase(struct ncclComm* comm, int peer, int coll, int*channelBase) { static ncclResult_t ncclChannelComputeBase(struct ncclComm* comm, int peer, int coll, int*channelBase) {
int p2pGroupSize = NCCL_MAX_WORK_ELEMENTS_P2P/2; int p2pGroupSize = NCCL_MAX_WORK_ELEMENTS_P2P/2;
int peerNode = comm->rankToNode[peer]; int peerNode = comm->rankToNode[peer];

View File

@ -50,11 +50,12 @@ struct ncclDevRedOpFull {
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL128, devredop, type)) MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL128, devredop, type))
#define DECL3(func, devredop, type, undef) \ #define DECL3(func, devredop, type, undef) \
DECL4(func, RING, devredop, type, undef) \ DECL4(func, RING, devredop, type, undef) \
DECL4(func, TREE, devredop, type, undef) \ DECL4(func, TREE, devredop, type, undef) \
DECL4(func, COLLNET_DIRECT, devredop, type, undef) \ DECL4(func, COLLNET_DIRECT, devredop, type, undef) \
DECL4(func, COLLNET_CHAIN, devredop, type, undef) \ DECL4(func, COLLNET_CHAIN, devredop, type, undef) \
DECL4(func, NVLS, devredop, type, undef) DECL4(func, NVLS, devredop, type, undef) \
DECL4(func, NVLS_TREE, devredop, type, undef)
#if defined(__CUDA_BF16_TYPES_EXIST__) #if defined(__CUDA_BF16_TYPES_EXIST__)
#define DECL2(func, devredop, undefForFloat) \ #define DECL2(func, devredop, undefForFloat) \

View File

@ -96,18 +96,51 @@ struct ncclCommCallback {
ncclResult_t(*fn)(struct ncclComm* comm, struct ncclCommCallback* cb); ncclResult_t(*fn)(struct ncclComm* comm, struct ncclCommCallback* cb);
}; };
struct ncclSharedResources {
int refCount;
struct ncclComm* owner; /* comm which creates this shared res. */
struct ncclChannelPeer* peers[MAXCHANNELS];
struct ncclDevChannelPeer* devPeers[MAXCHANNELS];
/* P2P operation counter, one per channel */
uint64_t p2pOpCount[MAXCHANNELS];
/* Collective operation counter */
uint64_t collOpCount;
int tpNRanks;
int tpNLocalRanks;
int tpNChannels;
int tpP2pNChannels;
int tpP2pChunkSize;
uint64_t magic;
// top parent rank to localRank translation table
int* tpRankToLocalRank;
// Internal streams
struct ncclStrongStream deviceStream, hostStream;
/* proxy related shared res */
struct ncclProxyState* proxyState;
};
struct ncclChannel { struct ncclChannel {
struct ncclChannelPeer* peers; struct ncclChannelPeer** peers;
struct ncclDevChannelPeer* devPeers; struct ncclDevChannelPeer** devPeers;
struct ncclRing ring; struct ncclRing ring;
int* devRingUserRanks; int* devRingUserRanks;
struct ncclTree tree; struct ncclTree tree;
struct ncclTree collnetChain; struct ncclTree collnetChain;
struct ncclDirect collnetDirect; struct ncclDirect collnetDirect;
struct ncclNvls nvls; struct ncclNvls nvls;
int id; // index of this channel int id; // index of this channel
uint32_t workFifoSent; // last used work index+1 uint32_t workFifoSent; // last used work index+1
uint64_t p2pOpCount;
/* comm split sharable resources */
struct ncclChannelPeer* collnetPeers;
struct ncclDevChannelPeer* collnetDevPeers;
struct ncclChannelPeer* nvlsPeers;
struct ncclDevChannelPeer* nvlsDevPeers;
}; };
struct ncclWorkList { struct ncclWorkList {
@ -161,6 +194,10 @@ struct ncclComm {
// List of destructors to run when comm is destructed // List of destructors to run when comm is destructed
struct ncclDestructor* destructorHead; struct ncclDestructor* destructorHead;
struct ncclSharedResources* sharedRes;
/* map to top parent ranks. */
int* topParentRanks;
int* topParentLocalRanks;
struct ncclChannel channels[MAXCHANNELS]; struct ncclChannel channels[MAXCHANNELS];
struct ncclPeerInfo* peerInfo; struct ncclPeerInfo* peerInfo;
struct ncclTopoSystem* topo; struct ncclTopoSystem* topo;
@ -174,11 +211,12 @@ struct ncclComm {
uint64_t magic; // Magic number for all network communication. Not a security key -- only goal is to detect mismatches. uint64_t magic; // Magic number for all network communication. Not a security key -- only goal is to detect mismatches.
uint64_t commHash;
int rank; // my rank in the communicator int rank; // my rank in the communicator
int nRanks; // number of GPUs in communicator int nRanks; // number of GPUs in communicator
int cudaDev; // my cuda device index int cudaDev; // my cuda device index
int compCap; // compute capability of the GPU int compCap; // compute capability of the GPU
int minCompCap; // min compute capability in the communicator int minCompCap, maxCompCap; // min/max compute capability in the communicator
int64_t busId; // my PCI bus ID in int format int64_t busId; // my PCI bus ID in int format
cpu_set_t cpuAffinity; // CPU affinity of the GPU cpu_set_t cpuAffinity; // CPU affinity of the GPU
int cudaArch; // matches __CUDA_ARCH__ of device int cudaArch; // matches __CUDA_ARCH__ of device
@ -199,12 +237,11 @@ struct ncclComm {
// Counter for tracking CUDA launches (P2P and collectives included) // Counter for tracking CUDA launches (P2P and collectives included)
uint64_t opCount; uint64_t opCount;
// Collective operation counter
uint64_t collOpCount;
// Channels for collectives // Channels for collectives
int nChannels; int nChannels;
int nvlsChannels; int nvlsChannels;
int collNetChannels;
// Channels (per peer) for p2p // Channels (per peer) for p2p
int p2pnChannels; int p2pnChannels;
int p2pnChannelsPerPeer; int p2pnChannelsPerPeer;
@ -229,6 +266,8 @@ struct ncclComm {
// Flag to ask NCCL kernels to abort // Flag to ask NCCL kernels to abort
volatile uint32_t *abortFlag; volatile uint32_t *abortFlag;
volatile uint32_t *childAbortFlag;
uint32_t *abortFlagRefCount;
// Device side of the communicator (for cudaFree's) // Device side of the communicator (for cudaFree's)
struct ncclDevComm* devComm; // actually = &ncclDevCommAndChannels::comm struct ncclDevComm* devComm; // actually = &ncclDevCommAndChannels::comm
@ -255,21 +294,23 @@ struct ncclComm {
char intraPad2[64 - sizeof(uint64_t)]; char intraPad2[64 - sizeof(uint64_t)];
uint64_t intraBarrierGate; // only used if this is intraComm0 uint64_t intraBarrierGate; // only used if this is intraComm0
struct ncclProxyState proxyState; struct ncclProxyState* proxyState;
int proxyRefCountOld; /* store proxy post-atomic-sub refcount */
// Whether this communicator uses collNet // Whether this communicator uses collNet
int collNetSupport; int collNetSupport;
int intraHighestTransportType; int intraHighestTransportType;
int* collNetHeads;
int collNetHeadsNum;
/* sharable collNet proxy progress resource. */
struct ncclCollNetSharedRes* collNetSharedRes;
// NVLink SHARP (NVLS) support // NVLink SHARP (NVLS) support
int nvlsSupport; int nvlsSupport;
void* nvlsResources; /* sharable NVLS resource. */
struct ncclNvlsSharedRes* nvlsResources;
size_t channelSize; // User requested work size (bytes) for channel partitions size_t channelSize; // User requested work size (bytes) for channel partitions
// Internal streams
struct ncclStrongStream deviceStream, hostStream;
// pools backed by comm->memPermanent // pools backed by comm->memPermanent
struct ncclMemoryPool memPool_ncclProxyOp; struct ncclMemoryPool memPool_ncclProxyOp;
struct ncclMemoryPool memPool_ncclKernelPlan; struct ncclMemoryPool memPool_ncclKernelPlan;
@ -294,13 +335,7 @@ struct ncclComm {
// First of the unlaunched kernels in `planQueue` // First of the unlaunched kernels in `planQueue`
struct ncclKernelPlan* unlaunchedPlansHead; struct ncclKernelPlan* unlaunchedPlansHead;
// communicator mode ncclConfig_t config;
int blocking;
// CGA cluster size
int cgaClusterSize;
int minCTAs, maxCTAs;
// network interface name
char *netName;
// initState is to more conveniently reclaim resources when errors happen. // initState is to more conveniently reclaim resources when errors happen.
ncclResult_t initState; ncclResult_t initState;
// flag to indicate if ncclCommFinalize() is called // flag to indicate if ncclCommFinalize() is called

View File

@ -11,6 +11,9 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "checks.h" #include "checks.h"
// Is cuMem API usage enabled
extern int ncclCuMemEnable();
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#else #else
@ -85,6 +88,7 @@ DECLARE_CUDA_PFN_EXTERN(cuMemExportToShareableHandle, 10020);
DECLARE_CUDA_PFN_EXTERN(cuMemImportFromShareableHandle, 10020); DECLARE_CUDA_PFN_EXTERN(cuMemImportFromShareableHandle, 10020);
DECLARE_CUDA_PFN_EXTERN(cuMemMap, 10020); DECLARE_CUDA_PFN_EXTERN(cuMemMap, 10020);
DECLARE_CUDA_PFN_EXTERN(cuMemRelease, 10020); DECLARE_CUDA_PFN_EXTERN(cuMemRelease, 10020);
DECLARE_CUDA_PFN_EXTERN(cuMemRetainAllocationHandle, 11000);
DECLARE_CUDA_PFN_EXTERN(cuMemSetAccess, 10020); DECLARE_CUDA_PFN_EXTERN(cuMemSetAccess, 10020);
DECLARE_CUDA_PFN_EXTERN(cuMemUnmap, 10020); DECLARE_CUDA_PFN_EXTERN(cuMemUnmap, 10020);
#if CUDA_VERSION >= 11070 #if CUDA_VERSION >= 11070

View File

@ -15,12 +15,13 @@
typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclFuncSend, ncclFuncRecv, ncclNumFuncs} ncclFunc_t; typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclFuncSend, ncclFuncRecv, ncclNumFuncs} ncclFunc_t;
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS]; extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS];
#define NCCL_NUM_ALGORITHMS 5 // Tree/Ring/CollNet* #define NCCL_NUM_ALGORITHMS 6 // Tree/Ring/CollNet*
#define NCCL_ALGO_TREE 0 #define NCCL_ALGO_TREE 0
#define NCCL_ALGO_RING 1 #define NCCL_ALGO_RING 1
#define NCCL_ALGO_COLLNET_DIRECT 2 #define NCCL_ALGO_COLLNET_DIRECT 2
#define NCCL_ALGO_COLLNET_CHAIN 3 #define NCCL_ALGO_COLLNET_CHAIN 3
#define NCCL_ALGO_NVLS 4 #define NCCL_ALGO_NVLS 4
#define NCCL_ALGO_NVLS_TREE 5
extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS]; extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS];
#define NCCL_NUM_PROTOCOLS 3 // Simple/LL/LL128 #define NCCL_NUM_PROTOCOLS 3 // Simple/LL/LL128
@ -100,10 +101,10 @@ struct ncclConnInfo {
}; };
struct ncclProxyConnector { struct ncclProxyConnector {
int rank; int tpRank;
int localRank; int tpLocalRank;
int sameProcess;
struct ncclProxyConnection* connection; struct ncclProxyConnection* connection;
struct ncclComm* comm;
}; };
struct ncclConnector { struct ncclConnector {
@ -112,7 +113,6 @@ struct ncclConnector {
struct ncclTransportComm* transportComm; struct ncclTransportComm* transportComm;
void* transportResources; void* transportResources;
struct ncclConnInfo conn; struct ncclConnInfo conn;
struct ncclComm *comm;
}; };
struct ncclRing { struct ncclRing {
@ -148,18 +148,24 @@ struct ncclDirect {
}; };
#define NCCL_MAX_NVLS_ARITY 8 #define NCCL_MAX_NVLS_ARITY 8
#define NCCL_MAX_NVLS_TREE_ARITY 3
struct ncclNvls { struct ncclNvls {
int out; int out;
int nHeads; // Number of parallel N<->1<->net operations we'll do in parallel; size of up/down int nHeads; // Number of parallel N<->1<->net operations we'll do in parallel; size of up/down
int headRank; // Index in 0..nHeads-1 I am the head rank of. -1 if I'm not a head rank (no local NIC) int headRank; // Index in 0..nHeads-1 I am the head rank of. -1 if I'm not a head rank (no local NIC)
int up[NCCL_MAX_NVLS_ARITY]; int up[NCCL_MAX_NVLS_ARITY];
int down; int down;
int treeUp;
int treeDown[NCCL_MAX_NVLS_TREE_ARITY];
int node;
int nNodes;
}; };
#define NCCL_MAX_CONNS 2 #define NCCL_MAX_CONNS 2
struct ncclChannelPeer { struct ncclChannelPeer {
struct ncclConnector send[NCCL_MAX_CONNS]; struct ncclConnector send[NCCL_MAX_CONNS];
struct ncclConnector recv[NCCL_MAX_CONNS]; struct ncclConnector recv[NCCL_MAX_CONNS];
int refCount;
}; };
struct ncclDevComm; struct ncclDevComm;
@ -270,7 +276,7 @@ struct ncclDevChannelPeer {
}; };
struct alignas(16) ncclDevChannel { struct alignas(16) ncclDevChannel {
struct ncclDevChannelPeer *peers; struct ncclDevChannelPeer** peers;
struct ncclRing ring; struct ncclRing ring;
struct ncclTree tree; struct ncclTree tree;
struct ncclTree collnetChain; struct ncclTree collnetChain;

View File

@ -243,7 +243,7 @@ static ncclResult_t ncclGdrCudaFree(void* gdrHandle) {
gdr_mem_desc_t *md = (gdr_mem_desc_t*)gdrHandle; gdr_mem_desc_t *md = (gdr_mem_desc_t*)gdrHandle;
NCCLCHECK(wrap_gdr_unmap(ncclGdrCopy, md->gdrMh, md->gdrMap, md->gdrMapSize)); NCCLCHECK(wrap_gdr_unmap(ncclGdrCopy, md->gdrMh, md->gdrMap, md->gdrMapSize));
NCCLCHECK(wrap_gdr_unpin_buffer(ncclGdrCopy, md->gdrMh)); NCCLCHECK(wrap_gdr_unpin_buffer(ncclGdrCopy, md->gdrMh));
CUDACHECK(cudaFree(md->gdrDevMem)); NCCLCHECK(ncclCudaFree(md->gdrDevMem));
free(md); free(md);
return ncclSuccess; return ncclSuccess;

View File

@ -53,9 +53,11 @@ ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu
#define NCCL_TOPO_CPU_TYPE_SKL 2 #define NCCL_TOPO_CPU_TYPE_SKL 2
#define NCCL_TOPO_CPU_TYPE_YONGFENG 1 #define NCCL_TOPO_CPU_TYPE_YONGFENG 1
ncclResult_t ncclTopoCpuType(struct ncclTopoSystem* system, int* arch, int* vendor, int* model); ncclResult_t ncclTopoCpuType(struct ncclTopoSystem* system, int* arch, int* vendor, int* model);
ncclResult_t ncclTopoGetNetCount(struct ncclTopoSystem* system, int* count); ncclResult_t ncclTopoGetGpuCount(struct ncclTopoSystem* system, int* count);
ncclResult_t ncclTopoGetNvsCount(struct ncclTopoSystem* system, int* count); ncclResult_t ncclTopoGetNvsCount(struct ncclTopoSystem* system, int* count);
ncclResult_t ncclTopoGetLocalNet(struct ncclTopoSystem* system, int rank, int* id); ncclResult_t ncclTopoGetNvsCount(struct ncclTopoSystem* system, int* count);
ncclResult_t ncclTopoGetLocalNet(struct ncclTopoSystem* system, int rank, int channelId, int* id);
ncclResult_t ncclTopoGetLocalGpu(struct ncclTopoSystem* system, int net, int* gpuIndex);
#define NCCL_TOPO_MAX_NODES 256 #define NCCL_TOPO_MAX_NODES 256
@ -66,6 +68,7 @@ ncclResult_t ncclTopoSearchInit(struct ncclTopoSystem* system);
#define NCCL_TOPO_PATTERN_SPLIT_TREE 2 // Spread NIC traffic between two GPUs (Tree parent on first GPU, tree children on the second GPU) #define NCCL_TOPO_PATTERN_SPLIT_TREE 2 // Spread NIC traffic between two GPUs (Tree parent on first GPU, tree children on the second GPU)
#define NCCL_TOPO_PATTERN_TREE 3 // All NIC traffic going to/from the same GPU #define NCCL_TOPO_PATTERN_TREE 3 // All NIC traffic going to/from the same GPU
#define NCCL_TOPO_PATTERN_RING 4 // Ring #define NCCL_TOPO_PATTERN_RING 4 // Ring
#define NCCL_TOPO_PATTERN_NVLS 5 // NVLS+SHARP and NVLS+Tree
struct ncclTopoGraph { struct ncclTopoGraph {
// Input / output // Input / output
int id; // ring : 0, tree : 1, collnet : 2 int id; // ring : 0, tree : 1, collnet : 2
@ -99,16 +102,15 @@ struct ncclTopoRanks {
int treeToParent[MAXCHANNELS]; int treeToParent[MAXCHANNELS];
int treeToChild0[MAXCHANNELS]; int treeToChild0[MAXCHANNELS];
int treeToChild1[MAXCHANNELS]; int treeToChild1[MAXCHANNELS];
int nvlsHeads[MAXCHANNELS];
}; };
ncclResult_t ncclTopoPreset(struct ncclComm* comm, ncclResult_t ncclTopoPreset(struct ncclComm* comm, struct ncclTopoGraph** graphs, struct ncclTopoRanks* topoRanks);
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph,
struct ncclTopoRanks* topoRanks);
ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns, ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns,
struct ncclTopoRanks** allTopoRanks, int* rings, struct ncclTopoGraph* collNetGraph); struct ncclTopoRanks** allTopoRanks, int* rings, struct ncclTopoGraph** graphs);
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph); ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph** graphs);
#include "info.h" #include "info.h"
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time); ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time);

View File

@ -95,7 +95,7 @@ inline void ncclGroupCommJoin(struct ncclComm* comm) {
ncclMemoryStackPush(&comm->memScoped); ncclMemoryStackPush(&comm->memScoped);
} }
ncclGroupBlocking = comm->blocking; ncclGroupBlocking = comm->config.blocking;
} }
// Add comm to this thread's group needing preconnect // Add comm to this thread's group needing preconnect

1043
src/include/ibvcore.h Normal file

File diff suppressed because it is too large Load Diff

44
src/include/ibvsymbols.h Normal file
View File

@ -0,0 +1,44 @@
#ifndef NCCL_IBV_SYMBOLS_H_
#define NCCL_IBV_SYMBOLS_H_
#ifdef NCCL_BUILD_RDMA_CORE
#include <infiniband/verbs.h>
#else
#include "ibvcore.h"
#endif
#include "nccl.h"
/* IB Verbs Function Pointers*/
struct ncclIbvSymbols {
int (*ibv_internal_fork_init)(void);
struct ibv_device** (*ibv_internal_get_device_list)(int *num_devices);
void (*ibv_internal_free_device_list)(struct ibv_device **list);
const char * (*ibv_internal_get_device_name)(struct ibv_device *device);
struct ibv_context* (*ibv_internal_open_device)(struct ibv_device* device);
int (*ibv_internal_close_device)(struct ibv_context *context);
int (*ibv_internal_get_async_event)(struct ibv_context *context, struct ibv_async_event *event);
void (*ibv_internal_ack_async_event)(struct ibv_async_event *event);
int (*ibv_internal_query_device)(struct ibv_context *context, struct ibv_device_attr *device_attr);
int (*ibv_internal_query_port)(struct ibv_context *context, uint8_t port_num, struct ibv_port_attr *port_attr);
int (*ibv_internal_query_gid)(struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid);
int (*ibv_internal_query_qp)(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask, struct ibv_qp_init_attr *init_attr);
struct ibv_pd * (*ibv_internal_alloc_pd)(struct ibv_context *context);
int (*ibv_internal_dealloc_pd)(struct ibv_pd *pd);
struct ibv_mr * (*ibv_internal_reg_mr)(struct ibv_pd *pd, void *addr, size_t length, int access);
struct ibv_mr * (*ibv_internal_reg_mr_iova2)(struct ibv_pd *pd, void *addr, size_t length, uint64_t iova, unsigned int access);
/* DMA-BUF support */
struct ibv_mr * (*ibv_internal_reg_dmabuf_mr)(struct ibv_pd *pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access);
int (*ibv_internal_dereg_mr)(struct ibv_mr *mr);
struct ibv_cq * (*ibv_internal_create_cq)(struct ibv_context *context, int cqe, void *cq_context, struct ibv_comp_channel *channel, int comp_vector);
int (*ibv_internal_destroy_cq)(struct ibv_cq *cq);
struct ibv_qp * (*ibv_internal_create_qp)(struct ibv_pd *pd, struct ibv_qp_init_attr *qp_init_attr);
int (*ibv_internal_modify_qp)(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask);
int (*ibv_internal_destroy_qp)(struct ibv_qp *qp);
const char * (*ibv_internal_event_type_str)(enum ibv_event_type event);
};
/* Constructs IB verbs symbols per rdma-core linking or dynamic loading mode */
ncclResult_t buildIbvSymbols(struct ncclIbvSymbols* ibvSymbols);
#endif // NCCL_IBV_SYMBOLS_H_

File diff suppressed because it is too large Load Diff

View File

@ -25,6 +25,7 @@ typedef enum : uint8_t {
ncclPatternCollnetChain, ncclPatternCollnetChain,
ncclPatternCollnetDirect, ncclPatternCollnetDirect,
ncclPatternNvls, ncclPatternNvls,
ncclPatternNvlsTree,
ncclPatternSend, ncclPatternSend,
ncclPatternRecv ncclPatternRecv
} ncclPattern_t; } ncclPattern_t;
@ -93,7 +94,6 @@ struct ncclCudaStreamList {
struct ncclCudaStreamList *next; struct ncclCudaStreamList *next;
cudaStream_t stream; cudaStream_t stream;
}; };
struct ncclTasks { struct ncclTasks {
struct Peer { struct Peer {
bool sendSeen, recvSeen; bool sendSeen, recvSeen;
@ -103,7 +103,8 @@ struct ncclTasks {
struct ncclIntruQueue<ncclTaskColl, &ncclTaskColl::next> collQueue; struct ncclIntruQueue<ncclTaskColl, &ncclTaskColl::next> collQueue;
size_t collBytesTotal; size_t collBytesTotal;
struct Peer* peers/*[nRanks]*/; struct Peer* peers/*[nRanks]*/;
int *p2pSendOrder/*[nRanks]*/, *p2pRecvOrder/*[nRanks]*/; int *p2pSendOrder, *p2pRecvOrder;
int p2pOrderSteps;
int nTasksColl, nTasksP2p; int nTasksColl, nTasksP2p;
// The list of user streams aggregated over all tasks present. // The list of user streams aggregated over all tasks present.

View File

@ -18,25 +18,6 @@ ncclResult_t ncclNetPluginInit();
ncclResult_t ncclNetInit(struct ncclComm* comm); ncclResult_t ncclNetInit(struct ncclComm* comm);
int ncclNetVersion(struct ncclComm* comm); int ncclNetVersion(struct ncclComm* comm);
// Translation to external API
static const char* ncclNetName(struct ncclComm* comm) { return comm->ncclNet->name; }
static ncclResult_t ncclNetDevices(struct ncclComm* comm, int* ndev) { NCCLCHECK(comm->ncclNet->devices(ndev)); return ncclSuccess; }
static ncclResult_t ncclNetGetProperties(struct ncclComm* comm, int dev, ncclNetProperties_t* props) { NCCLCHECK(comm->ncclNet->getProperties(dev, props)); return ncclSuccess; }
static ncclResult_t ncclNetListen(struct ncclComm* comm, int dev, void* handle, void** listenComm) { NCCLCHECK(comm->ncclNet->listen(dev, handle, listenComm)); return ncclSuccess; }
static ncclResult_t ncclNetConnect(struct ncclComm* comm, int dev, void* handle, void** sendComm) { NCCLCHECK(comm->ncclNet->connect(dev, handle, sendComm)); return ncclSuccess; }
static ncclResult_t ncclNetAccept(struct ncclComm* comm, void* listenComm, void** recvComm) { NCCLCHECK(comm->ncclNet->accept(listenComm, recvComm)); return ncclSuccess; }
static ncclResult_t ncclNetRegMr(struct ncclComm* comm, void* netComm, void* data, int size, int type, void** mhandle) { NCCLCHECK(comm->ncclNet->regMr(netComm, data, size, type, mhandle)); return ncclSuccess; }
/* DMA-BUF support */
static ncclResult_t ncclNetRegMrDmaBuf(struct ncclComm* comm, void* netComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { NCCLCHECK(comm->ncclNet->regMrDmaBuf(netComm, data, size, type, offset, fd, mhandle)); return ncclSuccess; }
static ncclResult_t ncclNetDeregMr(struct ncclComm* comm, void* netComm, void* mhandle) { NCCLCHECK(comm->ncclNet->deregMr(netComm, mhandle)); return ncclSuccess; }
static ncclResult_t ncclNetIsend(struct ncclComm* comm, void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { NCCLCHECK(comm->ncclNet->isend(sendComm, data, size, tag, mhandle, request)); return ncclSuccess; }
static ncclResult_t ncclNetIrecv(struct ncclComm* comm, void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { NCCLCHECK(comm->ncclNet->irecv(recvComm, n, data, sizes, tags, mhandles, request)); return ncclSuccess; }
static ncclResult_t ncclNetIflush(struct ncclComm* comm, void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { NCCLCHECK(comm->ncclNet->iflush(recvComm, n, data, sizes, mhandles, request)); return ncclSuccess; }
static ncclResult_t ncclNetTest(struct ncclComm* comm, void* request, int* done, int* sizes) { NCCLCHECK(comm->ncclNet->test(request, done, sizes)); return ncclSuccess; }
static ncclResult_t ncclNetCloseSend(struct ncclComm* comm, void* sendComm) { NCCLCHECK(comm->ncclNet->closeSend(sendComm)); return ncclSuccess; }
static ncclResult_t ncclNetCloseRecv(struct ncclComm* comm, void* recvComm) { NCCLCHECK(comm->ncclNet->closeRecv(recvComm)); return ncclSuccess; }
static ncclResult_t ncclNetCloseListen(struct ncclComm* comm, void* listenComm) { NCCLCHECK(comm->ncclNet->closeListen(listenComm)); return ncclSuccess; }
// Test whether the current GPU support GPU Direct RDMA. // Test whether the current GPU support GPU Direct RDMA.
ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport); ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport);

View File

@ -1,30 +1,33 @@
/*
* Copyright 2021-2023 NVIDIA Corporation. All rights reserved.
*
* Licensed under the Apache License v2.0 with LLVM Exceptions.
* See https://llvm.org/LICENSE.txt for license information.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/
#ifndef NVTX_EXT_IMPL_PAYLOAD_GUARD #ifndef NVTX_EXT_IMPL_PAYLOAD_GUARD
#error Never include this file directly -- it is automatically included by nvToolsExtPayload.h (except when NVTX_NO_IMPL is defined). #error Never include this file directly -- it is automatically included by nvToolsExtPayload.h (except when NVTX_NO_IMPL is defined).
#endif #endif
/*
* Helper array to get the alignment for each predefined C language type.
*/
typedef void* pointer_type; typedef void* pointer_type;
#if __STDC_VERSION__ >= 201112L /* or CPP11 */ #if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L)
#include <uchar.h>
#include <stdalign.h> #include <stdalign.h>
#endif
/* `alignof` is available as of C11 or C++11 */
#if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || (defined(__cplusplus) && __cplusplus >= 201103L)
#define nvtx_alignof(type) alignof(type) #define nvtx_alignof(type) alignof(type)
#define nvtx_alignof2(type,tname) alignof(type) #define nvtx_alignof2(type,tname) alignof(type)
#else /* __STDC_VERSION__ >= 201112L */
#ifndef __cplusplus
#include <stddef.h> #else /* (__STDC_VERSION__ >= 201112L) || (__cplusplus >= 201103L) */
#define nvtx_alignof(type) offsetof(struct {char c; type d;}, d)
#define nvtx_alignof2(type,tname) nvtx_alignof(type)
#else /* __cplusplus */ /* Create helper structs to determine type alignment. */
#define MKTYPEDEF(type) typedef struct {char c; type d;} _nvtx_##type
#define MKTYPEDEF(TYPE) typedef struct {char c; TYPE d;} _nvtx_##TYPE #define MKTYPEDEF2(type,tname) typedef struct {char c; type d;} _nvtx_##tname
#define MKTYPEDEF2(TYPE,TNAME) typedef struct {char c; TYPE d;} _nvtx_##TNAME
#define nvtx_alignof(TNAME) offsetof(_nvtx_##TNAME, d)
#define nvtx_alignof2(type,tname) offsetof(_nvtx_##tname, d)
MKTYPEDEF(char); MKTYPEDEF(char);
MKTYPEDEF2(unsigned char, uchar); MKTYPEDEF2(unsigned char, uchar);
@ -54,22 +57,33 @@ MKTYPEDEF(size_t);
MKTYPEDEF(pointer_type); MKTYPEDEF(pointer_type);
MKTYPEDEF(wchar_t); MKTYPEDEF(wchar_t);
#if (__STDC_VERSION__ > 201710L) || (defined(__cplusplus) && __cplusplus > 201703L)
{sizeof(char8_t), nvtx_alignof(char8_t)}, /* `char8_t` is available as of C++20 or C23 */
#if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 202311L) || (defined(__cplusplus) && __cplusplus >= 201811L)
MKTYPEDEF(char8_t); MKTYPEDEF(char8_t);
#endif #endif
#if (__STDC_VERSION__ >= 201112L) || (defined(__cplusplus) && __cplusplus >= 201103L)
/* `char16_t` and `char32_t` are available as of C++11 or C11 */
#if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) || (defined(__cplusplus) && __cplusplus >= 200704L)
MKTYPEDEF(char16_t); MKTYPEDEF(char16_t);
MKTYPEDEF(char32_t); MKTYPEDEF(char32_t);
#endif #endif
/* C requires to include stddef.h to use `offsetof` */
#ifndef __cplusplus
#include <stddef.h>
#endif
#define nvtx_alignof(tname) offsetof(_nvtx_##tname, d)
#define nvtx_alignof2(type, tname) offsetof(_nvtx_##tname, d)
#endif /* __STDC_VERSION__ >= 201112L */
#undef MKTYPEDEF #undef MKTYPEDEF
#undef MKTYPEDEF2 #undef MKTYPEDEF2
#endif /* __cplusplus */
#endif /* __STDC_VERSION__ >= 201112L */
/* /*
* Helper array to get the alignment for each predefined C/C++ language type.
* The order of entries must match the values in`enum nvtxPayloadSchemaEntryType`. * The order of entries must match the values in`enum nvtxPayloadSchemaEntryType`.
*/ */
const nvtxPayloadEntryTypeInfo_t nvtxExtPayloadTypeInfo[NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE] = const nvtxPayloadEntryTypeInfo_t nvtxExtPayloadTypeInfo[NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE] =
@ -109,13 +123,14 @@ const nvtxPayloadEntryTypeInfo_t nvtxExtPayloadTypeInfo[NVTX_PAYLOAD_ENTRY_TYPE_
/*** Special character types ***/ /*** Special character types ***/
/* NVTX_PAYLOAD_ENTRY_TYPE_WCHAR */ {sizeof(wchar_t), nvtx_alignof(wchar_t)}, /* NVTX_PAYLOAD_ENTRY_TYPE_WCHAR */ {sizeof(wchar_t), nvtx_alignof(wchar_t)},
/* NVTX_PAYLOAD_ENTRY_TYPE_CHAR8 */
#if (__STDC_VERSION__ > 201710L) || (defined(__cplusplus) && __cplusplus > 201703L) #if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 202311L) || (defined(__cplusplus) && __cplusplus >= 201811L)
{sizeof(char8_t), nvtx_alignof(char8_t)}, /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR8 */ {sizeof(char8_t), nvtx_alignof(char8_t)},
#else #else
{0, 0}, /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR8 */ {0, 0},
#endif #endif
#if (__STDC_VERSION__ >= 201112L) || (defined(__cplusplus) && __cplusplus >= 201103L)
#if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) || (defined(__cplusplus) && __cplusplus >= 200704L)
/* NVTX_PAYLOAD_ENTRY_TYPE_CHAR16 */ {sizeof(char16_t), nvtx_alignof(char16_t)}, /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR16 */ {sizeof(char16_t), nvtx_alignof(char16_t)},
/* NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 */ {sizeof(char32_t), nvtx_alignof(char32_t)} /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 */ {sizeof(char32_t), nvtx_alignof(char32_t)}
#else #else
@ -125,4 +140,4 @@ const nvtxPayloadEntryTypeInfo_t nvtxExtPayloadTypeInfo[NVTX_PAYLOAD_ENTRY_TYPE_
}; };
#undef nvtx_alignof #undef nvtx_alignof
#undef nvtx_alignof2 #undef nvtx_alignof2

View File

@ -9,4 +9,21 @@
#ifndef NCCL_P2P_H_ #ifndef NCCL_P2P_H_
#define NCCL_P2P_H_ #define NCCL_P2P_H_
#define NCCL_P2P_HANDLE_TYPE CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
typedef struct {
int data; // Currently only support an fd based descriptor
} ncclCuDesc;
typedef union {
// Legacy CUDA IPC
cudaIpcMemHandle_t devIpc;
// cuMem API support
ncclCuDesc cuDesc;
} ncclIpcDesc;
ncclResult_t ncclP2pAllocateShareableBuffer(size_t size, ncclIpcDesc *ipcDesc, void **ptr);
ncclResult_t ncclP2pFreeShareableBuffer(ncclIpcDesc *ipcDesc);
ncclResult_t ncclP2pImportShareableBuffer(struct ncclComm *comm, int tpPeer, size_t size, ncclIpcDesc *ipcDesc, void **devMemPtr);
#endif #endif

View File

@ -13,11 +13,12 @@
#include "ipcsocket.h" #include "ipcsocket.h"
#include <pthread.h> #include <pthread.h>
#include "shm.h" #include "shm.h"
#include "p2p.h"
enum ncclProxyOpState { ncclProxyOpNone, ncclProxyOpReady, ncclProxyOpProgress }; enum ncclProxyOpState { ncclProxyOpNone, ncclProxyOpReady, ncclProxyOpProgress };
struct ncclProxyArgs; struct ncclProxyArgs;
typedef ncclResult_t (*proxyProgressFunc_t)(struct ncclComm*, struct ncclProxyArgs*); typedef ncclResult_t (*proxyProgressFunc_t)(struct ncclProxyState*, struct ncclProxyArgs*);
#define NCCL_PROXY_MAX_SUBS MAXCHANNELS #define NCCL_PROXY_MAX_SUBS MAXCHANNELS
static_assert(NCCL_MAX_WORK_ELEMENTS <= MAXCHANNELS, "Not enough sub space for max work elements"); static_assert(NCCL_MAX_WORK_ELEMENTS <= MAXCHANNELS, "Not enough sub space for max work elements");
@ -120,18 +121,11 @@ struct ncclProxySharedP2p {
int size; int size;
char* cudaBuff; char* cudaBuff;
char* hostBuff; char* hostBuff;
cudaIpcMemHandle_t ipc; // CUDA IPC
ncclIpcDesc ipcDesc;
struct ncclProxyArgs* proxyAppend[MAXCHANNELS]; // Separate send and recv struct ncclProxyArgs* proxyAppend[MAXCHANNELS]; // Separate send and recv
}; };
struct ncclProxySharedCollNet {
int size;
char* cudaBuff;
char* hostBuff;
struct ncclProxyArgs* proxyAppend[2*NCCL_MAX_NETDEVS];
void* resources;
};
struct ncclProxyPeer { struct ncclProxyPeer {
struct ncclProxySharedP2p send; struct ncclProxySharedP2p send;
struct ncclProxySharedP2p recv; struct ncclProxySharedP2p recv;
@ -155,7 +149,6 @@ struct ncclProxyProgressState {
bool stop; bool stop;
struct ncclProxyPeer** localPeers; struct ncclProxyPeer** localPeers;
struct ncclSharedNetComms* netComms[NCCL_MAX_NETDEVS]; struct ncclSharedNetComms* netComms[NCCL_MAX_NETDEVS];
struct ncclProxySharedCollNet collNet;
struct ncclProxyArgs* active; struct ncclProxyArgs* active;
struct ncclProxyArgs* pool; struct ncclProxyArgs* pool;
struct ncclProxyPool* pools; struct ncclProxyPool* pools;
@ -182,12 +175,27 @@ struct ncclProxyAsyncOp {
struct ncclProxyLocalPeer { struct ncclProxyLocalPeer {
struct ncclSocket sock; struct ncclSocket sock;
int localRank; int tpRank;
int tpLocalRank;
ncclProxyAsyncOp* asyncOps; ncclProxyAsyncOp* asyncOps;
int asyncOpCounter; int asyncOpCounter;
}; };
struct ncclProxyState { struct ncclProxyState {
int refCount;
int tpRank;
int tpnRanks;
int tpLocalnRanks;
int cudaDev;
int p2pnChannels;
int p2pChunkSize;
int nChannels;
int buffSizes[NCCL_NUM_PROTOCOLS];
bool allocP2pNetLLBuffers;
bool dmaBufSupport;
ncclNet_t* ncclNet;
ncclCollNet_t* ncclCollNet;
volatile uint32_t* abortFlag;
// Service thread // Service thread
pthread_t thread; pthread_t thread;
struct ncclSocket* listenSock; struct ncclSocket* listenSock;
@ -199,6 +207,7 @@ struct ncclProxyState {
struct ncclSocket* peerSocks; struct ncclSocket* peerSocks;
struct ncclProxyOps* proxyOps; struct ncclProxyOps* proxyOps;
void** sharedDevMems; void** sharedDevMems;
struct ncclIpcSocket peerIpcSock; // cuMEM API support (UDS)
// Progress thread // Progress thread
struct ncclProxyProgressState progressState; struct ncclProxyProgressState progressState;
@ -218,13 +227,14 @@ enum proxyConnectState {
struct ncclProxyConnection { struct ncclProxyConnection {
int send, transport, shared; int send, transport, shared;
int localRank; int tpLocalRank, sameProcess;
struct ncclSocket* sock; struct ncclSocket* sock;
struct ncclTransportComm* tcomm; struct ncclTransportComm* tcomm;
struct ncclProxyArgs *proxyAppend; struct ncclProxyArgs *proxyAppend;
struct ncclProxyArgs **proxyAppendPtr; struct ncclProxyArgs **proxyAppendPtr;
void* transportResources; void* transportResources;
proxyConnectState state; proxyConnectState state;
struct ncclCollNetSharedRes* collNet;
}; };
typedef ncclResult_t (*threadFunc_t)(struct ncclProxyArgs*); typedef ncclResult_t (*threadFunc_t)(struct ncclProxyArgs*);
@ -240,7 +250,7 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* prox
ncclResult_t ncclProxyStart(struct ncclComm* comm); ncclResult_t ncclProxyStart(struct ncclComm* comm);
ncclResult_t ncclProxyInit(struct ncclComm* comm, struct ncclSocket* sock, union ncclSocketAddress* peerAddresses); ncclResult_t ncclProxyInit(struct ncclComm* comm, struct ncclSocket* sock, union ncclSocketAddress* peerAddresses);
ncclResult_t ncclProxyCreate(struct ncclComm* comm); ncclResult_t ncclProxyCreate(struct ncclComm* comm);
ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, int rank, struct ncclProxyConnector* proxyConn); ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, int proxyRank, struct ncclProxyConnector* proxyConn);
enum ncclProxyMsgType { enum ncclProxyMsgType {
ncclProxyMsgInit = 1, ncclProxyMsgInit = 1,
ncclProxyMsgSharedInit = 2, ncclProxyMsgSharedInit = 2,
@ -250,18 +260,21 @@ enum ncclProxyMsgType {
ncclProxyMsgClose = 6, ncclProxyMsgClose = 6,
ncclProxyMsgAbort = 7, ncclProxyMsgAbort = 7,
ncclProxyMsgStop = 8, ncclProxyMsgStop = 8,
ncclProxyMsgConvertFd = 9 // cuMem API support ncclProxyMsgConvertFd = 9, // cuMem API support (UDS)
}; };
// This function is called by a client of the proxy that needs to invoke any of the non-progress proxyOp types // This function is called by a client of the proxy that needs to invoke any of the non-progress proxyOp types
// Call this function on the client, supplying a locally unique opId. Then, poll on the return value of // Call this function on the client, supplying a locally unique opId. Then, poll on the return value of
// ncclPollProxyResponse(), supplying the same opId to confirm the operation has completed // ncclPollProxyResponse(), supplying the same opId to confirm the operation has completed
ncclResult_t ncclProxyCallAsync(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, int respSize, void* opId); ncclResult_t ncclProxyCallAsync(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, int respSize, void* opId);
// This function will internally call ncclProxyCallAsync() and spin until ncclPollProxyResponse() confirms the result is received // This function will internally call ncclProxyCallAsync() and spin until ncclPollProxyResponse() confirms the result is received
ncclResult_t ncclProxyCallBlocking(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize); ncclResult_t ncclProxyCallBlocking(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize);
ncclResult_t ncclPollProxyResponse(struct ncclProxyConnector* proxyConn, void* respBuff, void* opId); ncclResult_t ncclPollProxyResponse(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void* respBuff, void* opId);
ncclResult_t ncclProxyDestroy(struct ncclComm* comm); ncclResult_t ncclProxyClientConvertFdBlocking(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, int fd, int* convertedFd);
ncclResult_t ncclProxyStop(struct ncclComm* comm);
ncclResult_t ncclProxyShmUnlink(struct ncclComm* comm); ncclResult_t ncclProxyShmUnlink(struct ncclComm* comm);
ncclResult_t ncclProxyDestroy(struct ncclComm* comm);
#endif #endif

View File

@ -35,7 +35,6 @@ struct ncclComm;
struct ncclPeerInfo { struct ncclPeerInfo {
int rank; int rank;
int cudaDev; int cudaDev;
int netDev;
int gdrSupport; int gdrSupport;
uint64_t hostHash; uint64_t hostHash;
uint64_t pidHash; uint64_t pidHash;
@ -50,15 +49,46 @@ struct ncclConnect {
char data[CONNECT_SIZE]; char data[CONNECT_SIZE];
}; };
#if CUDART_VERSION >= 12010
#define NVLS_HANDLE_SIZE 64
struct ncclNvlsSharedRes {
int refCount;
CUmulticastObjectProp properties;
CUmemAccessDesc accessDesc;
int dev;
size_t size;
size_t granularity;
CUmemGenericAllocationHandle mcHandle; // Multicast handle for NVLS buffer
char* mcBuff; // Multicast NVLS buffer address
CUmemGenericAllocationHandle ucHandle; // Unicast Handle for NVLS buffer
char* ucBuff; // Unicast NVLS buffer address
char shareableHandle[NVLS_HANDLE_SIZE];
int nChannels;
};
#endif /* CUDART_VERSION >= 12010 */
struct ncclCollNetSharedRes {
int refCount;
int size;
char* cudaBuff;
char* hostBuff;
struct ncclProxyArgs* proxyAppend[2*NCCL_MAX_NETDEVS];
void* resources;
int nChannels;
size_t buffSize;
};
struct ncclTransportComm { struct ncclTransportComm {
ncclResult_t (*setup)(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo*, struct ncclPeerInfo*, struct ncclConnect*, struct ncclConnector*, int channelId, int connIndex); ncclResult_t (*setup)(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo*, struct ncclPeerInfo*, struct ncclConnect*, struct ncclConnector*, int channelId, int connIndex);
ncclResult_t (*connect)(struct ncclComm* comm, struct ncclConnect*, int nranks, int rank, struct ncclConnector*); ncclResult_t (*connect)(struct ncclComm* comm, struct ncclConnect*, int nranks, int rank, struct ncclConnector*);
ncclResult_t (*free)(struct ncclConnector*); ncclResult_t (*free)(struct ncclConnector*);
ncclResult_t (*proxySharedInit)(struct ncclProxyConnection* connection, struct ncclComm* comm, int nChannels); ncclResult_t (*proxySharedInit)(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, int nChannels);
ncclResult_t (*proxySetup)(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done); ncclResult_t (*proxySetup)(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done);
ncclResult_t (*proxyConnect)(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done); ncclResult_t (*proxyConnect)(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done);
ncclResult_t (*proxyFree)(struct ncclProxyConnection* connection, struct ncclComm* comm); ncclResult_t (*proxyFree)(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState);
ncclResult_t (*proxyProgress)(struct ncclComm* comm, struct ncclProxyArgs*); ncclResult_t (*proxyProgress)(struct ncclProxyState* proxyState, struct ncclProxyArgs*);
}; };
struct ncclTransport { struct ncclTransport {
@ -71,7 +101,8 @@ struct ncclTransport {
ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, int channelId, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex); ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, int channelId, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex);
ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType=NULL); ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType=NULL);
ncclResult_t ncclNvlsSetup(struct ncclComm* comm); ncclResult_t ncclNvlsInit(struct ncclComm* comm);
ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent);
ncclResult_t ncclNvlsFree(struct ncclComm* comm); ncclResult_t ncclNvlsFree(struct ncclComm* comm);
enum { collNetRecv=0, collNetSend=1 }; enum { collNetRecv=0, collNetSend=1 };

File diff suppressed because it is too large Load Diff

View File

@ -6,10 +6,46 @@
#include "nccl.h" #include "nccl.h"
#include "debug.h" #include "debug.h"
#include "param.h"
#include "cudawrap.h" #include "cudawrap.h"
#include <dlfcn.h> #include <dlfcn.h>
// This env var (NCCL_CUMEM_ENABLE) toggles cuMem API usage
NCCL_PARAM(CuMemEnable, "CUMEM_ENABLE", 0);
static int ncclCuMemSupported = 0;
// Determine whether CUMEM & VMM RDMA is supported on this platform
int ncclIsCuMemSupported() {
#if CUDART_VERSION < 11030
return 0;
#else
CUdevice currentDev;
int cudaDev;
int cudaDriverVersion;
int flag = 0;
ncclResult_t ret = ncclSuccess;
CUDACHECKGOTO(cudaDriverGetVersion(&cudaDriverVersion), ret, error);
if (cudaDriverVersion < 12000) return 0; // Need CUDA_VISIBLE_DEVICES support
CUDACHECKGOTO(cudaGetDevice(&cudaDev), ret, error);
if (CUPFN(cuMemCreate) == NULL) return 0;
CUCHECKGOTO(cuDeviceGet(&currentDev, cudaDev), ret, error);
// Query device to see if CUMEM VMM support is available
CUCHECKGOTO(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, currentDev), ret, error);
if (!flag) return 0;
// Query device to see if CUMEM RDMA support is available
CUCHECKGOTO(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, currentDev), ret, error);
if (!flag) return 0;
error:
return (ret == ncclSuccess);
#endif
}
int ncclCuMemEnable() {
return ((ncclParamCuMemEnable() == -2 && ncclCuMemSupported) || ncclParamCuMemEnable());
}
#define DECLARE_CUDA_PFN(symbol,version) PFN_##symbol##_v##version pfn_##symbol = nullptr #define DECLARE_CUDA_PFN(symbol,version) PFN_##symbol##_v##version pfn_##symbol = nullptr
#if CUDART_VERSION >= 11030 #if CUDART_VERSION >= 11030
@ -35,6 +71,7 @@ DECLARE_CUDA_PFN(cuMemExportToShareableHandle, 10020);
DECLARE_CUDA_PFN(cuMemImportFromShareableHandle, 10020); DECLARE_CUDA_PFN(cuMemImportFromShareableHandle, 10020);
DECLARE_CUDA_PFN(cuMemMap, 10020); DECLARE_CUDA_PFN(cuMemMap, 10020);
DECLARE_CUDA_PFN(cuMemRelease, 10020); DECLARE_CUDA_PFN(cuMemRelease, 10020);
DECLARE_CUDA_PFN(cuMemRetainAllocationHandle, 11000);
DECLARE_CUDA_PFN(cuMemSetAccess, 10020); DECLARE_CUDA_PFN(cuMemSetAccess, 10020);
DECLARE_CUDA_PFN(cuMemUnmap, 10020); DECLARE_CUDA_PFN(cuMemUnmap, 10020);
#if CUDA_VERSION >= 11070 #if CUDA_VERSION >= 11070
@ -89,7 +126,6 @@ static ncclResult_t cudaPfnFuncLoader(void) {
LOAD_SYM(cuCtxSetCurrent, 4000, 1); LOAD_SYM(cuCtxSetCurrent, 4000, 1);
LOAD_SYM(cuCtxGetDevice, 2000, 1); LOAD_SYM(cuCtxGetDevice, 2000, 1);
/* cuMem API support */ /* cuMem API support */
#if CUDA_VERSION >= 11030
LOAD_SYM(cuMemAddressReserve, 10020, 1); LOAD_SYM(cuMemAddressReserve, 10020, 1);
LOAD_SYM(cuMemAddressFree, 10020, 1); LOAD_SYM(cuMemAddressFree, 10020, 1);
LOAD_SYM(cuMemCreate, 10020, 1); LOAD_SYM(cuMemCreate, 10020, 1);
@ -98,9 +134,9 @@ static ncclResult_t cudaPfnFuncLoader(void) {
LOAD_SYM(cuMemImportFromShareableHandle, 10020, 1); LOAD_SYM(cuMemImportFromShareableHandle, 10020, 1);
LOAD_SYM(cuMemMap, 10020, 1); LOAD_SYM(cuMemMap, 10020, 1);
LOAD_SYM(cuMemRelease, 10020, 1); LOAD_SYM(cuMemRelease, 10020, 1);
LOAD_SYM(cuMemRetainAllocationHandle, 11000, 1);
LOAD_SYM(cuMemSetAccess, 10020, 1); LOAD_SYM(cuMemSetAccess, 10020, 1);
LOAD_SYM(cuMemUnmap, 10020, 1); LOAD_SYM(cuMemUnmap, 10020, 1);
#endif
#if CUDA_VERSION >= 11070 #if CUDA_VERSION >= 11070
LOAD_SYM(cuMemGetHandleForAddressRange, 11070, 1); // DMA-BUF support LOAD_SYM(cuMemGetHandleForAddressRange, 11070, 1); // DMA-BUF support
#endif #endif
@ -135,7 +171,7 @@ static void initOnceFunc() {
if (ncclCudaPath == NULL) if (ncclCudaPath == NULL)
snprintf(path, 1024, "%s", "libcuda.so"); snprintf(path, 1024, "%s", "libcuda.so");
else else
snprintf(path, 1024, "%s%s", ncclCudaPath, "libcuda.so"); snprintf(path, 1024, "%s/%s", ncclCudaPath, "libcuda.so");
(void) dlerror(); // Clear any previous errors (void) dlerror(); // Clear any previous errors
cudaLib = dlopen(path, RTLD_LAZY); cudaLib = dlopen(path, RTLD_LAZY);
@ -195,6 +231,9 @@ static void initOnceFunc() {
} }
#endif #endif
// Determine whether we support the cuMem APIs or not
ncclCuMemSupported = ncclIsCuMemSupported();
initResult = ncclSuccess; initResult = ncclSuccess;
return; return;
error: error:

158
src/misc/ibvsymbols.cc Normal file
View File

@ -0,0 +1,158 @@
#include <sys/types.h>
#include <unistd.h>
#include "ibvsymbols.h"
#ifdef NCCL_BUILD_RDMA_CORE
/* RDMA-core linking mode. Symbols are pointers to linked IB Verbs */
#define ASSIGN_SYM(container, symbol, name) container->name= &symbol;
// Passthrough function for ibv_reg_mr macro in verbs.h
struct ibv_mr* ibv_internal_reg_mr(
struct ibv_pd* pd,
void* addr,
size_t length,
int access) {
return ibv_reg_mr(pd, addr, length, access);
}
// Passthrough function for ibv_internal_query_port macro in verbs.h
int ibv_internal_query_port(
struct ibv_context* context,
uint8_t port_num,
struct ibv_port_attr* port_attr) {
return ibv_query_port(context, port_num, port_attr);
}
ncclResult_t buildIbvSymbols(struct ncclIbvSymbols* ibvSymbols) {
ASSIGN_SYM(ibvSymbols, ibv_get_device_list, ibv_internal_get_device_list);
ASSIGN_SYM(ibvSymbols, ibv_free_device_list, ibv_internal_free_device_list);
ASSIGN_SYM(ibvSymbols, ibv_get_device_name, ibv_internal_get_device_name);
ASSIGN_SYM(ibvSymbols, ibv_open_device, ibv_internal_open_device);
ASSIGN_SYM(ibvSymbols, ibv_close_device, ibv_internal_close_device);
ASSIGN_SYM(ibvSymbols, ibv_get_async_event, ibv_internal_get_async_event);
ASSIGN_SYM(ibvSymbols, ibv_ack_async_event, ibv_internal_ack_async_event);
ASSIGN_SYM(ibvSymbols, ibv_query_device, ibv_internal_query_device);
ASSIGN_SYM(ibvSymbols, ibv_query_gid, ibv_internal_query_gid);
ASSIGN_SYM(ibvSymbols, ibv_query_qp, ibv_internal_query_qp);
ASSIGN_SYM(ibvSymbols, ibv_alloc_pd, ibv_internal_alloc_pd);
ASSIGN_SYM(ibvSymbols, ibv_dealloc_pd, ibv_internal_dealloc_pd);
ASSIGN_SYM(ibvSymbols, ibv_reg_mr_iova2, ibv_internal_reg_mr_iova2);
ASSIGN_SYM(ibvSymbols, ibv_reg_dmabuf_mr, ibv_internal_reg_dmabuf_mr);
ASSIGN_SYM(ibvSymbols, ibv_dereg_mr, ibv_internal_dereg_mr);
ASSIGN_SYM(ibvSymbols, ibv_create_cq, ibv_internal_create_cq);
ASSIGN_SYM(ibvSymbols, ibv_destroy_cq, ibv_internal_destroy_cq);
ASSIGN_SYM(ibvSymbols, ibv_create_qp, ibv_internal_create_qp);
ASSIGN_SYM(ibvSymbols, ibv_modify_qp, ibv_internal_modify_qp);
ASSIGN_SYM(ibvSymbols, ibv_destroy_qp, ibv_internal_destroy_qp);
ASSIGN_SYM(ibvSymbols, ibv_fork_init, ibv_internal_fork_init);
ASSIGN_SYM(ibvSymbols, ibv_event_type_str, ibv_internal_event_type_str);
ibvSymbols->ibv_internal_reg_mr = &ibv_internal_reg_mr;
ibvSymbols->ibv_internal_query_port = &ibv_internal_query_port;
return ncclSuccess;
}
#else
/* RDMA-core dynamic loading mode. Symbols are loaded from shared objects. */
#include <dlfcn.h>
#include "core.h"
// IBVERBS Library versioning
#define IBVERBS_VERSION "IBVERBS_1.1"
ncclResult_t buildIbvSymbols(struct ncclIbvSymbols* ibvSymbols) {
static void* ibvhandle = NULL;
void* tmp;
void** cast;
ibvhandle=dlopen("libibverbs.so", RTLD_NOW);
if (!ibvhandle) {
ibvhandle=dlopen("libibverbs.so.1", RTLD_NOW);
if (!ibvhandle) {
INFO(NCCL_INIT, "Failed to open libibverbs.so[.1]");
goto teardown;
}
}
#define LOAD_SYM(handle, symbol, funcptr) do { \
cast = (void**)&funcptr; \
tmp = dlvsym(handle, symbol, IBVERBS_VERSION); \
if (tmp == NULL) { \
WARN("dlvsym failed on %s - %s version %s", symbol, dlerror(), IBVERBS_VERSION); \
goto teardown; \
} \
*cast = tmp; \
} while (0)
// Attempt to load a specific symbol version - fail silently
#define LOAD_SYM_VERSION(handle, symbol, funcptr, version) do { \
cast = (void**)&funcptr; \
*cast = dlvsym(handle, symbol, version); \
} while (0)
LOAD_SYM(ibvhandle, "ibv_get_device_list", ibvSymbols->ibv_internal_get_device_list);
LOAD_SYM(ibvhandle, "ibv_free_device_list", ibvSymbols->ibv_internal_free_device_list);
LOAD_SYM(ibvhandle, "ibv_get_device_name", ibvSymbols->ibv_internal_get_device_name);
LOAD_SYM(ibvhandle, "ibv_open_device", ibvSymbols->ibv_internal_open_device);
LOAD_SYM(ibvhandle, "ibv_close_device", ibvSymbols->ibv_internal_close_device);
LOAD_SYM(ibvhandle, "ibv_get_async_event", ibvSymbols->ibv_internal_get_async_event);
LOAD_SYM(ibvhandle, "ibv_ack_async_event", ibvSymbols->ibv_internal_ack_async_event);
LOAD_SYM(ibvhandle, "ibv_query_device", ibvSymbols->ibv_internal_query_device);
LOAD_SYM(ibvhandle, "ibv_query_port", ibvSymbols->ibv_internal_query_port);
LOAD_SYM(ibvhandle, "ibv_query_gid", ibvSymbols->ibv_internal_query_gid);
LOAD_SYM(ibvhandle, "ibv_query_qp", ibvSymbols->ibv_internal_query_qp);
LOAD_SYM(ibvhandle, "ibv_alloc_pd", ibvSymbols->ibv_internal_alloc_pd);
LOAD_SYM(ibvhandle, "ibv_dealloc_pd", ibvSymbols->ibv_internal_dealloc_pd);
LOAD_SYM(ibvhandle, "ibv_reg_mr", ibvSymbols->ibv_internal_reg_mr);
// Cherry-pick the ibv_reg_mr_iova2 API from IBVERBS 1.8
LOAD_SYM_VERSION(ibvhandle, "ibv_reg_mr_iova2", ibvSymbols->ibv_internal_reg_mr_iova2, "IBVERBS_1.8");
// Cherry-pick the ibv_reg_dmabuf_mr API from IBVERBS 1.12
LOAD_SYM_VERSION(ibvhandle, "ibv_reg_dmabuf_mr", ibvSymbols->ibv_internal_reg_dmabuf_mr, "IBVERBS_1.12");
LOAD_SYM(ibvhandle, "ibv_dereg_mr", ibvSymbols->ibv_internal_dereg_mr);
LOAD_SYM(ibvhandle, "ibv_create_cq", ibvSymbols->ibv_internal_create_cq);
LOAD_SYM(ibvhandle, "ibv_destroy_cq", ibvSymbols->ibv_internal_destroy_cq);
LOAD_SYM(ibvhandle, "ibv_create_qp", ibvSymbols->ibv_internal_create_qp);
LOAD_SYM(ibvhandle, "ibv_modify_qp", ibvSymbols->ibv_internal_modify_qp);
LOAD_SYM(ibvhandle, "ibv_destroy_qp", ibvSymbols->ibv_internal_destroy_qp);
LOAD_SYM(ibvhandle, "ibv_fork_init", ibvSymbols->ibv_internal_fork_init);
LOAD_SYM(ibvhandle, "ibv_event_type_str", ibvSymbols->ibv_internal_event_type_str);
return ncclSuccess;
teardown:
ibvSymbols->ibv_internal_get_device_list = NULL;
ibvSymbols->ibv_internal_free_device_list = NULL;
ibvSymbols->ibv_internal_get_device_name = NULL;
ibvSymbols->ibv_internal_open_device = NULL;
ibvSymbols->ibv_internal_close_device = NULL;
ibvSymbols->ibv_internal_get_async_event = NULL;
ibvSymbols->ibv_internal_ack_async_event = NULL;
ibvSymbols->ibv_internal_query_device = NULL;
ibvSymbols->ibv_internal_query_port = NULL;
ibvSymbols->ibv_internal_query_gid = NULL;
ibvSymbols->ibv_internal_query_qp = NULL;
ibvSymbols->ibv_internal_alloc_pd = NULL;
ibvSymbols->ibv_internal_dealloc_pd = NULL;
ibvSymbols->ibv_internal_reg_mr = NULL;
ibvSymbols->ibv_internal_reg_mr_iova2 = NULL;
ibvSymbols->ibv_internal_reg_dmabuf_mr = NULL;
ibvSymbols->ibv_internal_dereg_mr = NULL;
ibvSymbols->ibv_internal_create_cq = NULL;
ibvSymbols->ibv_internal_destroy_cq = NULL;
ibvSymbols->ibv_internal_create_qp = NULL;
ibvSymbols->ibv_internal_modify_qp = NULL;
ibvSymbols->ibv_internal_destroy_qp = NULL;
ibvSymbols->ibv_internal_fork_init = NULL;
ibvSymbols->ibv_internal_event_type_str = NULL;
if (ibvhandle != NULL) dlclose(ibvhandle);
return ncclSystemError;
}
#endif

View File

@ -8,314 +8,186 @@
#include <sys/types.h> #include <sys/types.h>
#include <unistd.h> #include <unistd.h>
#include <dlfcn.h> #include "ibvsymbols.h"
#include "core.h"
/*Function Pointers*/
int (*ibv_internal_fork_init)(void);
struct ibv_device** (*ibv_internal_get_device_list)(int *num_devices);
void (*ibv_internal_free_device_list)(struct ibv_device **list);
const char * (*ibv_internal_get_device_name)(struct ibv_device *device);
struct ibv_context* (*ibv_internal_open_device)(struct ibv_device* device);
int (*ibv_internal_close_device)(struct ibv_context *context);
int (*ibv_internal_get_async_event)(struct ibv_context *context, struct ibv_async_event *event);
void (*ibv_internal_ack_async_event)(struct ibv_async_event *event);
int (*ibv_internal_query_device)(struct ibv_context *context, struct ibv_device_attr *device_attr);
int (*ibv_internal_query_port)(struct ibv_context *context, uint8_t port_num, struct ibv_port_attr *port_attr);
int (*ibv_internal_query_gid)(struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid);
int (*ibv_internal_query_qp)(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask, struct ibv_qp_init_attr *init_attr);
struct ibv_pd * (*ibv_internal_alloc_pd)(struct ibv_context *context);
int (*ibv_internal_dealloc_pd)(struct ibv_pd *pd);
struct ibv_mr * (*ibv_internal_reg_mr)(struct ibv_pd *pd, void *addr, size_t length, int access);
struct ibv_mr * (*ibv_internal_reg_mr_iova2)(struct ibv_pd *pd, void *addr, size_t length, uint64_t iova, int access);
/* DMA-BUF support */
struct ibv_mr * (*ibv_internal_reg_dmabuf_mr)(struct ibv_pd *pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access);
int (*ibv_internal_dereg_mr)(struct ibv_mr *mr);
struct ibv_cq * (*ibv_internal_create_cq)(struct ibv_context *context, int cqe, void *cq_context, struct ibv_comp_channel *channel, int comp_vector);
int (*ibv_internal_destroy_cq)(struct ibv_cq *cq);
struct ibv_qp * (*ibv_internal_create_qp)(struct ibv_pd *pd, struct ibv_qp_init_attr *qp_init_attr);
int (*ibv_internal_modify_qp)(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask);
int (*ibv_internal_destroy_qp)(struct ibv_qp *qp);
const char * (*ibv_internal_event_type_str)(enum ibv_event_type event);
// IBVERBS Library versioning
#define IBVERBS_VERSION "IBVERBS_1.1"
static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT; static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT;
static ncclResult_t initResult; static ncclResult_t initResult;
struct ncclIbvSymbols ibvSymbols;
static void initOnceFunc(void) {
static void* ibvhandle = NULL;
void* tmp;
void** cast;
ibvhandle=dlopen("libibverbs.so", RTLD_NOW);
if (!ibvhandle) {
ibvhandle=dlopen("libibverbs.so.1", RTLD_NOW);
if (!ibvhandle) {
INFO(NCCL_INIT, "Failed to open libibverbs.so[.1]");
goto teardown;
}
}
#define LOAD_SYM(handle, symbol, funcptr) do { \
cast = (void**)&funcptr; \
tmp = dlvsym(handle, symbol, IBVERBS_VERSION); \
if (tmp == NULL) { \
WARN("dlvsym failed on %s - %s version %s", symbol, dlerror(), IBVERBS_VERSION); \
goto teardown; \
} \
*cast = tmp; \
} while (0)
// Attempt to load a specific symbol version - fail silently
#define LOAD_SYM_VERSION(handle, symbol, funcptr, version) do { \
cast = (void**)&funcptr; \
*cast = dlvsym(handle, symbol, version); \
} while (0)
LOAD_SYM(ibvhandle, "ibv_get_device_list", ibv_internal_get_device_list);
LOAD_SYM(ibvhandle, "ibv_free_device_list", ibv_internal_free_device_list);
LOAD_SYM(ibvhandle, "ibv_get_device_name", ibv_internal_get_device_name);
LOAD_SYM(ibvhandle, "ibv_open_device", ibv_internal_open_device);
LOAD_SYM(ibvhandle, "ibv_close_device", ibv_internal_close_device);
LOAD_SYM(ibvhandle, "ibv_get_async_event", ibv_internal_get_async_event);
LOAD_SYM(ibvhandle, "ibv_ack_async_event", ibv_internal_ack_async_event);
LOAD_SYM(ibvhandle, "ibv_query_device", ibv_internal_query_device);
LOAD_SYM(ibvhandle, "ibv_query_port", ibv_internal_query_port);
LOAD_SYM(ibvhandle, "ibv_query_gid", ibv_internal_query_gid);
LOAD_SYM(ibvhandle, "ibv_query_qp", ibv_internal_query_qp);
LOAD_SYM(ibvhandle, "ibv_alloc_pd", ibv_internal_alloc_pd);
LOAD_SYM(ibvhandle, "ibv_dealloc_pd", ibv_internal_dealloc_pd);
LOAD_SYM(ibvhandle, "ibv_reg_mr", ibv_internal_reg_mr);
// Cherry-pick the ibv_reg_mr_iova2 API from IBVERBS 1.8
LOAD_SYM_VERSION(ibvhandle, "ibv_reg_mr_iova2", ibv_internal_reg_mr_iova2, "IBVERBS_1.8");
// Cherry-pick the ibv_reg_dmabuf_mr API from IBVERBS 1.12
LOAD_SYM_VERSION(ibvhandle, "ibv_reg_dmabuf_mr", ibv_internal_reg_dmabuf_mr, "IBVERBS_1.12");
LOAD_SYM(ibvhandle, "ibv_dereg_mr", ibv_internal_dereg_mr);
LOAD_SYM(ibvhandle, "ibv_create_cq", ibv_internal_create_cq);
LOAD_SYM(ibvhandle, "ibv_destroy_cq", ibv_internal_destroy_cq);
LOAD_SYM(ibvhandle, "ibv_create_qp", ibv_internal_create_qp);
LOAD_SYM(ibvhandle, "ibv_modify_qp", ibv_internal_modify_qp);
LOAD_SYM(ibvhandle, "ibv_destroy_qp", ibv_internal_destroy_qp);
LOAD_SYM(ibvhandle, "ibv_fork_init", ibv_internal_fork_init);
LOAD_SYM(ibvhandle, "ibv_event_type_str", ibv_internal_event_type_str);
initResult = ncclSuccess;
return;
teardown:
ibv_internal_get_device_list = NULL;
ibv_internal_free_device_list = NULL;
ibv_internal_get_device_name = NULL;
ibv_internal_open_device = NULL;
ibv_internal_close_device = NULL;
ibv_internal_get_async_event = NULL;
ibv_internal_ack_async_event = NULL;
ibv_internal_query_device = NULL;
ibv_internal_query_port = NULL;
ibv_internal_query_gid = NULL;
ibv_internal_query_qp = NULL;
ibv_internal_alloc_pd = NULL;
ibv_internal_dealloc_pd = NULL;
ibv_internal_reg_mr = NULL;
ibv_internal_reg_mr_iova2 = NULL;
ibv_internal_reg_dmabuf_mr = NULL;
ibv_internal_dereg_mr = NULL;
ibv_internal_create_cq = NULL;
ibv_internal_destroy_cq = NULL;
ibv_internal_create_qp = NULL;
ibv_internal_modify_qp = NULL;
ibv_internal_destroy_qp = NULL;
ibv_internal_fork_init = NULL;
ibv_internal_event_type_str = NULL;
if (ibvhandle != NULL) dlclose(ibvhandle);
initResult = ncclSystemError;
return;
}
ncclResult_t wrap_ibv_symbols(void) { ncclResult_t wrap_ibv_symbols(void) {
pthread_once(&initOnceControl, initOnceFunc); pthread_once(&initOnceControl,
[](){ initResult = buildIbvSymbols(&ibvSymbols); });
return initResult; return initResult;
} }
#define IBV_PTR_CHECK_ERRNO(name_internal, call, retval, error_retval, name) \ /* CHECK_NOT_NULL: helper macro to check for NULL symbol */
if (name_internal == NULL) { \ #define CHECK_NOT_NULL(container, internal_name) \
if (container.internal_name == NULL) { \
WARN("lib wrapper not initialized."); \ WARN("lib wrapper not initialized."); \
return ncclInternalError; \ return ncclInternalError; \
} \ }
retval = call; \
#define IBV_PTR_CHECK_ERRNO(container, internal_name, call, retval, error_retval, name) \
CHECK_NOT_NULL(container, internal_name); \
retval = container.call; \
if (retval == error_retval) { \ if (retval == error_retval) { \
WARN("Call to " name " failed with error %s", strerror(errno)); \ WARN("Call to " name " failed with error %s", strerror(errno)); \
return ncclSystemError; \ return ncclSystemError; \
} \ } \
return ncclSuccess; return ncclSuccess;
#define IBV_PTR_CHECK(name_internal, call, retval, error_retval, name) \ #define IBV_PTR_CHECK(container, internal_name, call, retval, error_retval, name) \
if (name_internal == NULL) { \ CHECK_NOT_NULL(container, internal_name); \
WARN("lib wrapper not initialized."); \ retval = container.call; \
return ncclInternalError; \
} \
retval = call; \
if (retval == error_retval) { \ if (retval == error_retval) { \
WARN("Call to " name " failed"); \ WARN("Call to " name " failed"); \
return ncclSystemError; \ return ncclSystemError; \
} \ } \
return ncclSuccess; return ncclSuccess;
#define IBV_INT_CHECK_RET_ERRNO(name_internal, call, success_retval, name) \ #define IBV_INT_CHECK_RET_ERRNO(container, internal_name, call, success_retval, name) \
if (name_internal == NULL) { \ CHECK_NOT_NULL(container, internal_name); \
WARN("lib wrapper not initialized."); \ int ret = container.call; \
return ncclInternalError; \
} \
int ret = call; \
if (ret != success_retval) { \ if (ret != success_retval) { \
WARN("Call to " name " failed with error %s", strerror(ret)); \ WARN("Call to " name " failed with error %s", strerror(ret)); \
return ncclSystemError; \ return ncclSystemError; \
} \ } \
return ncclSuccess; return ncclSuccess;
#define IBV_INT_CHECK(name_internal, call, error_retval, name) \ #define IBV_INT_CHECK(container, internal_name, call, error_retval, name) \
if (name_internal == NULL) { \ CHECK_NOT_NULL(container, internal_name); \
WARN("lib wrapper not initialized."); \ int ret = container.call; \
return ncclInternalError; \
} \
int ret = call; \
if (ret == error_retval) { \ if (ret == error_retval) { \
WARN("Call to " name " failed"); \ WARN("Call to " name " failed"); \
return ncclSystemError; \ return ncclSystemError; \
} \ } \
return ncclSuccess; return ncclSuccess;
#define IBV_PASSTHRU(name_internal, call) \ #define IBV_PASSTHRU(container, internal_name, call) \
if (name_internal == NULL) { \ CHECK_NOT_NULL(container, internal_name); \
WARN("lib wrapper not initialized."); \ container.call; \
return ncclInternalError; \
} \
call; \
return ncclSuccess; return ncclSuccess;
ncclResult_t wrap_ibv_fork_init() { ncclResult_t wrap_ibv_fork_init() {
IBV_INT_CHECK(ibv_internal_fork_init, ibv_internal_fork_init(), -1, "ibv_fork_init"); IBV_INT_CHECK(ibvSymbols, ibv_internal_fork_init, ibv_internal_fork_init(), -1, "ibv_fork_init");
} }
ncclResult_t wrap_ibv_get_device_list(struct ibv_device ***ret, int *num_devices) { ncclResult_t wrap_ibv_get_device_list(struct ibv_device ***ret, int *num_devices) {
*ret = ibv_internal_get_device_list(num_devices); *ret = ibvSymbols.ibv_internal_get_device_list(num_devices);
if (*ret == NULL) *num_devices = 0; if (*ret == NULL) *num_devices = 0;
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t wrap_ibv_free_device_list(struct ibv_device **list) { ncclResult_t wrap_ibv_free_device_list(struct ibv_device **list) {
IBV_PASSTHRU(ibv_internal_free_device_list, ibv_internal_free_device_list(list)); IBV_PASSTHRU(ibvSymbols, ibv_internal_free_device_list, ibv_internal_free_device_list(list));
} }
const char *wrap_ibv_get_device_name(struct ibv_device *device) { const char *wrap_ibv_get_device_name(struct ibv_device *device) {
if (ibv_internal_get_device_name == NULL) { if (ibvSymbols.ibv_internal_get_device_name == NULL) {
WARN("lib wrapper not initialized."); WARN("lib wrapper not initialized.");
exit(-1); exit(-1);
} }
return ibv_internal_get_device_name(device); return ibvSymbols.ibv_internal_get_device_name(device);
} }
ncclResult_t wrap_ibv_open_device(struct ibv_context **ret, struct ibv_device *device) { /*returns 0 on success, -1 on failure*/ ncclResult_t wrap_ibv_open_device(struct ibv_context **ret, struct ibv_device *device) { /*returns 0 on success, -1 on failure*/
IBV_PTR_CHECK(ibv_internal_open_device, ibv_internal_open_device(device), *ret, NULL, "ibv_open_device"); IBV_PTR_CHECK(ibvSymbols, ibv_internal_open_device, ibv_internal_open_device(device), *ret, NULL, "ibv_open_device");
} }
ncclResult_t wrap_ibv_close_device(struct ibv_context *context) { /*returns 0 on success, -1 on failure*/ ncclResult_t wrap_ibv_close_device(struct ibv_context *context) { /*returns 0 on success, -1 on failure*/
IBV_INT_CHECK(ibv_internal_close_device, ibv_internal_close_device(context), -1, "ibv_close_device"); IBV_INT_CHECK(ibvSymbols, ibv_internal_close_device, ibv_internal_close_device(context), -1, "ibv_close_device");
} }
ncclResult_t wrap_ibv_get_async_event(struct ibv_context *context, struct ibv_async_event *event) { /*returns 0 on success, and -1 on error*/ ncclResult_t wrap_ibv_get_async_event(struct ibv_context *context, struct ibv_async_event *event) { /*returns 0 on success, and -1 on error*/
IBV_INT_CHECK(ibv_internal_get_async_event, ibv_internal_get_async_event(context, event), -1, "ibv_get_async_event"); IBV_INT_CHECK(ibvSymbols, ibv_internal_get_async_event, ibv_internal_get_async_event(context, event), -1, "ibv_get_async_event");
} }
ncclResult_t wrap_ibv_ack_async_event(struct ibv_async_event *event) { ncclResult_t wrap_ibv_ack_async_event(struct ibv_async_event *event) {
IBV_PASSTHRU(ibv_internal_ack_async_event, ibv_internal_ack_async_event(event)); IBV_PASSTHRU(ibvSymbols, ibv_internal_ack_async_event, ibv_internal_ack_async_event(event));
} }
ncclResult_t wrap_ibv_query_device(struct ibv_context *context, struct ibv_device_attr *device_attr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ ncclResult_t wrap_ibv_query_device(struct ibv_context *context, struct ibv_device_attr *device_attr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibv_internal_query_device, ibv_internal_query_device(context, device_attr), 0, "ibv_query_device"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_query_device, ibv_internal_query_device(context, device_attr), 0, "ibv_query_device");
} }
ncclResult_t wrap_ibv_query_port(struct ibv_context *context, uint8_t port_num, struct ibv_port_attr *port_attr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ ncclResult_t wrap_ibv_query_port(struct ibv_context *context, uint8_t port_num, struct ibv_port_attr *port_attr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibv_internal_query_port, ibv_internal_query_port(context, port_num, port_attr), 0, "ibv_query_port"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_query_port, ibv_internal_query_port(context, port_num, port_attr), 0, "ibv_query_port");
} }
ncclResult_t wrap_ibv_query_gid(struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid) { ncclResult_t wrap_ibv_query_gid(struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid) {
IBV_INT_CHECK_RET_ERRNO(ibv_internal_query_gid, ibv_internal_query_gid(context, port_num, index, gid), 0, "ibv_query_gid"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_query_gid, ibv_internal_query_gid(context, port_num, index, gid), 0, "ibv_query_gid");
} }
ncclResult_t wrap_ibv_query_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask, struct ibv_qp_init_attr *init_attr) { ncclResult_t wrap_ibv_query_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask, struct ibv_qp_init_attr *init_attr) {
IBV_INT_CHECK_RET_ERRNO(ibv_internal_query_qp, ibv_internal_query_qp(qp, attr, attr_mask, init_attr), 0, "ibv_query_qp"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_query_qp, ibv_internal_query_qp(qp, attr, attr_mask, init_attr), 0, "ibv_query_qp");
} }
ncclResult_t wrap_ibv_alloc_pd(struct ibv_pd **ret, struct ibv_context *context) { ncclResult_t wrap_ibv_alloc_pd(struct ibv_pd **ret, struct ibv_context *context) {
IBV_PTR_CHECK_ERRNO(ibv_internal_alloc_pd, ibv_internal_alloc_pd(context), *ret, NULL, "ibv_alloc_pd"); IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_alloc_pd, ibv_internal_alloc_pd(context), *ret, NULL, "ibv_alloc_pd");
} }
ncclResult_t wrap_ibv_dealloc_pd(struct ibv_pd *pd) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ ncclResult_t wrap_ibv_dealloc_pd(struct ibv_pd *pd) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibv_internal_dealloc_pd, ibv_internal_dealloc_pd(pd), 0, "ibv_dealloc_pd"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_dealloc_pd, ibv_internal_dealloc_pd(pd), 0, "ibv_dealloc_pd");
} }
ncclResult_t wrap_ibv_reg_mr(struct ibv_mr **ret, struct ibv_pd *pd, void *addr, size_t length, int access) { ncclResult_t wrap_ibv_reg_mr(struct ibv_mr **ret, struct ibv_pd *pd, void *addr, size_t length, int access) {
IBV_PTR_CHECK_ERRNO(ibv_internal_reg_mr, ibv_internal_reg_mr(pd, addr, length, access), *ret, NULL, "ibv_reg_mr"); IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_reg_mr, ibv_internal_reg_mr(pd, addr, length, access), *ret, NULL, "ibv_reg_mr");
} }
struct ibv_mr * wrap_direct_ibv_reg_mr(struct ibv_pd *pd, void *addr, size_t length, int access) { struct ibv_mr * wrap_direct_ibv_reg_mr(struct ibv_pd *pd, void *addr, size_t length, int access) {
if (ibv_internal_reg_mr == NULL) { if (ibvSymbols.ibv_internal_reg_mr == NULL) {
WARN("lib wrapper not initialized."); WARN("lib wrapper not initialized.");
return NULL; return NULL;
} }
return ibv_internal_reg_mr(pd, addr, length, access); return ibvSymbols.ibv_internal_reg_mr(pd, addr, length, access);
} }
ncclResult_t wrap_ibv_reg_mr_iova2(struct ibv_mr **ret, struct ibv_pd *pd, void *addr, size_t length, uint64_t iova, int access) { ncclResult_t wrap_ibv_reg_mr_iova2(struct ibv_mr **ret, struct ibv_pd *pd, void *addr, size_t length, uint64_t iova, int access) {
if (ibv_internal_reg_mr_iova2 == NULL) { if (ibvSymbols.ibv_internal_reg_mr_iova2 == NULL) {
return ncclInternalError; return ncclInternalError;
} }
if (ret == NULL) { return ncclSuccess; } // Assume dummy call if (ret == NULL) { return ncclSuccess; } // Assume dummy call
IBV_PTR_CHECK_ERRNO(ibv_internal_reg_mr_iova2, ibv_internal_reg_mr_iova2(pd, addr, length, iova, access), *ret, NULL, "ibv_reg_mr_iova2"); IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_reg_mr_iova2, ibv_internal_reg_mr_iova2(pd, addr, length, iova, access), *ret, NULL, "ibv_reg_mr_iova2");
} }
/* DMA-BUF support */ /* DMA-BUF support */
ncclResult_t wrap_ibv_reg_dmabuf_mr(struct ibv_mr **ret, struct ibv_pd *pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access) { ncclResult_t wrap_ibv_reg_dmabuf_mr(struct ibv_mr **ret, struct ibv_pd *pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access) {
IBV_PTR_CHECK_ERRNO(ibv_internal_reg_dmabuf_mr, ibv_internal_reg_dmabuf_mr(pd, offset, length, iova, fd, access), *ret, NULL, "ibv_reg_dmabuf_mr"); IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_reg_dmabuf_mr, ibv_internal_reg_dmabuf_mr(pd, offset, length, iova, fd, access), *ret, NULL, "ibv_reg_dmabuf_mr");
} }
struct ibv_mr * wrap_direct_ibv_reg_dmabuf_mr(struct ibv_pd *pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access) { struct ibv_mr * wrap_direct_ibv_reg_dmabuf_mr(struct ibv_pd *pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access) {
if (ibv_internal_reg_dmabuf_mr == NULL) { if (ibvSymbols.ibv_internal_reg_dmabuf_mr == NULL) {
errno = EOPNOTSUPP; // ncclIbDmaBufSupport() requires this errno being set errno = EOPNOTSUPP; // ncclIbDmaBufSupport() requires this errno being set
return NULL; return NULL;
} }
return ibv_internal_reg_dmabuf_mr(pd, offset, length, iova, fd, access); return ibvSymbols.ibv_internal_reg_dmabuf_mr(pd, offset, length, iova, fd, access);
} }
ncclResult_t wrap_ibv_dereg_mr(struct ibv_mr *mr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ ncclResult_t wrap_ibv_dereg_mr(struct ibv_mr *mr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibv_internal_dereg_mr, ibv_internal_dereg_mr(mr), 0, "ibv_dereg_mr"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_dereg_mr, ibv_internal_dereg_mr(mr), 0, "ibv_dereg_mr");
} }
ncclResult_t wrap_ibv_create_cq(struct ibv_cq **ret, struct ibv_context *context, int cqe, void *cq_context, struct ibv_comp_channel *channel, int comp_vector) { ncclResult_t wrap_ibv_create_cq(struct ibv_cq **ret, struct ibv_context *context, int cqe, void *cq_context, struct ibv_comp_channel *channel, int comp_vector) {
IBV_PTR_CHECK_ERRNO(ibv_internal_create_cq, ibv_internal_create_cq(context, cqe, cq_context, channel, comp_vector), *ret, NULL, "ibv_create_cq"); IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_create_cq, ibv_internal_create_cq(context, cqe, cq_context, channel, comp_vector), *ret, NULL, "ibv_create_cq");
} }
ncclResult_t wrap_ibv_destroy_cq(struct ibv_cq *cq) { ncclResult_t wrap_ibv_destroy_cq(struct ibv_cq *cq) {
IBV_INT_CHECK_RET_ERRNO(ibv_internal_destroy_cq, ibv_internal_destroy_cq(cq), 0, "ibv_destroy_cq"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_destroy_cq, ibv_internal_destroy_cq(cq), 0, "ibv_destroy_cq");
} }
ncclResult_t wrap_ibv_destroy_qp(struct ibv_qp *qp) { ncclResult_t wrap_ibv_destroy_qp(struct ibv_qp *qp) {
IBV_INT_CHECK_RET_ERRNO(ibv_internal_destroy_qp, ibv_internal_destroy_qp(qp), 0, "ibv_destroy_qp"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_destroy_qp, ibv_internal_destroy_qp(qp), 0, "ibv_destroy_qp");
} }
ncclResult_t wrap_ibv_create_qp(struct ibv_qp **ret, struct ibv_pd *pd, struct ibv_qp_init_attr *qp_init_attr) { ncclResult_t wrap_ibv_create_qp(struct ibv_qp **ret, struct ibv_pd *pd, struct ibv_qp_init_attr *qp_init_attr) {
IBV_PTR_CHECK_ERRNO(ibv_internal_create_qp, ibv_internal_create_qp(pd, qp_init_attr), *ret, NULL, "ibv_create_qp"); IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_create_qp, ibv_internal_create_qp(pd, qp_init_attr), *ret, NULL, "ibv_create_qp");
} }
ncclResult_t wrap_ibv_modify_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ ncclResult_t wrap_ibv_modify_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibv_internal_modify_qp, ibv_internal_modify_qp(qp, attr, attr_mask), 0, "ibv_modify_qp"); IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_modify_qp, ibv_internal_modify_qp(qp, attr, attr_mask), 0, "ibv_modify_qp");
} }
ncclResult_t wrap_ibv_event_type_str(char **ret, enum ibv_event_type event) { ncclResult_t wrap_ibv_event_type_str(char **ret, enum ibv_event_type event) {
*ret = (char *) ibv_internal_event_type_str(event); *ret = (char *) ibvSymbols.ibv_internal_event_type_str(event);
return ncclSuccess; return ncclSuccess;
} }

View File

@ -14,6 +14,7 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <unistd.h> #include <unistd.h>
#include <utils.h>
struct shmHandleInternal { struct shmHandleInternal {
int fd; int fd;
@ -86,17 +87,13 @@ ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** de
if (create) { if (create) {
*(int*)(hptr + shmSize) = refcount; *(int*)(hptr + shmSize) = refcount;
} else { } else {
int remref = __atomic_sub_fetch((int*)(hptr + shmSize), 1, __ATOMIC_RELAXED); int remref = ncclAtomicRefCountDecrement((int*)(hptr + shmSize));
if (remref == 0) { if (remref == 0) {
/* the last peer has completed attachment, it should unlink the shm mem file. */ /* the last peer has completed attachment, it should unlink the shm mem file. */
if (unlink(shmPath) != 0) { if (unlink(shmPath) != 0) {
WARN("unlink shared memory %s failed, error: %s", shmPath, strerror(errno)); WARN("unlink shared memory %s failed, error: %s", shmPath, strerror(errno));
} }
} }
if (refcount != -1) {
WARN("attaching memory should only reduce refcount by 1 but %d is passed", refcount);
}
} }
if (devShmPtr) { if (devShmPtr) {
@ -133,8 +130,8 @@ ncclResult_t ncclShmClose(ncclShmHandle_t handle) {
WARN("unlink shared memory %s failed, error: %s", tmphandle->shmPath, strerror(errno)); WARN("unlink shared memory %s failed, error: %s", tmphandle->shmPath, strerror(errno));
ret = ncclSystemError; ret = ncclSystemError;
} }
free(tmphandle->shmPath);
} }
free(tmphandle->shmPath);
} }
if (tmphandle->shmPtr) { if (tmphandle->shmPtr) {

View File

@ -411,7 +411,7 @@ static ncclResult_t socketTryAccept(struct ncclSocket* sock) {
if (sock->fd != -1) { if (sock->fd != -1) {
sock->state = ncclSocketStateAccepted; sock->state = ncclSocketStateAccepted;
} else if (errno != EAGAIN && errno != EWOULDBLOCK) { } else if (errno != EAGAIN && errno != EWOULDBLOCK) {
WARN("socketTryAccept: get errno %d that is not EAGAIN or EWOULDBLOCK", errno); WARN("socketTryAccept: Accept failed: %s", strerror(errno));
return ncclSystemError; return ncclSystemError;
} }
return ncclSuccess; return ncclSuccess;

View File

@ -46,6 +46,7 @@ typedef enum { ncclSuccess = 0,
#define NCCL_CONFIG_UNDEF_INT INT_MIN #define NCCL_CONFIG_UNDEF_INT INT_MIN
#define NCCL_CONFIG_UNDEF_PTR NULL #define NCCL_CONFIG_UNDEF_PTR NULL
#define NCCL_SPLIT_NOCOLOR -1
/* Communicator configuration. Users can assign value to attributes to specify the /* Communicator configuration. Users can assign value to attributes to specify the
* behavior of a communicator. */ * behavior of a communicator. */
@ -60,6 +61,7 @@ typedef struct ncclConfig_v21700 {
int minCTAs; int minCTAs;
int maxCTAs; int maxCTAs;
const char *netName; const char *netName;
int splitShare;
} ncclConfig_t; } ncclConfig_t;
/* Config initializer must be assigned to initialize config structure when it is created. /* Config initializer must be assigned to initialize config structure when it is created.
@ -72,7 +74,8 @@ typedef struct ncclConfig_v21700 {
NCCL_CONFIG_UNDEF_INT, /* cgaClusterSize */ \ NCCL_CONFIG_UNDEF_INT, /* cgaClusterSize */ \
NCCL_CONFIG_UNDEF_INT, /* minCTAs */ \ NCCL_CONFIG_UNDEF_INT, /* minCTAs */ \
NCCL_CONFIG_UNDEF_INT, /* maxCTAs */ \ NCCL_CONFIG_UNDEF_INT, /* maxCTAs */ \
NCCL_CONFIG_UNDEF_PTR /* netName */ \ NCCL_CONFIG_UNDEF_PTR, /* netName */ \
NCCL_CONFIG_UNDEF_INT /* splitShare */ \
} }
/* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer. /* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer.
@ -128,6 +131,16 @@ ncclResult_t pncclCommDestroy(ncclComm_t comm);
ncclResult_t ncclCommAbort(ncclComm_t comm); ncclResult_t ncclCommAbort(ncclComm_t comm);
ncclResult_t pncclCommAbort(ncclComm_t comm); ncclResult_t pncclCommAbort(ncclComm_t comm);
/* Creates one or more communicators from an existing one.
* Ranks with the same color will end up in the same communicator.
* Within the new communicator, key will be used to order ranks.
* NCCL_SPLIT_NOCOLOR as color will indicate the rank will not be part of any group
* and will therefore return a NULL communicator.
* If config is NULL, the new communicator will inherit the original communicator's
* configuration*/
ncclResult_t ncclCommSplit(ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t* config);
ncclResult_t pncclCommSplit(ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t* config);
/* Returns a string for each error code. */ /* Returns a string for each error code. */
const char* ncclGetErrorString(ncclResult_t result); const char* ncclGetErrorString(ncclResult_t result);
const char* pncclGetErrorString(ncclResult_t result); const char* pncclGetErrorString(ncclResult_t result);

View File

@ -258,10 +258,10 @@ static ncclResult_t collNetGetState(int i, enum ncclNetState* state) {
ncclResult_t ncclNetInit(struct ncclComm* comm) { ncclResult_t ncclNetInit(struct ncclComm* comm) {
// Initialize main communication network // Initialize main communication network
char* netName; const char* netName;
bool ok = false; bool ok = false;
netName = comm->netName; netName = comm->config.netName;
for (int i=0; i<3; i++) { for (int i=0; i<3; i++) {
if (ncclNets[i] == nullptr) continue; if (ncclNets[i] == nullptr) continue;
enum ncclNetState state; enum ncclNetState state;
@ -302,23 +302,27 @@ ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) {
return ncclSuccess; return ncclSuccess;
} }
#endif #endif
int netDevs; static int gdrSupportMatrix[32] = {
NCCLCHECK(ncclNetDevices(comm, &netDevs)); -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
*gdrSupport = 0; -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 };
for (int dev=0; dev<netDevs; dev++) { if (gdrSupportMatrix[comm->cudaDev] == -1) {
// Find a net device which is GDR-capable int netDevs;
ncclNetProperties_t props; NCCLCHECK(comm->ncclNet->devices(&netDevs));
NCCLCHECK(ncclNetGetProperties(comm, dev, &props)); gdrSupportMatrix[comm->cudaDev] = 0;
if ((props.ptrSupport & NCCL_PTR_CUDA) == 0) continue; for (int dev=0; dev<netDevs; dev++) {
// Find a net device which is GDR-capable
ncclNetProperties_t props;
NCCLCHECK(comm->ncclNet->getProperties(dev, &props));
if ((props.ptrSupport & NCCL_PTR_CUDA) == 0) continue;
// Allocate memory on the GPU and try to register it on the NIC. // Allocate memory on the GPU and try to register it on the NIC.
void *lComm = NULL, *sComm = NULL, *rComm = NULL; void *lComm = NULL, *sComm = NULL, *rComm = NULL;
ncclNetHandle_t handle; ncclNetHandle_t handle;
void* gpuPtr = NULL; char* gpuPtr = NULL;
void* mHandle = NULL; void* mHandle = NULL;
ncclResult_t ret; ncclResult_t ret;
ncclDebugNoWarn = NCCL_NET; ncclDebugNoWarn = NCCL_NET;
NCCLCHECKGOTO(ncclNetListen(comm, dev, &handle, &lComm), ret, cleanup1); NCCLCHECKGOTO(comm->ncclNet->listen(dev, &handle, &lComm), ret, cleanup1);
bool connected; bool connected;
connected = false; connected = false;
@ -330,32 +334,34 @@ ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) {
} }
if (sComm == NULL) if (sComm == NULL)
NCCLCHECKGOTO(ncclNetConnect(comm, dev, &handle, &sComm), ret, cleanup2); NCCLCHECKGOTO(comm->ncclNet->connect(dev, &handle, &sComm), ret, cleanup2);
if (rComm == NULL) if (rComm == NULL)
NCCLCHECKGOTO(ncclNetAccept(comm, lComm, &rComm), ret, cleanup2); NCCLCHECKGOTO(comm->ncclNet->accept(lComm, &rComm), ret, cleanup2);
connected = (rComm != NULL) && (sComm != NULL); connected = (rComm != NULL) && (sComm != NULL);
} }
CUDACHECKGOTO(cudaMalloc(&gpuPtr, GPU_BUF_SIZE), ret, cleanup2); NCCLCHECKGOTO(ncclCudaMalloc(&gpuPtr, GPU_BUF_SIZE), ret, cleanup2);
if (ncclNetRegMr(comm, sComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle) == ncclSuccess) { if (comm->ncclNet->regMr(sComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle) == ncclSuccess) {
NCCLCHECK(ncclNetDeregMr(comm, sComm, mHandle)); NCCLCHECK(comm->ncclNet->deregMr(sComm, mHandle));
NCCLCHECK(ncclNetRegMr(comm, rComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle)); NCCLCHECK(comm->ncclNet->regMr(rComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle));
NCCLCHECK(ncclNetDeregMr(comm, rComm, mHandle)); NCCLCHECK(comm->ncclNet->deregMr(rComm, mHandle));
*gdrSupport = 1; gdrSupportMatrix[comm->cudaDev] = 1;
} }
ncclDebugNoWarn = 0; ncclDebugNoWarn = 0;
CUDACHECK(cudaFree(gpuPtr)); NCCLCHECK(ncclCudaFree(gpuPtr));
cleanup2: cleanup2:
if (rComm != NULL) if (rComm != NULL)
NCCLCHECK(ncclNetCloseRecv(comm, rComm)); NCCLCHECK(comm->ncclNet->closeRecv(rComm));
if (sComm != NULL) if (sComm != NULL)
NCCLCHECK(ncclNetCloseSend(comm, sComm)); NCCLCHECK(comm->ncclNet->closeSend(sComm));
NCCLCHECK(ncclNetCloseListen(comm, lComm)); NCCLCHECK(comm->ncclNet->closeListen(lComm));
cleanup1: cleanup1:
break; break;
}
} }
*gdrSupport = gdrSupportMatrix[comm->cudaDev];
return ncclSuccess; return ncclSuccess;
} }

File diff suppressed because it is too large Load Diff

View File

@ -21,8 +21,8 @@ template <int type>
static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex, int* transportType) { static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex, int* transportType) {
struct ncclPeerInfo* myInfo = comm->peerInfo+comm->rank; struct ncclPeerInfo* myInfo = comm->peerInfo+comm->rank;
struct ncclPeerInfo* peerInfo = comm->peerInfo+peer; struct ncclPeerInfo* peerInfo = comm->peerInfo+peer;
struct ncclConnector* connector = (type == 1) ? comm->channels[channelId].peers[peer].send + connIndex : struct ncclConnector* connector = (type == 1) ? comm->channels[channelId].peers[peer]->send + connIndex :
comm->channels[channelId].peers[peer].recv + connIndex; comm->channels[channelId].peers[peer]->recv + connIndex;
for (int t=0; t<NTRANSPORTS; t++) { for (int t=0; t<NTRANSPORTS; t++) {
struct ncclTransport *transport = ncclTransports[t]; struct ncclTransport *transport = ncclTransports[t];
struct ncclTransportComm* transportComm = type == 1 ? &transport->send : &transport->recv; struct ncclTransportComm* transportComm = type == 1 ? &transport->send : &transport->recv;
@ -45,12 +45,12 @@ ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, int channelId, int n
uint64_t mask = 1UL << channel->id; uint64_t mask = 1UL << channel->id;
for (int i=0; i<nrecv; i++) { for (int i=0; i<nrecv; i++) {
int peer = peerRecv[i]; int peer = peerRecv[i];
if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer].recv[connIndex].connected) continue; if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer]->recv[connIndex].connected) continue;
comm->connectRecv[peer] |= mask; comm->connectRecv[peer] |= mask;
} }
for (int i=0; i<nsend; i++) { for (int i=0; i<nsend; i++) {
int peer = peerSend[i]; int peer = peerSend[i];
if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer].send[connIndex].connected) continue; if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer]->send[connIndex].connected) continue;
comm->connectSend[peer] |= mask; comm->connectSend[peer] |= mask;
} }
return ncclSuccess; return ncclSuccess;
@ -73,7 +73,7 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
struct ncclConnect** recvData = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Points to entries inside data for given recv connection within a channel struct ncclConnect** recvData = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Points to entries inside data for given recv connection within a channel
struct ncclConnect** sendData = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Points to entries inside data for given send connection within a channel struct ncclConnect** sendData = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Points to entries inside data for given send connection within a channel
NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->hostStream), ret, fail); NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->sharedRes->hostStream), ret, fail);
// First time initialization // First time initialization
for (int i=1; i<comm->nRanks; i++) { for (int i=1; i<comm->nRanks; i++) {
int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0); int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0);
@ -142,13 +142,16 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
for (int c=0; c<MAXCHANNELS; c++) { for (int c=0; c<MAXCHANNELS; c++) {
TIME_START(3); TIME_START(3);
if (sendMask & (1UL<<c)) { if (sendMask & (1UL<<c)) {
struct ncclConnector* conn = comm->channels[c].peers[sendPeer].send + connIndex; struct ncclConnector* conn = comm->channels[c].peers[sendPeer]->send + connIndex;
// This connector hasn't completed connection yet // This connector hasn't completed connection yet
if (conn->connected == 0) { if (conn->connected == 0) {
NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData[i] + sendDataOffset++, 1, comm->rank, conn), ret, fail); NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData[i] + sendDataOffset++, 1, comm->rank, conn), ret, fail);
if (ret == ncclSuccess) { if (ret == ncclSuccess) {
struct ncclDevChannelPeer* addr;
conn->connected = 1; conn->connected = 1;
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); /* comm->channels[c].devPeers[sendPeer]->send[connIndex] is a device memory access. */
CUDACHECKGOTO(cudaMemcpyAsync(&addr, &comm->channels[c].devPeers[sendPeer], sizeof(struct ncclDevChannelPeer*), cudaMemcpyDeviceToHost, comm->sharedRes->hostStream.cudaStream), ret, fail);
CUDACHECKGOTO(cudaMemcpyAsync(&addr->send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), ret, fail);
} else if (ret == ncclInProgress) { } else if (ret == ncclInProgress) {
allChannelsConnected = false; allChannelsConnected = false;
} }
@ -159,13 +162,16 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
// Start with recv channels // Start with recv channels
TIME_START(4); TIME_START(4);
if (recvMask & (1UL<<c)) { if (recvMask & (1UL<<c)) {
struct ncclConnector* conn = comm->channels[c].peers[recvPeer].recv + connIndex; struct ncclConnector* conn = comm->channels[c].peers[recvPeer]->recv + connIndex;
// This connector hasn't completed connection yet // This connector hasn't completed connection yet
if (conn->connected == 0) { if (conn->connected == 0) {
NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData[i] + recvDataOffset++, 1, comm->rank, conn), ret, fail); NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData[i] + recvDataOffset++, 1, comm->rank, conn), ret, fail);
if (ret == ncclSuccess) { if (ret == ncclSuccess) {
struct ncclDevChannelPeer* addr;
conn->connected = 1; conn->connected = 1;
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); /* comm->channels[c].devPeers[recvPeer]->recv[connIndex] is a device memory access. */
CUDACHECKGOTO(cudaMemcpyAsync(&addr, &comm->channels[c].devPeers[recvPeer], sizeof(struct ncclDevChannelPeer*), cudaMemcpyDeviceToHost, comm->sharedRes->hostStream.cudaStream), ret, fail);
CUDACHECKGOTO(cudaMemcpyAsync(&addr->recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), ret, fail);
} else if (ret == ncclInProgress) { } else if (ret == ncclInProgress) {
allChannelsConnected = false; allChannelsConnected = false;
} }
@ -191,8 +197,8 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
if (highestTransportType != NULL) *highestTransportType = highestType; if (highestTransportType != NULL) *highestTransportType = highestType;
TIME_PRINT("P2P Setup/Connect"); TIME_PRINT("P2P Setup/Connect");
exit: exit:
NCCLCHECK(ncclStrongStreamWaitStream(ncclCudaGraphNone(), &comm->deviceStream, &comm->hostStream)); NCCLCHECK(ncclStrongStreamWaitStream(ncclCudaGraphNone(), &comm->sharedRes->deviceStream, &comm->sharedRes->hostStream));
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->hostStream)); NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->sharedRes->hostStream));
return ret; return ret;
fail: fail:
goto exit; goto exit;
@ -226,7 +232,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
} }
// select // select
struct ncclChannelPeer* root = channel->peers+nranks; struct ncclChannelPeer* root = channel->peers[nranks];
// connector index: 0 for recv, 1 for send // connector index: 0 for recv, 1 for send
struct ncclConnector* conn = (type == collNetRecv) ? root->recv+type : root->send+type; struct ncclConnector* conn = (type == collNetRecv) ? root->recv+type : root->send+type;
struct ncclTransportComm* transportComm = (type == collNetRecv) ? &(collNetTransport.recv) : &(collNetTransport.send); struct ncclTransportComm* transportComm = (type == collNetRecv) ? &(collNetTransport.recv) : &(collNetTransport.send);
@ -265,8 +271,9 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
// connect // connect
if (isMaster) { if (isMaster) {
NCCLCHECKGOTO(transportComm->connect(comm, masterConnects, nMasters, rankInCollNet, conn), res, cleanup); NCCLCHECKGOTO(transportComm->connect(comm, masterConnects, nMasters, rankInCollNet, conn), res, cleanup);
struct ncclDevChannelPeer* devRoot = channel->devPeers+nranks; struct ncclDevChannelPeer* devRoot;
struct ncclConnInfo* devConnInfo = (type == collNetRecv) ? devRoot->recv+type : devRoot->send+type; CUDACHECKGOTO(cudaMemcpy(&devRoot, channel->devPeers + nranks, sizeof(struct ncclDevChannelPeer*), cudaMemcpyDeviceToHost), res, cleanup);
struct ncclConnInfo* devConnInfo = (type == collNetRecv) ? devRoot->recv + type : devRoot->send + type;
CUDACHECKGOTO(cudaMemcpy(devConnInfo, &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice), res, cleanup); CUDACHECKGOTO(cudaMemcpy(devConnInfo, &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice), res, cleanup);
} }
// recv side sends connect info to send side // recv side sends connect info to send side
@ -305,16 +312,20 @@ ncclResult_t ncclTransportCollNetFree(struct ncclComm* comm) {
// Free collNet resources // Free collNet resources
for (int r=0; r<comm->nChannels; r++) { for (int r=0; r<comm->nChannels; r++) {
struct ncclChannel* channel = comm->channels+r; struct ncclChannel* channel = comm->channels+r;
struct ncclChannelPeer* peer = channel->peers+comm->nRanks; struct ncclChannelPeer* peer = channel->peers[comm->nRanks];
for (int b=0; b<NCCL_MAX_CONNS; b++) { if (peer) {
struct ncclConnector* send = peer->send + b; if (ncclAtomicRefCountDecrement(&peer->refCount) == 0) {
if (send->transportResources && send->transportComm) NCCLCHECK(send->transportComm->free(send)); for (int b=0; b<NCCL_MAX_CONNS; b++) {
send->transportResources = NULL; // avoid double free struct ncclConnector* send = peer->send + b;
} if (send->transportResources && send->transportComm) NCCLCHECK(send->transportComm->free(send));
for (int b=0; b<NCCL_MAX_CONNS; b++) { send->transportResources = NULL; // avoid double free
struct ncclConnector* recv = peer->recv + b; }
if (recv->transportResources && recv->transportComm) NCCLCHECK(recv->transportComm->free(recv)); for (int b=0; b<NCCL_MAX_CONNS; b++) {
recv->transportResources = NULL; // avoid double free struct ncclConnector* recv = peer->recv + b;
if (recv->transportResources && recv->transportComm) NCCLCHECK(recv->transportComm->free(recv));
recv->transportResources = NULL; // avoid double free
}
}
} }
} }
return ncclSuccess; return ncclSuccess;

View File

@ -141,6 +141,7 @@ struct setupReq {
int netDev; int netDev;
int useGdr; int useGdr;
int needFlush; int needFlush;
struct ncclCollNetSharedRes* collNet;
}; };
@ -149,16 +150,19 @@ struct setupReq {
static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) { static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) {
struct setupReq req; struct setupReq req;
int proxyRank; int proxyRank, tpProxyRank;
NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank));
NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr));
send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0;
// Determine whether we need to flush the GDR buffer on recv or not // Determine whether we need to flush the GDR buffer on recv or not
if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush)); if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush));
NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &send->proxyConn.localRank)); NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &send->proxyConn.tpLocalRank));
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 1, myInfo->rank, &send->proxyConn)); tpProxyRank = comm->topParentRanks[myInfo->rank];
NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 1, tpProxyRank, &send->proxyConn));
ncclAtomicRefCountIncrement(&comm->collNetSharedRes->refCount);
req.collNet = comm->collNetSharedRes;
NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0));
INFO(NCCL_INIT|NCCL_NET,"CollNet %02d/%1d : %d [send] via COLLNET/%s/%d%s", channelId, connIndex, myInfo->rank, collNetName(comm), req.netDev, INFO(NCCL_INIT|NCCL_NET,"CollNet %02d/%1d : %d [send] via COLLNET/%s/%d%s", channelId, connIndex, myInfo->rank, collNetName(comm), req.netDev,
req.useGdr ? "/GDRDMA" : ""); req.useGdr ? "/GDRDMA" : "");
@ -168,15 +172,18 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph
static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId, int connIndex) { static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId, int connIndex) {
struct setupReq req; struct setupReq req;
int proxyRank; int proxyRank, tpProxyRank;
NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank));
NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 0, &req.useGdr)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 0, &req.useGdr));
recv->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; recv->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0;
NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &recv->proxyConn.localRank)); NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &recv->proxyConn.tpLocalRank));
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 0, myInfo->rank, &recv->proxyConn)); tpProxyRank = comm->topParentRanks[myInfo->rank];
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 0, tpProxyRank, &recv->proxyConn));
struct collNetRecvConnectInfo* info = (struct collNetRecvConnectInfo*) connectInfo; struct collNetRecvConnectInfo* info = (struct collNetRecvConnectInfo*) connectInfo;
NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), &info->collNetHandle, sizeof(collNetHandle_t))); ncclAtomicRefCountIncrement(&comm->collNetSharedRes->refCount);
req.collNet = comm->collNetSharedRes;
NCCLCHECK(ncclProxyCallBlocking(comm, &recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), &info->collNetHandle, sizeof(collNetHandle_t)));
INFO(NCCL_INIT|NCCL_NET,"CollNet %02d/%1d : %d [receive] via COLLNET/%s/%d%s", channelId, connIndex, myInfo->rank, collNetName(comm), req.netDev, INFO(NCCL_INIT|NCCL_NET,"CollNet %02d/%1d : %d [receive] via COLLNET/%s/%d%s", channelId, connIndex, myInfo->rank, collNetName(comm), req.netDev,
req.useGdr ? "/GDRDMA" : ""); req.useGdr ? "/GDRDMA" : "");
@ -221,7 +228,7 @@ static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* conne
// We're on the same process as the proxy. We can pass a pointer to a struct. // We're on the same process as the proxy. We can pass a pointer to a struct.
struct collNetConnectArgs args = { rank, nranks, connectInfos }; struct collNetConnectArgs args = { rank, nranks, connectInfos };
struct connectMap* map; struct connectMap* map;
NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*)));
// If collnet connect failed, propagate error to fallback on regular p2p // If collnet connect failed, propagate error to fallback on regular p2p
if (map == NULL) return ncclSystemError; if (map == NULL) return ncclSystemError;
@ -247,7 +254,7 @@ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* conne
// We're on the same process as the proxy. We can pass a pointer to a struct. // We're on the same process as the proxy. We can pass a pointer to a struct.
struct collNetConnectArgs args = { rank, nranks, connectInfos }; struct collNetConnectArgs args = { rank, nranks, connectInfos };
struct connectMap* map; struct connectMap* map;
NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); NCCLCHECK(ncclProxyCallBlocking(comm, &recv->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*)));
// If collnet connect failed, propagate error to fallback on regular p2p // If collnet connect failed, propagate error to fallback on regular p2p
if (map == NULL) return ncclSystemError; if (map == NULL) return ncclSystemError;
@ -276,7 +283,7 @@ static ncclResult_t recvFree(struct ncclConnector* recv) {
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
struct setupReq* req = (struct setupReq*)reqBuff; struct setupReq* req = (struct setupReq*)reqBuff;
if (reqSize != sizeof(struct setupReq)) return ncclInternalError; if (reqSize != sizeof(struct setupReq)) return ncclInternalError;
@ -288,9 +295,10 @@ static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struc
resources->netDev = req->netDev; resources->netDev = req->netDev;
resources->useGdr = req->useGdr; resources->useGdr = req->useGdr;
ncclNetProperties_t props; ncclNetProperties_t props;
NCCLCHECK(collNetGetProperties(comm, req->netDev, &props)); NCCLCHECK(proxyState->ncclCollNet->getProperties(req->netDev, &props));
connection->collNet = req->collNet;
/* DMA-BUF support */ /* DMA-BUF support */
resources->useDmaBuf = resources->useGdr && comm->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF); resources->useDmaBuf = resources->useGdr && proxyState->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF);
return ncclSuccess; return ncclSuccess;
} }
@ -300,19 +308,19 @@ struct sharedResources {
int commRefCount[NCCL_MAX_NETDEVS]; int commRefCount[NCCL_MAX_NETDEVS];
}; };
ncclResult_t sharedListen(struct ncclComm* comm, int netDev, void* collNetHandle) { static ncclResult_t sharedListen(struct ncclProxyState* proxyState, int netDev, struct ncclCollNetSharedRes* collNet, void* collNetHandle) {
struct sharedResources* resources = (struct sharedResources*)comm->proxyState.progressState.collNet.resources; struct sharedResources* resources = (struct sharedResources*)collNet->resources;
if (resources == NULL) { if (resources == NULL) {
NCCLCHECK(ncclCalloc(&resources, 1)); NCCLCHECK(ncclCalloc(&resources, 1));
comm->proxyState.progressState.collNet.resources = resources; collNet->resources = resources;
} }
if (resources->collNetComms[netDev] == NULL) if (resources->collNetComms[netDev] == NULL)
NCCLCHECK(collNetListen(comm, netDev, collNetHandle, resources->collNetListenComms+netDev)); NCCLCHECK(proxyState->ncclCollNet->listen(netDev, collNetHandle, resources->collNetListenComms + netDev));
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sharedConnect(struct ncclComm* comm, int netDev, struct ncclConnect* connectInfos, int nranks, int rank, void** collNetComm) { static ncclResult_t sharedConnect(struct ncclProxyState* proxyState, int netDev, struct ncclConnect* connectInfos, int nranks, int rank, struct ncclCollNetSharedRes* collNet, void** collNetComm) {
struct sharedResources* resources = (struct sharedResources*)comm->proxyState.progressState.collNet.resources; struct sharedResources* resources = (struct sharedResources*)collNet->resources;
if (resources->collNetComms[netDev] == NULL) { if (resources->collNetComms[netDev] == NULL) {
// Connect to coll comm // Connect to coll comm
collNetHandle_t** handlePtrs = NULL; collNetHandle_t** handlePtrs = NULL;
@ -321,13 +329,13 @@ static ncclResult_t sharedConnect(struct ncclComm* comm, int netDev, struct nccl
struct collNetRecvConnectInfo* info = (struct collNetRecvConnectInfo*)(connectInfos+i); struct collNetRecvConnectInfo* info = (struct collNetRecvConnectInfo*)(connectInfos+i);
handlePtrs[i] = &(info->collNetHandle); handlePtrs[i] = &(info->collNetHandle);
} }
ncclResult_t ret = collNetConnect(comm, (void**)handlePtrs, nranks, rank, ncclResult_t ret = proxyState->ncclCollNet->connect((void**)handlePtrs, nranks, rank,
resources->collNetListenComms[netDev], resources->collNetListenComms[netDev],
resources->collNetComms+netDev); resources->collNetComms+netDev);
free(handlePtrs); free(handlePtrs);
if (ret == ncclSuccess) { if (ret == ncclSuccess) {
// Close listen comm // Close listen comm
NCCLCHECK(collNetCloseListen(comm, resources->collNetListenComms[netDev])); NCCLCHECK(proxyState->ncclCollNet->closeListen(resources->collNetListenComms[netDev]));
} else { } else {
resources->collNetListenComms[netDev] = NULL; resources->collNetListenComms[netDev] = NULL;
} }
@ -337,55 +345,53 @@ static ncclResult_t sharedConnect(struct ncclComm* comm, int netDev, struct nccl
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sharedFree(struct ncclComm* comm, int netDev) { static ncclResult_t sharedFree(struct ncclProxyState* proxyState, struct ncclCollNetSharedRes* collNet, int netDev) {
struct sharedResources* resources = (struct sharedResources*)comm->proxyState.progressState.collNet.resources; struct sharedResources* resources = (struct sharedResources*)collNet->resources;
resources->commRefCount[netDev]--; resources->commRefCount[netDev]--;
if (resources->commRefCount[netDev] == 0) { if (resources->commRefCount[netDev] == 0) {
NCCLCHECK(collNetCloseColl(comm, resources->collNetComms[netDev])); NCCLCHECK(proxyState->ncclCollNet->closeColl(resources->collNetComms[netDev]));
} }
for (int n=0; n<NCCL_MAX_NETDEVS; n++) if (resources->commRefCount[n]) return ncclSuccess; for (int n=0; n<NCCL_MAX_NETDEVS; n++) if (resources->commRefCount[n]) return ncclSuccess;
comm->proxyState.progressState.collNet.resources = NULL; collNet->resources = NULL;
free(resources); free(resources);
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, char** gpuPtr, char** cpuPtr, int* size) { static ncclResult_t sharedBuffersInit(struct ncclCollNetSharedRes* collNet, int cuda, char** gpuPtr, char** cpuPtr, int* size) {
struct ncclProxySharedCollNet* state = &comm->proxyState.progressState.collNet; if (collNet->size == 0) {
if (state->size == 0) { collNet->size = 2 * collNet->nChannels * collNet->buffSize;
state->size = 2*comm->nChannels*comm->buffSizes[NCCL_PROTO_SIMPLE];
} }
*size = state->size; *size = collNet->size;
if (cuda && state->cudaBuff == NULL) { if (cuda && collNet->cudaBuff == NULL) {
NCCLCHECK(ncclCudaCalloc(&state->cudaBuff, *size)); NCCLCHECK(ncclCudaCalloc(&collNet->cudaBuff, *size));
} }
if (!cuda && state->hostBuff == NULL) { if (!cuda && collNet->hostBuff == NULL) {
NCCLCHECK(ncclCudaHostCalloc(&state->hostBuff, *size)); NCCLCHECK(ncclCudaHostCalloc(&collNet->hostBuff, *size));
} }
*gpuPtr = *cpuPtr = cuda ? state->cudaBuff : state->hostBuff; *gpuPtr = *cpuPtr = cuda ? collNet->cudaBuff : collNet->hostBuff;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int type, int slot, int channel, int* offset) { static ncclResult_t sharedBuffersGet(struct ncclCollNetSharedRes* collNet, int type, int slot, int channel, int* offset) {
// Use different pools for different channels and also separate send/recv. // Use different pools for different channels and also separate send/recv.
int slotSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; int slotSize = collNet->buffSize / NCCL_STEPS;
int globalSlot = (type*NCCL_STEPS+slot)*comm->nChannels+channel; int globalSlot = (type * NCCL_STEPS + slot) * collNet->nChannels + channel;
*offset = slotSize * globalSlot; *offset = slotSize * globalSlot;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sharedBuffersDestroy(struct ncclComm* comm) { static ncclResult_t sharedBuffersDestroy(struct ncclCollNetSharedRes* collNet) {
struct ncclProxySharedCollNet* state = &comm->proxyState.progressState.collNet; if (collNet->size == 0) return ncclSuccess;
if (state->size == 0) return ncclSuccess; NCCLCHECK(ncclCudaFree(collNet->cudaBuff));
CUDACHECK(cudaFree(state->cudaBuff)); NCCLCHECK(ncclCudaHostFree(collNet->hostBuff));
NCCLCHECK(ncclCudaHostFree(state->hostBuff));
// This will be called multiple times, with multiple channels and send/recv. Make sure we only do it once. // This will be called multiple times, with multiple channels and send/recv. Make sure we only do it once.
state->size = 0; collNet->size = 0;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
struct setupReq* req = (struct setupReq*)reqBuff; struct setupReq* req = (struct setupReq*)reqBuff;
if (reqSize != sizeof (struct setupReq)) return ncclInternalError; if (reqSize != sizeof (struct setupReq)) return ncclInternalError;
@ -398,18 +404,19 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc
resources->useGdr = req->useGdr; resources->useGdr = req->useGdr;
resources->needFlush = req->needFlush; resources->needFlush = req->needFlush;
ncclNetProperties_t props; ncclNetProperties_t props;
NCCLCHECK(collNetGetProperties(comm, req->netDev, &props)); NCCLCHECK(proxyState->ncclCollNet->getProperties(req->netDev, &props));
connection->collNet = req->collNet;
/* DMA-BUF support */ /* DMA-BUF support */
resources->useDmaBuf = resources->useGdr && comm->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF); resources->useDmaBuf = resources->useGdr && proxyState->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF);
collNetHandle_t* netHandle = (collNetHandle_t*) respBuff; collNetHandle_t* netHandle = (collNetHandle_t*) respBuff;
if (respSize != sizeof(collNetHandle_t)) return ncclInternalError; if (respSize != sizeof(collNetHandle_t)) return ncclInternalError;
NCCLCHECK(sharedListen(comm, req->netDev, netHandle)); NCCLCHECK(sharedListen(proxyState, req->netDev, req->collNet, netHandle));
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("sendProxyConnect: reqSize is %d != %ld", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("sendProxyConnect: reqSize is %d != %ld", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; }
struct collNetConnectArgs* args = (struct collNetConnectArgs*)reqBuff; struct collNetConnectArgs* args = (struct collNetConnectArgs*)reqBuff;
struct collNetSendConnectInfo* info = (struct collNetSendConnectInfo*)(args->connectInfos+args->rank); struct collNetSendConnectInfo* info = (struct collNetSendConnectInfo*)(args->connectInfos+args->rank);
@ -423,7 +430,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) for (int p=0; p<NCCL_NUM_PROTOCOLS; p++)
resources->recvMhandles[p] = info->mhandles[p]; resources->recvMhandles[p] = info->mhandles[p];
NCCLCHECK(sharedConnect(comm, resources->netDev, args->connectInfos, args->nranks, args->rank, &resources->collNetComm)); NCCLCHECK(sharedConnect(proxyState, resources->netDev, args->connectInfos, args->nranks, args->rank, connection->collNet, &resources->collNetComm));
// Collnet connect is allowed to fail. Gracefully handle that case by returning NULL to the caller. // Collnet connect is allowed to fail. Gracefully handle that case by returning NULL to the caller.
if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; } if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; }
@ -431,7 +438,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
*((struct connectMap**)respBuff) = NULL; *((struct connectMap**)respBuff) = NULL;
return ncclSuccess; return ncclSuccess;
} }
connection->proxyAppendPtr = comm->proxyState.progressState.collNet.proxyAppend+2*resources->netDev; connection->proxyAppendPtr = connection->collNet->proxyAppend + 2 * resources->netDev;
struct connectMap* map = &resources->map; struct connectMap* map = &resources->map;
@ -459,7 +466,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
// Allocate & Register shared buffers for the Simple protocol // Allocate & Register shared buffers for the Simple protocol
int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM; int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM;
struct connectMapMem* mapMem = map->mems+bank; struct connectMapMem* mapMem = map->mems+bank;
NCCLCHECK(sharedBuffersInit(comm, resources->useGdr, &mapMem->gpuPtr, &mapMem->cpuPtr, &mapMem->size)); NCCLCHECK(sharedBuffersInit(connection->collNet, resources->useGdr, &mapMem->gpuPtr, &mapMem->cpuPtr, &mapMem->size));
NCCL_NET_MAP_ADD_POINTER(map, 1, resources->useGdr, mapMem->size, buffs[NCCL_PROTO_SIMPLE]); NCCL_NET_MAP_ADD_POINTER(map, 1, resources->useGdr, mapMem->size, buffs[NCCL_PROTO_SIMPLE]);
#if CUDA_VERSION >= 11070 #if CUDA_VERSION >= 11070
@ -467,23 +474,23 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
if (resources->useGdr && resources->useDmaBuf) { if (resources->useGdr && resources->useDmaBuf) {
int dmabuf_fd; int dmabuf_fd;
CUCHECK(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)mapMem->cpuPtr, mapMem->size, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0)); CUCHECK(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)mapMem->cpuPtr, mapMem->size, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0));
NCCLCHECK(collNetRegMrDmaBuf(comm, resources->collNetComm, mapMem->cpuPtr, mapMem->size, NCCLCHECK(proxyState->ncclCollNet->regMrDmaBuf(resources->collNetComm, mapMem->cpuPtr, mapMem->size,
NCCL_PTR_CUDA, 0ULL, dmabuf_fd, NCCL_PTR_CUDA, 0ULL, dmabuf_fd,
&resources->sendMhandles[NCCL_PROTO_SIMPLE])); &resources->sendMhandles[NCCL_PROTO_SIMPLE]));
(void)close(dmabuf_fd); (void)close(dmabuf_fd);
} else // FALL-THROUGH to nv_peermem GDR path } else // FALL-THROUGH to nv_peermem GDR path
#endif #endif
{ {
NCCLCHECK(collNetRegMr(comm, resources->collNetComm, mapMem->cpuPtr, mapMem->size, NCCLCHECK(proxyState->ncclCollNet->regMr(resources->collNetComm, mapMem->cpuPtr, mapMem->size,
resources->useGdr ? NCCL_PTR_CUDA : NCCL_PTR_HOST, resources->useGdr ? NCCL_PTR_CUDA : NCCL_PTR_HOST,
&resources->sendMhandles[NCCL_PROTO_SIMPLE])); &resources->sendMhandles[NCCL_PROTO_SIMPLE]));
} }
*((struct connectMap**)respBuff) = &resources->map; *((struct connectMap**)respBuff) = &resources->map;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("recvProxyConnect: reqSize is %d != %ld", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("recvProxyConnect: reqSize is %d != %ld", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; }
struct collNetConnectArgs* args = (struct collNetConnectArgs*)reqBuff; struct collNetConnectArgs* args = (struct collNetConnectArgs*)reqBuff;
@ -491,7 +498,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
struct collNetSendConnectInfo* info = (struct collNetSendConnectInfo*)(args->connectInfos+args->rank); struct collNetSendConnectInfo* info = (struct collNetSendConnectInfo*)(args->connectInfos+args->rank);
resources->collNetRank = args->rank; resources->collNetRank = args->rank;
NCCLCHECK(sharedConnect(comm, resources->netDev, args->connectInfos, args->nranks, args->rank, &resources->collNetComm)); NCCLCHECK(sharedConnect(proxyState, resources->netDev, args->connectInfos, args->nranks, args->rank, connection->collNet, &resources->collNetComm));
// Collnet connect is allowed to fail. Gracefully handle that case by returning NULL to the caller. // Collnet connect is allowed to fail. Gracefully handle that case by returning NULL to the caller.
if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; } if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; }
@ -499,7 +506,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
*((struct connectMap**)respBuff) = NULL; *((struct connectMap**)respBuff) = NULL;
return ncclSuccess; return ncclSuccess;
} }
connection->proxyAppendPtr = comm->proxyState.progressState.collNet.proxyAppend+2*resources->netDev+1; connection->proxyAppendPtr = connection->collNet->proxyAppend + 2 * resources->netDev + 1;
struct connectMap* map = &resources->map; struct connectMap* map = &resources->map;
@ -528,7 +535,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
// Allocate & Register shared buffers for the Simple protocol // Allocate & Register shared buffers for the Simple protocol
int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM; int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM;
struct connectMapMem* mapMem = map->mems+bank; struct connectMapMem* mapMem = map->mems+bank;
NCCLCHECK(sharedBuffersInit(comm, resources->useGdr, &mapMem->gpuPtr, &mapMem->cpuPtr, &mapMem->size)); NCCLCHECK(sharedBuffersInit(connection->collNet, resources->useGdr, &mapMem->gpuPtr, &mapMem->cpuPtr, &mapMem->size));
NCCL_NET_MAP_ADD_POINTER(map, 1, resources->useGdr, mapMem->size, buffs[NCCL_PROTO_SIMPLE]); NCCL_NET_MAP_ADD_POINTER(map, 1, resources->useGdr, mapMem->size, buffs[NCCL_PROTO_SIMPLE]);
#if CUDA_VERSION >= 11070 #if CUDA_VERSION >= 11070
@ -536,16 +543,16 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
if (resources->useGdr && resources->useDmaBuf) { if (resources->useGdr && resources->useDmaBuf) {
int dmabuf_fd; int dmabuf_fd;
CUCHECK(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)mapMem->cpuPtr, mapMem->size, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0)); CUCHECK(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)mapMem->cpuPtr, mapMem->size, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0));
NCCLCHECK(collNetRegMrDmaBuf(comm, resources->collNetComm, mapMem->cpuPtr, mapMem->size, NCCLCHECK(proxyState->ncclCollNet->regMrDmaBuf(resources->collNetComm, mapMem->cpuPtr, mapMem->size,
NCCL_PTR_CUDA, 0ULL, dmabuf_fd, NCCL_PTR_CUDA, 0ULL, dmabuf_fd,
&resources->mhandles[NCCL_PROTO_SIMPLE])); &resources->mhandles[NCCL_PROTO_SIMPLE]));
(void)close(dmabuf_fd); (void)close(dmabuf_fd);
} else // FALL-THROUGH to nv_peermem GDR path } else // FALL-THROUGH to nv_peermem GDR path
#endif #endif
{ {
NCCLCHECK(collNetRegMr(comm, resources->collNetComm, mapMem->cpuPtr, mapMem->size, NCCLCHECK(proxyState->ncclCollNet->regMr(resources->collNetComm, mapMem->cpuPtr, mapMem->size,
resources->useGdr ? NCCL_PTR_CUDA : NCCL_PTR_HOST, resources->useGdr ? NCCL_PTR_CUDA : NCCL_PTR_HOST,
&resources->mhandles[NCCL_PROTO_SIMPLE])); &resources->mhandles[NCCL_PROTO_SIMPLE]));
} }
// Pass info to send side // Pass info to send side
@ -558,41 +565,43 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState) {
struct sendResources* resources = (struct sendResources*)(connection->transportResources); struct sendResources* resources = (struct sendResources*)(connection->transportResources);
if (resources) { if (resources) {
for (int p = 0; p < NCCL_NUM_PROTOCOLS; p++) { for (int p = 0; p < NCCL_NUM_PROTOCOLS; p++) {
if (resources->sendMhandles[p]) { if (resources->sendMhandles[p]) {
NCCLCHECK(collNetDeregMr(comm, resources->collNetComm, resources->sendMhandles[p])); NCCLCHECK(proxyState->ncclCollNet->deregMr(resources->collNetComm, resources->sendMhandles[p]));
} }
} }
struct connectMapMem* mems = resources->map.mems; struct connectMapMem* mems = resources->map.mems;
NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr));
CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); NCCLCHECK(ncclCudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr));
if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc));
NCCLCHECK(sharedBuffersDestroy(comm)); NCCLCHECK(sharedBuffersDestroy(connection->collNet));
NCCLCHECK(sharedFree(comm, resources->netDev)); NCCLCHECK(sharedFree(proxyState, connection->collNet, resources->netDev));
if (ncclAtomicRefCountDecrement(&connection->collNet->refCount) == 0) free(connection->collNet);
free(connection->transportResources); free(connection->transportResources);
} }
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState) {
struct recvResources* resources = (struct recvResources*)(connection->transportResources); struct recvResources* resources = (struct recvResources*)(connection->transportResources);
if (resources) { if (resources) {
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
if (resources->mhandles[p]) { if (resources->mhandles[p]) {
NCCLCHECK(collNetDeregMr(comm, resources->collNetComm, resources->mhandles[p])); NCCLCHECK(proxyState->ncclCollNet->deregMr(resources->collNetComm, resources->mhandles[p]));
} }
} }
struct connectMapMem* mems = resources->map.mems; struct connectMapMem* mems = resources->map.mems;
NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr));
CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); NCCLCHECK(ncclCudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr));
if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc));
NCCLCHECK(sharedBuffersDestroy(comm)); NCCLCHECK(sharedBuffersDestroy(connection->collNet));
NCCLCHECK(sharedFree(comm, resources->netDev)); NCCLCHECK(sharedFree(proxyState, connection->collNet, resources->netDev));
if (ncclAtomicRefCountDecrement(&connection->collNet->refCount) == 0) free(connection->collNet);
free(connection->transportResources); free(connection->transportResources);
} }
return ncclSuccess; return ncclSuccess;
@ -602,7 +611,7 @@ static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct
#define LAST_OF_GROUP(s) \ #define LAST_OF_GROUP(s) \
(s % COLLNET_GROUP_NSUBS == COLLNET_GROUP_NSUBS-1 || s == args->nsubs-1) (s % COLLNET_GROUP_NSUBS == COLLNET_GROUP_NSUBS-1 || s == args->nsubs-1)
static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArgs* args) { static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) { if (args->state == ncclProxyOpReady) {
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
@ -629,7 +638,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;
int sharedBuffSlot = sub->posted%NCCL_STEPS; int sharedBuffSlot = sub->posted%NCCL_STEPS;
int offset; int offset;
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset)); NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 0, sharedBuffSlot, 0, &offset));
resources->recvMem->offsFifo[buffSlot] = offset + s*args->chunkSize; resources->recvMem->offsFifo[buffSlot] = offset + s*args->chunkSize;
__sync_synchronize(); __sync_synchronize();
volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head;
@ -650,7 +659,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int ready = 1; int ready = 1;
if (s == 0) { if (s == 0) {
int offset; int offset;
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset)); NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 0, sharedBuffSlot, 0, &offset));
args->sharedBuff[sharedBuffSlot] = localBuff + offset; args->sharedBuff[sharedBuffSlot] = localBuff + offset;
args->sharedSize[sharedBuffSlot] = args->chunkSize; args->sharedSize[sharedBuffSlot] = args->chunkSize;
} }
@ -671,7 +680,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int count = totalSize / ncclTypeSize((ncclDataType_t)args->dtype); int count = totalSize / ncclTypeSize((ncclDataType_t)args->dtype);
reqFifo[group][buffSlot].size = args->sharedSize[sharedBuffSlot]; reqFifo[group][buffSlot].size = args->sharedSize[sharedBuffSlot];
char* sendAddress = (char*)args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*args->sharedSize[sharedBuffSlot]; char* sendAddress = (char*)args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*args->sharedSize[sharedBuffSlot];
NCCLCHECK(collNetIallreduce(comm, resources->collNetComm, sendAddress, (void*)(reqFifo[group][buffSlot].recvBuff), count, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sendMhandle, recvMhandle, sub->requests+buffSlot)); NCCLCHECK(proxyState->ncclCollNet->iallreduce(resources->collNetComm, sendAddress, (void*)(reqFifo[group][buffSlot].recvBuff), count, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sendMhandle, recvMhandle, sub->requests+buffSlot));
if (sub->requests[buffSlot] == NULL) continue; if (sub->requests[buffSlot] == NULL) continue;
TRACE(NCCL_NET, "sendProxy [%d/%d/%d] Iallreduce posted, size %d req %p", sub->transmitted, group, buffSlot, totalSize, sub->requests[buffSlot]); TRACE(NCCL_NET, "sendProxy [%d/%d/%d] Iallreduce posted, size %d req %p", sub->transmitted, group, buffSlot, totalSize, sub->requests[buffSlot]);
@ -687,7 +696,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int done, size; int done, size;
int group = s / COLLNET_GROUP_NSUBS; int group = s / COLLNET_GROUP_NSUBS;
int buffSlot = (sub->base+sub->done)%NCCL_STEPS; int buffSlot = (sub->base+sub->done)%NCCL_STEPS;
NCCLCHECK(collNetTest(comm, (void*)(sub->requests[buffSlot]), &done, &size)); NCCLCHECK(proxyState->ncclCollNet->test((void*)(sub->requests[buffSlot]), &done, &size));
if (done) { if (done) {
TRACE(NCCL_NET, "sendProxy [%d/%d/%d] request %p done, size %d", sub->done, group, buffSlot, sub->requests[buffSlot], size); TRACE(NCCL_NET, "sendProxy [%d/%d/%d] request %p done, size %d", sub->done, group, buffSlot, sub->requests[buffSlot], size);
// Make sure size is updated before we set recvBuff to NULL (from the view of recv proxy, concerning the flush) // Make sure size is updated before we set recvBuff to NULL (from the view of recv proxy, concerning the flush)
@ -711,7 +720,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArgs* args) { static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) { if (args->state == ncclProxyOpReady) {
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
@ -742,7 +751,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int sharedBuffSlot = sub->posted%NCCL_STEPS; int sharedBuffSlot = sub->posted%NCCL_STEPS;
int startChannel = group*COLLNET_GROUP_NSUBS; int startChannel = group*COLLNET_GROUP_NSUBS;
int offset; int offset;
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset)); NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 1, sharedBuffSlot, startChannel, &offset));
reqFifo[group][buffSlot].recvBuff = localBuff + offset; reqFifo[group][buffSlot].recvBuff = localBuff + offset;
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff); TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
sub->posted += args->sliceSteps; sub->posted += args->sliceSteps;
@ -773,8 +782,8 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
} else { } else {
int startChannel = group*COLLNET_GROUP_NSUBS; int startChannel = group*COLLNET_GROUP_NSUBS;
int offset; int offset;
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset)); NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 1, sharedBuffSlot, startChannel, &offset));
NCCLCHECK(collNetIflush(comm, resources->collNetComm, localBuff + offset, totalSize, mhandle, sub->requests+buffSlot)); NCCLCHECK(proxyState->ncclCollNet->iflush(resources->collNetComm, localBuff + offset, totalSize, mhandle, sub->requests+buffSlot));
} }
} else { } else {
for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps; for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps;
@ -788,7 +797,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int group = s / COLLNET_GROUP_NSUBS; int group = s / COLLNET_GROUP_NSUBS;
int buffSlot = (sub->base + sub->flushed)%NCCL_STEPS; int buffSlot = (sub->base + sub->flushed)%NCCL_STEPS;
int done = 1; int done = 1;
if (sub->requests[buffSlot]) NCCLCHECK(collNetTest(comm, sub->requests[buffSlot], &done, NULL)); if (sub->requests[buffSlot]) NCCLCHECK(proxyState->ncclCollNet->test(sub->requests[buffSlot], &done, NULL));
if (done) { if (done) {
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] flushed", sub->flushed, group, buffSlot); TRACE(NCCL_NET, "recvProxy [%d/%d/%d] flushed", sub->flushed, group, buffSlot);
for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps; for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps;
@ -802,7 +811,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int sharedBuffSlot = sub->transmitted%NCCL_STEPS; int sharedBuffSlot = sub->transmitted%NCCL_STEPS;
int startChannel = group*COLLNET_GROUP_NSUBS; int startChannel = group*COLLNET_GROUP_NSUBS;
int offset; int offset;
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset)); NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 1, sharedBuffSlot, startChannel, &offset));
volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo; volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo;
offsFifo[buffSlot] = offset + (s%COLLNET_GROUP_NSUBS)*args->chunkSize; offsFifo[buffSlot] = offset + (s%COLLNET_GROUP_NSUBS)*args->chunkSize;
__sync_synchronize(); __sync_synchronize();

View File

@ -11,6 +11,7 @@
#include "collectives.h" #include "collectives.h"
#include "gdrwrap.h" #include "gdrwrap.h"
#include "shm.h" #include "shm.h"
#include "p2p.h"
#include "profiler.h" #include "profiler.h"
static_assert(sizeof(ncclNetHandle_t) <= CONNECT_SIZE, "NET Connect info is too large"); static_assert(sizeof(ncclNetHandle_t) <= CONNECT_SIZE, "NET Connect info is too large");
@ -59,10 +60,8 @@ struct connectMapMem{
char* gpuPtr; char* gpuPtr;
char* cpuPtr; char* cpuPtr;
int size; int size;
union { ncclIpcDesc ipcDesc;
char shmPath[PATH_MAX]; char shmPath[PATH_MAX];
cudaIpcMemHandle_t ipc;
};
ncclShmHandle_t attachHandle; ncclShmHandle_t attachHandle;
ncclShmHandle_t createHandle; ncclShmHandle_t createHandle;
}; };
@ -87,9 +86,9 @@ struct sendResources {
struct ncclSendMem* sendMem; struct ncclSendMem* sendMem;
struct ncclRecvMem* recvMem; struct ncclRecvMem* recvMem;
int rank; int tpRank;
int localRank; int tpLocalRank;
int remoteRank; int tpRemoteRank;
int netDev; int netDev;
int useGdr; int useGdr;
int useDmaBuf; int useDmaBuf;
@ -113,10 +112,10 @@ struct recvResources {
struct ncclSendMem* sendMem; struct ncclSendMem* sendMem;
struct ncclRecvMem* recvMem; struct ncclRecvMem* recvMem;
int rank; int tpRank;
int localRank; int tpLocalRank;
int remoteRank; int tpRemoteRank;
int proxyRank; int tpRemoteProxyRank;
int netDev; int netDev;
int useGdr; int useGdr;
int useDmaBuf; int useDmaBuf;
@ -149,9 +148,9 @@ NCCL_PARAM(NetSharedBuffers, "NET_SHARED_BUFFERS", -2);
NCCL_PARAM(NetSharedComms, "NET_SHARED_COMMS", 1); NCCL_PARAM(NetSharedComms, "NET_SHARED_COMMS", 1);
struct setupReq { struct setupReq {
int rank; int tpRank;
int localRank; int tpLocalRank;
int remoteRank; int tpRemoteRank;
int shared; int shared;
int netDev; int netDev;
int useGdr; int useGdr;
@ -164,6 +163,7 @@ struct setupReq {
* information for this peer */ * information for this peer */
static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) { static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) {
struct setupReq req; struct setupReq req;
int localRank, tpProxyRank;
send->conn.shared = req.shared = graph ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1; send->conn.shared = req.shared = graph ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1;
req.channelId = channelId; req.channelId = channelId;
@ -174,20 +174,22 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph
NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr));
send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0;
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_NET, 1, proxyRank, &send->proxyConn)); tpProxyRank = comm->topParentRanks[proxyRank];
req.rank = myInfo->rank; NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_NET, 1, tpProxyRank, &send->proxyConn));
NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &req.localRank)); NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &localRank));
req.remoteRank = peerInfo->rank; req.tpLocalRank = comm->topParentLocalRanks[localRank];
NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); req.tpRank = comm->topParentRanks[myInfo->rank];
req.tpRemoteRank = comm->topParentRanks[peerInfo->rank];
NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0));
if (proxyRank == myInfo->rank) { if (proxyRank == myInfo->rank) {
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d%s%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, ncclNetName(comm), req.netDev, INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d%s%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, comm->ncclNet->name, req.netDev,
req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : ""); req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "");
} else { } else {
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d(%d)%s%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, ncclNetName(comm), req.netDev, INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d(%d)%s%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, comm->ncclNet->name, req.netDev,
proxyRank, req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : ""); proxyRank, req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "");
} }
*((int*)connectInfo) = proxyRank; *((int*)connectInfo) = tpProxyRank;
return ncclSuccess; return ncclSuccess;
} }
@ -199,13 +201,14 @@ NCCL_PARAM(GdrCopyFlushEnable, "GDRCOPY_FLUSH_ENABLE", 0);
/* Setup recv connector */ /* Setup recv connector */
static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId, int connIndex) { static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId, int connIndex) {
struct setupReq req; struct setupReq req;
int localRank;
recv->conn.shared = req.shared = graph ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1; recv->conn.shared = req.shared = graph ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1;
req.channelId = channelId; req.channelId = channelId;
req.connIndex = connIndex; req.connIndex = connIndex;
// Use myInfo->rank as the receiver uses its own NIC // Use myInfo->rank as the receiver uses its own NIC
int proxyRank; int proxyRank, tpProxyRank;
NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, myInfo->rank, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, myInfo->rank, &req.netDev, &proxyRank));
NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 0, &req.useGdr)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 0, &req.useGdr));
@ -213,13 +216,15 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph
if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush)); if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush));
// We don't support PXN on receive yet // We don't support PXN on receive yet
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_NET, 0, myInfo->rank, &recv->proxyConn)); tpProxyRank = comm->topParentRanks[myInfo->rank];
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_NET, 0, tpProxyRank, &recv->proxyConn));
req.rank = myInfo->rank; NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &localRank));
NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &req.localRank)); req.tpLocalRank = comm->topParentLocalRanks[localRank];
req.remoteRank = peerInfo->rank; req.tpRank = comm->topParentRanks[myInfo->rank];
NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), connectInfo, sizeof(ncclNetHandle_t))); req.tpRemoteRank = comm->topParentRanks[peerInfo->rank];
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [receive] via NET/%s/%d%s%s", channelId, connIndex, peerInfo->rank, peerInfo->busId, myInfo->rank, myInfo->busId, ncclNetName(comm), req.netDev, NCCLCHECK(ncclProxyCallBlocking(comm, &recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), connectInfo, sizeof(ncclNetHandle_t)));
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [receive] via NET/%s/%d%s%s", channelId, connIndex, peerInfo->rank, peerInfo->busId, myInfo->rank, myInfo->busId, comm->ncclNet->name, req.netDev,
req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : ""); req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "");
return ncclSuccess; return ncclSuccess;
} }
@ -274,39 +279,47 @@ static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* conne
send->transportResources = map; send->transportResources = map;
opId = send; opId = send;
INFO(NCCL_PROXY, "sendConnect ncclProxyCallAsync opId=%p", opId); INFO(NCCL_PROXY, "sendConnect ncclProxyCallAsync opId=%p", opId);
NCCLCHECK(ncclProxyCallAsync(&send->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(ncclNetHandle_t), sizeof(struct connectMap), opId)); NCCLCHECK(ncclProxyCallAsync(comm, &send->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(ncclNetHandle_t), sizeof(struct connectMap), opId));
} else { } else {
opId = send; opId = send;
} }
ncclResult_t ret; ncclResult_t ret;
NCCLCHECK(ret = ncclPollProxyResponse(&send->proxyConn, map, opId)); NCCLCHECK(ret = ncclPollProxyResponse(comm, &send->proxyConn, map, opId));
if (ret == ncclInProgress) { if (ret == ncclInProgress) {
return ret; return ret;
} }
INFO(NCCL_PROXY, "sendConnect ncclPollProxyResponse opId=%p", opId); INFO(NCCL_PROXY, "sendConnect ncclPollProxyResponse opId=%p", opId);
if (map->sameProcess) { if (map->sameProcess && !ncclCuMemEnable()) {
if (map->cudaDev != comm->cudaDev) { if (map->cudaDev != comm->cudaDev) {
// Enable P2P access if (!ncclCuMemEnable()) {
cudaError_t err = cudaDeviceEnablePeerAccess(map->cudaDev, 0); // Enable P2P access for Legacy IPC
if (err == cudaErrorPeerAccessAlreadyEnabled) { cudaError_t err = cudaDeviceEnablePeerAccess(map->cudaDev, 0);
cudaGetLastError(); if (err == cudaErrorPeerAccessAlreadyEnabled) {
} else if (err != cudaSuccess) { cudaGetLastError();
WARN("failed to peer with device %d: %d %s", map->cudaDev, err, cudaGetErrorString(err)); } else if (err != cudaSuccess) {
return ncclInternalError; WARN("failed to peer with device %d: %d %s", map->cudaDev, err, cudaGetErrorString(err));
return ncclInternalError;
}
} }
} }
} else { } else if (!(map->sameProcess && map->cudaDev == comm->cudaDev)) {
NCCLCHECK(netMapShm(map->mems+NCCL_NET_MAP_HOSTMEM)); if (!map->sameProcess) NCCLCHECK(netMapShm(map->mems+NCCL_NET_MAP_HOSTMEM));
if (map->mems[NCCL_NET_MAP_DEVMEM].size) { if (map->mems[NCCL_NET_MAP_DEVMEM].size) {
CUDACHECK(cudaIpcOpenMemHandle((void**)&map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr, map->mems[NCCL_NET_MAP_DEVMEM].ipc, cudaIpcMemLazyEnablePeerAccess)); NCCLCHECK(ncclP2pImportShareableBuffer(comm, send->proxyConn.tpRank,
map->mems[NCCL_NET_MAP_DEVMEM].size,
&map->mems[NCCL_NET_MAP_DEVMEM].ipcDesc,
(void**)&map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr));
map->mems[NCCL_NET_MAP_DEVMEM].cpuPtr = NULL; map->mems[NCCL_NET_MAP_DEVMEM].cpuPtr = NULL;
} }
if (map->mems[NCCL_NET_MAP_SHARED_DEVMEM].size) { if (map->mems[NCCL_NET_MAP_SHARED_DEVMEM].size) {
void** sharedDevMemPtr = comm->proxyState.sharedDevMems+send->proxyConn.localRank; void** sharedDevMemPtr = comm->proxyState->sharedDevMems + send->proxyConn.tpLocalRank;
if (*sharedDevMemPtr == NULL) { if (*sharedDevMemPtr == NULL) {
CUDACHECK(cudaIpcOpenMemHandle(sharedDevMemPtr, map->mems[NCCL_NET_MAP_SHARED_DEVMEM].ipc, cudaIpcMemLazyEnablePeerAccess)); NCCLCHECK(ncclP2pImportShareableBuffer(comm, send->proxyConn.tpRank,
map->mems[NCCL_NET_MAP_SHARED_DEVMEM].size,
&map->mems[NCCL_NET_MAP_SHARED_DEVMEM].ipcDesc,
sharedDevMemPtr));
} }
map->mems[NCCL_NET_MAP_SHARED_DEVMEM].gpuPtr = (char*)(*sharedDevMemPtr); map->mems[NCCL_NET_MAP_SHARED_DEVMEM].gpuPtr = (char*)(*sharedDevMemPtr);
map->mems[NCCL_NET_MAP_SHARED_DEVMEM].cpuPtr = NULL; map->mems[NCCL_NET_MAP_SHARED_DEVMEM].cpuPtr = NULL;
@ -340,13 +353,13 @@ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* conne
opId = recv; opId = recv;
INFO(NCCL_PROXY, "recvConnect ncclProxyCallAsync opId=%p &recv->proxyConn=%p connectInfo=%p", INFO(NCCL_PROXY, "recvConnect ncclProxyCallAsync opId=%p &recv->proxyConn=%p connectInfo=%p",
opId, &recv->proxyConn, connectInfo); opId, &recv->proxyConn, connectInfo);
NCCLCHECK(ncclProxyCallAsync(&recv->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(int), sizeof(struct connectMap), opId)); NCCLCHECK(ncclProxyCallAsync(comm, &recv->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(int), sizeof(struct connectMap), opId));
} else { } else {
opId = recv; opId = recv;
} }
ncclResult_t ret; ncclResult_t ret;
NCCLCHECK(ret = ncclPollProxyResponse(&recv->proxyConn, map, opId)); NCCLCHECK(ret = ncclPollProxyResponse(comm, &recv->proxyConn, map, opId));
if (ret == ncclInProgress) { if (ret == ncclInProgress) {
return ret; return ret;
} }
@ -371,10 +384,24 @@ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* conne
static ncclResult_t sendFree(struct ncclConnector* send) { static ncclResult_t sendFree(struct ncclConnector* send) {
struct connectMap* map = (struct connectMap*)(send->transportResources); struct connectMap* map = (struct connectMap*)(send->transportResources);
if (map) { if (map) {
if (map->sameProcess == 0) { int cudaDev;
NCCLCHECK(ncclShmClose(map->mems[NCCL_NET_MAP_HOSTMEM].attachHandle)); CUDACHECK(cudaGetDevice(&cudaDev));
if (map->sameProcess && map->cudaDev == cudaDev) {
// Our own GPU, so it wasn't mapped in
free(map);
return ncclSuccess;
}
if (!map->sameProcess || ncclCuMemEnable()) {
if (!map->sameProcess) NCCLCHECK(ncclShmClose(map->mems[NCCL_NET_MAP_HOSTMEM].attachHandle));
if (map->mems[NCCL_NET_MAP_DEVMEM].size) { if (map->mems[NCCL_NET_MAP_DEVMEM].size) {
CUDACHECK(cudaIpcCloseMemHandle(map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr)); if (ncclCuMemEnable()) {
// cuMem API support
NCCLCHECK(ncclP2pFreeShareableBuffer(&map->mems[NCCL_NET_MAP_DEVMEM].ipcDesc));
NCCLCHECK(ncclCuMemFree(map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr));
} else {
// Legacy CUDA IPC support
CUDACHECK(cudaIpcCloseMemHandle(map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr));
}
} }
} }
free(map); free(map);
@ -389,86 +416,87 @@ static ncclResult_t recvFree(struct ncclConnector* recv) {
} }
#define NCCL_SHARED_STEPS 16 #define NCCL_SHARED_STEPS 16
static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, int localRank, int type, int sameProcess, static ncclResult_t sharedBuffersInit(struct ncclProxyState* proxyState, int cuda, int tpLocalRank, int type, int sameProcess,
int nChannels, char** gpuPtr, char** cpuPtr, int* size, cudaIpcMemHandle_t* ipc) { int nChannels, char** gpuPtr, char** cpuPtr, int* size, ncclIpcDesc *ipcDesc) {
if (cuda == 0 && sameProcess == 0) { if (cuda == 0 && sameProcess == 0) {
WARN("PXN should not use host buffers for data"); WARN("PXN should not use host buffers for data");
return ncclInternalError; return ncclInternalError;
} }
struct ncclProxyProgressState* progressState = &comm->proxyState.progressState; struct ncclProxyProgressState* progressState = &proxyState->progressState;
if (progressState->localPeers == NULL) { if (progressState->localPeers == NULL) {
NCCLCHECK(ncclCalloc(&progressState->localPeers, comm->localRanks)); NCCLCHECK(ncclCalloc(&progressState->localPeers, proxyState->tpLocalnRanks));
} }
struct ncclProxyPeer** localPeers = progressState->localPeers; struct ncclProxyPeer** localPeers = progressState->localPeers;
if (localPeers[localRank] == NULL) { if (localPeers[tpLocalRank] == NULL) {
NCCLCHECK(ncclCalloc(localPeers+localRank, 1)); NCCLCHECK(ncclCalloc(localPeers + tpLocalRank, 1));
} }
struct ncclProxyPeer* peer = localPeers[localRank]; struct ncclProxyPeer* peer = localPeers[tpLocalRank];
struct ncclProxySharedP2p* state = type == 0 ? &peer->send : &peer->recv; struct ncclProxySharedP2p* state = type == 0 ? &peer->send : &peer->recv;
state->refcount++; state->refcount++;
if (state->size == 0) { if (state->size == 0) {
state->size = nChannels*NCCL_SHARED_STEPS*comm->p2pChunkSize; state->size = nChannels * NCCL_SHARED_STEPS * proxyState->p2pChunkSize;
} }
if (size) *size = state->size; if (size) *size = state->size;
if (cuda && state->cudaBuff == NULL) { if (cuda && state->cudaBuff == NULL) {
NCCLCHECK(ncclCudaCalloc(&state->cudaBuff, state->size)); if (sameProcess == 0 || ncclCuMemEnable()) {
if (sameProcess == 0) { NCCLCHECK(ncclP2pAllocateShareableBuffer(state->size, &state->ipcDesc, (void**)&state->cudaBuff));
CUDACHECK(cudaIpcGetMemHandle(&state->ipc, state->cudaBuff)); } else {
NCCLCHECK(ncclCudaCalloc(&state->cudaBuff, state->size));
} }
} }
if (!cuda && state->hostBuff == NULL) { if (!cuda && state->hostBuff == NULL) {
NCCLCHECK(ncclCudaHostCalloc(&state->hostBuff, state->size)); NCCLCHECK(ncclCudaHostCalloc(&state->hostBuff, state->size));
} }
if (cpuPtr) *cpuPtr = cuda ? state->cudaBuff : state->hostBuff; if (cpuPtr) *cpuPtr = cuda ? state->cudaBuff : state->hostBuff;
if (sameProcess) { if (gpuPtr) *gpuPtr = sameProcess ? *cpuPtr : NULL;
if (gpuPtr) *gpuPtr = *cpuPtr; if (ipcDesc) memcpy(ipcDesc, &state->ipcDesc, sizeof(state->ipcDesc));
} else {
if (gpuPtr) *gpuPtr = NULL;
if (ipc) memcpy(ipc, &state->ipc, sizeof(cudaIpcMemHandle_t));
}
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int channel, int slot, int* offset) { static ncclResult_t sharedBuffersGet(struct ncclProxyState* proxyState, int channel, int slot, int* offset) {
// Use different pools for different channels and also separate send/recv. // Use different pools for different channels and also separate send/recv.
int globalSlot = (channel*NCCL_SHARED_STEPS)+slot; int globalSlot = (channel*NCCL_SHARED_STEPS)+slot;
*offset = comm->p2pChunkSize * globalSlot; *offset = proxyState->p2pChunkSize * globalSlot;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sharedBuffersDestroy(struct ncclComm* comm, int localRank, int type) { static ncclResult_t sharedBuffersDestroy(struct ncclProxyState* proxyState, int tpLocalRank, int type, struct ncclProxyConnection* connection) {
if (comm->proxyState.progressState.localPeers == NULL) NCCLCHECK(ncclInternalError); if (proxyState->progressState.localPeers == NULL) NCCLCHECK(ncclInternalError);
struct ncclProxyPeer* peer = comm->proxyState.progressState.localPeers[localRank]; struct ncclProxyPeer* peer = proxyState->progressState.localPeers[tpLocalRank];
if (peer == NULL) NCCLCHECK(ncclInternalError;) if (peer == NULL) NCCLCHECK(ncclInternalError;)
struct ncclProxySharedP2p* state = type == 0 ? &peer->send : &peer->recv; struct ncclProxySharedP2p* state = type == 0 ? &peer->send : &peer->recv;
if (state->size == 0) NCCLCHECK(ncclInternalError); if (state->size == 0) NCCLCHECK(ncclInternalError);
state->refcount--; if (ncclAtomicRefCountDecrement(&state->refcount) == 0) {
if (state->refcount == 0) { if (state->cudaBuff) {
if (state->cudaBuff) CUDACHECK(cudaFree(state->cudaBuff)); if (!connection->sameProcess || ncclCuMemEnable()) {
NCCLCHECK(ncclP2pFreeShareableBuffer(&state->ipcDesc));
}
NCCLCHECK(ncclCudaFree(state->cudaBuff));
}
if (state->hostBuff) NCCLCHECK(ncclCudaHostFree(state->hostBuff)); if (state->hostBuff) NCCLCHECK(ncclCudaHostFree(state->hostBuff));
} }
if (peer->send.refcount || peer->recv.refcount) return ncclSuccess; if (peer->send.refcount || peer->recv.refcount) return ncclSuccess;
free(peer); free(peer);
comm->proxyState.progressState.localPeers[localRank] = NULL; proxyState->progressState.localPeers[tpLocalRank] = NULL;
for (int r=0; r<comm->localRanks; r++) { for (int r = 0; r < proxyState->tpLocalnRanks; r++) {
if (comm->proxyState.progressState.localPeers[r]) return ncclSuccess; if (proxyState->progressState.localPeers[r]) return ncclSuccess;
} }
// All peers are freed, free array // All peers are freed, free array
free(comm->proxyState.progressState.localPeers); free(proxyState->progressState.localPeers);
comm->proxyState.progressState.localPeers = NULL; proxyState->progressState.localPeers = NULL;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t proxySharedInit(struct ncclProxyConnection* connection, struct ncclComm* comm, int nChannels) { static ncclResult_t proxySharedInit(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, int nChannels) {
int rank = comm->localRankToRank[connection->localRank]; NCCLCHECK(sharedBuffersInit(proxyState, 1, connection->tpLocalRank, 0, connection->sameProcess, nChannels, NULL, NULL, NULL, NULL));
int sameProcess = comm->peerInfo[rank].pidHash == comm->peerInfo[comm->rank].pidHash ? 1 : 0;
NCCLCHECK(sharedBuffersInit(comm, 1, connection->localRank, 0, sameProcess, nChannels, NULL, NULL, NULL, NULL));
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
struct setupReq* req = (struct setupReq*) reqBuff; struct setupReq* req = (struct setupReq*) reqBuff;
if (reqSize != sizeof(struct setupReq)) return ncclInternalError; if (reqSize != sizeof(struct setupReq)) return ncclInternalError;
@ -476,18 +504,18 @@ static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struc
NCCLCHECK(ncclCalloc(&resources, 1)); NCCLCHECK(ncclCalloc(&resources, 1));
connection->transportResources = resources; connection->transportResources = resources;
resources->rank = req->rank; resources->tpRank = req->tpRank;
resources->localRank = req->localRank; resources->tpLocalRank = req->tpLocalRank;
resources->remoteRank = req->remoteRank; resources->tpRemoteRank = req->tpRemoteRank;
resources->netDev = req->netDev; resources->netDev = req->netDev;
resources->shared = connection->shared = req->shared; resources->shared = connection->shared = req->shared;
resources->useGdr = req->useGdr; resources->useGdr = req->useGdr;
resources->channelId = req->channelId; resources->channelId = req->channelId;
resources->connIndex = req->connIndex; resources->connIndex = req->connIndex;
ncclNetProperties_t props; ncclNetProperties_t props;
NCCLCHECK(ncclNetGetProperties(comm, req->netDev, &props)); NCCLCHECK(proxyState->ncclNet->getProperties(req->netDev, &props));
/* DMA-BUF support */ /* DMA-BUF support */
resources->useDmaBuf = resources->useGdr && comm->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF); resources->useDmaBuf = resources->useGdr && proxyState->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF);
resources->maxRecvs = props.maxRecvs; resources->maxRecvs = props.maxRecvs;
// We don't return any data // We don't return any data
@ -496,7 +524,7 @@ static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struc
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
struct setupReq* req = (struct setupReq*) reqBuff; struct setupReq* req = (struct setupReq*) reqBuff;
if (reqSize != sizeof(struct setupReq)) return ncclInternalError; if (reqSize != sizeof(struct setupReq)) return ncclInternalError;
@ -504,9 +532,9 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc
NCCLCHECK(ncclCalloc(&resources, 1)); NCCLCHECK(ncclCalloc(&resources, 1));
connection->transportResources = resources; connection->transportResources = resources;
resources->rank = req->rank; resources->tpRank = req->tpRank;
resources->localRank = req->localRank; resources->tpLocalRank = req->tpLocalRank;
resources->remoteRank = req->remoteRank; resources->tpRemoteRank = req->tpRemoteRank;
resources->netDev = req->netDev; resources->netDev = req->netDev;
resources->shared = connection->shared = req->shared; resources->shared = connection->shared = req->shared;
resources->useGdr = req->useGdr; resources->useGdr = req->useGdr;
@ -514,50 +542,50 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc
resources->channelId = req->channelId; resources->channelId = req->channelId;
resources->connIndex = req->connIndex; resources->connIndex = req->connIndex;
ncclNetProperties_t props; ncclNetProperties_t props;
NCCLCHECK(ncclNetGetProperties(comm, req->netDev, &props)); NCCLCHECK(proxyState->ncclNet->getProperties(req->netDev, &props));
/* DMA-BUF support */ /* DMA-BUF support */
resources->useDmaBuf = resources->useGdr && comm->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF); resources->useDmaBuf = resources->useGdr && proxyState->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF);
resources->maxRecvs = props.maxRecvs; resources->maxRecvs = props.maxRecvs;
if (respSize != sizeof(ncclNetHandle_t)) return ncclInternalError; if (respSize != sizeof(ncclNetHandle_t)) return ncclInternalError;
NCCLCHECK(ncclNetListen(comm, req->netDev, respBuff, &resources->netListenComm)); NCCLCHECK(proxyState->ncclNet->listen(req->netDev, respBuff, &resources->netListenComm));
*done = 1; *done = 1;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
struct sendResources* resources = (struct sendResources*)(connection->transportResources); struct sendResources* resources = (struct sendResources*)(connection->transportResources);
if (reqSize != sizeof(ncclNetHandle_t)) return ncclInternalError; if (reqSize != sizeof(ncclNetHandle_t)) return ncclInternalError;
ncclResult_t ret = ncclSuccess; ncclResult_t ret = ncclSuccess;
if (resources->shared) { if (resources->shared) {
// Shared buffers // Shared buffers
struct ncclProxyProgressState* progressState = &comm->proxyState.progressState; struct ncclProxyProgressState* progressState = &proxyState->progressState;
if (progressState->localPeers == NULL) { if (progressState->localPeers == NULL) {
NCCLCHECK(ncclCalloc(&progressState->localPeers, comm->localRanks)); NCCLCHECK(ncclCalloc(&progressState->localPeers, proxyState->tpLocalnRanks));
} }
struct ncclProxyPeer** localPeers = progressState->localPeers; struct ncclProxyPeer** localPeers = progressState->localPeers;
if (localPeers[resources->localRank] == NULL) { if (localPeers[resources->tpLocalRank] == NULL) {
NCCLCHECK(ncclCalloc(localPeers+resources->localRank, 1)); NCCLCHECK(ncclCalloc(localPeers + resources->tpLocalRank, 1));
} }
connection->proxyAppendPtr = localPeers[resources->localRank]->send.proxyAppend+resources->channelId; connection->proxyAppendPtr = localPeers[resources->tpLocalRank]->send.proxyAppend + resources->channelId;
if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) {
// Connect or reuse connection for a netdev/remote rank. // Connect or reuse connection for a netdev/remote rank.
if (progressState->netComms[resources->netDev] == NULL) { if (progressState->netComms[resources->netDev] == NULL) {
NCCLCHECK(ncclCalloc(progressState->netComms+resources->netDev, comm->nRanks)); NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks));
} }
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev]+resources->remoteRank; struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteRank;
if (comms->sendComm[resources->channelId] == NULL) ret = ncclNetConnect(comm, resources->netDev, reqBuff, comms->sendComm+resources->channelId); if (comms->sendComm[resources->channelId] == NULL) ret = proxyState->ncclNet->connect(resources->netDev, reqBuff, comms->sendComm + resources->channelId);
resources->netSendComm = comms->sendComm[resources->channelId]; resources->netSendComm = comms->sendComm[resources->channelId];
if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++; if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++;
} else { } else {
ret = ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm); ret = proxyState->ncclNet->connect(resources->netDev, reqBuff, &resources->netSendComm);
} }
} else { } else {
// Connect to remote peer // Connect to remote peer
ret = ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm); ret = proxyState->ncclNet->connect(resources->netDev, reqBuff, &resources->netSendComm);
connection->proxyAppendPtr = &connection->proxyAppend; connection->proxyAppendPtr = &connection->proxyAppend;
} }
@ -570,28 +598,27 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
// Create structures // Create structures
struct connectMap* map = &resources->map; struct connectMap* map = &resources->map;
map->sameProcess = map->sameProcess = connection->sameProcess;
comm->peerInfo[resources->rank].pidHash == comm->peerInfo[comm->rank].pidHash ? 1 : 0;
map->shared = resources->shared; map->shared = resources->shared;
CUDACHECK(cudaGetDevice(&map->cudaDev)); CUDACHECK(cudaGetDevice(&map->cudaDev));
if (resources->shared == 0) { // Only allocate dedicated buffers for ring/tree, not for p2p if (resources->shared == 0) { // Only allocate dedicated buffers for ring/tree, not for p2p
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
NCCL_NET_MAP_ADD_POINTER(map, 0, p!= NCCL_PROTO_LL && resources->useGdr, comm->buffSizes[p], buffs[p]); NCCL_NET_MAP_ADD_POINTER(map, 0, p!= NCCL_PROTO_LL && resources->useGdr, proxyState->buffSizes[p], buffs[p]);
resources->buffSizes[p] = comm->buffSizes[p]; resources->buffSizes[p] = proxyState->buffSizes[p];
} }
} else { } else {
// Get shared buffers // Get shared buffers
int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM; int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM;
struct connectMapMem* mapMem = map->mems+bank; struct connectMapMem* mapMem = map->mems+bank;
NCCLCHECK(sharedBuffersInit( NCCLCHECK(sharedBuffersInit(
comm, resources->useGdr, resources->localRank, 0, map->sameProcess, comm->p2pnChannels, proxyState, resources->useGdr, resources->tpLocalRank, 0, map->sameProcess, proxyState->p2pnChannels,
&mapMem->gpuPtr, &mapMem->cpuPtr, &mapMem->size, &mapMem->ipc)); &mapMem->gpuPtr, &mapMem->cpuPtr, &mapMem->size, &mapMem->ipcDesc));
resources->buffSizes[NCCL_PROTO_SIMPLE] = mapMem->size; resources->buffSizes[NCCL_PROTO_SIMPLE] = mapMem->size;
if (comm->allocP2pNetLLBuffers) { if (proxyState->allocP2pNetLLBuffers) {
NCCL_NET_MAP_ADD_POINTER(map, 0, 0 /*p == NCCL_PROTO_LL*/, comm->buffSizes[NCCL_PROTO_LL], buffs[NCCL_PROTO_LL]); NCCL_NET_MAP_ADD_POINTER(map, 0, 0 /*p == NCCL_PROTO_LL*/, proxyState->buffSizes[NCCL_PROTO_LL], buffs[NCCL_PROTO_LL]);
resources->buffSizes[NCCL_PROTO_LL] = comm->buffSizes[NCCL_PROTO_LL]; resources->buffSizes[NCCL_PROTO_LL] = proxyState->buffSizes[NCCL_PROTO_LL];
} }
NCCL_NET_MAP_ADD_POINTER(map, 1, resources->useGdr, mapMem->size, buffs[NCCL_PROTO_SIMPLE]); NCCL_NET_MAP_ADD_POINTER(map, 1, resources->useGdr, mapMem->size, buffs[NCCL_PROTO_SIMPLE]);
@ -602,15 +629,15 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
if (map->mems[NCCL_NET_MAP_DEVMEM].size) { if (map->mems[NCCL_NET_MAP_DEVMEM].size) {
if (resources->shared == 0) { if (resources->shared == 0) {
if (!map->sameProcess) { if (!map->sameProcess || ncclCuMemEnable()) {
ALIGN_SIZE(map->mems[NCCL_NET_MAP_DEVMEM].size, CUDA_IPC_MIN); ALIGN_SIZE(map->mems[NCCL_NET_MAP_DEVMEM].size, CUDA_IPC_MIN);
NCCLCHECK(ncclP2pAllocateShareableBuffer(map->mems[NCCL_NET_MAP_DEVMEM].size, &map->mems[NCCL_NET_MAP_DEVMEM].ipcDesc,
(void**)&map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr));
} else {
NCCLCHECK(ncclCudaCalloc(&map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr, map->mems[NCCL_NET_MAP_DEVMEM].size));
} }
NCCLCHECK(ncclCudaCalloc(&map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr, map->mems[NCCL_NET_MAP_DEVMEM].size));
map->mems[NCCL_NET_MAP_DEVMEM].cpuPtr = map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr; map->mems[NCCL_NET_MAP_DEVMEM].cpuPtr = map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr;
} }
if (!map->sameProcess) {
CUDACHECK(cudaIpcGetMemHandle(&map->mems[NCCL_NET_MAP_DEVMEM].ipc, map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr));
}
} }
if (map->sameProcess) { if (map->sameProcess) {
NCCLCHECK(ncclCudaHostCalloc(&map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].size)); NCCLCHECK(ncclCudaHostCalloc(&map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].size));
@ -645,12 +672,12 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
if (type == NCCL_PTR_CUDA && resources->useDmaBuf) { if (type == NCCL_PTR_CUDA && resources->useDmaBuf) {
int dmabuf_fd; int dmabuf_fd;
CUCHECK(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)resources->buffers[p], resources->buffSizes[p], CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0)); CUCHECK(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)resources->buffers[p], resources->buffSizes[p], CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0));
NCCLCHECK(ncclNetRegMrDmaBuf(comm, resources->netSendComm, resources->buffers[p], resources->buffSizes[p], type, 0ULL, dmabuf_fd, &resources->mhandles[p])); NCCLCHECK(proxyState->ncclNet->regMrDmaBuf(resources->netSendComm, resources->buffers[p], resources->buffSizes[p], type, 0ULL, dmabuf_fd, &resources->mhandles[p]));
(void)close(dmabuf_fd); (void)close(dmabuf_fd);
} else // FALL-THROUGH to nv_peermem GDR path } else // FALL-THROUGH to nv_peermem GDR path
#endif #endif
{ {
NCCLCHECK(ncclNetRegMr(comm, resources->netSendComm, resources->buffers[p], resources->buffSizes[p], NCCL_NET_MAP_DEV_MEM(map, buffs[p]) ? NCCL_PTR_CUDA : NCCL_PTR_HOST, &resources->mhandles[p])); NCCLCHECK(proxyState->ncclNet->regMr(resources->netSendComm, resources->buffers[p], resources->buffSizes[p], NCCL_NET_MAP_DEV_MEM(map, buffs[p]) ? NCCL_PTR_CUDA : NCCL_PTR_HOST, &resources->mhandles[p]));
} }
} }
} }
@ -661,40 +688,40 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
if (reqSize != sizeof(int)) return ncclInternalError; if (reqSize != sizeof(int)) return ncclInternalError;
struct recvResources* resources = (struct recvResources*)(connection->transportResources); struct recvResources* resources = (struct recvResources*)(connection->transportResources);
resources->proxyRank = *(int*)reqBuff; resources->tpRemoteProxyRank = *(int*)reqBuff;
ncclResult_t ret = ncclSuccess; ncclResult_t ret = ncclSuccess;
// Finish connection establishment from remote peer // Finish connection establishment from remote peer
if (resources->shared) { if (resources->shared) {
// Shared buffers // Shared buffers
struct ncclProxyProgressState* progressState = &comm->proxyState.progressState; struct ncclProxyProgressState* progressState = &proxyState->progressState;
if (progressState->localPeers == NULL) { if (progressState->localPeers == NULL) {
NCCLCHECK(ncclCalloc(&progressState->localPeers, comm->localRanks)); NCCLCHECK(ncclCalloc(&progressState->localPeers, proxyState->tpLocalnRanks));
} }
struct ncclProxyPeer** localPeers = progressState->localPeers; struct ncclProxyPeer** localPeers = progressState->localPeers;
if (localPeers[resources->localRank] == NULL) { if (localPeers[resources->tpLocalRank] == NULL) {
NCCLCHECK(ncclCalloc(localPeers+resources->localRank, 1)); NCCLCHECK(ncclCalloc(localPeers + resources->tpLocalRank, 1));
} }
connection->proxyAppendPtr = localPeers[resources->localRank]->recv.proxyAppend+resources->channelId; connection->proxyAppendPtr = localPeers[resources->tpLocalRank]->recv.proxyAppend + resources->channelId;
if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) {
// Connect or reuse connection for a netdev/remote rank. // Connect or reuse connection for a netdev/remote rank.
if (progressState->netComms[resources->netDev] == NULL) { if (progressState->netComms[resources->netDev] == NULL) {
NCCLCHECK(ncclCalloc(progressState->netComms+resources->netDev, comm->nRanks)); NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks));
} }
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev]+resources->proxyRank; struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteProxyRank;
if (comms->recvComm[resources->channelId] == NULL) ret = ncclNetAccept(comm, resources->netListenComm, comms->recvComm+resources->channelId); if (comms->recvComm[resources->channelId] == NULL) ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId);
resources->netRecvComm = comms->recvComm[resources->channelId]; resources->netRecvComm = comms->recvComm[resources->channelId];
if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++; if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++;
} else { } else {
ret = ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm); ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm);
} }
} else { } else {
// Connect to remote peer // Connect to remote peer
ret = ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm); ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm);
connection->proxyAppendPtr = &connection->proxyAppend; connection->proxyAppendPtr = &connection->proxyAppend;
} }
@ -705,26 +732,25 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
} }
*done = 1; *done = 1;
NCCLCHECK(ncclNetCloseListen(comm, resources->netListenComm)); NCCLCHECK(proxyState->ncclNet->closeListen(resources->netListenComm));
// Create structures // Create structures
struct connectMap* map = &resources->map; struct connectMap* map = &resources->map;
map->sameProcess = map->sameProcess = connection->sameProcess;
comm->peerInfo[resources->rank].pidHash == comm->peerInfo[comm->rank].pidHash ? 1 : 0;
if (map->sameProcess == 0) return ncclInternalError; // We don't support remote proxy for recv if (map->sameProcess == 0) return ncclInternalError; // We don't support remote proxy for recv
map->shared = resources->shared; map->shared = resources->shared;
if (resources->shared == 0) { // Only allocate dedicated buffers for ring/tree, not for p2p if (resources->shared == 0) { // Only allocate dedicated buffers for ring/tree, not for p2p
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
NCCL_NET_MAP_ADD_POINTER(map, 0, resources->useGdr, comm->buffSizes[p], buffs[p]); NCCL_NET_MAP_ADD_POINTER(map, 0, resources->useGdr, proxyState->buffSizes[p], buffs[p]);
resources->buffSizes[p] = comm->buffSizes[p]; resources->buffSizes[p] = proxyState->buffSizes[p];
} }
} else { } else {
// Get shared buffers // Get shared buffers
int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM; int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM;
struct connectMapMem* mapMem = map->mems+bank; struct connectMapMem* mapMem = map->mems+bank;
NCCLCHECK(sharedBuffersInit( NCCLCHECK(sharedBuffersInit(
comm, resources->useGdr, resources->localRank, 1, 1, comm->p2pnChannels, proxyState, resources->useGdr, resources->tpLocalRank, 1, 1, proxyState->p2pnChannels,
&mapMem->gpuPtr, &mapMem->cpuPtr, &mapMem->size, NULL)); &mapMem->gpuPtr, &mapMem->cpuPtr, &mapMem->size, NULL));
resources->buffSizes[NCCL_PROTO_SIMPLE] = mapMem->size; resources->buffSizes[NCCL_PROTO_SIMPLE] = mapMem->size;
NCCL_NET_MAP_ADD_POINTER(map, 1, resources->useGdr, mapMem->size, buffs[NCCL_PROTO_SIMPLE]); NCCL_NET_MAP_ADD_POINTER(map, 1, resources->useGdr, mapMem->size, buffs[NCCL_PROTO_SIMPLE]);
@ -733,14 +759,19 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
NCCL_NET_MAP_ADD_POINTER(map, 0, 0, sizeof(struct ncclSendMem), sendMem); NCCL_NET_MAP_ADD_POINTER(map, 0, 0, sizeof(struct ncclSendMem), sendMem);
NCCL_NET_MAP_ADD_POINTER(map, 0, 0, sizeof(struct ncclRecvMem), recvMem); NCCL_NET_MAP_ADD_POINTER(map, 0, 0, sizeof(struct ncclRecvMem), recvMem);
if (comm->allocP2pNetLLBuffers) { if (proxyState->allocP2pNetLLBuffers) {
NCCL_NET_MAP_ADD_POINTER(map, 0, 0 /*resources->useGdr*/, comm->buffSizes[NCCL_PROTO_LL], buffs[NCCL_PROTO_LL]); NCCL_NET_MAP_ADD_POINTER(map, 0, 0 /*resources->useGdr*/, proxyState->buffSizes[NCCL_PROTO_LL], buffs[NCCL_PROTO_LL]);
resources->buffSizes[NCCL_PROTO_LL] = comm->buffSizes[NCCL_PROTO_LL]; resources->buffSizes[NCCL_PROTO_LL] = proxyState->buffSizes[NCCL_PROTO_LL];
} }
if (map->mems[NCCL_NET_MAP_DEVMEM].size) { if (map->mems[NCCL_NET_MAP_DEVMEM].size) {
if (resources->shared == 0) { if (resources->shared == 0) {
NCCLCHECK(ncclCudaCalloc(&map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr, map->mems[NCCL_NET_MAP_DEVMEM].size)); if (ncclCuMemEnable()) {
NCCLCHECK(ncclP2pAllocateShareableBuffer(map->mems[NCCL_NET_MAP_DEVMEM].size, &map->mems[NCCL_NET_MAP_DEVMEM].ipcDesc,
(void**)&map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr));
} else {
NCCLCHECK(ncclCudaCalloc(&map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr, map->mems[NCCL_NET_MAP_DEVMEM].size));
}
map->mems[NCCL_NET_MAP_DEVMEM].cpuPtr = map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr; map->mems[NCCL_NET_MAP_DEVMEM].cpuPtr = map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr;
} }
} }
@ -771,12 +802,12 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
if (type == NCCL_PTR_CUDA && resources->useDmaBuf) { if (type == NCCL_PTR_CUDA && resources->useDmaBuf) {
int dmabuf_fd; int dmabuf_fd;
CUCHECK(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)resources->buffers[p], resources->buffSizes[p], CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0)); CUCHECK(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)resources->buffers[p], resources->buffSizes[p], CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0));
NCCLCHECK(ncclNetRegMrDmaBuf(comm, resources->netRecvComm, resources->buffers[p], resources->buffSizes[p], type, 0ULL, dmabuf_fd, &resources->mhandles[p])); NCCLCHECK(proxyState->ncclNet->regMrDmaBuf(resources->netRecvComm, resources->buffers[p], resources->buffSizes[p], type, 0ULL, dmabuf_fd, &resources->mhandles[p]));
(void)close(dmabuf_fd); (void)close(dmabuf_fd);
} else // FALL-THROUGH to nv_peermem GDR path } else // FALL-THROUGH to nv_peermem GDR path
#endif #endif
{ {
NCCLCHECK(ncclNetRegMr(comm, resources->netRecvComm, resources->buffers[p], resources->buffSizes[p], NCCL_NET_MAP_DEV_MEM(map, buffs[p]) ? NCCL_PTR_CUDA : NCCL_PTR_HOST, &resources->mhandles[p])); NCCLCHECK(proxyState->ncclNet->regMr(resources->netRecvComm, resources->buffers[p], resources->buffSizes[p], NCCL_NET_MAP_DEV_MEM(map, buffs[p]) ? NCCL_PTR_CUDA : NCCL_PTR_HOST, &resources->mhandles[p]));
} }
} }
} }
@ -787,17 +818,17 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState) {
struct sendResources* resources = (struct sendResources*)(connection->transportResources); struct sendResources* resources = (struct sendResources*)(connection->transportResources);
if (connection->state == connSharedInitialized) { // NVB Preconnect if (connection->state == connSharedInitialized) { // NVB Preconnect
NCCLCHECK(sharedBuffersDestroy(comm, connection->localRank, 0)); NCCLCHECK(sharedBuffersDestroy(proxyState, connection->tpLocalRank, 0, connection));
return ncclSuccess; return ncclSuccess;
} }
if (connection->state == connConnected) { if (connection->state == connConnected) {
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
if (resources->buffers[p]) { if (resources->buffers[p]) {
NCCLCHECK(ncclNetDeregMr(comm, resources->netSendComm, resources->mhandles[p])); NCCLCHECK(proxyState->ncclNet->deregMr(resources->netSendComm, resources->mhandles[p]));
} }
} }
struct connectMapMem* mems = resources->map.mems; struct connectMapMem* mems = resources->map.mems;
@ -806,19 +837,25 @@ static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct
} else { } else {
NCCLCHECK(ncclShmClose(mems[NCCL_NET_MAP_HOSTMEM].createHandle)); NCCLCHECK(ncclShmClose(mems[NCCL_NET_MAP_HOSTMEM].createHandle));
} }
CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); NCCLCHECK(ncclCudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr));
if (!resources->map.sameProcess || ncclCuMemEnable()) {
// cuMem API support
if (mems[NCCL_NET_MAP_DEVMEM].size) {
NCCLCHECK(ncclP2pFreeShareableBuffer(&mems[NCCL_NET_MAP_DEVMEM].ipcDesc));
}
}
if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc));
if (resources->shared) { if (resources->shared) {
NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 0)); NCCLCHECK(sharedBuffersDestroy(proxyState, resources->tpLocalRank, 0, connection));
if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) {
struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->remoteRank; struct ncclSharedNetComms* comms = proxyState->progressState.netComms[resources->netDev]+resources->tpRemoteRank;
comms->sendRefCount[resources->channelId]--; comms->sendRefCount[resources->channelId]--;
if (comms->sendRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseSend(comm, comms->sendComm[resources->channelId])); if (comms->sendRefCount[resources->channelId] == 0) NCCLCHECK(proxyState->ncclNet->closeSend(comms->sendComm[resources->channelId]));
} else { } else {
NCCLCHECK(ncclNetCloseSend(comm, resources->netSendComm)); NCCLCHECK(proxyState->ncclNet->closeSend(resources->netSendComm));
} }
} else { } else {
NCCLCHECK(ncclNetCloseSend(comm, resources->netSendComm)); NCCLCHECK(proxyState->ncclNet->closeSend(resources->netSendComm));
} }
} }
@ -826,44 +863,50 @@ static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState) {
struct recvResources* resources = (struct recvResources*)(connection->transportResources); struct recvResources* resources = (struct recvResources*)(connection->transportResources);
if (connection->state == connSharedInitialized) { // NVB Preconnect if (connection->state == connSharedInitialized) { // NVB Preconnect
NCCLCHECK(sharedBuffersDestroy(comm, connection->localRank, 1)); NCCLCHECK(sharedBuffersDestroy(proxyState, connection->tpLocalRank, 1, connection));
return ncclSuccess; return ncclSuccess;
} }
if (connection->state == connConnected) { if (connection->state == connConnected) {
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
if (resources->buffers[p]) { if (resources->buffers[p]) {
NCCLCHECK(ncclNetDeregMr(comm, resources->netRecvComm, resources->mhandles[p])); NCCLCHECK(proxyState->ncclNet->deregMr(resources->netRecvComm, resources->mhandles[p]));
} }
} }
struct connectMapMem* mems = resources->map.mems; struct connectMapMem* mems = resources->map.mems;
NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr));
CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); NCCLCHECK(ncclCudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr));
if (!resources->map.sameProcess || ncclCuMemEnable()) {
// cuMem API support
if (mems[NCCL_NET_MAP_DEVMEM].size) {
NCCLCHECK(ncclP2pFreeShareableBuffer(&mems[NCCL_NET_MAP_DEVMEM].ipcDesc));
}
}
if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc));
if (resources->shared) { if (resources->shared) {
NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 1)); NCCLCHECK(sharedBuffersDestroy(proxyState, resources->tpLocalRank, 1, connection));
if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) {
struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->proxyRank; struct ncclSharedNetComms* comms = proxyState->progressState.netComms[resources->netDev] + resources->tpRemoteProxyRank;
comms->recvRefCount[resources->channelId]--; comms->recvRefCount[resources->channelId]--;
if (comms->recvRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseRecv(comm, comms->recvComm[resources->channelId])); if (comms->recvRefCount[resources->channelId] == 0) NCCLCHECK(proxyState->ncclNet->closeRecv(comms->recvComm[resources->channelId]));
} else { } else {
NCCLCHECK(ncclNetCloseRecv(comm, resources->netRecvComm)); NCCLCHECK(proxyState->ncclNet->closeRecv(resources->netRecvComm));
} }
} else { } else {
NCCLCHECK(ncclNetCloseRecv(comm, resources->netRecvComm)); NCCLCHECK(proxyState->ncclNet->closeRecv(resources->netRecvComm));
} }
} }
if (resources) free(resources); if (resources) free(resources);
return ncclSuccess; return ncclSuccess;
} }
static_assert(NCCL_STEPS <= NCCL_NET_MAX_REQUESTS, "Not enough net requests to cover for steps"); static_assert(NCCL_STEPS <= NCCL_NET_MAX_REQUESTS, "Not enough net requests to cover for steps");
static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArgs* args) { static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) { if (args->state == ncclProxyOpReady) {
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
@ -894,7 +937,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
if (resources->shared) { if (resources->shared) {
int sharedBuffSlot = sub->posted%maxDepth; int sharedBuffSlot = sub->posted%maxDepth;
int offset; int offset;
NCCLCHECK(sharedBuffersGet(comm, sub->channelId, sharedBuffSlot*args->nsubs+s, &offset)); NCCLCHECK(sharedBuffersGet(proxyState, sub->channelId, sharedBuffSlot*args->nsubs+s, &offset));
resources->recvMem->offsFifo[buffSlot] = offset; resources->recvMem->offsFifo[buffSlot] = offset;
__sync_synchronize(); __sync_synchronize();
volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head;
@ -944,7 +987,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
} }
if (ready) { if (ready) {
// Data is ready, try to send. // Data is ready, try to send.
NCCLCHECK(ncclNetIsend(comm, resources->netSendComm, buff, size, resources->rank, mhandle, sub->requests+buffSlot)); NCCLCHECK(proxyState->ncclNet->isend(resources->netSendComm, buff, size, resources->tpRank, mhandle, sub->requests+buffSlot));
if (sub->requests[buffSlot] != NULL) { if (sub->requests[buffSlot] != NULL) {
TRACE(NCCL_NET, "sendProxy [%ld/%d] Isend posted, req %p", sub->transmitted, buffSlot, sub->requests[buffSlot]); TRACE(NCCL_NET, "sendProxy [%ld/%d] Isend posted, req %p", sub->transmitted, buffSlot, sub->requests[buffSlot]);
sizesFifo[buffSlot] = -1; sizesFifo[buffSlot] = -1;
@ -962,7 +1005,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
if (sub->done < sub->transmitted) { if (sub->done < sub->transmitted) {
int done; int done;
int buffSlot = (sub->base+sub->done)%NCCL_STEPS; int buffSlot = (sub->base+sub->done)%NCCL_STEPS;
NCCLCHECK(ncclNetTest(comm, sub->requests[buffSlot], &done, NULL)); NCCLCHECK(proxyState->ncclNet->test(sub->requests[buffSlot], &done, NULL));
if (done) { if (done) {
TRACE(NCCL_NET, "sendProxy [%ld/%d] request %p done", sub->done, buffSlot, sub->requests[buffSlot]); TRACE(NCCL_NET, "sendProxy [%ld/%d] request %p done", sub->done, buffSlot, sub->requests[buffSlot]);
sub->done += args->sliceSteps; sub->done += args->sliceSteps;
@ -988,7 +1031,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArgs* args) { static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) { if (args->state == ncclProxyOpReady) {
// Initialize subs and group them by same recvComm. // Initialize subs and group them by same recvComm.
void* recvComm; void* recvComm;
@ -1048,7 +1091,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
if (p == NCCL_PROTO_SIMPLE && resources->shared) { if (p == NCCL_PROTO_SIMPLE && resources->shared) {
int sharedBuffSlot = sub->posted%maxDepth; int sharedBuffSlot = sub->posted%maxDepth;
int offset; int offset;
NCCLCHECK(sharedBuffersGet(comm, sub->channelId, sharedBuffSlot*args->nsubs+s+i, &offset)); NCCLCHECK(sharedBuffersGet(proxyState, sub->channelId, sharedBuffSlot*args->nsubs+s+i, &offset));
volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo; volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo;
offsFifo[buffSlot] = offset; offsFifo[buffSlot] = offset;
ptrs[subCount] = localBuff+offset; ptrs[subCount] = localBuff+offset;
@ -1057,7 +1100,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
} }
sizes[subCount] = stepSize*args->sliceSteps; sizes[subCount] = stepSize*args->sliceSteps;
if (sub->nbytes < sizes[subCount]) sizes[subCount] = sub->nbytes; if (sub->nbytes < sizes[subCount]) sizes[subCount] = sub->nbytes;
tags[subCount] = resources->remoteRank; tags[subCount] = resources->tpRemoteRank;
mhandles[subCount] = resources->mhandles[p]; mhandles[subCount] = resources->mhandles[p];
subCount++; subCount++;
} }
@ -1066,7 +1109,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
uint64_t step = subGroup->posted; uint64_t step = subGroup->posted;
struct recvResources* resources = (struct recvResources*) (subGroup->connection->transportResources); struct recvResources* resources = (struct recvResources*) (subGroup->connection->transportResources);
void** requestPtr = subGroup->requests+(step%NCCL_STEPS); void** requestPtr = subGroup->requests+(step%NCCL_STEPS);
NCCLCHECK(ncclNetIrecv(comm, resources->netRecvComm, subCount, ptrs, sizes, tags, mhandles, requestPtr)); NCCLCHECK(proxyState->ncclNet->irecv(resources->netRecvComm, subCount, ptrs, sizes, tags, mhandles, requestPtr));
if (*requestPtr) { if (*requestPtr) {
for (int i=0; i<subGroup->groupSize; i++) { for (int i=0; i<subGroup->groupSize; i++) {
struct ncclProxySubArgs* sub = subGroup+i; struct ncclProxySubArgs* sub = subGroup+i;
@ -1088,7 +1131,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int sizes[NCCL_PROXY_MAX_SUBS]; int sizes[NCCL_PROXY_MAX_SUBS];
void* mhandles[NCCL_PROXY_MAX_SUBS]; void* mhandles[NCCL_PROXY_MAX_SUBS];
for (int i=0; i<NCCL_PROXY_MAX_SUBS; i++) sizes[i] = 0; for (int i=0; i<NCCL_PROXY_MAX_SUBS; i++) sizes[i] = 0;
NCCLCHECK(ncclNetTest(comm, subGroup->requests[step%NCCL_STEPS], &done, sizes)); NCCLCHECK(proxyState->ncclNet->test(subGroup->requests[step%NCCL_STEPS], &done, sizes));
if (done) { if (done) {
int needFlush = 0; int needFlush = 0;
int totalSize = 0; int totalSize = 0;
@ -1129,7 +1172,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
} }
} }
struct recvResources* resources = (struct recvResources*) (subGroup->connection->transportResources); struct recvResources* resources = (struct recvResources*) (subGroup->connection->transportResources);
NCCLCHECK(ncclNetIflush(comm, resources->netRecvComm, subCount, ptrs, sizes, mhandles, subGroup->requests+(step%NCCL_STEPS))); NCCLCHECK(proxyState->ncclNet->iflush(resources->netRecvComm, subCount, ptrs, sizes, mhandles, subGroup->requests+(step%NCCL_STEPS)));
} }
} }
args->idle = 0; args->idle = 0;
@ -1144,7 +1187,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
uint64_t step = subGroup->transmitted; uint64_t step = subGroup->transmitted;
int done = 1; int done = 1;
void* request = subGroup->requests[step%NCCL_STEPS]; void* request = subGroup->requests[step%NCCL_STEPS];
if (request) NCCLCHECK(ncclNetTest(comm, request, &done, NULL)); if (request) NCCLCHECK(proxyState->ncclNet->test(request, &done, NULL));
if (done) { if (done) {
for (int i=0; i<subGroup->groupSize; i++) { for (int i=0; i<subGroup->groupSize; i++) {
struct ncclProxySubArgs* sub = subGroup + i; struct ncclProxySubArgs* sub = subGroup + i;

View File

@ -99,6 +99,7 @@ static void* ncclIbAsyncThreadMain(void* args) {
} }
NCCL_PARAM(IbDisable, "IB_DISABLE", 0); NCCL_PARAM(IbDisable, "IB_DISABLE", 0);
NCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1);
static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort) { static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort) {
char devicePath[PATH_MAX]; char devicePath[PATH_MAX];
@ -110,7 +111,7 @@ static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort)
// Merge multi-port NICs into the same PCI device // Merge multi-port NICs into the same PCI device
p[strlen(p)-1] = '0'; p[strlen(p)-1] = '0';
// Also merge virtual functions (VF) into the same device // Also merge virtual functions (VF) into the same device
p[strlen(p)-3] = '0'; if (ncclParamIbMergeVfs()) p[strlen(p)-3] = '0';
// And keep the real port aside (the ibv port is always 1 on recent cards) // And keep the real port aside (the ibv port is always 1 on recent cards)
*realPort = 0; *realPort = 0;
for (int d=0; d<ncclNIbDevs; d++) { for (int d=0; d<ncclNIbDevs; d++) {
@ -381,16 +382,25 @@ struct ncclIbHandle {
struct ncclIbCommStage stage; // Used by the other side when connecting struct ncclIbCommStage stage; // Used by the other side when connecting
}; };
// Retain local and remote RoCE addresses for error logging
struct ncclIbGidInfo {
uint8_t link_layer;
union ibv_gid localGid;
union ibv_gid remoteGid;
};
#define NCCL_NET_IB_REQ_UNUSED 0 #define NCCL_NET_IB_REQ_UNUSED 0
#define NCCL_NET_IB_REQ_SEND 1 #define NCCL_NET_IB_REQ_SEND 1
#define NCCL_NET_IB_REQ_RECV 2 #define NCCL_NET_IB_REQ_RECV 2
#define NCCL_NET_IB_REQ_FLUSH 3 #define NCCL_NET_IB_REQ_FLUSH 3
const char* reqTypeStr[] = { "Unused", "Send", "Recv", "Flush" };
struct ncclIbRequest { struct ncclIbRequest {
struct ncclIbVerbs* verbs; struct ncclIbVerbs* verbs;
int type; int type;
int events; int events;
struct ncclSocket* sock; struct ncclSocket* sock;
struct ncclIbGidInfo* gidInfo;
int nreqs; int nreqs;
union { union {
struct { struct {
@ -440,8 +450,10 @@ struct ncclIbSendComm {
int ready; int ready;
struct ibv_qp* qps[NCCL_IB_MAX_QPS]; struct ibv_qp* qps[NCCL_IB_MAX_QPS];
int nqps; int nqps;
int qpIndex;
struct ibv_mr* fifoMr; struct ibv_mr* fifoMr;
int ar; int ar;
struct ncclIbGidInfo gidInfo;
}; };
// The SendFifo needs to be 32-byte aligned and each element needs // The SendFifo needs to be 32-byte aligned and each element needs
// to be a 32-byte multiple, so that an entry does not get split and // to be a 32-byte multiple, so that an entry does not get split and
@ -474,7 +486,9 @@ struct ncclIbRecvComm {
int ready; int ready;
struct ibv_qp* qps[NCCL_IB_MAX_QPS]; struct ibv_qp* qps[NCCL_IB_MAX_QPS];
int nqps; int nqps;
int qpIndex;
struct ncclIbGpuFlush gpuFlush; struct ncclIbGpuFlush gpuFlush;
struct ncclIbGidInfo gidInfo;
}; };
static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned");
@ -648,15 +662,14 @@ ib_connect_check:
// RoCE support // RoCE support
qpInfo.lid = portAttr.lid; qpInfo.lid = portAttr.lid;
qpInfo.link_layer = portAttr.link_layer; qpInfo.link_layer = comm->gidInfo.link_layer = portAttr.link_layer;
if (qpInfo.link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB if (qpInfo.link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB
for (int q=0; q<comm->nqps; q++) for (int q=0; q<comm->nqps; q++)
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.lid); INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.lid);
} else { // RoCE } else { // RoCE
union ibv_gid gid; NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &comm->gidInfo.localGid));
NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &gid)); qpInfo.spn = comm->gidInfo.localGid.global.subnet_prefix;
qpInfo.spn = gid.global.subnet_prefix; qpInfo.iid = comm->gidInfo.localGid.global.interface_id;
qpInfo.iid = gid.global.interface_id;
for (int q=0; q<comm->nqps; q++) for (int q=0; q<comm->nqps; q++)
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid); INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid);
} }
@ -682,6 +695,8 @@ ib_connect:
memcpy(&remQpInfo, stage->buffer, sizeof(ncclIbQpInfo)); memcpy(&remQpInfo, stage->buffer, sizeof(ncclIbQpInfo));
comm->gidInfo.remoteGid.global.subnet_prefix = remQpInfo.spn;
comm->gidInfo.remoteGid.global.interface_id = remQpInfo.iid;
for (int q=0; q<comm->nqps; q++) { for (int q=0; q<comm->nqps; q++) {
struct ibv_qp* qp = comm->qps[q]; struct ibv_qp* qp = comm->qps[q];
NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo));
@ -743,6 +758,9 @@ ib_recv:
/* copy back the received info */ /* copy back the received info */
memcpy(&remQpInfo, stage->buffer, sizeof(struct ncclIbQpInfo)); memcpy(&remQpInfo, stage->buffer, sizeof(struct ncclIbQpInfo));
rComm->gidInfo.remoteGid.global.subnet_prefix = remQpInfo.spn;
rComm->gidInfo.remoteGid.global.interface_id = remQpInfo.iid;
// IB setup // IB setup
struct ibv_context* ctx; struct ibv_context* ctx;
uint8_t ib_port; uint8_t ib_port;
@ -750,8 +768,7 @@ ib_recv:
ib_port = ncclIbDevs[lComm->dev].port; ib_port = ncclIbDevs[lComm->dev].port;
struct ibv_port_attr portAttr; struct ibv_port_attr portAttr;
NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr)); NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr));
union ibv_gid gid; NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &rComm->gidInfo.localGid));
NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &gid));
// QP Creation // QP Creation
NCCLCHECK(ncclIbInitVerbs(lComm->dev, ctx, &rComm->verbs)); NCCLCHECK(ncclIbInitVerbs(lComm->dev, ctx, &rComm->verbs));
@ -789,8 +806,8 @@ ib_recv:
localQpInfo.lid=portAttr.lid; localQpInfo.lid=portAttr.lid;
localQpInfo.link_layer=portAttr.link_layer; localQpInfo.link_layer=portAttr.link_layer;
localQpInfo.ib_port=ib_port; localQpInfo.ib_port=ib_port;
localQpInfo.spn=gid.global.subnet_prefix; localQpInfo.spn=rComm->gidInfo.localGid.global.subnet_prefix;
localQpInfo.iid=gid.global.interface_id; localQpInfo.iid=rComm->gidInfo.localGid.global.interface_id;
localQpInfo.mtu=portAttr.active_mtu; localQpInfo.mtu=portAttr.active_mtu;
NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, rComm->gpuFlush.qp->qp_num, &localQpInfo)); NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, rComm->gpuFlush.qp->qp_num, &localQpInfo));
NCCLCHECK(ncclIbRtsQp(rComm->gpuFlush.qp)); NCCLCHECK(ncclIbRtsQp(rComm->gpuFlush.qp));
@ -799,11 +816,11 @@ ib_recv:
// Fill Handle // Fill Handle
struct ncclIbQpInfo qpInfo; struct ncclIbQpInfo qpInfo;
qpInfo.lid=portAttr.lid; qpInfo.lid=portAttr.lid;
qpInfo.link_layer=portAttr.link_layer; qpInfo.link_layer= rComm->gidInfo.link_layer = portAttr.link_layer;
qpInfo.ib_port=ib_port; qpInfo.ib_port=ib_port;
for (int q=0; q<rComm->nqps; q++) qpInfo.qpn[q]=rComm->qps[q]->qp_num; for (int q=0; q<rComm->nqps; q++) qpInfo.qpn[q]=rComm->qps[q]->qp_num;
qpInfo.spn=gid.global.subnet_prefix; qpInfo.spn=rComm->gidInfo.localGid.global.subnet_prefix;
qpInfo.iid=gid.global.interface_id; qpInfo.iid=rComm->gidInfo.localGid.global.interface_id;
qpInfo.mtu=remQpInfo.mtu; qpInfo.mtu=remQpInfo.mtu;
stage->state = ncclIbCommStateSend; stage->state = ncclIbCommStateSend;
@ -841,6 +858,7 @@ ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest**
r->verbs = verbs; r->verbs = verbs;
r->events = 1; r->events = 1;
r->sock = NULL; r->sock = NULL;
r->gidInfo = NULL;
*req = r; *req = r;
return ncclSuccess; return ncclSuccess;
} }
@ -945,6 +963,8 @@ returning:
return res; return res;
} }
NCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 1);
ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
struct ncclIbRequest** reqs = comm->fifoReqs[slot]; struct ncclIbRequest** reqs = comm->fifoReqs[slot];
volatile struct ncclIbSendFifo* slots = comm->fifo[slot]; volatile struct ncclIbSendFifo* slots = comm->fifo[slot];
@ -1000,9 +1020,10 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
// Multi-QP: make sure IB writes are multiples of 128B so that LL and LL128 protocols still work // Multi-QP: make sure IB writes are multiples of 128B so that LL and LL128 protocols still work
const int align = 128; const int align = 128;
for (int q=0; q<comm->nqps; q++) { const int nqps = ncclParamIbSplitDataOnQps() ? comm->nqps : 1;
for (int q=0; q<nqps; q++) {
for (int r=0; r<nreqs; r++) { for (int r=0; r<nreqs; r++) {
int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, comm->nqps), align) * align; int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align;
int length = std::min(reqs[r]->send.size-reqs[r]->send.offset, chunkSize); int length = std::min(reqs[r]->send.size-reqs[r]->send.offset, chunkSize);
if (length <= 0) { if (length <= 0) {
comm->wrs[r].sg_list = NULL; comm->wrs[r].sg_list = NULL;
@ -1014,10 +1035,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
} }
} }
struct ibv_send_wr* bad_wr; struct ibv_send_wr* bad_wr;
NCCLCHECK(wrap_ibv_post_send(comm->qps[q], comm->wrs, &bad_wr)); NCCLCHECK(wrap_ibv_post_send(comm->qps[comm->qpIndex], comm->wrs, &bad_wr));
comm->qpIndex = (comm->qpIndex+1)%comm->nqps;
for (int r=0; r<nreqs; r++) { for (int r=0; r<nreqs; r++) {
int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, comm->nqps), align) * align; int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align;
reqs[r]->send.offset += chunkSize; reqs[r]->send.offset += chunkSize;
comm->sges[r].addr += chunkSize; comm->sges[r].addr += chunkSize;
comm->wrs[r].wr.rdma.remote_addr += chunkSize; comm->wrs[r].wr.rdma.remote_addr += chunkSize;
@ -1077,7 +1099,8 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh
req->send.data = data; req->send.data = data;
req->send.lkey = mr->lkey; req->send.lkey = mr->lkey;
req->send.offset = 0; req->send.offset = 0;
req->events = comm->nqps; req->events = ncclParamIbSplitDataOnQps() ? comm->nqps : 1;
if (comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) req->gidInfo = &comm->gidInfo;
*request = reqs[r] = req; *request = reqs[r] = req;
// If this is a multi-recv, send only when all requests have matched. // If this is a multi-recv, send only when all requests have matched.
@ -1171,6 +1194,7 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta
req->type = NCCL_NET_IB_REQ_RECV; req->type = NCCL_NET_IB_REQ_RECV;
req->sock = &comm->sock; req->sock = &comm->sock;
req->nreqs = n; req->nreqs = n;
if (comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) req->gidInfo = &comm->gidInfo;
for (int i=0; i<n; i++) req->recv.sizes[i] = 0; for (int i=0; i<n; i++) req->recv.sizes[i] = 0;
struct ibv_recv_wr wr; struct ibv_recv_wr wr;
@ -1181,13 +1205,15 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta
wr.num_sge = 0; wr.num_sge = 0;
TIME_START(1); TIME_START(1);
for (int q=0; q<comm->nqps; q++) { const int nqps = ncclParamIbSplitDataOnQps() ? comm->nqps : 1;
struct ibv_qp* qp = comm->qps[q]; for (int q=0; q<nqps; q++) {
struct ibv_qp* qp = comm->qps[comm->qpIndex];
struct ibv_recv_wr* bad_wr; struct ibv_recv_wr* bad_wr;
NCCLCHECK(wrap_ibv_post_recv(qp, &wr, &bad_wr)); NCCLCHECK(wrap_ibv_post_recv(qp, &wr, &bad_wr));
comm->qpIndex = (comm->qpIndex+1)%comm->nqps;
} }
TIME_STOP(1); TIME_STOP(1);
req->events = comm->nqps; req->events = nqps;
*request = req; *request = req;
@ -1258,8 +1284,16 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
char line[SOCKET_NAME_MAXLEN+1]; char line[SOCKET_NAME_MAXLEN+1];
union ncclSocketAddress addr; union ncclSocketAddress addr;
ncclSocketGetAddr(r->sock, &addr); ncclSocketGetAddr(r->sock, &addr);
WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d", char localGidString[INET6_ADDRSTRLEN] = "";
ncclSocketToString(&addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err); char remoteGidString[INET6_ADDRSTRLEN] = "";
const char* localGidStr = NULL, *remoteGidStr = NULL;
if (r->gidInfo) {
localGidStr = inet_ntop(AF_INET6, &r->gidInfo->localGid, localGidString, sizeof(localGidString));
remoteGidStr = inet_ntop(AF_INET6, &r->gidInfo->remoteGid, remoteGidString, sizeof(remoteGidString));
}
WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d (%s)%s%s%s%s",
ncclSocketToString(&addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type],
localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGid ":"", remoteGidString);
return ncclRemoteError; return ncclRemoteError;
} }

View File

@ -43,22 +43,7 @@ struct ncclTransport nvlsTransport = {
{ NULL, NULL, nvlsRecvFree, NULL, NULL, NULL, NULL, NULL } { NULL, NULL, nvlsRecvFree, NULL, NULL, NULL, NULL, NULL }
}; };
#define NVLS_HANDLE_SIZE 64 ncclResult_t nvlsGetProperties(struct ncclComm *comm, struct ncclNvlsSharedRes* resources, int dev, int nranks, size_t size) {
struct nvlsResources {
CUmulticastObjectProp properties;
CUmemAccessDesc accessDesc;
int dev;
size_t size;
size_t granularity;
CUmemGenericAllocationHandle mcHandle; // Multicast handle for NVLS buffer
char* mcBuff; // Multicast NVLS buffer address
CUmemGenericAllocationHandle ucHandle; // Unicast Handle for NVLS buffer
char* ucBuff; // Unicast NVLS buffer address
};
ncclResult_t nvlsGetProperties(struct ncclComm *comm, struct nvlsResources* resources, int dev, int nranks, size_t size) {
CUmulticastObjectProp* prop = &resources->properties; CUmulticastObjectProp* prop = &resources->properties;
memset(prop, 0, sizeof(*prop)); memset(prop, 0, sizeof(*prop));
prop->size = size; prop->size = size;
@ -81,7 +66,7 @@ ncclResult_t nvlsGetProperties(struct ncclComm *comm, struct nvlsResources* reso
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t nvlsGroupCreate(struct ncclComm *comm, struct nvlsResources* resources, int rank, unsigned int nranks, char* shareableHandle) { ncclResult_t nvlsGroupCreate(struct ncclComm *comm, struct ncclNvlsSharedRes* resources, int rank, unsigned int nranks, char* shareableHandle) {
size_t size = resources->size; size_t size = resources->size;
// Create a Multicast group // Create a Multicast group
@ -103,24 +88,13 @@ ncclResult_t nvlsGroupCreate(struct ncclComm *comm, struct nvlsResources* resour
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t nvlsGroupAddDevice(struct ncclComm *comm, struct nvlsResources* resources) { ncclResult_t nvlsGroupAddDevice(struct ncclComm *comm, struct ncclNvlsSharedRes* resources) {
INFO(NCCL_NVLS, "NVLS group %llx adding dev %d", resources->mcHandle, resources->dev); INFO(NCCL_NVLS, "NVLS group %llx adding dev %d", resources->mcHandle, resources->dev);
CUCHECK(cuMulticastAddDevice(resources->mcHandle, resources->dev)); CUCHECK(cuMulticastAddDevice(resources->mcHandle, resources->dev));
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t nvlsGroupUnbind(struct ncclComm *comm, struct nvlsResources* resources) { ncclResult_t nvlsGroupConnect(struct ncclComm *comm, struct ncclNvlsSharedRes* resources, int rank, char* shareableHandle) {
int dev = resources->dev;
size_t size = resources->size;
INFO(NCCL_NVLS, "NVLS Unbind MC handle %llx size %zi dev %d", resources->mcHandle, size, dev);
// Unbind physical memory from group for the given device
CUCHECK(cuMulticastUnbind(resources->mcHandle, dev, 0/*mcOffset*/, size));
return ncclSuccess;
}
ncclResult_t nvlsGroupConnect(struct ncclComm *comm, struct nvlsResources* resources, int rank, char* shareableHandle) {
CUmemAllocationHandleType type = NVLS_CU_MEM_HANDLE_TYPE; CUmemAllocationHandleType type = NVLS_CU_MEM_HANDLE_TYPE;
INFO(NCCL_NVLS, "NVLS importing shareableHandle %p from rank %d", shareableHandle, rank); INFO(NCCL_NVLS, "NVLS importing shareableHandle %p from rank %d", shareableHandle, rank);
@ -131,9 +105,11 @@ ncclResult_t nvlsGroupConnect(struct ncclComm *comm, struct nvlsResources* resou
int fd = *(int *)shareableHandle; int fd = *(int *)shareableHandle;
TRACE(NCCL_NVLS, "NVLS rank %d Importing shareable handle from rank %d fd %d", comm->localRank, rank, fd); TRACE(NCCL_NVLS, "NVLS rank %d Importing shareable handle from rank %d fd %d", comm->localRank, rank, fd);
struct ncclProxyConnector proxyConn; struct ncclProxyConnector proxyConn;
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, rank, &proxyConn)); int tpProxyRank = comm->topParentRanks[rank];
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, tpProxyRank, &proxyConn));
TRACE(NCCL_NVLS, "NVLS rank %d request conversion of fd %d from rank %d", comm->localRank, fd, rank); TRACE(NCCL_NVLS, "NVLS rank %d request conversion of fd %d from rank %d", comm->localRank, fd, rank);
NCCLCHECK(ncclProxyCallBlocking(&proxyConn, ncclProxyMsgConvertFd, shareableHandle, sizeof(int), &fd, sizeof(int))); NCCLCHECK(ncclProxyClientConvertFdBlocking(comm, &proxyConn, fd, (int *)shareableHandle));
fd = *(int *)shareableHandle;
TRACE(NCCL_NVLS, "NVLS rank %d received converted fd %d from rank %d", comm->localRank, fd, rank); TRACE(NCCL_NVLS, "NVLS rank %d received converted fd %d from rank %d", comm->localRank, fd, rank);
CUCHECK(cuMemImportFromShareableHandle(&resources->mcHandle, (void *)(uintptr_t)fd, type)); CUCHECK(cuMemImportFromShareableHandle(&resources->mcHandle, (void *)(uintptr_t)fd, type));
} else { } else {
@ -146,7 +122,20 @@ ncclResult_t nvlsGroupConnect(struct ncclComm *comm, struct nvlsResources* resou
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t nvlsGroupBindMem(struct ncclComm *comm, struct nvlsResources* resources) { ncclResult_t nvlsGroupDisconnect(struct ncclComm *comm, struct ncclNvlsSharedRes* resources) {
CUmemAllocationHandleType type = NVLS_CU_MEM_HANDLE_TYPE;
// Import and map the remote memory descriptor to the local GPU
if (type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) {
// cuMem UDS support
int fd = *(int *)resources->shareableHandle;
(void) close(fd);
}
return ncclSuccess;
}
ncclResult_t nvlsGroupBindMem(struct ncclComm *comm, struct ncclNvlsSharedRes* resources) {
size_t size = resources->size; size_t size = resources->size;
size_t granularity; size_t granularity;
CUdeviceptr ptr = 0; CUdeviceptr ptr = 0;
@ -178,7 +167,21 @@ ncclResult_t nvlsGroupBindMem(struct ncclComm *comm, struct nvlsResources* resou
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t nvlsGroupMapMem(struct ncclComm *comm, struct nvlsResources* resources) { ncclResult_t nvlsGroupUnbind(struct ncclComm *comm, struct ncclNvlsSharedRes* resources) {
int dev = resources->dev;
size_t size = resources->size;
INFO(NCCL_NVLS, "NVLS Unbind MC handle %llx size %zi dev %d", resources->mcHandle, size, dev);
// Unbind physical memory from group for the given device
CUCHECK(cuMulticastUnbind(resources->mcHandle, dev, 0/*mcOffset*/, size));
// Release the MC group resources
NCCLCHECK(nvlsGroupDisconnect(comm, resources));
return ncclSuccess;
}
ncclResult_t nvlsGroupMapMem(struct ncclComm *comm, struct ncclNvlsSharedRes* resources) {
size_t size = resources->size; size_t size = resources->size;
CUdeviceptr ptr = 0; CUdeviceptr ptr = 0;
@ -196,7 +199,7 @@ ncclResult_t nvlsGroupMapMem(struct ncclComm *comm, struct nvlsResources* resour
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t nvlsGroupUnmapMem(struct ncclComm *comm, struct nvlsResources* resources) { ncclResult_t nvlsGroupUnmapMem(struct ncclComm *comm, struct ncclNvlsSharedRes* resources) {
size_t size; size_t size;
CUdeviceptr ptr; CUdeviceptr ptr;
INFO(NCCL_NVLS, "NVLS Unmap mem UC handle 0x%llx(%p) MC handle 0x%llx(%p)", INFO(NCCL_NVLS, "NVLS Unmap mem UC handle 0x%llx(%p) MC handle 0x%llx(%p)",
@ -224,135 +227,173 @@ ncclResult_t nvlsGroupUnmapMem(struct ncclComm *comm, struct nvlsResources* reso
#define NVLS_MEM_ALIGN_SIZE (1 << 21) #define NVLS_MEM_ALIGN_SIZE (1 << 21)
NCCL_PARAM(NvlsEnable, "NVLS_ENABLE", 2);
NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", 16); NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", 16);
NCCL_PARAM(NvlsEnable, "NVLS_ENABLE", 1); ncclResult_t ncclNvlsInit(struct ncclComm* comm) {
comm->nvlsSupport = 0;
comm->nvlsChannels = 0;
int gpuCount;
NCCLCHECK(ncclTopoGetGpuCount(comm->topo, &gpuCount));
if (!ncclParamNvlsEnable() || gpuCount <= 2) return ncclSuccess;
ncclResult_t ncclNvlsSetup(struct ncclComm* comm) {
if (!ncclParamNvlsEnable() || comm->localRanks <= 1 || comm->nNodes>1) return ncclSuccess;
CUdevice dev; CUdevice dev;
int driverVersion; int driverVersion;
if (CUPFN(cuDeviceGet) == NULL) return ncclSuccess; if (CUPFN(cuDeviceGet) == NULL) return ncclSuccess;
CUCHECK(cuDeviceGet(&dev, comm->cudaDev)); CUCHECK(cuCtxGetDevice(&dev));
CUDACHECK(cudaDriverGetVersion(&driverVersion)); CUDACHECK(cudaDriverGetVersion(&driverVersion));
comm->nvlsSupport = 0; if (ncclParamNvlsEnable() == 2) {
// NVLS Multicast support requires CUDA12.1 UMD + KMD // NVLS Multicast support requires CUDA12.1 UMD + KMD
if (CUPFN(cuMulticastCreate) != NULL && driverVersion >= 12010) { if (CUPFN(cuMulticastCreate) != NULL /*&& driverVersion >= 12010 */) {
CUCHECK(cuDeviceGetAttribute(&comm->nvlsSupport, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev)); CUCHECK(cuDeviceGetAttribute(&comm->nvlsSupport, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev));
} }
INFO(NCCL_INIT, "NVLS multicast support is %savailable on dev %d", comm->nvlsSupport ? "" : "not ", dev);
if (comm->nvlsSupport == 0) return ncclSuccess;
int nChannels = comm->nvlsChannels = std::max(comm->minCTAs, std::min(comm->maxCTAs, (int)ncclParamNvlsChannels()));
int rank = comm->localRank, nranks = comm->localRanks;
for (int c=0; c<nChannels; c++) {
NCCLCHECK(initChannel(comm, c));
}
ncclResult_t res = ncclSuccess;
struct nvlsResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
comm->nvlsResources = resources;
size_t buffSize = comm->buffSizes[NCCL_PROTO_SIMPLE];
size_t memSize = NVLS_MEM_ALIGN_SIZE;
size_t nvlsPerRankSize = nChannels*2*(buffSize+memSize);
size_t nvlsTotalSize = nvlsPerRankSize*nranks;
INFO(NCCL_INIT|NCCL_NVLS, "NVLS comm %p rank %d nranks %d buffSize %zi memSize %zi nvlsPerRankSize %zi nvlsTotalSize %zi",
comm, rank, nranks, buffSize, memSize, nvlsPerRankSize, nvlsTotalSize);
char* nvlsShareableHandle = NULL;
NCCLCHECKGOTO(ncclCalloc(&nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup);
NCCLCHECKGOTO(nvlsGetProperties(comm, resources, dev, nranks, nvlsTotalSize), res, cleanup);
if (rank == 0) {
NCCLCHECKGOTO(nvlsGroupCreate(comm, resources, rank, nranks, nvlsShareableHandle), res, cleanup);
NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, rank, nranks, 0, nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup);
} else { } else {
NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, rank, nranks, 0, nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup); comm->nvlsSupport = 1;
NCCLCHECKGOTO(nvlsGroupConnect(comm, resources, 0, nvlsShareableHandle), res, cleanup);
} }
NCCLCHECKGOTO(nvlsGroupAddDevice(comm, resources), res, cleanup); INFO(NCCL_INIT, "NVLS multicast support is %savailable on dev %d", comm->nvlsSupport ? "" : "not ", dev);
NCCLCHECKGOTO(nvlsGroupBindMem(comm, resources), res, cleanup); if (comm->nvlsSupport == 1) comm->nvlsChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, (int)ncclParamNvlsChannels()));
// Local intra-node barrier to ensure everyone has bound their memory to the group return ncclSuccess;
NCCLCHECKGOTO(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]), res, cleanup); }
NCCLCHECKGOTO(nvlsGroupMapMem(comm, resources), res, cleanup);
for (int c=0; c<nChannels; c++) { ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent) {
struct ncclChannel* channel = comm->channels+c; if (comm->nvlsSupport == 0 || comm->nvlsChannels == 0) return ncclSuccess;
channel->nvls.nHeads = nranks;
for (int i=0; i<NCCL_MAX_NVLS_ARITY; i++) channel->nvls.up[i] = -1;
channel->nvls.down = comm->nRanks+1+comm->localRank;
channel->nvls.out = -1; // Network not yet implemented.
channel->nvls.headRank = comm->localRank; // Network not yet implemented.
}
for (int r=0; r<nranks; r++) { int nHeads = comm->channels[0].nvls.nHeads;
int nvlsPeer = comm->nRanks+1+r; int headRank = comm->channels[0].nvls.headRank;
for (int c=0; c<nChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
channel->nvls.up[r] = nvlsPeer;
char* mem = NULL; CUdevice dev;
struct ncclChannelPeer* peer = channel->peers+nvlsPeer; CUCHECK(cuCtxGetDevice(&dev));
// Reduce UC -> MC ncclResult_t res = ncclSuccess;
mem = resources->ucBuff + (r*2*nChannels+c)*(buffSize+memSize); bool nvlsShare = true;
peer->send[0].transportComm = &nvlsTransport.send; if (parent && parent->nvlsSupport && parent->config.splitShare && parent->localRanks == comm->localRanks)
peer->send[0].conn.buffs[NCCL_PROTO_SIMPLE] = mem; nvlsShare = true;
peer->send[0].conn.head = (uint64_t*)(mem+buffSize); else
peer->send[0].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); nvlsShare = false;
mem = resources->mcBuff + (r*2*nChannels+c)*(buffSize+memSize);
peer->recv[1].transportComm = &nvlsTransport.recv;
peer->recv[1].conn.buffs[NCCL_PROTO_SIMPLE] = mem;
peer->recv[1].conn.head = (uint64_t*)(mem+buffSize);
peer->recv[1].conn.tail = (uint64_t*)(mem+buffSize+memSize/2);
peer->recv[1].conn.flags |= NCCL_NVLS_MIN_POLL;
// Broadcast MC -> UC if (nvlsShare) {
mem = resources->ucBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize); /* reuse NVLS resources */
peer->recv[0].transportComm = &nvlsTransport.recv; comm->nvlsChannels = std::min(comm->nvlsChannels, parent->nvlsResources->nChannels);
peer->recv[0].conn.buffs[NCCL_PROTO_SIMPLE] = mem; for (int c = 0; c < comm->nvlsChannels; c++) {
peer->recv[0].conn.head = (uint64_t*)(mem+buffSize); NCCLCHECKGOTO(initNvlsChannel(comm, c, parent, true), res, cleanup);
peer->recv[0].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); }
mem = resources->mcBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize);
peer->send[1].transportComm = &nvlsTransport.send;
peer->send[1].conn.buffs[NCCL_PROTO_SIMPLE] = mem;
peer->send[1].conn.head = (uint64_t*)(mem+buffSize);
peer->send[1].conn.tail = (uint64_t*)(mem+buffSize+memSize/2);
peer->send[1].conn.flags |= NCCL_NVLS_MIN_POLL;
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].send[0], &peer->send[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); comm->nvlsResources = parent->nvlsResources;
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].recv[0], &peer->recv[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); ncclAtomicRefCountIncrement(&parent->nvlsResources->refCount);
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].send[1], &peer->send[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); } else {
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].recv[1], &peer->recv[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); int nChannels;
ncclResult_t res = ncclSuccess;
struct ncclNvlsSharedRes* resources;
/*INFO(NCCL_INIT|NCCL_NVLS, "Peer %d Channel %d MC buff %p/%p UC Buff %p/%p", NCCLCHECK(ncclCalloc(&resources, 1));
nvlsPeer, c, comm->nvlsResources = resources;
resources->mcBuff + (r*2*nChannels+c)*(buffSize+memSize), resources->refCount = 1;
resources->mcBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize),
resources->ucBuff + (r*2*nChannels+c)*(buffSize+memSize), if (parent && parent->config.splitShare) {
resources->ucBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize));*/ /* ranks on other nodes might share the NVLS resources, we need to cap nvlsChannels
* to make sure nvlsChannels match for each rank. */
comm->nvlsChannels = std::min(comm->nvlsChannels, parent->nvlsResources->nChannels);
}
nChannels = resources->nChannels = comm->nvlsChannels;
for (int c = 0; c < nChannels; c++) {
NCCLCHECK(initNvlsChannel(comm, c, parent, false));
}
size_t buffSize = comm->buffSizes[NCCL_PROTO_SIMPLE];
size_t memSize = NVLS_MEM_ALIGN_SIZE;
size_t nvlsPerRankSize = nChannels * 2 * (buffSize + memSize);
size_t nvlsTotalSize = nvlsPerRankSize * nHeads;
INFO(NCCL_INIT | NCCL_NVLS, "NVLS comm %p headRank %d nHeads %d buffSize %zi memSize %zi nvlsPerRankSize %zi nvlsTotalSize %zi",
comm, headRank, nHeads, buffSize, memSize, nvlsPerRankSize, nvlsTotalSize);
char* shareableHandle = resources->shareableHandle;
NCCLCHECKGOTO(nvlsGetProperties(comm, resources, dev, comm->localRanks, nvlsTotalSize), res, cleanup);
if (comm->localRank == 0) {
NCCLCHECKGOTO(nvlsGroupCreate(comm, resources, comm->localRank, comm->localRanks, shareableHandle), res, cleanup);
NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, shareableHandle, NVLS_HANDLE_SIZE), res, cleanup);
} else {
NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, shareableHandle, NVLS_HANDLE_SIZE), res, cleanup);
NCCLCHECKGOTO(nvlsGroupConnect(comm, resources, comm->localRankToRank[0], shareableHandle), res, cleanup);
}
NCCLCHECKGOTO(nvlsGroupAddDevice(comm, resources), res, cleanup);
NCCLCHECKGOTO(nvlsGroupBindMem(comm, resources), res, cleanup);
// Local intra-node barrier to ensure everyone has bound their memory to the group
NCCLCHECKGOTO(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]), res, cleanup);
NCCLCHECKGOTO(nvlsGroupMapMem(comm, resources), res, cleanup);
for (int h = 0; h < nHeads; h++) {
int nvlsPeer = comm->nRanks + 1 + h;
for (int c = 0; c < nChannels; c++) {
struct ncclChannel* channel = comm->channels + c;
char* mem = NULL;
struct ncclChannelPeer* peer = channel->peers[nvlsPeer];
// Reduce UC -> MC
mem = resources->ucBuff + (h * 2 * nChannels + c) * (buffSize + memSize);
peer->send[1].transportComm = &nvlsTransport.send;
peer->send[1].conn.buffs[NCCL_PROTO_SIMPLE] = mem;
peer->send[1].conn.head = (uint64_t*)(mem + buffSize);
peer->send[1].conn.tail = (uint64_t*)(mem + buffSize + memSize / 2);
mem = resources->mcBuff + (h * 2 * nChannels + c) * (buffSize + memSize);
peer->recv[0].transportComm = &nvlsTransport.recv;
peer->recv[0].conn.buffs[NCCL_PROTO_SIMPLE] = mem;
peer->recv[0].conn.head = (uint64_t*)(mem + buffSize);
peer->recv[0].conn.tail = (uint64_t*)(mem + buffSize + memSize / 2);
peer->recv[0].conn.flags |= NCCL_NVLS_MIN_POLL;
// Broadcast MC -> UC
mem = resources->ucBuff + ((h * 2 + 1) * nChannels + c) * (buffSize + memSize);
peer->recv[1].transportComm = &nvlsTransport.recv;
peer->recv[1].conn.buffs[NCCL_PROTO_SIMPLE] = mem;
peer->recv[1].conn.head = (uint64_t*)(mem + buffSize);
peer->recv[1].conn.tail = (uint64_t*)(mem + buffSize + memSize / 2);
mem = resources->mcBuff + ((h * 2 + 1) * nChannels + c) * (buffSize + memSize);
peer->send[0].transportComm = &nvlsTransport.send;
peer->send[0].conn.buffs[NCCL_PROTO_SIMPLE] = mem;
peer->send[0].conn.head = (uint64_t*)(mem + buffSize);
peer->send[0].conn.tail = (uint64_t*)(mem + buffSize + memSize / 2);
peer->send[0].conn.flags |= NCCL_NVLS_MIN_POLL;
struct ncclDevChannelPeer* addr;
CUDACHECKGOTO(cudaMemcpyAsync(&addr, comm->channels[c].devPeers + nvlsPeer, sizeof(struct ncclDevChannelPeer*), cudaMemcpyDeviceToHost, comm->sharedRes->hostStream.cudaStream), res, cleanup);
CUDACHECKGOTO(cudaMemcpyAsync(&addr->send[0], &peer->send[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
CUDACHECKGOTO(cudaMemcpyAsync(&addr->recv[0], &peer->recv[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
CUDACHECKGOTO(cudaMemcpyAsync(&addr->send[1], &peer->send[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
CUDACHECKGOTO(cudaMemcpyAsync(&addr->recv[1], &peer->recv[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
/*INFO(NCCL_INIT|NCCL_NVLS, "Peer %d Channel %d MC buff %p/%p UC Buff %p/%p",
nvlsPeer, c,
resources->mcBuff + (h*2*nChannels+c)*(buffSize+memSize),
resources->mcBuff + ((h*2+1)*nChannels+c)*(buffSize+memSize),
resources->ucBuff + (h*2*nChannels+c)*(buffSize+memSize),
resources->ucBuff + ((h*2+1)*nChannels+c)*(buffSize+memSize));*/
}
} }
} }
free(nvlsShareableHandle);
return res; return res;
cleanup: cleanup:
comm->nvlsSupport = 0; comm->nvlsSupport = 0;
free(nvlsShareableHandle);
return res; return res;
} }
ncclResult_t ncclNvlsFree(struct ncclComm* comm) { ncclResult_t ncclNvlsFree(struct ncclComm* comm) {
struct nvlsResources* resources = (struct nvlsResources*)comm->nvlsResources; struct ncclNvlsSharedRes* resources = (struct ncclNvlsSharedRes*)comm->nvlsResources;
if (resources == NULL) return ncclSuccess; if (resources == NULL) return ncclSuccess;
NCCLCHECK(nvlsGroupUnbind(comm, resources));
NCCLCHECK(nvlsGroupUnmapMem(comm, resources)); if (ncclAtomicRefCountDecrement(&resources->refCount) == 0) {
free(resources); NCCLCHECK(nvlsGroupUnbind(comm, resources));
comm->nvlsResources = NULL; NCCLCHECK(nvlsGroupUnmapMem(comm, resources));
free(resources);
comm->nvlsResources = NULL;
}
return ncclSuccess; return ncclSuccess;
} }
@ -362,7 +403,12 @@ ncclResult_t ncclNvlsFree(struct ncclComm* comm) {
* Pre CUDA 12.1 stubs * Pre CUDA 12.1 stubs
*/ */
ncclResult_t ncclNvlsSetup(struct ncclComm* comm) { ncclResult_t ncclNvlsInit(struct ncclComm* comm) {
comm->nvlsChannels = 0;
return ncclSuccess;
}
ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent) {
return ncclSuccess; return ncclSuccess;
} }

View File

@ -8,17 +8,21 @@
#include "graph.h" #include "graph.h"
#include "utils.h" #include "utils.h"
#include "shm.h" #include "shm.h"
#include "p2p.h"
enum p2pType { P2P_DIRECT, P2P_INTERMEDIATE, P2P_IPC, P2P_CUMEM };
struct ncclP2pBuff { struct ncclP2pBuff {
void* directPtr; void* directPtr;
cudaIpcMemHandle_t devIpc; size_t size;
ncclIpcDesc ipcDesc;
}; };
struct p2pConnectInfo { struct p2pConnectInfo {
int rank; int rank;
int read; int read;
struct ncclP2pBuff p2pBuff; struct ncclP2pBuff p2pBuff;
// Use by CE memcpy // Used by CE memcpy
char shmName[7]; char shmName[7];
int shmSize; int shmSize;
}; };
@ -28,7 +32,7 @@ struct p2pShm {
struct ncclSendMem sendMem; struct ncclSendMem sendMem;
struct ncclRecvMem recvMem; struct ncclRecvMem recvMem;
}; };
struct p2pProxyInfo { struct p2pShmProxyInfo {
// Shared memory between proxy and receiving GPU // Shared memory between proxy and receiving GPU
struct p2pShm* shm; struct p2pShm* shm;
struct p2pShm* devShm; struct p2pShm* devShm;
@ -43,30 +47,34 @@ struct p2pProxyInfo {
// Receiver buffer // Receiver buffer
char* recvFifo; char* recvFifo;
// Used by progress only // Used by CE memcpy progress only
uint64_t step; uint64_t step;
cudaStream_t stream; cudaStream_t stream;
cudaEvent_t events[NCCL_STEPS]; cudaEvent_t events[NCCL_STEPS];
}; };
static_assert(sizeof(p2pConnectInfo) <= CONNECT_SIZE, "P2P Connect info is too large"); static_assert(sizeof(p2pConnectInfo) <= CONNECT_SIZE, "P2P Connect info is too large");
struct p2pSendResources { struct p2pResources {
struct ncclSendMem* devMem; enum p2pType type;
void* sendMemIpc; union {
void* recvMemIpc; struct ncclSendMem* sendDevMem;
struct p2pProxyInfo proxyInfo; struct ncclRecvMem* recvDevMem;
}; };
struct p2pRecvResources {
struct ncclRecvMem* devMem;
void* sendMemIpc; void* sendMemIpc;
void* recvMemIpc; void* recvMemIpc;
// CE memcpy support
struct p2pShmProxyInfo proxyInfo;
struct p2pShm* shm; struct p2pShm* shm;
struct p2pShm* devShm; struct p2pShm* devShm;
int shmSize; int shmSize;
ncclShmHandle_t handle; ncclShmHandle_t handle;
}; };
// cuMem API support
struct p2pCuMemProxyInfo {
struct ncclP2pBuff p2pBuff;
};
#include <sys/types.h> #include <sys/types.h>
/* Convert a PCI busId string into a local cudaDev device index (cf. CUDA_VISIBLE_DEVICES) */ /* Convert a PCI busId string into a local cudaDev device index (cf. CUDA_VISIBLE_DEVICES) */
@ -86,6 +94,7 @@ static int busIdToCudaDev(int64_t busId) {
return -1; return -1;
} }
// CE memcpy support
NCCL_PARAM(P2pUseCudaMemcpy, "P2P_USE_CUDA_MEMCPY", 0); NCCL_PARAM(P2pUseCudaMemcpy, "P2P_USE_CUDA_MEMCPY", 0);
static int useMemcpy = 0; static int useMemcpy = 0;
static void initCeOperation(); static void initCeOperation();
@ -140,7 +149,8 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
return ncclSuccess; return ncclSuccess;
} }
if (p2p != 0) { // This will always fail when using NCCL_CUMEM_ENABLE=1
if (p2p != 0 && !ncclCuMemEnable()) {
// Cached result of the legacyIPC detection // Cached result of the legacyIPC detection
static int legacyIPC = -1; static int legacyIPC = -1;
if (legacyIPC >= 0) { if (legacyIPC >= 0) {
@ -150,12 +160,12 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
// Check that legacy IPC support is available (WSL WAR) // Check that legacy IPC support is available (WSL WAR)
char *dummy; char *dummy;
cudaIpcMemHandle_t ipc; cudaIpcMemHandle_t ipc;
NCCLCHECK(ncclCudaCalloc(&dummy, CUDA_IPC_MIN)); NCCLCHECK(ncclCudaMalloc(&dummy, CUDA_IPC_MIN));
if (cudaIpcGetMemHandle(&ipc, dummy) != cudaSuccess) { if (cudaIpcGetMemHandle(&ipc, dummy) != cudaSuccess) {
INFO(NCCL_INIT|NCCL_P2P,"Legacy IPC not supported"); INFO(NCCL_INIT|NCCL_P2P,"Legacy IPC not supported");
*ret = 0; *ret = 0;
} }
CUDACHECK(cudaFree(dummy)); NCCLCHECK(ncclCudaFree(dummy));
legacyIPC = *ret; legacyIPC = *ret;
return ncclSuccess; return ncclSuccess;
} }
@ -176,6 +186,98 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
TRACE(P2P,"IPC: %016lx %016lx %016lx %016lx", devIpc[4], devIpc[5], devIpc[6], devIpc[7]); \ TRACE(P2P,"IPC: %016lx %016lx %016lx %016lx", devIpc[4], devIpc[5], devIpc[6], devIpc[7]); \
} while (0) } while (0)
// cuMem API support
ncclResult_t ncclP2pAllocateShareableBuffer(size_t size, ncclIpcDesc *ipcDesc, void **ptr) {
if (ncclCuMemEnable()) {
#if CUDART_VERSION >= 11030
// cuMem API support
CUmemAllocationHandleType type = NCCL_P2P_HANDLE_TYPE;
CUmemGenericAllocationHandle handle;
NCCLCHECK(ncclCuMemAlloc(ptr, &handle, size));
CUCHECK(cuMemExportToShareableHandle(&ipcDesc->cuDesc, handle, type, 0));
#else
return ncclInternalError;
#endif
} else {
// Allocate a CUDA buffer and generate an IPC handle for it
NCCLCHECK(ncclCudaCalloc((char **)ptr, size));
cudaError_t res = cudaIpcGetMemHandle(&ipcDesc->devIpc, *ptr);
if (res != cudaSuccess) {
WARN("cudaIpcGetMemHandle failed : %s", cudaGetErrorString(res));
ncclCudaFree(*ptr);
CUDACHECK(res);
}
}
INFO(NCCL_P2P|NCCL_ALLOC, "Allocated shareable buffer %p size %zi ipcDesc %p", *ptr, size, ipcDesc);
return ncclSuccess;
}
ncclResult_t ncclP2pFreeShareableBuffer(ncclIpcDesc *ipcDesc) {
if (ncclCuMemEnable()) {
// cuMem API support
CUmemAllocationHandleType type = NCCL_P2P_HANDLE_TYPE;
if (type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) {
int fd = *(int *) &ipcDesc->cuDesc.data;
if (fd <= 0) return ncclInternalError;
(void) close(fd);
}
}
return ncclSuccess;
}
ncclResult_t ncclP2pImportShareableBuffer(struct ncclComm *comm, int tpPeer, size_t size, ncclIpcDesc *ipcDesc, void **devMemPtr) {
if (ncclCuMemEnable()) {
#if CUDART_VERSION >= 11030
// cuMem API support
CUdeviceptr dptr = 0;
CUmemAllocationHandleType type = NCCL_P2P_HANDLE_TYPE;
CUmemGenericAllocationHandle handle;
ncclCuDesc *cuDesc = &ipcDesc->cuDesc;
// Import and map the remote memory descriptor to the local GPU
if (type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) {
// UDS fd support
struct ncclProxyConnector proxyConn;
int fd = *(int *)(&cuDesc->data);
int newFd = -1;
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, tpPeer, &proxyConn));
NCCLCHECK(ncclProxyClientConvertFdBlocking(comm, &proxyConn, fd, &newFd));
INFO(NCCL_P2P, "UDS converted fd %d -> %d on peer %d", fd, newFd, tpPeer);
CUCHECK(cuMemImportFromShareableHandle(&handle, (void *)(uintptr_t)newFd, type));
close(newFd);
} else {
CUCHECK(cuMemImportFromShareableHandle(&handle, cuDesc, type));
}
CUCHECK(cuMemAddressReserve(&dptr, size, /* alignment */ 0, /* addr */ 0, /* flags */ 0));
CUCHECK(cuMemMap(dptr, size, /* offset */ 0, handle, /* flags */ 0));
TRACE(NCCL_P2P, "Imported shareable buffer size %zi handle 0x%lx dptr %p", size, (long)handle, (void*)dptr);
// Allow access by the local GPU
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = comm->cudaDev;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CUCHECK(cuMemSetAccess(dptr, size, &accessDesc, 1));
TRACE(NCCL_P2P, "Set Access for %p size %zi dev %d", (void*)dptr, size, accessDesc.location.id);
*devMemPtr = (void *)dptr;
#else
return ncclInternalError;
#endif
} else {
// Legacy CUDA IPC
CUDACHECK(cudaIpcOpenMemHandle(devMemPtr, ipcDesc->devIpc, cudaIpcMemLazyEnablePeerAccess));
}
INFO(NCCL_P2P, "Imported shareable buffer device %d size %zi ptr %p", comm->cudaDev, size, *devMemPtr);
return ncclSuccess;
}
// Setting this to non zero causes P2P to use Reads rather than Writes // Setting this to non zero causes P2P to use Reads rather than Writes
NCCL_PARAM(P2pReadEnable, "P2P_READ_ENABLE", -2); NCCL_PARAM(P2pReadEnable, "P2P_READ_ENABLE", -2);
@ -192,10 +294,11 @@ static ncclResult_t p2pGetInfo(struct ncclTopoSystem* topo, struct ncclPeerInfo*
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t p2pMap(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclP2pBuff* p2pBuff, void** devMem, void** ipcPtr) { static ncclResult_t p2pMap(struct ncclComm *comm, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclP2pBuff* p2pBuff, void** devMem, void** ipcPtr) {
if (myInfo->pidHash == peerInfo->pidHash) { if (!ncclCuMemEnable() && myInfo->pidHash == peerInfo->pidHash) {
if (peerInfo->cudaDev != myInfo->cudaDev) { if (peerInfo->cudaDev != myInfo->cudaDev) {
// Enable P2P access // Same PID different GPUs, enable P2P access
// Legacy CUDA IPC
cudaError_t err = cudaDeviceEnablePeerAccess(peerInfo->cudaDev, 0); cudaError_t err = cudaDeviceEnablePeerAccess(peerInfo->cudaDev, 0);
if (err == cudaErrorPeerAccessAlreadyEnabled) { if (err == cudaErrorPeerAccessAlreadyEnabled) {
cudaGetLastError(); cudaGetLastError();
@ -208,8 +311,15 @@ static ncclResult_t p2pMap(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* pee
*devMem = p2pBuff->directPtr; *devMem = p2pBuff->directPtr;
*ipcPtr = NULL; *ipcPtr = NULL;
} else { } else {
CUDACHECK(cudaIpcOpenMemHandle(devMem, p2pBuff->devIpc, cudaIpcMemLazyEnablePeerAccess)); if ((myInfo->pidHash == peerInfo->pidHash) && (peerInfo->cudaDev == myInfo->cudaDev)) {
*ipcPtr = *devMem; // Same PID and GPU
*devMem = p2pBuff->directPtr;
*ipcPtr = NULL;
} else {
// Different PID or different GPU
NCCLCHECK(ncclP2pImportShareableBuffer(comm, comm->topParentRanks[peerInfo->rank], p2pBuff->size, &p2pBuff->ipcDesc, devMem));
*ipcPtr = *devMem;
}
} }
return ncclSuccess; return ncclSuccess;
} }
@ -217,7 +327,8 @@ static ncclResult_t p2pMap(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* pee
/* Send: Create and return connect structures for this peer to connect to me */ /* Send: Create and return connect structures for this peer to connect to me */
ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo,
struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) { struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) {
struct p2pSendResources* resources; struct p2pResources* resources;
int tpProxyRank;
NCCLCHECK(ncclCalloc(&resources, 1)); NCCLCHECK(ncclCalloc(&resources, 1));
send->transportResources = resources; send->transportResources = resources;
int useRead, intermediateRank; int useRead, intermediateRank;
@ -233,35 +344,47 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
int sendSize = sizeof(struct ncclSendMem); int sendSize = sizeof(struct ncclSendMem);
// For P2P Read the SIMPLE buffer is tagged on the end of the ncclSendMem structure // For P2P Read the SIMPLE buffer is tagged on the end of the ncclSendMem structure
if (info->read) sendSize += send->comm->buffSizes[NCCL_PROTO_SIMPLE]; if (info->read) sendSize += comm->buffSizes[NCCL_PROTO_SIMPLE];
ALIGN_SIZE(sendSize, CUDA_IPC_MIN); ALIGN_SIZE(sendSize, CUDA_IPC_MIN);
if (intermediateRank == -1) { if (intermediateRank == -1) {
info->rank = myInfo->rank; info->rank = myInfo->rank;
if (myInfo->pidHash == peerInfo->pidHash && useMemcpy == 0) { if (myInfo->pidHash == peerInfo->pidHash && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0 && !ncclCuMemEnable()) {
if (ncclParamP2pDirectDisable() == 0) send->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; resources->type = P2P_DIRECT;
send->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE;
INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s", INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s",
channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr); channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr);
} else { } else {
// cuMem API support
if (ncclCuMemEnable()) {
resources->type = P2P_CUMEM;
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%x] -> %d[%x] via P2P/CUMEM%s%s",
channelId, connIndex, myInfo->rank, myInfo->cudaDev, peerInfo->rank, peerInfo->cudaDev, useReadStr, useMemcpy ? "/CE" : "");;
} else {
// Legacy CUDA IPC
resources->type = P2P_IPC;
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/IPC%s%s",
channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, useMemcpy ? "/CE" : "");
}
send->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; send->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE;
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/IPC%s%s",
channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, useMemcpy ? "/CE" : "");
} }
} else { } else {
resources->type = P2P_INTERMEDIATE;
info->rank = intermediateRank; info->rank = intermediateRank;
INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/indirect/%d[%lx]%s", INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/indirect/%d[%lx]%s",
channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, intermediateRank, channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, intermediateRank,
comm->peerInfo[intermediateRank].busId, useReadStr); comm->peerInfo[intermediateRank].busId, useReadStr);
} }
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, info->rank, &send->proxyConn)); tpProxyRank = comm->topParentRanks[info->rank];
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, tpProxyRank, &send->proxyConn));
if (useMemcpy) { if (useMemcpy) {
NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, NULL, 0, &resources->proxyInfo, sizeof(struct p2pProxyInfo))); NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgSetup, NULL, 0, &resources->proxyInfo, sizeof(struct p2pShmProxyInfo)));
info->shmSize = resources->proxyInfo.shmSize; info->shmSize = resources->proxyInfo.shmSize;
memcpy(info->shmName, resources->proxyInfo.shmName, sizeof(info->shmName)); memcpy(info->shmName, resources->proxyInfo.shmName, sizeof(info->shmName));
} else { } else {
NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &sendSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgSetup, &sendSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff)));
NCCLCHECK(p2pMap(myInfo, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&resources->devMem, &resources->sendMemIpc)); NCCLCHECK(p2pMap(comm, myInfo, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&resources->sendDevMem, &resources->sendMemIpc));
} }
return ncclSuccess; return ncclSuccess;
@ -270,7 +393,8 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
/* Create and return connect structures for this peer to connect to me */ /* Create and return connect structures for this peer to connect to me */
ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo,
struct ncclConnect* connectInfo, struct ncclConnector * recv, int channelId, int connIndex) { struct ncclConnect* connectInfo, struct ncclConnector * recv, int channelId, int connIndex) {
struct p2pRecvResources* resources; struct p2pResources* resources;
int tpProxyRank;
NCCLCHECK(ncclCalloc(&resources, 1)); NCCLCHECK(ncclCalloc(&resources, 1));
recv->transportResources = resources; recv->transportResources = resources;
int useRead, intermediateRank; int useRead, intermediateRank;
@ -284,44 +408,56 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
int recvSize = sizeof(struct ncclRecvMem); int recvSize = sizeof(struct ncclRecvMem);
// For P2P Read the SIMPLE buffer is tagged on the end of the ncclSendMem structure // For P2P Read the SIMPLE buffer is tagged on the end of the ncclSendMem structure
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) if (!(info->read && p == NCCL_PROTO_SIMPLE)) recvSize += recv->comm->buffSizes[p]; for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) if (!(info->read && p == NCCL_PROTO_SIMPLE)) recvSize += comm->buffSizes[p];
ALIGN_SIZE(recvSize, CUDA_IPC_MIN); ALIGN_SIZE(recvSize, CUDA_IPC_MIN);
if (intermediateRank == -1) { if (intermediateRank == -1) {
info->rank = myInfo->rank; info->rank = myInfo->rank;
if (myInfo->pidHash == peerInfo->pidHash && useMemcpy == 0) { if (myInfo->pidHash == peerInfo->pidHash && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0 && !ncclCuMemEnable()) {
if (ncclParamP2pDirectDisable() == 0) recv->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; resources->type = P2P_DIRECT;
recv->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE;
} else { } else {
if (ncclCuMemEnable()) {
// cuMem API support
resources->type = P2P_CUMEM;
TRACE(NCCL_INIT|NCCL_P2P,"Ring %02d : %d[%d] <- %d[%d] via P2P/CUMEM",
channelId, myInfo->rank, myInfo->cudaDev, peerInfo->rank, peerInfo->cudaDev);
} else {
// Legacy CUDA IPC
resources->type = P2P_IPC;
}
recv->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; recv->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE;
} }
} else { } else {
resources->type = P2P_INTERMEDIATE;
info->rank = intermediateRank; info->rank = intermediateRank;
} }
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 0, info->rank, &recv->proxyConn)); tpProxyRank = comm->topParentRanks[info->rank];
NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &recvSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 0, tpProxyRank, &recv->proxyConn));
NCCLCHECK(ncclProxyCallBlocking(comm, &recv->proxyConn, ncclProxyMsgSetup, &recvSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff)));
NCCLCHECK(p2pMap(myInfo, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&resources->devMem, &resources->recvMemIpc)); NCCLCHECK(p2pMap(comm, myInfo, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&resources->recvDevMem, &resources->recvMemIpc));
return ncclSuccess; return ncclSuccess;
} }
/* Connect/Send to this peer */ /* Connect/Send to this peer */
static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) { static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) {
struct p2pSendResources* resources = (struct p2pSendResources*)send->transportResources; struct p2pResources* resources = (struct p2pResources*)send->transportResources;
struct ncclRecvMem* remDevMem; struct ncclRecvMem* remDevMem = NULL;
struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo; struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo;
NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&remDevMem, &resources->recvMemIpc)); NCCLCHECK(p2pMap(comm, comm->peerInfo+rank, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&remDevMem, &resources->recvMemIpc));
char* buff = (char*)(remDevMem+1); char* buff = (char*)(remDevMem+1);
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
if (info->read && p == NCCL_PROTO_SIMPLE) { if (info->read && p == NCCL_PROTO_SIMPLE) {
/* For P2P Read the SIMPLE buffer is local (ncclSendMem) */ /* For P2P Read the SIMPLE buffer is local (ncclSendMem) */
if (resources->devMem == NULL) return ncclInternalError; // We should not use read + memcpy if (resources->sendDevMem == NULL) return ncclInternalError; // We should not use read + memcpy
send->conn.buffs[p] = (char*)(resources->devMem+1); send->conn.buffs[p] = (char*)(resources->sendDevMem+1);
} else { } else {
send->conn.buffs[p] = buff; send->conn.buffs[p] = buff;
buff += send->comm->buffSizes[p]; buff += comm->buffSizes[p];
} }
} }
@ -330,20 +466,20 @@ static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* co
send->conn.sizesFifo = resources->proxyInfo.ceRecvMem->sizesFifo; send->conn.sizesFifo = resources->proxyInfo.ceRecvMem->sizesFifo;
send->conn.head = &resources->proxyInfo.devShm->sendMem.head; send->conn.head = &resources->proxyInfo.devShm->sendMem.head;
// Send SIMPLE buff to proxy, and replace it by local buffer // Send SIMPLE buff to proxy, and replace it by local buffer
NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &send->conn.buffs[NCCL_PROTO_SIMPLE], sizeof(void*), NULL, 0)); NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgConnect, &send->conn.buffs[NCCL_PROTO_SIMPLE], sizeof(void*), NULL, 0));
send->conn.buffs[NCCL_PROTO_SIMPLE] = resources->proxyInfo.ceDevBuff; send->conn.buffs[NCCL_PROTO_SIMPLE] = resources->proxyInfo.ceDevBuff;
} else { } else {
send->conn.tail = &remDevMem->tail; send->conn.tail = &remDevMem->tail;
send->conn.head = &resources->devMem->head; send->conn.head = &resources->sendDevMem->head;
send->conn.ptrExchange = &resources->devMem->ptrExchange; send->conn.ptrExchange = &resources->sendDevMem->ptrExchange;
send->conn.redOpArgExchange = resources->devMem->redOpArgExchange; send->conn.redOpArgExchange = resources->sendDevMem->redOpArgExchange;
} }
return ncclSuccess; return ncclSuccess;
} }
/* Connect/Recv from this peer */ /* Connect/Recv from this peer */
ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) { ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) {
struct p2pRecvResources* resources = (struct p2pRecvResources*)recv->transportResources; struct p2pResources* resources = (struct p2pResources*)recv->transportResources;
struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo; struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo;
struct ncclSendMem* remDevMem = NULL; struct ncclSendMem* remDevMem = NULL;
@ -353,20 +489,22 @@ ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectIn
sprintf(shmPath, "/dev/shm/nccl-%s", info->shmName); sprintf(shmPath, "/dev/shm/nccl-%s", info->shmName);
TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmPath, info->shmSize); TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmPath, info->shmSize);
resources->shmSize = info->shmSize; resources->shmSize = info->shmSize;
// Attach to peer's SHM segment
NCCLCHECK(ncclShmOpen(shmPath, info->shmSize, (void**)&resources->shm, (void**)&resources->devShm, -1, &resources->handle)); NCCLCHECK(ncclShmOpen(shmPath, info->shmSize, (void**)&resources->shm, (void**)&resources->devShm, -1, &resources->handle));
recv->conn.tail = &resources->devShm->recvMem.tail; recv->conn.tail = &resources->devShm->recvMem.tail;
recv->conn.head = &resources->devShm->sendMem.head; recv->conn.head = &resources->devShm->sendMem.head;
} else { } else {
NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&remDevMem, &resources->sendMemIpc)); NCCLCHECK(p2pMap(comm, comm->peerInfo+rank, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&remDevMem, &resources->sendMemIpc));
recv->conn.tail = &resources->devMem->tail; struct ncclRecvMem* devMem = resources->recvDevMem;
recv->conn.tail = &devMem->tail;
recv->conn.head = &remDevMem->head; recv->conn.head = &remDevMem->head;
recv->conn.ptrExchange = &remDevMem->ptrExchange; recv->conn.ptrExchange = &remDevMem->ptrExchange;
recv->conn.redOpArgExchange = remDevMem->redOpArgExchange; recv->conn.redOpArgExchange = remDevMem->redOpArgExchange;
} }
char* buff = (char*)(resources->devMem+1); char* buff = (char*)(resources->recvDevMem+1);
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
if (info->read && p == NCCL_PROTO_SIMPLE) { if (info->read && p == NCCL_PROTO_SIMPLE) {
if (remDevMem == NULL) return ncclInternalError; // We should not use read + memcpy if (remDevMem == NULL) return ncclInternalError; // We should not use read + memcpy
@ -374,93 +512,113 @@ ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectIn
recv->conn.buffs[p] = (char*)(remDevMem+1); recv->conn.buffs[p] = (char*)(remDevMem+1);
} else { } else {
recv->conn.buffs[p] = buff; recv->conn.buffs[p] = buff;
buff += recv->comm->buffSizes[p]; buff += comm->buffSizes[p];
} }
} }
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t p2pSendFree(struct ncclConnector* send) { ncclResult_t p2pSendFree(struct ncclConnector* send) {
struct p2pSendResources* resources = (struct p2pSendResources*)send->transportResources; struct p2pResources* resources = (struct p2pResources*)send->transportResources;
if (resources) { if (resources) {
if (resources->sendMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->sendMemIpc)); if (ncclCuMemEnable()) {
if (resources->recvMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->recvMemIpc)); // cuMem API support
if (resources->sendMemIpc) NCCLCHECK(ncclCudaFree(resources->sendMemIpc));
if (resources->recvMemIpc) NCCLCHECK(ncclCudaFree(resources->recvMemIpc));
}
else {
if (resources->sendMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->sendMemIpc));
if (resources->recvMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->recvMemIpc));
}
free(resources); free(resources);
} }
return ncclSuccess; return ncclSuccess;
} }
ncclResult_t p2pRecvFree(struct ncclConnector* recv) { ncclResult_t p2pRecvFree(struct ncclConnector* recv) {
struct p2pRecvResources* resources = (struct p2pRecvResources*)recv->transportResources; struct p2pResources* resources = (struct p2pResources*)recv->transportResources;
if (resources) { if (resources) {
if (resources->sendMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->sendMemIpc)); if (ncclCuMemEnable()) {
if (resources->recvMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->recvMemIpc)); // cuMem API support
if (useMemcpy) { if (resources->sendMemIpc) NCCLCHECK(ncclCudaFree(resources->sendMemIpc));
NCCLCHECK(ncclShmClose(resources->handle)); if (resources->recvMemIpc) NCCLCHECK(ncclCudaFree(resources->recvMemIpc));
}
else {
if (resources->sendMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->sendMemIpc));
if (resources->recvMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->recvMemIpc));
if (useMemcpy) {
NCCLCHECK(ncclShmClose(resources->handle));
}
} }
free(resources); free(resources);
} }
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t p2pSendProxySetup(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t p2pSendProxySetup(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
if (useMemcpy) { if (useMemcpy) {
struct p2pProxyInfo* proxyInfo; // CE memcpy support
struct p2pShmProxyInfo* proxyInfo;
NCCLCHECK(ncclCalloc(&proxyInfo, 1)); NCCLCHECK(ncclCalloc(&proxyInfo, 1));
connection->transportResources = proxyInfo; connection->transportResources = proxyInfo;
NCCLCHECK(ncclCudaCalloc(&proxyInfo->ceDevBuff, comm->buffSizes[NCCL_PROTO_SIMPLE])); NCCLCHECK(ncclCudaCalloc(&proxyInfo->ceDevBuff, proxyState->buffSizes[NCCL_PROTO_SIMPLE]));
char shmPath[PATH_MAX]; char shmPath[PATH_MAX];
shmPath[0] = '\0'; shmPath[0] = '\0';
proxyInfo->shmSize = sizeof(struct ncclSendMem) + sizeof(struct ncclRecvMem); proxyInfo->shmSize = sizeof(struct ncclSendMem) + sizeof(struct ncclRecvMem);
// Create a SHM segment for the peer to attach to
NCCLCHECK(ncclShmOpen(shmPath, proxyInfo->shmSize, (void**)&proxyInfo->shm, (void**)&proxyInfo->devShm, 1, &proxyInfo->handle)); NCCLCHECK(ncclShmOpen(shmPath, proxyInfo->shmSize, (void**)&proxyInfo->shm, (void**)&proxyInfo->devShm, 1, &proxyInfo->handle));
TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, proxyInfo->shmSize); TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, proxyInfo->shmSize);
memcpy(proxyInfo->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(proxyInfo->shmName)); memcpy(proxyInfo->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(proxyInfo->shmName));
NCCLCHECK(ncclCudaHostCalloc(&proxyInfo->ceRecvMem, 1)); NCCLCHECK(ncclCudaHostCalloc(&proxyInfo->ceRecvMem, 1));
if (respSize != sizeof(struct p2pProxyInfo)) return ncclInternalError; if (respSize != sizeof(struct p2pShmProxyInfo)) return ncclInternalError;
memcpy(respBuff, proxyInfo, sizeof(struct p2pProxyInfo)); memcpy(respBuff, proxyInfo, sizeof(struct p2pShmProxyInfo));
} else { } else {
if (reqSize != sizeof(int)) return ncclInternalError; if (reqSize != sizeof(int)) return ncclInternalError;
int size = *((int*)reqBuff); int size = *((int*)reqBuff);
if (respSize != sizeof(struct ncclP2pBuff)) return ncclInternalError; if (respSize != sizeof(struct ncclP2pBuff)) return ncclInternalError;
struct ncclP2pBuff* p2pBuff = (struct ncclP2pBuff*)respBuff; struct ncclP2pBuff* p2pBuff = (struct ncclP2pBuff*)respBuff;
NCCLCHECK(ncclCudaCalloc((char**)&p2pBuff->directPtr, size)); NCCLCHECK(ncclP2pAllocateShareableBuffer(size, &p2pBuff->ipcDesc, &p2pBuff->directPtr));
connection->transportResources = p2pBuff->directPtr; p2pBuff->size = size;
cudaError_t res = cudaIpcGetMemHandle(&p2pBuff->devIpc, p2pBuff->directPtr); if (ncclCuMemEnable()) {
if (res != cudaSuccess) { // cuMem API support
WARN("cudaIpcGetMemHandle failed : %s", cudaGetErrorString(res)); struct p2pCuMemProxyInfo* proxyInfo;
cudaFree(p2pBuff->directPtr); NCCLCHECK(ncclCalloc(&proxyInfo, 1));
free(p2pBuff); memcpy(&proxyInfo->p2pBuff, p2pBuff, sizeof(*p2pBuff));
CUDACHECK(res); connection->transportResources = proxyInfo;
} else {
connection->transportResources = p2pBuff->directPtr;
} }
} }
*done = 1; *done = 1;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t p2pRecvProxySetup(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t p2pRecvProxySetup(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
if (reqSize != sizeof(int)) return ncclInternalError; if (reqSize != sizeof(int)) return ncclInternalError;
int size = *((int*)reqBuff); int size = *((int*)reqBuff);
if (respSize != sizeof(struct ncclP2pBuff)) return ncclInternalError; if (respSize != sizeof(struct ncclP2pBuff)) return ncclInternalError;
struct ncclP2pBuff* p2pBuff = (struct ncclP2pBuff*)respBuff; struct ncclP2pBuff* p2pBuff = (struct ncclP2pBuff*)respBuff;
NCCLCHECK(ncclCudaCalloc((char**)&p2pBuff->directPtr, size)); NCCLCHECK(ncclP2pAllocateShareableBuffer(size, &p2pBuff->ipcDesc, &p2pBuff->directPtr));
connection->transportResources = p2pBuff->directPtr; p2pBuff->size = size;
cudaError_t res = cudaIpcGetMemHandle(&p2pBuff->devIpc, p2pBuff->directPtr); if (ncclCuMemEnable()) {
if (res != cudaSuccess) { // cuMem API support
WARN("cudaIpcGetMemHandle failed : %s", cudaGetErrorString(res)); struct p2pCuMemProxyInfo* proxyInfo;
cudaFree(p2pBuff->directPtr); NCCLCHECK(ncclCalloc(&proxyInfo, 1));
free(p2pBuff); memcpy(&proxyInfo->p2pBuff, p2pBuff, sizeof(*p2pBuff));
CUDACHECK(res); connection->transportResources = proxyInfo;
} else {
connection->transportResources = p2pBuff->directPtr;
} }
*done = 1; *done = 1;
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t p2pSendProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t p2pSendProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
struct p2pProxyInfo* proxyInfo = (struct p2pProxyInfo*)connection->transportResources; struct p2pShmProxyInfo* proxyInfo = (struct p2pShmProxyInfo*)connection->transportResources;
if (reqSize != sizeof(void*)) return ncclInternalError; if (reqSize != sizeof(void*)) return ncclInternalError;
proxyInfo->recvFifo = *((char**)reqBuff); proxyInfo->recvFifo = *((char**)reqBuff);
@ -473,13 +631,14 @@ static ncclResult_t p2pSendProxyConnect(struct ncclProxyConnection* connection,
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t p2pSendProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { static ncclResult_t p2pSendProxyFree(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState) {
// CE memcpy support
if (useMemcpy) { if (useMemcpy) {
struct p2pProxyInfo* proxyInfo = (struct p2pProxyInfo*)connection->transportResources; struct p2pShmProxyInfo* proxyInfo = (struct p2pShmProxyInfo*)connection->transportResources;
if (proxyInfo) { if (proxyInfo) {
NCCLCHECK(ncclShmClose(proxyInfo->handle)); NCCLCHECK(ncclShmClose(proxyInfo->handle));
NCCLCHECK(ncclCudaHostFree(proxyInfo->ceRecvMem)); NCCLCHECK(ncclCudaHostFree(proxyInfo->ceRecvMem));
CUDACHECK(cudaFree(proxyInfo->ceDevBuff)); NCCLCHECK(ncclCudaFree(proxyInfo->ceDevBuff));
CUDACHECK(cudaStreamDestroy(proxyInfo->stream)); CUDACHECK(cudaStreamDestroy(proxyInfo->stream));
for (int i=0; i<NCCL_STEPS; i++) { for (int i=0; i<NCCL_STEPS; i++) {
CUDACHECK(cudaEventDestroy(proxyInfo->events[i])); CUDACHECK(cudaEventDestroy(proxyInfo->events[i]));
@ -487,23 +646,45 @@ static ncclResult_t p2pSendProxyFree(struct ncclProxyConnection* connection, str
free(proxyInfo); free(proxyInfo);
} }
} else { } else {
// Do not check return code as CUDA may have already shut down if (ncclCuMemEnable()) {
cudaFree(connection->transportResources); // cuMem API support
struct p2pCuMemProxyInfo *proxyInfo = (struct p2pCuMemProxyInfo *) connection->transportResources;
if (proxyInfo) {
struct ncclP2pBuff *p2pBuff = &proxyInfo->p2pBuff;
ncclP2pFreeShareableBuffer(&p2pBuff->ipcDesc);
ncclCudaFree(p2pBuff->directPtr);
free(proxyInfo);
}
} else {
// Do not check return code as CUDA may have already shut down
ncclCudaFree(connection->transportResources);
}
} }
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t p2pRecvProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { static ncclResult_t p2pRecvProxyFree(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState) {
// Do not check return code as CUDA may have already shut down if (ncclCuMemEnable()) {
cudaFree(connection->transportResources); struct p2pCuMemProxyInfo *proxyInfo = (struct p2pCuMemProxyInfo *) connection->transportResources;
if (proxyInfo) {
struct ncclP2pBuff *p2pBuff = &proxyInfo->p2pBuff;
ncclP2pFreeShareableBuffer(&p2pBuff->ipcDesc);
ncclCudaFree(p2pBuff->directPtr);
free(proxyInfo);
}
} else {
// Do not check return code as CUDA may have already shut down
ncclCudaFree(connection->transportResources);
}
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t p2pSendProxyProgress(struct ncclComm* comm, struct ncclProxyArgs* args) { // CE memcpy support
static ncclResult_t p2pSendProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) { if (args->state == ncclProxyOpReady) {
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
struct p2pProxyInfo* resources = (struct p2pProxyInfo*) (sub->connection->transportResources); struct p2pShmProxyInfo* resources = (struct p2pShmProxyInfo*) (sub->connection->transportResources);
// Round to next multiple of sliceSteps // Round to next multiple of sliceSteps
sub->base = ROUNDUP(resources->step, args->chunkSteps); sub->base = ROUNDUP(resources->step, args->chunkSteps);
sub->posted = sub->transmitted = sub->done = 0; sub->posted = sub->transmitted = sub->done = 0;
@ -513,10 +694,10 @@ static ncclResult_t p2pSendProxyProgress(struct ncclComm* comm, struct ncclProxy
args->idle = 1; args->idle = 1;
if (args->state == ncclProxyOpProgress) { if (args->state == ncclProxyOpProgress) {
int p = args->protocol; int p = args->protocol;
int stepSize = comm->buffSizes[p] / NCCL_STEPS; int stepSize = proxyState->buffSizes[p] / NCCL_STEPS;
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
struct p2pProxyInfo* resources = (struct p2pProxyInfo*) (sub->connection->transportResources); struct p2pShmProxyInfo* resources = (struct p2pShmProxyInfo*) (sub->connection->transportResources);
if (p != NCCL_PROTO_SIMPLE) { // Only Simple uses cudaMemcpy if (p != NCCL_PROTO_SIMPLE) { // Only Simple uses cudaMemcpy
resources->step = sub->base + sub->nsteps; resources->step = sub->base + sub->nsteps;
args->done++; args->done++;

View File

@ -85,7 +85,7 @@ static ncclResult_t shmSendSetup(struct ncclComm* comm, struct ncclTopoGraph* gr
shmPath[0] = '\0'; shmPath[0] = '\0';
int shmSize = sizeof(struct ncclSendMem); int shmSize = sizeof(struct ncclSendMem);
if (shmLocality == SHM_SEND_SIDE) { if (shmLocality == SHM_SEND_SIDE) {
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) shmSize += send->comm->buffSizes[p]; for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) shmSize += comm->buffSizes[p];
} }
info->shmSize = resources->shmSize = shmSize; info->shmSize = resources->shmSize = shmSize;
NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1, &resources->hostHandle)); NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1, &resources->hostHandle));
@ -108,7 +108,7 @@ static ncclResult_t shmRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* gr
shmPath[0] = '\0'; shmPath[0] = '\0';
int shmSize = sizeof(struct ncclRecvMem); int shmSize = sizeof(struct ncclRecvMem);
if (shmLocality == SHM_RECV_SIDE) { if (shmLocality == SHM_RECV_SIDE) {
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) shmSize += recv->comm->buffSizes[p]; for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) shmSize += comm->buffSizes[p];
} }
info->shmSize = resources->shmSize = shmSize; info->shmSize = resources->shmSize = shmSize;
NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1, &resources->hostHandle)); NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1, &resources->hostHandle));
@ -146,7 +146,7 @@ static ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* co
char* buff = shmLocality == SHM_SEND_SIDE ? (char*)(resources->devHostMem+1) : (char*)(resources->devRemHostMem+1); char* buff = shmLocality == SHM_SEND_SIDE ? (char*)(resources->devHostMem+1) : (char*)(resources->devRemHostMem+1);
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
send->conn.buffs[p] = buff; send->conn.buffs[p] = buff;
buff += send->comm->buffSizes[p]; buff += comm->buffSizes[p];
} }
send->conn.tail = &resources->devRemHostMem->tail; send->conn.tail = &resources->devRemHostMem->tail;
send->conn.head = &resources->devHostMem->head; send->conn.head = &resources->devHostMem->head;
@ -155,9 +155,11 @@ static ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* co
send->conn.sizesFifo = resources->devRemHostMem->sizesFifo; send->conn.sizesFifo = resources->devRemHostMem->sizesFifo;
} }
if (useMemcpySend) { if (useMemcpySend) {
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_SHM, 1, comm->rank, &send->proxyConn)); int tpProxyRank;
tpProxyRank = comm->topParentRanks[comm->rank];
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_SHM, 1, tpProxyRank, &send->proxyConn));
struct shmProxyInfo proxyInfo = { NULL, NULL, send->conn.buffs[NCCL_PROTO_SIMPLE], resources->hostMem, resources->remHostMem }; struct shmProxyInfo proxyInfo = { NULL, NULL, send->conn.buffs[NCCL_PROTO_SIMPLE], resources->hostMem, resources->remHostMem };
NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo)));
send->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo; send->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo;
send->conn.tail = &proxyInfo.ceRecvMem->tail; send->conn.tail = &proxyInfo.ceRecvMem->tail;
send->conn.sizesFifo = proxyInfo.ceRecvMem->sizesFifo; send->conn.sizesFifo = proxyInfo.ceRecvMem->sizesFifo;
@ -179,7 +181,7 @@ static ncclResult_t shmRecvConnect(struct ncclComm* comm, struct ncclConnect* co
char* buff = shmLocality == SHM_RECV_SIDE ? (char*)(resources->devHostMem+1) : (char*)(resources->devRemHostMem+1); char* buff = shmLocality == SHM_RECV_SIDE ? (char*)(resources->devHostMem+1) : (char*)(resources->devRemHostMem+1);
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) { for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
recv->conn.buffs[p] = buff; recv->conn.buffs[p] = buff;
buff += recv->comm->buffSizes[p]; buff += comm->buffSizes[p];
} }
recv->conn.head = &resources->devRemHostMem->head; recv->conn.head = &resources->devRemHostMem->head;
recv->conn.tail = &resources->devHostMem->tail; recv->conn.tail = &resources->devHostMem->tail;
@ -187,7 +189,7 @@ static ncclResult_t shmRecvConnect(struct ncclComm* comm, struct ncclConnect* co
if (useMemcpyRecv) { if (useMemcpyRecv) {
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_SHM, 0, comm->rank, &recv->proxyConn)); NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_SHM, 0, comm->rank, &recv->proxyConn));
struct shmProxyInfo proxyInfo = { NULL, NULL, recv->conn.buffs[NCCL_PROTO_SIMPLE], resources->remHostMem, resources->hostMem }; struct shmProxyInfo proxyInfo = { NULL, NULL, recv->conn.buffs[NCCL_PROTO_SIMPLE], resources->remHostMem, resources->hostMem };
NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); NCCLCHECK(ncclProxyCallBlocking(comm, &recv->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo)));
recv->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo; recv->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo;
recv->conn.tail = &proxyInfo.ceRecvMem->tail; recv->conn.tail = &proxyInfo.ceRecvMem->tail;
} }
@ -214,12 +216,12 @@ static ncclResult_t shmRecvFree(struct ncclConnector* recv) {
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t shmSendProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t shmSendProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
struct shmProxyInfo* proxyInfo; struct shmProxyInfo* proxyInfo;
NCCLCHECK(ncclCalloc(&proxyInfo, 1)); NCCLCHECK(ncclCalloc(&proxyInfo, 1));
if (reqSize != sizeof(struct shmProxyInfo)) return ncclInternalError; if (reqSize != sizeof(struct shmProxyInfo)) return ncclInternalError;
memcpy(proxyInfo, reqBuff, reqSize); memcpy(proxyInfo, reqBuff, reqSize);
NCCLCHECK(ncclCudaCalloc(&proxyInfo->devFifo, comm->buffSizes[NCCL_PROTO_SIMPLE])); NCCLCHECK(ncclCudaCalloc(&proxyInfo->devFifo, proxyState->buffSizes[NCCL_PROTO_SIMPLE]));
NCCLCHECK(ncclCudaHostCalloc(&proxyInfo->ceRecvMem, 1)); NCCLCHECK(ncclCudaHostCalloc(&proxyInfo->ceRecvMem, 1));
CUDACHECK(cudaStreamCreateWithFlags(&proxyInfo->stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&proxyInfo->stream, cudaStreamNonBlocking));
for (int i=0; i<NCCL_STEPS; i++) { for (int i=0; i<NCCL_STEPS; i++) {
@ -232,12 +234,12 @@ static ncclResult_t shmSendProxyConnect(struct ncclProxyConnection* connection,
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t shmRecvProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { static ncclResult_t shmRecvProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
struct shmProxyInfo* proxyInfo; struct shmProxyInfo* proxyInfo;
NCCLCHECK(ncclCalloc(&proxyInfo, 1)); NCCLCHECK(ncclCalloc(&proxyInfo, 1));
if (reqSize != sizeof(struct shmProxyInfo)) return ncclInternalError; if (reqSize != sizeof(struct shmProxyInfo)) return ncclInternalError;
memcpy(proxyInfo, reqBuff, reqSize); memcpy(proxyInfo, reqBuff, reqSize);
NCCLCHECK(ncclCudaCalloc(&proxyInfo->devFifo, comm->buffSizes[NCCL_PROTO_SIMPLE])); NCCLCHECK(ncclCudaCalloc(&proxyInfo->devFifo, proxyState->buffSizes[NCCL_PROTO_SIMPLE]));
NCCLCHECK(ncclCudaHostCalloc(&proxyInfo->ceRecvMem, 1)); NCCLCHECK(ncclCudaHostCalloc(&proxyInfo->ceRecvMem, 1));
CUDACHECK(cudaStreamCreateWithFlags(&proxyInfo->stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&proxyInfo->stream, cudaStreamNonBlocking));
for (int i=0; i<NCCL_STEPS; i++) { for (int i=0; i<NCCL_STEPS; i++) {
@ -250,12 +252,12 @@ static ncclResult_t shmRecvProxyConnect(struct ncclProxyConnection* connection,
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t shmSendProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { static ncclResult_t shmSendProxyFree(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState) {
struct shmProxyInfo* resources = (struct shmProxyInfo*)connection->transportResources; struct shmProxyInfo* resources = (struct shmProxyInfo*)connection->transportResources;
if (resources) { if (resources) {
CUDACHECK(cudaStreamDestroy(resources->stream)); CUDACHECK(cudaStreamDestroy(resources->stream));
CUDACHECK(cudaFree(resources->devFifo)); NCCLCHECK(ncclCudaFree(resources->devFifo));
NCCLCHECK(ncclCudaHostFree(resources->ceRecvMem)); NCCLCHECK(ncclCudaHostFree(resources->ceRecvMem));
for (int i=0; i<NCCL_STEPS; i++) { for (int i=0; i<NCCL_STEPS; i++) {
CUDACHECK(cudaEventDestroy(resources->events[i])); CUDACHECK(cudaEventDestroy(resources->events[i]));
@ -265,12 +267,12 @@ static ncclResult_t shmSendProxyFree(struct ncclProxyConnection* connection, str
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t shmRecvProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { static ncclResult_t shmRecvProxyFree(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState) {
struct shmProxyInfo* resources = (struct shmProxyInfo*)connection->transportResources; struct shmProxyInfo* resources = (struct shmProxyInfo*)connection->transportResources;
if (resources) { if (resources) {
CUDACHECK(cudaStreamDestroy(resources->stream)); CUDACHECK(cudaStreamDestroy(resources->stream));
CUDACHECK(cudaFree(resources->devFifo)); NCCLCHECK(ncclCudaFree(resources->devFifo));
NCCLCHECK(ncclCudaHostFree(resources->ceRecvMem)); NCCLCHECK(ncclCudaHostFree(resources->ceRecvMem));
for (int i=0; i<NCCL_STEPS; i++) { for (int i=0; i<NCCL_STEPS; i++) {
CUDACHECK(cudaEventDestroy(resources->events[i])); CUDACHECK(cudaEventDestroy(resources->events[i]));
@ -280,7 +282,7 @@ static ncclResult_t shmRecvProxyFree(struct ncclProxyConnection* connection, str
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t shmSendProxyProgress(struct ncclComm* comm, struct ncclProxyArgs* args) { static ncclResult_t shmSendProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) { if (args->state == ncclProxyOpReady) {
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
@ -294,7 +296,7 @@ static ncclResult_t shmSendProxyProgress(struct ncclComm* comm, struct ncclProxy
args->idle = 1; args->idle = 1;
if (args->state == ncclProxyOpProgress) { if (args->state == ncclProxyOpProgress) {
int p = args->protocol; int p = args->protocol;
int stepSize = comm->buffSizes[p] / NCCL_STEPS; int stepSize = proxyState->buffSizes[p] / NCCL_STEPS;
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
struct shmProxyInfo* resources = (struct shmProxyInfo*) (sub->connection->transportResources); struct shmProxyInfo* resources = (struct shmProxyInfo*) (sub->connection->transportResources);
@ -339,7 +341,7 @@ static ncclResult_t shmSendProxyProgress(struct ncclComm* comm, struct ncclProxy
return ncclSuccess; return ncclSuccess;
} }
static ncclResult_t shmRecvProxyProgress(struct ncclComm* comm, struct ncclProxyArgs* args) { static ncclResult_t shmRecvProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) { if (args->state == ncclProxyOpReady) {
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
@ -353,7 +355,7 @@ static ncclResult_t shmRecvProxyProgress(struct ncclComm* comm, struct ncclProxy
args->idle = 1; args->idle = 1;
if (args->state == ncclProxyOpProgress) { if (args->state == ncclProxyOpProgress) {
int p = args->protocol; int p = args->protocol;
int stepSize = comm->buffSizes[p] / NCCL_STEPS; int stepSize = proxyState->buffSizes[p] / NCCL_STEPS;
for (int s=0; s<args->nsubs; s++) { for (int s=0; s<args->nsubs; s++) {
struct ncclProxySubArgs* sub = args->subs+s; struct ncclProxySubArgs* sub = args->subs+s;
struct shmProxyInfo* resources = (struct shmProxyInfo*) (sub->connection->transportResources); struct shmProxyInfo* resources = (struct shmProxyInfo*) (sub->connection->transportResources);