diff --git a/makefiles/common.mk b/makefiles/common.mk index 0c0d04a..35d1826 100644 --- a/makefiles/common.mk +++ b/makefiles/common.mk @@ -10,7 +10,7 @@ VERBOSE ?= 0 KEEP ?= 0 DEBUG ?= 0 TRACE ?= 0 -PROFAPI ?= 0 +PROFAPI ?= 1 NVTX ?= 1 NVCC = $(CUDA_HOME)/bin/nvcc @@ -25,22 +25,26 @@ CUDA_MINOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 2) # You should define NVCC_GENCODE in your environment to the minimal set # of archs to reduce compile time. -CUDA8_GENCODE = -gencode=arch=compute_35,code=sm_35 \ - -gencode=arch=compute_50,code=sm_50 \ +CUDA8_GENCODE = -gencode=arch=compute_50,code=sm_50 \ -gencode=arch=compute_60,code=sm_60 \ -gencode=arch=compute_61,code=sm_61 +ifeq ($(shell test "0$(CUDA_MAJOR)" -lt 12; echo $$?),0) +# SM35 is deprecated from CUDA12.0 onwards +CUDA8_GENCODE += -gencode=arch=compute_35,code=sm_35 +endif CUDA9_GENCODE = -gencode=arch=compute_70,code=sm_70 CUDA11_GENCODE = -gencode=arch=compute_80,code=sm_80 -CUDA11_8_GENCODE = -gencode=arch=compute_90,code=sm_90 +CUDA12_GENCODE = -gencode=arch=compute_90,code=sm_90 CUDA8_PTX = -gencode=arch=compute_61,code=compute_61 CUDA9_PTX = -gencode=arch=compute_70,code=compute_70 CUDA11_PTX = -gencode=arch=compute_80,code=compute_80 -CUDA11_8_PTX = -gencode=arch=compute_90,code=compute_90 +CUDA12_PTX = -gencode=arch=compute_90,code=compute_90 + ifeq ($(shell test "0$(CUDA_MAJOR)" -eq 11 -a "0$(CUDA_MINOR)" -ge 8 -o "0$(CUDA_MAJOR)" -gt 11; echo $$?),0) # Include Hopper support if we're using CUDA11.8 or above - NVCC_GENCODE ?= $(CUDA8_GENCODE) $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA11_8_GENCODE) $(CUDA11_8_PTX) + NVCC_GENCODE ?= $(CUDA8_GENCODE) $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA12_GENCODE) $(CUDA12_PTX) else ifeq ($(shell test "0$(CUDA_MAJOR)" -ge 11; echo $$?),0) NVCC_GENCODE ?= $(CUDA8_GENCODE) $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA11_PTX) # Include Volta support if we're using CUDA9 or above diff --git a/makefiles/version.mk b/makefiles/version.mk index be64e9a..f4a0d8d 100644 --- a/makefiles/version.mk +++ b/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 -NCCL_MINOR := 15 -NCCL_PATCH := 5 +NCCL_MINOR := 16 +NCCL_PATCH := 2 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/src/Makefile b/src/Makefile index 1539e14..4753018 100644 --- a/src/Makefile +++ b/src/Makefile @@ -9,7 +9,7 @@ include ../makefiles/version.mk ##### src files INCEXPORTS := nccl.h nccl_net.h -LIBSRCFILES := init.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \ +LIBSRCFILES := init.cc init_nvtx.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \ misc/cudawrap.cc misc/nvmlwrap.cc misc/ibvwrap.cc misc/gdrwrap.cc \ misc/utils.cc misc/argcheck.cc misc/socket.cc misc/shmutils.cc misc/profiler.cc misc/param.cc misc/strongstream.cc \ transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc \ diff --git a/src/bootstrap.cc b/src/bootstrap.cc index b7e0576..c348b3e 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -13,6 +13,11 @@ #include #include "proxy.h" +struct bootstrapRootArgs { + struct ncclSocket* listenSock; + uint64_t magic; +}; + /* Init functions */ static char bootstrapNetIfName[MAX_IF_NAME_SIZE+1]; static union ncclSocketAddress bootstrapNetIfAddr; @@ -26,7 +31,7 @@ ncclResult_t bootstrapNetInit() { char* env = getenv("NCCL_COMM_ID"); if (env) { union ncclSocketAddress remoteAddr; - if (ncclGetSocketAddrFromString(&remoteAddr, env) != ncclSuccess) { + if (ncclSocketGetAddrFromString(&remoteAddr, env) != ncclSuccess) { WARN("Invalid NCCL_COMM_ID, please use format: : or []: or :"); return ncclInvalidArgument; } @@ -89,8 +94,10 @@ static ncclResult_t setFilesLimit() { return ncclSuccess; } -static void *bootstrapRoot(void* args) { - struct ncclSocket* listenSock = (struct ncclSocket*)args; +static void *bootstrapRoot(void* rargs) { + struct bootstrapRootArgs* args = (struct bootstrapRootArgs*)rargs; + struct ncclSocket* listenSock = args->listenSock; + uint64_t magic = args->magic; ncclResult_t res = ncclSuccess; int nranks = 0, c = 0; struct extInfo info; @@ -104,11 +111,10 @@ static void *bootstrapRoot(void* args) { /* Receive addresses from all ranks */ do { struct ncclSocket sock; - /* bootstrap root thread always uses blocking ncclSocketAccept. */ - NCCLCHECKGOTO(ncclSocketInit(&sock, NULL, NULL, 0), res, out); + NCCLCHECKGOTO(ncclSocketInit(&sock), res, out); NCCLCHECKGOTO(ncclSocketAccept(&sock, listenSock), res, out); NCCLCHECKGOTO(bootstrapNetRecv(&sock, &info, sizeof(info)), res, out); - close(sock.fd); + NCCLCHECKGOTO(ncclSocketClose(&sock), res, out); if (c == 0) { nranks = info.nranks; @@ -139,54 +145,60 @@ static void *bootstrapRoot(void* args) { for (int r=0; rfd); - free(listenSock); + if (listenSock != NULL) { + ncclSocketClose(listenSock); + free(listenSock); + } if (rankAddresses) free(rankAddresses); if (rankAddressesRoot) free(rankAddressesRoot); if (zero) free(zero); + free(rargs); TRACE(NCCL_INIT, "DONE"); return NULL; } -ncclResult_t bootstrapCreateRoot(ncclUniqueId* id, bool idFromEnv) { +ncclResult_t bootstrapCreateRoot(struct ncclBootstrapHandle* handle, bool idFromEnv) { struct ncclSocket* listenSock; - NCCLCHECK(ncclCalloc(&listenSock, 1)); - memcpy(&listenSock->addr, id, sizeof(union ncclSocketAddress)); - NCCLCHECK(ncclSocketListen(listenSock)); - memcpy(id, &listenSock->addr, sizeof(union ncclSocketAddress)); + struct bootstrapRootArgs* args; pthread_t thread; - pthread_create(&thread, NULL, bootstrapRoot, (void*)listenSock); + + NCCLCHECK(ncclCalloc(&listenSock, 1)); + NCCLCHECK(ncclSocketInit(listenSock, &handle->addr, handle->magic, ncclSocketTypeBootstrap, NULL, 0)); + NCCLCHECK(ncclSocketListen(listenSock)); + NCCLCHECK(ncclSocketGetAddr(listenSock, &handle->addr)); + + NCCLCHECK(ncclCalloc(&args, 1)); + args->listenSock = listenSock; + args->magic = handle->magic; + NEQCHECK(pthread_create(&thread, NULL, bootstrapRoot, (void*)args), 0); ncclSetThreadName(thread, "NCCL BootstrapR"); - pthread_detach(thread); // will not be pthread_join()'d + NEQCHECK(pthread_detach(thread), 0); // will not be pthread_join()'d return ncclSuccess; } -ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) { - static_assert(sizeof(union ncclSocketAddress) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId"); - memset(id, 0, sizeof(ncclUniqueId)); - union ncclSocketAddress* connectAddr = (union ncclSocketAddress*) id; +ncclResult_t bootstrapGetUniqueId(struct ncclBootstrapHandle* handle) { + memset(handle, 0, sizeof(ncclBootstrapHandle)); + NCCLCHECK(getRandomData(&handle->magic, sizeof(handle->magic))); char* env = getenv("NCCL_COMM_ID"); if (env) { INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", env); - if (ncclGetSocketAddrFromString(connectAddr, env) != ncclSuccess) { + if (ncclSocketGetAddrFromString(&handle->addr, env) != ncclSuccess) { WARN("Invalid NCCL_COMM_ID, please use format: : or []: or :"); return ncclInvalidArgument; } } else { - memcpy(id, &bootstrapNetIfAddr, sizeof(union ncclSocketAddress)); - NCCLCHECK(bootstrapCreateRoot(id, false)); + memcpy(&handle->addr, &bootstrapNetIfAddr, sizeof(union ncclSocketAddress)); + NCCLCHECK(bootstrapCreateRoot(handle, false)); } return ncclSuccess; @@ -209,38 +221,39 @@ struct bootstrapState { int cudaDev; int rank; int nranks; + uint64_t magic; volatile uint32_t *abortFlag; }; -ncclResult_t bootstrapInit(ncclUniqueId * id, struct ncclComm* comm) { +ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* comm) { int rank = comm->rank; int nranks = comm->nRanks; struct bootstrapState* state; + struct ncclSocket* proxySocket; + ncclSocketAddress nextAddr; + struct ncclSocket sock, listenSockRoot; + struct extInfo info = { 0 }; + NCCLCHECK(ncclCalloc(&state, 1)); state->rank = rank; state->nranks = nranks; state->abortFlag = comm->abortFlag; comm->bootstrap = state; + comm->magic = state->magic = handle->magic; TRACE(NCCL_INIT, "rank %d nranks %d", rank, nranks); - struct extInfo info = { 0 }; info.rank = rank; info.nranks = nranks; - struct ncclSocket sock, listenSockRoot; - - NCCLCHECK(ncclSocketInit(&sock, (union ncclSocketAddress*) id, comm->abortFlag, 0)); - NCCLCHECK(ncclSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->abortFlag, 0)); - NCCLCHECK(ncclSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->abortFlag, 0)); - NCCLCHECK(ncclSocketInit(&state->ringSendSocket, NULL, comm->abortFlag, 0)); - NCCLCHECK(ncclSocketInit(&state->ringRecvSocket, NULL, comm->abortFlag, 0)); // Create socket for other ranks to contact me + NCCLCHECK(ncclSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECK(ncclSocketListen(&state->listenSock)); - memcpy(&info.extAddressListen, &state->listenSock.addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(&state->listenSock, &info.extAddressListen)); // Create socket for root to contact me + NCCLCHECK(ncclSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECK(ncclSocketListen(&listenSockRoot)); - memcpy(&info.extAddressListenRoot, &listenSockRoot.addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot)); // stagger connection times to avoid an overload of the root if (nranks > 128) { @@ -253,32 +266,37 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, struct ncclComm* comm) { } // send info on my listening socket to root + NCCLCHECK(ncclSocketInit(&sock, &handle->addr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECK(ncclSocketConnect(&sock)); NCCLCHECK(bootstrapNetSend(&sock, &info, sizeof(info))); - close(sock.fd); + NCCLCHECK(ncclSocketClose(&sock)); // get info on my "next" rank in the bootstrap ring from root + NCCLCHECK(ncclSocketInit(&sock)); NCCLCHECK(ncclSocketAccept(&sock, &listenSockRoot)); - NCCLCHECK(bootstrapNetRecv(&sock, &state->ringSendSocket.addr, sizeof(union ncclSocketAddress))); - close(sock.fd); - close(listenSockRoot.fd); + NCCLCHECK(bootstrapNetRecv(&sock, &nextAddr, sizeof(union ncclSocketAddress))); + NCCLCHECK(ncclSocketClose(&sock)); + NCCLCHECK(ncclSocketClose(&listenSockRoot)); + NCCLCHECK(ncclSocketInit(&state->ringSendSocket, &nextAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECK(ncclSocketConnect(&state->ringSendSocket)); // Accept the connect request from the previous rank in the AllGather ring + NCCLCHECK(ncclSocketInit(&state->ringRecvSocket)); NCCLCHECK(ncclSocketAccept(&state->ringRecvSocket, &state->listenSock)); // AllGather all listen handlers NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks)); - memcpy(state->peerCommAddresses+rank, &state->listenSock.addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(&state->listenSock, state->peerCommAddresses+rank)); NCCLCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union ncclSocketAddress))); // Create the service proxy NCCLCHECK(ncclCalloc(&state->peerProxyAddresses, nranks)); - struct ncclSocket* proxySocket; + + // proxy is aborted through a message; don't set abortFlag NCCLCHECK(ncclCalloc(&proxySocket, 1)); - NCCLCHECK(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, NULL, 0)); + NCCLCHECK(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeProxy, comm->abortFlag)); NCCLCHECK(ncclSocketListen(proxySocket)); - memcpy(state->peerProxyAddresses+rank, &proxySocket->addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(proxySocket, state->peerProxyAddresses+rank)); NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress))); NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses)); @@ -314,16 +332,21 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) { } ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) { + ncclResult_t ret = ncclSuccess; struct bootstrapState* state = (struct bootstrapState*)commState; struct ncclSocket sock; - NCCLCHECK(ncclSocketInit(&sock, state->peerCommAddresses+peer, state->abortFlag, 1)); - NCCLCHECK(ncclSocketConnect(&sock)); - NCCLCHECK(bootstrapNetSend(&sock, &state->rank, sizeof(int))); - NCCLCHECK(bootstrapNetSend(&sock, &tag, sizeof(int))); - NCCLCHECK(bootstrapNetSend(&sock, data, size)); - close(sock.fd); - return ncclSuccess; + NCCLCHECKGOTO(ncclSocketInit(&sock, state->peerCommAddresses+peer, state->magic, ncclSocketTypeBootstrap, state->abortFlag), ret, fail); + NCCLCHECKGOTO(ncclSocketConnect(&sock), ret, fail); + NCCLCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail); + NCCLCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail); + NCCLCHECKGOTO(bootstrapNetSend(&sock, data, size), ret, fail); + +exit: + NCCLCHECK(ncclSocketClose(&sock)); + return ret; +fail: + goto exit; } ncclResult_t bootstrapBarrier(void* commState, int *ranks, int rank, int nranks, int tag) { @@ -382,9 +405,10 @@ ncclResult_t unexpectedEnqueue(struct bootstrapState* state, int peer, int tag, return ncclSuccess; } -ncclResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, struct ncclSocket* sock) { +ncclResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, struct ncclSocket* sock, int* found) { struct unexConn* elem = state->unexpectedConnections; struct unexConn* prev = NULL; + *found = 0; while (elem) { if (elem->peer == peer && elem->tag == tag) { if (prev == NULL) { @@ -394,54 +418,75 @@ ncclResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, } memcpy(sock, &elem->sock, sizeof(struct ncclSocket)); free(elem); + *found = 1; return ncclSuccess; } prev = elem; elem = elem->next; } - sock->fd = -1; return ncclSuccess; } +static void unexpectedFree(struct bootstrapState* state) { + struct unexConn* elem = state->unexpectedConnections; + struct unexConn* prev = NULL; + + while (elem) { + prev = elem; + elem = elem->next; + free(prev); + } + return; +} + // We can't know who we'll receive from, so we need to receive everything at once ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size) { + ncclResult_t ret = ncclSuccess; struct bootstrapState* state = (struct bootstrapState*)commState; struct ncclSocket sock; + int newPeer, newTag; // Search unexpected connections first - NCCLCHECK(unexpectedDequeue(state, peer, tag, &sock)); - if (sock.fd != -1) { - NCCLCHECK(bootstrapNetRecv(&sock, ((char*)data), size)); - close(sock.fd); - return ncclSuccess; + int found; + NCCLCHECK(unexpectedDequeue(state, peer, tag, &sock, &found)); + if (found) { + NCCLCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail); + goto exit; } // Then look for new connections - NCCLCHECK(ncclSocketInit(&sock, NULL, state->listenSock.abortFlag, 0)); while (1) { - NCCLCHECK(ncclSocketAccept(&sock, &state->listenSock)); - int newPeer, newTag; - NCCLCHECK(bootstrapNetRecv(&sock, &newPeer, sizeof(int))); - NCCLCHECK(bootstrapNetRecv(&sock, &newTag, sizeof(int))); + NCCLCHECKGOTO(ncclSocketInit(&sock), ret, fail); + NCCLCHECKGOTO(ncclSocketAccept(&sock, &state->listenSock), ret, fail); + NCCLCHECKGOTO(bootstrapNetRecv(&sock, &newPeer, sizeof(int)), ret, fail); + NCCLCHECKGOTO(bootstrapNetRecv(&sock, &newTag, sizeof(int)), ret, fail); if (newPeer == peer && newTag == tag) { - NCCLCHECK(bootstrapNetRecv(&sock, ((char*)data), size)); - close(sock.fd); - return ncclSuccess; + NCCLCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail); + goto exit; } // Unexpected connection. Save for later. - NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, &sock)); + NCCLCHECKGOTO(unexpectedEnqueue(state, newPeer, newTag, &sock), ret, fail); } +exit: + NCCLCHECK(ncclSocketClose(&sock)); + return ret; +fail: + goto exit; } ncclResult_t bootstrapClose(void* commState) { struct bootstrapState* state = (struct bootstrapState*)commState; if (state->unexpectedConnections != NULL) { - WARN("Unexpected connections are not empty"); - return ncclInternalError; + unexpectedFree(state); + if (*state->abortFlag == 0) { + WARN("Unexpected connections are not empty"); + return ncclInternalError; + } } - if (state->listenSock.fd >= 0) close(state->listenSock.fd); - if (state->ringSendSocket.fd >= 0) close(state->ringSendSocket.fd); - if (state->ringRecvSocket.fd >= 0) close(state->ringRecvSocket.fd); + + NCCLCHECK(ncclSocketClose(&state->listenSock)); + NCCLCHECK(ncclSocketClose(&state->ringSendSocket)); + NCCLCHECK(ncclSocketClose(&state->ringRecvSocket)); free(state->peerCommAddresses); free(state); @@ -452,9 +497,9 @@ ncclResult_t bootstrapClose(void* commState) { ncclResult_t bootstrapAbort(void* commState) { struct bootstrapState* state = (struct bootstrapState*)commState; if (commState == NULL) return ncclSuccess; - if (state->listenSock.fd) close(state->listenSock.fd); - if (state->ringSendSocket.fd) close(state->ringSendSocket.fd); - if (state->ringRecvSocket.fd) close(state->ringRecvSocket.fd); + NCCLCHECK(ncclSocketClose(&state->listenSock)); + NCCLCHECK(ncclSocketClose(&state->ringSendSocket)); + NCCLCHECK(ncclSocketClose(&state->ringRecvSocket)); free(state->peerCommAddresses); free(state->peerProxyAddresses); free(state); diff --git a/src/collectives/all_gather.cc b/src/collectives/all_gather.cc index 266fd5a..97ec981 100644 --- a/src/collectives/all_gather.cc +++ b/src/collectives/all_gather.cc @@ -11,7 +11,13 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + // Just pass the size of one message and not the total bytes sent/received. + constexpr nvtxPayloadSchemaEntry_t AllGatherSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"} + }; + size_t msgsize = sendcount * ncclTypeSize(datatype); + NVTX3_FUNC_WITH_PARAMS(AllGather, AllGatherSchema, msgsize) + struct ncclInfo info = { ncclFuncAllGather, "AllGather", sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */ ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS }; diff --git a/src/collectives/all_reduce.cc b/src/collectives/all_reduce.cc index b67f3be..8ac61a2 100644 --- a/src/collectives/all_reduce.cc +++ b/src/collectives/all_reduce.cc @@ -5,12 +5,25 @@ ************************************************************************/ #include "enqueue.h" +#include "nccl.h" NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream); ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + struct NvtxParamsAllReduce { + size_t bytes; + ncclRedOp_t op; + }; + // Just pass the size of one message and not the total bytes sent/received. + static constexpr nvtxPayloadSchemaEntry_t AllReduceSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"}, + {0, NVTX_PAYLOAD_ENTRY_NCCL_REDOP, "Reduction operation", nullptr, 0, + offsetof(NvtxParamsAllReduce, op)} + }; + NvtxParamsAllReduce payload{count * ncclTypeSize(datatype), op}; + NVTX3_FUNC_WITH_PARAMS(AllReduce, AllReduceSchema, payload) + struct ncclInfo info = { ncclFuncAllReduce, "AllReduce", sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */ ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS }; diff --git a/src/collectives/broadcast.cc b/src/collectives/broadcast.cc index db0fb49..c73502e 100644 --- a/src/collectives/broadcast.cc +++ b/src/collectives/broadcast.cc @@ -11,7 +11,17 @@ NCCL_API(ncclResult_t, ncclBroadcast, const void* sendbuff, void* recvbuff, size ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + struct NvtxParamsBroadcast { + size_t bytes; + int root; + }; + constexpr nvtxPayloadSchemaEntry_t BroadcastSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Bytes"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Root", nullptr, 0, offsetof(NvtxParamsBroadcast, root)} + }; + NvtxParamsBroadcast payload{count * ncclTypeSize(datatype), root}; + NVTX3_FUNC_WITH_PARAMS(Broadcast, BroadcastSchema, payload) + struct ncclInfo info = { ncclFuncBroadcast, "Broadcast", sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */ BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS }; diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 310938f..95cc990 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -37,6 +37,7 @@ struct ncclShmemData { }; uint64_t redOpArgs[NCCL_MAX_DIRECT_ARITY+1]; int channelId; + int aborted; alignas(16) struct ncclDevComm comm; alignas(16) struct ncclDevChannel channel; alignas(16) struct ncclWork work; @@ -135,6 +136,8 @@ __device__ void ncclKernel( } __syncthreads(); // publish ncclShmem.channelId int channelId = ncclShmem.channelId; + /* set abort flag to 0 */ + if (tid == 0) ncclShmem.aborted = 0; if (true) { void *dst, *src; diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index a727849..9d2d19a 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -62,7 +62,10 @@ class Primitives< inline __device__ bool checkAbort(int &spins) { spins++; if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { - flags |= *ncclShmem.comm.abortFlag ? Aborted : 0; + if (*ncclShmem.comm.abortFlag) { + flags |= Aborted; + ncclShmem.aborted = 1; + } spins = 0; } return flags & Aborted; @@ -176,6 +179,9 @@ class Primitives< ncclShmem.groups[group].dsts[0] = userBuff + dstIx + offset; waitPeer(dstIx, remoteIx, offset, sliceSize); subBarrier(); + /* if user abort the kernel, we don't need to actually perform copy/reduce; just set size + * to 0 to avoid unnecessary workload. */ + size_t workSize = ncclShmem.aborted ? 0 : sliceSize; if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (Send) { @@ -184,7 +190,7 @@ class Primitives< (tid, nworkers, nullptr, false, 1, (T const**)ncclShmem.groups[group].srcs, fan.nsend(), (T**)ncclShmem.groups[group].dsts+1, - sliceSize); + workSize); } } else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem.groups[group].dsts[Dst] == nullptr) { // For broadcast in CollNet to do empty send @@ -192,7 +198,7 @@ class Primitives< (tid, nworkers, ncclShmem.redOpArgs, postOp, Recv, (T const**)ncclShmem.groups[group].srcs, Dst, (T**)ncclShmem.groups[group].dsts, - sliceSize); + workSize); } else { constexpr int PreOpN = SrcBuf != Input ? 0 : DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; @@ -200,7 +206,7 @@ class Primitives< (tid, nworkers, ncclShmem.redOpArgs, postOp, Recv*fan.nrecv()+Src, (T const**)ncclShmem.groups[group].srcs, Send*fan.nsend()+Dst, (T**)ncclShmem.groups[group].dsts, - sliceSize); + workSize); } barrier(); // This barrier has a counterpart in following loop if (Send && (flags & RolePostSend) && index == 0) __threadfence_system(); diff --git a/src/collectives/reduce.cc b/src/collectives/reduce.cc index 86388df..6335516 100644 --- a/src/collectives/reduce.cc +++ b/src/collectives/reduce.cc @@ -6,12 +6,26 @@ #include "enqueue.h" #include "collectives.h" +#include "nccl.h" NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + struct NvtxParamsReduce { + size_t bytes; + int root; + ncclRedOp_t op; + }; + constexpr nvtxPayloadSchemaEntry_t ReduceSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Root", nullptr, 0, offsetof(NvtxParamsReduce, root)}, + {0, NVTX_PAYLOAD_ENTRY_NCCL_REDOP, "Reduction operation", nullptr, 0, + offsetof(NvtxParamsReduce, op)} + }; + NvtxParamsReduce payload{count * ncclTypeSize(datatype), root, op}; + NVTX3_FUNC_WITH_PARAMS(Reduce, ReduceSchema, payload) + struct ncclInfo info = { ncclFuncReduce, "Reduce", sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */ REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS }; diff --git a/src/collectives/reduce_scatter.cc b/src/collectives/reduce_scatter.cc index 57c67bf..5242545 100644 --- a/src/collectives/reduce_scatter.cc +++ b/src/collectives/reduce_scatter.cc @@ -6,12 +6,24 @@ #include "enqueue.h" #include "collectives.h" +#include "nccl.h" NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream); ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + struct NvtxParamsReduceScatter { + size_t bytes; + ncclRedOp_t op; + }; + constexpr nvtxPayloadSchemaEntry_t ReduceScatterSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"}, + {0, NVTX_PAYLOAD_ENTRY_NCCL_REDOP, "Reduction operation", nullptr, 0, + offsetof(NvtxParamsReduceScatter, op)} + }; + NvtxParamsReduceScatter payload{recvcount * ncclTypeSize(datatype), op}; + NVTX3_FUNC_WITH_PARAMS(ReduceScatter, ReduceScatterSchema, payload) + struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter", sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */ REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS }; diff --git a/src/collectives/sendrecv.cc b/src/collectives/sendrecv.cc index 0e9ca4f..9a81b0a 100644 --- a/src/collectives/sendrecv.cc +++ b/src/collectives/sendrecv.cc @@ -8,11 +8,22 @@ #include "collectives.h" #include "argcheck.h" // Need some checks here since we access comm +struct NvtxParamsSendRecv { + size_t bytes; + int peer; +}; +constexpr const nvtxPayloadSchemaEntry_t SendRecvSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Bytes"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Peer rank", nullptr, 0, offsetof(NvtxParamsSendRecv, peer)} +}; + NCCL_API(ncclResult_t, ncclSend, const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer}; + NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload) + struct ncclInfo info = { ncclFuncSend, "Send", NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ 1, 1 }; @@ -27,7 +38,9 @@ NCCL_API(ncclResult_t, ncclRecv, void* recvbuff, size_t count, ncclDataType_t da ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer}; + NVTX3_FUNC_WITH_PARAMS(Recv, SendRecvSchema, payload) + struct ncclInfo info = { ncclFuncRecv, "Recv", NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ 1, 1 }; diff --git a/src/enqueue.cc b/src/enqueue.cc index 8bac73f..0744e09 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -609,8 +609,7 @@ static ncclResult_t scheduleP2pTasksToPlan( // Compute how much to split operations // Natural step size matching buffer steps. - ssize_t stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; - if (comm->nNodes > 1) stepSize = comm->p2pNetChunkSize; + ssize_t stepSize = comm->p2pChunkSize; // Try to use all channels int nChannelsMax = comm->p2pnChannelsPerPeer; int nChannelsMin = nChannelsMax; @@ -1008,7 +1007,13 @@ ncclResult_t ncclLaunchKernelBefore_NoUncapturedCuda(struct ncclComm* comm, stru #if CUDART_VERSION >= 11080 #define NCCL_MAX_CGA_CLUSTER_SIZE 8 -NCCL_PARAM(CGAClusterSize, "CGA_CLUSTER_SIZE", 0); +#define NCCL_CGA_CLUSTER_SIZE_SM90 4 +NCCL_PARAM(CGAClusterSize, "CGA_CLUSTER_SIZE", -2); +#endif + +#if CUDART_VERSION >= 12000 +// NCCL uses the "Remote" Mem Sync domain by default +NCCL_PARAM(MemSyncDomain, "MEM_SYNC_DOMAIN", cudaLaunchMemSyncDomainRemote); #endif ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan) { @@ -1022,22 +1027,25 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan #if CUDART_VERSION >= 11080 int driverVersion; NCCLCHECK(ncclCudaDriverVersion(&driverVersion)); - - unsigned int clusterSize = 0; - clusterSize = ncclParamCGAClusterSize(); - if (clusterSize > NCCL_MAX_CGA_CLUSTER_SIZE) { - static bool warned = false; - if (warned == false) { - WARN("NCCL_CGA_CLUSTER_SIZE value %d is too big. Limiting value to %d.", - clusterSize, NCCL_MAX_CGA_CLUSTER_SIZE); - warned = true; + if (driverVersion >= 11080) { + int compCap = comm->compCap; + unsigned int clusterSize = (compCap == 90) ? NCCL_CGA_CLUSTER_SIZE_SM90 : 0; + if (ncclParamCGAClusterSize() != -2) { + clusterSize = ncclParamCGAClusterSize(); + if (clusterSize > NCCL_MAX_CGA_CLUSTER_SIZE) { + static bool warned = false; + if (warned == false) { + WARN("NCCL_CGA_CLUSTER_SIZE value %d is too big. Limiting value to %d.", + clusterSize, NCCL_MAX_CGA_CLUSTER_SIZE); + warned = true; + } + clusterSize = NCCL_MAX_CGA_CLUSTER_SIZE; + } } - clusterSize = NCCL_MAX_CGA_CLUSTER_SIZE; - } - if (clusterSize && driverVersion >= 11080) { cudaLaunchConfig_t launchConfig = {0}; - cudaLaunchAttribute launchAttrs[2]; + cudaLaunchAttribute launchAttrs[3]; + int attrs = 0; /* Cooperative Group Array (CGA) * On sm90 and later we have an extra level of hierarchy where we * can group together several blocks within the Grid, called @@ -1048,17 +1056,25 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan * concurrently scheduled onto a group of SMs. * The maximum value is 8 and it must be divisible into the grid dimensions */ - // Grid dimension must be divisible by clusterSize - if (grid.x % clusterSize) clusterSize = 1; - launchAttrs[0].id = cudaLaunchAttributeClusterDimension; - launchAttrs[0].val.clusterDim = {clusterSize, 1, 1}; - launchAttrs[1].id = cudaLaunchAttributeClusterSchedulingPolicyPreference; - launchAttrs[1].val.clusterSchedulingPolicyPreference = cudaClusterSchedulingPolicySpread; - + if (clusterSize) { + // Grid dimension must be divisible by clusterSize + if (grid.x % clusterSize) clusterSize = 1; + launchAttrs[attrs].id = cudaLaunchAttributeClusterDimension; + launchAttrs[attrs++].val.clusterDim = {clusterSize, 1, 1}; + launchAttrs[attrs].id = cudaLaunchAttributeClusterSchedulingPolicyPreference; + launchAttrs[attrs++].val.clusterSchedulingPolicyPreference = cudaClusterSchedulingPolicySpread; + } + #if CUDART_VERSION >= 12000 + if (compCap >= 90 && driverVersion >= 12000) { + // Set the NCCL Mem Sync domain on CUDA 12.0 and later (sm90) + launchAttrs[attrs].id = cudaLaunchAttributeMemSyncDomain; + launchAttrs[attrs++].val.memSyncDomain = (cudaLaunchMemSyncDomain) ncclParamMemSyncDomain(); + } + #endif launchConfig.gridDim = grid; launchConfig.blockDim = block; launchConfig.attrs = launchAttrs; - launchConfig.numAttrs = sizeof(launchAttrs)/sizeof(launchAttrs[0]); + launchConfig.numAttrs = attrs; launchConfig.stream = launchStream; CUDACHECK(cudaLaunchKernelExC(&launchConfig, fn, args)); @@ -1093,14 +1109,18 @@ ncclResult_t ncclLaunchFinish(struct ncclComm* comm) { // back to us for reclaiming via callbackQueue. ncclIntruQueueConstruct(&comm->planQueue); cudaStream_t launchStream = tasks->streams->stream; // First user stream gets launch - // Create dependency for deviceStream on launchStream. - NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, &comm->deviceStream, launchStream), result, resume1); + // Create dependency for deviceStream on launchStream. We know that deviceStream + // hasn't been modified since launchStream waited on it (in ncclLaunchPrepare), + // so we can say that launchStream subsumes it. + NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, &comm->deviceStream, launchStream, /*b_subsumes_a=*/true), result, resume1); resume1: - // Create dependency for other user streams (skip launch stream). + // Create dependency for other user streams (skip launch stream) on deviceStream. + // Again, the user streams haven't been touched since deviceStream waited on them + // so we can say they are subsumed by deviceStream. struct ncclCudaStreamList* sl = tasks->streams->next; tasks->streams = nullptr; // Reset comm->tasks.streams to empty. while (sl != nullptr) { - NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, sl->stream, &comm->deviceStream), result, resume2); + NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, sl->stream, &comm->deviceStream, /*b_subsumes_a=*/true), result, resume2); resume2: sl = sl->next; } diff --git a/src/graph/paths.cc b/src/graph/paths.cc index 01f1582..7134b90 100644 --- a/src/graph/paths.cc +++ b/src/graph/paths.cc @@ -754,3 +754,15 @@ ncclResult_t ncclTopoGetNvbGpus(struct ncclTopoSystem* system, int rank, int* nr *nranks = nvbGpus; return ncclSuccess; } + +int ncclTopoPathAllNVLink(struct ncclTopoSystem* system) { + int minPath = PATH_DIS; + for (int i=0; inodes[GPU].count; i++) { + struct ncclTopoLinkList* paths = system->nodes[GPU].nodes[i].paths[GPU]; + for (int j=0; jnodes[GPU].count; j++) { + if (i == j) continue; + minPath = std::min(minPath, paths[j].type); + } + } + return minPath >= PATH_PIX ? 0 : 1; +} diff --git a/src/graph/topo.cc b/src/graph/topo.cc index 9e4c978..d91aa63 100644 --- a/src/graph/topo.cc +++ b/src/graph/topo.cc @@ -72,6 +72,9 @@ static ncclResult_t ncclTopoGetInterCpuBw(struct ncclTopoNode* cpu, float* bw) { if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_INTEL) { *bw = cpu->cpu.model == NCCL_TOPO_CPU_TYPE_SKL ? SKL_QPI_BW : QPI_BW; } + if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_AMD) { + *bw = AMD_BW; + } if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_ZHAOXIN) { *bw = cpu->cpu.model == NCCL_TOPO_CPU_TYPE_YONGFENG ? YONGFENG_ZPI_BW : ZPI_BW; } diff --git a/src/graph/topo.h b/src/graph/topo.h index 20a3e9d..1a1a04c 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -18,6 +18,7 @@ #define PCI_BW 12.0 // PCI Gen3 x16 #define QPI_BW 6.0 #define SKL_QPI_BW 9.0 +#define AMD_BW 16.0 #define ZPI_BW 6.0 #define YONGFENG_ZPI_BW 9.0 #define P9_BW 32.0 diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index a787c0b..e70db04 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -10,10 +10,16 @@ #include "nccl.h" #include "comm.h" +struct ncclBootstrapHandle { + uint64_t magic; + union ncclSocketAddress addr; +}; +static_assert(sizeof(struct ncclBootstrapHandle) <= sizeof(ncclUniqueId), "Bootstrap handle is too large to fit inside NCCL unique ID"); + ncclResult_t bootstrapNetInit(); -ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv); -ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out); -ncclResult_t bootstrapInit(ncclUniqueId* id, struct ncclComm* comm); +ncclResult_t bootstrapCreateRoot(struct ncclBootstrapHandle* handle, bool idFromEnv); +ncclResult_t bootstrapGetUniqueId(struct ncclBootstrapHandle* handle); +ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* comm); ncclResult_t bootstrapAllGather(void* commState, void* allData, int size); ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size); ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); diff --git a/src/include/comm.h b/src/include/comm.h index 16e95b3..655292a 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -171,9 +171,12 @@ struct ncclComm { uint64_t* connectSend; uint64_t* connectRecv; + uint64_t magic; // Magic number for all network communication. Not a security key -- only goal is to detect mismatches. + int rank; // my rank in the communicator int nRanks; // number of GPUs in communicator int cudaDev; // my cuda device index + int compCap; // compute capability of the GPU int64_t busId; // my PCI bus ID in int format cpu_set_t cpuAffinity; // CPU affinity of the GPU @@ -208,7 +211,7 @@ struct ncclComm { // Buffer sizes int buffSizes[NCCL_NUM_PROTOCOLS]; - int p2pNetChunkSize; + int p2pChunkSize; // Algorithm/Protocols thresholds ssize_t threadThresholds[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; @@ -240,7 +243,6 @@ struct ncclComm { // Intra-process sync struct ncclComm* intraComm0; // leader of intra-process comms (self possible) struct ncclComm* intraNext; // next of intra-process comms, intraComm0 is head - int intraRefs; // reference count from intra-process comms (zero if not leader else intraRanks) int intraRank; int intraRanks; uint32_t intraBarrierPhase; diff --git a/src/include/graph.h b/src/include/graph.h index 91e85e7..62323a2 100644 --- a/src/include/graph.h +++ b/src/include/graph.h @@ -28,6 +28,7 @@ void ncclTopoFree(struct ncclTopoSystem* system); ncclResult_t ncclTopoTrimSystem(struct ncclTopoSystem* system, struct ncclComm* comm); ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm); ncclResult_t ncclTopoGetNvbGpus(struct ncclTopoSystem* system, int rank, int* nranks, int** ranks); +int ncclTopoPathAllNVLink(struct ncclTopoSystem* system); // Query topology ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoGraph* graph, int channelId, int peerRank, int* net, int* proxyRank); diff --git a/src/include/nvtx.h b/src/include/nvtx.h index 7796126..2aeb932 100644 --- a/src/include/nvtx.h +++ b/src/include/nvtx.h @@ -9,6 +9,77 @@ #include "nvtx3.hpp" +#if __cpp_constexpr >= 201304L && !defined(NVTX3_RELAXED_CONSTEXPR) +#define NVTX3_RELAXED_CONSTEXPR constexpr +#else +#define NVTX3_RELAXED_CONSTEXPR +#endif + +// Define all NCCL-provided static schema IDs here (avoid duplicates). +#define NVTX_SID_CommInitRank 0 +#define NVTX_SID_CommInitAll 1 +#define NVTX_SID_CommDestroy 2 // same schema as NVTX_SID_CommInitRank +#define NVTX_SID_CommAbort 3 // same schema as NVTX_SID_CommInitRank +#define NVTX_SID_AllGather 4 +#define NVTX_SID_AllReduce 5 +#define NVTX_SID_Broadcast 6 +#define NVTX_SID_ReduceScatter 7 +#define NVTX_SID_Reduce 8 +#define NVTX_SID_Send 9 +#define NVTX_SID_Recv 10 + +// Define static schema ID for the reduction operation. +#define NVTX_PAYLOAD_ENTRY_NCCL_REDOP 11 + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + +extern const nvtxDomainHandle_t ncclNvtxDomainHandle; + struct nccl_domain{static constexpr char const* name{"NCCL"};}; +class payload_schema { + public: + NVTX3_RELAXED_CONSTEXPR explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept + { + schema_attr.name = schemaName; + schema_attr.entries = entries; + schema_attr.numEntries = numEntries; + schema_attr.schemaId = schemaId; + nvtxPayloadSchemaRegister(nvtx3::domain::get(), &schema_attr); + } + + payload_schema() = delete; + ~payload_schema() = default; + payload_schema(payload_schema const&) = default; + payload_schema& operator=(payload_schema const&) = default; + payload_schema(payload_schema&&) = default; + payload_schema& operator=(payload_schema&&) = default; + + private: + nvtxPayloadSchemaAttr_t schema_attr{ + NVTX_PAYLOAD_SCHEMA_ATTR_TYPE | + NVTX_PAYLOAD_SCHEMA_ATTR_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_NUM_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_STATIC_SIZE | + NVTX_PAYLOAD_SCHEMA_ATTR_SCHEMA_ID, + nullptr, + NVTX_PAYLOAD_SCHEMA_TYPE_STATIC, + NVTX_PAYLOAD_SCHEMA_FLAG_NONE, + nullptr, 0, 0, 0}; +}; + +// Create NVTX push/pop range with parameters +// @param name of the operation (see `NVTX_SID_*`) +// @param N schema name +// @param S schema (entries) +// @param P payload (struct) +#define NVTX3_FUNC_WITH_PARAMS(ID, S, P) \ + static const payload_schema schema{S, std::extent::value, \ + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, #ID}; \ + static ::nvtx3::v1::registered_string const nvtx3_func_name__{__func__}; \ + nvtxPayloadData_t nvtx3_bpl__[] = { \ + {NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, sizeof(P), &(P)}}; \ + ::nvtx3::v1::event_attributes nvtx3_func_attr__{nvtx3_func_name__, nvtx3_bpl__}; \ + ::nvtx3::v1::domain_thread_range const nvtx3_range__{nvtx3_func_attr__}; + +extern void initNvtxRegisteredEnums(); + #endif diff --git a/src/include/nvtx3.hpp b/src/include/nvtx3.hpp index 1e99373..353fddf 100644 --- a/src/include/nvtx3.hpp +++ b/src/include/nvtx3.hpp @@ -92,6 +92,7 @@ /* clang-format on */ #include +#include #include #include @@ -1732,6 +1733,22 @@ class event_attributes { attributes_.messageType = m.get_type(); } + /** + * @brief Variadic constructor where the first argument is a binary payload. + * + * Sets the value of the `EventAttribute`s message based on `m` and forwards + * the remaining variadic parameter pack to the next constructor. + * + */ + template + NVTX3_RELAXED_CONSTEXPR explicit event_attributes(nvtxPayloadData_t const* bpl, Args const&... args) noexcept + : event_attributes(args...) + { + attributes_.payloadType = NVTX_PAYLOAD_TYPE_BINARY; + attributes_.reserved0 = 1; // NCCL uses only a single binary payload per event. + attributes_.payload.ullValue = NVTX_POINTER_AS_PAYLOAD_ULLVALUE(bpl); + } + ~event_attributes() = default; event_attributes(event_attributes const&) = default; event_attributes& operator=(event_attributes const&) = default; diff --git a/src/include/nvtx3/nvToolsExtPayload.h b/src/include/nvtx3/nvToolsExtPayload.h new file mode 100644 index 0000000..1683f92 --- /dev/null +++ b/src/include/nvtx3/nvToolsExtPayload.h @@ -0,0 +1,776 @@ +/* +* Copyright 2021 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +#include "nvtx3/nvToolsExt.h" + +#ifndef NVTOOLSEXT_PAYLOAD_H +#define NVTOOLSEXT_PAYLOAD_H + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** + * \brief A compatibility ID value used in initialization to identify version + * differences. + */ +#define NVTX_EXT_COMPATID_PAYLOAD 0x0103 + +/** + * \brief This module ID identifies the payload extension. It has to be unique + * among the extension modules. + */ +#define NVTX_EXT_MODULEID_PAYLOAD 2 + +/** + * \brief Additional values for the enum @ref nvtxPayloadType_t + */ +#define NVTX_PAYLOAD_TYPE_BINARY ((int32_t)0xDFBD0009) + + +/** --------------------------------------------------------------------------- + * Payload schema entry flags. + * ------------------------------------------------------------------------- */ +#define NVTX_PAYLOAD_ENTRY_FLAG_UNUSED 0 + +/** + * Absolute pointer into a payload (entry) of the same event. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_POINTER (1 << 1) + +/** + * Offset from base address of the payload. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_OFFSET_FROM_BASE (1 << 2) + +/** + * Offset from the end of this payload entry. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_OFFSET_FROM_HERE (1 << 3) + +/** + * The value is an array with fixed length, set with the field `arrayLength`. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_FIXED_SIZE (1 << 4) + +/** + * The value is a zero-/null-terminated array. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_ZERO_TERMINATED (2 << 4) + +/** + * \brief A single or multi-dimensional array of variable length. + * + * The field `arrayLength` contains the index of the schema entry that holds the + * length(s). If the other field points to a scalar entry then this will be the + * 1D array. If the other field points to a FIXED_SIZE array, then the number of + * dimensions is defined with the registration of the scheme. If the other field + * is ZERO_TERMINATED, the array the dimensions can be determined at runtime. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_LENGTH_INDEX (3 << 4) + +/** + * A tool may not support deep copy and just ignore this flag. + * See @ref NVTX_PAYLOAD_SCHEMA_FLAG_DEEP_COPY for more details. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_DEEP_COPY (1 << 9) + +/** + * The entry specifies the message in a deferred event. The entry type can be + * any string type. The flag is ignored for schemas that are not flagged with + * `NVTX_PAYLOAD_SCHEMA_FLAG_RANGE*` or `NVTX_PAYLOAD_SCHEMA_FLAG_MARK`. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_EVENT_MESSAGE (1 << 10) + +/** + * @note The ‘array’ flags assume that the array is embedded. Otherwise, + * @ref NVTX_PAYLOAD_ENTRY_FLAG_POINTER has to be additionally specified. Some + * combinations may be invalid based on the `NVTX_PAYLOAD_SCHEMA_TYPE_*` this + * entry is enclosed. For instance, variable length embedded arrays are valid + * within @ref NVTX_PAYLOAD_SCHEMA_TYPE_DYNAMIC but invalid with + * @ref NVTX_PAYLOAD_SCHEMA_TYPE_STATIC. See `NVTX_PAYLOAD_SCHEMA_TYPE_*` for + * additional details. + */ + +/* Helper macro to check if an entry represents an array. */ +#define NVTX_PAYLOAD_ENTRY_FLAG_IS_ARRAY (\ + NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_FIXED_SIZE | \ + NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_ZERO_TERMINATED | \ + NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_LENGTH_INDEX) + +/** --------------------------------------------------------------------------- + * Types of entries in a payload schema. + * ------------------------------------------------------------------------- */ + +/** + * @note Several of the predefined types contain the size (in bits) in their + * names. For some data types the size (in bytes) is not fixed and may differ + * for different platforms/operating systems/compilers. To provide portability, + * an array of sizes (in bytes) for type 1 to 28 ( @ref + * NVTX_PAYLOAD_ENTRY_TYPE_CHAR to @ref NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE) + * is passed to the NVTX extension initialization function + * @ref InitializeInjectionNvtxExtension via the `extInfo` field of + * @ref nvtxExtModuleInfo_t. + */ + +#define NVTX_PAYLOAD_ENTRY_TYPE_INVALID 0 + +/** + * Basic integer types. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CHAR 1 +#define NVTX_PAYLOAD_ENTRY_TYPE_UCHAR 2 +#define NVTX_PAYLOAD_ENTRY_TYPE_SHORT 3 +#define NVTX_PAYLOAD_ENTRY_TYPE_USHORT 4 +#define NVTX_PAYLOAD_ENTRY_TYPE_INT 5 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT 6 +#define NVTX_PAYLOAD_ENTRY_TYPE_LONG 7 +#define NVTX_PAYLOAD_ENTRY_TYPE_ULONG 8 +#define NVTX_PAYLOAD_ENTRY_TYPE_LONGLONG 9 +#define NVTX_PAYLOAD_ENTRY_TYPE_ULONGLONG 10 + +/** + * Integer types with explicit size. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_INT8 11 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT8 12 +#define NVTX_PAYLOAD_ENTRY_TYPE_INT16 13 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT16 14 +#define NVTX_PAYLOAD_ENTRY_TYPE_INT32 15 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT32 16 +#define NVTX_PAYLOAD_ENTRY_TYPE_INT64 17 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT64 18 + +/** + * C floating point types + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT 19 +#define NVTX_PAYLOAD_ENTRY_TYPE_DOUBLE 20 +#define NVTX_PAYLOAD_ENTRY_TYPE_LONGDOUBLE 21 + +/** + * Size type (`size_t`) + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_SIZE 22 + +/** + * Any address, e.g. `void*`. If the pointer type matters, use the flag @ref + * NVTX_PAYLOAD_ENTRY_FLAG_POINTER and the respective type instead. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_ADDRESS 23 + +/** + * Special character types. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_WCHAR 24 /* wide character (since C90) */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CHAR8 25 /* since C2x and C++20 */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CHAR16 26 +#define NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 27 + +/** + * There is type size and alignment information for all previous types. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE (NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 + 1) + +/** + * Store raw 8-bit binary data. As with `char`, 1-byte alignment is assumed. + * Typically a tool will display this as hex or binary. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_BYTE 32 + +/** + * These types do not have standardized equivalents. It is assumed that the + * number at the end corresponds to the bits used to store the value and that + * the alignment corresponds to standardized types of the same size. + * A tool may not support these types. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_INT128 33 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT128 34 + +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT16 42 +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT32 43 +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT64 44 +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT128 45 + +#define NVTX_PAYLOAD_ENTRY_TYPE_BF16 50 +#define NVTX_PAYLOAD_ENTRY_TYPE_TF32 52 + +/** + * These types are normalized numbers stored in integers. UNORMs represent 0.0 + * to 1.0 and SNORMs represent -1.0 to 1.0. The number after represents the + * number of integer bits. Alignment is take from equivalent types INT# matching + * to SNORM# and UINT# matching to UNORM#. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_SNORM8 61 +#define NVTX_PAYLOAD_ENTRY_TYPE_UNORM8 62 +#define NVTX_PAYLOAD_ENTRY_TYPE_SNORM16 63 +#define NVTX_PAYLOAD_ENTRY_TYPE_UNORM16 64 +#define NVTX_PAYLOAD_ENTRY_TYPE_SNORM32 65 +#define NVTX_PAYLOAD_ENTRY_TYPE_UNORM32 66 +#define NVTX_PAYLOAD_ENTRY_TYPE_SNORM64 67 +#define NVTX_PAYLOAD_ENTRY_TYPE_UNORM64 68 + +/** + * String types. + * + * If `arrayOrUnionDetail` is greater than `0`, the entry is a fixed-size string + * with the provided length. + * + * `NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_FIXED_SIZE` is ignored for string types. It + * just specifies once more that the entry is a fixed-size string. + * + * Setting the flag `NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_ZERO_TERMINATED` indicates a + * zero-terminated string. If `arrayOrUnionDetail` is greater than `0`, a zero- + * terminated array of fixed-size strings is assumed. + * + * Setting the flag `NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_LENGTH_INDEX` specifies the + * entry index of the entry which contains the string length. It is not possible + * to describe a variable length array of strings. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CSTRING 75 /* `char*`, system LOCALE */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CSTRING_UTF8 76 +#define NVTX_PAYLOAD_ENTRY_TYPE_CSTRING_UTF16 77 +#define NVTX_PAYLOAD_ENTRY_TYPE_CSTRING_UTF32 78 + +/** + * @ref nvtxStringHandle_t returned by @ref nvtxDomainRegisterString + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE 80 + +/** + * Entry types to be used in deferred events. Data types are as defined by + * NVTXv3 core: category -> uint32_t, color -> uint32_t, color type -> int32_t. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_NVTX_CATEGORY 90 +#define NVTX_PAYLOAD_ENTRY_TYPE_NVTX_COLORTYPE 91 +#define NVTX_PAYLOAD_ENTRY_TYPE_NVTX_COLOR 92 + +/** + * This type marks the union selector member (entry index) in schemas used by + * a union with internal internal selector. + * See @ref NVTX_PAYLOAD_SCHEMA_TYPE_UNION_WITH_INTERNAL_SELECTOR. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_UNION_SELECTOR 100 + +/** + * Timestamp types occupy the range from 128 to 255 + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP64 128 /* data type is uint64_t */ + +/** + * CPU timestamp sources. + * \todo All 64 bits? + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_TSC 129 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_TSC_NONVIRTUALIZED 130 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_REALTIME 131 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_REALTIME_COARSE 132 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_MONOTONIC 133 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_MONOTONIC_RAW 134 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_MONOTONIC_COARSE 135 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_BOOTTIME 136 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_PROCESS_CPUTIME_ID 137 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_THREAD_CPUTIME_ID 138 + +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_WIN_QPC 160 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_WIN_GSTAFT 161 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_WIN_GSTAFTP 162 + +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_C_TIME 163 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_C_CLOCK 164 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_C_TIMESPEC_GET 165 + +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_STEADY_CLOCK 166 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_HIGH_RESOLUTION_CLOCK 167 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_SYSTEM_CLOCK 168 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_UTC_CLOCK 169 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_TAI_CLOCK 170 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_GPS_CLOCK 171 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_FILE_CLOCK 172 + +/** + * \brief GPU timestamp sources. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_GPU_GLOBALTIMER 192 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_GPU_SM_CLOCK 193 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_GPU_SM_CLOCK64 194 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_GPU_CUPTI 195 + +/** + * The timestamp was provided by the NVTX handler’s timestamp routine. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_TOOL_PROVIDED 224 + +/** + * This predefined schema ID can be used in `nvtxPayloadData_t` to indicate that + * the payload is a blob of memory which other payload entries may point into. + * A tool will not expose this payload directly. + */ +#define NVTX_TYPE_PAYLOAD_SCHEMA_REFERENCED 1022 + +/** + * This predefined schema ID can be used in `nvtxPayloadData_t` to indicate that + * the payload is a blob which can be shown with an arbitrary data viewer. + */ +#define NVTX_TYPE_PAYLOAD_SCHEMA_RAW 1023 + +/* Custom (static) schema IDs. */ +#define NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START (1 << 24) + +/* Dynamic schema IDs (generated by the tool) start here. */ +#define NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_DYNAMIC_START 4294967296 // 1 << 32 + + +/** + * \brief Size and alignment information for predefined payload entry types. + * + * The struct contains the size and the alignment size in bytes. A respective + * array for the predefined types is passed via nvtxExtModuleInfo_t to the NVTX + * client/handler. The type (ID) is used as index into this array. + */ +typedef struct nvtxPayloadEntryTypeInfo_t +{ + uint16_t size; + uint16_t align; +} nvtxPayloadEntryTypeInfo_t; + +/** + * \brief Entry in a schema. + * + * A payload schema consists of an array of payload schema entries. It is + * registered with @ref nvtxPayloadSchemaRegister. `flag` can be set to `0` for + * simple values, 'type' is the only "required" field. If not set explicitly, + * all other fields are zero-initialized, which means that the entry has no name + * and the offset is determined based on self-alignment rules. + * + * Example schema: + * nvtxPayloadSchemaEntry_t desc[] = { + * {0, NVTX_EXT_PAYLOAD_TYPE_UINT8, "one byte"}, + * {0, NVTX_EXT_PAYLOAD_TYPE_INT32, "four bytes"} + * }; + */ +typedef struct nvtxPayloadSchemaEntry_t +{ + /** + * \brief Flags to augment the basic type. + * + * This field allows additional properties of the payload entry to be + * specified. Valid values are `NVTX_PAYLOAD_ENTRY_FLAG_*`. + */ + uint64_t flags; + + /** + * \brief Predefined payload schema entry type or ID of a registered payload + * schema. + */ + uint64_t type; + + /** + * \brief Name of the payload entry. (Optional) + * + * Providing a name is useful to give a meaning to the associated value. + */ + const char* name; + + /** + * \brief Description of the payload entry. (Optional) + */ + const char* description; + + /** + * \brief String or array length or union selector for union types. + * + * If @ref type is a C string type, this defines the length of the string. + * + * If @ref flags specify that the entry is an array, this field defines the + * length of the array. See `NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_*` for more + * details. + * + * If @ref type implies that the entry is a union with schema type + * @ref NVTX_PAYLOAD_SCHEMA_TYPE_UNION (external selection of the union + * member), this field contains the index (starting with 0) to an entry of + * integer type in the same schema. The associated field contains the + * selected union member. + * + * @note An array of schema type @ref NVTX_PAYLOAD_SCHEMA_TYPE_UNION is not + * supported. @ref NVTX_PAYLOAD_SCHEMA_TYPE_UNION_WITH_INTERNAL_SELECTOR can + * be used instead. + */ + uint64_t arrayOrUnionDetail; + + /** + * \brief Offset in the binary payload data (in bytes). + * + * This field specifies the byte offset from the base address of the actual + * binary data (blob) to the data of this entry. + * + * This is an optional field, but it is recommended to specify this field to + * avoid issues in the automatic detection of the offset by a tool/handler. + */ + uint64_t offset; + + /** + * Semantics are not yet defined. + */ + void* semantics; + + /** + * Reserved for future use. Do not use it! + */ + void* reserved; +} nvtxPayloadSchemaEntry_t; + +/** + * \brief Binary payload data, size and decoding information. + * + * An array of nvtxPayloadData_t is passed to the NVTX event attribute payload + * member. To attach a single payload the macro @ref NVTX_EXT_PAYLOAD_SET_ATTR + * can be used. + */ +typedef struct nvtxPayloadData_t +{ + /** + * The schema ID, which defines the layout of the binary data. + */ + uint64_t schemaId; + + /** + * Size of the binary payload (blob) in bytes. + */ + size_t size; + + /** + * Pointer to the binary payload data. + */ + const void* payload; +} nvtxPayloadData_t; + +/* Helper macros for safe double-cast of pointer to uint64_t value */ +#ifndef NVTX_POINTER_AS_PAYLOAD_ULLVALUE +# ifdef __cplusplus +# define NVTX_POINTER_AS_PAYLOAD_ULLVALUE(p) \ + static_cast(reinterpret_cast(p)) +# else +#define NVTX_POINTER_AS_PAYLOAD_ULLVALUE(p) ((uint64_t)(uintptr_t)p) +# endif +#endif + + +#define NVTX_PAYLOAD_CONCAT2(a,b) a##b +#define NVTX_PAYLOAD_CONCAT(a,b) NVTX_PAYLOAD_CONCAT2(a,b) +#define NVTX_DATA_VAR NVTX_PAYLOAD_CONCAT(nvtxDFDB,__LINE__) + +/** + * \brief Helper macro to attach a single payload to an NVTX event attribute. + * + * @note The NVTX push, start or mark operation must not be in the same or a + * nested scope. + */ +#define NVTX_PAYLOAD_EVTATTR_SET(EVTATTR, SCHEMA_ID, PAYLOAD_ADDR, SIZE) \ + nvtxPayloadData_t NVTX_DATA_VAR[] = {{SCHEMA_ID, SIZE, PAYLOAD_ADDR}}; \ + (EVTATTR).payload.ullValue = \ + NVTX_POINTER_AS_PAYLOAD_ULLVALUE(NVTX_DATA_VAR); \ + (EVTATTR).payloadType = NVTX_PAYLOAD_TYPE_BINARY; \ + (EVTATTR).reserved0 = 1; + +/** + * \brief Helper macro to attach multiple payloads to an NVTX event attribute. + * + * The payload data array (`nvtxPayloadData_t`) is passed as first argument to + * this macro. + */ +#define NVTX_PAYLOAD_EVTATTR_SET_MULTIPLE(EVTATTR, PAYLOADS) \ + (EVTATTR).payloadType = NVTX_PAYLOAD_TYPE_BINARY; \ + (EVTATTR).reserved0 = sizeof(PAYLOADS)/sizeof(nvtxPayloadData_t); \ + (EVTATTR).payload.ullValue = NVTX_POINTER_AS_PAYLOAD_ULLVALUE(PAYLOADS); + + +/** + * \brief The payload schema type. + * + * A schema can be either of these types. + */ +enum nvtxPayloadSchemaType +{ + NVTX_PAYLOAD_SCHEMA_TYPE_INVALID = 0, + + NVTX_PAYLOAD_SCHEMA_TYPE_STATIC = 1, + NVTX_PAYLOAD_SCHEMA_TYPE_DYNAMIC = 2, + + NVTX_PAYLOAD_SCHEMA_TYPE_UNION = 3, + NVTX_PAYLOAD_SCHEMA_TYPE_UNION_WITH_INTERNAL_SELECTOR = 4 +}; + +/** + * \brief Flags for static and dynamic schemas. + */ +enum nvtxPayloadSchemaFlags +{ + NVTX_PAYLOAD_SCHEMA_FLAG_NONE = 0, + + /** + * This flag indicates that a schema and the corresponding payloads can + * contain fields which require a deep copy. + */ + NVTX_PAYLOAD_SCHEMA_FLAG_DEEP_COPY = (1 << 1), + + /** + * This flag indicates that a schema and the corresponding payloads can + * be referenced by another payload of the same event. + */ + NVTX_PAYLOAD_SCHEMA_FLAG_REFERENCED = (1 << 2), + + /** + * The schema describes a deferred event/marker. Such a schema requires one + * timestamp entry and one string entry with the flag + * `NVTX_PAYLOAD_ENTRY_FLAG_EVENT_MESSAGE`. Category and color can be + * optionally specified with the respective entry types. The deferred event + * can contain a binary payload itself by using a custom schema ID as type + * its schema description. Multiple occurrences of the same event can be + * described by specifying an array timestamps. + */ + NVTX_PAYLOAD_SCHEMA_FLAG_DEFERRED_EVENT = (1 << 3), + /** + * The schema describes a deferred event/marker. Such a schema requires + * one start timestamp, one end timestamp and one string entry with the flag + * `NVTX_PAYLOAD_ENTRY_FLAG_EVENT_MESSAGE`. Category and color can be + * optionally specified with the respective entry types. The deferred range + * can contain a binary payload itself by using a custom schema ID as type + * its schema description. + * + * Timestamps can be provided in different ways: + * - A single range has two timestamp entries with the first (smaller entry + * index) being used as the start/push timestamp. + * - If the range schema contains one array of timestamps, the tool assumes + * that the array contains alternating start and end timestamps. + * - If two timestamp arrays are specified the first entry (with the + * smaller entry index) is assumed to contain the start timestamps. Both + * arrays have to be of the same size. + */ + NVTX_PAYLOAD_SCHEMA_FLAG_DEFERRED_RANGE = (2 << 3) +}; + +/** + * The values allow the valid fields in @ref nvtxPayloadSchemaAttr_t to be + * specified via setting the field `fieldMask`. + */ +#define NVTX_PAYLOAD_SCHEMA_ATTR_NAME (1 << 1) +#define NVTX_PAYLOAD_SCHEMA_ATTR_TYPE (1 << 2) +#define NVTX_PAYLOAD_SCHEMA_ATTR_FLAGS (1 << 3) +#define NVTX_PAYLOAD_SCHEMA_ATTR_ENTRIES (1 << 4) +#define NVTX_PAYLOAD_SCHEMA_ATTR_NUM_ENTRIES (1 << 5) +#define NVTX_PAYLOAD_SCHEMA_ATTR_STATIC_SIZE (1 << 6) +#define NVTX_PAYLOAD_SCHEMA_ATTR_ALIGNMENT (1 << 7) +#define NVTX_PAYLOAD_SCHEMA_ATTR_SCHEMA_ID (1 << 8) + +/** + * NVTX payload schema attributes. + */ +typedef struct nvtxPayloadSchemaAttr_t +{ + /** + * \brief Mask of valid fields in this structure. + * + * The values from `enum nvtxPayloadSchemaAttributes` have to be used. + */ + uint64_t fieldMask; + + /** + * \brief Name of the payload schema. (Optional) + */ + const char* name; + + /** + * \brief Payload schema type. (Mandatory) \anchor PAYLOAD_TYPE_FIELD + * + * A value from `enum nvtxPayloadSchemaType` has to be used. + */ + uint64_t type; + + /** + * \brief Payload schema flags. (Optional) + * + * Flags defined in `enum nvtxPayloadSchemaFlags` can be used to set + * additional properties of the schema. + */ + uint64_t flags; + + /** + * \brief Entries of a payload schema. (Mandatory) \anchor ENTRIES_FIELD + * + * This field is a pointer to an array of schema entries, each describing a + * field in a data structure, e.g. in a C struct or union. + */ + const nvtxPayloadSchemaEntry_t* entries; + + /** + * \brief Number of entries in the payload schema. (Mandatory) + * + * Number of entries in the array of payload entries \ref ENTRIES_FIELD. + */ + size_t numEntries; + + /** + * \brief The binary payload size in bytes for static payload schemas. + * + * If \ref PAYLOAD_TYPE_FIELD is @ref NVTX_PAYLOAD_SCHEMA_TYPE_DYNAMIC this + * value is ignored. If this field is not specified for a schema of type + * @ref NVTX_PAYLOAD_SCHEMA_TYPE_STATIC, the size can be automatically + * determined by a tool. + */ + size_t payloadStaticSize; + + /** + * \brief The byte alignment for packed structures. + * + * If not specified, this field defaults to `0`, which means that the fields + * in the data structure are not packed and natural alignment rules can be + * applied. + */ + size_t packAlign; + + /* Static/custom schema ID must be + >= NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START and + < NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_DYNAMIC_START */ + uint64_t schemaId; +} nvtxPayloadSchemaAttr_t; + +/** + * \brief Register a payload schema. + * + * @param domain NVTX domain handle. + * @param attr NVTX payload schema attributes. + */ +NVTX_DECLSPEC uint64_t NVTX_API nvtxPayloadSchemaRegister( + nvtxDomainHandle_t domain, const nvtxPayloadSchemaAttr_t* attr); + +/** + * \brief Enumeration entry. + * + * Since the value of an enum entry might not be meaningful for the analysis, + * a tool can show the name of enum entry instead. + * + * @note EXPERIMENTAL + */ +typedef struct nvtxPayloadEnum_t +{ + /** + * Name of the enum value. + */ + const char* name; + + /** + * Value of the enum entry. + */ + uint64_t value; + + /** + * Indicates that this entry sets a specific set of bits, which can be used + * to easily define bitsets. + */ + int8_t isFlag; +} nvtxPayloadEnum_t; + +/** + * The values are used to set the field `fieldMask` and specify which fields in + * `nvtxPayloadEnumAttr_t` are set. + */ +#define NVTX_PAYLOAD_ENUM_ATTR_NAME (1 << 1) +#define NVTX_PAYLOAD_ENUM_ATTR_ENTRIES (1 << 2) +#define NVTX_PAYLOAD_ENUM_ATTR_NUM_ENTRIES (1 << 3) +#define NVTX_PAYLOAD_ENUM_ATTR_SIZE (1 << 4) +#define NVTX_PAYLOAD_ENUM_ATTR_SCHEMA_ID (1 << 5) + +/** + * NVTX payload enumeration type attributes. + */ +typedef struct nvtxPayloadEnumAttr_t { + /** + * Mask of valid fields in this struct. + * The values from `enum nvtxPayloadSchemaAttributes` have to be used. + */ + uint64_t fieldMask; + + /** + * Name of the enum. (Optional) + */ + const char* name; + + /** + * Entries of the enum. (Mandatory) + */ + const nvtxPayloadEnum_t* entries; + + /** + * Number of entries in the enum. (Mandatory) + */ + size_t numEntries; + + /** + * Size of enumeration type in bytes + */ + size_t sizeOfEnum; + + /** + * Static/custom schema ID must be + * >= NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START and + * < NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_DYNAMIC_START + */ + uint64_t schemaId; +} nvtxPayloadEnumAttr_t; + +/** + * \brief Register an enumeration type with the payload extension. + * + * @param domain NVTX domain handle + * @param attr NVTX payload enumeration type attributes. + */ +NVTX_DECLSPEC uint64_t nvtxPayloadEnumRegister(nvtxDomainHandle_t domain, + const nvtxPayloadEnumAttr_t* attr); + +/** + * \brief Callback Ids of API functions in the payload extension. + * + * The NVTX handler can use these values to register a handler function. When + * InitializeInjectionNvtxExtension(nvtxExtModuleInfo_t* moduleInfo) is + * executed, a handler routine 'handlenvtxPayloadRegisterSchema' can be + * registered as follows: + * moduleInfo->segments->slots[NVTX3EXT_CBID_nvtxPayloadSchemaRegister] = + * (intptr_t)handlenvtxPayloadRegisterSchema; + */ +typedef enum NvtxExtPayloadCallbackId +{ + NVTX3EXT_CBID_nvtxPayloadSchemaRegister = 0, + NVTX3EXT_CBID_nvtxPayloadEnumRegister = 1, + NVTX3EXT_CBID_PAYLOAD_FN_NUM = 2 +} NvtxExtPayloadCallbackId; + +#ifdef __GNUC__ +#pragma GCC visibility push(internal) +#endif + +#define NVTX_EXT_TYPES_GUARD /* Ensure other headers cannot include directly */ +#include "nvtxExtDetail/nvtxExtTypes.h" +#undef NVTX_EXT_TYPES_GUARD + +#ifndef NVTX_NO_IMPL +#define NVTX_EXT_IMPL_PAYLOAD_GUARD /* Ensure other headers cannot included directly */ +#include "nvtxExtDetail/nvtxExtPayloadTypeInfo.h" +#include "nvtxExtDetail/nvtxExtImplPayload_v1.h" +#undef NVTX_EXT_IMPL_PAYLOAD_GUARD +#endif /*NVTX_NO_IMPL*/ + +#ifdef __GNUC__ +#pragma GCC visibility pop +#endif + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif /* NVTOOLSEXT_PAYLOAD_H */ diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtImpl.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtImpl.h new file mode 100644 index 0000000..5e42778 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtImpl.h @@ -0,0 +1,93 @@ +/* +* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +#ifndef NVTX_EXT_IMPL_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined). +#endif + +#ifndef NVTX_EXT_IMPL_H +#define NVTX_EXT_IMPL_H +/* ---- Include required platform headers ---- */ + +#if defined(_WIN32) + +#include + +#else +#include + +#if defined(__ANDROID__) +#include +#endif + +#if defined(__linux__) || defined(__CYGWIN__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#endif + +/* ---- Define macros used in this file ---- */ + +#ifdef NVTX_DEBUG_PRINT +#ifdef __ANDROID__ +#include +#define NVTX_ERR(...) __android_log_print(ANDROID_LOG_ERROR, "NVTOOLSEXT", __VA_ARGS__); +#define NVTX_INFO(...) __android_log_print(ANDROID_LOG_INFO, "NVTOOLSEXT", __VA_ARGS__); +#else +#include +#define NVTX_ERR(...) fprintf(stderr, "NVTX_ERROR: " __VA_ARGS__) +#define NVTX_INFO(...) fprintf(stderr, "NVTX_INFO: " __VA_ARGS__) +#endif +#else /* !defined(NVTX_DEBUG_PRINT) */ +#define NVTX_ERR(...) +#define NVTX_INFO(...) +#endif + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +// #ifdef __GNUC__ +// #pragma GCC visibility push(hidden) +// #endif + +#define NVTX_EXTENSION_FRESH 0 +#define NVTX_EXTENSION_DISABLED 1 +#define NVTX_EXTENSION_STARTING 2 +#define NVTX_EXTENSION_LOADED 3 + +NVTX_LINKONCE_DEFINE_GLOBAL NvtxExtInitializeInjectionFunc_t NVTX_VERSIONED_IDENTIFIER(injectionFnPtr) = (NvtxExtInitializeInjectionFunc_t)0; + +#define NVTX_EXT_INIT_GUARD +#include "nvtxExtInit.h" +#undef NVTX_EXT_INIT_GUARD + +// #ifdef __GNUC__ +// #pragma GCC visibility pop +// #endif + +#ifdef __cplusplus +} /* extern "C" */ +#endif /* __cplusplus */ + +#endif /* NVTX_EXT_IMPL_H */ \ No newline at end of file diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h new file mode 100644 index 0000000..d589f63 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h @@ -0,0 +1,85 @@ +/* +* Copyright 2021 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +#ifndef NVTX_EXT_IMPL_PAYLOAD_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExtPayload.h (except when NVTX_NO_IMPL is defined). +#endif + +#define NVTX_EXT_IMPL_GUARD +#include "nvtxExtImpl.h" +#undef NVTX_EXT_IMPL_GUARD + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#define NVTX_EXT_PAYLOAD_VERSIONED_IDENTIFIER_L3(NAME, VERSION, COMPATID) \ + NAME##_v##VERSION##_mem##COMPATID +#define NVTX_EXT_PAYLOAD_VERSIONED_IDENTIFIER_L2(NAME, VERSION, COMPATID) \ + NVTX_EXT_PAYLOAD_VERSIONED_IDENTIFIER_L3(NAME, VERSION, COMPATID) +#define NVTX_EXT_PAYLOAD_VERSIONED_ID(NAME) \ + NVTX_EXT_PAYLOAD_VERSIONED_IDENTIFIER_L2(NAME, NVTX_VERSION, NVTX_EXT_COMPATID_PAYLOAD) + +/* + * Function slots for the binary payload extension. First entry is the module + * state, initialized to `0` (`NVTX_EXTENSION_FRESH`). + */ +NVTX_LINKONCE_DEFINE_GLOBAL intptr_t +NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)[NVTX3EXT_CBID_PAYLOAD_FN_NUM + 1] + = {0}; + +NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadInitOnce)() +{ + nvtxExtModuleSegment_t segment = { + 0, // unused (only one segment) + NVTX3EXT_CBID_PAYLOAD_FN_NUM, + NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots) + 1 + }; + + nvtxExtModuleInfo_t module = { + NVTX_VERSION, sizeof(nvtxExtModuleInfo_t), + NVTX_EXT_MODULEID_PAYLOAD, NVTX_EXT_COMPATID_PAYLOAD, + 1, &segment, // number of segments, segments + NULL, // no export function needed + // bake type sizes and alignment information into program binary + &nvtxExtPayloadTypeInfo + }; + + NVTX_INFO( "%s\n", __FUNCTION__ ); + + NVTX_VERSIONED_IDENTIFIER(nvtxExtInitOnce)(&module, + NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)); +} + +#define NVTX_EXT_FN_IMPL(ret_val, fn_name, signature, arg_names) \ +typedef ret_val ( * fn_name##_impl_fntype )signature; \ +NVTX_LINKONCE_DEFINE_FUNCTION ret_val fn_name signature { \ + intptr_t slot = NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)[NVTX3EXT_CBID_##fn_name + 1]; \ + if (slot != NVTX_EXTENSION_DISABLED) { \ + if (slot) { \ + return (*(fn_name##_impl_fntype)slot) arg_names; \ + } else { \ + NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadInitOnce)(); \ + slot = NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)[NVTX3EXT_CBID_##fn_name + 1]; \ + if (slot != NVTX_EXTENSION_DISABLED && slot) { \ + return (*(fn_name##_impl_fntype)slot) arg_names; \ + } \ + } \ + } \ + return ((ret_val)(intptr_t)-1); \ +} + +NVTX_EXT_FN_IMPL(uint64_t, nvtxPayloadSchemaRegister, (nvtxDomainHandle_t domain, const nvtxPayloadSchemaAttr_t* attr), (domain, attr)) + +NVTX_EXT_FN_IMPL(uint64_t, nvtxPayloadEnumRegister, (nvtxDomainHandle_t domain, const nvtxPayloadEnumAttr_t* attr), (domain, attr)) + +#undef NVTX_EXT_FN_IMPL + +#ifdef __cplusplus +} /* extern "C" */ +#endif /* __cplusplus */ \ No newline at end of file diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtInit.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtInit.h new file mode 100644 index 0000000..724c217 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtInit.h @@ -0,0 +1,363 @@ +/* +* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +#ifndef NVTX_EXT_INIT_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined). +#endif + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/* ---- Platform-independent helper definitions and functions ---- */ + +/* Prefer macros over inline functions to reduce symbol resolution at link time */ + +#if defined(_WIN32) +#define NVTX_PATHCHAR wchar_t +#define NVTX_STR(x) L##x +#define NVTX_GETENV _wgetenv +#define NVTX_BUFSIZE MAX_PATH +#define NVTX_DLLHANDLE HMODULE +#define NVTX_DLLOPEN(x) LoadLibraryW(x) +#define NVTX_DLLFUNC GetProcAddress +#define NVTX_DLLCLOSE FreeLibrary +#define NVTX_YIELD() SwitchToThread() +#define NVTX_MEMBAR() MemoryBarrier() +#define NVTX_ATOMIC_WRITE_32(address, value) InterlockedExchange((volatile LONG*)address, value) +#define NVTX_ATOMIC_CAS_32(old, address, exchange, comparand) old = InterlockedCompareExchange((volatile LONG*)address, exchange, comparand) +#define NVTX_ATOMIC_WRITE_PTR(address, value) InterlockedExchangePointer((volatile PVOID*)address, (PVOID)value) +#define NVTX_ATOMIC_CAS_PTR(old, address, exchange, comparand) old = (intptr_t)InterlockedCompareExchangePointer((volatile PVOID*)address, (PVOID)exchange, (PVOID)comparand) + + +#elif defined(__GNUC__) +#define NVTX_PATHCHAR char +#define NVTX_STR(x) x +#define NVTX_GETENV getenv +#define NVTX_BUFSIZE PATH_MAX +#define NVTX_DLLHANDLE void* +#define NVTX_DLLOPEN(x) dlopen(x, RTLD_LAZY) +#define NVTX_DLLFUNC dlsym +#define NVTX_DLLCLOSE dlclose +#define NVTX_YIELD() sched_yield() +#define NVTX_MEMBAR() __sync_synchronize() +/* Ensure full memory barrier for atomics, to match Windows functions */ +#define NVTX_ATOMIC_WRITE_32(address, value) __sync_synchronize(); __sync_lock_test_and_set(address, value) +#define NVTX_ATOMIC_CAS_32(old, address, exchange, comparand) __sync_synchronize(); old = __sync_val_compare_and_swap(address, exchange, comparand) +#define NVTX_ATOMIC_WRITE_PTR(address, value) __sync_synchronize(); __sync_lock_test_and_set(address, value) +#define NVTX_ATOMIC_CAS_PTR(old, address, exchange, comparand) __sync_synchronize(); old = __sync_val_compare_and_swap(address, exchange, comparand) +#else +#error The library does not support your configuration! +#endif + +/* Define this to 1 for platforms that where pre-injected libraries can be discovered. */ +#if defined(_WIN32) +/* TODO */ +#define NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY 0 +#else +#define NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY 0 +#endif + +/* Define this to 1 for platforms that support environment variables */ +/* TODO: Detect UWP, a.k.a. Windows Store app, and set this to 0. */ +/* Try: #if defined(WINAPI_FAMILY_PARTITION) && WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) */ +#define NVTX_SUPPORT_ENV_VARS 1 + +/* Define this to 1 for platforms that support dynamic/shared libraries */ +#define NVTX_SUPPORT_DYNAMIC_INJECTION_LIBRARY 1 + +/* Injection libraries implementing InitializeInjectionNvtxExtension may be statically linked, +* and this will override any dynamic injection. Useful for platforms where dynamic +* injection is not available. Since weak symbols not explicitly marked extern are +* guaranteed to be initialized to zero if no definitions are found by the linker, the +* dynamic injection process proceeds normally if pfnInitializeInjectionNvtx2 is 0. */ +#if defined(__GNUC__) && !defined(_WIN32) && !defined(__CYGWIN__) +#define NVTX_SUPPORT_STATIC_INJECTION_LIBRARY 1 +/* To statically inject an NVTX library, define InitializeInjectionNvtxExtension_fnptr as a normal +* symbol (not weak) pointing to the implementation of InitializeInjectionNvtxExtension (which +* does not need to be named "InitializeInjectionNvtxExtension" as is necessary in a dynamic +* injection library. */ +__attribute__((weak)) NvtxExtInitializeInjectionFunc_t InitializeInjectionNvtxExtension_fnptr; +#else +#define NVTX_SUPPORT_STATIC_INJECTION_LIBRARY 0 +#endif + + + +/* This function tries to find or load an NVTX injection library and get the +* address of its InitializeInjectionExtension function. If such a function pointer +* is found, it is called, and passed the address of this NVTX instance's +* nvtxGetExportTable function, so the injection can attach to this instance. +* If the initialization fails for any reason, any dynamic library loaded will +* be freed, and all NVTX implementation functions will be set to no-ops. If +* initialization succeeds, NVTX functions not attached to the tool will be set +* to no-ops. This is implemented as one function instead of several small +* functions to minimize the number of weak symbols the linker must resolve. +* Order of search is: +* - Pre-injected library exporting InitializeInjectionNvtxExtension +* - Loadable library exporting InitializeInjectionNvtxExtension +* - Path specified by env var NVTX_INJECTION??_PATH (?? is 32 or 64) +* - On Android, libNvtxInjection??.so within the package (?? is 32 or 64) +* - Statically-linked injection library defining InitializeInjectionNvtx2_fnptr +*/ +NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_VERSIONED_IDENTIFIER(nvtxExtLoadInjectionLibrary)(NvtxExtInitializeInjectionFunc_t* out_init_fnptr); +NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_VERSIONED_IDENTIFIER(nvtxExtLoadInjectionLibrary)(NvtxExtInitializeInjectionFunc_t* out_init_fnptr) +{ + const char* const initFuncName = "InitializeInjectionNvtxExtension"; + NvtxExtInitializeInjectionFunc_t init_fnptr = (NvtxExtInitializeInjectionFunc_t)0; + NVTX_DLLHANDLE injectionLibraryHandle = (NVTX_DLLHANDLE)0; + + if(out_init_fnptr){ + *out_init_fnptr = (NvtxExtInitializeInjectionFunc_t)0; + } + +#if NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY + /* Use POSIX global symbol chain to query for init function from any module */ + init_fnptr = (NvtxExtInitializeInjectionFunc_t)NVTX_DLLFUNC(0, initFuncName); +#endif + +#if NVTX_SUPPORT_DYNAMIC_INJECTION_LIBRARY + /* Try discovering dynamic injection library to load */ + if (!init_fnptr) + { +#if NVTX_SUPPORT_ENV_VARS + /* If env var NVTX_INJECTION64_PATH is set, it should contain the path + * to a 64-bit dynamic NVTX injection library (and similar for 32-bit). */ + const NVTX_PATHCHAR* const nvtxEnvVarName = (sizeof(void*) == 4) + ? NVTX_STR("NVTX_INJECTION32_PATH") + : NVTX_STR("NVTX_INJECTION64_PATH"); +#endif /* NVTX_SUPPORT_ENV_VARS */ + NVTX_PATHCHAR injectionLibraryPathBuf[NVTX_BUFSIZE]; + const NVTX_PATHCHAR* injectionLibraryPath = (const NVTX_PATHCHAR*)0; + + /* Refer to this variable explicitly in case all references to it are #if'ed out */ + (void)injectionLibraryPathBuf; + +#if NVTX_SUPPORT_ENV_VARS + /* Disable the warning for getenv & _wgetenv -- this usage is safe because + * these functions are not called again before using the returned value. */ +#if defined(_MSC_VER) +#pragma warning( push ) +#pragma warning( disable : 4996 ) +#endif + injectionLibraryPath = NVTX_GETENV(nvtxEnvVarName); +#if defined(_MSC_VER) +#pragma warning( pop ) +#endif +#endif + +#if defined(__ANDROID__) + if (!injectionLibraryPath) + { + const char *bits = (sizeof(void*) == 4) ? "32" : "64"; + char cmdlineBuf[32]; + char pkgName[PATH_MAX]; + int count; + int pid; + FILE *fp; + size_t bytesRead; + size_t pos; + + pid = (int)getpid(); + count = snprintf(cmdlineBuf, sizeof(cmdlineBuf), "/proc/%d/cmdline", pid); + if (count <= 0 || count >= (int)sizeof(cmdlineBuf)) + { + NVTX_ERR("Path buffer too small for: /proc/%d/cmdline\n", pid); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + + fp = fopen(cmdlineBuf, "r"); + if (!fp) + { + NVTX_ERR("File couldn't be opened: %s\n", cmdlineBuf); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + + bytesRead = fread(pkgName, 1, sizeof(pkgName) - 1, fp); + fclose(fp); + if (bytesRead == 0) + { + NVTX_ERR("Package name couldn't be read from file: %s\n", cmdlineBuf); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + + pkgName[bytesRead] = 0; + + /* String can contain colon as a process separator. In this case the package name is before the colon. */ + pos = 0; + while (pos < bytesRead && pkgName[pos] != ':' && pkgName[pos] != '\0') + { + ++pos; + } + pkgName[pos] = 0; + + count = snprintf(injectionLibraryPathBuf, NVTX_BUFSIZE, "/data/data/%s/files/libNvtxInjection%s.so", pkgName, bits); + if (count <= 0 || count >= NVTX_BUFSIZE) + { + NVTX_ERR("Path buffer too small for: /data/data/%s/files/libNvtxInjection%s.so\n", pkgName, bits); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + + /* On Android, verify path is accessible due to aggressive file access restrictions. */ + /* For dlopen, if the filename contains a leading slash, then it is interpreted as a */ + /* relative or absolute pathname; otherwise it will follow the rules in ld.so. */ + if (injectionLibraryPathBuf[0] == '/') + { +#if (__ANDROID_API__ < 21) + int access_err = access(injectionLibraryPathBuf, F_OK | R_OK); +#else + int access_err = faccessat(AT_FDCWD, injectionLibraryPathBuf, F_OK | R_OK, 0); +#endif + if (access_err != 0) + { + NVTX_ERR("Injection library path wasn't accessible [code=%s] [path=%s]\n", strerror(errno), injectionLibraryPathBuf); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + } + injectionLibraryPath = injectionLibraryPathBuf; + } +#endif + + /* At this point, injectionLibraryPath is specified if a dynamic + * injection library was specified by a tool. */ + if (injectionLibraryPath) + { + /* Load the injection library */ + injectionLibraryHandle = NVTX_DLLOPEN(injectionLibraryPath); + if (!injectionLibraryHandle) + { + NVTX_ERR("Failed to load injection library\n"); + return NVTX_ERR_INIT_LOAD_LIBRARY; + } + else + { + /* Attempt to get the injection library's entry-point */ + init_fnptr = (NvtxExtInitializeInjectionFunc_t)NVTX_DLLFUNC(injectionLibraryHandle, initFuncName); + if (!init_fnptr) + { + NVTX_DLLCLOSE(injectionLibraryHandle); + NVTX_ERR("Failed to get address of function %s from injection library\n", initFuncName); + return NVTX_ERR_INIT_MISSING_LIBRARY_ENTRY_POINT; + } + } + } + } +#endif + +#if NVTX_SUPPORT_STATIC_INJECTION_LIBRARY + if (!init_fnptr) + { + /* Check weakly-defined function pointer. A statically-linked injection can define this as + * a normal symbol and it will take precedence over a dynamic injection. */ + if (InitializeInjectionNvtxExtension_fnptr) + { + init_fnptr = InitializeInjectionNvtxExtension_fnptr; + } + } +#endif + + if(out_init_fnptr){ + *out_init_fnptr = init_fnptr; + } + + /* At this point, if init_fnptr is not set, then no tool has specified + * an NVTX injection library -- return non-success result so all NVTX + * API functions will be set to no-ops. */ + if (!init_fnptr) + { + return NVTX_ERR_NO_INJECTION_LIBRARY_AVAILABLE; + } + + return NVTX_SUCCESS; +} + +NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxExtInitOnce) ( + nvtxExtModuleInfo_t* moduleInfo, + intptr_t* moduleState + ) +{ + intptr_t old; + + NVTX_INFO( "%s\n", __FUNCTION__ ); + + if( *moduleState == NVTX_EXTENSION_LOADED) { + return; + } + + NVTX_ATOMIC_CAS_PTR( + old, + moduleState, + NVTX_EXTENSION_STARTING, + NVTX_EXTENSION_FRESH); + if (old == NVTX_EXTENSION_FRESH) + { + NvtxExtInitializeInjectionFunc_t init_fnptr = NVTX_VERSIONED_IDENTIFIER(injectionFnPtr); + int entryPointStatus = 0; + int forceAllToNoops = 0; + + /* Load & initialize injection library -- it will assign the function pointers */ + if(init_fnptr == 0){ + int result = 0; + + /* try to load vanilla NVTX first*/ + nvtxInitialize(0); + + result = NVTX_VERSIONED_IDENTIFIER(nvtxExtLoadInjectionLibrary)(&init_fnptr); + /*at this point init_fnptr will be either 0 or a real function*/ + + if(result == NVTX_SUCCESS) { + NVTX_VERSIONED_IDENTIFIER(injectionFnPtr) = init_fnptr; + } + else { + NVTX_ERR("Failed to load injection library\n"); + } + } + + if(init_fnptr != 0) { + /* Invoke injection library's initialization function. If it returns + * 0 (failure) and a dynamic injection was loaded, unload it. */ + entryPointStatus = init_fnptr(moduleInfo); + if (entryPointStatus == 0) { + NVTX_ERR("Failed to initialize injection library -- initialization function returned 0\n"); + } + } + + /* Clean up any functions that are still uninitialized so that they are skipped. + * Set all to null if injection init function failed as well. + */ + forceAllToNoops = (init_fnptr == 0) || (entryPointStatus == 0); + for(size_t s = 0; s < moduleInfo->segmentsCount; ++s){ + nvtxExtModuleSegment_t* segment = moduleInfo->segments+s; + for(size_t i = 0; i < segment->slotCount; ++i){ + if(forceAllToNoops || (segment->functionSlots[i] == NVTX_EXTENSION_FRESH)){ + segment->functionSlots[i] = NVTX_EXTENSION_DISABLED; + } + } + } + + NVTX_MEMBAR(); + + /* Signal that initialization has finished, so now the assigned function pointers will be used */ + NVTX_ATOMIC_WRITE_PTR( + moduleState, + NVTX_EXTENSION_LOADED); + } + else /* Spin-wait until initialization has finished */ + { + NVTX_MEMBAR(); + while (*moduleState != NVTX_EXTENSION_LOADED) + { + NVTX_YIELD(); + NVTX_MEMBAR(); + } + } +} + +#ifdef __cplusplus +} +#endif /* __cplusplus */ diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtPayloadTypeInfo.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtPayloadTypeInfo.h new file mode 100644 index 0000000..c2c1ac5 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtPayloadTypeInfo.h @@ -0,0 +1,128 @@ +#ifndef NVTX_EXT_IMPL_PAYLOAD_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExtPayload.h (except when NVTX_NO_IMPL is defined). +#endif + +/* + * Helper array to get the alignment for each predefined C language type. + */ + +typedef void* pointer_type; + +#if __STDC_VERSION__ >= 201112L /* or CPP11 */ +#include +#define nvtx_alignof(type) alignof(type) +#define nvtx_alignof2(type,tname) alignof(type) +#else /* __STDC_VERSION__ >= 201112L */ +#ifndef __cplusplus + +#include +#define nvtx_alignof(type) offsetof(struct {char c; type d;}, d) +#define nvtx_alignof2(type,tname) nvtx_alignof(type) + +#else /* __cplusplus */ + +#define MKTYPEDEF(TYPE) typedef struct {char c; TYPE d;} _nvtx_##TYPE +#define MKTYPEDEF2(TYPE,TNAME) typedef struct {char c; TYPE d;} _nvtx_##TNAME +#define nvtx_alignof(TNAME) offsetof(_nvtx_##TNAME, d) +#define nvtx_alignof2(type,tname) offsetof(_nvtx_##tname, d) + +MKTYPEDEF(char); +MKTYPEDEF2(unsigned char, uchar); +MKTYPEDEF(short); +MKTYPEDEF2(unsigned short, ushort); +MKTYPEDEF(int); +MKTYPEDEF2(unsigned int, uint); +MKTYPEDEF(long); +MKTYPEDEF2(unsigned long, ulong); +MKTYPEDEF2(long long, longlong); +MKTYPEDEF2(unsigned long long, ulonglong); + +MKTYPEDEF(int8_t); +MKTYPEDEF(uint8_t); +MKTYPEDEF(int16_t); +MKTYPEDEF(uint16_t); +MKTYPEDEF(int32_t); +MKTYPEDEF(uint32_t); +MKTYPEDEF(int64_t); +MKTYPEDEF(uint64_t); + +MKTYPEDEF(float); +MKTYPEDEF(double); +MKTYPEDEF2(long double, longdouble); + +MKTYPEDEF(size_t); +MKTYPEDEF(pointer_type); + +MKTYPEDEF(wchar_t); +#if (__STDC_VERSION__ > 201710L) || (defined(__cplusplus) && __cplusplus > 201703L) + {sizeof(char8_t), nvtx_alignof(char8_t)}, + MKTYPEDEF(char8_t); +#endif +#if (__STDC_VERSION__ >= 201112L) || (defined(__cplusplus) && __cplusplus >= 201103L) + MKTYPEDEF(char16_t); + MKTYPEDEF(char32_t); +#endif + +#undef MKTYPEDEF +#undef MKTYPEDEF2 + +#endif /* __cplusplus */ +#endif /* __STDC_VERSION__ >= 201112L */ + +/* + * The order of entries must match the values in`enum nvtxPayloadSchemaEntryType`. + */ +const nvtxPayloadEntryTypeInfo_t nvtxExtPayloadTypeInfo[NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE] = +{ + /* The first entry contains this array's length and the size of each entry in this array. */ + {NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE, sizeof(nvtxPayloadEntryTypeInfo_t)}, + + /*** C integer types ***/ + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR */ {sizeof(char), nvtx_alignof(char)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UCHAR */ {sizeof(unsigned char), nvtx_alignof2(unsigned char, uchar)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_SHORT */ {sizeof(short), nvtx_alignof(short)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_USHORT */ {sizeof(unsigned short), nvtx_alignof2(unsigned short, ushort)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_INT */ {sizeof(int), nvtx_alignof(int)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT */ {sizeof(unsigned int), nvtx_alignof2(unsigned int, uint)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_LONG */ {sizeof(long), nvtx_alignof(long)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_ULONG */ {sizeof(unsigned long), nvtx_alignof2(unsigned long, ulong)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_LONGLONG */ {sizeof(long long), nvtx_alignof2(long long, longlong)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_ULONGLONG */ {sizeof(unsigned long long), nvtx_alignof2(unsigned long long,ulonglong)}, + + /*** Integer types with explicit size ***/ + /* NVTX_PAYLOAD_ENTRY_TYPE_INT8 */ {sizeof(int8_t), nvtx_alignof(int8_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT8 */ {sizeof(uint8_t), nvtx_alignof(uint8_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_INT16 */ {sizeof(int16_t), nvtx_alignof(int16_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT16 */ {sizeof(uint16_t), nvtx_alignof(uint16_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_INT32 */ {sizeof(int32_t), nvtx_alignof(int32_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT32 */ {sizeof(uint32_t), nvtx_alignof(uint32_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_INT64 */ {sizeof(int64_t), nvtx_alignof(int64_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT64 */ {sizeof(uint64_t), nvtx_alignof(uint64_t)}, + + /*** C floating point types ***/ + /* NVTX_PAYLOAD_ENTRY_TYPE_FLOAT */ {sizeof(float), nvtx_alignof(float)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_DOUBLE */ {sizeof(double), nvtx_alignof(double)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_LONGDOUBLE */ {sizeof(long double), nvtx_alignof2(long double, longdouble)}, + + /* NVTX_PAYLOAD_ENTRY_TYPE_SIZE */ {sizeof(size_t), nvtx_alignof(size_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_ADDRESS */ {sizeof(pointer_type), nvtx_alignof(pointer_type)}, + + /*** Special character types ***/ + /* NVTX_PAYLOAD_ENTRY_TYPE_WCHAR */ {sizeof(wchar_t), nvtx_alignof(wchar_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR8 */ +#if (__STDC_VERSION__ > 201710L) || (defined(__cplusplus) && __cplusplus > 201703L) + {sizeof(char8_t), nvtx_alignof(char8_t)}, +#else + {0, 0}, +#endif +#if (__STDC_VERSION__ >= 201112L) || (defined(__cplusplus) && __cplusplus >= 201103L) + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR16 */ {sizeof(char16_t), nvtx_alignof(char16_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 */ {sizeof(char32_t), nvtx_alignof(char32_t)} +#else + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR16 */ {0, 0}, + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 */ {0, 0} +#endif +}; + +#undef nvtx_alignof +#undef nvtx_alignof2 \ No newline at end of file diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtTypes.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtTypes.h new file mode 100644 index 0000000..bcad095 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtTypes.h @@ -0,0 +1,44 @@ +/* +* Copyright 2021 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +/* This header defines types which are used by the internal implementation +* of NVTX and callback subscribers. API clients do not use these types, +* so they are defined here instead of in nvToolsExt.h to clarify they are +* not part of the NVTX client API. */ + +#ifndef NVTXEXTTYPES_H +#define NVTXEXTTYPES_H + +#ifndef NVTX_EXT_TYPES_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExt[EXTENSION].h. +#endif + +typedef intptr_t (NVTX_API * NvtxExtGetExportFunction_t)(uint32_t exportFunctionId); + +typedef struct nvtxExtModuleSegment_t +{ + size_t segmentId; + size_t slotCount; + intptr_t* functionSlots; +} nvtxExtModuleSegment_t; + +typedef struct nvtxExtModuleInfo_t +{ + uint16_t nvtxVer; + uint16_t structSize; + uint16_t moduleId; + uint16_t compatId; + size_t segmentsCount; + nvtxExtModuleSegment_t* segments; + NvtxExtGetExportFunction_t getExportFunction; + const void* extInfo; +} nvtxExtModuleInfo_t; + +typedef int (NVTX_API * NvtxExtInitializeInjectionFunc_t)(nvtxExtModuleInfo_t* moduleInfo); + +#endif /* NVTXEXTTYPES_H */ \ No newline at end of file diff --git a/src/include/proxy.h b/src/include/proxy.h index fa8f388..4c75e21 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -11,6 +11,7 @@ #include "info.h" #include "socket.h" #include +#include "shm.h" enum ncclProxyOpState { ncclProxyOpNone, ncclProxyOpReady, ncclProxyOpProgress }; @@ -106,6 +107,7 @@ struct ncclProxyOpsPool { struct ncclProxyOps { ncclProxyOpsPool* pool; + ncclShmHandle_t handle; int count; int freeOp; int nextOps; @@ -145,6 +147,7 @@ struct ncclProxyPool; struct ncclProxyProgressState { // Used by main threads to send work to progress thread struct ncclProxyOpsPool* opsPool; + ncclShmHandle_t handle; char opsPoolShmSuffix[6]; pthread_t thread; @@ -164,7 +167,6 @@ struct ncclProxyState { struct ncclSocket* listenSock; int stop; CUcontext cudaCtx; - int safeAbortFlag; // Used by main thread union ncclSocketAddress* peerAddresses; @@ -176,6 +178,15 @@ struct ncclProxyState { struct ncclProxyProgressState progressState; }; +enum proxyConnectState { + connUninitialized = 0, + connInitialized = 1, + connSharedInitialized = 2, + connSetupDone = 3, + connConnected = 4, + numConnStates = 5 +}; + struct ncclProxyConnection { int send, transport, shared; int localRank; @@ -184,7 +195,7 @@ struct ncclProxyConnection { struct ncclProxyArgs *proxyAppend; struct ncclProxyArgs **proxyAppendPtr; void* transportResources; - bool initFlag; + proxyConnectState state; }; typedef ncclResult_t (*threadFunc_t)(struct ncclProxyArgs*); diff --git a/src/include/shm.h b/src/include/shm.h index 08dc849..61b0b4d 100644 --- a/src/include/shm.h +++ b/src/include/shm.h @@ -9,7 +9,9 @@ #include "nccl.h" -ncclResult_t ncclShmOpen(char* shmPath, const int shmSize, void** shmPtr, void** devShmPtr, int create); -ncclResult_t ncclShmUnlink(const char* shmname); -ncclResult_t ncclShmClose(void* shmPtr, void* devShmPtr, const int shmSize); +typedef void* ncclShmHandle_t; +ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** devShmPtr, int refcount, ncclShmHandle_t* handle); +ncclResult_t ncclShmClose(ncclShmHandle_t handle); +ncclResult_t ncclShmUnlink(ncclShmHandle_t handle); + #endif diff --git a/src/include/socket.h b/src/include/socket.h index d72480b..a0c7a4d 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -21,6 +21,7 @@ #define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec) #define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s) #define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV) +#define NCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL /* Common socket address storage structure for IPv4/IPv6 */ union ncclSocketAddress { @@ -30,32 +31,59 @@ union ncclSocketAddress { }; enum ncclSocketState { - ncclSocketConnecting = 0, - ncclSocketConnected = 1, - ncclSocketError = 2, - ncclSocketStateNum = 3 -} ; + ncclSocketStateNone = 0, + ncclSocketStateInitialized = 1, + ncclSocketStateAccepting = 2, + ncclSocketStateAccepted = 3, + ncclSocketStateConnecting = 4, + ncclSocketStateConnectPolling = 5, + ncclSocketStateConnected = 6, + ncclSocketStateReady = 7, + ncclSocketStateClosed = 8, + ncclSocketStateError = 9, + ncclSocketStateNum = 10 +}; + +enum ncclSocketType { + ncclSocketTypeUnknown = 0, + ncclSocketTypeBootstrap = 1, + ncclSocketTypeProxy = 2, + ncclSocketTypeNetSocket = 3, + ncclSocketTypeNetIb = 4 +}; struct ncclSocket { int fd; + int acceptFd; + int timedOutRetries; + int refusedRetries; union ncclSocketAddress addr; volatile uint32_t* abortFlag; int asyncFlag; enum ncclSocketState state; + int salen; + uint64_t magic; + enum ncclSocketType type; }; const char *ncclSocketToString(union ncclSocketAddress *addr, char *buf, const int numericHostForm = 1); -ncclResult_t ncclGetSocketAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair); +ncclResult_t ncclSocketGetAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair); int ncclFindInterfaceMatchSubnet(char* ifNames, union ncclSocketAddress* localAddrs, union ncclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs); int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNameMaxSize, int maxIfs); + +// Initialize a socket +ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr = NULL, uint64_t magic = NCCL_SOCKET_MAGIC, enum ncclSocketType type = ncclSocketTypeUnknown, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0); // Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call ncclResult_t ncclSocketListen(struct ncclSocket* sock); +ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr); // Connect to sock->addr. sock->fd is set after a successful call. ncclResult_t ncclSocketConnect(struct ncclSocket* sock); // Return socket connection state. -ncclResult_t ncclGetSocketState(struct ncclSocket* sock, enum ncclSocketState* state); -// Accept an incoming connection from listenSocket->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr. -ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSocket); +ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running); +// Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr. +ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* ulistenSock); +ncclResult_t ncclSocketGetFd(struct ncclSocket* sock, int* fd); +ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock); #define NCCL_SOCKET_SEND 0 #define NCCL_SOCKET_RECV 1 @@ -65,6 +93,5 @@ ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed); -/* initialize a socket. */ -ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr = NULL, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0); +ncclResult_t ncclSocketClose(struct ncclSocket* sock); #endif diff --git a/src/include/strongstream.h b/src/include/strongstream.h index 16b6e07..0984dfe 100644 --- a/src/include/strongstream.h +++ b/src/include/strongstream.h @@ -98,16 +98,19 @@ ncclResult_t ncclStrongStreamLaunchKernel( ); // Cause `a` to wait for the current state `b`. Both `a` and `b` must be acquired. +// `b_subsumes_a` indicates that all work in `a` is already present in `b`, thus +// we want to fast-forward `a` to be a clone of `b`. Knowing this permits the +// implementation to induce few graph dependencies. ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b + struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b, bool b_subsumes_a=false ); // `b` must be capturing within `graph`. ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, struct ncclStrongStream* a, cudaStream_t b + struct ncclCudaGraph graph, struct ncclStrongStream* a, cudaStream_t b, bool b_subsumes_a=false ); // `a` must be capturing within `graph`. ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, cudaStream_t a, struct ncclStrongStream* b + struct ncclCudaGraph graph, cudaStream_t a, struct ncclStrongStream* b, bool b_subsumes_a=false ); // Synchrnoization does not need the strong stream to be acquired. diff --git a/src/include/utils.h b/src/include/utils.h index 0604d15..1c300b0 100644 --- a/src/include/utils.h +++ b/src/include/utils.h @@ -27,6 +27,7 @@ ncclResult_t getHostName(char* hostname, int maxlen, const char delim); uint64_t getHash(const char* string, int n); uint64_t getHostHash(); uint64_t getPidHash(); +ncclResult_t getRandomData(void* buffer, size_t bytes); struct netIf { char prefix[64]; @@ -48,6 +49,19 @@ inline uint64_t clockNano() { return uint64_t(ts.tv_sec)*1000*1000*1000 + ts.tv_nsec; } +/* get any bytes of random data from /dev/urandom, return 0 if it succeeds; else + * return -1 */ +inline ncclResult_t getRandomData(void* buffer, size_t bytes) { + ncclResult_t ret = ncclSuccess; + if (bytes > 0) { + const size_t one = 1UL; + FILE* fp = fopen("/dev/urandom", "r"); + if (buffer == NULL || fp == NULL || fread(buffer, bytes, one, fp) != one) ret = ncclSystemError; + if (fp) fclose(fp); + } + return ret; +} + //////////////////////////////////////////////////////////////////////////////// template diff --git a/src/init.cc b/src/init.cc index ab0a064..91a8793 100644 --- a/src/init.cc +++ b/src/init.cc @@ -87,6 +87,7 @@ static ncclResult_t ncclInit() { NCCLCHECK(bootstrapNetInit()); NCCLCHECK(ncclNetPluginInit()); + initNvtxRegisteredEnums(); __atomic_store_n(&initialized, true, __ATOMIC_RELEASE); } pthread_mutex_unlock(&initLock); @@ -104,7 +105,7 @@ NCCL_API(ncclResult_t, ncclGetUniqueId, ncclUniqueId* out); ncclResult_t ncclGetUniqueId(ncclUniqueId* out) { NCCLCHECK(ncclInit()); NCCLCHECK(PtrCheck(out, "GetUniqueId", "out")); - ncclResult_t res = bootstrapGetUniqueId(out); + ncclResult_t res = bootstrapGetUniqueId((struct ncclBootstrapHandle*)out); TRACE_CALL("ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out)); return res; } @@ -117,7 +118,7 @@ ncclResult_t ncclGetUniqueId(ncclUniqueId* out) { #endif void NCCL_NO_OPTIMIZE commPoison(ncclComm_t comm) { - // Important that this does not trash intraComm0 & intraRefs. + // Important that this does not trash intraComm0. comm->rank = comm->cudaDev = comm->busId = comm->nRanks = -1; } @@ -173,11 +174,15 @@ void ncclCommPushCudaGdrFree(struct ncclComm* comm, void* handle) { } static ncclResult_t commFree(ncclComm_t comm) { + /* commFree() should not involve any sync among ranks. */ if (comm == NULL) return ncclSuccess; - // Stop all threads before we free anything. - NCCLCHECK(ncclProxyDestroy(comm)); + /* in commReclaim, we have guaranteed only last rank which calls ncclCommDestroy() will + * free all intra-process communicators; therefore, we only need to focus on local + * resource cleanup in commFree(). */ + if (comm->proxyState.thread) + pthread_join(comm->proxyState.thread, nullptr); delete[] comm->userRedOps; @@ -214,30 +219,10 @@ static ncclResult_t commFree(ncclComm_t comm) { ncclMemoryStackDestruct(&comm->memScoped); ncclMemoryStackDestruct(&comm->memPermanent); - commPoison(comm); // Important that this does not interfere with anything used below. + ncclCudaHostFree((void *)comm->abortFlag); - if (comm->initState == ncclSuccess) { - struct ncclComm* intraComm0 = comm->intraComm0; - if (0 == ncclAtomicRefCountDecrement(&intraComm0->intraRefs)) { - // Wait for all service threads to be done. We could not - // do it earlier because it could have blocked and prevented - // other ranks in the process to call ncclCommDestroy - comm = intraComm0; - while (comm != nullptr) { - if (comm->proxyState.thread) pthread_join(comm->proxyState.thread, nullptr); - struct ncclComm* next = comm->intraNext; - free(comm); - comm = next; - } - } - } else if (comm->proxyState.thread) { - pthread_join(comm->proxyState.thread, nullptr); - ncclCudaHostFree((void *)comm->abortFlag); - free(comm); - } else { - ncclCudaHostFree((void *)comm->abortFlag); - free(comm); - } + commPoison(comm); // poison comm before free to avoid comm reuse. + free(comm); return ncclSuccess; } @@ -253,7 +238,7 @@ NCCL_PARAM(DmaBufEnable, "DMABUF_ENABLE", 1); // Detect DMA-BUF support static ncclResult_t dmaBufSupported(struct ncclComm* comm) { - if (ncclParamDmaBufEnable() == 0 || comm->ncclNet->regMrDmaBuf == NULL) return ncclInternalError; + if (ncclParamDmaBufEnable() == 0 || comm->ncclNet->regMrDmaBuf == NULL || ncclCudaLibraryInit() != ncclSuccess) return ncclInternalError; #if CUDA_VERSION >= 11070 int flag = 0; CUdevice dev; @@ -330,7 +315,8 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { cudaGetDevice(&comm->cudaDev); NCCLCHECK(getBusId(comm->cudaDev, &comm->busId)); - TRACE(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx", comm, rank, ndev, comm->cudaDev, comm->busId); + comm->compCap = ncclCudaCompCap(); + TRACE(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx compCap %d", comm, rank, ndev, comm->cudaDev, comm->busId, comm->compCap); comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false; comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false; @@ -360,11 +346,13 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { } static ncclResult_t devCommSetup(ncclComm_t comm) { - NCCLCHECK(ncclStrongStreamAcquireUncaptured(&comm->deviceStream)); - + ncclResult_t ret = ncclSuccess; int nRanks = comm->nRanks; - struct ncclDevCommAndChannels *devCommAndChans, tmpCommAndChans; - NCCLCHECK(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.cudaStream)); + struct ncclDevCommAndChannels tmpCommAndChans; + struct ncclDevCommAndChannels *devCommAndChans = NULL; + + NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->deviceStream), ret, fail); + NCCLCHECKGOTO(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.cudaStream), ret, fail); ncclCommPushCudaFree(comm, devCommAndChans); comm->devComm = &devCommAndChans->comm; tmpCommAndChans.comm.rank = comm->rank; @@ -384,18 +372,18 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { if (ncclGdrCopy != NULL && ncclParamGdrCopyFifoEnable() == 1) { // The workFifoHeap lives in GDR mapped CUDA memory. - NCCLCHECK(ncclGdrCudaCalloc(&comm->workFifoHeap, &comm->devWorkFifoHeap, comm->workFifoDepth, &comm->workFifoHeapGdrHandle)); + NCCLCHECKGOTO(ncclGdrCudaCalloc(&comm->workFifoHeap, &comm->devWorkFifoHeap, comm->workFifoDepth, &comm->workFifoHeapGdrHandle), ret, fail); ncclCommPushCudaGdrFree(comm, comm->workFifoHeapGdrHandle); } else { // The workFifoHeap lives in cudaHost memory. comm->workFifoHeapGdrHandle = nullptr; - NCCLCHECK(ncclCudaHostCalloc(&comm->workFifoHeap, comm->workFifoDepth)); + NCCLCHECKGOTO(ncclCudaHostCalloc(&comm->workFifoHeap, comm->workFifoDepth), ret, fail); ncclCommPushCudaHostFree(comm, comm->workFifoHeap); comm->devWorkFifoHeap = comm->workFifoHeap; } tmpCommAndChans.comm.workFifoHeap = comm->devWorkFifoHeap; - NCCLCHECK(ncclCudaHostCalloc(&comm->workFifoDone, MAXCHANNELS)); + NCCLCHECKGOTO(ncclCudaHostCalloc(&comm->workFifoDone, MAXCHANNELS), ret, fail); ncclCommPushCudaHostFree(comm, comm->workFifoDone); comm->workFifoSent = 0; comm->workFifoAckdMin = 0; @@ -410,14 +398,17 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c]; if (comm->channels[c].ring.userRanks != nullptr) { - NCCLCHECK(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.cudaStream)); + NCCLCHECKGOTO(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.cudaStream), ret, fail); } } - NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.cudaStream)); + NCCLCHECKGOTO(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.cudaStream), ret, fail); +exit: CUDACHECK(cudaStreamSynchronize(comm->deviceStream.cudaStream)); NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); - return ncclSuccess; + return ret; +fail: + goto exit; } // Pre-process the string so that running "strings" on the lib can quickly reveal the version. @@ -481,6 +472,8 @@ NCCL_PARAM(LlBuffSize, "LL_BUFFSIZE", -2); NCCL_PARAM(Ll128BuffSize, "LL128_BUFFSIZE", -2); NCCL_PARAM(P2pNetChunkSize, "P2P_NET_CHUNKSIZE", (1 << 17)); /* 128 kB */ +NCCL_PARAM(P2pPciChunkSize, "P2P_PCI_CHUNKSIZE", (1 << 17)); /* 128 kB */ +NCCL_PARAM(P2pNvlChunkSize, "P2P_NVL_CHUNKSIZE", (1 << 19)); /* 512 kB */ static ncclResult_t computeBuffSizes(struct ncclComm* comm) { int cpuArch, cpuVendor, cpuModel; @@ -495,7 +488,10 @@ static ncclResult_t computeBuffSizes(struct ncclComm* comm) { comm->buffSizes[p] = envs[p] != -2 ? envs[p] : defaults[p]; } - comm->p2pNetChunkSize = ncclParamP2pNetChunkSize(); + if (comm->nNodes > 1) comm->p2pChunkSize = ncclParamP2pNetChunkSize(); + else if (ncclTopoPathAllNVLink(comm->topo)) comm->p2pChunkSize = ncclParamP2pNvlChunkSize(); + else comm->p2pChunkSize = ncclParamP2pPciChunkSize(); + INFO(NCCL_INIT, "P2P Chunksize set to %d", comm->p2pChunkSize); return ncclSuccess; } @@ -504,90 +500,241 @@ NCCL_PARAM(CollNetNodeThreshold, "COLLNET_NODE_THRESHOLD", 2); NCCL_PARAM(NvbPreconnect, "NVB_PRECONNECT", 1); NCCL_PARAM(AllocP2pNetLLBuffers, "NCCL_ALLOC_P2P_NET_LL_BUFFERS", 0); +static ncclResult_t collNetTrySetup(ncclComm_t comm, struct ncclTopoGraph* collNetGraph) { + ncclResult_t ret = ncclSuccess; + int* heads = NULL; + int rank = comm->rank; + int collNetSetupFail = 0; + int highestTypes[NCCL_MAX_LOCAL_RANKS] = { TRANSPORT_P2P }; + // Find all head ranks + int nHeads = collNetGraph->nChannels; + int highestTransportType0, highestTransportType1; + char line[1024]; + + NCCLCHECKGOTO(ncclCalloc(&heads, nHeads), ret, fail); + // Head GPU index is always 0 + for (int c = 0; c < nHeads; c++) { + heads[c] = collNetGraph->intra[c * comm->localRanks + 0]; + } + + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channel = comm->channels + c; + for (int h = 0; h < nHeads; h++) { + const int head = heads[h]; + collNetSetupFail = ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetRecv); + if (!collNetSetupFail) collNetSetupFail = ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetSend); + } + // Verify CollNet setup across ranks after trying the first channel + if (c == 0) { + NCCLCHECKGOTO(ncclTransportCollNetCheck(comm, collNetSetupFail), ret, fail); + } + } + // Verify CollNet setup across ranks after trying all channels + NCCLCHECKGOTO(ncclTransportCollNetCheck(comm, collNetSetupFail), ret, fail); + TRACE(NCCL_INIT, "rank %d Connected inter-node CollNet", rank); + + line[0] = '\0'; + for (int c = 0; c < comm->nChannels; c++) { + struct ncclTree* chain = &comm->channels[c].collnetChain; + snprintf(line + strlen(line), 1023 - strlen(line), " [%d] %d->%d->%d", + c, chain->down[0], rank, chain->up); + } + line[1023] = '\0'; + + INFO(NCCL_INIT, "Collnet Chains %s", line); + // Connect Collnet + chain + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channel = comm->channels + c; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->collnetChain.up, 1, channel->collnetChain.down, 0), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, collNetGraph, 0), ret, fail); + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channel = comm->channels + c; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, channel->collnetChain.down, 1, &channel->collnetChain.up, 1), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, collNetGraph, 1), ret, fail); + INFO(NCCL_INIT, "Connected collnet + chain"); + + // Connect intra-node CollNet + Direct + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channelRecv = comm->channels + c; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.up, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.down, 0), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, collNetGraph, 0, &highestTransportType0), ret, fail); + + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channelSend = comm->channels + c; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.down, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.up, 1), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, collNetGraph, 1, &highestTransportType1), ret, fail); + + // Exchange highest intra-node transport type among ranks + // because we need to know whether all ranks can p2p each other to determine whether we can directly read/write registered user buffer + comm->intraHighestTransportType = highestTypes[comm->localRank] = highestTransportType0 > highestTransportType1 ? highestTransportType0 : highestTransportType1; + NCCLCHECKGOTO(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, highestTypes, sizeof(int)), ret, fail); + for (int i = 0; i < comm->localRanks; i++) { + if (highestTypes[i] > comm->intraHighestTransportType) + comm->intraHighestTransportType = highestTypes[i]; + } + + INFO(NCCL_INIT, "rank %d Connected CollNet", rank); + +exit: + free(heads); + return ret; +fail: + ncclTransportCollNetFree(comm); + comm->collNetSupport = 0; + goto exit; +} + static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* commId) { // We use 2 AllGathers // 1. { peerInfo, comm, compCap} // 2. { nChannels, graphInfo, topoRanks } - + ncclResult_t ret = ncclSuccess; int rank = comm->rank; int nranks = comm->nRanks; uint64_t commHash = getHash(commId->internal, NCCL_UNIQUE_ID_BYTES); + cpu_set_t affinitySave; + struct ncclTopoGraph ringGraph; + struct ncclTopoGraph treeGraph; + struct ncclTopoGraph collNetGraph; + + struct graphInfo { + int pattern; + int nChannels; + int sameChannels; + float bwIntra; + float bwInter; + int typeIntra; + int typeInter; + }; + + struct allGatherInfo { + int netDev; + int collNetSupport; + struct graphInfo tree; + struct graphInfo ring; + struct graphInfo collNet; + struct ncclTopoRanks topoRanks; + }; + + int nChannelsOrig; + struct allGatherInfo *allGather3Data = NULL; + struct ncclTopoRanks** allTopoRanks = NULL; + int *nodesFirstRank = NULL, *nodesTreePatterns = NULL; + int *rings = NULL; + int* nvbPeers = NULL; + struct ncclProxyConnector proxyConn; + int* pxnPeers = NULL; + TRACE(NCCL_INIT, "comm %p, commHash %lx, rank %d nranks %d - BEGIN", comm, commHash, rank, nranks); - NCCLCHECK(bootstrapInit(commId, comm)); + NCCLCHECKGOTO(bootstrapInit((struct ncclBootstrapHandle*)commId, comm), ret, fail); // AllGather1 - begin - NCCLCHECK(ncclCalloc(&comm->peerInfo, nranks+1)); // Extra rank to represent CollNet root - NCCLCHECK(fillInfo(comm, comm->peerInfo+rank, commHash)); - NCCLCHECK(bootstrapAllGather(comm->bootstrap, comm->peerInfo, sizeof(struct ncclPeerInfo))); + NCCLCHECKGOTO(ncclCalloc(&comm->peerInfo, nranks+1), ret, fail); // Extra rank to represent CollNet root + NCCLCHECKGOTO(fillInfo(comm, comm->peerInfo+rank, commHash), ret, fail); + NCCLCHECKGOTO(bootstrapAllGather(comm->bootstrap, comm->peerInfo, sizeof(struct ncclPeerInfo)), ret, fail); for (int i = 0; i < nranks; i++) { if ((i != rank) && (comm->peerInfo[i].hostHash == comm->peerInfo[rank].hostHash) && (comm->peerInfo[i].busId == comm->peerInfo[rank].busId)) { WARN("Duplicate GPU detected : rank %d and rank %d both on CUDA device %lx", rank, i, comm->peerInfo[rank].busId); - return ncclInvalidUsage; + ret = ncclInvalidUsage; + goto fail; } } - // AllGather1 - end + do { + // Compute intra-process ranks + int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0; + for (int i = 0; i < nranks; i++) { + if ((comm->peerInfo[i].hostHash == comm->peerInfo[rank].hostHash) + && (comm->peerInfo[i].pidHash == comm->peerInfo[rank].pidHash)) { + // Rank is in same process + if (intraProcRanks == 0) intraProcRank0 = i; + if (i == rank) intraProcRank = intraProcRanks; + intraProcRanks++; + if (intraProcRank0 == rank && rank != i) { + comm->peerInfo[i].comm->intraNext = comm->intraNext; + comm->intraNext = comm->peerInfo[i].comm; + } + } + } + TRACE(NCCL_INIT,"pidHash[%d] %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", + rank, comm->peerInfo[rank].pidHash, intraProcRank, intraProcRanks, intraProcRank0); + if (intraProcRank == -1 || intraProcRank0 == -1 || comm->peerInfo[intraProcRank0].comm == NULL) { + WARN("Failed to determine intra proc ranks rank %d hostHash %lx pidHash %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", + rank, comm->peerInfo[rank].hostHash, comm->peerInfo[rank].pidHash, + intraProcRank, intraProcRanks, intraProcRank0); + ret = ncclInternalError; + goto fail; + } + struct ncclComm* comm0 = comm->peerInfo[intraProcRank0].comm; + assert(intraProcRank==0 ? comm==comm0 : true); + comm->intraComm0 = comm0; + comm->intraRank = intraProcRank; + comm->intraRanks = intraProcRanks; + comm->intraBarrierPhase = 0; + comm->intraBarrierCounter = 0; + comm->intraBarrierGate = 0; + } while(0); + // Topo detection / System graph creation - NCCLCHECK(ncclTopoGetSystem(comm, &comm->topo)); + NCCLCHECKGOTO(ncclTopoGetSystem(comm, &comm->topo), ret, fail); // Compute paths between GPUs and NICs - NCCLCHECK(ncclTopoComputePaths(comm->topo, comm)); + NCCLCHECKGOTO(ncclTopoComputePaths(comm->topo, comm), ret, fail); // Remove inaccessible GPUs and unused NICs - NCCLCHECK(ncclTopoTrimSystem(comm->topo, comm)); + NCCLCHECKGOTO(ncclTopoTrimSystem(comm->topo, comm), ret, fail); // Recompute paths after trimming - NCCLCHECK(ncclTopoComputePaths(comm->topo, comm)); + NCCLCHECKGOTO(ncclTopoComputePaths(comm->topo, comm), ret, fail); // Init search - NCCLCHECK(ncclTopoSearchInit(comm->topo)); + NCCLCHECKGOTO(ncclTopoSearchInit(comm->topo), ret, fail); // Print final topology - NCCLCHECK(ncclTopoPrint(comm->topo)); + NCCLCHECKGOTO(ncclTopoPrint(comm->topo), ret, fail); // Set Affinity to a CPU local the our GPU, so that all memory we allocate // on the host is local. - NCCLCHECK(ncclTopoGetCpuAffinity(comm->topo, comm->rank, &comm->cpuAffinity)); - cpu_set_t affinitySave; + NCCLCHECKGOTO(ncclTopoGetCpuAffinity(comm->topo, comm->rank, &comm->cpuAffinity), ret, fail); if (CPU_COUNT(&comm->cpuAffinity)) { sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave); sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity); } - ncclResult_t ret; // Launch proxy service thread - NCCLCHECK(ncclProxyCreate(comm)); + NCCLCHECKGOTO(ncclProxyCreate(comm), ret, fail); // Get rings and trees - struct ncclTopoGraph ringGraph; ringGraph.id = 0; ringGraph.pattern = NCCL_TOPO_PATTERN_RING; ringGraph.collNet = 0; ringGraph.minChannels = 1; ringGraph.maxChannels = MAXCHANNELS/2; - NCCLCHECK(ncclTopoCompute(comm->topo, &ringGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &ringGraph)); + NCCLCHECKGOTO(ncclTopoCompute(comm->topo, &ringGraph), ret, fail); + NCCLCHECKGOTO(ncclTopoPrintGraph(comm->topo, &ringGraph), ret, fail); - struct ncclTopoGraph treeGraph; treeGraph.id = 1; treeGraph.pattern = NCCL_TOPO_PATTERN_BALANCED_TREE; treeGraph.collNet = 0; treeGraph.minChannels = 1; treeGraph.maxChannels = ringGraph.nChannels; - NCCLCHECK(ncclTopoCompute(comm->topo, &treeGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &treeGraph)); + NCCLCHECKGOTO(ncclTopoCompute(comm->topo, &treeGraph), ret, fail); + NCCLCHECKGOTO(ncclTopoPrintGraph(comm->topo, &treeGraph), ret, fail); - struct ncclTopoGraph collNetGraph; collNetGraph.id = 2; collNetGraph.pattern = NCCL_TOPO_PATTERN_TREE; collNetGraph.collNet = 1; collNetGraph.minChannels = collNetGraph.maxChannels = ringGraph.nChannels; - NCCLCHECK(ncclTopoCompute(comm->topo, &collNetGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &collNetGraph)); + NCCLCHECKGOTO(ncclTopoCompute(comm->topo, &collNetGraph), ret, fail); + NCCLCHECKGOTO(ncclTopoPrintGraph(comm->topo, &collNetGraph), ret, fail); // Initialize num P2P LL buffers for this communicator comm->allocP2pNetLLBuffers = ncclParamAllocP2pNetLLBuffers() == 1; if (comm->rank == ncclParamGraphDumpFileRank()) { struct ncclTopoGraph* graphs[3] = { &ringGraph, &treeGraph, &collNetGraph }; - NCCLCHECK(ncclTopoDumpGraphs(comm->topo, 3, graphs)); + NCCLCHECKGOTO(ncclTopoDumpGraphs(comm->topo, 3, graphs), ret, fail); } // Determine local CollNet support before all-gather @@ -603,28 +750,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm if (comm->collNetSupport == 1 && collNetGraph.nChannels <= 0) comm->collNetSupport = 0; // AllGather3 - begin - struct ncclGraphInfo { - int pattern; - int nChannels; - int sameChannels; - float bwIntra; - float bwInter; - int typeIntra; - int typeInter; - }; - - struct { - int netDev; - int collNetSupport; - struct ncclGraphInfo tree; - struct ncclGraphInfo ring; - struct ncclGraphInfo collNet; - struct ncclTopoRanks topoRanks; - } *allGather3Data; - - NCCLCHECK(ncclCalloc(&allGather3Data, nranks)); - - NCCLCHECK(ncclTopoGetLocalNet(comm->topo, rank, &allGather3Data[rank].netDev)); + NCCLCHECKGOTO(ncclCalloc(&allGather3Data, nranks), ret, fail); + NCCLCHECKGOTO(ncclTopoGetLocalNet(comm->topo, rank, &allGather3Data[rank].netDev), ret, fail); allGather3Data[rank].tree.pattern = treeGraph.pattern; allGather3Data[rank].tree.nChannels = treeGraph.nChannels; allGather3Data[rank].tree.sameChannels = treeGraph.sameChannels; @@ -649,15 +776,14 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm allGather3Data[rank].collNetSupport = comm->collNetSupport; comm->nChannels = std::min(treeGraph.nChannels, ringGraph.nChannels); - NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &collNetGraph, &allGather3Data[rank].topoRanks)); + NCCLCHECKGOTO(ncclTopoPreset(comm, &treeGraph, &ringGraph, &collNetGraph, &allGather3Data[rank].topoRanks), ret, fail); - NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data))); + NCCLCHECKGOTO(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data)), ret, fail); // Determine nNodes, firstRanks, ... - int *nodesFirstRank, *nodesTreePatterns; - NCCLCHECK(ncclCalloc(&nodesFirstRank, nranks)); - NCCLCHECK(ncclCalloc(&nodesTreePatterns, nranks)); - NCCLCHECK(ncclCalloc(&comm->rankToNode, comm->nRanks)); + NCCLCHECKGOTO(ncclCalloc(&nodesFirstRank, nranks), ret, fail); + NCCLCHECKGOTO(ncclCalloc(&nodesTreePatterns, nranks), ret, fail); + NCCLCHECKGOTO(ncclCalloc(&comm->rankToNode, comm->nRanks), ret, fail); for (int r=0; rrankToNode[r] = node; } // Now that we know nNodes, alloc nodeRanks and compute localRanks for each node - NCCLCHECK(ncclCalloc(&comm->nodeRanks, comm->nNodes)); - NCCLCHECK(ncclCalloc(&comm->rankToLocalRank, comm->nRanks)); + NCCLCHECKGOTO(ncclCalloc(&comm->nodeRanks, comm->nNodes), ret, fail); + NCCLCHECKGOTO(ncclCalloc(&comm->rankToLocalRank, comm->nRanks), ret, fail); for (int r=0; rnRanks; r++) { int node = comm->rankToNode[r]; comm->rankToLocalRank[r] = comm->nodeRanks[node].localRanks; @@ -680,7 +806,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } // Allocate ranks arrays for each node for (int n=0; nnNodes; n++) { - NCCLCHECK(ncclCalloc(&comm->nodeRanks[n].localRankToRank, comm->nodeRanks[n].localRanks)); + NCCLCHECKGOTO(ncclCalloc(&comm->nodeRanks[n].localRankToRank, comm->nodeRanks[n].localRanks), ret, fail); comm->maxLocalRanks = std::max(comm->maxLocalRanks, comm->nodeRanks[n].localRanks); comm->nodeRanks[n].localRanks = 0; } @@ -700,12 +826,12 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm WARN("Failed to determine local ranks rank %d hostHash %lx pidHash %lx localRank %d localRanks %d localRank0 %d", rank, comm->peerInfo[rank].hostHash, comm->peerInfo[rank].pidHash, comm->localRank, comm->localRanks, comm->localRankToRank[0]); - return ncclInternalError; + ret = ncclInternalError; + goto fail; } - int nChannelsOrig = comm->nChannels; - struct ncclTopoRanks** allTopoRanks; - NCCLCHECK(ncclCalloc(&allTopoRanks, comm->nRanks)); + nChannelsOrig = comm->nChannels; + NCCLCHECKGOTO(ncclCalloc(&allTopoRanks, comm->nRanks), ret, fail); for (int i=0; ipeerInfo[i].netDev = allGather3Data[i].netDev; allTopoRanks[i] = &allGather3Data[i].topoRanks; @@ -754,15 +880,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } } - int *rings; - NCCLCHECK(ncclCalloc(&rings, nranks*MAXCHANNELS)); - NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph)); - - free(allTopoRanks); - free(nodesTreePatterns); - free(nodesFirstRank); - free(allGather3Data); - + NCCLCHECKGOTO(ncclCalloc(&rings, nranks*MAXCHANNELS), ret, fail); + NCCLCHECKGOTO(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph), ret, fail); // AllGather3 - end TRACE(NCCL_INIT, "rank %d nranks %d - BUILT %d TREES/RINGS", rank, nranks, comm->nChannels); @@ -778,110 +897,31 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm line[1023] = '\0'; INFO(NCCL_INIT, "Trees%s", line); - NCCLCHECK(computeBuffSizes(comm)); + NCCLCHECKGOTO(computeBuffSizes(comm), ret, fail); // Connect with prev/next for each ring for (int c=0; cnChannels; c++) { struct ncclChannel* channel = comm->channels+c; - NCCLCHECKGOTO(setupChannel(comm, c, rank, nranks, rings+c*nranks), ret, affinity_restore); + NCCLCHECKGOTO(setupChannel(comm, c, rank, nranks, rings+c*nranks), ret, fail); if (comm->nRanks == 1) continue; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->ring.prev, 1, &channel->ring.next, 0), ret, affinity_restore); + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->ring.prev, 1, &channel->ring.next, 0), ret, fail); } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, 0), ret, affinity_restore); - free(rings); + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, 0), ret, fail); INFO(NCCL_INIT, "Connected all rings"); // Connect Trees for (int c=0; cnChannels; c++) { struct ncclChannel* channel = comm->channels+c; if (comm->nRanks == 1) continue; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_TREE_ARITY, channel->tree.down, 1, &channel->tree.up, 0), ret, affinity_restore); - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->tree.up, NCCL_MAX_TREE_ARITY, channel->tree.down, 0), ret, affinity_restore); + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_TREE_ARITY, channel->tree.down, 1, &channel->tree.up, 0), ret, fail); + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->tree.up, NCCL_MAX_TREE_ARITY, channel->tree.down, 0), ret, fail); } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, 0), ret, affinity_restore); + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, 0), ret, fail); INFO(NCCL_INIT, "Connected all trees"); // Check if we can setup CollNet - if (comm->collNetSupport > 0) { - int collNetSetupFail = 0; - int highestTypes[NCCL_MAX_LOCAL_RANKS] = {TRANSPORT_P2P}; - // Find all head ranks - int nHeads = collNetGraph.nChannels; - int *heads; - NCCLCHECK(ncclCalloc(&heads, nHeads)); - // Head GPU index is always 0 - for (int c=0; clocalRanks+0]; - } - for (int c=0; cnChannels; c++) { - struct ncclChannel* channel = comm->channels+c; - for (int h=0; hcollNetSupport > 0) collNetTrySetup(comm, &collNetGraph); - char line[1024]; - line[0]='\0'; - for (int c=0; cnChannels; c++) { - struct ncclTree* chain = &comm->channels[c].collnetChain; - snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d->%d->%d", - c, chain->down[0], rank, chain->up); - } - line[1023] = '\0'; - INFO(NCCL_INIT, "Collnet Chains %s", line); - // Connect Collnet + chain - for (int c=0; cnChannels; c++) { - struct ncclChannel* channel = comm->channels+c; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->collnetChain.up, 1, channel->collnetChain.down, 0), ret, collnet_cleanup); - } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 0), ret, collnet_cleanup); - for (int c=0; cnChannels; c++) { - struct ncclChannel* channel = comm->channels+c; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, channel->collnetChain.down, 1, &channel->collnetChain.up, 1), ret, collnet_cleanup); - } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 1), ret, collnet_cleanup); - INFO(NCCL_INIT, "Connected collnet + chain"); - - // Connect intra-node CollNet + Direct - int highestTransportType0, highestTransportType1; - for (int c=0; cnChannels; c++) { - struct ncclChannel* channelRecv = comm->channels+c; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.up, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.down, 0), ret, collnet_cleanup); - } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 0, &highestTransportType0), ret, collnet_cleanup); - for (int c=0; cnChannels; c++) { - struct ncclChannel* channelSend = comm->channels+c; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.down, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.up, 1), ret, collnet_cleanup); - } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 1, &highestTransportType1), ret, collnet_cleanup); - - // Exchange highest intra-node transport type among ranks - // because we need to know whether all ranks can p2p each other to determine whether we can directly read/write registered user buffer - comm->intraHighestTransportType = highestTypes[comm->localRank] = highestTransportType0 > highestTransportType1 ? highestTransportType0 : highestTransportType1; - NCCLCHECK(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, highestTypes, sizeof(int))); - for (int i=0; ilocalRanks; i++) { - if (highestTypes[i] > comm->intraHighestTransportType) - comm->intraHighestTransportType = highestTypes[i]; - } - INFO(NCCL_INIT, "rank %d Connected CollNet", rank); - -collnet_cleanup: - free(heads); - if (ret != ncclSuccess) { - NCCLCHECK(ncclTransportCollNetFree(comm)); - comm->collNetSupport = 0; - ret = ncclSuccess; - } - } TRACE(NCCL_INIT, "rank %d nranks %d - CONNECTED %d RINGS AND TREES", rank, nranks, comm->nChannels); // Compute time models for algorithm and protocol combinations @@ -892,11 +932,11 @@ collnet_cleanup: minCompCap = std::min(comm->peerInfo[i].cudaCompCap, minCompCap); maxCompCap = std::max(comm->peerInfo[i].cudaCompCap, maxCompCap); } - NCCLCHECK(ncclTopoTuneModel(comm, minCompCap, maxCompCap, &treeGraph, &ringGraph, &collNetGraph)); + NCCLCHECKGOTO(ncclTopoTuneModel(comm, minCompCap, maxCompCap, &treeGraph, &ringGraph, &collNetGraph), ret, fail); } while(0); // Compute nChannels per peer for p2p - NCCLCHECK(ncclTopoComputeP2pChannels(comm)); + NCCLCHECKGOTO(ncclTopoComputeP2pChannels(comm), ret, fail); do { // Setup p2p structures in comm->tasks struct ncclTasks* tasks = &comm->tasks; @@ -947,80 +987,40 @@ collnet_cleanup: if (ncclParamNvbPreconnect()) { // Connect p2p when using NVB path int nvbNpeers; - int* nvbPeers; - NCCLCHECK(ncclTopoGetNvbGpus(comm->topo, comm->rank, &nvbNpeers, &nvbPeers)); + NCCLCHECKGOTO(ncclTopoGetNvbGpus(comm->topo, comm->rank, &nvbNpeers, &nvbPeers), ret, fail); for (int r=0; rp2pnChannelsPerPeer; c++) { - NCCLCHECK(ncclChannelCompute(comm, peer, c, ncclFuncSend, &channelId)); + NCCLCHECKGOTO(ncclChannelCompute(comm, peer, c, ncclFuncSend, &channelId), ret, fail); if (comm->channels[channelId].peers[peer].send[1].connected == 0) { comm->connectSend[peer] |= (1UL<p2pnChannelsPerPeer; c++) { - NCCLCHECK(ncclChannelCompute(comm, peer, c, ncclFuncRecv, &channelId)); + NCCLCHECKGOTO(ncclChannelCompute(comm, peer, c, ncclFuncRecv, &channelId), ret, fail); if (comm->channels[channelId].peers[peer].recv[1].connected == 0) { comm->connectRecv[peer] |= (1UL<rank, &proxyConn)); - NCCLCHECK(ncclProxyCall(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0)); + NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_NET, 1, comm->rank, &proxyConn), ret, fail); + NCCLCHECKGOTO(ncclProxyCall(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); // Then to remote ones when using PXN if (ncclPxnDisable(comm) == 0) { int nranks; - int* pxnPeers; - NCCLCHECK(ncclTopoGetPxnRanks(comm, &pxnPeers, &nranks)); + NCCLCHECKGOTO(ncclTopoGetPxnRanks(comm, &pxnPeers, &nranks), ret, fail); for (int r=0; rp2pnChannels, sizeof(int), NULL, 0)); + NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_NET, 1, pxnPeers[r], &proxyConn), ret, fail); + NCCLCHECKGOTO(ncclProxyCall(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); } - free(pxnPeers); } - do { - // Compute intra-process ranks - int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0; - for (int i = 0; i < nranks; i++) { - if ((comm->peerInfo[i].hostHash == comm->peerInfo[rank].hostHash) - && (comm->peerInfo[i].pidHash == comm->peerInfo[rank].pidHash)) { - // Rank is in same process - if (intraProcRanks == 0) intraProcRank0 = i; - if (i == rank) intraProcRank = intraProcRanks; - intraProcRanks++; - if (intraProcRank0 == rank && rank != i) { - comm->peerInfo[i].comm->intraNext = comm->intraNext; - comm->intraNext = comm->peerInfo[i].comm; - } - } - } - TRACE(NCCL_INIT,"pidHash[%d] %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", - rank, comm->peerInfo[rank].pidHash, intraProcRank, intraProcRanks, intraProcRank0); - if (intraProcRank == -1 || intraProcRank0 == -1 || comm->peerInfo[intraProcRank0].comm == NULL) { - WARN("Failed to determine intra proc ranks rank %d hostHash %lx pidHash %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", - rank, comm->peerInfo[rank].hostHash, comm->peerInfo[rank].pidHash, - intraProcRank, intraProcRanks, intraProcRank0); - return ncclInternalError; - } - struct ncclComm* comm0 = comm->peerInfo[intraProcRank0].comm; - assert(intraProcRank==0 ? comm==comm0 : true); - comm->intraComm0 = comm0; - comm->intraRefs = intraProcRank==0 ? intraProcRanks : 0; - comm->intraRank = intraProcRank; - comm->intraRanks = intraProcRanks; - comm->intraBarrierPhase = 0; - comm->intraBarrierCounter = 0; - comm->intraBarrierGate = 0; - } while(0); - if (comm->intraRank == 0) { // Load ncclParamLaunchMode char* str = getenv("NCCL_LAUNCH_MODE"); enum ncclLaunchMode mode, modeOld; @@ -1037,22 +1037,31 @@ collnet_cleanup: } } - NCCLCHECKGOTO(devCommSetup(comm), ret, affinity_restore); + // Call devCommSetup before the last barrier, making sure we don't have a thread running in front and starting to + // launch NCCL kernels before all cuda mem allocation is complete. That could cause a deadlock. + NCCLCHECKGOTO(devCommSetup(comm), ret, fail); /* Local intra-node barrier */ - NCCLCHECK(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0])); - - // Unlink proxy shm to make sure it will be properly cleaned up. - NCCLCHECK(ncclProxyShmUnlink(comm)); + NCCLCHECKGOTO(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]), ret, fail); // We should have allocated all buffers, collective fifos, ... we can // restore the affinity. -affinity_restore: - if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave); - if (ret != ncclSuccess) return ret; - TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks); - return ncclSuccess; + +exit: + if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave); + // Unlink proxy shm to make sure it will be properly cleaned up. + ncclProxyShmUnlink(comm); + free(allTopoRanks); + free(nodesTreePatterns); + free(nodesFirstRank); + free(allGather3Data); + free(rings); + free(nvbPeers); + free(pxnPeers); + return ret; +fail: + goto exit; } NCCL_PARAM(SetStackSize, "SET_STACK_SIZE", 0); @@ -1080,25 +1089,25 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { int cudaDev = job->cudaDev; ncclResult_t res = ncclSuccess; - CUDACHECK(cudaSetDevice(cudaDev)); + CUDACHECKGOTO(cudaSetDevice(cudaDev), res, fail); // Set the maximum kernel stack size of all kernels to avoid // a CUDA memory reconfig on load (c.f. NVSHMEM issue) if (maxLocalSizeBytes > 0 && ncclParamSetStackSize() == 1) { TRACE(NCCL_INIT, "Setting cudaLimitStackSize to %zi", maxLocalSizeBytes); CUDACHECKIGNORE(cudaDeviceSetLimit(cudaLimitStackSize, maxLocalSizeBytes)); } - NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup); - NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup); + NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, fail); + NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, fail); // update communicator state comm->initState = ncclSuccess; - INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId); - TRACE_CALL("ncclCommInitRank(%p,%d,0x%llx,%d,%d)", *newcomm, nranks, (unsigned long long)hashUniqueId(commId), myrank, (*newcomm)->cudaDev); - return ncclSuccess; -cleanup: - comm->initState = res; + INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx commId 0x%llx - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId, (unsigned long long)hashUniqueId(commId)); +exit: return res; +fail: + comm->initState = res; + goto exit; } static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { @@ -1122,13 +1131,13 @@ static void ncclCommInitRankUndo(struct ncclAsyncJob* job_) { } static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank, int cudaDev, ncclConfig_t *config) { - ncclResult_t res; + ncclResult_t res = ncclSuccess; ncclComm_t comm = NULL; struct ncclCommInitRankAsyncJob *job = NULL; char* env = getenv("NCCL_COMM_ID"); if (env && myrank == 0) { INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", env); - NCCLCHECKGOTO(bootstrapCreateRoot(&commId, true), res, fail); + NCCLCHECKGOTO(bootstrapCreateRoot((struct ncclBootstrapHandle*)&commId, true), res, fail); } NCCLCHECKGOTO(ncclInit(), res, fail); @@ -1146,8 +1155,6 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUni NCCLCHECKGOTO(ncclCalloc(&comm, 1), res, fail); NCCLCHECKGOTO(ncclCudaHostCalloc((uint32_t**)&comm->abortFlag, 1), res, fail); - // set up comm state and abortFlag only - *comm->abortFlag = 0; NCCLCHECKGOTO(parseCommConfig(comm, config), res, fail); /* start with ncclInternalError and will be changed to ncclSuccess if init succeeds. */ comm->initState = ncclInternalError; @@ -1164,28 +1171,53 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUni exit: return ncclGroupErrCheck(res); fail: + if (job) free(job); + if (comm) { + if (comm->abortFlag) ncclCudaHostFree((void *)comm->abortFlag); + free(comm); + } + if (newcomm) *newcomm = NULL; goto exit; } +struct NvtxParamsCommInitRank +{ + int rank; + int nranks; + int cudaDev; +}; +constexpr nvtxPayloadSchemaEntry_t CommInitRankSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Rank"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "No. of ranks", nullptr, 0, offsetof(NvtxParamsCommInitRank, nranks)}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "CUDA device", nullptr, 0, offsetof(NvtxParamsCommInitRank, cudaDev)}, +}; + NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank); ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank) { - NVTX3_FUNC_RANGE_IN(nccl_domain); - // Load the CUDA driver and dlsym hooks (can fail on old drivers) (void)ncclCudaLibraryInit(); int cudaDev; CUDACHECK(cudaGetDevice(&cudaDev)); + + NvtxParamsCommInitRank payload{myrank, nranks, cudaDev}; + NVTX3_FUNC_WITH_PARAMS(CommInitRank, CommInitRankSchema, payload) + NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, NULL)); return ncclSuccess; } NCCL_API(ncclResult_t, ncclCommInitAll, ncclComm_t* comms, int ndev, const int* devlist); ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { - NVTX3_FUNC_RANGE_IN(nccl_domain); ncclResult_t ret = ncclSuccess; int totalnDev; int *gpuFlags = NULL; + + constexpr nvtxPayloadSchemaEntry_t CommInitAllSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "No. of devices"} + }; + NVTX3_FUNC_WITH_PARAMS(CommInitAll, CommInitAllSchema, ndev) + // Load the CUDA driver and dlsym hooks (can fail on old drivers) (void)ncclCudaLibraryInit(); @@ -1297,9 +1329,8 @@ static ncclResult_t commDestroySync(struct ncclAsyncJob* job_) { struct ncclCommFinalizeAsyncJob* job = (struct ncclCommFinalizeAsyncJob*) job_; ncclComm_t comm = job->comm; int savedDevice; - CUDACHECK(cudaGetDevice(&savedDevice)); int commDevice = comm->cudaDev; - ncclResult_t ret; + ncclResult_t ret = ncclSuccess; CUDACHECKGOTO(cudaGetDevice(&savedDevice), ret, fail); if (savedDevice != commDevice) { @@ -1322,6 +1353,7 @@ static ncclResult_t commDestroySync(struct ncclAsyncJob* job_) { CUDACHECKGOTO(cudaSetDevice(savedDevice), ret, fail); } + comm->finalizeCalled = true; exit: return ret; fail: @@ -1350,7 +1382,6 @@ static ncclResult_t commFinalize(ncclComm_t comm, bool userCalled) { ncclResult_t ret = ncclSuccess; struct ncclCommFinalizeAsyncJob *job = NULL; - comm->finalizeCalled = true; /* launch async thread to finalize comm. */ NCCLCHECKGOTO(ncclCalloc(&job, 1), ret, fail); job->comm = comm; @@ -1412,9 +1443,9 @@ static ncclResult_t commReclaim(ncclComm_t comm) { NCCLCHECKGOTO(commFinalize(comm, false), ret, fail); } - if (comm->initState != ncclSuccess) { - /* if init errors happen, no finalize thread should have been launched. Main thread can reclaim - * everything since no NCCL kernel was issued. */ + if (comm->intraComm0 == NULL) { + /* if init errors happen and comm->intraComm0 == NULL, no proxy connection is built up, and no finalize thread + * have been launched. Main thread can reclaim everything since no NCCL kernel was issued. */ struct ncclCommFinalizeAsyncJob job; job.comm = comm; @@ -1424,6 +1455,10 @@ static ncclResult_t commReclaim(ncclComm_t comm) { WARN("commReclaim: comm %p (rank = %d) in abort, error %d", comm, curRank, ret); } + if ((ret = ncclProxyDestroy(comm)) != ncclSuccess) { + WARN("commReclaim: comm %p (rank = %d) destroys proxy resource error %d", comm, curRank, ret); + } + if ((ret = commCleanup(comm)) != ncclSuccess) { WARN("commReclaim: cleanup comm %p rank %d failed in destroy/abort, error %d", comm, curRank, ret); } @@ -1439,18 +1474,51 @@ static ncclResult_t commReclaim(ncclComm_t comm) { ncclComm_t curIntraComm; ncclComm_t nextIntraComm = intracomm0; + /* this is the last call to ncclCommDestroy/Abort, we need to make sure all comms + * in the process have been finalized before we free local resources. */ while (nextIntraComm) { curIntraComm = nextIntraComm; curRank = curIntraComm->rank; nextIntraComm = nextIntraComm->intraNext; - if (comm->finalizeCalled == false) { + if (curIntraComm->finalizeCalled == false) { struct ncclCommFinalizeAsyncJob job; job.comm = curIntraComm; /* every comm aborts, commDestroySync should not be blocked. */ if ((ret = commDestroySync((struct ncclAsyncJob*) &job)) != ncclSuccess) WARN("commReclaim: comm %p (rank = %d) in abort, error %d", curIntraComm, curRank, ret); } + } + + /* ncclProxyDestroy() loop must be put after commDestroySync() loop. Namely, you cannot do: + * while(...) { + * commDestroySync(...); + * ncclProxyDestroy(...); + * } + * Considering one process multi-gpu case, we must guarantee all kernels are complete before + * we free proxy resources; otherwise, we will face invalid memory issues where proxy connection + * and related intermediate memory from one rank are freed but other ranks are still using it. + * This is not a problem for multi-process case, since intermediate memory is opened by CUDA IPC + * or mmap where memory free is guarded by CUDA driver and operating system, so we will not have + * invalid memory access issue. */ + nextIntraComm = intracomm0; + while (nextIntraComm) { + curIntraComm = nextIntraComm; + curRank = curIntraComm->rank; + nextIntraComm = nextIntraComm->intraNext; + + /* free intraprocess proxy resources. */ + if ((ret = ncclProxyDestroy(curIntraComm)) != ncclSuccess) { + WARN("commReclaim: comm %p (rank = %d) destroys proxy resource error %d", curIntraComm, curRank, ret); + } + } + + /* free local resources. */ + nextIntraComm = intracomm0; + while (nextIntraComm) { + curIntraComm = nextIntraComm; + curRank = curIntraComm->rank; + nextIntraComm = nextIntraComm->intraNext; if ((ret = commCleanup(curIntraComm)) != ncclSuccess) { WARN("commReclaim: cleanup comm %p rank %d failed in destroy/abort, error %d", curIntraComm, curRank, ret); @@ -1467,11 +1535,16 @@ fail: NCCL_API(ncclResult_t, ncclCommDestroy, ncclComm_t comm); ncclResult_t ncclCommDestroy(ncclComm_t comm) { - NVTX3_FUNC_RANGE_IN(nccl_domain); - if (comm == NULL) + if (comm == NULL) { + NVTX3_FUNC_RANGE_IN(nccl_domain); return ncclSuccess; + } int rank = comm->rank, nranks = comm->nRanks, cudaDev = comm->cudaDev; + + NvtxParamsCommInitRank payload{rank, nranks, cudaDev}; + NVTX3_FUNC_WITH_PARAMS(CommDestroy, CommInitRankSchema, payload) + int64_t busId = comm->busId; TRACE(NCCL_INIT, "comm %p rank %d nRanks %d cudaDev %d busId %lx", comm, rank, nranks, cudaDev, busId); // Try and prevent a double free of the comm struct (user error) @@ -1491,11 +1564,16 @@ ncclResult_t ncclCommDestroy(ncclComm_t comm) { NCCL_API(ncclResult_t, ncclCommAbort, ncclComm_t comm); ncclResult_t ncclCommAbort(ncclComm_t comm) { - NVTX3_FUNC_RANGE_IN(nccl_domain); - if (comm == NULL) + if (comm == NULL) { + NVTX3_FUNC_RANGE_IN(nccl_domain); return ncclSuccess; + } int rank = comm->rank, nranks = comm->nRanks, cudaDev = comm->cudaDev; + + NvtxParamsCommInitRank payload{rank, nranks, cudaDev}; + NVTX3_FUNC_WITH_PARAMS(CommAbort, CommInitRankSchema, payload) + int64_t busId = comm->busId; TRACE(NCCL_INIT, "comm %p rank %d nRanks %d cudaDev %d busId %lx", comm, rank, nranks, cudaDev, busId); diff --git a/src/init_nvtx.cc b/src/init_nvtx.cc new file mode 100644 index 0000000..44face6 --- /dev/null +++ b/src/init_nvtx.cc @@ -0,0 +1,26 @@ +#include "nccl.h" +#include "nvtx.h" + +static constexpr const nvtxPayloadEnum_t NvtxEnumRedSchema[] = { + {"Sum", ncclSum}, + {"Product", ncclProd}, + {"Max", ncclMax}, + {"Min", ncclMin}, + {"Avg", ncclAvg} +}; + +// Must be called before the first call to any reduction operation. +void initNvtxRegisteredEnums() { + // Register schemas and strings + constexpr const nvtxPayloadEnumAttr_t eAttr { + .fieldMask = NVTX_PAYLOAD_ENUM_ATTR_ENTRIES | NVTX_PAYLOAD_ENUM_ATTR_NUM_ENTRIES | + NVTX_PAYLOAD_ENUM_ATTR_SIZE | NVTX_PAYLOAD_ENUM_ATTR_SCHEMA_ID, + .name = NULL, + .entries = NvtxEnumRedSchema, + .numEntries = std::extent::value, + .sizeOfEnum = sizeof(ncclRedOp_t), + .schemaId = NVTX_PAYLOAD_ENTRY_NCCL_REDOP + }; + + nvtxPayloadEnumRegister(nvtx3::domain::get(), &eAttr); +} diff --git a/src/misc/cudawrap.cc b/src/misc/cudawrap.cc index b1786f4..e2c1a6f 100644 --- a/src/misc/cudawrap.cc +++ b/src/misc/cudawrap.cc @@ -87,7 +87,7 @@ static void initOnceFunc() { cudaLib = dlopen(path, RTLD_LAZY); if (cudaLib == NULL) { - WARN("Failed to find CUDA library in %s (NCCL_CUDA_PATH=%s)", ncclCudaPath, ncclCudaPath); + WARN("Failed to find CUDA library (NCCL_CUDA_PATH='%s') : %s", ncclCudaPath ? ncclCudaPath : "", dlerror()); goto error; } diff --git a/src/misc/shmutils.cc b/src/misc/shmutils.cc index a432ff6..9f17903 100644 --- a/src/misc/shmutils.cc +++ b/src/misc/shmutils.cc @@ -15,79 +15,152 @@ #include #include -// Change functions behavior to match other SYS functions -static int shm_allocate(int fd, const int shmSize) { - int err = posix_fallocate(fd, 0, shmSize); - if (err) { errno = err; return -1; } - return 0; -} -static int shm_map(int fd, const int shmSize, void** ptr) { - *ptr = mmap(NULL, shmSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - return (*ptr == MAP_FAILED) ? -1 : 0; +struct shmHandleInternal { + int fd; + char* shmPath; + char* shmPtr; + void* devShmPtr; + size_t shmSize; + size_t realShmSize; + int* refcount; +}; + +static void shmHandleInit(int fd, char* shmPath, size_t shmSize, size_t realShmSize, char* hptr, void* dptr, bool create, struct shmHandleInternal* handle) { + handle->fd = fd; + handle->shmPtr = hptr; + handle->devShmPtr = dptr; + handle->shmSize = shmSize; + handle->realShmSize = realShmSize; + handle->refcount = (int*)(hptr + shmSize); + if (create) { + int slen = strlen(shmPath); + handle->shmPath = (char*)malloc(slen + 1); + memcpy(handle->shmPath, shmPath, slen + 1); + if (hptr) memset(hptr, 0, shmSize); + } else { + handle->shmPath = NULL; + } + return; } -static ncclResult_t ncclShmSetup(char* shmPath, const int shmSize, int* fd, void** ptr, int create) { +ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** devShmPtr, int refcount, ncclShmHandle_t* handle) { + int fd = -1; + char* hptr = NULL; + void* dptr = NULL; + ncclResult_t ret = ncclSuccess; + struct shmHandleInternal* tmphandle; + bool create = refcount > 0 ? true : false; + const size_t refSize = sizeof(int); /* extra sizeof(int) bytes for reference count */ + const size_t realShmSize = shmSize + refSize; + + *handle = *shmPtr = NULL; /* assume shmPtr and handle always set correctly by users. */ + EQCHECKGOTO(tmphandle = (struct shmHandleInternal*)malloc(sizeof(struct shmHandleInternal)), NULL, ret, fail); if (create) { + /* refcount > 0 means the caller tries to allocate a shared memory. This shared memory segment will have + * refcount references; when the peer attaches, it should pass -1 to reduce one reference count. When it + * goes down to 0, unlink should be called in order to delete shared memory file. */ if (shmPath[0] == '\0') { sprintf(shmPath, "/dev/shm/nccl-XXXXXX"); - *fd = mkstemp(shmPath); + fd = mkstemp(shmPath); } else { - SYSCHECKVAL(open(shmPath, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR), "open", *fd); + SYSCHECKGOTO(fd = open(shmPath, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR), ret, fail); } - if (ftruncate(*fd, shmSize) != 0) { - WARN("Error: failed to extend %s to %d bytes", shmPath, shmSize); - return ncclSystemError; + + if (ftruncate(fd, realShmSize) != 0) { + WARN("Error: failed to extend %s to %ld bytes", shmPath, realShmSize); + ret = ncclSystemError; + goto fail; } - INFO(NCCL_ALLOC, "Allocated %d bytes of shared memory in %s\n", shmSize, shmPath); + INFO(NCCL_ALLOC, "Allocated %ld bytes of shared memory in %s", realShmSize, shmPath); } else { - SYSCHECKVAL(open(shmPath, O_RDWR, S_IRUSR | S_IWUSR), "open", *fd); - } - *ptr = (char*)mmap(NULL, shmSize, PROT_READ|PROT_WRITE, MAP_SHARED, *fd, 0); - if (*ptr == NULL) { - WARN("Could not map %s\n", shmPath); - return ncclSystemError; - } - close(*fd); - *fd = -1; - if (create) memset(*ptr, 0, shmSize); - return ncclSuccess; -} - -ncclResult_t ncclShmOpen(char* shmPath, const int shmSize, void** shmPtr, void** devShmPtr, int create) { - int fd = -1; - void* ptr = MAP_FAILED; - ncclResult_t res = ncclSuccess; - - NCCLCHECKGOTO(ncclShmSetup(shmPath, shmSize, &fd, &ptr, create), res, sysError); - if (devShmPtr) { - CUDACHECKGOTO(cudaHostRegister(ptr, shmSize, cudaHostRegisterMapped), res, cudaError); - CUDACHECKGOTO(cudaHostGetDevicePointer(devShmPtr, ptr, 0), res, cudaError); + SYSCHECKGOTO(fd = open(shmPath, O_RDWR, S_IRUSR | S_IWUSR), ret, fail); } - *shmPtr = ptr; - return ncclSuccess; -sysError: - WARN("Error while %s shared memory segment %s (size %d)", create ? "creating" : "attaching to", shmPath, shmSize); -cudaError: - if (fd != -1) close(fd); - if (create) shm_unlink(shmPath); - if (ptr != MAP_FAILED) munmap(ptr, shmSize); - *shmPtr = NULL; - return res; -} + hptr = (char*)mmap(NULL, realShmSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (hptr == MAP_FAILED) { + WARN("Could not map %s size %zi, error: %s", shmPath, realShmSize, strerror(errno)); + ret = ncclSystemError; + goto fail; + } -ncclResult_t ncclShmUnlink(const char* shmPath) { - if (shmPath != NULL) SYSCHECK(unlink(shmPath), "unlink"); - return ncclSuccess; -} + if (create) { + *(int*)(hptr + shmSize) = refcount; + } else { + int remref = __atomic_sub_fetch((int*)(hptr + shmSize), 1, __ATOMIC_RELAXED); + if (remref == 0) { + /* the last peer has completed attachment, it should unlink the shm mem file. */ + if (unlink(shmPath) != 0) { + WARN("unlink shared memory %s failed, error: %s", shmPath, strerror(errno)); + } + } -ncclResult_t ncclShmClose(void* shmPtr, void* devShmPtr, const int shmSize) { - if (shmPtr) { - if (devShmPtr) CUDACHECK(cudaHostUnregister(shmPtr)); - if (munmap(shmPtr, shmSize) != 0) { - WARN("munmap of shared memory failed"); - return ncclSystemError; + if (refcount != -1) { + WARN("attaching memory should only reduce refcount by 1 but %d is passed", refcount); } } - return ncclSuccess; + + if (devShmPtr) { + CUDACHECKGOTO(cudaHostRegister((void*)hptr, realShmSize, cudaHostRegisterMapped), ret, fail); + CUDACHECKGOTO(cudaHostGetDevicePointer(&dptr, (void*)hptr, 0), ret, fail); + } + + shmHandleInit(fd, shmPath, shmSize, realShmSize, hptr, dptr, create, tmphandle); +exit: + *shmPtr = hptr; + if (devShmPtr) *devShmPtr = dptr; + *handle = (ncclShmHandle_t)tmphandle; + return ret; +fail: + WARN("Error while %s shared memory segment %s (size %ld)", create ? "creating" : "attaching to", shmPath, shmSize); + if (tmphandle) { + shmHandleInit(fd, shmPath, shmSize, realShmSize, hptr, dptr, create, tmphandle); + ncclShmClose((ncclShmHandle_t)tmphandle); + tmphandle = NULL; + } + hptr = NULL; + dptr = NULL; + goto exit; +} + +ncclResult_t ncclShmClose(ncclShmHandle_t handle) { + ncclResult_t ret = ncclSuccess; + struct shmHandleInternal* tmphandle = (struct shmHandleInternal*)handle; + if (tmphandle) { + if (tmphandle->fd >= 0) { + close(tmphandle->fd); + if (tmphandle->shmPath != NULL && *tmphandle->refcount > 0) { + if (unlink(tmphandle->shmPath) != 0) { + WARN("unlink shared memory %s failed, error: %s", tmphandle->shmPath, strerror(errno)); + ret = ncclSystemError; + } + free(tmphandle->shmPath); + } + } + + if (tmphandle->shmPtr) { + if (tmphandle->devShmPtr) CUDACHECK(cudaHostUnregister(tmphandle->shmPtr)); + if (munmap(tmphandle->shmPtr, tmphandle->realShmSize) != 0) { + WARN("munmap of shared memory %p size %ld failed, error: %s", tmphandle->shmPtr, tmphandle->realShmSize, strerror(errno)); + ret = ncclSystemError; + } + } + free(tmphandle); + } + return ret; +} + +ncclResult_t ncclShmUnlink(ncclShmHandle_t handle) { + ncclResult_t ret = ncclSuccess; + struct shmHandleInternal* tmphandle = (struct shmHandleInternal*)handle; + if (tmphandle) { + if (tmphandle->shmPath != NULL) { + if (unlink(tmphandle->shmPath) != 0) { + WARN("unlink shared memory %s failed, error: %s", tmphandle->shmPath, strerror(errno)); + ret = ncclSystemError; + } + free(tmphandle->shmPath); + tmphandle->shmPath = NULL; + } + } + return ret; } diff --git a/src/misc/socket.cc b/src/misc/socket.cc index 7161aee..e861480 100644 --- a/src/misc/socket.cc +++ b/src/misc/socket.cc @@ -12,6 +12,52 @@ #include #include +static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) { + int bytes = 0; + *closed = 0; + char* data = (char*)ptr; + char line[SOCKET_NAME_MAXLEN+1]; + do { + if (op == NCCL_SOCKET_RECV) bytes = recv(sock->fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); + if (op == NCCL_SOCKET_SEND) bytes = send(sock->fd, data+(*offset), size-(*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL); + if (op == NCCL_SOCKET_RECV && bytes == 0) { + *closed = 1; + return ncclSuccess; + } + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + WARN("socketProgressOpt: Call to recv from %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); + return ncclRemoteError; + } else { + bytes = 0; + } + } + (*offset) += bytes; + if (sock->abortFlag && *sock->abortFlag != 0) { + INFO(NCCL_NET, "socketProgressOpt: abort called"); + return ncclInternalError; + } + } while (bytes > 0 && (*offset) < size); + return ncclSuccess; +} + +static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { + int closed; + NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0, &closed)); + if (closed) { + char line[SOCKET_NAME_MAXLEN+1]; + WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); + return ncclRemoteError; + } + return ncclSuccess; +} + +static ncclResult_t socketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { + while (*offset < size) + NCCLCHECK(socketProgress(op, sock, ptr, size, offset)); + return ncclSuccess; +} + /* Format a string representation of a (union ncclSocketAddress *) socket address using getnameinfo() * * Output: "IPv4/IPv6 address" @@ -194,7 +240,7 @@ int ncclFindInterfaceMatchSubnet(char* ifNames, union ncclSocketAddress* localAd return found; } -ncclResult_t ncclGetSocketAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair) { +ncclResult_t ncclSocketGetAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair) { if (!(ip_port_pair && strlen(ip_port_pair) > 1)) { WARN("Net : string is null"); return ncclInvalidArgument; @@ -296,7 +342,7 @@ int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNa INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId); // Try to find interface that is in the same subnet as the IP in comm id union ncclSocketAddress idAddr; - ncclGetSocketAddrFromString(&idAddr, commId); + ncclSocketGetAddrFromString(&idAddr, commId); nIfs = ncclFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs); } } @@ -310,39 +356,31 @@ int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNa } ncclResult_t ncclSocketListen(struct ncclSocket* sock) { - /* IPv4/IPv6 support */ - int family = sock->addr.sa.sa_family; - int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); - int flags; - - /* Create socket and bind it to a port */ - int fd = socket(family, SOCK_STREAM, 0); - if (fd == -1) { - WARN("Net : Socket creation failed : %s", strerror(errno)); - return ncclSystemError; + if (sock == NULL) { + WARN("ncclSocketListen: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->fd == -1) { + WARN("ncclSocketListen: file descriptor is -1"); + return ncclInvalidArgument; } if (socketToPort(&sock->addr)) { // Port is forced by env. Make sure we get the port. int opt = 1; #if defined(SO_REUSEPORT) - SYSCHECK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); + SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); #else - SYSCHECK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); + SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); #endif } - /* The socket is set non-blocking for OS level, but asyncFlag is used to control - * blocking and non-blocking behavior in user level. */ - EQCHECK(flags = fcntl(fd, F_GETFL), -1); - SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); - // addr port should be 0 (Any port) - SYSCHECK(bind(fd, &sock->addr.sa, salen), "bind"); + SYSCHECK(bind(sock->fd, &sock->addr.sa, sock->salen), "bind"); /* Get the assigned Port */ - socklen_t size = salen; - SYSCHECK(getsockname(fd, &sock->addr.sa, &size), "getsockname"); + socklen_t size = sock->salen; + SYSCHECK(getsockname(sock->fd, &sock->addr.sa, &size), "getsockname"); #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; @@ -352,220 +390,431 @@ ncclResult_t ncclSocketListen(struct ncclSocket* sock) { /* Put the socket in listen mode * NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn */ - SYSCHECK(listen(fd, 16384), "listen"); - sock->fd = fd; + SYSCHECK(listen(sock->fd, 16384), "listen"); + sock->state = ncclSocketStateReady; return ncclSuccess; } -static ncclResult_t getFdState(int fd, enum ncclSocketState* state) { - struct pollfd pfd; - int timeout = 1, ret; - socklen_t rlen = sizeof(int); - - memset(&pfd, 0, sizeof(struct pollfd)); - pfd.fd = fd; - pfd.events = POLLOUT; - SYSCHECK(ret = poll(&pfd, 1, timeout), "poll"); - if (ret == 0) { - ret = EINPROGRESS; - } else { - /* check socket status */ - EQCHECK(ret == 1 && (pfd.revents & POLLOUT), 0); - SYSCHECK(getsockopt(fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt"); - } - - if (ret == EINPROGRESS || ret == ECONNREFUSED) - *state = ncclSocketConnecting; - else if (ret == 0) - *state = ncclSocketConnected; - else - *state = ncclSocketError; - return ncclSuccess; +ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr) { + if (sock == NULL) { + WARN("ncclSocketGetAddr: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->state != ncclSocketStateReady) return ncclInternalError; + memcpy(addr, &sock->addr, sizeof(union ncclSocketAddress)); + return ncclSuccess; } -ncclResult_t ncclGetSocketState(struct ncclSocket* sock, enum ncclSocketState* state) { - NCCLCHECK(getFdState(sock->fd, state)); - sock->state = *state; +static ncclResult_t socketTryAccept(struct ncclSocket* sock) { + socklen_t socklen = sizeof(union ncclSocketAddress); + sock->fd = accept(sock->acceptFd, &sock->addr.sa, &socklen); + if (sock->fd != -1) { + sock->state = ncclSocketStateAccepted; + } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + WARN("socketTryAccept: get errno %d that is not EAGAIN or EWOULDBLOCK", errno); + return ncclSystemError; + } + return ncclSuccess; +} + +static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) { + uint64_t magic; + enum ncclSocketType type; + int received = 0; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); + if (received == 0) return ncclSuccess; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); + if (magic != sock->magic) { + WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic); + close(sock->fd); + sock->fd = -1; + // Ignore spurious connection and accept again + sock->state = ncclSocketStateAccepting; return ncclSuccess; + } else { + received = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &type, sizeof(type), &received)); + if (type != sock->type) { + WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type); + sock->state = ncclSocketStateError; + close(sock->fd); + sock->fd = -1; + return ncclInternalError; + } else { + sock->state = ncclSocketStateReady; + } + } + return ncclSuccess; +} + +static ncclResult_t socketStartConnect(struct ncclSocket* sock) { + /* blocking/non-blocking connect() is determined by asyncFlag. */ + int ret = connect(sock->fd, &sock->addr.sa, sock->salen); + + if (ret == 0) { + sock->state = ncclSocketStateConnected; + return ncclSuccess; + } else if (errno == EINPROGRESS) { + sock->state = ncclSocketStateConnectPolling; + return ncclSuccess; + } else if (errno == ECONNREFUSED) { + if (++sock->refusedRetries == RETRY_REFUSED_TIMES) { + sock->state = ncclSocketStateError; + WARN("socketStartConnect: exceeded retries (%d)", sock->refusedRetries); + return ncclRemoteError; + } + usleep(SLEEP_INT); + if (sock->refusedRetries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); + return ncclSuccess; + } else if (errno == ETIMEDOUT) { + if (++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) { + sock->state = ncclSocketStateError; + WARN("socketStartConnect: exceeded timeouts (%d)", sock->timedOutRetries); + return ncclRemoteError; + } + usleep(SLEEP_INT); + return ncclSuccess; + } else { + char line[SOCKET_NAME_MAXLEN+1]; + sock->state = ncclSocketStateError; + WARN("socketStartConnect: Connect to %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); + return ncclSystemError; + } +} + +static ncclResult_t socketPollConnect(struct ncclSocket* sock) { + struct pollfd pfd; + int timeout = 1, ret; + socklen_t rlen = sizeof(int); + + memset(&pfd, 0, sizeof(struct pollfd)); + pfd.fd = sock->fd; + pfd.events = POLLOUT; + SYSCHECK(ret = poll(&pfd, 1, timeout), "poll"); + if (ret == 0) return ncclSuccess; + + /* check socket status */ + EQCHECK(ret == 1 && (pfd.revents & POLLOUT), 0); + SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt"); + + if (ret == 0) { + sock->state = ncclSocketStateConnected; + } else if (ret == ECONNREFUSED) { + if (++sock->refusedRetries == RETRY_REFUSED_TIMES) { + sock->state = ncclSocketStateError; + WARN("socketPollConnect: exceeded retries (%d)", sock->refusedRetries); + return ncclRemoteError; + } + if (sock->refusedRetries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); + usleep(SLEEP_INT); + sock->state = ncclSocketStateConnecting; + } else if (ret == ETIMEDOUT) { + if (++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) { + sock->state = ncclSocketStateError; + WARN("socketPollConnect: exceeded timeouts (%d)", sock->timedOutRetries); + return ncclRemoteError; + } + usleep(SLEEP_INT); + sock->state = ncclSocketStateConnecting; + } else if (ret != EINPROGRESS) { + sock->state = ncclSocketStateError; + return ncclSystemError; + } + return ncclSuccess; +} + +ncclResult_t ncclSocketPollConnect(struct ncclSocket* sock) { + if (sock == NULL) { + WARN("ncclSocketPollConnect: pass NULL socket"); + return ncclInvalidArgument; + } + NCCLCHECK(socketPollConnect(sock)); + return ncclSuccess; +} + +static ncclResult_t socketFinalizeConnect(struct ncclSocket* sock) { + int sent = 0; + NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); + if (sent == 0) return ncclSuccess; + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); + sent = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent)); + sock->state = ncclSocketStateReady; + return ncclSuccess; +} + +static ncclResult_t socketProgressState(struct ncclSocket* sock) { + if (sock->state == ncclSocketStateAccepting) { + NCCLCHECK(socketTryAccept(sock)); + } + if (sock->state == ncclSocketStateAccepted) { + NCCLCHECK(socketFinalizeAccept(sock)); + } + if (sock->state == ncclSocketStateConnecting) { + NCCLCHECK(socketStartConnect(sock)); + } + if (sock->state == ncclSocketStateConnectPolling) { + NCCLCHECK(socketPollConnect(sock)); + } + if (sock->state == ncclSocketStateConnected) { + NCCLCHECK(socketFinalizeConnect(sock)); + } + return ncclSuccess; +} + +ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running) { + if (sock == NULL) { + *running = 0; + return ncclSuccess; + } + if (sock->state == ncclSocketStateError || sock->state == ncclSocketStateClosed) { + WARN("ncclSocketReady: unexpected socket state %d", sock->state); + return ncclRemoteError; + } + *running = (sock->state == ncclSocketStateReady) ? 1 : 0; + if (*running == 0) { + NCCLCHECK(socketProgressState(sock)); + *running = (sock->state == ncclSocketStateReady) ? 1 : 0; + } + return ncclSuccess; } ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { +#ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; - /* IPv4/IPv6 support */ - int family = sock->addr.sa.sa_family; - if (family != AF_INET && family != AF_INET6) { - WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)", - ncclSocketToString(&sock->addr, line), family, AF_INET, AF_INET6); +#endif + const int one = 1; + + if (sock == NULL) { + WARN("ncclSocketConnect: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->fd == -1) { + WARN("ncclSocketConnect: file descriptor is -1"); + return ncclInvalidArgument; + } + + if (sock->state != ncclSocketStateInitialized) { + WARN("ncclSocketConnect: wrong socket state %d", sock->state); + if (sock->state == ncclSocketStateError) return ncclRemoteError; return ncclInternalError; } - int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); - int flags; - - /* Connect to a hostname / port */ - int fd = socket(family, SOCK_STREAM, 0); - if (fd == -1) { - WARN("Net : Socket creation failed : %s", strerror(errno)); - return ncclSystemError; - } - - const int one = 1; - SYSCHECK(setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); - - /* The socket is set non-blocking for OS level, but asyncFlag is used to control - * blocking and non-blocking behavior in user level. */ - EQCHECK(flags = fcntl(fd, F_GETFL), -1); - SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); - - /* const int bufsize = 128*1024; - SYSCHECK(setsockopt(fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt"); - SYSCHECK(setsockopt(fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/ - TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", ncclSocketToString(&sock->addr, line)); - int ret; - int timedout_retries = 0; - int refused_retries = 0; -retry: - /* blocking/non-blocking connect() is determined by asyncFlag. */ - ret = connect(fd, &sock->addr.sa, salen); + SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); - if (!sock->asyncFlag) { - /* blocking socket, need retry if connect fails. */ - if (errno == EINPROGRESS || errno == EAGAIN || errno == EALREADY || - (errno == ECONNREFUSED && ++refused_retries < RETRY_REFUSED_TIMES) || - (errno == ETIMEDOUT && ++timedout_retries < RETRY_TIMEDOUT_TIMES)) { - /* check abortFlag as long as we have chance to retry. */ - if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; - if (errno == ECONNREFUSED && refused_retries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); - usleep(SLEEP_INT); - goto retry; - } + sock->state = ncclSocketStateConnecting; + do { + NCCLCHECK(socketProgressState(sock)); + } while (sock->asyncFlag == 0 && + (sock->abortFlag == NULL || *sock->abortFlag == 0) && + (sock->state == ncclSocketStateConnecting || + sock->state == ncclSocketStateConnectPolling || + sock->state == ncclSocketStateConnected)); - /* If connect() fails with errno == EAGAIN/EINPROGRESS/ETIMEDOUT, we may want to try connect again. - * However, it can return EISCONN instead of success which indicates connection is built up in - * background already. No need to call connect() again. */ - if (ret == 0 || errno == EISCONN) { - sock->fd = fd; + if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; + + switch (sock->state) { + case ncclSocketStateConnecting: + case ncclSocketStateConnectPolling: + case ncclSocketStateConnected: + case ncclSocketStateReady: return ncclSuccess; - } - } else { - sock->fd = fd; - return ncclSuccess; + case ncclSocketStateError: + return ncclSystemError; + default: + WARN("ncclSocketConnect: wrong socket state %d", sock->state); + return ncclInternalError; } - - WARN("Net : Connect to %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); - return ncclRemoteError; } -ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSocket) { - socklen_t socklen = sizeof(union ncclSocketAddress); - struct pollfd pollfd; - int tmpFd = sock->fd = -1; - int pollret; +ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSock) { + ncclResult_t ret = ncclSuccess; - pollfd.fd = listenSocket->fd; - pollfd.events = POLLIN; -retry: - if ((pollret = poll(&pollfd, 1, listenSocket->asyncFlag ? 0 : 100)) < 0) { - return ncclSystemError; - } else { - tmpFd = accept(listenSocket->fd, &sock->addr.sa, &socklen); + if (listenSock == NULL || sock == NULL) { + WARN("ncclSocketAccept: pass NULL socket"); + ret = ncclInvalidArgument; + goto exit; + } + if (listenSock->state != ncclSocketStateReady) { + WARN("ncclSocketAccept: wrong socket state %d", listenSock->state); + if (listenSock->state == ncclSocketStateError) + ret = ncclSystemError; + else + ret = ncclInternalError; + goto exit; } - if (!listenSocket->asyncFlag) { - /* blocking socket, if tmpFd is still -1, we need to retry */ - if (tmpFd == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { - if (listenSocket->abortFlag && *listenSocket->abortFlag != 0) return ncclInternalError; - goto retry; - } - EQCHECK(tmpFd, -1); + if (sock->acceptFd == -1) { + memcpy(sock, listenSock, sizeof(struct ncclSocket)); + sock->acceptFd = listenSock->fd; + sock->state = ncclSocketStateAccepting; } - sock->fd = tmpFd; - return ncclSuccess; + do { + NCCLCHECKGOTO(socketProgressState(sock), ret, exit); + } while (sock->asyncFlag == 0 && + (sock->abortFlag == NULL || *sock->abortFlag == 0) && + (sock->state == ncclSocketStateAccepting || + sock->state == ncclSocketStateAccepted)); + + if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; + + switch (sock->state) { + case ncclSocketStateAccepting: + case ncclSocketStateAccepted: + case ncclSocketStateReady: + ret = ncclSuccess; + break; + case ncclSocketStateError: + ret = ncclSystemError; + break; + default: + WARN("ncclSocketAccept: wrong socket state %d", sock->state); + ret = ncclInternalError; + break; + } + +exit: + return ret; } -ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, volatile uint32_t* abortFlag, int asyncFlag) { - if (sock == NULL) - return ncclSuccess; +ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag) { + ncclResult_t ret = ncclSuccess; + if (sock == NULL) goto exit; + sock->timedOutRetries = 0; + sock->refusedRetries = 0; + sock->abortFlag = abortFlag; + sock->asyncFlag = asyncFlag; + sock->state = ncclSocketStateInitialized; + sock->magic = magic; + sock->type = type; sock->fd = -1; + sock->acceptFd = -1; + if (addr) { + /* IPv4/IPv6 support */ + int family; memcpy(&sock->addr, addr, sizeof(union ncclSocketAddress)); + family = sock->addr.sa.sa_family; + if (family != AF_INET && family != AF_INET6) { + char line[SOCKET_NAME_MAXLEN+1]; + WARN("ncclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)", + ncclSocketToString(&sock->addr, line), family, AF_INET, AF_INET6); + ret = ncclInternalError; + goto fail; + } + sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); + + /* Connect to a hostname / port */ + sock->fd = socket(family, SOCK_STREAM, 0); + if (sock->fd == -1) { + WARN("ncclSocketInit: Socket creation failed : %s", strerror(errno)); + ret = ncclSystemError; + goto fail; + } } else { memset(&sock->addr, 0, sizeof(union ncclSocketAddress)); } - sock->abortFlag = abortFlag; - sock->asyncFlag = asyncFlag; - sock->state = ncclSocketStateNum; - return ncclSuccess; -} -static ncclResult_t ncclSocketProgressOpt(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) { - int bytes = 0; - *closed = 0; - char* data = (char*)ptr; - char line[SOCKET_NAME_MAXLEN+1]; - do { - if (op == NCCL_SOCKET_RECV) bytes = recv(sock->fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); - if (op == NCCL_SOCKET_SEND) bytes = send(sock->fd, data+(*offset), size-(*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL); - if (op == NCCL_SOCKET_RECV && bytes == 0) { - *closed = 1; - return ncclSuccess; - } - if (bytes == -1) { - if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { - WARN("Net : Call to recv from %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); - return ncclRemoteError; - } else { - bytes = 0; - } - } - (*offset) += bytes; - if (sock->abortFlag && *sock->abortFlag != 0) { - INFO(NCCL_NET, "Socket progress: abort called"); - return ncclInternalError; - } - } while (bytes > 0 && (*offset) < size); - return ncclSuccess; + /* Set socket as non-blocking if async or if we need to be able to abort */ + if ((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) { + int flags; + EQCHECKGOTO(flags = fcntl(sock->fd, F_GETFL), -1, ret, fail); + SYSCHECKGOTO(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), ret, fail); + } + +exit: + return ret; +fail: + goto exit; } ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { - int closed; - NCCLCHECK(ncclSocketProgressOpt(op, sock, ptr, size, offset, 0, &closed)); - if (closed) { - char line[SOCKET_NAME_MAXLEN+1]; - WARN("Net : Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); - return ncclRemoteError; + if (sock == NULL) { + WARN("ncclSocketProgress: pass NULL socket"); + return ncclInvalidArgument; } + NCCLCHECK(socketProgress(op, sock, ptr, size, offset)); return ncclSuccess; } ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { - while (*offset < size) - NCCLCHECK(ncclSocketProgress(op, sock, ptr, size, offset)); + if (sock == NULL) { + WARN("ncclSocketWait: pass NULL socket"); + return ncclInvalidArgument; + } + NCCLCHECK(socketWait(op, sock, ptr, size, offset)); return ncclSuccess; } ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size) { int offset = 0; - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_SEND, sock, ptr, size, &offset)); + if (sock == NULL) { + WARN("ncclSocketSend: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->state != ncclSocketStateReady) { + WARN("ncclSocketSend: socket state (%d) is not ready", sock->state); + return ncclInternalError; + } + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, ptr, size, &offset)); return ncclSuccess; } ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size) { int offset = 0; - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_RECV, sock, ptr, size, &offset)); + if (sock == NULL) { + WARN("ncclSocketRecv: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->state != ncclSocketStateReady) { + WARN("ncclSocketRecv: socket state (%d) is not ready", sock->state); + return ncclInternalError; + } + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, ptr, size, &offset)); return ncclSuccess; } // Receive or detect connection closed ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed) { int offset = 0; + if (sock == NULL) { + WARN("ncclSocketTryRecv: pass NULL socket"); + return ncclInvalidArgument; + } *closed = 0; while (offset < size) { - NCCLCHECK(ncclSocketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); + NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); if (*closed) return ncclSuccess; } return ncclSuccess; } + +ncclResult_t ncclSocketClose(struct ncclSocket* sock) { + if (sock != NULL) { + if (sock->fd >= 0) close(sock->fd); + sock->state = ncclSocketStateClosed; + sock->fd = -1; + } + return ncclSuccess; +} + +ncclResult_t ncclSocketGetFd(struct ncclSocket* sock, int* fd) { + if (sock == NULL) { + WARN("ncclSocketGetFd: pass NULL socket"); + return ncclInvalidArgument; + } + if (fd) *fd = sock->fd; + return ncclSuccess; +} + +ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock) { + if (sock == NULL) { + WARN("ncclSocketGetFd: pass NULL socket"); + return ncclInvalidArgument; + } + sock->fd = fd; + return ncclSuccess; +} diff --git a/src/misc/strongstream.cc b/src/misc/strongstream.cc index d07698b..61b0e4b 100644 --- a/src/misc/strongstream.cc +++ b/src/misc/strongstream.cc @@ -305,7 +305,8 @@ static void mergeTips(struct ncclStrongStreamGraph* a, cudaGraphNode_t const* bN } ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b + struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b, + bool b_subsumes_a ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { @@ -319,6 +320,7 @@ ncclResult_t ncclStrongStreamWaitStream( NCCLCHECK(checkGraphId(ag, graph.graphId)); struct ncclStrongStreamGraph* bg = b->graphHead; NCCLCHECK(checkGraphId(bg, graph.graphId)); + if (b_subsumes_a) ag->tipCount = 0; mergeTips(ag, bg->tipNodes, bg->tipCount); } a->serialEventNeedsRecord = true; @@ -330,7 +332,8 @@ ncclResult_t ncclStrongStreamWaitStream( } ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, struct ncclStrongStream* a, cudaStream_t b + struct ncclCudaGraph graph, struct ncclStrongStream* a, cudaStream_t b, + bool b_subsumes_a ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { @@ -351,6 +354,7 @@ ncclResult_t ncclStrongStreamWaitStream( } struct ncclStrongStreamGraph* ag = a->graphHead; NCCLCHECK(checkGraphId(ag, graph.graphId)); + if (b_subsumes_a) ag->tipCount = 0; mergeTips(ag, bNodes, bCount); } a->serialEventNeedsRecord = true; @@ -362,7 +366,8 @@ ncclResult_t ncclStrongStreamWaitStream( } ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, cudaStream_t a, struct ncclStrongStream* b + struct ncclCudaGraph graph, cudaStream_t a, struct ncclStrongStream* b, + bool b_subsumes_a ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { @@ -374,7 +379,9 @@ ncclResult_t ncclStrongStreamWaitStream( } else { struct ncclStrongStreamGraph* bg = b->graphHead; NCCLCHECK(checkGraphId(bg, graph.graphId)); - CUDACHECK(cudaStreamUpdateCaptureDependencies(a, bg->tipNodes, bg->tipCount, cudaStreamAddCaptureDependencies)); + CUDACHECK(cudaStreamUpdateCaptureDependencies(a, bg->tipNodes, bg->tipCount, + b_subsumes_a ? cudaStreamSetCaptureDependencies : cudaStreamAddCaptureDependencies + )); } #else CUDACHECK(cudaEventRecord(b->scratchEvent, b->cudaStream)); diff --git a/src/nccl.h.in b/src/nccl.h.in index ccb8f57..44a68e9 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -123,7 +123,7 @@ const char* pncclGetErrorString(ncclResult_t result); * comm is currently unused and can be set to NULL */ const char* ncclGetLastError(ncclComm_t comm); -const char* pncclGetError(ncclComm_t comm); +const char* pncclGetLastError(ncclComm_t comm); /* Checks whether the comm has encountered any asynchronous errors */ ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); diff --git a/src/proxy.cc b/src/proxy.cc index 696f57f..2103b7a 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -437,8 +437,7 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op) int stepSize = info->comm->buffSizes[op->protocol]/NCCL_STEPS; - // If nNodes > 1 and we're using Simple, reduce the stepSize to increase shared buffer utilization - if (info->comm->nNodes > 1 && op->protocol == NCCL_PROTO_SIMPLE) stepSize = info->comm->p2pNetChunkSize; + if (op->protocol == NCCL_PROTO_SIMPLE) stepSize = info->comm->p2pChunkSize; info->chunkSize = stepSize; op->root = info->root; @@ -532,6 +531,8 @@ static ncclResult_t progressOps(struct ncclComm* comm, struct ncclProxyProgressS return ncclSuccess; } +NCCL_PARAM(ProxyAppendBatchSize, "PROXY_APPEND_BATCH_SIZE", 16); + static ncclResult_t ncclProxyGetPostedOps(struct ncclComm* comm, int* added) { struct ncclProxyProgressState* state = &comm->proxyState.progressState; if (state->opsPool == NULL) return ncclInternalError; @@ -570,9 +571,16 @@ process_nextops: int freeOpEnd[NCCL_MAX_LOCAL_RANKS]; for (int i=0; ilocalRanks; i++) freeOp[i] = -1; + uint64_t lastOpCount = 0; + int lastPeer = -1; + int count = 0; for (int opIndex = state->nextOps; opIndex != -1;) { struct ncclProxyOp* peerOp = pool->ops+opIndex; int peer = opIndex / MAX_OPS_PER_PEER; + if ((lastOpCount && peerOp->opCount != lastOpCount) || ((lastPeer != -1) && peer != lastPeer)) count++; + if (count == ncclParamProxyAppendBatchSize()+1) break; + lastOpCount = peerOp->opCount; + lastPeer = peer; if (peerOp->connection == NULL) return ncclInternalError; if (peerOp->next != -1) __builtin_prefetch(pool->ops+peerOp->next); NCCLCHECK(ProxyAppend(state, peerOp)); @@ -676,7 +684,7 @@ void* ncclProxyProgress(void *comm_) { int lastIdle = 0; struct ncclProxyArgs profArgs; // Only used for profiling purposes - while (state->stop == 0 && *comm->abortFlag == 0) { + while ((state->stop == false || (state->stop == true && state->active)) && *comm->abortFlag == 0) { int idle = 1; ncclResult_t ret = progressOps(comm, state, state->active, &idle); if (ret != ncclSuccess) { @@ -686,18 +694,17 @@ void* ncclProxyProgress(void *comm_) { } if (lastIdle == 0 && idle == 1) ncclProfilingRecord(&profArgs, 0, 0, ncclProxyProfileIdle); if (lastIdle == 1 && idle == 0) ncclProfilingRecord(&profArgs, 0, 0, ncclProxyProfileActive); - if (idle) { - int added = 0; - TIME_START(3); + int added = 0; + TIME_START(3); + if (state->stop == false) ret = ncclProxyGetPostedOps(comm, &added); - if (added) { TIME_STOP(3); } else { TIME_CANCEL(3); } - if (ret != ncclSuccess) { - (void) ncclCommSetAsyncError(comm, ret); - INFO(NCCL_ALL,"%s:%d -> %d [Proxy Thread]", __FILE__, __LINE__, ret); - } - if (added == 0) { - sched_yield(); // No request progressed. Let others run. - } + if (added) { TIME_STOP(3); } else { TIME_CANCEL(3); } + if (ret != ncclSuccess) { + (void) ncclCommSetAsyncError(comm, ret); + INFO(NCCL_ALL,"%s:%d -> %d [Proxy Thread]", __FILE__, __LINE__, ret); + } + if (added == 0) { + sched_yield(); // No request progressed. Let others run. } lastIdle = idle; } @@ -814,7 +821,7 @@ static ncclResult_t ncclProxyFreeConnections(struct ncclProxyConnectionPool* poo int max = b == pool->banks-1 ? pool->offset : NCCL_PROXY_CONN_POOL_SIZE; for (int i=0; ipools[b]+i; - if (connection->initFlag == true) { + if (connection->state != connUninitialized) { NCCLCHECK(proxyFree(connection, comm)); } } @@ -827,6 +834,10 @@ static ncclResult_t ncclProxyFreeConnections(struct ncclProxyConnectionPool* poo #include "transport.h" ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, int rank, struct ncclProxyConnector* proxyConn) { + struct ncclSocket* sock; + int ready; + int type = ncclProxyMsgInit; + // Keep one connection per mlocal rank proxyConn->connection = NULL; proxyConn->rank = rank; @@ -834,17 +845,18 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in NCCLCHECK(ncclCalloc(&comm->proxyState.peerSocks, comm->localRanks)); NCCLCHECK(ncclCalloc(&comm->proxyState.proxyOps, comm->localRanks)); NCCLCHECK(ncclCalloc(&comm->proxyState.sharedDevMems, comm->localRanks)); - for (int r=0; rlocalRanks; r++) { - NCCLCHECK(ncclSocketInit(&comm->proxyState.peerSocks[r], NULL, comm->abortFlag, 0)); + for (int i = 0; i < comm->localRanks; ++i) { + NCCLCHECK(ncclSocketSetFd(-1, &comm->proxyState.peerSocks[i])); } } + NCCLCHECK(ncclTopoGetLocalRank(comm->topo, rank, &proxyConn->localRank)); - struct ncclSocket* sock = comm->proxyState.peerSocks+proxyConn->localRank; - if (sock->fd == -1) { - memcpy(&sock->addr, comm->proxyState.peerAddresses+rank, sizeof(union ncclSocketAddress)); + sock = comm->proxyState.peerSocks + proxyConn->localRank; + NCCLCHECK(ncclSocketReady(sock, &ready)); + if (!ready) { + NCCLCHECK(ncclSocketInit(sock, comm->proxyState.peerAddresses+rank, comm->magic, ncclSocketTypeProxy, comm->abortFlag)); NCCLCHECK(ncclSocketConnect(sock)); } - int type = ncclProxyMsgInit; NCCLCHECK(ncclSocketSend(sock, &type, sizeof(int))); NCCLCHECK(ncclSocketSend(sock, &transport, sizeof(int))); NCCLCHECK(ncclSocketSend(sock, &send, sizeof(int))); @@ -857,7 +869,7 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in NCCLCHECK(ncclSocketRecv(sock, poolPath+sizeof("/dev/shm/nccl-")-1, sizeof("XXXXXX")-1)); struct ncclProxyOps* proxyOps = comm->proxyState.proxyOps+proxyConn->localRank; if (proxyOps->pool == NULL) { - NCCLCHECK(ncclShmOpen(poolPath, sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, 0)); + NCCLCHECK(ncclShmOpen(poolPath, sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, -1, &proxyOps->handle)); proxyOps->nextOps = proxyOps->nextOpsEnd = proxyOps->freeOp = -1; } } @@ -868,11 +880,12 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in const char* ncclProxyMsgTypeStr[] = { "Unknown", "Init", "SharedInit", "Setup", "Connect", "Start", "Close", "Abort", "Stop" }; ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) { - if (proxyConn->comm->proxyState.peerSocks == NULL) return ncclInternalError; - struct ncclSocket* sock = proxyConn->comm->proxyState.peerSocks+proxyConn->localRank; - if (sock->fd == -1) return ncclInternalError; - ncclResult_t ret; + struct ncclSocket* sock; + ncclResult_t ret = ncclSuccess; + if (proxyConn->comm->proxyState.peerSocks == NULL) return ncclInternalError; + sock = proxyConn->comm->proxyState.peerSocks + proxyConn->localRank; + if (sock == NULL) return ncclInternalError; NCCLCHECKGOTO(ncclSocketSend(sock, &type, sizeof(int)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &proxyConn->connection, sizeof(void*)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &reqSize, sizeof(int)), ret, error); @@ -882,7 +895,6 @@ ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* return ncclSuccess; error: WARN("Proxy Call to rank %d failed (%s)", proxyConn->comm->localRankToRank[proxyConn->localRank], ncclProxyMsgTypeStr[type]); - sock->fd = -1; return ret; } @@ -892,16 +904,15 @@ static ncclResult_t proxyProgressInit(struct ncclComm* comm) { int size = sizeof(struct ncclProxyOpsPool); struct ncclProxyOpsPool* pool = NULL; - char shmPath[sizeof("/dev/shm/nccl-XXXXXX")]; - shmPath[0] = '\0'; - NCCLCHECK(ncclShmOpen(shmPath, size, (void**)&pool, NULL, 1)); - - // Init pool - pool->nextOps = -1; - // The service thread may be launched already but localRanks may not be set yet. while (comm->localRanks == 0) sched_yield(); + char shmPath[sizeof("/dev/shm/nccl-XXXXXX")]; + shmPath[0] = '\0'; + NCCLCHECK(ncclShmOpen(shmPath, size, (void**)&pool, NULL, comm->localRanks + 1, &state->handle)); + // Init pool + pool->nextOps = -1; + for (int r=0; rlocalRanks; r++) { pool->freeOps[r] = r*MAX_OPS_PER_PEER; for (int i=0; iops[r*MAX_OPS_PER_PEER+i].next = r*MAX_OPS_PER_PEER+i+1; @@ -928,7 +939,7 @@ static ncclResult_t proxyProgressInit(struct ncclComm* comm) { static void proxyOpsFree(struct ncclComm* comm) { struct ncclProxyProgressState* state = &comm->proxyState.progressState; - if (ncclShmClose(state->opsPool, NULL, sizeof(struct ncclProxyOpsPool)) != ncclSuccess) { + if (ncclShmClose(state->handle) != ncclSuccess) { WARN("[Service thread] shm close failed"); } } @@ -937,10 +948,8 @@ ncclResult_t ncclProxyShmUnlink(struct ncclComm* comm) { struct ncclProxyProgressState* state = &comm->proxyState.progressState; if (state->opsPool == NULL) return ncclSuccess; - char shmPath[] = "/dev/shm/nccl-XXXXXX"; - memcpy(shmPath+sizeof("/dev/shm/nccl-")-1, state->opsPoolShmSuffix, sizeof("XXXXXX")-1); - if (ncclShmUnlink(shmPath) != ncclSuccess) { - WARN("[Service thread] shm unlink failed"); + if (ncclShmUnlink(state->handle) != ncclSuccess) { + WARN("[Service thread] proxy ops shm unlink failed"); } return ncclSuccess; } @@ -965,7 +974,7 @@ static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclPr NCCLCHECK(ncclSocketSend(sock, state->opsPoolShmSuffix, sizeof("XXXXXX")-1)); } INFO(NCCL_NET, "New proxy %s connection %d from local rank %d, transport %d", connection->send ? "send":"recv", id, connection->localRank, connection->transport); - __atomic_store_n(&connection->initFlag, true, __ATOMIC_RELEASE); + __atomic_store_n(&connection->state, connInitialized, __ATOMIC_RELEASE); return ncclSuccess; } @@ -980,6 +989,7 @@ static ncclResult_t proxyConnSharedInit(struct ncclProxyLocalPeer* peer, struct int nChannels; NCCLCHECK(ncclSocketRecv(sock, &nChannels, sizeof(int))); if (connection->tcomm->proxySharedInit) NCCLCHECK(connection->tcomm->proxySharedInit(connection, comm, nChannels)); + __atomic_store_n(&connection->state, connSharedInitialized, __ATOMIC_RELEASE); return ncclSuccess; } @@ -991,14 +1001,29 @@ static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclC NCCLCHECK(op->connection->tcomm->proxyConnect(op->connection, comm, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done)); } else return ncclInternalError; if (done) { - if (op->respSize) NCCLCHECK(ncclSocketSend(op->connection->sock, op->respBuff, op->respSize)); - if (op->reqBuff) free(op->reqBuff); - if (op->respBuff) free(op->respBuff); - op->reqBuff = NULL; - op->respBuff = NULL; + if (op->type == ncclProxyMsgSetup) + __atomic_store_n(&op->connection->state, connSetupDone, __ATOMIC_RELEASE); + else if (op->type == ncclProxyMsgConnect) + __atomic_store_n(&op->connection->state, connConnected, __ATOMIC_RELEASE); + /* if setup or connect is done, we should not return any error at this point since + * ncclSocketSend might already send the respBuff to the requester. If we still choose + * to abort and close the connection, it can cause segfault if the requester is using + * the respBuff. */ + if (op->respSize) ncclSocketSend(op->connection->sock, op->respBuff, op->respSize); + if (op->reqBuff) { + free(op->reqBuff); + op->reqBuff = NULL; + } + if (op->respBuff) { + free(op->respBuff); + op->respBuff = NULL; + } op->type = 0; (*asyncOpCount)--; + } else if (*comm->abortFlag != 0) { + return ncclInternalError; } + return ncclSuccess; } @@ -1042,36 +1067,52 @@ void* ncclProxyService(void* _args) { struct ncclProxyLocalPeer peers[NCCL_MAX_LOCAL_RANKS]; memset(&peers, 0, sizeof(struct ncclProxyLocalPeer)*NCCL_MAX_LOCAL_RANKS); for (int s=0; sabortFlag, 0); pollfds[s].fd = -1; pollfds[s].events = POLLHUP|POLLIN; } - pollfds[NCCL_MAX_LOCAL_RANKS].fd = comm->proxyState.listenSock->fd; + if (ncclSocketGetFd(comm->proxyState.listenSock, &pollfds[NCCL_MAX_LOCAL_RANKS].fd) != ncclSuccess) { + WARN("[Proxy Service] Get listenSock fd fails\n"); + return NULL; + }; pollfds[NCCL_MAX_LOCAL_RANKS].events = POLLIN; int maxnpeers = 0; int npeers = 0; int stop = 0; int asyncOpCount = 0; - while ((stop == 0 || (stop == 1 && npeers > 0)) && *comm->abortFlag == 0) { + while (stop == 0 || (stop == 1 && npeers > 0)) { + /* Even if local comm aborts, we cannot let proxy thread exit if we still have peer + * connections. Need to wait until all other related comms call abort and safely exit + * together, or we could face segmentation fault. */ + if (*comm->abortFlag != 0) stop = 1; /* never let proxy service thread blocks in poll, or it cannot receive abortFlag. */ - if (poll(pollfds, NCCL_MAX_LOCAL_RANKS+1, asyncOpCount ? 0 : 500) < 0) { - WARN("[Proxy Service] Poll failed: %s\n", strerror(errno)); + int ret; + do { + ret = poll(pollfds, NCCL_MAX_LOCAL_RANKS+1, asyncOpCount ? 0 : 500); + } while (ret < 0 && errno == EINTR); + if (ret < 0) { + WARN("[Proxy Service] Poll failed: %s", strerror(errno)); return NULL; } if (pollfds[NCCL_MAX_LOCAL_RANKS].revents) { int s = 0; - while (s < NCCL_MAX_LOCAL_RANKS && peers[s].sock.fd != -1) s++; + while (s < NCCL_MAX_LOCAL_RANKS && pollfds[s].fd >= 0) s++; if (s == NCCL_MAX_LOCAL_RANKS) { WARN("[Proxy service] Too many connections (%d max)", NCCL_MAX_LOCAL_RANKS); return NULL; } if (maxnpeers < s+1) maxnpeers = s+1; - struct ncclSocket* sock = &peers[s].sock; - if (ncclSocketAccept(sock, comm->proxyState.listenSock) != ncclSuccess) { + if (ncclSocketInit(&peers[s].sock) != ncclSuccess) { + WARN("[Service thread] Initialize peers[%d].sock fails\n", s); + return NULL; + } + if (ncclSocketAccept(&peers[s].sock, comm->proxyState.listenSock) != ncclSuccess) { WARN("[Service thread] Accept failed %s", strerror(errno)); } else { - pollfds[s].fd = sock->fd; + if (ncclSocketGetFd(&peers[s].sock, &pollfds[s].fd) != ncclSuccess) { + WARN("[Service thread] Get peers[%d].sock fd fails\n", s); + return NULL; + } npeers++; peers[s].localRank = -1; } @@ -1083,10 +1124,12 @@ void* ncclProxyService(void* _args) { int closeConn = 0; int type = 0; ncclResult_t res = ncclSuccess; + + if (pollfds[s].fd == -1) continue; if (op->type != 0) { res = proxyProgressAsync(op, comm, &asyncOpCount); type = op->type; - if (res != ncclSuccess) op->type = 0; + if (res != ncclSuccess) closeConn = 1; } else if (pollfds[s].revents & POLLIN) { int closed; if (ncclSocketTryRecv(sock, &type, sizeof(int), &closed) != ncclSuccess) { @@ -1120,26 +1163,31 @@ void* ncclProxyService(void* _args) { closeConn = 1; } if (closeConn) { - close(sock->fd); - sock->fd = pollfds[s].fd = -1; + ncclSocketClose(sock); + if (op->reqBuff) { + free(op->reqBuff); + op->reqBuff = NULL; + } + if (op->respBuff) { + free(op->respBuff); + op->respBuff = NULL; + } + op->type = 0; + pollfds[s].fd = -1; npeers--; } } } - /* wait until main thread flush all NCCL operations. */ - while (*comm->abortFlag != 0 && __atomic_load_n(&comm->proxyState.safeAbortFlag, __ATOMIC_ACQUIRE) == 0) - usleep(1000); // Wait for all operations to complete and stop progress thread before freeing any resource if (ncclProxyProgressDestroy(comm) != ncclSuccess) { WARN("[Proxy Service] proxyDestroy failed"); } for (int s=0; sproxyState.listenSock->fd); - free(comm->proxyState.listenSock); + ncclSocketClose(comm->proxyState.listenSock); proxyOpsFree(comm); return NULL; } @@ -1164,32 +1212,29 @@ ncclResult_t ncclProxyDestroy(struct ncclComm* comm) { if (state->peerAddresses) { if (*comm->abortFlag == 0) { struct ncclSocket sock; - sock.abortFlag = NULL; - sock.asyncFlag = 0; - memcpy(&sock.addr, comm->proxyState.peerAddresses+comm->rank, sizeof(union ncclSocketAddress)); - NCCLCHECK(ncclSocketConnect(&sock)); int type = ncclProxyMsgStop; + NCCLCHECK(ncclSocketInit(&sock, comm->proxyState.peerAddresses + comm->rank, comm->magic, ncclSocketTypeProxy, comm->abortFlag)); + NCCLCHECK(ncclSocketConnect(&sock)); NCCLCHECK(ncclSocketSend(&sock, &type, sizeof(int))); - close(sock.fd); - } else { - /* when abortFlag is set, all socket related communications are no longer reliable. We need to - * set a flag to let proxy thread exit. */ - __atomic_store_n(&state->safeAbortFlag, 1, __ATOMIC_RELEASE); + NCCLCHECK(ncclSocketClose(&sock)); } free(state->peerAddresses); } + if (state->peerSocks) { for (int i=0; ilocalRanks; i++) { - if (state->peerSocks[i].fd != -1) { + int fd; + NCCLCHECK(ncclSocketGetFd(state->peerSocks + i, &fd)); + if (fd >= 0) { if (state->proxyOps[i].pool) { - NCCLCHECK(ncclShmClose(state->proxyOps[i].pool, NULL, sizeof(struct ncclProxyOpsPool))); + NCCLCHECK(ncclShmClose(state->proxyOps[i].handle)); } if (state->sharedDevMems[i]) { CUDACHECK(cudaIpcCloseMemHandle(state->sharedDevMems[i])); } int type = ncclProxyMsgClose; - if (*comm->abortFlag == 0) NCCLCHECK(ncclSocketSend(state->peerSocks+i, &type, sizeof(int))); - close(state->peerSocks[i].fd); + if (*comm->abortFlag == 0) NCCLCHECK(ncclSocketSend(state->peerSocks + i, &type, sizeof(int))); + NCCLCHECK(ncclSocketClose(state->peerSocks + i)); } } free(state->peerSocks); diff --git a/src/transport.cc b/src/transport.cc index b3ca90d..66d8b51 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -67,12 +67,11 @@ void dumpData(struct ncclConnect* data, int ndata) { ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType/*=NULL*/) { // Stream used during transport setup; need for P2P pre-connect + CUDA Graph + ncclResult_t ret = ncclSuccess; int highestType = TRANSPORT_P2P; // track highest transport type - - cudaStream_t transportSetupStream; - CUDACHECK(cudaStreamCreateWithFlags(&transportSetupStream, cudaStreamNonBlocking)); - struct ncclConnect data[2*MAXCHANNELS]; + + NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->hostStream), ret, fail); for (int i=1; inRanks; i++) { int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0); int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; @@ -86,7 +85,7 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* TIME_START(0); for (int c=0; c(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex, &type)); + NCCLCHECKGOTO(selectTransport<0>(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex, &type), ret, fail); if (type > highestType) highestType = type; } } @@ -95,7 +94,7 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* struct ncclConnect* sendData = recvData+recvChannels; for (int c=0; c(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex, &type)); + NCCLCHECKGOTO(selectTransport<1>(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex, &type), ret, fail); if (type > highestType) highestType = type; } } @@ -104,16 +103,16 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* TIME_START(2); if (sendPeer == recvPeer) { if (recvChannels+sendChannels) { - NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels))); - NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels))); + NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); + NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); sendData = data; recvData = data+sendChannels; } } else { - if (recvChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels)); - if (sendChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels)); - if (sendChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels)); - if (recvChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels)); + if (recvChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels), ret, fail); + if (sendChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels), ret, fail); + if (sendChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels), ret, fail); + if (recvChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels), ret, fail); } TIME_STOP(2); @@ -121,9 +120,9 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* for (int c=0; cchannels[c].peers[sendPeer].send + connIndex; - NCCLCHECK(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn)); + NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn), ret, fail); conn->connected = 1; - CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, transportSetupStream)); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); } } TIME_STOP(3); @@ -131,19 +130,23 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* for (int c=0; cchannels[c].peers[recvPeer].recv + connIndex; - NCCLCHECK(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn)); + NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn), ret, fail); conn->connected = 1; - CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, transportSetupStream)); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); } } TIME_STOP(4); comm->connectRecv[recvPeer] = comm->connectSend[sendPeer] = 0UL; } - CUDACHECK(cudaStreamSynchronize(transportSetupStream)); - CUDACHECK(cudaStreamDestroy(transportSetupStream)); + if (highestTransportType != NULL) *highestTransportType = highestType; TIME_PRINT("P2P Setup/Connect"); - return ncclSuccess; +exit: + NCCLCHECK(ncclStrongStreamWaitStream(ncclCudaGraphNone(), &comm->deviceStream, &comm->hostStream)); + NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->hostStream)); + return ret; +fail: + goto exit; } extern struct ncclTransport collNetTransport; diff --git a/src/transport/net.cc b/src/transport/net.cc index a3a1579..bdb2e2d 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -63,6 +63,7 @@ struct connectMapMem{ char shmPath[PATH_MAX]; cudaIpcMemHandle_t ipc; }; + ncclShmHandle_t handle; }; struct connectMap { @@ -224,13 +225,12 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph } static ncclResult_t netMapShm(struct connectMapMem* mem) { - NCCLCHECK(ncclShmOpen(mem->shmPath, mem->size, (void**)&mem->cpuPtr, (void**)&mem->gpuPtr, 0)); - NCCLCHECK(ncclShmUnlink(mem->shmPath)); + NCCLCHECK(ncclShmOpen(mem->shmPath, mem->size, (void**)&mem->cpuPtr, (void**)&mem->gpuPtr, -1, &mem->handle)); return ncclSuccess; } static ncclResult_t netCreateShm(struct connectMapMem* mem) { mem->shmPath[0] = '\0'; // Let ncclShmOpen create a tmp file - NCCLCHECK(ncclShmOpen(mem->shmPath, mem->size, (void**)&mem->cpuPtr, NULL, 1)); + NCCLCHECK(ncclShmOpen(mem->shmPath, mem->size, (void**)&mem->cpuPtr, NULL, 1, &mem->handle)); return ncclSuccess; } @@ -339,7 +339,7 @@ static ncclResult_t sendFree(struct ncclConnector* send) { struct connectMap* map = (struct connectMap*)(send->transportResources); if (map) { if (map->sameProcess == 0) { - NCCLCHECK(ncclShmClose(map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].gpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].size)); + NCCLCHECK(ncclShmClose(map->mems[NCCL_NET_MAP_HOSTMEM].handle)); if (map->mems[NCCL_NET_MAP_DEVMEM].size) { CUDACHECK(cudaIpcCloseMemHandle(map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr)); } @@ -372,7 +372,7 @@ static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, int local struct ncclProxySharedP2p* state = type == 0 ? &peer->send : &peer->recv; state->refcount++; if (state->size == 0) { - state->size = nChannels*NCCL_SHARED_STEPS*comm->p2pNetChunkSize; + state->size = nChannels*NCCL_SHARED_STEPS*comm->p2pChunkSize; } if (size) *size = state->size; @@ -399,7 +399,7 @@ static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, int local static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int channel, int slot, int* offset) { // Use different pools for different channels and also separate send/recv. int globalSlot = (channel*NCCL_SHARED_STEPS)+slot; - *offset = comm->p2pNetChunkSize * globalSlot; + *offset = comm->p2pChunkSize * globalSlot; return ncclSuccess; } @@ -523,6 +523,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm)); connection->proxyAppendPtr = &connection->proxyAppend; } + if (resources->netSendComm == NULL) { *done = 0; return ncclSuccess; @@ -657,6 +658,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm)); connection->proxyAppendPtr = &connection->proxyAppend; } + if (resources->netRecvComm == NULL) { *done = 0; return ncclSuccess; @@ -746,67 +748,75 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { struct sendResources* resources = (struct sendResources*)(connection->transportResources); - if (resources == NULL) { // NVB Preconnect + if (connection->state == connSharedInitialized) { // NVB Preconnect NCCLCHECK(sharedBuffersDestroy(comm, connection->localRank, 0)); return ncclSuccess; } - for (int p=0; pbuffers[p]) { - NCCLCHECK(ncclNetDeregMr(comm, resources->netSendComm, resources->mhandles[p])); + + if (connection->state == connConnected) { + for (int p=0; pbuffers[p]) { + NCCLCHECK(ncclNetDeregMr(comm, resources->netSendComm, resources->mhandles[p])); + } } - } - struct connectMapMem* mems = resources->map.mems; - if (resources->map.sameProcess) { - NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); - } else { - NCCLCHECK(ncclShmClose(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, NULL, mems[NCCL_NET_MAP_HOSTMEM].size)); - } - CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); - if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); - if (resources->shared) { - NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 0)); - if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { - struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->remoteRank; - comms->sendRefCount[resources->channelId]--; - if (comms->sendRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseSend(comm, comms->sendComm[resources->channelId])); + struct connectMapMem* mems = resources->map.mems; + if (resources->map.sameProcess) { + NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); + } else { + NCCLCHECK(ncclShmClose(mems[NCCL_NET_MAP_HOSTMEM].handle)); + } + CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); + if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); + if (resources->shared) { + NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 0)); + if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { + struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->remoteRank; + comms->sendRefCount[resources->channelId]--; + if (comms->sendRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseSend(comm, comms->sendComm[resources->channelId])); + } else { + NCCLCHECK(ncclNetCloseSend(comm, resources->netSendComm)); + } } else { NCCLCHECK(ncclNetCloseSend(comm, resources->netSendComm)); } - } else { - NCCLCHECK(ncclNetCloseSend(comm, resources->netSendComm)); } - free(resources); + + if (connection->state == connSetupDone) free(resources); return ncclSuccess; } static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { struct recvResources* resources = (struct recvResources*)(connection->transportResources); - if (resources == NULL) { // NVB Preconnect + if (connection->state == connSharedInitialized) { // NVB Preconnect NCCLCHECK(sharedBuffersDestroy(comm, connection->localRank, 1)); return ncclSuccess; } - for (int p=0; pbuffers[p]) { - NCCLCHECK(ncclNetDeregMr(comm, resources->netRecvComm, resources->mhandles[p])); + + if (connection->state == connConnected) { + for (int p=0; pbuffers[p]) { + NCCLCHECK(ncclNetDeregMr(comm, resources->netRecvComm, resources->mhandles[p])); + } } - } - struct connectMapMem* mems = resources->map.mems; - NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); - CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); - if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); - if (resources->shared) { - NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 1)); - if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { - struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->proxyRank; - comms->recvRefCount[resources->channelId]--; - if (comms->recvRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseRecv(comm, comms->recvComm[resources->channelId])); + struct connectMapMem* mems = resources->map.mems; + NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); + CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); + if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); + if (resources->shared) { + NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 1)); + if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { + struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->proxyRank; + comms->recvRefCount[resources->channelId]--; + if (comms->recvRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseRecv(comm, comms->recvComm[resources->channelId])); + } else { + NCCLCHECK(ncclNetCloseRecv(comm, resources->netRecvComm)); + } } else { NCCLCHECK(ncclNetCloseRecv(comm, resources->netRecvComm)); } - } else { - NCCLCHECK(ncclNetCloseRecv(comm, resources->netRecvComm)); } - free(resources); + + if (connection->state == connSetupDone) free(resources); return ncclSuccess; } diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index a1f8897..8818554 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -57,6 +57,7 @@ struct alignas(64) ncclIbDev { int realPort; int maxQp; struct ncclIbMrCache mrCache; + int ar; // ADAPTIVE_ROUTING }; #define MAX_IB_PORT 15 @@ -80,6 +81,7 @@ NCCL_PARAM(IbSl, "IB_SL", 0); NCCL_PARAM(IbTc, "IB_TC", 0); NCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192); NCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2); +NCCL_PARAM(IbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2); pthread_t ncclIbAsyncThread; static void* ncclIbAsyncThreadMain(void* args) { @@ -221,6 +223,11 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { ncclIbDevs[ncclNIbDevs].mrCache.population = 0; ncclIbDevs[ncclNIbDevs].mrCache.slots = NULL; + // Enable ADAPTIVE_ROUTING by default on IB networks + // But allow it to be overloaded by an env parameter + ncclIbDevs[ncclNIbDevs].ar = (portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND) ? 1 : 0; + if (ncclParamIbAdaptiveRouting() != -2) ncclIbDevs[ncclNIbDevs].ar = ncclParamIbAdaptiveRouting(); + pthread_create(&ncclIbAsyncThread, NULL, ncclIbAsyncThreadMain, context); ncclSetThreadName(ncclIbAsyncThread, "NCCL IbAsync %2d", ncclNIbDevs); pthread_detach(ncclIbAsyncThread); // will not be pthread_join()'d @@ -298,11 +305,6 @@ failure: return ncclSystemError; } -static ncclResult_t GetSocketAddr(union ncclSocketAddress* addr) { - memcpy(addr, &ncclIbIfAddr, sizeof(*addr)); - return ncclSuccess; -} - #define NCCL_NET_IB_MAX_RECVS 8 ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { @@ -364,6 +366,7 @@ struct ncclIbCommStage { struct ncclIbHandle { union ncclSocketAddress connectAddr; // Filled by the target + uint64_t magic; // random number to help debugging struct ncclIbCommStage stage; // Used by the other side when connecting }; @@ -376,7 +379,7 @@ struct ncclIbRequest { struct ncclIbVerbs* verbs; int type; int events; - union ncclSocketAddress *addr; + struct ncclSocket* sock; int nreqs; union { struct { @@ -427,6 +430,7 @@ struct ncclIbSendComm { struct ibv_qp* qps[NCCL_IB_MAX_QPS]; int nqps; struct ibv_mr* fifoMr; + int ar; }; // The SendFifo needs to be 32-byte aligned and each element needs // to be a 32-byte multiple, so that an entry does not get split and @@ -571,19 +575,19 @@ ncclResult_t ncclIbListen(int dev, void* opaqueHandle, void** listenComm) { static_assert(sizeof(struct ncclIbHandle) < NCCL_NET_HANDLE_MAXSIZE, "ncclIbHandle size too large"); memset(handle, 0, sizeof(struct ncclIbHandle)); comm->dev = dev; - comm->sock.asyncFlag = 1; /* nonblocking socket is required by network communication. */ - NCCLCHECK(GetSocketAddr(&comm->sock.addr)); + handle->magic = NCCL_SOCKET_MAGIC; + NCCLCHECK(ncclSocketInit(&comm->sock, &ncclIbIfAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1)); NCCLCHECK(ncclSocketListen(&comm->sock)); - memcpy(&handle->connectAddr, &comm->sock.addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(&comm->sock, &handle->connectAddr)); *listenComm = comm; return ncclSuccess; } ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) { struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle; - enum ncclSocketState conState; struct ncclIbCommStage* stage = &handle->stage; struct ncclIbSendComm* comm = (struct ncclIbSendComm*)stage->comm; + int ready; *sendComm = NULL; if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; @@ -594,20 +598,15 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) { } NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(struct ncclIbSendComm))); - NCCLCHECK(ncclSocketInit(&comm->sock, &handle->connectAddr, NULL, 1)); + NCCLCHECK(ncclSocketInit(&comm->sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1)); stage->comm = comm; stage->state = ncclIbCommStateConnect; NCCLCHECK(ncclSocketConnect(&comm->sock)); ib_connect_check: /* since ncclSocketConnect is async, we must check if connection is complete */ - NCCLCHECK(ncclGetSocketState(&comm->sock, &conState)); - if (conState == ncclSocketConnecting) { - /* expect user to call again */ - return ncclSuccess; - } else if (conState == ncclSocketError) { - return ncclRemoteError; - } + NCCLCHECK(ncclSocketReady(&comm->sock, &ready)); + if (!ready) return ncclSuccess; // IB Setup struct ibv_context* ctx; @@ -619,6 +618,7 @@ ib_connect_check: for (int q=0; qnqps; q++) { NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, comm->qps+q)); } + comm->ar = ncclIbDevs[dev].ar; // ADAPTIVE_ROUTING // Send my QP Info to receiver through the socket. Hope this won't block. struct ibv_port_attr portAttr; @@ -670,9 +670,10 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { struct ncclIbListenComm* lComm = (struct ncclIbListenComm*)listenComm; struct ncclIbCommStage* stage = &lComm->stage; struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*)stage->comm; + int ready; *recvComm = NULL; - if (stage->state == ncclIbCommStateAccept) goto ib_accept; + if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; if (stage->state == ncclIbCommStateRecv) goto ib_recv; if (stage->state == ncclIbCommStateSend) goto ib_send; if (stage->state != ncclIbCommStateStart) { @@ -683,12 +684,12 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { NCCLCHECK(ncclIbMalloc((void**)&rComm, sizeof(struct ncclIbRecvComm))); stage->comm = rComm; stage->state = ncclIbCommStateAccept; - NCCLCHECK(ncclSocketInit(&rComm->sock, NULL, lComm->sock.abortFlag, 1)); - -ib_accept: + NCCLCHECK(ncclSocketInit(&rComm->sock)); NCCLCHECK(ncclSocketAccept(&rComm->sock, &lComm->sock)); - if (rComm->sock.fd == -1) - return ncclSuccess; + +ib_accept_check: + NCCLCHECK(ncclSocketReady(&rComm->sock, &ready)); + if (!ready) return ncclSuccess; struct ncclIbQpInfo remQpInfo; stage->state = ncclIbCommStateRecv; @@ -791,7 +792,7 @@ ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest** if (r->type == NCCL_NET_IB_REQ_UNUSED) { r->verbs = verbs; r->events = 1; - r->addr = NULL; + r->sock = NULL; *req = r; return ncclSuccess; } @@ -966,8 +967,8 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { } struct ibv_send_wr* lastWr = comm->wrs+nreqs-1; - if (nreqs > 1 || reqs[0]->send.size > ncclParamIbArThreshold()) { - // When using adaptive routing, send the bulk of the data first as an + if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamIbArThreshold())) { + // When using ADAPTIVE_ROUTING, send the bulk of the data first as an // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote // completion. lastWr++; @@ -1033,28 +1034,31 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh // Sanity checks to catch user collective call count/size mismatches if (size > slots[r].size) { - char line[SOCKET_NAME_MAXLEN+1]; + char line[SOCKET_NAME_MAXLEN + 1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(&comm->sock, &addr); WARN("NET/IB : req %d/%d tag %x peer %s collective mismatch error, local size %d remote size %d", - r, nreqs, tag, ncclSocketToString(&comm->sock.addr, line), size, slots[r].size); + r, nreqs, tag, ncclSocketToString(&addr, line), size, slots[r].size); return ncclInvalidUsage; } // plus any potential programming errors else if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkey == 0) { - char line[SOCKET_NAME_MAXLEN+1]; - WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkey %x", - r, nreqs, tag, ncclSocketToString(&comm->sock.addr, line), slots[r].size, slots[r].addr, slots[r].rkey); + char line[SOCKET_NAME_MAXLEN + 1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(&comm->sock, &addr); + WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkey %x", + r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkey); return ncclInternalError; } struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); req->type = NCCL_NET_IB_REQ_SEND; - req->addr = &comm->sock.addr; + req->sock = &comm->sock; req->verbs = &comm->verbs; req->nreqs = nreqs; req->send.size = size; req->send.data = data; req->send.lkey = mr->lkey; req->send.offset = 0; - req->addr = &comm->sock.addr; req->events = comm->nqps; *request = reqs[r] = req; @@ -1147,7 +1151,7 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); req->type = NCCL_NET_IB_REQ_RECV; - req->addr = &comm->sock.addr; + req->sock = &comm->sock; req->nreqs = n; for (int i=0; irecv.sizes[i] = 0; @@ -1186,7 +1190,7 @@ ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); req->type = NCCL_NET_IB_REQ_FLUSH; - req->addr = &comm->sock.addr; + req->sock = &comm->sock; struct ibv_mr* mr = (struct ibv_mr*)mhandles[last]; struct ibv_send_wr wr; @@ -1234,8 +1238,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { struct ibv_wc *wc = wcs+w; if (wc->status != IBV_WC_SUCCESS) { char line[SOCKET_NAME_MAXLEN+1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(r->sock, &addr); WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d", - ncclSocketToString(r->addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err); + ncclSocketToString(&addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err); return ncclRemoteError; } @@ -1267,7 +1273,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { ncclResult_t ncclIbCloseSend(void* sendComm) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; if (comm) { - close(comm->sock.fd); + NCCLCHECK(ncclSocketClose(&comm->sock)); for (int q=0; qnqps; q++) if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); if (comm->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr)); @@ -1281,7 +1287,7 @@ ncclResult_t ncclIbCloseSend(void* sendComm) { ncclResult_t ncclIbCloseRecv(void* recvComm) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; if (comm) { - close(comm->sock.fd); + NCCLCHECK(ncclSocketClose(&comm->sock)); for (int q=0; qnqps; q++) if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); if (comm->gpuFlush.enabled) { @@ -1298,7 +1304,7 @@ ncclResult_t ncclIbCloseRecv(void* recvComm) { ncclResult_t ncclIbCloseListen(void* listenComm) { struct ncclIbListenComm* comm = (struct ncclIbListenComm*)listenComm; if (comm) { - close(comm->sock.fd); + NCCLCHECK(ncclSocketClose(&comm->sock)); free(comm); } return ncclSuccess; diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc index 678aab8..08a8c3a 100644 --- a/src/transport/net_socket.cc +++ b/src/transport/net_socket.cc @@ -18,16 +18,16 @@ /* Init functions */ static int ncclNetIfs = -1; -struct ncclSocketDev { +struct ncclNetSocketDev { union ncclSocketAddress addr; char devName[MAX_IF_NAME_SIZE]; char* pciPath; }; -static struct ncclSocketDev ncclSocketDevs[MAX_IFS]; +static struct ncclNetSocketDev ncclNetSocketDevs[MAX_IFS]; -pthread_mutex_t ncclSocketLock = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t ncclNetSocketLock = PTHREAD_MUTEX_INITIALIZER; -static ncclResult_t ncclSocketGetPciPath(char* devName, char** pciPath) { +static ncclResult_t ncclNetSocketGetPciPath(char* devName, char** pciPath) { char devicePath[PATH_MAX]; snprintf(devicePath, PATH_MAX, "/sys/class/net/%s/device", devName); // May return NULL if the file doesn't exist. @@ -35,9 +35,9 @@ static ncclResult_t ncclSocketGetPciPath(char* devName, char** pciPath) { return ncclSuccess; } -ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) { +ncclResult_t ncclNetSocketInit(ncclDebugLogger_t logFunction) { if (ncclNetIfs == -1) { - pthread_mutex_lock(&ncclSocketLock); + pthread_mutex_lock(&ncclNetSocketLock); if (ncclNetIfs == -1) { char names[MAX_IF_NAME_SIZE*MAX_IFS]; union ncclSocketAddress addrs[MAX_IFS]; @@ -52,9 +52,9 @@ ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) { line[0] = '\0'; addrline[SOCKET_NAME_MAXLEN] = '\0'; for (int i=0; iname = ncclSocketDevs[dev].devName; - props->pciPath = ncclSocketDevs[dev].pciPath; +ncclResult_t ncclNetSocketGetProperties(int dev, ncclNetProperties_t* props) { + props->name = ncclNetSocketDevs[dev].devName; + props->pciPath = ncclNetSocketDevs[dev].pciPath; props->guid = dev; props->ptrSupport = NCCL_PTR_HOST; - NCCLCHECK(ncclSocketGetSpeed(props->name, &props->speed)); + NCCLCHECK(ncclNetSocketGetSpeed(props->name, &props->speed)); props->latency = 0; // Not set props->port = 0; props->maxComms = 65536; @@ -104,12 +104,6 @@ ncclResult_t ncclSocketGetProperties(int dev, ncclNetProperties_t* props) { return ncclSuccess; } -ncclResult_t GetSocketAddr(int dev, union ncclSocketAddress* addr) { - if (dev >= ncclNetIfs) return ncclInternalError; - memcpy(addr, &ncclSocketDevs[dev].addr, sizeof(*addr)); - return ncclSuccess; -} - /* Communication functions */ #define MAX_SOCKETS 64 @@ -120,29 +114,30 @@ ncclResult_t GetSocketAddr(int dev, union ncclSocketAddress* addr) { NCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2); NCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2); -enum ncclSocketCommState { - ncclSocketCommStateStart = 0, - ncclSocketCommStateConnect = 1, - ncclSocketCommStateAccept = 3, - ncclSocketCommStateSend = 4, - ncclSocketCommStateRecv = 5, +enum ncclNetSocketCommState { + ncclNetSocketCommStateStart = 0, + ncclNetSocketCommStateConnect = 1, + ncclNetSocketCommStateAccept = 3, + ncclNetSocketCommStateSend = 4, + ncclNetSocketCommStateRecv = 5, }; -struct ncclSocketCommStage { - enum ncclSocketCommState state; +struct ncclNetSocketCommStage { + enum ncclNetSocketCommState state; uint8_t iteration; struct ncclSocket* sock; - struct ncclSocketComm* comm; + struct ncclNetSocketComm* comm; }; -struct ncclSocketHandle { +struct ncclNetSocketHandle { union ncclSocketAddress connectAddr; + uint64_t magic; // random number to help debugging int nSocks; int nThreads; - struct ncclSocketCommStage stage; + struct ncclNetSocketCommStage stage; }; -struct ncclSocketTask { +struct ncclNetSocketTask { int op; void* data; int size; @@ -152,41 +147,41 @@ struct ncclSocketTask { ncclResult_t result; }; -struct ncclSocketRequest { +struct ncclNetSocketRequest { int op; void* data; int size; struct ncclSocket* ctrlSock; int offset; int used; - struct ncclSocketComm* comm; - struct ncclSocketTask* tasks[MAX_SOCKETS]; + struct ncclNetSocketComm* comm; + struct ncclNetSocketTask* tasks[MAX_SOCKETS]; int nSubs; }; -struct ncclSocketTaskQueue { +struct ncclNetSocketTaskQueue { int next; int len; - struct ncclSocketTask* tasks; + struct ncclNetSocketTask* tasks; }; -struct ncclSocketThreadResources { - struct ncclSocketTaskQueue threadTaskQueue; +struct ncclNetSocketThreadResources { + struct ncclNetSocketTaskQueue threadTaskQueue; int stop; - struct ncclSocketComm* comm; + struct ncclNetSocketComm* comm; pthread_mutex_t threadLock; pthread_cond_t threadCond; }; -struct ncclSocketListenComm { +struct ncclNetSocketListenComm { struct ncclSocket sock; - struct ncclSocketCommStage stage; + struct ncclNetSocketCommStage stage; int nSocks; int nThreads; int dev; }; -struct ncclSocketComm { +struct ncclNetSocketComm { struct ncclSocket ctrlSock; struct ncclSocket socks[MAX_SOCKETS]; int dev; @@ -194,15 +189,15 @@ struct ncclSocketComm { int nSocks; int nThreads; int nextSock; - struct ncclSocketRequest requests[MAX_REQUESTS]; + struct ncclNetSocketRequest requests[MAX_REQUESTS]; pthread_t helperThread[MAX_THREADS]; - struct ncclSocketThreadResources threadResources[MAX_THREADS]; + struct ncclNetSocketThreadResources threadResources[MAX_THREADS]; }; void* persistentSocketThread(void *args_) { - struct ncclSocketThreadResources* resource = (struct ncclSocketThreadResources*)args_; - struct ncclSocketComm* comm = resource->comm; - struct ncclSocketTaskQueue* myQueue = &resource->threadTaskQueue; + struct ncclNetSocketThreadResources* resource = (struct ncclNetSocketThreadResources*)args_; + struct ncclNetSocketComm* comm = resource->comm; + struct ncclNetSocketTaskQueue* myQueue = &resource->threadTaskQueue; int nSocksPerThread = comm->nSocks / comm->nThreads; while (1) { int idle = 1; @@ -212,7 +207,7 @@ void* persistentSocketThread(void *args_) { do { repeat = 0; for (int j=0; jtasks+i+j; + struct ncclNetSocketTask* r = myQueue->tasks+i+j; if (r != NULL && r->used == 1 && r->offset < r->size) { r->result = ncclSocketProgress(r->op, r->sock, r->data, r->size, &r->offset); if (r->result != ncclSuccess) { @@ -236,7 +231,7 @@ void* persistentSocketThread(void *args_) { } } -ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) { +ncclResult_t ncclNetSocketGetNsockNthread(int dev, int* ns, int* nt) { int nSocksPerThread = ncclParamSocketNsocksPerThread(); int nThreads = ncclParamSocketNthreads(); if (nThreads > MAX_THREADS) { @@ -247,7 +242,7 @@ ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) { // Auto-detection int autoNt=0, autoNs=1; // By default, we only use the main thread and do not spawn extra threads char vendorPath[PATH_MAX]; - snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclSocketDevs[dev].devName); + snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclNetSocketDevs[dev].devName); char* rPath = realpath(vendorPath, NULL); int fd = open(rPath, O_RDONLY); free(rPath); @@ -285,36 +280,20 @@ end: return ncclSuccess; } -ncclResult_t ncclSocketNewListenComm(struct ncclSocketListenComm** comm) { - NCCLCHECK(ncclCalloc(comm, 1)); - (*comm)->sock.fd = -1; - return ncclSuccess; -} - -ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) { - NCCLCHECK(ncclCalloc(comm, 1)); - (*comm)->ctrlSock.fd = -1; - for (int i=0; i < MAX_SOCKETS; i++) { - (*comm)->socks[i].fd = -1; - } - (*comm)->nextSock = 0; - return ncclSuccess; -} - -ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) { - if (dev < 0) { // data transfer socket is based on specified dev +ncclResult_t ncclNetSocketListen(int dev, void* opaqueHandle, void** listenComm) { + if (dev < 0 || dev >= ncclNetIfs) { // data transfer socket is based on specified dev return ncclInternalError; } - struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle; - memset(handle, 0, sizeof(struct ncclSocketHandle)); - static_assert(sizeof(struct ncclSocketHandle) <= NCCL_NET_HANDLE_MAXSIZE, "ncclSocketHandle size too large"); - struct ncclSocketListenComm* comm; - NCCLCHECK(ncclSocketNewListenComm(&comm)); - NCCLCHECK(GetSocketAddr(dev, &comm->sock.addr)); - comm->sock.asyncFlag = 1; + struct ncclNetSocketHandle* handle = (struct ncclNetSocketHandle*) opaqueHandle; + memset(handle, 0, sizeof(struct ncclNetSocketHandle)); + static_assert(sizeof(struct ncclNetSocketHandle) <= NCCL_NET_HANDLE_MAXSIZE, "ncclNetSocketHandle size too large"); + struct ncclNetSocketListenComm* comm; + NCCLCHECK(ncclCalloc(&comm, 1)); + handle->magic = NCCL_SOCKET_MAGIC; + NCCLCHECK(ncclSocketInit(&comm->sock, &ncclNetSocketDevs[dev].addr, handle->magic, ncclSocketTypeNetSocket, NULL, 1)); NCCLCHECK(ncclSocketListen(&comm->sock)); - memcpy(&handle->connectAddr, &comm->sock.addr, sizeof(union ncclSocketAddress)); - NCCLCHECK(ncclSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads)); + NCCLCHECK(ncclSocketGetAddr(&comm->sock, &handle->connectAddr)); + NCCLCHECK(ncclNetSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads)); handle->nSocks = comm->nSocks; handle->nThreads = comm->nThreads; comm->dev = dev; @@ -322,46 +301,41 @@ ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) { return ncclSuccess; } -ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) { - if (dev < 0) { // data transfer socket is based on specified dev +ncclResult_t ncclNetSocketConnect(int dev, void* opaqueHandle, void** sendComm) { + if (dev < 0 || dev >= ncclNetIfs) { // data transfer socket is based on specified dev return ncclInternalError; } - enum ncclSocketState conState; - struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle; - struct ncclSocketCommStage* stage = &handle->stage; - struct ncclSocketComm* comm = stage->comm; + int ready; + struct ncclNetSocketHandle* handle = (struct ncclNetSocketHandle*) opaqueHandle; + struct ncclNetSocketCommStage* stage = &handle->stage; + struct ncclNetSocketComm* comm = stage->comm; uint8_t i = stage->iteration; struct ncclSocket* sock = stage->sock; *sendComm = NULL; - if (stage->state == ncclSocketCommStateConnect) goto socket_connect_check; - if (stage->state == ncclSocketCommStateSend) goto socket_send; + if (stage->state == ncclNetSocketCommStateConnect) goto socket_connect_check; + if (stage->state == ncclNetSocketCommStateSend) goto socket_send; - NCCLCHECK(ncclSocketNewComm(&comm)); + NCCLCHECK(ncclCalloc(&comm, 1)); stage->comm = comm; comm->nSocks = handle->nSocks; comm->nThreads = handle->nThreads; comm->dev = dev; CUDACHECK(cudaGetDevice(&comm->cudaDev)); for (; inSocks+1; i++) { - sock = i == comm->nSocks ? &comm->ctrlSock : comm->socks+i; - NCCLCHECK(ncclSocketInit(sock, &handle->connectAddr, NULL, 1)); + sock = (i == comm->nSocks) ? &comm->ctrlSock : comm->socks+i; + NCCLCHECK(ncclSocketInit(sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetSocket, NULL, 1)); stage->sock = sock; - stage->state = ncclSocketCommStateConnect; + stage->state = ncclNetSocketCommStateConnect; stage->iteration = i; NCCLCHECK(ncclSocketConnect(sock)); socket_connect_check: - NCCLCHECK(ncclGetSocketState(sock, &conState)); - if (conState == ncclSocketConnecting) { - /* expect user to call again */ - return ncclSuccess; - } else if (conState == ncclSocketError) { - return ncclRemoteError; - } - stage->state = ncclSocketCommStateSend; + NCCLCHECK(ncclSocketReady(sock, &ready)); + if (! ready) return ncclSuccess; + stage->state = ncclNetSocketCommStateSend; socket_send: int done = 0; @@ -372,59 +346,63 @@ socket_send: return ncclSuccess; } -ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) { - struct ncclSocketListenComm* lComm = (struct ncclSocketListenComm*)listenComm; - struct ncclSocketCommStage* stage = &lComm->stage; - struct ncclSocketComm* rComm = stage->comm; +ncclResult_t ncclNetSocketAccept(void* listenComm, void** recvComm) { + struct ncclNetSocketListenComm* lComm = (struct ncclNetSocketListenComm*)listenComm; + struct ncclNetSocketCommStage* stage = &lComm->stage; + struct ncclNetSocketComm* rComm = stage->comm; uint8_t i = stage->iteration; struct ncclSocket* sock = stage->sock; + int ready; *recvComm = NULL; - if (stage->state == ncclSocketCommStateAccept) goto socket_accept; - if (stage->state == ncclSocketCommStateRecv) goto socket_recv; + if (stage->state == ncclNetSocketCommStateAccept) goto socket_accept_check; + if (stage->state == ncclNetSocketCommStateRecv) goto socket_recv; - NCCLCHECK(ncclSocketNewComm(&rComm)); + NCCLCHECK(ncclCalloc(&rComm, 1)); stage->comm = rComm; rComm->nSocks = lComm->nSocks; rComm->nThreads = lComm->nThreads; rComm->dev = lComm->dev; CUDACHECK(cudaGetDevice(&rComm->cudaDev)); - lComm->sock.asyncFlag = 1; for (; inSocks+1; i++) { uint8_t sendSockIdx; - ncclCalloc(&sock, 1); - NCCLCHECK(ncclSocketInit(sock, NULL, lComm->sock.abortFlag, 1)); - stage->sock = sock; - stage->state = ncclSocketCommStateAccept; - stage->iteration = i; -socket_accept: - NCCLCHECK(ncclSocketAccept(sock, &lComm->sock)); - if (sock->fd == -1) return ncclSuccess; - stage->state = ncclSocketCommStateRecv; + NCCLCHECK(ncclCalloc(&sock, 1)); + NCCLCHECK(ncclSocketInit(sock)); + stage->sock = sock; + stage->state = ncclNetSocketCommStateAccept; + stage->iteration = i; + NCCLCHECK(ncclSocketAccept(sock, &lComm->sock)); + +socket_accept_check: + NCCLCHECK(ncclSocketReady(sock, &ready)); + if (!ready) return ncclSuccess; + + stage->state = ncclNetSocketCommStateRecv; socket_recv: int done = 0; NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &sendSockIdx, sizeof(uint8_t), &done)); if (done == 0) return ncclSuccess; - if (sendSockIdx == rComm->nSocks) memcpy(&rComm->ctrlSock, sock, sizeof(struct ncclSocket)); - else memcpy(rComm->socks+sendSockIdx, sock, sizeof(struct ncclSocket)); - + if (sendSockIdx == rComm->nSocks) + memcpy(&rComm->ctrlSock, sock, sizeof(struct ncclSocket)); + else + memcpy(rComm->socks+sendSockIdx, sock, sizeof(struct ncclSocket)); free(sock); } *recvComm = rComm; /* reset lComm state */ - stage->state = ncclSocketCommStateStart; + stage->state = ncclNetSocketCommStateStart; stage->iteration = 0; stage->sock = NULL; stage->comm = NULL; return ncclSuccess; } -ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketRequest** req) { +ncclResult_t ncclNetSocketGetRequest(struct ncclNetSocketComm* comm, int op, void* data, int size, struct ncclNetSocketRequest** req) { for (int i=0; irequests+i; + struct ncclNetSocketRequest* r = comm->requests+i; if (r->used == 0) { r->op = op; r->data = data; @@ -441,10 +419,10 @@ ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* dat return ncclInternalError; } -ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketTask** req) { +ncclResult_t ncclNetSocketGetTask(struct ncclNetSocketComm* comm, int op, void* data, int size, struct ncclNetSocketTask** req) { int tid = comm->nextSock % comm->nThreads; - struct ncclSocketThreadResources* res = comm->threadResources+tid; - struct ncclSocketTaskQueue* queue = &res->threadTaskQueue; + struct ncclNetSocketThreadResources* res = comm->threadResources+tid; + struct ncclNetSocketTaskQueue* queue = &res->threadTaskQueue; // create helper threads and prepare per-thread task queue if (queue->tasks == NULL) { // each request can be divided up to nSocks tasks, and @@ -459,12 +437,12 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, pthread_create(comm->helperThread+tid, NULL, persistentSocketThread, res); ncclSetThreadName(comm->helperThread[tid], "NCCL Sock%c%1u%2u%2u", op == NCCL_SOCKET_SEND ? 'S' : 'R', comm->dev, tid, comm->cudaDev); } - struct ncclSocketTask* r = queue->tasks+queue->next; + struct ncclNetSocketTask* r = queue->tasks+queue->next; if (r->used == 0) { r->op = op; r->data = data; r->size = size; - r->sock = comm->socks+comm->nextSock; + r->sock = comm->socks + comm->nextSock; r->offset = 0; r->result = ncclSuccess; comm->nextSock = (comm->nextSock + 1) % comm->nSocks; @@ -480,9 +458,9 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, return ncclInternalError; } -ncclResult_t ncclSocketTest(void* request, int* done, int* size) { +ncclResult_t ncclNetSocketTest(void* request, int* done, int* size) { *done = 0; - struct ncclSocketRequest *r = (struct ncclSocketRequest*)request; + struct ncclNetSocketRequest *r = (struct ncclNetSocketRequest*)request; if (r == NULL) { WARN("NET/Socket : test called with NULL request"); return ncclInternalError; @@ -500,9 +478,11 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { // Check size is less or equal to the size provided by the user if (r->op == NCCL_SOCKET_RECV && data > r->size) { char line[SOCKET_NAME_MAXLEN+1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(r->ctrlSock, &addr); WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in healthy state, \ there may be a mismatch in collective sizes or environment settings (e.g. NCCL_PROTO, NCCL_ALGO) between ranks", - ncclSocketToString(&r->ctrlSock->addr, line), data, r->size); + ncclSocketToString(&addr, line), data, r->size); return ncclInvalidUsage; } r->size = data; @@ -515,7 +495,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks)); while (chunkOffset < r->size) { int chunkSize = std::min(taskSize, r->size-chunkOffset); - NCCLCHECK(ncclSocketGetTask(r->comm, r->op, (char*)(r->data)+chunkOffset, chunkSize, r->tasks+i++)); + NCCLCHECK(ncclNetSocketGetTask(r->comm, r->op, (char*)(r->data)+chunkOffset, chunkSize, r->tasks+i++)); chunkOffset += chunkSize; } } @@ -525,7 +505,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { if (r->nSubs > 0) { int nCompleted = 0; for (int i=0; inSubs; i++) { - struct ncclSocketTask* sub = r->tasks[i]; + struct ncclNetSocketTask* sub = r->tasks[i]; if (sub->result != ncclSuccess) return sub->result; if (sub->offset == sub->size) nCompleted++; } @@ -534,7 +514,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { *done = 1; r->used = 0; for (int i=0; inSubs; i++) { - struct ncclSocketTask* sub = r->tasks[i]; + struct ncclNetSocketTask* sub = r->tasks[i]; sub->used = 0; } } @@ -552,43 +532,45 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { return ncclSuccess; } -ncclResult_t ncclSocketRegMr(void* comm, void* data, int size, int type, void** mhandle) { +ncclResult_t ncclNetSocketRegMr(void* comm, void* data, int size, int type, void** mhandle) { return (type != NCCL_PTR_HOST) ? ncclInternalError : ncclSuccess; } -ncclResult_t ncclSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess; } +ncclResult_t ncclNetSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess; } -ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { - struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm; - NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclSocketRequest**)request)); +ncclResult_t ncclNetSocketIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { + struct ncclNetSocketComm* comm = (struct ncclNetSocketComm*)sendComm; + NCCLCHECK(ncclNetSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclNetSocketRequest**)request)); return ncclSuccess; } -ncclResult_t ncclSocketIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { - struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm; +ncclResult_t ncclNetSocketIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { + struct ncclNetSocketComm* comm = (struct ncclNetSocketComm*)recvComm; if (n != 1) return ncclInternalError; - NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_RECV, data[0], sizes[0], (struct ncclSocketRequest**)request)); + NCCLCHECK(ncclNetSocketGetRequest(comm, NCCL_SOCKET_RECV, data[0], sizes[0], (struct ncclNetSocketRequest**)request)); return ncclSuccess; } -ncclResult_t ncclSocketIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { +ncclResult_t ncclNetSocketIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { // We don't support CUDA pointers, so we don't need a flush operation return ncclInternalError; } -ncclResult_t ncclSocketCloseListen(void* opaqueComm) { - struct ncclSocketListenComm* comm = (struct ncclSocketListenComm*)opaqueComm; +ncclResult_t ncclNetSocketCloseListen(void* opaqueComm) { + struct ncclNetSocketListenComm* comm = (struct ncclNetSocketListenComm*)opaqueComm; if (comm) { - if (comm->sock.fd != -1) close(comm->sock.fd); + int ready; + NCCLCHECK(ncclSocketReady(&comm->sock, &ready)); + if (ready) NCCLCHECK(ncclSocketClose(&comm->sock)); free(comm); } return ncclSuccess; } -ncclResult_t ncclSocketClose(void* opaqueComm) { - struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm; +ncclResult_t ncclNetSocketClose(void* opaqueComm) { + struct ncclNetSocketComm* comm = (struct ncclNetSocketComm*)opaqueComm; if (comm) { for (int i=0; inThreads; i++) { - struct ncclSocketThreadResources* res = comm->threadResources+i; + struct ncclNetSocketThreadResources* res = comm->threadResources+i; if (comm->helperThread[i]) { pthread_mutex_lock(&res->threadLock); res->stop = 1; @@ -598,9 +580,12 @@ ncclResult_t ncclSocketClose(void* opaqueComm) { } free(res->threadTaskQueue.tasks); } - if (comm->ctrlSock.fd != -1) close(comm->ctrlSock.fd); + int ready; + NCCLCHECK(ncclSocketReady(&comm->ctrlSock, &ready)); + if (ready) NCCLCHECK(ncclSocketClose(&comm->ctrlSock)); for (int i=0; inSocks; i++) { - if (comm->socks[i].fd != -1) close(comm->socks[i].fd); + NCCLCHECK(ncclSocketReady(&comm->socks[i], &ready)); + if (ready) NCCLCHECK(ncclSocketClose(&comm->socks[i])); } free(comm); } @@ -609,20 +594,20 @@ ncclResult_t ncclSocketClose(void* opaqueComm) { ncclNet_t ncclNetSocket = { "Socket", - ncclSocketInit, - ncclSocketDevices, - ncclSocketGetProperties, - ncclSocketListen, - ncclSocketConnect, - ncclSocketAccept, - ncclSocketRegMr, + ncclNetSocketInit, + ncclNetSocketDevices, + ncclNetSocketGetProperties, + ncclNetSocketListen, + ncclNetSocketConnect, + ncclNetSocketAccept, + ncclNetSocketRegMr, NULL, // No DMA-BUF support - ncclSocketDeregMr, - ncclSocketIsend, - ncclSocketIrecv, - ncclSocketIflush, - ncclSocketTest, - ncclSocketClose, - ncclSocketClose, - ncclSocketCloseListen + ncclNetSocketDeregMr, + ncclNetSocketIsend, + ncclNetSocketIrecv, + ncclNetSocketIflush, + ncclNetSocketTest, + ncclNetSocketClose, + ncclNetSocketClose, + ncclNetSocketCloseListen }; diff --git a/src/transport/p2p.cc b/src/transport/p2p.cc index b0bad4a..e7a4fd0 100644 --- a/src/transport/p2p.cc +++ b/src/transport/p2p.cc @@ -34,6 +34,7 @@ struct p2pProxyInfo { struct p2pShm* devShm; char shmName[7]; int shmSize; + ncclShmHandle_t handle; // Intermediate step for sender struct ncclRecvMem* ceRecvMem; @@ -63,6 +64,7 @@ struct p2pRecvResources { struct p2pShm* shm; struct p2pShm* devShm; int shmSize; + ncclShmHandle_t handle; }; #include @@ -351,9 +353,7 @@ ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectIn sprintf(shmPath, "/dev/shm/nccl-%s", info->shmName); TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmPath, info->shmSize); resources->shmSize = info->shmSize; - NCCLCHECK(ncclShmOpen(shmPath, info->shmSize, (void**)&resources->shm, (void**)&resources->devShm, 0)); - // Remove the file to ensure proper clean-up - NCCLCHECK(ncclShmUnlink(shmPath)); + NCCLCHECK(ncclShmOpen(shmPath, info->shmSize, (void**)&resources->shm, (void**)&resources->devShm, -1, &resources->handle)); recv->conn.tail = &resources->devShm->recvMem.tail; recv->conn.head = &resources->devShm->sendMem.head; @@ -396,7 +396,7 @@ ncclResult_t p2pRecvFree(struct ncclConnector* recv) { if (resources->sendMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->sendMemIpc)); if (resources->recvMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->recvMemIpc)); if (useMemcpy) { - NCCLCHECK(ncclShmClose(resources->shm, resources->devShm, resources->shmSize)); + NCCLCHECK(ncclShmClose(resources->handle)); } free(resources); } @@ -414,7 +414,7 @@ static ncclResult_t p2pSendProxySetup(struct ncclProxyConnection* connection, st char shmPath[PATH_MAX]; shmPath[0] = '\0'; proxyInfo->shmSize = sizeof(struct ncclSendMem) + sizeof(struct ncclRecvMem); - NCCLCHECK(ncclShmOpen(shmPath, proxyInfo->shmSize, (void**)&proxyInfo->shm, (void**)&proxyInfo->devShm, 1)); + NCCLCHECK(ncclShmOpen(shmPath, proxyInfo->shmSize, (void**)&proxyInfo->shm, (void**)&proxyInfo->devShm, 1, &proxyInfo->handle)); TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, proxyInfo->shmSize); memcpy(proxyInfo->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(proxyInfo->shmName)); @@ -477,7 +477,7 @@ static ncclResult_t p2pSendProxyFree(struct ncclProxyConnection* connection, str if (useMemcpy) { struct p2pProxyInfo* proxyInfo = (struct p2pProxyInfo*)connection->transportResources; if (proxyInfo) { - NCCLCHECK(ncclShmClose(proxyInfo->shm, proxyInfo->devShm, proxyInfo->shmSize)); + NCCLCHECK(ncclShmClose(proxyInfo->handle)); NCCLCHECK(ncclCudaHostFree(proxyInfo->ceRecvMem)); CUDACHECK(cudaFree(proxyInfo->ceDevBuff)); CUDACHECK(cudaStreamDestroy(proxyInfo->stream)); diff --git a/src/transport/shm.cc b/src/transport/shm.cc index 740bd2a..4bce480 100644 --- a/src/transport/shm.cc +++ b/src/transport/shm.cc @@ -17,18 +17,22 @@ struct shmSendResources { int remShmSize; struct ncclRecvMem* remHostMem; struct ncclRecvMem* devRemHostMem; + ncclShmHandle_t remHandle; int shmSize; struct ncclSendMem* hostMem; struct ncclSendMem* devHostMem; + ncclShmHandle_t hostHandle; }; struct shmRecvResources { int remShmSize; struct ncclSendMem* remHostMem; struct ncclSendMem* devRemHostMem; + ncclShmHandle_t remHandle; int shmSize; struct ncclRecvMem* hostMem; struct ncclRecvMem* devHostMem; + ncclShmHandle_t hostHandle; }; #define SHM_SEND_SIDE 1 @@ -84,7 +88,7 @@ static ncclResult_t shmSendSetup(struct ncclComm* comm, struct ncclTopoGraph* gr for (int p=0; pcomm->buffSizes[p]; } info->shmSize = resources->shmSize = shmSize; - NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1)); + NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1, &resources->hostHandle)); TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, info->shmSize); memcpy(info->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(info->shmName)); @@ -107,7 +111,7 @@ static ncclResult_t shmRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* gr for (int p=0; pcomm->buffSizes[p]; } info->shmSize = resources->shmSize = shmSize; - NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1)); + NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1, &resources->hostHandle)); TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, info->shmSize); memcpy(info->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(info->shmName)); @@ -137,9 +141,7 @@ static ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* co sprintf(shmPath, "/dev/shm/nccl-%s", info->shmName); resources->remShmSize = info->shmSize; TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmPath, info->shmSize); - NCCLCHECK(ncclShmOpen(shmPath, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, 0)); - // Remove the file to ensure proper clean-up - NCCLCHECK(ncclShmUnlink(shmPath)); + NCCLCHECK(ncclShmOpen(shmPath, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, -1, &resources->remHandle)); char* buff = shmLocality == SHM_SEND_SIDE ? (char*)(resources->devHostMem+1) : (char*)(resources->devRemHostMem+1); for (int p=0; pshmName); resources->remShmSize = info->shmSize; TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmPath, info->shmSize); - NCCLCHECK(ncclShmOpen(shmPath, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, 0)); - NCCLCHECK(ncclShmUnlink(shmPath)); + NCCLCHECK(ncclShmOpen(shmPath, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, -1, &resources->remHandle)); char* buff = shmLocality == SHM_RECV_SIDE ? (char*)(resources->devHostMem+1) : (char*)(resources->devRemHostMem+1); for (int p=0; ptransportResources; if (resources) { - NCCLCHECK(ncclShmClose(resources->hostMem, resources->devHostMem, resources->shmSize)); - NCCLCHECK(ncclShmClose(resources->remHostMem, resources->devRemHostMem, resources->remShmSize)); + NCCLCHECK(ncclShmClose(resources->hostHandle)); + NCCLCHECK(ncclShmClose(resources->remHandle)); free(resources); } return ncclSuccess; @@ -206,8 +207,8 @@ static ncclResult_t shmSendFree(struct ncclConnector* send) { static ncclResult_t shmRecvFree(struct ncclConnector* recv) { struct shmRecvResources* resources = (struct shmRecvResources*)recv->transportResources; if (resources) { - NCCLCHECK(ncclShmClose(resources->hostMem, resources->devHostMem, resources->shmSize)); - NCCLCHECK(ncclShmClose(resources->remHostMem, resources->devRemHostMem, resources->remShmSize)); + NCCLCHECK(ncclShmClose(resources->hostHandle)); + NCCLCHECK(ncclShmClose(resources->remHandle)); free(resources); } return ncclSuccess;