From 5d3ab08b69754cb863ab1bf6e3fb6f7612bb725b Mon Sep 17 00:00:00 2001 From: Sylvain Jeaugey Date: Mon, 27 Feb 2023 02:48:21 -0800 Subject: [PATCH] 2.17.1-1 Add new NVLS algorithm for allreduce using NVLink SHARP (intra-node only). Add new config options: cgaClusterSize, minCTAs, maxCTAs, netName. Enable LL128 when we use PXN to close rings. NVTX3 includes update. Fix crash when one CollNet (SHARP) rail fails to initialize. --- makefiles/version.mk | 4 +- src/Makefile | 5 +- src/bootstrap.cc | 18 + src/channel.cc | 7 +- src/collectives/device/all_gather.h | 42 + src/collectives/device/all_reduce.h | 65 +- src/collectives/device/common.h | 38 +- src/collectives/device/common_kernel.h | 963 +++------ src/collectives/device/functions.cu | 6 +- src/collectives/device/onerank_reduce.cu | 8 +- src/collectives/device/op128.h | 286 +++ src/collectives/device/primitives.h | 4 +- src/collectives/device/prims_ll.h | 8 +- src/collectives/device/prims_ll128.h | 49 +- src/collectives/device/prims_simple.h | 222 +- src/collectives/device/reduce_kernel.h | 1114 +++++----- src/collectives/device/reduce_scatter.h | 42 + src/collectives/device/sendrecv.h | 5 +- src/debug.cc | 2 + src/enqueue.cc | 131 +- src/graph/connect.cc | 4 +- src/graph/paths.cc | 4 +- src/graph/search.cc | 1 - src/graph/topo.cc | 2 +- src/graph/tuning.cc | 59 +- src/group.cc | 3 +- src/include/bootstrap.h | 1 + src/include/collectives.h | 12 +- src/include/comm.h | 13 + src/include/cudawrap.h | 24 +- src/include/devcomm.h | 83 +- src/include/enqueue.h | 3 +- src/include/info.h | 1 + src/include/ipcsocket.h | 37 + src/include/nccl_net.h | 2 +- src/include/nvtx.h | 16 +- src/include/nvtx3/nvToolsExt.h | 2 +- src/include/nvtx3/nvToolsExtCuda.h | 2 +- src/include/nvtx3/nvToolsExtCudaRt.h | 2 +- src/include/nvtx3/nvToolsExtOpenCL.h | 2 +- src/include/nvtx3/nvToolsExtPayload.h | 4 +- src/include/nvtx3/nvToolsExtSync.h | 2 +- src/include/{ => nvtx3}/nvtx3.hpp | 1846 +++++++++++------ src/include/nvtx3/nvtxDetail/nvtxImpl.h | 2 +- src/include/nvtx3/nvtxDetail/nvtxImplCore.h | 2 +- .../nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h | 2 +- .../nvtx3/nvtxDetail/nvtxImplCuda_v3.h | 2 +- .../nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h | 2 +- .../nvtx3/nvtxDetail/nvtxImplSync_v3.h | 2 +- src/include/nvtx3/nvtxDetail/nvtxInit.h | 2 +- src/include/nvtx3/nvtxDetail/nvtxInitDecls.h | 2 +- src/include/nvtx3/nvtxDetail/nvtxInitDefs.h | 2 +- src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h | 2 +- src/include/nvtx3/nvtxDetail/nvtxTypes.h | 2 +- .../nvtxExtDetail/nvtxExtImplPayload_v1.h | 3 +- src/include/proxy.h | 42 +- src/include/socket.h | 2 +- src/include/transport.h | 5 +- src/init.cc | 235 ++- src/misc/cudawrap.cc | 55 +- src/misc/ipcsocket.cc | 200 ++ src/misc/socket.cc | 23 +- src/nccl.h.in | 17 +- src/net.cc | 44 +- src/proxy.cc | 435 +++- src/transport.cc | 119 +- src/transport/coll_net.cc | 22 +- src/transport/net.cc | 80 +- src/transport/net_ib.cc | 91 +- src/transport/nvls.cc | 373 ++++ src/transport/p2p.cc | 16 +- src/transport/shm.cc | 4 +- 72 files changed, 4541 insertions(+), 2391 deletions(-) create mode 100644 src/include/ipcsocket.h rename src/include/{ => nvtx3}/nvtx3.hpp (51%) create mode 100644 src/misc/ipcsocket.cc create mode 100644 src/transport/nvls.cc diff --git a/makefiles/version.mk b/makefiles/version.mk index e8e7b7a..6877b63 100644 --- a/makefiles/version.mk +++ b/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 -NCCL_MINOR := 16 -NCCL_PATCH := 5 +NCCL_MINOR := 17 +NCCL_PATCH := 1 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/src/Makefile b/src/Makefile index 4753018..ca5ddce 100644 --- a/src/Makefile +++ b/src/Makefile @@ -12,7 +12,8 @@ INCEXPORTS := nccl.h nccl_net.h LIBSRCFILES := init.cc init_nvtx.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \ misc/cudawrap.cc misc/nvmlwrap.cc misc/ibvwrap.cc misc/gdrwrap.cc \ misc/utils.cc misc/argcheck.cc misc/socket.cc misc/shmutils.cc misc/profiler.cc misc/param.cc misc/strongstream.cc \ - transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc \ + misc/ipcsocket.cc \ + transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc transport/nvls.cc \ collectives/sendrecv.cc collectives/all_reduce.cc collectives/all_gather.cc collectives/broadcast.cc collectives/reduce.cc collectives/reduce_scatter.cc \ graph/topo.cc graph/paths.cc graph/search.cc graph/connect.cc graph/rings.cc graph/trees.cc graph/tuning.cc graph/xml.cc @@ -62,7 +63,7 @@ ALWAYS_REBUILD: -include $(DEPFILES) $(LIBDIR)/$(LIBTARGET) $(LIBDIR)/$(STATICLIBTARGET) : $(LIBOBJ) -$(INCDIR)/nccl.h : nccl.h.in +$(INCDIR)/nccl.h : nccl.h.in ../makefiles/version.mk # NCCL_VERSION(X,Y,Z) ((X) * 10000 + (Y) * 100 + (Z)) @$(eval NCCL_VERSION := $(shell printf "%d%02d%02d" $(NCCL_MAJOR) $(NCCL_MINOR) $(NCCL_PATCH))) mkdir -p $(INCDIR) diff --git a/src/bootstrap.cc b/src/bootstrap.cc index c348b3e..a3a4df6 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -386,6 +386,24 @@ ncclResult_t bootstrapIntraNodeAllGather(void* commState, int *ranks, int rank, return ncclSuccess; } +// IntraNode in-place Broadcast +ncclResult_t bootstrapIntraNodeBroadcast(void* commState, int *ranks, int rank, int nranks, int root, void* bcastData, int size) { + if (nranks == 1) return ncclSuccess; + TRACE(NCCL_INIT, "rank %d nranks %d root %d size %d - ENTER", rank, nranks, root, size); + + if (rank == root) { + for (int i=0; iid != -1) return ncclSuccess; int nRanks = comm->nRanks; + int nPeers = nRanks + 1 /* Collnet */ + comm->localRanks /* NVLS */; channel->id = channelId; channel->workFifoSent = 0; NCCLCHECK(ncclStrongStreamAcquireUncaptured(&comm->deviceStream)); // The extra on nRanks+1 is for collnet root (i.e. network) - channel->peers = ncclMemoryStackAlloc(&comm->memPermanent, nRanks+1); - NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nRanks+1, comm->deviceStream.cudaStream)); + channel->peers = ncclMemoryStackAlloc(&comm->memPermanent, nPeers); + NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nPeers, comm->deviceStream.cudaStream)); ncclCommPushCudaFree(comm, channel->devPeers); channel->ring.userRanks = ncclMemoryStackAlloc(&comm->memPermanent, nRanks); @@ -29,7 +30,7 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) { NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); - for (int r=0; r < nRanks+1; ++r) { + for (int r=0; r < nPeers; ++r) { for (int b=0; b < NCCL_MAX_CONNS; b++) { channel->peers[r].send[b].comm = comm; channel->peers[r].recv[b].comm = comm; diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index 4e82dd6..a5f3f29 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -97,3 +97,45 @@ struct RunWorkElement(args); } }; + +template +struct RunWorkElement { + __device__ __forceinline__ void run(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nChannels = args->nChannels; + struct ncclNvls* nvls = &ncclShmem.channel.nvls; + const ssize_t chunkSize = int(args->lastChunkSize); + const ssize_t size = args->count; + const ssize_t loopSize = nChannels*chunkSize; + + const int nThreadsGather = 128; + const int nThreadsBcast = 384 + WARP_SIZE; + const int tidEndGather = nThreadsGather; + const int tidEndBcast = tidEndGather + nThreadsBcast; + + using Proto = ProtoSimple<1, 1>; + + if (tid < tidEndGather) { + // Gather + int group = (0*Proto::MaxGroupWidth) | (0<<16); + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.gather(offset, nvls->nHeads*size, nelem, size, -1, 0); + } + } else if (tid < tidEndBcast) { + int group = (3*Proto::MaxGroupWidth) | (1<<16); + // Bcast through MC + Primitives, /*Direct=*/0, Proto, 0> + prims(tid-tidEndGather, nThreadsBcast, NULL, &nvls->down, args->sendbuff, NULL, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.send(offset, nelem); + } + } + } +}; diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index 3f12e5e..f51eb43 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -306,9 +306,9 @@ struct RunWorkElementnHeads*chunkSize; int nelem = min(direct->nHeads*chunkSize, size-offset); if (args->regUsed) { - prims.directScatter(offset, nelem, chunkSize, direct->headRank, direct->shift); + prims.directScatter(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); } else { - prims.scatter(offset, nelem, chunkSize, direct->headRank, direct->shift); + prims.scatter(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); } } } else if (tid >= tidStartReduce && direct->out != -1) { @@ -344,7 +344,7 @@ struct RunWorkElementnHeads*chunkSize; int nelem = min(direct->nHeads*chunkSize, size-offset); - prims.directGather(offset, nelem, chunkSize, direct->headRank, direct->shift); + prims.directGather(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); } } else if (tid >= tidStartBcast && tid < tidStartScatter && direct->out != -1) { int group = (1*Proto::MaxGroupWidth) | (0<<16); @@ -371,6 +371,65 @@ struct RunWorkElement +struct RunWorkElement { + __device__ __forceinline__ void run(ncclWorkElem *args) { + #if NCCL_NVLS_ENABLED + const int tid = threadIdx.x; + const int bid = args->bid; + const int nChannels = args->nChannels; + struct ncclNvls* nvls = &ncclShmem.channel.nvls; + const ssize_t chunkSize = int(args->lastChunkSize); + const ssize_t size = args->count; + const ssize_t loopSize = nChannels*nvls->nHeads*chunkSize; + const int nranks = ncclShmem.comm.nRanks; + const int reduceWarps = nranks <= 6 ? 6 : 4; + const int copyWarps = ((NCCL_MAX_NTHREADS/WARP_SIZE) - reduceWarps)/2; + + const int nThreadsScatter = copyWarps*WARP_SIZE; + const int nThreadsGather = (copyWarps-1)*WARP_SIZE; + const int nThreadsReduce = (reduceWarps+1)*WARP_SIZE; + const int tidEndScatter = nThreadsScatter; + const int tidEndGather = tidEndScatter + nThreadsGather; + const int tidEndReduce = tidEndGather + nThreadsReduce; + + using Proto = ProtoSimple<1, 1, COLL_UNROLL, /*NVLS=*/true>; + + if (tid < tidEndScatter) { + // Scatter + int group = (0*Proto::MaxGroupWidth) | (0<<16); + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize; + int nelem = min(nvls->nHeads*chunkSize, size-offset); + prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0); + } + } else if (tid < tidEndGather) { + // Gather + int group = (2*Proto::MaxGroupWidth) | (0<<16); + Primitives, /*Direct=*/0, Proto, 0> + prims(tid-tidEndScatter, nThreadsGather, nvls->up, NULL, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize; + int nelem = min(nvls->nHeads*chunkSize, size-offset); + prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); + } + } else if (tid < tidEndReduce) { + int group = (3*Proto::MaxGroupWidth) | (1<<16); + // Reduce, broadcast through NVLS + Primitives, /*Direct=*/0, Proto, 0> + prims(tid-tidEndGather, nThreadsReduce, &nvls->down, &nvls->down, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.recvSend(nelem); + } + } + #endif // NCCL_NVLS_ENABLED + } +}; + template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 95cc990..ad9ca48 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -11,31 +11,23 @@ #include "devcomm.h" #include "op128.h" -#if __CUDA_ARCH__ >= 800 -#define COLL_UNROLL 8 -#else -#define COLL_UNROLL 4 -#endif - +#define COLL_UNROLL (ncclCollUnroll()) #define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree typedef void(*ncclKern_t)(); extern __device__ ncclKern_t ncclFuncs[]; struct ncclShmemGroup { - ncclConnInfo *recvConns[NCCL_MAX_DIRECT_ARITY]; - ncclConnInfo *sendConns[NCCL_MAX_DIRECT_ARITY]; - void* srcs[NCCL_MAX_DIRECT_ARITY+1]; - void* dsts[NCCL_MAX_DIRECT_ARITY+1]; - int totalSendSize[NCCL_MAX_SLICE_PER_CHUNK]; + ncclConnInfo *recvConns[NCCL_MAX_NVLS_ARITY]; + ncclConnInfo *sendConns[NCCL_MAX_NVLS_ARITY]; + void* srcs[NCCL_MAX_NVLS_ARITY+1]; + void* dsts[NCCL_MAX_NVLS_ARITY+1]; + int nvlsRecv; }; struct ncclShmemData { - union { - uint64_t ll128warp[NCCL_LL128_MAX_NTHREADS/WARP_SIZE][NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE]; - struct ncclShmemGroup groups[NCCL_MAX_GROUPS]; - }; - uint64_t redOpArgs[NCCL_MAX_DIRECT_ARITY+1]; + struct ncclShmemGroup groups[NCCL_MAX_GROUPS]; + uint64_t redOpArgs[NCCL_MAX_NVLS_ARITY+1]; int channelId; int aborted; alignas(16) struct ncclDevComm comm; @@ -45,6 +37,15 @@ struct ncclShmemData { static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "shmem.work needs to be 16B aligned"); extern __shared__ ncclShmemData ncclShmem; +#if __CUDA_ARCH__ >= 700 + extern __shared__ ulong2 ncclShmemPerWarp[/*ncclShmemDynamicSize()/sizeof(ulong2)*/]; +#else + extern __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)]; +#endif + +__device__ inline void* ncclScratchForWarp(int warp) { + return (char*)ncclShmemPerWarp + warp*ncclShmemScratchWarpSize(); +} __device__ inline bool barrierReduceAny(int bit) { uint32_t popc; @@ -235,7 +236,8 @@ __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)() { \ IMPL_COLL4(func, TREE, devredop, type, ncclType) \ IMPL_COLL4(func, RING, devredop, type, ncclType) \ IMPL_COLL4(func, COLLNET_DIRECT, devredop, type, ncclType) \ - IMPL_COLL4(func, COLLNET_CHAIN, devredop, type, ncclType) + IMPL_COLL4(func, COLLNET_CHAIN, devredop, type, ncclType) \ + IMPL_COLL4(func, NVLS, devredop, type, ncclType) #if NCCL_TYPE == 0 #define IMPL_COLL2(func, devredop) IMPL_COLL3(func, devredop, int8_t, ncclInt8) @@ -291,4 +293,6 @@ __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)() { \ #define IMPL_COLL_P(func) #endif +#define NCCL_NVLS_ENABLED (__CUDA_ARCH__ >= 900 && NCCL_NVLS_SUPPORTS(NCCL_TYPE, NCCL_OP)) + #endif diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index c21d373..9a2e004 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -8,13 +8,15 @@ #define NCCL_COMMON_KERNEL_H_ #include "devcomm.h" +#include "op128.h" +#include "reduce_kernel.h" #include #include #include // Define min for ssize_t -static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } +inline __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } inline __device__ int loadInt(int* ptr) { int v; @@ -23,670 +25,353 @@ inline __device__ int loadInt(int* ptr) { return v; } -typedef uint64_t PackType; - -template -struct FuncTraits /*{ - __device__ static T preOp(Fn, T); - __device__ static T postOp(Fn, T); -}*/; - -// unpack x and y to elements of type T and apply FUNC to each element -template -struct MULTI { - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const; - __device__ PackType preOp(FUNC fn, PackType x) const; - __device__ PackType postOp(FUNC fn, PackType x) const; -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of uint32_t."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - // for char, we do these as vector ops - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - int8_t elt[8]; - } u; - u.pack = x; - #pragma unroll - for (int i=0; i < 8; i++) - u.elt[i] = FuncTraits().preOp(fn, u.elt[i]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - int8_t elt[8]; - } u; - u.pack = x; - #pragma unroll - for (int i=0; i < 8; i++) - u.elt[i] = FuncTraits().postOp(fn, u.elt[i]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of uint32_t."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - // for char, we do these as vector ops - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint8_t elt[8]; - } u; - u.pack = x; - #pragma unroll - for (int i=0; i < 8; i++) - u.elt[i] = FuncTraits().preOp(fn, u.elt[i]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint8_t elt[8]; - } u; - u.pack = x; - #pragma unroll - for (int i=0; i < 8; i++) - u.elt[i] = FuncTraits().postOp(fn, u.elt[i]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(int32_t), - "PackType must be twice the size of int."); - union converter { - PackType storage; - struct { - int32_t a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - int32_t elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - int32_t elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of int."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint32_t elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint32_t elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 4 * sizeof(half), - "PackType must be four times the size of half."); - - union Converter { - PackType pack; - half2 h2[2]; - }; - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - Converter cx, cy, cr; - cx.pack = x; - cy.pack = y; - cr.h2[0] = fn(cx.h2[0], cy.h2[0]); - cr.h2[1] = fn(cx.h2[1], cy.h2[1]); - return cr.pack; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - Converter c; - c.pack = x; - c.h2[0] = FuncTraits().preOp(fn, c.h2[0]); - c.h2[1] = FuncTraits().preOp(fn, c.h2[1]); - return c.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - Converter c; - c.pack = x; - c.h2[0] = FuncTraits().postOp(fn, c.h2[0]); - c.h2[1] = FuncTraits().postOp(fn, c.h2[1]); - return c.pack; - } -}; - -#if defined(__CUDA_BF16_TYPES_EXIST__) -template -struct MULTI { - static_assert(sizeof(PackType) == 4 * sizeof(__nv_bfloat16), - "PackType must be four times the size of __nv_bfloat16."); - - union Converter { - PackType pack; - __nv_bfloat162 h2[2]; - }; - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - Converter cx, cy, cr; - cx.pack = x; - cy.pack = y; - cr.h2[0] = fn(cx.h2[0], cy.h2[0]); - cr.h2[1] = fn(cx.h2[1], cy.h2[1]); - return cr.pack; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - Converter c; - c.pack = x; - c.h2[0] = FuncTraits().preOp(fn, c.h2[0]); - c.h2[1] = FuncTraits().preOp(fn, c.h2[1]); - return c.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - Converter c; - c.pack = x; - c.h2[0] = FuncTraits().postOp(fn, c.h2[0]); - c.h2[1] = FuncTraits().postOp(fn, c.h2[1]); - return c.pack; - } -}; -#endif - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(float), - "PackType must be twice the size of float."); - union converter { - PackType storage; - struct { - float a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - float elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - float elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(double), - "PackType must be the same size as double."); - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - double rv = fn(__longlong_as_double(x), __longlong_as_double(y)); - return __double_as_longlong(rv); - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - double elt; - } u; - u.pack = x; - u.elt = FuncTraits().preOp(fn, u.elt); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - double elt; - } u; - u.pack = x; - u.elt = FuncTraits().postOp(fn, u.elt); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(uint64_t), - "PackType must be the same size as uint64_t."); - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - uint64_t rv = fn(x, y); - return rv; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint64_t elt; - } u; - u.pack = x; - u.elt = FuncTraits().preOp(fn, u.elt); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint64_t elt; - } u; - u.pack = x; - u.elt = FuncTraits().postOp(fn, u.elt); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(int64_t), - "PackType must be the same size as int64_t."); - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - int64_t rv = fn((int64_t)x, (int64_t)y); - return rv; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - int64_t elt; - } u; - u.pack = x; - u.elt = FuncTraits().preOp(fn, u.elt); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - int64_t elt; - } u; - u.pack = x; - u.elt = FuncTraits().postOp(fn, u.elt); - return u.pack; - } -}; - -template inline __device__ -T vFetch(const volatile T* ptr) { - return *ptr; -} - -template inline __device__ -void vStore(volatile T* ptr, const T val) { - *ptr = val; -} - -#if CUDART_VERSION < 9000 -template<> inline __device__ -half vFetch(const volatile half* ptr) { - half r; - r.x = ptr->x; - return r; -} - -template<> inline __device__ -void vStore(volatile half* ptr, const half val) { - ptr->x = val.x; -} -#else -template<> inline __device__ -half vFetch(const volatile half* ptr) { - half r; - r = ((half*)ptr)[0]; - return r; -} - -template<> inline __device__ -void vStore(volatile half* ptr, const half val) { - ((half*)ptr)[0] = val; -} -#endif - -#if defined(__CUDA_BF16_TYPES_EXIST__) -template<> inline __device__ -__nv_bfloat16 vFetch<__nv_bfloat16>(const volatile __nv_bfloat16* ptr) { - __nv_bfloat16 r; - r = ((__nv_bfloat16*)ptr)[0]; - return r; -} - -template<> inline __device__ -void vStore<__nv_bfloat16>(volatile __nv_bfloat16* ptr, const __nv_bfloat16 val) { - ((__nv_bfloat16*)ptr)[0] = val; -} -#endif - -typedef ulong2 Pack128; - -template -struct MULTI128 { - __device__ void operator()(FUNC fn, Pack128& x, Pack128 const& y) const { - x.x = MULTI()(fn, x.x, y.x); - x.y = MULTI()(fn, x.y, y.y); - } - __device__ void preOp(FUNC fn, Pack128 &x) const { - x.x = MULTI().preOp(fn, x.x); - x.y = MULTI().preOp(fn, x.y); - } - __device__ void postOp(FUNC fn, Pack128 &x) const { - x.x = MULTI().postOp(fn, x.x); - x.y = MULTI().postOp(fn, x.y); - } -}; - -inline __device__ void Fetch128(Pack128& v, const Pack128* p) { - asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory"); -} -inline __device__ void Store128(Pack128* p, Pack128& v) { - asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory"); -} - -template -__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t, - uint64_t* redOpArgs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const Int Nelem +template +__device__ __forceinline__ void reduceCopyPacks( + int nThreads, int &thread, + uint64_t redArg, uint64_t *preOpArgs, bool postOp, + int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, + IntBytes &nBytesBehind, IntBytes &nBytesAhead ) { - const Int inc = nw * UNROLL * WARP_SIZE; - Int offset = w * UNROLL * WARP_SIZE + t; + static_assert(std::is_signed::value, "IntBytes must be a signed integral type."); - const T* srcs[MAXSRCS]; - for (int i=0; i().preOp(fn, vals[u]); - } + // This thread's initial position. + IntBytes threadBytesBehind = nBytesBehind + (warp*BytePerHunk + lane*BytePerPack); + IntBytes threadBytesAhead = nBytesAhead - (warp*BytePerHunk + lane*BytePerPack); + // Number of hunks to be consumed over all warps. + IntBytes nHunksAhead = nBytesAhead/BytePerHunk; + // Advance collective position. + nBytesBehind += nHunksAhead*BytePerHunk; + nBytesAhead -= nHunksAhead*BytePerHunk; + if (Unroll==1 && BytePerPack <= nBytesAhead) { + // Only Unroll=1 can do partial hunks (where not all threads partake). + nHunksAhead += 1; + nBytesBehind += nBytesAhead - (nBytesAhead%BytePerPack); + nBytesAhead = nBytesAhead%BytePerPack; + } + nHunksAhead -= warp; - #pragma unroll - for (int i=1; i().preOp(fn, vals2[u]); + RedFn redFn(redArg); + uintptr_t minSrcs[MinSrcs + !MinSrcs]; + uintptr_t minDsts[MinDsts + !MinDsts]; + #pragma unroll + for (int s=0; s < MinSrcs; s++) + minSrcs[s] = cvta_to_global(srcPtrs[s]) + threadBytesBehind; + #pragma unroll + for (int d=0; d < MinDsts; d++) + minDsts[d] = cvta_to_global(dstPtrs[d]) + threadBytesBehind; + + // We dictate loop termination condition according to whether partial hunks + // can be handled or not. + while (Unroll==1 ? (BytePerPack <= threadBytesAhead) : (0 < nHunksAhead)) { + BytePack acc[Unroll]; + + { RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + acc[u] = ld_volatile_global(minSrcs[0]); + minSrcs[0] += WARP_SIZE*BytePerPack; + if (0 < PreOpSrcs) acc[u] = applyPreOp(preFn, acc[u]); } - for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]); } - #pragma unroll - for (int i=MINSRCS; i().preOp(fn, vals2[u]); - } - for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]); + + #pragma unroll (MinSrcs-1 + !(MinSrcs-1)) + for (int s=1; s < MinSrcs; s++) { + BytePack tmp[Unroll]; + RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + tmp[u] = ld_volatile_global(minSrcs[s]); + minSrcs[s] += WARP_SIZE*BytePerPack; + } + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]); + acc[u] = applyReduce(redFn, acc[u], tmp[u]); + } + } + + for (int s=MinSrcs; (MinSrcs < MaxSrcs) && (s < MaxSrcs) && (s < nSrcs); s++) { + uintptr_t src = cvta_to_global(srcPtrs[s]) + threadBytesBehind; + BytePack tmp[Unroll]; + RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + tmp[u] = ld_volatile_global(src); + src += WARP_SIZE*BytePerPack; + } + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]); + acc[u] = applyReduce(redFn, acc[u], tmp[u]); } } if (postOp) { - FUNC fn(redOpArgs[0]); - #pragma unroll - for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits().postOp(fn, vals[u]); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) + acc[u] = applyPostOp(redFn, acc[u]); } - // Store - #pragma unroll - for (int i = 0; i < MINDSTS; i++) { - for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]); - } - #pragma unroll - for (int i=MINDSTS; i(minDsts[d], acc[u]); + minDsts[d] += WARP_SIZE*BytePerPack; } } - for (int i=0; i(dst, acc[u]); + dst += WARP_SIZE*BytePerPack; + } + } + + nWarps = nThreads/WARP_SIZE; + #pragma unroll + for (int s=0; s < MinSrcs; s++) minSrcs[s] += (nWarps-1)*BytePerHunk; + #pragma unroll + for (int d=0; d < MinDsts; d++) minDsts[d] += (nWarps-1)*BytePerHunk; + threadBytesBehind += nWarps*BytePerHunk; + threadBytesAhead -= nWarps*BytePerHunk; + nHunksAhead -= nWarps; } + + nWarps = nThreads/WARP_SIZE; + warp = thread/WARP_SIZE; + lane = thread%WARP_SIZE; + // The last loop iteration could have been partial, i.e. not taken by all + // threads. The threads that weren't included need an extra subtraction to + // make the value warp uniform. + if (Unroll==1 && nHunksAhead > 0) nHunksAhead -= nWarps; + // Rotate warps so the warp which got the least work here will be warp 0. + // This effectively assigns: warp = (warp-nHunks+nWarps)%nWarps; + warp = -nHunksAhead; + thread = warp*WARP_SIZE + lane; } -template -__device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t, - uint64_t* redOpArgs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const Int Npack - ) { - const Int inc = nw * UNROLL * WARP_SIZE; - Int offset = w * UNROLL * WARP_SIZE + t; - - const Pack128* srcs[MAXSRCS]; - for (int i=0; i().preOp(fn, vals[u]); - } - - #pragma unroll - for (int i=1; i().preOp(fn, vals2[u]); - } - for (int u = 0; u < UNROLL; ++u) MULTI128()(fn, vals[u], vals2[u]); - } - #pragma unroll - for (int i=MINSRCS; i().preOp(fn, vals2[u]); - } - for (int u = 0; u < UNROLL; ++u) MULTI128()(fn, vals[u], vals2[u]); - } - } - - if (postOp) { - FUNC fn(redOpArgs[0]); - #pragma unroll - for (int u = 0; u < UNROLL; ++u) MULTI128().postOp(fn, vals[u]); - } - - // Store - #pragma unroll - for (int i = 0; i < MINDSTS; i++) { - for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]); - } - #pragma unroll - for (int i=MINDSTS; i -__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(Pack128); } - -#define PACKELEMS (sizeof(Pack128) / sizeof(T)) - -template +template __device__ __forceinline__ void ReduceOrCopyMulti( - const int tid, const int nthreads, uint64_t* redOpArgs, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, Int N + int thread, int nThreads, + uint64_t redArg, uint64_t *preOpArgs, bool postOp, + int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, + IntBytes nElts ) { - Int Nrem = N; - if (Nrem <= 0) return; - - int w = tid / WARP_SIZE; // Warp number - int nw = nthreads / WARP_SIZE; // Number of warps - int t = tid % WARP_SIZE; // Thread (inside the warp) + //int nWarps = nThreads/WARP_SIZE; + //int warp = thread/WARP_SIZE; + int lane = thread%WARP_SIZE; // Check that all is 16B aligned. If not don't use 16B load/stores. - int align = 0; - #pragma unroll - for (int i=0; i + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + if (nBytesAhead == 0) return; - // main loop - Int Npack = (Nrem / (PACKELEMS*UNROLL*WARP_SIZE)) * (UNROLL*WARP_SIZE); // round down - Int Nelem = Npack * PACKELEMS; - - ReduceCopy128bMulti - (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; - - // slightly less optimized for section when we don't have full unrolling - Npack = Nrem / PACKELEMS; - Nelem = Npack * PACKELEMS; - - ReduceCopy128bMulti - (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; + reduceCopyPacks + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + if (nBytesAhead == 0) return; } - // unrolled, by-type (mostly for unaligned buffers) - Int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down + reduceCopyPacks + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + if (nBytesAhead == 0) return; - ReduceCopyMulti - (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; - - // no unroll, by type. Should finish what's remaining. - ReduceCopyMulti - (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem); + reduceCopyPacks + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); } +// Copies from srcAddr to dstAddr using multimem load/store. The amount copied +// will be at most Unroll*BytePerPack*WARP_SIZE. If Partial=1, then the amount +// will be the min() of that and nBytesAhead. If srcAddr is not BytePerPack +// aligned then the amount copied will be less by (srcAddr%BytePerPack) since +// we begin loads at the first pack containing the first element. +template +__device__ __forceinline__ void copyMultimemMultimem_WarpUnrolled( + int lane, RedFn redFn, bool postOp, uintptr_t srcAddr, uintptr_t dstAddr, + IntBytes nBytesAhead, uint32_t scratchAddr + ) { + int srcMisalign = SrcAligned ? 0 : srcAddr%BytePerPack; + srcAddr -= srcMisalign; + + BytePack reg[Unroll]; + int offset = lane*BytePerPack; + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (!Partial || (offset < srcMisalign + nBytesAhead)) { + reg[u] = applyLoadMultimem(redFn, srcAddr+offset); + if (postOp) reg[u] = applyPostOp(redFn, reg[u]); + } + offset += WARP_SIZE*BytePerPack; + } + + if (SrcAligned && DstAligned) { + offset = lane*BytePerPack; + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (!Partial || offset < nBytesAhead) { + multimem_st_global(dstAddr+offset, reg[u]); + } + offset += WARP_SIZE*BytePerPack; + } + } else { + __syncwarp(); + offset = lane*BytePerPack; + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (!Partial || (offset < srcMisalign + nBytesAhead)) { + st_shared(scratchAddr+offset, reg[u]); + } + offset += WARP_SIZE*BytePerPack; + } + __syncwarp(); + if (!SrcAligned) { + // Ignore the beginning of the first pack corresponding to bytes overread + // due to misalignment. + nBytesAhead = min(nBytesAhead, Unroll*WARP_SIZE*BytePerPack - srcMisalign); + } + copyGlobalShared_WarpUnrolled + + (lane, dstAddr, scratchAddr+srcMisalign, nBytesAhead); + } +} + +// copyMultimemMultimem_IfEnabled has two overloads: the enabled case whose first arg +// has type `std::true_type` and the disabled case with first arg `std::false_type`. +// This is to guard the template instantiations of Apply_LoadMultimem on types/ops where +// they aren't supported. A nicer approach is to use C++17's "if constexpr". +template +__device__ __forceinline__ void copyMultimemMultimem_IfEnabled( + std::false_type enabled/*=false*/, + int thread, int nThreads, uint64_t redArg, bool postOp, + void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr + ) { + // nop +} + +template +__device__ __forceinline__ void copyMultimemMultimem_IfEnabled( + std::true_type enabled/*=true*/, + int thread, int nThreads, uint64_t redArg, bool postOp, + void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr + ) { + static_assert(std::is_signed::value, "IntBytes must be a signed integral type."); + + constexpr int BytePerPack = Apply_LoadMultimem::PackSize; + using T = typename RedFn::EltType; + constexpr int Unroll = ncclNvlsUnroll(BytePerPack); + constexpr int BytePerHunk = Unroll*WARP_SIZE*BytePerPack; + int nWarps = nThreads/WARP_SIZE; + int warp = thread/WARP_SIZE; + int lane = thread%WARP_SIZE; + RedFn redFn(redArg); + + uintptr_t srcAddr = cvta_to_global(srcPtr); + uintptr_t dstAddr = cvta_to_global(dstPtr); + IntBytes warpBytesAhead = nElts*sizeof(T); + bool partialHunkIsFront; + + // First handle misalignment of srcAddr. + if ((BytePerPack != sizeof(T)) && (srcAddr%BytePerPack != 0)) { + // If srcAddr isn't pack aligned then the first hunk processed will be short + // the same number of bytes as srcAddr's misalignment. + if (warp == 0) { + partialHunkIsFront = true; + goto PartialHunk; // "call" PartialHunk() + PartialHunkFrontReturn: + warp = nWarps; + } + warp -= 1; // Rotate warp numbers for load balancing + int advanced = BytePerHunk-(srcAddr%BytePerPack); // since copyMultimemMultimem_WarpUnrolled shorts by the misalignment + srcAddr += advanced; // srcAddr is now pack aligned + dstAddr += advanced; + warpBytesAhead -= advanced; + } + + warpBytesAhead -= warp*BytePerHunk; + srcAddr += warp*BytePerHunk; + dstAddr += warp*BytePerHunk; + // Now that srcAddr is pack aligned detect if dstAddr is pack aligned. + if ((BytePerPack == sizeof(T)) || (dstAddr%BytePerPack == 0)) { + while (BytePerHunk <= warpBytesAhead) { + copyMultimemMultimem_WarpUnrolled + + (lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr); + srcAddr += nWarps*BytePerHunk; + dstAddr += nWarps*BytePerHunk; + warpBytesAhead -= nWarps*BytePerHunk; + } + } else { + while (BytePerHunk <= warpBytesAhead) { + copyMultimemMultimem_WarpUnrolled + + (lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr); + srcAddr += nWarps*BytePerHunk; + dstAddr += nWarps*BytePerHunk; + warpBytesAhead -= nWarps*BytePerHunk; + } + } + + if (0 < warpBytesAhead) { + partialHunkIsFront = false; + goto PartialHunk; // "call" PartialHunk() + PartialHunkBackReturn:; + } + return; + +PartialHunk: + // We have to handle a partial hunk possibly at the front and back of the + // buffer. We generate the code once here since its a lot of instructions, + // and then simulate function calls with gotos. + copyMultimemMultimem_WarpUnrolled + + (lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr); + if (partialHunkIsFront) goto PartialHunkFrontReturn; + goto PartialHunkBackReturn; +} + +template +__device__ __forceinline__ void copyMultimemMultimem( + int thread, int nThreads, uint64_t redArg, bool postOp, + void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr + ) { + constexpr bool Enabled = Apply_LoadMultimem::PackSize != 0; + copyMultimemMultimem_IfEnabled( + /*enabled=*/std::integral_constant(), + thread, nThreads, redArg, postOp, srcPtr, dstPtr, nElts, warpScratchAddr); +} #endif // COMMON_KERNEL_H_ diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu index 7c36064..2b634d8 100644 --- a/src/collectives/device/functions.cu +++ b/src/collectives/device/functions.cu @@ -9,6 +9,9 @@ #include "common.h" __shared__ ncclShmemData ncclShmem; +#if __CUDA_ARCH__ < 700 + __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)]; +#endif #define NCCL_FUNC5(func, algo, devredop, type, nullify) \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ @@ -19,7 +22,8 @@ __shared__ ncclShmemData ncclShmem; NCCL_FUNC5(func, TREE, devredop, type, nullify), \ NCCL_FUNC5(func, RING, devredop, type, nullify), \ NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \ - NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify) + NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \ + NCCL_FUNC5(func, NVLS, devredop, type, nullify) #if defined(__CUDA_BF16_TYPES_EXIST__) // Must be consistent with ncclDataType_t diff --git a/src/collectives/device/onerank_reduce.cu b/src/collectives/device/onerank_reduce.cu index f594e34..85d1255 100644 --- a/src/collectives/device/onerank_reduce.cu +++ b/src/collectives/device/onerank_reduce.cu @@ -6,7 +6,7 @@ #include "devcomm.h" #include "collectives.h" -#include "reduce_kernel.h" +#include "common_kernel.h" #include "common.h" namespace { @@ -35,8 +35,10 @@ namespace { i1 = i1 < eltN ? i1 : eltN; src += i0; dst += i0; - ReduceOrCopyMulti - (tid, tn, &(we->redOpArg), true, 1, &src, 1, &dst, i1-i0); + void *vsrc = (void*)src; + void *vdst = (void*)dst; + ReduceOrCopyMulti + (tid, tn, we->redOpArg, &(we->redOpArg), true, 1, &vsrc, 1, &vdst, i1-i0); } } } diff --git a/src/collectives/device/op128.h b/src/collectives/device/op128.h index 46fc8df..2ea2268 100644 --- a/src/collectives/device/op128.h +++ b/src/collectives/device/op128.h @@ -65,4 +65,290 @@ inline __device__ void loadShmemMisaligned128(T *ptr, uint64_t &v0, uint64_t &v1 v1 = tmp8[1]; } + +template +__device__ __forceinline__ uint32_t cvta_to_shared(T* ptr) { + return (uint32_t)__cvta_generic_to_shared(ptr); +} +template +__device__ __forceinline__ uintptr_t cvta_to_global(T* ptr) { + return (uintptr_t)__cvta_generic_to_global(ptr); +} + +template +__device__ __forceinline__ T* cvta_from_shared(uint32_t shptr) { + T* ans; + asm("cvta.shared.u64 %0, %1;" : "=l"(ans) : "l"(uint64_t(shptr))); + return ans; +} +template +__device__ __forceinline__ T* cvta_from_global(uintptr_t gptr) { + T* ans; + asm("cvta.global.u64 %0, %1;" : "=l"(ans) : "l"(gptr)); + return ans; +} + +//////////////////////////////////////////////////////////////////////////////// +// BytePack: struct of bytes. + +template +union BytePack; +template<> +union BytePack<1> { + uint8_t u8, native; +}; +template<> +union BytePack<2> { + BytePack<1> half[2]; + uint8_t u8[2]; + uint16_t u16, native; +}; +template<> +union BytePack<4> { + BytePack<2> half[2]; + uint8_t u8[4]; + uint16_t u16[2]; + uint32_t u32, native; +}; +template<> +union BytePack<8> { + BytePack<4> half[2]; + uint8_t u8[8]; + uint16_t u16[4]; + uint32_t u32[2]; + uint64_t u64, native; +}; +template<> +union alignas(16) BytePack<16> { + BytePack<8> half[2]; + uint8_t u8[16]; + uint16_t u16[8]; + uint32_t u32[4]; + uint64_t u64[2]; + ulong2 ul2, native; +}; + +template +__device__ __forceinline__ BytePack toPack(T value) { + union { BytePack p; T v; }; + v = value; + return p; +} +template +__device__ __forceinline__ T fromPack(BytePack pack) { + union { BytePack p; T v; }; + p = pack; + return v; +} + +//////////////////////////////////////////////////////////////////////////////// +// Load/store of BytePack using integral addresses. + +template __device__ BytePack ld_global(uintptr_t addr); +template __device__ BytePack ld_volatile_global(uintptr_t addr); +template __device__ BytePack ld_shared(uint32_t addr); +template __device__ BytePack ld_volatile_shared(uint32_t addr); +template __device__ void st_global(uintptr_t addr, BytePack value); +template __device__ void st_shared(uint32_t addr, BytePack value); + +// Used to define implementations for above prototypes. +#define DEFINE_ld_st(bytes, data_cxx_ty, data_ptx_ty, data_reg_ty, space, addr_cxx_ty, addr_reg_ty) \ + template<> \ + __device__ __forceinline__ BytePack ld_##space(addr_cxx_ty addr) { \ + data_cxx_ty tmp; \ + asm("ld." #space "." #data_ptx_ty " %0, [%1];" : "="#data_reg_ty(tmp) : #addr_reg_ty(addr)); \ + BytePack ans; \ + ans.native = tmp; \ + return ans; \ + } \ + template<> \ + __device__ __forceinline__ BytePack ld_volatile_##space(addr_cxx_ty addr) { \ + data_cxx_ty tmp; \ + asm("ld.volatile." #space "." #data_ptx_ty " %0, [%1];" : "="#data_reg_ty(tmp) : #addr_reg_ty(addr)); \ + BytePack ans; \ + ans.native = tmp; \ + return ans; \ + } \ + template<> \ + __device__ __forceinline__ void st_##space(addr_cxx_ty addr, BytePack value) { \ + data_cxx_ty tmp = value.native; \ + asm volatile("st." #space "." #data_ptx_ty " [%0], %1;" :: #addr_reg_ty(addr), #data_reg_ty(tmp) : "memory"); \ + } +// Single-byte types use 4-byte registers since there is no 1-byte register +// character for asm blocks. See https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints +DEFINE_ld_st(1, uint32_t, b8, r, global, uintptr_t, l) +DEFINE_ld_st(1, uint32_t, b8, r, shared, uint32_t, r) +DEFINE_ld_st(2, uint16_t, b16, h, global, uintptr_t, l) +DEFINE_ld_st(2, uint16_t, b16, h, shared, uint32_t, r) +DEFINE_ld_st(4, uint32_t, b32, r, global, uintptr_t, l) +DEFINE_ld_st(4, uint32_t, b32, r, shared, uint32_t, r) +DEFINE_ld_st(8, uint64_t, b64, l, global, uintptr_t, l) +DEFINE_ld_st(8, uint64_t, b64, l, shared, uint32_t, r) +#undef DEFINE_ld_st + +#define DEFINE_ld_st_16(space, addr_cxx_ty, addr_reg_ty) \ + template<> \ + __device__ __forceinline__ BytePack<16> ld_##space<16>(addr_cxx_ty addr) { \ + BytePack<16> ans; \ + asm("ld." #space ".v2.b64 {%0,%1}, [%2];" : "=l"(ans.u64[0]), "=l"(ans.u64[1]) : #addr_reg_ty(addr)); \ + return ans; \ + } \ + template<> \ + __device__ __forceinline__ BytePack<16> ld_volatile_##space<16>(addr_cxx_ty addr) { \ + BytePack<16> ans; \ + asm("ld.volatile." #space ".v2.b64 {%0,%1}, [%2];" : "=l"(ans.u64[0]), "=l"(ans.u64[1]) : #addr_reg_ty(addr)); \ + return ans; \ + } \ + template<> \ + __device__ __forceinline__ void st_##space<16>(addr_cxx_ty addr, BytePack<16> value) { \ + asm("st." #space ".v2.b64 [%0], {%1,%2};" :: #addr_reg_ty(addr), "l"(value.u64[0]), "l"(value.u64[1]) : "memory"); \ + } +DEFINE_ld_st_16(global, uintptr_t, l) +DEFINE_ld_st_16(shared, uint32_t, r) +#undef DEFINE_ld_st_16 + +//////////////////////////////////////////////////////////////////////////////// +// Atomic load/store using c++ pointers. + +__device__ __forceinline__ uint64_t ld_volatile_global(uint64_t *ptr) { + uint64_t ans; + asm("ld.volatile.global.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr))); + return ans; +} +__device__ __forceinline__ uint64_t ld_relaxed_sys_global(uint64_t *ptr) { + uint64_t ans; + #if __CUDA_ARCH__ >= 700 + asm("ld.relaxed.sys.global.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr))); + #else + asm("ld.volatile.global.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr))); + #endif + return ans; +} +__device__ __forceinline__ uint64_t ld_acquire_sys_global(uint64_t *ptr) { + uint64_t ans; + #if __CUDA_ARCH__ >= 700 + asm("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr))); + #else + asm("ld.volatile.sys.global.u64 %0, [%1]; membar.gl;" : "=l"(ans) : "l"(cvta_to_global(ptr))); + #endif + return ans; +} + +__device__ __forceinline__ void st_volatile_global(uint64_t *ptr, uint64_t val) { + asm volatile("st.volatile.global.u64 [%0], %1;" :: "l"(cvta_to_global(ptr)), "l"(val) : "memory"); +} +__device__ __forceinline__ void st_relaxed_sys_global(uint64_t *ptr, uint64_t val) { + #if __CUDA_ARCH__ >= 700 + asm volatile("st.relaxed.sys.global.u64 [%0], %1;" :: "l"(cvta_to_global(ptr)), "l"(val) : "memory"); + #else + asm volatile("st.volatile.global.u64 [%0], %1;" :: "l"(cvta_to_global(ptr)), "l"(val) : "memory"); + #endif +} +__device__ __forceinline__ void st_release_sys_global(uint64_t *ptr, uint64_t val) { + #if __CUDA_ARCH__ >= 700 + asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(cvta_to_global(ptr)), "l"(val) : "memory"); + #else + asm volatile("membar.sys; st.volatile.global.u64 [%0], %1;" :: "l"(cvta_to_global(ptr)), "l"(val) : "memory"); + #endif +} + +__device__ __forceinline__ void fence_acq_rel_sys() { + #if __CUDA_ARCH__ >= 700 + asm volatile("fence.acq_rel.sys;" ::: "memory"); + #else + asm volatile("membar.sys;" ::: "memory"); + #endif +} +__device__ __forceinline__ void fence_acq_rel_gpu() { + #if __CUDA_ARCH__ >= 700 + asm volatile("fence.acq_rel.gpu;" ::: "memory"); + #else + asm volatile("membar.gl;" ::: "memory"); + #endif +} + +//////////////////////////////////////////////////////////////////////////////// +// Multimem stores of BytePack. + +template +__device__ __forceinline__ void multimem_st_global(uintptr_t addr, BytePack val); + +#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 +template<> +__device__ __forceinline__ void multimem_st_global<4>(uintptr_t addr, BytePack<4> val) { + asm volatile("multimem.st.global.b32 [%0], %1;" :: "l"(addr), "r"(val.u32) : "memory"); +} +template<> +__device__ __forceinline__ void multimem_st_global<8>(uintptr_t addr, BytePack<8> val) { + asm volatile("multimem.st.global.b64 [%0], %1;" :: "l"(addr), "l"(val.u64) : "memory"); +} +template<> +__device__ __forceinline__ void multimem_st_global<16>(uintptr_t addr, BytePack<16> val) { + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" + :: "l"(addr), "r"(val.u32[0]), "r"(val.u32[1]), "r"(val.u32[2]), "r"(val.u32[3]) + : "memory"); +} +#else +template +__device__ __forceinline__ void multimem_st_global(uintptr_t addr, BytePack val) { + // nop +} +#endif + +// Warp-uniform memory copy from shared address (not generic) to global memory. +// The number of bytes copied is `min(MaxBytes, nBytesAhead)`, a negative value +// is interpeted as zero. EltSize is the guaranteed alignment of the addresses and sizes. +template +__device__ __forceinline__ void copyGlobalShared_WarpUnrolled( + int lane, uintptr_t dstAddr, uint32_t srcAddr, IntBytes nBytesAhead + ) { + static_assert(std::is_signed::value, "`IntBytes` must be a signed integral type."); + int nBytes = min(nBytesAhead, (IntBytes)MaxBytes); + int nFrontBytes = min(nBytes, (16 - int(dstAddr%16))%16); + int nMiddleBytes = (nBytes-nFrontBytes) & -16; + int nBackBytes = (nBytes-nFrontBytes) % 16; + + { int backLane = WARP_SIZE-1 - lane; + bool hasFront = lane*EltSize < nFrontBytes; + bool hasBack = backLane*EltSize < nBackBytes; + int offset = hasFront ? lane*EltSize : (nBytes - (backLane+1)*EltSize); + if (hasFront | hasBack) { + BytePack tmp = ld_shared(srcAddr+offset); + // Can't use multimem_st since it doesn't support EltSize==2 + st_global(dstAddr+offset, tmp); + } + } + + srcAddr += nFrontBytes; + int srcMisalign = EltSize < 4 ? (srcAddr%4) : 0; + srcAddr += -srcMisalign + lane*16; + dstAddr += nFrontBytes + lane*16; + nMiddleBytes -= lane*16; + #pragma unroll + for (int u=0; u < divUp(MaxBytes, WARP_SIZE*16); u++) { + if (nMiddleBytes <= 0) break; + union { + BytePack<4> b4[4]; + BytePack<16> b16; + }; + b4[0] = ld_shared<4>(srcAddr + 0*4); + b4[1] = ld_shared<4>(srcAddr + 1*4); + b4[2] = ld_shared<4>(srcAddr + 2*4); + b4[3] = ld_shared<4>(srcAddr + 3*4); + if (srcMisalign != 0) { + BytePack<4> b4_4 = ld_shared<4>(srcAddr + 4*4); + b4[0].u32 = __funnelshift_r(b4[0].u32, b4[1].u32, srcMisalign*8); + b4[1].u32 = __funnelshift_r(b4[1].u32, b4[2].u32, srcMisalign*8); + b4[2].u32 = __funnelshift_r(b4[2].u32, b4[3].u32, srcMisalign*8); + b4[3].u32 = __funnelshift_r(b4[3].u32, b4_4.u32, srcMisalign*8); + } + if (Multimem) multimem_st_global<16>(dstAddr, b16); + else st_global<16>(dstAddr, b16); + + srcAddr += WARP_SIZE*16; + dstAddr += WARP_SIZE*16; + nMiddleBytes -= WARP_SIZE*16; + } +} + #endif diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index ccc0d22..050acad 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -9,6 +9,7 @@ #include #include "reduce_kernel.h" // for reduction funcs +#include "common_kernel.h" #include "common.h" #define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000 @@ -20,12 +21,13 @@ * to how that protocol operates with a consistent interface so that our * algorithm code can operate protocol parametrically. */ -template +template struct ProtoSimple { static constexpr int Id = NCCL_PROTO_SIMPLE; static constexpr int SlicePerChunk = SlicePerChunk_1; static constexpr int StepPerSlice = StepPerSlice_1; static constexpr int Unroll = Unroll_1; + static constexpr bool NVLS = NVLS_1; // Data bytes (no flags etc) in one step of the fifo queue. __device__ static int calcBytePerStep() { diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index 60f64ff..c43f1a5 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -255,18 +255,18 @@ class Primitives: } if (SRC) { data = dl.loadFinish(); - if (SrcBuf == Input) data = MULTI().preOp(redOp, data); + if (SrcBuf == Input) data = applyPreOp(redOp, data); } if (RECV) { - data = !SRC ? peerData : MULTI()(redOp, peerData, data); + data = !SRC ? peerData : applyReduce(redOp, peerData, data); #pragma unroll MaxRecv for (int i=1; i < MaxRecv && i < fan.nrecv(); i++) { peerData = readLLFinish(offset, line, i); - data = MULTI()(redOp, peerData, data); + data = applyReduce(redOp, peerData, data); } } - if (postOp) data = MULTI().postOp(redOp, data); + if (postOp) data = applyPostOp(redOp, data); // Send : inter-node, then intra-node, then local if (SEND) { diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index 773a921..8a4570a 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -82,7 +82,14 @@ class Primitives: if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1; } inline __device__ void postSend() { - if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; } + if (sendConnTailPtr) { +#if __CUDA_ARCH__ >= 900 + __threadfence_system(); +#else + __threadfence(); +#endif + *sendConnTailPtr = sendConnTail += 1; + } } template @@ -109,7 +116,7 @@ class Primitives: // buffer into shmem. int misalignment = reinterpret_cast(src) % 16; uint64_t *src8 = reinterpret_cast(reinterpret_cast(src) & -uintptr_t(16)); - uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warpInBlock]); + uint64_t *shm8 = shmemCvtPtr((uint64_t*)ncclScratchForWarp(warpInBlock)); #pragma unroll for(int g=0; g < WordPerThread/2; g++) if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T)) @@ -153,7 +160,7 @@ class Primitives: } // Write to dst if 16-byte aligned, shmem otherwise. int misalignment = reinterpret_cast(dst)%16; - uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warpInBlock]); + uint64_t *shm8 = shmemCvtPtr((uint64_t*)ncclScratchForWarp(warpInBlock)); #pragma unroll for(int g=0; g < WordPerThread/2; g++) { int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); @@ -167,7 +174,7 @@ class Primitives: __syncwarp(); // Write rest from shmem to dst. No need to coalesce stores to 16-bytes, // the hardware keeps up fine. - T *shm = (T*)ncclShmem.ll128warp[warpInBlock]; + T *shm = (T*)ncclScratchForWarp(warpInBlock); int skip = misalignment == 0 ? eltN & -EltPer16B : 0; for(int i=skip+wid; i < eltN; i += WARP_SIZE) dst[i] = shm[i]; @@ -196,6 +203,10 @@ class Primitives: } needReload &= (0 == checkAbort(spins, 0, 0)); } while (__any_sync(WARP_MASK, needReload)); + + #pragma unroll + for (int u=0; u: if (SrcBuf == Input) { #pragma unroll for (int u=0; u().preOp(redOp, v[u]); + v[u] = applyPreOp(redOp, v[u]); if (!flagThread) - v[u+1] = MULTI().preOp(redOp, v[u+1]); + v[u+1] = applyPreOp(redOp, v[u+1]); } } } @@ -218,8 +229,8 @@ class Primitives: { // Consume data from first recv #pragma unroll for (int u=0; u()(redOp, vr[u], v[u]) : vr[u]; - v[u+1] = SRC ? MULTI()(redOp, vr[u+1], v[u+1]) : vr[u+1]; + v[u] = SRC ? applyReduce(redOp, vr[u], v[u]) : vr[u]; + v[u+1] = SRC ? applyReduce(redOp, vr[u+1], v[u+1]) : vr[u+1]; } } @@ -238,20 +249,24 @@ class Primitives: needReload &= (0 == checkAbort(spins, i, 0)); } while (__any_sync(WARP_MASK, needReload)); + #pragma unroll + for (int u=0; u()(redOp, vr[u], v[u]); - v[u+1] = MULTI()(redOp, vr[u+1], v[u+1]); + v[u] = applyReduce(redOp, vr[u], v[u]); + v[u+1] = applyReduce(redOp, vr[u+1], v[u+1]); } } } /********************** End Recv ************************/ - if (postOp && !FuncTraits::IsPostOpIdentity) { + if (postOp) { #pragma unroll for (int u=0; u().postOp(redOp, v[u]); - v[u+1] = MULTI().postOp(redOp, v[u+1]); + v[u] = applyPostOp(redOp, v[u]); + v[u+1] = applyPostOp(redOp, v[u+1]); } } @@ -282,14 +297,6 @@ class Primitives: __device__ __forceinline__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { constexpr int SRC = SrcBuf != -1 ? 1 : 0; constexpr int DST = DstBuf != -1 ? 1 : 0; - static_assert(-1<=SrcBuf && SrcBuf < 2, "Uhoh"); - static_assert(-1<=DstBuf && DstBuf < 2, "Uhoh"); - static_assert(DstBuf!=Input, "Mistake?"); - #if 0 - assert((SrcBuf==-1) == (srcIx==-1)); - assert((DstBuf==-1) == (dstIx==-1)); - #endif - T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx; int wireOffset = WireWordPerSlice*warp + 2*wid; diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index 9d2d19a..2cd3797 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -5,9 +5,9 @@ ************************************************************************/ template + int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, bool NVLS> class Primitives< - T, RedOp, Fan, Direct, ProtoSimple, P2p + T, RedOp, Fan, Direct, ProtoSimple, P2p > { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; @@ -22,8 +22,10 @@ class Primitives< SizesFifoEnabled = 0x100, DirectWrite = 0x200, DirectRead = 0x400, - ThreadsSynced = 0x800; - const int tid; + ThreadsSynced = 0x800, + NvlsMinPolling = 0x1000, + NvlsRecv = 0x2000; + const int tid, tidInBlock; int nthreads; int nworkers; const int stepSize; @@ -41,22 +43,54 @@ class Primitives< int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled) T *directBuff; // !(flags & SizesFifoEnabled) }; - uint64_t volatile *connStepPtr; + uint64_t *connStepPtr; uint64_t connStepCache; // Cache last seen value of (*connStepPtr) // Don't use barrier 0 as it's used by the final sync - inline __device__ void barrier() { - if (nthreads == WARP_SIZE) - __syncwarp(); - else - asm volatile("bar.sync %0, %1;" :: "r"(15-group), "r"(nthreads)); + __device__ void barrier() { flags |= ThreadsSynced; + if (nthreads == WARP_SIZE) __syncwarp(); + else { + int bar = 15-group; + asm volatile("bar.sync %0, %1;" :: "r"(bar), "r"(nthreads) : "memory"); + } } - inline __device__ void subBarrier() { - if (nworkers == nthreads) - barrier(); - else - asm volatile("bar.sync %0, %1;" :: "r"(8-group), "r"(nworkers)); + __device__ void subBarrier() { + if (nworkers == WARP_SIZE) __syncwarp(); + else { + int bar = (nworkers==nthreads ? 15 : 8) - group; + asm volatile("bar.sync %0, %1;" :: "r"(bar), "r"(nworkers) : "memory"); + } + } + + __device__ bool barrierAny(int vote) { + flags |= ThreadsSynced; + if (nthreads == WARP_SIZE) { + return __any_sync(~0u, vote); + } else { + int ans, bar = 15-group; + asm volatile( + "{ .reg .pred p;" + " setp.ne.s32 p, %1, 0;" + " bar.red.or.pred p, %2, %3, p; " + " selp.s32 %0, 1, 0, p; }" + : "=r"(ans) : "r"(vote), "r"(bar), "r"(nthreads) : "memory"); + return ans != 0; + } + } + __device__ bool subBarrierAny(int vote) { + if (nworkers == WARP_SIZE) { + return __any_sync(~0u, vote); + } else { + int ans, bar = (nworkers==nthreads ? 15 : 8) - group; + asm volatile( + "{ .reg .pred p;" + " setp.ne.s32 p, %1, 0;" + " bar.red.or.pred p, %2, %3, p; " + " selp.s32 %0, 1, 0, p; }" + : "=r"(ans) : "r"(vote), "r"(bar), "r"(nworkers) : "memory"); + return ans != 0; + } } inline __device__ bool checkAbort(int &spins) { @@ -71,6 +105,19 @@ class Primitives< return flags & Aborted; } + inline __device__ uint64_t loadStepValue(uint64_t* ptr) { + #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 + if (NVLS && (flags & NvlsMinPolling)) { + uint64_t ans; + asm("multimem.ld_reduce.acquire.sys.global.min.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr))); + return ans; + } + #endif + // volatile is faster than acquire but not as correct. Make sure ReduceOrCopyMulti + // loads data using volatile so it doesn't see stale data in L1. + return ld_volatile_global(ptr); + } + template __device__ __forceinline__ void waitPeer(intptr_t dstIx, intptr_t remoteIx, int offset, int nelts) { const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; @@ -80,7 +127,7 @@ class Primitives< ((flags & (Send*RoleWaitSend)) && !noSendWait)) { int spins = 0; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { - connStepCache = *connStepPtr; + connStepCache = loadStepValue(connStepPtr); if (checkAbort(spins)) break; //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); } @@ -119,10 +166,11 @@ class Primitives< } template - inline __device__ void postPeer() { + inline __device__ void postPeer(bool dataStored) { if (flags & (Recv*RolePostRecv | Send*RolePostSend)) { step += StepPerSlice; - *connStepPtr = step; + if (Send && (flags & RolePostSend) && dataStored) fence_acq_rel_sys(); + st_relaxed_sys_global(connStepPtr, step); } } @@ -166,7 +214,7 @@ class Primitives< // post(); // } // Since we no longer unroll, new branch added here #if __CUDA_ARCH__ < 700 - // Yeah, so all that above don't matter a lick on older hardware. + // Above doesn't matter on older hardware. #pragma unroll SlicePerChunk #else #pragma unroll 1 @@ -181,37 +229,39 @@ class Primitives< subBarrier(); /* if user abort the kernel, we don't need to actually perform copy/reduce; just set size * to 0 to avoid unnecessary workload. */ - size_t workSize = ncclShmem.aborted ? 0 : sliceSize; - if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { + int workSize = ncclShmem.aborted ? 0 : sliceSize; + if (NVLS && ncclShmem.groups[group].nvlsRecv) { + void* src = ncclShmem.groups[group].srcs[0]; + void* dst = ncclShmem.groups[group].dsts[0]; + copyMultimemMultimem(tid, nworkers, ncclShmem.redOpArgs[0], postOp, src, dst, workSize, + cvta_to_shared(ncclScratchForWarp(tidInBlock/WARP_SIZE))); + } else if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (Send) { - // (1-Send) is only there to avoid compilation errors in case MaxSend=0 (and Send=0). - ReduceOrCopyMulti - (tid, nworkers, nullptr, false, - 1, (T const**)ncclShmem.groups[group].srcs, - fan.nsend(), (T**)ncclShmem.groups[group].dsts+1, + ReduceOrCopyMulti + (tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false, + 1, ncclShmem.groups[group].srcs, + fan.nsend(), ncclShmem.groups[group].dsts+1, workSize); } } else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem.groups[group].dsts[Dst] == nullptr) { // For broadcast in CollNet to do empty send - ReduceOrCopyMulti - (tid, nworkers, ncclShmem.redOpArgs, postOp, - Recv, (T const**)ncclShmem.groups[group].srcs, - Dst, (T**)ncclShmem.groups[group].dsts, + ReduceOrCopyMulti + (tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp, + Recv, ncclShmem.groups[group].srcs, + Dst, ncclShmem.groups[group].dsts, workSize); } else { - constexpr int PreOpN = SrcBuf != Input ? 0 : - DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; - ReduceOrCopyMulti - (tid, nworkers, ncclShmem.redOpArgs, postOp, - Recv*fan.nrecv()+Src, (T const**)ncclShmem.groups[group].srcs, - Send*fan.nsend()+Dst, (T**)ncclShmem.groups[group].dsts, + constexpr int PreOpSrcs = SrcBuf != Input ? 0 : + DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; + ReduceOrCopyMulti + (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, + Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs, + Send*fan.nsend()+Dst, ncclShmem.groups[group].dsts, workSize); } barrier(); // This barrier has a counterpart in following loop - if (Send && (flags & RolePostSend) && index == 0) __threadfence_system(); - __syncwarp(); - postPeer(); + postPeer(0 < sliceSize); offset += sliceSize; slice += 1; } while (slice < SlicePerChunk && offset < nelem); @@ -229,9 +279,7 @@ class Primitives< waitPeer(0, 0, 0, 0); } barrier(); // Has couterpart in preceding worker-only loop. - if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) __threadfence_system(); - __syncwarp(); - postPeer(); + postPeer(0 < sliceSize); offset += sliceSize; slice += 1; } @@ -242,7 +290,7 @@ class Primitives< // shift: peer offset to avoid all ranks sending to or receiving from same peer template __device__ __forceinline__ void - ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp) { + ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift, bool postOp) { constexpr int DirectRecv = 1 && Direct && DirectRecv1; constexpr int DirectSend = 1 && Direct && DirectSend1; int offset = 0; // slice offset @@ -252,12 +300,12 @@ class Primitives< #pragma unroll for (int slice=0; slice(0, inpIx, offset, realSize); subBarrier(); @@ -265,23 +313,23 @@ class Primitives< // Loop over peers for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; - const T* src0 = (T*)ncclShmem.groups[group].srcs[0] + peerOffset; - int realPeerSize = min(realSize, totalElem-peerOffset); + if (skip >= 0 && i >= skip) pOffset += peerElem; + void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset; + int realPeerSize = min(realSize, totalElem-pOffset); if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) { - ReduceOrCopyMulti(tid, nworkers, ncclShmem.redOpArgs, false, 1, &src0, 1, (T**)ncclShmem.groups[group].dsts+i, realPeerSize); + ReduceOrCopyMulti(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize); // Mark for threadfence at the end - if (tid == 0) ncclShmem.groups[group].totalSendSize[slice] += realPeerSize; + fenceNeeded |= true; } } } else if (Recv) { if (flags & RoleOutput) ncclShmem.groups[group].dsts[0] = userBuff + outIx + offset; - int peerOffset = index*peerElem; - if (skip >= 0 && index >= skip) peerOffset += peerElem; + int pOffset = index*peerOffset; + if (skip >= 0 && index >= skip) pOffset += peerElem; // Adjust remote index with peer offset in case we are directly pulling from peer's output buffer - waitPeer(outIx, outIx+peerOffset, offset, realSize); + waitPeer(outIx, outIx+pOffset, offset, realSize); subBarrier(); if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { // Since waitPeer sets srcs[0] to output buffer + offset, we are doing a direct-write based recv @@ -290,21 +338,17 @@ class Primitives< #pragma unroll for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; - T* dst0 = (T*)ncclShmem.groups[group].dsts[0] + peerOffset; - int realPeerSize = min(realSize, totalElem-peerOffset); - if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, ncclShmem.redOpArgs, postOp, 1, (const T**)ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); + pOffset = i*peerOffset; + if (skip >= 0 && i >= skip) pOffset += peerElem; + void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset; + int realPeerSize = min(realSize, totalElem-pOffset); + if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); } } } } - barrier(); - // If we indeed send something, threadfence - if (Send && (flags & RolePostSend) && ncclShmem.groups[group].totalSendSize[slice] > 0 && index == 0) - __threadfence_system(); - __syncwarp(); - postPeer(); + fenceNeeded = barrierAny(fenceNeeded); + postPeer(fenceNeeded); offset += realSize; } } @@ -320,25 +364,33 @@ class Primitives< } if (flags & RoleWaitRecv) { ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs() + if ((index == 0) && (flags & RoleWaitRecv)) { + if (conn->flags & NCCL_NVLS_MIN_POLL) { + flags |= NvlsMinPolling; + ncclShmem.groups[group].nvlsRecv = 1; + } else { + ncclShmem.groups[group].nvlsRecv = 0; + } + } connStepPtr = conn->tail; - connStepCache = *connStepPtr; + connStepCache = loadStepValue(connStepPtr); flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; if (Direct) { // User buffers have been registered - if ((conn->direct & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { + if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite : (e->direct & NCCL_DIRECT_READ) ? DirectRead : 0; } - } else if (conn->direct & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { + } else if (conn->flags & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { // direct read not allowed in non-register case // otherwise, in one-to-multi send, we could mix empty send and intermediate send - flags |= (conn->direct & NCCL_DIRECT_WRITE) ? DirectWrite : 0; + flags |= (conn->flags & NCCL_DIRECT_WRITE) ? DirectWrite : 0; } } } @@ -359,8 +411,9 @@ class Primitives< } if (flags & RoleWaitSend) { ncclShmem.groups[group].sendConns[index] = conn; // WaitSend role saves since that's who needs it in setDataPtrs() + flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0; connStepPtr = conn->head; - connStepCache = *connStepPtr; + connStepCache = loadStepValue(connStepPtr); flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; if (flags & OffsFifoEnabled) connOffsFifoPtr = conn->offsFifo; @@ -371,20 +424,20 @@ class Primitives< connSizesFifoPtr = conn->sizesFifo; } else if (Direct) { // User buffers have been registered - if ((conn->direct & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { + if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite : (e->direct & NCCL_DIRECT_READ) ? DirectRead : 0; } - } else if (conn->direct & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { + } else if (conn->flags & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { // direct read not allowed in non-register case // otherwise, in one-to-multi send, we could mix empty send and intermediate send - flags |= (conn->direct & NCCL_DIRECT_WRITE) ? DirectWrite : 0; + flags |= (conn->flags & NCCL_DIRECT_WRITE) ? DirectWrite : 0; } } } @@ -397,7 +450,7 @@ class Primitives< int tid, int nthreads, int const *recvPeers, int const *sendPeers, void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr ): - tid(tid), + tid(tid), tidInBlock(threadIdx.x), stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) { // For send operations, we need an extra warp to overlap the threadfence and the copy @@ -412,7 +465,7 @@ class Primitives< this->fan = Fan(nrecv, nsend); constexpr int ThreadPerSync = 8; - static_assert(MaxSend < ThreadPerSync && MaxRecv < ThreadPerSync, "Not enough threads to cover all peers"); + static_assert(MaxSend <= ThreadPerSync && MaxRecv <= ThreadPerSync, "Not enough threads to cover all peers"); int g = tid / ThreadPerSync; int ng = nthreads / ThreadPerSync; @@ -566,6 +619,9 @@ class Primitives< genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp); } + __device__ __forceinline__ void recvSend(int eltN, bool postOp=false) { + genericOp<0, 0, 1, 1, -1, -1>(-1, -1, -1, eltN, postOp); + } __device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, -1, eltN, postOp); } @@ -596,20 +652,20 @@ class Primitives< } __device__ __forceinline__ void - scatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<0, 0, 0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false); + scatter(intptr_t inpIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { + ScatterGatherOp<0, 0, 0, 1>(inpIx, -1, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } __device__ __forceinline__ void - directScatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<0, 1, 0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false); + directScatter(intptr_t inpIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { + ScatterGatherOp<0, 1, 0, 1>(inpIx, -1, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } __device__ __forceinline__ void - gather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp=false) { - ScatterGatherOp<0, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp); + gather(intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift, bool postOp=false) { + ScatterGatherOp<0, 0, 1, 0>(-1, outIx, totalElem, peerElem, peerOffset, skip, shift, postOp); } __device__ __forceinline__ void - directGather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, /*postOp=*/false); + directGather(intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { + ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } }; diff --git a/src/collectives/device/reduce_kernel.h b/src/collectives/device/reduce_kernel.h index 878ec79..7e1b5eb 100644 --- a/src/collectives/device/reduce_kernel.h +++ b/src/collectives/device/reduce_kernel.h @@ -8,466 +8,447 @@ #ifndef NCCL_REDUCE_KERNEL_H_ #define NCCL_REDUCE_KERNEL_H_ -#include "common_kernel.h" +#include "op128.h" #include #include -template -struct FuncNull { - __device__ FuncNull(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return 0; - } -}; +//////////////////////////////////////////////////////////////////////////////// +// The reduction function classes. All classes must: +// 1. Expose the `EltType` typedef. +// 2. Have constructor taking no arguments (default constructible). +// 3. Have constructor taking `uint64_t opArg`. template -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return x + y; - } -}; - +struct FuncNull { using EltType = T; __device__ FuncNull(uint64_t opArg=0) {}; }; template -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return x * y; - } -}; - +struct FuncSum { using EltType = T; __device__ FuncSum(uint64_t opArg=0) {}; }; template -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return (x < y) ? y : x; - } -}; - +struct FuncProd { using EltType = T; __device__ FuncProd(uint64_t opArg=0) {}; }; template -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return (x < y) ? x : y; - } -}; +struct FuncMin { using EltType = T; __device__ FuncMin(uint64_t opArg=0) {}; }; +template +struct FuncMax { using EltType = T; __device__ FuncMax(uint64_t opArg=0) {}; }; + +template struct FuncPreMulSum; +template struct FuncSumPostDiv; + +//////////////////////////////////////////////////////////////////////////////// +// Trait classes for reduction functions. Given a function (FuncSum, etc.) +// and a number of elements in a pack, will reduce, preOp, or postOp a pack +// of elements. These classes are intended to be specialized for specific +// combinations of reduction function and pack size. + +template +struct Apply_Reduce /*{ + static BytePack reduce( + Fn fn, BytePack a, BytePack b + ); +}*/; +template +struct Apply_PreOp/*{ + static constexpr bool IsIdentity; + static BytePack preOp(Fn fn, BytePack a); +}*/; +template +struct Apply_PostOp/*{ + static constexpr bool IsIdentity; + static BytePack postOp(Fn fn, BytePack a); +}*/; +template +struct Apply_LoadMultimem/*{ + static constexpr int PackSize; // 0 if not implemented + static BytePack load(Fn fn, uintptr_t addr); +}*/; + +//////////////////////////////////////////////////////////////////////////////// +// Public API for calling the trait classes. These take the data elements as a +// pack of any type, which could be a BytePack or any integral type (uint64_t, +// uint32_t, etc.), and will return a new pack where each element has been +// transformed appropriately. + +template +__device__ __forceinline__ Pack applyReduce(Fn fn, Pack a, Pack b) { + return fromPack( + Apply_Reduce + ::reduce(fn, toPack(a), toPack(b)) + ); +} + +template +__device__ __forceinline__ Pack applyPreOp(Fn fn, Pack a) { + return fromPack( + Apply_PreOp + ::preOp(fn, toPack(a)) + ); +} + +template +__device__ __forceinline__ Pack applyPostOp(Fn fn, Pack a) { + return fromPack( + Apply_PostOp + ::postOp(fn, toPack(a)) + ); +} template -struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max - static constexpr bool IsPreOpIdentity = true; - static constexpr bool IsPostOpIdentity = true; - - template - __device__ static T preOp(Fn, T x) { return x; } - template - __device__ static T postOp(Fn, T x) { return x; } -}; - -#define MASK0 0x00ff00ff -#define MASK1 0xff00ff00 -static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) { - /* This can be used both for signed and unsigned 8-bit addition */ - const uint32_t x0 = x & MASK0; - const uint32_t x1 = x & MASK1; - const uint32_t y0 = y & MASK0; - const uint32_t y1 = y & MASK1; - const uint32_t r0 = (x0+y0); - const uint32_t r1 = (x1+y1); - return (r0 & MASK0) | (r1 & MASK1); +__device__ __forceinline__ BytePack::PackSize> applyLoadMultimem(Fn fn, uintptr_t addr) { + return Apply_LoadMultimem::load(fn, addr); } -template<> -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vadd4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - return addChar4(x, y); -#endif - } - __device__ int8_t operator()(const int8_t x, const int8_t y) const { - return x+y; - } -}; -template<> -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vadd4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - return addChar4(x, y); -#endif - } - __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { - return x+y; +//////////////////////////////////////////////////////////////////////////////// +// Apply_Reduce + +// General recursive definition (EltPerPack > 1). This is how we iterate over +// all elements in a pack of any size, by breaking it into halves. Eventually +// we'll hit a base case (a more specific template specialization which takes +// precedence). +template +struct Apply_Reduce { + template + __device__ static BytePack reduce(Fn fn, BytePack a, BytePack b) { + a.half[0] = Apply_Reduce::reduce(fn, a.half[0], b.half[0]); + a.half[1] = Apply_Reduce::reduce(fn, a.half[1], b.half[1]); + return a; } }; -static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) { - /* This can be used both for signed and unsigned 8-bit multiplication */ - union converter { uint32_t storage; char4 a; }; - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = cx.a.x * cy.a.x; - cr.a.y = cx.a.y * cy.a.y; - cr.a.z = cx.a.z * cy.a.z; - cr.a.w = cx.a.w * cy.a.w; - return cr.storage; -} - -template<> -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { - return mulChar4(x, y); - } - __device__ int8_t operator()(const int8_t x, const int8_t y) const { - return x*y; +// Base case definitions (EltPerPack == 1) +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncSum fn, BytePack a, BytePack b) { + return a; } }; -template<> -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { - return mulChar4(x, y); +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncSum fn, BytePack a, BytePack b) { + return toPack(fromPack(a) + fromPack(b)); } - __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { - return x*y; +}; +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncProd fn, BytePack a, BytePack b) { + return toPack(fromPack(a) * fromPack(b)); + } +}; +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncMin fn, BytePack a, BytePack b) { + return toPack(min(fromPack(a), fromPack(b))); + } +}; +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncMax fn, BytePack a, BytePack b) { + return toPack(max(fromPack(a), fromPack(b))); } }; +// Optimizations for specfic types and element count combinations: template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - union converter { uint32_t storage; char4 a; }; - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vmax4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = max(cx.a.x, cy.a.x); - cr.a.y = max(cx.a.y, cy.a.y); - cr.a.z = max(cx.a.z, cy.a.z); - cr.a.w = max(cx.a.w, cy.a.w); - return cr.storage; -#endif - } - __device__ int8_t operator()(const int8_t x, const int8_t y) const { - return (x>y) ? x : y; +struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncSum fn, BytePack<4> a, BytePack<4> b) { + constexpr uint32_t lo = 0x00ff00ff; + constexpr uint32_t hi = ~lo; + uint32_t x = a.u32; + uint32_t y = b.u32; + a.u32 = (((x&lo) + (y&lo))&lo) + (((x&hi) + (y&hi))&hi); + return a; } }; template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - union converter { uint32_t storage; uchar4 a; }; - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vmax4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = max(cx.a.x, cy.a.x); - cr.a.y = max(cx.a.y, cy.a.y); - cr.a.z = max(cx.a.z, cy.a.z); - cr.a.w = max(cx.a.w, cy.a.w); - return cr.storage; -#endif - } - __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { - return (x>y) ? x : y; +struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncSum fn, BytePack<4> a, BytePack<4> b) { + return Apply_Reduce, 4>::reduce(FuncSum(), a, b); } }; -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - union converter { uint32_t storage; char4 a; }; - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vmin4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = min(cx.a.x, cy.a.x); - cr.a.y = min(cx.a.y, cy.a.y); - cr.a.z = min(cx.a.z, cy.a.z); - cr.a.w = min(cx.a.w, cy.a.w); - return cr.storage; +#if 300 <= __CUDA_ARCH__ && __CUDA_ARCH__ < 500 + template<> + struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncMin fn, BytePack<4> a, BytePack<4> b) { + uint32_t z=0; + asm("vmin4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(a.u32) : "r"(a.u32), "r"(b.u32), "r"(z)); + return a; + } + }; + template<> + struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncMin fn, BytePack<4> a, BytePack<4> b) { + int32_t z=0; + asm("vmin4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(a.u32) : "r"(a.u32), "r"(b.u32), "r"(z)); + return a; + } + }; + template<> + struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncMax fn, BytePack<4> a, BytePack<4> b) { + uint32_t z=0; + asm("vmax4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(a.u32) : "r"(a.u32), "r"(b.u32), "r"(z)); + return a; + } + }; + template<> + struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncMax fn, BytePack<4> a, BytePack<4> b) { + int32_t z=0; + asm("vmax4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(a.u32) : "r"(a.u32), "r"(b.u32), "r"(z)); + return a; + } + }; #endif - } - __device__ int8_t operator()(const int8_t x, const int8_t y) const { - return (x -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - union converter { uint32_t storage; uchar4 a; }; - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vmin4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = min(cx.a.x, cy.a.x); - cr.a.y = min(cx.a.y, cy.a.y); - cr.a.z = min(cx.a.z, cy.a.z); - cr.a.w = min(cx.a.w, cy.a.w); - return cr.storage; -#endif - } - __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { - return (x -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ half2 operator()(const half2 x, const half2 y) const { +#define SPECIALIZE_REDUCE(Fn, T, EltPerPack, Vec, expr_of_x_y) \ + template<> \ + struct Apply_Reduce, EltPerPack> { \ + __device__ __forceinline__ static BytePack reduce( \ + Fn fn, BytePack a, BytePack b \ + ) { \ + Vec x = fromPack(a); \ + Vec y = fromPack(b); \ + return toPack(expr_of_x_y); \ + } \ + }; + #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - return __hadd2(x, y); + SPECIALIZE_REDUCE(FuncSum, half, 1, half, __hadd(x, y)) + SPECIALIZE_REDUCE(FuncSum, half, 2, half2, __hadd2(x, y)) + SPECIALIZE_REDUCE(FuncProd, half, 1, half, __hmul(x, y)) + SPECIALIZE_REDUCE(FuncProd, half, 2, half2, __hmul2(x, y)) #else - float2 fx, fy, fr; - fx = __half22float2(x); - fy = __half22float2(y); - fr.x = fx.x + fy.x; - fr.y = fx.y + fy.y; - return __float22half2_rn(fr); + SPECIALIZE_REDUCE(FuncSum, half, 1, half, __float2half(__half2float(x) + __half2float(y))) + SPECIALIZE_REDUCE(FuncProd, half, 1, half, __float2half(__half2float(x) * __half2float(y))) #endif + +#if __CUDA_ARCH__ >= 800 + SPECIALIZE_REDUCE(FuncMin, half, 1, half, __hmin(x, y)) + SPECIALIZE_REDUCE(FuncMin, half, 2, half2, __hmin2(x, y)) + SPECIALIZE_REDUCE(FuncMax, half, 1, half, __hmax(x, y)) + SPECIALIZE_REDUCE(FuncMax, half, 2, half2, __hmax2(x, y)) +#else + SPECIALIZE_REDUCE(FuncMin, half, 1, half, __float2half(fminf(__half2float(x), __half2float(y)))) + SPECIALIZE_REDUCE(FuncMax, half, 1, half, __float2half(fmaxf(__half2float(x), __half2float(y)))) +#endif + +#if defined(__CUDA_BF16_TYPES_EXIST__) +#if __CUDA_ARCH__ >= 800 + SPECIALIZE_REDUCE(FuncSum, __nv_bfloat16, 1, __nv_bfloat16, __hadd(x, y)) + SPECIALIZE_REDUCE(FuncSum, __nv_bfloat16, 2, __nv_bfloat162, __hadd2(x, y)) + SPECIALIZE_REDUCE(FuncProd, __nv_bfloat16, 1, __nv_bfloat16, __hmul(x, y)) + SPECIALIZE_REDUCE(FuncProd, __nv_bfloat16, 2, __nv_bfloat162, __hmul2(x, y)) + SPECIALIZE_REDUCE(FuncMin, __nv_bfloat16, 1, __nv_bfloat16, __hmin(x, y)) + SPECIALIZE_REDUCE(FuncMin, __nv_bfloat16, 2, __nv_bfloat162, __hmin2(x, y)) + SPECIALIZE_REDUCE(FuncMax, __nv_bfloat16, 1, __nv_bfloat16, __hmax(x, y)) + SPECIALIZE_REDUCE(FuncMax, __nv_bfloat16, 2, __nv_bfloat162, __hmax2(x, y)) +#else + SPECIALIZE_REDUCE(FuncSum, __nv_bfloat16, 1, __nv_bfloat16, __float2bfloat16(__bfloat162float(x) + __bfloat162float(y))) + SPECIALIZE_REDUCE(FuncProd, __nv_bfloat16, 1, __nv_bfloat16, __float2bfloat16(__bfloat162float(x) * __bfloat162float(y))) + SPECIALIZE_REDUCE(FuncMin, __nv_bfloat16, 1, __nv_bfloat16, __float2bfloat16(fminf(__bfloat162float(x), __bfloat162float(y)))) + SPECIALIZE_REDUCE(FuncMax, __nv_bfloat16, 1, __nv_bfloat16, __float2bfloat16(fmaxf(__bfloat162float(x), __bfloat162float(y)))) +#endif +#endif + +#undef SPECIALIZE_REDUCE + +//////////////////////////////////////////////////////////////////////////////// +// Apply_PreOp + +// General recursive definition (EltPerPack > 1) +template +struct Apply_PreOp { + static constexpr bool IsIdentity = Apply_PreOp::IsIdentity; + template + __device__ static BytePack preOp(Fn fn, BytePack a) { + #if __cpp_if_constexpr + if constexpr(!IsIdentity) { + #else + if (!IsIdentity) { + #endif + // The `if (!IsIdentity)` condition is not strictly necessary, but it may help + // compiler in that it won't have to tear a register apart for no reason + // just to put it back together again. + a.half[0] = Apply_PreOp::preOp(fn, a.half[0]); + a.half[1] = Apply_PreOp::preOp(fn, a.half[1]); + } + return a; } - __device__ half operator()(const half x, const half y) const { +}; +// Base case definition (EltPerPack == 1), by default is identity function. +template +struct Apply_PreOp { + static constexpr bool IsIdentity = true; + template + __device__ static BytePack preOp(Fn fn, BytePack a) { + return a; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Apply_PostOp + +// General recursive definition (EltPerPack > 1) +template +struct Apply_PostOp { + static constexpr bool IsIdentity = Apply_PostOp::IsIdentity; + template + __device__ static BytePack postOp(Fn fn, BytePack a) { + #if __cpp_if_constexpr + if constexpr(!IsIdentity) { + #else + if (!IsIdentity) { + #endif + // The `if (!IsIdentity)` condition is not strictly necessary, but it may help + // compiler in that it won't have to tear a register apart for no reason + // just to put it back together again. + a.half[0] = Apply_PostOp::postOp(fn, a.half[0]); + a.half[1] = Apply_PostOp::postOp(fn, a.half[1]); + } + return a; + } +}; +// Base case definition (EltPerPack == 1), by default is identity function. +template +struct Apply_PostOp { + static constexpr bool IsIdentity = true; + template + __device__ static BytePack postOp(Fn fn, BytePack a) { + return a; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +// FuncPreMulSum + +// General definition for all integral types, float, and double. +template +struct FuncPreMulSum { + using EltType = T; + T scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; T val; }; + u64 = opArg; + scalar = val; + } +}; + +template<> +struct FuncPreMulSum { + using EltType = half; #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - return __hadd(x, y); -#else - return __float2half( __half2float(x) + __half2float(y) ); -#endif + half2 scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; half val; }; + u64 = opArg; + scalar.x = val; + scalar.y = val; } +#else + float scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; half val; }; + u64 = opArg; + scalar = __half2float(val); + } +#endif }; #if defined(__CUDA_BF16_TYPES_EXIST__) -template<> -struct FuncSum<__nv_bfloat16> { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const { -#if __CUDA_ARCH__ >= 800 - return __hadd2(x, y); -#else - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); -#endif - } - __device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const { -#if __CUDA_ARCH__ >= 800 - return __hadd(x, y); -#else - return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); -#endif - } -}; + template<> + struct FuncPreMulSum<__nv_bfloat16> { + using EltType = __nv_bfloat16; + #if __CUDA_ARCH__ >= 800 + __nv_bfloat162 scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; __nv_bfloat16 val; }; + u64 = opArg; + scalar.x = val; + scalar.y = val; + } + #else + float scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; __nv_bfloat16 val; }; + u64 = opArg; + scalar = __bfloat162float(val); + } + #endif + }; #endif +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncPreMulSum fn, BytePack a, BytePack b) { + // FuncPreMulSum reduce dispatches to FuncSum. + return Apply_Reduce, 1>::reduce(FuncSum(), a, b); + } +}; + +// PreOp of FuncPreMulSum for integral types, float, and double. +template +struct Apply_PreOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp(FuncPreMulSum fn, BytePack a) { + return toPack(fromPack(a) * fn.scalar); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Apply_PreOp of FuncPreMulSum for float16. + template<> -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ half2 operator()(const half2 x, const half2 y) const { +struct Apply_PreOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp(FuncPreMulSum fn, BytePack a) { + #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 + return toPack(__hmul(fromPack(a), fn.scalar.x)); + #else + return toPack(__float2half(__half2float(fromPack(a)) * fn.scalar)); + #endif + } +}; #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - return __hmul2(x, y); -#else - float2 fx, fy, fr; - fx = __half22float2(x); - fy = __half22float2(y); - fr.x = fx.x * fy.x; - fr.y = fx.y * fy.y; - return __float22half2_rn(fr); + template<> + struct Apply_PreOp, /*EltPerPack=*/2> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp(FuncPreMulSum fn, BytePack a) { + return toPack(__hmul2(fromPack(a), fn.scalar)); + } + }; #endif - } - __device__ half operator()(const half x, const half y) const { -#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - return __hmul(x, y); -#else - return __float2half( __half2float(x) * __half2float(y) ); -#endif - } -}; + +//////////////////////////////////////////////////////////////////////////////// +// Apply_PreOp of FuncPreMulSum for bfloat16. #if defined(__CUDA_BF16_TYPES_EXIST__) -template<> -struct FuncProd<__nv_bfloat16> { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmul2(x, y); -#else - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); -#endif - } - __device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmul(x, y); -#else - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); -#endif - } -}; + template<> + struct Apply_PreOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp( + FuncPreMulSum<__nv_bfloat16> fn, BytePack a + ) { + #if __CUDA_ARCH__ >= 800 + return toPack<__nv_bfloat16>(__hmul(fromPack<__nv_bfloat16>(a), fn.scalar.x)); + #else + return toPack<__nv_bfloat16>(__float2bfloat16(__bfloat162float(fromPack<__nv_bfloat16>(a)) * fn.scalar)); + #endif + } + }; + #if __CUDA_ARCH__ >= 800 + template<> + struct Apply_PreOp, /*EltPerPack=*/2> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp( + FuncPreMulSum<__nv_bfloat16> fn, BytePack a + ) { + return toPack<__nv_bfloat162>(__hmul2(fromPack<__nv_bfloat162>(a), fn.scalar)); + } + }; + #endif #endif -template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ half2 operator()(const half2 x, const half2 y) const { - float2 fx, fy, fr; - fx = __half22float2(x); - fy = __half22float2(y); - fr.x = fmaxf(fx.x, fy.x); - fr.y = fmaxf(fx.y, fy.y); - return __float22half2_rn(fr); - } - __device__ half operator()(const half x, const half y) const { - float fx, fy, fm; - fx = __half2float(x); - fy = __half2float(y); - fm = fmaxf(fx, fy); - return __float2half(fm); - } -}; - -#if defined(__CUDA_BF16_TYPES_EXIST__) -template<> -struct FuncMax<__nv_bfloat16> { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmax2(x, y); -#else - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fmaxf(fxl, fyl), fmaxf(fxh, fyh)); -#endif - } - __device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmax(x, y); -#else - float fx, fy; - fx = __bfloat162float(x); - fy = __bfloat162float(y); - return __float2bfloat16(fmaxf(fx, fy)); -#endif - } -}; -#endif - -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ half2 operator()(const half2 x, const half2 y) const { - float2 fx, fy, fr; - fx = __half22float2(x); - fy = __half22float2(y); - fr.x = fminf(fx.x, fy.x); - fr.y = fminf(fx.y, fy.y); - return __float22half2_rn(fr); - } - __device__ half operator()(const half x, const half y) const { - float fx, fy, fm; - fx = __half2float(x); - fy = __half2float(y); - fm = fminf(fx, fy); - return __float2half(fm); - } -}; - -#if defined(__CUDA_BF16_TYPES_EXIST__) -template<> -struct FuncMin<__nv_bfloat16> { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmin2(x, y); -#else - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fminf(fxl, fyl), fminf(fxh, fyh)); -#endif - } - __device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmin(x, y); -#else - float fx, fy; - fx = __bfloat162float(x); - fy = __bfloat162float(y); - return __float2bfloat16(fminf(fx, fy)); -#endif - } -}; -#endif - -template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ float operator()(float x, float y) const { - return fmaxf(x, y); - } -}; -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ float operator()(float x, float y) const { - return fminf(x, y); - } -}; - -template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ double operator()(double x, double y) const { - return fmax(x, y); - } -}; -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ double operator()(double x, double y) const { - return fmin(x, y); - } -}; +//////////////////////////////////////////////////////////////////////////////// +// FuncSumPostDiv template struct IsFloatingPoint: std::false_type {}; @@ -483,223 +464,128 @@ template<> struct IsFloatingPoint: std::true_type {}; template::value> -struct FuncSumPostDiv; +struct FuncSumPostDiv_IntOnly; template -struct FuncSumPostDiv: FuncSum { - static constexpr bool IsPreOpIdentity = true; - static constexpr bool IsPostOpIdentity = false; - int n; - __device__ FuncSumPostDiv(uint64_t opArg): n(opArg) {} - // inherits FuncSum::operator() - __device__ T preOp(T x) const { return x; } - __device__ T postOp(T x) const { return T(x/n); } +struct FuncSumPostDiv: FuncSumPostDiv_IntOnly { + __device__ FuncSumPostDiv(uint64_t opArg=0): + FuncSumPostDiv_IntOnly(opArg) { + } }; template -struct FuncSumPostDiv { +struct FuncSumPostDiv_IntOnly: FuncSum { + using EltType = T; + int divisor; + __device__ FuncSumPostDiv_IntOnly(uint64_t opArg=0): divisor(opArg) {} +}; + +template +struct FuncSumPostDiv_IntOnly { static_assert(sizeof(T)!=sizeof(T), "FuncSumPostDiv is only for implementing ncclAvg on integral types."); }; template -struct FuncPreMulSum: FuncSum { // integral T since all floats are specialized below - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - T scale; - __device__ FuncPreMulSum(uint64_t opArg) { scale = *(T*)&opArg; } - // inherits FuncSum::operator() - __device__ T preOp(T x) const { return x*scale; } - __device__ T postOp(T x) const { return x; } -}; - -template<> -struct FuncPreMulSum: FuncSum { - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - double scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale = *(double*)&opArg; - } - // inherits FuncSum::operator() - __device__ double preOp(double x) const { - return IsPreOpIdentity ? x : x*scale; - } - __device__ double postOp(double x) const { - return IsPostOpIdentity ? x : x*scale; +struct Apply_Reduce, /*EltPerPack=*/1>: + Apply_Reduce, 1> { + __device__ static BytePack reduce(FuncSumPostDiv fn, BytePack a, BytePack b) { + // FuncSumPostDiv reduce dispatches to FuncSum. + return Apply_Reduce, 1>::reduce(FuncSum(), a, b); } }; -template<> -struct FuncPreMulSum: FuncSum { - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - float scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale = *(float*)&opArg; - } - // inherits FuncSum::operator() - __device__ float preOp(float x) const { - return IsPreOpIdentity ? x : x*scale; - } - __device__ float postOp(float x) const { - return IsPostOpIdentity ? x : x*scale; - } -}; - -template<> -struct FuncPreMulSum: FuncSum { - // Change these to switch between all prescale, all postscale, or both by sqrt(N). - // Obviously, the only invalid combination is both true. An improvement would be - // make this parameterized as a build time setting and passed here through - // preprocessor definitions. - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - -#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - half2 scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale.x = *(half*)&opArg; - scale.y = scale.x; - } - // inherits FuncSum::operator() - __device__ half preOp(half x) const { - return IsPreOpIdentity ? x : __hmul(x, scale.x); - } - __device__ half2 preOp(half2 x) const { - return IsPreOpIdentity ? x : __hmul2(x, scale); - } - __device__ half postOp(half x) const { - return IsPostOpIdentity ? x : __hmul(x, scale.x); - } - __device__ half2 postOp(half2 x) const { - return IsPostOpIdentity ? x : __hmul2(x, scale); - } -#else - float scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale = __half2float(*(half*)&opArg); - } - // inherits FuncSum::operator() - __device__ half preOp(half x) const { - return IsPreOpIdentity ? x : __float2half(__half2float(x)*scale); - } - __device__ half2 preOp(half2 x) const { - if (IsPreOpIdentity) - return x; - else { - float2 a = __half22float2(x); - a.x *= scale; - a.y *= scale; - return __float22half2_rn(a); - } - } - __device__ half postOp(half x) const { - return IsPostOpIdentity ? x : __float2half(__half2float(x)*scale); - } - __device__ half2 postOp(half2 x) const { - if (IsPostOpIdentity) - return x; - else { - float2 a = __half22float2(x); - a.x *= scale; - a.y *= scale; - return __float22half2_rn(a); - } - } -#endif -}; - -#if defined(__CUDA_BF16_TYPES_EXIST__) -template<> -struct FuncPreMulSum<__nv_bfloat16>: FuncSum<__nv_bfloat16> { - // Change these to switch between all prescale, all postscale, or both by sqrt(N). - // Obviously, the only invalid combination is both true. An improvement would be - // make this parameterized as a build time setting and passed here through - // preprocessor definitions. - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - -#if __CUDA_ARCH__ >= 800 - __nv_bfloat162 scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale.x = *(__nv_bfloat16*)&opArg; - scale.y = scale.x; - } - // inherits FuncSum::operator() - __device__ __nv_bfloat16 preOp(__nv_bfloat16 x) const { - return IsPreOpIdentity ? x : __hmul(x, scale.x); - } - __device__ __nv_bfloat162 preOp(__nv_bfloat162 x) const { - return IsPreOpIdentity ? x : __hmul2(x, scale); - } - __device__ __nv_bfloat16 postOp(__nv_bfloat16 x) const { - return IsPostOpIdentity ? x : __hmul(x, scale.x); - } - __device__ __nv_bfloat162 postOp(__nv_bfloat162 x) const { - return IsPostOpIdentity ? x : __hmul2(x, scale); - } -#else - float scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale = *(__nv_bfloat16*)&opArg; - } - // inherits FuncSum::operator() - __device__ __nv_bfloat16 preOp(__nv_bfloat16 x) const { - return IsPreOpIdentity ? x : __float2bfloat16(__bfloat162float(x)*scale); - } - __device__ __nv_bfloat162 preOp(__nv_bfloat162 x) const { - if (IsPreOpIdentity) - return x; - else { - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x); - return __floats2bfloat162_rn(fxl * scale, fxh * scale); - } - } - __device__ __nv_bfloat16 postOp(__nv_bfloat16 x) const { - return IsPostOpIdentity ? x : __float2bfloat16(__bfloat162float(x)*scale); - } - __device__ __nv_bfloat162 postOp(__nv_bfloat162 x) const { - if (IsPostOpIdentity) - return x; - else { - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x); - return __floats2bfloat162_rn(fxl * scale, fxh * scale); - } - } -#endif -}; -#endif - template -struct FuncTraits> { - static constexpr bool IsPreOpIdentity = FuncPreMulSum::IsPreOpIdentity; - static constexpr bool IsPostOpIdentity = FuncPreMulSum::IsPostOpIdentity; - - template - __device__ static U preOp(FuncPreMulSum fn, U x) { - return fn.preOp(x); - } - template - __device__ static U postOp(FuncPreMulSum fn, U x) { - return fn.postOp(x); +struct Apply_PostOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack postOp(FuncSumPostDiv fn, BytePack a) { + return toPack(fromPack(a) / fn.divisor); } }; -template -struct FuncTraits> { - static constexpr bool IsPreOpIdentity = FuncSumPostDiv::IsPreOpIdentity; - static constexpr bool IsPostOpIdentity = FuncSumPostDiv::IsPostOpIdentity; - template - __device__ static U preOp(FuncSumPostDiv fn, U x) { - return fn.preOp(x); - } - template - __device__ static U postOp(FuncSumPostDiv fn, U x) { - return fn.postOp(x); - } +//////////////////////////////////////////////////////////////////////////////// +// Apply_LoadMultimem + +template +struct Apply_LoadMultimem { + static constexpr int PackSize = 0; // Indicates not implemented }; + +#define SIZEOF_BytePack_field_u16 2 +#define PTX_REG_BytePack_field_u16 "h" + +#define SIZEOF_BytePack_field_u32 4 +#define PTX_REG_BytePack_field_u32 "r" + +#define SIZEOF_BytePack_field_u64 8 +#define PTX_REG_BytePack_field_u64 "l" + +#define DEFINE_Apply_LoadMultimem(Fn, T, op, ptx_ty, pack_field) \ + template<> \ + struct Apply_LoadMultimem> { \ + static constexpr int PackSize = 1*(SIZEOF_BytePack_field_##pack_field); \ + __device__ static BytePack load(Fn fn, uintptr_t addr) { \ + BytePack ans; \ + asm("multimem.ld_reduce.global." #op "." #ptx_ty " %0, [%1];" \ + : "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field) \ + : "l"(addr)); \ + return ans; \ + } \ + }; +#define DEFINE_Apply_LoadMultimem_v4(Fn, T, op, ptx_ty, pack_field) \ + template<> \ + struct Apply_LoadMultimem> { \ + static constexpr int PackSize = 4*(SIZEOF_BytePack_field_##pack_field); \ + __device__ static BytePack load(Fn fn, uintptr_t addr) { \ + BytePack ans; \ + asm("multimem.ld_reduce.global." #op ".v4." #ptx_ty " {%0,%1,%2,%3}, [%4];" \ + : "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[0]), \ + "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[1]), \ + "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[2]), \ + "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[3]) \ + : "l"(addr)); \ + return ans; \ + } \ + }; + +#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 + DEFINE_Apply_LoadMultimem(FuncSum, uint32_t, add, u32, u32) + DEFINE_Apply_LoadMultimem(FuncMin, uint32_t, min, u32, u32) + DEFINE_Apply_LoadMultimem(FuncMax, uint32_t, max, u32, u32) + + DEFINE_Apply_LoadMultimem(FuncSum, int32_t, add, s32, u32) + DEFINE_Apply_LoadMultimem(FuncMin, int32_t, min, s32, u32) + DEFINE_Apply_LoadMultimem(FuncMax, int32_t, max, s32, u32) + + DEFINE_Apply_LoadMultimem(FuncSum, uint64_t, add, u64, u64) + DEFINE_Apply_LoadMultimem(FuncMin, uint64_t, min, u64, u64) + DEFINE_Apply_LoadMultimem(FuncMax, uint64_t, max, u64, u64) + + DEFINE_Apply_LoadMultimem(FuncSum, int64_t, add, u64, u64) + DEFINE_Apply_LoadMultimem(FuncMin, int64_t, min, s64, u64) + DEFINE_Apply_LoadMultimem(FuncMax, int64_t, max, s64, u64) + + DEFINE_Apply_LoadMultimem_v4(FuncSum, float, add, f32, u32) + + DEFINE_Apply_LoadMultimem(FuncSum, double, add, f64, u64) + + DEFINE_Apply_LoadMultimem_v4(FuncSum, half, add, f16x2, u32) + DEFINE_Apply_LoadMultimem_v4(FuncMin, half, min, f16x2, u32) + DEFINE_Apply_LoadMultimem_v4(FuncMax, half, max, f16x2, u32) + + #if defined(__CUDA_BF16_TYPES_EXIST__) + DEFINE_Apply_LoadMultimem_v4(FuncSum, __nv_bfloat16, add, bf16x2, u32) + DEFINE_Apply_LoadMultimem_v4(FuncMin, __nv_bfloat16, min, bf16x2, u32) + DEFINE_Apply_LoadMultimem_v4(FuncMax, __nv_bfloat16, max, bf16x2, u32) + #endif +#endif + +#undef DEFINE_Apply_LoadMultimem +#undef DEFINE_Apply_LoadMultimem_v4 +#undef SIZEOF_BytePack_field_u64 +#undef PTX_REG_BytePack_field_u64 +#undef SIZEOF_BytePack_field_u32 +#undef PTX_REG_BytePack_field_u32 +#undef SIZEOF_BytePack_field_u16 +#undef PTX_REG_BytePack_field_u16 + #endif // REDUCE_KERNEL_H_ diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h index 754889a..c448e59 100644 --- a/src/collectives/device/reduce_scatter.h +++ b/src/collectives/device/reduce_scatter.h @@ -87,3 +87,45 @@ struct RunWorkElement(args); } }; + +template +struct RunWorkElement { + __device__ __forceinline__ void run(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nChannels = args->nChannels; + struct ncclNvls* nvls = &ncclShmem.channel.nvls; + const ssize_t chunkSize = int(args->lastChunkSize); + const ssize_t size = args->count; + const ssize_t loopSize = nChannels*chunkSize; + + const int nThreadsScatter = 128 + WARP_SIZE; + const int nThreadsReduce = 384; + const int tidEndScatter = nThreadsScatter; + const int tidEndReduce = tidEndScatter + nThreadsReduce; + + using Proto = ProtoSimple<1, 1>; + + if (tid < tidEndScatter) { + // Scatter + int group = (0*Proto::MaxGroupWidth) | (0<<16); + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.scatter(offset, nvls->nHeads*size, nelem, size, -1, 0); + } + } else if (tid < tidEndReduce) { + int group = (3*Proto::MaxGroupWidth) | (1<<16); + // Reduce through MC + Primitives, /*Direct=*/0, Proto, 0> + prims(tid-tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.recv(offset, nelem); + } + } + } +}; diff --git a/src/collectives/device/sendrecv.h b/src/collectives/device/sendrecv.h index ec1e20c..41fe0c2 100644 --- a/src/collectives/device/sendrecv.h +++ b/src/collectives/device/sendrecv.h @@ -13,12 +13,13 @@ struct RunWork { template __device__ void runSend(const int tid, const int nthreads, const int group, struct ncclWorkElemP2p* args) { void* buff = reinterpret_cast(uintptr_t(args->buffHi32)<<32 | args->buffLo32); - size_t count = reinterpret_cast(size_t(args->countHi32)<<32 | args->countLo32); + ssize_t count = reinterpret_cast(size_t(args->countHi32)<<32 | args->countLo32); if (args->peer == ncclShmem.comm.rank) { struct ncclWorkElemP2p* recvArgs = args-1; void* recvBuff = reinterpret_cast(uintptr_t(recvArgs->buffHi32)<<32 | recvArgs->buffLo32); if (buff != recvBuff) { - ReduceOrCopyMulti(tid, nthreads, nullptr, false, 1, (const T**)&buff, 1, (T**)&recvBuff, count); + ReduceOrCopyMulti + (tid, nthreads, 0, nullptr, false, 1, &buff, 1, &recvBuff, count); } } else { int chunkSize = args->chunkSize/sizeof(T); diff --git a/src/debug.cc b/src/debug.cc index 5955a6e..560c1d2 100644 --- a/src/debug.cc +++ b/src/debug.cc @@ -74,6 +74,8 @@ void ncclDebugInit() { mask = NCCL_ALLOC; } else if (strcasecmp(subsys, "CALL") == 0) { mask = NCCL_CALL; + } else if (strcasecmp(subsys, "NVLS") == 0) { + mask = NCCL_NVLS; } else if (strcasecmp(subsys, "ALL") == 0) { mask = NCCL_ALL; } diff --git a/src/enqueue.cc b/src/enqueue.cc index 0744e09..85f2ac9 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -32,7 +32,8 @@ struct ncclKernelMatch { NCCL_FUNC5(func, TREE, devredop, type, specialized), \ NCCL_FUNC5(func, RING, devredop, type, specialized), \ NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, specialized), \ - NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, specialized) + NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, specialized), \ + NCCL_FUNC5(func, NVLS, devredop, type, specialized) #ifdef __CUDA_BF16_TYPES_EXIST__ #define HAVE_BFLOAT16 1 @@ -90,34 +91,48 @@ static const ncclKernelMatch ncclKerns[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNum static ncclResult_t computeColl(struct ncclInfo* info /* input */, int* workFuncIndex, struct ncclWorkElem* work, struct ncclProxyOp* proxyOp /* output */); -// Determine the maximum kernel stack size of all CUDA kernels -size_t ncclKernMaxLocalSize() { - ncclResult_t res = ncclSuccess; - int numNcclKerns = sizeof(ncclKerns)/sizeof(ncclKerns[0]); - cudaFuncAttributes attr = {0}; - size_t max = 0; - for (int i = 0; i < numNcclKerns; i++) { - CUDACHECKGOTO(cudaFuncGetAttributes(&attr, ncclKerns[i].kernelFn), res, error); - if (attr.localSizeBytes > max) max = attr.localSizeBytes; +NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0); + +// Returns maximum kernel stack size of all CUDA kernels +ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize) { + constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]); + ncclResult_t result = ncclSuccess; + + if (maxStackSize) *maxStackSize = 0; + int carveout = ncclParamL1SharedMemoryCarveout(); + + // Keep track if we already visited a function pointer. + void* lru[2] = {nullptr, nullptr}; + for (int i=0; i < KernelCount; i++) { + void* fn = ncclKerns[i].kernelFn; + if (fn == lru[0] || fn == lru[1]) goto next_kernel; + lru[1] = lru[0]; + lru[0] = fn; + + if (maxStackSize) { + cudaFuncAttributes attr = {0}; + CUDACHECKGOTO(cudaFuncGetAttributes(&attr, fn), result, ignore0); + if (attr.localSizeBytes > *maxStackSize) *maxStackSize = attr.localSizeBytes; + ignore0:; + } + + if (carveout) { + CUDACHECKGOTO(cudaFuncSetAttribute(fn, + cudaFuncAttributePreferredSharedMemoryCarveout, carveout), + result, ignore1); + ignore1:; + } + + if (ncclShmemDynamicSize(cudaArch) != 0) { + CUDACHECKGOTO(cudaFuncSetAttribute(fn, + cudaFuncAttributeMaxDynamicSharedMemorySize, ncclShmemDynamicSize(cudaArch)), + result, next_kernel); + } + next_kernel:; } - -error: - return (res != ncclSuccess) ? 0 : max; + return result; } -// Set shared memory carveout for the nccl kernels -ncclResult_t ncclKernSetSharedMemoryCarveout(int carveOut) { - ncclResult_t res = ncclSuccess; - int numNcclKerns = sizeof(ncclKerns)/sizeof(ncclKerns[0]); - for (int i = 0; i < numNcclKerns; i++) { - CUDACHECKGOTO(cudaFuncSetAttribute(ncclKerns[i].kernelFn, cudaFuncAttributePreferredSharedMemoryCarveout, carveOut), res, error); - } - -error: - return res; -} - - /*****************************************************************************/ /* Launch system : synchronization and CUDA kernel launch */ /*****************************************************************************/ @@ -248,10 +263,9 @@ static ncclResult_t addProxyOpIfNeeded(struct ncclComm* comm, struct ncclKernelP static ncclResult_t addCollToPlan( struct ncclComm* comm, struct ncclKernelPlan* plan, int* nWorkBudget, int funcIndex, struct ncclWorkElem const* workElem, struct ncclProxyOp const* proxyOp, - int nBid, size_t bytes, bool regBufUsed, void* regBufSend[], void* regBufRecv[] + int nCollChannels, int nBid, size_t bytes, bool regBufUsed, void* regBufSend[], void* regBufRecv[] ) { struct ncclKernelPlan::Channel *chans = plan->channels; - int nCollChannels = comm->nChannels; // Choose the `nBid` least loaded channels to do the work. This ensures // all bids go to different channels in case they need to synchronize. @@ -268,9 +282,7 @@ static ncclResult_t addCollToPlan( } } // Sort in the rest of the channels. If a channel has less work than the max - // member of least[], replace that member and compute the new max. The optimal - // algorithm uses a max-heap, but for our small sizes I suspect the better - // asymptotic complexity would be swamped by the increased instruction complexity. + // member of least[], replace that member and compute the new max. for (int c=nBid; c < nCollChannels; c++) { if (chans[c].collBytes < maxBytesInLeast) { least[maxIndexInLeast] = c; @@ -541,8 +553,9 @@ static ncclResult_t scheduleCollTasksToPlan( info.sliceSteps = head->sliceSteps; NCCLCHECK(ncclInfoSetDerived(&info, comm->nRanks)); if (nAggOps > 1) { + int maxChannels = aggInfo.algorithm == NCCL_ALGO_NVLS ? comm->nvlsChannels : comm->nChannels; info.nChannels = DIVUP(info.nBytes, bytePerChannel[collNetSupport]); - info.nChannels = std::max(1, std::min(info.nChannels, comm->nChannels)); + info.nChannels = std::max(1, std::min(info.nChannels, maxChannels)); info.algorithm = aggInfo.algorithm; info.protocol = aggInfo.protocol; info.nThreads = aggInfo.nThreads; @@ -565,8 +578,9 @@ static ncclResult_t scheduleCollTasksToPlan( NCCLCHECK(registerIntraNodeBuffers(comm, plan, &info, ®BufUsed, regBufSend, regBufRecv)); } + int maxChannels = info.algorithm == NCCL_ALGO_NVLS ? comm->nvlsChannels : comm->nChannels; NCCLCHECK(addCollToPlan(comm, plan, nWorkBudget, workFuncIndex, &workElem, &proxyOp, - info.nChannels, info.nBytes, regBufUsed, regBufSend, regBufRecv)); + maxChannels, info.nChannels, info.nBytes, regBufUsed, regBufSend, regBufRecv)); tasks->nTasksColl -= 1; tasks->collBytesTotal -= info.nBytes; ncclIntruQueueDequeue(&tasks->collQueue); @@ -856,7 +870,7 @@ static void CUDART_CB hostStreamPlanCallback(void *plan_) { struct ncclKernelPlan* plan = (struct ncclKernelPlan*)plan_; ncclResult_t result = hostStreamPlanTask(plan->comm, plan); if (result != ncclSuccess) { - WARN("hostStreamPlanCallback() failed : %s\n", ncclGetErrorString(result)); + WARN("hostStreamPlanCallback() failed : %s", ncclGetErrorString(result)); } } @@ -964,7 +978,7 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { } NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, launchStream, &comm->deviceStream), result, failure); - if (persistent || comm->persistentRefs != 0) { + if (persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking) { // We have to launch host tasks to push proxy args. We are careful to only // do this if necessary since host tasks impose a high performance cost in CUDA. bool acquired = false; @@ -1005,12 +1019,6 @@ ncclResult_t ncclLaunchKernelBefore_NoUncapturedCuda(struct ncclComm* comm, stru return ncclSuccess; } -#if CUDART_VERSION >= 11080 -#define NCCL_MAX_CGA_CLUSTER_SIZE 8 -#define NCCL_CGA_CLUSTER_SIZE_SM90 4 -NCCL_PARAM(CGAClusterSize, "CGA_CLUSTER_SIZE", -2); -#endif - #if CUDART_VERSION >= 12000 // NCCL uses the "Remote" Mem Sync domain by default NCCL_PARAM(MemSyncDomain, "MEM_SYNC_DOMAIN", cudaLaunchMemSyncDomainRemote); @@ -1022,6 +1030,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan cudaStream_t launchStream = tasks->streams->stream; dim3 grid = {(unsigned)plan->channelCount, 1, 1}; dim3 block = {(unsigned)plan->threadPerBlock, 1, 1}; + size_t smem = ncclShmemDynamicSize(comm->cudaArch); void *args[3] = {&comm->devComm, &plan->channelMask, &plan->workHead}; #if CUDART_VERSION >= 11080 @@ -1029,19 +1038,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan NCCLCHECK(ncclCudaDriverVersion(&driverVersion)); if (driverVersion >= 11080) { int compCap = comm->compCap; - unsigned int clusterSize = (compCap == 90) ? NCCL_CGA_CLUSTER_SIZE_SM90 : 0; - if (ncclParamCGAClusterSize() != -2) { - clusterSize = ncclParamCGAClusterSize(); - if (clusterSize > NCCL_MAX_CGA_CLUSTER_SIZE) { - static bool warned = false; - if (warned == false) { - WARN("NCCL_CGA_CLUSTER_SIZE value %d is too big. Limiting value to %d.", - clusterSize, NCCL_MAX_CGA_CLUSTER_SIZE); - warned = true; - } - clusterSize = NCCL_MAX_CGA_CLUSTER_SIZE; - } - } + unsigned int clusterSize = (compCap == 90) ? comm->cgaClusterSize : 0; cudaLaunchConfig_t launchConfig = {0}; cudaLaunchAttribute launchAttrs[3]; @@ -1073,6 +1070,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan #endif launchConfig.gridDim = grid; launchConfig.blockDim = block; + launchConfig.dynamicSmemBytes = smem; launchConfig.attrs = launchAttrs; launchConfig.numAttrs = attrs; launchConfig.stream = launchStream; @@ -1082,12 +1080,12 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan } #endif // Standard kernel launch - CUDACHECK(cudaLaunchKernel(fn, grid, block, args, 0, launchStream)); + CUDACHECK(cudaLaunchKernel(fn, grid, block, args, smem, launchStream)); return ncclSuccess; } ncclResult_t ncclLaunchKernelAfter_NoCuda(struct ncclComm* comm, struct ncclKernelPlan* plan) { - if (comm->persistentRefs == 0) { // implies !plan->persistent + if (!(plan->persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking)) { // If this isn't being captured and there aren't any CUDA graphs alive // then we don't need to do our proxyOp pushing on the host stream. NCCLCHECK(hostStreamPlanTask(comm, plan)); @@ -1161,6 +1159,8 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i int nAlgos = NCCL_NUM_ALGORITHMS; for (int a=0; adatatype, info->opFull.op)) continue; + for (int p=0; palgorithm == NCCL_ALGO_NVLS) { + // NVLS should not need more than 16 channels to get peak BW. + nc = comm->nvlsChannels; } else { // Ring/Tree channel tuning while (info->nBytes < nc*nt*threadThreshold) { @@ -1207,6 +1210,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i if (info->algorithm == NCCL_ALGO_TREE) nt += 3*WARP_SIZE; if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) nt += 3*WARP_SIZE; if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN) nt += 3*WARP_SIZE; + if (info->algorithm == NCCL_ALGO_NVLS) nt = NCCL_MAX_NTHREADS; } nt = nt/WARP_SIZE < 3 ? 3*WARP_SIZE : nt; info->nChannels = nc; @@ -1225,6 +1229,7 @@ static ncclResult_t getPatternInfo(struct ncclInfo* info) { info->pattern = ncclPatternRing; break; case ncclFuncAllReduce: info->pattern = + info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls : info->algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect : info->algorithm == NCCL_ALGO_COLLNET_CHAIN ? ncclPatternCollnetChain : info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : @@ -1244,6 +1249,7 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) { case ncclPatternPipelineFrom: case ncclPatternPipelineTo: case ncclPatternCollnetChain: + case ncclPatternNvls: info->nstepsPerLoop = info-> nchunksPerLoop = 1; break; case ncclPatternCollnetDirect: info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].collnetDirect.nHeads; break; @@ -1319,6 +1325,14 @@ comp_next: while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*8 && chunkSize > 65536) chunkSize /= 2; while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2; work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); + } else if (info->algorithm == NCCL_ALGO_NVLS) { + if (chunkSize > 131072) chunkSize = 131072; + // Use uint64_t so that concurrentOps*chunkSize*X does not overflow + uint64_t concurrentOps = info->nChannels*info->comm->channels[0].nvls.nHeads; + if ((info->nBytes < (32 * (concurrentOps*chunkSize))) && (chunkSize > 65536)) chunkSize = 65536; + if ((info->nBytes < (8 * (concurrentOps*chunkSize))) && (chunkSize > 32768)) chunkSize = 32768; + if ((info->nBytes < (2 * (concurrentOps*chunkSize))) && (chunkSize > 16384)) chunkSize = 16384; + work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); } else if (info->protocol == NCCL_PROTO_LL) { const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine); const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize; @@ -1618,6 +1632,11 @@ ncclResult_t ncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm) { WARN("ncclRedOpDestroy : operator is garbage."); return ncclInvalidArgument; } + if (comm == NULL) { + WARN("ncclRedOpDestroy : invalid communicator passed."); + return ncclInvalidArgument; + } + int ix = int(ncclUserRedOpMangle(comm, op)) - int(ncclNumOps); if (comm->userRedOpCapacity <= ix || comm->userRedOps[ix].freeNext != -1) { WARN("ncclRedOpDestroy : operator unknown to this communicator."); diff --git a/src/graph/connect.cc b/src/graph/connect.cc index ccf1e04..68f7572 100644 --- a/src/graph/connect.cc +++ b/src/graph/connect.cc @@ -313,8 +313,8 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa // Honor NCCL_MIN_NRINGS/NCCL_MAX_NRINGS. // We permit combining max, then min, to only use the first channels, then duplicate them. - nChannels = comm->nChannels = std::min((int)ncclMaxNchannels(), nChannels); - nChannels = comm->nChannels = copyChannels(comm, nChannels, ncclMinNchannels(), ringPrev, ringNext); + nChannels = comm->nChannels = std::min(std::min(ncclMaxNchannels(), nChannels), comm->maxCTAs); + nChannels = comm->nChannels = copyChannels(comm, nChannels, std::max(ncclMinNchannels(), comm->minCTAs), ringPrev, ringNext); // Create rings array and check all is fine NCCLCHECK(ncclBuildRings(nChannels, rings, comm->rank, comm->nRanks, ringPrev, ringNext)); diff --git a/src/graph/paths.cc b/src/graph/paths.cc index 7134b90..728b55f 100644 --- a/src/graph/paths.cc +++ b/src/graph/paths.cc @@ -461,7 +461,7 @@ ncclResult_t ncclTopoGetIntermediateRank(struct ncclTopoSystem* system, int rank type = node->type; } if (type != GPU) { - WARN("Could not find intermediate GPU between GPU rank %d and NIC %d\n", rank, netDev); + WARN("Could not find intermediate GPU between GPU rank %d and NIC %d", rank, netDev); return ncclInternalError; } *intermediateRank = node->gpu.rank; @@ -707,6 +707,7 @@ static int nextPow2(int v) { } ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm) { + /* here we already honor comm->max/minCTAs for p2pnChannels. */ comm->p2pnChannels = std::min(comm->nChannels, (int)ncclParamMaxP2pNChannels()); comm->p2pnChannels = std::max(comm->p2pnChannels, (int)ncclParamMinP2pNChannels()); int minChannels = comm->p2pnChannels; @@ -734,7 +735,6 @@ ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm) { for (int b=1, mb=(comm->p2pnChannels>>1); bp2pnChannels; b<<=1, mb>>=1) if (c & b) mirror |= mb; comm->p2pChannels[c] = mirror; } - INFO(NCCL_INIT, "%d coll channels, %d p2p channels, %d p2p channels per peer", comm->nChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer); return ncclSuccess; } diff --git a/src/graph/search.cc b/src/graph/search.cc index 534d401..f9106ed 100644 --- a/src/graph/search.cc +++ b/src/graph/search.cc @@ -765,7 +765,6 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph if (ngpus == 1) if (graph->pattern != NCCL_TOPO_PATTERN_RING) graph->pattern = NCCL_TOPO_PATTERN_TREE; - // SPLIT_TREE works better on older archs. int ccMin; NCCLCHECK(ncclTopoGetCompCap(system, &ccMin, NULL)); diff --git a/src/graph/topo.cc b/src/graph/topo.cc index 9e4c978..ea36ac3 100644 --- a/src/graph/topo.cc +++ b/src/graph/topo.cc @@ -815,6 +815,6 @@ ncclResult_t ncclTopoGetLocalRank(struct ncclTopoSystem* system, int rank, int* return ncclSuccess; } } - WARN("Could not find local GPU with rank %d\n", rank); + WARN("Could not find local GPU with rank %d", rank); return ncclInternalError; } diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index 18afc03..90cf218 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -53,7 +53,7 @@ ncclResult_t parseList(const char* str, const char* elems[], int nelems, int* li // Latencies in us, Bandwidths in GB/s // Tree { LL, LL128, Simple } , Ring { LL, LL128, Simple } -static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 4.4, 4.4, 0 }, { 3.6, 10.0, 8.4 }, { 4.4, 4.4, 0 }, { 4.4, 4.4, 0 }}; +static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 4.4, 4.4, 0 }, { 3.6, 10.0, 8.4 }, { 4.4, 4.4, 0 }, { 4.4, 4.4, 0 }, { 0, 0, 40.0 }}; // NVLink, PCI, Network #define NCCL_HW_NVLINK 0 @@ -63,13 +63,16 @@ static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 4.4, static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { /* NVLINK */ { /* Tree (LL/LL128/Simple)*/ { .52, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .47, 1.9, 3.4 }, - /* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 8.0 } }, + /* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 8.0 }, + /* NVLS */ { 0, 0, 0 } }, /* PCI */ { /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 }, - /* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 8.0 } }, + /* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 8.0 }, + /* NVLS */ { 0, 0, 0 } }, /* NET */ { /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 28 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 9.6 }, - /* CollNetDirect (Simple)*/ { 0, 0, 10.7 }, /* CollNetChain (Simple)*/ { 0, 0, 10.7 } } + /* CollNetDirect (Simple)*/ { 0, 0, 10.7 }, /* CollNetChain (Simple)*/ { 0, 0, 10.7 }, + /* NVLS */ { 0, 0, 0 } } }; /* Array indexes used below */ @@ -78,7 +81,7 @@ static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = #define HOPPER_COMPCAP_IDX 2 // LL128 max BW per channel -static const double ll128MaxBwPerCh = 20.0; +static const double ll128MaxBwPerCh[3] = { 20.0, 20.0, 36.7 }; static const double llMaxBws[3][3] = { /* Volta-N1/Intel-N2/Intel-N4) */ {39.0, 39.0, 20.4}, /* Ampere-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0}, @@ -88,7 +91,7 @@ static const double llMaxBws[3][3] = { static const double perChMaxTreeBws[3][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 23.6, 17.8}, - /* Hopper (N1/N2/N4) */ {24.0, 23.6, 17.8}, + /* Hopper (N1/N2/N4) */ {38.7, 41.4, 33.0}, }; ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph) { @@ -98,7 +101,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_SIMPLE_MAX_NTHREADS, NCCL_SIMPLE_MAX_NTHREADS); comm->maxThreads[NCCL_ALGO_COLLNET_DIRECT][NCCL_PROTO_SIMPLE] = - comm->maxThreads[NCCL_ALGO_COLLNET_CHAIN][NCCL_PROTO_SIMPLE] = NCCL_SIMPLE_MAX_NTHREADS; + comm->maxThreads[NCCL_ALGO_COLLNET_CHAIN][NCCL_PROTO_SIMPLE] = + comm->maxThreads[NCCL_ALGO_NVLS][NCCL_PROTO_SIMPLE] = NCCL_SIMPLE_MAX_NTHREADS; comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL] = getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_LL_MAX_NTHREADS, NCCL_LL_MAX_NTHREADS); comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL128] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL128] = @@ -108,7 +112,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom int nRanks = comm->nRanks; if (nRanks <= 1) return ncclSuccess; - int compCapIndex = (minCompCap == 80 && maxCompCap == 80) ? AMPERE_COMPCAP_IDX : ((minCompCap == 90 && maxCompCap == 90) ? HOPPER_COMPCAP_IDX : VOLTA_COMPCAP_IDX); + int compCapIndex = minCompCap >= 90 ? HOPPER_COMPCAP_IDX : minCompCap >= 80 ? AMPERE_COMPCAP_IDX : VOLTA_COMPCAP_IDX; int cpuArch, cpuVendor, cpuModel; NCCLCHECK(ncclTopoCpuType(comm->topo, &cpuArch, &cpuVendor, &cpuModel)); int index2 = nNodes <= 2 ? nNodes-1 : 2; @@ -120,7 +124,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE]; float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount - struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph, collNetGraph }; + struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph, collNetGraph, ringGraph/* we only need the NVSwitch speed for NVLS*/ }; int intraHw[NCCL_NUM_ALGORITHMS], hw[NCCL_NUM_ALGORITHMS]; for (int a=0; atypeIntra == LINK_NVL ? NCCL_HW_NVLINK : NCCL_HW_PCI; for (int a=0; abwIntra : graphs[a]->bwInter; float busBw = graphs[a]->nChannels * bw; // Various model refinements if (compCapIndex == AMPERE_COMPCAP_IDX) busBw = std::min(busBw, 235.0f); + if (compCapIndex == HOPPER_COMPCAP_IDX) busBw = std::min(busBw, 370.0f); if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); } - if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh*graphs[a]->nChannels); + if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[compCapIndex]*graphs[a]->nChannels); if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw); if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw = std::min(busBw*1.0/3.8, llMaxBw); - if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), ll128MaxBwPerCh*graphs[a]->nChannels); + if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), ll128MaxBwPerCh[compCapIndex]*graphs[a]->nChannels); if (a == NCCL_ALGO_COLLNET_DIRECT && p != NCCL_PROTO_SIMPLE) busBw = 0; // Not used if (a == NCCL_ALGO_COLLNET_CHAIN && p != NCCL_PROTO_SIMPLE) busBw = 0; // Not used if (a == NCCL_ALGO_COLLNET_DIRECT && p == NCCL_PROTO_SIMPLE) { @@ -159,7 +168,10 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom if (a == NCCL_ALGO_COLLNET_CHAIN && p == NCCL_PROTO_SIMPLE) busBw *= .75; // Convert bus BW to algorithm BW - float ratio = (a != NCCL_ALGO_RING) ? .5 : (1.0 * nRanks) / nsteps; + float ratio; + if (a == NCCL_ALGO_RING) ratio = (1.0 * nRanks) / nsteps; + else if (a == NCCL_ALGO_NVLS) ratio = .75; + else ratio = .5; comm->bandwidths[coll][a][p] = busBw * ratio; comm->latencies[coll][a][p] = baseLat[a][p]; @@ -195,7 +207,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom // Protocols/Algorithms enable/disable, and user overrides. // All are enabled except ll128 which is enabled by default only in certain cases. int protoEnable[NCCL_NUM_PROTOCOLS] = { 1, 2, 1 }; - int algoEnable[NCCL_NUM_ALGORITHMS] = { 1, 1, 1, 1 }; + int algoEnable[NCCL_NUM_ALGORITHMS] = { 1, 1, 1, 1, 1 }; const char *protoStr = getenv("NCCL_PROTO"); if (protoStr) { @@ -207,6 +219,10 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom INFO(NCCL_ENV, "NCCL_ALGO set by environment to %s", algoStr); NCCLCHECK(parseList(algoStr, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoEnable)); } + + // Disable NVLink SHARP if not supported + if (comm->nvlsSupport == 0 /* || comm->localRanks <= 2*/) algoEnable[NCCL_ALGO_NVLS] = 0; + // Disable CollNet if it is not supported if (comm->collNetSupport == 0) { algoEnable[NCCL_ALGO_COLLNET_DIRECT] = 0; @@ -228,7 +244,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom if (pEnable == 2 && p == NCCL_PROTO_LL128) { // Enable LL128 by default only on Volta/Ampere/Hopper+NVLink. Other cases are not tested and may cause silent data corruption. pEnable = 1; - pEnable &= (graphs[a]->typeInter <= PATH_PXB); + pEnable &= (graphs[a]->typeInter <= PATH_PXB || (minCompCap >= 90 && graphs[a]->typeInter <= PATH_PXN)); pEnable &= (graphs[a]->typeIntra <= PATH_NVL); pEnable &= (minCompCap == maxCompCap); switch (minCompCap) { @@ -239,8 +255,9 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom } } if (pEnable == 0) comm->bandwidths[c][a][p] = 0; - // Only disable algo for Allreduce since others only have one - if (c == ncclFuncAllReduce && algoEnable[a] == 0) comm->bandwidths[c][a][p] = 0; + // Never disable ring for non-allreduce operations. That allows to run real apps with NCCL_ALGO=TREE. + if (a == NCCL_ALGO_RING && c != ncclFuncAllReduce) continue; + if (algoEnable[a] == 0) comm->bandwidths[c][a][p] = 0; } if (comm->rank == 0) { @@ -284,9 +301,9 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom char* str = getenv("NCCL_THREAD_THRESHOLDS"); if (str) { INFO(NCCL_ENV, "NCCL_THREAD_THRESHOLDS set by environment to %s", str); - ssize_t t[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = {{ -2, -2, -2 }, { -2, -2, -2 }, { -2, -2, -2 }, { -2, -2, -2 }}; + ssize_t t[2][NCCL_NUM_PROTOCOLS] = {{ -2, -2, -2 }, { -2, -2, -2 }}; sscanf(str, "%ld %ld %ld %ld %ld %ld", t[0], t[0]+1, t[0]+2, t[1], t[1]+1, t[1]+2); - for (int a=0; a= 0) comm->threadThresholds[a][p] = t[a][p]; } @@ -323,7 +340,9 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int proto if (algorithm == NCCL_ALGO_TREE && logSize < 23) bw *= treeCorrectionFactor[protocol][logSize]; if (info->nChannels != 0) bw = bw / info->comm->nChannels * info->nChannels; if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && info->comm->nNodes > 1 - && info->coll == ncclFuncAllReduce && info->nBytes >= info->comm->nRanks/16.0*65536) lat *= 1.9; // Plateau effect of ring + && info->coll == ncclFuncAllReduce && info->nBytes >= info->comm->nRanks/16.0*65536) { + lat *= info->comm->minCompCap < 90 ? 1.9 : 1.5; // Plateau effect of ring + } // Tree pipelining saves latency in aggregation cases int latCount = algorithm == NCCL_ALGO_RING ? numPipeOps : DIVUP(numPipeOps, NCCL_MAX_WORK_ELEMENTS); *time = lat * latCount + (info->nBytes) / (1000 * bw); diff --git a/src/group.cc b/src/group.cc index ff416e3..3380778 100644 --- a/src/group.cc +++ b/src/group.cc @@ -315,7 +315,7 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) { ret = ncclSystemError; } job->state = ncclGroupJobJoined; - if (job->result != ncclSuccess) { + if (job->result != ncclSuccess && ret == ncclSuccess) { ret = job->result; errorJobAbortFlag = true; } @@ -326,7 +326,6 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) { if (*groupAbortFlag == true || errorJobAbortFlag == true) { *job->abortFlag = 1; - ret = ncclInternalError; } job = job->next; diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index e70db04..2ecea7a 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -25,6 +25,7 @@ ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int s ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); ncclResult_t bootstrapBarrier(void* commState, int *ranks, int rank, int nranks, int tag); ncclResult_t bootstrapIntraNodeAllGather(void* commState, int *ranks, int rank, int nranks, void* allData, int size); +ncclResult_t bootstrapIntraNodeBroadcast(void* commState, int *ranks, int rank, int nranks, int root, void* bcastData, int size); ncclResult_t bootstrapClose(void* commState); ncclResult_t bootstrapAbort(void* commState); #endif diff --git a/src/include/collectives.h b/src/include/collectives.h index f50a379..fa8fe47 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -53,7 +53,8 @@ struct ncclDevRedOpFull { DECL4(func, RING, devredop, type, undef) \ DECL4(func, TREE, devredop, type, undef) \ DECL4(func, COLLNET_DIRECT, devredop, type, undef) \ - DECL4(func, COLLNET_CHAIN, devredop, type, undef) + DECL4(func, COLLNET_CHAIN, devredop, type, undef) \ + DECL4(func, NVLS, devredop, type, undef) #if defined(__CUDA_BF16_TYPES_EXIST__) #define DECL2(func, devredop, undefForFloat) \ @@ -121,4 +122,13 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)(); #define REDUCE_CHUNKSTEPS 1 #define NCCL_MAX_SLICE_PER_CHUNK 2 // max value for CHUNKSTEPS/SLICESTEPS, must accord with above +// We can't use the enum identifiers like ncclSum, ncclFloat, etc since this +// macro will be used in preprocessor conditionals where enums have no meaning. +#define NCCL_NVLS_SUPPORTS(/*ncclDataType_t*/ type, /*ncclDevRedOp_t*/ red) \ + (((type==2 || type==3) && (red==0 || red==2 || red==3)) || \ + ((type==4 || type==5) && (red==0 || red==2 || red==3)) || \ + ((type==6 || type==9) && (red==0 || red==2 || red==3)) || \ + (type==7 && red==0) || \ + (type==8 && red==0)) + #endif diff --git a/src/include/comm.h b/src/include/comm.h index 655292a..0be5abd 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -104,6 +104,7 @@ struct ncclChannel { struct ncclTree tree; struct ncclTree collnetChain; struct ncclDirect collnetDirect; + struct ncclNvls nvls; int id; // index of this channel uint32_t workFifoSent; // last used work index+1 uint64_t p2pOpCount; @@ -177,8 +178,10 @@ struct ncclComm { int nRanks; // number of GPUs in communicator int cudaDev; // my cuda device index int compCap; // compute capability of the GPU + int minCompCap; // min compute capability in the communicator int64_t busId; // my PCI bus ID in int format cpu_set_t cpuAffinity; // CPU affinity of the GPU + int cudaArch; // matches __CUDA_ARCH__ of device int node; int nNodes; @@ -201,6 +204,7 @@ struct ncclComm { // Channels for collectives int nChannels; + int nvlsChannels; // Channels (per peer) for p2p int p2pnChannels; int p2pnChannelsPerPeer; @@ -257,6 +261,10 @@ struct ncclComm { int collNetSupport; int intraHighestTransportType; + // NVLink SHARP (NVLS) support + int nvlsSupport; + void* nvlsResources; + size_t channelSize; // User requested work size (bytes) for channel partitions // Internal streams @@ -288,6 +296,11 @@ struct ncclComm { // communicator mode int blocking; + // CGA cluster size + int cgaClusterSize; + int minCTAs, maxCTAs; + // network interface name + char *netName; // initState is to more conveniently reclaim resources when errors happen. ncclResult_t initState; // flag to indicate if ncclCommFinalize() is called diff --git a/src/include/cudawrap.h b/src/include/cudawrap.h index 0fd5945..317ca2d 100644 --- a/src/include/cudawrap.h +++ b/src/include/cudawrap.h @@ -73,10 +73,32 @@ DECLARE_CUDA_PFN_EXTERN(cuGetErrorName, 6000); DECLARE_CUDA_PFN_EXTERN(cuMemGetAddressRange, 3020); DECLARE_CUDA_PFN_EXTERN(cuCtxCreate, 3020); DECLARE_CUDA_PFN_EXTERN(cuCtxDestroy, 4000); +DECLARE_CUDA_PFN_EXTERN(cuCtxGetCurrent, 4000); DECLARE_CUDA_PFN_EXTERN(cuCtxSetCurrent, 4000); +DECLARE_CUDA_PFN_EXTERN(cuCtxGetDevice, 2000); +// cuMem API support +DECLARE_CUDA_PFN_EXTERN(cuMemAddressReserve, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemAddressFree, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemCreate, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemGetAllocationGranularity, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemExportToShareableHandle, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemImportFromShareableHandle, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemMap, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemRelease, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemSetAccess, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemUnmap, 10020); #if CUDA_VERSION >= 11070 DECLARE_CUDA_PFN_EXTERN(cuMemGetHandleForAddressRange, 11070); // DMA-BUF support #endif +#if CUDA_VERSION >= 12010 +/* NVSwitch Multicast support */ +DECLARE_CUDA_PFN_EXTERN(cuMulticastAddDevice, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastBindMem, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastBindAddr, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastCreate, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastGetGranularity, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastUnbind, 12010); +#endif #endif /* CUDA Driver functions loaded with dlsym() */ @@ -88,6 +110,7 @@ DECLARE_CUDA_PFN_EXTERN(cuGetProcAddress, 11030); ncclResult_t ncclCudaLibraryInit(void); extern int ncclCudaDriverVersionCache; +extern bool ncclCudaLaunchBlocking; // initialized by ncclCudaLibraryInit() inline ncclResult_t ncclCudaDriverVersion(int* driver) { int version = __atomic_load_n(&ncclCudaDriverVersionCache, __ATOMIC_RELAXED); @@ -98,5 +121,4 @@ inline ncclResult_t ncclCudaDriverVersion(int* driver) { *driver = version; return ncclSuccess; } - #endif diff --git a/src/include/devcomm.h b/src/include/devcomm.h index 53d6838..14ff92e 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -15,11 +15,12 @@ typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclFuncSend, ncclFuncRecv, ncclNumFuncs} ncclFunc_t; extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS]; -#define NCCL_NUM_ALGORITHMS 4 // Tree/Ring/CollNet* +#define NCCL_NUM_ALGORITHMS 5 // Tree/Ring/CollNet* #define NCCL_ALGO_TREE 0 #define NCCL_ALGO_RING 1 #define NCCL_ALGO_COLLNET_DIRECT 2 #define NCCL_ALGO_COLLNET_CHAIN 3 +#define NCCL_ALGO_NVLS 4 extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS]; #define NCCL_NUM_PROTOCOLS 3 // Simple/LL/LL128 @@ -78,6 +79,7 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK #define NCCL_DIRECT_NIC 0x04 #define NCCL_IPC_WRITE 0x08 #define NCCL_IPC_READ 0x10 +#define NCCL_NVLS_MIN_POLL 0x20 struct ncclConnInfo { // Regular comm mechanism @@ -85,7 +87,7 @@ struct ncclConnInfo { uint64_t *tail; // Local for recv, remote for send uint64_t *head; // Local for send, remote for recv - int direct; // Direct communication + int flags; // Direct communication / other flags int shared; // Buffers are shared void **ptrExchange; // Pointer exchange for direct communication uint64_t* redOpArgExchange; // PreOp scaler exchange for direct pull case @@ -138,13 +140,22 @@ struct ncclTree { struct ncclDirect { int depth; int out; - int nHeads; - int headRank; - int shift; + int nHeads; // Number of parallel N<->1<->net operations we'll do in parallel; size of up/down + int headRank; // Index in 0..nHeads-1 I am the head rank of. -1 if I'm not a head rank (no local NIC) + int shift; // Shuffling of send/recv for scatter/gather operations, basically localRank%nHeads int up[NCCL_MAX_DIRECT_ARITY]; int down[NCCL_MAX_DIRECT_ARITY]; }; +#define NCCL_MAX_NVLS_ARITY 8 +struct ncclNvls { + int out; + int nHeads; // Number of parallel N<->1<->net operations we'll do in parallel; size of up/down + int headRank; // Index in 0..nHeads-1 I am the head rank of. -1 if I'm not a head rank (no local NIC) + int up[NCCL_MAX_NVLS_ARITY]; + int down; +}; + #define NCCL_MAX_CONNS 2 struct ncclChannelPeer { struct ncclConnector send[NCCL_MAX_CONNS]; @@ -264,6 +275,7 @@ struct alignas(16) ncclDevChannel { struct ncclTree tree; struct ncclTree collnetChain; struct ncclDirect collnetDirect; + struct ncclNvls nvls; uint32_t* workFifoDone; // Location of done counter, device writes index+1 of last work processed }; @@ -288,4 +300,65 @@ struct alignas(16) ncclDevCommAndChannels { struct ncclDevChannel channels[MAXCHANNELS]; }; +#ifdef __CUDA_ARCH__ + #define NCCL_CUDA_ARCH __CUDA_ARCH__ +#else + #define NCCL_CUDA_ARCH 0 +#endif + +template +__host__ __device__ constexpr T min_constexpr(T a) { return a; } +template +__host__ __device__ constexpr T min_constexpr(T a, T b, Ts ...c) { + return min_constexpr((a < b ? a : b), c...); +} + +template +__host__ __device__ constexpr T max_constexpr(T a) { return a; } +template +__host__ __device__ constexpr T max_constexpr(T a, T b, Ts ...c) { + return max_constexpr((a > b ? a : b), c...); +} + +// Calculate the unroll factor given: +// * bytePerPack: number of bytes accessed per instruction +// * insns: max permissible unroll value +// * bytes: desired number of in-flight bytes per iteration ( = unroll*bytePerPack) +__host__ __device__ constexpr int ncclCalcUnroll(int bytePerPack, int insns, int bytes) { + return min_constexpr(insns, (bytes + bytePerPack-1)/bytePerPack); +} + +// Note that all unroll value logic should depend on a given cudaArch argument +// and not __CUDA_ARCH__ since these need to be host-side executable where the +// arch value is strictly runtime only. By defaulting to NCCL_CUDA_ARCH, device +// side code can elide passing the arch for brevity. + +__host__ __device__ constexpr int ncclCollUnroll(int cudaArch = NCCL_CUDA_ARCH) { + // Our collective unroll should move to the same bytes&insns model as NVLS. + return cudaArch >= 800 ? 8 : 4; +} + +__host__ __device__ constexpr int ncclNvlsUnrollBytes(int cudaArch = NCCL_CUDA_ARCH) { return 4*16; } +__host__ __device__ constexpr int ncclNvlsUnrollInsns(int cudaArch = NCCL_CUDA_ARCH) { return 16; } + +__host__ __device__ constexpr int ncclNvlsUnroll(int bytePerPack, int cudaArch = NCCL_CUDA_ARCH) { + return ncclCalcUnroll(bytePerPack, ncclNvlsUnrollInsns(cudaArch), ncclNvlsUnrollBytes(cudaArch)); +} + +// The amount of dynamic shmem per warp +__host__ __device__ constexpr int ncclShmemScratchWarpSize(int cudaArch = NCCL_CUDA_ARCH) { + return (max_constexpr( + /*LL */0, + /*LL128 */(NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE)*sizeof(uint64_t), + /*SIMPLE*/(ncclCollUnroll(cudaArch)*WARP_SIZE + 1)*16, + // NVLS needs an extra 16B to read unaligned data. + /*NVLS */WARP_SIZE*(cudaArch >= 900 ? ncclNvlsUnrollBytes(cudaArch) : 0) + 16 + ) + 15) & -16; // pad to 16 bytes +} + +// The amount of dynamic shmem per block +__host__ __device__ constexpr int ncclShmemDynamicSize(int cudaArch = NCCL_CUDA_ARCH) { + return cudaArch < 700 ? 0 : ncclShmemScratchWarpSize(cudaArch)*(NCCL_MAX_NTHREADS/WARP_SIZE); +} + #endif diff --git a/src/include/enqueue.h b/src/include/enqueue.h index 74b7ccd..634f037 100644 --- a/src/include/enqueue.h +++ b/src/include/enqueue.h @@ -15,8 +15,7 @@ #define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64) #define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */ -size_t ncclKernMaxLocalSize(); -ncclResult_t ncclKernSetSharedMemoryCarveout(int carveOut); +ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize); ncclResult_t ncclEnqueueCheck(struct ncclInfo* info); ncclResult_t ncclLaunchPrepare(struct ncclComm* comm); ncclResult_t ncclLaunchKernelBefore_NoUncapturedCuda(struct ncclComm* comm, struct ncclKernelPlan* plan); diff --git a/src/include/info.h b/src/include/info.h index a770c32..1ce61f9 100644 --- a/src/include/info.h +++ b/src/include/info.h @@ -24,6 +24,7 @@ typedef enum : uint8_t { ncclPatternTreeUpDown, ncclPatternCollnetChain, ncclPatternCollnetDirect, + ncclPatternNvls, ncclPatternSend, ncclPatternRecv } ncclPattern_t; diff --git a/src/include/ipcsocket.h b/src/include/ipcsocket.h new file mode 100644 index 0000000..700f0bc --- /dev/null +++ b/src/include/ipcsocket.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2016-2023, NVIDIA CORPORATION. All rights reserved. + * + * See COPYRIGHT for license information + */ + +#ifndef NCCL_IPCSOCKET_H +#define NCCL_IPCSOCKET_H + +#include "nccl.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NCCL_IPC_SOCKNAME_LEN 64 + +struct ncclIpcSocket { + int fd; + char socketName[NCCL_IPC_SOCKNAME_LEN]; + volatile uint32_t* abortFlag; +}; + +ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, uint64_t hash, volatile uint32_t* abortFlag); +ncclResult_t ncclIpcSocketClose(struct ncclIpcSocket *handle); + +ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd); +ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, int rank, uint64_t hash); + +#endif /* NCCL_IPCSOCKET_H */ diff --git a/src/include/nccl_net.h b/src/include/nccl_net.h index 255a44e..a387e66 100644 --- a/src/include/nccl_net.h +++ b/src/include/nccl_net.h @@ -20,7 +20,7 @@ #define NCCL_NET_MAX_REQUESTS 8 typedef enum {NCCL_LOG_NONE=0, NCCL_LOG_VERSION=1, NCCL_LOG_WARN=2, NCCL_LOG_INFO=3, NCCL_LOG_ABORT=4, NCCL_LOG_TRACE=5} ncclDebugLogLevel; -typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_CALL=512, NCCL_ALL=~0} ncclDebugLogSubSys; +typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_CALL=512, NCCL_PROXY=1024, NCCL_NVLS=2048, NCCL_ALL=~0} ncclDebugLogSubSys; typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); diff --git a/src/include/nvtx.h b/src/include/nvtx.h index 2aeb932..ab32ef2 100644 --- a/src/include/nvtx.h +++ b/src/include/nvtx.h @@ -7,12 +7,12 @@ #ifndef NCCL_NVTX_H_ #define NCCL_NVTX_H_ -#include "nvtx3.hpp" +#include "nvtx3/nvtx3.hpp" -#if __cpp_constexpr >= 201304L && !defined(NVTX3_RELAXED_CONSTEXPR) -#define NVTX3_RELAXED_CONSTEXPR constexpr +#if __cpp_constexpr >= 201304L && !defined(NVTX3_CONSTEXPR_IF_CPP14) +#define NVTX3_CONSTEXPR_IF_CPP14 constexpr #else -#define NVTX3_RELAXED_CONSTEXPR +#define NVTX3_CONSTEXPR_IF_CPP14 #endif // Define all NCCL-provided static schema IDs here (avoid duplicates). @@ -37,7 +37,7 @@ struct nccl_domain{static constexpr char const* name{"NCCL"};}; class payload_schema { public: - NVTX3_RELAXED_CONSTEXPR explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept + explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept { schema_attr.name = schemaName; schema_attr.entries = entries; @@ -74,11 +74,11 @@ class payload_schema { #define NVTX3_FUNC_WITH_PARAMS(ID, S, P) \ static const payload_schema schema{S, std::extent::value, \ NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, #ID}; \ - static ::nvtx3::v1::registered_string const nvtx3_func_name__{__func__}; \ + static ::nvtx3::v1::registered_string_in const nvtx3_func_name__{__func__}; \ nvtxPayloadData_t nvtx3_bpl__[] = { \ {NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, sizeof(P), &(P)}}; \ - ::nvtx3::v1::event_attributes nvtx3_func_attr__{nvtx3_func_name__, nvtx3_bpl__}; \ - ::nvtx3::v1::domain_thread_range const nvtx3_range__{nvtx3_func_attr__}; + ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__, nvtx3_bpl__}; \ + ::nvtx3::v1::scoped_range_in const nvtx3_range__{nvtx3_func_attr__}; extern void initNvtxRegisteredEnums(); diff --git a/src/include/nvtx3/nvToolsExt.h b/src/include/nvtx3/nvToolsExt.h index ce4b0be..1093838 100644 --- a/src/include/nvtx3/nvToolsExt.h +++ b/src/include/nvtx3/nvToolsExt.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvToolsExtCuda.h b/src/include/nvtx3/nvToolsExtCuda.h index b1e654c..b1b80ad 100644 --- a/src/include/nvtx3/nvToolsExtCuda.h +++ b/src/include/nvtx3/nvToolsExtCuda.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvToolsExtCudaRt.h b/src/include/nvtx3/nvToolsExtCudaRt.h index 002f6e9..1e19958 100644 --- a/src/include/nvtx3/nvToolsExtCudaRt.h +++ b/src/include/nvtx3/nvToolsExtCudaRt.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvToolsExtOpenCL.h b/src/include/nvtx3/nvToolsExtOpenCL.h index 611c0cb..a7b8a19 100644 --- a/src/include/nvtx3/nvToolsExtOpenCL.h +++ b/src/include/nvtx3/nvToolsExtOpenCL.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvToolsExtPayload.h b/src/include/nvtx3/nvToolsExtPayload.h index 1683f92..a46c833 100644 --- a/src/include/nvtx3/nvToolsExtPayload.h +++ b/src/include/nvtx3/nvToolsExtPayload.h @@ -1,12 +1,12 @@ /* -* Copyright 2021 NVIDIA Corporation. All rights reserved. +* Copyright 2021-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception */ -#include "nvtx3/nvToolsExt.h" +#include "nvToolsExt.h" #ifndef NVTOOLSEXT_PAYLOAD_H #define NVTOOLSEXT_PAYLOAD_H diff --git a/src/include/nvtx3/nvToolsExtSync.h b/src/include/nvtx3/nvToolsExtSync.h index 5d24729..113fcd1 100644 --- a/src/include/nvtx3/nvToolsExtSync.h +++ b/src/include/nvtx3/nvToolsExtSync.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3.hpp b/src/include/nvtx3/nvtx3.hpp similarity index 51% rename from src/include/nvtx3.hpp rename to src/include/nvtx3/nvtx3.hpp index 353fddf..cb0ef68 100644 --- a/src/include/nvtx3.hpp +++ b/src/include/nvtx3/nvtx3.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,15 +20,15 @@ /* This section handles the decision of whether to provide unversioned symbols. * If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is #defined, unversioned symbols are - * not provided, and explicit-version symbols such as nvtx3::v1::thread_range + * not provided, and explicit-version symbols such as nvtx3::v1::scoped_range * and NVTX3_V1_FUNC_RANGE must be used. By default, the first #include of this - * header will define the unversioned symbols such as nvtx3::thread_range and + * header will define the unversioned symbols such as nvtx3::scoped_range and * NVTX3_FUNC_RANGE. Subsequently including a different major version of this * header without #defining NVTX3_CPP_REQUIRE_EXPLICIT_VERSION triggers an error * since the symbols would conflict. Subsequently including of a different * minor version within the same major version is allowed. Functionality of * minor versions is cumulative, regardless of include order. - * + * * Since NVTX3_CPP_REQUIRE_EXPLICIT_VERSION allows all combinations of versions * to coexist without problems within a translation unit, the recommended best * practice for instrumenting header-based libraries with NVTX C++ Wrappers is @@ -39,66 +39,58 @@ */ /* clang-format off */ #if !defined(NVTX3_CPP_REQUIRE_EXPLICIT_VERSION) - /* Define macro used by all definitions in this header to indicate the - * unversioned symbols should be defined in addition to the versioned ones. + /* Define macro used by all definitions in this header to indicate the + * unversioned symbols should be defined in addition to the versioned ones. + */ + #define NVTX3_INLINE_THIS_VERSION + + #if !defined(NVTX3_CPP_INLINED_VERSION_MAJOR) + /* First occurrence of this header in the translation unit. Define macros + * indicating which version shall be used for unversioned symbols. */ - #define NVTX3_INLINE_THIS_VERSION - #if !defined(NVTX3_CPP_INLINED_VERSION_MAJOR) - /* First occurrence of this header in the translation unit. Define macros - * indicating which version shall be used for unversioned symbols. - */ + /** + * @brief Semantic major version number for NVTX C++ wrappers of unversioned symbols + * + * Breaking changes may occur between major versions, and different major versions + * cannot provide unversioned symbols in the same translation unit (.cpp file). + * + * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined. + * + * Not to be confused with the version number of the NVTX core library. + */ + #define NVTX3_CPP_INLINED_VERSION_MAJOR 1 // NVTX3_CPP_VERSION_MAJOR - /** - * @brief Semantic major version number for NVTX C++ wrappers of unversioned symbols - * - * Breaking changes may occur between major versions, and different major versions - * cannot provide unversioned symbols in the same translation unit (.cpp file). - * - * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined. - * - * Not to be confused with the version number of the NVTX core library. - */ - #define NVTX3_CPP_INLINED_VERSION_MAJOR 1 // NVTX3_CPP_VERSION_MAJOR - - /** - * @brief Semantic minor version number for NVTX C++ wrappers of unversioned symbols - * - * No breaking changes occur between minor versions -- minor version changes within - * a major version are purely additive. - * - * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined. - * - * Not to be confused with the version number of the NVTX core library. - */ - #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR - #elif NVTX3_CPP_INLINED_VERSION_MAJOR != NVTX3_CPP_VERSION_MAJOR - /* Unsupported case -- cannot define unversioned symbols for different major versions - * in the same translation unit. - */ - #error \ - "Two different major versions of the NVTX C++ Wrappers are being included in a single .cpp file, with unversioned symbols enabled in both. Only one major version can enable unversioned symbols in a .cpp file. To disable unversioned symbols, #define NVTX3_CPP_REQUIRE_EXPLICIT_VERSION before #including nvtx3.hpp, and use the explicit-version symbols instead -- this is the preferred way to use nvtx3.hpp from a header file." - #elif (NVTX3_CPP_INLINED_VERSION_MAJOR == NVTX3_CPP_VERSION_MAJOR) && \ - (NVTX3_CPP_INLINED_VERSION_MINOR < NVTX3_CPP_VERSION_MINOR) - /* An older minor version of the same major version already defined unversioned - * symbols. The new features provided in this header will be inlined - * redefine the minor version macro to this header's version. - */ - #undef NVTX3_CPP_INLINED_VERSION_MINOR - #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR - // else, already have this version or newer, nothing to do - #endif + /** + * @brief Semantic minor version number for NVTX C++ wrappers of unversioned symbols + * + * No breaking changes occur between minor versions -- minor version changes within + * a major version are purely additive. + * + * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined. + * + * Not to be confused with the version number of the NVTX core library. + */ + #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR + #elif NVTX3_CPP_INLINED_VERSION_MAJOR != NVTX3_CPP_VERSION_MAJOR + /* Unsupported case -- cannot define unversioned symbols for different major versions + * in the same translation unit. + */ + #error \ + "Two different major versions of the NVTX C++ Wrappers are being included in a single .cpp file, with unversioned symbols enabled in both. Only one major version can enable unversioned symbols in a .cpp file. To disable unversioned symbols, #define NVTX3_CPP_REQUIRE_EXPLICIT_VERSION before #including nvtx3.hpp, and use the explicit-version symbols instead -- this is the preferred way to use nvtx3.hpp from a header file." + #elif (NVTX3_CPP_INLINED_VERSION_MAJOR == NVTX3_CPP_VERSION_MAJOR) && \ + (NVTX3_CPP_INLINED_VERSION_MINOR < NVTX3_CPP_VERSION_MINOR) + /* An older minor version of the same major version already defined unversioned + * symbols. The new features provided in this header will be inlined + * redefine the minor version macro to this header's version. + */ + #undef NVTX3_CPP_INLINED_VERSION_MINOR + #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR + // else, already have this version or newer, nothing to do + #endif #endif /* clang-format on */ -#include -#include - -#include -#include -#include -#include - /** * @file nvtx3.hpp * @@ -112,19 +104,19 @@ * * \section QUICK_START Quick Start * - * To add NVTX ranges to your code, use the `nvtx3::thread_range` RAII object. A + * To add NVTX ranges to your code, use the `nvtx3::scoped_range` RAII object. A * range begins when the object is created, and ends when the object is * destroyed. * * \code{.cpp} * #include "nvtx3.hpp" - * void some_function(){ + * void some_function() { * // Begins a NVTX range with the messsage "some_function" * // The range ends when some_function() returns and `r` is destroyed - * nvtx3::thread_range r{"some_function"}; + * nvtx3::scoped_range r{"some_function"}; * - * for(int i = 0; i < 6; ++i){ - * nvtx3::thread_range loop{"loop range"}; + * for(int i = 0; i < 6; ++i) { + * nvtx3::scoped_range loop{"loop range"}; * std::this_thread::sleep_for(std::chrono::seconds{1}); * } * } // Range ends when `r` is destroyed @@ -142,10 +134,9 @@ * * \code{.cpp} * #include "nvtx3.hpp" - * void some_function(){ + * void some_function() { * // Creates a range with a message "some_function" that ends when the - * enclosing - * // function returns + * // enclosing function returns * NVTX3_FUNC_RANGE(); * ... * } @@ -165,66 +156,66 @@ * be accomplished with an NVTX range created on the entry to the function and * terminated on return from `my_function` using the push/pop C APIs: * - * ``` - * void my_function(...){ + * \code{.cpp} + * void my_function(...) { * nvtxRangePushA("my_function"); // Begins NVTX range * // do work * nvtxRangePop(); // Ends NVTX range * } - * ``` + * \endcode * * One of the challenges with using the NVTX C API is that it requires manually * terminating the end of the range with `nvtxRangePop`. This can be challenging * if `my_function()` has multiple returns or can throw exceptions as it * requires calling `nvtxRangePop()` before all possible return points. * - * NVTX++ solves this inconvenience through the "RAII" technique by providing a - * `nvtx3::thread_range` class that begins a range at construction and ends the - * range on destruction. The above example then becomes: + * NVTX C++ solves this inconvenience through the "RAII" technique by providing + * a `nvtx3::scoped_range` class that begins a range at construction and ends + * the range on destruction. The above example then becomes: * - * ``` - * void my_function(...){ - * nvtx3::thread_range r{"my_function"}; // Begins NVTX range + * \code{.cpp} + * void my_function(...) { + * nvtx3::scoped_range r{"my_function"}; // Begins NVTX range * // do work * } // Range ends on exit from `my_function` when `r` is destroyed - * ``` + * \endcode * * The range object `r` is deterministically destroyed whenever `my_function` * returns---ending the NVTX range without manual intervention. For more - * information, see \ref RANGES and `nvtx3::domain_thread_range`. + * information, see \ref RANGES and `nvtx3::scoped_range_in`. * * Another inconvenience of the NVTX C APIs are the several constructs where the * user is expected to initialize an object at the beginning of an application * and reuse that object throughout the lifetime of the application. For example - * Domains, Categories, and Registered messages. + * see domains, categories, and registered messages. * * Example: - * ``` + * \code{.cpp} * nvtxDomainHandle_t D = nvtxDomainCreateA("my domain"); * // Reuse `D` throughout the rest of the application - * ``` + * \endcode * * This can be problematic if the user application or library does not have an * explicit initialization function called before all other functions to * ensure that these long-lived objects are initialized before being used. * - * NVTX++ makes use of the "construct on first use" technique to alleviate this - * inconvenience. In short, a function local static object is constructed upon - * the first invocation of a function and returns a reference to that object on - * all future invocations. See the documentation for - * `nvtx3::registered_string`, `nvtx3::domain`, `nvtx3::named_category`, and + * NVTX C++ makes use of the "construct on first use" technique to alleviate + * this inconvenience. In short, a function local static object is constructed + * upon the first invocation of a function and returns a reference to that + * object on all future invocations. See the documentation for `nvtx3::domain`, + * `nvtx3::named_category`, `nvtx3::registered_string`, and * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use for more * information. * * Using construct on first use, the above example becomes: - * ``` + * \code{.cpp} * struct my_domain{ static constexpr char const* name{"my domain"}; }; * * // The first invocation of `domain::get` for the type `my_domain` will * // construct a `nvtx3::domain` object and return a reference to it. Future * // invocations simply return a reference. * nvtx3::domain const& D = nvtx3::domain::get(); - * ``` + * \endcode * For more information about NVTX and how it can be used, see * https://docs.nvidia.com/cuda/profiler-users-guide/index.html#nvtx and * https://devblogs.nvidia.com/cuda-pro-tip-generate-custom-application-profile-timelines-nvtx/ @@ -236,106 +227,108 @@ * application. Common examples are using ranges to annotate the time it takes * to execute a function or an iteration of a loop. * - * NVTX++ uses RAII to automate the generation of ranges that are tied to the + * NVTX C++ uses RAII to automate the generation of ranges that are tied to the * lifetime of objects. Similar to `std::lock_guard` in the C++ Standard * Template Library. * - * \subsection THREAD_RANGE Thread Range + * \subsection scoped_range Scoped Range * - * `nvtx3::domain_thread_range` is a class that begins a range upon construction + * `nvtx3::scoped_range_in` is a class that begins a range upon construction * and ends the range at destruction. This is one of the most commonly used - * constructs in NVTX++ and is useful for annotating spans of time on a + * constructs in NVTX C++ and is useful for annotating spans of time on a * particular thread. These ranges can be nested to arbitrary depths. * - * `nvtx3::thread_range` is an alias for a `nvtx3::domain_thread_range` in the + * `nvtx3::scoped_range` is an alias for a `nvtx3::scoped_range_in` in the * global NVTX domain. For more information about Domains, see \ref DOMAINS. * * Various attributes of a range can be configured constructing a - * `nvtx3::domain_thread_range` with a `nvtx3::event_attributes` object. For + * `nvtx3::scoped_range_in` with a `nvtx3::event_attributes` object. For * more information, see \ref ATTRIBUTES. * * Example: * * \code{.cpp} - * void some_function(){ + * void some_function() { * // Creates a range for the duration of `some_function` - * nvtx3::thread_range r{}; + * nvtx3::scoped_range r{}; * - * while(true){ + * while(true) { * // Creates a range for every loop iteration * // `loop_range` is nested inside `r` - * nvtx3::thread_range loop_range{}; + * nvtx3::scoped_range loop_range{}; * } * } * \endcode * - * \subsection PROCESS_RANGE Process Range + * \subsection unique_range Unique Range * - * `nvtx3::domain_process_range` is identical to `nvtx3::domain_thread_range` - * with the exception that a `domain_process_range` can be created and destroyed - * on different threads. This is useful to annotate spans of time that can - * bridge multiple threads. + * `nvtx3::unique_range` is similar to `nvtx3::scoped_range`, with a few key differences: + * - `unique_range` objects can be destroyed in any order whereas `scoped_range` objects must be + * destroyed in exact reverse creation order + * - `unique_range` can start and end on different threads + * - `unique_range` is moveable + * - `unique_range` objects can be constructed as heap objects * - * `nvtx3::domain_thread_range`s should be preferred unless one needs the - * ability to begin and end a range on different threads. + * There is extra overhead associated with `unique_range` constructs and therefore use of + * `nvtx3::scoped_range_in` should be preferred. * * \section MARKS Marks * - * `nvtx3::mark` allows annotating an instantaneous event in an application's - * timeline. For example, indicating when a mutex is locked or unlocked. + * `nvtx3::mark` annotates an instantaneous point in time with a "marker". + * + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: * * \code{.cpp} - * std::mutex global_lock; - * void lock_mutex(){ - * global_lock.lock(); - * // Marks an event immediately after the mutex is locked - * nvtx3::mark("lock_mutex"); + * bool success = do_operation(...); + * if (!success) { + * nvtx3::mark("operation failed!"); * } * \endcode * * \section DOMAINS Domains * - * Similar to C++ namespaces, Domains allow for scoping NVTX events. By default, + * Similar to C++ namespaces, domains allow for scoping NVTX events. By default, * all NVTX events belong to the "global" domain. Libraries and applications * should scope their events to use a custom domain to differentiate where the * events originate from. * * It is common for a library or application to have only a single domain and * for the name of that domain to be known at compile time. Therefore, Domains - * in NVTX++ are represented by _tag types_. + * in NVTX C++ are represented by _tag types_. * - * For example, to define a custom domain, simply define a new concrete type + * For example, to define a custom domain, simply define a new concrete type * (a `class` or `struct`) with a `static` member called `name` that contains * the desired name of the domain. * - * ``` + * \code{.cpp} * struct my_domain{ static constexpr char const* name{"my domain"}; }; - * ``` + * \endcode * - * For any NVTX++ construct that can be scoped to a domain, the type `my_domain` - * can be passed as an explicit template argument to scope it to the custom - * domain. + * For any NVTX C++ construct that can be scoped to a domain, the type + * `my_domain` can be passed as an explicit template argument to scope it to + * the custom domain. * * The tag type `nvtx3::domain::global` represents the global NVTX domain. * * \code{.cpp} - * // By default, `domain_thread_range` belongs to the global domain - * nvtx3::domain_thread_range<> r0{}; + * // By default, `scoped_range_in` belongs to the global domain + * nvtx3::scoped_range_in<> r0{}; * - * // Alias for a `domain_thread_range` in the global domain - * nvtx3::thread_range r1{}; + * // Alias for a `scoped_range_in` in the global domain + * nvtx3::scoped_range r1{}; * * // `r` belongs to the custom domain - * nvtx3::domain_thread_range r{}; + * nvtx3::scoped_range_in r{}; * \endcode * - * When using a custom domain, it is reccomended to define type aliases for NVTX + * When using a custom domain, it is recommended to define type aliases for NVTX * constructs in the custom domain. - * ``` - * using my_thread_range = nvtx3::domain_thread_range; - * using my_registered_string = nvtx3::registered_string; - * using my_named_category = nvtx3::named_category; - * ``` + * \code{.cpp} + * using my_scoped_range = nvtx3::scoped_range_in; + * using my_registered_string = nvtx3::registered_string_in; + * using my_named_category = nvtx3::named_category_in; + * \endcode * * See `nvtx3::domain` for more information. * @@ -359,35 +352,41 @@ * information. * * \code{.cpp} - * // Custom color, message - * event_attributes attr{nvtx3::rgb{127, 255, 0}, - * "message"}; + * // Set message, same as passing nvtx3::message{"message"} + * nvtx3::event_attributes attr{"message"}; * - * // Custom color, message, payload, category - * event_attributes attr{nvtx3::rgb{127, 255, 0}, - * nvtx3::payload{42}, - * "message", - * nvtx3::category{1}}; + * // Set message and color + * nvtx3::event_attributes attr{"message", nvtx3::rgb{127, 255, 0}}; * - * // Arguments can be in any order - * event_attributes attr{nvtx3::payload{42}, - * nvtx3::category{1}, - * "message", - * nvtx3::rgb{127, 255, 0}}; + * // Set message, color, payload, category + * nvtx3::event_attributes attr{"message", + * nvtx3::rgb{127, 255, 0}, + * nvtx3::payload{42}, + * nvtx3::category{1}}; * - * // "First wins" with multiple arguments of the same type - * event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; // payload is - * 42 \endcode + * // Same as above -- can use any order of arguments + * nvtx3::event_attributes attr{nvtx3::payload{42}, + * nvtx3::category{1}, + * "message", + * nvtx3::rgb{127, 255, 0}}; + * + * // Multiple arguments of the same type are allowed, but only the first is + * // used -- in this example, payload is set to 42: + * nvtx3::event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; + * + * // Using the nvtx3 namespace in a local scope makes the syntax more succinct: + * using namespace nvtx3; + * event_attributes attr{"message", rgb{127, 255, 0}, payload{42}, category{1}}; + * \endcode * * \subsection MESSAGES message * - * A `nvtx3::message` allows associating a custom message string with an NVTX - * event. + * `nvtx3::message` sets the message string for an NVTX event. * * Example: * \code{.cpp} - * // Create an `event_attributes` with the custom message "my message" - * nvtx3::event_attributes attr{nvtx3::Mesage{"my message"}}; + * // Create an `event_attributes` with the message "my message" + * nvtx3::event_attributes attr{nvtx3::message{"my message"}}; * * // strings and string literals implicitly assumed to be a `nvtx3::message` * nvtx3::event_attributes attr{"my message"}; @@ -415,8 +414,8 @@ * * Example: * \code{.cpp} - * // Explicitly constructed, static `registered_string` - * static registered_string static_message{"my message"}; + * // Explicitly constructed, static `registered_string` in my_domain: + * static registered_string_in static_message{"my message"}; * * // Or use construct on first use: * // Define a tag type with a `message` member string to register @@ -424,8 +423,8 @@ * * // Uses construct on first use to register the contents of * // `my_message::message` - * nvtx3::registered_string const& msg = - * nvtx3::registered_string::get(); \endcode + * auto& msg = nvtx3::registered_string_in::get(); + * \endcode * * \subsection COLOR color * @@ -466,34 +465,32 @@ * custom tag type with static `name` and `id` members. * * \code{.cpp} - * // Explicitly constructed, static `named_category` - * static nvtx3::named_category static_category{42, "my category"}; + * // Explicitly constructed, static `named_category` in my_domain: + * static nvtx3::named_category_in static_category{42, "my category"}; * - * // OR use construct on first use: + * // Or use construct on first use: * // Define a tag type with `name` and `id` members - * struct my_category{ + * struct my_category { * static constexpr char const* name{"my category"}; // category name - * static constexpr category::id_type id{42}; // category id + * static constexpr uint32_t id{42}; // category id * }; * * // Use construct on first use to name the category id `42` - * // with name "my category" - * nvtx3::named_category const& my_category = - * named_category::get(); + * // with name "my category": + * auto& cat = named_category_in::get(); * * // Range `r` associated with category id `42` - * nvtx3::event_attributes attr{my_category}; + * nvtx3::event_attributes attr{cat}; * \endcode * * \subsection PAYLOAD payload * * Allows associating a user-defined numerical value with an event. * - * ``` - * nvtx3:: event_attributes attr{nvtx3::payload{42}}; // Constructs a payload - * from - * // the `int32_t` value 42 - * ``` + * \code{.cpp} + * // Constructs a payload from the `int32_t` value 42 + * nvtx3:: event_attributes attr{nvtx3::payload{42}}; + * \endcode * * * \section EXAMPLE Example @@ -513,34 +510,33 @@ * struct my_message{ static constexpr char const* message{"my message"}; }; * * // For convenience, use aliases for domain scoped objects - * using my_thread_range = nvtx3::domain_thread_range; - * using my_registered_string = nvtx3::registered_string; - * using my_named_category = nvtx3::named_category; + * using my_scoped_range = nvtx3::scoped_range_in; + * using my_registered_string = nvtx3::registered_string_in; + * using my_named_category = nvtx3::named_category_in; * * // Default values for all attributes * nvtx3::event_attributes attr{}; - * my_thread_range r0{attr}; + * my_scoped_range r0{attr}; * * // Custom (unregistered) message, and unnamed category * nvtx3::event_attributes attr1{"message", nvtx3::category{2}}; - * my_thread_range r1{attr1}; + * my_scoped_range r1{attr1}; * * // Alternatively, pass arguments of `event_attributes` ctor directly to - * // `my_thread_range` - * my_thread_range r2{"message", nvtx3::category{2}}; + * // `my_scoped_range` + * my_scoped_range r2{"message", nvtx3::category{2}}; * * // construct on first use a registered string - * auto msg = my_registered_string::get(); + * auto& msg = my_registered_string::get(); * * // construct on first use a named category - * auto category = my_named_category::get(); + * auto& cat = my_named_category::get(); * - * // Use registered string and named category - * my_thread_range r3{msg, category, nvtx3::rgb{127, 255, 0}, - * nvtx3::payload{42}}; + * // Use registered string and named category with a custom payload + * my_scoped_range r3{msg, cat, nvtx3::payload{42}}; * * // Any number of arguments in any order - * my_thread_range r{nvtx3::rgb{127, 255,0}, msg}; + * my_scoped_range r{nvtx3::rgb{127, 255,0}, msg}; * * \endcode * \section MACROS Convenience Macros @@ -550,11 +546,11 @@ * * A convenient way to do this is to use the \ref NVTX3_FUNC_RANGE and * \ref NVTX3_FUNC_RANGE_IN macros. These macros take care of constructing an - * `nvtx3::domain_thread_range` with the name of the enclosing function as the + * `nvtx3::scoped_range_in` with the name of the enclosing function as the * range's message. * * \code{.cpp} - * void some_function(){ + * void some_function() { * // Automatically generates an NVTX range for the duration of the function * // using "some_function" as the event's message. * NVTX3_FUNC_RANGE(); @@ -565,6 +561,25 @@ /* Temporary helper #defines, removed with #undef at end of header */ +#if !defined(NVTX3_USE_CHECKED_OVERLOADS_FOR_GET) +#if defined(_MSC_VER) && _MSC_VER < 1914 +/* Microsoft's compiler prior to VS2017 Update 7 (15.7) uses an older parser + * that does not work with domain::get's specialization for domain::global, + * and would require extra conditions to make SFINAE work for the overloaded + * get() functions. This macro disables use of overloaded get() in order to + * work with VS2015 and versions of VS2017 below 15.7, without penalizing + * users of newer compilers. Building with this flag set to 0 means errors + * when defining tag structs (see documentation for domain, named_category, + * and registered_string) will have more complex compiler error messages + * instead of the clear static_assert messages from the get() overloads. + */ +#define NVTX3_USE_CHECKED_OVERLOADS_FOR_GET 0 +#else +#define NVTX3_USE_CHECKED_OVERLOADS_FOR_GET 1 +#endif +#define NVTX3_USE_CHECKED_OVERLOADS_FOR_GET_DEFINED_HERE +#endif + /* Within this header, nvtx3::NVTX3_VERSION_NAMESPACE resolves to nvtx3::vX, * where "X" is the major version number. */ #define NVTX3_CONCAT(A, B) A##B @@ -580,18 +595,30 @@ #define NVTX3_INLINE_IF_REQUESTED #endif -/* Enables the use of constexpr when support for C++14 relaxed constexpr - * is present. +/* Enables the use of constexpr when support for C++14 constexpr is present. * - * Initializing a legacy-C (i.e., no constructor) union member requires - * initializing in the constructor body. Non-empty constexpr constructors - * require C++14 relaxed constexpr. In strict C++11 compilation, fall back - * to using non-constexpr constructors for classes with union members. + * Initialization of a class member that is a union to a specific union member + * can only be done in the body of a constructor, not in a member initializer + * list. A constexpr constructor must have an empty body until C++14, so there + * is no way to make an initializer of a member union constexpr in C++11. This + * macro allows making functions constexpr in C++14 or newer, but non-constexpr + * in C++11 compilation. It is used here on constructors that initialize their + * member unions. */ #if __cpp_constexpr >= 201304L -#define NVTX3_RELAXED_CONSTEXPR constexpr +#define NVTX3_CONSTEXPR_IF_CPP14 constexpr #else -#define NVTX3_RELAXED_CONSTEXPR +#define NVTX3_CONSTEXPR_IF_CPP14 +#endif + + /* Use a macro for static asserts, which defaults to static_assert, but that + * testing tools can replace with a logging function. For example: + * #define NVTX3_STATIC_ASSERT(c, m) \ + * do { if (!(c)) printf("static_assert would fail: %s\n", m); } while (0) + */ +#if !defined(NVTX3_STATIC_ASSERT) +#define NVTX3_STATIC_ASSERT(condition, message) static_assert(condition, message); +#define NVTX3_STATIC_ASSERT_DEFINED_HERE #endif /* Implementation sections, enclosed in guard macros for each minor version */ @@ -599,6 +626,15 @@ #ifndef NVTX3_CPP_DEFINITIONS_V1_0 #define NVTX3_CPP_DEFINITIONS_V1_0 +#include "nvToolsExt.h" +#include "nvToolsExtPayload.h" + +#include +#include +#include +#include +#include + namespace nvtx3 { NVTX3_INLINE_IF_REQUESTED namespace NVTX3_VERSION_NAMESPACE @@ -606,20 +642,35 @@ NVTX3_INLINE_IF_REQUESTED namespace NVTX3_VERSION_NAMESPACE namespace detail { -/** - * @brief Verifies if a type `T` contains a member `T::name` of type `const - * char*` or `const wchar_t*`. - * - * @tparam T The type to verify - * @return True if `T` contains a member `T::name` of type `const char*` or - * `const wchar_t*`. - */ +template +struct always_false : std::false_type {}; + +template +struct has_name : std::false_type {}; template -constexpr auto has_name_member() noexcept -> decltype(T::name, bool()) -{ - return (std::is_same::type>::value || - std::is_same::type>::value); -} +struct has_name : std::true_type {}; + +template +struct has_id : std::false_type {}; +template +struct has_id : std::true_type {}; + +template +struct has_message : std::false_type {}; +template +struct has_message : std::true_type {}; + +template +struct is_c_string : std::false_type {}; +template +struct is_c_string::value || + std::is_convertible::value +>::type> : std::true_type {}; + +template +using is_uint32 = std::is_same::type, uint32_t>; + } // namespace detail /** @@ -634,7 +685,7 @@ constexpr auto has_name_member() noexcept -> decltype(T::name, bool()) * `domain`s are expected to be long-lived and unique to a library or * application. As such, it is assumed a domain's name is known at compile * time. Therefore, all NVTX constructs that can be associated with a domain - * require the domain to be specified via a *type* `DomainName` passed as an + * require the domain to be specified via a *type* `D` passed as an * explicit template parameter. * * The type `domain::global` may be used to indicate that the global NVTX @@ -642,109 +693,46 @@ constexpr auto has_name_member() noexcept -> decltype(T::name, bool()) * * None of the C++ NVTX constructs require the user to manually construct a * `domain` object. Instead, if a custom domain is desired, the user is - * expected to define a type `DomainName` that contains a member - * `DomainName::name` which resolves to either a `char const*` or `wchar_t - * const*`. The value of `DomainName::name` is used to name and uniquely + * expected to define a type `D` that contains a member + * `D::name` which resolves to either a `char const*` or `wchar_t + * const*`. The value of `D::name` is used to name and uniquely * identify the custom domain. * * Upon the first use of an NVTX construct associated with the type - * `DomainName`, the "construct on first use" pattern is used to construct a + * `D`, the "construct on first use" pattern is used to construct a * function local static `domain` object. All future NVTX constructs - * associated with `DomainType` will use a reference to the previously + * associated with `D` will use a reference to the previously * constructed `domain` object. See `domain::get`. * * Example: - * ``` + * \code{.cpp} * // The type `my_domain` defines a `name` member used to name and identify - * the - * // `domain` object identified by `my_domain`. + * // the `domain` object identified by `my_domain`. * struct my_domain{ static constexpr char const* name{"my_domain"}; }; * * // The NVTX range `r` will be grouped with all other NVTX constructs * // associated with `my_domain`. - * nvtx3::domain_thread_range r{}; + * nvtx3::scoped_range_in r{}; * - * // An alias can be created for a `domain_thread_range` in the custom domain - * using my_thread_range = nvtx3::domain_thread_range; - * my_thread_range my_range{}; + * // An alias can be created for a `scoped_range_in` in the custom domain + * using my_scoped_range = nvtx3::scoped_range_in; + * my_scoped_range my_range{}; * * // `domain::global` indicates that the global NVTX domain is used - * nvtx3::domain_thread_range r2{}; + * nvtx3::scoped_range_in r2{}; * - * // For convenience, `nvtx3::thread_range` is an alias for a range in the + * // For convenience, `nvtx3::scoped_range` is an alias for a range in the * // global domain - * nvtx3::thread_range r3{}; - * ``` + * nvtx3::scoped_range r3{}; + * \endcode */ class domain { public: domain(domain const&) = delete; domain& operator=(domain const&) = delete; - domain(domain&&) = delete; + domain(domain&&) = delete; domain& operator=(domain&&) = delete; - /** - * @brief Returns reference to an instance of a function local static - * `domain` object. - * - * Uses the "construct on first use" idiom to safely ensure the `domain` - * object is initialized exactly once upon first invocation of - * `domain::get()`. All following invocations will return a - * reference to the previously constructed `domain` object. See - * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use - * - * None of the constructs in this header require the user to directly invoke - * `domain::get`. It is automatically invoked when constructing objects like - * a `domain_thread_range` or `category`. Advanced users may wish to use - * `domain::get` for the convenience of the "construct on first use" idiom - * when using domains with their own use of the NVTX C API. - * - * This function is threadsafe as of C++11. If two or more threads call - * `domain::get` concurrently, exactly one of them is guaranteed - * to construct the `domain` object and the other(s) will receive a - * reference to the object after it is fully constructed. - * - * The domain's name is specified via the type `DomainName` pass as an - * explicit template parameter. `DomainName` is required to contain a - * member `DomainName::name` that resolves to either a `char const*` or - * `wchar_t const*`. The value of `DomainName::name` is used to name and - * uniquely identify the `domain`. - * - * Example: - * ``` - * // The type `my_domain` defines a `name` member used to name and identify - * // the `domain` object identified by `my_domain`. - * struct my_domain{ static constexpr char const* name{"my domain"}; }; - * - * auto D = domain::get(); // First invocation constructs a - * // `domain` with the name "my domain" - * - * auto D1 = domain::get(); // Simply returns reference to - * // previously constructed `domain`. - * ``` - * - * @tparam DomainName Type that contains a `DomainName::name` member used to - * name the `domain` object. - * @return Reference to the `domain` corresponding to the type `DomainName`. - */ - template - static domain const& get() - { - static_assert(detail::has_name_member(), - "Type used to identify a domain must contain a name member of" - "type const char* or const wchar_t*"); - static domain const d{DomainName::name}; - return d; - } - - /** - * @brief Conversion operator to `nvtxDomainHandle_t`. - * - * Allows transparently passing a domain object into an API expecting a - * native `nvtxDomainHandle_t` object. - */ - operator nvtxDomainHandle_t() const noexcept { return _domain; } - /** * @brief Tag type for the "global" NVTX domain. * @@ -759,6 +747,113 @@ class domain { struct global { }; +#if NVTX3_USE_CHECKED_OVERLOADS_FOR_GET + /** + * @brief Returns reference to an instance of a function local static + * `domain` object. + * + * Uses the "construct on first use" idiom to safely ensure the `domain` + * object is initialized exactly once upon first invocation of + * `domain::get()`. All following invocations will return a + * reference to the previously constructed `domain` object. See + * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use + * + * None of the constructs in this header require the user to directly invoke + * `domain::get`. It is automatically invoked when constructing objects like + * a `scoped_range_in` or `category`. Advanced users may wish to use + * `domain::get` for the convenience of the "construct on first use" idiom + * when using domains with their own use of the NVTX C API. + * + * This function is threadsafe as of C++11. If two or more threads call + * `domain::get` concurrently, exactly one of them is guaranteed + * to construct the `domain` object and the other(s) will receive a + * reference to the object after it is fully constructed. + * + * The domain's name is specified via the type `D` pass as an + * explicit template parameter. `D` is required to contain a + * member `D::name` that resolves to either a `char const*` or + * `wchar_t const*`. The value of `D::name` is used to name and + * uniquely identify the `domain`. + * + * Example: + * \code{.cpp} + * // The type `my_domain` defines a `name` member used to name and identify + * // the `domain` object identified by `my_domain`. + * struct my_domain{ static constexpr char const* name{"my domain"}; }; + * + * auto& D1 = domain::get(); // First invocation constructs a + * // `domain` with the name "my domain" + * + * auto& D2 = domain::get(); // Quickly returns reference to + * // previously constructed `domain`. + * \endcode + * + * @tparam D Type that contains a `D::name` member used to + * name the `domain` object. + * @return Reference to the `domain` corresponding to the type `D`. + */ + template ::value + , int>::type = 0> + static domain const& get() noexcept + { + static domain const d(D::name); + return d; + } + + /** + * @brief Overload of `domain::get` to provide a clear compile error when + * `D` has a `name` member that is not directly convertible to either + * `char const*` or `wchar_t const*`. + */ + template ::value + , int>::type = 0> + static domain const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::always_false::value, + "Type used to identify an NVTX domain must contain a static constexpr member " + "called 'name' of type const char* or const wchar_t* -- 'name' member is not " + "convertible to either of those types"); + static domain const unused; + return unused; // Function must compile for static_assert to be triggered + } + + /** + * @brief Overload of `domain::get` to provide a clear compile error when + * `D` does not have a `name` member. + */ + template ::value + , int>::type = 0> + static domain const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::always_false::value, + "Type used to identify an NVTX domain must contain a static constexpr member " + "called 'name' of type const char* or const wchar_t* -- 'name' member is missing"); + static domain const unused; + return unused; // Function must compile for static_assert to be triggered + } +#else + template + static domain const& get() noexcept + { + static domain const d(D::name); + return d; + } +#endif + + /** + * @brief Conversion operator to `nvtxDomainHandle_t`. + * + * Allows transparently passing a domain object into an API expecting a + * native `nvtxDomainHandle_t` object. + */ + operator nvtxDomainHandle_t() const noexcept { return _domain; } + private: /** * @brief Construct a new domain with the specified `name`. @@ -808,7 +903,7 @@ class domain { * "global" NVTX domain. * */ - domain() = default; + domain() noexcept {} /** * @brief Intentionally avoid calling nvtxDomainDestroy on the `domain` object. @@ -844,15 +939,15 @@ class domain { * */ template <> -inline domain const& domain::get() +inline domain const& domain::get() noexcept { static domain const d{}; return d; } /** - * @brief Indicates the values of the red, green, blue color channels for - * a rgb color code. + * @brief Indicates the values of the red, green, and blue color channels for + * an RGB color to use as an event attribute (assumes no transparency). * */ struct rgb { @@ -869,19 +964,22 @@ struct rgb { * @param green_ Value of the green channel * @param blue_ Value of the blue channel */ - constexpr rgb(component_type red_, component_type green_, component_type blue_) noexcept + constexpr rgb( + component_type red_, + component_type green_, + component_type blue_) noexcept : red{red_}, green{green_}, blue{blue_} { } - component_type const red{}; ///< Red channel value - component_type const green{}; ///< Green channel value - component_type const blue{}; ///< Blue channel value + component_type red{}; ///< Red channel value + component_type green{}; ///< Green channel value + component_type blue{}; ///< Blue channel value }; /** * @brief Indicates the value of the alpha, red, green, and blue color - * channels for an argb color code. + * channels for an ARGB color to use as an event attribute. * */ struct argb final : rgb { @@ -897,15 +995,16 @@ struct argb final : rgb { * @param blue_ Value of the blue channel * */ - constexpr argb(component_type alpha_, - component_type red_, - component_type green_, - component_type blue_) noexcept + constexpr argb( + component_type alpha_, + component_type red_, + component_type green_, + component_type blue_) noexcept : rgb{red_, green_, blue_}, alpha{alpha_} { } - component_type const alpha{}; ///< Alpha channel value + component_type alpha{}; ///< Alpha channel value }; /** @@ -947,8 +1046,8 @@ class color { * * @param argb The alpha, red, green, blue components of the desired `color` */ - constexpr color(argb argb) noexcept - : color{from_bytes_msb_to_lsb(argb.alpha, argb.red, argb.green, argb.blue)} + constexpr color(argb argb_) noexcept + : color{from_bytes_msb_to_lsb(argb_.alpha, argb_.red, argb_.green, argb_.blue)} { } @@ -960,8 +1059,8 @@ class color { * * @param rgb The red, green, blue components of the desired `color` */ - constexpr color(rgb rgb) noexcept - : color{from_bytes_msb_to_lsb(0xFF, rgb.red, rgb.green, rgb.blue)} + constexpr color(rgb rgb_) noexcept + : color{from_bytes_msb_to_lsb(0xFF, rgb_.red, rgb_.green, rgb_.blue)} { } @@ -977,11 +1076,11 @@ class color { */ constexpr nvtxColorType_t get_type() const noexcept { return _type; } - color() = delete; - ~color() = default; + color() = delete; + ~color() = default; color(color const&) = default; color& operator=(color const&) = default; - color(color&&) = default; + color(color&&) = default; color& operator=(color&&) = default; private: @@ -990,16 +1089,17 @@ class color { * most to least significant byte order. * */ - constexpr static value_type from_bytes_msb_to_lsb(uint8_t byte3, - uint8_t byte2, - uint8_t byte1, - uint8_t byte0) noexcept + constexpr static value_type from_bytes_msb_to_lsb( + uint8_t byte3, + uint8_t byte2, + uint8_t byte1, + uint8_t byte0) noexcept { return uint32_t{byte3} << 24 | uint32_t{byte2} << 16 | uint32_t{byte1} << 8 | uint32_t{byte0}; } - value_type const _value{}; ///< color's argb color code - nvtxColorType_t const _type{NVTX_COLOR_ARGB}; ///< NVTX color type code + value_type _value{}; ///< color's argb color code + nvtxColorType_t _type{NVTX_COLOR_ARGB}; ///< NVTX color type code }; /** @@ -1014,10 +1114,10 @@ class color { * nvtx3::category cat1{1}; * * // Range `r1` belongs to the category identified by the value `1`. - * nvtx3::thread_range r1{cat1}; + * nvtx3::scoped_range r1{cat1}; * * // Range `r2` belongs to the same category as `r1` - * nvtx3::thread_range r2{nvtx3::category{1}}; + * nvtx3::scoped_range r2{nvtx3::category{1}}; * \endcode * * To associate a name string with a category id, see `named_category`. @@ -1033,7 +1133,7 @@ class category { * * The `category` will be unnamed and identified only by its `id` value. * - * All `category` objects sharing the same `id` are equivalent. + * All `category`s in a domain sharing the same `id` are equivalent. * * @param[in] id The `category`'s identifying value */ @@ -1045,15 +1145,15 @@ class category { */ constexpr id_type get_id() const noexcept { return id_; } - category() = delete; - ~category() = default; + category() = delete; + ~category() = default; category(category const&) = default; category& operator=(category const&) = default; - category(category&&) = default; + category(category&&) = default; category& operator=(category&&) = default; private: - id_type const id_{}; ///< category's unique identifier + id_type id_{}; ///< category's unique identifier }; /** @@ -1075,45 +1175,46 @@ class category { * * Example: * \code{.cpp} - * // Explicitly constructed, static `named_category` + * // Explicitly constructed, static `named_category` in global domain: * static nvtx3::named_category static_category{42, "my category"}; * * // Range `r` associated with category id `42` - * nvtx3::thread_range r{static_category}; + * nvtx3::scoped_range r{static_category}; * * // OR use construct on first use: * * // Define a type with `name` and `id` members - * struct my_category{ + * struct my_category { * static constexpr char const* name{"my category"}; // category name - * static constexpr category::id_type id{42}; // category id + * static constexpr uint32_t id{42}; // category id * }; * * // Use construct on first use to name the category id `42` * // with name "my category" - * auto my_category = named_category::get(); + * auto& cat = named_category_in::get(); * * // Range `r` associated with category id `42` - * nvtx3::thread_range r{my_category}; + * nvtx3::scoped_range r{cat}; * \endcode * - * `named_category`'s association of a name to a category id is local to the - * domain specified by the type `D`. An id may have a different name in + * `named_category_in`'s association of a name to a category id is local to + * the domain specified by the type `D`. An id may have a different name in * another domain. * * @tparam D Type containing `name` member used to identify the `domain` to - * which the `named_category` belongs. Else, `domain::global` to indicate + * which the `named_category_in` belongs. Else, `domain::global` to indicate * that the global NVTX domain should be used. */ template -class named_category final : public category { +class named_category_in final : public category { public: +#if NVTX3_USE_CHECKED_OVERLOADS_FOR_GET /** - * @brief Returns a global instance of a `named_category` as a + * @brief Returns a global instance of a `named_category_in` as a * function-local static. * - * Creates a `named_category` with name and id specified by the contents of - * a type `C`. `C::name` determines the name and `C::id` determines the + * Creates a `named_category_in` with name and id specified by the contents + * of a type `C`. `C::name` determines the name and `C::id` determines the * category id. * * This function is useful for constructing a named `category` exactly once @@ -1122,36 +1223,97 @@ class named_category final : public category { * Example: * \code{.cpp} * // Define a type with `name` and `id` members - * struct my_category{ + * struct my_category { * static constexpr char const* name{"my category"}; // category name * static constexpr uint32_t id{42}; // category id * }; * * // Use construct on first use to name the category id `42` * // with name "my category" - * auto cat = named_category::get(); + * auto& cat = named_category_in::get(); * * // Range `r` associated with category id `42` - * nvtx3::thread_range r{cat}; + * nvtx3::scoped_range r{cat}; * \endcode * * Uses the "construct on first use" idiom to safely ensure the `category` * object is initialized exactly once. See * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use * - * @tparam C Type containing a member `C::name` that resolves to either a + * @tparam C Type containing a member `C::name` that resolves to either a * `char const*` or `wchar_t const*` and `C::id`. */ - template - static named_category const& get() noexcept + template ::value && + detail::is_uint32::value + , int>::type = 0> + static named_category_in const& get() noexcept { - static_assert(detail::has_name_member(), - "Type used to name a category must contain a name member."); - static named_category const category{C::id, C::name}; - return category; + static named_category_in const cat(C::id, C::name); + return cat; } + /** - * @brief Construct a `category` with the specified `id` and `name`. + * @brief Overload of `named_category_in::get` to provide a clear compile error + * when `C` has the required `name` and `id` members, but they are not the + * required types. `name` must be directly convertible to `char const*` or + * `wchar_t const*`, and `id` must be `uint32_t`. + */ + template ::value || + !detail::is_uint32::value + , int>::type = 0> + static named_category_in const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::is_c_string::value, + "Type used to name an NVTX category must contain a static constexpr member " + "called 'name' of type const char* or const wchar_t* -- 'name' member is not " + "convertible to either of those types"); + NVTX3_STATIC_ASSERT(detail::is_uint32::value, + "Type used to name an NVTX category must contain a static constexpr member " + "called 'id' of type uint32_t -- 'id' member is the wrong type"); + static named_category_in const unused; + return unused; // Function must compile for static_assert to be triggered + } + + /** + * @brief Overload of `named_category_in::get` to provide a clear compile error + * when `C` does not have the required `name` and `id` members. + */ + template ::value || + !detail::has_id::value + , int>::type = 0> + static named_category_in const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::has_name::value, + "Type used to name an NVTX category must contain a static constexpr member " + "called 'name' of type const char* or const wchar_t* -- 'name' member is missing"); + NVTX3_STATIC_ASSERT(detail::has_id::value, + "Type used to name an NVTX category must contain a static constexpr member " + "called 'id' of type uint32_t -- 'id' member is missing"); + static named_category_in const unused; + return unused; // Function must compile for static_assert to be triggered + } +#else + template + static named_category_in const& get() noexcept + { + static named_category_in const cat(C::id, C::name); + return cat; + } +#endif + + private: + // Default constructor is only used internally for static_assert(false) cases. + named_category_in() noexcept : category{0} {} + + public: + /** + * @brief Construct a `named_category_in` with the specified `id` and `name`. * * The name `name` will be registered with `id`. * @@ -1160,7 +1322,7 @@ class named_category final : public category { * @param[in] id The category id to name * @param[in] name The name to associated with `id` */ - named_category(id_type id, char const* name) noexcept : category{id} + named_category_in(id_type id, char const* name) noexcept : category{id} { #ifndef NVTX_DISABLE nvtxDomainNameCategoryA(domain::get(), get_id(), name); @@ -1171,7 +1333,7 @@ class named_category final : public category { }; /** - * @brief Construct a `category` with the specified `id` and `name`. + * @brief Construct a `named_category_in` with the specified `id` and `name`. * * The name `name` will be registered with `id`. * @@ -1180,7 +1342,7 @@ class named_category final : public category { * @param[in] id The category id to name * @param[in] name The name to associated with `id` */ - named_category(id_type id, wchar_t const* name) noexcept : category{id} + named_category_in(id_type id, wchar_t const* name) noexcept : category{id} { #ifndef NVTX_DISABLE nvtxDomainNameCategoryW(domain::get(), get_id(), name); @@ -1191,6 +1353,12 @@ class named_category final : public category { }; }; +/** + * @brief Alias for a `named_category_in` in the global NVTX domain. + * + */ +using named_category = named_category_in; + /** * @brief A message registered with NVTX. * @@ -1205,16 +1373,16 @@ class named_category final : public category { * * A particular message should only be registered once and the handle * reused throughout the rest of the application. This can be done by either - * explicitly creating static `registered_string` objects, or using the - * `registered_string::get` construct on first use helper (recommended). + * explicitly creating static `registered_string_in` objects, or using the + * `registered_string_in::get` construct on first use helper (recommended). * * Example: * \code{.cpp} - * // Explicitly constructed, static `registered_string` - * static registered_string static_message{"message"}; + * // Explicitly constructed, static `registered_string` in my_domain: + * static registered_string_in static_message{"message"}; * * // "message" is associated with the range `r` - * nvtx3::thread_range r{static_message}; + * nvtx3::scoped_range r{static_message}; * * // Or use construct on first use: * @@ -1224,30 +1392,31 @@ class named_category final : public category { * * // Uses construct on first use to register the contents of * // `my_message::message` - * auto msg = registered_string::get(); + * auto& msg = registered_string_in::get(); * * // "my message" is associated with the range `r` - * nvtx3::thread_range r{msg}; + * nvtx3::scoped_range r{msg}; * \endcode * - * `registered_string`s are local to a particular domain specified via + * `registered_string_in`s are local to a particular domain specified via * the type `D`. * * @tparam D Type containing `name` member used to identify the `domain` to - * which the `registered_string` belongs. Else, `domain::global` to indicate + * which the `registered_string_in` belongs. Else, `domain::global` to indicate * that the global NVTX domain should be used. */ template -class registered_string { +class registered_string_in { public: +#if NVTX3_USE_CHECKED_OVERLOADS_FOR_GET /** - * @brief Returns a global instance of a `registered_string` as a function + * @brief Returns a global instance of a `registered_string_in` as a function * local static. * * Provides a convenient way to register a message with NVTX without having * to explicitly register the message. * - * Upon first invocation, constructs a `registered_string` whose contents + * Upon first invocation, constructs a `registered_string_in` whose contents * are specified by `message::message`. * * All future invocations will return a reference to the object constructed @@ -1262,26 +1431,74 @@ class registered_string { * * // Uses construct on first use to register the contents of * // `my_message::message` - * auto msg = registered_string::get(); + * auto& msg = registered_string_in::get(); * * // "my message" is associated with the range `r` - * nvtx3::thread_range r{msg}; + * nvtx3::scoped_range r{msg}; * \endcode * * @tparam M Type required to contain a member `M::message` that * resolves to either a `char const*` or `wchar_t const*` used as the * registered string's contents. - * @return Reference to a `registered_string` associated with the type `M`. + * @return Reference to a `registered_string_in` associated with the type `M`. */ - template - static registered_string const& get() noexcept + template ::value + , int>::type = 0> + static registered_string_in const& get() noexcept { - static registered_string const registered_string{M::message}; - return registered_string; + static registered_string_in const regstr(M::message); + return regstr; } /** - * @brief Constructs a `registered_string` from the specified `msg` string. + * @brief Overload of `registered_string_in::get` to provide a clear compile error + * when `M` has a `message` member that is not directly convertible to either + * `char const*` or `wchar_t const*`. + */ + template ::value + , int>::type = 0> + static registered_string_in const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::always_false::value, + "Type used to register an NVTX string must contain a static constexpr member " + "called 'message' of type const char* or const wchar_t* -- 'message' member is " + "not convertible to either of those types"); + static registered_string_in const unused; + return unused; // Function must compile for static_assert to be triggered + } + + /** + * @brief Overload of `registered_string_in::get` to provide a clear compile error when + * `M` does not have a `message` member. + */ + template ::value + , int>::type = 0> + static registered_string_in const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::always_false::value, + "Type used to register an NVTX string must contain a static constexpr member " + "called 'message' of type const char* or const wchar_t* -- 'message' member " + "is missing"); + static registered_string_in const unused; + return unused; // Function must compile for static_assert to be triggered + } +#else + template + static registered_string_in const& get() noexcept + { + static registered_string_in const regstr(M::message); + return regstr; + } +#endif + + /** + * @brief Constructs a `registered_string_in` from the specified `msg` string. * * Registers `msg` with NVTX and associates a handle with the registered * message. @@ -1291,13 +1508,13 @@ class registered_string { * * @param msg The contents of the message */ - explicit registered_string(char const* msg) noexcept + explicit registered_string_in(char const* msg) noexcept : handle_{nvtxDomainRegisterStringA(domain::get(), msg)} { } /** - * @brief Constructs a `registered_string` from the specified `msg` string. + * @brief Constructs a `registered_string_in` from the specified `msg` string. * * Registers `msg` with NVTX and associates a handle with the registered * message. @@ -1307,10 +1524,11 @@ class registered_string { * * @param msg The contents of the message */ - explicit registered_string(std::string const& msg) noexcept : registered_string{msg.c_str()} {} + explicit registered_string_in(std::string const& msg) noexcept + : registered_string_in{msg.c_str()} {} /** - * @brief Constructs a `registered_string` from the specified `msg` string. + * @brief Constructs a `registered_string_in` from the specified `msg` string. * * Registers `msg` with NVTX and associates a handle with the registered * message. @@ -1320,13 +1538,13 @@ class registered_string { * * @param msg The contents of the message */ - explicit registered_string(wchar_t const* msg) noexcept + explicit registered_string_in(wchar_t const* msg) noexcept : handle_{nvtxDomainRegisterStringW(domain::get(), msg)} { } /** - * @brief Constructs a `registered_string` from the specified `msg` string. + * @brief Constructs a `registered_string_in` from the specified `msg` string. * * Registers `msg` with NVTX and associates a handle with the registered * message. @@ -1336,7 +1554,8 @@ class registered_string { * * @param msg The contents of the message */ - explicit registered_string(std::wstring const& msg) noexcept : registered_string{msg.c_str()} {} + explicit registered_string_in(std::wstring const& msg) noexcept + : registered_string_in{msg.c_str()} {} /** * @brief Returns the registered string's handle @@ -1344,18 +1563,27 @@ class registered_string { */ nvtxStringHandle_t get_handle() const noexcept { return handle_; } - registered_string() = delete; - ~registered_string() = default; - registered_string(registered_string const&) = default; - registered_string& operator=(registered_string const&) = default; - registered_string(registered_string&&) = default; - registered_string& operator=(registered_string&&) = default; +private: + // Default constructor is only used internally for static_assert(false) cases. + registered_string_in() noexcept {}; +public: + ~registered_string_in() = default; + registered_string_in(registered_string_in const&) = default; + registered_string_in& operator=(registered_string_in const&) = default; + registered_string_in(registered_string_in&&) = default; + registered_string_in& operator=(registered_string_in&&) = default; private: - nvtxStringHandle_t const handle_{}; ///< The handle returned from - ///< registering the message with NVTX + nvtxStringHandle_t handle_{}; ///< The handle returned from + ///< registering the message with NVTX }; +/** + * @brief Alias for a `registered_string_in` in the global NVTX domain. + * + */ +using registered_string = registered_string_in; + /** * @brief Allows associating a message string with an NVTX event via * its `EventAttribute`s. @@ -1374,7 +1602,7 @@ class registered_string { * nvtx3::event_attributes attr0{nvtx3::message{"message 0"}}; * * // `range0` contains message "message 0" - * nvtx3::thread_range range0{attr0}; + * nvtx3::scoped_range range0{attr0}; * * // `std::string` and string literals are implicitly assumed to be * // the contents of an `nvtx3::message` @@ -1382,15 +1610,15 @@ class registered_string { * nvtx3::event_attributes attr1{"message 1"}; * * // `range1` contains message "message 1" - * nvtx3::thread_range range1{attr1}; + * nvtx3::scoped_range range1{attr1}; * * // `range2` contains message "message 2" - * nvtx3::thread_range range2{nvtx3::Mesage{"message 2"}}; + * nvtx3::scoped_range range2{nvtx3::Mesage{"message 2"}}; * * // `std::string` and string literals are implicitly assumed to be * // the contents of an `nvtx3::message` * // `range3` contains message "message 3" - * nvtx3::thread_range range3{"message 3"}; + * nvtx3::scoped_range range3{"message 3"}; * \endcode */ class message { @@ -1402,7 +1630,7 @@ class message { * * @param msg The contents of the message */ - NVTX3_RELAXED_CONSTEXPR message(char const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_ASCII} + NVTX3_CONSTEXPR_IF_CPP14 message(char const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_ASCII} { value_.ascii = msg; } @@ -1429,7 +1657,7 @@ class message { * * @param msg The contents of the message */ - NVTX3_RELAXED_CONSTEXPR message(wchar_t const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_UNICODE} + NVTX3_CONSTEXPR_IF_CPP14 message(wchar_t const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_UNICODE} { value_.unicode = msg; } @@ -1452,35 +1680,59 @@ class message { message(std::wstring&&) = delete; /** - * @brief Construct a `message` from a `registered_string`. + * @brief Construct a `message` from a `registered_string_in`. * * @tparam D Type containing `name` member used to identify the `domain` - * to which the `registered_string` belongs. Else, `domain::global` to + * to which the `registered_string_in` belongs. Else, `domain::global` to * indicate that the global NVTX domain should be used. * @param msg The message that has already been registered with NVTX. */ template - NVTX3_RELAXED_CONSTEXPR message(registered_string const& msg) noexcept + NVTX3_CONSTEXPR_IF_CPP14 message(registered_string_in const& msg) noexcept : type_{NVTX_MESSAGE_TYPE_REGISTERED} { value_.registered = msg.get_handle(); } + /** + * @brief Construct a `message` from NVTX C API type and value. + * + * @param type nvtxMessageType_t enum value indicating type of the payload + * @param value nvtxMessageValue_t union containing message + */ + constexpr message( + nvtxMessageType_t const& type, + nvtxMessageValue_t const& value) noexcept + : type_{type}, value_(value) + { + } + + /** + * @brief Construct a `message` from NVTX C API registered string handle. + * + * @param handle nvtxStringHandle_t value of registered string handle + */ + NVTX3_CONSTEXPR_IF_CPP14 message(nvtxStringHandle_t handle) noexcept + : type_{NVTX_MESSAGE_TYPE_REGISTERED} + { + value_.registered = handle; + } + /** * @brief Return the union holding the value of the message. * */ - NVTX3_RELAXED_CONSTEXPR value_type get_value() const noexcept { return value_; } + constexpr value_type get_value() const noexcept { return value_; } /** * @brief Return the type information about the value the union holds. * */ - NVTX3_RELAXED_CONSTEXPR nvtxMessageType_t get_type() const noexcept { return type_; } + constexpr nvtxMessageType_t get_type() const noexcept { return type_; } private: - nvtxMessageType_t const type_{}; ///< message type - nvtxMessageValue_t value_{}; ///< message contents + nvtxMessageType_t type_{}; ///< message type + nvtxMessageValue_t value_{}; ///< message contents }; /** @@ -1488,17 +1740,16 @@ class message { * its `event_attributes`. * * Example: - * ``` - * nvtx3:: event_attributes attr{nvtx3::payload{42}}; // Constructs a payload - * from - * // the `int32_t` value 42 + * \code{.cpp} + * // Constructs a payload from the int32_t value 42 + * nvtx3:: event_attributes attr{nvtx3::payload{42}}; * * // `range0` will have an int32_t payload of 42 - * nvtx3::thread_range range0{attr}; + * nvtx3::scoped_range range0{attr}; * * // range1 has double payload of 3.14 - * nvtx3::thread_range range1{ nvtx3::payload{3.14} }; - * ``` + * nvtx3::scoped_range range1{nvtx3::payload{3.14}}; + * \endcode */ class payload { public: @@ -1509,7 +1760,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(int64_t value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(int64_t value) noexcept : type_{NVTX_PAYLOAD_TYPE_INT64}, value_{} { value_.llValue = value; @@ -1520,7 +1771,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(int32_t value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(int32_t value) noexcept : type_{NVTX_PAYLOAD_TYPE_INT32}, value_{} { value_.iValue = value; @@ -1531,7 +1782,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(uint64_t value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(uint64_t value) noexcept : type_{NVTX_PAYLOAD_TYPE_UNSIGNED_INT64}, value_{} { value_.ullValue = value; @@ -1542,7 +1793,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(uint32_t value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(uint32_t value) noexcept : type_{NVTX_PAYLOAD_TYPE_UNSIGNED_INT32}, value_{} { value_.uiValue = value; @@ -1554,7 +1805,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(float value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(float value) noexcept : type_{NVTX_PAYLOAD_TYPE_FLOAT}, value_{} { value_.fValue = value; @@ -1566,27 +1817,40 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(double value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(double value) noexcept : type_{NVTX_PAYLOAD_TYPE_DOUBLE}, value_{} { value_.dValue = value; } + /** + * @brief Construct a `payload` from NVTX C API type and value. + * + * @param type nvtxPayloadType_t enum value indicating type of the payload + * @param value nvtxEventAttributes_t::payload_t union containing payload + */ + constexpr payload( + nvtxPayloadType_t const& type, + value_type const& value) noexcept + : type_{type}, value_(value) + { + } + /** * @brief Return the union holding the value of the payload * */ - NVTX3_RELAXED_CONSTEXPR value_type get_value() const noexcept { return value_; } + constexpr value_type get_value() const noexcept { return value_; } /** * @brief Return the information about the type the union holds. * */ - NVTX3_RELAXED_CONSTEXPR nvtxPayloadType_t get_type() const noexcept { return type_; } + constexpr nvtxPayloadType_t get_type() const noexcept { return type_; } private: - nvtxPayloadType_t const type_; ///< Type of the payload value - value_type value_; ///< Union holding the payload value + nvtxPayloadType_t type_; ///< Type of the payload value + value_type value_; ///< Union holding the payload value }; /** @@ -1611,42 +1875,39 @@ class payload { * * Example: * \code{.cpp} - * event_attributes attr{}; // No arguments, use defaults for all attributes + * // Set message, same as using nvtx3::message{"message"} + * event_attributes attr{"message"}; * - * event_attributes attr{"message"}; // Custom message, rest defaulted - * - * // Custom color & message + * // Set message and color * event_attributes attr{"message", nvtx3::rgb{127, 255, 0}}; * - * /// Custom color & message, can use any order of arguments - * event_attributes attr{nvtx3::rgb{127, 255, 0}, "message"}; + * // Set message, color, payload, category + * event_attributes attr{"message", + * nvtx3::rgb{127, 255, 0}, + * nvtx3::payload{42}, + * nvtx3::category{1}}; * - * - * // Custom color, message, payload, category - * event_attributes attr{nvtx3::rgb{127, 255, 0}, - * "message", - * nvtx3::payload{42}, - * nvtx3::category{1}}; - * - * // Custom color, message, payload, category, can use any order of arguments + * // Same as above -- can use any order of arguments * event_attributes attr{nvtx3::payload{42}, - * nvtx3::category{1}, - * "message", - * nvtx3::rgb{127, 255, 0}}; + * nvtx3::category{1}, + * "message", + * nvtx3::rgb{127, 255, 0}}; * * // Multiple arguments of the same type are allowed, but only the first is - * // used. All others are ignored - * event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; // payload - * is 42 + * // used -- in this example, payload is set to 42: + * event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; * * // Range `r` will be customized according the attributes in `attr` - * nvtx3::thread_range r{attr}; + * nvtx3::scoped_range r{attr}; * - * // For convenience, the arguments that can be passed to the - * `event_attributes` - * // constructor may be passed to the `domain_thread_range` contructor where - * // they will be forwarded to the `EventAttribute`s constructor - * nvtx3::thread_range r{nvtx3::payload{42}, nvtx3::category{1}, "message"}; + * // For convenience, `event_attributes` constructor arguments may be passed + * // to the `scoped_range_in` contructor -- they are forwarded to the + * // `event_attributes` constructor + * nvtx3::scoped_range r{nvtx3::payload{42}, nvtx3::category{1}, "message"}; + * + * // Using the nvtx3 namespace in a local scope makes the syntax more succinct: + * using namespace nvtx3; + * scoped_range r{payload{42}, category{1}, "message"}; * \endcode * */ @@ -1682,7 +1943,7 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(category const& c, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(category const& c, Args const&... args) noexcept : event_attributes(args...) { attributes_.category = c.get_id(); @@ -1696,7 +1957,7 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(color const& c, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(color const& c, Args const&... args) noexcept : event_attributes(args...) { attributes_.color = c.get_value(); @@ -1711,7 +1972,7 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(payload const& p, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(payload const& p, Args const&... args) noexcept : event_attributes(args...) { attributes_.payload = p.get_value(); @@ -1726,14 +1987,14 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(message const& m, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(message const& m, Args const&... args) noexcept : event_attributes(args...) { attributes_.message = m.get_value(); attributes_.messageType = m.get_type(); } - /** + /** * @brief Variadic constructor where the first argument is a binary payload. * * Sets the value of the `EventAttribute`s message based on `m` and forwards @@ -1741,7 +2002,7 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(nvtxPayloadData_t const* bpl, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(nvtxPayloadData_t const* bpl, Args const&... args) noexcept : event_attributes(args...) { attributes_.payloadType = NVTX_PAYLOAD_TYPE_BINARY; @@ -1749,10 +2010,10 @@ class event_attributes { attributes_.payload.ullValue = NVTX_POINTER_AS_PAYLOAD_ULLVALUE(bpl); } - ~event_attributes() = default; + ~event_attributes() = default; event_attributes(event_attributes const&) = default; event_attributes& operator=(event_attributes const&) = default; - event_attributes(event_attributes&&) = default; + event_attributes(event_attributes&&) = default; event_attributes& operator=(event_attributes&&) = default; /** @@ -1772,16 +2033,16 @@ class event_attributes { * When constructed, begins a nested NVTX range on the calling thread in the * specified domain. Upon destruction, ends the NVTX range. * - * Behavior is undefined if a `domain_thread_range` object is + * Behavior is undefined if a `scoped_range_in` object is * created/destroyed on different threads. * - * `domain_thread_range` is neither moveable nor copyable. + * `scoped_range_in` is neither moveable nor copyable. * - * `domain_thread_range`s may be nested within other ranges. + * `scoped_range_in`s may be nested within other ranges. * * The domain of the range is specified by the template type parameter `D`. * By default, the `domain::global` is used, which scopes the range to the - * global NVTX domain. The convenience alias `thread_range` is provided for + * global NVTX domain. The convenience alias `scoped_range` is provided for * ranges scoped to the global domain. * * A custom domain can be defined by creating a type, `D`, with a static @@ -1789,48 +2050,47 @@ class event_attributes { * `D`. `D::name` must resolve to either `char const*` or `wchar_t const*` * * Example: - * ``` + * \code{.cpp} * // Define a type `my_domain` with a member `name` used to name the domain * // associated with the type `my_domain`. * struct my_domain{ - * static constexpr const char * name{"my domain"}; + * static constexpr char const* name{"my domain"}; * }; - * ``` + * \endcode * * Usage: - * ``` - * nvtx3::domain_thread_range<> r0{"range 0"}; // Range in global domain + * \code{.cpp} + * nvtx3::scoped_range_in r1{"range 1"}; // Range in my domain * - * nvtx3::thread_range r1{"range 1"}; // Alias for range in global domain + * // Three equivalent ways to make a range in the global domain: + * nvtx3::scoped_range_in r2{"range 2"}; + * nvtx3::scoped_range_in<> r3{"range 3"}; + * nvtx3::scoped_range r4{"range 4"}; * - * nvtx3::domain_thread_range r2{"range 2"}; // Range in custom - * domain + * // Create an alias to succinctly make ranges in my domain: + * using my_scoped_range = nvtx3::scoped_range_in; * - * // specify an alias to a range that uses a custom domain - * using my_thread_range = nvtx3::domain_thread_range; - * - * my_thread_range r3{"range 3"}; // Alias for range in custom domain - * ``` + * my_scoped_range r3{"range 3"}; + * \endcode */ template -class domain_thread_range { +class scoped_range_in { public: /** - * @brief Construct a `domain_thread_range` with the specified + * @brief Construct a `scoped_range_in` with the specified * `event_attributes` * * Example: - * ``` + * \code{cpp} * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}}; - * nvtx3::domain_thread_range<> range{attr}; // Creates a range with message - * contents - * // "msg" and green color - * ``` + * nvtx3::scoped_range range{attr}; // Creates a range with message contents + * // "msg" and green color + * \endcode * * @param[in] attr `event_attributes` that describes the desired attributes * of the range. */ - explicit domain_thread_range(event_attributes const& attr) noexcept + explicit scoped_range_in(event_attributes const& attr) noexcept { #ifndef NVTX_DISABLE nvtxDomainRangePushEx(domain::get(), attr.get()); @@ -1840,65 +2100,55 @@ class domain_thread_range { } /** - * @brief Constructs a `domain_thread_range` from the constructor arguments + * @brief Constructs a `scoped_range_in` from the constructor arguments * of an `event_attributes`. * - * Forwards the arguments `first, args...` to construct an + * Forwards the arguments `args...` to construct an * `event_attributes` object. The `event_attributes` object is then - * associated with the `domain_thread_range`. + * associated with the `scoped_range_in`. * * For more detail, see `event_attributes` documentation. * * Example: - * ``` + * \code{cpp} * // Creates a range with message "message" and green color - * nvtx3::domain_thread_range<> r{"message", nvtx3::rgb{127,255,0}}; - * ``` + * nvtx3::scoped_range r{"message", nvtx3::rgb{127,255,0}}; + * \endcode * - * @note To prevent making needless copies of `event_attributes` objects, - * this constructor is disabled when the first argument is an - * `event_attributes` object, instead preferring the explicit - * `domain_thread_range(event_attributes const&)` constructor. - * - * @param[in] first First argument to forward to the `event_attributes` - * constructor. - * @param[in] args Variadic parameter pack of additional arguments to - * forward. + * @param[in] args Arguments to used to construct an `event_attributes` associated with this + * range. * */ - template >::value>> - explicit domain_thread_range(First const& first, Args const&... args) noexcept - : domain_thread_range{event_attributes{first, args...}} + template + explicit scoped_range_in(Args const&... args) noexcept + : scoped_range_in{event_attributes{args...}} { } /** - * @brief Default constructor creates a `domain_thread_range` with no + * @brief Default constructor creates a `scoped_range_in` with no * message, color, payload, nor category. * */ - domain_thread_range() : domain_thread_range{event_attributes{}} {} + scoped_range_in() noexcept : scoped_range_in{event_attributes{}} {} /** * @brief Delete `operator new` to disallow heap allocated objects. * - * `domain_thread_range` must follow RAII semantics to guarantee proper push/pop semantics. + * `scoped_range_in` must follow RAII semantics to guarantee proper push/pop semantics. * */ void* operator new(std::size_t) = delete; - domain_thread_range(domain_thread_range const&) = delete; - domain_thread_range& operator=(domain_thread_range const&) = delete; - domain_thread_range(domain_thread_range&&) = delete; - domain_thread_range& operator=(domain_thread_range&&) = delete; + scoped_range_in(scoped_range_in const&) = delete; + scoped_range_in& operator=(scoped_range_in const&) = delete; + scoped_range_in(scoped_range_in&&) = delete; + scoped_range_in& operator=(scoped_range_in&&) = delete; /** - * @brief Destroy the domain_thread_range, ending the NVTX range event. + * @brief Destroy the scoped_range_in, ending the NVTX range event. */ - ~domain_thread_range() noexcept + ~scoped_range_in() noexcept { #ifndef NVTX_DISABLE nvtxDomainRangePop(domain::get()); @@ -1907,25 +2157,103 @@ class domain_thread_range { }; /** - * @brief Alias for a `domain_thread_range` in the global NVTX domain. + * @brief Alias for a `scoped_range_in` in the global NVTX domain. * */ -using thread_range = domain_thread_range<>; +using scoped_range = scoped_range_in; + +namespace detail { + +/// @cond internal +template +class optional_scoped_range_in +{ +public: + optional_scoped_range_in() = default; + + void begin(event_attributes const& attr) noexcept + { +#ifndef NVTX_DISABLE + // This class is not meant to be part of the public NVTX C++ API and should + // only be used in the `NVTX3_FUNC_RANGE_IF` and `NVTX3_FUNC_RANGE_IF_IN` + // macros. However, to prevent developers from misusing this class, make + // sure to not start multiple ranges. + if (initialized) { return; } + + nvtxDomainRangePushEx(domain::get(), attr.get()); + initialized = true; +#endif + } + + ~optional_scoped_range_in() noexcept + { +#ifndef NVTX_DISABLE + if (initialized) { nvtxDomainRangePop(domain::get()); } +#endif + } + + void* operator new(std::size_t) = delete; + optional_scoped_range_in(optional_scoped_range_in const&) = delete; + optional_scoped_range_in& operator=(optional_scoped_range_in const&) = delete; + optional_scoped_range_in(optional_scoped_range_in&&) = delete; + optional_scoped_range_in& operator=(optional_scoped_range_in&&) = delete; + +private: +#ifndef NVTX_DISABLE + bool initialized = false; +#endif +}; +/// @endcond + +} // namespace detail /** * @brief Handle used for correlating explicit range start and end events. * + * A handle is "null" if it does not correspond to any range. + * */ struct range_handle { /// Type used for the handle's value using value_type = nvtxRangeId_t; + /** * @brief Construct a `range_handle` from the given id. * */ constexpr explicit range_handle(value_type id) noexcept : _range_id{id} {} + /** + * @brief Constructs a null range handle. + * + * A null range_handle corresponds to no range. Calling `end_range` on a + * null handle is undefined behavior when a tool is active. + * + */ + constexpr range_handle() noexcept = default; + + /** + * @brief Checks whether this handle is null + * + * Provides contextual conversion to `bool`. + * + * \code{cpp} + * range_handle handle{}; + * if (handle) {...} + * \endcode + * + */ + constexpr explicit operator bool() const noexcept { return get_value() != null_range_id; }; + + /** + * @brief Implicit conversion from `nullptr` constructs a null handle. + * + * Satisfies the "NullablePointer" requirement to make `range_handle` comparable with `nullptr`. + * + */ + constexpr range_handle(std::nullptr_t) noexcept {} + /** * @brief Returns the `range_handle`'s value * @@ -1934,42 +2262,68 @@ struct range_handle { constexpr value_type get_value() const noexcept { return _range_id; } private: - value_type _range_id{}; ///< The underlying NVTX range id + /// Sentinel value for a null handle that corresponds to no range + static constexpr value_type null_range_id = nvtxRangeId_t{0}; + + value_type _range_id{null_range_id}; ///< The underlying NVTX range id }; +/** + * @brief Compares two range_handles for equality + * + * @param lhs The first range_handle to compare + * @param rhs The second range_handle to compare + */ +inline constexpr bool operator==(range_handle lhs, range_handle rhs) noexcept +{ + return lhs.get_value() == rhs.get_value(); +} + +/** + * @brief Compares two range_handles for inequality + * + * @param lhs The first range_handle to compare + * @param rhs The second range_handle to compare + */ +inline constexpr bool operator!=(range_handle lhs, range_handle rhs) noexcept { return !(lhs == rhs); } + /** * @brief Manually begin an NVTX range. * * Explicitly begins an NVTX range and returns a unique handle. To end the - * range, pass the handle to `end_range()`. + * range, pass the handle to `end_range_in()`. * - * `start_range/end_range` are the most explicit and lowest level APIs provided - * for creating ranges. Use of `nvtx3::domain_process_range` should be + * `nvtx3::start_range(...)` is equivalent to `nvtx3::start_range_in<>(...)` and + * `nvtx3::start_range_in(...)`. + * + * `start_range_in/end_range_in` are the most explicit and lowest level APIs + * provided for creating ranges. Use of `nvtx3::unique_range_in` should be * preferred unless one is unable to tie the range to the lifetime of an object. * * Example: - * ``` + * \code{.cpp} * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}}; - * nvtx3::range_handle h = nvxt3::start_range(attr); // Manually begins a range + * // Manually begin a range + * nvtx3::range_handle h = nvtx3::start_range_in(attr); * ... - * nvtx3::end_range(h); // Ends the range - * ``` + * nvtx3::end_range_in(h); // End the range + * \endcode * * @tparam D Type containing `name` member used to identify the `domain` * to which the range belongs. Else, `domain::global` to indicate that the * global NVTX domain should be used. * @param[in] attr `event_attributes` that describes the desired attributes * of the range. - * @return Unique handle to be passed to `end_range` to end the range. + * @return Unique handle to be passed to `end_range_in` to end the range. */ template -range_handle start_range(event_attributes const& attr) noexcept +inline range_handle start_range_in(event_attributes const& attr) noexcept { #ifndef NVTX_DISABLE return range_handle{nvtxDomainRangeStartEx(domain::get(), attr.get())}; #else (void)attr; - return range_handle{}; + return {}; #endif } @@ -1977,60 +2331,157 @@ range_handle start_range(event_attributes const& attr) noexcept * @brief Manually begin an NVTX range. * * Explicitly begins an NVTX range and returns a unique handle. To end the - * range, pass the handle to `end_range()`. + * range, pass the handle to `end_range_in()`. * - * Forwards the arguments `first, args...` to construct an `event_attributes` - * object. The `event_attributes` object is then associated with the range. + * `nvtx3::start_range(...)` is equivalent to `nvtx3::start_range_in<>(...)` and + * `nvtx3::start_range_in(...)`. * - * For more detail, see `event_attributes` documentation. - * - * Example: - * ``` - * nvtx3::range_handle h = nvxt3::start_range("msg", nvtx3::rgb{127,255,0}); // - * Begin range - * ... - * nvtx3::end_range(h); // Ends the range - * ``` - * - * `start_range/end_range` are the most explicit and lowest level APIs provided - * for creating ranges. Use of `nvtx3::domain_process_range` should be + * `start_range_in/end_range_in` are the most explicit and lowest level APIs + * provided for creating ranges. Use of `nvtx3::unique_range_in` should be * preferred unless one is unable to tie the range to the lifetime of an object. * - * @param first[in] First argument to pass to an `event_attributes` - * @param args[in] Variadiac parameter pack of the rest of the arguments for an - * `event_attributes`. + * This overload uses `args...` to construct an `event_attributes` to + * associate with the range. For more detail, see `event_attributes`. + * + * Example: + * \code{cpp} + * // Manually begin a range + * nvtx3::range_handle h = nvtx3::start_range_in("msg", nvtx3::rgb{127,255,0}); + * ... + * nvtx3::end_range_in(h); // Ends the range + * \endcode + * + * @tparam D Type containing `name` member used to identify the `domain` + * to which the range belongs. Else, `domain::global` to indicate that the + * global NVTX domain should be used. + * @param args[in] Variadic parameter pack of the arguments for an `event_attributes`. * @return Unique handle to be passed to `end_range` to end the range. */ -template >::value>> -range_handle start_range(First const& first, Args const&... args) noexcept +template +inline range_handle start_range_in(Args const&... args) noexcept { #ifndef NVTX_DISABLE - return start_range(event_attributes{first, args...}); + return start_range_in(event_attributes{args...}); #else - (void)first; - return range_handle{}; + return {}; #endif } /** - * @brief Manually end the range associated with the handle `r`. + * @brief Manually begin an NVTX range in the global domain. + * + * Explicitly begins an NVTX range and returns a unique handle. To end the + * range, pass the handle to `end_range()`. + * + * `nvtx3::start_range(...)` is equivalent to `nvtx3::start_range_in<>(...)` and + * `nvtx3::start_range_in(...)`. + * + * `start_range/end_range` are the most explicit and lowest level APIs + * provided for creating ranges. Use of `nvtx3::unique_range` should be + * preferred unless one is unable to tie the range to the lifetime of an object. + * + * Example: + * \code{.cpp} + * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}}; + * // Manually begin a range + * nvtx3::range_handle h = nvtx3::start_range(attr); + * ... + * nvtx3::end_range(h); // End the range + * \endcode + * + * @param[in] attr `event_attributes` that describes the desired attributes + * of the range. + * @return Unique handle to be passed to `end_range_in` to end the range. + */ +inline range_handle start_range(event_attributes const& attr) noexcept +{ +#ifndef NVTX_DISABLE + return start_range_in(attr); +#else + (void)attr; + return {}; +#endif +} + +/** + * @brief Manually begin an NVTX range in the global domain. + * + * Explicitly begins an NVTX range and returns a unique handle. To end the + * range, pass the handle to `end_range_in()`. + * + * `nvtx3::start_range(...)` is equivalent to `nvtx3::start_range_in<>(...)` and + * `nvtx3::start_range_in(...)`. + * + * `start_range_in/end_range_in` are the most explicit and lowest level APIs + * provided for creating ranges. Use of `nvtx3::unique_range_in` should be + * preferred unless one is unable to tie the range to the lifetime of an object. + * + * This overload uses `args...` to construct an `event_attributes` to + * associate with the range. For more detail, see `event_attributes`. + * + * Example: + * \code{cpp} + * // Manually begin a range + * nvtx3::range_handle h = nvtx3::start_range("msg", nvtx3::rgb{127,255,0}); + * ... + * nvtx3::end_range(h); // Ends the range + * \endcode + * + * @param args[in] Variadic parameter pack of the arguments for an `event_attributes`. + * @return Unique handle to be passed to `end_range` to end the range. + */ +template +inline range_handle start_range(Args const&... args) noexcept +{ +#ifndef NVTX_DISABLE + return start_range_in(args...); +#else + return {}; +#endif +} + +/** + * @brief Manually end the range associated with the handle `r` in domain `D`. + * + * Explicitly ends the NVTX range indicated by the handle `r` returned from a + * prior call to `start_range_in`. The range may end on a different thread + * from where it began. + * + * @tparam D Type containing `name` member used to identify the `domain` to + * which the range belongs. Else, `domain::global` to indicate that the global + * NVTX domain should be used. + * @param r Handle to a range started by a prior call to `start_range_in`. + * + * @warning The domain type specified as template parameter to this function + * must be the same that was specified on the associated `start_range_in` call. + */ +template +inline void end_range_in(range_handle r) noexcept +{ +#ifndef NVTX_DISABLE + nvtxDomainRangeEnd(domain::get(), r.get_value()); +#else + (void)r; +#endif +} + +/** + * @brief Manually end the range associated with the handle `r` in the global + * domain. * * Explicitly ends the NVTX range indicated by the handle `r` returned from a * prior call to `start_range`. The range may end on a different thread from * where it began. * - * This function does not have a Domain tag type template parameter as the - * handle `r` already indicates the domain to which the range belongs. - * * @param r Handle to a range started by a prior call to `start_range`. + * + * @warning The domain type specified as template parameter to this function + * must be the same that was specified on the associated `start_range` call. */ -inline void end_range(range_handle r) +inline void end_range(range_handle r) noexcept { #ifndef NVTX_DISABLE - nvtxRangeEnd(r.get_value()); + end_range_in(r); #else (void)r; #endif @@ -2043,120 +2494,145 @@ inline void end_range(range_handle r) * When constructed, begins a NVTX range in the specified domain. Upon * destruction, ends the NVTX range. * - * Similar to `nvtx3::domain_thread_range`, the only difference being that - * `domain_process_range` can start and end on different threads. + * Similar to `nvtx3::scoped_range_in`, with a few key differences: + * - `unique_range` objects can be destroyed in an order whereas `scoped_range` objects must be + * destroyed in exact reverse creation order + * - `unique_range` can start and end on different threads + * - `unique_range` is moveable + * - `unique_range` objects can be constructed as heap objects * - * Use of `nvtx3::domain_thread_range` should be preferred unless one needs - * the ability to start and end a range on different threads. - * - * `domain_process_range` is moveable, but not copyable. + * There is extra overhead associated with `unique_range` constructs and therefore use of + * `nvtx3::scoped_range_in` should be preferred. * * @tparam D Type containing `name` member used to identify the `domain` - * to which the `domain_process_range` belongs. Else, `domain::global` to + * to which the `unique_range_in` belongs. Else, `domain::global` to * indicate that the global NVTX domain should be used. */ template -class domain_process_range { +class unique_range_in { public: /** - * @brief Construct a new domain process range object + * @brief Construct a new unique_range_in object with the specified event attributes * - * @param attr + * Example: + * \code{cpp} + * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}}; + * nvtx3::unique_range_in range{attr}; // Creates a range with message contents + * // "msg" and green color + * \endcode + * + * @param[in] attr `event_attributes` that describes the desired attributes + * of the range. */ - explicit domain_process_range(event_attributes const& attr) noexcept - : handle_{new range_handle{start_range(attr)}} + explicit unique_range_in(event_attributes const& attr) noexcept + : handle_{start_range_in(attr)} { } /** - * @brief Construct a new domain process range object + * @brief Constructs a `unique_range_in` from the constructor arguments + * of an `event_attributes`. * - * @param first - * @param args + * Forwards the arguments `args...` to construct an + * `event_attributes` object. The `event_attributes` object is then + * associated with the `unique_range_in`. + * + * For more detail, see `event_attributes` documentation. + * + * Example: + * \code{.cpp} + * // Creates a range with message "message" and green color + * nvtx3::unique_range_in<> r{"message", nvtx3::rgb{127,255,0}}; + * \endcode + * + * @param[in] args Variadic parameter pack of arguments to construct an `event_attributes` + * associated with this range. */ - template >::value>> - explicit domain_process_range(First const& first, Args const&... args) noexcept - : domain_process_range{event_attributes{first, args...}} + template + explicit unique_range_in(Args const&... args) noexcept + : unique_range_in{event_attributes{args...}} { } /** - * @brief Construct a new domain process range object + * @brief Default constructor creates a `unique_range_in` with no + * message, color, payload, nor category. * */ - constexpr domain_process_range() noexcept : domain_process_range{event_attributes{}} {} + constexpr unique_range_in() noexcept : unique_range_in{event_attributes{}} {} /** - * @brief Destroy the `domain_process_range` ending the range. + * @brief Destroy the `unique_range_in` ending the range. * */ - ~domain_process_range() - { - if (handle_) { end_range(*handle_); } - } + ~unique_range_in() noexcept = default; /** * @brief Move constructor allows taking ownership of the NVTX range from - * another `domain_process_range`. + * another `unique_range_in`. * - * @param other + * @param other The range to take ownership of */ - domain_process_range(domain_process_range&& other) = default; + unique_range_in(unique_range_in&& other) noexcept = default; /** * @brief Move assignment operator allows taking ownership of an NVTX range - * from another `domain_process_range`. + * from another `unique_range_in`. * - * @param other - * @return domain_process_range& + * @param other The range to take ownership of */ - domain_process_range& operator=(domain_process_range&& other) = default; + unique_range_in& operator=(unique_range_in&& other) noexcept = default; /// Copy construction is not allowed to prevent multiple objects from owning /// the same range handle - domain_process_range(domain_process_range const&) = delete; + unique_range_in(unique_range_in const&) = delete; /// Copy assignment is not allowed to prevent multiple objects from owning the /// same range handle - domain_process_range& operator=(domain_process_range const&) = delete; + unique_range_in& operator=(unique_range_in const&) = delete; private: - std::unique_ptr handle_; ///< Range handle used to correlate - ///< the start/end of the range + + struct end_range_handle { + using pointer = range_handle; /// Override the pointer type of the unique_ptr + void operator()(range_handle h) const noexcept { end_range_in(h); } + }; + + /// Range handle used to correlate the start/end of the range + std::unique_ptr handle_; }; /** - * @brief Alias for a `domain_process_range` in the global NVTX domain. + * @brief Alias for a `unique_range_in` in the global NVTX domain. * */ -using process_range = domain_process_range<>; +using unique_range = unique_range_in; /** - * @brief Annotates an instantaneous point in time with the attributes specified - * by `attr`. + * @brief Annotates an instantaneous point in time with a "marker", using the + * attributes specified by `attr`. * - * Unlike a "range", a mark is an instantaneous event in an application, e.g., - * locking/unlocking a mutex. + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: * * \code{.cpp} - * std::mutex global_lock; - * void lock_mutex(){ - * global_lock.lock(); - * nvtx3::mark("lock_mutex"); + * bool success = do_operation(...); + * if (!success) { + * nvtx3::event_attributes attr{"operation failed!", nvtx3::rgb{255,0,0}}; + * nvtx3::mark_in(attr); * } * \endcode * + * Note that nvtx3::mark_in is a function, not a class like scoped_range_in. + * * @tparam D Type containing `name` member used to identify the `domain` - * to which the `domain_process_range` belongs. Else, `domain::global` to + * to which the `unique_range_in` belongs. Else, `domain::global` to * indicate that the global NVTX domain should be used. * @param[in] attr `event_attributes` that describes the desired attributes * of the mark. */ template -inline void mark(event_attributes const& attr) noexcept +inline void mark_in(event_attributes const& attr) noexcept { #ifndef NVTX_DISABLE nvtxDomainMarkEx(domain::get(), attr.get()); @@ -2165,10 +2641,105 @@ inline void mark(event_attributes const& attr) noexcept #endif } +/** + * @brief Annotates an instantaneous point in time with a "marker", using the + * arguments to construct an `event_attributes`. + * + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: + * + * \code{.cpp} + * bool success = do_operation(...); + * if (!success) { + * nvtx3::mark_in("operation failed!", nvtx3::rgb{255,0,0}); + * } + * \endcode + * + * Note that nvtx3::mark_in is a function, not a class like scoped_range_in. + * + * Forwards the arguments `args...` to construct an `event_attributes` object. + * The attributes are then associated with the marker. For more detail, see + * the `event_attributes` documentation. + * + * @tparam D Type containing `name` member used to identify the `domain` + * to which the `unique_range_in` belongs. Else `domain::global` to + * indicate that the global NVTX domain should be used. + * @param[in] args Variadic parameter pack of arguments to construct an `event_attributes` + * associated with this range. + * + */ +template +inline void mark_in(Args const&... args) noexcept +{ +#ifndef NVTX_DISABLE + mark_in(event_attributes{args...}); +#endif +} + +/** + * @brief Annotates an instantaneous point in time with a "marker", using the + * attributes specified by `attr`, in the global domain. + * + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: + * + * \code{.cpp} + * bool success = do_operation(...); + * if (!success) { + * nvtx3::event_attributes attr{"operation failed!", nvtx3::rgb{255,0,0}}; + * nvtx3::mark(attr); + * } + * \endcode + * + * Note that nvtx3::mark is a function, not a class like scoped_range. + * + * @param[in] attr `event_attributes` that describes the desired attributes + * of the mark. + */ +inline void mark(event_attributes const& attr) noexcept +{ +#ifndef NVTX_DISABLE + mark_in(attr); +#endif +} + +/** + * @brief Annotates an instantaneous point in time with a "marker", using the + * arguments to construct an `event_attributes`, in the global domain. + * + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: + * + * \code{.cpp} + * bool success = do_operation(...); + * if (!success) { + * nvtx3::mark("operation failed!", nvtx3::rgb{255,0,0}); + * } + * \endcode + * + * Note that nvtx3::mark is a function, not a class like scoped_range. + * + * Forwards the arguments `args...` to construct an `event_attributes` object. + * The attributes are then associated with the marker. For more detail, see + * the `event_attributes` documentation. + * + * @param[in] args Variadic parameter pack of arguments to construct an + * `event_attributes` associated with this range. + * + */ +template +inline void mark(Args const&... args) noexcept +{ +#ifndef NVTX_DISABLE + mark_in(args...); +#endif +} + } // namespace NVTX3_VERSION_NAMESPACE } // namespace nvtx3 +#ifndef NVTX_DISABLE /** * @brief Convenience macro for generating a range in the specified `domain` * from the lifetime of a function @@ -2177,34 +2748,58 @@ inline void mark(event_attributes const& attr) noexcept * the entry point of a function to its exit. It is intended to be the first * line of the function. * - * Constructs a static `registered_string` using the name of the immediately + * Constructs a static `registered_string_in` using the name of the immediately * enclosing function returned by `__func__` and constructs a - * `nvtx3::thread_range` using the registered function name as the range's + * `nvtx3::scoped_range` using the registered function name as the range's * message. * * Example: - * ``` + * \code{.cpp} * struct my_domain{static constexpr char const* name{"my_domain"};}; * - * void foo(...){ + * void foo(...) { * NVTX3_FUNC_RANGE_IN(my_domain); // Range begins on entry to foo() * // do stuff * ... * } // Range ends on return from foo() - * ``` + * \endcode * * @param[in] D Type containing `name` member used to identify the - * `domain` to which the `registered_string` belongs. Else, + * `domain` to which the `registered_string_in` belongs. Else, * `domain::global` to indicate that the global NVTX domain should be used. */ -#ifndef NVTX_DISABLE #define NVTX3_V1_FUNC_RANGE_IN(D) \ - static ::nvtx3::v1::registered_string const nvtx3_func_name__{__func__}; \ + static ::nvtx3::v1::registered_string_in const nvtx3_func_name__{__func__}; \ static ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__}; \ - ::nvtx3::v1::domain_thread_range const nvtx3_range__{nvtx3_func_attr__}; + ::nvtx3::v1::scoped_range_in const nvtx3_range__{nvtx3_func_attr__}; + +/** + * @brief Convenience macro for generating a range in the specified `domain` + * from the lifetime of a function if the given boolean expression evaluates + * to true. + * + * Similar to `NVTX3_V1_FUNC_RANGE_IN(D)`, the only difference being that + * `NVTX3_V1_FUNC_RANGE_IF_IN(D, C)` only generates a range if the given boolean + * expression evaluates to true. + * + * @param[in] D Type containing `name` member used to identify the + * `domain` to which the `registered_string_in` belongs. Else, + * `domain::global` to indicate that the global NVTX domain should be used. + * + * @param[in] C Boolean expression used to determine if a range should be + * generated. + */ +#define NVTX3_V1_FUNC_RANGE_IF_IN(D, C) \ + ::nvtx3::v1::detail::optional_scoped_range_in optional_nvtx3_range__; \ + if (C) { \ + static ::nvtx3::v1::registered_string_in const nvtx3_func_name__{__func__}; \ + static ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__}; \ + optional_nvtx3_range__.begin(nvtx3_func_attr__); \ + } #else #define NVTX3_V1_FUNC_RANGE_IN(D) -#endif +#define NVTX3_V1_FUNC_RANGE_IF_IN(D, C) +#endif // NVTX_DISABLE /** * @brief Convenience macro for generating a range in the global domain from the @@ -2214,28 +2809,43 @@ inline void mark(event_attributes const& attr) noexcept * the entry point of a function to its exit. It is intended to be the first * line of the function. * - * Constructs a static `registered_string` using the name of the immediately + * Constructs a static `registered_string_in` using the name of the immediately * enclosing function returned by `__func__` and constructs a - * `nvtx3::thread_range` using the registered function name as the range's + * `nvtx3::scoped_range` using the registered function name as the range's * message. * * Example: - * ``` - * void foo(...){ + * \code{.cpp} + * void foo(...) { * NVTX3_FUNC_RANGE(); // Range begins on entry to foo() * // do stuff * ... * } // Range ends on return from foo() - * ``` + * \endcode */ #define NVTX3_V1_FUNC_RANGE() NVTX3_V1_FUNC_RANGE_IN(::nvtx3::v1::domain::global) +/** + * @brief Convenience macro for generating a range in the global domain from the + * lifetime of a function if the given boolean expression evaluates to true. + * + * Similar to `NVTX3_V1_FUNC_RANGE()`, the only difference being that + * `NVTX3_V1_FUNC_RANGE_IF(C)` only generates a range if the given boolean + * expression evaluates to true. + * + * @param[in] C Boolean expression used to determine if a range should be + * generated. + */ +#define NVTX3_V1_FUNC_RANGE_IF(C) NVTX3_V1_FUNC_RANGE_IF_IN(::nvtx3::v1::domain::global, C) + /* When inlining this version, versioned macros must have unversioned aliases. * For each NVTX3_Vx_ #define, make an NVTX3_ alias of it here.*/ #if defined(NVTX3_INLINE_THIS_VERSION) /* clang format off */ -#define NVTX3_FUNC_RANGE_IN NVTX3_V1_FUNC_RANGE_IN -#define NVTX3_FUNC_RANGE NVTX3_V1_FUNC_RANGE +#define NVTX3_FUNC_RANGE NVTX3_V1_FUNC_RANGE +#define NVTX3_FUNC_RANGE_IF NVTX3_V1_FUNC_RANGE_IF +#define NVTX3_FUNC_RANGE_IN NVTX3_V1_FUNC_RANGE_IN +#define NVTX3_FUNC_RANGE_IF_IN NVTX3_V1_FUNC_RANGE_IF_IN /* clang format on */ #endif @@ -2278,8 +2888,18 @@ inline void mark(event_attributes const& attr) noexcept #undef NVTX3_NAMESPACE_FOR #undef NVTX3_VERSION_NAMESPACE #undef NVTX3_INLINE_IF_REQUESTED -#undef NVTX3_RELAXED_CONSTEXPR +#undef NVTX3_CONSTEXPR_IF_CPP14 #if defined(NVTX3_INLINE_THIS_VERSION) #undef NVTX3_INLINE_THIS_VERSION #endif + +#if defined(NVTX3_USE_CHECKED_OVERLOADS_FOR_GET_DEFINED_HERE) +#undef NVTX3_USE_CHECKED_OVERLOADS_FOR_GET_DEFINED_HERE +#undef NVTX3_USE_CHECKED_OVERLOADS_FOR_GET +#endif + +#if defined(NVTX3_STATIC_ASSERT_DEFINED_HERE) +#undef NVTX3_STATIC_ASSERT_DEFINED_HERE +#undef NVTX3_STATIC_ASSERT +#endif diff --git a/src/include/nvtx3/nvtxDetail/nvtxImpl.h b/src/include/nvtx3/nvtxDetail/nvtxImpl.h index be27f43..590ce90 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImpl.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImpl.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCore.h b/src/include/nvtx3/nvtxDetail/nvtxImplCore.h index 9f014ca..7a48aa8 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplCore.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplCore.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h index d4c0cdf..156f15a 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h index 4b5d6c7..5a379b1 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h index 4a026f0..bd8d404 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h index 90616da..686686c 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxInit.h b/src/include/nvtx3/nvtxDetail/nvtxInit.h index 44dcc0f..43cad70 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxInit.h +++ b/src/include/nvtx3/nvtxDetail/nvtxInit.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h b/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h index 261681b..a52e278 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h +++ b/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h b/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h index ded156c..a670d96 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h +++ b/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h b/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h index 908ce88..57661c7 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h +++ b/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxTypes.h b/src/include/nvtx3/nvtxDetail/nvtxTypes.h index 53c6c00..f646b54 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxTypes.h +++ b/src/include/nvtx3/nvtxDetail/nvtxTypes.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h index d589f63..4663fda 100644 --- a/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h @@ -35,10 +35,11 @@ NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)[NVTX3EXT_CBID_PAYLOAD_FN_NUM NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadInitOnce)() { + intptr_t* fnSlots = NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots) + 1; nvtxExtModuleSegment_t segment = { 0, // unused (only one segment) NVTX3EXT_CBID_PAYLOAD_FN_NUM, - NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots) + 1 + fnSlots }; nvtxExtModuleInfo_t module = { diff --git a/src/include/proxy.h b/src/include/proxy.h index 4c75e21..5e7f728 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -10,6 +10,7 @@ #include "devcomm.h" #include "info.h" #include "socket.h" +#include "ipcsocket.h" #include #include "shm.h" @@ -161,6 +162,31 @@ struct ncclProxyProgressState { int nextOps; }; +// Expected proxy response fifo +struct ncclExpectedProxyResponse { + void* opId; + int respSize; + bool done; + void* respBuff; + struct ncclExpectedProxyResponse* next; +}; + +struct ncclProxyAsyncOp { + int type; + struct ncclProxyConnection* connection; + int reqSize, respSize; + char *reqBuff, *respBuff; + void* opId; + ncclProxyAsyncOp* next; +}; + +struct ncclProxyLocalPeer { + struct ncclSocket sock; + int localRank; + ncclProxyAsyncOp* asyncOps; + int asyncOpCounter; +}; + struct ncclProxyState { // Service thread pthread_t thread; @@ -176,6 +202,9 @@ struct ncclProxyState { // Progress thread struct ncclProxyProgressState progressState; + + // Queue of expected responses from the proxy + struct ncclExpectedProxyResponse* expectedResponses; }; enum proxyConnectState { @@ -220,10 +249,19 @@ enum ncclProxyMsgType { ncclProxyMsgStart = 5, ncclProxyMsgClose = 6, ncclProxyMsgAbort = 7, - ncclProxyMsgStop = 8 + ncclProxyMsgStop = 8, + ncclProxyMsgConvertFd = 9 // cuMem API support }; -ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize); +// This function is called by a client of the proxy that needs to invoke any of the non-progress proxyOp types +// Call this function on the client, supplying a locally unique opId. Then, poll on the return value of +// ncclPollProxyResponse(), supplying the same opId to confirm the operation has completed +ncclResult_t ncclProxyCallAsync(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, int respSize, void* opId); + +// This function will internally call ncclProxyCallAsync() and spin until ncclPollProxyResponse() confirms the result is received +ncclResult_t ncclProxyCallBlocking(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize); +ncclResult_t ncclPollProxyResponse(struct ncclProxyConnector* proxyConn, void* respBuff, void* opId); + ncclResult_t ncclProxyDestroy(struct ncclComm* comm); ncclResult_t ncclProxyShmUnlink(struct ncclComm* comm); #endif diff --git a/src/include/socket.h b/src/include/socket.h index a0c7a4d..9e51372 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -92,6 +92,6 @@ ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset); ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size); -ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed); +ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed, bool blocking); ncclResult_t ncclSocketClose(struct ncclSocket* sock); #endif diff --git a/src/include/transport.h b/src/include/transport.h index e13c9e8..f212f26 100644 --- a/src/include/transport.h +++ b/src/include/transport.h @@ -62,7 +62,7 @@ struct ncclTransportComm { }; struct ncclTransport { - const char name[4]; + const char name[8]; ncclResult_t (*canConnect)(int*, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo*, struct ncclPeerInfo*); struct ncclTransportComm send; struct ncclTransportComm recv; @@ -71,6 +71,9 @@ struct ncclTransport { ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, int channelId, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex); ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType=NULL); +ncclResult_t ncclNvlsSetup(struct ncclComm* comm); +ncclResult_t ncclNvlsFree(struct ncclComm* comm); + enum { collNetRecv=0, collNetSend=1 }; int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, struct ncclChannel* channel, int masterRank, int masterPeer, int collNetGraphChannelId, int type); ncclResult_t ncclTransportCollNetCheck(struct ncclComm* comm, int collNetSetupFail); diff --git a/src/init.cc b/src/init.cc index 6a5f3c3..40f7872 100644 --- a/src/init.cc +++ b/src/init.cc @@ -35,13 +35,13 @@ #endif const char* ncclFuncStr[NCCL_NUM_FUNCTIONS] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce" }; -const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain" }; +const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS" }; const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" }; NCCL_PARAM(GroupCudaStream, "GROUP_CUDA_STREAM", NCCL_GROUP_CUDA_STREAM); NCCL_PARAM(CheckPointers, "CHECK_POINTERS", 0); -NCCL_PARAM(CommBlocking, "COMM_BLOCKING", 0); +NCCL_PARAM(CommBlocking, "COMM_BLOCKING", NCCL_CONFIG_UNDEF_INT); static uint64_t hashUniqueId(ncclUniqueId const &id) { char const *bytes = (char const*)&id; @@ -67,12 +67,8 @@ ncclResult_t initGdrCopy() { return ncclSuccess; } - -NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0); - pthread_mutex_t initLock = PTHREAD_MUTEX_INITIALIZER; static bool initialized = false; -static size_t maxLocalSizeBytes = 0; static ncclResult_t ncclInit() { if (__atomic_load_n(&initialized, __ATOMIC_ACQUIRE)) return ncclSuccess; @@ -80,9 +76,6 @@ static ncclResult_t ncclInit() { if (!initialized) { initEnv(); initGdrCopy(); - maxLocalSizeBytes = ncclKernMaxLocalSize(); - int carveout = ncclParamL1SharedMemoryCarveout(); - if (carveout) ncclKernSetSharedMemoryCarveout(carveout); // Always initialize bootstrap network NCCLCHECK(bootstrapNetInit()); NCCLCHECK(ncclNetPluginInit()); @@ -210,6 +203,8 @@ static ncclResult_t commFree(ncclComm_t comm) { NCCLCHECK(ncclStrongStreamDestruct(&comm->deviceStream)); } + if (comm->nvlsSupport) NCCLCHECK(ncclNvlsFree(comm)); + struct ncclDestructor* dtor = comm->destructorHead; while (dtor != nullptr) { NCCLCHECK(dtor->fn(dtor)); @@ -220,6 +215,7 @@ static ncclResult_t commFree(ncclComm_t comm) { ncclMemoryStackDestruct(&comm->memPermanent); ncclCudaHostFree((void *)comm->abortFlag); + free(comm->netName); commPoison(comm); // poison comm before free to avoid comm reuse. free(comm); @@ -243,8 +239,8 @@ static ncclResult_t dmaBufSupported(struct ncclComm* comm) { int flag = 0; CUdevice dev; int cudaDriverVersion; - CUCHECK(cuDriverGetVersion(&cudaDriverVersion)); - if (cudaDriverVersion < 11070) return ncclInternalError; + CUDACHECK(cudaDriverGetVersion(&cudaDriverVersion)); + if (CUPFN(cuDeviceGet) == NULL || cudaDriverVersion < 11070) return ncclInternalError; CUCHECK(cuDeviceGet(&dev, comm->cudaDev)); // Query device to see if DMA-BUF support is available (void) CUPFN(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED, dev)); @@ -265,7 +261,7 @@ ncclResult_t ncclCommEnsureReady(ncclComm_t comm) { NCCLCHECK(ncclCommGetAsyncError(comm, &ret)); if (ret != ncclSuccess) { /* if ret is not ncclInProgress, we just keep it. */ - WARN("Attempt to use communicator before the previous operation returned ncclSuccess\n"); + WARN("Attempt to use communicator before the previous operation returned ncclSuccess"); if (ret == ncclInProgress) ret = ncclInvalidArgument; goto exit; } @@ -395,6 +391,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { tmpCommAndChans.channels[c].tree = comm->channels[c].tree; tmpCommAndChans.channels[c].collnetChain = comm->channels[c].collnetChain; tmpCommAndChans.channels[c].collnetDirect = comm->channels[c].collnetDirect; + tmpCommAndChans.channels[c].nvls = comm->channels[c].nvls; tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c]; if (comm->channels[c].ring.userRanks != nullptr) { @@ -521,8 +518,8 @@ static ncclResult_t collNetTrySetup(ncclComm_t comm, struct ncclTopoGraph* collN struct ncclChannel* channel = comm->channels + c; for (int h = 0; h < nHeads; h++) { const int head = heads[h]; - collNetSetupFail = ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetRecv); - if (!collNetSetupFail) collNetSetupFail = ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetSend); + collNetSetupFail |= ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetRecv); + if (!collNetSetupFail) collNetSetupFail |= ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetSend); } // Verify CollNet setup across ranks after trying the first channel if (c == 0) { @@ -922,6 +919,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm // Check if we can setup CollNet if (comm->collNetSupport > 0) collNetTrySetup(comm, &collNetGraph); + NCCLCHECKGOTO(ncclNvlsSetup(comm), ret, fail); + TRACE(NCCL_INIT, "rank %d nranks %d - CONNECTED %d RINGS AND TREES", rank, nranks, comm->nChannels); // Compute time models for algorithm and protocol combinations @@ -929,7 +928,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm int myCompCap = comm->peerInfo[rank].cudaCompCap; int minCompCap = myCompCap, maxCompCap = myCompCap; for (int i = 0; i < nranks; i++) { - minCompCap = std::min(comm->peerInfo[i].cudaCompCap, minCompCap); + comm->minCompCap = minCompCap = std::min(comm->peerInfo[i].cudaCompCap, minCompCap); maxCompCap = std::max(comm->peerInfo[i].cudaCompCap, maxCompCap); } NCCLCHECKGOTO(ncclTopoTuneModel(comm, minCompCap, maxCompCap, &treeGraph, &ringGraph, &collNetGraph), ret, fail); @@ -938,6 +937,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm // Compute nChannels per peer for p2p NCCLCHECKGOTO(ncclTopoComputeP2pChannels(comm), ret, fail); + INFO(NCCL_INIT, "%d coll channels, %d nvls channels, %d p2p channels, %d p2p channels per peer", comm->nChannels, comm->nvlsChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer); + do { // Setup p2p structures in comm->tasks struct ncclTasks* tasks = &comm->tasks; int nRanks = comm->nRanks; @@ -1004,12 +1005,13 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } } } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, NULL, 1), ret, fail); } // Connect to local net proxy NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_NET, 1, comm->rank, &proxyConn), ret, fail); - NCCLCHECKGOTO(ncclProxyCall(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); + NCCLCHECKGOTO(ncclProxyCallBlocking(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); // Then to remote ones when using PXN if (ncclPxnDisable(comm) == 0) { @@ -1017,7 +1019,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm NCCLCHECKGOTO(ncclTopoGetPxnRanks(comm, &pxnPeers, &nranks), ret, fail); for (int r=0; rp2pnChannels, sizeof(int), NULL, 0), ret, fail); + NCCLCHECKGOTO(ncclProxyCallBlocking(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); } } @@ -1065,6 +1067,11 @@ fail: } NCCL_PARAM(SetStackSize, "SET_STACK_SIZE", 0); +NCCL_PARAM(CGAClusterSize, "CGA_CLUSTER_SIZE", NCCL_CONFIG_UNDEF_INT); +// Match config max/minCTAs +NCCL_PARAM(MaxCTAs, "MAX_CTAS", NCCL_CONFIG_UNDEF_INT); +NCCL_PARAM(MinCTAs, "MIN_CTAS", NCCL_CONFIG_UNDEF_INT); +#define NCCL_MAX_CGA_CLUSTER_SIZE 8 struct ncclCommInitRankAsyncJob { struct ncclAsyncJob base; @@ -1087,9 +1094,16 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { ncclUniqueId commId = job->commId; // C++ struct assignment int myrank = job->myrank; int cudaDev = job->cudaDev; + int archMajor, archMinor; + size_t maxLocalSizeBytes = 0; ncclResult_t res = ncclSuccess; CUDACHECKGOTO(cudaSetDevice(cudaDev), res, fail); + CUDACHECK(cudaDeviceGetAttribute(&archMajor, cudaDevAttrComputeCapabilityMajor, cudaDev)); + CUDACHECK(cudaDeviceGetAttribute(&archMinor, cudaDevAttrComputeCapabilityMinor, cudaDev)); + comm->cudaArch = 100*archMajor + 10*archMinor; + + NCCLCHECK(ncclInitKernelsForDevice(comm->cudaArch, &maxLocalSizeBytes)); // Set the maximum kernel stack size of all kernels to avoid // a CUDA memory reconfig on load (c.f. NVSHMEM issue) if (maxLocalSizeBytes > 0 && ncclParamSetStackSize() == 1) { @@ -1114,18 +1128,143 @@ fail: goto exit; } -static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { - ncclResult_t ret = ncclSuccess; - - /* first set configuration */ - if (config) { - comm->blocking = config->blocking; - } else { - /* default setting of communicator */ - comm->blocking = 1; +#define NCCL_CONFIG_DEFAULT(config, field, undef, defvalue, fieldStr, format) \ + if (config->field == undef) { \ + config->field = defvalue; \ + } else { \ + INFO(NCCL_ENV, "Comm config " fieldStr " set to " format, config->field); \ } +static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { + ncclResult_t ret = ncclSuccess; + /* config must not be NULL in this function */ + int blockingEnv; + int cgaClusterSizeEnv; + int minCTAsEnv; + int maxCTAsEnv; + const char *envNetName, *tmpNetName; + ncclConfig_t defaultConfig = NCCL_CONFIG_INITIALIZER; + ncclConfig_t internalConfig = NCCL_CONFIG_INITIALIZER; + ncclConfig_t *internalConfigPtr; + size_t realSize; + + internalConfigPtr = &internalConfig; + if (config) { + memcpy((void*)&realSize, (void*)config, sizeof(size_t)); + realSize = realSize > sizeof(ncclConfig_t) ? sizeof(ncclConfig_t) : realSize; + memcpy((void*)internalConfigPtr, (void*)config, realSize); + if (internalConfigPtr->magic != 0xcafebeef) { + WARN("ncclConfig_t argument not initialized via NCCL_CONFIG_INITIALIZER"); + ret = ncclInvalidArgument; + goto fail; + } + + /* check version. */ + if (internalConfigPtr->version < NCCL_VERSION(2, 14, 0)) { + internalConfigPtr->blocking = defaultConfig.blocking; + } + + if (internalConfigPtr->version < NCCL_VERSION(2, 17, 0)) { + internalConfigPtr->cgaClusterSize = defaultConfig.cgaClusterSize; + internalConfigPtr->minCTAs = defaultConfig.minCTAs; + internalConfigPtr->maxCTAs = defaultConfig.maxCTAs; + internalConfigPtr->netName = defaultConfig.netName; + } + } + + /* check input config attributes, -1 means user-undefined and we should use default value from NCCL. */ + if (internalConfigPtr->blocking != NCCL_CONFIG_UNDEF_INT && internalConfigPtr->blocking != 0 && internalConfigPtr->blocking != 1) { + WARN("Invalid config blocking attribute value %d", internalConfigPtr->blocking); + ret = ncclInvalidArgument; + goto fail; + } + + if (internalConfigPtr->cgaClusterSize != NCCL_CONFIG_UNDEF_INT && internalConfigPtr->cgaClusterSize < 0) { + WARN("Invalid config cgaClusterSize attribute value %d", internalConfigPtr->cgaClusterSize); + ret = ncclInvalidArgument; + goto fail; + } + + if ((internalConfigPtr->minCTAs != NCCL_CONFIG_UNDEF_INT && + internalConfigPtr->minCTAs <= 0) || + (internalConfigPtr->maxCTAs != NCCL_CONFIG_UNDEF_INT && + internalConfigPtr->maxCTAs <= 0) || + (internalConfigPtr->minCTAs > internalConfigPtr->maxCTAs)) { + WARN("Invalid config min/max channels attribute value %d/%d", internalConfigPtr->minCTAs, internalConfigPtr->maxCTAs); + ret = ncclInvalidArgument; + goto fail; + } + + /* default config value can be tuned on different platform. */ + NCCL_CONFIG_DEFAULT(internalConfigPtr, blocking, NCCL_CONFIG_UNDEF_INT, 1, "Blocking", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, cgaClusterSize, NCCL_CONFIG_UNDEF_INT, 4, "CGA cluster size", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, minCTAs, NCCL_CONFIG_UNDEF_INT, 1, "Min CTAs", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, maxCTAs, NCCL_CONFIG_UNDEF_INT, MAXCHANNELS, "Max CTAs", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, netName, NCCL_CONFIG_UNDEF_PTR, NULL, "Net name", "%s"); + + tmpNetName = internalConfigPtr->netName; + + /* assign config to communicator */ + comm->blocking = internalConfigPtr->blocking; + comm->cgaClusterSize = internalConfigPtr->cgaClusterSize; + comm->minCTAs = internalConfigPtr->minCTAs; + comm->maxCTAs = internalConfigPtr->maxCTAs; + + /* override configuration from env variable. */ + blockingEnv = ncclParamCommBlocking(); + if (blockingEnv == 0 || blockingEnv == 1) + comm->blocking = blockingEnv; + + cgaClusterSizeEnv = ncclParamCGAClusterSize(); + if (0 <= cgaClusterSizeEnv && cgaClusterSizeEnv <= NCCL_MAX_CGA_CLUSTER_SIZE) { + comm->cgaClusterSize = cgaClusterSizeEnv; + } else if (cgaClusterSizeEnv > NCCL_MAX_CGA_CLUSTER_SIZE) { + WARN("NCCL_CGA_CLUSTER_SIZE value %d is too big. Limiting value to %d.", cgaClusterSizeEnv, NCCL_MAX_CGA_CLUSTER_SIZE); + comm->cgaClusterSize = NCCL_MAX_CGA_CLUSTER_SIZE; + } + + minCTAsEnv = ncclParamMinCTAs(); + if (minCTAsEnv != NCCL_CONFIG_UNDEF_INT) { + comm->minCTAs = minCTAsEnv; + } + + maxCTAsEnv = ncclParamMaxCTAs(); + if (maxCTAsEnv != NCCL_CONFIG_UNDEF_INT) { + comm->maxCTAs = maxCTAsEnv; + } + + /* cap channels if needed */ + if (comm->minCTAs > MAXCHANNELS) { + WARN("minCTAs %d is larger than #channels upper limit %d", comm->minCTAs, MAXCHANNELS); + comm->minCTAs = MAXCHANNELS; + } + + if (comm->maxCTAs > MAXCHANNELS) { + WARN("maxCTAs %d is larger than #channels upper limit %d", comm->maxCTAs, MAXCHANNELS); + comm->maxCTAs = MAXCHANNELS; + } + + if (comm->minCTAs > comm->maxCTAs) { + WARN("minCTAs %d is larger than maxCTAs %d", comm->minCTAs, comm->maxCTAs); + ret = ncclInvalidArgument; + goto fail; + } + + envNetName = getenv("NCCL_NET"); + if (envNetName) + tmpNetName = envNetName; + if (tmpNetName != NULL) { + int netNameLen = strlen(tmpNetName) + 1; + comm->netName = (char*)malloc(netNameLen); + memcpy(comm->netName, tmpNetName, netNameLen); + } else { + comm->netName = NULL; + } + +exit: return ret; +fail: + goto exit; } static void ncclCommInitRankUndo(struct ncclAsyncJob* job_) { @@ -1151,6 +1290,7 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUni CUDACHECKGOTO(cudaFree(NULL), res, fail); NCCLCHECKGOTO(PtrCheck(newcomm, "CommInitRank", "newcomm"), res, fail); + NCCLCHECKGOTO(PtrCheck(config, "CommInitRank", "config"), res, fail); if (nranks < 1 || myrank < 0 || myrank >= nranks) { WARN("Invalid rank requested : %d/%d", myrank, nranks); res = ncclInvalidArgument; @@ -1201,12 +1341,13 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId comm (void)ncclCudaLibraryInit(); int cudaDev; + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; CUDACHECK(cudaGetDevice(&cudaDev)); NvtxParamsCommInitRank payload{myrank, nranks, cudaDev}; NVTX3_FUNC_WITH_PARAMS(CommInitRank, CommInitRankSchema, payload) - NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, NULL)); + NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, &config)); return ncclSuccess; } @@ -1215,6 +1356,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { ncclResult_t ret = ncclSuccess; int totalnDev; int *gpuFlags = NULL; + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; constexpr nvtxPayloadSchemaEntry_t CommInitAllSchema[] = { {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "No. of devices"} @@ -1258,7 +1400,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { NCCLCHECKGOTO(ncclGroupStart(), ret, fail); for (int i=0; i sizeof(ncclConfig_t) ? sizeof(ncclConfig_t) : realSize; - memcpy((void*)internalConfigPtr, (void*)config, realSize); - if (internalConfigPtr->magic != 0xcafebeef) { - WARN("ncclConfig_t argument not initialized via NCCL_CONFIG_INITIALIZER"); - ret = ncclInvalidArgument; - goto exit; - } - } - - /* check input config attributes */ - if (internalConfigPtr->blocking != 0 && internalConfigPtr->blocking != 1) { - WARN("Invalid config blocking attribute value %d", internalConfigPtr->blocking); - ret = ncclInvalidArgument; - goto exit; - } - - /* overwrite configuration from env variable. */ - blockingEnv = ncclParamCommBlocking(); - if (blockingEnv != 0 && blockingEnv != 1) { - WARN("Invalid NCCL_COMM_BLOCKING value %d", blockingEnv); - } - if (blockingEnv == 1) internalConfigPtr->blocking = blockingEnv; (void)ncclCudaLibraryInit(); - CUDACHECKGOTO(cudaGetDevice(&cudaDev), ret, exit); + CUDACHECKGOTO(cudaGetDevice(&cudaDev), ret, fail); + + if (config == NULL) + internalConfigPtr = &internalConfig; + else + internalConfigPtr = config; NCCLCHECKGOTO(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, internalConfigPtr), ret, fail); exit: diff --git a/src/misc/cudawrap.cc b/src/misc/cudawrap.cc index e2c1a6f..4fe9023 100644 --- a/src/misc/cudawrap.cc +++ b/src/misc/cudawrap.cc @@ -23,11 +23,33 @@ DECLARE_CUDA_PFN(cuMemGetAddressRange, 3020); /* proxy.cc */ DECLARE_CUDA_PFN(cuCtxCreate, 3020); DECLARE_CUDA_PFN(cuCtxDestroy, 4000); +DECLARE_CUDA_PFN(cuCtxGetCurrent, 4000); DECLARE_CUDA_PFN(cuCtxSetCurrent, 4000); +DECLARE_CUDA_PFN(cuCtxGetDevice, 2000); +/* cuMem API support */ +DECLARE_CUDA_PFN(cuMemAddressReserve, 10020); +DECLARE_CUDA_PFN(cuMemAddressFree, 10020); +DECLARE_CUDA_PFN(cuMemCreate, 10020); +DECLARE_CUDA_PFN(cuMemGetAllocationGranularity, 10020); +DECLARE_CUDA_PFN(cuMemExportToShareableHandle, 10020); +DECLARE_CUDA_PFN(cuMemImportFromShareableHandle, 10020); +DECLARE_CUDA_PFN(cuMemMap, 10020); +DECLARE_CUDA_PFN(cuMemRelease, 10020); +DECLARE_CUDA_PFN(cuMemSetAccess, 10020); +DECLARE_CUDA_PFN(cuMemUnmap, 10020); #if CUDA_VERSION >= 11070 /* transport/collNet.cc/net.cc*/ DECLARE_CUDA_PFN(cuMemGetHandleForAddressRange, 11070); // DMA-BUF support #endif +#if CUDA_VERSION >= 12010 +/* NVSwitch Multicast support */ +DECLARE_CUDA_PFN(cuMulticastAddDevice, 12010); +DECLARE_CUDA_PFN(cuMulticastBindMem, 12010); +DECLARE_CUDA_PFN(cuMulticastBindAddr, 12010); +DECLARE_CUDA_PFN(cuMulticastCreate, 12010); +DECLARE_CUDA_PFN(cuMulticastGetGranularity, 12010); +DECLARE_CUDA_PFN(cuMulticastUnbind, 12010); +#endif #endif /* CUDA Driver functions loaded with dlsym() */ @@ -39,6 +61,7 @@ DECLARE_CUDA_PFN(cuGetProcAddress, 11030); static void *cudaLib; int ncclCudaDriverVersionCache = -1; +bool ncclCudaLaunchBlocking = false; #if CUDART_VERSION >= 11030 /* @@ -62,9 +85,33 @@ static ncclResult_t cudaPfnFuncLoader(void) { LOAD_SYM(cuMemGetAddressRange, 3020, 1); LOAD_SYM(cuCtxCreate, 3020, 1); LOAD_SYM(cuCtxDestroy, 4000, 1); + LOAD_SYM(cuCtxGetCurrent, 4000, 1); LOAD_SYM(cuCtxSetCurrent, 4000, 1); + LOAD_SYM(cuCtxGetDevice, 2000, 1); +/* cuMem API support */ +#if CUDA_VERSION >= 11030 + LOAD_SYM(cuMemAddressReserve, 10020, 1); + LOAD_SYM(cuMemAddressFree, 10020, 1); + LOAD_SYM(cuMemCreate, 10020, 1); + LOAD_SYM(cuMemGetAllocationGranularity, 10020, 1); + LOAD_SYM(cuMemExportToShareableHandle, 10020, 1); + LOAD_SYM(cuMemImportFromShareableHandle, 10020, 1); + LOAD_SYM(cuMemMap, 10020, 1); + LOAD_SYM(cuMemRelease, 10020, 1); + LOAD_SYM(cuMemSetAccess, 10020, 1); + LOAD_SYM(cuMemUnmap, 10020, 1); +#endif #if CUDA_VERSION >= 11070 LOAD_SYM(cuMemGetHandleForAddressRange, 11070, 1); // DMA-BUF support +#endif +#if CUDA_VERSION >= 12010 +/* NVSwitch Multicast support */ + LOAD_SYM(cuMulticastAddDevice, 12010, 1); + LOAD_SYM(cuMulticastBindMem, 12010, 1); + LOAD_SYM(cuMulticastBindAddr, 12010, 1); + LOAD_SYM(cuMulticastCreate, 12010, 1); + LOAD_SYM(cuMulticastGetGranularity, 12010, 1); + LOAD_SYM(cuMulticastUnbind, 12010, 1); #endif return ncclSuccess; } @@ -74,6 +121,11 @@ static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT; static ncclResult_t initResult; static void initOnceFunc() { + do { + char* val = getenv("CUDA_LAUNCH_BLOCKING"); + ncclCudaLaunchBlocking = val!=nullptr && val[0]!=0 && !(val[0]=='0' && val[1]==0); + } while (0); + CUresult res; /* * Load CUDA driver library @@ -85,9 +137,10 @@ static void initOnceFunc() { else snprintf(path, 1024, "%s%s", ncclCudaPath, "libcuda.so"); + (void) dlerror(); // Clear any previous errors cudaLib = dlopen(path, RTLD_LAZY); if (cudaLib == NULL) { - WARN("Failed to find CUDA library (NCCL_CUDA_PATH='%s') : %s", ncclCudaPath ? ncclCudaPath : "", dlerror()); + WARN("Failed to find CUDA library %s (NCCL_CUDA_PATH='%s') : %s", path, ncclCudaPath ? ncclCudaPath : "", dlerror()); goto error; } diff --git a/src/misc/ipcsocket.cc b/src/misc/ipcsocket.cc new file mode 100644 index 0000000..b2dee48 --- /dev/null +++ b/src/misc/ipcsocket.cc @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2016-2023, NVIDIA CORPORATION. All rights reserved. + * + * See COPYRIGHT for license information + */ + +#include "ipcsocket.h" +#include "utils.h" +#include +#include +#include + +// Enable Linux abstract socket naming +#define USE_ABSTRACT_SOCKET + +#define NCCL_IPC_SOCKNAME_STR "/tmp/nccl-socket-%d-%lx" + +/* + * Create a Unix Domain Socket + */ +ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, volatile uint32_t* abortFlag) { + int fd = -1; + struct sockaddr_un cliaddr; + char temp[NCCL_IPC_SOCKNAME_LEN] = ""; + + if (handle == NULL) { + return ncclInternalError; + } + + handle->fd = -1; + handle->socketName[0] = '\0'; + if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) { + WARN("UDS: Socket creation error : %d", errno); + return ncclSystemError; + } + + bzero(&cliaddr, sizeof(cliaddr)); + cliaddr.sun_family = AF_UNIX; + + // Create unique name for the socket. + int len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); + if (len > (sizeof(cliaddr.sun_path) - 1)) { + WARN("UDS: Cannot bind provided name to socket. Name too large"); + return ncclInternalError; + } +#ifndef USE_ABSTRACT_SOCKET + unlink(temp); +#endif + + TRACE(NCCL_INIT, "UDS: Creating socket %s", temp); + + strncpy(cliaddr.sun_path, temp, len); +#ifdef USE_ABSTRACT_SOCKET + cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick +#endif + if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) { + WARN("UDS: Binding to socket %s failed : %d", temp, errno); + close(fd); + return ncclSystemError; + } + + handle->fd = fd; + strcpy(handle->socketName, temp); + + handle->abortFlag = abortFlag; + // Mark socket as non-blocking + if (handle->abortFlag) { + int flags; + EQCHECK(flags = fcntl(fd, F_GETFL), -1); + SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); + } + + return ncclSuccess; +} + +ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) { + if (handle == NULL) { + return ncclInternalError; + } + if (handle->fd <= 0) { + return ncclSuccess; + } +#ifndef USE_ABSTRACT_SOCKET + if (handle->socketName[0] != '\0') { + unlink(handle->socketName); + } +#endif + close(handle->fd); + + return ncclSuccess; +} + +ncclResult_t ncclIpcSocketRecvFd(ncclIpcSocket *handle, int *recvFd) { + struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; + struct iovec iov[1]; + + // Union to guarantee alignment requirements for control array + union { + struct cmsghdr cm; + char control[CMSG_SPACE(sizeof(int))]; + } control_un; + + struct cmsghdr *cmptr; + char dummy_buffer[1]; + int ret; + + msg.msg_control = control_un.control; + msg.msg_controllen = sizeof(control_un.control); + + iov[0].iov_base = (void *)dummy_buffer; + iov[0].iov_len = sizeof(dummy_buffer); + + msg.msg_iov = iov; + msg.msg_iovlen = 1; + + while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { + WARN("UDS: Receiving data over socket failed : %d", errno); + return ncclSystemError; + } + if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; + } + + if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) { + if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) { + WARN("UDS: Receiving data over socket failed"); + return ncclSystemError; + } + + memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd)); + } else { + WARN("UDS: Receiving data over socket %s failed", handle->socketName); + return ncclSystemError; + } + + TRACE(NCCL_INIT|NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName); + + return ncclSuccess; +} + +ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, int rank, uint64_t hash) { + struct msghdr msg; + struct iovec iov[1]; + char temp[NCCL_IPC_SOCKNAME_LEN]; + + union { + struct cmsghdr cm; + char control[CMSG_SPACE(sizeof(int))]; + } control_un; + + struct cmsghdr *cmptr; + struct sockaddr_un cliaddr; + + // Construct client address to send this shareable handle to + bzero(&cliaddr, sizeof(cliaddr)); + cliaddr.sun_family = AF_UNIX; + + int len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); + if (len > (sizeof(cliaddr.sun_path) - 1)) { + WARN("UDS: Cannot connect to provided name for socket. Name too large"); + return ncclInternalError; + } + (void) strncpy(cliaddr.sun_path, temp, len); + + TRACE(NCCL_INIT, "UDS: Sending fd %d to UDS socket %s", sendFd, temp); + +#ifdef USE_ABSTRACT_SOCKET + cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick +#endif + + msg.msg_control = control_un.control; + msg.msg_controllen = sizeof(control_un.control); + + cmptr = CMSG_FIRSTHDR(&msg); + cmptr->cmsg_len = CMSG_LEN(sizeof(int)); + cmptr->cmsg_level = SOL_SOCKET; + cmptr->cmsg_type = SCM_RIGHTS; + + memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd)); + + msg.msg_name = (void *)&cliaddr; + msg.msg_namelen = sizeof(struct sockaddr_un); + + iov[0].iov_base = (void *)""; + iov[0].iov_len = 1; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_flags = 0; + + ssize_t sendResult; + while ((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { + WARN("UDS: Sending data over socket %s failed : %d", temp, errno); + return ncclSystemError; + } + if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; + } + + return ncclSuccess; +} diff --git a/src/misc/socket.cc b/src/misc/socket.cc index 9f93e26..56c96c5 100644 --- a/src/misc/socket.cc +++ b/src/misc/socket.cc @@ -43,7 +43,7 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { int closed; - NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0, &closed)); + NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed)); if (closed) { char line[SOCKET_NAME_MAXLEN+1]; WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); @@ -785,16 +785,33 @@ ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size) { } // Receive or detect connection closed -ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed) { +ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed, bool blocking) { int offset = 0; if (sock == NULL) { WARN("ncclSocketTryRecv: pass NULL socket"); return ncclInvalidArgument; } *closed = 0; - while (offset < size) { + // Block until connection closes or nbytes received + if (blocking) { + while (offset < size) { + NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); + if (*closed) return ncclSuccess; + } + } else { NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); if (*closed) return ncclSuccess; + + // If any bytes were received, block waiting for the rest + if (offset > 0) { + while (offset < size) { + NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); + if (*closed) return ncclSuccess; + } + // No bytes were received, return ncclInProgress + } else { + return ncclInProgress; + } } return ncclSuccess; } diff --git a/src/nccl.h.in b/src/nccl.h.in index 44a68e9..ff981e4 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -25,8 +25,10 @@ extern "C" { #endif +#include /* Opaque handle to communicator */ typedef struct ncclComm* ncclComm_t; +#define NCCL_COMM_NULL NULL #define NCCL_UNIQUE_ID_BYTES 128 typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId; @@ -42,15 +44,22 @@ typedef enum { ncclSuccess = 0, ncclInProgress = 7, ncclNumResults = 8 } ncclResult_t; +#define NCCL_CONFIG_UNDEF_INT INT_MIN +#define NCCL_CONFIG_UNDEF_PTR NULL + /* Communicator configuration. Users can assign value to attributes to specify the * behavior of a communicator. */ -typedef struct ncclConfig_v21400 { +typedef struct ncclConfig_v21700 { /* attributes that users should never touch. */ size_t size; unsigned int magic; unsigned int version; /* attributes that users are able to customize. */ int blocking; + int cgaClusterSize; + int minCTAs; + int maxCTAs; + const char *netName; } ncclConfig_t; /* Config initializer must be assigned to initialize config structure when it is created. @@ -59,7 +68,11 @@ typedef struct ncclConfig_v21400 { sizeof(ncclConfig_t), /* size */ \ 0xcafebeef, /* magic */ \ NCCL_VERSION(NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH), /* version */ \ - 1 /* blocking */ \ + NCCL_CONFIG_UNDEF_INT, /* blocking */ \ + NCCL_CONFIG_UNDEF_INT, /* cgaClusterSize */ \ + NCCL_CONFIG_UNDEF_INT, /* minCTAs */ \ + NCCL_CONFIG_UNDEF_INT, /* maxCTAs */ \ + NCCL_CONFIG_UNDEF_PTR /* netName */ \ } /* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer. diff --git a/src/net.cc b/src/net.cc index 1480c76..1315b3d 100644 --- a/src/net.cc +++ b/src/net.cc @@ -176,14 +176,8 @@ ncclResult_t ncclNetPluginInit() { } void* netPluginLib = dlopen(ncclNetPluginName, RTLD_NOW | RTLD_LOCAL); if (netPluginLib == nullptr) { - // dlopen does not guarantee to set errno, but dlerror only gives us a - // string, so checking errno doesn't hurt to try to provide a better - // error message - if (errno == ENOENT) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found (%s), using internal implementation", ncclNetPluginName); - } else { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Plugin load returned %d : %s.", errno, dlerror()); - } + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Plugin load (%s) returned %d : %s", ncclNetPluginName, errno, dlerror()); + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found, using internal implementation"); return ncclSuccess; } @@ -264,9 +258,10 @@ static ncclResult_t collNetGetState(int i, enum ncclNetState* state) { ncclResult_t ncclNetInit(struct ncclComm* comm) { // Initialize main communication network - char* netName = getenv("NCCL_NET"); + char* netName; bool ok = false; + netName = comm->netName; for (int i=0; i<3; i++) { if (ncclNets[i] == nullptr) continue; enum ncclNetState state; @@ -324,9 +319,26 @@ ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) { ncclResult_t ret; ncclDebugNoWarn = NCCL_NET; NCCLCHECKGOTO(ncclNetListen(comm, dev, &handle, &lComm), ret, cleanup1); - NCCLWAITGOTO(ncclNetConnect(comm, dev, &handle, &sComm), sComm != NULL, comm->abortFlag, ret, cleanup2); - NCCLWAITGOTO(ncclNetAccept(comm, lComm, &rComm), rComm != NULL, comm->abortFlag, ret, cleanup3); - CUDACHECKGOTO(cudaMalloc(&gpuPtr, GPU_BUF_SIZE), ret, cleanup4); + + bool connected; + connected = false; + while (!connected) { + + // If we're aborting now, skip to cleanup + if (*comm->abortFlag) { + goto cleanup2; + } + + if (sComm == NULL) + NCCLCHECKGOTO(ncclNetConnect(comm, dev, &handle, &sComm), ret, cleanup2); + + if (rComm == NULL) + NCCLCHECKGOTO(ncclNetAccept(comm, lComm, &rComm), ret, cleanup2); + + connected = (rComm != NULL) && (sComm != NULL); + } + + CUDACHECKGOTO(cudaMalloc(&gpuPtr, GPU_BUF_SIZE), ret, cleanup2); if (ncclNetRegMr(comm, sComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle) == ncclSuccess) { NCCLCHECK(ncclNetDeregMr(comm, sComm, mHandle)); NCCLCHECK(ncclNetRegMr(comm, rComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle)); @@ -335,11 +347,11 @@ ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) { } ncclDebugNoWarn = 0; CUDACHECK(cudaFree(gpuPtr)); -cleanup4: - NCCLCHECK(ncclNetCloseRecv(comm, rComm)); -cleanup3: - NCCLCHECK(ncclNetCloseSend(comm, sComm)); cleanup2: + if (rComm != NULL) + NCCLCHECK(ncclNetCloseRecv(comm, rComm)); + if (sComm != NULL) + NCCLCHECK(ncclNetCloseSend(comm, sComm)); NCCLCHECK(ncclNetCloseListen(comm, lComm)); cleanup1: break; diff --git a/src/proxy.cc b/src/proxy.cc index cfaa266..298b66f 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -14,6 +14,7 @@ #include "timer.h" #include +#include enum { proxyRecv=0, proxySend=1 }; @@ -37,6 +38,155 @@ struct ncclProxyPool { struct ncclProxyArgs elems[PROXYARGS_ALLOCATE_SIZE]; }; +static void expectedProxyResponseFree(struct ncclProxyState* state) { + struct ncclExpectedProxyResponse* elem = state->expectedResponses; + struct ncclExpectedProxyResponse* prev = NULL; + + while (elem) { + prev = elem; + elem = elem->next; + free(prev->respBuff); + free(prev); + } +} + +static ncclResult_t expectedProxyResponseStore(struct ncclProxyState* state, void* opId, void* respBuff, int respSize) { + struct ncclExpectedProxyResponse* elem = state->expectedResponses; + while (elem) { + if (elem->opId == opId) { + if (respSize != elem->respSize) { + WARN("Mismatched response size for opId=%p", opId); + return ncclInternalError; + } + + if (elem->done) { + WARN("Storing response for already completed opId=%p", opId); + return ncclInternalError; + } + + memcpy(elem->respBuff, respBuff, respSize); + elem->done = true; + return ncclSuccess; + } + elem = elem->next; + } + + WARN("Proxy response for opId=%p doesn't match any expected response", opId); + return ncclInternalError; +} + +static ncclResult_t expectedProxyResponseEnqueue(struct ncclProxyState* state, void* opId, int respSize, void* respData, int respDataSize) { + struct ncclExpectedProxyResponse* ex; + NCCLCHECK(ncclCalloc(&ex, 1)); + ex->opId = opId; + + // Pre-alloc response buffer + ex->respBuff = malloc(respSize); + ex->respSize = respSize; + ex->done = false; + if (respData) { + memcpy(ex->respBuff, respData, respDataSize); + ex->done = true; + } + + // Enqueue + struct ncclExpectedProxyResponse* list = state->expectedResponses; + if (list == NULL) { + state->expectedResponses = ex; + return ncclSuccess; + } + while (list->next) list = list->next; + list->next = ex; + return ncclSuccess; +} + +static ncclResult_t expectedProxyResponseDequeue(struct ncclProxyState* state, void* opId, void* respBuff, int* found) { + struct ncclExpectedProxyResponse* elem = state->expectedResponses; + struct ncclExpectedProxyResponse* prev = NULL; + *found = 0; + while (elem) { + if ((elem->opId == opId) && elem->done) { + if (prev == NULL) { + state->expectedResponses = elem->next; + } else { + prev->next = elem->next; + } + memcpy(respBuff, elem->respBuff, elem->respSize); + free(elem->respBuff); + free(elem); + *found = 1; + return ncclSuccess; + } + prev = elem; + elem = elem->next; + } + return ncclSuccess; +} + +static ncclResult_t expectedProxyResponseRemove(struct ncclProxyState* state, void* opId) { + struct ncclExpectedProxyResponse* elem = state->expectedResponses; + struct ncclExpectedProxyResponse* prev = NULL; + while (elem) { + if (elem->opId == opId) { + if (prev == NULL) { + state->expectedResponses = elem->next; + } else { + prev->next = elem->next; + } + free(elem->respBuff); + free(elem); + return ncclSuccess; + } + prev = elem; + elem = elem->next; + } + WARN("Couldn't find opId=%p", opId); + return ncclInternalError; +} + +static ncclResult_t asyncProxyOpEnqueue(struct ncclProxyLocalPeer* peer, ncclProxyAsyncOp* op) { + ncclProxyAsyncOp* list = peer->asyncOps; + if (list == NULL) { + peer->asyncOps = op; + return ncclSuccess; + } + while (list->next) list = list->next; + list->next = op; + return ncclSuccess; +} + +static ncclResult_t asyncProxyOpDequeue(struct ncclProxyLocalPeer* peer, ncclProxyAsyncOp* op) { + struct ncclProxyAsyncOp* elem = peer->asyncOps; + struct ncclProxyAsyncOp* prev = NULL; + while (elem) { + if (elem->opId == op->opId) { + if (prev == NULL) { + peer->asyncOps = elem->next; + } else { + prev->next = elem->next; + } + + if (elem->reqBuff) { + free(elem->reqBuff); + } + if (elem->respBuff) { + free(elem->respBuff); + } + free(elem); + + return ncclSuccess; + } + prev = elem; + elem = elem->next; + } + if (op) { + WARN("Attempting to dequeue nonexistent async opId=%p", op->opId); + } else { + WARN("Attempting to dequeue null operation"); + } + return ncclInternalError; +} + static ncclResult_t allocateArgs(struct ncclProxyProgressState* state, struct ncclProxyArgs** argsptr) { struct ncclProxyArgs* elem; if (state->pool == NULL) { @@ -86,7 +236,7 @@ ncclResult_t getOpIndex(struct ncclProxyArgs* op, struct ncclProxyProgressState* pool = pool->next; p++; } - WARN("Could not find pool of op %p\n", op); + WARN("Could not find pool of op %p", op); return ncclInternalError; } @@ -140,7 +290,7 @@ ncclResult_t dumpProxyState(struct ncclProxyProgressState* state) { nextOp->state |= OP_SEEN; printf("\n"); if (nextOp->next) { - WARN("Inactive op has next set!\n"); + WARN("Inactive op has next set!"); } nextOp = nextOp->nextPeer; } @@ -337,7 +487,7 @@ ncclResult_t ncclLocalOpAppend(struct ncclComm* comm, struct ncclProxyConnector* } } if (lastOp == -1) { - WARN("Unable to post incomplete proxy op chain %d..%d (opCount %ld)\n", proxyOps->nextOps, proxyOps->nextOpsEnd, lastOpCount); + WARN("Unable to post incomplete proxy op chain %d..%d (opCount %ld)", proxyOps->nextOps, proxyOps->nextOpsEnd, lastOpCount); return ncclInternalError; } // Cut chain at lastOp @@ -770,19 +920,6 @@ ncclResult_t ncclProxyProgressDestroy(struct ncclComm* comm) { return ncclSuccess; } -struct ncclProxyAsyncOp { - int type; - struct ncclProxyConnection* connection; - int reqSize, respSize; - char *reqBuff, *respBuff; -}; - -struct ncclProxyLocalPeer { - struct ncclSocket sock; - int localRank; - struct ncclProxyAsyncOp asyncOps; -}; - #define NCCL_PROXY_CONN_POOL_SIZE_POW2 7 #define NCCL_PROXY_CONN_POOL_SIZE (1<<(NCCL_PROXY_CONN_POOL_SIZE_POW2)) #define NCCL_PROXY_CONN_POOL_MASK ((NCCL_PROXY_CONN_POOL_SIZE)-1) @@ -790,7 +927,6 @@ struct ncclProxyConnectionPool { struct ncclProxyConnection** pools; int banks; int offset; - struct ncclProxyAsyncOp* ops; }; static ncclResult_t ncclProxyNewConnection(struct ncclProxyConnectionPool* pool, int* id) { @@ -888,26 +1024,137 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in return ncclSuccess; } -const char* ncclProxyMsgTypeStr[] = { "Unknown", "Init", "SharedInit", "Setup", "Connect", "Start", "Close", "Abort", "Stop" }; -ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) { +const char* ncclProxyMsgTypeStr[] = { "Unknown", "Init", "SharedInit", "Setup", "Connect", "Start", "Close", "Abort", "Stop", "ConvertFd" }; +ncclResult_t ncclProxyCallAsync(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, int respSize, void* opId) { struct ncclSocket* sock; ncclResult_t ret = ncclSuccess; + void* respData = NULL; + int respDataSize = 0; + struct ncclComm* comm = proxyConn->comm; + struct ncclIpcSocket ipcSock = { 0 }; - if (proxyConn->comm->proxyState.peerSocks == NULL) return ncclInternalError; - sock = proxyConn->comm->proxyState.peerSocks + proxyConn->localRank; + if (*comm->abortFlag != 0) { + WARN("ncclProxyCallAsync() - Saw abortFlag while waiting for proxyThread response"); + return ncclInternalError; + } + if (comm->proxyState.peerSocks == NULL) return ncclInternalError; + + sock = comm->proxyState.peerSocks + proxyConn->localRank; if (sock == NULL) return ncclInternalError; + + if (type == ncclProxyMsgConvertFd) { + // cuMem API support + // Create a UDS socket to receive the converted fd + NCCLCHECK(ncclIpcSocketInit(&ipcSock, comm->localRank, (uint64_t)proxyConn->connection, comm->abortFlag)); + } + NCCLCHECKGOTO(ncclSocketSend(sock, &type, sizeof(int)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &proxyConn->connection, sizeof(void*)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &reqSize, sizeof(int)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &respSize, sizeof(int)), ret, error); if (reqSize) NCCLCHECKGOTO(ncclSocketSend(sock, reqBuff, reqSize), ret, error); - if (respSize) NCCLCHECKGOTO(ncclSocketRecv(sock, respBuff, respSize), ret, error); + + if (type == ncclProxyMsgConvertFd) { + // cuMem API support + int recvFd = -1; + if (reqSize != sizeof(int) || respSize != sizeof(int)) return ncclInternalError; + // Receive converted fd over UDS + NCCLCHECK(ncclIpcSocketRecvFd(&ipcSock, &recvFd)); + TRACE(NCCL_NET, "UDS: ConvertFd rank %d returned %p %d", proxyConn->localRank, &recvFd, recvFd); + assert(recvFd != -1); + respData = &recvFd; + respDataSize = sizeof(recvFd); + NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + } else { + // Send opId to proxy + NCCLCHECKGOTO(ncclSocketSend(sock, &opId, sizeof(opId)), ret, error); + } + // Add proxyOp to expected response queue + NCCLCHECK(expectedProxyResponseEnqueue(&comm->proxyState, opId, respSize, respData, respDataSize)); + return ncclSuccess; error: - WARN("Proxy Call to rank %d failed (%s)", proxyConn->comm->localRankToRank[proxyConn->localRank], ncclProxyMsgTypeStr[type]); + NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + WARN("Proxy Call to rank %d failed (%s)", comm->localRankToRank[proxyConn->localRank], ncclProxyMsgTypeStr[type]); return ret; } +ncclResult_t ncclPollProxyResponse(struct ncclProxyConnector* proxyConn, void* respBuff, void* opId) { + struct ncclComm* comm = proxyConn->comm; + + // Receive the connection pointer from the Proxy + if (*comm->abortFlag) { + WARN("Comm %p is in abort state", comm); + return ncclInternalError; + } + if (comm->proxyState.peerSocks == NULL) return ncclInternalError; + + // Check response queue + int found = 0; + NCCLCHECK(expectedProxyResponseDequeue(&comm->proxyState, opId, respBuff, &found)); + if (found == 0) { + // Attempt to read in a new response header from the proxy thread + struct ncclSocket* sock = comm->proxyState.peerSocks + proxyConn->localRank; + + void* recvOpId; + int offset = 0; + if (ncclSuccess != ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset)) { + WARN("Socket recv failed while polling for opId=%p", opId); + return ncclInternalError; + } + + if (offset == 0) { + return ncclInProgress; + // If we've returned a partial response, block to receive the rest of it + } else if (offset < sizeof(recvOpId)) { + while (offset < sizeof(recvOpId)) + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset)); + } + + INFO(NCCL_PROXY, "ncclPollProxyResponse Recieved new opId=%p", recvOpId); + + // Now do a blocking recv of the response size + int respSize = 0; + NCCLCHECK(ncclSocketRecv(sock, &respSize, sizeof(respSize))); + + // If there's a respSize to recv + if (respSize > 0) { + NCCLCHECK(ncclSocketRecv(sock, respBuff, respSize)); + } + + if (recvOpId == opId) { + INFO(NCCL_PROXY, "recvOpId=%p matches expected opId=%p", recvOpId, opId); + NCCLCHECK(expectedProxyResponseRemove(&comm->proxyState, recvOpId)); + return ncclSuccess; + } else { + INFO(NCCL_PROXY, "Queing opId=%p", recvOpId); + // Store the result and mark response as completed + NCCLCHECK(expectedProxyResponseStore(&comm->proxyState, recvOpId, respBuff, respSize)); + return ncclInProgress; + } + } else { + INFO(NCCL_PROXY, "ncclPollProxyResponse Dequeued cached opId=%p", opId); + } + + return ncclSuccess; +} + +ncclResult_t ncclProxyCallBlocking(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) { + // Alloc some memory to act as a handle + void* opId = malloc(1); + + NCCLCHECK(ncclProxyCallAsync(proxyConn, type, reqBuff, reqSize, respSize, opId)); + ncclResult_t res = ncclInProgress; + + while (res == ncclInProgress) { + res = ncclPollProxyResponse(proxyConn, respBuff, opId); + } + + free(opId); + + return res; +} + static ncclResult_t proxyProgressInit(struct ncclComm* comm) { struct ncclProxyProgressState* state = &comm->proxyState.progressState; if (state->opsPool == NULL) { @@ -998,16 +1245,55 @@ static ncclResult_t proxyConnSharedInit(struct ncclProxyLocalPeer* peer, struct if (reqSize != sizeof(int) || respSize != 0) return ncclInternalError; int nChannels; NCCLCHECK(ncclSocketRecv(sock, &nChannels, sizeof(int))); + + // Store opId for completion response + void* opId; + NCCLCHECK(ncclSocketRecv(sock, &opId, sizeof(opId))); + INFO(NCCL_PROXY, "proxyConnSharedInit received opId=%p", opId); + if (connection->tcomm->proxySharedInit) NCCLCHECK(connection->tcomm->proxySharedInit(connection, comm, nChannels)); __atomic_store_n(&connection->state, connSharedInitialized, __ATOMIC_RELEASE); + + // Send the opId for referencing async operation + INFO(NCCL_PROXY, "proxyConnSharedInit::ncclSocketSend(opId=%p)", opId); + NCCLCHECK(ncclSocketSend(connection->sock, &opId, sizeof(opId))); + + // Send the response size + INFO(NCCL_PROXY, "proxyConnSharedInit::ncclSocketSend(op.respSize=%d)", respSize); + NCCLCHECK(ncclSocketSend(connection->sock, &respSize, sizeof(respSize))); + return ncclSuccess; } -static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclComm* comm, int* asyncOpCount) { +// cuMem API support +static ncclResult_t proxyConvertFd(struct ncclProxyLocalPeer* peer, struct ncclComm* comm) { + struct ncclSocket* sock = &peer->sock; + uint64_t connection; + NCCLCHECK(ncclSocketRecv(sock, &connection, sizeof(uint64_t))); + int reqSize, respSize; + NCCLCHECK(ncclSocketRecv(sock, &reqSize, sizeof(int))); + NCCLCHECK(ncclSocketRecv(sock, &respSize, sizeof(int))); + if (reqSize != sizeof(int) || respSize != sizeof(int)) return ncclInternalError; + + int fd; + struct ncclIpcSocket ipcSock = { 0 }; + NCCLCHECK(ncclSocketRecv(sock, &fd, sizeof(int))); + + INFO(NCCL_NET, "UDS: proxyConvertFd received fd %d peer %d connection %lx", fd, peer->localRank, connection); + // Send back the converted fd using UDS + NCCLCHECK(ncclIpcSocketInit(&ipcSock, comm->localRank, connection^1, comm->abortFlag)); + NCCLCHECK(ncclIpcSocketSendFd(&ipcSock, fd, peer->localRank, connection)); + NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + return ncclSuccess; +} + +static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclComm* comm, int* asyncOpCount, struct ncclProxyLocalPeer* peer) { int done = 1; if (op->type == ncclProxyMsgSetup) { + INFO(NCCL_PROXY, "proxyProgressAsync::proxySetup() opId=%p", op->opId); NCCLCHECK(op->connection->tcomm->proxySetup(op->connection, comm, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done)); } else if (op->type == ncclProxyMsgConnect) { + INFO(NCCL_PROXY, "proxyProgressAsync::proxyConnect() opId=%p op.reqBuff=%p", op->opId, op->reqBuff); NCCLCHECK(op->connection->tcomm->proxyConnect(op->connection, comm, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done)); } else return ncclInternalError; if (done) { @@ -1015,31 +1301,38 @@ static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclC __atomic_store_n(&op->connection->state, connSetupDone, __ATOMIC_RELEASE); else if (op->type == ncclProxyMsgConnect) __atomic_store_n(&op->connection->state, connConnected, __ATOMIC_RELEASE); - /* if setup or connect is done, we should not return any error at this point since + /* if setup or connect is done, we should not return any error at this point since * ncclSocketSend might already send the respBuff to the requester. If we still choose * to abort and close the connection, it can cause segfault if the requester is using * the respBuff. */ - if (op->respSize) ncclSocketSend(op->connection->sock, op->respBuff, op->respSize); - if (op->reqBuff) { - free(op->reqBuff); - op->reqBuff = NULL; + + // Send the opId for referencing async operation + NCCLCHECK(ncclSocketSend(op->connection->sock, &op->opId, sizeof(op->opId))); + + // Send the response size + NCCLCHECK(ncclSocketSend(op->connection->sock, &op->respSize, sizeof(op->respSize))); + + if (op->respSize) { + // Send the response + NCCLCHECK(ncclSocketSend(op->connection->sock, op->respBuff, op->respSize)); } - if (op->respBuff) { - free(op->respBuff); - op->respBuff = NULL; - } - op->type = 0; + + asyncProxyOpDequeue(peer, op); (*asyncOpCount)--; + return ncclSuccess; + } else if (*comm->abortFlag != 0) { return ncclInternalError; } - return ncclSuccess; + return ncclInProgress; } static ncclResult_t proxyConnSetupConnect(int type, struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclComm* comm, int* asyncOpCount) { struct ncclSocket* sock = &peer->sock; - struct ncclProxyAsyncOp* asyncOp = &peer->asyncOps; + struct ncclProxyAsyncOp* asyncOp; + NCCLCHECK(ncclCalloc(&asyncOp, 1)); + asyncOp->type = type; NCCLCHECK(ncclSocketRecv(sock, &asyncOp->connection, sizeof(void*))); @@ -1049,9 +1342,16 @@ static ncclResult_t proxyConnSetupConnect(int type, struct ncclProxyLocalPeer* p NCCLCHECK(ncclCalloc(&asyncOp->reqBuff, asyncOp->reqSize)); NCCLCHECK(ncclSocketRecv(sock, asyncOp->reqBuff, asyncOp->reqSize)); } + + // Store opId for completion response + NCCLCHECK(ncclSocketRecv(sock, &asyncOp->opId, sizeof(asyncOp->opId))); + if (asyncOp->respSize) NCCLCHECK(ncclCalloc(&asyncOp->respBuff, asyncOp->respSize)); + + asyncProxyOpEnqueue(peer, asyncOp); + (*asyncOpCount)++; - NCCLCHECK(proxyProgressAsync(asyncOp, comm, asyncOpCount)); + NCCLCHECK(proxyProgressAsync(asyncOp, comm, asyncOpCount, peer)); return ncclSuccess; } @@ -1081,7 +1381,7 @@ void* ncclProxyService(void* _args) { pollfds[s].events = POLLHUP|POLLIN; } if (ncclSocketGetFd(comm->proxyState.listenSock, &pollfds[NCCL_MAX_LOCAL_RANKS].fd) != ncclSuccess) { - WARN("[Proxy Service] Get listenSock fd fails\n"); + WARN("[Proxy Service] Get listenSock fd fails"); return NULL; }; pollfds[NCCL_MAX_LOCAL_RANKS].events = POLLIN; @@ -1113,14 +1413,14 @@ void* ncclProxyService(void* _args) { } if (maxnpeers < s+1) maxnpeers = s+1; if (ncclSocketInit(&peers[s].sock) != ncclSuccess) { - WARN("[Service thread] Initialize peers[%d].sock fails\n", s); + WARN("[Service thread] Initialize peers[%d].sock fails", s); return NULL; } if (ncclSocketAccept(&peers[s].sock, comm->proxyState.listenSock) != ncclSuccess) { WARN("[Service thread] Accept failed %s", strerror(errno)); } else { if (ncclSocketGetFd(&peers[s].sock, &pollfds[s].fd) != ncclSuccess) { - WARN("[Service thread] Get peers[%d].sock fd fails\n", s); + WARN("[Service thread] Get peers[%d].sock fd fails", s); return NULL; } npeers++; @@ -1130,25 +1430,37 @@ void* ncclProxyService(void* _args) { for (int s=0; ssock; - struct ncclProxyAsyncOp* op = &peer->asyncOps; int closeConn = 0; int type = 0; ncclResult_t res = ncclSuccess; - if (pollfds[s].fd == -1) continue; - if (op->type != 0) { - res = proxyProgressAsync(op, comm, &asyncOpCount); + + // Progress all ops for this ncclProxyLocalPeer + ncclProxyAsyncOp* op = peer->asyncOps; + while (op != nullptr) { type = op->type; - if (res != ncclSuccess) closeConn = 1; - } else if (pollfds[s].revents & POLLIN) { + res = proxyProgressAsync(op, comm, &asyncOpCount, peer); + if (res == ncclSuccess || res == ncclInProgress) { + op = op->next; + } else { + // Res is a bad result + closeConn = 1; + WARN("[Service thread] Error encountered progressing operation=%s, res=%d, closing connection", ncclProxyMsgTypeStr[type], res); + break; + } + } + + // Check for additional ops coming in + if (pollfds[s].revents & POLLIN) { int closed; - if (ncclSocketTryRecv(sock, &type, sizeof(int), &closed) != ncclSuccess) { - WARN("[Service thread] Could not receive type from localRank %d", peer->localRank); + res = ncclSocketTryRecv(sock, &type, sizeof(int), &closed, false /*blocking*/); + if (res != ncclSuccess && res != ncclInProgress) { + WARN("[Service thread] Could not receive type from localRank %d, res=%u, closed=%d", peer->localRank, res, closed); closeConn = 1; } else if (closed) { INFO(NCCL_INIT|NCCL_NET, "[Service thread] Connection closed by localRank %d", peer->localRank); closeConn = 1; - } else { + } else if (res == ncclSuccess) { // We received something from the sock if (type == ncclProxyMsgStop) { stop = 1; closeConn = 1; @@ -1159,30 +1471,32 @@ void* ncclProxyService(void* _args) { } else if (type == ncclProxyMsgSharedInit) { res = proxyConnSharedInit(peers+s, &connectionPool, comm); } else if (type == ncclProxyMsgSetup || type == ncclProxyMsgConnect) { + INFO(NCCL_PROXY, "proxyConnSetupConnect for peer->localRank %d,", peer->localRank); res = proxyConnSetupConnect(type, peers+s, &connectionPool, comm, &asyncOpCount); + } else if (type == ncclProxyMsgConvertFd) { + res = proxyConvertFd(peers+s, comm); // cuMem API support } else { - WARN("[Service thread] Unknown command %d from localRank %d\n", type, peer->localRank); + WARN("[Service thread] Unknown command %d from localRank %d", type, peer->localRank); closeConn = 1; } + + INFO(NCCL_PROXY, "Received and initiated operation=%s res=%d", ncclProxyMsgTypeStr[type], res); } } else if (pollfds[s].revents & POLLHUP) { closeConn = 1; - } - if (res != ncclSuccess) { + } + if (res != ncclSuccess && res != ncclInProgress) { WARN("[Proxy Service %d] Failed to execute operation %s from rank %d, retcode %d", comm->rank, ncclProxyMsgTypeStr[type], comm->localRankToRank[peer->localRank], res); closeConn = 1; } + if (closeConn) { ncclSocketClose(sock); - if (op->reqBuff) { - free(op->reqBuff); - op->reqBuff = NULL; + + if (op != nullptr) { + asyncProxyOpDequeue(peer, op); + asyncOpCount--; } - if (op->respBuff) { - free(op->respBuff); - op->respBuff = NULL; - } - op->type = 0; pollfds[s].fd = -1; npeers--; } @@ -1250,6 +1564,7 @@ ncclResult_t ncclProxyDestroy(struct ncclComm* comm) { free(state->peerSocks); free(state->proxyOps); free(state->sharedDevMems); + expectedProxyResponseFree(state); } return ncclSuccess; } diff --git a/src/transport.cc b/src/transport.cc index 66d8b51..a50f912 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -69,9 +69,12 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* // Stream used during transport setup; need for P2P pre-connect + CUDA Graph ncclResult_t ret = ncclSuccess; int highestType = TRANSPORT_P2P; // track highest transport type - struct ncclConnect data[2*MAXCHANNELS]; + struct ncclConnect** data = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Store intermediate send/recvData structs for connect + struct ncclConnect** recvData = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Points to entries inside data for given recv connection within a channel + struct ncclConnect** sendData = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Points to entries inside data for given send connection within a channel NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->hostStream), ret, fail); + // First time initialization for (int i=1; inRanks; i++) { int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0); int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; @@ -79,22 +82,28 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* uint64_t recvMask = comm->connectRecv[recvPeer]; uint64_t sendMask = comm->connectSend[sendPeer]; - struct ncclConnect* recvData = data; + // Data[i] contains all ncclConnect information for all send and receive connections with a given send and recv peer + // This data is packed in the array based on the number of sendChannels and recvChannels connected with these peers + // The first N entries contain recvData, connection information for recv connections + // The next M entries contain sendData, connection information for send connections + // It's not guaranteed that each entry of data has the same number of total or send/recv specific connections + data[i] = (ncclConnect*) malloc(sizeof(ncclConnect) * 2*MAXCHANNELS); + recvData[i] = data[i]; int sendChannels = 0, recvChannels = 0; int type; TIME_START(0); for (int c=0; c(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex, &type), ret, fail); + NCCLCHECKGOTO(selectTransport<0>(comm, graph, recvData[i]+recvChannels++, c, recvPeer, connIndex, &type), ret, fail); if (type > highestType) highestType = type; } } TIME_STOP(0); TIME_START(1); - struct ncclConnect* sendData = recvData+recvChannels; + sendData[i] = recvData[i]+recvChannels; for (int c=0; c(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex, &type), ret, fail); + NCCLCHECKGOTO(selectTransport<1>(comm, graph, sendData[i]+sendChannels++, c, sendPeer, connIndex, &type), ret, fail); if (type > highestType) highestType = type; } } @@ -103,42 +112,82 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* TIME_START(2); if (sendPeer == recvPeer) { if (recvChannels+sendChannels) { - NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); - NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); - sendData = data; - recvData = data+sendChannels; + NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data[i], sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); + NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data[i], sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); + sendData[i] = data[i]; + recvData[i] = data[i]+sendChannels; } } else { - if (recvChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels), ret, fail); - if (sendChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels), ret, fail); - if (sendChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels), ret, fail); - if (recvChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels), ret, fail); + if (recvChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData[i], sizeof(struct ncclConnect)*recvChannels), ret, fail); + if (sendChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData[i], sizeof(struct ncclConnect)*sendChannels), ret, fail); + if (sendChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData[i], sizeof(struct ncclConnect)*sendChannels), ret, fail); + if (recvChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData[i], sizeof(struct ncclConnect)*recvChannels), ret, fail); } TIME_STOP(2); - - TIME_START(3); - for (int c=0; cchannels[c].peers[sendPeer].send + connIndex; - NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn), ret, fail); - conn->connected = 1; - CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); - } - } - TIME_STOP(3); - TIME_START(4); - for (int c=0; cchannels[c].peers[recvPeer].recv + connIndex; - NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn), ret, fail); - conn->connected = 1; - CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); - } - } - TIME_STOP(4); - comm->connectRecv[recvPeer] = comm->connectSend[sendPeer] = 0UL; } + // Loop until all channels with all ranks have been connected + bool allChannelsConnected; + allChannelsConnected = false; + while (!allChannelsConnected) { + allChannelsConnected = true; + for (int i=1; inRanks; i++) { + int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; + int sendPeer = (comm->rank + i) % comm->nRanks; + uint64_t recvMask = comm->connectRecv[recvPeer]; + uint64_t sendMask = comm->connectSend[sendPeer]; + + int sendDataOffset = 0; + int recvDataOffset = 0; + for (int c=0; cchannels[c].peers[sendPeer].send + connIndex; + // This connector hasn't completed connection yet + if (conn->connected == 0) { + NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData[i] + sendDataOffset++, 1, comm->rank, conn), ret, fail); + if (ret == ncclSuccess) { + conn->connected = 1; + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); + } else if (ret == ncclInProgress) { + allChannelsConnected = false; + } + } + } + TIME_STOP(3); + + // Start with recv channels + TIME_START(4); + if (recvMask & (1UL<channels[c].peers[recvPeer].recv + connIndex; + // This connector hasn't completed connection yet + if (conn->connected == 0) { + NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData[i] + recvDataOffset++, 1, comm->rank, conn), ret, fail); + if (ret == ncclSuccess) { + conn->connected = 1; + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); + } else if (ret == ncclInProgress) { + allChannelsConnected = false; + } + } + } + TIME_STOP(4); + } + } + } + + // Clear all connect masks and free each connectInfo array + for (int i=1; inRanks; i++) { + int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; + int sendPeer = (comm->rank + i) % comm->nRanks; + comm->connectRecv[recvPeer] = comm->connectSend[sendPeer] = 0UL; + free(data[i]); + } + + free(data); + free(sendData); + free(recvData); + if (highestTransportType != NULL) *highestTransportType = highestType; TIME_PRINT("P2P Setup/Connect"); exit: diff --git a/src/transport/coll_net.cc b/src/transport/coll_net.cc index de10f2f..2273518 100644 --- a/src/transport/coll_net.cc +++ b/src/transport/coll_net.cc @@ -152,13 +152,13 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int proxyRank; NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr)); - send->conn.direct |= req.useGdr ? NCCL_DIRECT_NIC : 0; + send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; // Determine whether we need to flush the GDR buffer on recv or not if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush)); NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &send->proxyConn.localRank)); NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 1, myInfo->rank, &send->proxyConn)); - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); INFO(NCCL_INIT|NCCL_NET,"CollNet %02d/%1d : %d [send] via COLLNET/%s/%d%s", channelId, connIndex, myInfo->rank, collNetName(comm), req.netDev, req.useGdr ? "/GDRDMA" : ""); @@ -171,12 +171,12 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int proxyRank; NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 0, &req.useGdr)); - recv->conn.direct |= req.useGdr ? NCCL_DIRECT_NIC : 0; + recv->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &recv->proxyConn.localRank)); NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 0, myInfo->rank, &recv->proxyConn)); struct collNetRecvConnectInfo* info = (struct collNetRecvConnectInfo*) connectInfo; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), &info->collNetHandle, sizeof(collNetHandle_t))); + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), &info->collNetHandle, sizeof(collNetHandle_t))); INFO(NCCL_INIT|NCCL_NET,"CollNet %02d/%1d : %d [receive] via COLLNET/%s/%d%s", channelId, connIndex, myInfo->rank, collNetName(comm), req.netDev, req.useGdr ? "/GDRDMA" : ""); @@ -221,7 +221,7 @@ static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* conne // We're on the same process as the proxy. We can pass a pointer to a struct. struct collNetConnectArgs args = { rank, nranks, connectInfos }; struct connectMap* map; - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); // If collnet connect failed, propagate error to fallback on regular p2p if (map == NULL) return ncclSystemError; @@ -247,7 +247,7 @@ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* conne // We're on the same process as the proxy. We can pass a pointer to a struct. struct collNetConnectArgs args = { rank, nranks, connectInfos }; struct connectMap* map; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); // If collnet connect failed, propagate error to fallback on regular p2p if (map == NULL) return ncclSystemError; @@ -410,7 +410,7 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc } static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { - if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("sendProxyConnect: reqSize is %d != %ld\n", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } + if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("sendProxyConnect: reqSize is %d != %ld", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } struct collNetConnectArgs* args = (struct collNetConnectArgs*)reqBuff; struct collNetSendConnectInfo* info = (struct collNetSendConnectInfo*)(args->connectInfos+args->rank); @@ -426,7 +426,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(sharedConnect(comm, resources->netDev, args->connectInfos, args->nranks, args->rank, &resources->collNetComm)); // Collnet connect is allowed to fail. Gracefully handle that case by returning NULL to the caller. - if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld\n", respSize, sizeof(void*)); return ncclInternalError; } + if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; } if (resources->collNetComm == NULL) { *((struct connectMap**)respBuff) = NULL; return ncclSuccess; @@ -484,7 +484,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str } static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { - if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("recvProxyConnect: reqSize is %d != %ld\n", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } + if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("recvProxyConnect: reqSize is %d != %ld", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } struct collNetConnectArgs* args = (struct collNetConnectArgs*)reqBuff; struct recvResources* resources = (struct recvResources*)(connection->transportResources); @@ -494,7 +494,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(sharedConnect(comm, resources->netDev, args->connectInfos, args->nranks, args->rank, &resources->collNetComm)); // Collnet connect is allowed to fail. Gracefully handle that case by returning NULL to the caller. - if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld\n", respSize, sizeof(void*)); return ncclInternalError; } + if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; } if (resources->collNetComm == NULL) { *((struct connectMap**)respBuff) = NULL; return ncclSuccess; @@ -553,7 +553,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str for (int p=0; pmhandles[p] = resources->mhandles[p]; - if (respSize != sizeof(struct connectMap*)) { WARN("recvProxyConnect: respSize is %d != %ld\n", respSize, sizeof(void*)); return ncclInternalError; } + if (respSize != sizeof(struct connectMap*)) { WARN("recvProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; } *((struct connectMap**)respBuff) = &resources->map; return ncclSuccess; } diff --git a/src/transport/net.cc b/src/transport/net.cc index b358ad6..fe98a4c 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -172,13 +172,13 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int proxyRank; NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, peerInfo->rank, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr)); - send->conn.direct |= req.useGdr ? NCCL_DIRECT_NIC : 0; + send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_NET, 1, proxyRank, &send->proxyConn)); req.rank = myInfo->rank; NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &req.localRank)); req.remoteRank = peerInfo->rank; - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); if (proxyRank == myInfo->rank) { INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d%s%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, ncclNetName(comm), req.netDev, @@ -218,8 +218,7 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph req.rank = myInfo->rank; NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &req.localRank)); req.remoteRank = peerInfo->rank; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), connectInfo, sizeof(ncclNetHandle_t))); - + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), connectInfo, sizeof(ncclNetHandle_t))); INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [receive] via NET/%s/%d%s%s", channelId, connIndex, peerInfo->rank, peerInfo->busId, myInfo->rank, myInfo->busId, ncclNetName(comm), req.netDev, req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : ""); return ncclSuccess; @@ -264,11 +263,28 @@ static ncclResult_t netDumpMap(struct connectMap* map) { } static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) { - // Setup device pointers - struct connectMap* map; - NCCLCHECK(ncclCalloc(&map, 1)); - send->transportResources = map; - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(ncclNetHandle_t), map, sizeof(struct connectMap))); + struct connectMap* map = (connectMap*) send->transportResources; + + void* opId; + + // map isn't allocated thus this op hasn't been submitted yet + if (!map) { + // Setup device pointers + NCCLCHECK(ncclCalloc(&map, 1)); + send->transportResources = map; + opId = send; + INFO(NCCL_PROXY, "sendConnect ncclProxyCallAsync opId=%p", opId); + NCCLCHECK(ncclProxyCallAsync(&send->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(ncclNetHandle_t), sizeof(struct connectMap), opId)); + } else { + opId = send; + } + + ncclResult_t ret; + NCCLCHECK(ret = ncclPollProxyResponse(&send->proxyConn, map, opId)); + if (ret == ncclInProgress) { + return ret; + } + INFO(NCCL_PROXY, "sendConnect ncclPollProxyResponse opId=%p", opId); if (map->sameProcess) { if (map->cudaDev != comm->cudaDev) { @@ -315,10 +331,26 @@ static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* conne /* Connect to this peer */ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) { - struct connectMap* map; - NCCLCHECK(ncclCalloc(&map, 1)); - recv->transportResources = map; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(int), map, sizeof(struct connectMap))); + struct connectMap* map = (connectMap*) recv->transportResources; + void* opId; + if (!map) { + NCCLCHECK(ncclCalloc(&map, 1)); + recv->transportResources = map; + // Use recv connector as unique identifier + opId = recv; + INFO(NCCL_PROXY, "recvConnect ncclProxyCallAsync opId=%p &recv->proxyConn=%p connectInfo=%p", + opId, &recv->proxyConn, connectInfo); + NCCLCHECK(ncclProxyCallAsync(&recv->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(int), sizeof(struct connectMap), opId)); + } else { + opId = recv; + } + + ncclResult_t ret; + NCCLCHECK(ret = ncclPollProxyResponse(&recv->proxyConn, map, opId)); + if (ret == ncclInProgress) { + return ret; + } + INFO(NCCL_PROXY, "recvConnect ncclPollProxyResponse opId=%p", opId); //NCCLCHECK(netDumpMap(map)); struct ncclSendMem *sendMem = (struct ncclSendMem*) NCCL_NET_MAP_GET_POINTER(map, gpu, sendMem); @@ -490,12 +522,14 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc if (respSize != sizeof(ncclNetHandle_t)) return ncclInternalError; NCCLCHECK(ncclNetListen(comm, req->netDev, respBuff, &resources->netListenComm)); *done = 1; + return ncclSuccess; } static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { struct sendResources* resources = (struct sendResources*)(connection->transportResources); if (reqSize != sizeof(ncclNetHandle_t)) return ncclInternalError; + ncclResult_t ret = ncclSuccess; if (resources->shared) { // Shared buffers @@ -515,21 +549,22 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclCalloc(progressState->netComms+resources->netDev, comm->nRanks)); } struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev]+resources->remoteRank; - if (comms->sendComm[resources->channelId] == NULL) NCCLCHECK(ncclNetConnect(comm, resources->netDev, reqBuff, comms->sendComm+resources->channelId)); + if (comms->sendComm[resources->channelId] == NULL) ret = ncclNetConnect(comm, resources->netDev, reqBuff, comms->sendComm+resources->channelId); resources->netSendComm = comms->sendComm[resources->channelId]; if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++; } else { - NCCLCHECK(ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm)); + ret = ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm); } } else { // Connect to remote peer - NCCLCHECK(ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm)); + ret = ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm); connection->proxyAppendPtr = &connection->proxyAppend; } + NCCLCHECK(ret); if (resources->netSendComm == NULL) { *done = 0; - return ncclSuccess; + return ncclInProgress; } *done = 1; @@ -630,6 +665,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str if (reqSize != sizeof(int)) return ncclInternalError; struct recvResources* resources = (struct recvResources*)(connection->transportResources); resources->proxyRank = *(int*)reqBuff; + ncclResult_t ret = ncclSuccess; // Finish connection establishment from remote peer if (resources->shared) { @@ -650,23 +686,25 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclCalloc(progressState->netComms+resources->netDev, comm->nRanks)); } struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev]+resources->proxyRank; - if (comms->recvComm[resources->channelId] == NULL) NCCLCHECK(ncclNetAccept(comm, resources->netListenComm, comms->recvComm+resources->channelId)); + if (comms->recvComm[resources->channelId] == NULL) ret = ncclNetAccept(comm, resources->netListenComm, comms->recvComm+resources->channelId); resources->netRecvComm = comms->recvComm[resources->channelId]; if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++; } else { - NCCLCHECK(ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm)); + ret = ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm); } } else { // Connect to remote peer - NCCLCHECK(ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm)); + ret = ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm); connection->proxyAppendPtr = &connection->proxyAppend; } + NCCLCHECK(ret); if (resources->netRecvComm == NULL) { *done = 0; - return ncclSuccess; + return ncclInProgress; } *done = 1; + NCCLCHECK(ncclNetCloseListen(comm, resources->netListenComm)); // Create structures diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index 0645005..664d51b 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -363,7 +363,9 @@ enum ncclIbCommState { ncclIbCommStateAccept = 3, ncclIbCommStateSend = 4, ncclIbCommStateRecv = 5, - ncclIbCommStateConnected = 6, + ncclIbCommStateConnecting = 6, + ncclIbCommStateConnected = 7, + ncclIbCommStatePendingReady = 8, }; struct ncclIbCommStage { @@ -599,8 +601,10 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) { int ready; *sendComm = NULL; - if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; - if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; + if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStateConnecting) goto ib_connect; + if (stage->state == ncclIbCommStateConnected) goto ib_send_ready; if (stage->state != ncclIbCommStateStart) { WARN("Error: trying to connect already connected sendComm"); return ncclInternalError; @@ -664,11 +668,37 @@ ib_connect_check: ib_send: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->sock, stage->buffer, sizeof(qpInfo), &stage->offset)); - if (stage->offset != sizeof(qpInfo)) - return ncclSuccess; + if (stage->offset != sizeof(qpInfo)) return ncclSuccess; + + stage->state = ncclIbCommStateConnecting; + stage->offset = 0; + // Clear the staging buffer for re-use + memset(stage->buffer, 0, sizeof(qpInfo)); + +ib_connect: + struct ncclIbQpInfo remQpInfo; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, stage->buffer, sizeof(ncclIbQpInfo), &stage->offset)); + if (stage->offset != sizeof(remQpInfo)) return ncclSuccess; + + memcpy(&remQpInfo, stage->buffer, sizeof(ncclIbQpInfo)); + + for (int q=0; qnqps; q++) { + struct ibv_qp* qp = comm->qps[q]; + NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); + NCCLCHECK(ncclIbRtsQp(qp)); + } + + comm->ready = 1; + stage->state = ncclIbCommStateConnected; + stage->offset = 0; + +ib_send_ready: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->sock, &comm->ready, sizeof(int), &stage->offset)); + if (stage->offset != sizeof(int)) return ncclSuccess; free(stage->buffer); - stage->state = ncclIbCommStateConnected; + stage->state = ncclIbCommStateStart; + *sendComm = comm; return ncclSuccess; } @@ -685,8 +715,9 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; if (stage->state == ncclIbCommStateRecv) goto ib_recv; if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStatePendingReady) goto ib_recv_ready; if (stage->state != ncclIbCommStateStart) { - WARN("Listencomm in unknown state %d\n", stage->state); + WARN("Listencomm in unknown state %d", stage->state); return ncclInternalError; } @@ -704,10 +735,10 @@ ib_accept_check: stage->state = ncclIbCommStateRecv; stage->offset = 0; NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remQpInfo))); + ib_recv: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->sock, stage->buffer, sizeof(remQpInfo), &stage->offset)); - if (stage->offset != sizeof(remQpInfo)) - return ncclSuccess; + if (stage->offset != sizeof(remQpInfo)) return ncclSuccess; /* copy back the received info */ memcpy(&remQpInfo, stage->buffer, sizeof(struct ncclIbQpInfo)); @@ -780,10 +811,18 @@ ib_recv: if (stage->buffer) free(stage->buffer); NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(struct ncclIbQpInfo))); memcpy(stage->buffer, &qpInfo, sizeof(struct ncclIbQpInfo)); + ib_send: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->sock, stage->buffer, sizeof(struct ncclIbQpInfo), &stage->offset)); if (stage->offset < sizeof(struct ncclIbQpInfo)) return ncclSuccess; + stage->offset = 0; + stage->state = ncclIbCommStatePendingReady; + +ib_recv_ready: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->sock, &rComm->ready, sizeof(int), &stage->offset)); + if (stage->offset != sizeof(int)) return ncclSuccess; + free(stage->buffer); *recvComm = rComm; @@ -815,36 +854,6 @@ ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) { return ncclSuccess; } -ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) { - struct ncclIbQpInfo remQpInfo; - - // Do not block on this receive, return if not ready. - int bytes = 0; - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &remQpInfo, sizeof(remQpInfo), &bytes)); - if (bytes == 0) return ncclSuccess; // Try again later - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_RECV, &comm->sock, &remQpInfo, sizeof(remQpInfo), &bytes)); - - for (int q=0; qnqps; q++) { - struct ibv_qp* qp = comm->qps[q]; - NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); - } - comm->ready = 1; - // Block until this is done. It *should* not block indefinitely. - NCCLCHECK(ncclSocketSend(&comm->sock, &comm->ready, sizeof(int))); - - return ncclSuccess; -} - -ncclResult_t ncclRecvCheck(struct ncclIbRecvComm* comm) { - // Do not block on this receive, return if not ready. - int bytes = 0; - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &comm->ready, sizeof(int), &bytes)); - if (bytes == 0) return ncclSuccess; // Try again later - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_RECV, &comm->sock, &comm->ready, sizeof(int), &bytes)); - return ncclSuccess; -} - ncclResult_t ncclIbTest(void* request, int* done, int* size); /* DMA-BUF support */ @@ -1020,7 +1029,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; - if (comm->ready == 0) NCCLCHECK(ncclSendCheck(comm)); + if (comm->ready == 0) { WARN("NET/IB: ncclIbIsend() called when comm->ready == 0"); return ncclInternalError; } if (comm->ready == 0) { *request = NULL; return ncclSuccess; } struct ibv_mr* mr = (struct ibv_mr*)mhandle; @@ -1153,7 +1162,7 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; - if (comm->ready == 0) NCCLCHECK(ncclRecvCheck(comm)); + if (comm->ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->ready == 0"); return ncclInternalError; } if (comm->ready == 0) { *request = NULL; return ncclSuccess; } if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; diff --git a/src/transport/nvls.cc b/src/transport/nvls.cc new file mode 100644 index 0000000..336877c --- /dev/null +++ b/src/transport/nvls.cc @@ -0,0 +1,373 @@ +/************************************************************************* + * Copyright (c) 2016-2023, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +// Implementation of the NVLink SHARP (NVLS) transport + +#include "comm.h" +#include "graph.h" +#include "utils.h" +#include "proxy.h" + +#if CUDART_VERSION >= 12010 + +// Currently we only support POSIX_FILE_DESCRIPTOR handle exchange +#define USE_POSIX_FD 1 + +#if USE_POSIX_FD +#define NVLS_CU_MEM_HANDLE_TYPE CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR +#else +#define NVLS_CU_MEM_HANDLE_TYPE CU_MEM_HANDLE_TYPE_NONE +#endif + +ncclResult_t nvlsCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { + // This transport cannot be used for p2p + *ret = 0; + return ncclSuccess; +} + +ncclResult_t nvlsSendFree(struct ncclConnector* send) { + return ncclSuccess; +} + +ncclResult_t nvlsRecvFree(struct ncclConnector* recv) { + return ncclSuccess; +} + +struct ncclTransport nvlsTransport = { + "NVLS", + nvlsCanConnect, + { NULL, NULL, nvlsSendFree, NULL, NULL, NULL, NULL, NULL }, + { NULL, NULL, nvlsRecvFree, NULL, NULL, NULL, NULL, NULL } +}; + +#define NVLS_HANDLE_SIZE 64 + +struct nvlsResources { + CUmulticastObjectProp properties; + CUmemAccessDesc accessDesc; + int dev; + size_t size; + size_t granularity; + CUmemGenericAllocationHandle mcHandle; // Multicast handle for NVLS buffer + char* mcBuff; // Multicast NVLS buffer address + CUmemGenericAllocationHandle ucHandle; // Unicast Handle for NVLS buffer + char* ucBuff; // Unicast NVLS buffer address +}; + + +ncclResult_t nvlsGetProperties(struct ncclComm *comm, struct nvlsResources* resources, int dev, int nranks, size_t size) { + CUmulticastObjectProp* prop = &resources->properties; + memset(prop, 0, sizeof(*prop)); + prop->size = size; + prop->numDevices = nranks; + prop->handleTypes = NVLS_CU_MEM_HANDLE_TYPE; + prop->flags = 0; + + // Could be changed to CU_MULTICAST_GRANULARITY_MINIMUM when 3418538 resolved + CUCHECK(cuMulticastGetGranularity(&resources->granularity, prop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); + + ALIGN_SIZE(size, resources->granularity); + prop->size = resources->size = size; + + memset(&resources->accessDesc, 0, sizeof(resources->accessDesc)); + resources->accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + resources->accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + resources->accessDesc.location.id = dev; + resources->dev = dev; + + return ncclSuccess; +} + +ncclResult_t nvlsGroupCreate(struct ncclComm *comm, struct nvlsResources* resources, int rank, unsigned int nranks, char* shareableHandle) { + size_t size = resources->size; + + // Create a Multicast group + CUmulticastObjectProp* prop = &resources->properties; + + INFO(NCCL_NVLS, "NVLS Creating Multicast group nranks %d size %zi on rank %d", nranks, size, rank); + CUCHECK(cuMulticastCreate(&resources->mcHandle, prop)); + + if (NVLS_CU_MEM_HANDLE_TYPE != CU_MEM_HANDLE_TYPE_NONE) { + // Get a handle to pass to other ranks + CUCHECK(cuMemExportToShareableHandle(shareableHandle, resources->mcHandle, NVLS_CU_MEM_HANDLE_TYPE, 0)); + } + else { + memcpy(shareableHandle, &resources->mcHandle, sizeof(resources->mcHandle)); + } + + INFO(NCCL_NVLS, "NVLS Created Multicast group %llx nranks %d size %zi on rank %d", resources->mcHandle, nranks, size, rank); + + return ncclSuccess; +} + +ncclResult_t nvlsGroupAddDevice(struct ncclComm *comm, struct nvlsResources* resources) { + INFO(NCCL_NVLS, "NVLS group %llx adding dev %d", resources->mcHandle, resources->dev); + CUCHECK(cuMulticastAddDevice(resources->mcHandle, resources->dev)); + return ncclSuccess; +} + +ncclResult_t nvlsGroupUnbind(struct ncclComm *comm, struct nvlsResources* resources) { + int dev = resources->dev; + size_t size = resources->size; + INFO(NCCL_NVLS, "NVLS Unbind MC handle %llx size %zi dev %d", resources->mcHandle, size, dev); + + // Unbind physical memory from group for the given device + CUCHECK(cuMulticastUnbind(resources->mcHandle, dev, 0/*mcOffset*/, size)); + + return ncclSuccess; +} + +ncclResult_t nvlsGroupConnect(struct ncclComm *comm, struct nvlsResources* resources, int rank, char* shareableHandle) { + CUmemAllocationHandleType type = NVLS_CU_MEM_HANDLE_TYPE; + + INFO(NCCL_NVLS, "NVLS importing shareableHandle %p from rank %d", shareableHandle, rank); + + // Import and map the remote memory descriptor to the local GPU + if (type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) { + // cuMem UDS support + int fd = *(int *)shareableHandle; + TRACE(NCCL_NVLS, "NVLS rank %d Importing shareable handle from rank %d fd %d", comm->localRank, rank, fd); + struct ncclProxyConnector proxyConn; + NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, rank, &proxyConn)); + TRACE(NCCL_NVLS, "NVLS rank %d request conversion of fd %d from rank %d", comm->localRank, fd, rank); + NCCLCHECK(ncclProxyCallBlocking(&proxyConn, ncclProxyMsgConvertFd, shareableHandle, sizeof(int), &fd, sizeof(int))); + TRACE(NCCL_NVLS, "NVLS rank %d received converted fd %d from rank %d", comm->localRank, fd, rank); + CUCHECK(cuMemImportFromShareableHandle(&resources->mcHandle, (void *)(uintptr_t)fd, type)); + } else { + if (NVLS_CU_MEM_HANDLE_TYPE != CU_MEM_HANDLE_TYPE_NONE) { + CUCHECK(cuMemImportFromShareableHandle(&resources->mcHandle, (void *)shareableHandle, type)); + } else { + memcpy(&resources->mcHandle, shareableHandle, sizeof(resources->mcHandle)); + } + } + return ncclSuccess; +} + +ncclResult_t nvlsGroupBindMem(struct ncclComm *comm, struct nvlsResources* resources) { + size_t size = resources->size; + size_t granularity; + CUdeviceptr ptr = 0; + CUmemAllocationProp prop; + + memset(&prop, 0, sizeof(prop)); + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = resources->dev; + prop.requestedHandleTypes = NVLS_CU_MEM_HANDLE_TYPE; + CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); + + // Map a VA for UC memory + CUCHECK(cuMemAddressReserve(&ptr, size, granularity, 0U, 0)); + + // Alloc local physical mem for this NVLS group + CUCHECK(cuMemCreate(&resources->ucHandle, size, &prop, 0)); + CUCHECK(cuMemMap(ptr, size, 0, resources->ucHandle, 0)); + CUCHECK(cuMemSetAccess(ptr, size, &resources->accessDesc, 1)); + CUDACHECK(cudaMemset((void*)ptr, 0, size)); + resources->ucBuff = (char*)ptr; + INFO(NCCL_NVLS, "NVLS Mapped UC at %p size %zi", resources->ucBuff, size); + + // Bind physical memory to the Multicast group + // NB: It will block until all ranks have been added to the Group + INFO(NCCL_NVLS, "NVLS Bind mem %p UC handle 0x%llx MC handle 0x%llx size %zi", (void*)ptr, resources->ucHandle, resources->mcHandle, size); + CUCHECK(cuMulticastBindMem(resources->mcHandle, 0/*mcOffset*/, resources->ucHandle, 0/*memOffset*/, size, 0/*flags*/)); + + return ncclSuccess; +} + +ncclResult_t nvlsGroupMapMem(struct ncclComm *comm, struct nvlsResources* resources) { + size_t size = resources->size; + CUdeviceptr ptr = 0; + + // Create a VA for the NVLS + CUCHECK(cuMemAddressReserve(&ptr, size, resources->granularity, 0U, 0)); + // Map the VA locally + CUCHECK(cuMemMap(ptr, size, 0, resources->mcHandle, 0)); + resources->mcBuff = (char*)ptr; + INFO(NCCL_NVLS, "NVLS Mapped MC buffer at %p size %zi", resources->mcBuff, size); + + // Having completed the BindMem we can now call SetAccess + // NB: It will block until all ranks have bound to the Group + CUCHECK(cuMemSetAccess((CUdeviceptr)resources->mcBuff, size, &resources->accessDesc, 1)); + + return ncclSuccess; +} + +ncclResult_t nvlsGroupUnmapMem(struct ncclComm *comm, struct nvlsResources* resources) { + size_t size; + CUdeviceptr ptr; + INFO(NCCL_NVLS, "NVLS Unmap mem UC handle 0x%llx(%p) MC handle 0x%llx(%p)", + resources->ucHandle, resources->ucBuff, resources->mcHandle, resources->mcBuff); + + // Release the UC memory and mapping + ptr = (CUdeviceptr)resources->ucBuff; + size = resources->size; + CUCHECK(cuMemUnmap(ptr, size)); + CUCHECK(cuMemAddressFree(ptr, size)); + CUCHECK(cuMemRelease(resources->ucHandle)); + + // Release the MC memory and mapping + ptr = (CUdeviceptr)resources->mcBuff; + size = resources->size; + CUCHECK(cuMemUnmap(ptr, size)); + CUCHECK(cuMemAddressFree(ptr, size)); + CUCHECK(cuMemRelease(resources->mcHandle)); + + return ncclSuccess; +} + +#include "bootstrap.h" +#include "channel.h" + +#define NVLS_MEM_ALIGN_SIZE (1 << 21) + +NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", 16); + +NCCL_PARAM(NvlsEnable, "NVLS_ENABLE", 1); + +ncclResult_t ncclNvlsSetup(struct ncclComm* comm) { + if (!ncclParamNvlsEnable() || comm->localRanks <= 1 || comm->nNodes>1) return ncclSuccess; + CUdevice dev; + int driverVersion; + if (CUPFN(cuDeviceGet) == NULL) return ncclSuccess; + CUCHECK(cuDeviceGet(&dev, comm->cudaDev)); + CUDACHECK(cudaDriverGetVersion(&driverVersion)); + comm->nvlsSupport = 0; + // NVLS Multicast support requires CUDA12.1 UMD + KMD + if (CUPFN(cuMulticastCreate) != NULL && driverVersion >= 12010) { + CUCHECK(cuDeviceGetAttribute(&comm->nvlsSupport, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev)); + } + INFO(NCCL_INIT, "NVLS multicast support is %savailable on dev %d", comm->nvlsSupport ? "" : "not ", dev); + if (comm->nvlsSupport == 0) return ncclSuccess; + + int nChannels = comm->nvlsChannels = std::max(comm->minCTAs, std::min(comm->maxCTAs, (int)ncclParamNvlsChannels())); + int rank = comm->localRank, nranks = comm->localRanks; + + for (int c=0; cnvlsResources = resources; + + size_t buffSize = comm->buffSizes[NCCL_PROTO_SIMPLE]; + size_t memSize = NVLS_MEM_ALIGN_SIZE; + size_t nvlsPerRankSize = nChannels*2*(buffSize+memSize); + size_t nvlsTotalSize = nvlsPerRankSize*nranks; + + INFO(NCCL_INIT|NCCL_NVLS, "NVLS comm %p rank %d nranks %d buffSize %zi memSize %zi nvlsPerRankSize %zi nvlsTotalSize %zi", + comm, rank, nranks, buffSize, memSize, nvlsPerRankSize, nvlsTotalSize); + + char* nvlsShareableHandle = NULL; + NCCLCHECKGOTO(ncclCalloc(&nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup); + NCCLCHECKGOTO(nvlsGetProperties(comm, resources, dev, nranks, nvlsTotalSize), res, cleanup); + if (rank == 0) { + NCCLCHECKGOTO(nvlsGroupCreate(comm, resources, rank, nranks, nvlsShareableHandle), res, cleanup); + NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, rank, nranks, 0, nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup); + } else { + NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, rank, nranks, 0, nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup); + NCCLCHECKGOTO(nvlsGroupConnect(comm, resources, 0, nvlsShareableHandle), res, cleanup); + } + + NCCLCHECKGOTO(nvlsGroupAddDevice(comm, resources), res, cleanup); + NCCLCHECKGOTO(nvlsGroupBindMem(comm, resources), res, cleanup); + // Local intra-node barrier to ensure everyone has bound their memory to the group + NCCLCHECKGOTO(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]), res, cleanup); + NCCLCHECKGOTO(nvlsGroupMapMem(comm, resources), res, cleanup); + + for (int c=0; cchannels+c; + channel->nvls.nHeads = nranks; + for (int i=0; invls.up[i] = -1; + channel->nvls.down = comm->nRanks+1+comm->localRank; + channel->nvls.out = -1; // Network not yet implemented. + channel->nvls.headRank = comm->localRank; // Network not yet implemented. + } + + for (int r=0; rnRanks+1+r; + for (int c=0; cchannels+c; + channel->nvls.up[r] = nvlsPeer; + + char* mem = NULL; + struct ncclChannelPeer* peer = channel->peers+nvlsPeer; + + // Reduce UC -> MC + mem = resources->ucBuff + (r*2*nChannels+c)*(buffSize+memSize); + peer->send[0].transportComm = &nvlsTransport.send; + peer->send[0].conn.buffs[NCCL_PROTO_SIMPLE] = mem; + peer->send[0].conn.head = (uint64_t*)(mem+buffSize); + peer->send[0].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); + mem = resources->mcBuff + (r*2*nChannels+c)*(buffSize+memSize); + peer->recv[1].transportComm = &nvlsTransport.recv; + peer->recv[1].conn.buffs[NCCL_PROTO_SIMPLE] = mem; + peer->recv[1].conn.head = (uint64_t*)(mem+buffSize); + peer->recv[1].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); + peer->recv[1].conn.flags |= NCCL_NVLS_MIN_POLL; + + // Broadcast MC -> UC + mem = resources->ucBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize); + peer->recv[0].transportComm = &nvlsTransport.recv; + peer->recv[0].conn.buffs[NCCL_PROTO_SIMPLE] = mem; + peer->recv[0].conn.head = (uint64_t*)(mem+buffSize); + peer->recv[0].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); + mem = resources->mcBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize); + peer->send[1].transportComm = &nvlsTransport.send; + peer->send[1].conn.buffs[NCCL_PROTO_SIMPLE] = mem; + peer->send[1].conn.head = (uint64_t*)(mem+buffSize); + peer->send[1].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); + peer->send[1].conn.flags |= NCCL_NVLS_MIN_POLL; + + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].send[0], &peer->send[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].recv[0], &peer->recv[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].send[1], &peer->send[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].recv[1], &peer->recv[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); + + /*INFO(NCCL_INIT|NCCL_NVLS, "Peer %d Channel %d MC buff %p/%p UC Buff %p/%p", + nvlsPeer, c, + resources->mcBuff + (r*2*nChannels+c)*(buffSize+memSize), + resources->mcBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize), + resources->ucBuff + (r*2*nChannels+c)*(buffSize+memSize), + resources->ucBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize));*/ + } + } + + free(nvlsShareableHandle); + return res; + +cleanup: + comm->nvlsSupport = 0; + free(nvlsShareableHandle); + return res; +} + +ncclResult_t ncclNvlsFree(struct ncclComm* comm) { + struct nvlsResources* resources = (struct nvlsResources*)comm->nvlsResources; + if (resources == NULL) return ncclSuccess; + NCCLCHECK(nvlsGroupUnbind(comm, resources)); + NCCLCHECK(nvlsGroupUnmapMem(comm, resources)); + free(resources); + comm->nvlsResources = NULL; + return ncclSuccess; +} + +#else + +/* + * Pre CUDA 12.1 stubs + */ + +ncclResult_t ncclNvlsSetup(struct ncclComm* comm) { + return ncclSuccess; +} + +ncclResult_t ncclNvlsFree(struct ncclComm* comm) { + return ncclSuccess; +} + +#endif /* CUDA_VERSION >= 12010 */ diff --git a/src/transport/p2p.cc b/src/transport/p2p.cc index e7a4fd0..fa33d7b 100644 --- a/src/transport/p2p.cc +++ b/src/transport/p2p.cc @@ -239,11 +239,11 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st if (intermediateRank == -1) { info->rank = myInfo->rank; if (myInfo->pidHash == peerInfo->pidHash && useMemcpy == 0) { - if (ncclParamP2pDirectDisable() == 0) send->conn.direct |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; + if (ncclParamP2pDirectDisable() == 0) send->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr); } else { - send->conn.direct |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; + send->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/IPC%s%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, useMemcpy ? "/CE" : ""); } @@ -256,11 +256,11 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, info->rank, &send->proxyConn)); if (useMemcpy) { - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgSetup, NULL, 0, &resources->proxyInfo, sizeof(struct p2pProxyInfo))); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, NULL, 0, &resources->proxyInfo, sizeof(struct p2pProxyInfo))); info->shmSize = resources->proxyInfo.shmSize; memcpy(info->shmName, resources->proxyInfo.shmName, sizeof(info->shmName)); } else { - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgSetup, &sendSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &sendSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); NCCLCHECK(p2pMap(myInfo, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&resources->devMem, &resources->sendMemIpc)); } @@ -290,16 +290,16 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st if (intermediateRank == -1) { info->rank = myInfo->rank; if (myInfo->pidHash == peerInfo->pidHash && useMemcpy == 0) { - if (ncclParamP2pDirectDisable() == 0) recv->conn.direct |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; + if (ncclParamP2pDirectDisable() == 0) recv->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; } else { - recv->conn.direct |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; + recv->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; } } else { info->rank = intermediateRank; } NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 0, info->rank, &recv->proxyConn)); - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgSetup, &recvSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &recvSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); NCCLCHECK(p2pMap(myInfo, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&resources->devMem, &resources->recvMemIpc)); return ncclSuccess; @@ -330,7 +330,7 @@ static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* co send->conn.sizesFifo = resources->proxyInfo.ceRecvMem->sizesFifo; send->conn.head = &resources->proxyInfo.devShm->sendMem.head; // Send SIMPLE buff to proxy, and replace it by local buffer - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgConnect, &send->conn.buffs[NCCL_PROTO_SIMPLE], sizeof(void*), NULL, 0)); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &send->conn.buffs[NCCL_PROTO_SIMPLE], sizeof(void*), NULL, 0)); send->conn.buffs[NCCL_PROTO_SIMPLE] = resources->proxyInfo.ceDevBuff; } else { send->conn.tail = &remDevMem->tail; diff --git a/src/transport/shm.cc b/src/transport/shm.cc index 4bce480..30fc992 100644 --- a/src/transport/shm.cc +++ b/src/transport/shm.cc @@ -157,7 +157,7 @@ static ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* co if (useMemcpySend) { NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_SHM, 1, comm->rank, &send->proxyConn)); struct shmProxyInfo proxyInfo = { NULL, NULL, send->conn.buffs[NCCL_PROTO_SIMPLE], resources->hostMem, resources->remHostMem }; - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); send->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo; send->conn.tail = &proxyInfo.ceRecvMem->tail; send->conn.sizesFifo = proxyInfo.ceRecvMem->sizesFifo; @@ -187,7 +187,7 @@ static ncclResult_t shmRecvConnect(struct ncclComm* comm, struct ncclConnect* co if (useMemcpyRecv) { NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_SHM, 0, comm->rank, &recv->proxyConn)); struct shmProxyInfo proxyInfo = { NULL, NULL, recv->conn.buffs[NCCL_PROTO_SIMPLE], resources->remHostMem, resources->hostMem }; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); recv->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo; recv->conn.tail = &proxyInfo.ceRecvMem->tail; }