diff --git a/ext-net/dummy/plugin.c b/ext-net/dummy/plugin.c index 67d7d88..dcf0a23 100644 --- a/ext-net/dummy/plugin.c +++ b/ext-net/dummy/plugin.c @@ -16,9 +16,11 @@ __hidden ncclResult_t pluginPtrSupport(int dev, int* supportedTypes) { return nc __hidden ncclResult_t pluginListen(int dev, void* handle, void** listenComm) { return ncclInternalError; } __hidden ncclResult_t pluginConnect(int dev, void* handle, void** sendComm) { return ncclInternalError; } __hidden ncclResult_t pluginAccept(void* listenComm, void** recvComm) { return ncclInternalError; } -__hidden ncclResult_t pluginIsend(void* sendComm, void* data, int size, int type, void** request) { return ncclInternalError; } -__hidden ncclResult_t pluginIrecv(void* recvComm, void* data, int size, int type, void** request) { return ncclInternalError; } -__hidden ncclResult_t pluginFlush(void* recvComm, void* data, int size) { return ncclInternalError; } +__hidden ncclResult_t pluginRegMr(void* collComm, void* data, int size, int type, void** mhandle) { return ncclInternalError; } +__hidden ncclResult_t pluginDeregMr(void* collComm, void* mhandle) { return ncclInternalError;} +__hidden ncclResult_t pluginIsend(void* sendComm, void* data, int size, void* mhandle, void** request) { return ncclInternalError; } +__hidden ncclResult_t pluginIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) { return ncclInternalError; } +__hidden ncclResult_t pluginFlush(void* recvComm, void* data, int size, void* mhandle) { return ncclInternalError; } __hidden ncclResult_t pluginTest(void* request, int* done, int* size) { return ncclInternalError; } __hidden ncclResult_t pluginCloseSend(void* sendComm) { return ncclInternalError; } __hidden ncclResult_t pluginCloseRecv(void* recvComm) { return ncclInternalError; } @@ -33,6 +35,8 @@ ncclNet_t NCCL_PLUGIN_SYMBOL = { pluginListen, pluginConnect, pluginAccept, + pluginRegMr, + pluginDeregMr, pluginIsend, pluginIrecv, pluginFlush, @@ -41,3 +45,36 @@ ncclNet_t NCCL_PLUGIN_SYMBOL = { pluginCloseRecv, pluginCloseListen }; + +__hidden ncclResult_t pluginCollNetInit(ncclDebugLogger_t logFunction) { return ncclSuccess; } +__hidden ncclResult_t pluginCollNetDevices(int* ndev) { *ndev = 0; return ncclSuccess; } +__hidden ncclResult_t pluginCollNetPciPath(int dev, char** path) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetPtrSupport(int dev, int* supportedTypes) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetListen(int dev, void* handle, void** listenComm) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetConnect(void* handles[], int nranks, int rank, void* listenComm, void** collComm) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetReduceSupport(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetRegMr(void* collComm, void* data, int size, int type, void** mhandle) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetDeregMr(void* collComm, void* mhandle) { return ncclInternalError;} +__hidden ncclResult_t pluginCollNetIallreduce(void* collComm, void* sendData, void* recvData, int count, ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetFlush(void* collComm, void* data, int size, void* mhandle) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetTest(void* request, int* done, int* size) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetCloseColl(void* collComm) { return ncclInternalError; } +__hidden ncclResult_t pluginCollNetCloseListen(void* listenComm) { return ncclInternalError; } + +ncclCollNet_t NCCL_COLLNET_PLUGIN_SYMBOL = { + "Dummy", + pluginCollNetInit, + pluginCollNetDevices, + pluginCollNetPciPath, + pluginCollNetPtrSupport, + pluginCollNetListen, + pluginCollNetConnect, + pluginCollNetReduceSupport, + pluginCollNetRegMr, + pluginCollNetDeregMr, + pluginCollNetIallreduce, + pluginCollNetFlush, + pluginCollNetTest, + pluginCollNetCloseColl, + pluginCollNetCloseListen +}; diff --git a/makefiles/common.mk b/makefiles/common.mk index d4c353b..64f8d2d 100644 --- a/makefiles/common.mk +++ b/makefiles/common.mk @@ -55,7 +55,7 @@ CXXFLAGS := -DCUDA_MAJOR=$(CUDA_MAJOR) -DCUDA_MINOR=$(CUDA_MINOR) -fPIC -fvisi # Maxrregcount needs to be set accordingly to NCCL_MAX_NTHREADS (otherwise it will cause kernel launch errors) # 512 : 120, 640 : 96, 768 : 80, 1024 : 60 # We would not have to set this if we used __launch_bounds__, but this only works on kernels, not on functions. -NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 -Xptxas -maxrregcount=96 -Xfatbin -compress-all +NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xptxas -maxrregcount=96 -Xfatbin -compress-all # Use addprefix so that we can specify more than one path NVLDFLAGS := -L${CUDA_LIB} -lcudart -lrt diff --git a/makefiles/version.mk b/makefiles/version.mk index 87c2017..833ab99 100644 --- a/makefiles/version.mk +++ b/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 -NCCL_MINOR := 9 -NCCL_PATCH := 9 +NCCL_MINOR := 10 +NCCL_PATCH := 3 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/pkg/debian/libnccl-dev.install.in b/pkg/debian/libnccl-dev.install.in index 9cedf3e..13eca26 100644 --- a/pkg/debian/libnccl-dev.install.in +++ b/pkg/debian/libnccl-dev.install.in @@ -1,3 +1,4 @@ include/nccl.h /usr/include +include/nccl_net.h /usr/include lib/libnccl.so /usr/lib/${pkg:MultiArch} lib/libnccl_static.a /usr/lib/${pkg:MultiArch} diff --git a/pkg/redhat/nccl.spec.in b/pkg/redhat/nccl.spec.in index f1cce5c..8e5aed6 100644 --- a/pkg/redhat/nccl.spec.in +++ b/pkg/redhat/nccl.spec.in @@ -7,7 +7,7 @@ Group: Development/Libraries License: BSD URL: http://developer.nvidia.com/nccl Source0: nccl_${nccl:Major}.${nccl:Minor}.${nccl:Patch}${nccl:Suffix}-${pkg:Revision}+cuda${cuda:Major}.${cuda:Minor}_${pkg:Arch}.txz -Prereq: /sbin/ldconfig +Requires(pre,preun): /sbin/ldconfig %description NCCL (pronounced "Nickel") is a stand-alone library of standard collective @@ -46,6 +46,7 @@ ln -s libnccl.so.${nccl:Major}.${nccl:Minor}.${nccl:Patch} $RPM_BUILD_ROOT/%{_li # devel install -m 755 -d $RPM_BUILD_ROOT/%{_includedir} install -m 644 include/nccl.h $RPM_BUILD_ROOT/%{_includedir} +install -m 644 include/nccl_net.h $RPM_BUILD_ROOT/%{_includedir} ln -s libnccl.so.${nccl:Major} $RPM_BUILD_ROOT/%{_libdir}/libnccl.so # static @@ -64,6 +65,7 @@ rm -rf $RPM_BUILD_ROOT %doc LICENSE.txt %defattr(-,root,root,-) %{_includedir}/nccl.h +%{_includedir}/nccl_net.h %{_libdir}/libnccl.so %files static diff --git a/src/bootstrap.cc b/src/bootstrap.cc index ff58c42..021a49a 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -43,7 +43,7 @@ ncclResult_t bootstrapNetInit() { } char line[SOCKET_NAME_MAXLEN+MAX_IF_NAME_SIZE+2]; sprintf(line, " %s:", bootstrapNetIfName); - socketToString(&bootstrapNetIfAddr.sa, line+strlen(line)); + socketToString(&bootstrapNetIfAddr, line+strlen(line)); INFO(NCCL_INIT, "Bootstrap : Using%s", line); bootstrapNetInitDone = 1; } @@ -55,27 +55,27 @@ ncclResult_t bootstrapNetInit() { /* Socket Interface Selection type */ enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 }; -static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd) { - struct sockaddr_in sockaddr; - socklen_t socklen = sizeof(struct sockaddr_in); - SYSCHECKVAL(accept(listenFd, (struct sockaddr*)&sockaddr, &socklen), "accept", *recvFd); +static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd, union socketAddress *addr) { + struct sockaddr *saddr = &addr->sa; + socklen_t socklen = sizeof(union socketAddress); + SYSCHECKVAL(accept(listenFd, saddr, &socklen), "accept", *recvFd); return ncclSuccess; } // Additional sync functions -static ncclResult_t bootstrapNetSend(int fd, void* data, int size) { - NCCLCHECK(socketSend(fd, &size, sizeof(int))); - NCCLCHECK(socketSend(fd, data, size)); +static ncclResult_t bootstrapNetSend(int fd, union socketAddress *addr, void* data, int size) { + NCCLCHECK(socketSend(fd, addr, &size, sizeof(int))); + NCCLCHECK(socketSend(fd, addr, data, size)); return ncclSuccess; } -static ncclResult_t bootstrapNetRecv(int fd, void* data, int size) { +static ncclResult_t bootstrapNetRecv(int fd, union socketAddress *addr, void* data, int size) { int recvSize; - NCCLCHECK(socketRecv(fd, &recvSize, sizeof(int))); + NCCLCHECK(socketRecv(fd, addr, &recvSize, sizeof(int))); if (recvSize > size) { WARN("Message truncated : received %d bytes instead of %d", recvSize, size); return ncclInternalError; } - NCCLCHECK(socketRecv(fd, data, std::min(recvSize, size))); + NCCLCHECK(socketRecv(fd, addr, data, std::min(recvSize, size))); return ncclSuccess; } @@ -111,8 +111,9 @@ static void *bootstrapRoot(void* args) { /* Receive addresses from all ranks */ do { int tmpFd; - NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd), res, out); - NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &info, sizeof(info)), res, out); + union socketAddress addr; + NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd, &addr), res, out); + NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &addr, &info, sizeof(info)), res, out); close(tmpFd); if (c == 0) { @@ -145,7 +146,7 @@ static void *bootstrapRoot(void* args) { int next = (r+1) % nranks; int tmpSendFd; NCCLCHECKGOTO(connectAddress(&tmpSendFd, rankAddressesRoot+r), res, out); - NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddresses+next, sizeof(union socketAddress)), res, out); + NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddressesRoot+r, rankAddresses+next, sizeof(union socketAddress)), res, out); close(tmpSendFd); } TRACE(NCCL_INIT, "SENT OUT ALL %d HANDLES", nranks); @@ -193,6 +194,7 @@ struct unexConn { int peer; int tag; int fd; + union socketAddress addr; struct unexConn* next; }; @@ -207,6 +209,7 @@ struct extState { int extListenFd; int extRingRecvFd; int extRingSendFd; + union socketAddress extRingRecvAddr, extRingSendAddr; union socketAddress* peerCommAddresses; union socketAddress* peerAllocAddresses; struct unexConn* unexpectedConnections; @@ -221,9 +224,9 @@ struct extState { #define MAX_SEGMENTS 128 -static ncclResult_t remoteAlloc(void** ptr, int fd) { +static ncclResult_t remoteAlloc(void** ptr, int fd, union socketAddress *addr) { size_t size; - NCCLCHECK(socketRecv(fd, &size, sizeof(size_t))); + NCCLCHECK(socketRecv(fd, addr, &size, sizeof(size_t))); cudaIpcMemHandle_t devIpc; NCCLCHECK(ncclCudaCalloc((char**)ptr, size)); cudaError_t res = cudaIpcGetMemHandle(&devIpc, *ptr); @@ -233,9 +236,9 @@ static ncclResult_t remoteAlloc(void** ptr, int fd) { CUDACHECK(res); } // The CUDA IPC - NCCLCHECK(socketSend(fd, &devIpc, sizeof(cudaIpcMemHandle_t))); + NCCLCHECK(socketSend(fd, addr, &devIpc, sizeof(cudaIpcMemHandle_t))); // And the direct pointer - NCCLCHECK(socketSend(fd, ptr, sizeof(void*))); + NCCLCHECK(socketSend(fd, addr, ptr, sizeof(void*))); return ncclSuccess; } @@ -267,11 +270,12 @@ void* ncclRemoteMemAllocationService(void* args) { } if (pollfds[MAX_SEGMENTS].revents) { int s = 0; + union socketAddress addr; while (segments[s] != NULL && s < MAX_SEGMENTS) s++; - if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd) != ncclSuccess) { + if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd, &addr) != ncclSuccess) { pollfds[s].fd = -1; } else { - if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd) != ncclSuccess)) { + if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd, &addr) != ncclSuccess)) { WARN("[Rem Allocator] Allocation failed (segment %d, fd %d)", s, pollfds[s].fd); close(pollfds[s].fd); pollfds[s].fd = -1; @@ -306,10 +310,11 @@ ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, int fd; ncclResult_t res; *id = -1; - NCCLCHECK(connectAddress(&fd, state->peerAllocAddresses+rank)); - NCCLCHECKGOTO(socketSend(fd, &size, sizeof(size_t)), res, end); - NCCLCHECKGOTO(socketRecv(fd, ipc, sizeof(cudaIpcMemHandle_t)), res, end); - NCCLCHECKGOTO(socketRecv(fd, ptr, sizeof(void*)), res, end); + union socketAddress *addr = state->peerAllocAddresses+rank; + NCCLCHECK(connectAddress(&fd, addr)); + NCCLCHECKGOTO(socketSend(fd, addr, &size, sizeof(size_t)), res, end); + NCCLCHECKGOTO(socketRecv(fd, addr, ipc, sizeof(cudaIpcMemHandle_t)), res, end); + NCCLCHECKGOTO(socketRecv(fd, addr, ptr, sizeof(void*)), res, end); *id = fd; end: return res; @@ -353,19 +358,19 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS // send info on my listening socket to root union socketAddress* rootAddr = (union socketAddress*)id; NCCLCHECK(connectAddress(&tmpSendFd, rootAddr)); - NCCLCHECK(bootstrapNetSend(tmpSendFd, &info, sizeof(info))); + NCCLCHECK(bootstrapNetSend(tmpSendFd, rootAddr, &info, sizeof(info))); close(tmpSendFd); // get info on my "next" rank in the bootstrap ring from root - union socketAddress extAddressNext; - NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd)); - NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &extAddressNext, sizeof(extAddressNext))); + union socketAddress addr; + NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd, &addr)); + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &state->extRingSendAddr, sizeof(state->extRingSendAddr))); close(tmpRecvFd); close(extListenFdRoot); - NCCLCHECK(connectAddress(&state->extRingSendFd, &extAddressNext)); + NCCLCHECK(connectAddress(&state->extRingSendFd, &state->extRingSendAddr)); // Accept the connect request from the previous rank in the AllGather ring - NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd)); + NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd, &state->extRingRecvAddr)); // AllGather all listen handlers NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks)); @@ -403,9 +408,9 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) { size_t sslice = (rank - i + nranks) % nranks; // Send slice to the right - NCCLCHECK(bootstrapNetSend(state->extRingSendFd, data+sslice*size, size)); + NCCLCHECK(bootstrapNetSend(state->extRingSendFd, &state->extRingSendAddr, data+sslice*size, size)); // Recv slice from the left - NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, data+rslice*size, size)); + NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, &state->extRingRecvAddr, data+rslice*size, size)); } TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size); @@ -415,21 +420,44 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) { ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) { struct extState* state = (struct extState*)commState; int tmpSendFd; - NCCLCHECK(connectAddress(&tmpSendFd, state->peerCommAddresses+peer)); - NCCLCHECK(bootstrapNetSend(tmpSendFd, &state->rank, sizeof(int))); - NCCLCHECK(bootstrapNetSend(tmpSendFd, &tag, sizeof(int))); - NCCLCHECK(bootstrapNetSend(tmpSendFd, data, size)); + union socketAddress *addr = state->peerCommAddresses+peer; + NCCLCHECK(connectAddress(&tmpSendFd, addr)); + NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, &state->rank, sizeof(int))); + NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, &tag, sizeof(int))); + NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, data, size)); close(tmpSendFd); return ncclSuccess; } -ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd) { +ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks) { + if (nranks == 1) return ncclSuccess; + TRACE(NCCL_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag); + + /* Simple intra process barrier + * + * Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet, + * "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988" + */ + int data[1]; + for (int mask=1; maskpeer = peer; unex->tag = tag; unex->fd = fd; + unex->addr = *addr; // Enqueue struct unexConn* list = state->unexpectedConnections; @@ -442,7 +470,7 @@ ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd return ncclSuccess; } -int unexpectedDequeue(struct extState* state, int peer, int tag) { +int unexpectedDequeue(struct extState* state, int peer, int tag, union socketAddress *addr) { struct unexConn* elem = state->unexpectedConnections; struct unexConn* prev = NULL; while (elem) { @@ -453,6 +481,7 @@ int unexpectedDequeue(struct extState* state, int peer, int tag) { prev->next = elem->next; } int fd = elem->fd; + *addr = elem->addr; free(elem); return fd; } @@ -467,27 +496,29 @@ ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int s struct extState* state = (struct extState*)commState; int tmpRecvFd; + union socketAddress addr; // Search unexpected connections first - if ((tmpRecvFd = unexpectedDequeue(state, peer, tag)) != -1) { - NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size)); + if ((tmpRecvFd = unexpectedDequeue(state, peer, tag, &addr)) != -1) { + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, ((char*)data), size)); close(tmpRecvFd); return ncclSuccess; } // Then look for new connections while (1) { - NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd)); + union socketAddress addr; + NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd, &addr)); int newPeer, newTag; - NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newPeer, sizeof(int))); - NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newTag, sizeof(int))); + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &newPeer, sizeof(int))); + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &newTag, sizeof(int))); if (newPeer == peer && newTag == tag) { - NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size)); + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, ((char*)data), size)); close(tmpRecvFd); return ncclSuccess; } // Unexpected connection. Save for later. - NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, tmpRecvFd)); + NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, tmpRecvFd, &addr)); } } diff --git a/src/collectives/device/Makefile b/src/collectives/device/Makefile index 3796fb1..ead98ec 100644 --- a/src/collectives/device/Makefile +++ b/src/collectives/device/Makefile @@ -1,5 +1,5 @@ # -# Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. # # See LICENSE.txt for license information # @@ -32,7 +32,7 @@ all_deps: $(DEPENDFILES) $(RULESFILE) : @printf "Generating %-35s > %s\n" rules $@ @mkdir -p $(OBJDIR) - @./gen_rules.sh $(OBJDIR) > $@ + @CUDA_MAJOR=${CUDA_MAJOR} CUDA_MINOR=${CUDA_MINOR} ./gen_rules.sh $(OBJDIR) > $@ -include $(RULESFILE) diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index e057dc8..3d781af 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -5,204 +5,95 @@ ************************************************************************/ #include "devcomm.h" -#include "primitives.h" #include "collectives.h" +#include "primitives.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); - const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS; - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*(ssize_t)chunkSize; - const ssize_t size = args->coll.count; +namespace { + template + __device__ void runRing(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; + ncclRing *ring = &ncclShmem.channel.ring; + const int *ringRanks = ring->devUserRanks; + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLGATHER_CHUNKSTEPS : 1)); + // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. + const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2); + const int nranks = ncclShmem.comm.nRanks; + const ssize_t loopSize = nChannels*int(chunkSize); + const ssize_t size = args->coll.count; - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; + T *inputBuf = (T*)args->sendbuff; + T *outputBuf = (T*)args->recvbuff; + Primitives, 1, Proto> + prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf); - ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0); - - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); - ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); - ssize_t chunkOffset = gridOffset + bid*realChunkSize; - - /////////////// begin AllGather steps /////////////// - ssize_t offset; - int nelem = min(realChunkSize, size-chunkOffset); - int rankDest; - - // step 0: push data to next GPU - rankDest = ring->devUserRanks[0]; - offset = chunkOffset + rankDest * size; - - if (thisInput + chunkOffset == thisOutput + offset) { // In place - prims.directSend(thisInput+chunkOffset, offset, nelem); - } else { - prims.directCopySend(thisInput+chunkOffset, thisOutput+offset, offset, nelem); - } - - // k-2 steps: copy to next GPU - for (int j=1; jdevUserRanks[nranks-j]; - offset = chunkOffset + rankDest * size; - - prims.directRecvCopySend(thisOutput+offset, offset, nelem); - } - - // Make final copy from buffer to dest. - rankDest = ring->devUserRanks[1]; - offset = chunkOffset + rankDest * size; - - // Final wait/copy. - prims.directRecv(thisOutput+offset, offset, nelem); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t realChunkSize; + if (Proto::Id == NCCL_PROTO_SIMPLE) { + realChunkSize = min(chunkSize, divUp(size-gridOffset,nChannels)); + realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); } - } -}; + else if (Proto::Id == NCCL_PROTO_LL) + realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize; + else if (Proto::Id == NCCL_PROTO_LL128) + realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128); + realChunkSize = int(realChunkSize); -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); - ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; + ssize_t chunkOffset = gridOffset + int(bid*realChunkSize); - ncclLLPrimitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm); + /////////////// begin AllGather steps /////////////// + ssize_t offset; + int nelem = min(realChunkSize, size-chunkOffset); + int rankDest; - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; + // step 0: push data to next GPU + rankDest = ringRanks[0]; + offset = chunkOffset + rankDest * size; - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - if (size-gridOffset < loopSize) { - chunkSize = args->coll.lastChunkSize; - } - ssize_t chunkOffset = gridOffset + bid*chunkSize; - - /////////////// begin AllGather steps /////////////// - ssize_t offset; - int nelem = min(chunkSize, size-chunkOffset); - int rankDest; - - // step 0: push data to next GPU - rankDest = ring->devUserRanks[0]; - offset = chunkOffset + rankDest * size; - - if (thisInput + chunkOffset == thisOutput + offset) { // In place - LLprims.send(thisInput+chunkOffset, nelem); - } else { - LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem); - } - - // k-2 steps: copy to next GPU - for (int j=1; jdevUserRanks[nranks-j]; - offset = chunkOffset + rankDest * size; - - LLprims.recvCopySend(thisOutput+offset, nelem); - } - - // step k-1: final store - rankDest = ring->devUserRanks[1]; - offset = chunkOffset + rankDest * size; - - LLprims.recv(thisOutput+offset, nelem); + if (inputBuf + chunkOffset == outputBuf + offset) { // In place + prims.directSend(chunkOffset, offset, nelem); + } else { + prims.directCopySend(chunkOffset, offset, offset, nelem); } - } -}; -#include "prims_ll128.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); - ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T)); - // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. - const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; - - ncclLL128Primitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm); - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize); - - ssize_t chunkOffset = gridOffset + bid*chunkSize; - - /////////////// begin AllGather steps /////////////// - ssize_t offset; - int nelem = min(chunkSize, size-chunkOffset); - int rankDest; - - // step 0: push data to next GPU - rankDest = ring->devUserRanks[0]; + // k-2 steps: copy to next GPU + for (int j=1; jdevUserRanks[nranks-j]; - offset = chunkOffset + rankDest * size; - - LLprims.recvCopySend(thisOutput+offset, nelem); - } - - // step k-1: final store - rankDest = ring->devUserRanks[1]; - offset = chunkOffset + rankDest * size; - - LLprims.recv(thisOutput+offset, nelem); + prims.directRecvCopySend(offset, offset, nelem); } + + // Make final copy from buffer to dest. + rankDest = ringRanks[1]; + offset = chunkOffset + rankDest * size; + + // Final wait/copy. + prims.directRecv(offset, nelem); } + } +} + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + using Proto = ProtoSimple; + runRing(args); + } }; -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } }; -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } }; - diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index 88acb13..bd088e9 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -5,566 +5,384 @@ ************************************************************************/ #include "devcomm.h" -#include "primitives.h" #include "collectives.h" +#include "primitives.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { +namespace { + template + __device__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; + const int nthreads = args->nThreads; const int bid = args->coll.bid; const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); - const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS; - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*(ssize_t)chunkSize; + ncclRing *ring = &ncclShmem.channel.ring; + int ringIx = ring->index; + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLREDUCE_CHUNKSTEPS : 1)); + const int nranks = ncclShmem.comm.nRanks; + const ssize_t loopSize = nChannels*nranks*chunkSize; const ssize_t size = args->coll.count; - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; + int minChunkSize; + if (Proto::Id == NCCL_PROTO_LL) + minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T)); + if (Proto::Id == NCCL_PROTO_LL128) { + // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. + minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2; + } - ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0); + Primitives, 1, Proto> prims + (tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) { - ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels)); - ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); - ssize_t chunkOffset = gridOffset + bid*nranks*realChunkSize; + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t realChunkSize; + if (Proto::Id == NCCL_PROTO_SIMPLE) { + realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*nranks)); + realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); + } + else + realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize); + realChunkSize = int(realChunkSize); + + auto calcOffset = [&]__device__(int chunk)->ssize_t { + if (Proto::Id == NCCL_PROTO_SIMPLE) + return gridOffset + bid*nranks*realChunkSize + chunk*realChunkSize; + else + return gridOffset + (chunk*nChannels + bid)*realChunkSize; + }; + auto modRanks = [&]__device__(int r)->int { + return r - (r >= nranks ? nranks : 0); + }; - /////////////// begin AllReduce steps /////////////// ssize_t offset; int nelem; int chunk; // step 0: push data to next GPU - chunk = ring->devUserRanks[nranks-1]; - offset = chunkOffset + chunk * realChunkSize; + chunk = modRanks(ringIx + nranks-1); + offset = calcOffset(chunk); nelem = min(realChunkSize, size-offset); - - prims.send(thisInput+offset, nelem); + prims.send(offset, nelem); // k-2 steps: reduce and copy to next GPU for (int j=2; jdevUserRanks[nranks-j]; - offset = chunkOffset + chunk * realChunkSize; + chunk = modRanks(ringIx + nranks-j); + offset = calcOffset(chunk); nelem = min(realChunkSize, size-offset); - - prims.recvReduceSend(thisInput+offset, nelem); + prims.recvReduceSend(offset, nelem); } // step k-1: reduce this buffer and data, which will produce the final // result that we store in this data and push to the next GPU - chunk = ring->devUserRanks[0]; - offset = chunkOffset + chunk * realChunkSize; + chunk = ringIx + 0; + offset = calcOffset(chunk); nelem = min(realChunkSize, size-offset); - - prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem); + prims.directRecvReduceCopySend(offset, offset, offset, nelem, /*postOp=*/true); // k-2 steps: copy to next GPU for (int j=1; jdevUserRanks[nranks-j]; - offset = chunkOffset + chunk * realChunkSize; + chunk = modRanks(ringIx + nranks-j); + offset = calcOffset(chunk); nelem = min(realChunkSize, size-offset); - - prims.directRecvCopySend(thisOutput+offset, offset, nelem); + prims.directRecvCopySend(offset, offset, nelem); } // Make final copy from buffer to dest. - chunk = ring->devUserRanks[1]; - offset = chunkOffset + chunk * realChunkSize; + chunk = modRanks(ringIx + 1); + offset = calcOffset(chunk); nelem = min(realChunkSize, size-offset); - - // Final wait/copy. - prims.directRecv(thisOutput+offset, offset, nelem); + prims.directRecv(offset, nelem); } } -}; -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { + template + __device__ void runTreeUpDown(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads-2*WARP_SIZE; + const int nthreads = args->nThreads; const int bid = args->coll.bid; const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclTree* tree = &channel->tree; - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); - int chunkSize = args->coll.lastChunkSize; - const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T); - const ssize_t loopSize = nChannels*chunkSize; + ncclTree *tree = &ncclShmem.channel.tree; + ssize_t chunkSize = int( + Proto::Id == NCCL_PROTO_SIMPLE ? args->coll.lastChunkSize + /* LL & LL128 */ : Proto::calcBytePerStep()/sizeof(T)); + const ssize_t minChunkSize = int( + Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads-2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T)) + /* LL & LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T))); + const ssize_t loopSize = int(nChannels*chunkSize); const ssize_t size = args->coll.count; - if (loopSize > size) { - chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize; - } + if (loopSize > size) + chunkSize = divUp((int)size, int(nChannels*minChunkSize))*int(minChunkSize); - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - -#if 1 - if (tid < nthreads+WARP_SIZE) { - // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclPrimitives - prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - // Up - ssize_t offset = gridOffset + bid*chunkSize; - int nelem = min(chunkSize, size-offset); - if (tree->up == -1) { - prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem); - } else if (tree->down[0] == -1) { - prims.send(thisInput+offset, nelem); - } else { - prims.recvReduceSend(thisInput+offset, nelem); + { // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) + Primitives, /*Direct=*/0, Proto> prims + (tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff); + if (tree->up == -1) { + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true); + } + } + else if (tree->down[0] == -1) { + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.send(offset, nelem); + } + } + else { + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.recvReduceSend(offset, nelem); } } } - if (tid < nthreads+WARP_SIZE) { - // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclPrimitives - prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - // Down - ssize_t offset = gridOffset + bid*chunkSize; - int nelem = min(chunkSize, size-offset); - if (tree->up == -1) { - prims.directSend(thisOutput+offset, offset, nelem); - } else if (tree->down[0] == -1) { - prims.directRecv(thisOutput+offset, offset, nelem); - } else { - prims.directRecvCopySend(thisOutput+offset, offset, nelem); + { // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) + Primitives, /*Direct=*/1, Proto> prims + (tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff); + if (tree->up == -1) { + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.directSendFromOutput(offset, offset, nelem); + } + } + else if (tree->down[0] == -1) { + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.directRecv(offset, nelem); + } + } + else { + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.directRecvCopySend(offset, offset, nelem); } } } -#else - int nthreadsSplit = nthreads/2; - if (nthreadsSplit >= 256) nthreadsSplit += 64; + } + + template + __device__ void runTreeSplit(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; + ncclTree *tree = &ncclShmem.channel.tree; + ssize_t chunkSize = int( + Proto::Id != NCCL_PROTO_LL ? args->coll.lastChunkSize + : Proto::calcBytePerStep()/sizeof(T)); + const ssize_t minChunkSize = int( + Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads - 2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T)) : + Proto::Id == NCCL_PROTO_LL ? nthreads*(Proto::calcBytePerGrain()/sizeof(T)) + /* LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T))/8); + const ssize_t loopSize = int(nChannels*chunkSize); + const ssize_t size = args->coll.count; + + int nthreadsSplit; + if (Proto::Id == NCCL_PROTO_SIMPLE) { + nthreadsSplit = nthreads/2; + if (nthreadsSplit >= 256) nthreadsSplit += 64; + } else { // LL & LL128 + // Receiving from up to 3 sources is more compute intensive than sending + // to 3 dests. Use 70% for reduce and 30% for bcast. + nthreadsSplit = (nthreads*7/(10*WARP_SIZE))*WARP_SIZE; + } + + if (loopSize > size) + chunkSize = divUp((int)size, nChannels*int(minChunkSize))*int(minChunkSize); + if (tree->up == -1) { - if (tid < nthreads+WARP_SIZE) { - // ReduceAndBroadcast : max number of recv is 3, max number of send is 3 - ncclPrimitives - prims(tid, nthreads, tree->down, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0); + // Reduce and broadcast. Max number of recv is 3, max number of send is 3 + Primitives, /*Direct=*/1, Proto> + prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.directRecvReduceCopySend(offset, offset, offset, nelem, /*doPost=*/true); + } + } + else if (tid < nthreadsSplit) { + /* Reduce up. Max number of recv is 3, max number of send is 1 (binary tree + local). + * Why Direct=1???? + * Answer: Because despite not performing any direct operations, the ctor + * must assume Direct so that it can exchange direct pointers with remote ctors + * that are Direct, otherwise it hangs. A cleaner solution would be to seperate + * into DirectRecv and DirectSend capabilities, this ctor would have both=0, + * but the ctor above for tree roots would be DirectRecv=0 DirectSend=1. + */ + Primitives, /*Direct=*/1, Proto> + prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth); + if (tree->down[0] == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*chunkSize; + ssize_t offset = gridOffset + bid*int(chunkSize); int nelem = min(chunkSize, size-offset); - prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem); + prims.send(offset, nelem); } } - } else { - if (tid < nthreadsSplit + WARP_SIZE) { - // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclPrimitives - prims(tid, nthreadsSplit, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0); + else { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - // Up - ssize_t offset = gridOffset + bid*chunkSize; + ssize_t offset = gridOffset + bid*int(chunkSize); int nelem = min(chunkSize, size-offset); - if (tree->down[0] == -1) { - prims.send(thisInput+offset, nelem); - } else { - prims.recvReduceSend(thisInput+offset, nelem); - } - } - } else { - // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclPrimitives - prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &tree->up, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 2); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - // Down - ssize_t offset = gridOffset + bid*chunkSize; - int nelem = min(chunkSize, size-offset); - if (tree->down[0] == -1) { - prims.directRecv(thisOutput+offset, offset, nelem); - } else { - prims.directRecvCopySend(thisOutput+offset, offset, nelem); - } + prims.recvReduceSend(offset, nelem); } } } -#endif + else { + // Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local) + Primitives, /*Direct=*/1, Proto> + prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth); + if (tree->down[0] == -1) { + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.directRecv(offset, nelem); + } + } + else { + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*int(chunkSize); + int nelem = min(chunkSize, size-offset); + prims.directRecvCopySend(offset, offset, nelem); + } + } + } + } +} + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + using Proto = ProtoSimple; + runRing(args); } }; -template -class ncclFunction { -#define COLLNET_COPY_THREADS 96 - public: - __device__ void run(struct ncclWorkElem* args) { +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + #if CUDART_VERSION >= 11020 && CUDART_VERSION < 11040 && __CUDA_ARCH__ >= 800 + runTreeUpDown>(args); + #else + runTreeSplit>(args); + #endif + } +}; + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + static constexpr int COLLNET_COPY_THREADS = 96; const int tid = threadIdx.x; - //const int nthreads = args->nThreads-3*WARP_SIZE; const int bid = args->coll.bid; const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclDirect* tree = &channel->collTree; - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); - int chunkSize = args->coll.lastChunkSize; + struct ncclDirect* tree = &ncclShmem.channel.collTree; + const ssize_t chunkSize = int(args->coll.lastChunkSize); const ssize_t size = args->coll.count; const ssize_t loopSize = nChannels*tree->nHeads*chunkSize; - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - const int hasUp = (tree->up[0] >= 0) ? 1 : 0; const int hasDn = (tree->down[0] >= 0) ? 1 : 0; - const int nThreadsScatter = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 3*COLLNET_COPY_THREADS : 0; - const int nThreadsGather = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 2*COLLNET_COPY_THREADS : 0; - const int nThreadsBcast = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 0 : 2*COLLNET_COPY_THREADS; - // Gather does not need sync threads, sparing one more warp for reduce - const int nThreadsReduce = NCCL_SIMPLE_MAX_NTHREADS + WARP_SIZE - nThreadsScatter - nThreadsGather - nThreadsBcast; + const int nThreadsScatter = WARP_SIZE + ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 3*COLLNET_COPY_THREADS : 0); + const int nThreadsGather = ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 2*COLLNET_COPY_THREADS : 0); + const int nThreadsBcast = WARP_SIZE + ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 0 : 2*COLLNET_COPY_THREADS); + const int nThreadsReduce = args->nThreads - nThreadsScatter - nThreadsGather - nThreadsBcast; const int tidStartBcast = nThreadsGather; - const int tidStartScatter = tidStartBcast + nThreadsBcast + WARP_SIZE; - const int tidStartReduce = tidStartScatter + nThreadsScatter + WARP_SIZE; + const int tidStartScatter = tidStartBcast + nThreadsBcast; + const int tidStartReduce = tidStartScatter + nThreadsScatter; + + using Proto = ProtoSimple<1, 1>; if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) { // Scatter - ncclPrimitives - prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 4); + Primitives, /*Direct=*/0, Proto> + prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, args->sendbuff, args->recvbuff, 2*Proto::MaxGroupWidth); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize; int nelem = min(tree->nHeads*chunkSize, size-offset); - prims.scatter(thisInput+offset, nelem, chunkSize, tree->headRank, tree->shift); + prims.scatter(offset, nelem, chunkSize, tree->headRank, tree->shift); } } else if (tid >= tidStartReduce && tree->out != -1) { - // Reduce, send to network - ncclPrimitives - prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, NULL, stepSize, channel, comm, ncclShmem->ptrs, 6); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; - int nelem = min(chunkSize, size-offset); - if (hasDn) { - prims.recvReduceSend(thisInput+offset, nelem); - } else { - prims.send(thisInput+offset, nelem); + if (hasDn) { + // Reduce, send to network + Primitives, /*Direct=*/0, Proto> + prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.recvReduceSend(offset, nelem); + } + } else { + // Directly send to network + Primitives, /*Direct=*/0, Proto> + prims(tid-tidStartReduce, nThreadsReduce, nullptr, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.send(offset, nelem); } } } else if (tid < tidStartBcast && hasUp) { // Gather - ncclPrimitives - prims(tid, nThreadsGather, tree->up, NULL, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0); + Primitives, /*Direct=*/0, Proto> + prims(tid, nThreadsGather, tree->up, NULL, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize; int nelem = min(tree->nHeads*chunkSize, size-offset); - prims.gather(thisOutput+offset, nelem, chunkSize, tree->headRank, tree->shift); + prims.gather(offset, nelem, chunkSize, tree->headRank, tree->shift); } } else if (tid >= tidStartBcast && tid < tidStartScatter && tree->out != -1) { - // Recv from network, broadcast - ncclPrimitives - prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 2); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; - int nelem = min(chunkSize, size-offset); - if (hasDn) { - prims.recvCopySend(thisOutput+offset, nelem); - } else { - prims.recv(thisOutput+offset, nelem); - } - } - } - } -}; - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); - ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); - const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T); - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*nranks*chunkSize; - const ssize_t size = args->coll.count; - - ncclLLPrimitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm); - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize); - - /////////////// begin AllReduce steps /////////////// - ssize_t offset; - int nelem; - int chunk; - - // step 0: push data to next GPU - chunk = ring->devUserRanks[nranks-1]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - LLprims.send(thisInput+offset, nelem); - - // k-2 steps: reduce and copy to next GPU - for (int j=2; jdevUserRanks[nranks-j]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - LLprims.recvReduceSend(thisInput+offset, nelem); - } - - // step k-1: reduce this buffer and data, which will produce the final - // result that we store in this data and push to the next GPU - chunk = ring->devUserRanks[0]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem); - - // k-2 steps: copy to next GPU - for (int j=1; jdevUserRanks[nranks-j]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - LLprims.recvCopySend(thisOutput+offset, nelem); - } - - // Make final copy from buffer to dest. - chunk = ring->devUserRanks[1]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - // Here we need to copy from buffer to this output. - LLprims.recv(thisOutput+offset, nelem); - } - } -}; - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclTree* tree = &channel->tree; - const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); - ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); - const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T); - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; - - if (loopSize > size) { - chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize; - } - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - - do { - // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclLLPrimitives LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - // Up - ssize_t offset = gridOffset + bid*chunkSize; - int nelem = min(chunkSize, size-offset); - if (tree->up == -1) { - LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem); - } else if (tree->down[0] == -1) { - LLprims.send(thisInput+offset, nelem); - } else { - LLprims.recvReduceSend(thisInput+offset, nelem); - } - } - } while(0); - - do { - // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclLLPrimitives LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - // Down - ssize_t offset = gridOffset + bid*chunkSize; - int nelem = min(chunkSize, size-offset); - if (tree->up == -1) { - LLprims.send(thisOutput+offset, nelem); - } else if (tree->down[0] == -1) { - LLprims.recv(thisOutput+offset, nelem); - } else { - LLprims.recvCopySend(thisOutput+offset, nelem); - } - } - } while(0); - } -}; - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { } -}; - -#include "prims_ll128.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); - ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T)); - // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. - const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*nranks*chunkSize; - const ssize_t size = args->coll.count; - - ncclLL128Primitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm); - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize); - - /////////////// begin AllReduce steps /////////////// - ssize_t offset; - int nelem; - int chunk; - - // step 0: push data to next GPU - chunk = ring->devUserRanks[nranks-1]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - LLprims.send(thisInput+offset, nelem); - - // k-2 steps: reduce and copy to next GPU - for (int j=2; jdevUserRanks[nranks-j]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - LLprims.recvReduceSend(thisInput+offset, nelem); - } - - // step k-1: reduce this buffer and data, which will produce the final - // result that we store in this data and push to the next GPU - chunk = ring->devUserRanks[0]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem); - - // k-2 steps: copy to next GPU - for (int j=1; jdevUserRanks[nranks-j]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - LLprims.recvCopySend(thisOutput+offset, nelem); - } - - // Make final copy from buffer to dest. - chunk = ring->devUserRanks[1]; - offset = gridOffset + (chunk*nChannels+bid) * chunkSize; - nelem = min(chunkSize, size-offset); - - // Here we need to copy from buffer to this output. - LLprims.recv(thisOutput+offset, nelem); - } - } -}; - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclTree* tree = &channel->tree; - const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); - ssize_t chunkSize = args->coll.lastChunkSize; - const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8; - const ssize_t loopSize = nChannels*chunkSize; - int nthreadsSplit = NCCL_LL128_SPLIT(nthreads); - const ssize_t size = args->coll.count; - - if (loopSize > size) { - chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize; - } - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - - if (tree->up == -1) { - // ReduceAndBroadcast : max number of recv is 3, max number of send is 3 - ncclLL128Primitives LLprims(tid, nthreads, tree->down, tree->down, stepSize, channel, comm); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*chunkSize; - int nelem = min(chunkSize, size-offset); - LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem); - } - } else { - if (tid < nthreadsSplit) { - // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclLL128Primitives LLprims(tid, nthreadsSplit, tree->down, &tree->up, stepSize, channel, comm); + if (hasDn) { + // Recv from network, broadcast + Primitives, /*Direct=*/0, Proto> + prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - // Up - ssize_t offset = gridOffset + bid*chunkSize; + ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); - if (tree->down[0] == -1) { - LLprims.send(thisInput+offset, nelem); - } else { - LLprims.recvReduceSend(thisInput+offset, nelem); - } + prims.recvCopySend(offset, nelem, /*postOp=*/true); } } else { - // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclLL128Primitives LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, stepSize, channel, comm); + // Recv from network (no post thread needed) + Primitives, /*Direct=*/0, Proto> + prims(tid-tidStartBcast, nThreadsBcast, &tree->out, nullptr, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - // Down - ssize_t offset = gridOffset + bid*chunkSize; + ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); - if (tree->down[0] == -1) { - LLprims.recv(thisOutput+offset, nelem); - } else { - LLprims.recvCopySend(thisOutput+offset, nelem); - } + prims.recv(offset, nelem, /*postOp=*/true); } } } } }; -template -class ncclFunction { - public: -__device__ void run(struct ncclWorkElem* args) { } +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } +}; + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runTreeSplit(args); + } +}; + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } +}; + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runTreeSplit(args); + } }; diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h index 72216ac..f867315 100644 --- a/src/collectives/device/broadcast.h +++ b/src/collectives/device/broadcast.h @@ -5,158 +5,78 @@ ************************************************************************/ #include "devcomm.h" -#include "primitives.h" #include "collectives.h" +#include "primitives.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); - const int chunkSize = stepSize * BROADCAST_CHUNKSTEPS; - const ssize_t loopSize = nChannels*(ssize_t)chunkSize; - const ssize_t size = args->coll.count; - const int rank = ring->devUserRanks[0]; - const int nextRank = ring->devUserRanks[1]; - const int root = args->coll.root; +namespace { + template + __device__ void runRing(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; + ncclRing *ring = &ncclShmem.channel.ring; + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? BROADCAST_CHUNKSTEPS : 1)); + const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))); + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; + const int rank = ring->devUserRanks[0]; + const int nextRank = ring->devUserRanks[1]; + const int root = args->coll.root; - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; + T *inputBuf = (T*)args->sendbuff; + T *outputBuf = (T*)args->recvbuff; + Primitives, 0, Proto> + prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf); - ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t realChunkSize; + if (Proto::Id == NCCL_PROTO_SIMPLE) { + realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels)); + realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); + } + else if (Proto::Id == NCCL_PROTO_LL) + realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize; + else if (Proto::Id == NCCL_PROTO_LL128) + realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128); + realChunkSize = int(realChunkSize); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); - ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); - ssize_t offset = gridOffset + bid*realChunkSize; - int nelem = min(realChunkSize, size-offset); + ssize_t offset = gridOffset + int(bid*realChunkSize); + int nelem = min(realChunkSize, size-offset); - if (rank == root) { - if (thisInput == thisOutput) { - prims.send(thisInput+offset, nelem); - } else { - prims.copySend(thisInput+offset, thisOutput+offset, nelem); - } - } else if (nextRank == root) { - prims.recv(thisOutput+offset, nelem); + if (rank == root) { + if (inputBuf == outputBuf) { + prims.send(offset, nelem); } else { - prims.recvCopySend(thisOutput+offset, nelem); + prims.copySend(offset, offset, nelem); } + } else if (nextRank == root) { + prims.recv(offset, nelem); + } else { + prims.recvCopySend(offset, nelem); } } + } +} + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + using Proto = ProtoSimple; + runRing(args); + } }; -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); - ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; - const int rank = ring->devUserRanks[0]; - const int nextRank = ring->devUserRanks[1]; - const int root = args->coll.root; - - ncclLLPrimitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm); - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - if (size-gridOffset < loopSize) { - chunkSize = args->coll.lastChunkSize; - } - ssize_t offset = gridOffset + bid*chunkSize; - - int nelem = min(chunkSize, size-offset); - if (rank == root) { - if (thisInput == thisOutput) { - LLprims.send(thisInput+offset, nelem); - } else { - LLprims.copySend(thisInput + offset, thisOutput + offset, nelem); - } - } else if (nextRank == root) { - LLprims.recv(thisOutput + offset, nelem); - } else { - LLprims.recvCopySend(thisOutput + offset, nelem); - } - } - } +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } }; -#include "prims_ll128.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); - ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T)); - const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; - const int rank = ring->devUserRanks[0]; - const int nextRank = ring->devUserRanks[1]; - const int root = args->coll.root; - - ncclLL128Primitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm); - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize); - ssize_t offset = gridOffset + bid*chunkSize; - - int nelem = min(chunkSize, size-offset); - if (rank == root) { - if (thisInput == thisOutput) { - LLprims.send(thisInput+offset, nelem); - } else { - LLprims.copySend(thisInput + offset, thisOutput + offset, nelem); - } - } else if (nextRank == root) { - LLprims.recv(thisOutput + offset, nelem); - } else { - LLprims.recvCopySend(thisOutput + offset, nelem); - } - } - } -}; - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} -}; - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } }; diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 2673a0a..f37995d 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -10,108 +10,152 @@ #include "collectives.h" #include "devcomm.h" - #if __CUDA_ARCH__ >= 800 #define COLL_UNROLL 8 -#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree #else #define COLL_UNROLL 4 -#define NCCL_MAX_DEV_ARITY NCCL_MAX_TREE_ARITY #endif -// Exit If Abort Barrier across CTA: make sure all threads exit consistently -// Each thread sets a predicate to true if abort == 1 -// all CTA's threads enter the barrier and do a popc on their predicates being True -// If any of the thread's predicate was True, all the threads call exit() -static inline __device__ void exitIfAbortBarrier(int abort) { +#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree + +__device__ inline bool barrierReduceAny(int bit) { uint32_t popc; - asm ("{"); - asm volatile (" .reg .pred barr_pred;"); - asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abort)); - asm volatile (" bar.red.popc.u32 %0, 0, barr_pred;" : "=r"(popc)); - asm ("}"); - if (popc) { asm volatile ("exit;"); } + asm ("{" + ".reg .pred barr_pred;" + "setp.eq.u32 barr_pred, %1, 1;" + "bar.red.popc.u32 %0, 0, barr_pred;" + "}" : "=r"(popc) : "r"(bit)); + return popc != 0; } -typedef void(*ncclKern_t)(struct ncclWorkElem* args); -extern __device__ ncclKern_t ncclFuncs[]; +template +__device__ int copyToShmem(T *dst, T const *src, int turn=0) { + static_assert(sizeof(uint64_t) <= alignof(T), "Uhoh"); + uint64_t *d = reinterpret_cast(dst); + uint64_t const *s = reinterpret_cast(src); + int t = threadIdx.x - turn; + if (t < 0) t += blockDim.x; + int n = sizeof(T)/sizeof(uint64_t); -static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) { - int* d = (int*)dst; - int* s = (int*)src; - for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o]; + int delta = (n + WARP_SIZE-1) & -WARP_SIZE; // round up to warp lane 0 + if (delta < blockDim.x) { + turn += delta; + if (turn >= blockDim.x) turn -= blockDim.x; + } + else + turn = 0; + + n -= t; + d += t; + s += t; + #pragma unroll + for (int i=0; i < divUp(sizeof(T), WARP_SIZE*sizeof(uint64_t)); i++) { + if (n > 0) { + *d = *s; + d += blockDim.x; + s += blockDim.x; + n -= blockDim.x; + } + } + return turn; } -static __device__ void load_coll(struct ncclWork* localWork, struct ncclWork *hostWork, struct ncclWork* workFifo, int tid, struct ncclDevComm* comm) { - load_parallel(localWork, workFifo, sizeof(struct ncclWork), tid); - // Check whether the last operation was aborted and make sure all threads exit - int abort = tid == 0 ? *(comm->abortFlag) : 0; - exitIfAbortBarrier(abort); - if (tid == 0) hostWork->elems[0].active = 0; -} - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} +template +struct RunWorkElement { + __device__ void run(ncclWorkElem*) { + // Put NOT IMPLEMENTED behavior here. + } }; -struct ncclShmemPtrs { +template +struct RunWork { + __device__ void run(ncclWork *w) { + int tid = threadIdx.x; + #pragma unroll 1 + for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) { + if (tid < w->elems[e].nThreads) + RunWorkElement().run(&w->elems[e]); + } + } +}; + +typedef void(*ncclKern_t)(); +extern __device__ ncclKern_t ncclFuncs[]; + +struct ncclShmemGroup { + ncclConnInfo *recvConns[NCCL_MAX_DIRECT_ARITY]; + ncclConnInfo *sendConns[NCCL_MAX_DIRECT_ARITY]; void* srcs[NCCL_MAX_DIRECT_ARITY+1]; void* dsts[NCCL_MAX_DIRECT_ARITY+1]; }; struct ncclShmemData { union { - volatile uint64_t data[NCCL_LL128_SHMEM_SIZE]; - struct ncclShmemPtrs ptrs[NCCL_MAX_GROUPS]; + uint64_t ll128warp[NCCL_LL128_MAX_NTHREADS/WARP_SIZE][NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE]; + struct ncclShmemGroup groups[NCCL_MAX_GROUPS]; }; - struct ncclWork localWork; + ncclDevComm comm; + ncclChannel channel; + ncclWork work; }; -extern __device__ struct ncclShmemData *ncclShmem; -template -__device__ void ncclKernel(struct ncclWorkElem first) { +extern __shared__ ncclShmemData ncclShmem; + +template +__device__ void ncclKernel(ncclWorkElem first) { int tid = threadIdx.x; int bid = blockIdx.x; - __shared__ struct ncclShmemData shmem; - ncclShmem = &shmem; - auto f = ncclFunction(); + int turn = copyToShmem(&ncclShmem.comm, first.comm); + // get address of channel without incurring indirect load from ncclDevCom::channels + ncclChannel *channel = &((ncclDevCommAndChannels*)first.comm)->channels[bid]; + turn = copyToShmem(&ncclShmem.channel, channel, turn); - struct ncclDevComm* comm = first.comm; - struct ncclChannel* channel = comm->channels+bid; - struct ncclWorkElem* w = NULL; + // To optimize for latency, (only) the first operation is passed as argument. + if (bid == 0 && first.active != 0) { + turn = copyToShmem(&ncclShmem.work.elems[0], &first, turn); + if (tid == 0) ncclShmem.work.elems[1].active = 0; + } + __syncthreads(); // publish ncclShmem - /* To optimize for latency, (only) the first operation is passed as argument.*/ - if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) w = &first; + ncclWork *workFifoHost = ncclShmem.channel.workFifo; + ncclWork *workFifoDev = ncclShmem.channel.workFifoDev; + int workFifoIx = ncclShmem.channel.index; - while (1) { - if (w == NULL) { - w = shmem.localWork.elems; - __syncthreads(); - load_coll(&shmem.localWork, channel->workFifo+channel->index, channel->workFifoDev+channel->index, tid, comm); + if (bid == 0 && first.active != 0) + goto SkipLoadWork; + + while (true) { + copyToShmem(&ncclShmem.work, &workFifoDev[workFifoIx]); // turn no longer helps + { // Check whether the last operation was aborted and make sure all threads exit + int aborted = tid == 0 ? *ncclShmem.comm.abortFlag : 0; + if (barrierReduceAny(aborted)) // publish ncclShmem.work + break; + if (tid == 0) + workFifoHost[workFifoIx].elems[0].active = 0; } - if (tid < w->nThreads) { - if (w->funcIndex == FINDEX) { - f.run(w); - } else { - ncclFuncs[w->funcIndex](w); - } - } - if (tid == 0) channel->index = (channel->index+1) % NCCL_MAX_OPS; - if (w->active == 2) { - return; - } - w = NULL; + + SkipLoadWork: + workFifoIx = (workFifoIx + 1)%NCCL_MAX_OPS; + if (tid == 0) + channel->index = workFifoIx; // write back to real channel, not shmem shadow + + if (ncclShmem.work.elems[0].funcIndex == FnIndex) + RunWork().run(&ncclShmem.work); + else + ncclFuncs[ncclShmem.work.elems[0].funcIndex](); + + if (ncclShmem.work.elems[0].active == 2) + break; + __syncthreads(); } } // Only generate kernels for SUM #if NCCL_OP == 0 #define IMPL_COLL_KERN(func, algo, proto, redop, type, fIndex) \ -__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first) { \ - ncclKernel, type, COLL_UNROLL, fIndex>(first); \ +__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(ncclWorkElem first) { \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex>(first); \ } #else #define IMPL_COLL_KERN(func, algo, proto, redop, type, fInded) @@ -119,9 +163,8 @@ __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkEl // Examples : AllReduce, RING, LL, Sum, uint8 #define IMPL_COLL_FUNC(func, algo, proto, redop, type) \ -__device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args) { \ - auto f = ncclFunction, type, COLL_UNROLL>(); \ - f.run(args); \ +__device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)() { \ + RunWork, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem.work); \ } // Only generate inline kernels for LL @@ -154,6 +197,8 @@ __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkEl #define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, float, ncclFloat32) #elif NCCL_TYPE == 8 #define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, double, ncclFloat64) +#elif NCCL_TYPE == 9 && defined(__CUDA_BF16_TYPES_EXIST__) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, __nv_bfloat16, ncclBfloat16) #endif // Reduction define all functions @@ -165,6 +210,8 @@ __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkEl #define IMPL_COLL_R(func) IMPL_COLL2(func, Min); #elif NCCL_OP == 3 #define IMPL_COLL_R(func) IMPL_COLL2(func, Max); +#elif NCCL_OP == 4 +#define IMPL_COLL_R(func) IMPL_COLL2(func, Avg); #endif #if NCCL_OP == 0 && NCCL_TYPE == 0 diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index ff466a0..c90988c 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -24,10 +24,19 @@ inline __device__ void loadPtr(void** ptr, T* &v) { typedef uint64_t PackType; +template +struct FuncTraits /*{ + __device__ static Fn make(); + __device__ static T preOp(Fn, T); + __device__ static T postOp(Fn, T); +}*/; + // unpack x and y to elements of type T and apply FUNC to each element template struct MULTI { - __device__ PackType operator()(const PackType x, const PackType y) const; + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const; + __device__ PackType preOp(FUNC fn, PackType x) const; + __device__ PackType postOp(FUNC fn, PackType x) const; }; template @@ -41,17 +50,39 @@ struct MULTI { }; }; - __device__ PackType operator()(const PackType x, const PackType y) const { + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; // for char, we do these as vector ops - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); + cr.a = fn(cx.a, cy.a); + cr.b = fn(cx.b, cy.b); return cr.storage; } + __device__ PackType preOp(FUNC fn, PackType x) const { + union { + PackType pack; + int8_t elt[8]; + } u; + u.pack = x; + #pragma unroll + for (int i=0; i < 8; i++) + u.elt[i] = FuncTraits().preOp(fn, u.elt[i]); + return u.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + union { + PackType pack; + int8_t elt[8]; + } u; + u.pack = x; + #pragma unroll + for (int i=0; i < 8; i++) + u.elt[i] = FuncTraits().postOp(fn, u.elt[i]); + return u.pack; + } }; template @@ -65,17 +96,39 @@ struct MULTI { }; }; - __device__ PackType operator()(const PackType x, const PackType y) const { + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; // for char, we do these as vector ops - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); + cr.a = fn(cx.a, cy.a); + cr.b = fn(cx.b, cy.b); return cr.storage; } + __device__ PackType preOp(FUNC fn, PackType x) const { + union { + PackType pack; + uint8_t elt[8]; + } u; + u.pack = x; + #pragma unroll + for (int i=0; i < 8; i++) + u.elt[i] = FuncTraits().preOp(fn, u.elt[i]); + return u.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + union { + PackType pack; + uint8_t elt[8]; + } u; + u.pack = x; + #pragma unroll + for (int i=0; i < 8; i++) + u.elt[i] = FuncTraits().postOp(fn, u.elt[i]); + return u.pack; + } }; template @@ -89,16 +142,36 @@ struct MULTI { }; }; - __device__ PackType operator()(const PackType x, const PackType y) const { + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); + cr.a = fn(cx.a, cy.a); + cr.b = fn(cx.b, cy.b); return cr.storage; } + __device__ PackType preOp(FUNC fn, PackType x) const { + union { + PackType pack; + int32_t elt[2]; + } u; + u.pack = x; + u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); + u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); + return u.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + union { + PackType pack; + int32_t elt[2]; + } u; + u.pack = x; + u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); + u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); + return u.pack; + } }; template @@ -112,16 +185,36 @@ struct MULTI { }; }; - __device__ PackType operator()(const PackType x, const PackType y) const { + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); + cr.a = fn(cx.a, cy.a); + cr.b = fn(cx.b, cy.b); return cr.storage; } + __device__ PackType preOp(FUNC fn, PackType x) const { + union { + PackType pack; + uint32_t elt[2]; + } u; + u.pack = x; + u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); + u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); + return u.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + union { + PackType pack; + uint32_t elt[2]; + } u; + u.pack = x; + u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); + u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); + return u.pack; + } }; template @@ -129,22 +222,69 @@ struct MULTI { static_assert(sizeof(PackType) == 4 * sizeof(half), "PackType must be four times the size of half."); - struct PackHalf2 { - half2 a, b; + union Converter { + PackType pack; + half2 h2[2]; }; - - __device__ PackType operator()(const PackType x, const PackType y) const { - struct PackHalf2 cx, cy, cr; - cx = *(reinterpret_cast(&x)); - cy = *(reinterpret_cast(&y)); - - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); - - return *(reinterpret_cast(&cr)); + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { + Converter cx, cy, cr; + cx.pack = x; + cy.pack = y; + cr.h2[0] = fn(cx.h2[0], cy.h2[0]); + cr.h2[1] = fn(cx.h2[1], cy.h2[1]); + return cr.pack; + } + __device__ PackType preOp(FUNC fn, PackType x) const { + Converter c; + c.pack = x; + c.h2[0] = FuncTraits().preOp(fn, c.h2[0]); + c.h2[1] = FuncTraits().preOp(fn, c.h2[1]); + return c.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + Converter c; + c.pack = x; + c.h2[0] = FuncTraits().postOp(fn, c.h2[0]); + c.h2[1] = FuncTraits().postOp(fn, c.h2[1]); + return c.pack; } }; +#if defined(__CUDA_BF16_TYPES_EXIST__) +template +struct MULTI { + static_assert(sizeof(PackType) == 4 * sizeof(__nv_bfloat16), + "PackType must be four times the size of __nv_bfloat16."); + + union Converter { + PackType pack; + __nv_bfloat162 h2[2]; + }; + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { + Converter cx, cy, cr; + cx.pack = x; + cy.pack = y; + cr.h2[0] = fn(cx.h2[0], cy.h2[0]); + cr.h2[1] = fn(cx.h2[1], cy.h2[1]); + return cr.pack; + } + __device__ PackType preOp(FUNC fn, PackType x) const { + Converter c; + c.pack = x; + c.h2[0] = FuncTraits().preOp(fn, c.h2[0]); + c.h2[1] = FuncTraits().preOp(fn, c.h2[1]); + return c.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + Converter c; + c.pack = x; + c.h2[0] = FuncTraits().postOp(fn, c.h2[0]); + c.h2[1] = FuncTraits().postOp(fn, c.h2[1]); + return c.pack; + } +}; +#endif + template struct MULTI { static_assert(sizeof(PackType) == 2 * sizeof(float), @@ -156,46 +296,120 @@ struct MULTI { }; }; - __device__ PackType operator()(const PackType x, const PackType y) const { + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); + cr.a = fn(cx.a, cy.a); + cr.b = fn(cx.b, cy.b); return cr.storage; } + __device__ PackType preOp(FUNC fn, PackType x) const { + union { + PackType pack; + float elt[2]; + } u; + u.pack = x; + u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); + u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); + return u.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + union { + PackType pack; + float elt[2]; + } u; + u.pack = x; + u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); + u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); + return u.pack; + } }; template struct MULTI { static_assert(sizeof(PackType) == sizeof(double), "PackType must be the same size as double."); - __device__ PackType operator()(const PackType x, const PackType y) const { - double rv = FUNC()(__longlong_as_double(x), __longlong_as_double(y)); + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { + double rv = fn(__longlong_as_double(x), __longlong_as_double(y)); return __double_as_longlong(rv); } + __device__ PackType preOp(FUNC fn, PackType x) const { + union { + PackType pack; + double elt; + } u; + u.pack = x; + u.elt = FuncTraits().preOp(fn, u.elt); + return u.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + union { + PackType pack; + double elt; + } u; + u.pack = x; + u.elt = FuncTraits().postOp(fn, u.elt); + return u.pack; + } }; template struct MULTI { static_assert(sizeof(PackType) == sizeof(uint64_t), "PackType must be the same size as uint64_t."); - __device__ PackType operator()(const PackType x, const PackType y) const { - uint64_t rv = FUNC()(x, y); + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { + uint64_t rv = fn(x, y); return rv; } + __device__ PackType preOp(FUNC fn, PackType x) const { + union { + PackType pack; + uint64_t elt; + } u; + u.pack = x; + u.elt = FuncTraits().preOp(fn, u.elt); + return u.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + union { + PackType pack; + uint64_t elt; + } u; + u.pack = x; + u.elt = FuncTraits().postOp(fn, u.elt); + return u.pack; + } }; template struct MULTI { static_assert(sizeof(PackType) == sizeof(int64_t), "PackType must be the same size as int64_t."); - __device__ PackType operator()(const PackType x, const PackType y) const { - int64_t rv = FUNC()((int64_t)x, (int64_t)y); + __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { + int64_t rv = fn((int64_t)x, (int64_t)y); return rv; } + __device__ PackType preOp(FUNC fn, PackType x) const { + union { + PackType pack; + int64_t elt; + } u; + u.pack = x; + u.elt = FuncTraits().preOp(fn, u.elt); + return u.pack; + } + __device__ PackType postOp(FUNC fn, PackType x) const { + union { + PackType pack; + int64_t elt; + } u; + u.pack = x; + u.elt = FuncTraits().postOp(fn, u.elt); + return u.pack; + } }; template inline __device__ @@ -234,13 +448,35 @@ void vStore(volatile half* ptr, const half val) { } #endif +#if defined(__CUDA_BF16_TYPES_EXIST__) +template<> inline __device__ +__nv_bfloat16 vFetch<__nv_bfloat16>(const volatile __nv_bfloat16* ptr) { + __nv_bfloat16 r; + r = ((__nv_bfloat16*)ptr)[0]; + return r; +} + +template<> inline __device__ +void vStore<__nv_bfloat16>(volatile __nv_bfloat16* ptr, const __nv_bfloat16 val) { + ((__nv_bfloat16*)ptr)[0] = val; +} +#endif + typedef ulong2 Pack128; template struct MULTI128 { - __device__ void operator()(Pack128& x, Pack128& y) { - x.x = MULTI()(x.x, y.x); - x.y = MULTI()(x.y, y.y); + __device__ void operator()(FUNC fn, Pack128& x, Pack128 const& y) const { + x.x = MULTI()(fn, x.x, y.x); + x.y = MULTI()(fn, x.y, y.y); + } + __device__ void preOp(FUNC fn, Pack128 &x) const { + x.x = MULTI().preOp(fn, x.x); + x.y = MULTI().preOp(fn, x.y); + } + __device__ void postOp(FUNC fn, Pack128 &x) const { + x.x = MULTI().postOp(fn, x.x); + x.y = MULTI().postOp(fn, x.y); } }; @@ -253,7 +489,8 @@ inline __device__ void Store128(Pack128* p, Pack128& v) { template __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t, - int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) { + FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem + ) { const int inc = nw * UNROLL * WARP_SIZE; int offset = w * UNROLL * WARP_SIZE + t; @@ -266,22 +503,30 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const T vals[UNROLL]; // Load and reduce for (int u = 0; u < UNROLL; ++u) vals[u] = vFetch(srcs[0]+u*WARP_SIZE); + if (preOpSrc0) { + for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits().preOp(fn, vals[u]); + } #pragma unroll for (int i=1; i().postOp(fn, vals[u]); + } + // Store #pragma unroll for (int i = 0; i < MINDSTS; i++) { @@ -301,7 +546,8 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const template __device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t, - int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) { + FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack + ) { const int inc = nw * UNROLL * WARP_SIZE; int offset = w * UNROLL * WARP_SIZE + t; @@ -314,22 +560,30 @@ __device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, c Pack128 vals[UNROLL]; // Load and reduce for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE); + if (preOpSrc0) { + for (int u = 0; u < UNROLL; ++u) MULTI128().preOp(fn, vals[u]); + } #pragma unroll for (int i=1; i()(vals[u], vals2[u]); + for (int u = 0; u < UNROLL; ++u) MULTI128()(fn, vals[u], vals2[u]); } #pragma unroll for (int i=MINSRCS; i()(vals[u], vals2[u]); + for (int u = 0; u < UNROLL; ++u) MULTI128()(fn, vals[u], vals2[u]); } } + if (postOp) { + #pragma unroll + for (int u = 0; u < UNROLL; ++u) MULTI128().postOp(fn, vals[u]); + } + // Store #pragma unroll for (int i = 0; i < MINDSTS; i++) { @@ -353,9 +607,9 @@ __device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(Pack128); } #define PACKELEMS (sizeof(Pack128) / sizeof(T)) template -__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads, - int nsrcs, const T** srcs, int ndsts, T** dsts, - int N) { +__device__ __forceinline__ void ReduceOrCopyMulti( + const int tid, const int nthreads, FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, int N + ) { int Nrem = N; if (Nrem <= 0) return; @@ -381,7 +635,8 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre int Npack = (Nrem / (PACKELEMS*UNROLL*WARP_SIZE)) * (UNROLL*WARP_SIZE); // round down int Nelem = Npack * PACKELEMS; - ReduceCopy128bMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack); + ReduceCopy128bMulti + (w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); Nrem -= Nelem; if (Nrem == 0) return; @@ -391,7 +646,8 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre Npack = Nrem / PACKELEMS; Nelem = Npack * PACKELEMS; - ReduceCopy128bMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack); + ReduceCopy128bMulti + (w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); Nrem -= Nelem; if (Nrem == 0) return; @@ -401,14 +657,16 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre // unrolled, by-type (mostly for unaligned buffers) int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down - ReduceCopyMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem); + ReduceCopyMulti + (w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem); Nrem -= Nelem; if (Nrem == 0) return; offset += Nelem; // no unroll, by type. Should finish what's remaining. - ReduceCopyMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem); + ReduceCopyMulti + (w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem); } #endif // COMMON_KERNEL_H_ diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu index 553a882..15d7a6e 100644 --- a/src/collectives/device/functions.cu +++ b/src/collectives/device/functions.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -8,7 +8,7 @@ #include "collectives.h" #include "common.h" -__device__ struct ncclShmemData* ncclShmem; +__shared__ ncclShmemData ncclShmem; #define NCCL_FUNC5(func, algo, redop, type) \ NCCL_FUNC_NAME(func, algo, LL, redop, type), \ @@ -20,6 +20,31 @@ __device__ struct ncclShmemData* ncclShmem; NCCL_FUNC5(func, RING, redop, type), \ NCCL_FUNC5(func, COLLNET, redop, type) +#if defined(__CUDA_BF16_TYPES_EXIST__) +// Must be consistent with ncclDataType_t +#define NCCL_FUNCS3A(func, redop) \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, uint8_t), \ + NCCL_FUNC4(func, redop, int32_t), \ + NCCL_FUNC4(func, redop, uint32_t), \ + NCCL_FUNC4(func, redop, int64_t), \ + NCCL_FUNC4(func, redop, uint64_t), \ + NCCL_FUNC4(func, redop, half), \ + NCCL_FUNC4(func, redop, float), \ + NCCL_FUNC4(func, redop, double), \ + NCCL_FUNC4(func, redop, __nv_bfloat16) +#define NCCL_FUNCS3B(func, redop) \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t), \ + NCCL_FUNC4(func, redop, int8_t) +#else // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(func, redop) \ NCCL_FUNC4(func, redop, int8_t), \ @@ -41,17 +66,21 @@ __device__ struct ncclShmemData* ncclShmem; NCCL_FUNC4(func, redop, int8_t), \ NCCL_FUNC4(func, redop, int8_t), \ NCCL_FUNC4(func, redop, int8_t) +#endif // Must be consistent with ncclRedOp_t #define NCCL_FUNCS2A(func) \ NCCL_FUNCS3A(func, Sum ), \ NCCL_FUNCS3A(func, Prod), \ NCCL_FUNCS3A(func, Max ), \ - NCCL_FUNCS3A(func, Min ) + NCCL_FUNCS3A(func, Min ), \ + NCCL_FUNCS3A(func, Avg) + #define NCCL_FUNCS2B(func) \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ + NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum) // Must be consistent with ncclFunc_t diff --git a/src/collectives/device/gen_rules.sh b/src/collectives/device/gen_rules.sh index 97dc0ae..e99dc61 100755 --- a/src/collectives/device/gen_rules.sh +++ b/src/collectives/device/gen_rules.sh @@ -1,19 +1,26 @@ #!/bin/bash # -# Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. # # See LICENSE.txt for license information # dir=$1 +datatypes="i8 u8 i32 u32 i64 u64 f16 f32 f64" +if [ "$CUDA_MAJOR" -ge 11 ] +then + datatypes+=" bf16" +fi + targets="GENOBJS := \\\\\n" for base in sendrecv all_reduce all_gather broadcast reduce reduce_scatter; do opn=0 - for op in sum prod min max; do + for op in sum prod min max avg; do dtn=0 - for dt in i8 u8 i32 u32 i64 u64 f16 f32 f64; do + # Order must match that of the ncclDataType_t enum + for dt in ${datatypes}; do echo "${dir}/${base}_${op}_${dt}.o : ${base}.cu ${dir}/${base}.dep" echo " @printf \"Compiling %-35s > %s\\\\n\" ${base}.cu ${dir}/${base}_${op}_${dt}.o" echo " mkdir -p ${dir}" diff --git a/src/collectives/device/op128.h b/src/collectives/device/op128.h index 9405dc2..46fc8df 100644 --- a/src/collectives/device/op128.h +++ b/src/collectives/device/op128.h @@ -33,4 +33,36 @@ inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_ :: "l"(v0), "l"(v1), "l"(shmemAsmPtr)); } +template +inline __device__ void loadShmemMisaligned128(T *ptr, uint64_t &v0, uint64_t &v1) { + union { + uint32_t tmp4[4]; + uint64_t tmp8[2]; + }; + if(sizeof(T) < 4) { + uint32_t *ptr4 = reinterpret_cast(reinterpret_cast(ptr) & -uintptr_t(4)); + #pragma unroll + for(int e=0; e < 4; e++) { + // Produce 4 bytes of sub-register type by reading 2 4-byte + // aligned values and shifting. + uint32_t lo, hi; + asm("ld.shared.b32 %0,[%1];" : "=r"(lo) : "l"(ptr4+e+0)); + asm("ld.shared.b32 %0,[%1];" : "=r"(hi) : "l"(ptr4+e+1)); + tmp4[e] = __funnelshift_r(lo, hi, 8*(int(reinterpret_cast(ptr))%4)); + } + } + else if(sizeof(T) == 4) { + #pragma unroll + for(int e=0; e < 4; e++) + asm("ld.shared.b32 %0,[%1];" : "=r"(tmp4[e]) : "l"(ptr+e)); + } + else /*sizeof(T)==8*/ { + #pragma unroll + for(int e=0; e < 2; e++) + asm("ld.shared.b64 %0,[%1];" : "=l"(tmp8[e]) : "l"(ptr+e)); + } + v0 = tmp8[0]; + v1 = tmp8[1]; +} + #endif diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index 605a491..8f63447 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -11,381 +11,132 @@ #include "reduce_kernel.h" // for reduction funcs #include "common.h" -#define SPINS_BEFORE_CHECK_ABORT 1000000 +#define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000 -// Unroll unconditionally the first send/recv since nsend/nrecv should be at -// least 1 if SEND/RECV is set. -#define FOR_SEND(func, ...) do { \ - if (SEND) { \ - /* Send to far first, then close */ \ - for (int i=1; i +struct ProtoSimple { + static constexpr int Id = NCCL_PROTO_SIMPLE; + static constexpr int SlicePerChunk = SlicePerChunk_1; + static constexpr int StepPerSlice = StepPerSlice_1; + static constexpr int Unroll = Unroll_1; -#define FOR_RECV(func, ...) do { \ - if (RECV) { \ - /* Recv from close first, then far */ \ - func(0, ##__VA_ARGS__); \ - for (int i=1; i -class ncclPrimitives { - private: - const int tid; - int nthreads; - int nworkers; - const int stepSize; - int nrecv = 0; - int nsend = 0; - struct ncclConnInfo* conn = NULL; - volatile int* connSizesFifoPtr = NULL; - void** connPtrsFifoPtr = NULL; - volatile uint64_t* connHeadPtr = NULL; - volatile uint64_t* connTailPtr = NULL; - uint64_t connTailCache; // Cache last seen value - uint64_t connHeadCache; // Cache last seen value - - int index; // Peer index I'm responsible for - int peer = -1; - int role = 0; - int group; - uint64_t step; - T* direct = NULL; - T* buff; - struct ncclDevComm* comm; - - const T** srcs; - T** dsts; - - // Don't use barrier 0 as it's used by the final sync - inline __device__ void barrier() { - if (nthreads == WARP_SIZE) __syncwarp(); - else asm volatile ("bar.sync %0, %1;" :: "r"(group+1), "r"(nthreads)); + // Data bytes (no flags etc) in one step of the fifo queue. + __device__ static int calcBytePerStep() { + return ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; } - inline __device__ void subBarrier() { - if (nworkers == nthreads) barrier(); - else asm volatile ("bar.sync %0, %1;" :: "r"(group+2), "r"(nworkers)); + // Granularity of data bytes transferred per thread. + __device__ static int calcBytePerGrain() { + return sizeof(uint64_t); // Bogus value? Nobody queries this metric for simple. } - - uint32_t spins = 0; - uint32_t abort = 0; - - inline __device__ int checkAbort() { - spins++; - if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { - abort = *(comm->abortFlag); - spins = 0; - } - return abort; - } - - template - inline __device__ T* directPtr(ssize_t directOffset) { - return DIRECTPTR && direct ? direct+directOffset : buff+(step%NCCL_STEPS)*stepSize; - } - - template - inline __device__ void waitSend(ssize_t directOffset, int nbytes) { - spins = 0; - while (connHeadCache + NCCL_STEPS < step + SLICESTEPS) { - connHeadCache = *connHeadPtr; - if (checkAbort()) break; - } - if (connSizesFifoPtr) { - connSizesFifoPtr[step%NCCL_STEPS] = nbytes; - } - - if (connPtrsFifoPtr) loadPtr(connPtrsFifoPtr+step%NCCL_STEPS, dsts[DST+index]); - else dsts[DST+index] = directPtr(directOffset); - step += SLICESTEPS; - } - - template - inline __device__ void waitRecv(ssize_t directOffset) { - spins = 0; - while (connTailCache < step + SLICESTEPS) { - connTailCache = *connTailPtr; - if (checkAbort()) break; - } - if (connPtrsFifoPtr) loadPtr(connPtrsFifoPtr+step%NCCL_STEPS, srcs[SRC+index]); - else srcs[SRC+index] = directPtr(directOffset); - step += SLICESTEPS; - } - - inline __device__ void postRecv() { - *connHeadPtr = step += SLICESTEPS; - } - - inline __device__ void postSend() { - *connTailPtr = step += SLICESTEPS; - } - - template - inline __device__ void - GenericOp(const T* srcPtr, T* dstPtr, int nelem, ssize_t directOffset) { - int offset = 0; - int sliceSize = stepSize*SLICESTEPS; - int dataSize = max(DIVUP(nelem, 16*SLICESPERCHUNK)*16, sliceSize/32); - - #pragma unroll - for (int slice=0; slice(directOffset+offset); - if (DST && (role & ROLE_DST)) dsts[0] = dstPtr+offset; - if (SEND && (role & ROLE_WAIT_SEND)) waitSend(directOffset+offset, realSize*sizeof(T)); - if (realSize > 0) { - subBarrier(); - if (DIRECTRECV && srcs[0] == dsts[0]) { - // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy - if (SEND) { - // (1-SEND) is only there to avoid compilation errors in case NSEND=0 (and SEND=0). - ReduceOrCopyMulti(tid, nworkers, 1, srcs, nsend, dsts+1, realSize); - } - } else { - ReduceOrCopyMulti(tid, nworkers, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize); - } - } - } - barrier(); - if (SEND && (role & ROLE_POST_SEND) && realSize > 0 && index == 0) __threadfence_system(); - __syncwarp(); - if (SEND && (role & ROLE_POST_SEND)) postSend(); - if (RECV && (role & ROLE_POST_RECV)) postRecv(); - offset += realSize; - } - } - - // Scatter and gather do not support DIRECT - template - inline __device__ void - ScatterGatherOp(const T* srcPtr, T* dstPtr, int totalElem, int peerElem, int skip, int shift) { - int offset = 0; // slice offset - int sliceSize = stepSize*SLICESTEPS; - int dataSize = max(DIVUP(peerElem, 16*SLICESPERCHUNK)*16, sliceSize/32); // per-peer slice size - - #pragma unroll - for (int slice=0; slice(0); - // realSize is not accurate here; but intra-node does not rely on sizes FIFO - if (SEND && (role & ROLE_WAIT_SEND)) waitSend<0, 0>(0, realSize*sizeof(T)); - subBarrier(); - if (SEND) { - #pragma unroll - for (int j=0; j=0 && i >= skip) peerOffset += peerElem; - const T* src0 = srcPtr + peerOffset; - int realPeerSize = min(realSize, totalElem-peerOffset); - if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, 1, &src0, 1, dsts+i, realPeerSize); - } - } else if (RECV) { - #pragma unroll - for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; - T* dst0 = dstPtr + peerOffset; - int realPeerSize = min(realSize, totalElem-peerOffset); - if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, 1, srcs+i, 1, &dst0, realPeerSize); - } - } - } - barrier(); - if (SEND && (role & ROLE_POST_SEND) && realSize > 0 && index == 0) __threadfence_system(); - __syncwarp(); - if (SEND && (role & ROLE_POST_SEND)) postSend(); - if (RECV && (role & ROLE_POST_RECV)) postRecv(); - offset += realSize; - } - } - - __device__ __forceinline__ void loadRecvConn(struct ncclChannel* channel, T* directBuff) { - if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) { - // For oneshot: groups 0,2 use conn 0, groups 4,6 use conn 1 - const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/4 : 0; - conn = &channel->devPeers[peer].recv[connIndex].conn; - step = conn->step; - step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS); - if (role & ROLE_POST_RECV) { - connHeadPtr = conn->head; - // Return credits in case we rounded up. - *connHeadPtr = step; - } - if (role & ROLE_WAIT_RECV) { - buff = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; - if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) { - direct = directBuff; - *conn->ptrExchange = directBuff; - } - connTailPtr = conn->tail; - connTailCache = *connTailPtr; - connPtrsFifoPtr = conn->ptrsFifo; - } - } - } - - __device__ __forceinline__ void loadSendConn(struct ncclChannel* channel) { - if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) { - // For oneshot: groups 0,2 use conn 0, groups 4,6 use conn 1 - const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/4 : 0; - conn = &channel->devPeers[peer].send[connIndex].conn; - step = conn->step; - step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS); - if (role & ROLE_POST_SEND) { - connTailPtr = conn->tail; - } - if (role & ROLE_WAIT_SEND) { - buff = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; - if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) { - void* volatile* ptr = conn->ptrExchange; - while ((direct = (T*)(*ptr)) == NULL) { if (checkAbort()) break; } - *ptr = NULL; - } - connHeadPtr = conn->head; - connHeadCache = *connHeadPtr; - connSizesFifoPtr = conn->sizesFifo; - connPtrsFifoPtr = conn->ptrsFifo; - } - } - } - - __device__ __forceinline__ void saveSync() { - if (role & (ROLE_POST_SEND|ROLE_POST_RECV)) { - conn->step = step; - __threadfence_system(); - } - } - - public: - __device__ __forceinline__ - ncclPrimitives(const int tid, const int nworkers, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, struct ncclShmemPtrs* ptrs, int group) - : comm(comm), tid(tid), nworkers(nworkers), stepSize(stepSize), srcs((const T**)ptrs[group].srcs), dsts((T**)ptrs[group].dsts), group(group) { - nthreads = nworkers; - // For send operations, we need an extra warp to overlap the threadfence and the copy - int postThreads = NSEND && nworkers >= 64 ? WARP_SIZE : 0; - nthreads += postThreads; - - // Make sure step is updated before we read it. - barrier(); - - for (int i=0; i(src, NULL, nelem, 0); - } - __device__ __forceinline__ void - directSend(const T* src, ssize_t directOffset, int nelem) { - GenericOp<0, 1, 0, 1, 1, 0>(src, NULL, nelem, directOffset); - } - - __device__ __forceinline__ void - recv(T* dst, int nelem) { - GenericOp<0, 0, 1, 0, 0, 1>(NULL, dst, nelem, 0); - } - __device__ __forceinline__ void - directRecv(T* dst, ssize_t directOffset, int nelem) { - GenericOp<1, 0, 1, 0, 0, 1>(NULL, dst, nelem, directOffset); - } - - __device__ __forceinline__ void - copySend(const T* src, T* dst, int nelem) { - GenericOp<0, 0, 0, 1, 1, 1>(src, dst, nelem, 0); - } - __device__ __forceinline__ void - directCopySend(const T* src, T* dst, ssize_t directOffset, int nelem) { - GenericOp<0, 1, 0, 1, 1, 1>(src, dst, nelem, directOffset); - } - - __device__ __forceinline__ void - recvCopySend(T* dst, int nelem) { - GenericOp<0, 0, 1, 1, 0, 1>(NULL, dst, nelem, 0); - } - __device__ __forceinline__ void - directRecvCopySend(T* dst, ssize_t directOffset, int nelem) { - GenericOp<1, 1, 1, 1, 0, 1>(NULL, dst, nelem, directOffset); - } - - __device__ __forceinline__ void - recvReduceCopy(const T* src, T* dst, int nelem) { - GenericOp<0, 0, 1, 0, 1, 1>(src, dst, nelem, 0); - } - - __device__ __forceinline__ void - recvReduceSend(const T* src, int nelem) { - GenericOp<0, 0, 1, 1, 1, 0>(src, NULL, nelem, 0); - } - - __device__ __forceinline__ void - recvReduceCopySend(const T* src, T* dst, int nelem) { - GenericOp<0, 0, 1, 1, 1, 1>(src, dst, nelem, 0); - } - __device__ __forceinline__ void - directRecvReduceCopySend(const T* src, T* dst, ssize_t directOffset, int nelem) { - // Direct is only for the send part - GenericOp<0, 1, 1, 1, 1, 1>(src, dst, nelem, directOffset); - } - - __device__ __forceinline__ void - scatter(const T* src, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<0, 1>(src, NULL, totalElem, peerElem, skip, shift); - } - - __device__ __forceinline__ void - gather(T* dst, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<1, 0>(NULL, dst, totalElem, peerElem, skip, shift); - } - - __device__ __forceinline__ ~ncclPrimitives() { - // Save steps for the next operation - saveSync(); + // Group width is how many consecutive group values a subchannel occupies. + static constexpr int MaxGroupWidth = 2; + __device__ static int calcGroupWidth(bool send, int nthreads) { + return send && nthreads-WARP_SIZE >= 64 ? 2 : 1; } }; -#include "prims_ll.h" -//#include "prims_ll128.h" +struct ProtoLL { + static constexpr int Id = NCCL_PROTO_LL; + // Data bytes (no flags etc) in one step of the fifo queue. + __device__ static int calcBytePerStep() { + return ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/2; // Half is data + } + // Granularity of data bytes transferred per thread. + __device__ static int calcBytePerGrain() { + return sizeof(uint64_t); // One 16-byte line has 8-bytes of data + } + // Group width is how many consecutive group values a subchannel occupies. + static constexpr int MaxGroupWidth = 1; + __device__ static int calcGroupWidth(bool send, int nthreads) { + return 1; + } +}; + +struct ProtoLL128 { + static constexpr int Id = NCCL_PROTO_LL128; + + // Data bytes (no flags etc) in one step of the fifo queue. + __device__ static int calcBytePerStep() { + return (ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS)*NCCL_LL128_DATAELEMS/NCCL_LL128_LINEELEMS; + } + // Granularity of data bytes transferred per thread. + __device__ static int calcBytePerGrain() { + return NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_DATAELEMS*sizeof(uint64_t)/NCCL_LL128_LINEELEMS; + } + // Group width is how many consecutive group values a subchannel occupies. + static constexpr int MaxGroupWidth = 1; + __device__ static int calcGroupWidth(bool send, int nthreads) { + return 1; + } +}; + +/* Fan (as in fan-in & fan-out) classes hold recv and send counts. The template + * arguments are static bounds on the maximum values. Asymmetric counts are + * independent. Symmetric is a static guarantee that nrecv==nsend, so it only + * stores one value at runtime. This optimization save 32-bit register, but more + * importantly uses fewer predicate registers when unrolling loops. + */ +template +struct FanAsymmetric { + static constexpr int MaxRecv = MaxRecv_, MaxSend = MaxSend_; + int nr, ns; + FanAsymmetric() = default; + __device__ FanAsymmetric(int nrecv, int nsend): nr(nrecv), ns(nsend) { + // assert(nrecv <= MaxRecv && nsend <= MaxSend); + } + __device__ int nrecv() const { return MaxRecv ? nr : 0; } + __device__ int nsend() const { return MaxSend ? ns : 0; } +}; + +template +struct FanSymmetric { + static constexpr int MaxRecv = MaxArity, MaxSend = MaxArity; + int n; + FanSymmetric() = default; + __device__ FanSymmetric(int nrecv, int nsend): n(nrecv) { + // assert(nrecv == nsend && nrecv <= MaxArity); + } + __device__ int nrecv() const { return n; } + __device__ int nsend() const { return n; } +}; + +// The primitives class. Specialized per protocol in the other headers. +template +class Primitives; + +// Used by LL & LL128 to implement direct members in the naive way. +template +struct PrimitivesWithoutDirect { + __device__ void directSend(intptr_t inpIx, intptr_t remoteOutIx, int eltN) { + static_cast(this)->send(inpIx, eltN); + } + __device__ void directSendFromOutput(intptr_t outIx, intptr_t remoteOutIx, int eltN) { + static_cast(this)->sendFromOutput(outIx, eltN); + } + __device__ void directRecv(intptr_t outIx, int eltN) { + static_cast(this)->recv(outIx, eltN, /*postOp=*/false); + } + __device__ void directCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { + static_cast(this)->copySend(inpIx, outIx, eltN, postOp); + } + __device__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) { + static_cast(this)->recvCopySend(outIx, eltN, /*postOp=*/false); + } + __device__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { + // Direct is only for the send part + static_cast(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp); + } +}; + +#include "prims_simple.h" +#include "prims_ll.h" +#include "prims_ll128.h" #endif diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index 48972a9..507cfba 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -4,15 +4,20 @@ * See LICENSE.txt for license information ************************************************************************/ -template -class ncclLLPrimitives { - private: +template +class Primitives: + public PrimitivesWithoutDirect> { + + static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; + static constexpr int Input=0, Output=1; + RedOp redOp; const int tid; const int nthreads; const int wid; + const int group; const int stepLines; - int nrecv = 0; - int nsend = 0; + Fan fan; + T *userBufs[2]; struct ncclConnInfo* recvConn = NULL; volatile uint64_t* recvConnHeadPtr = NULL; uint64_t recvConnHead; @@ -23,11 +28,10 @@ class ncclLLPrimitives { uint64_t sendConnHead; uint64_t sendConnHeadCache; // Cache last seen value - uint64_t recvStep[NRECV]; - uint64_t sendStep[NSEND]; - union ncclLLFifoLine* recvBuff[NRECV]; - union ncclLLFifoLine* sendBuff[NSEND]; - struct ncclDevComm* comm; + uint64_t recvStep[MaxRecv]; + uint64_t sendStep[MaxSend]; + union ncclLLFifoLine* recvBuff[MaxRecv]; + union ncclLLFifoLine* sendBuff[MaxSend]; inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; } inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; } @@ -37,27 +41,26 @@ class ncclLLPrimitives { inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); } inline __device__ void barrier() { - asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); + asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(1+group)); } - uint32_t spins = 0; uint32_t abort = 0; - inline __device__ int checkAbort(int i, int send) { + inline __device__ int checkAbort(int &spins, int send) { spins++; - if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { - abort = *(comm->abortFlag); + if (abort == 0 && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { + abort = *ncclShmem.comm.abortFlag; spins = 0; } return abort; } inline __device__ void waitSend(int nbytes) { - spins = 0; if (sendConnHeadPtr) { + int spins = 0; while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { sendConnHeadCache = *sendConnHeadPtr; - if (checkAbort(wid, 1)) break; + if (checkAbort(spins, 1)) break; } if (sendConnFifoPtr) { int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; @@ -85,83 +88,212 @@ class ncclLLPrimitives { sendStep[i]++; } - __device__ uint64_t readLL(int i, int offset) { + __device__ uint64_t readLL(int offset, int i) { union ncclLLFifoLine* src = recvPtr(i) + offset; uint32_t flag = recvFlag(i); uint32_t data1, flag1, data2, flag2; - spins = 0; + int spins = 0; do { - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4)); - if (checkAbort(i, 0)) break; + asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4)); + if (checkAbort(spins, 0)) break; } while ((flag1 != flag) || (flag2 != flag)); uint64_t val64 = data1 + (((uint64_t)data2) << 32); return val64; } + template + __device__ void readLLBeginAll(int offset, ncclLLFifoLine(&line)[MaxRecv]) { + #pragma unroll + for (int i=BeginIx; i < MaxRecv; i++) { + if (i < fan.nrecv()) { + union ncclLLFifoLine* src = recvPtr(i) + offset; + asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4)); + } + } + } + __device__ uint64_t readLLFinish(int offset, ncclLLFifoLine(&line)[MaxRecv], int i) { + union ncclLLFifoLine* src = recvPtr(i) + offset; + uint32_t flag = recvFlag(i); + int spins = 0; + while (line[i].flag1 != flag || line[i].flag2 != flag) { + asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4)); + if (checkAbort(spins, 0)) break; + } + uint64_t val64 = line[i].data1 + (((uint64_t)line[i].data2) << 32); + return val64; + } + __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) { asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag)); } - // Using memcpy handles misaligned pointers. - __device__ uint64_t readAL(uint64_t* src) { - uint64_t val; - memcpy((char*)&val, (char*)src, sizeof(uint64_t)); - return val; + static constexpr int EltPerLine = sizeof(uint64_t)/sizeof(T); + + template + __device__ static U load(U *src) { + union { + U elt; + uint16_t u2; + uint32_t u4; + uint64_t u8; + }; + if(sizeof(U) == 1) + asm("ld.volatile.global.b8 %0,[%1];" : "=r"(u4) : "l"(src)); + else if(sizeof(U) == 2) + asm("ld.volatile.global.b16 %0,[%1];" : "=h"(u2) : "l"(src)); + else if(sizeof(U) == 4) + asm("ld.volatile.global.b32 %0,[%1];" : "=r"(u4) : "l"(src)); + else + asm("ld.volatile.global.b64 %0,[%1];" : "=l"(u8) : "l"(src)); + return elt; } - __device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) { - memcpy((char*)dst, (char*)&val, nbytes); + template + __device__ static void store(U *dst, U val) { + union { + U elt; + uint16_t u2; + uint32_t u4; + uint64_t u8; + }; + elt = val; + if(sizeof(U) == 1) + asm("st.volatile.global.b8 [%0],%1;" :: "l"(dst), "r"(u4)); + else if(sizeof(U) == 2) + asm("st.volatile.global.b16 [%0],%1;" :: "l"(dst), "h"(u2)); + else if(sizeof(U) == 4) + asm("st.volatile.global.b32 [%0],%1;" :: "l"(dst), "r"(u4)); + else + asm("st.volatile.global.b64 [%0],%1;" :: "l"(dst), "l"(u8)); } - template - __device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) { - uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T); - uint32_t npack = DIVUP(nbytes, sizeof(uint64_t)); - uint64_t* srcPack = (uint64_t*)srcPtr; - uint64_t* dstPack = (uint64_t*)dstPtr; - int offset = tid; + struct DataLoader { + int misalign; + union { + uint32_t u4[sizeof(T) <= 2 ? 3 : 2]; + uint64_t u8; + T elt[EltPerLine]; + }; - // Always waitSend in case of cleanup - if (SEND) waitSend(npack*sizeof(union ncclLLFifoLine)); - - // Do multiples of 64 bits - #pragma unroll 2 - for (; offset()(readLL(0, offset), val); - for (int i=1; i()(readLL(i, offset), val); - } + __device__ void loadBegin(T *src, int eltN) { + if (sizeof(T) <= 2) { + misalign = reinterpret_cast(src)%4; + uint32_t *p = reinterpret_cast(reinterpret_cast(src) & -uintptr_t(4)); + u4[0] = load(p+0); + u4[1] = misalign + eltN*sizeof(T) > 4 ? load(p+1) : 0; + // u4[2] would be simpler, but that throws warnings on some compilers + u4[sizeof(T) <= 2 ? 2 : 0] = misalign + eltN*sizeof(T) > 8 ? load(p+2) : 0; } - - // Send : inter-node, then intra-node, then local - if (SEND) { - for (int i=1; i + __device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { + constexpr int SRC = SrcBuf != -1 ? 1 : 0; + constexpr int DST = DstBuf != -1 ? 1 : 0; + T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; + T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx; + + // Always waitSend in case of cleanup + nelem = nelem < 0 ? 0 : nelem; + if (SEND) waitSend(divUp(nelem, EltPerLine)*sizeof(ncclLLFifoLine)); + + nelem -= tid*EltPerLine; + srcElts += tid*EltPerLine; + dstElts += tid*EltPerLine; + int offset = tid; + int eltPerTrip = nthreads*EltPerLine; + while (nelem > 0) { + int eltInLine = EltPerLine < nelem ? EltPerLine : nelem; + + DataLoader dl; + ncclLLFifoLine line[MaxRecv]; + uint64_t data, peerData; + if (SRC) { + dl.loadBegin(srcElts, eltInLine); + srcElts += eltPerTrip; + } + if (RECV) { + readLLBeginAll<1>(offset, line); + peerData = readLL(offset, 0); + } + if (SRC) { + data = dl.loadFinish(); + if (SrcBuf == Input) data = MULTI().preOp(redOp, data); + } + if (RECV) { + data = !SRC ? peerData : MULTI()(redOp, peerData, data); + #pragma unroll MaxRecv + for (int i=1; i < MaxRecv && i < fan.nrecv(); i++) { + peerData = readLLFinish(offset, line, i); + data = MULTI()(redOp, peerData, data); + } + } + + if (postOp) data = MULTI().postOp(redOp, data); + + // Send : inter-node, then intra-node, then local + if (SEND) { + for (int i=1; i < MaxSend && i < fan.nsend(); i++) + storeLL(sendPtr(i)+offset, data, sendFlag(i)); + storeLL(sendPtr(0)+offset, data, sendFlag(0)); + } + if (DST) { + storeData(dstElts, data, eltInLine); + dstElts += eltPerTrip; + } + nelem -= eltPerTrip; + offset += nthreads; + } + + if (RECV) { + for (int i=0; i < MaxRecv; i++) incRecv(i); + postRecv(); + } + if (SEND) { + for (int i=1; i < MaxSend && i < fan.nsend(); i++) + incSend(i, offset); + incSend(0, offset); + } } __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) { recvBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL]; recvStep[i] = conn->step; if (wid == i) recvConn = conn; - nrecv++; } __device__ __forceinline__ void loadRecvSync() { - if (tid >= nthreads-WARP_SIZE && wid < nrecv) { + if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv()) { recvConnHeadPtr = recvConn->head; recvConnHead = recvConn->step; } @@ -171,10 +303,9 @@ class ncclLLPrimitives { sendBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL]; sendStep[i] = conn->step; if (wid == i) sendConn = conn; - nsend++; } __device__ __forceinline__ void loadSendSync() { - if (tid < nsend) { + if (tid < fan.nsend()) { sendConnHeadPtr = sendConn->head; sendConnHeadCache = *sendConnHeadPtr; sendConnHead = sendConn->step; @@ -182,65 +313,74 @@ class ncclLLPrimitives { } } - __device__ __forceinline__ void saveRecvSync() { - if (tid >= nthreads-WARP_SIZE && wid < nrecv) { - recvConn->step = recvConnHead; - __threadfence_block(); - } - } - - __device__ __forceinline__ void saveSendSync() { - if (tid < nsend) { - sendConn->step = sendConnHead; - __threadfence_block(); - } - } - public: - __device__ __forceinline__ - ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines) { - // Make sure step is updated before we read it. - barrier(); + __device__ Primitives( + const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, + void const *inputBuf, void *outputBuf, int group=0 + ): + redOp(FuncTraits().make(ncclShmem.comm.nRanks)), + tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group), + stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) { + auto *channel = &ncclShmem.channel; // If we are going to support oneshot collNet + LL, then we would need to add connector index here - for (int i=0; i= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv->conn, i); - for (int i=0; i= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send->conn, i); + int nrecv=0, nsend=0; + while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) { + loadRecvConn(&channel->devPeers[recvPeers[nrecv]].recv->conn, nrecv); + nrecv++; + } + while (nsend < MaxSend && sendPeers[nsend] >= 0) { + loadSendConn(&channel->devPeers[sendPeers[nsend]].send->conn, nsend); + nsend++; + } + this->fan = Fan(nrecv, nsend); loadRecvSync(); loadSendSync(); + setDataPtrs(inputBuf, outputBuf); } - __device__ void send(const T* src, int nelem) { - return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem); - } - - __device__ void recv(T* dst, int nelem) { - return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem); - } - - __device__ void recvReduceSend(const T* src, int nelem) { - return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem); - } - - __device__ void recvReduceCopy(const T* src, T* dst, int nelem) { - return LLGenericOp<1, 0, 1, 1>(src, dst, nelem); - } - - __device__ void copySend(const T* src, T* dst, int nelem) { - return LLGenericOp<0, 1, 1, 1>(src, dst, nelem); - } - - __device__ void recvCopySend(T* dst, int nelem) { - return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem); - } - - __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) { - return LLGenericOp<1, 1, 1, 1>(src, dst, nelem); - } - - __device__ __forceinline__ ~ncclLLPrimitives() { + __device__ ~Primitives() { // Save steps for the next operation - saveRecvSync(); - saveSendSync(); + if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv()) + recvConn->step = recvConnHead; + if (tid < fan.nsend()) + sendConn->step = sendConnHead; + // Ensure all steps written back + barrier(); + } + + __device__ void setDataPtrs(void const *inputBuf, void *outputBuf) { + userBufs[Input] = (T*)inputBuf; + userBufs[Output] = (T*)outputBuf; + } + + __device__ void moveDataPtrs(intptr_t delta) { + userBufs[Input] += delta; + userBufs[Output] += delta; + } + + __device__ void send(intptr_t inpIx, int eltN) { + return LLGenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false); + } + __device__ void sendFromOutput(intptr_t outIx, int eltN) { + return LLGenericOp<0, 1, Output, -1>(outIx, -1, eltN, false); + } + __device__ void recv(intptr_t outIx, int eltN, bool postOp=false) { + return LLGenericOp<1, 0, -1, Output>(-1, outIx, eltN, postOp); + } + __device__ void recvReduceSend(intptr_t inpIx, int eltN) { + return LLGenericOp<1, 1, Input, -1>(inpIx, -1, eltN, false); + } + __device__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + return LLGenericOp<1, 0, Input, Output>(inpIx, outIx, eltN, postOp); + } + __device__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + return LLGenericOp<0, 1, Input, Output>(inpIx, outIx, eltN, postOp); + } + __device__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) { + return LLGenericOp<1, 1, -1, Output>(-1, outIx, eltN, postOp); + } + __device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + return LLGenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp); } }; diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index 9ce1195..439072e 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -8,17 +8,22 @@ #define NCCL_LL128_FLAGTHREAD (NCCL_LL128_LINEELEMS-1) -template -class ncclLL128Primitives { - private: +template +class Primitives: + public PrimitivesWithoutDirect> { + + static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; + static constexpr int Input=0, Output=1; + RedOp redOp; const int tid; const int nthreads; const int wid; const int stepSize; const int warp; const bool flagThread; - int nrecv = 0; - int nsend = 0; + const int group; + Fan fan; + T *userBufs[2]; struct ncclConnInfo* recvConn = NULL; volatile uint64_t* recvConnHeadPtr = NULL; uint64_t recvConnHead; @@ -31,13 +36,10 @@ class ncclLL128Primitives { uint64_t sendConnHead; uint64_t sendConnHeadCache; // Cache last seen value - uint64_t recvStep[NRECV]; - uint64_t sendStep[NSEND]; - uint64_t* recvBuff[NRECV]; - uint64_t* sendBuff[NSEND]; - struct ncclDevComm* comm; - - volatile uint64_t* shmem; + uint64_t recvStep[MaxRecv]; + uint64_t sendStep[MaxSend]; + uint64_t* recvBuff[MaxRecv]; + uint64_t* sendBuff[MaxSend]; inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepSize; } inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; } @@ -47,31 +49,26 @@ class ncclLL128Primitives { inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; } inline __device__ void barrier() { - if (NSEND>NRECV) { - asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); - } else { - asm volatile ("bar.sync 2, %0;" :: "r"(nthreads)); - } + asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(1+group)); } - uint32_t spins = 0; uint32_t abort = 0; - inline __device__ int checkAbort(int i, int send) { + inline __device__ int checkAbort(int &spins, int i, int send) { spins++; - if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { - abort = *(comm->abortFlag); + if (abort == 0 && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { + abort = *ncclShmem.comm.abortFlag; spins = 0; } return abort; } inline __device__ void waitSend(int nbytes) { - spins = 0; if (sendConnHeadPtr) { + int spins = 0; while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { sendConnHeadCache = *sendConnHeadPtr; - if (checkAbort(wid, 1)) break; + if (checkAbort(spins, wid, 1)) break; } if (sendConnFifoPtr) { sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes; @@ -80,137 +77,185 @@ class ncclLL128Primitives { } } - inline __device__ void incRecv(int i) { - recvStep[i] += 1; - } inline __device__ void postRecv() { if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1; } - - inline __device__ void incSend(int i) { - sendStep[i] += 1; - } inline __device__ void postSend() { if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; } } - template - inline __device__ void loadSrcToShmem128(int maxOffset, const uint64_t* src64Ptr) { -#if 0 - uint64_t v[ELEMS_PER_THREAD]; - #pragma unroll - for (int u=0; u + __device__ __forceinline__ void loadRegsBegin(uint64_t(®s)[WordPerThread], T const *src, int eltN) { + constexpr int EltPer16B = 16/sizeof(T); + if(reinterpret_cast(src)%16 == 0) { + /* We are aligned to 16 bytes, so load directly to registers no shmem. + * Flag threads load half as much data which gets shuffled to the even + * registers during Finish. The point of splitting into two phases is to + * defer that shuffle, which incurs a dependency stall, until after other + * memops are launched by the caller. + */ + #pragma unroll + for(int g=0; g < WordPerThread/2; g++) { + int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); + if(!flagThread || g%2==0) { + if(ix*EltPer16B < eltN) + load128((uint64_t*)(src + ix*EltPer16B), regs[2*g+0], regs[2*g+1]); + } } } -#endif - } + else { + // Not aligned. Stage the smallest 16 byte aligned region subsuming the + // buffer into shmem. + int misalignment = reinterpret_cast(src) % 16; + uint64_t *src8 = reinterpret_cast(reinterpret_cast(src) & -uintptr_t(16)); + uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]); + #pragma unroll + for(int g=0; g < WordPerThread/2; g++) + if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T)) + load128(src8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]); + #pragma unroll + for(int g=0; g < WordPerThread/2; g++) + storeShmem128(shm8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]); - inline __device__ void loadSrcToShmem(int start, int end, const T* srcPtr) { - T* shmemPtr = (T*)(shmem-2*wid); - for (int offset = start+wid; offset < end; offset += WARP_SIZE) { - shmemPtr[offset] = srcPtr[offset]; + __syncwarp(); + + // Now load from shmem stage to regs. Preserve the same pre-shuffled layout + // as the aligned case since Finish() will be applied regardless. + T *shm = (T*)shm8 + misalignment/sizeof(T); + #pragma unroll + for(int g=0; g < WordPerThread/2; g++) { + int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); + if(!flagThread || g%2==0) { + if(ix*EltPer16B < eltN) + loadShmemMisaligned128(shm + ix*EltPer16B, regs[2*g+0], regs[2*g+1]); + } + } } } - template - inline __device__ void storeShmemToDst128(int maxOffset, uint64_t* dst64Ptr) { - uint64_t v[ELEMS_PER_THREAD]; - uint64_t* shmemAsmPtr = shmemCvtPtr(shmem); + template + __device__ __forceinline__ void loadRegsFinish(uint64_t(®s)[WordPerThread]) { + // Move data out of flag registers into the vacant registers. #pragma unroll - for (int u=0; u + __device__ __forceinline__ void storeRegs(T *dst, uint64_t(®s)[WordPerThread], int eltN) { + constexpr int EltPer16B = 16/sizeof(T); + // Reverse Finish() register permuatation. + #pragma unroll + for (int g=1; g < WordPerThread/2; g+=2) { + if (flagThread) regs[2*g-1] = regs[2*g]; } + // Write to dst if 16-byte aligned, shmem otherwise. + int misalignment = reinterpret_cast(dst)%16; + uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]); + #pragma unroll + for(int g=0; g < WordPerThread/2; g++) { + int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); + if (!flagThread || g%2==0) { + if(misalignment == 0 && (ix+1)*EltPer16B <= eltN) + store128((uint64_t*)(dst + ix*EltPer16B), regs[2*g+0], regs[2*g+1]); + else + storeShmem128(shm8+2*ix, regs[2*g+0], regs[2*g+1]); + } + } + __syncwarp(); + // Write rest from shmem to dst. No need to coalesce stores to 16-bytes, + // the hardware keeps up fine. + T *shm = (T*)ncclShmem.ll128warp[warp]; + int skip = misalignment == 0 ? eltN & -EltPer16B : 0; + for(int i=skip+wid; i < eltN; i += WARP_SIZE) + dst[i] = shm[i]; } #define WARP_MASK 0xffffffff - template - __device__ __forceinline__ void recvReduceSendCopy(int ll128Offset) { - uint64_t v[ELEMS_PER_THREAD]; + template + __device__ __forceinline__ void recvReduceSendCopy(uint64_t(&v)[ELEMS_PER_THREAD], int ll128Offset, bool postOp) { + constexpr int SRC = SrcBuf != -1 ? 1 : 0; + uint64_t vr[ELEMS_PER_THREAD]; - /************* Data Loading : SHMEM -> REG **************/ - if (SRC) { - volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS; - #pragma unroll - for (int u=0; u REG ************/ - - /************************ Recv **************************/ + __syncwarp(); + /************************ Wait first recv ********************/ if (RECV) { - uint64_t flag = recvFlag(0); uint64_t* ptr = recvPtr(0)+ll128Offset; + uint64_t flag = recvFlag(0); bool needReload; - uint64_t v0, v1; + int spins = 0; do { needReload = false; #pragma unroll for (int u=0; u().preOp(redOp, v[u]); + if (!flagThread) + v[u+1] = MULTI().preOp(redOp, v[u+1]); + } + } + } + + /************************ Recv rest *********************/ + if (RECV) { + { // Consume data from first recv + uint64_t* ptr = recvPtr(0)+ll128Offset; + #pragma unroll + for (int u=0; u()(redOp, vr[u], v[u]) : vr[u]; + v[u+1] = SRC ? MULTI()(redOp, vr[u+1], v[u+1]) : vr[u+1]; } - } while (__any_sync(WARP_MASK, needReload) && checkAbort(0, 0) == 0); - #pragma unroll - for (int u=0; u()(v0, v[u]) : v0; - v[u+1] = SRC ? MULTI()(v1, v[u+1]) : v1; } - for (int i=1; i()(v0, v[u]); - v[u+1] = MULTI()(v1, v[u+1]); + v[u] = MULTI()(redOp, vr[u], v[u]); + v[u+1] = MULTI()(redOp, vr[u+1], v[u+1]); } } } /********************** End Recv ************************/ + if (postOp && !FuncTraits::IsPostOpIdentity) { + #pragma unroll + for (int u=0; u().postOp(redOp, v[u]); + v[u+1] = MULTI().postOp(redOp, v[u+1]); + } + } + /************************ Send **************************/ if (SEND) { - for (int i=1; i SHMEM **************/ - if (DST) { - volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS; - #pragma unroll - for (int u=0; u SHMEM ************/ } - #define LL128INC (WARP_SIZE*NCCL_LL128_SHMEM_ELEMS_PER_THREAD) - #define ELEMINC (LL128INC-(LL128INC/NCCL_LL128_LINEELEMS)) + static constexpr int WireWordPerSlice = WARP_SIZE*NCCL_LL128_SHMEM_ELEMS_PER_THREAD; + static constexpr int DataEltPerSlice = (WireWordPerSlice - WireWordPerSlice/NCCL_LL128_LINEELEMS)*(sizeof(uint64_t)/sizeof(T)); - template - __device__ void GenericOp(const T* srcPtr, T* dstPtr, int nelem) { - if (nelem <= 0) { - // Don't move any data but still increase steps and sync with prev/next - if (SEND) waitSend(0); - FOR_SEND(incSend); if (SEND) postSend(); - FOR_RECV(incRecv); if (RECV) postRecv(); - return; - } - const int nelem64 = ((nelem*sizeof(T))/(2*sizeof(uint64_t)))*2; - const uint64_t* src64Ptr = ((uint64_t*)srcPtr); - uint64_t* dst64Ptr = ((uint64_t*)dstPtr); + template + __device__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { + constexpr int SRC = SrcBuf != -1 ? 1 : 0; + constexpr int DST = DstBuf != -1 ? 1 : 0; + static_assert(-1<=SrcBuf && SrcBuf < 2, "Uhoh"); + static_assert(-1<=DstBuf && DstBuf < 2, "Uhoh"); + static_assert(DstBuf!=Input, "Mistake?"); + #if 0 + assert((SrcBuf==-1) == (srcIx==-1)); + assert((DstBuf==-1) == (dstIx==-1)); + #endif - int ll128Offset = LL128INC*warp+2*wid; - int elemOffset = ELEMINC*warp; + T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; + T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx; + int wireOffset = WireWordPerSlice*warp + 2*wid; const int nwarps = nthreads/WARP_SIZE; + nelem = nelem < 0 ? 0 : nelem; - if (SEND) waitSend(DIVUP(nelem*sizeof(T), ELEMINC*sizeof(uint64_t))*LL128INC*sizeof(uint64_t)); + if (SEND) waitSend(divUp(nelem, DataEltPerSlice)*WireWordPerSlice*sizeof(uint64_t)); barrier(); + nelem -= DataEltPerSlice*warp; + srcPtr += DataEltPerSlice*warp; + dstPtr += DataEltPerSlice*warp; + while (nelem > 0) { + const int eltInSlice = min(nelem, DataEltPerSlice); + uint64_t regs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD]; + if (SRC) loadRegsBegin(regs, srcPtr, eltInSlice); + recvReduceSendCopy(regs, wireOffset, postOp); + if (DST) storeRegs(dstPtr, regs, eltInSlice); - while (elemOffset*(sizeof(uint64_t)/sizeof(T)) < nelem) { - const int maxOffset128 = min(nelem64-elemOffset, (int)ELEMINC); - const int maxOffset = min(nelem-(elemOffset*((int)(sizeof(uint64_t)/sizeof(T)))), (int)(ELEMINC*(sizeof(uint64_t)/sizeof(T)))); - if (SRC) { - int done = 0; - if ((((uint64_t)srcPtr)&0xf) == 0) { - loadSrcToShmem128(maxOffset128-2*wid, src64Ptr+elemOffset+2*wid); - done = maxOffset128*(sizeof(uint64_t)/sizeof(T)); - } - loadSrcToShmem(done, maxOffset, (T*)(src64Ptr+elemOffset)); - } - __syncwarp(); - recvReduceSendCopy(ll128Offset); - __syncwarp(); - if (DST) { - int done = 0; - if ((((uint64_t)dstPtr)&0xf) == 0) { - storeShmemToDst128(maxOffset128-2*wid, dst64Ptr+elemOffset+2*wid); - done = maxOffset128*(sizeof(uint64_t)/sizeof(T)); - } - storeShmemToDst(done, maxOffset, (T*)(dst64Ptr+elemOffset)); - } - __syncwarp(); - ll128Offset += LL128INC*nwarps; - elemOffset += ELEMINC*nwarps; + wireOffset += WireWordPerSlice*nwarps; + srcPtr += DataEltPerSlice*nwarps; + dstPtr += DataEltPerSlice*nwarps; + nelem -= DataEltPerSlice*nwarps; } barrier(); - FOR_SEND(incSend); if (SEND) postSend(); - FOR_RECV(incRecv); if (RECV) postRecv(); + if (SEND) for (int i=0; i < MaxSend; i++) sendStep[i] += 1; + if (SEND) postSend(); + if (RECV) for (int i=0; i < MaxRecv; i++) recvStep[i] += 1; + if (RECV) postRecv(); } __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) { recvBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128]; recvStep[i] = conn->step; if (wid == i) recvConn = conn; - nrecv++; } __device__ __forceinline__ void loadRecvSync() { - if (tid >= nthreads-WARP_SIZE && wid < nrecv) { + if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv()) { recvConnHeadPtr = recvConn->head; recvConnHead = recvConn->step; } @@ -311,16 +335,15 @@ class ncclLL128Primitives { sendBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128]; sendStep[i] = conn->step; if (wid == i) sendConn = conn; - nsend++; } __device__ __forceinline__ void loadSendSync() { - if (tid < nsend) { + if (tid < fan.nsend()) { sendConnHeadPtr = sendConn->head; sendConnHeadCache = *sendConnHeadPtr; sendConnHead = sendConn->step; sendConnFifoPtr = sendConn->sizesFifo; } - if (tid >= nthreads-WARP_SIZE && wid= nthreads-WARP_SIZE && widsizesFifo) { sendConnTailPtr = sendConn->tail; sendConnTail = sendConn->step; @@ -328,64 +351,74 @@ class ncclLL128Primitives { } } - __device__ __forceinline__ void saveRecvSync() { - if (tid >= nthreads-WARP_SIZE && wid < nrecv) { - recvConn->step = recvConnHead; - __threadfence_block(); +public: + __device__ Primitives( + const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, + void const *inputBuf, void *outputBuf, int group=0 + ): + redOp(FuncTraits().make(ncclShmem.comm.nRanks)), + tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), + flagThread((tid%8)==7), group(group), + stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) { + + auto *channel = &ncclShmem.channel; + int nrecv=0, nsend=0; + while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) { + loadRecvConn(&channel->devPeers[recvPeers[nrecv]].recv->conn, nrecv); + nrecv++; } - } - - __device__ __forceinline__ void saveSendSync() { - if (tid < nsend) { - sendConn->step = sendConnHead; - __threadfence_block(); + while (nsend < MaxSend && sendPeers[nsend] >= 0) { + loadSendConn(&channel->devPeers[sendPeers[nsend]].send->conn, nsend); + nsend++; } - } - - public: - __device__ __forceinline__ - ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem->data+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) { - // Make sure step is updated before we read it. - barrier(); - - for (int i=0; i= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv->conn, i); - for (int i=0; i= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send->conn, i); + this->fan = Fan(nrecv, nsend); loadRecvSync(); loadSendSync(); + setDataPtrs(inputBuf, outputBuf); } - __device__ void send(const T* src, int nelem) { - return GenericOp<0, 1, 1, 0>(src, NULL, nelem); - } - - __device__ void recv(T* dst, int nelem) { - return GenericOp<1, 0, 0, 1>(NULL, dst, nelem); - } - - __device__ void recvReduceSend(const T* src, int nelem) { - return GenericOp<1, 1, 1, 0>(src, NULL, nelem); - } - - __device__ void recvReduceCopy(const T* src, T* dst, int nelem) { - return GenericOp<1, 0, 1, 1>(src, dst, nelem); - } - - __device__ void copySend(const T* src, T* dst, int nelem) { - return GenericOp<0, 1, 1, 1>(src, dst, nelem); - } - - __device__ void recvCopySend(T* dst, int nelem) { - return GenericOp<1, 1, 0, 1>(NULL, dst, nelem); - } - - __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) { - return GenericOp<1, 1, 1, 1>(src, dst, nelem); - } - - __device__ __forceinline__ ~ncclLL128Primitives() { + __device__ ~Primitives() { // Save steps for the next operation - saveRecvSync(); - saveSendSync(); + if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv()) + recvConn->step = recvConnHead; + if (tid < fan.nsend()) + sendConn->step = sendConnHead; + // Ensure all steps written back + barrier(); + } + + __device__ void setDataPtrs(void const *inputBuf, void *outputBuf) { + userBufs[Input] = (T*)inputBuf; + userBufs[Output] = (T*)outputBuf; + } + + __device__ void moveDataPtrs(intptr_t delta) { + userBufs[Input] += delta; + userBufs[Output] += delta; + } + + __device__ void send(intptr_t inpIx, int eltN) { + return GenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false); + } + __device__ void sendFromOutput(intptr_t outIx, int eltN) { + return GenericOp<0, 1, Output, -1>(outIx, -1, eltN, false); + } + __device__ void recv(intptr_t outIx, int eltN, bool postOp=false) { + return GenericOp<1, 0, -1, Output>(-1, outIx, eltN, postOp); + } + __device__ void recvReduceSend(intptr_t inpIx, int eltN) { + return GenericOp<1, 1, Input, -1>(inpIx, -1, eltN, false); + } + __device__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + return GenericOp<1, 0, Input, Output>(inpIx, outIx, eltN, postOp); + } + __device__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + return GenericOp<0, 1, Input, Output>(inpIx, outIx, eltN, postOp); + } + __device__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) { + return GenericOp<1, 1, -1, Output>(-1, outIx, eltN, postOp); + } + __device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + return GenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp); } }; diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h new file mode 100644 index 0000000..9238d63 --- /dev/null +++ b/src/collectives/device/prims_simple.h @@ -0,0 +1,463 @@ +/************************************************************************* + * Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +template +class Primitives< + T, RedOp, Fan, Direct, ProtoSimple + > { + static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; + static constexpr int Input=0, Output=1; + static constexpr int RoleInput = 0x01, + RoleOutput = 0x02, + RoleWaitRecv = 0x04, + RoleWaitSend = 0x08, + RolePostSend = 0x10, + RolePostRecv = 0x20, + Aborted = 0x40, + PtrsFifoEnabled = 0x80, + SizesFifoEnabled = 0x100, + DirectEnabled = 0x200, + ThreadsSynced = 0x400; + const int tid; + int nthreads; + int nworkers; + const int stepSize; + Fan fan; + RedOp const redOp; + int index; // Peer index I'm responsible for + int flags; + int group; + uint64_t step; + union { + void **connPtrsFifoPtr; // (flags & PtrsFifoEnabled) + T *userBuff; // (flags & (RoleInput|RoleOutput)) + T *connEltsFifo; // !(flags & (PtrsFifoEnabled|RoleInput|RoleOutput)) + }; + union { + int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled) + T *directBuff; // !(flags & SizesFifoEnabled) + }; + uint64_t volatile *connStepPtr; + uint64_t connStepCache; // Cache last seen value of (*connStepPtr) + + // Don't use barrier 0 as it's used by the final sync + inline __device__ void barrier() { + if (nthreads == WARP_SIZE) + __syncwarp(); + else + asm volatile("bar.sync %0, %1;" :: "r"(group+1), "r"(nthreads)); + flags |= ThreadsSynced; + } + inline __device__ void subBarrier() { + if (nworkers == nthreads) + barrier(); + else + asm volatile("bar.sync %0, %1;" :: "r"(group+2), "r"(nworkers)); + } + + inline __device__ bool checkAbort(int &spins) { + spins++; + if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { + flags |= *ncclShmem.comm.abortFlag ? Aborted : 0; + spins = 0; + } + return flags & Aborted; + } + + template + inline __device__ void waitPeer(intptr_t dstIx, intptr_t remoteOutIx, int offset, int nelts) { + if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) { + bool const isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; + int spins = 0; + while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { + connStepCache = *connStepPtr; + if (checkAbort(spins)) break; + //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); + } + + if (isSendNotRecv && (flags & SizesFifoEnabled)) + connSizesFifoPtr[step%NCCL_STEPS] = nelts*sizeof(T); + + void **ptrs = isSendNotRecv ? (ncclShmem.groups[group].dsts + Dst) + : (ncclShmem.groups[group].srcs + Src); + if (flags & PtrsFifoEnabled) + loadPtr(connPtrsFifoPtr + step%NCCL_STEPS, ptrs[index]); + else if ((isSendNotRecv ? DirectSend : DirectRecv) && (flags & DirectEnabled)) + ptrs[index] = directBuff + (isSendNotRecv ? remoteOutIx : dstIx) + offset; + else + ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; + step += StepPerSlice; + } + } + + template + inline __device__ void postPeer() { + if (flags & (Recv*RolePostRecv | Send*RolePostSend)) { + step += StepPerSlice; + *connStepPtr = step; + } + } + + template + inline __device__ void genericOp( + intptr_t srcIx, intptr_t dstIx, intptr_t remoteOutIx, int nelem, bool postOp + ) { + constexpr int DirectRecv = 1 && Direct && DirectRecv1; + constexpr int DirectSend = 1 && Direct && DirectSend1; + constexpr int Src = SrcBuf != -1; + constexpr int Dst = DstBuf != -1; + + nelem = nelem < 0 ? 0 : nelem; + int sliceSize = stepSize*StepPerSlice; + sliceSize = max(divUp(nelem, 16*SlicePerChunk)*16, sliceSize/32); + int slice = 0; + int offset = 0; + + if (tid < nworkers && offset < nelem) { + // Worker-only loop for non-empty slices. Non-workers and empty slices are + // processed in the loop following this if block. The benefit of splitting + // the loop like this is we pull two branches out of the critical path. + // Using "number of branch insns (taken or not) encountered dynamically" + // as the performance metric, then: + // perf_orig = 2*numslices + // perf_new = 2+numslices + // So the new code and old code behave the same for numslices=2, and for + // numslices>2 the new code is superior. And note that in the case + // numslices=1, the loop is trivially unrollable (single iteration) so we + // don't incur that that tail branch and we still have perf_new=2. + // + // ORIGINAL CODE: + // unrolled for(slices) { + // if(worker) { // This branch removed + // wait(); + // subBarrier(); + // if(slice not empty) // This branch removed + // ReduceCopyMulti(); + // } + // barrier(); + // post(); + // } // Since we no longer unroll, new branch added here + #if __CUDA_ARCH__ < 700 + // Yeah, so all that above don't matter a lick on older hardware. + #pragma unroll SlicePerChunk + #else + #pragma unroll 1 + #endif + do { + sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; + if (Src && (flags & (SrcBuf==Input ? RoleInput : RoleOutput))) + ncclShmem.groups[group].srcs[0] = userBuff + srcIx + offset; + if (Dst && (flags & (DstBuf==Input ? RoleInput : RoleOutput))) + ncclShmem.groups[group].dsts[0] = userBuff + dstIx + offset; + waitPeer(dstIx, remoteOutIx, offset, sliceSize); + subBarrier(); + if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { + // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy + if (Send) { + // (1-Send) is only there to avoid compilation errors in case MaxSend=0 (and Send=0). + ReduceOrCopyMulti + (tid, nworkers, redOp, false, false, + 1, (T const**)ncclShmem.groups[group].srcs, + fan.nsend(), (T**)ncclShmem.groups[group].dsts+1, + sliceSize); + } + } else { + ReduceOrCopyMulti + (tid, nworkers, redOp, SrcBuf==Input, postOp, + Recv*fan.nrecv()+Src, (T const**)ncclShmem.groups[group].srcs, + Send*fan.nsend()+Dst, (T**)ncclShmem.groups[group].dsts, + sliceSize); + } + barrier(); // This barrier has a counterpart in following loop + if (Send && (flags & RolePostSend) && index == 0) __threadfence_system(); + __syncwarp(); + postPeer(); + offset += sliceSize; + slice += 1; + } while (slice < SlicePerChunk && offset < nelem); + } + + // Non-workers come straight here. Workers too but only once the remaining + // slices are all empty. Since empty slices are the uncommon case, and + // worker perf is the limiter, perf-wise this loop is effectively unentered, + // hence just a single branch insn. + #pragma unroll 1 + while (slice < SlicePerChunk) { + sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; + { // Only workers could have Wait roles so we know the slice must be empty + // since we've exited the loop above. + waitPeer(0, 0, 0, 0); + } + barrier(); // Has couterpart in preceding worker-only loop. + if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) __threadfence_system(); + __syncwarp(); + postPeer(); + offset += sliceSize; + slice += 1; + } + } + + // Scatter and gather do not support Direct + template + inline __device__ void + ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp) { + int offset = 0; // slice offset + int sliceSize = stepSize*StepPerSlice; + int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size + + #pragma unroll + for (int slice=0; slice(0, 0, 0, realSize); + subBarrier(); + if (Send) { + #pragma unroll + for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; + const T* src0 = (T*)ncclShmem.groups[group].srcs[0] + peerOffset; + int realPeerSize = min(realSize, totalElem-peerOffset); + if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, redOp, true, false, 1, &src0, 1, (T**)ncclShmem.groups[group].dsts+i, realPeerSize); + } + } else if (Recv) { + #pragma unroll + for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; + T* dst0 = (T*)ncclShmem.groups[group].dsts[0] + peerOffset; + int realPeerSize = min(realSize, totalElem-peerOffset); + if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, redOp, false, postOp, 1, (T const**)ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); + } + } + } + barrier(); + if (Send && (flags & RolePostSend) && realSize > 0 && index == 0) __threadfence_system(); + __syncwarp(); + postPeer(); + offset += realSize; + } + } + + __device__ __forceinline__ void loadRecvConn(ncclPeer *peer) { + if (flags & (RoleWaitRecv|RolePostRecv)) { + // For other colls: group <= 2, hence always use conn 0 + // For P2P: Direct is set to 1, hence always use conn 0 + // Ideally we should be accepting connIndex from the constructor! + const int connIndex = Direct ? 0 : group/4; + auto *conn = &peer->recv[connIndex].conn; + step = conn->step; + step = roundUp(step, SlicePerChunk*StepPerSlice); + if (flags & RolePostRecv) { + connStepPtr = conn->head; + *connStepPtr = step; // Return credits in case we rounded up. + } + if (flags & RoleWaitRecv) { + ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs() + connStepPtr = conn->tail; + connStepCache = *connStepPtr; + flags |= (conn->ptrsFifo != nullptr) ? PtrsFifoEnabled : 0; + flags |= (Direct && (conn->direct & NCCL_DIRECT_GPU)) ? DirectEnabled : 0; + if (flags & PtrsFifoEnabled) + connPtrsFifoPtr = conn->ptrsFifo; + else + connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; + } + } + } + + __device__ __forceinline__ void loadSendConn(ncclPeer *peer) { + if (flags & (RoleWaitSend|RolePostSend)) { + // For other colls: group <= 2, hence always use conn 0 + // For P2P: Direct is set to 1, hence always use conn 0 + // Ideally we should be accepting connIndex from the constructor! + const int connIndex = Direct ? 0 : group/4; + auto *conn = &peer->send[connIndex].conn; + step = conn->step; + step = roundUp(step, SlicePerChunk*StepPerSlice); + if (flags & RolePostSend) { + connStepPtr = conn->tail; + } + if (flags & RoleWaitSend) { + ncclShmem.groups[group].sendConns[index] = conn; // WaitSend role saves since that's who needs it in setDataPtrs() + connStepPtr = conn->head; + connStepCache = *connStepPtr; + flags |= (conn->ptrsFifo != nullptr) ? PtrsFifoEnabled : 0; + if (flags & PtrsFifoEnabled) + connPtrsFifoPtr = conn->ptrsFifo; + else + connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; + + if (conn->sizesFifo != nullptr) { + flags |= SizesFifoEnabled; + connSizesFifoPtr = conn->sizesFifo; + } + else if (Direct && (conn->direct & NCCL_DIRECT_GPU)) + flags |= DirectEnabled; + } + } + } + + public: + __device__ Primitives( + int tid, int nthreads, int const *recvPeers, int const *sendPeers, + void const *inputBuf, void *outputBuf, int group=0 + ): + tid(tid), + stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)), + redOp(FuncTraits::make(ncclShmem.comm.nRanks)) { + + // For send operations, we need an extra warp to overlap the threadfence and the copy + this->nthreads = nthreads; + this->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0); + this->group = group; + + int nrecv=0, nsend=0; + while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++; + while (nsend < MaxSend && sendPeers[nsend] != -1) nsend++; + this->fan = Fan(nrecv, nsend); + + constexpr int ThreadPerSync = 8; + static_assert(MaxSend < ThreadPerSync && MaxRecv < ThreadPerSync, "Not enough threads to cover all peers"); + + int g = tid / ThreadPerSync; + int ng = nthreads / ThreadPerSync; + index = tid % ThreadPerSync; + flags = 0; + if (g == 0) { + if (index < nrecv) flags |= RoleWaitRecv; + if (index == nrecv) flags |= RoleInput; + } else if (g == 1) { + if (index < nsend) flags |= RoleWaitSend; + if (index == nsend) flags |= RoleOutput; + } else if (g == ng - 2) { + if (index < nrecv) flags |= RolePostRecv; + } else if (g == ng - 1) { + if (index < nsend) flags |= RolePostSend; + } + + int peer = 0; + if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index]; + if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index]; + + loadRecvConn(&ncclShmem.channel.devPeers[peer]); + loadSendConn(&ncclShmem.channel.devPeers[peer]); + + setDataPtrs(inputBuf, outputBuf); + } + + __device__ ~Primitives() { + // Ensure ncclShmem.groups[].send/recvConns are available + if (!(flags & ThreadsSynced)) + barrier(); + // Save steps for the next operation + if (flags & (RolePostSend|RolePostRecv)) { + auto *conns = (flags & RolePostSend) ? ncclShmem.groups[group].sendConns : ncclShmem.groups[group].recvConns; + conns[index]->step = step; + } + // Make sure all threads are done writing back conn->step and done using + // ncclShmem.groups[group] + barrier(); + } + + __device__ void setDataPtrs(void const *inputBuf, void *outputBuf) { + if (flags & RoleInput) userBuff = (T*)inputBuf; + if (flags & RoleOutput) userBuff = (T*)outputBuf; + if (Direct && flags == (flags|RoleWaitRecv|DirectEnabled)) { + int spins = 0; + void *volatile *slot = ncclShmem.groups[group].recvConns[index]->ptrExchange; + // Wait for consumer to consume previous value before trampling it. + while (*slot != nullptr && !checkAbort(spins)); + directBuff = (T*)outputBuf; + // Encode pointer by XOR'ing against some address they definitely wouldn't send + // since we want to allow them sending us nullptr while not colliding with + // the empty slot value. + *slot = reinterpret_cast(reinterpret_cast(outputBuf) ^ reinterpret_cast(slot)); + } + if (Direct && flags == (flags|RoleWaitSend|DirectEnabled)) { + int spins = 0; + void *volatile *slot = ncclShmem.groups[group].sendConns[index]->ptrExchange; + void *ptr; + while (true) { + ptr = *slot; + if (ptr != nullptr || checkAbort(spins)) break; + } + directBuff = reinterpret_cast(reinterpret_cast(ptr) ^ reinterpret_cast(slot)); + *slot = nullptr; + } + } + + __device__ void moveDataPtrs(intptr_t delta) { + if (flags & (RoleInput|RoleOutput)) + userBuff += delta; + } + + __device__ __forceinline__ void send(intptr_t inpIx, int eltN) { + genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, -1, eltN, false); + } + __device__ __forceinline__ void sendFromOutput(intptr_t outIx, int eltN) { + genericOp<0, 0, 0, 1, Output, -1>(outIx, -1, -1, eltN, false); + } + __device__ __forceinline__ void directSend(intptr_t inpIx, intptr_t remoteOutIx, int eltN) { + genericOp<0, 1, 0, 1, Input, -1>(inpIx, -1, remoteOutIx, eltN, false); + } + __device__ __forceinline__ void directSendFromOutput(intptr_t outIx, intptr_t remoteOutIx, int eltN) { + genericOp<0, 1, 0, 1, Output, -1>(outIx, -1, remoteOutIx, eltN, false); + } + + __device__ __forceinline__ void recv(intptr_t outIx, int eltN, bool postOp=false) { + genericOp<0, 0, 1, 0, -1, Output>(-1, outIx, -1, eltN, postOp); + } + __device__ __forceinline__ void directRecv(intptr_t outIx, int eltN) { + genericOp<1, 0, 1, 0, -1, Output>(-1, outIx, -1, eltN, /*postOp=*/false); + } + + __device__ __forceinline__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + genericOp<0, 0, 0, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp); + } + __device__ __forceinline__ void directCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { + genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp); + } + + __device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) { + genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, -1, eltN, postOp); + } + __device__ __forceinline__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) { + genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, false); + } + + __device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, -1, eltN, postOp); + } + + __device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) { + genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, -1, eltN, postOp); + } + + __device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp); + } + __device__ __forceinline__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { + // Direct is only for the send part + genericOp<0, 1, 1, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp); + } + + __device__ __forceinline__ void + scatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) { + ScatterGatherOp<0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false); + } + + __device__ __forceinline__ void + gather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp=false) { + ScatterGatherOp<1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp); + } +}; diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h index 313209d..1ce4c2e 100644 --- a/src/collectives/device/reduce.h +++ b/src/collectives/device/reduce.h @@ -5,148 +5,87 @@ ************************************************************************/ #include "devcomm.h" -#include "primitives.h" #include "collectives.h" +#include "primitives.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); - const int chunkSize = stepSize * REDUCE_CHUNKSTEPS; - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*(ssize_t)chunkSize; - const ssize_t size = args->coll.count; - const int rank = ring->devUserRanks[0]; - const int prevRank = ring->devUserRanks[nranks-1]; - const int root = args->coll.root; +namespace { + template + __device__ void runRing(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; + ncclRing *ring = &ncclShmem.channel.ring; + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCE_CHUNKSTEPS : 1)); + const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))); + const int nranks = ncclShmem.comm.nRanks; + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; + const int rank = ncclShmem.comm.rank; + const int prevRank = ring->devUserRanks[nranks-1]; + const int root = args->coll.root; - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; + Primitives, 0, Proto> + prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff); - ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0); + auto calcChunkSize = [&]__device__(ssize_t gridOffset)->int { + int realChunkSize; + if (Proto::Id == NCCL_PROTO_SIMPLE) { + realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels)); + realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); + } + else if (Proto::Id == NCCL_PROTO_LL) + realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize; + else if (Proto::Id == NCCL_PROTO_LL128) + realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize); + return realChunkSize; + }; + if (prevRank == root) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); - ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); + int realChunkSize = calcChunkSize(gridOffset); ssize_t offset = gridOffset + bid*realChunkSize; int nelem = min(realChunkSize, size-offset); - if (prevRank == root) { - prims.send(thisInput+offset, nelem); - } else if (rank == root) { - prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem); - } else { - prims.recvReduceSend(thisInput+offset, nelem); - } + prims.send(offset, nelem); } } -}; - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); - ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; - const int rank = comm->rank; - const int prevRank = ring->devUserRanks[nranks-1]; - const int root = args->coll.root; - - ncclLLPrimitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm); - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - + else if (rank == root) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - if (size-gridOffset < loopSize) { - chunkSize = args->coll.lastChunkSize; - } - ssize_t offset = gridOffset + bid*chunkSize; - - int nelem = min(chunkSize, size-offset); - if (prevRank == root) { - LLprims.send(thisInput+offset, nelem); - } else if (rank == root) { - LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem); - } else { - LLprims.recvReduceSend(thisInput+offset, nelem); - } + int realChunkSize = calcChunkSize(gridOffset); + ssize_t offset = gridOffset + bid*realChunkSize; + int nelem = min(realChunkSize, size-offset); + prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true); } } -}; - -#include "prims_ll128.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); - ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T)); - const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; - const int rank = comm->rank; - const int prevRank = ring->devUserRanks[nranks-1]; - const int root = args->coll.root; - - ncclLL128Primitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm); - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - + else { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize); - ssize_t offset = gridOffset + bid*chunkSize; - - int nelem = min(chunkSize, size-offset); - if (prevRank == root) { - LLprims.send(thisInput+offset, nelem); - } else if (rank == root) { - LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem); - } else { - LLprims.recvReduceSend(thisInput+offset, nelem); - } + int realChunkSize = calcChunkSize(gridOffset); + ssize_t offset = gridOffset + bid*realChunkSize; + int nelem = min(realChunkSize, size-offset); + prims.recvReduceSend(offset, nelem); } } + } +} + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + using Proto = ProtoSimple; + runRing(args); + } }; -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } }; -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } }; diff --git a/src/collectives/device/reduce_kernel.h b/src/collectives/device/reduce_kernel.h index 0e90793..87a6823 100644 --- a/src/collectives/device/reduce_kernel.h +++ b/src/collectives/device/reduce_kernel.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -10,6 +10,7 @@ #include "common_kernel.h" #include +#include template struct FuncNull { @@ -46,6 +47,18 @@ struct FuncMin { } }; +template +struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max + static constexpr bool IsPreOpIdentity = true; + static constexpr bool IsPostOpIdentity = true; + + __device__ static Fn make(int rankN) { return Fn(); } + template + __device__ static T preOp(Fn, T x) { return x; } + template + __device__ static T postOp(Fn, T x) { return x; } +}; + #define MASK0 0x00ff00ff #define MASK1 0xff00ff00 static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) { @@ -239,6 +252,31 @@ struct FuncSum { } }; +#if defined(__CUDA_BF16_TYPES_EXIST__) +template<> +struct FuncSum<__nv_bfloat16> { + __device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const { +#if __CUDA_ARCH__ >= 800 + return __hadd2(x, y); +#else + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#endif + } + __device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const { +#if __CUDA_ARCH__ >= 800 + return __hadd(x, y); +#else + return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); +#endif + } +}; +#endif + template<> struct FuncProd { __device__ half2 operator()(const half2 x, const half2 y) const { @@ -262,6 +300,31 @@ struct FuncProd { } }; +#if defined(__CUDA_BF16_TYPES_EXIST__) +template<> +struct FuncProd<__nv_bfloat16> { + __device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const { +#if __CUDA_ARCH__ >= 800 + return __hmul2(x, y); +#else + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#endif + } + __device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const { +#if __CUDA_ARCH__ >= 800 + return __hmul(x, y); +#else + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); +#endif + } +}; +#endif + template<> struct FuncMax { __device__ half2 operator()(const half2 x, const half2 y) const { @@ -281,6 +344,34 @@ struct FuncMax { } }; +#if defined(__CUDA_BF16_TYPES_EXIST__) +template<> +struct FuncMax<__nv_bfloat16> { + __device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const { +#if __CUDA_ARCH__ >= 800 + return __hmax2(x, y); +#else + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fmaxf(fxl, fyl), fmaxf(fxh, fyh)); +#endif + } + __device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const { +#if __CUDA_ARCH__ >= 800 + return __hmax(x, y); +#else + float fx, fy; + fx = __bfloat162float(x); + fy = __bfloat162float(y); + return __float2bfloat16(fmaxf(fx, fy)); +#endif + } +}; +#endif + template<> struct FuncMin { __device__ half2 operator()(const half2 x, const half2 y) const { @@ -299,4 +390,269 @@ struct FuncMin { return __float2half(fm); } }; + +#if defined(__CUDA_BF16_TYPES_EXIST__) +template<> +struct FuncMin<__nv_bfloat16> { + __device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const { +#if __CUDA_ARCH__ >= 800 + return __hmin2(x, y); +#else + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fminf(fxl, fyl), fminf(fxh, fyh)); +#endif + } + __device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const { +#if __CUDA_ARCH__ >= 800 + return __hmin(x, y); +#else + float fx, fy; + fx = __bfloat162float(x); + fy = __bfloat162float(y); + return __float2bfloat16(fminf(fx, fy)); +#endif + } +}; +#endif + +template<> +struct FuncMax { + __device__ float operator()(float x, float y) const { + return fmaxf(x, y); + } +}; +template<> +struct FuncMin { + __device__ float operator()(float x, float y) const { + return fminf(x, y); + } +}; + +template<> +struct FuncMax { + __device__ double operator()(double x, double y) const { + return fmax(x, y); + } +}; +template<> +struct FuncMin { + __device__ double operator()(double x, double y) const { + return fmin(x, y); + } +}; + +template +struct FuncAvg: FuncSum { + static_assert(!std::is_floating_point::value, "Uhoh"); + static constexpr bool IsPreOpIdentity = true; + static constexpr bool IsPostOpIdentity = false; + int n; + + template + __device__ FuncAvg(int n): n(n) {} + + __device__ T preOp(T x) const { + return x; + } + __device__ T postOp(T x) const { + return T(x/n); + } +}; + +template<> +struct FuncAvg: FuncSum { + static constexpr bool IsPreOpIdentity = false; + static constexpr bool IsPostOpIdentity = true; + double rcp; + __device__ FuncAvg(int n) { + rcp = __drcp_rn(double(n)); + } + // inherits FuncSum::operator() + __device__ double preOp(double x) const { + return IsPreOpIdentity ? x : x*rcp; + } + __device__ double postOp(double x) const { + return IsPostOpIdentity ? x : x*rcp; + } +}; + +template<> +struct FuncAvg: FuncSum { + static constexpr bool IsPreOpIdentity = false; + static constexpr bool IsPostOpIdentity = true; + float rcp; + __device__ FuncAvg(int n) { + rcp = __frcp_rn(float(n)); + } + // inherits FuncSum::operator() + __device__ float preOp(float x) const { + return IsPreOpIdentity ? x : x*rcp; + } + __device__ float postOp(float x) const { + return IsPostOpIdentity ? x : x*rcp; + } +}; + +template<> +struct FuncAvg: FuncSum { + // Change these to switch between all prescale, all postscale, or both by sqrt(N). + // Obviously, the only invalid combination is both true. An improvement would be + // make this parameterized as a build time setting and passed here through + // preprocessor definitions. + static constexpr bool IsPreOpIdentity = false; + static constexpr bool IsPostOpIdentity = true; + +#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 + half2 scale; + __device__ FuncAvg(int n) { + if (!IsPreOpIdentity && !IsPostOpIdentity) + scale.x = __float2half(__frsqrt_rn(float(n))); + else + scale.x = __float2half(__frcp_rn(float(n))); + scale.y = scale.x; + } + // inherits FuncSum::operator() + __device__ half preOp(half x) const { + return IsPreOpIdentity ? x : __hmul(x, scale.x); + } + __device__ half2 preOp(half2 x) const { + return IsPreOpIdentity ? x : __hmul2(x, scale); + } + __device__ half postOp(half x) const { + return IsPostOpIdentity ? x : __hmul(x, scale.x); + } + __device__ half2 postOp(half2 x) const { + return IsPostOpIdentity ? x : __hmul2(x, scale); + } +#else + float scale; + __device__ FuncAvg(int n) { + if (!IsPreOpIdentity && !IsPostOpIdentity) + scale = __frsqrt_rn(float(n)); + else + scale = __frcp_rn(float(n)); + } + // inherits FuncSum::operator() + __device__ half preOp(half x) const { + return IsPreOpIdentity ? x : __float2half(__half2float(x)*scale); + } + __device__ half2 preOp(half2 x) const { + if (IsPreOpIdentity) + return x; + else { + float2 a = __half22float2(x); + a.x *= scale; + a.y *= scale; + return __float22half2_rn(a); + } + } + __device__ half postOp(half x) const { + return IsPostOpIdentity ? x : __float2half(__half2float(x)*scale); + } + __device__ half2 postOp(half2 x) const { + if (IsPostOpIdentity) + return x; + else { + float2 a = __half22float2(x); + a.x *= scale; + a.y *= scale; + return __float22half2_rn(a); + } + } +#endif +}; + +#if defined(__CUDA_BF16_TYPES_EXIST__) +template<> +struct FuncAvg<__nv_bfloat16>: FuncSum<__nv_bfloat16> { + // Change these to switch between all prescale, all postscale, or both by sqrt(N). + // Obviously, the only invalid combination is both true. An improvement would be + // make this parameterized as a build time setting and passed here through + // preprocessor definitions. + static constexpr bool IsPreOpIdentity = false; + static constexpr bool IsPostOpIdentity = true; + +#if __CUDA_ARCH__ >= 800 + __nv_bfloat162 scale; + __device__ FuncAvg(int n) { + if (!IsPreOpIdentity && !IsPostOpIdentity) + scale.x = __float2bfloat16(__frsqrt_rn(float(n))); + else + scale.x = __float2bfloat16(__frcp_rn(float(n))); + scale.y = scale.x; + } + // inherits FuncSum::operator() + __device__ __nv_bfloat16 preOp(__nv_bfloat16 x) const { + return IsPreOpIdentity ? x : __hmul(x, scale.x); + } + __device__ __nv_bfloat162 preOp(__nv_bfloat162 x) const { + return IsPreOpIdentity ? x : __hmul2(x, scale); + } + __device__ __nv_bfloat16 postOp(__nv_bfloat16 x) const { + return IsPostOpIdentity ? x : __hmul(x, scale.x); + } + __device__ __nv_bfloat162 postOp(__nv_bfloat162 x) const { + return IsPostOpIdentity ? x : __hmul2(x, scale); + } +#else + float scale; + __device__ FuncAvg(int n) { + if (!IsPreOpIdentity && !IsPostOpIdentity) + scale = __frsqrt_rn(float(n)); + else + scale = __frcp_rn(float(n)); + } + // inherits FuncSum::operator() + __device__ __nv_bfloat16 preOp(__nv_bfloat16 x) const { + return IsPreOpIdentity ? x : __float2bfloat16(__bfloat162float(x)*scale); + } + __device__ __nv_bfloat162 preOp(__nv_bfloat162 x) const { + if (IsPreOpIdentity) + return x; + else { + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x); + return __floats2bfloat162_rn(fxl * scale, fxh * scale); + } + } + __device__ __nv_bfloat16 postOp(__nv_bfloat16 x) const { + return IsPostOpIdentity ? x : __float2bfloat16(__bfloat162float(x)*scale); + } + __device__ __nv_bfloat162 postOp(__nv_bfloat162 x) const { + if (IsPostOpIdentity) + return x; + else { + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x); + return __floats2bfloat162_rn(fxl * scale, fxh * scale); + } + } +#endif +}; +#endif + +template +struct FuncTraits> { + static constexpr bool IsPreOpIdentity = FuncAvg::IsPreOpIdentity; + static constexpr bool IsPostOpIdentity = FuncAvg::IsPostOpIdentity; + + __device__ static FuncAvg make(int rankN) { + return FuncAvg(rankN); + } + template + __device__ static U preOp(FuncAvg fn, U x) { + return fn.preOp(x); + } + template + __device__ static U postOp(FuncAvg fn, U x) { + return fn.postOp(x); + } +}; + #endif // REDUCE_KERNEL_H_ diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h index a0d45dc..c61a028 100644 --- a/src/collectives/device/reduce_scatter.h +++ b/src/collectives/device/reduce_scatter.h @@ -5,192 +5,85 @@ ************************************************************************/ #include "devcomm.h" -#include "primitives.h" #include "collectives.h" +#include "primitives.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); - const int chunkSize = stepSize * REDUCESCATTER_CHUNKSTEPS; - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*(ssize_t)chunkSize; - const ssize_t size = args->coll.count; +namespace { + template + __device__ void runRing(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; + ncclRing *ring = &ncclShmem.channel.ring; + int const *ringRanks = ring->devUserRanks; + const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCESCATTER_CHUNKSTEPS : 1)); + // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. + const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2); + const int nranks = ncclShmem.comm.nRanks; + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; + Primitives, 0, Proto> + prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff); - ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0); - - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); - ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); - ssize_t chunkOffset = gridOffset + bid*realChunkSize; - - /////////////// begin ReduceScatter steps /////////////// - ssize_t offset; - int nelem = min(realChunkSize, size-chunkOffset); - int rankDest; - - // step 0: push data to next GPU - rankDest = ring->devUserRanks[nranks-1]; - offset = chunkOffset + rankDest * size; - - prims.send(thisInput+offset, nelem); - - // k-2 steps: reduce and copy to next GPU - for (int j=2; jdevUserRanks[nranks-j]; - offset = chunkOffset + rankDest * size; - - prims.recvReduceSend(thisInput+offset, nelem); - } - - // step k-1: reduce this buffer and data, which will produce the final result - rankDest = ring->devUserRanks[0]; - offset = chunkOffset + rankDest * size; - - prims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t realChunkSize; + if (Proto::Id == NCCL_PROTO_SIMPLE) { + realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels)); + realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); } - } -}; + else if (Proto::Id == NCCL_PROTO_LL) + realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize; + else if (Proto::Id == NCCL_PROTO_LL128) + realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize); + realChunkSize = int(realChunkSize); -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); - ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; + ssize_t chunkOffset = gridOffset + bid*int(realChunkSize); - ncclLLPrimitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm); + /////////////// begin ReduceScatter steps /////////////// + ssize_t offset; + int nelem = min(realChunkSize, size-chunkOffset); + int rankDest; - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; + // step 0: push data to next GPU + rankDest = ringRanks[nranks-1]; + offset = chunkOffset + rankDest * size; + prims.send(offset, nelem); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - if (size-gridOffset < loopSize) { - chunkSize = args->coll.lastChunkSize; - } - ssize_t chunkOffset = gridOffset + bid*chunkSize; - - /////////////// begin ReduceScatter steps /////////////// - ssize_t offset; - int nelem = min(chunkSize, size-chunkOffset); - int rankDest; - - // step 0: push data to next GPU - rankDest = ring->devUserRanks[nranks-1]; + // k-2 steps: reduce and copy to next GPU + for (int j=2; jdevUserRanks[nranks-j]; - offset = chunkOffset + rankDest * size; - - LLprims.recvReduceSend(thisInput+offset, nelem); - } - - // step k-1: reduce this buffer and data, which will produce the final - // result that we store in this data - rankDest = ring->devUserRanks[0]; - offset = chunkOffset + rankDest * size; - - LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem); + prims.recvReduceSend(offset, nelem); } + + // step k-1: reduce this buffer and data, which will produce the final result + rankDest = ringRanks[0]; + offset = chunkOffset + rankDest * size; + prims.recvReduceCopy(offset, chunkOffset, nelem, /*postOp=*/true); } + } +} + +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + using Proto = ProtoSimple; + runRing(args); + } }; -#include "prims_ll128.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) { - const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->coll.bid; - const int nChannels = args->coll.nChannels; - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclRing* ring = &channel->ring; - const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); - ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T)); - // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. - const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; - const int nranks = comm->nRanks; - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->coll.count; - - ncclLL128Primitives LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm); - - // Compute pointers - const T * __restrict__ thisInput = (const T*)args->sendbuff; - T * __restrict__ thisOutput = (T*)args->recvbuff; - - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize); - - ssize_t chunkOffset = gridOffset + bid*chunkSize; - - /////////////// begin ReduceScatter steps /////////////// - ssize_t offset; - int nelem = min(chunkSize, size-chunkOffset); - int rankDest; - - // step 0: push data to next GPU - rankDest = ring->devUserRanks[nranks-1]; - offset = chunkOffset + rankDest * size; - - LLprims.send(thisInput+offset, nelem); - - // k-2 steps: reduce and copy to next GPU - for (int j=2; jdevUserRanks[nranks-j]; - offset = chunkOffset + rankDest * size; - - LLprims.recvReduceSend(thisInput+offset, nelem); - } - - // step k-1: reduce this buffer and data, which will produce the final - // result that we store in this data - rankDest = ring->devUserRanks[0]; - offset = chunkOffset + rankDest * size; - - LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem); - } - } +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } }; -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} -}; - -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* args) {} +template +struct RunWorkElement { + __device__ void run(ncclWorkElem *args) { + runRing(args); + } }; diff --git a/src/collectives/device/sendrecv.h b/src/collectives/device/sendrecv.h index b489f42..e5e948f 100644 --- a/src/collectives/device/sendrecv.h +++ b/src/collectives/device/sendrecv.h @@ -5,89 +5,87 @@ ************************************************************************/ #include "devcomm.h" -#include "primitives.h" #include "collectives.h" +#include "primitives.h" -template -class ncclFunction { - public: - __device__ void run(struct ncclWorkElem* firstArgs) { - struct ncclWorkElem* args = firstArgs; - int tid = threadIdx.x; - int group = 0; - for (int s=0; sp2p.nThreads; - if (nThreadsSegment == 0) return; // Nothing else to do - int groupRecv = group; - group += 1; - int groupSend = group; - group += nThreadsSegment > 128 ? 2 : 1; - if (tid < nThreadsSegment) { - const int nThreads = nThreadsSegment > 128 ? nThreadsSegment-WARP_SIZE : nThreadsSegment; +template +struct RunWork { + __device__ void run(ncclWork *work) { + int tid = threadIdx.x; + int group = 0; + const int rank = ncclShmem.comm.rank; + const int nRanks = ncclShmem.comm.nRanks; + using Proto = ProtoSimple<1, 1>; - // Compute pointers - const T* sendbuff = (const T*)args->sendbuff; - T* recvbuff = (T*)args->recvbuff; - const ssize_t sendCount = args->p2p.sendCount; - const ssize_t recvCount = args->p2p.recvCount; + for (int s=0; selems[s]; + int nThreadsSegment = args->p2p.nThreads; + if (args->active == 0 || nThreadsSegment == 0) break; - const int delta = args->p2p.delta; - if (delta == 0) { - if (tid < nThreads && sendbuff != recvbuff) { - // local copy : ReduceOrCopyMulti takes an int as number of elements, - // so we split it in blocks of 1G elements. - int blockSize = 1<<30; - for (size_t offset=0; offset(tid, nThreads, 1, &sendbuff, 1, &recvbuff, blockSize); - sendbuff += blockSize; recvbuff += blockSize; - } - } - } else { - struct ncclDevComm* comm = args->comm; - struct ncclChannel* channel = comm->channels+blockIdx.x; + int nThreadsSplit = (nThreadsSegment - (nThreadsSegment > 128 ? WARP_SIZE : 0))/2; + int groupRecv = group; + group += Proto::calcGroupWidth(/*send=*/false, nThreadsSplit); + int groupSend = group; + group += Proto::calcGroupWidth(/*send=*/true, nThreadsSegment - nThreadsSplit); - const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/(sizeof(T)*NCCL_STEPS); + if (tid < nThreadsSegment) { + // Compute pointers + T const* sendbuff = (const T*)args->sendbuff; + T* recvbuff = (T*)args->recvbuff; + ssize_t const sendCount = args->p2p.sendCount; + ssize_t const recvCount = args->p2p.recvCount; + int const delta = args->p2p.delta; - int nThreadsSplit = nThreads/2; - if ((tid < nThreadsSplit) && recvCount >= 0) { - const int chunkSize = args->p2p.recvChunkSize/sizeof(T); - int peer = (comm->rank-delta+comm->nRanks)%comm->nRanks; - int nt = nThreadsSplit; - ncclPrimitives - prims(tid, nt, &peer, NULL, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, groupRecv); - - if (recvCount == 0) { - prims.recv(recvbuff, 0); - } else for (ssize_t offset = 0; offset < recvCount; offset += chunkSize) { - int realChunkSize = min(chunkSize, recvCount-offset); - ALIGN_SIZE(realChunkSize, nt*sizeof(uint64_t)/sizeof(T)); - int nelem = min(realChunkSize, recvCount-offset); - prims.directRecv(recvbuff+offset, offset, nelem); - } - } - if ((tid >= nThreadsSplit) && sendCount >= 0) { - const int chunkSize = args->p2p.sendChunkSize/sizeof(T); - int peer = (comm->rank+delta)%comm->nRanks; - int nt = nThreads-nThreadsSplit; - ncclPrimitives - prims(tid-nThreadsSplit, nt, NULL, &peer, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, groupSend); - - if (sendCount == 0) { - prims.send(sendbuff, 0); - } else for (ssize_t offset = 0; offset < sendCount; offset += chunkSize) { - int realChunkSize = min(chunkSize, sendCount-offset); - ALIGN_SIZE(realChunkSize, nt*sizeof(uint64_t)/sizeof(T)); - int nelem = min(realChunkSize, sendCount-offset); - prims.directSend(sendbuff+offset, offset, nelem); - } + if (delta == 0) { + if (sendbuff != recvbuff) { + // local copy : ReduceOrCopyMulti takes an int as number of elements, + // so we split it in blocks of 1G elements. + int blockSize = 1<<30; + for (size_t offset=0; offset(tid, nThreadsSegment, RedOp(), false, false, 1, &sendbuff, 1, &recvbuff, blockSize); + sendbuff += blockSize; + recvbuff += blockSize; } } } - tid -= nThreadsSegment; - if (tid < 0) return; - args++; + else { + if ((tid < nThreadsSplit) && recvCount >= 0) { + int const peer = (rank - delta + nRanks)%nRanks; + int const t0 = 0; + int const nt = nThreadsSplit; + int const chunkSize = args->p2p.recvChunkSize/sizeof(T); + Primitives, 1, Proto> prims + (tid-t0, nt, &peer, nullptr, nullptr, recvbuff, groupRecv); + ssize_t offset = 0; + do { + int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T))); + nelem = min(chunkSize, recvCount-offset); + prims.directRecv(offset, nelem); + offset += nelem; + } while(offset < recvCount); + } + + if ((tid >= nThreadsSplit) && sendCount >= 0) { + int const peer = (rank + delta)%nRanks; + int const t0 = nThreadsSplit; + int const nt = nThreadsSegment - nThreadsSplit; + int const chunkSize = args->p2p.sendChunkSize/sizeof(T); + Primitives, 1, Proto> prims + (tid-t0, nt, nullptr, &peer, sendbuff, nullptr, groupSend); + ssize_t offset = 0; + do { + int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T))); + nelem = min(chunkSize, sendCount-offset); + prims.directSend(offset, offset, nelem); + offset += nelem; + } while(offset < sendCount); + } + } + break; } + tid -= nThreadsSegment; } + } }; diff --git a/src/enqueue.cc b/src/enqueue.cc index 00920da..5f8c6ab 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -20,6 +20,31 @@ (void*)NCCL_FUNC5(func, RING, redop, type), \ (void*)NCCL_FUNC5(func, COLLNET, redop, type) +#if defined(__CUDA_BF16_TYPES_EXIST__) +// Must be consistent with ncclDataType_t +#define NCCL_FUNCS3A(func, redop) \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, uint8_t), \ + (void*)NCCL_FUNC4(func, redop, int32_t), \ + (void*)NCCL_FUNC4(func, redop, uint32_t), \ + (void*)NCCL_FUNC4(func, redop, int64_t), \ + (void*)NCCL_FUNC4(func, redop, uint64_t), \ + (void*)NCCL_FUNC4(func, redop, half), \ + (void*)NCCL_FUNC4(func, redop, float), \ + (void*)NCCL_FUNC4(func, redop, double), \ + (void*)NCCL_FUNC4(func, redop, __nv_bfloat16) +#define NCCL_FUNCS3B(func, redop) \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t) +#else // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(func, redop) \ (void*)NCCL_FUNC4(func, redop, int8_t), \ @@ -41,17 +66,20 @@ (void*)NCCL_FUNC4(func, redop, int8_t), \ (void*)NCCL_FUNC4(func, redop, int8_t), \ (void*)NCCL_FUNC4(func, redop, int8_t) +#endif // Must be consistent with ncclRedOp_t -- but we only generate kernel for sums. #define NCCL_FUNCS2A(func) \ NCCL_FUNCS3A(func, Sum), \ NCCL_FUNCS3A(func, Sum), \ NCCL_FUNCS3A(func, Sum), \ + NCCL_FUNCS3A(func, Sum), \ NCCL_FUNCS3A(func, Sum) #define NCCL_FUNCS2B(func) \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ + NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum) // Must be consistent with the ncclFuncSet enum @@ -154,16 +182,11 @@ static ncclResult_t setupLaunch(struct ncclQueueInfo* eqInfo, int usingCudaGraph channel->workFifo[(channel->workFifoTail-1)%NCCL_MAX_OPS].elems[0].active = 2; if (c == 0) { - // Find the first operation, choose the kernel accordingly and pass it as the first argument. - // Note that changing cuda launch argument after capture is not supported by cudaGraph + // As we inline the first coll directly, we can free it immediately. + // Except P2P or aggregation cases struct ncclWork* work = channel->workFifo+((channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS); struct ncclWorkElem* elem = work->elems; - if (!usingCudaGraph) { - params->func = ncclKerns[elem->funcIndex]; - memcpy(&comm->args, elem, sizeof(struct ncclWorkElem)); - } - // As we inline the first coll directly, we can free it immediately. - if (elem->funcIndex != FUNC_INDEX_P2P) elem->active = 0; + if (elem->funcIndex != FUNC_INDEX_P2P && eqInfo->elemList->count() == 1) elem->active = 0; } if (channel->gdrMemDesc) { @@ -292,6 +315,7 @@ static ncclResult_t ncclLaunchProxy(struct ncclQueueInfo* eqInfo) { for (int r=0; rmaxChannels; r++) { struct ncclChannel* channel = comm->channels+r; channel->workCount = 0; + channel->totalSize = 0; } comm->lastChannel = 0; NCCLCHECK(ncclProxyStart(comm)); @@ -323,8 +347,7 @@ ncclResult_t ncclLaunchReset(ncclComm_t comm) { // But we need to keep the current enqueue info for CUDA graph // Thus we need to creating a new enqueue info for the next run if (comm->usingCudaGraph) { - NCCLCHECK(ncclCalloc(&comm->enqueueInfo, 1)); - comm->enqueueInfo->comm = comm; + NCCLCHECK(ncclCreateQueueInfo(&comm->enqueueInfo, comm)); } else { // If not in CUDA graph mode, we reuse the same info space NCCLCHECK(ncclResetQueueInfo(comm->enqueueInfo)); @@ -345,22 +368,29 @@ ncclResult_t ncclLaunchReset(ncclComm_t comm) { /* Enqueueing system : computation of kernel and proxy operations parameters */ /*****************************************************************************/ -static ncclResult_t getAlgoInfo(struct ncclInfo* info) { +static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetTypeSupport) { + if (info->comm->collNetSupport > 0) { + ncclRedOp_t netOp = info->op == ncclAvg ? ncclSum : info->op; + NCCLCHECK(collNetReduceSupport(info->datatype, netOp, collNetTypeSupport)); + } else { + *collNetTypeSupport = 0; + } + return ncclSuccess; +} + +static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, int numPipeOps) { struct ncclComm* comm = info->comm; float minTime = 3600000000.0; // Hopefully no operation will take an hour to complete. // Find algorithm / protocol. info->algorithm = -1; info->protocol = -1; + if (comm->nRanks == 1) return ncclSuccess; int nAlgos = NCCL_NUM_ALGORITHMS; - // Check collNet support - int collNetTypeSupport = 0; - if (info->comm->collNetSupport > 0) - NCCLCHECK(collNetReduceSupport(info->datatype, info->op, &collNetTypeSupport)); for (int a=0; a= 0 && time < minTime) { info->algorithm = a; info->protocol = p; @@ -397,7 +427,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) { } if (info->protocol == NCCL_PROTO_SIMPLE) { nt += WARP_SIZE; // Extra warp for sync - if (info->algorithm == NCCL_ALGO_TREE) nt += WARP_SIZE; + if (info->algorithm == NCCL_ALGO_TREE) nt += 3*WARP_SIZE; if (info->algorithm == NCCL_ALGO_COLLNET) nt += 3*WARP_SIZE; } info->nChannels = nc; @@ -447,8 +477,14 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) { static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWorkElem* work, struct ncclProxyArgs* proxyArgs /* output */) { work->comm = info->comm->devComm; + int collNetTypeSupport = 0; + // Check whether algo and proto have been preset + if (info->nChannels > 0 && info->nThreads > 0) goto comp_next; + NCCLCHECK(getCollNetSupport(info, &collNetTypeSupport)); + NCCLCHECK(getAlgoInfo(info, collNetTypeSupport, 1)); + +comp_next: // Set nstepsPerLoop and nchunksPerLoop - NCCLCHECK(getAlgoInfo(info)); NCCLCHECK(getPatternInfo(info)); NCCLCHECK(getLoopInfo(info)); @@ -478,10 +514,9 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype); } else if (info->algorithm == NCCL_ALGO_COLLNET && info->protocol == NCCL_PROTO_SIMPLE) { // Optimize chunkSize / nSteps - while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*32 && chunkSize > 262144) chunkSize /= 2; - while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*16 && chunkSize > 131072) chunkSize /= 2; + while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*64 && chunkSize > 131072) chunkSize /= 2; + while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*8 && chunkSize > 65536) chunkSize /= 2; while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*8 && chunkSize > 32768) chunkSize /= 2; - while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth/2 && chunkSize > 16384) chunkSize /= 2; // Use lastChunkSize as chunkSize work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype); } else if (info->protocol == NCCL_PROTO_LL) { @@ -512,7 +547,9 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo proxyArgs->chunkSize = chunkSize; proxyArgs->protocol = info->protocol; proxyArgs->dtype = info->datatype; - proxyArgs->redOp = (info->algorithm == NCCL_ALGO_COLLNET) ? info->op : ncclNumOps; // Only set redOp when using CollNet + proxyArgs->redOp = info->algorithm != NCCL_ALGO_COLLNET ? ncclNumOps : // Only set redOp when using CollNet + info->op == ncclAvg ? ncclSum : // Network sees avg as sum + info->op; proxyArgs->pattern = info->pattern; proxyArgs->root = info->root; // This is used by P2P to reduce the receive buffer size. We don't use it in collectives @@ -550,7 +587,7 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) { // Compute cuda kernel arg and proxy arg templates struct ncclQueueElem* eqElem; - NCCLCHECK(ncclAddQueueElem(comm->enqueueInfo, &eqElem)); + NCCLCHECK(comm->enqueueInfo->elemList->getNewElem(&eqElem)); struct ncclWorkElem* work = &eqElem->work; eqElem->proxyArgs.nsubs = 1; NCCLCHECK(computeColl(info, work, &eqElem->proxyArgs)); @@ -573,6 +610,29 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) { return ncclSuccess; } +static inline int findShortestChannel(ncclComm_t comm) { + size_t minSize = SIZE_MAX; + int minC = 0; + for (int c=0; cnChannels; c++) { + struct ncclChannel* channel = comm->channels+c; + if (channel->totalSize < minSize) { + minSize = channel->totalSize; + minC = c; + } + } + return minC; +} + +static inline ncclResult_t getNextChannel(ncclComm_t comm, int* nextChannel) { + if (comm->asyncAllocMode == ncclComm::SHORTEST_QUEUE) { + *nextChannel = findShortestChannel(comm); + } else { + *nextChannel = comm->lastChannel % comm->nChannels; + comm->lastChannel++; + } + return ncclSuccess; +} + // Dynamic enqueue code static ncclResult_t ncclEnqueueCollKernel(ncclComm_t comm, struct ncclQueueElem* eqElem) { struct ncclWorkElem* work = &eqElem->work; @@ -600,9 +660,6 @@ static ncclResult_t ncclEnqueueCollKernel(ncclComm_t comm, struct ncclQueueElem* return ncclSuccess; } -#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64) -#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */ - ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) { if (comm->asyncOpCount == 0) { return ncclSuccess; @@ -613,19 +670,47 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) { NCCLCHECK(ncclSetupCollKernel(info)); } else { // Aggregation - size_t channelSize = NCCL_AGG_CHANNEL_SIZE * comm->nRanks; // scale channel size based on nranks as latency increases + size_t channelSize; + if (comm->channelSize > 0) { + channelSize = comm->channelSize; + } else if (comm->collNetSupport && comm->asyncOps[0].coll == ncclFuncAllReduce) { + channelSize = 256 * 1024; + } else { + channelSize = NCCL_AGG_CHANNEL_SIZE * std::min(16, comm->nRanks); // scale channel size based on nranks as latency increases + } // Reduce the per-channel size if we cannot fully utilize the channels while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2; int channelUsed = 0; + ncclFunc_t commonColl = ncclNumFuncs; + int fastPath = 1; + int allCollNetSupport = comm->collNetSupport; for (int c = 0; c < comm->asyncOpCount; c++) { struct ncclInfo* info = comm->asyncOps+c; - info->nChannels = std::min((int)DIVUP(info->nBytes, channelSize), comm->nChannels); // assign number of channels + info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels channelUsed += info->nChannels; + // We can use fast path if all collectives are the same + if (commonColl == ncclNumFuncs) commonColl = info->coll; + else if (commonColl != info->coll) fastPath = 0; + else if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport)); + } + // Compute algo, proto, nthreads for the entire kernel + struct ncclInfo total; + total.comm = comm; + total.coll = commonColl; + total.nBytes = comm->asyncTotalSize; + total.nChannels = std::min(channelUsed, comm->nChannels); + int perChannelOps = DIVUP(channelUsed, total.nChannels); + if (fastPath) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps)); + for (int c = 0; c < comm->asyncOpCount; c++) { + struct ncclInfo* info = comm->asyncOps+c; + if (fastPath) { + info->algorithm = total.algorithm; + info->protocol = total.protocol; + info->nThreads = total.nThreads; + } NCCLCHECK(ncclSetupCollKernel(info)); } - // If we wrap around on channels, then the inlined op on channel 0 is not the last one on this channel - // Then we need to change active from 2 to 1 - if (channelUsed > comm->nChannels) comm->args.active = 1; + comm->args.active = 0; // disable inline argument } // Reset counters comm->asyncOpCount = 0; @@ -662,7 +747,7 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) { } } } - NCCLCHECK(enqueueP2pInfo(comm->p2pSends+info->root, (void*)info->sendbuff, nBytes)); + NCCLCHECK(ncclSaveP2pInfo(comm->p2pSends[info->root], (void*)info->sendbuff, nBytes)); comm->p2pSendCount++; } else { if (peer != comm->rank) { @@ -675,15 +760,22 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) { } } } - NCCLCHECK(enqueueP2pInfo(comm->p2pRecvs+info->root, info->recvbuff, nBytes)); + NCCLCHECK(ncclSaveP2pInfo(comm->p2pRecvs[info->root], info->recvbuff, nBytes)); comm->p2pRecvCount++; } return ncclSuccess; } -static int getSegment(int delta, struct ncclWork* work) { - for (int s=0; selems[s].p2p.delta != delta; s++) { - if (work->elems[s].p2p.nThreads == 0) return s; +enum { COLL_SEGMENT=0, P2P_SEGMENT=1 }; +static int getSegment(int type, int delta, struct ncclWork* work) { + if (type == P2P_SEGMENT) { // P2P + for (int s=0; selems[s].p2p.delta != delta; s++) { + if (work->elems[s].active == 0) return s; + } + } else { // aggregation + for (int s=0; selems[s].active == 0) return s; + } } return -1; } @@ -702,16 +794,19 @@ static ncclResult_t computeP2pWorkElem(struct ncclInfo* info /* input */, struct return ncclSuccess; } -static ncclResult_t enqueueP2pOp(struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s) { +static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s) { // Copy element into corresponding segment of ncclWork memcpy(work->elems+s, elem, sizeof(struct ncclWorkElem)); + work->elems[s].active = 1; // Determine nThreads at dynamic time - const int nsegments = s+1; - int nThreads = 512; - while (nsegments*nThreads > 512) nThreads /= 2; - if (nThreads >= 128) nThreads += WARP_SIZE; - for (int i=0; ielems[i].p2p.nThreads = nThreads; + if (type == P2P_SEGMENT) { + const int nsegments = s+1; + int nThreads = 512; + while (nsegments*nThreads > 512) nThreads /= 2; + if (nThreads >= 128) nThreads += WARP_SIZE; + for (int i=0; ielems[i].p2p.nThreads = nThreads; + } return ncclSuccess; } @@ -725,9 +820,9 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS; struct ncclWork* w = channel->workFifo+opIndex; int segment = -1; - if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].p2p.nThreads == 0) { + if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) { // Try to pack more segments into a single operation - segment = getSegment(workElem->p2p.delta, w); + segment = getSegment(P2P_SEGMENT, workElem->p2p.delta, w); } if (segment == -1) { NCCLCHECK(getNextOp(channel, &w, NULL)); @@ -736,7 +831,7 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e // store work element into FIFO NCCLCHECK(ncclProxySaveP2p(comm, proxyArgs)); - NCCLCHECK(enqueueP2pOp(workElem, w, segment)); + NCCLCHECK(enqueueSegOp(P2P_SEGMENT, workElem, w, segment)); return ncclSuccess; } @@ -744,7 +839,7 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) { ncclComm* comm = info->comm; // Compute cuda kernel arg and proxy arg templates struct ncclQueueElem* eqElem; - NCCLCHECK(ncclAddQueueElem(comm->enqueueInfo, &eqElem)); + NCCLCHECK(comm->enqueueInfo->elemList->getNewElem(&eqElem)); // The proxy code will set and tune the send/recv chunk size, make sure to run it first. NCCLCHECK(ncclProxyComputeP2p(info, &eqElem->proxyArgs)); NCCLCHECK(computeP2pWorkElem(info, &eqElem->work)); @@ -760,11 +855,51 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) { // The CUDA kernel does not use the inlined first work element as fastpath argument if (params->func == NULL) { params->func = ncclKerns[eqElem->work.funcIndex]; - memcpy(&comm->args, &eqElem->work, sizeof(struct ncclWorkElem)); + comm->args.comm = eqElem->work.comm; + comm->args.active = 0; } return ncclSuccess; } +ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem* eqElem) { + struct ncclWorkElem* work = &eqElem->work; + struct ncclProxyArgs* proxyArgs = &eqElem->proxyArgs; + + int nChannels = work->coll.nChannels; + size_t channelSize = work->coll.count*ncclTypeSize(proxyArgs->dtype)/work->coll.nChannels; + for (int bid=0; bidchannels+channelId; + + // Proxy + proxyArgs->subs[0].channel = channel; + proxyArgs->opCount = comm->collOpCount; + proxyArgs->commOpCount = comm->opCount; + if (proxyArgs->subs[0].nsteps) NCCLCHECK(ncclProxySaveColl(proxyArgs, comm->nRanks)); + + // Try to reuse last work if not full yet + work->coll.bid = bid % nChannels; + int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS; + struct ncclWork* w = channel->workFifo+opIndex; + int segment = -1; + if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) { + // Try to pack more segments into a single operation + segment = getSegment(COLL_SEGMENT, 0, w); + } + if (segment == -1) { + NCCLCHECK(getNextOp(channel, &w, NULL)); + segment = 0; + } + + // store work element into FIFO + NCCLCHECK(enqueueSegOp(COLL_SEGMENT, work, w, segment)); + channel->totalSize += channelSize; + } + comm->collOpCount++; + return ncclSuccess; +} + template void CUDART_CB ncclEnqueueHostSetup(void* arg) { ncclResult_t ret; @@ -772,14 +907,17 @@ void CUDART_CB ncclEnqueueHostSetup(void* arg) { ncclComm_t comm = eqInfo->comm; // Iterate through the element list - struct ncclQueueElem* eqElem = eqInfo->elemList.head; - while (eqElem != eqInfo->elemList.tail) { // The queue always has one extra element + struct ncclQueueElem* eqElem = eqInfo->elemList->begin(); + while (eqElem != NULL) { if (eqElem->work.funcIndex == FUNC_INDEX_P2P) { NCCLCHECKGOTO(ncclEnqueueP2pKernel(comm, eqElem), ret, cb_end); + } else if (eqInfo->elemList->count() > 1) { + // We have more than one operation, hence aggregating + NCCLCHECKGOTO(ncclEnqueueAsyncKernel(comm, eqElem), ret, cb_end); } else { NCCLCHECKGOTO(ncclEnqueueCollKernel(comm, eqElem), ret, cb_end); } - eqElem = eqElem->next; + eqElem = eqInfo->elemList->getNext(); } NCCLCHECKGOTO(setupLaunch(eqInfo, USING_CUDA_GRAPH), ret, cb_end); diff --git a/src/graph/paths.cc b/src/graph/paths.cc index fae5afa..64c54df 100644 --- a/src/graph/paths.cc +++ b/src/graph/paths.cc @@ -388,7 +388,9 @@ ncclResult_t ncclTopoComputePaths(struct ncclTopoSystem* system, struct ncclPeer struct ncclPeerInfo* srcInfo = peerInfos+system->nodes[GPU].nodes[p].gpu.rank; int shm; NCCLCHECK(ncclTransports[TRANSPORT_SHM].canConnect(&shm, system, NULL, srcInfo, dstInfo)); - if (shm == 0) { + int p2p; + NCCLCHECK(ncclTransports[TRANSPORT_P2P].canConnect(&p2p, system, NULL, srcInfo, dstInfo)); + if (shm == 0 && p2p == 0) { // Mark this peer as inaccessible. We'll trim it later. system->nodes[GPU].nodes[p].paths[GPU][g].count = 0; } diff --git a/src/graph/search.cc b/src/graph/search.cc index 7ced017..8894bd1 100644 --- a/src/graph/search.cc +++ b/src/graph/search.cc @@ -707,8 +707,10 @@ ncclResult_t ncclTopoGetXmlFromGraphs(int ngraphs, struct ncclTopoGraph** graphs return ncclSuccess; } -float speedArray[] = { 42.0, 30.0, 24.0, 21.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 }; -#define NSPEEDS (sizeof(speedArray)/sizeof(float)) +float speedArrayIntra[] = { 44.0, 30.0, 22.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0 }; +float speedArrayInter[] = { 48.0, 30.0, 24.0, 22.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 }; +#define NSPEEDSINTRA (sizeof(speedArrayIntra)/sizeof(float)) +#define NSPEEDSINTER (sizeof(speedArrayInter)/sizeof(float)) ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph) { int ngpus = system->nodes[GPU].count; @@ -738,15 +740,23 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph // SPLIT_TREE works better on older archs. int ccMin; NCCLCHECK(ncclTopoGetCompCap(system, &ccMin, NULL)); - if (ccMin < 80 && graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE) graph->pattern = NCCL_TOPO_PATTERN_SPLIT_TREE; struct ncclTopoGraph tmpGraph; memcpy(&tmpGraph, graph, sizeof(struct ncclTopoGraph)); // First try crossnic, then decrease speed and finally increase speedIntra. + int nspeeds = 0; + float* speedArray = NULL; + if (system->nodes[NET].count == 0) { + nspeeds = NSPEEDSINTRA; + speedArray = speedArrayIntra; + } else { + nspeeds = NSPEEDSINTER; + speedArray = speedArrayInter; + } int pass = 1; int speedIndex = 0; - while (speedArray[speedIndex] > system->maxWidth && speedIndex < NSPEEDS-1) speedIndex++; + while (speedArray[speedIndex] > system->maxWidth && speedIndex < nspeeds-1) speedIndex++; tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex]; int64_t globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT; @@ -813,12 +823,12 @@ search: tmpGraph.crossNic = 0; // Decrease speed until we find a solution - if ((speedIndex < NSPEEDS-1) && (graph->nChannels == 0 || (speedArray[speedIndex+1]/graph->speedInter > .49))) { + if ((speedIndex < nspeeds-1) && (graph->nChannels == 0 || (speedArray[speedIndex+1]/graph->speedInter > .49))) { tmpGraph.speedInter = tmpGraph.speedIntra = speedArray[++speedIndex]; goto search; } speedIndex = 0; - while (speedArray[speedIndex] > system->maxWidth && speedIndex < NSPEEDS-1) speedIndex++; + while (speedArray[speedIndex] > system->maxWidth && speedIndex < nspeeds-1) speedIndex++; tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex]; } @@ -829,7 +839,7 @@ done: time = -1; memcpy(&tmpGraph, graph, sizeof(tmpGraph)); speedIndex = 0; - while (speedArray[speedIndex] > graph->speedInter && speedIndex < NSPEEDS-1) speedIndex++; + while (speedArray[speedIndex] > graph->speedInter && speedIndex < nspeeds-1) speedIndex++; tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex]; tmpGraph.minChannels = graph->nChannels; pass = 2; diff --git a/src/graph/topo.cc b/src/graph/topo.cc index 52d5406..135f569 100644 --- a/src/graph/topo.cc +++ b/src/graph/topo.cc @@ -583,7 +583,10 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy char* xmlTopoFile = getenv("NCCL_TOPO_FILE"); if (xmlTopoFile) { INFO(NCCL_ENV, "NCCL_TOPO_FILE set by environment to %s", xmlTopoFile); - NCCLCHECK(ncclTopoGetXmlFromFile(xmlTopoFile, xml)); + NCCLCHECK(ncclTopoGetXmlFromFile(xmlTopoFile, xml, 1)); + } else { + // Try default XML topology location + NCCLCHECK(ncclTopoGetXmlFromFile("/var/run/nvidia-topologyd/virtualTopology.xml", xml, 0)); } if (xml->maxIndex == 0) { // Create top tag @@ -691,7 +694,7 @@ ncclResult_t ncclTopoCpuType(struct ncclTopoSystem* system, int* arch, int* vend NCCL_PARAM(IgnoreCpuAffinity, "IGNORE_CPU_AFFINITY", 0); -ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank) { +ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu_set_t* affinity) { struct ncclTopoNode* cpu = NULL, *gpu = NULL; for (int g=0; gnodes[GPU].count; g++) { if (system->nodes[GPU].nodes[g].gpu.rank == rank) { @@ -744,12 +747,13 @@ ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank) { // Use a subset of the GPU affinity set CPU_AND(&finalMask, &mask, &cpuMask); + memcpy(affinity, &finalMask, sizeof(cpu_set_t)); + // If there is a non empty set, use it to set affinity if (CPU_COUNT(&finalMask)) { char affinityStr[sizeof(cpu_set_t)*2]; NCCLCHECK(ncclCpusetToStr(&finalMask, affinityStr)); INFO(NCCL_INIT, "Setting affinity for GPU %d to %s", gpu->gpu.dev, affinityStr); - SYSCHECK(sched_setaffinity(0, sizeof(cpu_set_t), &finalMask), "sched_setaffinity"); } return ncclSuccess; } diff --git a/src/graph/topo.h b/src/graph/topo.h index 1e10bb2..304b496 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -9,12 +9,11 @@ #include "graph.h" #include "core.h" -#include #define LOC_WIDTH 5000.0 #define SM60_NVLINK_WIDTH 18.0 -#define SM70_NVLINK_WIDTH 21.0 -#define SM80_NVLINK_WIDTH 21.0 +#define SM70_NVLINK_WIDTH 22.0 +#define SM80_NVLINK_WIDTH 22.0 #define SM86_NVLINK_WIDTH 12.0 #define PCI_WIDTH 12.0 // PCI Gen3 x16 #define QPI_WIDTH 6.0 diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index db085cb..e30a927 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -60,20 +60,19 @@ static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 4.4, #define NCCL_HW_PCI 1 #define NCCL_HW_NET 2 // Tree/Simple is the latency a 256kB chunk, which is ~ base lat + 256k/12GB/s (+ 256k/12GB/s for the network). -static const float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = +static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { /* NVLINK */ - { /* Tree (LL/LL128/Simple)*/ { .52, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .47, 1.9, 3.4 }, /* CollNet (LL/LL128/Simple)*/ { .5, 1.2, 4.0 } }, + { /* Tree (LL/LL128/Simple)*/ { .52, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .47, 1.9, 3.4 }, /* CollNet (LL/LL128/Simple)*/ { .5, 1.2, 8.0 } }, /* PCI */ - { /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 }, /* CollNet (LL/LL128/Simple)*/ { 1.0, 1.9, 5.5 } }, + { /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 }, /* CollNet (LL/LL128/Simple)*/ { 1.0, 1.9, 8.0 } }, /* NET */ { /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 28 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 9.6 }, /* CollNet (LL/LL128/Simple)*/ { 5.0, 5.0, 10.7 } } }; -// LL128 max BW (per channel) for the different collectives -// ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce -static const double ll128MaxBwPerCh[NCCL_NUM_FUNCTIONS] = { 18.8, 12.0, 18.3, 15.2, 16.9 }; +// LL128 max BW per channel +static const double ll128MaxBwPerCh = 20.0; static const double llMaxBws[2][3] = { /* Volta-N1/Intel-N2/Intel-N4) */ {39.0, 39.0, 20.4}, /* Ampere-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0} }; -static const double perChMaxTreeBws[2][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 22.5, 16.0} }; +static const double perChMaxTreeBws[2][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 23.6, 17.8} }; ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph) { int simpleDefaultThreads = (ringGraph->speedIntra*ringGraph->nChannels <= PCI_WIDTH) ? 256 : NCCL_SIMPLE_MAX_NTHREADS; @@ -100,6 +99,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom int index1 = nNodes == 1 ? compCap80 : cpuVendor == NCCL_TOPO_CPU_VENDOR_AMD ? 1 : 0; double llMaxBw = llMaxBws[index1][index2]; double perChMaxTreeBw = perChMaxTreeBws[compCap80][index2]; + // De-penalize Tree/Simple latency on Power systems to favor Tree than Ring + if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE]; float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph }; @@ -125,11 +126,10 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom // Various model refinements if (compCap80) busBw = std::min(busBw, 235.0f); if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); } - if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels); + if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh*graphs[a]->nChannels); if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw); if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw = std::min(busBw*1.0/3.8, llMaxBw); - if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 0.915 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels); - if (a == NCCL_ALGO_COLLNET) busBw *= .9; + if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), ll128MaxBwPerCh*graphs[a]->nChannels); if (a == NCCL_ALGO_COLLNET && p != NCCL_PROTO_SIMPLE) busBw = 0; // Oneshot CollNet only supports Simple // Convert bus BW to algorithm BW @@ -157,7 +157,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom 2 * ((nRanks/nNodes-1) * intraLat + log2i(nNodes) * interLat); } else { comm->latencies[coll][a][p] += - 2 * (nRanks/nNodes-1) * intraLat + interLat; + 2 * (std::min(1, (nRanks/nNodes-1)) * intraLat + (nRanks/nNodes-1) * 0.5) + interLat; // Add 0.5 arity serialization latency } } } @@ -266,11 +266,11 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom // factor is not ideal but works quite well. Powers of two, 64 B to 256MB. static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][23] = { { 1.0, 1.0, 1.0, 1.0, .9, .8, .7, .7, .7, .7, .6, .5, .4, .4, .5, .6, .7, .8, .9, 1.0, 1.0, 1.0, 1.0 }, - { 1.0, 1.0, 1.0, 1.0, 1.0, .9, .8, .8, .8, .7, .6, .6, .6, .5, .6, .6, .7, .7, .8, .9, .9, .92, .92 }, + { 1.0, 1.0, 1.0, 1.0, 1.0, .9, .8, .8, .8, .7, .6, .6, .6, .6, .6, .6, .8, .9, .9, .9, .9, 1.0, 1.0 }, { .9, .9, .9, .9, .9, .9, .9, .8, .7, .6, .6, .5, .5, .5, .5, .6, .7, .8, .7, .7, .8, .9, .9 } }; -ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, float* time) { +ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time) { float bw = info->comm->bandwidths[info->coll][algorithm][protocol]; float lat = info->comm->latencies[info->coll][algorithm][protocol]; if (bw == 0) { @@ -281,6 +281,8 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int proto if (info->nChannels != 0) bw = bw / info->comm->nChannels * info->nChannels; if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && info->comm->nNodes > 1 && info->coll == ncclFuncAllReduce && info->nBytes >= info->comm->nRanks/16.0*65536) lat *= 1.9; // Plateau effect of ring - *time = lat + (info->nBytes) / (1000 * bw); + // Tree pipelining saves latency in aggregation cases + int latCount = algorithm == NCCL_ALGO_RING ? numPipeOps : DIVUP(numPipeOps, NCCL_MAX_WORK_ELEMENTS); + *time = lat * latCount + (info->nBytes) / (1000 * bw); return ncclSuccess; } diff --git a/src/graph/xml.cc b/src/graph/xml.cc index 05a77bf..29e8f00 100644 --- a/src/graph/xml.cc +++ b/src/graph/xml.cc @@ -300,12 +300,15 @@ ncclResult_t ncclTopoXmlLoadSystem(FILE* file, struct ncclXml* xml, struct ncclX return ncclSuccess; } -ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml) { +ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml, int warn) { FILE* file = fopen(xmlTopoFile, "r"); if (file == NULL) { - WARN("Could not open XML topology file %s : %s", xmlTopoFile, strerror(errno)); + if (warn) { + WARN("Could not open XML topology file %s : %s", xmlTopoFile, strerror(errno)); + } return ncclSuccess; } + INFO(NCCL_GRAPH, "Loading topology file %s", xmlTopoFile); struct xmlHandler handlers[] = { { "system", ncclTopoXmlLoadSystem } }; xml->maxIndex = 0; NCCLCHECK(xmlLoadSub(file, xml, NULL, handlers, 1)); @@ -441,8 +444,8 @@ ncclResult_t ncclTopoGetPciNode(struct ncclXml* xml, const char* busId, struct n NCCLCHECK(xmlFindTagKv(xml, "pci", pciNode, "busid", busId)); if (*pciNode == NULL) { NCCLCHECK(xmlAddNode(xml, NULL, "pci", pciNode)); + NCCLCHECK(xmlSetAttr(*pciNode, "busid", busId)); } - NCCLCHECK(xmlSetAttr(*pciNode, "busid", busId)); return ncclSuccess; } @@ -463,100 +466,114 @@ ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml* const char* busId; NCCLCHECK(xmlGetAttr(pciNode, "busid", &busId)); char* path = NULL; - int index; - NCCLCHECK(xmlGetAttrIndex(pciNode, "class", &index)); - if (index == -1) { - if (path == NULL) NCCLCHECK(getPciPath(busId, &path)); + ncclDebugNoWarn = NCCL_GRAPH; + getPciPath(busId, &path); + ncclDebugNoWarn = 0; + + if (path) { NCCLCHECK(ncclTopoSetAttrFromSys(pciNode, path, "class", "class")); } + int index; ncclDebugNoWarn = NCCL_GRAPH; NCCLCHECK(xmlGetAttrIndex(pciNode, "vendor", &index)); if (index == -1) { - if (path == NULL) getPciPath(busId, &path); if (path) ncclTopoSetAttrFromSys(pciNode, path, "vendor", "vendor"); } NCCLCHECK(xmlGetAttrIndex(pciNode, "device", &index)); if (index == -1) { - if (path == NULL) getPciPath(busId, &path); if (path) ncclTopoSetAttrFromSys(pciNode, path, "device", "device"); } NCCLCHECK(xmlGetAttrIndex(pciNode, "subsystem_vendor", &index)); if (index == -1) { - if (path == NULL) getPciPath(busId, &path); if (path) ncclTopoSetAttrFromSys(pciNode, path, "subsystem_vendor", "subsystem_vendor"); } NCCLCHECK(xmlGetAttrIndex(pciNode, "subsystem_device", &index)); if (index == -1) { - if (path == NULL) getPciPath(busId, &path); if (path) ncclTopoSetAttrFromSys(pciNode, path, "subsystem_device", "subsystem_device"); } ncclDebugNoWarn = 0; NCCLCHECK(xmlGetAttrIndex(pciNode, "link_speed", &index)); if (index == -1) { - if (path == NULL) NCCLCHECK(getPciPath(busId, &path)); - char deviceSpeedStr[MAX_STR_LEN]; - float deviceSpeed; - NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_speed", deviceSpeedStr)); - sscanf(deviceSpeedStr, "%f GT/s", &deviceSpeed); - char portSpeedStr[MAX_STR_LEN]; - float portSpeed; - NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_speed", portSpeedStr)); - sscanf(portSpeedStr, "%f GT/s", &portSpeed); - NCCLCHECK(xmlSetAttr(pciNode, "link_speed", portSpeed < deviceSpeed ? portSpeedStr : deviceSpeedStr)); + if (path) { + char deviceSpeedStr[MAX_STR_LEN]; + float deviceSpeed; + NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_speed", deviceSpeedStr)); + sscanf(deviceSpeedStr, "%f GT/s", &deviceSpeed); + char portSpeedStr[MAX_STR_LEN]; + float portSpeed; + NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_speed", portSpeedStr)); + sscanf(portSpeedStr, "%f GT/s", &portSpeed); + NCCLCHECK(xmlSetAttr(pciNode, "link_speed", portSpeed < deviceSpeed ? portSpeedStr : deviceSpeedStr)); + } else { + NCCLCHECK(xmlSetAttr(pciNode, "link_speed", "")); + } } NCCLCHECK(xmlGetAttrIndex(pciNode, "link_width", &index)); if (index == -1) { - if (path == NULL) NCCLCHECK(getPciPath(busId, &path)); - char strValue[MAX_STR_LEN]; - NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_width", strValue)); - int deviceWidth = strtol(strValue, NULL, 0); - NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_width", strValue)); - int portWidth = strtol(strValue, NULL, 0); - NCCLCHECK(xmlSetAttrInt(pciNode, "link_width", std::min(deviceWidth,portWidth))); + if (path) { + char strValue[MAX_STR_LEN]; + NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_width", strValue)); + int deviceWidth = strtol(strValue, NULL, 0); + NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_width", strValue)); + int portWidth = strtol(strValue, NULL, 0); + NCCLCHECK(xmlSetAttrInt(pciNode, "link_width", std::min(deviceWidth,portWidth))); + } else { + NCCLCHECK(xmlSetAttr(pciNode, "link_width", "")); + } } struct ncclXmlNode* parent = pciNode->parent; if (parent == NULL) { - if (path == NULL) NCCLCHECK(getPciPath(busId, &path)); + if (path) { + // Save that for later in case next step is a CPU + char numaIdStr[MAX_STR_LEN]; + NCCLCHECK(ncclTopoGetStrFromSys(path, "numa_node", numaIdStr)); - // Save that for later in case next step is a CPU - char numaIdStr[MAX_STR_LEN]; - NCCLCHECK(ncclTopoGetStrFromSys(path, "numa_node", numaIdStr)); - - // Go up one level in the PCI tree. Rewind two "/" and follow the upper PCI - // switch, or stop if we reach a CPU root complex. - int slashCount = 0; - int parentOffset; - for (parentOffset = strlen(path)-1; parentOffset>0; parentOffset--) { - if (path[parentOffset] == '/') { - slashCount++; - path[parentOffset] = '\0'; - int start = parentOffset - 1; - while (start>0 && path[start] != '/') start--; - // Check whether the parent path looks like "BBBB:BB:DD.F" or not. - if (checkBDFFormat(path+start+1) == 0) { - // This a CPU root complex. Create a CPU tag and stop there. - struct ncclXmlNode* topNode; - NCCLCHECK(xmlFindTag(xml, "system", &topNode)); - NCCLCHECK(xmlGetSubKv(topNode, "cpu", &parent, "numaid", numaIdStr)); - if (parent == NULL) { - NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent)); - NCCLCHECK(xmlSetAttr(parent, "numaid", numaIdStr)); - } - } else if (slashCount == 2) { - // Continue on the upper PCI switch - for (int i = strlen(path)-1; i>0; i--) { - if (path[i] == '/') { - NCCLCHECK(xmlFindTagKv(xml, "pci", &parent, "busid", path+i+1)); - if (parent == NULL) { - NCCLCHECK(xmlAddNode(xml, NULL, "pci", &parent)); - NCCLCHECK(xmlSetAttr(parent, "busid", path+i+1)); + // Go up one level in the PCI tree. Rewind two "/" and follow the upper PCI + // switch, or stop if we reach a CPU root complex. + int slashCount = 0; + int parentOffset; + for (parentOffset = strlen(path)-1; parentOffset>0; parentOffset--) { + if (path[parentOffset] == '/') { + slashCount++; + path[parentOffset] = '\0'; + int start = parentOffset - 1; + while (start>0 && path[start] != '/') start--; + // Check whether the parent path looks like "BBBB:BB:DD.F" or not. + if (checkBDFFormat(path+start+1) == 0) { + // This a CPU root complex. Create a CPU tag and stop there. + struct ncclXmlNode* topNode; + NCCLCHECK(xmlFindTag(xml, "system", &topNode)); + NCCLCHECK(xmlGetSubKv(topNode, "cpu", &parent, "numaid", numaIdStr)); + if (parent == NULL) { + NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent)); + NCCLCHECK(xmlSetAttr(parent, "numaid", numaIdStr)); + } + } else if (slashCount == 2) { + // Continue on the upper PCI switch + for (int i = strlen(path)-1; i>0; i--) { + if (path[i] == '/') { + NCCLCHECK(xmlFindTagKv(xml, "pci", &parent, "busid", path+i+1)); + if (parent == NULL) { + NCCLCHECK(xmlAddNode(xml, NULL, "pci", &parent)); + NCCLCHECK(xmlSetAttr(parent, "busid", path+i+1)); + } + break; } - break; } } } + if (parent) break; + } + } else { + // No information on /sys, attach GPU to unknown CPU + NCCLCHECK(xmlFindTagKv(xml, "cpu", &parent, "numaid", "-1")); + if (parent == NULL) { + struct ncclXmlNode* topNode; + NCCLCHECK(xmlFindTag(xml, "system", &topNode)); + NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent)); + NCCLCHECK(xmlSetAttr(parent, "numaid", "-1")); + NCCLCHECK(ncclTopoGetXmlFromCpu(parent, xml)); } - if (parent) break; } pciNode->parent = parent; parent->subs[parent->nSubs++] = pciNode; @@ -661,12 +678,14 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm if (index == -1) { const char* busId; NCCLCHECK(xmlGetAttr(sub, "target", &busId)); - if (strcmp(busId, "fffffff:ffff:ff") == 0) { + char* path; + ncclDebugNoWarn = NCCL_GRAPH; + getPciPath(busId, &path); + ncclDebugNoWarn = 0; + if (path == NULL || strcmp(busId, "fffffff:ffff:ff") == 0) { // Remote NVLink device is not visible inside this VM. Assume NVSwitch. NCCLCHECK(xmlSetAttr(sub, "tclass", "0x068000")); } else { - char* path; - NCCLCHECK(getPciPath(busId, &path)); NCCLCHECK(ncclTopoSetAttrFromSys(sub, path, "class", "tclass")); free(path); } @@ -679,6 +698,7 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct ncclXmlNode** gpuNode) { struct ncclXmlNode* node; NCCLCHECK(ncclTopoGetPciNode(xml, busId, &node)); + NCCLCHECK(xmlSetAttrIfUnset(node, "class", "0x03")); NCCLCHECK(ncclTopoGetXmlFromSys(node, xml)); nvmlDevice_t nvmlDev = NULL; static int nvmlInit = 0; @@ -731,6 +751,7 @@ ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const cha char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; strcpy(busId, pciSysPath+offset+1); NCCLCHECK(ncclTopoGetPciNode(xml, busId, &parent)); + NCCLCHECK(xmlSetAttrIfUnset(parent, "class", "0x02")); NCCLCHECK(ncclTopoGetXmlFromSys(parent, xml)); } else { // Virtual NIC, no PCI device, attach to first CPU diff --git a/src/graph/xml.h b/src/graph/xml.h index 6f1ecfb..76f29b2 100644 --- a/src/graph/xml.h +++ b/src/graph/xml.h @@ -38,7 +38,7 @@ struct ncclXml { /* File functions */ #define NCCL_TOPO_XML_VERSION 1 -ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml); +ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml, int warn); ncclResult_t ncclTopoDumpXmlToFile(const char* xmlTopoFile, struct ncclXml* xml); #define NCCL_GRAPH_XML_VERSION 1 ncclResult_t ncclTopoGetXmlGraphFromFile(const char* xmlGraphFile, struct ncclXml* xml); @@ -137,6 +137,18 @@ static ncclResult_t xmlSetAttr(struct ncclXmlNode* node, const char* attrName, c return ncclSuccess; } +static ncclResult_t xmlSetAttrIfUnset(struct ncclXmlNode* node, const char* attrName, const char* value) { + int index; + NCCLCHECK(xmlGetAttrIndex(node, attrName, &index)); + if (index != -1) return ncclSuccess; + index = node->nAttrs++; + strncpy(node->attrs[index].key, attrName, MAX_STR_LEN); + node->attrs[index].key[MAX_STR_LEN] = '\0'; + strncpy(node->attrs[index].value, value, MAX_STR_LEN); + node->attrs[index].value[MAX_STR_LEN] = '\0'; + return ncclSuccess; +} + static ncclResult_t xmlSetAttrInt(struct ncclXmlNode* node, const char* attrName, const int value) { int index; NCCLCHECK(xmlGetAttrIndex(node, attrName, &index)); diff --git a/src/group.cc b/src/group.cc index 382b61e..217e76d 100644 --- a/src/group.cc +++ b/src/group.cc @@ -133,6 +133,7 @@ void* ncclAsyncThreadPreconnect(void* args_) { struct ncclAsyncArgs* args = (struct ncclAsyncArgs*)args_; struct ncclComm* comm = args->coll.comm; CUDACHECKTHREAD(cudaSetDevice(comm->cudaDev)); + if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity); NCCLCHECKTHREAD(ncclTransportP2pSetup(comm, NULL, 0)); return args; } @@ -217,8 +218,6 @@ ncclResult_t ncclGroupEnd() { struct ncclComm* comm = args->coll.comm; int rank = comm->rank; int nRanks = comm->nRanks; - struct ncclP2Plist* p2pSends = comm->p2pSends; - struct ncclP2Plist* p2pRecvs = comm->p2pRecvs; // Compute how much to split operations // Natural step size matching buffer steps. @@ -241,8 +240,8 @@ ncclResult_t ncclGroupEnd() { sched_delta: uint32_t from = (rank+nRanks-delta)%nRanks; uint32_t to = (rank+delta)%nRanks; - struct ncclP2Pinfo* recv = p2pRecvs[from].head; - struct ncclP2Pinfo* send = p2pSends[to].head; + struct ncclP2Pinfo* recv = comm->p2pRecvs[from] ? comm->p2pRecvs[from]->getNext() : NULL; + struct ncclP2Pinfo* send = comm->p2pSends[to] ? comm->p2pSends[to]->getNext() : NULL; if (recv != NULL || send != NULL) { ssize_t totRecvBytes = -1, totSendBytes = -1; if (recv != NULL) totRecvBytes = recv->nbytes; @@ -273,15 +272,11 @@ sched_delta: sendOffset += sendChunkSize; chunk++; } while (sendRemaining || recvRemaining); - if (recv) { - NCCLCHECKGOTO(dequeueP2pInfo(p2pRecvs+from), ret, group_cleanup); - comm->p2pRecvCount--; - } - if (send) { - NCCLCHECKGOTO(dequeueP2pInfo(p2pSends+to), ret, group_cleanup); - comm->p2pSendCount--; - } + if (recv) comm->p2pRecvCount--; + if (send) comm->p2pSendCount--; } + if (recv == NULL && comm->p2pRecvs[from]) comm->p2pRecvs[from]->recycle(); + if (send == NULL && comm->p2pSends[to]) comm->p2pSends[to]->recycle(); index++; if (index == 1 && deltas[1] == deltas[0]) index++; if (index == 2 && deltas[2] == deltas[0]) index++; @@ -381,11 +376,9 @@ group_cleanup: comm->asyncTotalSize = 0; // Dequeue p2p lists if (comm->p2pSendCount > 0 || comm->p2pRecvCount > 0) { - struct ncclP2Plist* p2pSends = comm->p2pSends; - struct ncclP2Plist* p2pRecvs = comm->p2pRecvs; for (int peer=0; peernRanks; peer++) { - while (p2pSends[peer].head != NULL) dequeueP2pInfo(p2pSends+peer); - while (p2pRecvs[peer].head != NULL) dequeueP2pInfo(p2pRecvs+peer); + if (comm->p2pSends[peer]) comm->p2pSends[peer]->recycle(); + if (comm->p2pRecvs[peer]) comm->p2pRecvs[peer]->recycle(); } comm->p2pSendCount = comm->p2pRecvCount = 0; } diff --git a/src/include/align.h b/src/include/align.h index 1c9e7aa..e3780fe 100644 --- a/src/include/align.h +++ b/src/include/align.h @@ -16,4 +16,29 @@ #define ALIGN_SIZE(size, align) \ size = ((size + (align) - 1) / (align)) * (align); +#if !__CUDA_ARCH__ + #ifndef __host__ + #define __host__ + #endif + #ifndef __device__ + #define __device__ + #endif +#endif + +template +__host__ __device__ constexpr Z divUp(X x, Y y) { + return (x+y-1)/y; +} + +template +__host__ __device__ constexpr Z roundUp(X x, Y y) { + return (x+y-1) - (x+y-1)%y; +} + +// assumes second argument is a power of 2 +template +__host__ __device__ constexpr Z alignUp(X x, int a) { + return (x+a-1) & Z(-a); +} + #endif diff --git a/src/include/alloc.h b/src/include/alloc.h index e898d37..9488c90 100644 --- a/src/include/alloc.h +++ b/src/include/alloc.h @@ -13,12 +13,13 @@ #include template -static ncclResult_t ncclCudaHostCalloc(T** ptr, size_t nelem) { +static ncclResult_t ncclCudaHostCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) { CUDACHECK(cudaHostAlloc(ptr, nelem*sizeof(T), cudaHostAllocMapped)); memset(*ptr, 0, nelem*sizeof(T)); - INFO(NCCL_ALLOC, "Cuda Host Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr); + INFO(NCCL_ALLOC, "%s:%d Cuda Host Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr); return ncclSuccess; } +#define ncclCudaHostCalloc(...) ncclCudaHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__) static inline ncclResult_t ncclCudaHostFree(void* ptr) { CUDACHECK(cudaFreeHost(ptr)); @@ -26,7 +27,7 @@ static inline ncclResult_t ncclCudaHostFree(void* ptr) { } template -static ncclResult_t ncclCalloc(T** ptr, size_t nelem) { +static ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) { void* p = malloc(nelem*sizeof(T)); if (p == NULL) { WARN("Failed to malloc %ld bytes", nelem*sizeof(T)); @@ -34,12 +35,13 @@ static ncclResult_t ncclCalloc(T** ptr, size_t nelem) { } memset(p, 0, nelem*sizeof(T)); *ptr = (T*)p; - INFO(NCCL_ALLOC, "Mem Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr); + INFO(NCCL_ALLOC, "%s:%d Mem Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr); return ncclSuccess; } +#define ncclCalloc(...) ncclCallocDebug(__VA_ARGS__, __FILE__, __LINE__) template -static ncclResult_t ncclCudaCalloc(T** ptr, size_t nelem) { +static ncclResult_t ncclCudaCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) { // Need async stream for P2P pre-connect + CUDA Graph cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); @@ -47,9 +49,10 @@ static ncclResult_t ncclCudaCalloc(T** ptr, size_t nelem) { CUDACHECK(cudaMemsetAsync(*ptr, 0, nelem*sizeof(T), stream)); CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamDestroy(stream)); - INFO(NCCL_ALLOC, "Cuda Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr); + INFO(NCCL_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr); return ncclSuccess; } +#define ncclCudaCalloc(...) ncclCudaCallocDebug(__VA_ARGS__, __FILE__, __LINE__) template static ncclResult_t ncclCudaMemcpy(T* dst, T* src, size_t nelem) { @@ -60,7 +63,7 @@ static ncclResult_t ncclCudaMemcpy(T* dst, T* src, size_t nelem) { // Allocate memory to be potentially ibv_reg_mr'd. This needs to be // allocated on separate pages as those pages will be marked DONTFORK // and if they are shared, that could cause a crash in a child process -static ncclResult_t ncclIbMalloc(void** ptr, size_t size) { +static ncclResult_t ncclIbMallocDebug(void** ptr, size_t size, const char *filefunc, int line) { size_t page_size = sysconf(_SC_PAGESIZE); void* p; int size_aligned = ROUNDUP(size, page_size); @@ -68,8 +71,9 @@ static ncclResult_t ncclIbMalloc(void** ptr, size_t size) { if (ret != 0) return ncclSystemError; memset(p, 0, size); *ptr = p; - INFO(NCCL_ALLOC, "Ib Alloc Size %ld pointer %p", size, *ptr); + INFO(NCCL_ALLOC, "%s:%d Ib Alloc Size %ld pointer %p", filefunc, line, size, *ptr); return ncclSuccess; } +#define ncclIbMalloc(...) ncclIbMallocDebug(__VA_ARGS__, __FILE__, __LINE__) #endif diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index fff8e26..9c2e4f6 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -16,6 +16,7 @@ ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commSt ncclResult_t bootstrapAllGather(void* commState, void* allData, int size); ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size); ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); +ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks); ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, cudaIpcMemHandle_t* ipc, void** ptr); ncclResult_t bootstrapRemFree(int id, int rank, void* commState); ncclResult_t bootstrapClose(void* commState); diff --git a/src/include/collectives.h b/src/include/collectives.h index 9b9022e..db073f0 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -21,8 +21,8 @@ /* Declare all collective operations */ #define DECL5(func, algo, proto, redop, type) \ - extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args); \ - extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem c); \ + extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(); \ + extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(ncclWorkElem c); \ #define DECL4(func, algo, redop, type) \ DECL5(func, algo, SIMPLE, redop, type) \ @@ -34,6 +34,19 @@ DECL4(func, TREE, redop, type) \ DECL4(func, COLLNET, redop, type) +#if defined(__CUDA_BF16_TYPES_EXIST__) +#define DECL2(func, redop) \ + DECL3(func, redop, int8_t) \ + DECL3(func, redop, uint8_t) \ + DECL3(func, redop, int32_t) \ + DECL3(func, redop, uint32_t) \ + DECL3(func, redop, int64_t) \ + DECL3(func, redop, uint64_t) \ + DECL3(func, redop, half) \ + DECL3(func, redop, float) \ + DECL3(func, redop, double) \ + DECL3(func, redop, __nv_bfloat16) +#else #define DECL2(func, redop) \ DECL3(func, redop, int8_t) \ DECL3(func, redop, uint8_t) \ @@ -44,12 +57,14 @@ DECL3(func, redop, half) \ DECL3(func, redop, float) \ DECL3(func, redop, double) +#endif #define DECL(func) \ DECL2(func, Sum) \ DECL2(func, Prod) \ DECL2(func, Min) \ - DECL2(func, Max) + DECL2(func, Max) \ + DECL2(func, Avg) #define DECL_ALL \ DECL2(Broadcast, Sum) \ diff --git a/src/include/comm.h b/src/include/comm.h index ee8ac46..214d988 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -72,6 +72,7 @@ struct ncclComm { int nRanks; // number of GPUs in communicator int cudaDev; // my cuda device index int64_t busId; // my PCI bus ID in int format + cpu_set_t cpuAffinity; // CPU affinity of the GPU int node; int nNodes; @@ -146,11 +147,13 @@ struct ncclComm { struct ncclInfo* asyncOps; int asyncOpCount; size_t asyncTotalSize; + ssize_t channelSize; int lastChannel; + enum { ROUND_ROBIN, SHORTEST_QUEUE } asyncAllocMode; //list of async p2p operation queued in a group semantics - struct ncclP2Plist* p2pSends; - struct ncclP2Plist* p2pRecvs; + ncclP2Plist** p2pSends; + ncclP2Plist** p2pRecvs; int p2pSendCount; int p2pRecvCount; diff --git a/src/include/core.h b/src/include/core.h index 2283134..823a016 100644 --- a/src/include/core.h +++ b/src/include/core.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -36,6 +36,9 @@ static __inline__ int ncclTypeSize(ncclDataType_t type) { case ncclUint8: return 1; case ncclFloat16: +#if defined(__CUDA_BF16_TYPES_EXIST__) + case ncclBfloat16: +#endif return 2; case ncclInt32: case ncclUint32: diff --git a/src/include/devcomm.h b/src/include/devcomm.h index 9071dd1..f172f38 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -12,7 +12,7 @@ #include #define NCCL_NUM_FUNCTIONS 5 // SendRecv not included for now -typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv} ncclFunc_t; +typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclNumFuncs} ncclFunc_t; extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS]; #define NCCL_NUM_ALGORITHMS 3 // Tree/Ring/CollNet @@ -69,10 +69,6 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK #define NCCL_LL128_MAX_NTHREADS 640 #define NCCL_LL128_ELEMS_PER_THREAD 120 -// Receiving from up to 3 sources is more compute intensive than sending -// to 3 dests. Use 70% for reduce and 30% for bcast. -#define NCCL_LL128_SPLIT(nt) ((nt*7/(10*32))*32) - #define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 8 #define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS) @@ -116,6 +112,8 @@ struct ncclRing { // devices. Ordered from current device. int* userRanks; int* devUserRanks; + + int index; // This rank's index in the ring }; @@ -203,6 +201,7 @@ struct ncclChannel { // Operation list for aggregation struct ncclWork* workFifo; int workCount; + size_t totalSize; uint64_t workFifoTail; // Only used by CPU uint16_t index; // Only used by GPU @@ -228,4 +227,9 @@ struct ncclDevComm { struct ncclChannel* channels; }; +struct ncclDevCommAndChannels { + ncclDevComm comm; + ncclChannel channels[MAXCHANNELS]; +}; + #endif diff --git a/src/include/enqueue.h b/src/include/enqueue.h index 6081f85..4632c9b 100644 --- a/src/include/enqueue.h +++ b/src/include/enqueue.h @@ -11,6 +11,9 @@ #include "group.h" #include "collectives.h" +#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64) +#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */ + size_t ncclKernMaxLocalSize(); ncclResult_t ncclEnqueueCheck(struct ncclInfo* info); ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast); @@ -31,39 +34,22 @@ ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph); struct ncclQueueElem { struct ncclWorkElem work; struct ncclProxyArgs proxyArgs; - struct ncclQueueElem* next; }; -// Store enqueue elements in a list -struct ncclQueueElemList { - struct ncclQueueElem* head; - struct ncclQueueElem* tail; -}; +typedef ncclRecyclableList ncclQueueElemList; // Structure passed to CUDA graph struct ncclQueueInfo { ncclComm_t comm; int maxChannels; // Dynamic version of gridDim ncclResult_t ret; // Return value of host setup call - struct ncclQueueElemList elemList; + ncclQueueElemList* elemList; }; -// Get next element from enqueue list -static ncclResult_t ncclAddQueueElem(struct ncclQueueInfo* eqInfo, struct ncclQueueElem** elemOut) { - if (eqInfo == NULL) return ncclInternalError; - struct ncclQueueElemList* list = &eqInfo->elemList; - if (list->tail != NULL) { - *elemOut = list->tail; - memset(*elemOut, 0, sizeof(struct ncclWorkElem) + sizeof(struct ncclProxyArgs)); - } else { - NCCLCHECK(ncclCalloc(&list->tail, 1)); - *elemOut = list->tail; - list->head = list->tail; - } - if (list->tail->next == NULL) { - NCCLCHECK(ncclCalloc(&list->tail->next, 1)); - } - list->tail = list->tail->next; +static ncclResult_t ncclCreateQueueInfo(struct ncclQueueInfo** eqInfo, ncclComm_t comm) { + NCCLCHECK(ncclCalloc(eqInfo, 1)); + (*eqInfo)->comm = comm; + (*eqInfo)->elemList = new ncclQueueElemList(); return ncclSuccess; } @@ -72,7 +58,7 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) { if (eqInfo == NULL) return ncclInternalError; eqInfo->maxChannels = 0; eqInfo->ret = ncclSuccess; - eqInfo->elemList.tail = eqInfo->elemList.head; + eqInfo->elemList->recycle(); return ncclSuccess; } @@ -81,12 +67,7 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) { static void ncclDestroyQueueInfo(void* ptr) { if (ptr == NULL) return; struct ncclQueueInfo* eqInfo = (struct ncclQueueInfo*)ptr; - struct ncclQueueElem* head = eqInfo->elemList.head; - while (head != NULL) { - struct ncclQueueElem* temp = head; - head = head->next; - free(temp); - } + delete eqInfo->elemList; free(eqInfo); } #endif // End include guard diff --git a/src/include/graph.h b/src/include/graph.h index 1429b3a..4b7a836 100644 --- a/src/include/graph.h +++ b/src/include/graph.h @@ -13,6 +13,7 @@ #include #include #include +#include ncclResult_t ncclTopoCudaPath(int cudaDev, char** path); @@ -33,8 +34,8 @@ ncclResult_t ncclTopoGetNetDev(struct ncclTopoSystem* system, int rank, struct n ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int64_t id1, int64_t id2, int* p2p, int *read, int* intermediateRank); ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* topo, int64_t busId, int netDev, int read, int* useGdr); -// Set CPU affinity -ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank); +// Find CPU affinity +ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu_set_t* affinity); #define NCCL_TOPO_CPU_ARCH_X86 1 #define NCCL_TOPO_CPU_ARCH_POWER 2 @@ -100,6 +101,6 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph); #include "info.h" -ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, float* time); +ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time); #endif diff --git a/src/include/p2p.h b/src/include/p2p.h index 756c8d2..2519873 100644 --- a/src/include/p2p.h +++ b/src/include/p2p.h @@ -12,32 +12,16 @@ struct ncclP2Pinfo { void* buff; ssize_t nbytes; - struct ncclP2Pinfo* next; }; -struct ncclP2Plist { - struct ncclP2Pinfo *head; - struct ncclP2Pinfo *tail; -}; +typedef ncclRecyclableList ncclP2Plist; -static ncclResult_t enqueueP2pInfo(ncclP2Plist* p2p, void* buff, ssize_t nBytes) { - if (p2p == NULL) return ncclInternalError; +static ncclResult_t ncclSaveP2pInfo(ncclP2Plist* &p2p, void* buff, ssize_t nBytes) { + if (p2p == NULL) p2p = new ncclP2Plist(); struct ncclP2Pinfo* next; - NCCLCHECK(ncclCalloc(&next, 1)); + NCCLCHECK(p2p->getNewElem(&next)); next->buff = buff; next->nbytes = nBytes; - if (p2p->tail != NULL) p2p->tail->next = next; - p2p->tail = next; - if (p2p->head == NULL) p2p->head = next; - return ncclSuccess; -} - -static ncclResult_t dequeueP2pInfo(ncclP2Plist* p2p) { - if (p2p == NULL) return ncclInternalError; - struct ncclP2Pinfo* temp = p2p->head; - p2p->head = p2p->head->next; - if (p2p->tail == temp) p2p->tail = NULL; - free(temp); return ncclSuccess; } #endif diff --git a/src/include/socket.h b/src/include/socket.h index 8b59f72..6ca5f7d 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -30,12 +30,13 @@ union socketAddress { struct sockaddr_in6 sin6; }; -/* Format a string representation of a (struct sockaddr *) socket address using getnameinfo() +/* Format a string representation of a (union socketAddress *) socket address using getnameinfo() * * Output: "IPv4/IPv6 address" */ -static inline const char *socketToString(struct sockaddr *saddr, char *buf) { - if (buf == NULL || saddr == NULL) return NULL; +static inline const char *socketToString(union socketAddress *addr, char *buf) { + if (buf == NULL || addr == NULL) return NULL; + struct sockaddr *saddr = &addr->sa; if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0]='\0'; return buf; } char host[NI_MAXHOST], service[NI_MAXSERV]; (void) getnameinfo(saddr, sizeof(union socketAddress), host, NI_MAXHOST, service, NI_MAXSERV, NI_NUMERICHOST|NI_NUMERICSERV); @@ -43,8 +44,9 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) { return buf; } -static inline uint16_t socketToPort(struct sockaddr *saddr) { - return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port); +static inline uint16_t socketToPort(union socketAddress *addr) { + struct sockaddr *saddr = &addr->sa; + return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port); } /* Allow the user to force the IPv4/IPv6 interface selection */ @@ -85,7 +87,7 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre if (family != AF_INET && family != AF_INET6) continue; - TRACE(NCCL_INIT|NCCL_NET,"Found interface %s:%s", interface->ifa_name, socketToString(interface->ifa_addr, line)); + TRACE(NCCL_INIT|NCCL_NET,"Found interface %s:%s", interface->ifa_name, socketToString((union socketAddress *)interface->ifa_addr, line)); /* Allow the caller to force the socket family type */ if (sock_family != -1 && family != sock_family) @@ -194,13 +196,13 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd // Store the interface name strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize); - TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr->sa), line_a)); + TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(localAddrs+found, line), socketToString(remoteAddr, line_a)); found++; if (found == maxIfs) break; } if (found == 0) { - WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr->sa), line_a)); + WARN("Net : No interface found in the same subnet as remote address %s", socketToString(remoteAddr, line_a)); } freeifaddrs(interfaces); return found; @@ -333,7 +335,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr) return ncclSystemError; } - if (socketToPort(&localAddr->sa)) { + if (socketToPort(localAddr)) { // Port is forced by env. Make sure we get the port. int opt = 1; #if defined(SO_REUSEPORT) @@ -352,7 +354,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr) #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; - TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(&localAddr->sa, line)); + TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(localAddr, line)); #endif /* Put the socket in listen mode @@ -364,10 +366,12 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr) } static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) { + char line[SOCKET_NAME_MAXLEN+1]; /* IPv4/IPv6 support */ int family = remoteAddr->sa.sa_family; if (family != AF_INET && family != AF_INET6) { - WARN("Error : connecting to address with family %d is neither AF_INET(%d) nor AF_INET6(%d)", family, AF_INET, AF_INET6); + WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)", + socketToString(remoteAddr, line), family, AF_INET, AF_INET6); return ncclInternalError; } int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6); @@ -386,8 +390,7 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) { SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt"); SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/ - char line[SOCKET_NAME_MAXLEN+1]; - TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(&remoteAddr->sa, line)); + TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(remoteAddr, line)); int ret; int timedout_retries = 0; @@ -403,25 +406,26 @@ retry: goto retry; } } - WARN("Connect to %s failed : %s", socketToString(&remoteAddr->sa, line), strerror(errno)); + WARN("Net : Connect to %s failed : %s", socketToString(remoteAddr, line), strerror(errno)); return ncclSystemError; } #define NCCL_SOCKET_SEND 0 #define NCCL_SOCKET_RECV 1 -static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) { +static ncclResult_t socketProgressOpt(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset, int block) { int bytes = 0; char* data = (char*)ptr; + char line[SOCKET_NAME_MAXLEN+1]; do { if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); if (op == NCCL_SOCKET_RECV && bytes == 0) { - WARN("Net : Connection closed by remote peer"); + WARN("Net : Connection closed by remote peer %s", socketToString(addr, line)); return ncclSystemError; } if (bytes == -1) { if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { - WARN("Call to recv failed : %s", strerror(errno)); + WARN("Net : Call to recv from %s failed : %s", socketToString(addr, line), strerror(errno)); return ncclSystemError; } else { bytes = 0; @@ -432,25 +436,25 @@ static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* return ncclSuccess; } -static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) { - return socketProgressOpt(op, fd, ptr, size, offset, 0); +static ncclResult_t socketProgress(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset) { + return socketProgressOpt(op, fd, addr, ptr, size, offset, 0); } -static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) { +static ncclResult_t socketWait(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset) { while (*offset < size) - NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1)); + NCCLCHECK(socketProgressOpt(op, fd, addr, ptr, size, offset, 1)); return ncclSuccess; } -static ncclResult_t socketSend(int fd, void* ptr, int size) { +static ncclResult_t socketSend(int fd, union socketAddress *addr, void* ptr, int size) { int offset = 0; - NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, ptr, size, &offset)); + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, addr, ptr, size, &offset)); return ncclSuccess; } -static ncclResult_t socketRecv(int fd, void* ptr, int size) { +static ncclResult_t socketRecv(int fd, union socketAddress *addr, void* ptr, int size) { int offset = 0; - NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset)); + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, addr, ptr, size, &offset)); return ncclSuccess; } diff --git a/src/include/utils.h b/src/include/utils.h index 86ab3a2..739a774 100644 --- a/src/include/utils.h +++ b/src/include/utils.h @@ -37,4 +37,76 @@ static long log2i(long n) { return l; } +// Recyclable list that avoids frequent malloc/free +template +struct ncclListElem { + T data; + struct ncclListElem* next; +}; + +template +class ncclRecyclableList { + private: + struct ncclListElem* head; + struct ncclListElem* tail; + struct ncclListElem* cursor; + int n; + + public: + ncclRecyclableList() { + tail = cursor = head = NULL; + n = 0; + } + + int count() const { return n; } + + // Get a new element from the list and return pointer + ncclResult_t getNewElem(T** dataOut) { + if (tail != NULL) { + *dataOut = &tail->data; + memset(*dataOut, 0, sizeof(T)); + } else { + NCCLCHECK(ncclCalloc(&tail, 1)); + *dataOut = &tail->data; + cursor = head = tail; + } + if (tail->next == NULL) { + NCCLCHECK(ncclCalloc(&tail->next, 1)); + } + tail = tail->next; + n += 1; + return ncclSuccess; + } + + T* begin() { + if (head == NULL || head == tail) return NULL; + cursor = head->next; + return &head->data; + } + + // Get next element from the list during an iteration + T* getNext() { + // tail always points to the next element to be enqueued + // hence does not contain valid data + if (cursor == NULL || cursor == tail) return NULL; + T* rv = &cursor->data; + cursor = cursor->next; + return rv; + } + + // Recycle the list without freeing the space + void recycle() { + tail = cursor = head; + n = 0; + } + + ~ncclRecyclableList() { + while (head != NULL) { + struct ncclListElem* temp = head; + head = head->next; + free(temp); + } + } +}; + #endif diff --git a/src/init.cc b/src/init.cc index 474218b..6fb251f 100644 --- a/src/init.cc +++ b/src/init.cc @@ -79,21 +79,17 @@ ncclResult_t initNetPlugin(ncclNet_t** net, ncclCollNet_t** collnet) { } return ncclSuccess; } - ncclNet_t* extNet = (ncclNet_t*) dlsym(netPluginLib, STR(NCCL_PLUGIN_SYMBOL)); - if (extNet == NULL) { + *net = (ncclNet_t*) dlsym(netPluginLib, STR(NCCL_PLUGIN_SYMBOL)); + if (*net == NULL) { INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_PLUGIN_SYMBOL) " symbol."); - } else if (initNet(extNet) == ncclSuccess) { - *net = extNet; - // Check for CollNet - ncclCollNet_t* extCollNet = (ncclCollNet_t*) dlsym(netPluginLib, STR(NCCL_COLLNET_PLUGIN_SYMBOL)); - if (extCollNet == NULL) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol."); - } else if (initCollNet(extCollNet) == ncclSuccess) { - *collnet = extCollNet; - } + if (netPluginLib != NULL) dlclose(netPluginLib); return ncclSuccess; } - if (netPluginLib != NULL) dlclose(netPluginLib); + // Check for CollNet + *collnet = (ncclCollNet_t*) dlsym(netPluginLib, STR(NCCL_COLLNET_PLUGIN_SYMBOL)); + if (*collnet == NULL) { + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol."); + } return ncclSuccess; } @@ -101,13 +97,27 @@ ncclResult_t initNet() { // Always initialize bootstrap network NCCLCHECK(bootstrapNetInit()); - NCCLCHECK(initNetPlugin(&ncclNet, &ncclCollNet)); - if (ncclNet != NULL) return ncclSuccess; - if (initNet(&ncclNetIb) == ncclSuccess) { - ncclNet = &ncclNetIb; - } else { - NCCLCHECK(initNet(&ncclNetSocket)); - ncclNet = &ncclNetSocket; + // Initialize main communication network + ncclNet_t* nets[3] = { NULL, &ncclNetIb, &ncclNetSocket }; + ncclCollNet_t* collNets[3] = { NULL, NULL, NULL }; + NCCLCHECK(initNetPlugin(nets+0, collNets+0)); + char* netName = getenv("NCCL_NET"); + + for (int i=0; i<3; i++) { + if (nets[i] == NULL) continue; + if (netName && strcmp(netName, nets[i]->name) != 0) continue; + // net plugin is already initialized + if (initNet(nets[i]) != ncclSuccess) continue; + ncclNet = nets[i]; + if (collNets[i] && initCollNet(collNets[i]) == ncclSuccess) { + ncclCollNet = collNets[i]; + } + break; + } + + if (ncclNet == NULL) { + WARN("Error: network %s not found.", netName ? netName : ""); + return ncclInvalidUsage; } return ncclSuccess; } @@ -177,6 +187,10 @@ static ncclResult_t commFree(ncclComm_t comm) { return ncclSuccess; free(comm->connectSend); free(comm->connectRecv); + for (int peer=0; peernRanks; peer++) { + delete comm->p2pSends[peer]; + delete comm->p2pRecvs[peer]; + } free(comm->p2pSends); free(comm->p2pRecvs); free(comm->asyncOps); @@ -187,8 +201,7 @@ static ncclResult_t commFree(ncclComm_t comm) { if (comm->bootstrap) NCCLCHECK(bootstrapClose(comm->bootstrap)); - CUDACHECK(cudaFree(comm->hostDevComm.channels)); - CUDACHECK(cudaFree(comm->devComm)); + CUDACHECK(cudaFree((ncclDevCommAndChannels*)comm->devComm)); for (int channel=0; channelchannels+channel, comm->nRanks)); @@ -224,6 +237,8 @@ static ncclResult_t commFree(ncclComm_t comm) { return ncclSuccess; } +NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2); + static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { if (ndev < 1) { WARN("invalid device count (%d) requested", ndev); @@ -271,9 +286,15 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { NCCLCHECK(ncclCalloc(&comm->asyncOps, NCCL_MAX_OPS)); comm->asyncOpCount = 0; comm->asyncTotalSize = 0; + comm->channelSize = ncclParamAggChannelSize(); + comm->asyncAllocMode = ncclComm::SHORTEST_QUEUE; + char* str = getenv("NCCL_AGG_ALLOC_MODE"); + if (str) INFO(NCCL_ENV, "NCCL_AGG_ALLOC_MODE set by environment to %s", str); + if (str && strcmp(str, "ROUND_ROBIN") == 0) { + comm->asyncAllocMode = ncclComm::ROUND_ROBIN; + } - NCCLCHECK(ncclCalloc(&comm->enqueueInfo, 1)); - comm->enqueueInfo->comm = comm; + NCCLCHECK(ncclCreateQueueInfo(&comm->enqueueInfo, comm)); comm->lastSetupNode = NULL; comm->lastCudaGraphId = -1; @@ -296,9 +317,13 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { } static ncclResult_t devCommSetup(ncclComm_t comm) { + ncclDevCommAndChannels *devCommAndChans; + NCCLCHECK(ncclCudaCalloc(&devCommAndChans, 1)); + comm->devComm = &devCommAndChans->comm; + comm->hostDevComm.channels = devCommAndChans->channels; + // Duplicate the channels on the device int nChannels = std::max(comm->nChannels, comm->p2pnChannels); - NCCLCHECK(ncclCudaCalloc(&comm->hostDevComm.channels, nChannels)); NCCLCHECK(ncclCudaMemcpy(comm->hostDevComm.channels, comm->channels, nChannels)); // Copy userRanks and peers @@ -307,7 +332,6 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { } // Duplicate the dev comm on the device - NCCLCHECK(ncclCudaCalloc(&comm->devComm, 1)); NCCLCHECK(ncclCudaMemcpy(comm->devComm, &comm->hostDevComm, 1)); return ncclSuccess; } @@ -349,15 +373,15 @@ static ncclResult_t setupChannel(struct ncclComm* comm, int channelId, int rank, NCCLCHECK(initChannel(comm, channelId)); struct ncclRing* ring = &comm->channels[channelId].ring; - // Reorganize ranks to start with rank. - int shift; - for (shift = 0; shiftindex = (ixRank-ixZero + nranks)%nranks; for (int i=0; iuserRanks[i] = ringRanks[(i+shift)%nranks]; + ring->userRanks[i] = ringRanks[(i+ixRank)%nranks]; } return ncclSuccess; } @@ -379,7 +403,7 @@ ncclResult_t initParams(struct ncclComm* comm) { } // Allocate/Set Intra Process Structures and set CG options -ncclResult_t ncclCommSetIntra(struct ncclComm* comm, int rank, int ranks, struct ncclComm* comm0) { +ncclResult_t ncclCommSetIntraProc(struct ncclComm* comm, int rank, int ranks, struct ncclComm* comm0) { comm->intraRank = rank; comm->intraRanks = ranks; comm->intraPhase = 0; @@ -500,37 +524,45 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } // Compute intra ranks and minimum CUDA Compute capabilities of intra-node GPUs and all GPUs - int intraRank0 = -1, intraRank = -1, intraRanks = 0; + int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0; + int intraNodeRank0 = -1, intraNodeRank = -1, intraNodeRanks = 0; int myCompCap = allGather1Data[rank].cudaCompCap; int minCompCap = myCompCap, maxCompCap = myCompCap; - uint64_t otherHostHash; - int tmpNnodes = 1; + int intraNodeGlobalRanks[256]; for (int i = 0; i < nranks; i++) { if (allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) { + // Rank is on same node + if (intraNodeRanks == 0) intraNodeRank0 = i; + if (i == rank) intraNodeRank = intraNodeRanks; + intraNodeGlobalRanks[intraNodeRanks++] = i; if (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash) { - if (intraRanks == 0) intraRank0 = i; - if (i == rank) intraRank = intraRanks; - intraRanks++; - } - } else { // Determine whether number of nodes is 2 (for use in tree pattern determination) - if (tmpNnodes == 1) { - otherHostHash = allGather1Data[i].peerInfo.hostHash; - tmpNnodes = 2; - } else if (tmpNnodes == 2 && otherHostHash != allGather1Data[i].peerInfo.hostHash) { - tmpNnodes = 3; + // Rank is in same process + if (intraProcRanks == 0) intraProcRank0 = i; + if (i == rank) intraProcRank = intraProcRanks; + intraProcRanks++; } } minCompCap = std::min(allGather1Data[i].cudaCompCap, minCompCap); maxCompCap = std::max(allGather1Data[i].cudaCompCap, maxCompCap); } - TRACE(NCCL_INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d", - rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0); - if (intraRank == -1 || intraRank0 == -1 || allGather1Data[intraRank0].comm == NULL) { - WARN("Failed to determine intra ranks hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d", - rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0); + TRACE(NCCL_INIT,"hostHash[%d] %lx intraNodeRank %d intraNodeRanks %d intraNodeRank0 %d", + rank, allGather1Data[rank].peerInfo.hostHash, intraNodeRank, intraNodeRanks, intraNodeRank0); + TRACE(NCCL_INIT,"pidHash[%d] %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", + rank, allGather1Data[rank].peerInfo.pidHash, intraProcRank, intraProcRanks, intraProcRank0); + if (intraProcRank == -1 || intraProcRank0 == -1 || allGather1Data[intraProcRank0].comm == NULL) { + WARN("Failed to determine intra proc ranks rank %d hostHash %lx pidHash %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", + rank, allGather1Data[rank].peerInfo.hostHash, allGather1Data[rank].peerInfo.pidHash, + intraProcRank, intraProcRanks, intraProcRank0); return ncclInternalError; } - struct ncclComm* intraRank0Comm = allGather1Data[intraRank0].comm; + if (intraNodeRank == -1 || intraNodeRank0 == -1 || intraNodeRanks == 0) { + WARN("Failed to determine intra node ranks rank %d hostHash %lx pidHash %lx intraNodeRank %d intraNodeRanks %d intraNodeRank0 %d", + rank, allGather1Data[rank].peerInfo.hostHash, allGather1Data[rank].peerInfo.pidHash, + intraNodeRank, intraNodeRanks, intraNodeRank0); + return ncclInternalError; + } + struct ncclComm* intraProcRank0Comm = allGather1Data[intraProcRank0].comm; + uint64_t intraNodeRank0pidHash = allGather1Data[intraNodeRank0].peerInfo.pidHash; free(allGather1Data); @@ -562,7 +594,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm struct ncclTopoGraph treeGraph; treeGraph.id = 1; - treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE; + treeGraph.pattern = NCCL_TOPO_PATTERN_BALANCED_TREE; treeGraph.crossNic = ncclParamCrossNic(); treeGraph.collNet = 0; treeGraph.minChannels = 1; @@ -585,8 +617,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } // Determine local CollNet support before all-gather - if (tmpNnodes > 1 && ncclParamCollNetEnable() == 1 && collNetSupport() == 1 && collNetGraph.nChannels > 0) comm->collNetSupport = 1; - if (intraRanks > 8) { + if (ncclParamCollNetEnable() == 1 && collNetSupport() == 1 && collNetGraph.nChannels > 0) comm->collNetSupport = 1; + if (intraNodeRanks > 8) { if (comm->collNetSupport == 1) WARN("CollNet currently only supports up to 8 GPUs per node"); comm->collNetSupport = 0; } @@ -719,15 +751,19 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm struct ncclTree* tree = &comm->channels[c].tree; snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d", c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up); + INFO(NCCL_GRAPH, "Ring %02d : %d -> %d -> %d", c, comm->channels[c].ring.prev, comm->rank, comm->channels[c].ring.next); } line[1023] = '\0'; INFO(NCCL_INIT, "Trees%s", line); // Set Affinity to a CPU local the our GPU, so that all memory we allocate // on the host is local. + NCCLCHECK(ncclTopoGetCpuAffinity(comm->topo, comm->rank, &comm->cpuAffinity)); cpu_set_t affinitySave; - sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave); - NCCLCHECK(ncclTopoSetAffinity(comm->topo, comm->rank)); + if (CPU_COUNT(&comm->cpuAffinity)) { + sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave); + sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity); + } ncclResult_t ret; NCCLCHECK(computeBuffSizes(comm)); @@ -768,10 +804,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm struct ncclChannel* channel = comm->channels+c; for (int h=0; hbootstrap, intraNodeGlobalRanks, (int)intraNodeRank0pidHash, intraNodeRank, intraNodeRanks)); if (comm->nNodes) NCCLCHECK(ncclProxyCreate(comm)); // We should have allocated all buffers, collective fifos, ... we can // restore the affinity. affinity_restore: - sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave); + if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave); if (ret != ncclSuccess) return ret; TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks); diff --git a/src/nccl.h.in b/src/nccl.h.in index 6cb046a..a793cac 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -9,6 +9,9 @@ #include #include +#if CUDART_VERSION >= 11000 +#include +#endif #define NCCL_MAJOR ${nccl:Major} #define NCCL_MINOR ${nccl:Minor} @@ -103,7 +106,8 @@ typedef enum { ncclSum = 0, ncclProd = 1, ncclMax = 2, ncclMin = 3, - ncclNumOps = 4 } ncclRedOp_t; + ncclAvg = 4, + ncclNumOps = 5 } ncclRedOp_t; /* Data types */ typedef enum { ncclInt8 = 0, ncclChar = 0, @@ -115,7 +119,13 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclFloat16 = 6, ncclHalf = 6, ncclFloat32 = 7, ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, - ncclNumTypes = 9 } ncclDataType_t; +#if defined(__CUDA_BF16_TYPES_EXIST__) + ncclBfloat16 = 9, + ncclNumTypes = 10 +#else + ncclNumTypes = 9 +#endif +} ncclDataType_t; /* * Collective communication operations diff --git a/src/proxy.cc b/src/proxy.cc index ed503ac..e5d2eab 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -41,9 +41,19 @@ static ncclResult_t allocateArgs(struct ncclComm* comm, struct ncclProxyArgs** a state->poolReturned = NULL; pthread_mutex_unlock(&state->poolMutex); } else { - // Allocate a new pool of elements + // Allocate a new pool of elements. Make sure we allocate the memory close + // to the network thread struct ncclProxyPool* newPool; + cpu_set_t affinitySave; + if (CPU_COUNT(&comm->cpuAffinity)) { + sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave); + sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity); + } NCCLCHECK(ncclCalloc(&newPool, 1)); + if (CPU_COUNT(&comm->cpuAffinity)) { + sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave); + } + struct ncclProxyArgs* newElems = newPool->elems; // Chain newly allocated elements for (int i=0; irank; int nranks = comm->nRanks; int nMasters = comm->nNodes; int rankInCollNet = -1; - int supported = 0; int isMaster = (rank == masterRank) ? 1 : 0; struct { int collNetRank; @@ -148,9 +147,9 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN // check if we can connect to collnet, whose root is the nranks-th rank struct ncclPeerInfo *myInfo = comm->peerInfo+rank, *peerInfo = comm->peerInfo+nranks; peerInfo->rank = nranks; - int ret = 1; + int support = 1; if (isMaster) { - NCCLCHECK(collNetTransport.canConnect(&ret, comm->topo, collNetGraph, myInfo, peerInfo)); + NCCLCHECK(collNetTransport.canConnect(&support, comm->topo, collNetGraph, myInfo, peerInfo)); } // send master receives connect info from peer recv master @@ -168,7 +167,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN conn->transportComm = transportComm; // setup struct ncclConnect myConnect; - if (isMaster && ret > 0) { + if (isMaster && support) { NCCLCHECK(transportComm->setup(comm, collNetGraph, myInfo, peerInfo, &myConnect, conn, collNetGraphChannelId, type)); } // prepare connect handles @@ -198,7 +197,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN if (isMaster) memcpy(masterConnects+rankInCollNet, &(sendrecvExchange.connect), sizeof(struct ncclConnect)); } // connect - if (isMaster && ret > 0) { + if (isMaster && support) { NCCLCHECKGOTO(transportComm->connect(comm, masterConnects, nMasters, rankInCollNet, conn), res, cleanup); struct ncclPeer* devRoot = channel->devPeers+nranks; struct ncclConnector* devConn = (type == collNetRecv) ? devRoot->recv+type : devRoot->send+type; @@ -211,13 +210,11 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, masterPeer, collNetGraph->id, &sendrecvExchange, sizeof(sendrecvExchange)), res, cleanup); TRACE(NCCL_INIT, "CollNet [recv] : rank %d collNetRank %d collNetNranks %d sent connect to rank %d", rank, rankInCollNet, nMasters, masterPeer); } - if (ret > 0) { - supported = 1; - } + if (support) fail = 0; cleanup: if (allConnects != NULL) free(allConnects); if (masterConnects != NULL) free(masterConnects); - return supported; + return fail; } ncclResult_t ncclTransportCollNetCheck(struct ncclComm* comm, int collNetSetupFail) { diff --git a/src/transport/coll_net.cc b/src/transport/coll_net.cc index 7b7ec56..49398f9 100644 --- a/src/transport/coll_net.cc +++ b/src/transport/coll_net.cc @@ -459,10 +459,9 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) { int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; char* ptr; int sharedBuffSlot = sub->posted%NCCL_STEPS; - NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, 0, &ptr)); - args->sharedBuff[sharedBuffSlot] = ptr; - int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS; - reqFifo[group][buffSlot].recvBuff = args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize; + int startChannel = group*COLLNET_GROUP_NSUBS; + NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, startChannel, &ptr)); + reqFifo[group][buffSlot].recvBuff = ptr; TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff); sub->posted += args->sliceSteps; args->idle = 0; @@ -478,9 +477,10 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) { TRACE(NCCL_NET, "recvProxy [%d/%d/%d] received, size %d", sub->received, group, buffSlot, totalSize); sub->received += args->sliceSteps; if (reqFifo[group][buffSlot].size > 0 && p == NCCL_PROTO_SIMPLE && resources->useGdr) { - int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS; - char* recvAddress = (char*)args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize; - NCCLCHECK(collNetIflush(resources->collNetComm, recvAddress, totalSize, mhandle, sub->requests+buffSlot)); + int startChannel = group*COLLNET_GROUP_NSUBS; + char* groupRecvAddress; + NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, 1, 1, sharedBuffSlot, startChannel, &groupRecvAddress)); + NCCLCHECK(collNetIflush(resources->collNetComm, groupRecvAddress, totalSize, mhandle, sub->requests+buffSlot)); } else { for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps; } @@ -505,8 +505,10 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) { int group = s / COLLNET_GROUP_NSUBS; int buffSlot = (sub->base + sub->transmitted)%NCCL_STEPS; int sharedBuffSlot = sub->transmitted%NCCL_STEPS; - int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS; - char* ptr = args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize + (s%COLLNET_GROUP_NSUBS)*args->sharedSize[sharedBuffSlot]; + int startChannel = group*COLLNET_GROUP_NSUBS; + char* groupRecvAddress; + NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, 1, 1, sharedBuffSlot, startChannel, &groupRecvAddress)); + char* ptr = groupRecvAddress + (s%COLLNET_GROUP_NSUBS)*args->sharedSize[sharedBuffSlot]; if (p == NCCL_PROTO_SIMPLE) { volatile void** ptrsFifo = (volatile void**)resources->recvMem->ptrsFifo; ptrsFifo[buffSlot] = ptr; diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index d867e3e..5b9f01e 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -201,7 +201,7 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { } line[1023] = '\0'; char addrline[SOCKET_NAME_MAXLEN+1]; - INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s ; OOB %s:%s", line, ncclIbIfName, socketToString(&ncclIbIfAddr.sa, addrline)); + INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s ; OOB %s:%s", line, ncclIbIfName, socketToString(&ncclIbIfAddr, addrline)); } pthread_mutex_unlock(&ncclIbLock); } @@ -252,10 +252,12 @@ ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { #define MAX_REQUESTS NCCL_NET_MAX_REQUESTS +#define NCCL_IB_MAX_QPS 128 + struct ncclIbQpInfo { uint32_t lid; uint8_t ib_port; - uint32_t qpn; + uint32_t qpn[NCCL_IB_MAX_QPS]; // For RoCE uint64_t spn; @@ -277,6 +279,7 @@ struct ncclIbRequest { struct ncclIbVerbs* verbs; int events; int size; + union socketAddress *addr; }; struct ncclIbVerbs { @@ -305,8 +308,10 @@ struct ncclIbSendComm { struct ncclIbSendFifo fifo[MAX_REQUESTS]; uint32_t fifoHead; int fd; + union socketAddress addr; int ready; - struct ibv_qp* qp; + struct ibv_qp* qps[NCCL_IB_MAX_QPS]; + int nqps; struct ibv_mr* fifoMr; }; // The SendFifo needs to be 32-byte aligned and each element needs @@ -337,16 +342,20 @@ struct ncclIbRecvComm { struct ncclIbVerbs verbs; struct ncclIbRemFifo remFifo; int fd; + union socketAddress addr; int ready; - struct ibv_qp* qp; + struct ibv_qp* qps[NCCL_IB_MAX_QPS]; + int nqps; struct ncclIbGpuFlush gpuFlush; }; static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); +NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1); + ncclResult_t ncclIbInitVerbs(ibv_context* ctx, struct ncclIbVerbs* verbs) { NCCLCHECK(wrap_ibv_alloc_pd(&verbs->pd, ctx)); // Recv requests can generate 2 completions (one for the post FIFO, one for the Recv). - NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS, NULL, NULL, 0)); + NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS*ncclParamIbQpsPerConn(), NULL, NULL, 0)); return ncclSuccess; } @@ -379,12 +388,12 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbVerbs* verbs, int acce return ncclSuccess; } -ncclResult_t ncclIbRtrQp(ibv_qp* qp, struct ncclIbQpInfo* info) { +ncclResult_t ncclIbRtrQp(ibv_qp* qp, uint32_t qpn, struct ncclIbQpInfo* info) { struct ibv_qp_attr qpAttr; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_RTR; qpAttr.path_mtu = info->mtu; - qpAttr.dest_qp_num = info->qpn; + qpAttr.dest_qp_num = qpn; qpAttr.rq_psn = 0; qpAttr.max_dest_rd_atomic = 1; qpAttr.min_rnr_timer = 12; @@ -441,18 +450,23 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) { NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr)); *sendComm = comm; + comm->addr = handle->connectAddr; + // IB Setup ibv_context* ctx = ncclIbDevs[dev].context; NCCLCHECK(ncclIbInitVerbs(ctx, &comm->verbs)); uint8_t ib_port = ncclIbDevs[dev].port; - NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, &comm->qp)); + comm->nqps = ncclParamIbQpsPerConn(); + for (int q=0; qnqps; q++) { + NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, comm->qps+q)); + } // Send my QP Info to receiver through the socket. Hope this won't block. struct ibv_port_attr portAttr; NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr)); struct ncclIbQpInfo qpInfo; qpInfo.ib_port = ib_port; - qpInfo.qpn = comm->qp->qp_num; + for (int q=0; qnqps; q++) qpInfo.qpn[q] = comm->qps[q]->qp_num; qpInfo.mtu = portAttr.active_mtu; // Prepare my fifo @@ -463,16 +477,18 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) { // RoCE support qpInfo.lid = portAttr.lid; if (qpInfo.lid) { // IB - INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn, qpInfo.mtu, qpInfo.lid); + for (int q=0; qnqps; q++) + INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.lid); } else { // RoCE union ibv_gid gid; NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &gid)); qpInfo.spn = gid.global.subnet_prefix; qpInfo.iid = gid.global.interface_id; - INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn, qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid); + for (int q=0; qnqps; q++) + INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid); } - NCCLCHECK(socketSend(comm->fd, &qpInfo, sizeof(qpInfo))); + NCCLCHECK(socketSend(comm->fd, &comm->addr, &qpInfo, sizeof(qpInfo))); return ncclSuccess; } @@ -483,11 +499,10 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { struct ncclIbRecvComm* rComm; NCCLCHECK(ncclIbMalloc((void**)&rComm, sizeof(struct ncclIbRecvComm))); - struct sockaddr_in sockaddr; - socklen_t socklen = sizeof(struct sockaddr_in); - SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd); + socklen_t socklen = sizeof(union socketAddress); + SYSCHECKVAL(accept(lComm->fd, &rComm->addr.sa, &socklen), "accept", rComm->fd); struct ncclIbQpInfo remQpInfo; - NCCLCHECK(socketRecv(rComm->fd, &remQpInfo, sizeof(remQpInfo))); + NCCLCHECK(socketRecv(rComm->fd, &rComm->addr, &remQpInfo, sizeof(remQpInfo))); // IB setup ibv_context* ctx = ncclIbDevs[lComm->dev].context; @@ -499,15 +514,20 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { // QP Creation NCCLCHECK(ncclIbInitVerbs(ctx, &rComm->verbs)); - NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, &rComm->qp)); + rComm->nqps = ncclParamIbQpsPerConn(); + for (int q=0; qnqps; q++) { + NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, rComm->qps+q)); + } // Adjust the MTU remQpInfo.mtu = (enum ibv_mtu)std::min(remQpInfo.mtu, portAttr.active_mtu); // Setup QP - struct ibv_qp* qp = rComm->qp; - NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); + for (int q=0; qnqps; q++) { + struct ibv_qp* qp = rComm->qps[q]; + NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); + NCCLCHECK(ncclIbRtsQp(qp)); + } // Retain remote fifo info and prepare my RDMA ops rComm->remFifo.rkey = remQpInfo.fifoRkey; @@ -525,29 +545,26 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { rComm->gpuFlush.sge.length = 1; rComm->gpuFlush.sge.lkey = rComm->gpuFlush.hostMr->lkey; NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ, &rComm->gpuFlush.qp)); - struct ncclIbQpInfo localQpInfo = { - .lid=portAttr.lid, - .ib_port=ib_port, - .qpn=rComm->gpuFlush.qp->qp_num, - .spn=gid.global.subnet_prefix, - .iid=gid.global.interface_id, - .mtu=portAttr.active_mtu - }; - NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, &localQpInfo)); + struct ncclIbQpInfo localQpInfo; + localQpInfo.lid=portAttr.lid; + localQpInfo.ib_port=ib_port; + localQpInfo.spn=gid.global.subnet_prefix; + localQpInfo.iid=gid.global.interface_id; + localQpInfo.mtu=portAttr.active_mtu; + NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, rComm->gpuFlush.qp->qp_num, &localQpInfo)); NCCLCHECK(ncclIbRtsQp(rComm->gpuFlush.qp)); } // Fill Handle - struct ncclIbQpInfo qpInfo = { - .lid=portAttr.lid, - .ib_port=ib_port, - .qpn=qp->qp_num, - .spn=gid.global.subnet_prefix, - .iid=gid.global.interface_id, - .mtu=remQpInfo.mtu - }; + struct ncclIbQpInfo qpInfo; + qpInfo.lid=portAttr.lid; + qpInfo.ib_port=ib_port; + for (int q=0; qnqps; q++) qpInfo.qpn[q]=rComm->qps[q]->qp_num; + qpInfo.spn=gid.global.subnet_prefix; + qpInfo.iid=gid.global.interface_id; + qpInfo.mtu=remQpInfo.mtu; - NCCLCHECK(socketSend(rComm->fd, &qpInfo, sizeof(qpInfo))); + NCCLCHECK(socketSend(rComm->fd, &rComm->addr, &qpInfo, sizeof(qpInfo))); *recvComm = rComm; return ncclSuccess; } @@ -561,6 +578,7 @@ ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest** r->verbs = verbs; r->events = 1; r->size = -1; + r->addr = NULL; *req = r; return ncclSuccess; } @@ -576,19 +594,21 @@ ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) { ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) { struct ncclIbQpInfo remQpInfo; - struct ibv_qp* qp = comm->qp; // Do not block on this receive, return if not ready. int bytes = 0; - NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes)); + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &remQpInfo, sizeof(remQpInfo), &bytes)); if (bytes == 0) return ncclSuccess; // Try again later - NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes)); + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &remQpInfo, sizeof(remQpInfo), &bytes)); - NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); + for (int q=0; qnqps; q++) { + struct ibv_qp* qp = comm->qps[q]; + NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); + NCCLCHECK(ncclIbRtsQp(qp)); + } comm->ready = 1; // Block until this is done. It *should* not block indefinitely. - NCCLCHECK(socketSend(comm->fd, &comm->ready, sizeof(int))); + NCCLCHECK(socketSend(comm->fd, &comm->addr, &comm->ready, sizeof(int))); return ncclSuccess; } @@ -596,9 +616,9 @@ ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) { ncclResult_t ncclRecvCheck(struct ncclIbRecvComm* comm) { // Do not block on this receive, return if not ready. int bytes = 0; - NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes)); + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &comm->ready, sizeof(int), &bytes)); if (bytes == 0) return ncclSuccess; // Try again later - NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes)); + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &comm->ready, sizeof(int), &bytes)); return ncclSuccess; } @@ -643,20 +663,15 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); req->size = size; + req->addr = &comm->addr; struct ibv_send_wr wr[2]; memset(&wr[0], 0, sizeof(wr[0])); wr[0].wr_id = (uint64_t)req; struct ibv_sge sge; - if (size == 0) { - wr[0].sg_list = NULL; - wr[0].num_sge = 0; - } else { - sge.addr=(uintptr_t)data; sge.length=(unsigned int)size; sge.lkey=mr->lkey; - wr[0].sg_list = &sge; - wr[0].num_sge = 1; - } + sge.addr=(uintptr_t)data; sge.lkey=mr->lkey; + #if USE_RDMA_WRITE == 0 wr[0].opcode = IBV_WR_SEND; wr[0].send_flags = IBV_SEND_SIGNALED; @@ -665,8 +680,9 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo // Sanity checks to catch user collective call count/size mismatches // plus any potential programming errors if (size > slot->size || slot->size < 0 || slot->addr == 0 || slot->rkey == 0 || slot->seq != comm->fifoHead) { - WARN("NET/IB : collective mismatch error local size %d remote %d addr %lx rkey %x seq %x/%x", - size, slot->size, slot->addr, slot->rkey, slot->seq, comm->fifoHead); + char line[SOCKET_NAME_MAXLEN+1]; + WARN("NET/IB : peer %s collective mismatch error local size %d remote %d addr %lx rkey %x seq %x/%x", + socketToString(req->addr, line), size, slot->size, slot->addr, slot->rkey, slot->seq, comm->fifoHead); return ncclInternalError; } wr[0].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -703,8 +719,26 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo } #endif - struct ibv_send_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_send(comm->qp, wr, &bad_wr)); + int chunkSize = std::max(8, DIVUP(size, comm->nqps)); + + int offset = 0; + for (int q=0; qnqps; q++) { + int length = std::min(size-offset, chunkSize); + if (length <= 0) { + wr[0].sg_list = NULL; + wr[0].num_sge = 0; + } else { + sge.length = length; + wr[0].sg_list = &sge; + wr[0].num_sge = 1; + } + struct ibv_send_wr* bad_wr; + NCCLCHECK(wrap_ibv_post_send(comm->qps[q], wr, &bad_wr)); + offset += chunkSize; + sge.addr += chunkSize; + wr[0].wr.rdma.remote_addr += chunkSize; + } + req->events = comm->nqps; *request = req; return ncclSuccess; @@ -757,7 +791,7 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, uint32_t rkey, uint64_t } struct ibv_send_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_send(comm->qp, &wr, &bad_wr)); + NCCLCHECK(wrap_ibv_post_send(comm->qps[0], &wr, &bad_wr)); comm->remFifo.tail++; return ncclSuccess; @@ -773,23 +807,22 @@ ncclResult_t ncclIbIrecv(void* recvComm, void* data, int size, void* mhandle, vo struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); req->size = size; + req->addr = &comm->addr; struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); wr.wr_id = (uint64_t)req; - struct ibv_sge sge; - if (size == 0) { - wr.sg_list = NULL; - wr.num_sge = 0; - } else { - sge.addr=(uintptr_t)data; sge.length=(unsigned int)size; sge.lkey=mr->lkey; - wr.sg_list = &sge; - wr.num_sge = 1; - } + wr.sg_list = NULL; + wr.num_sge = 0; + + for (int q=0; qnqps; q++) { + struct ibv_qp* qp = comm->qps[q]; + struct ibv_recv_wr* bad_wr; + NCCLCHECK(wrap_ibv_post_recv(qp, &wr, &bad_wr)); + } + req->events = comm->nqps; - struct ibv_recv_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_recv(comm->qp, &wr, &bad_wr)); *request = req; // Post to FIFO to notify sender @@ -803,6 +836,7 @@ ncclResult_t ncclIbIflush(void* recvComm, void* data, int size, void* mhandle, v struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); + req->addr = &comm->addr; struct ibv_mr* mr = (struct ibv_mr*)mhandle; struct ibv_send_wr wr; @@ -843,7 +877,9 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size) { for (int w=0; wstatus != IBV_WC_SUCCESS) { - WARN("NET/IB : Got completion with error %d, opcode %d, len %d, vendor err %d", wc->status, wc->opcode, wc->byte_len, wc->vendor_err); + char line[SOCKET_NAME_MAXLEN+1]; + WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d", + socketToString(r->addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err); return ncclSystemError; } @@ -853,7 +889,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size) { doneReq->size = wc->byte_len; #if USE_RDMA_WRITE } else if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { - doneReq->size = wc->imm_data; + if (doneReq->size == -1) + doneReq->size = wc->imm_data; + else + doneReq->size += wc->imm_data; #endif } doneReq->events--; @@ -866,7 +905,8 @@ ncclResult_t ncclIbCloseSend(void* sendComm) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; if (comm) { close(comm->fd); - if (comm->qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qp)); + for (int q=0; qnqps; q++) + if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); if (comm->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr)); NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs)); free(comm); @@ -878,7 +918,8 @@ ncclResult_t ncclIbCloseRecv(void* recvComm) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; if (comm) { close(comm->fd); - if (comm->qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qp)); + for (int q=0; qnqps; q++) + if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); if (comm->gpuFlush.enabled) { if (comm->gpuFlush.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->gpuFlush.qp)); if (comm->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->gpuFlush.hostMr)); diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc index 79d991d..c045a8f 100644 --- a/src/transport/net_socket.cc +++ b/src/transport/net_socket.cc @@ -56,7 +56,7 @@ ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) { memcpy(&ncclSocketDevs[i].addr, addrs+i, sizeof(union socketAddress)); NCCLCHECK(ncclSocketGetPciPath(ncclSocketDevs[i].devName, &ncclSocketDevs[i].pciPath)); snprintf(line+strlen(line), MAX_LINE_LEN-strlen(line), " [%d]%s:%s", i, names+i*MAX_IF_NAME_SIZE, - socketToString(&addrs[i].sa, addrline)); + socketToString(&addrs[i], addrline)); } line[MAX_LINE_LEN] = '\0'; INFO(NCCL_INIT|NCCL_NET,"NET/Socket : Using%s", line); @@ -129,6 +129,7 @@ struct ncclSocketTask { void* data; int size; int fd; + union socketAddress *addr; int offset; int used; ncclResult_t result; @@ -139,6 +140,7 @@ struct ncclSocketRequest { void* data; int size; int ctrlFd; + union socketAddress *addr; int offset; int used; struct ncclSocketComm* comm; @@ -170,6 +172,7 @@ struct ncclSocketListenComm { struct ncclSocketComm { int ctrlFd; + union socketAddress addr; int fds[MAX_SOCKETS]; int nSocks; int nThreads; @@ -195,7 +198,7 @@ void* persistentSocketThread(void *args_) { for (int j=0; jtasks+i+j; if (r != NULL && r->used == 1 && r->offset < r->size) { - r->result = socketProgress(r->op, r->fd, r->data, r->size, &r->offset); + r->result = socketProgress(r->op, r->fd, r->addr, r->data, r->size, &r->offset); if (r->result != ncclSuccess) { WARN("NET/Socket : socket progress error"); return NULL; @@ -311,11 +314,12 @@ ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) { for (int i=0; inSocks+1; i++) { int tmpFd, offset=0; NCCLCHECK(connectAddress(&tmpFd, &handle->connectAddr)); - NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset)); + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &handle->connectAddr, &i, sizeof(int), &offset)); if (i == comm->nSocks) comm->ctrlFd = tmpFd; else comm->fds[i] = tmpFd; } *sendComm = comm; + comm->addr = handle->connectAddr; return ncclSuccess; } @@ -327,10 +331,9 @@ ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) { rComm->nThreads = lComm->nThreads; for (int i=0; inSocks+1; i++) { int tmpFd, sendSockIdx, offset=0; - struct sockaddr_in sockaddr; - socklen_t socklen = sizeof(struct sockaddr_in); - SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", tmpFd); - NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &sendSockIdx, sizeof(int), &offset)); + socklen_t socklen = sizeof(union socketAddress); + SYSCHECKVAL(accept(lComm->fd, &rComm->addr.sa, &socklen), "accept", tmpFd); + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &rComm->addr, &sendSockIdx, sizeof(int), &offset)); if (sendSockIdx == rComm->nSocks) rComm->ctrlFd = tmpFd; else rComm->fds[sendSockIdx] = tmpFd; } @@ -346,6 +349,7 @@ ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* dat r->data = data; r->size = size; r->ctrlFd = comm->ctrlFd; + r->addr = &comm->addr; r->used = 1; r->comm = comm; r->nSubs = 0; @@ -380,6 +384,7 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, r->data = data; r->size = size; r->fd = comm->fds[comm->nextFd]; + r->addr = &comm->addr; r->offset = 0; r->result = ncclSuccess; comm->nextFd = (comm->nextFd + 1) % comm->nSocks; @@ -406,16 +411,17 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { if (r->used == 1) { /* try to send/recv size */ int data = r->size; int offset = 0; - NCCLCHECK(socketProgress(r->op, r->ctrlFd, &data, sizeof(int), &offset)); + NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->addr, &data, sizeof(int), &offset)); if (offset == 0) return ncclSuccess; /* Not ready -- retry later */ // Not sure we could ever receive less than 4 bytes, but just in case ... - if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, &data, sizeof(int), &offset)); + if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, r->addr, &data, sizeof(int), &offset)); // Check size is less or equal to the size provided by the user if (r->op == NCCL_SOCKET_RECV && data > r->size) { - WARN("NET/Socket : message truncated : receiving %d bytes instead of %d", data, r->size); + char line[SOCKET_NAME_MAXLEN+1]; + WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d", socketToString(r->addr, line), data, r->size); return ncclInternalError; } r->size = data; @@ -453,7 +459,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { } } else { // progress request using main thread if (r->offset < r->size) { - NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->data, r->size, &r->offset)); + NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->addr, r->data, r->size, &r->offset)); } if (r->offset == r->size) { if (size) *size = r->size; diff --git a/src/transport/p2p.cc b/src/transport/p2p.cc index 75fff87..7764258 100644 --- a/src/transport/p2p.cc +++ b/src/transport/p2p.cc @@ -53,8 +53,8 @@ static int busIdToCudaDev(int64_t busId) { /* Determine if two peers can communicate through p2p */ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { - // Rule out different nodes - if (info1->hostHash != info2->hostHash) { + // Rule out different nodes / isolated containers + if (info1->hostHash != info2->hostHash || info1->shmDev != info2->shmDev) { *ret = 0; return ncclSuccess; }