From 28189e2df85885b78bf20384cefc3cedab023bf5 Mon Sep 17 00:00:00 2001 From: Sylvain Jeaugey Date: Tue, 29 Nov 2022 04:27:46 -0800 Subject: [PATCH] 2.16.2-1 Add support for CUDA 12.0, drop Kepler (sm_35). Support for H100 features. Make socket code more robust and protected. Solves #555. Improve performance on large CUDA graphs, reducing dependencies. Reduce inter-socket bandwidth on AMD CPUs to favor better paths. Various fixes to ncclCommAbort. Make service thread polling resistant to EINTR. Compile with profiling API by default. Extend NVTX instrumentation with call arguments. --- makefiles/common.mk | 16 +- makefiles/version.mk | 4 +- src/Makefile | 2 +- src/bootstrap.cc | 197 +++-- src/collectives/all_gather.cc | 8 +- src/collectives/all_reduce.cc | 15 +- src/collectives/broadcast.cc | 12 +- src/collectives/device/common.h | 3 + src/collectives/device/prims_simple.h | 14 +- src/collectives/reduce.cc | 16 +- src/collectives/reduce_scatter.cc | 14 +- src/collectives/sendrecv.cc | 17 +- src/enqueue.cc | 76 +- src/graph/paths.cc | 12 + src/graph/topo.cc | 3 + src/graph/topo.h | 1 + src/include/bootstrap.h | 12 +- src/include/comm.h | 6 +- src/include/graph.h | 1 + src/include/nvtx.h | 71 ++ src/include/nvtx3.hpp | 17 + src/include/nvtx3/nvToolsExtPayload.h | 776 ++++++++++++++++++ src/include/nvtx3/nvtxExtDetail/nvtxExtImpl.h | 93 +++ .../nvtxExtDetail/nvtxExtImplPayload_v1.h | 85 ++ src/include/nvtx3/nvtxExtDetail/nvtxExtInit.h | 363 ++++++++ .../nvtxExtDetail/nvtxExtPayloadTypeInfo.h | 128 +++ .../nvtx3/nvtxExtDetail/nvtxExtTypes.h | 44 + src/include/proxy.h | 15 +- src/include/shm.h | 8 +- src/include/socket.h | 49 +- src/include/strongstream.h | 9 +- src/include/utils.h | 14 + src/init.cc | 648 ++++++++------- src/init_nvtx.cc | 26 + src/misc/cudawrap.cc | 2 +- src/misc/shmutils.cc | 193 +++-- src/misc/socket.cc | 597 ++++++++++---- src/misc/strongstream.cc | 15 +- src/nccl.h.in | 2 +- src/proxy.cc | 197 +++-- src/transport.cc | 41 +- src/transport/net.cc | 102 +-- src/transport/net_ib.cc | 86 +- src/transport/net_socket.cc | 317 ++++--- src/transport/p2p.cc | 12 +- src/transport/shm.cc | 23 +- 46 files changed, 3325 insertions(+), 1037 deletions(-) create mode 100644 src/include/nvtx3/nvToolsExtPayload.h create mode 100644 src/include/nvtx3/nvtxExtDetail/nvtxExtImpl.h create mode 100644 src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h create mode 100644 src/include/nvtx3/nvtxExtDetail/nvtxExtInit.h create mode 100644 src/include/nvtx3/nvtxExtDetail/nvtxExtPayloadTypeInfo.h create mode 100644 src/include/nvtx3/nvtxExtDetail/nvtxExtTypes.h create mode 100644 src/init_nvtx.cc diff --git a/makefiles/common.mk b/makefiles/common.mk index 0c0d04a..35d1826 100644 --- a/makefiles/common.mk +++ b/makefiles/common.mk @@ -10,7 +10,7 @@ VERBOSE ?= 0 KEEP ?= 0 DEBUG ?= 0 TRACE ?= 0 -PROFAPI ?= 0 +PROFAPI ?= 1 NVTX ?= 1 NVCC = $(CUDA_HOME)/bin/nvcc @@ -25,22 +25,26 @@ CUDA_MINOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 2) # You should define NVCC_GENCODE in your environment to the minimal set # of archs to reduce compile time. -CUDA8_GENCODE = -gencode=arch=compute_35,code=sm_35 \ - -gencode=arch=compute_50,code=sm_50 \ +CUDA8_GENCODE = -gencode=arch=compute_50,code=sm_50 \ -gencode=arch=compute_60,code=sm_60 \ -gencode=arch=compute_61,code=sm_61 +ifeq ($(shell test "0$(CUDA_MAJOR)" -lt 12; echo $$?),0) +# SM35 is deprecated from CUDA12.0 onwards +CUDA8_GENCODE += -gencode=arch=compute_35,code=sm_35 +endif CUDA9_GENCODE = -gencode=arch=compute_70,code=sm_70 CUDA11_GENCODE = -gencode=arch=compute_80,code=sm_80 -CUDA11_8_GENCODE = -gencode=arch=compute_90,code=sm_90 +CUDA12_GENCODE = -gencode=arch=compute_90,code=sm_90 CUDA8_PTX = -gencode=arch=compute_61,code=compute_61 CUDA9_PTX = -gencode=arch=compute_70,code=compute_70 CUDA11_PTX = -gencode=arch=compute_80,code=compute_80 -CUDA11_8_PTX = -gencode=arch=compute_90,code=compute_90 +CUDA12_PTX = -gencode=arch=compute_90,code=compute_90 + ifeq ($(shell test "0$(CUDA_MAJOR)" -eq 11 -a "0$(CUDA_MINOR)" -ge 8 -o "0$(CUDA_MAJOR)" -gt 11; echo $$?),0) # Include Hopper support if we're using CUDA11.8 or above - NVCC_GENCODE ?= $(CUDA8_GENCODE) $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA11_8_GENCODE) $(CUDA11_8_PTX) + NVCC_GENCODE ?= $(CUDA8_GENCODE) $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA12_GENCODE) $(CUDA12_PTX) else ifeq ($(shell test "0$(CUDA_MAJOR)" -ge 11; echo $$?),0) NVCC_GENCODE ?= $(CUDA8_GENCODE) $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA11_PTX) # Include Volta support if we're using CUDA9 or above diff --git a/makefiles/version.mk b/makefiles/version.mk index be64e9a..f4a0d8d 100644 --- a/makefiles/version.mk +++ b/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 -NCCL_MINOR := 15 -NCCL_PATCH := 5 +NCCL_MINOR := 16 +NCCL_PATCH := 2 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/src/Makefile b/src/Makefile index 1539e14..4753018 100644 --- a/src/Makefile +++ b/src/Makefile @@ -9,7 +9,7 @@ include ../makefiles/version.mk ##### src files INCEXPORTS := nccl.h nccl_net.h -LIBSRCFILES := init.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \ +LIBSRCFILES := init.cc init_nvtx.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \ misc/cudawrap.cc misc/nvmlwrap.cc misc/ibvwrap.cc misc/gdrwrap.cc \ misc/utils.cc misc/argcheck.cc misc/socket.cc misc/shmutils.cc misc/profiler.cc misc/param.cc misc/strongstream.cc \ transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc \ diff --git a/src/bootstrap.cc b/src/bootstrap.cc index b7e0576..c348b3e 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -13,6 +13,11 @@ #include #include "proxy.h" +struct bootstrapRootArgs { + struct ncclSocket* listenSock; + uint64_t magic; +}; + /* Init functions */ static char bootstrapNetIfName[MAX_IF_NAME_SIZE+1]; static union ncclSocketAddress bootstrapNetIfAddr; @@ -26,7 +31,7 @@ ncclResult_t bootstrapNetInit() { char* env = getenv("NCCL_COMM_ID"); if (env) { union ncclSocketAddress remoteAddr; - if (ncclGetSocketAddrFromString(&remoteAddr, env) != ncclSuccess) { + if (ncclSocketGetAddrFromString(&remoteAddr, env) != ncclSuccess) { WARN("Invalid NCCL_COMM_ID, please use format: : or []: or :"); return ncclInvalidArgument; } @@ -89,8 +94,10 @@ static ncclResult_t setFilesLimit() { return ncclSuccess; } -static void *bootstrapRoot(void* args) { - struct ncclSocket* listenSock = (struct ncclSocket*)args; +static void *bootstrapRoot(void* rargs) { + struct bootstrapRootArgs* args = (struct bootstrapRootArgs*)rargs; + struct ncclSocket* listenSock = args->listenSock; + uint64_t magic = args->magic; ncclResult_t res = ncclSuccess; int nranks = 0, c = 0; struct extInfo info; @@ -104,11 +111,10 @@ static void *bootstrapRoot(void* args) { /* Receive addresses from all ranks */ do { struct ncclSocket sock; - /* bootstrap root thread always uses blocking ncclSocketAccept. */ - NCCLCHECKGOTO(ncclSocketInit(&sock, NULL, NULL, 0), res, out); + NCCLCHECKGOTO(ncclSocketInit(&sock), res, out); NCCLCHECKGOTO(ncclSocketAccept(&sock, listenSock), res, out); NCCLCHECKGOTO(bootstrapNetRecv(&sock, &info, sizeof(info)), res, out); - close(sock.fd); + NCCLCHECKGOTO(ncclSocketClose(&sock), res, out); if (c == 0) { nranks = info.nranks; @@ -139,54 +145,60 @@ static void *bootstrapRoot(void* args) { for (int r=0; rfd); - free(listenSock); + if (listenSock != NULL) { + ncclSocketClose(listenSock); + free(listenSock); + } if (rankAddresses) free(rankAddresses); if (rankAddressesRoot) free(rankAddressesRoot); if (zero) free(zero); + free(rargs); TRACE(NCCL_INIT, "DONE"); return NULL; } -ncclResult_t bootstrapCreateRoot(ncclUniqueId* id, bool idFromEnv) { +ncclResult_t bootstrapCreateRoot(struct ncclBootstrapHandle* handle, bool idFromEnv) { struct ncclSocket* listenSock; - NCCLCHECK(ncclCalloc(&listenSock, 1)); - memcpy(&listenSock->addr, id, sizeof(union ncclSocketAddress)); - NCCLCHECK(ncclSocketListen(listenSock)); - memcpy(id, &listenSock->addr, sizeof(union ncclSocketAddress)); + struct bootstrapRootArgs* args; pthread_t thread; - pthread_create(&thread, NULL, bootstrapRoot, (void*)listenSock); + + NCCLCHECK(ncclCalloc(&listenSock, 1)); + NCCLCHECK(ncclSocketInit(listenSock, &handle->addr, handle->magic, ncclSocketTypeBootstrap, NULL, 0)); + NCCLCHECK(ncclSocketListen(listenSock)); + NCCLCHECK(ncclSocketGetAddr(listenSock, &handle->addr)); + + NCCLCHECK(ncclCalloc(&args, 1)); + args->listenSock = listenSock; + args->magic = handle->magic; + NEQCHECK(pthread_create(&thread, NULL, bootstrapRoot, (void*)args), 0); ncclSetThreadName(thread, "NCCL BootstrapR"); - pthread_detach(thread); // will not be pthread_join()'d + NEQCHECK(pthread_detach(thread), 0); // will not be pthread_join()'d return ncclSuccess; } -ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) { - static_assert(sizeof(union ncclSocketAddress) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId"); - memset(id, 0, sizeof(ncclUniqueId)); - union ncclSocketAddress* connectAddr = (union ncclSocketAddress*) id; +ncclResult_t bootstrapGetUniqueId(struct ncclBootstrapHandle* handle) { + memset(handle, 0, sizeof(ncclBootstrapHandle)); + NCCLCHECK(getRandomData(&handle->magic, sizeof(handle->magic))); char* env = getenv("NCCL_COMM_ID"); if (env) { INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", env); - if (ncclGetSocketAddrFromString(connectAddr, env) != ncclSuccess) { + if (ncclSocketGetAddrFromString(&handle->addr, env) != ncclSuccess) { WARN("Invalid NCCL_COMM_ID, please use format: : or []: or :"); return ncclInvalidArgument; } } else { - memcpy(id, &bootstrapNetIfAddr, sizeof(union ncclSocketAddress)); - NCCLCHECK(bootstrapCreateRoot(id, false)); + memcpy(&handle->addr, &bootstrapNetIfAddr, sizeof(union ncclSocketAddress)); + NCCLCHECK(bootstrapCreateRoot(handle, false)); } return ncclSuccess; @@ -209,38 +221,39 @@ struct bootstrapState { int cudaDev; int rank; int nranks; + uint64_t magic; volatile uint32_t *abortFlag; }; -ncclResult_t bootstrapInit(ncclUniqueId * id, struct ncclComm* comm) { +ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* comm) { int rank = comm->rank; int nranks = comm->nRanks; struct bootstrapState* state; + struct ncclSocket* proxySocket; + ncclSocketAddress nextAddr; + struct ncclSocket sock, listenSockRoot; + struct extInfo info = { 0 }; + NCCLCHECK(ncclCalloc(&state, 1)); state->rank = rank; state->nranks = nranks; state->abortFlag = comm->abortFlag; comm->bootstrap = state; + comm->magic = state->magic = handle->magic; TRACE(NCCL_INIT, "rank %d nranks %d", rank, nranks); - struct extInfo info = { 0 }; info.rank = rank; info.nranks = nranks; - struct ncclSocket sock, listenSockRoot; - - NCCLCHECK(ncclSocketInit(&sock, (union ncclSocketAddress*) id, comm->abortFlag, 0)); - NCCLCHECK(ncclSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->abortFlag, 0)); - NCCLCHECK(ncclSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->abortFlag, 0)); - NCCLCHECK(ncclSocketInit(&state->ringSendSocket, NULL, comm->abortFlag, 0)); - NCCLCHECK(ncclSocketInit(&state->ringRecvSocket, NULL, comm->abortFlag, 0)); // Create socket for other ranks to contact me + NCCLCHECK(ncclSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECK(ncclSocketListen(&state->listenSock)); - memcpy(&info.extAddressListen, &state->listenSock.addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(&state->listenSock, &info.extAddressListen)); // Create socket for root to contact me + NCCLCHECK(ncclSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECK(ncclSocketListen(&listenSockRoot)); - memcpy(&info.extAddressListenRoot, &listenSockRoot.addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot)); // stagger connection times to avoid an overload of the root if (nranks > 128) { @@ -253,32 +266,37 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, struct ncclComm* comm) { } // send info on my listening socket to root + NCCLCHECK(ncclSocketInit(&sock, &handle->addr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECK(ncclSocketConnect(&sock)); NCCLCHECK(bootstrapNetSend(&sock, &info, sizeof(info))); - close(sock.fd); + NCCLCHECK(ncclSocketClose(&sock)); // get info on my "next" rank in the bootstrap ring from root + NCCLCHECK(ncclSocketInit(&sock)); NCCLCHECK(ncclSocketAccept(&sock, &listenSockRoot)); - NCCLCHECK(bootstrapNetRecv(&sock, &state->ringSendSocket.addr, sizeof(union ncclSocketAddress))); - close(sock.fd); - close(listenSockRoot.fd); + NCCLCHECK(bootstrapNetRecv(&sock, &nextAddr, sizeof(union ncclSocketAddress))); + NCCLCHECK(ncclSocketClose(&sock)); + NCCLCHECK(ncclSocketClose(&listenSockRoot)); + NCCLCHECK(ncclSocketInit(&state->ringSendSocket, &nextAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECK(ncclSocketConnect(&state->ringSendSocket)); // Accept the connect request from the previous rank in the AllGather ring + NCCLCHECK(ncclSocketInit(&state->ringRecvSocket)); NCCLCHECK(ncclSocketAccept(&state->ringRecvSocket, &state->listenSock)); // AllGather all listen handlers NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks)); - memcpy(state->peerCommAddresses+rank, &state->listenSock.addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(&state->listenSock, state->peerCommAddresses+rank)); NCCLCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union ncclSocketAddress))); // Create the service proxy NCCLCHECK(ncclCalloc(&state->peerProxyAddresses, nranks)); - struct ncclSocket* proxySocket; + + // proxy is aborted through a message; don't set abortFlag NCCLCHECK(ncclCalloc(&proxySocket, 1)); - NCCLCHECK(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, NULL, 0)); + NCCLCHECK(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeProxy, comm->abortFlag)); NCCLCHECK(ncclSocketListen(proxySocket)); - memcpy(state->peerProxyAddresses+rank, &proxySocket->addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(proxySocket, state->peerProxyAddresses+rank)); NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress))); NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses)); @@ -314,16 +332,21 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) { } ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) { + ncclResult_t ret = ncclSuccess; struct bootstrapState* state = (struct bootstrapState*)commState; struct ncclSocket sock; - NCCLCHECK(ncclSocketInit(&sock, state->peerCommAddresses+peer, state->abortFlag, 1)); - NCCLCHECK(ncclSocketConnect(&sock)); - NCCLCHECK(bootstrapNetSend(&sock, &state->rank, sizeof(int))); - NCCLCHECK(bootstrapNetSend(&sock, &tag, sizeof(int))); - NCCLCHECK(bootstrapNetSend(&sock, data, size)); - close(sock.fd); - return ncclSuccess; + NCCLCHECKGOTO(ncclSocketInit(&sock, state->peerCommAddresses+peer, state->magic, ncclSocketTypeBootstrap, state->abortFlag), ret, fail); + NCCLCHECKGOTO(ncclSocketConnect(&sock), ret, fail); + NCCLCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail); + NCCLCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail); + NCCLCHECKGOTO(bootstrapNetSend(&sock, data, size), ret, fail); + +exit: + NCCLCHECK(ncclSocketClose(&sock)); + return ret; +fail: + goto exit; } ncclResult_t bootstrapBarrier(void* commState, int *ranks, int rank, int nranks, int tag) { @@ -382,9 +405,10 @@ ncclResult_t unexpectedEnqueue(struct bootstrapState* state, int peer, int tag, return ncclSuccess; } -ncclResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, struct ncclSocket* sock) { +ncclResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, struct ncclSocket* sock, int* found) { struct unexConn* elem = state->unexpectedConnections; struct unexConn* prev = NULL; + *found = 0; while (elem) { if (elem->peer == peer && elem->tag == tag) { if (prev == NULL) { @@ -394,54 +418,75 @@ ncclResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, } memcpy(sock, &elem->sock, sizeof(struct ncclSocket)); free(elem); + *found = 1; return ncclSuccess; } prev = elem; elem = elem->next; } - sock->fd = -1; return ncclSuccess; } +static void unexpectedFree(struct bootstrapState* state) { + struct unexConn* elem = state->unexpectedConnections; + struct unexConn* prev = NULL; + + while (elem) { + prev = elem; + elem = elem->next; + free(prev); + } + return; +} + // We can't know who we'll receive from, so we need to receive everything at once ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size) { + ncclResult_t ret = ncclSuccess; struct bootstrapState* state = (struct bootstrapState*)commState; struct ncclSocket sock; + int newPeer, newTag; // Search unexpected connections first - NCCLCHECK(unexpectedDequeue(state, peer, tag, &sock)); - if (sock.fd != -1) { - NCCLCHECK(bootstrapNetRecv(&sock, ((char*)data), size)); - close(sock.fd); - return ncclSuccess; + int found; + NCCLCHECK(unexpectedDequeue(state, peer, tag, &sock, &found)); + if (found) { + NCCLCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail); + goto exit; } // Then look for new connections - NCCLCHECK(ncclSocketInit(&sock, NULL, state->listenSock.abortFlag, 0)); while (1) { - NCCLCHECK(ncclSocketAccept(&sock, &state->listenSock)); - int newPeer, newTag; - NCCLCHECK(bootstrapNetRecv(&sock, &newPeer, sizeof(int))); - NCCLCHECK(bootstrapNetRecv(&sock, &newTag, sizeof(int))); + NCCLCHECKGOTO(ncclSocketInit(&sock), ret, fail); + NCCLCHECKGOTO(ncclSocketAccept(&sock, &state->listenSock), ret, fail); + NCCLCHECKGOTO(bootstrapNetRecv(&sock, &newPeer, sizeof(int)), ret, fail); + NCCLCHECKGOTO(bootstrapNetRecv(&sock, &newTag, sizeof(int)), ret, fail); if (newPeer == peer && newTag == tag) { - NCCLCHECK(bootstrapNetRecv(&sock, ((char*)data), size)); - close(sock.fd); - return ncclSuccess; + NCCLCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail); + goto exit; } // Unexpected connection. Save for later. - NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, &sock)); + NCCLCHECKGOTO(unexpectedEnqueue(state, newPeer, newTag, &sock), ret, fail); } +exit: + NCCLCHECK(ncclSocketClose(&sock)); + return ret; +fail: + goto exit; } ncclResult_t bootstrapClose(void* commState) { struct bootstrapState* state = (struct bootstrapState*)commState; if (state->unexpectedConnections != NULL) { - WARN("Unexpected connections are not empty"); - return ncclInternalError; + unexpectedFree(state); + if (*state->abortFlag == 0) { + WARN("Unexpected connections are not empty"); + return ncclInternalError; + } } - if (state->listenSock.fd >= 0) close(state->listenSock.fd); - if (state->ringSendSocket.fd >= 0) close(state->ringSendSocket.fd); - if (state->ringRecvSocket.fd >= 0) close(state->ringRecvSocket.fd); + + NCCLCHECK(ncclSocketClose(&state->listenSock)); + NCCLCHECK(ncclSocketClose(&state->ringSendSocket)); + NCCLCHECK(ncclSocketClose(&state->ringRecvSocket)); free(state->peerCommAddresses); free(state); @@ -452,9 +497,9 @@ ncclResult_t bootstrapClose(void* commState) { ncclResult_t bootstrapAbort(void* commState) { struct bootstrapState* state = (struct bootstrapState*)commState; if (commState == NULL) return ncclSuccess; - if (state->listenSock.fd) close(state->listenSock.fd); - if (state->ringSendSocket.fd) close(state->ringSendSocket.fd); - if (state->ringRecvSocket.fd) close(state->ringRecvSocket.fd); + NCCLCHECK(ncclSocketClose(&state->listenSock)); + NCCLCHECK(ncclSocketClose(&state->ringSendSocket)); + NCCLCHECK(ncclSocketClose(&state->ringRecvSocket)); free(state->peerCommAddresses); free(state->peerProxyAddresses); free(state); diff --git a/src/collectives/all_gather.cc b/src/collectives/all_gather.cc index 266fd5a..97ec981 100644 --- a/src/collectives/all_gather.cc +++ b/src/collectives/all_gather.cc @@ -11,7 +11,13 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + // Just pass the size of one message and not the total bytes sent/received. + constexpr nvtxPayloadSchemaEntry_t AllGatherSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"} + }; + size_t msgsize = sendcount * ncclTypeSize(datatype); + NVTX3_FUNC_WITH_PARAMS(AllGather, AllGatherSchema, msgsize) + struct ncclInfo info = { ncclFuncAllGather, "AllGather", sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */ ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS }; diff --git a/src/collectives/all_reduce.cc b/src/collectives/all_reduce.cc index b67f3be..8ac61a2 100644 --- a/src/collectives/all_reduce.cc +++ b/src/collectives/all_reduce.cc @@ -5,12 +5,25 @@ ************************************************************************/ #include "enqueue.h" +#include "nccl.h" NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream); ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + struct NvtxParamsAllReduce { + size_t bytes; + ncclRedOp_t op; + }; + // Just pass the size of one message and not the total bytes sent/received. + static constexpr nvtxPayloadSchemaEntry_t AllReduceSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"}, + {0, NVTX_PAYLOAD_ENTRY_NCCL_REDOP, "Reduction operation", nullptr, 0, + offsetof(NvtxParamsAllReduce, op)} + }; + NvtxParamsAllReduce payload{count * ncclTypeSize(datatype), op}; + NVTX3_FUNC_WITH_PARAMS(AllReduce, AllReduceSchema, payload) + struct ncclInfo info = { ncclFuncAllReduce, "AllReduce", sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */ ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS }; diff --git a/src/collectives/broadcast.cc b/src/collectives/broadcast.cc index db0fb49..c73502e 100644 --- a/src/collectives/broadcast.cc +++ b/src/collectives/broadcast.cc @@ -11,7 +11,17 @@ NCCL_API(ncclResult_t, ncclBroadcast, const void* sendbuff, void* recvbuff, size ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + struct NvtxParamsBroadcast { + size_t bytes; + int root; + }; + constexpr nvtxPayloadSchemaEntry_t BroadcastSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Bytes"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Root", nullptr, 0, offsetof(NvtxParamsBroadcast, root)} + }; + NvtxParamsBroadcast payload{count * ncclTypeSize(datatype), root}; + NVTX3_FUNC_WITH_PARAMS(Broadcast, BroadcastSchema, payload) + struct ncclInfo info = { ncclFuncBroadcast, "Broadcast", sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */ BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS }; diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 310938f..95cc990 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -37,6 +37,7 @@ struct ncclShmemData { }; uint64_t redOpArgs[NCCL_MAX_DIRECT_ARITY+1]; int channelId; + int aborted; alignas(16) struct ncclDevComm comm; alignas(16) struct ncclDevChannel channel; alignas(16) struct ncclWork work; @@ -135,6 +136,8 @@ __device__ void ncclKernel( } __syncthreads(); // publish ncclShmem.channelId int channelId = ncclShmem.channelId; + /* set abort flag to 0 */ + if (tid == 0) ncclShmem.aborted = 0; if (true) { void *dst, *src; diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index a727849..9d2d19a 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -62,7 +62,10 @@ class Primitives< inline __device__ bool checkAbort(int &spins) { spins++; if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { - flags |= *ncclShmem.comm.abortFlag ? Aborted : 0; + if (*ncclShmem.comm.abortFlag) { + flags |= Aborted; + ncclShmem.aborted = 1; + } spins = 0; } return flags & Aborted; @@ -176,6 +179,9 @@ class Primitives< ncclShmem.groups[group].dsts[0] = userBuff + dstIx + offset; waitPeer(dstIx, remoteIx, offset, sliceSize); subBarrier(); + /* if user abort the kernel, we don't need to actually perform copy/reduce; just set size + * to 0 to avoid unnecessary workload. */ + size_t workSize = ncclShmem.aborted ? 0 : sliceSize; if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (Send) { @@ -184,7 +190,7 @@ class Primitives< (tid, nworkers, nullptr, false, 1, (T const**)ncclShmem.groups[group].srcs, fan.nsend(), (T**)ncclShmem.groups[group].dsts+1, - sliceSize); + workSize); } } else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem.groups[group].dsts[Dst] == nullptr) { // For broadcast in CollNet to do empty send @@ -192,7 +198,7 @@ class Primitives< (tid, nworkers, ncclShmem.redOpArgs, postOp, Recv, (T const**)ncclShmem.groups[group].srcs, Dst, (T**)ncclShmem.groups[group].dsts, - sliceSize); + workSize); } else { constexpr int PreOpN = SrcBuf != Input ? 0 : DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; @@ -200,7 +206,7 @@ class Primitives< (tid, nworkers, ncclShmem.redOpArgs, postOp, Recv*fan.nrecv()+Src, (T const**)ncclShmem.groups[group].srcs, Send*fan.nsend()+Dst, (T**)ncclShmem.groups[group].dsts, - sliceSize); + workSize); } barrier(); // This barrier has a counterpart in following loop if (Send && (flags & RolePostSend) && index == 0) __threadfence_system(); diff --git a/src/collectives/reduce.cc b/src/collectives/reduce.cc index 86388df..6335516 100644 --- a/src/collectives/reduce.cc +++ b/src/collectives/reduce.cc @@ -6,12 +6,26 @@ #include "enqueue.h" #include "collectives.h" +#include "nccl.h" NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + struct NvtxParamsReduce { + size_t bytes; + int root; + ncclRedOp_t op; + }; + constexpr nvtxPayloadSchemaEntry_t ReduceSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Root", nullptr, 0, offsetof(NvtxParamsReduce, root)}, + {0, NVTX_PAYLOAD_ENTRY_NCCL_REDOP, "Reduction operation", nullptr, 0, + offsetof(NvtxParamsReduce, op)} + }; + NvtxParamsReduce payload{count * ncclTypeSize(datatype), root, op}; + NVTX3_FUNC_WITH_PARAMS(Reduce, ReduceSchema, payload) + struct ncclInfo info = { ncclFuncReduce, "Reduce", sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */ REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS }; diff --git a/src/collectives/reduce_scatter.cc b/src/collectives/reduce_scatter.cc index 57c67bf..5242545 100644 --- a/src/collectives/reduce_scatter.cc +++ b/src/collectives/reduce_scatter.cc @@ -6,12 +6,24 @@ #include "enqueue.h" #include "collectives.h" +#include "nccl.h" NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream); ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + struct NvtxParamsReduceScatter { + size_t bytes; + ncclRedOp_t op; + }; + constexpr nvtxPayloadSchemaEntry_t ReduceScatterSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"}, + {0, NVTX_PAYLOAD_ENTRY_NCCL_REDOP, "Reduction operation", nullptr, 0, + offsetof(NvtxParamsReduceScatter, op)} + }; + NvtxParamsReduceScatter payload{recvcount * ncclTypeSize(datatype), op}; + NVTX3_FUNC_WITH_PARAMS(ReduceScatter, ReduceScatterSchema, payload) + struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter", sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */ REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS }; diff --git a/src/collectives/sendrecv.cc b/src/collectives/sendrecv.cc index 0e9ca4f..9a81b0a 100644 --- a/src/collectives/sendrecv.cc +++ b/src/collectives/sendrecv.cc @@ -8,11 +8,22 @@ #include "collectives.h" #include "argcheck.h" // Need some checks here since we access comm +struct NvtxParamsSendRecv { + size_t bytes; + int peer; +}; +constexpr const nvtxPayloadSchemaEntry_t SendRecvSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Bytes"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Peer rank", nullptr, 0, offsetof(NvtxParamsSendRecv, peer)} +}; + NCCL_API(ncclResult_t, ncclSend, const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer}; + NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload) + struct ncclInfo info = { ncclFuncSend, "Send", NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ 1, 1 }; @@ -27,7 +38,9 @@ NCCL_API(ncclResult_t, ncclRecv, void* recvbuff, size_t count, ncclDataType_t da ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream) { - NVTX3_FUNC_RANGE_IN(nccl_domain); + NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer}; + NVTX3_FUNC_WITH_PARAMS(Recv, SendRecvSchema, payload) + struct ncclInfo info = { ncclFuncRecv, "Recv", NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ 1, 1 }; diff --git a/src/enqueue.cc b/src/enqueue.cc index 8bac73f..0744e09 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -609,8 +609,7 @@ static ncclResult_t scheduleP2pTasksToPlan( // Compute how much to split operations // Natural step size matching buffer steps. - ssize_t stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; - if (comm->nNodes > 1) stepSize = comm->p2pNetChunkSize; + ssize_t stepSize = comm->p2pChunkSize; // Try to use all channels int nChannelsMax = comm->p2pnChannelsPerPeer; int nChannelsMin = nChannelsMax; @@ -1008,7 +1007,13 @@ ncclResult_t ncclLaunchKernelBefore_NoUncapturedCuda(struct ncclComm* comm, stru #if CUDART_VERSION >= 11080 #define NCCL_MAX_CGA_CLUSTER_SIZE 8 -NCCL_PARAM(CGAClusterSize, "CGA_CLUSTER_SIZE", 0); +#define NCCL_CGA_CLUSTER_SIZE_SM90 4 +NCCL_PARAM(CGAClusterSize, "CGA_CLUSTER_SIZE", -2); +#endif + +#if CUDART_VERSION >= 12000 +// NCCL uses the "Remote" Mem Sync domain by default +NCCL_PARAM(MemSyncDomain, "MEM_SYNC_DOMAIN", cudaLaunchMemSyncDomainRemote); #endif ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan) { @@ -1022,22 +1027,25 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan #if CUDART_VERSION >= 11080 int driverVersion; NCCLCHECK(ncclCudaDriverVersion(&driverVersion)); - - unsigned int clusterSize = 0; - clusterSize = ncclParamCGAClusterSize(); - if (clusterSize > NCCL_MAX_CGA_CLUSTER_SIZE) { - static bool warned = false; - if (warned == false) { - WARN("NCCL_CGA_CLUSTER_SIZE value %d is too big. Limiting value to %d.", - clusterSize, NCCL_MAX_CGA_CLUSTER_SIZE); - warned = true; + if (driverVersion >= 11080) { + int compCap = comm->compCap; + unsigned int clusterSize = (compCap == 90) ? NCCL_CGA_CLUSTER_SIZE_SM90 : 0; + if (ncclParamCGAClusterSize() != -2) { + clusterSize = ncclParamCGAClusterSize(); + if (clusterSize > NCCL_MAX_CGA_CLUSTER_SIZE) { + static bool warned = false; + if (warned == false) { + WARN("NCCL_CGA_CLUSTER_SIZE value %d is too big. Limiting value to %d.", + clusterSize, NCCL_MAX_CGA_CLUSTER_SIZE); + warned = true; + } + clusterSize = NCCL_MAX_CGA_CLUSTER_SIZE; + } } - clusterSize = NCCL_MAX_CGA_CLUSTER_SIZE; - } - if (clusterSize && driverVersion >= 11080) { cudaLaunchConfig_t launchConfig = {0}; - cudaLaunchAttribute launchAttrs[2]; + cudaLaunchAttribute launchAttrs[3]; + int attrs = 0; /* Cooperative Group Array (CGA) * On sm90 and later we have an extra level of hierarchy where we * can group together several blocks within the Grid, called @@ -1048,17 +1056,25 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan * concurrently scheduled onto a group of SMs. * The maximum value is 8 and it must be divisible into the grid dimensions */ - // Grid dimension must be divisible by clusterSize - if (grid.x % clusterSize) clusterSize = 1; - launchAttrs[0].id = cudaLaunchAttributeClusterDimension; - launchAttrs[0].val.clusterDim = {clusterSize, 1, 1}; - launchAttrs[1].id = cudaLaunchAttributeClusterSchedulingPolicyPreference; - launchAttrs[1].val.clusterSchedulingPolicyPreference = cudaClusterSchedulingPolicySpread; - + if (clusterSize) { + // Grid dimension must be divisible by clusterSize + if (grid.x % clusterSize) clusterSize = 1; + launchAttrs[attrs].id = cudaLaunchAttributeClusterDimension; + launchAttrs[attrs++].val.clusterDim = {clusterSize, 1, 1}; + launchAttrs[attrs].id = cudaLaunchAttributeClusterSchedulingPolicyPreference; + launchAttrs[attrs++].val.clusterSchedulingPolicyPreference = cudaClusterSchedulingPolicySpread; + } + #if CUDART_VERSION >= 12000 + if (compCap >= 90 && driverVersion >= 12000) { + // Set the NCCL Mem Sync domain on CUDA 12.0 and later (sm90) + launchAttrs[attrs].id = cudaLaunchAttributeMemSyncDomain; + launchAttrs[attrs++].val.memSyncDomain = (cudaLaunchMemSyncDomain) ncclParamMemSyncDomain(); + } + #endif launchConfig.gridDim = grid; launchConfig.blockDim = block; launchConfig.attrs = launchAttrs; - launchConfig.numAttrs = sizeof(launchAttrs)/sizeof(launchAttrs[0]); + launchConfig.numAttrs = attrs; launchConfig.stream = launchStream; CUDACHECK(cudaLaunchKernelExC(&launchConfig, fn, args)); @@ -1093,14 +1109,18 @@ ncclResult_t ncclLaunchFinish(struct ncclComm* comm) { // back to us for reclaiming via callbackQueue. ncclIntruQueueConstruct(&comm->planQueue); cudaStream_t launchStream = tasks->streams->stream; // First user stream gets launch - // Create dependency for deviceStream on launchStream. - NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, &comm->deviceStream, launchStream), result, resume1); + // Create dependency for deviceStream on launchStream. We know that deviceStream + // hasn't been modified since launchStream waited on it (in ncclLaunchPrepare), + // so we can say that launchStream subsumes it. + NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, &comm->deviceStream, launchStream, /*b_subsumes_a=*/true), result, resume1); resume1: - // Create dependency for other user streams (skip launch stream). + // Create dependency for other user streams (skip launch stream) on deviceStream. + // Again, the user streams haven't been touched since deviceStream waited on them + // so we can say they are subsumed by deviceStream. struct ncclCudaStreamList* sl = tasks->streams->next; tasks->streams = nullptr; // Reset comm->tasks.streams to empty. while (sl != nullptr) { - NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, sl->stream, &comm->deviceStream), result, resume2); + NCCLCHECKGOTO(ncclStrongStreamWaitStream(tasks->capturingGraph, sl->stream, &comm->deviceStream, /*b_subsumes_a=*/true), result, resume2); resume2: sl = sl->next; } diff --git a/src/graph/paths.cc b/src/graph/paths.cc index 01f1582..7134b90 100644 --- a/src/graph/paths.cc +++ b/src/graph/paths.cc @@ -754,3 +754,15 @@ ncclResult_t ncclTopoGetNvbGpus(struct ncclTopoSystem* system, int rank, int* nr *nranks = nvbGpus; return ncclSuccess; } + +int ncclTopoPathAllNVLink(struct ncclTopoSystem* system) { + int minPath = PATH_DIS; + for (int i=0; inodes[GPU].count; i++) { + struct ncclTopoLinkList* paths = system->nodes[GPU].nodes[i].paths[GPU]; + for (int j=0; jnodes[GPU].count; j++) { + if (i == j) continue; + minPath = std::min(minPath, paths[j].type); + } + } + return minPath >= PATH_PIX ? 0 : 1; +} diff --git a/src/graph/topo.cc b/src/graph/topo.cc index 9e4c978..d91aa63 100644 --- a/src/graph/topo.cc +++ b/src/graph/topo.cc @@ -72,6 +72,9 @@ static ncclResult_t ncclTopoGetInterCpuBw(struct ncclTopoNode* cpu, float* bw) { if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_INTEL) { *bw = cpu->cpu.model == NCCL_TOPO_CPU_TYPE_SKL ? SKL_QPI_BW : QPI_BW; } + if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_AMD) { + *bw = AMD_BW; + } if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_ZHAOXIN) { *bw = cpu->cpu.model == NCCL_TOPO_CPU_TYPE_YONGFENG ? YONGFENG_ZPI_BW : ZPI_BW; } diff --git a/src/graph/topo.h b/src/graph/topo.h index 20a3e9d..1a1a04c 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -18,6 +18,7 @@ #define PCI_BW 12.0 // PCI Gen3 x16 #define QPI_BW 6.0 #define SKL_QPI_BW 9.0 +#define AMD_BW 16.0 #define ZPI_BW 6.0 #define YONGFENG_ZPI_BW 9.0 #define P9_BW 32.0 diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index a787c0b..e70db04 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -10,10 +10,16 @@ #include "nccl.h" #include "comm.h" +struct ncclBootstrapHandle { + uint64_t magic; + union ncclSocketAddress addr; +}; +static_assert(sizeof(struct ncclBootstrapHandle) <= sizeof(ncclUniqueId), "Bootstrap handle is too large to fit inside NCCL unique ID"); + ncclResult_t bootstrapNetInit(); -ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv); -ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out); -ncclResult_t bootstrapInit(ncclUniqueId* id, struct ncclComm* comm); +ncclResult_t bootstrapCreateRoot(struct ncclBootstrapHandle* handle, bool idFromEnv); +ncclResult_t bootstrapGetUniqueId(struct ncclBootstrapHandle* handle); +ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* comm); ncclResult_t bootstrapAllGather(void* commState, void* allData, int size); ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size); ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); diff --git a/src/include/comm.h b/src/include/comm.h index 16e95b3..655292a 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -171,9 +171,12 @@ struct ncclComm { uint64_t* connectSend; uint64_t* connectRecv; + uint64_t magic; // Magic number for all network communication. Not a security key -- only goal is to detect mismatches. + int rank; // my rank in the communicator int nRanks; // number of GPUs in communicator int cudaDev; // my cuda device index + int compCap; // compute capability of the GPU int64_t busId; // my PCI bus ID in int format cpu_set_t cpuAffinity; // CPU affinity of the GPU @@ -208,7 +211,7 @@ struct ncclComm { // Buffer sizes int buffSizes[NCCL_NUM_PROTOCOLS]; - int p2pNetChunkSize; + int p2pChunkSize; // Algorithm/Protocols thresholds ssize_t threadThresholds[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; @@ -240,7 +243,6 @@ struct ncclComm { // Intra-process sync struct ncclComm* intraComm0; // leader of intra-process comms (self possible) struct ncclComm* intraNext; // next of intra-process comms, intraComm0 is head - int intraRefs; // reference count from intra-process comms (zero if not leader else intraRanks) int intraRank; int intraRanks; uint32_t intraBarrierPhase; diff --git a/src/include/graph.h b/src/include/graph.h index 91e85e7..62323a2 100644 --- a/src/include/graph.h +++ b/src/include/graph.h @@ -28,6 +28,7 @@ void ncclTopoFree(struct ncclTopoSystem* system); ncclResult_t ncclTopoTrimSystem(struct ncclTopoSystem* system, struct ncclComm* comm); ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm); ncclResult_t ncclTopoGetNvbGpus(struct ncclTopoSystem* system, int rank, int* nranks, int** ranks); +int ncclTopoPathAllNVLink(struct ncclTopoSystem* system); // Query topology ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoGraph* graph, int channelId, int peerRank, int* net, int* proxyRank); diff --git a/src/include/nvtx.h b/src/include/nvtx.h index 7796126..2aeb932 100644 --- a/src/include/nvtx.h +++ b/src/include/nvtx.h @@ -9,6 +9,77 @@ #include "nvtx3.hpp" +#if __cpp_constexpr >= 201304L && !defined(NVTX3_RELAXED_CONSTEXPR) +#define NVTX3_RELAXED_CONSTEXPR constexpr +#else +#define NVTX3_RELAXED_CONSTEXPR +#endif + +// Define all NCCL-provided static schema IDs here (avoid duplicates). +#define NVTX_SID_CommInitRank 0 +#define NVTX_SID_CommInitAll 1 +#define NVTX_SID_CommDestroy 2 // same schema as NVTX_SID_CommInitRank +#define NVTX_SID_CommAbort 3 // same schema as NVTX_SID_CommInitRank +#define NVTX_SID_AllGather 4 +#define NVTX_SID_AllReduce 5 +#define NVTX_SID_Broadcast 6 +#define NVTX_SID_ReduceScatter 7 +#define NVTX_SID_Reduce 8 +#define NVTX_SID_Send 9 +#define NVTX_SID_Recv 10 + +// Define static schema ID for the reduction operation. +#define NVTX_PAYLOAD_ENTRY_NCCL_REDOP 11 + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + +extern const nvtxDomainHandle_t ncclNvtxDomainHandle; + struct nccl_domain{static constexpr char const* name{"NCCL"};}; +class payload_schema { + public: + NVTX3_RELAXED_CONSTEXPR explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept + { + schema_attr.name = schemaName; + schema_attr.entries = entries; + schema_attr.numEntries = numEntries; + schema_attr.schemaId = schemaId; + nvtxPayloadSchemaRegister(nvtx3::domain::get(), &schema_attr); + } + + payload_schema() = delete; + ~payload_schema() = default; + payload_schema(payload_schema const&) = default; + payload_schema& operator=(payload_schema const&) = default; + payload_schema(payload_schema&&) = default; + payload_schema& operator=(payload_schema&&) = default; + + private: + nvtxPayloadSchemaAttr_t schema_attr{ + NVTX_PAYLOAD_SCHEMA_ATTR_TYPE | + NVTX_PAYLOAD_SCHEMA_ATTR_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_NUM_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_STATIC_SIZE | + NVTX_PAYLOAD_SCHEMA_ATTR_SCHEMA_ID, + nullptr, + NVTX_PAYLOAD_SCHEMA_TYPE_STATIC, + NVTX_PAYLOAD_SCHEMA_FLAG_NONE, + nullptr, 0, 0, 0}; +}; + +// Create NVTX push/pop range with parameters +// @param name of the operation (see `NVTX_SID_*`) +// @param N schema name +// @param S schema (entries) +// @param P payload (struct) +#define NVTX3_FUNC_WITH_PARAMS(ID, S, P) \ + static const payload_schema schema{S, std::extent::value, \ + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, #ID}; \ + static ::nvtx3::v1::registered_string const nvtx3_func_name__{__func__}; \ + nvtxPayloadData_t nvtx3_bpl__[] = { \ + {NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, sizeof(P), &(P)}}; \ + ::nvtx3::v1::event_attributes nvtx3_func_attr__{nvtx3_func_name__, nvtx3_bpl__}; \ + ::nvtx3::v1::domain_thread_range const nvtx3_range__{nvtx3_func_attr__}; + +extern void initNvtxRegisteredEnums(); + #endif diff --git a/src/include/nvtx3.hpp b/src/include/nvtx3.hpp index 1e99373..353fddf 100644 --- a/src/include/nvtx3.hpp +++ b/src/include/nvtx3.hpp @@ -92,6 +92,7 @@ /* clang-format on */ #include +#include #include #include @@ -1732,6 +1733,22 @@ class event_attributes { attributes_.messageType = m.get_type(); } + /** + * @brief Variadic constructor where the first argument is a binary payload. + * + * Sets the value of the `EventAttribute`s message based on `m` and forwards + * the remaining variadic parameter pack to the next constructor. + * + */ + template + NVTX3_RELAXED_CONSTEXPR explicit event_attributes(nvtxPayloadData_t const* bpl, Args const&... args) noexcept + : event_attributes(args...) + { + attributes_.payloadType = NVTX_PAYLOAD_TYPE_BINARY; + attributes_.reserved0 = 1; // NCCL uses only a single binary payload per event. + attributes_.payload.ullValue = NVTX_POINTER_AS_PAYLOAD_ULLVALUE(bpl); + } + ~event_attributes() = default; event_attributes(event_attributes const&) = default; event_attributes& operator=(event_attributes const&) = default; diff --git a/src/include/nvtx3/nvToolsExtPayload.h b/src/include/nvtx3/nvToolsExtPayload.h new file mode 100644 index 0000000..1683f92 --- /dev/null +++ b/src/include/nvtx3/nvToolsExtPayload.h @@ -0,0 +1,776 @@ +/* +* Copyright 2021 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +#include "nvtx3/nvToolsExt.h" + +#ifndef NVTOOLSEXT_PAYLOAD_H +#define NVTOOLSEXT_PAYLOAD_H + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** + * \brief A compatibility ID value used in initialization to identify version + * differences. + */ +#define NVTX_EXT_COMPATID_PAYLOAD 0x0103 + +/** + * \brief This module ID identifies the payload extension. It has to be unique + * among the extension modules. + */ +#define NVTX_EXT_MODULEID_PAYLOAD 2 + +/** + * \brief Additional values for the enum @ref nvtxPayloadType_t + */ +#define NVTX_PAYLOAD_TYPE_BINARY ((int32_t)0xDFBD0009) + + +/** --------------------------------------------------------------------------- + * Payload schema entry flags. + * ------------------------------------------------------------------------- */ +#define NVTX_PAYLOAD_ENTRY_FLAG_UNUSED 0 + +/** + * Absolute pointer into a payload (entry) of the same event. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_POINTER (1 << 1) + +/** + * Offset from base address of the payload. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_OFFSET_FROM_BASE (1 << 2) + +/** + * Offset from the end of this payload entry. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_OFFSET_FROM_HERE (1 << 3) + +/** + * The value is an array with fixed length, set with the field `arrayLength`. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_FIXED_SIZE (1 << 4) + +/** + * The value is a zero-/null-terminated array. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_ZERO_TERMINATED (2 << 4) + +/** + * \brief A single or multi-dimensional array of variable length. + * + * The field `arrayLength` contains the index of the schema entry that holds the + * length(s). If the other field points to a scalar entry then this will be the + * 1D array. If the other field points to a FIXED_SIZE array, then the number of + * dimensions is defined with the registration of the scheme. If the other field + * is ZERO_TERMINATED, the array the dimensions can be determined at runtime. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_LENGTH_INDEX (3 << 4) + +/** + * A tool may not support deep copy and just ignore this flag. + * See @ref NVTX_PAYLOAD_SCHEMA_FLAG_DEEP_COPY for more details. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_DEEP_COPY (1 << 9) + +/** + * The entry specifies the message in a deferred event. The entry type can be + * any string type. The flag is ignored for schemas that are not flagged with + * `NVTX_PAYLOAD_SCHEMA_FLAG_RANGE*` or `NVTX_PAYLOAD_SCHEMA_FLAG_MARK`. + */ +#define NVTX_PAYLOAD_ENTRY_FLAG_EVENT_MESSAGE (1 << 10) + +/** + * @note The ‘array’ flags assume that the array is embedded. Otherwise, + * @ref NVTX_PAYLOAD_ENTRY_FLAG_POINTER has to be additionally specified. Some + * combinations may be invalid based on the `NVTX_PAYLOAD_SCHEMA_TYPE_*` this + * entry is enclosed. For instance, variable length embedded arrays are valid + * within @ref NVTX_PAYLOAD_SCHEMA_TYPE_DYNAMIC but invalid with + * @ref NVTX_PAYLOAD_SCHEMA_TYPE_STATIC. See `NVTX_PAYLOAD_SCHEMA_TYPE_*` for + * additional details. + */ + +/* Helper macro to check if an entry represents an array. */ +#define NVTX_PAYLOAD_ENTRY_FLAG_IS_ARRAY (\ + NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_FIXED_SIZE | \ + NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_ZERO_TERMINATED | \ + NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_LENGTH_INDEX) + +/** --------------------------------------------------------------------------- + * Types of entries in a payload schema. + * ------------------------------------------------------------------------- */ + +/** + * @note Several of the predefined types contain the size (in bits) in their + * names. For some data types the size (in bytes) is not fixed and may differ + * for different platforms/operating systems/compilers. To provide portability, + * an array of sizes (in bytes) for type 1 to 28 ( @ref + * NVTX_PAYLOAD_ENTRY_TYPE_CHAR to @ref NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE) + * is passed to the NVTX extension initialization function + * @ref InitializeInjectionNvtxExtension via the `extInfo` field of + * @ref nvtxExtModuleInfo_t. + */ + +#define NVTX_PAYLOAD_ENTRY_TYPE_INVALID 0 + +/** + * Basic integer types. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CHAR 1 +#define NVTX_PAYLOAD_ENTRY_TYPE_UCHAR 2 +#define NVTX_PAYLOAD_ENTRY_TYPE_SHORT 3 +#define NVTX_PAYLOAD_ENTRY_TYPE_USHORT 4 +#define NVTX_PAYLOAD_ENTRY_TYPE_INT 5 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT 6 +#define NVTX_PAYLOAD_ENTRY_TYPE_LONG 7 +#define NVTX_PAYLOAD_ENTRY_TYPE_ULONG 8 +#define NVTX_PAYLOAD_ENTRY_TYPE_LONGLONG 9 +#define NVTX_PAYLOAD_ENTRY_TYPE_ULONGLONG 10 + +/** + * Integer types with explicit size. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_INT8 11 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT8 12 +#define NVTX_PAYLOAD_ENTRY_TYPE_INT16 13 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT16 14 +#define NVTX_PAYLOAD_ENTRY_TYPE_INT32 15 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT32 16 +#define NVTX_PAYLOAD_ENTRY_TYPE_INT64 17 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT64 18 + +/** + * C floating point types + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT 19 +#define NVTX_PAYLOAD_ENTRY_TYPE_DOUBLE 20 +#define NVTX_PAYLOAD_ENTRY_TYPE_LONGDOUBLE 21 + +/** + * Size type (`size_t`) + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_SIZE 22 + +/** + * Any address, e.g. `void*`. If the pointer type matters, use the flag @ref + * NVTX_PAYLOAD_ENTRY_FLAG_POINTER and the respective type instead. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_ADDRESS 23 + +/** + * Special character types. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_WCHAR 24 /* wide character (since C90) */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CHAR8 25 /* since C2x and C++20 */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CHAR16 26 +#define NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 27 + +/** + * There is type size and alignment information for all previous types. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE (NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 + 1) + +/** + * Store raw 8-bit binary data. As with `char`, 1-byte alignment is assumed. + * Typically a tool will display this as hex or binary. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_BYTE 32 + +/** + * These types do not have standardized equivalents. It is assumed that the + * number at the end corresponds to the bits used to store the value and that + * the alignment corresponds to standardized types of the same size. + * A tool may not support these types. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_INT128 33 +#define NVTX_PAYLOAD_ENTRY_TYPE_UINT128 34 + +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT16 42 +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT32 43 +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT64 44 +#define NVTX_PAYLOAD_ENTRY_TYPE_FLOAT128 45 + +#define NVTX_PAYLOAD_ENTRY_TYPE_BF16 50 +#define NVTX_PAYLOAD_ENTRY_TYPE_TF32 52 + +/** + * These types are normalized numbers stored in integers. UNORMs represent 0.0 + * to 1.0 and SNORMs represent -1.0 to 1.0. The number after represents the + * number of integer bits. Alignment is take from equivalent types INT# matching + * to SNORM# and UINT# matching to UNORM#. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_SNORM8 61 +#define NVTX_PAYLOAD_ENTRY_TYPE_UNORM8 62 +#define NVTX_PAYLOAD_ENTRY_TYPE_SNORM16 63 +#define NVTX_PAYLOAD_ENTRY_TYPE_UNORM16 64 +#define NVTX_PAYLOAD_ENTRY_TYPE_SNORM32 65 +#define NVTX_PAYLOAD_ENTRY_TYPE_UNORM32 66 +#define NVTX_PAYLOAD_ENTRY_TYPE_SNORM64 67 +#define NVTX_PAYLOAD_ENTRY_TYPE_UNORM64 68 + +/** + * String types. + * + * If `arrayOrUnionDetail` is greater than `0`, the entry is a fixed-size string + * with the provided length. + * + * `NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_FIXED_SIZE` is ignored for string types. It + * just specifies once more that the entry is a fixed-size string. + * + * Setting the flag `NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_ZERO_TERMINATED` indicates a + * zero-terminated string. If `arrayOrUnionDetail` is greater than `0`, a zero- + * terminated array of fixed-size strings is assumed. + * + * Setting the flag `NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_LENGTH_INDEX` specifies the + * entry index of the entry which contains the string length. It is not possible + * to describe a variable length array of strings. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CSTRING 75 /* `char*`, system LOCALE */ +#define NVTX_PAYLOAD_ENTRY_TYPE_CSTRING_UTF8 76 +#define NVTX_PAYLOAD_ENTRY_TYPE_CSTRING_UTF16 77 +#define NVTX_PAYLOAD_ENTRY_TYPE_CSTRING_UTF32 78 + +/** + * @ref nvtxStringHandle_t returned by @ref nvtxDomainRegisterString + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE 80 + +/** + * Entry types to be used in deferred events. Data types are as defined by + * NVTXv3 core: category -> uint32_t, color -> uint32_t, color type -> int32_t. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_NVTX_CATEGORY 90 +#define NVTX_PAYLOAD_ENTRY_TYPE_NVTX_COLORTYPE 91 +#define NVTX_PAYLOAD_ENTRY_TYPE_NVTX_COLOR 92 + +/** + * This type marks the union selector member (entry index) in schemas used by + * a union with internal internal selector. + * See @ref NVTX_PAYLOAD_SCHEMA_TYPE_UNION_WITH_INTERNAL_SELECTOR. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_UNION_SELECTOR 100 + +/** + * Timestamp types occupy the range from 128 to 255 + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP64 128 /* data type is uint64_t */ + +/** + * CPU timestamp sources. + * \todo All 64 bits? + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_TSC 129 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_TSC_NONVIRTUALIZED 130 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_REALTIME 131 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_REALTIME_COARSE 132 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_MONOTONIC 133 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_MONOTONIC_RAW 134 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_MONOTONIC_COARSE 135 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_BOOTTIME 136 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_PROCESS_CPUTIME_ID 137 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPU_CLOCK_GETTIME_THREAD_CPUTIME_ID 138 + +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_WIN_QPC 160 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_WIN_GSTAFT 161 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_WIN_GSTAFTP 162 + +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_C_TIME 163 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_C_CLOCK 164 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_C_TIMESPEC_GET 165 + +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_STEADY_CLOCK 166 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_HIGH_RESOLUTION_CLOCK 167 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_SYSTEM_CLOCK 168 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_UTC_CLOCK 169 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_TAI_CLOCK 170 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_GPS_CLOCK 171 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_CPP_FILE_CLOCK 172 + +/** + * \brief GPU timestamp sources. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_GPU_GLOBALTIMER 192 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_GPU_SM_CLOCK 193 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_GPU_SM_CLOCK64 194 +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_GPU_CUPTI 195 + +/** + * The timestamp was provided by the NVTX handler’s timestamp routine. + */ +#define NVTX_PAYLOAD_ENTRY_TYPE_TIMESTAMP_TOOL_PROVIDED 224 + +/** + * This predefined schema ID can be used in `nvtxPayloadData_t` to indicate that + * the payload is a blob of memory which other payload entries may point into. + * A tool will not expose this payload directly. + */ +#define NVTX_TYPE_PAYLOAD_SCHEMA_REFERENCED 1022 + +/** + * This predefined schema ID can be used in `nvtxPayloadData_t` to indicate that + * the payload is a blob which can be shown with an arbitrary data viewer. + */ +#define NVTX_TYPE_PAYLOAD_SCHEMA_RAW 1023 + +/* Custom (static) schema IDs. */ +#define NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START (1 << 24) + +/* Dynamic schema IDs (generated by the tool) start here. */ +#define NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_DYNAMIC_START 4294967296 // 1 << 32 + + +/** + * \brief Size and alignment information for predefined payload entry types. + * + * The struct contains the size and the alignment size in bytes. A respective + * array for the predefined types is passed via nvtxExtModuleInfo_t to the NVTX + * client/handler. The type (ID) is used as index into this array. + */ +typedef struct nvtxPayloadEntryTypeInfo_t +{ + uint16_t size; + uint16_t align; +} nvtxPayloadEntryTypeInfo_t; + +/** + * \brief Entry in a schema. + * + * A payload schema consists of an array of payload schema entries. It is + * registered with @ref nvtxPayloadSchemaRegister. `flag` can be set to `0` for + * simple values, 'type' is the only "required" field. If not set explicitly, + * all other fields are zero-initialized, which means that the entry has no name + * and the offset is determined based on self-alignment rules. + * + * Example schema: + * nvtxPayloadSchemaEntry_t desc[] = { + * {0, NVTX_EXT_PAYLOAD_TYPE_UINT8, "one byte"}, + * {0, NVTX_EXT_PAYLOAD_TYPE_INT32, "four bytes"} + * }; + */ +typedef struct nvtxPayloadSchemaEntry_t +{ + /** + * \brief Flags to augment the basic type. + * + * This field allows additional properties of the payload entry to be + * specified. Valid values are `NVTX_PAYLOAD_ENTRY_FLAG_*`. + */ + uint64_t flags; + + /** + * \brief Predefined payload schema entry type or ID of a registered payload + * schema. + */ + uint64_t type; + + /** + * \brief Name of the payload entry. (Optional) + * + * Providing a name is useful to give a meaning to the associated value. + */ + const char* name; + + /** + * \brief Description of the payload entry. (Optional) + */ + const char* description; + + /** + * \brief String or array length or union selector for union types. + * + * If @ref type is a C string type, this defines the length of the string. + * + * If @ref flags specify that the entry is an array, this field defines the + * length of the array. See `NVTX_PAYLOAD_ENTRY_FLAG_ARRAY_*` for more + * details. + * + * If @ref type implies that the entry is a union with schema type + * @ref NVTX_PAYLOAD_SCHEMA_TYPE_UNION (external selection of the union + * member), this field contains the index (starting with 0) to an entry of + * integer type in the same schema. The associated field contains the + * selected union member. + * + * @note An array of schema type @ref NVTX_PAYLOAD_SCHEMA_TYPE_UNION is not + * supported. @ref NVTX_PAYLOAD_SCHEMA_TYPE_UNION_WITH_INTERNAL_SELECTOR can + * be used instead. + */ + uint64_t arrayOrUnionDetail; + + /** + * \brief Offset in the binary payload data (in bytes). + * + * This field specifies the byte offset from the base address of the actual + * binary data (blob) to the data of this entry. + * + * This is an optional field, but it is recommended to specify this field to + * avoid issues in the automatic detection of the offset by a tool/handler. + */ + uint64_t offset; + + /** + * Semantics are not yet defined. + */ + void* semantics; + + /** + * Reserved for future use. Do not use it! + */ + void* reserved; +} nvtxPayloadSchemaEntry_t; + +/** + * \brief Binary payload data, size and decoding information. + * + * An array of nvtxPayloadData_t is passed to the NVTX event attribute payload + * member. To attach a single payload the macro @ref NVTX_EXT_PAYLOAD_SET_ATTR + * can be used. + */ +typedef struct nvtxPayloadData_t +{ + /** + * The schema ID, which defines the layout of the binary data. + */ + uint64_t schemaId; + + /** + * Size of the binary payload (blob) in bytes. + */ + size_t size; + + /** + * Pointer to the binary payload data. + */ + const void* payload; +} nvtxPayloadData_t; + +/* Helper macros for safe double-cast of pointer to uint64_t value */ +#ifndef NVTX_POINTER_AS_PAYLOAD_ULLVALUE +# ifdef __cplusplus +# define NVTX_POINTER_AS_PAYLOAD_ULLVALUE(p) \ + static_cast(reinterpret_cast(p)) +# else +#define NVTX_POINTER_AS_PAYLOAD_ULLVALUE(p) ((uint64_t)(uintptr_t)p) +# endif +#endif + + +#define NVTX_PAYLOAD_CONCAT2(a,b) a##b +#define NVTX_PAYLOAD_CONCAT(a,b) NVTX_PAYLOAD_CONCAT2(a,b) +#define NVTX_DATA_VAR NVTX_PAYLOAD_CONCAT(nvtxDFDB,__LINE__) + +/** + * \brief Helper macro to attach a single payload to an NVTX event attribute. + * + * @note The NVTX push, start or mark operation must not be in the same or a + * nested scope. + */ +#define NVTX_PAYLOAD_EVTATTR_SET(EVTATTR, SCHEMA_ID, PAYLOAD_ADDR, SIZE) \ + nvtxPayloadData_t NVTX_DATA_VAR[] = {{SCHEMA_ID, SIZE, PAYLOAD_ADDR}}; \ + (EVTATTR).payload.ullValue = \ + NVTX_POINTER_AS_PAYLOAD_ULLVALUE(NVTX_DATA_VAR); \ + (EVTATTR).payloadType = NVTX_PAYLOAD_TYPE_BINARY; \ + (EVTATTR).reserved0 = 1; + +/** + * \brief Helper macro to attach multiple payloads to an NVTX event attribute. + * + * The payload data array (`nvtxPayloadData_t`) is passed as first argument to + * this macro. + */ +#define NVTX_PAYLOAD_EVTATTR_SET_MULTIPLE(EVTATTR, PAYLOADS) \ + (EVTATTR).payloadType = NVTX_PAYLOAD_TYPE_BINARY; \ + (EVTATTR).reserved0 = sizeof(PAYLOADS)/sizeof(nvtxPayloadData_t); \ + (EVTATTR).payload.ullValue = NVTX_POINTER_AS_PAYLOAD_ULLVALUE(PAYLOADS); + + +/** + * \brief The payload schema type. + * + * A schema can be either of these types. + */ +enum nvtxPayloadSchemaType +{ + NVTX_PAYLOAD_SCHEMA_TYPE_INVALID = 0, + + NVTX_PAYLOAD_SCHEMA_TYPE_STATIC = 1, + NVTX_PAYLOAD_SCHEMA_TYPE_DYNAMIC = 2, + + NVTX_PAYLOAD_SCHEMA_TYPE_UNION = 3, + NVTX_PAYLOAD_SCHEMA_TYPE_UNION_WITH_INTERNAL_SELECTOR = 4 +}; + +/** + * \brief Flags for static and dynamic schemas. + */ +enum nvtxPayloadSchemaFlags +{ + NVTX_PAYLOAD_SCHEMA_FLAG_NONE = 0, + + /** + * This flag indicates that a schema and the corresponding payloads can + * contain fields which require a deep copy. + */ + NVTX_PAYLOAD_SCHEMA_FLAG_DEEP_COPY = (1 << 1), + + /** + * This flag indicates that a schema and the corresponding payloads can + * be referenced by another payload of the same event. + */ + NVTX_PAYLOAD_SCHEMA_FLAG_REFERENCED = (1 << 2), + + /** + * The schema describes a deferred event/marker. Such a schema requires one + * timestamp entry and one string entry with the flag + * `NVTX_PAYLOAD_ENTRY_FLAG_EVENT_MESSAGE`. Category and color can be + * optionally specified with the respective entry types. The deferred event + * can contain a binary payload itself by using a custom schema ID as type + * its schema description. Multiple occurrences of the same event can be + * described by specifying an array timestamps. + */ + NVTX_PAYLOAD_SCHEMA_FLAG_DEFERRED_EVENT = (1 << 3), + /** + * The schema describes a deferred event/marker. Such a schema requires + * one start timestamp, one end timestamp and one string entry with the flag + * `NVTX_PAYLOAD_ENTRY_FLAG_EVENT_MESSAGE`. Category and color can be + * optionally specified with the respective entry types. The deferred range + * can contain a binary payload itself by using a custom schema ID as type + * its schema description. + * + * Timestamps can be provided in different ways: + * - A single range has two timestamp entries with the first (smaller entry + * index) being used as the start/push timestamp. + * - If the range schema contains one array of timestamps, the tool assumes + * that the array contains alternating start and end timestamps. + * - If two timestamp arrays are specified the first entry (with the + * smaller entry index) is assumed to contain the start timestamps. Both + * arrays have to be of the same size. + */ + NVTX_PAYLOAD_SCHEMA_FLAG_DEFERRED_RANGE = (2 << 3) +}; + +/** + * The values allow the valid fields in @ref nvtxPayloadSchemaAttr_t to be + * specified via setting the field `fieldMask`. + */ +#define NVTX_PAYLOAD_SCHEMA_ATTR_NAME (1 << 1) +#define NVTX_PAYLOAD_SCHEMA_ATTR_TYPE (1 << 2) +#define NVTX_PAYLOAD_SCHEMA_ATTR_FLAGS (1 << 3) +#define NVTX_PAYLOAD_SCHEMA_ATTR_ENTRIES (1 << 4) +#define NVTX_PAYLOAD_SCHEMA_ATTR_NUM_ENTRIES (1 << 5) +#define NVTX_PAYLOAD_SCHEMA_ATTR_STATIC_SIZE (1 << 6) +#define NVTX_PAYLOAD_SCHEMA_ATTR_ALIGNMENT (1 << 7) +#define NVTX_PAYLOAD_SCHEMA_ATTR_SCHEMA_ID (1 << 8) + +/** + * NVTX payload schema attributes. + */ +typedef struct nvtxPayloadSchemaAttr_t +{ + /** + * \brief Mask of valid fields in this structure. + * + * The values from `enum nvtxPayloadSchemaAttributes` have to be used. + */ + uint64_t fieldMask; + + /** + * \brief Name of the payload schema. (Optional) + */ + const char* name; + + /** + * \brief Payload schema type. (Mandatory) \anchor PAYLOAD_TYPE_FIELD + * + * A value from `enum nvtxPayloadSchemaType` has to be used. + */ + uint64_t type; + + /** + * \brief Payload schema flags. (Optional) + * + * Flags defined in `enum nvtxPayloadSchemaFlags` can be used to set + * additional properties of the schema. + */ + uint64_t flags; + + /** + * \brief Entries of a payload schema. (Mandatory) \anchor ENTRIES_FIELD + * + * This field is a pointer to an array of schema entries, each describing a + * field in a data structure, e.g. in a C struct or union. + */ + const nvtxPayloadSchemaEntry_t* entries; + + /** + * \brief Number of entries in the payload schema. (Mandatory) + * + * Number of entries in the array of payload entries \ref ENTRIES_FIELD. + */ + size_t numEntries; + + /** + * \brief The binary payload size in bytes for static payload schemas. + * + * If \ref PAYLOAD_TYPE_FIELD is @ref NVTX_PAYLOAD_SCHEMA_TYPE_DYNAMIC this + * value is ignored. If this field is not specified for a schema of type + * @ref NVTX_PAYLOAD_SCHEMA_TYPE_STATIC, the size can be automatically + * determined by a tool. + */ + size_t payloadStaticSize; + + /** + * \brief The byte alignment for packed structures. + * + * If not specified, this field defaults to `0`, which means that the fields + * in the data structure are not packed and natural alignment rules can be + * applied. + */ + size_t packAlign; + + /* Static/custom schema ID must be + >= NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START and + < NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_DYNAMIC_START */ + uint64_t schemaId; +} nvtxPayloadSchemaAttr_t; + +/** + * \brief Register a payload schema. + * + * @param domain NVTX domain handle. + * @param attr NVTX payload schema attributes. + */ +NVTX_DECLSPEC uint64_t NVTX_API nvtxPayloadSchemaRegister( + nvtxDomainHandle_t domain, const nvtxPayloadSchemaAttr_t* attr); + +/** + * \brief Enumeration entry. + * + * Since the value of an enum entry might not be meaningful for the analysis, + * a tool can show the name of enum entry instead. + * + * @note EXPERIMENTAL + */ +typedef struct nvtxPayloadEnum_t +{ + /** + * Name of the enum value. + */ + const char* name; + + /** + * Value of the enum entry. + */ + uint64_t value; + + /** + * Indicates that this entry sets a specific set of bits, which can be used + * to easily define bitsets. + */ + int8_t isFlag; +} nvtxPayloadEnum_t; + +/** + * The values are used to set the field `fieldMask` and specify which fields in + * `nvtxPayloadEnumAttr_t` are set. + */ +#define NVTX_PAYLOAD_ENUM_ATTR_NAME (1 << 1) +#define NVTX_PAYLOAD_ENUM_ATTR_ENTRIES (1 << 2) +#define NVTX_PAYLOAD_ENUM_ATTR_NUM_ENTRIES (1 << 3) +#define NVTX_PAYLOAD_ENUM_ATTR_SIZE (1 << 4) +#define NVTX_PAYLOAD_ENUM_ATTR_SCHEMA_ID (1 << 5) + +/** + * NVTX payload enumeration type attributes. + */ +typedef struct nvtxPayloadEnumAttr_t { + /** + * Mask of valid fields in this struct. + * The values from `enum nvtxPayloadSchemaAttributes` have to be used. + */ + uint64_t fieldMask; + + /** + * Name of the enum. (Optional) + */ + const char* name; + + /** + * Entries of the enum. (Mandatory) + */ + const nvtxPayloadEnum_t* entries; + + /** + * Number of entries in the enum. (Mandatory) + */ + size_t numEntries; + + /** + * Size of enumeration type in bytes + */ + size_t sizeOfEnum; + + /** + * Static/custom schema ID must be + * >= NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START and + * < NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_DYNAMIC_START + */ + uint64_t schemaId; +} nvtxPayloadEnumAttr_t; + +/** + * \brief Register an enumeration type with the payload extension. + * + * @param domain NVTX domain handle + * @param attr NVTX payload enumeration type attributes. + */ +NVTX_DECLSPEC uint64_t nvtxPayloadEnumRegister(nvtxDomainHandle_t domain, + const nvtxPayloadEnumAttr_t* attr); + +/** + * \brief Callback Ids of API functions in the payload extension. + * + * The NVTX handler can use these values to register a handler function. When + * InitializeInjectionNvtxExtension(nvtxExtModuleInfo_t* moduleInfo) is + * executed, a handler routine 'handlenvtxPayloadRegisterSchema' can be + * registered as follows: + * moduleInfo->segments->slots[NVTX3EXT_CBID_nvtxPayloadSchemaRegister] = + * (intptr_t)handlenvtxPayloadRegisterSchema; + */ +typedef enum NvtxExtPayloadCallbackId +{ + NVTX3EXT_CBID_nvtxPayloadSchemaRegister = 0, + NVTX3EXT_CBID_nvtxPayloadEnumRegister = 1, + NVTX3EXT_CBID_PAYLOAD_FN_NUM = 2 +} NvtxExtPayloadCallbackId; + +#ifdef __GNUC__ +#pragma GCC visibility push(internal) +#endif + +#define NVTX_EXT_TYPES_GUARD /* Ensure other headers cannot include directly */ +#include "nvtxExtDetail/nvtxExtTypes.h" +#undef NVTX_EXT_TYPES_GUARD + +#ifndef NVTX_NO_IMPL +#define NVTX_EXT_IMPL_PAYLOAD_GUARD /* Ensure other headers cannot included directly */ +#include "nvtxExtDetail/nvtxExtPayloadTypeInfo.h" +#include "nvtxExtDetail/nvtxExtImplPayload_v1.h" +#undef NVTX_EXT_IMPL_PAYLOAD_GUARD +#endif /*NVTX_NO_IMPL*/ + +#ifdef __GNUC__ +#pragma GCC visibility pop +#endif + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif /* NVTOOLSEXT_PAYLOAD_H */ diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtImpl.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtImpl.h new file mode 100644 index 0000000..5e42778 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtImpl.h @@ -0,0 +1,93 @@ +/* +* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +#ifndef NVTX_EXT_IMPL_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined). +#endif + +#ifndef NVTX_EXT_IMPL_H +#define NVTX_EXT_IMPL_H +/* ---- Include required platform headers ---- */ + +#if defined(_WIN32) + +#include + +#else +#include + +#if defined(__ANDROID__) +#include +#endif + +#if defined(__linux__) || defined(__CYGWIN__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#endif + +/* ---- Define macros used in this file ---- */ + +#ifdef NVTX_DEBUG_PRINT +#ifdef __ANDROID__ +#include +#define NVTX_ERR(...) __android_log_print(ANDROID_LOG_ERROR, "NVTOOLSEXT", __VA_ARGS__); +#define NVTX_INFO(...) __android_log_print(ANDROID_LOG_INFO, "NVTOOLSEXT", __VA_ARGS__); +#else +#include +#define NVTX_ERR(...) fprintf(stderr, "NVTX_ERROR: " __VA_ARGS__) +#define NVTX_INFO(...) fprintf(stderr, "NVTX_INFO: " __VA_ARGS__) +#endif +#else /* !defined(NVTX_DEBUG_PRINT) */ +#define NVTX_ERR(...) +#define NVTX_INFO(...) +#endif + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +// #ifdef __GNUC__ +// #pragma GCC visibility push(hidden) +// #endif + +#define NVTX_EXTENSION_FRESH 0 +#define NVTX_EXTENSION_DISABLED 1 +#define NVTX_EXTENSION_STARTING 2 +#define NVTX_EXTENSION_LOADED 3 + +NVTX_LINKONCE_DEFINE_GLOBAL NvtxExtInitializeInjectionFunc_t NVTX_VERSIONED_IDENTIFIER(injectionFnPtr) = (NvtxExtInitializeInjectionFunc_t)0; + +#define NVTX_EXT_INIT_GUARD +#include "nvtxExtInit.h" +#undef NVTX_EXT_INIT_GUARD + +// #ifdef __GNUC__ +// #pragma GCC visibility pop +// #endif + +#ifdef __cplusplus +} /* extern "C" */ +#endif /* __cplusplus */ + +#endif /* NVTX_EXT_IMPL_H */ \ No newline at end of file diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h new file mode 100644 index 0000000..d589f63 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h @@ -0,0 +1,85 @@ +/* +* Copyright 2021 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +#ifndef NVTX_EXT_IMPL_PAYLOAD_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExtPayload.h (except when NVTX_NO_IMPL is defined). +#endif + +#define NVTX_EXT_IMPL_GUARD +#include "nvtxExtImpl.h" +#undef NVTX_EXT_IMPL_GUARD + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#define NVTX_EXT_PAYLOAD_VERSIONED_IDENTIFIER_L3(NAME, VERSION, COMPATID) \ + NAME##_v##VERSION##_mem##COMPATID +#define NVTX_EXT_PAYLOAD_VERSIONED_IDENTIFIER_L2(NAME, VERSION, COMPATID) \ + NVTX_EXT_PAYLOAD_VERSIONED_IDENTIFIER_L3(NAME, VERSION, COMPATID) +#define NVTX_EXT_PAYLOAD_VERSIONED_ID(NAME) \ + NVTX_EXT_PAYLOAD_VERSIONED_IDENTIFIER_L2(NAME, NVTX_VERSION, NVTX_EXT_COMPATID_PAYLOAD) + +/* + * Function slots for the binary payload extension. First entry is the module + * state, initialized to `0` (`NVTX_EXTENSION_FRESH`). + */ +NVTX_LINKONCE_DEFINE_GLOBAL intptr_t +NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)[NVTX3EXT_CBID_PAYLOAD_FN_NUM + 1] + = {0}; + +NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadInitOnce)() +{ + nvtxExtModuleSegment_t segment = { + 0, // unused (only one segment) + NVTX3EXT_CBID_PAYLOAD_FN_NUM, + NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots) + 1 + }; + + nvtxExtModuleInfo_t module = { + NVTX_VERSION, sizeof(nvtxExtModuleInfo_t), + NVTX_EXT_MODULEID_PAYLOAD, NVTX_EXT_COMPATID_PAYLOAD, + 1, &segment, // number of segments, segments + NULL, // no export function needed + // bake type sizes and alignment information into program binary + &nvtxExtPayloadTypeInfo + }; + + NVTX_INFO( "%s\n", __FUNCTION__ ); + + NVTX_VERSIONED_IDENTIFIER(nvtxExtInitOnce)(&module, + NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)); +} + +#define NVTX_EXT_FN_IMPL(ret_val, fn_name, signature, arg_names) \ +typedef ret_val ( * fn_name##_impl_fntype )signature; \ +NVTX_LINKONCE_DEFINE_FUNCTION ret_val fn_name signature { \ + intptr_t slot = NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)[NVTX3EXT_CBID_##fn_name + 1]; \ + if (slot != NVTX_EXTENSION_DISABLED) { \ + if (slot) { \ + return (*(fn_name##_impl_fntype)slot) arg_names; \ + } else { \ + NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadInitOnce)(); \ + slot = NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)[NVTX3EXT_CBID_##fn_name + 1]; \ + if (slot != NVTX_EXTENSION_DISABLED && slot) { \ + return (*(fn_name##_impl_fntype)slot) arg_names; \ + } \ + } \ + } \ + return ((ret_val)(intptr_t)-1); \ +} + +NVTX_EXT_FN_IMPL(uint64_t, nvtxPayloadSchemaRegister, (nvtxDomainHandle_t domain, const nvtxPayloadSchemaAttr_t* attr), (domain, attr)) + +NVTX_EXT_FN_IMPL(uint64_t, nvtxPayloadEnumRegister, (nvtxDomainHandle_t domain, const nvtxPayloadEnumAttr_t* attr), (domain, attr)) + +#undef NVTX_EXT_FN_IMPL + +#ifdef __cplusplus +} /* extern "C" */ +#endif /* __cplusplus */ \ No newline at end of file diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtInit.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtInit.h new file mode 100644 index 0000000..724c217 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtInit.h @@ -0,0 +1,363 @@ +/* +* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +#ifndef NVTX_EXT_INIT_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined). +#endif + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/* ---- Platform-independent helper definitions and functions ---- */ + +/* Prefer macros over inline functions to reduce symbol resolution at link time */ + +#if defined(_WIN32) +#define NVTX_PATHCHAR wchar_t +#define NVTX_STR(x) L##x +#define NVTX_GETENV _wgetenv +#define NVTX_BUFSIZE MAX_PATH +#define NVTX_DLLHANDLE HMODULE +#define NVTX_DLLOPEN(x) LoadLibraryW(x) +#define NVTX_DLLFUNC GetProcAddress +#define NVTX_DLLCLOSE FreeLibrary +#define NVTX_YIELD() SwitchToThread() +#define NVTX_MEMBAR() MemoryBarrier() +#define NVTX_ATOMIC_WRITE_32(address, value) InterlockedExchange((volatile LONG*)address, value) +#define NVTX_ATOMIC_CAS_32(old, address, exchange, comparand) old = InterlockedCompareExchange((volatile LONG*)address, exchange, comparand) +#define NVTX_ATOMIC_WRITE_PTR(address, value) InterlockedExchangePointer((volatile PVOID*)address, (PVOID)value) +#define NVTX_ATOMIC_CAS_PTR(old, address, exchange, comparand) old = (intptr_t)InterlockedCompareExchangePointer((volatile PVOID*)address, (PVOID)exchange, (PVOID)comparand) + + +#elif defined(__GNUC__) +#define NVTX_PATHCHAR char +#define NVTX_STR(x) x +#define NVTX_GETENV getenv +#define NVTX_BUFSIZE PATH_MAX +#define NVTX_DLLHANDLE void* +#define NVTX_DLLOPEN(x) dlopen(x, RTLD_LAZY) +#define NVTX_DLLFUNC dlsym +#define NVTX_DLLCLOSE dlclose +#define NVTX_YIELD() sched_yield() +#define NVTX_MEMBAR() __sync_synchronize() +/* Ensure full memory barrier for atomics, to match Windows functions */ +#define NVTX_ATOMIC_WRITE_32(address, value) __sync_synchronize(); __sync_lock_test_and_set(address, value) +#define NVTX_ATOMIC_CAS_32(old, address, exchange, comparand) __sync_synchronize(); old = __sync_val_compare_and_swap(address, exchange, comparand) +#define NVTX_ATOMIC_WRITE_PTR(address, value) __sync_synchronize(); __sync_lock_test_and_set(address, value) +#define NVTX_ATOMIC_CAS_PTR(old, address, exchange, comparand) __sync_synchronize(); old = __sync_val_compare_and_swap(address, exchange, comparand) +#else +#error The library does not support your configuration! +#endif + +/* Define this to 1 for platforms that where pre-injected libraries can be discovered. */ +#if defined(_WIN32) +/* TODO */ +#define NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY 0 +#else +#define NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY 0 +#endif + +/* Define this to 1 for platforms that support environment variables */ +/* TODO: Detect UWP, a.k.a. Windows Store app, and set this to 0. */ +/* Try: #if defined(WINAPI_FAMILY_PARTITION) && WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) */ +#define NVTX_SUPPORT_ENV_VARS 1 + +/* Define this to 1 for platforms that support dynamic/shared libraries */ +#define NVTX_SUPPORT_DYNAMIC_INJECTION_LIBRARY 1 + +/* Injection libraries implementing InitializeInjectionNvtxExtension may be statically linked, +* and this will override any dynamic injection. Useful for platforms where dynamic +* injection is not available. Since weak symbols not explicitly marked extern are +* guaranteed to be initialized to zero if no definitions are found by the linker, the +* dynamic injection process proceeds normally if pfnInitializeInjectionNvtx2 is 0. */ +#if defined(__GNUC__) && !defined(_WIN32) && !defined(__CYGWIN__) +#define NVTX_SUPPORT_STATIC_INJECTION_LIBRARY 1 +/* To statically inject an NVTX library, define InitializeInjectionNvtxExtension_fnptr as a normal +* symbol (not weak) pointing to the implementation of InitializeInjectionNvtxExtension (which +* does not need to be named "InitializeInjectionNvtxExtension" as is necessary in a dynamic +* injection library. */ +__attribute__((weak)) NvtxExtInitializeInjectionFunc_t InitializeInjectionNvtxExtension_fnptr; +#else +#define NVTX_SUPPORT_STATIC_INJECTION_LIBRARY 0 +#endif + + + +/* This function tries to find or load an NVTX injection library and get the +* address of its InitializeInjectionExtension function. If such a function pointer +* is found, it is called, and passed the address of this NVTX instance's +* nvtxGetExportTable function, so the injection can attach to this instance. +* If the initialization fails for any reason, any dynamic library loaded will +* be freed, and all NVTX implementation functions will be set to no-ops. If +* initialization succeeds, NVTX functions not attached to the tool will be set +* to no-ops. This is implemented as one function instead of several small +* functions to minimize the number of weak symbols the linker must resolve. +* Order of search is: +* - Pre-injected library exporting InitializeInjectionNvtxExtension +* - Loadable library exporting InitializeInjectionNvtxExtension +* - Path specified by env var NVTX_INJECTION??_PATH (?? is 32 or 64) +* - On Android, libNvtxInjection??.so within the package (?? is 32 or 64) +* - Statically-linked injection library defining InitializeInjectionNvtx2_fnptr +*/ +NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_VERSIONED_IDENTIFIER(nvtxExtLoadInjectionLibrary)(NvtxExtInitializeInjectionFunc_t* out_init_fnptr); +NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_VERSIONED_IDENTIFIER(nvtxExtLoadInjectionLibrary)(NvtxExtInitializeInjectionFunc_t* out_init_fnptr) +{ + const char* const initFuncName = "InitializeInjectionNvtxExtension"; + NvtxExtInitializeInjectionFunc_t init_fnptr = (NvtxExtInitializeInjectionFunc_t)0; + NVTX_DLLHANDLE injectionLibraryHandle = (NVTX_DLLHANDLE)0; + + if(out_init_fnptr){ + *out_init_fnptr = (NvtxExtInitializeInjectionFunc_t)0; + } + +#if NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY + /* Use POSIX global symbol chain to query for init function from any module */ + init_fnptr = (NvtxExtInitializeInjectionFunc_t)NVTX_DLLFUNC(0, initFuncName); +#endif + +#if NVTX_SUPPORT_DYNAMIC_INJECTION_LIBRARY + /* Try discovering dynamic injection library to load */ + if (!init_fnptr) + { +#if NVTX_SUPPORT_ENV_VARS + /* If env var NVTX_INJECTION64_PATH is set, it should contain the path + * to a 64-bit dynamic NVTX injection library (and similar for 32-bit). */ + const NVTX_PATHCHAR* const nvtxEnvVarName = (sizeof(void*) == 4) + ? NVTX_STR("NVTX_INJECTION32_PATH") + : NVTX_STR("NVTX_INJECTION64_PATH"); +#endif /* NVTX_SUPPORT_ENV_VARS */ + NVTX_PATHCHAR injectionLibraryPathBuf[NVTX_BUFSIZE]; + const NVTX_PATHCHAR* injectionLibraryPath = (const NVTX_PATHCHAR*)0; + + /* Refer to this variable explicitly in case all references to it are #if'ed out */ + (void)injectionLibraryPathBuf; + +#if NVTX_SUPPORT_ENV_VARS + /* Disable the warning for getenv & _wgetenv -- this usage is safe because + * these functions are not called again before using the returned value. */ +#if defined(_MSC_VER) +#pragma warning( push ) +#pragma warning( disable : 4996 ) +#endif + injectionLibraryPath = NVTX_GETENV(nvtxEnvVarName); +#if defined(_MSC_VER) +#pragma warning( pop ) +#endif +#endif + +#if defined(__ANDROID__) + if (!injectionLibraryPath) + { + const char *bits = (sizeof(void*) == 4) ? "32" : "64"; + char cmdlineBuf[32]; + char pkgName[PATH_MAX]; + int count; + int pid; + FILE *fp; + size_t bytesRead; + size_t pos; + + pid = (int)getpid(); + count = snprintf(cmdlineBuf, sizeof(cmdlineBuf), "/proc/%d/cmdline", pid); + if (count <= 0 || count >= (int)sizeof(cmdlineBuf)) + { + NVTX_ERR("Path buffer too small for: /proc/%d/cmdline\n", pid); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + + fp = fopen(cmdlineBuf, "r"); + if (!fp) + { + NVTX_ERR("File couldn't be opened: %s\n", cmdlineBuf); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + + bytesRead = fread(pkgName, 1, sizeof(pkgName) - 1, fp); + fclose(fp); + if (bytesRead == 0) + { + NVTX_ERR("Package name couldn't be read from file: %s\n", cmdlineBuf); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + + pkgName[bytesRead] = 0; + + /* String can contain colon as a process separator. In this case the package name is before the colon. */ + pos = 0; + while (pos < bytesRead && pkgName[pos] != ':' && pkgName[pos] != '\0') + { + ++pos; + } + pkgName[pos] = 0; + + count = snprintf(injectionLibraryPathBuf, NVTX_BUFSIZE, "/data/data/%s/files/libNvtxInjection%s.so", pkgName, bits); + if (count <= 0 || count >= NVTX_BUFSIZE) + { + NVTX_ERR("Path buffer too small for: /data/data/%s/files/libNvtxInjection%s.so\n", pkgName, bits); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + + /* On Android, verify path is accessible due to aggressive file access restrictions. */ + /* For dlopen, if the filename contains a leading slash, then it is interpreted as a */ + /* relative or absolute pathname; otherwise it will follow the rules in ld.so. */ + if (injectionLibraryPathBuf[0] == '/') + { +#if (__ANDROID_API__ < 21) + int access_err = access(injectionLibraryPathBuf, F_OK | R_OK); +#else + int access_err = faccessat(AT_FDCWD, injectionLibraryPathBuf, F_OK | R_OK, 0); +#endif + if (access_err != 0) + { + NVTX_ERR("Injection library path wasn't accessible [code=%s] [path=%s]\n", strerror(errno), injectionLibraryPathBuf); + return NVTX_ERR_INIT_ACCESS_LIBRARY; + } + } + injectionLibraryPath = injectionLibraryPathBuf; + } +#endif + + /* At this point, injectionLibraryPath is specified if a dynamic + * injection library was specified by a tool. */ + if (injectionLibraryPath) + { + /* Load the injection library */ + injectionLibraryHandle = NVTX_DLLOPEN(injectionLibraryPath); + if (!injectionLibraryHandle) + { + NVTX_ERR("Failed to load injection library\n"); + return NVTX_ERR_INIT_LOAD_LIBRARY; + } + else + { + /* Attempt to get the injection library's entry-point */ + init_fnptr = (NvtxExtInitializeInjectionFunc_t)NVTX_DLLFUNC(injectionLibraryHandle, initFuncName); + if (!init_fnptr) + { + NVTX_DLLCLOSE(injectionLibraryHandle); + NVTX_ERR("Failed to get address of function %s from injection library\n", initFuncName); + return NVTX_ERR_INIT_MISSING_LIBRARY_ENTRY_POINT; + } + } + } + } +#endif + +#if NVTX_SUPPORT_STATIC_INJECTION_LIBRARY + if (!init_fnptr) + { + /* Check weakly-defined function pointer. A statically-linked injection can define this as + * a normal symbol and it will take precedence over a dynamic injection. */ + if (InitializeInjectionNvtxExtension_fnptr) + { + init_fnptr = InitializeInjectionNvtxExtension_fnptr; + } + } +#endif + + if(out_init_fnptr){ + *out_init_fnptr = init_fnptr; + } + + /* At this point, if init_fnptr is not set, then no tool has specified + * an NVTX injection library -- return non-success result so all NVTX + * API functions will be set to no-ops. */ + if (!init_fnptr) + { + return NVTX_ERR_NO_INJECTION_LIBRARY_AVAILABLE; + } + + return NVTX_SUCCESS; +} + +NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxExtInitOnce) ( + nvtxExtModuleInfo_t* moduleInfo, + intptr_t* moduleState + ) +{ + intptr_t old; + + NVTX_INFO( "%s\n", __FUNCTION__ ); + + if( *moduleState == NVTX_EXTENSION_LOADED) { + return; + } + + NVTX_ATOMIC_CAS_PTR( + old, + moduleState, + NVTX_EXTENSION_STARTING, + NVTX_EXTENSION_FRESH); + if (old == NVTX_EXTENSION_FRESH) + { + NvtxExtInitializeInjectionFunc_t init_fnptr = NVTX_VERSIONED_IDENTIFIER(injectionFnPtr); + int entryPointStatus = 0; + int forceAllToNoops = 0; + + /* Load & initialize injection library -- it will assign the function pointers */ + if(init_fnptr == 0){ + int result = 0; + + /* try to load vanilla NVTX first*/ + nvtxInitialize(0); + + result = NVTX_VERSIONED_IDENTIFIER(nvtxExtLoadInjectionLibrary)(&init_fnptr); + /*at this point init_fnptr will be either 0 or a real function*/ + + if(result == NVTX_SUCCESS) { + NVTX_VERSIONED_IDENTIFIER(injectionFnPtr) = init_fnptr; + } + else { + NVTX_ERR("Failed to load injection library\n"); + } + } + + if(init_fnptr != 0) { + /* Invoke injection library's initialization function. If it returns + * 0 (failure) and a dynamic injection was loaded, unload it. */ + entryPointStatus = init_fnptr(moduleInfo); + if (entryPointStatus == 0) { + NVTX_ERR("Failed to initialize injection library -- initialization function returned 0\n"); + } + } + + /* Clean up any functions that are still uninitialized so that they are skipped. + * Set all to null if injection init function failed as well. + */ + forceAllToNoops = (init_fnptr == 0) || (entryPointStatus == 0); + for(size_t s = 0; s < moduleInfo->segmentsCount; ++s){ + nvtxExtModuleSegment_t* segment = moduleInfo->segments+s; + for(size_t i = 0; i < segment->slotCount; ++i){ + if(forceAllToNoops || (segment->functionSlots[i] == NVTX_EXTENSION_FRESH)){ + segment->functionSlots[i] = NVTX_EXTENSION_DISABLED; + } + } + } + + NVTX_MEMBAR(); + + /* Signal that initialization has finished, so now the assigned function pointers will be used */ + NVTX_ATOMIC_WRITE_PTR( + moduleState, + NVTX_EXTENSION_LOADED); + } + else /* Spin-wait until initialization has finished */ + { + NVTX_MEMBAR(); + while (*moduleState != NVTX_EXTENSION_LOADED) + { + NVTX_YIELD(); + NVTX_MEMBAR(); + } + } +} + +#ifdef __cplusplus +} +#endif /* __cplusplus */ diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtPayloadTypeInfo.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtPayloadTypeInfo.h new file mode 100644 index 0000000..c2c1ac5 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtPayloadTypeInfo.h @@ -0,0 +1,128 @@ +#ifndef NVTX_EXT_IMPL_PAYLOAD_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExtPayload.h (except when NVTX_NO_IMPL is defined). +#endif + +/* + * Helper array to get the alignment for each predefined C language type. + */ + +typedef void* pointer_type; + +#if __STDC_VERSION__ >= 201112L /* or CPP11 */ +#include +#define nvtx_alignof(type) alignof(type) +#define nvtx_alignof2(type,tname) alignof(type) +#else /* __STDC_VERSION__ >= 201112L */ +#ifndef __cplusplus + +#include +#define nvtx_alignof(type) offsetof(struct {char c; type d;}, d) +#define nvtx_alignof2(type,tname) nvtx_alignof(type) + +#else /* __cplusplus */ + +#define MKTYPEDEF(TYPE) typedef struct {char c; TYPE d;} _nvtx_##TYPE +#define MKTYPEDEF2(TYPE,TNAME) typedef struct {char c; TYPE d;} _nvtx_##TNAME +#define nvtx_alignof(TNAME) offsetof(_nvtx_##TNAME, d) +#define nvtx_alignof2(type,tname) offsetof(_nvtx_##tname, d) + +MKTYPEDEF(char); +MKTYPEDEF2(unsigned char, uchar); +MKTYPEDEF(short); +MKTYPEDEF2(unsigned short, ushort); +MKTYPEDEF(int); +MKTYPEDEF2(unsigned int, uint); +MKTYPEDEF(long); +MKTYPEDEF2(unsigned long, ulong); +MKTYPEDEF2(long long, longlong); +MKTYPEDEF2(unsigned long long, ulonglong); + +MKTYPEDEF(int8_t); +MKTYPEDEF(uint8_t); +MKTYPEDEF(int16_t); +MKTYPEDEF(uint16_t); +MKTYPEDEF(int32_t); +MKTYPEDEF(uint32_t); +MKTYPEDEF(int64_t); +MKTYPEDEF(uint64_t); + +MKTYPEDEF(float); +MKTYPEDEF(double); +MKTYPEDEF2(long double, longdouble); + +MKTYPEDEF(size_t); +MKTYPEDEF(pointer_type); + +MKTYPEDEF(wchar_t); +#if (__STDC_VERSION__ > 201710L) || (defined(__cplusplus) && __cplusplus > 201703L) + {sizeof(char8_t), nvtx_alignof(char8_t)}, + MKTYPEDEF(char8_t); +#endif +#if (__STDC_VERSION__ >= 201112L) || (defined(__cplusplus) && __cplusplus >= 201103L) + MKTYPEDEF(char16_t); + MKTYPEDEF(char32_t); +#endif + +#undef MKTYPEDEF +#undef MKTYPEDEF2 + +#endif /* __cplusplus */ +#endif /* __STDC_VERSION__ >= 201112L */ + +/* + * The order of entries must match the values in`enum nvtxPayloadSchemaEntryType`. + */ +const nvtxPayloadEntryTypeInfo_t nvtxExtPayloadTypeInfo[NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE] = +{ + /* The first entry contains this array's length and the size of each entry in this array. */ + {NVTX_PAYLOAD_ENTRY_TYPE_INFO_ARRAY_SIZE, sizeof(nvtxPayloadEntryTypeInfo_t)}, + + /*** C integer types ***/ + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR */ {sizeof(char), nvtx_alignof(char)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UCHAR */ {sizeof(unsigned char), nvtx_alignof2(unsigned char, uchar)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_SHORT */ {sizeof(short), nvtx_alignof(short)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_USHORT */ {sizeof(unsigned short), nvtx_alignof2(unsigned short, ushort)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_INT */ {sizeof(int), nvtx_alignof(int)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT */ {sizeof(unsigned int), nvtx_alignof2(unsigned int, uint)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_LONG */ {sizeof(long), nvtx_alignof(long)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_ULONG */ {sizeof(unsigned long), nvtx_alignof2(unsigned long, ulong)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_LONGLONG */ {sizeof(long long), nvtx_alignof2(long long, longlong)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_ULONGLONG */ {sizeof(unsigned long long), nvtx_alignof2(unsigned long long,ulonglong)}, + + /*** Integer types with explicit size ***/ + /* NVTX_PAYLOAD_ENTRY_TYPE_INT8 */ {sizeof(int8_t), nvtx_alignof(int8_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT8 */ {sizeof(uint8_t), nvtx_alignof(uint8_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_INT16 */ {sizeof(int16_t), nvtx_alignof(int16_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT16 */ {sizeof(uint16_t), nvtx_alignof(uint16_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_INT32 */ {sizeof(int32_t), nvtx_alignof(int32_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT32 */ {sizeof(uint32_t), nvtx_alignof(uint32_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_INT64 */ {sizeof(int64_t), nvtx_alignof(int64_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_UINT64 */ {sizeof(uint64_t), nvtx_alignof(uint64_t)}, + + /*** C floating point types ***/ + /* NVTX_PAYLOAD_ENTRY_TYPE_FLOAT */ {sizeof(float), nvtx_alignof(float)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_DOUBLE */ {sizeof(double), nvtx_alignof(double)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_LONGDOUBLE */ {sizeof(long double), nvtx_alignof2(long double, longdouble)}, + + /* NVTX_PAYLOAD_ENTRY_TYPE_SIZE */ {sizeof(size_t), nvtx_alignof(size_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_ADDRESS */ {sizeof(pointer_type), nvtx_alignof(pointer_type)}, + + /*** Special character types ***/ + /* NVTX_PAYLOAD_ENTRY_TYPE_WCHAR */ {sizeof(wchar_t), nvtx_alignof(wchar_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR8 */ +#if (__STDC_VERSION__ > 201710L) || (defined(__cplusplus) && __cplusplus > 201703L) + {sizeof(char8_t), nvtx_alignof(char8_t)}, +#else + {0, 0}, +#endif +#if (__STDC_VERSION__ >= 201112L) || (defined(__cplusplus) && __cplusplus >= 201103L) + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR16 */ {sizeof(char16_t), nvtx_alignof(char16_t)}, + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 */ {sizeof(char32_t), nvtx_alignof(char32_t)} +#else + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR16 */ {0, 0}, + /* NVTX_PAYLOAD_ENTRY_TYPE_CHAR32 */ {0, 0} +#endif +}; + +#undef nvtx_alignof +#undef nvtx_alignof2 \ No newline at end of file diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtTypes.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtTypes.h new file mode 100644 index 0000000..bcad095 --- /dev/null +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtTypes.h @@ -0,0 +1,44 @@ +/* +* Copyright 2021 NVIDIA Corporation. All rights reserved. +* +* Licensed under the Apache License v2.0 with LLVM Exceptions. +* See https://llvm.org/LICENSE.txt for license information. +* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +*/ + +/* This header defines types which are used by the internal implementation +* of NVTX and callback subscribers. API clients do not use these types, +* so they are defined here instead of in nvToolsExt.h to clarify they are +* not part of the NVTX client API. */ + +#ifndef NVTXEXTTYPES_H +#define NVTXEXTTYPES_H + +#ifndef NVTX_EXT_TYPES_GUARD +#error Never include this file directly -- it is automatically included by nvToolsExt[EXTENSION].h. +#endif + +typedef intptr_t (NVTX_API * NvtxExtGetExportFunction_t)(uint32_t exportFunctionId); + +typedef struct nvtxExtModuleSegment_t +{ + size_t segmentId; + size_t slotCount; + intptr_t* functionSlots; +} nvtxExtModuleSegment_t; + +typedef struct nvtxExtModuleInfo_t +{ + uint16_t nvtxVer; + uint16_t structSize; + uint16_t moduleId; + uint16_t compatId; + size_t segmentsCount; + nvtxExtModuleSegment_t* segments; + NvtxExtGetExportFunction_t getExportFunction; + const void* extInfo; +} nvtxExtModuleInfo_t; + +typedef int (NVTX_API * NvtxExtInitializeInjectionFunc_t)(nvtxExtModuleInfo_t* moduleInfo); + +#endif /* NVTXEXTTYPES_H */ \ No newline at end of file diff --git a/src/include/proxy.h b/src/include/proxy.h index fa8f388..4c75e21 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -11,6 +11,7 @@ #include "info.h" #include "socket.h" #include +#include "shm.h" enum ncclProxyOpState { ncclProxyOpNone, ncclProxyOpReady, ncclProxyOpProgress }; @@ -106,6 +107,7 @@ struct ncclProxyOpsPool { struct ncclProxyOps { ncclProxyOpsPool* pool; + ncclShmHandle_t handle; int count; int freeOp; int nextOps; @@ -145,6 +147,7 @@ struct ncclProxyPool; struct ncclProxyProgressState { // Used by main threads to send work to progress thread struct ncclProxyOpsPool* opsPool; + ncclShmHandle_t handle; char opsPoolShmSuffix[6]; pthread_t thread; @@ -164,7 +167,6 @@ struct ncclProxyState { struct ncclSocket* listenSock; int stop; CUcontext cudaCtx; - int safeAbortFlag; // Used by main thread union ncclSocketAddress* peerAddresses; @@ -176,6 +178,15 @@ struct ncclProxyState { struct ncclProxyProgressState progressState; }; +enum proxyConnectState { + connUninitialized = 0, + connInitialized = 1, + connSharedInitialized = 2, + connSetupDone = 3, + connConnected = 4, + numConnStates = 5 +}; + struct ncclProxyConnection { int send, transport, shared; int localRank; @@ -184,7 +195,7 @@ struct ncclProxyConnection { struct ncclProxyArgs *proxyAppend; struct ncclProxyArgs **proxyAppendPtr; void* transportResources; - bool initFlag; + proxyConnectState state; }; typedef ncclResult_t (*threadFunc_t)(struct ncclProxyArgs*); diff --git a/src/include/shm.h b/src/include/shm.h index 08dc849..61b0b4d 100644 --- a/src/include/shm.h +++ b/src/include/shm.h @@ -9,7 +9,9 @@ #include "nccl.h" -ncclResult_t ncclShmOpen(char* shmPath, const int shmSize, void** shmPtr, void** devShmPtr, int create); -ncclResult_t ncclShmUnlink(const char* shmname); -ncclResult_t ncclShmClose(void* shmPtr, void* devShmPtr, const int shmSize); +typedef void* ncclShmHandle_t; +ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** devShmPtr, int refcount, ncclShmHandle_t* handle); +ncclResult_t ncclShmClose(ncclShmHandle_t handle); +ncclResult_t ncclShmUnlink(ncclShmHandle_t handle); + #endif diff --git a/src/include/socket.h b/src/include/socket.h index d72480b..a0c7a4d 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -21,6 +21,7 @@ #define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec) #define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s) #define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV) +#define NCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL /* Common socket address storage structure for IPv4/IPv6 */ union ncclSocketAddress { @@ -30,32 +31,59 @@ union ncclSocketAddress { }; enum ncclSocketState { - ncclSocketConnecting = 0, - ncclSocketConnected = 1, - ncclSocketError = 2, - ncclSocketStateNum = 3 -} ; + ncclSocketStateNone = 0, + ncclSocketStateInitialized = 1, + ncclSocketStateAccepting = 2, + ncclSocketStateAccepted = 3, + ncclSocketStateConnecting = 4, + ncclSocketStateConnectPolling = 5, + ncclSocketStateConnected = 6, + ncclSocketStateReady = 7, + ncclSocketStateClosed = 8, + ncclSocketStateError = 9, + ncclSocketStateNum = 10 +}; + +enum ncclSocketType { + ncclSocketTypeUnknown = 0, + ncclSocketTypeBootstrap = 1, + ncclSocketTypeProxy = 2, + ncclSocketTypeNetSocket = 3, + ncclSocketTypeNetIb = 4 +}; struct ncclSocket { int fd; + int acceptFd; + int timedOutRetries; + int refusedRetries; union ncclSocketAddress addr; volatile uint32_t* abortFlag; int asyncFlag; enum ncclSocketState state; + int salen; + uint64_t magic; + enum ncclSocketType type; }; const char *ncclSocketToString(union ncclSocketAddress *addr, char *buf, const int numericHostForm = 1); -ncclResult_t ncclGetSocketAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair); +ncclResult_t ncclSocketGetAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair); int ncclFindInterfaceMatchSubnet(char* ifNames, union ncclSocketAddress* localAddrs, union ncclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs); int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNameMaxSize, int maxIfs); + +// Initialize a socket +ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr = NULL, uint64_t magic = NCCL_SOCKET_MAGIC, enum ncclSocketType type = ncclSocketTypeUnknown, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0); // Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call ncclResult_t ncclSocketListen(struct ncclSocket* sock); +ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr); // Connect to sock->addr. sock->fd is set after a successful call. ncclResult_t ncclSocketConnect(struct ncclSocket* sock); // Return socket connection state. -ncclResult_t ncclGetSocketState(struct ncclSocket* sock, enum ncclSocketState* state); -// Accept an incoming connection from listenSocket->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr. -ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSocket); +ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running); +// Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr. +ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* ulistenSock); +ncclResult_t ncclSocketGetFd(struct ncclSocket* sock, int* fd); +ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock); #define NCCL_SOCKET_SEND 0 #define NCCL_SOCKET_RECV 1 @@ -65,6 +93,5 @@ ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed); -/* initialize a socket. */ -ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr = NULL, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0); +ncclResult_t ncclSocketClose(struct ncclSocket* sock); #endif diff --git a/src/include/strongstream.h b/src/include/strongstream.h index 16b6e07..0984dfe 100644 --- a/src/include/strongstream.h +++ b/src/include/strongstream.h @@ -98,16 +98,19 @@ ncclResult_t ncclStrongStreamLaunchKernel( ); // Cause `a` to wait for the current state `b`. Both `a` and `b` must be acquired. +// `b_subsumes_a` indicates that all work in `a` is already present in `b`, thus +// we want to fast-forward `a` to be a clone of `b`. Knowing this permits the +// implementation to induce few graph dependencies. ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b + struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b, bool b_subsumes_a=false ); // `b` must be capturing within `graph`. ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, struct ncclStrongStream* a, cudaStream_t b + struct ncclCudaGraph graph, struct ncclStrongStream* a, cudaStream_t b, bool b_subsumes_a=false ); // `a` must be capturing within `graph`. ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, cudaStream_t a, struct ncclStrongStream* b + struct ncclCudaGraph graph, cudaStream_t a, struct ncclStrongStream* b, bool b_subsumes_a=false ); // Synchrnoization does not need the strong stream to be acquired. diff --git a/src/include/utils.h b/src/include/utils.h index 0604d15..1c300b0 100644 --- a/src/include/utils.h +++ b/src/include/utils.h @@ -27,6 +27,7 @@ ncclResult_t getHostName(char* hostname, int maxlen, const char delim); uint64_t getHash(const char* string, int n); uint64_t getHostHash(); uint64_t getPidHash(); +ncclResult_t getRandomData(void* buffer, size_t bytes); struct netIf { char prefix[64]; @@ -48,6 +49,19 @@ inline uint64_t clockNano() { return uint64_t(ts.tv_sec)*1000*1000*1000 + ts.tv_nsec; } +/* get any bytes of random data from /dev/urandom, return 0 if it succeeds; else + * return -1 */ +inline ncclResult_t getRandomData(void* buffer, size_t bytes) { + ncclResult_t ret = ncclSuccess; + if (bytes > 0) { + const size_t one = 1UL; + FILE* fp = fopen("/dev/urandom", "r"); + if (buffer == NULL || fp == NULL || fread(buffer, bytes, one, fp) != one) ret = ncclSystemError; + if (fp) fclose(fp); + } + return ret; +} + //////////////////////////////////////////////////////////////////////////////// template diff --git a/src/init.cc b/src/init.cc index ab0a064..91a8793 100644 --- a/src/init.cc +++ b/src/init.cc @@ -87,6 +87,7 @@ static ncclResult_t ncclInit() { NCCLCHECK(bootstrapNetInit()); NCCLCHECK(ncclNetPluginInit()); + initNvtxRegisteredEnums(); __atomic_store_n(&initialized, true, __ATOMIC_RELEASE); } pthread_mutex_unlock(&initLock); @@ -104,7 +105,7 @@ NCCL_API(ncclResult_t, ncclGetUniqueId, ncclUniqueId* out); ncclResult_t ncclGetUniqueId(ncclUniqueId* out) { NCCLCHECK(ncclInit()); NCCLCHECK(PtrCheck(out, "GetUniqueId", "out")); - ncclResult_t res = bootstrapGetUniqueId(out); + ncclResult_t res = bootstrapGetUniqueId((struct ncclBootstrapHandle*)out); TRACE_CALL("ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out)); return res; } @@ -117,7 +118,7 @@ ncclResult_t ncclGetUniqueId(ncclUniqueId* out) { #endif void NCCL_NO_OPTIMIZE commPoison(ncclComm_t comm) { - // Important that this does not trash intraComm0 & intraRefs. + // Important that this does not trash intraComm0. comm->rank = comm->cudaDev = comm->busId = comm->nRanks = -1; } @@ -173,11 +174,15 @@ void ncclCommPushCudaGdrFree(struct ncclComm* comm, void* handle) { } static ncclResult_t commFree(ncclComm_t comm) { + /* commFree() should not involve any sync among ranks. */ if (comm == NULL) return ncclSuccess; - // Stop all threads before we free anything. - NCCLCHECK(ncclProxyDestroy(comm)); + /* in commReclaim, we have guaranteed only last rank which calls ncclCommDestroy() will + * free all intra-process communicators; therefore, we only need to focus on local + * resource cleanup in commFree(). */ + if (comm->proxyState.thread) + pthread_join(comm->proxyState.thread, nullptr); delete[] comm->userRedOps; @@ -214,30 +219,10 @@ static ncclResult_t commFree(ncclComm_t comm) { ncclMemoryStackDestruct(&comm->memScoped); ncclMemoryStackDestruct(&comm->memPermanent); - commPoison(comm); // Important that this does not interfere with anything used below. + ncclCudaHostFree((void *)comm->abortFlag); - if (comm->initState == ncclSuccess) { - struct ncclComm* intraComm0 = comm->intraComm0; - if (0 == ncclAtomicRefCountDecrement(&intraComm0->intraRefs)) { - // Wait for all service threads to be done. We could not - // do it earlier because it could have blocked and prevented - // other ranks in the process to call ncclCommDestroy - comm = intraComm0; - while (comm != nullptr) { - if (comm->proxyState.thread) pthread_join(comm->proxyState.thread, nullptr); - struct ncclComm* next = comm->intraNext; - free(comm); - comm = next; - } - } - } else if (comm->proxyState.thread) { - pthread_join(comm->proxyState.thread, nullptr); - ncclCudaHostFree((void *)comm->abortFlag); - free(comm); - } else { - ncclCudaHostFree((void *)comm->abortFlag); - free(comm); - } + commPoison(comm); // poison comm before free to avoid comm reuse. + free(comm); return ncclSuccess; } @@ -253,7 +238,7 @@ NCCL_PARAM(DmaBufEnable, "DMABUF_ENABLE", 1); // Detect DMA-BUF support static ncclResult_t dmaBufSupported(struct ncclComm* comm) { - if (ncclParamDmaBufEnable() == 0 || comm->ncclNet->regMrDmaBuf == NULL) return ncclInternalError; + if (ncclParamDmaBufEnable() == 0 || comm->ncclNet->regMrDmaBuf == NULL || ncclCudaLibraryInit() != ncclSuccess) return ncclInternalError; #if CUDA_VERSION >= 11070 int flag = 0; CUdevice dev; @@ -330,7 +315,8 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { cudaGetDevice(&comm->cudaDev); NCCLCHECK(getBusId(comm->cudaDev, &comm->busId)); - TRACE(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx", comm, rank, ndev, comm->cudaDev, comm->busId); + comm->compCap = ncclCudaCompCap(); + TRACE(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx compCap %d", comm, rank, ndev, comm->cudaDev, comm->busId, comm->compCap); comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false; comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false; @@ -360,11 +346,13 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { } static ncclResult_t devCommSetup(ncclComm_t comm) { - NCCLCHECK(ncclStrongStreamAcquireUncaptured(&comm->deviceStream)); - + ncclResult_t ret = ncclSuccess; int nRanks = comm->nRanks; - struct ncclDevCommAndChannels *devCommAndChans, tmpCommAndChans; - NCCLCHECK(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.cudaStream)); + struct ncclDevCommAndChannels tmpCommAndChans; + struct ncclDevCommAndChannels *devCommAndChans = NULL; + + NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->deviceStream), ret, fail); + NCCLCHECKGOTO(ncclCudaCallocAsync(&devCommAndChans, 1, comm->deviceStream.cudaStream), ret, fail); ncclCommPushCudaFree(comm, devCommAndChans); comm->devComm = &devCommAndChans->comm; tmpCommAndChans.comm.rank = comm->rank; @@ -384,18 +372,18 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { if (ncclGdrCopy != NULL && ncclParamGdrCopyFifoEnable() == 1) { // The workFifoHeap lives in GDR mapped CUDA memory. - NCCLCHECK(ncclGdrCudaCalloc(&comm->workFifoHeap, &comm->devWorkFifoHeap, comm->workFifoDepth, &comm->workFifoHeapGdrHandle)); + NCCLCHECKGOTO(ncclGdrCudaCalloc(&comm->workFifoHeap, &comm->devWorkFifoHeap, comm->workFifoDepth, &comm->workFifoHeapGdrHandle), ret, fail); ncclCommPushCudaGdrFree(comm, comm->workFifoHeapGdrHandle); } else { // The workFifoHeap lives in cudaHost memory. comm->workFifoHeapGdrHandle = nullptr; - NCCLCHECK(ncclCudaHostCalloc(&comm->workFifoHeap, comm->workFifoDepth)); + NCCLCHECKGOTO(ncclCudaHostCalloc(&comm->workFifoHeap, comm->workFifoDepth), ret, fail); ncclCommPushCudaHostFree(comm, comm->workFifoHeap); comm->devWorkFifoHeap = comm->workFifoHeap; } tmpCommAndChans.comm.workFifoHeap = comm->devWorkFifoHeap; - NCCLCHECK(ncclCudaHostCalloc(&comm->workFifoDone, MAXCHANNELS)); + NCCLCHECKGOTO(ncclCudaHostCalloc(&comm->workFifoDone, MAXCHANNELS), ret, fail); ncclCommPushCudaHostFree(comm, comm->workFifoDone); comm->workFifoSent = 0; comm->workFifoAckdMin = 0; @@ -410,14 +398,17 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c]; if (comm->channels[c].ring.userRanks != nullptr) { - NCCLCHECK(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.cudaStream)); + NCCLCHECKGOTO(ncclCudaMemcpyAsync(tmpCommAndChans.channels[c].ring.userRanks, comm->channels[c].ring.userRanks, nRanks, comm->deviceStream.cudaStream), ret, fail); } } - NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.cudaStream)); + NCCLCHECKGOTO(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.cudaStream), ret, fail); +exit: CUDACHECK(cudaStreamSynchronize(comm->deviceStream.cudaStream)); NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); - return ncclSuccess; + return ret; +fail: + goto exit; } // Pre-process the string so that running "strings" on the lib can quickly reveal the version. @@ -481,6 +472,8 @@ NCCL_PARAM(LlBuffSize, "LL_BUFFSIZE", -2); NCCL_PARAM(Ll128BuffSize, "LL128_BUFFSIZE", -2); NCCL_PARAM(P2pNetChunkSize, "P2P_NET_CHUNKSIZE", (1 << 17)); /* 128 kB */ +NCCL_PARAM(P2pPciChunkSize, "P2P_PCI_CHUNKSIZE", (1 << 17)); /* 128 kB */ +NCCL_PARAM(P2pNvlChunkSize, "P2P_NVL_CHUNKSIZE", (1 << 19)); /* 512 kB */ static ncclResult_t computeBuffSizes(struct ncclComm* comm) { int cpuArch, cpuVendor, cpuModel; @@ -495,7 +488,10 @@ static ncclResult_t computeBuffSizes(struct ncclComm* comm) { comm->buffSizes[p] = envs[p] != -2 ? envs[p] : defaults[p]; } - comm->p2pNetChunkSize = ncclParamP2pNetChunkSize(); + if (comm->nNodes > 1) comm->p2pChunkSize = ncclParamP2pNetChunkSize(); + else if (ncclTopoPathAllNVLink(comm->topo)) comm->p2pChunkSize = ncclParamP2pNvlChunkSize(); + else comm->p2pChunkSize = ncclParamP2pPciChunkSize(); + INFO(NCCL_INIT, "P2P Chunksize set to %d", comm->p2pChunkSize); return ncclSuccess; } @@ -504,90 +500,241 @@ NCCL_PARAM(CollNetNodeThreshold, "COLLNET_NODE_THRESHOLD", 2); NCCL_PARAM(NvbPreconnect, "NVB_PRECONNECT", 1); NCCL_PARAM(AllocP2pNetLLBuffers, "NCCL_ALLOC_P2P_NET_LL_BUFFERS", 0); +static ncclResult_t collNetTrySetup(ncclComm_t comm, struct ncclTopoGraph* collNetGraph) { + ncclResult_t ret = ncclSuccess; + int* heads = NULL; + int rank = comm->rank; + int collNetSetupFail = 0; + int highestTypes[NCCL_MAX_LOCAL_RANKS] = { TRANSPORT_P2P }; + // Find all head ranks + int nHeads = collNetGraph->nChannels; + int highestTransportType0, highestTransportType1; + char line[1024]; + + NCCLCHECKGOTO(ncclCalloc(&heads, nHeads), ret, fail); + // Head GPU index is always 0 + for (int c = 0; c < nHeads; c++) { + heads[c] = collNetGraph->intra[c * comm->localRanks + 0]; + } + + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channel = comm->channels + c; + for (int h = 0; h < nHeads; h++) { + const int head = heads[h]; + collNetSetupFail = ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetRecv); + if (!collNetSetupFail) collNetSetupFail = ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetSend); + } + // Verify CollNet setup across ranks after trying the first channel + if (c == 0) { + NCCLCHECKGOTO(ncclTransportCollNetCheck(comm, collNetSetupFail), ret, fail); + } + } + // Verify CollNet setup across ranks after trying all channels + NCCLCHECKGOTO(ncclTransportCollNetCheck(comm, collNetSetupFail), ret, fail); + TRACE(NCCL_INIT, "rank %d Connected inter-node CollNet", rank); + + line[0] = '\0'; + for (int c = 0; c < comm->nChannels; c++) { + struct ncclTree* chain = &comm->channels[c].collnetChain; + snprintf(line + strlen(line), 1023 - strlen(line), " [%d] %d->%d->%d", + c, chain->down[0], rank, chain->up); + } + line[1023] = '\0'; + + INFO(NCCL_INIT, "Collnet Chains %s", line); + // Connect Collnet + chain + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channel = comm->channels + c; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->collnetChain.up, 1, channel->collnetChain.down, 0), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, collNetGraph, 0), ret, fail); + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channel = comm->channels + c; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, channel->collnetChain.down, 1, &channel->collnetChain.up, 1), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, collNetGraph, 1), ret, fail); + INFO(NCCL_INIT, "Connected collnet + chain"); + + // Connect intra-node CollNet + Direct + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channelRecv = comm->channels + c; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.up, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.down, 0), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, collNetGraph, 0, &highestTransportType0), ret, fail); + + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channelSend = comm->channels + c; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.down, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.up, 1), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, collNetGraph, 1, &highestTransportType1), ret, fail); + + // Exchange highest intra-node transport type among ranks + // because we need to know whether all ranks can p2p each other to determine whether we can directly read/write registered user buffer + comm->intraHighestTransportType = highestTypes[comm->localRank] = highestTransportType0 > highestTransportType1 ? highestTransportType0 : highestTransportType1; + NCCLCHECKGOTO(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, highestTypes, sizeof(int)), ret, fail); + for (int i = 0; i < comm->localRanks; i++) { + if (highestTypes[i] > comm->intraHighestTransportType) + comm->intraHighestTransportType = highestTypes[i]; + } + + INFO(NCCL_INIT, "rank %d Connected CollNet", rank); + +exit: + free(heads); + return ret; +fail: + ncclTransportCollNetFree(comm); + comm->collNetSupport = 0; + goto exit; +} + static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* commId) { // We use 2 AllGathers // 1. { peerInfo, comm, compCap} // 2. { nChannels, graphInfo, topoRanks } - + ncclResult_t ret = ncclSuccess; int rank = comm->rank; int nranks = comm->nRanks; uint64_t commHash = getHash(commId->internal, NCCL_UNIQUE_ID_BYTES); + cpu_set_t affinitySave; + struct ncclTopoGraph ringGraph; + struct ncclTopoGraph treeGraph; + struct ncclTopoGraph collNetGraph; + + struct graphInfo { + int pattern; + int nChannels; + int sameChannels; + float bwIntra; + float bwInter; + int typeIntra; + int typeInter; + }; + + struct allGatherInfo { + int netDev; + int collNetSupport; + struct graphInfo tree; + struct graphInfo ring; + struct graphInfo collNet; + struct ncclTopoRanks topoRanks; + }; + + int nChannelsOrig; + struct allGatherInfo *allGather3Data = NULL; + struct ncclTopoRanks** allTopoRanks = NULL; + int *nodesFirstRank = NULL, *nodesTreePatterns = NULL; + int *rings = NULL; + int* nvbPeers = NULL; + struct ncclProxyConnector proxyConn; + int* pxnPeers = NULL; + TRACE(NCCL_INIT, "comm %p, commHash %lx, rank %d nranks %d - BEGIN", comm, commHash, rank, nranks); - NCCLCHECK(bootstrapInit(commId, comm)); + NCCLCHECKGOTO(bootstrapInit((struct ncclBootstrapHandle*)commId, comm), ret, fail); // AllGather1 - begin - NCCLCHECK(ncclCalloc(&comm->peerInfo, nranks+1)); // Extra rank to represent CollNet root - NCCLCHECK(fillInfo(comm, comm->peerInfo+rank, commHash)); - NCCLCHECK(bootstrapAllGather(comm->bootstrap, comm->peerInfo, sizeof(struct ncclPeerInfo))); + NCCLCHECKGOTO(ncclCalloc(&comm->peerInfo, nranks+1), ret, fail); // Extra rank to represent CollNet root + NCCLCHECKGOTO(fillInfo(comm, comm->peerInfo+rank, commHash), ret, fail); + NCCLCHECKGOTO(bootstrapAllGather(comm->bootstrap, comm->peerInfo, sizeof(struct ncclPeerInfo)), ret, fail); for (int i = 0; i < nranks; i++) { if ((i != rank) && (comm->peerInfo[i].hostHash == comm->peerInfo[rank].hostHash) && (comm->peerInfo[i].busId == comm->peerInfo[rank].busId)) { WARN("Duplicate GPU detected : rank %d and rank %d both on CUDA device %lx", rank, i, comm->peerInfo[rank].busId); - return ncclInvalidUsage; + ret = ncclInvalidUsage; + goto fail; } } - // AllGather1 - end + do { + // Compute intra-process ranks + int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0; + for (int i = 0; i < nranks; i++) { + if ((comm->peerInfo[i].hostHash == comm->peerInfo[rank].hostHash) + && (comm->peerInfo[i].pidHash == comm->peerInfo[rank].pidHash)) { + // Rank is in same process + if (intraProcRanks == 0) intraProcRank0 = i; + if (i == rank) intraProcRank = intraProcRanks; + intraProcRanks++; + if (intraProcRank0 == rank && rank != i) { + comm->peerInfo[i].comm->intraNext = comm->intraNext; + comm->intraNext = comm->peerInfo[i].comm; + } + } + } + TRACE(NCCL_INIT,"pidHash[%d] %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", + rank, comm->peerInfo[rank].pidHash, intraProcRank, intraProcRanks, intraProcRank0); + if (intraProcRank == -1 || intraProcRank0 == -1 || comm->peerInfo[intraProcRank0].comm == NULL) { + WARN("Failed to determine intra proc ranks rank %d hostHash %lx pidHash %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", + rank, comm->peerInfo[rank].hostHash, comm->peerInfo[rank].pidHash, + intraProcRank, intraProcRanks, intraProcRank0); + ret = ncclInternalError; + goto fail; + } + struct ncclComm* comm0 = comm->peerInfo[intraProcRank0].comm; + assert(intraProcRank==0 ? comm==comm0 : true); + comm->intraComm0 = comm0; + comm->intraRank = intraProcRank; + comm->intraRanks = intraProcRanks; + comm->intraBarrierPhase = 0; + comm->intraBarrierCounter = 0; + comm->intraBarrierGate = 0; + } while(0); + // Topo detection / System graph creation - NCCLCHECK(ncclTopoGetSystem(comm, &comm->topo)); + NCCLCHECKGOTO(ncclTopoGetSystem(comm, &comm->topo), ret, fail); // Compute paths between GPUs and NICs - NCCLCHECK(ncclTopoComputePaths(comm->topo, comm)); + NCCLCHECKGOTO(ncclTopoComputePaths(comm->topo, comm), ret, fail); // Remove inaccessible GPUs and unused NICs - NCCLCHECK(ncclTopoTrimSystem(comm->topo, comm)); + NCCLCHECKGOTO(ncclTopoTrimSystem(comm->topo, comm), ret, fail); // Recompute paths after trimming - NCCLCHECK(ncclTopoComputePaths(comm->topo, comm)); + NCCLCHECKGOTO(ncclTopoComputePaths(comm->topo, comm), ret, fail); // Init search - NCCLCHECK(ncclTopoSearchInit(comm->topo)); + NCCLCHECKGOTO(ncclTopoSearchInit(comm->topo), ret, fail); // Print final topology - NCCLCHECK(ncclTopoPrint(comm->topo)); + NCCLCHECKGOTO(ncclTopoPrint(comm->topo), ret, fail); // Set Affinity to a CPU local the our GPU, so that all memory we allocate // on the host is local. - NCCLCHECK(ncclTopoGetCpuAffinity(comm->topo, comm->rank, &comm->cpuAffinity)); - cpu_set_t affinitySave; + NCCLCHECKGOTO(ncclTopoGetCpuAffinity(comm->topo, comm->rank, &comm->cpuAffinity), ret, fail); if (CPU_COUNT(&comm->cpuAffinity)) { sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave); sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity); } - ncclResult_t ret; // Launch proxy service thread - NCCLCHECK(ncclProxyCreate(comm)); + NCCLCHECKGOTO(ncclProxyCreate(comm), ret, fail); // Get rings and trees - struct ncclTopoGraph ringGraph; ringGraph.id = 0; ringGraph.pattern = NCCL_TOPO_PATTERN_RING; ringGraph.collNet = 0; ringGraph.minChannels = 1; ringGraph.maxChannels = MAXCHANNELS/2; - NCCLCHECK(ncclTopoCompute(comm->topo, &ringGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &ringGraph)); + NCCLCHECKGOTO(ncclTopoCompute(comm->topo, &ringGraph), ret, fail); + NCCLCHECKGOTO(ncclTopoPrintGraph(comm->topo, &ringGraph), ret, fail); - struct ncclTopoGraph treeGraph; treeGraph.id = 1; treeGraph.pattern = NCCL_TOPO_PATTERN_BALANCED_TREE; treeGraph.collNet = 0; treeGraph.minChannels = 1; treeGraph.maxChannels = ringGraph.nChannels; - NCCLCHECK(ncclTopoCompute(comm->topo, &treeGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &treeGraph)); + NCCLCHECKGOTO(ncclTopoCompute(comm->topo, &treeGraph), ret, fail); + NCCLCHECKGOTO(ncclTopoPrintGraph(comm->topo, &treeGraph), ret, fail); - struct ncclTopoGraph collNetGraph; collNetGraph.id = 2; collNetGraph.pattern = NCCL_TOPO_PATTERN_TREE; collNetGraph.collNet = 1; collNetGraph.minChannels = collNetGraph.maxChannels = ringGraph.nChannels; - NCCLCHECK(ncclTopoCompute(comm->topo, &collNetGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &collNetGraph)); + NCCLCHECKGOTO(ncclTopoCompute(comm->topo, &collNetGraph), ret, fail); + NCCLCHECKGOTO(ncclTopoPrintGraph(comm->topo, &collNetGraph), ret, fail); // Initialize num P2P LL buffers for this communicator comm->allocP2pNetLLBuffers = ncclParamAllocP2pNetLLBuffers() == 1; if (comm->rank == ncclParamGraphDumpFileRank()) { struct ncclTopoGraph* graphs[3] = { &ringGraph, &treeGraph, &collNetGraph }; - NCCLCHECK(ncclTopoDumpGraphs(comm->topo, 3, graphs)); + NCCLCHECKGOTO(ncclTopoDumpGraphs(comm->topo, 3, graphs), ret, fail); } // Determine local CollNet support before all-gather @@ -603,28 +750,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm if (comm->collNetSupport == 1 && collNetGraph.nChannels <= 0) comm->collNetSupport = 0; // AllGather3 - begin - struct ncclGraphInfo { - int pattern; - int nChannels; - int sameChannels; - float bwIntra; - float bwInter; - int typeIntra; - int typeInter; - }; - - struct { - int netDev; - int collNetSupport; - struct ncclGraphInfo tree; - struct ncclGraphInfo ring; - struct ncclGraphInfo collNet; - struct ncclTopoRanks topoRanks; - } *allGather3Data; - - NCCLCHECK(ncclCalloc(&allGather3Data, nranks)); - - NCCLCHECK(ncclTopoGetLocalNet(comm->topo, rank, &allGather3Data[rank].netDev)); + NCCLCHECKGOTO(ncclCalloc(&allGather3Data, nranks), ret, fail); + NCCLCHECKGOTO(ncclTopoGetLocalNet(comm->topo, rank, &allGather3Data[rank].netDev), ret, fail); allGather3Data[rank].tree.pattern = treeGraph.pattern; allGather3Data[rank].tree.nChannels = treeGraph.nChannels; allGather3Data[rank].tree.sameChannels = treeGraph.sameChannels; @@ -649,15 +776,14 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm allGather3Data[rank].collNetSupport = comm->collNetSupport; comm->nChannels = std::min(treeGraph.nChannels, ringGraph.nChannels); - NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &collNetGraph, &allGather3Data[rank].topoRanks)); + NCCLCHECKGOTO(ncclTopoPreset(comm, &treeGraph, &ringGraph, &collNetGraph, &allGather3Data[rank].topoRanks), ret, fail); - NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data))); + NCCLCHECKGOTO(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data)), ret, fail); // Determine nNodes, firstRanks, ... - int *nodesFirstRank, *nodesTreePatterns; - NCCLCHECK(ncclCalloc(&nodesFirstRank, nranks)); - NCCLCHECK(ncclCalloc(&nodesTreePatterns, nranks)); - NCCLCHECK(ncclCalloc(&comm->rankToNode, comm->nRanks)); + NCCLCHECKGOTO(ncclCalloc(&nodesFirstRank, nranks), ret, fail); + NCCLCHECKGOTO(ncclCalloc(&nodesTreePatterns, nranks), ret, fail); + NCCLCHECKGOTO(ncclCalloc(&comm->rankToNode, comm->nRanks), ret, fail); for (int r=0; rrankToNode[r] = node; } // Now that we know nNodes, alloc nodeRanks and compute localRanks for each node - NCCLCHECK(ncclCalloc(&comm->nodeRanks, comm->nNodes)); - NCCLCHECK(ncclCalloc(&comm->rankToLocalRank, comm->nRanks)); + NCCLCHECKGOTO(ncclCalloc(&comm->nodeRanks, comm->nNodes), ret, fail); + NCCLCHECKGOTO(ncclCalloc(&comm->rankToLocalRank, comm->nRanks), ret, fail); for (int r=0; rnRanks; r++) { int node = comm->rankToNode[r]; comm->rankToLocalRank[r] = comm->nodeRanks[node].localRanks; @@ -680,7 +806,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } // Allocate ranks arrays for each node for (int n=0; nnNodes; n++) { - NCCLCHECK(ncclCalloc(&comm->nodeRanks[n].localRankToRank, comm->nodeRanks[n].localRanks)); + NCCLCHECKGOTO(ncclCalloc(&comm->nodeRanks[n].localRankToRank, comm->nodeRanks[n].localRanks), ret, fail); comm->maxLocalRanks = std::max(comm->maxLocalRanks, comm->nodeRanks[n].localRanks); comm->nodeRanks[n].localRanks = 0; } @@ -700,12 +826,12 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm WARN("Failed to determine local ranks rank %d hostHash %lx pidHash %lx localRank %d localRanks %d localRank0 %d", rank, comm->peerInfo[rank].hostHash, comm->peerInfo[rank].pidHash, comm->localRank, comm->localRanks, comm->localRankToRank[0]); - return ncclInternalError; + ret = ncclInternalError; + goto fail; } - int nChannelsOrig = comm->nChannels; - struct ncclTopoRanks** allTopoRanks; - NCCLCHECK(ncclCalloc(&allTopoRanks, comm->nRanks)); + nChannelsOrig = comm->nChannels; + NCCLCHECKGOTO(ncclCalloc(&allTopoRanks, comm->nRanks), ret, fail); for (int i=0; ipeerInfo[i].netDev = allGather3Data[i].netDev; allTopoRanks[i] = &allGather3Data[i].topoRanks; @@ -754,15 +880,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } } - int *rings; - NCCLCHECK(ncclCalloc(&rings, nranks*MAXCHANNELS)); - NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph)); - - free(allTopoRanks); - free(nodesTreePatterns); - free(nodesFirstRank); - free(allGather3Data); - + NCCLCHECKGOTO(ncclCalloc(&rings, nranks*MAXCHANNELS), ret, fail); + NCCLCHECKGOTO(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph), ret, fail); // AllGather3 - end TRACE(NCCL_INIT, "rank %d nranks %d - BUILT %d TREES/RINGS", rank, nranks, comm->nChannels); @@ -778,110 +897,31 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm line[1023] = '\0'; INFO(NCCL_INIT, "Trees%s", line); - NCCLCHECK(computeBuffSizes(comm)); + NCCLCHECKGOTO(computeBuffSizes(comm), ret, fail); // Connect with prev/next for each ring for (int c=0; cnChannels; c++) { struct ncclChannel* channel = comm->channels+c; - NCCLCHECKGOTO(setupChannel(comm, c, rank, nranks, rings+c*nranks), ret, affinity_restore); + NCCLCHECKGOTO(setupChannel(comm, c, rank, nranks, rings+c*nranks), ret, fail); if (comm->nRanks == 1) continue; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->ring.prev, 1, &channel->ring.next, 0), ret, affinity_restore); + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->ring.prev, 1, &channel->ring.next, 0), ret, fail); } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, 0), ret, affinity_restore); - free(rings); + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, 0), ret, fail); INFO(NCCL_INIT, "Connected all rings"); // Connect Trees for (int c=0; cnChannels; c++) { struct ncclChannel* channel = comm->channels+c; if (comm->nRanks == 1) continue; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_TREE_ARITY, channel->tree.down, 1, &channel->tree.up, 0), ret, affinity_restore); - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->tree.up, NCCL_MAX_TREE_ARITY, channel->tree.down, 0), ret, affinity_restore); + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_TREE_ARITY, channel->tree.down, 1, &channel->tree.up, 0), ret, fail); + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->tree.up, NCCL_MAX_TREE_ARITY, channel->tree.down, 0), ret, fail); } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, 0), ret, affinity_restore); + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, 0), ret, fail); INFO(NCCL_INIT, "Connected all trees"); // Check if we can setup CollNet - if (comm->collNetSupport > 0) { - int collNetSetupFail = 0; - int highestTypes[NCCL_MAX_LOCAL_RANKS] = {TRANSPORT_P2P}; - // Find all head ranks - int nHeads = collNetGraph.nChannels; - int *heads; - NCCLCHECK(ncclCalloc(&heads, nHeads)); - // Head GPU index is always 0 - for (int c=0; clocalRanks+0]; - } - for (int c=0; cnChannels; c++) { - struct ncclChannel* channel = comm->channels+c; - for (int h=0; hcollNetSupport > 0) collNetTrySetup(comm, &collNetGraph); - char line[1024]; - line[0]='\0'; - for (int c=0; cnChannels; c++) { - struct ncclTree* chain = &comm->channels[c].collnetChain; - snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d->%d->%d", - c, chain->down[0], rank, chain->up); - } - line[1023] = '\0'; - INFO(NCCL_INIT, "Collnet Chains %s", line); - // Connect Collnet + chain - for (int c=0; cnChannels; c++) { - struct ncclChannel* channel = comm->channels+c; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->collnetChain.up, 1, channel->collnetChain.down, 0), ret, collnet_cleanup); - } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 0), ret, collnet_cleanup); - for (int c=0; cnChannels; c++) { - struct ncclChannel* channel = comm->channels+c; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, channel->collnetChain.down, 1, &channel->collnetChain.up, 1), ret, collnet_cleanup); - } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 1), ret, collnet_cleanup); - INFO(NCCL_INIT, "Connected collnet + chain"); - - // Connect intra-node CollNet + Direct - int highestTransportType0, highestTransportType1; - for (int c=0; cnChannels; c++) { - struct ncclChannel* channelRecv = comm->channels+c; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.up, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.down, 0), ret, collnet_cleanup); - } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 0, &highestTransportType0), ret, collnet_cleanup); - for (int c=0; cnChannels; c++) { - struct ncclChannel* channelSend = comm->channels+c; - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.down, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.up, 1), ret, collnet_cleanup); - } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 1, &highestTransportType1), ret, collnet_cleanup); - - // Exchange highest intra-node transport type among ranks - // because we need to know whether all ranks can p2p each other to determine whether we can directly read/write registered user buffer - comm->intraHighestTransportType = highestTypes[comm->localRank] = highestTransportType0 > highestTransportType1 ? highestTransportType0 : highestTransportType1; - NCCLCHECK(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, highestTypes, sizeof(int))); - for (int i=0; ilocalRanks; i++) { - if (highestTypes[i] > comm->intraHighestTransportType) - comm->intraHighestTransportType = highestTypes[i]; - } - INFO(NCCL_INIT, "rank %d Connected CollNet", rank); - -collnet_cleanup: - free(heads); - if (ret != ncclSuccess) { - NCCLCHECK(ncclTransportCollNetFree(comm)); - comm->collNetSupport = 0; - ret = ncclSuccess; - } - } TRACE(NCCL_INIT, "rank %d nranks %d - CONNECTED %d RINGS AND TREES", rank, nranks, comm->nChannels); // Compute time models for algorithm and protocol combinations @@ -892,11 +932,11 @@ collnet_cleanup: minCompCap = std::min(comm->peerInfo[i].cudaCompCap, minCompCap); maxCompCap = std::max(comm->peerInfo[i].cudaCompCap, maxCompCap); } - NCCLCHECK(ncclTopoTuneModel(comm, minCompCap, maxCompCap, &treeGraph, &ringGraph, &collNetGraph)); + NCCLCHECKGOTO(ncclTopoTuneModel(comm, minCompCap, maxCompCap, &treeGraph, &ringGraph, &collNetGraph), ret, fail); } while(0); // Compute nChannels per peer for p2p - NCCLCHECK(ncclTopoComputeP2pChannels(comm)); + NCCLCHECKGOTO(ncclTopoComputeP2pChannels(comm), ret, fail); do { // Setup p2p structures in comm->tasks struct ncclTasks* tasks = &comm->tasks; @@ -947,80 +987,40 @@ collnet_cleanup: if (ncclParamNvbPreconnect()) { // Connect p2p when using NVB path int nvbNpeers; - int* nvbPeers; - NCCLCHECK(ncclTopoGetNvbGpus(comm->topo, comm->rank, &nvbNpeers, &nvbPeers)); + NCCLCHECKGOTO(ncclTopoGetNvbGpus(comm->topo, comm->rank, &nvbNpeers, &nvbPeers), ret, fail); for (int r=0; rp2pnChannelsPerPeer; c++) { - NCCLCHECK(ncclChannelCompute(comm, peer, c, ncclFuncSend, &channelId)); + NCCLCHECKGOTO(ncclChannelCompute(comm, peer, c, ncclFuncSend, &channelId), ret, fail); if (comm->channels[channelId].peers[peer].send[1].connected == 0) { comm->connectSend[peer] |= (1UL<p2pnChannelsPerPeer; c++) { - NCCLCHECK(ncclChannelCompute(comm, peer, c, ncclFuncRecv, &channelId)); + NCCLCHECKGOTO(ncclChannelCompute(comm, peer, c, ncclFuncRecv, &channelId), ret, fail); if (comm->channels[channelId].peers[peer].recv[1].connected == 0) { comm->connectRecv[peer] |= (1UL<rank, &proxyConn)); - NCCLCHECK(ncclProxyCall(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0)); + NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_NET, 1, comm->rank, &proxyConn), ret, fail); + NCCLCHECKGOTO(ncclProxyCall(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); // Then to remote ones when using PXN if (ncclPxnDisable(comm) == 0) { int nranks; - int* pxnPeers; - NCCLCHECK(ncclTopoGetPxnRanks(comm, &pxnPeers, &nranks)); + NCCLCHECKGOTO(ncclTopoGetPxnRanks(comm, &pxnPeers, &nranks), ret, fail); for (int r=0; rp2pnChannels, sizeof(int), NULL, 0)); + NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_NET, 1, pxnPeers[r], &proxyConn), ret, fail); + NCCLCHECKGOTO(ncclProxyCall(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); } - free(pxnPeers); } - do { - // Compute intra-process ranks - int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0; - for (int i = 0; i < nranks; i++) { - if ((comm->peerInfo[i].hostHash == comm->peerInfo[rank].hostHash) - && (comm->peerInfo[i].pidHash == comm->peerInfo[rank].pidHash)) { - // Rank is in same process - if (intraProcRanks == 0) intraProcRank0 = i; - if (i == rank) intraProcRank = intraProcRanks; - intraProcRanks++; - if (intraProcRank0 == rank && rank != i) { - comm->peerInfo[i].comm->intraNext = comm->intraNext; - comm->intraNext = comm->peerInfo[i].comm; - } - } - } - TRACE(NCCL_INIT,"pidHash[%d] %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", - rank, comm->peerInfo[rank].pidHash, intraProcRank, intraProcRanks, intraProcRank0); - if (intraProcRank == -1 || intraProcRank0 == -1 || comm->peerInfo[intraProcRank0].comm == NULL) { - WARN("Failed to determine intra proc ranks rank %d hostHash %lx pidHash %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d", - rank, comm->peerInfo[rank].hostHash, comm->peerInfo[rank].pidHash, - intraProcRank, intraProcRanks, intraProcRank0); - return ncclInternalError; - } - struct ncclComm* comm0 = comm->peerInfo[intraProcRank0].comm; - assert(intraProcRank==0 ? comm==comm0 : true); - comm->intraComm0 = comm0; - comm->intraRefs = intraProcRank==0 ? intraProcRanks : 0; - comm->intraRank = intraProcRank; - comm->intraRanks = intraProcRanks; - comm->intraBarrierPhase = 0; - comm->intraBarrierCounter = 0; - comm->intraBarrierGate = 0; - } while(0); - if (comm->intraRank == 0) { // Load ncclParamLaunchMode char* str = getenv("NCCL_LAUNCH_MODE"); enum ncclLaunchMode mode, modeOld; @@ -1037,22 +1037,31 @@ collnet_cleanup: } } - NCCLCHECKGOTO(devCommSetup(comm), ret, affinity_restore); + // Call devCommSetup before the last barrier, making sure we don't have a thread running in front and starting to + // launch NCCL kernels before all cuda mem allocation is complete. That could cause a deadlock. + NCCLCHECKGOTO(devCommSetup(comm), ret, fail); /* Local intra-node barrier */ - NCCLCHECK(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0])); - - // Unlink proxy shm to make sure it will be properly cleaned up. - NCCLCHECK(ncclProxyShmUnlink(comm)); + NCCLCHECKGOTO(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]), ret, fail); // We should have allocated all buffers, collective fifos, ... we can // restore the affinity. -affinity_restore: - if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave); - if (ret != ncclSuccess) return ret; - TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks); - return ncclSuccess; + +exit: + if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave); + // Unlink proxy shm to make sure it will be properly cleaned up. + ncclProxyShmUnlink(comm); + free(allTopoRanks); + free(nodesTreePatterns); + free(nodesFirstRank); + free(allGather3Data); + free(rings); + free(nvbPeers); + free(pxnPeers); + return ret; +fail: + goto exit; } NCCL_PARAM(SetStackSize, "SET_STACK_SIZE", 0); @@ -1080,25 +1089,25 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { int cudaDev = job->cudaDev; ncclResult_t res = ncclSuccess; - CUDACHECK(cudaSetDevice(cudaDev)); + CUDACHECKGOTO(cudaSetDevice(cudaDev), res, fail); // Set the maximum kernel stack size of all kernels to avoid // a CUDA memory reconfig on load (c.f. NVSHMEM issue) if (maxLocalSizeBytes > 0 && ncclParamSetStackSize() == 1) { TRACE(NCCL_INIT, "Setting cudaLimitStackSize to %zi", maxLocalSizeBytes); CUDACHECKIGNORE(cudaDeviceSetLimit(cudaLimitStackSize, maxLocalSizeBytes)); } - NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup); - NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup); + NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, fail); + NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, fail); // update communicator state comm->initState = ncclSuccess; - INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId); - TRACE_CALL("ncclCommInitRank(%p,%d,0x%llx,%d,%d)", *newcomm, nranks, (unsigned long long)hashUniqueId(commId), myrank, (*newcomm)->cudaDev); - return ncclSuccess; -cleanup: - comm->initState = res; + INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx commId 0x%llx - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId, (unsigned long long)hashUniqueId(commId)); +exit: return res; +fail: + comm->initState = res; + goto exit; } static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { @@ -1122,13 +1131,13 @@ static void ncclCommInitRankUndo(struct ncclAsyncJob* job_) { } static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank, int cudaDev, ncclConfig_t *config) { - ncclResult_t res; + ncclResult_t res = ncclSuccess; ncclComm_t comm = NULL; struct ncclCommInitRankAsyncJob *job = NULL; char* env = getenv("NCCL_COMM_ID"); if (env && myrank == 0) { INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", env); - NCCLCHECKGOTO(bootstrapCreateRoot(&commId, true), res, fail); + NCCLCHECKGOTO(bootstrapCreateRoot((struct ncclBootstrapHandle*)&commId, true), res, fail); } NCCLCHECKGOTO(ncclInit(), res, fail); @@ -1146,8 +1155,6 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUni NCCLCHECKGOTO(ncclCalloc(&comm, 1), res, fail); NCCLCHECKGOTO(ncclCudaHostCalloc((uint32_t**)&comm->abortFlag, 1), res, fail); - // set up comm state and abortFlag only - *comm->abortFlag = 0; NCCLCHECKGOTO(parseCommConfig(comm, config), res, fail); /* start with ncclInternalError and will be changed to ncclSuccess if init succeeds. */ comm->initState = ncclInternalError; @@ -1164,28 +1171,53 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUni exit: return ncclGroupErrCheck(res); fail: + if (job) free(job); + if (comm) { + if (comm->abortFlag) ncclCudaHostFree((void *)comm->abortFlag); + free(comm); + } + if (newcomm) *newcomm = NULL; goto exit; } +struct NvtxParamsCommInitRank +{ + int rank; + int nranks; + int cudaDev; +}; +constexpr nvtxPayloadSchemaEntry_t CommInitRankSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Rank"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "No. of ranks", nullptr, 0, offsetof(NvtxParamsCommInitRank, nranks)}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "CUDA device", nullptr, 0, offsetof(NvtxParamsCommInitRank, cudaDev)}, +}; + NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank); ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank) { - NVTX3_FUNC_RANGE_IN(nccl_domain); - // Load the CUDA driver and dlsym hooks (can fail on old drivers) (void)ncclCudaLibraryInit(); int cudaDev; CUDACHECK(cudaGetDevice(&cudaDev)); + + NvtxParamsCommInitRank payload{myrank, nranks, cudaDev}; + NVTX3_FUNC_WITH_PARAMS(CommInitRank, CommInitRankSchema, payload) + NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, NULL)); return ncclSuccess; } NCCL_API(ncclResult_t, ncclCommInitAll, ncclComm_t* comms, int ndev, const int* devlist); ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { - NVTX3_FUNC_RANGE_IN(nccl_domain); ncclResult_t ret = ncclSuccess; int totalnDev; int *gpuFlags = NULL; + + constexpr nvtxPayloadSchemaEntry_t CommInitAllSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "No. of devices"} + }; + NVTX3_FUNC_WITH_PARAMS(CommInitAll, CommInitAllSchema, ndev) + // Load the CUDA driver and dlsym hooks (can fail on old drivers) (void)ncclCudaLibraryInit(); @@ -1297,9 +1329,8 @@ static ncclResult_t commDestroySync(struct ncclAsyncJob* job_) { struct ncclCommFinalizeAsyncJob* job = (struct ncclCommFinalizeAsyncJob*) job_; ncclComm_t comm = job->comm; int savedDevice; - CUDACHECK(cudaGetDevice(&savedDevice)); int commDevice = comm->cudaDev; - ncclResult_t ret; + ncclResult_t ret = ncclSuccess; CUDACHECKGOTO(cudaGetDevice(&savedDevice), ret, fail); if (savedDevice != commDevice) { @@ -1322,6 +1353,7 @@ static ncclResult_t commDestroySync(struct ncclAsyncJob* job_) { CUDACHECKGOTO(cudaSetDevice(savedDevice), ret, fail); } + comm->finalizeCalled = true; exit: return ret; fail: @@ -1350,7 +1382,6 @@ static ncclResult_t commFinalize(ncclComm_t comm, bool userCalled) { ncclResult_t ret = ncclSuccess; struct ncclCommFinalizeAsyncJob *job = NULL; - comm->finalizeCalled = true; /* launch async thread to finalize comm. */ NCCLCHECKGOTO(ncclCalloc(&job, 1), ret, fail); job->comm = comm; @@ -1412,9 +1443,9 @@ static ncclResult_t commReclaim(ncclComm_t comm) { NCCLCHECKGOTO(commFinalize(comm, false), ret, fail); } - if (comm->initState != ncclSuccess) { - /* if init errors happen, no finalize thread should have been launched. Main thread can reclaim - * everything since no NCCL kernel was issued. */ + if (comm->intraComm0 == NULL) { + /* if init errors happen and comm->intraComm0 == NULL, no proxy connection is built up, and no finalize thread + * have been launched. Main thread can reclaim everything since no NCCL kernel was issued. */ struct ncclCommFinalizeAsyncJob job; job.comm = comm; @@ -1424,6 +1455,10 @@ static ncclResult_t commReclaim(ncclComm_t comm) { WARN("commReclaim: comm %p (rank = %d) in abort, error %d", comm, curRank, ret); } + if ((ret = ncclProxyDestroy(comm)) != ncclSuccess) { + WARN("commReclaim: comm %p (rank = %d) destroys proxy resource error %d", comm, curRank, ret); + } + if ((ret = commCleanup(comm)) != ncclSuccess) { WARN("commReclaim: cleanup comm %p rank %d failed in destroy/abort, error %d", comm, curRank, ret); } @@ -1439,18 +1474,51 @@ static ncclResult_t commReclaim(ncclComm_t comm) { ncclComm_t curIntraComm; ncclComm_t nextIntraComm = intracomm0; + /* this is the last call to ncclCommDestroy/Abort, we need to make sure all comms + * in the process have been finalized before we free local resources. */ while (nextIntraComm) { curIntraComm = nextIntraComm; curRank = curIntraComm->rank; nextIntraComm = nextIntraComm->intraNext; - if (comm->finalizeCalled == false) { + if (curIntraComm->finalizeCalled == false) { struct ncclCommFinalizeAsyncJob job; job.comm = curIntraComm; /* every comm aborts, commDestroySync should not be blocked. */ if ((ret = commDestroySync((struct ncclAsyncJob*) &job)) != ncclSuccess) WARN("commReclaim: comm %p (rank = %d) in abort, error %d", curIntraComm, curRank, ret); } + } + + /* ncclProxyDestroy() loop must be put after commDestroySync() loop. Namely, you cannot do: + * while(...) { + * commDestroySync(...); + * ncclProxyDestroy(...); + * } + * Considering one process multi-gpu case, we must guarantee all kernels are complete before + * we free proxy resources; otherwise, we will face invalid memory issues where proxy connection + * and related intermediate memory from one rank are freed but other ranks are still using it. + * This is not a problem for multi-process case, since intermediate memory is opened by CUDA IPC + * or mmap where memory free is guarded by CUDA driver and operating system, so we will not have + * invalid memory access issue. */ + nextIntraComm = intracomm0; + while (nextIntraComm) { + curIntraComm = nextIntraComm; + curRank = curIntraComm->rank; + nextIntraComm = nextIntraComm->intraNext; + + /* free intraprocess proxy resources. */ + if ((ret = ncclProxyDestroy(curIntraComm)) != ncclSuccess) { + WARN("commReclaim: comm %p (rank = %d) destroys proxy resource error %d", curIntraComm, curRank, ret); + } + } + + /* free local resources. */ + nextIntraComm = intracomm0; + while (nextIntraComm) { + curIntraComm = nextIntraComm; + curRank = curIntraComm->rank; + nextIntraComm = nextIntraComm->intraNext; if ((ret = commCleanup(curIntraComm)) != ncclSuccess) { WARN("commReclaim: cleanup comm %p rank %d failed in destroy/abort, error %d", curIntraComm, curRank, ret); @@ -1467,11 +1535,16 @@ fail: NCCL_API(ncclResult_t, ncclCommDestroy, ncclComm_t comm); ncclResult_t ncclCommDestroy(ncclComm_t comm) { - NVTX3_FUNC_RANGE_IN(nccl_domain); - if (comm == NULL) + if (comm == NULL) { + NVTX3_FUNC_RANGE_IN(nccl_domain); return ncclSuccess; + } int rank = comm->rank, nranks = comm->nRanks, cudaDev = comm->cudaDev; + + NvtxParamsCommInitRank payload{rank, nranks, cudaDev}; + NVTX3_FUNC_WITH_PARAMS(CommDestroy, CommInitRankSchema, payload) + int64_t busId = comm->busId; TRACE(NCCL_INIT, "comm %p rank %d nRanks %d cudaDev %d busId %lx", comm, rank, nranks, cudaDev, busId); // Try and prevent a double free of the comm struct (user error) @@ -1491,11 +1564,16 @@ ncclResult_t ncclCommDestroy(ncclComm_t comm) { NCCL_API(ncclResult_t, ncclCommAbort, ncclComm_t comm); ncclResult_t ncclCommAbort(ncclComm_t comm) { - NVTX3_FUNC_RANGE_IN(nccl_domain); - if (comm == NULL) + if (comm == NULL) { + NVTX3_FUNC_RANGE_IN(nccl_domain); return ncclSuccess; + } int rank = comm->rank, nranks = comm->nRanks, cudaDev = comm->cudaDev; + + NvtxParamsCommInitRank payload{rank, nranks, cudaDev}; + NVTX3_FUNC_WITH_PARAMS(CommAbort, CommInitRankSchema, payload) + int64_t busId = comm->busId; TRACE(NCCL_INIT, "comm %p rank %d nRanks %d cudaDev %d busId %lx", comm, rank, nranks, cudaDev, busId); diff --git a/src/init_nvtx.cc b/src/init_nvtx.cc new file mode 100644 index 0000000..44face6 --- /dev/null +++ b/src/init_nvtx.cc @@ -0,0 +1,26 @@ +#include "nccl.h" +#include "nvtx.h" + +static constexpr const nvtxPayloadEnum_t NvtxEnumRedSchema[] = { + {"Sum", ncclSum}, + {"Product", ncclProd}, + {"Max", ncclMax}, + {"Min", ncclMin}, + {"Avg", ncclAvg} +}; + +// Must be called before the first call to any reduction operation. +void initNvtxRegisteredEnums() { + // Register schemas and strings + constexpr const nvtxPayloadEnumAttr_t eAttr { + .fieldMask = NVTX_PAYLOAD_ENUM_ATTR_ENTRIES | NVTX_PAYLOAD_ENUM_ATTR_NUM_ENTRIES | + NVTX_PAYLOAD_ENUM_ATTR_SIZE | NVTX_PAYLOAD_ENUM_ATTR_SCHEMA_ID, + .name = NULL, + .entries = NvtxEnumRedSchema, + .numEntries = std::extent::value, + .sizeOfEnum = sizeof(ncclRedOp_t), + .schemaId = NVTX_PAYLOAD_ENTRY_NCCL_REDOP + }; + + nvtxPayloadEnumRegister(nvtx3::domain::get(), &eAttr); +} diff --git a/src/misc/cudawrap.cc b/src/misc/cudawrap.cc index b1786f4..e2c1a6f 100644 --- a/src/misc/cudawrap.cc +++ b/src/misc/cudawrap.cc @@ -87,7 +87,7 @@ static void initOnceFunc() { cudaLib = dlopen(path, RTLD_LAZY); if (cudaLib == NULL) { - WARN("Failed to find CUDA library in %s (NCCL_CUDA_PATH=%s)", ncclCudaPath, ncclCudaPath); + WARN("Failed to find CUDA library (NCCL_CUDA_PATH='%s') : %s", ncclCudaPath ? ncclCudaPath : "", dlerror()); goto error; } diff --git a/src/misc/shmutils.cc b/src/misc/shmutils.cc index a432ff6..9f17903 100644 --- a/src/misc/shmutils.cc +++ b/src/misc/shmutils.cc @@ -15,79 +15,152 @@ #include #include -// Change functions behavior to match other SYS functions -static int shm_allocate(int fd, const int shmSize) { - int err = posix_fallocate(fd, 0, shmSize); - if (err) { errno = err; return -1; } - return 0; -} -static int shm_map(int fd, const int shmSize, void** ptr) { - *ptr = mmap(NULL, shmSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - return (*ptr == MAP_FAILED) ? -1 : 0; +struct shmHandleInternal { + int fd; + char* shmPath; + char* shmPtr; + void* devShmPtr; + size_t shmSize; + size_t realShmSize; + int* refcount; +}; + +static void shmHandleInit(int fd, char* shmPath, size_t shmSize, size_t realShmSize, char* hptr, void* dptr, bool create, struct shmHandleInternal* handle) { + handle->fd = fd; + handle->shmPtr = hptr; + handle->devShmPtr = dptr; + handle->shmSize = shmSize; + handle->realShmSize = realShmSize; + handle->refcount = (int*)(hptr + shmSize); + if (create) { + int slen = strlen(shmPath); + handle->shmPath = (char*)malloc(slen + 1); + memcpy(handle->shmPath, shmPath, slen + 1); + if (hptr) memset(hptr, 0, shmSize); + } else { + handle->shmPath = NULL; + } + return; } -static ncclResult_t ncclShmSetup(char* shmPath, const int shmSize, int* fd, void** ptr, int create) { +ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** devShmPtr, int refcount, ncclShmHandle_t* handle) { + int fd = -1; + char* hptr = NULL; + void* dptr = NULL; + ncclResult_t ret = ncclSuccess; + struct shmHandleInternal* tmphandle; + bool create = refcount > 0 ? true : false; + const size_t refSize = sizeof(int); /* extra sizeof(int) bytes for reference count */ + const size_t realShmSize = shmSize + refSize; + + *handle = *shmPtr = NULL; /* assume shmPtr and handle always set correctly by users. */ + EQCHECKGOTO(tmphandle = (struct shmHandleInternal*)malloc(sizeof(struct shmHandleInternal)), NULL, ret, fail); if (create) { + /* refcount > 0 means the caller tries to allocate a shared memory. This shared memory segment will have + * refcount references; when the peer attaches, it should pass -1 to reduce one reference count. When it + * goes down to 0, unlink should be called in order to delete shared memory file. */ if (shmPath[0] == '\0') { sprintf(shmPath, "/dev/shm/nccl-XXXXXX"); - *fd = mkstemp(shmPath); + fd = mkstemp(shmPath); } else { - SYSCHECKVAL(open(shmPath, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR), "open", *fd); + SYSCHECKGOTO(fd = open(shmPath, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR), ret, fail); } - if (ftruncate(*fd, shmSize) != 0) { - WARN("Error: failed to extend %s to %d bytes", shmPath, shmSize); - return ncclSystemError; + + if (ftruncate(fd, realShmSize) != 0) { + WARN("Error: failed to extend %s to %ld bytes", shmPath, realShmSize); + ret = ncclSystemError; + goto fail; } - INFO(NCCL_ALLOC, "Allocated %d bytes of shared memory in %s\n", shmSize, shmPath); + INFO(NCCL_ALLOC, "Allocated %ld bytes of shared memory in %s", realShmSize, shmPath); } else { - SYSCHECKVAL(open(shmPath, O_RDWR, S_IRUSR | S_IWUSR), "open", *fd); - } - *ptr = (char*)mmap(NULL, shmSize, PROT_READ|PROT_WRITE, MAP_SHARED, *fd, 0); - if (*ptr == NULL) { - WARN("Could not map %s\n", shmPath); - return ncclSystemError; - } - close(*fd); - *fd = -1; - if (create) memset(*ptr, 0, shmSize); - return ncclSuccess; -} - -ncclResult_t ncclShmOpen(char* shmPath, const int shmSize, void** shmPtr, void** devShmPtr, int create) { - int fd = -1; - void* ptr = MAP_FAILED; - ncclResult_t res = ncclSuccess; - - NCCLCHECKGOTO(ncclShmSetup(shmPath, shmSize, &fd, &ptr, create), res, sysError); - if (devShmPtr) { - CUDACHECKGOTO(cudaHostRegister(ptr, shmSize, cudaHostRegisterMapped), res, cudaError); - CUDACHECKGOTO(cudaHostGetDevicePointer(devShmPtr, ptr, 0), res, cudaError); + SYSCHECKGOTO(fd = open(shmPath, O_RDWR, S_IRUSR | S_IWUSR), ret, fail); } - *shmPtr = ptr; - return ncclSuccess; -sysError: - WARN("Error while %s shared memory segment %s (size %d)", create ? "creating" : "attaching to", shmPath, shmSize); -cudaError: - if (fd != -1) close(fd); - if (create) shm_unlink(shmPath); - if (ptr != MAP_FAILED) munmap(ptr, shmSize); - *shmPtr = NULL; - return res; -} + hptr = (char*)mmap(NULL, realShmSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (hptr == MAP_FAILED) { + WARN("Could not map %s size %zi, error: %s", shmPath, realShmSize, strerror(errno)); + ret = ncclSystemError; + goto fail; + } -ncclResult_t ncclShmUnlink(const char* shmPath) { - if (shmPath != NULL) SYSCHECK(unlink(shmPath), "unlink"); - return ncclSuccess; -} + if (create) { + *(int*)(hptr + shmSize) = refcount; + } else { + int remref = __atomic_sub_fetch((int*)(hptr + shmSize), 1, __ATOMIC_RELAXED); + if (remref == 0) { + /* the last peer has completed attachment, it should unlink the shm mem file. */ + if (unlink(shmPath) != 0) { + WARN("unlink shared memory %s failed, error: %s", shmPath, strerror(errno)); + } + } -ncclResult_t ncclShmClose(void* shmPtr, void* devShmPtr, const int shmSize) { - if (shmPtr) { - if (devShmPtr) CUDACHECK(cudaHostUnregister(shmPtr)); - if (munmap(shmPtr, shmSize) != 0) { - WARN("munmap of shared memory failed"); - return ncclSystemError; + if (refcount != -1) { + WARN("attaching memory should only reduce refcount by 1 but %d is passed", refcount); } } - return ncclSuccess; + + if (devShmPtr) { + CUDACHECKGOTO(cudaHostRegister((void*)hptr, realShmSize, cudaHostRegisterMapped), ret, fail); + CUDACHECKGOTO(cudaHostGetDevicePointer(&dptr, (void*)hptr, 0), ret, fail); + } + + shmHandleInit(fd, shmPath, shmSize, realShmSize, hptr, dptr, create, tmphandle); +exit: + *shmPtr = hptr; + if (devShmPtr) *devShmPtr = dptr; + *handle = (ncclShmHandle_t)tmphandle; + return ret; +fail: + WARN("Error while %s shared memory segment %s (size %ld)", create ? "creating" : "attaching to", shmPath, shmSize); + if (tmphandle) { + shmHandleInit(fd, shmPath, shmSize, realShmSize, hptr, dptr, create, tmphandle); + ncclShmClose((ncclShmHandle_t)tmphandle); + tmphandle = NULL; + } + hptr = NULL; + dptr = NULL; + goto exit; +} + +ncclResult_t ncclShmClose(ncclShmHandle_t handle) { + ncclResult_t ret = ncclSuccess; + struct shmHandleInternal* tmphandle = (struct shmHandleInternal*)handle; + if (tmphandle) { + if (tmphandle->fd >= 0) { + close(tmphandle->fd); + if (tmphandle->shmPath != NULL && *tmphandle->refcount > 0) { + if (unlink(tmphandle->shmPath) != 0) { + WARN("unlink shared memory %s failed, error: %s", tmphandle->shmPath, strerror(errno)); + ret = ncclSystemError; + } + free(tmphandle->shmPath); + } + } + + if (tmphandle->shmPtr) { + if (tmphandle->devShmPtr) CUDACHECK(cudaHostUnregister(tmphandle->shmPtr)); + if (munmap(tmphandle->shmPtr, tmphandle->realShmSize) != 0) { + WARN("munmap of shared memory %p size %ld failed, error: %s", tmphandle->shmPtr, tmphandle->realShmSize, strerror(errno)); + ret = ncclSystemError; + } + } + free(tmphandle); + } + return ret; +} + +ncclResult_t ncclShmUnlink(ncclShmHandle_t handle) { + ncclResult_t ret = ncclSuccess; + struct shmHandleInternal* tmphandle = (struct shmHandleInternal*)handle; + if (tmphandle) { + if (tmphandle->shmPath != NULL) { + if (unlink(tmphandle->shmPath) != 0) { + WARN("unlink shared memory %s failed, error: %s", tmphandle->shmPath, strerror(errno)); + ret = ncclSystemError; + } + free(tmphandle->shmPath); + tmphandle->shmPath = NULL; + } + } + return ret; } diff --git a/src/misc/socket.cc b/src/misc/socket.cc index 7161aee..e861480 100644 --- a/src/misc/socket.cc +++ b/src/misc/socket.cc @@ -12,6 +12,52 @@ #include #include +static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) { + int bytes = 0; + *closed = 0; + char* data = (char*)ptr; + char line[SOCKET_NAME_MAXLEN+1]; + do { + if (op == NCCL_SOCKET_RECV) bytes = recv(sock->fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); + if (op == NCCL_SOCKET_SEND) bytes = send(sock->fd, data+(*offset), size-(*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL); + if (op == NCCL_SOCKET_RECV && bytes == 0) { + *closed = 1; + return ncclSuccess; + } + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + WARN("socketProgressOpt: Call to recv from %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); + return ncclRemoteError; + } else { + bytes = 0; + } + } + (*offset) += bytes; + if (sock->abortFlag && *sock->abortFlag != 0) { + INFO(NCCL_NET, "socketProgressOpt: abort called"); + return ncclInternalError; + } + } while (bytes > 0 && (*offset) < size); + return ncclSuccess; +} + +static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { + int closed; + NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0, &closed)); + if (closed) { + char line[SOCKET_NAME_MAXLEN+1]; + WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); + return ncclRemoteError; + } + return ncclSuccess; +} + +static ncclResult_t socketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { + while (*offset < size) + NCCLCHECK(socketProgress(op, sock, ptr, size, offset)); + return ncclSuccess; +} + /* Format a string representation of a (union ncclSocketAddress *) socket address using getnameinfo() * * Output: "IPv4/IPv6 address" @@ -194,7 +240,7 @@ int ncclFindInterfaceMatchSubnet(char* ifNames, union ncclSocketAddress* localAd return found; } -ncclResult_t ncclGetSocketAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair) { +ncclResult_t ncclSocketGetAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair) { if (!(ip_port_pair && strlen(ip_port_pair) > 1)) { WARN("Net : string is null"); return ncclInvalidArgument; @@ -296,7 +342,7 @@ int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNa INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId); // Try to find interface that is in the same subnet as the IP in comm id union ncclSocketAddress idAddr; - ncclGetSocketAddrFromString(&idAddr, commId); + ncclSocketGetAddrFromString(&idAddr, commId); nIfs = ncclFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs); } } @@ -310,39 +356,31 @@ int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNa } ncclResult_t ncclSocketListen(struct ncclSocket* sock) { - /* IPv4/IPv6 support */ - int family = sock->addr.sa.sa_family; - int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); - int flags; - - /* Create socket and bind it to a port */ - int fd = socket(family, SOCK_STREAM, 0); - if (fd == -1) { - WARN("Net : Socket creation failed : %s", strerror(errno)); - return ncclSystemError; + if (sock == NULL) { + WARN("ncclSocketListen: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->fd == -1) { + WARN("ncclSocketListen: file descriptor is -1"); + return ncclInvalidArgument; } if (socketToPort(&sock->addr)) { // Port is forced by env. Make sure we get the port. int opt = 1; #if defined(SO_REUSEPORT) - SYSCHECK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); + SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); #else - SYSCHECK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); + SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); #endif } - /* The socket is set non-blocking for OS level, but asyncFlag is used to control - * blocking and non-blocking behavior in user level. */ - EQCHECK(flags = fcntl(fd, F_GETFL), -1); - SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); - // addr port should be 0 (Any port) - SYSCHECK(bind(fd, &sock->addr.sa, salen), "bind"); + SYSCHECK(bind(sock->fd, &sock->addr.sa, sock->salen), "bind"); /* Get the assigned Port */ - socklen_t size = salen; - SYSCHECK(getsockname(fd, &sock->addr.sa, &size), "getsockname"); + socklen_t size = sock->salen; + SYSCHECK(getsockname(sock->fd, &sock->addr.sa, &size), "getsockname"); #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; @@ -352,220 +390,431 @@ ncclResult_t ncclSocketListen(struct ncclSocket* sock) { /* Put the socket in listen mode * NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn */ - SYSCHECK(listen(fd, 16384), "listen"); - sock->fd = fd; + SYSCHECK(listen(sock->fd, 16384), "listen"); + sock->state = ncclSocketStateReady; return ncclSuccess; } -static ncclResult_t getFdState(int fd, enum ncclSocketState* state) { - struct pollfd pfd; - int timeout = 1, ret; - socklen_t rlen = sizeof(int); - - memset(&pfd, 0, sizeof(struct pollfd)); - pfd.fd = fd; - pfd.events = POLLOUT; - SYSCHECK(ret = poll(&pfd, 1, timeout), "poll"); - if (ret == 0) { - ret = EINPROGRESS; - } else { - /* check socket status */ - EQCHECK(ret == 1 && (pfd.revents & POLLOUT), 0); - SYSCHECK(getsockopt(fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt"); - } - - if (ret == EINPROGRESS || ret == ECONNREFUSED) - *state = ncclSocketConnecting; - else if (ret == 0) - *state = ncclSocketConnected; - else - *state = ncclSocketError; - return ncclSuccess; +ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr) { + if (sock == NULL) { + WARN("ncclSocketGetAddr: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->state != ncclSocketStateReady) return ncclInternalError; + memcpy(addr, &sock->addr, sizeof(union ncclSocketAddress)); + return ncclSuccess; } -ncclResult_t ncclGetSocketState(struct ncclSocket* sock, enum ncclSocketState* state) { - NCCLCHECK(getFdState(sock->fd, state)); - sock->state = *state; +static ncclResult_t socketTryAccept(struct ncclSocket* sock) { + socklen_t socklen = sizeof(union ncclSocketAddress); + sock->fd = accept(sock->acceptFd, &sock->addr.sa, &socklen); + if (sock->fd != -1) { + sock->state = ncclSocketStateAccepted; + } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + WARN("socketTryAccept: get errno %d that is not EAGAIN or EWOULDBLOCK", errno); + return ncclSystemError; + } + return ncclSuccess; +} + +static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) { + uint64_t magic; + enum ncclSocketType type; + int received = 0; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); + if (received == 0) return ncclSuccess; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); + if (magic != sock->magic) { + WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic); + close(sock->fd); + sock->fd = -1; + // Ignore spurious connection and accept again + sock->state = ncclSocketStateAccepting; return ncclSuccess; + } else { + received = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &type, sizeof(type), &received)); + if (type != sock->type) { + WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type); + sock->state = ncclSocketStateError; + close(sock->fd); + sock->fd = -1; + return ncclInternalError; + } else { + sock->state = ncclSocketStateReady; + } + } + return ncclSuccess; +} + +static ncclResult_t socketStartConnect(struct ncclSocket* sock) { + /* blocking/non-blocking connect() is determined by asyncFlag. */ + int ret = connect(sock->fd, &sock->addr.sa, sock->salen); + + if (ret == 0) { + sock->state = ncclSocketStateConnected; + return ncclSuccess; + } else if (errno == EINPROGRESS) { + sock->state = ncclSocketStateConnectPolling; + return ncclSuccess; + } else if (errno == ECONNREFUSED) { + if (++sock->refusedRetries == RETRY_REFUSED_TIMES) { + sock->state = ncclSocketStateError; + WARN("socketStartConnect: exceeded retries (%d)", sock->refusedRetries); + return ncclRemoteError; + } + usleep(SLEEP_INT); + if (sock->refusedRetries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); + return ncclSuccess; + } else if (errno == ETIMEDOUT) { + if (++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) { + sock->state = ncclSocketStateError; + WARN("socketStartConnect: exceeded timeouts (%d)", sock->timedOutRetries); + return ncclRemoteError; + } + usleep(SLEEP_INT); + return ncclSuccess; + } else { + char line[SOCKET_NAME_MAXLEN+1]; + sock->state = ncclSocketStateError; + WARN("socketStartConnect: Connect to %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); + return ncclSystemError; + } +} + +static ncclResult_t socketPollConnect(struct ncclSocket* sock) { + struct pollfd pfd; + int timeout = 1, ret; + socklen_t rlen = sizeof(int); + + memset(&pfd, 0, sizeof(struct pollfd)); + pfd.fd = sock->fd; + pfd.events = POLLOUT; + SYSCHECK(ret = poll(&pfd, 1, timeout), "poll"); + if (ret == 0) return ncclSuccess; + + /* check socket status */ + EQCHECK(ret == 1 && (pfd.revents & POLLOUT), 0); + SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt"); + + if (ret == 0) { + sock->state = ncclSocketStateConnected; + } else if (ret == ECONNREFUSED) { + if (++sock->refusedRetries == RETRY_REFUSED_TIMES) { + sock->state = ncclSocketStateError; + WARN("socketPollConnect: exceeded retries (%d)", sock->refusedRetries); + return ncclRemoteError; + } + if (sock->refusedRetries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); + usleep(SLEEP_INT); + sock->state = ncclSocketStateConnecting; + } else if (ret == ETIMEDOUT) { + if (++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) { + sock->state = ncclSocketStateError; + WARN("socketPollConnect: exceeded timeouts (%d)", sock->timedOutRetries); + return ncclRemoteError; + } + usleep(SLEEP_INT); + sock->state = ncclSocketStateConnecting; + } else if (ret != EINPROGRESS) { + sock->state = ncclSocketStateError; + return ncclSystemError; + } + return ncclSuccess; +} + +ncclResult_t ncclSocketPollConnect(struct ncclSocket* sock) { + if (sock == NULL) { + WARN("ncclSocketPollConnect: pass NULL socket"); + return ncclInvalidArgument; + } + NCCLCHECK(socketPollConnect(sock)); + return ncclSuccess; +} + +static ncclResult_t socketFinalizeConnect(struct ncclSocket* sock) { + int sent = 0; + NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); + if (sent == 0) return ncclSuccess; + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); + sent = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent)); + sock->state = ncclSocketStateReady; + return ncclSuccess; +} + +static ncclResult_t socketProgressState(struct ncclSocket* sock) { + if (sock->state == ncclSocketStateAccepting) { + NCCLCHECK(socketTryAccept(sock)); + } + if (sock->state == ncclSocketStateAccepted) { + NCCLCHECK(socketFinalizeAccept(sock)); + } + if (sock->state == ncclSocketStateConnecting) { + NCCLCHECK(socketStartConnect(sock)); + } + if (sock->state == ncclSocketStateConnectPolling) { + NCCLCHECK(socketPollConnect(sock)); + } + if (sock->state == ncclSocketStateConnected) { + NCCLCHECK(socketFinalizeConnect(sock)); + } + return ncclSuccess; +} + +ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running) { + if (sock == NULL) { + *running = 0; + return ncclSuccess; + } + if (sock->state == ncclSocketStateError || sock->state == ncclSocketStateClosed) { + WARN("ncclSocketReady: unexpected socket state %d", sock->state); + return ncclRemoteError; + } + *running = (sock->state == ncclSocketStateReady) ? 1 : 0; + if (*running == 0) { + NCCLCHECK(socketProgressState(sock)); + *running = (sock->state == ncclSocketStateReady) ? 1 : 0; + } + return ncclSuccess; } ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { +#ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; - /* IPv4/IPv6 support */ - int family = sock->addr.sa.sa_family; - if (family != AF_INET && family != AF_INET6) { - WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)", - ncclSocketToString(&sock->addr, line), family, AF_INET, AF_INET6); +#endif + const int one = 1; + + if (sock == NULL) { + WARN("ncclSocketConnect: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->fd == -1) { + WARN("ncclSocketConnect: file descriptor is -1"); + return ncclInvalidArgument; + } + + if (sock->state != ncclSocketStateInitialized) { + WARN("ncclSocketConnect: wrong socket state %d", sock->state); + if (sock->state == ncclSocketStateError) return ncclRemoteError; return ncclInternalError; } - int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); - int flags; - - /* Connect to a hostname / port */ - int fd = socket(family, SOCK_STREAM, 0); - if (fd == -1) { - WARN("Net : Socket creation failed : %s", strerror(errno)); - return ncclSystemError; - } - - const int one = 1; - SYSCHECK(setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); - - /* The socket is set non-blocking for OS level, but asyncFlag is used to control - * blocking and non-blocking behavior in user level. */ - EQCHECK(flags = fcntl(fd, F_GETFL), -1); - SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); - - /* const int bufsize = 128*1024; - SYSCHECK(setsockopt(fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt"); - SYSCHECK(setsockopt(fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/ - TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", ncclSocketToString(&sock->addr, line)); - int ret; - int timedout_retries = 0; - int refused_retries = 0; -retry: - /* blocking/non-blocking connect() is determined by asyncFlag. */ - ret = connect(fd, &sock->addr.sa, salen); + SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); - if (!sock->asyncFlag) { - /* blocking socket, need retry if connect fails. */ - if (errno == EINPROGRESS || errno == EAGAIN || errno == EALREADY || - (errno == ECONNREFUSED && ++refused_retries < RETRY_REFUSED_TIMES) || - (errno == ETIMEDOUT && ++timedout_retries < RETRY_TIMEDOUT_TIMES)) { - /* check abortFlag as long as we have chance to retry. */ - if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; - if (errno == ECONNREFUSED && refused_retries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); - usleep(SLEEP_INT); - goto retry; - } + sock->state = ncclSocketStateConnecting; + do { + NCCLCHECK(socketProgressState(sock)); + } while (sock->asyncFlag == 0 && + (sock->abortFlag == NULL || *sock->abortFlag == 0) && + (sock->state == ncclSocketStateConnecting || + sock->state == ncclSocketStateConnectPolling || + sock->state == ncclSocketStateConnected)); - /* If connect() fails with errno == EAGAIN/EINPROGRESS/ETIMEDOUT, we may want to try connect again. - * However, it can return EISCONN instead of success which indicates connection is built up in - * background already. No need to call connect() again. */ - if (ret == 0 || errno == EISCONN) { - sock->fd = fd; + if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; + + switch (sock->state) { + case ncclSocketStateConnecting: + case ncclSocketStateConnectPolling: + case ncclSocketStateConnected: + case ncclSocketStateReady: return ncclSuccess; - } - } else { - sock->fd = fd; - return ncclSuccess; + case ncclSocketStateError: + return ncclSystemError; + default: + WARN("ncclSocketConnect: wrong socket state %d", sock->state); + return ncclInternalError; } - - WARN("Net : Connect to %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); - return ncclRemoteError; } -ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSocket) { - socklen_t socklen = sizeof(union ncclSocketAddress); - struct pollfd pollfd; - int tmpFd = sock->fd = -1; - int pollret; +ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSock) { + ncclResult_t ret = ncclSuccess; - pollfd.fd = listenSocket->fd; - pollfd.events = POLLIN; -retry: - if ((pollret = poll(&pollfd, 1, listenSocket->asyncFlag ? 0 : 100)) < 0) { - return ncclSystemError; - } else { - tmpFd = accept(listenSocket->fd, &sock->addr.sa, &socklen); + if (listenSock == NULL || sock == NULL) { + WARN("ncclSocketAccept: pass NULL socket"); + ret = ncclInvalidArgument; + goto exit; + } + if (listenSock->state != ncclSocketStateReady) { + WARN("ncclSocketAccept: wrong socket state %d", listenSock->state); + if (listenSock->state == ncclSocketStateError) + ret = ncclSystemError; + else + ret = ncclInternalError; + goto exit; } - if (!listenSocket->asyncFlag) { - /* blocking socket, if tmpFd is still -1, we need to retry */ - if (tmpFd == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { - if (listenSocket->abortFlag && *listenSocket->abortFlag != 0) return ncclInternalError; - goto retry; - } - EQCHECK(tmpFd, -1); + if (sock->acceptFd == -1) { + memcpy(sock, listenSock, sizeof(struct ncclSocket)); + sock->acceptFd = listenSock->fd; + sock->state = ncclSocketStateAccepting; } - sock->fd = tmpFd; - return ncclSuccess; + do { + NCCLCHECKGOTO(socketProgressState(sock), ret, exit); + } while (sock->asyncFlag == 0 && + (sock->abortFlag == NULL || *sock->abortFlag == 0) && + (sock->state == ncclSocketStateAccepting || + sock->state == ncclSocketStateAccepted)); + + if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; + + switch (sock->state) { + case ncclSocketStateAccepting: + case ncclSocketStateAccepted: + case ncclSocketStateReady: + ret = ncclSuccess; + break; + case ncclSocketStateError: + ret = ncclSystemError; + break; + default: + WARN("ncclSocketAccept: wrong socket state %d", sock->state); + ret = ncclInternalError; + break; + } + +exit: + return ret; } -ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, volatile uint32_t* abortFlag, int asyncFlag) { - if (sock == NULL) - return ncclSuccess; +ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag) { + ncclResult_t ret = ncclSuccess; + if (sock == NULL) goto exit; + sock->timedOutRetries = 0; + sock->refusedRetries = 0; + sock->abortFlag = abortFlag; + sock->asyncFlag = asyncFlag; + sock->state = ncclSocketStateInitialized; + sock->magic = magic; + sock->type = type; sock->fd = -1; + sock->acceptFd = -1; + if (addr) { + /* IPv4/IPv6 support */ + int family; memcpy(&sock->addr, addr, sizeof(union ncclSocketAddress)); + family = sock->addr.sa.sa_family; + if (family != AF_INET && family != AF_INET6) { + char line[SOCKET_NAME_MAXLEN+1]; + WARN("ncclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)", + ncclSocketToString(&sock->addr, line), family, AF_INET, AF_INET6); + ret = ncclInternalError; + goto fail; + } + sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); + + /* Connect to a hostname / port */ + sock->fd = socket(family, SOCK_STREAM, 0); + if (sock->fd == -1) { + WARN("ncclSocketInit: Socket creation failed : %s", strerror(errno)); + ret = ncclSystemError; + goto fail; + } } else { memset(&sock->addr, 0, sizeof(union ncclSocketAddress)); } - sock->abortFlag = abortFlag; - sock->asyncFlag = asyncFlag; - sock->state = ncclSocketStateNum; - return ncclSuccess; -} -static ncclResult_t ncclSocketProgressOpt(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) { - int bytes = 0; - *closed = 0; - char* data = (char*)ptr; - char line[SOCKET_NAME_MAXLEN+1]; - do { - if (op == NCCL_SOCKET_RECV) bytes = recv(sock->fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); - if (op == NCCL_SOCKET_SEND) bytes = send(sock->fd, data+(*offset), size-(*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL); - if (op == NCCL_SOCKET_RECV && bytes == 0) { - *closed = 1; - return ncclSuccess; - } - if (bytes == -1) { - if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { - WARN("Net : Call to recv from %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); - return ncclRemoteError; - } else { - bytes = 0; - } - } - (*offset) += bytes; - if (sock->abortFlag && *sock->abortFlag != 0) { - INFO(NCCL_NET, "Socket progress: abort called"); - return ncclInternalError; - } - } while (bytes > 0 && (*offset) < size); - return ncclSuccess; + /* Set socket as non-blocking if async or if we need to be able to abort */ + if ((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) { + int flags; + EQCHECKGOTO(flags = fcntl(sock->fd, F_GETFL), -1, ret, fail); + SYSCHECKGOTO(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), ret, fail); + } + +exit: + return ret; +fail: + goto exit; } ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { - int closed; - NCCLCHECK(ncclSocketProgressOpt(op, sock, ptr, size, offset, 0, &closed)); - if (closed) { - char line[SOCKET_NAME_MAXLEN+1]; - WARN("Net : Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); - return ncclRemoteError; + if (sock == NULL) { + WARN("ncclSocketProgress: pass NULL socket"); + return ncclInvalidArgument; } + NCCLCHECK(socketProgress(op, sock, ptr, size, offset)); return ncclSuccess; } ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { - while (*offset < size) - NCCLCHECK(ncclSocketProgress(op, sock, ptr, size, offset)); + if (sock == NULL) { + WARN("ncclSocketWait: pass NULL socket"); + return ncclInvalidArgument; + } + NCCLCHECK(socketWait(op, sock, ptr, size, offset)); return ncclSuccess; } ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size) { int offset = 0; - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_SEND, sock, ptr, size, &offset)); + if (sock == NULL) { + WARN("ncclSocketSend: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->state != ncclSocketStateReady) { + WARN("ncclSocketSend: socket state (%d) is not ready", sock->state); + return ncclInternalError; + } + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, ptr, size, &offset)); return ncclSuccess; } ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size) { int offset = 0; - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_RECV, sock, ptr, size, &offset)); + if (sock == NULL) { + WARN("ncclSocketRecv: pass NULL socket"); + return ncclInvalidArgument; + } + if (sock->state != ncclSocketStateReady) { + WARN("ncclSocketRecv: socket state (%d) is not ready", sock->state); + return ncclInternalError; + } + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, ptr, size, &offset)); return ncclSuccess; } // Receive or detect connection closed ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed) { int offset = 0; + if (sock == NULL) { + WARN("ncclSocketTryRecv: pass NULL socket"); + return ncclInvalidArgument; + } *closed = 0; while (offset < size) { - NCCLCHECK(ncclSocketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); + NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); if (*closed) return ncclSuccess; } return ncclSuccess; } + +ncclResult_t ncclSocketClose(struct ncclSocket* sock) { + if (sock != NULL) { + if (sock->fd >= 0) close(sock->fd); + sock->state = ncclSocketStateClosed; + sock->fd = -1; + } + return ncclSuccess; +} + +ncclResult_t ncclSocketGetFd(struct ncclSocket* sock, int* fd) { + if (sock == NULL) { + WARN("ncclSocketGetFd: pass NULL socket"); + return ncclInvalidArgument; + } + if (fd) *fd = sock->fd; + return ncclSuccess; +} + +ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock) { + if (sock == NULL) { + WARN("ncclSocketGetFd: pass NULL socket"); + return ncclInvalidArgument; + } + sock->fd = fd; + return ncclSuccess; +} diff --git a/src/misc/strongstream.cc b/src/misc/strongstream.cc index d07698b..61b0e4b 100644 --- a/src/misc/strongstream.cc +++ b/src/misc/strongstream.cc @@ -305,7 +305,8 @@ static void mergeTips(struct ncclStrongStreamGraph* a, cudaGraphNode_t const* bN } ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b + struct ncclCudaGraph graph, struct ncclStrongStream* a, struct ncclStrongStream* b, + bool b_subsumes_a ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { @@ -319,6 +320,7 @@ ncclResult_t ncclStrongStreamWaitStream( NCCLCHECK(checkGraphId(ag, graph.graphId)); struct ncclStrongStreamGraph* bg = b->graphHead; NCCLCHECK(checkGraphId(bg, graph.graphId)); + if (b_subsumes_a) ag->tipCount = 0; mergeTips(ag, bg->tipNodes, bg->tipCount); } a->serialEventNeedsRecord = true; @@ -330,7 +332,8 @@ ncclResult_t ncclStrongStreamWaitStream( } ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, struct ncclStrongStream* a, cudaStream_t b + struct ncclCudaGraph graph, struct ncclStrongStream* a, cudaStream_t b, + bool b_subsumes_a ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { @@ -351,6 +354,7 @@ ncclResult_t ncclStrongStreamWaitStream( } struct ncclStrongStreamGraph* ag = a->graphHead; NCCLCHECK(checkGraphId(ag, graph.graphId)); + if (b_subsumes_a) ag->tipCount = 0; mergeTips(ag, bNodes, bCount); } a->serialEventNeedsRecord = true; @@ -362,7 +366,8 @@ ncclResult_t ncclStrongStreamWaitStream( } ncclResult_t ncclStrongStreamWaitStream( - struct ncclCudaGraph graph, cudaStream_t a, struct ncclStrongStream* b + struct ncclCudaGraph graph, cudaStream_t a, struct ncclStrongStream* b, + bool b_subsumes_a ) { #if CUDART_VERSION >= 11030 if (graph.graph == nullptr) { @@ -374,7 +379,9 @@ ncclResult_t ncclStrongStreamWaitStream( } else { struct ncclStrongStreamGraph* bg = b->graphHead; NCCLCHECK(checkGraphId(bg, graph.graphId)); - CUDACHECK(cudaStreamUpdateCaptureDependencies(a, bg->tipNodes, bg->tipCount, cudaStreamAddCaptureDependencies)); + CUDACHECK(cudaStreamUpdateCaptureDependencies(a, bg->tipNodes, bg->tipCount, + b_subsumes_a ? cudaStreamSetCaptureDependencies : cudaStreamAddCaptureDependencies + )); } #else CUDACHECK(cudaEventRecord(b->scratchEvent, b->cudaStream)); diff --git a/src/nccl.h.in b/src/nccl.h.in index ccb8f57..44a68e9 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -123,7 +123,7 @@ const char* pncclGetErrorString(ncclResult_t result); * comm is currently unused and can be set to NULL */ const char* ncclGetLastError(ncclComm_t comm); -const char* pncclGetError(ncclComm_t comm); +const char* pncclGetLastError(ncclComm_t comm); /* Checks whether the comm has encountered any asynchronous errors */ ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); diff --git a/src/proxy.cc b/src/proxy.cc index 696f57f..2103b7a 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -437,8 +437,7 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op) int stepSize = info->comm->buffSizes[op->protocol]/NCCL_STEPS; - // If nNodes > 1 and we're using Simple, reduce the stepSize to increase shared buffer utilization - if (info->comm->nNodes > 1 && op->protocol == NCCL_PROTO_SIMPLE) stepSize = info->comm->p2pNetChunkSize; + if (op->protocol == NCCL_PROTO_SIMPLE) stepSize = info->comm->p2pChunkSize; info->chunkSize = stepSize; op->root = info->root; @@ -532,6 +531,8 @@ static ncclResult_t progressOps(struct ncclComm* comm, struct ncclProxyProgressS return ncclSuccess; } +NCCL_PARAM(ProxyAppendBatchSize, "PROXY_APPEND_BATCH_SIZE", 16); + static ncclResult_t ncclProxyGetPostedOps(struct ncclComm* comm, int* added) { struct ncclProxyProgressState* state = &comm->proxyState.progressState; if (state->opsPool == NULL) return ncclInternalError; @@ -570,9 +571,16 @@ process_nextops: int freeOpEnd[NCCL_MAX_LOCAL_RANKS]; for (int i=0; ilocalRanks; i++) freeOp[i] = -1; + uint64_t lastOpCount = 0; + int lastPeer = -1; + int count = 0; for (int opIndex = state->nextOps; opIndex != -1;) { struct ncclProxyOp* peerOp = pool->ops+opIndex; int peer = opIndex / MAX_OPS_PER_PEER; + if ((lastOpCount && peerOp->opCount != lastOpCount) || ((lastPeer != -1) && peer != lastPeer)) count++; + if (count == ncclParamProxyAppendBatchSize()+1) break; + lastOpCount = peerOp->opCount; + lastPeer = peer; if (peerOp->connection == NULL) return ncclInternalError; if (peerOp->next != -1) __builtin_prefetch(pool->ops+peerOp->next); NCCLCHECK(ProxyAppend(state, peerOp)); @@ -676,7 +684,7 @@ void* ncclProxyProgress(void *comm_) { int lastIdle = 0; struct ncclProxyArgs profArgs; // Only used for profiling purposes - while (state->stop == 0 && *comm->abortFlag == 0) { + while ((state->stop == false || (state->stop == true && state->active)) && *comm->abortFlag == 0) { int idle = 1; ncclResult_t ret = progressOps(comm, state, state->active, &idle); if (ret != ncclSuccess) { @@ -686,18 +694,17 @@ void* ncclProxyProgress(void *comm_) { } if (lastIdle == 0 && idle == 1) ncclProfilingRecord(&profArgs, 0, 0, ncclProxyProfileIdle); if (lastIdle == 1 && idle == 0) ncclProfilingRecord(&profArgs, 0, 0, ncclProxyProfileActive); - if (idle) { - int added = 0; - TIME_START(3); + int added = 0; + TIME_START(3); + if (state->stop == false) ret = ncclProxyGetPostedOps(comm, &added); - if (added) { TIME_STOP(3); } else { TIME_CANCEL(3); } - if (ret != ncclSuccess) { - (void) ncclCommSetAsyncError(comm, ret); - INFO(NCCL_ALL,"%s:%d -> %d [Proxy Thread]", __FILE__, __LINE__, ret); - } - if (added == 0) { - sched_yield(); // No request progressed. Let others run. - } + if (added) { TIME_STOP(3); } else { TIME_CANCEL(3); } + if (ret != ncclSuccess) { + (void) ncclCommSetAsyncError(comm, ret); + INFO(NCCL_ALL,"%s:%d -> %d [Proxy Thread]", __FILE__, __LINE__, ret); + } + if (added == 0) { + sched_yield(); // No request progressed. Let others run. } lastIdle = idle; } @@ -814,7 +821,7 @@ static ncclResult_t ncclProxyFreeConnections(struct ncclProxyConnectionPool* poo int max = b == pool->banks-1 ? pool->offset : NCCL_PROXY_CONN_POOL_SIZE; for (int i=0; ipools[b]+i; - if (connection->initFlag == true) { + if (connection->state != connUninitialized) { NCCLCHECK(proxyFree(connection, comm)); } } @@ -827,6 +834,10 @@ static ncclResult_t ncclProxyFreeConnections(struct ncclProxyConnectionPool* poo #include "transport.h" ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, int rank, struct ncclProxyConnector* proxyConn) { + struct ncclSocket* sock; + int ready; + int type = ncclProxyMsgInit; + // Keep one connection per mlocal rank proxyConn->connection = NULL; proxyConn->rank = rank; @@ -834,17 +845,18 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in NCCLCHECK(ncclCalloc(&comm->proxyState.peerSocks, comm->localRanks)); NCCLCHECK(ncclCalloc(&comm->proxyState.proxyOps, comm->localRanks)); NCCLCHECK(ncclCalloc(&comm->proxyState.sharedDevMems, comm->localRanks)); - for (int r=0; rlocalRanks; r++) { - NCCLCHECK(ncclSocketInit(&comm->proxyState.peerSocks[r], NULL, comm->abortFlag, 0)); + for (int i = 0; i < comm->localRanks; ++i) { + NCCLCHECK(ncclSocketSetFd(-1, &comm->proxyState.peerSocks[i])); } } + NCCLCHECK(ncclTopoGetLocalRank(comm->topo, rank, &proxyConn->localRank)); - struct ncclSocket* sock = comm->proxyState.peerSocks+proxyConn->localRank; - if (sock->fd == -1) { - memcpy(&sock->addr, comm->proxyState.peerAddresses+rank, sizeof(union ncclSocketAddress)); + sock = comm->proxyState.peerSocks + proxyConn->localRank; + NCCLCHECK(ncclSocketReady(sock, &ready)); + if (!ready) { + NCCLCHECK(ncclSocketInit(sock, comm->proxyState.peerAddresses+rank, comm->magic, ncclSocketTypeProxy, comm->abortFlag)); NCCLCHECK(ncclSocketConnect(sock)); } - int type = ncclProxyMsgInit; NCCLCHECK(ncclSocketSend(sock, &type, sizeof(int))); NCCLCHECK(ncclSocketSend(sock, &transport, sizeof(int))); NCCLCHECK(ncclSocketSend(sock, &send, sizeof(int))); @@ -857,7 +869,7 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in NCCLCHECK(ncclSocketRecv(sock, poolPath+sizeof("/dev/shm/nccl-")-1, sizeof("XXXXXX")-1)); struct ncclProxyOps* proxyOps = comm->proxyState.proxyOps+proxyConn->localRank; if (proxyOps->pool == NULL) { - NCCLCHECK(ncclShmOpen(poolPath, sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, 0)); + NCCLCHECK(ncclShmOpen(poolPath, sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, -1, &proxyOps->handle)); proxyOps->nextOps = proxyOps->nextOpsEnd = proxyOps->freeOp = -1; } } @@ -868,11 +880,12 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in const char* ncclProxyMsgTypeStr[] = { "Unknown", "Init", "SharedInit", "Setup", "Connect", "Start", "Close", "Abort", "Stop" }; ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) { - if (proxyConn->comm->proxyState.peerSocks == NULL) return ncclInternalError; - struct ncclSocket* sock = proxyConn->comm->proxyState.peerSocks+proxyConn->localRank; - if (sock->fd == -1) return ncclInternalError; - ncclResult_t ret; + struct ncclSocket* sock; + ncclResult_t ret = ncclSuccess; + if (proxyConn->comm->proxyState.peerSocks == NULL) return ncclInternalError; + sock = proxyConn->comm->proxyState.peerSocks + proxyConn->localRank; + if (sock == NULL) return ncclInternalError; NCCLCHECKGOTO(ncclSocketSend(sock, &type, sizeof(int)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &proxyConn->connection, sizeof(void*)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &reqSize, sizeof(int)), ret, error); @@ -882,7 +895,6 @@ ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* return ncclSuccess; error: WARN("Proxy Call to rank %d failed (%s)", proxyConn->comm->localRankToRank[proxyConn->localRank], ncclProxyMsgTypeStr[type]); - sock->fd = -1; return ret; } @@ -892,16 +904,15 @@ static ncclResult_t proxyProgressInit(struct ncclComm* comm) { int size = sizeof(struct ncclProxyOpsPool); struct ncclProxyOpsPool* pool = NULL; - char shmPath[sizeof("/dev/shm/nccl-XXXXXX")]; - shmPath[0] = '\0'; - NCCLCHECK(ncclShmOpen(shmPath, size, (void**)&pool, NULL, 1)); - - // Init pool - pool->nextOps = -1; - // The service thread may be launched already but localRanks may not be set yet. while (comm->localRanks == 0) sched_yield(); + char shmPath[sizeof("/dev/shm/nccl-XXXXXX")]; + shmPath[0] = '\0'; + NCCLCHECK(ncclShmOpen(shmPath, size, (void**)&pool, NULL, comm->localRanks + 1, &state->handle)); + // Init pool + pool->nextOps = -1; + for (int r=0; rlocalRanks; r++) { pool->freeOps[r] = r*MAX_OPS_PER_PEER; for (int i=0; iops[r*MAX_OPS_PER_PEER+i].next = r*MAX_OPS_PER_PEER+i+1; @@ -928,7 +939,7 @@ static ncclResult_t proxyProgressInit(struct ncclComm* comm) { static void proxyOpsFree(struct ncclComm* comm) { struct ncclProxyProgressState* state = &comm->proxyState.progressState; - if (ncclShmClose(state->opsPool, NULL, sizeof(struct ncclProxyOpsPool)) != ncclSuccess) { + if (ncclShmClose(state->handle) != ncclSuccess) { WARN("[Service thread] shm close failed"); } } @@ -937,10 +948,8 @@ ncclResult_t ncclProxyShmUnlink(struct ncclComm* comm) { struct ncclProxyProgressState* state = &comm->proxyState.progressState; if (state->opsPool == NULL) return ncclSuccess; - char shmPath[] = "/dev/shm/nccl-XXXXXX"; - memcpy(shmPath+sizeof("/dev/shm/nccl-")-1, state->opsPoolShmSuffix, sizeof("XXXXXX")-1); - if (ncclShmUnlink(shmPath) != ncclSuccess) { - WARN("[Service thread] shm unlink failed"); + if (ncclShmUnlink(state->handle) != ncclSuccess) { + WARN("[Service thread] proxy ops shm unlink failed"); } return ncclSuccess; } @@ -965,7 +974,7 @@ static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclPr NCCLCHECK(ncclSocketSend(sock, state->opsPoolShmSuffix, sizeof("XXXXXX")-1)); } INFO(NCCL_NET, "New proxy %s connection %d from local rank %d, transport %d", connection->send ? "send":"recv", id, connection->localRank, connection->transport); - __atomic_store_n(&connection->initFlag, true, __ATOMIC_RELEASE); + __atomic_store_n(&connection->state, connInitialized, __ATOMIC_RELEASE); return ncclSuccess; } @@ -980,6 +989,7 @@ static ncclResult_t proxyConnSharedInit(struct ncclProxyLocalPeer* peer, struct int nChannels; NCCLCHECK(ncclSocketRecv(sock, &nChannels, sizeof(int))); if (connection->tcomm->proxySharedInit) NCCLCHECK(connection->tcomm->proxySharedInit(connection, comm, nChannels)); + __atomic_store_n(&connection->state, connSharedInitialized, __ATOMIC_RELEASE); return ncclSuccess; } @@ -991,14 +1001,29 @@ static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclC NCCLCHECK(op->connection->tcomm->proxyConnect(op->connection, comm, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done)); } else return ncclInternalError; if (done) { - if (op->respSize) NCCLCHECK(ncclSocketSend(op->connection->sock, op->respBuff, op->respSize)); - if (op->reqBuff) free(op->reqBuff); - if (op->respBuff) free(op->respBuff); - op->reqBuff = NULL; - op->respBuff = NULL; + if (op->type == ncclProxyMsgSetup) + __atomic_store_n(&op->connection->state, connSetupDone, __ATOMIC_RELEASE); + else if (op->type == ncclProxyMsgConnect) + __atomic_store_n(&op->connection->state, connConnected, __ATOMIC_RELEASE); + /* if setup or connect is done, we should not return any error at this point since + * ncclSocketSend might already send the respBuff to the requester. If we still choose + * to abort and close the connection, it can cause segfault if the requester is using + * the respBuff. */ + if (op->respSize) ncclSocketSend(op->connection->sock, op->respBuff, op->respSize); + if (op->reqBuff) { + free(op->reqBuff); + op->reqBuff = NULL; + } + if (op->respBuff) { + free(op->respBuff); + op->respBuff = NULL; + } op->type = 0; (*asyncOpCount)--; + } else if (*comm->abortFlag != 0) { + return ncclInternalError; } + return ncclSuccess; } @@ -1042,36 +1067,52 @@ void* ncclProxyService(void* _args) { struct ncclProxyLocalPeer peers[NCCL_MAX_LOCAL_RANKS]; memset(&peers, 0, sizeof(struct ncclProxyLocalPeer)*NCCL_MAX_LOCAL_RANKS); for (int s=0; sabortFlag, 0); pollfds[s].fd = -1; pollfds[s].events = POLLHUP|POLLIN; } - pollfds[NCCL_MAX_LOCAL_RANKS].fd = comm->proxyState.listenSock->fd; + if (ncclSocketGetFd(comm->proxyState.listenSock, &pollfds[NCCL_MAX_LOCAL_RANKS].fd) != ncclSuccess) { + WARN("[Proxy Service] Get listenSock fd fails\n"); + return NULL; + }; pollfds[NCCL_MAX_LOCAL_RANKS].events = POLLIN; int maxnpeers = 0; int npeers = 0; int stop = 0; int asyncOpCount = 0; - while ((stop == 0 || (stop == 1 && npeers > 0)) && *comm->abortFlag == 0) { + while (stop == 0 || (stop == 1 && npeers > 0)) { + /* Even if local comm aborts, we cannot let proxy thread exit if we still have peer + * connections. Need to wait until all other related comms call abort and safely exit + * together, or we could face segmentation fault. */ + if (*comm->abortFlag != 0) stop = 1; /* never let proxy service thread blocks in poll, or it cannot receive abortFlag. */ - if (poll(pollfds, NCCL_MAX_LOCAL_RANKS+1, asyncOpCount ? 0 : 500) < 0) { - WARN("[Proxy Service] Poll failed: %s\n", strerror(errno)); + int ret; + do { + ret = poll(pollfds, NCCL_MAX_LOCAL_RANKS+1, asyncOpCount ? 0 : 500); + } while (ret < 0 && errno == EINTR); + if (ret < 0) { + WARN("[Proxy Service] Poll failed: %s", strerror(errno)); return NULL; } if (pollfds[NCCL_MAX_LOCAL_RANKS].revents) { int s = 0; - while (s < NCCL_MAX_LOCAL_RANKS && peers[s].sock.fd != -1) s++; + while (s < NCCL_MAX_LOCAL_RANKS && pollfds[s].fd >= 0) s++; if (s == NCCL_MAX_LOCAL_RANKS) { WARN("[Proxy service] Too many connections (%d max)", NCCL_MAX_LOCAL_RANKS); return NULL; } if (maxnpeers < s+1) maxnpeers = s+1; - struct ncclSocket* sock = &peers[s].sock; - if (ncclSocketAccept(sock, comm->proxyState.listenSock) != ncclSuccess) { + if (ncclSocketInit(&peers[s].sock) != ncclSuccess) { + WARN("[Service thread] Initialize peers[%d].sock fails\n", s); + return NULL; + } + if (ncclSocketAccept(&peers[s].sock, comm->proxyState.listenSock) != ncclSuccess) { WARN("[Service thread] Accept failed %s", strerror(errno)); } else { - pollfds[s].fd = sock->fd; + if (ncclSocketGetFd(&peers[s].sock, &pollfds[s].fd) != ncclSuccess) { + WARN("[Service thread] Get peers[%d].sock fd fails\n", s); + return NULL; + } npeers++; peers[s].localRank = -1; } @@ -1083,10 +1124,12 @@ void* ncclProxyService(void* _args) { int closeConn = 0; int type = 0; ncclResult_t res = ncclSuccess; + + if (pollfds[s].fd == -1) continue; if (op->type != 0) { res = proxyProgressAsync(op, comm, &asyncOpCount); type = op->type; - if (res != ncclSuccess) op->type = 0; + if (res != ncclSuccess) closeConn = 1; } else if (pollfds[s].revents & POLLIN) { int closed; if (ncclSocketTryRecv(sock, &type, sizeof(int), &closed) != ncclSuccess) { @@ -1120,26 +1163,31 @@ void* ncclProxyService(void* _args) { closeConn = 1; } if (closeConn) { - close(sock->fd); - sock->fd = pollfds[s].fd = -1; + ncclSocketClose(sock); + if (op->reqBuff) { + free(op->reqBuff); + op->reqBuff = NULL; + } + if (op->respBuff) { + free(op->respBuff); + op->respBuff = NULL; + } + op->type = 0; + pollfds[s].fd = -1; npeers--; } } } - /* wait until main thread flush all NCCL operations. */ - while (*comm->abortFlag != 0 && __atomic_load_n(&comm->proxyState.safeAbortFlag, __ATOMIC_ACQUIRE) == 0) - usleep(1000); // Wait for all operations to complete and stop progress thread before freeing any resource if (ncclProxyProgressDestroy(comm) != ncclSuccess) { WARN("[Proxy Service] proxyDestroy failed"); } for (int s=0; sproxyState.listenSock->fd); - free(comm->proxyState.listenSock); + ncclSocketClose(comm->proxyState.listenSock); proxyOpsFree(comm); return NULL; } @@ -1164,32 +1212,29 @@ ncclResult_t ncclProxyDestroy(struct ncclComm* comm) { if (state->peerAddresses) { if (*comm->abortFlag == 0) { struct ncclSocket sock; - sock.abortFlag = NULL; - sock.asyncFlag = 0; - memcpy(&sock.addr, comm->proxyState.peerAddresses+comm->rank, sizeof(union ncclSocketAddress)); - NCCLCHECK(ncclSocketConnect(&sock)); int type = ncclProxyMsgStop; + NCCLCHECK(ncclSocketInit(&sock, comm->proxyState.peerAddresses + comm->rank, comm->magic, ncclSocketTypeProxy, comm->abortFlag)); + NCCLCHECK(ncclSocketConnect(&sock)); NCCLCHECK(ncclSocketSend(&sock, &type, sizeof(int))); - close(sock.fd); - } else { - /* when abortFlag is set, all socket related communications are no longer reliable. We need to - * set a flag to let proxy thread exit. */ - __atomic_store_n(&state->safeAbortFlag, 1, __ATOMIC_RELEASE); + NCCLCHECK(ncclSocketClose(&sock)); } free(state->peerAddresses); } + if (state->peerSocks) { for (int i=0; ilocalRanks; i++) { - if (state->peerSocks[i].fd != -1) { + int fd; + NCCLCHECK(ncclSocketGetFd(state->peerSocks + i, &fd)); + if (fd >= 0) { if (state->proxyOps[i].pool) { - NCCLCHECK(ncclShmClose(state->proxyOps[i].pool, NULL, sizeof(struct ncclProxyOpsPool))); + NCCLCHECK(ncclShmClose(state->proxyOps[i].handle)); } if (state->sharedDevMems[i]) { CUDACHECK(cudaIpcCloseMemHandle(state->sharedDevMems[i])); } int type = ncclProxyMsgClose; - if (*comm->abortFlag == 0) NCCLCHECK(ncclSocketSend(state->peerSocks+i, &type, sizeof(int))); - close(state->peerSocks[i].fd); + if (*comm->abortFlag == 0) NCCLCHECK(ncclSocketSend(state->peerSocks + i, &type, sizeof(int))); + NCCLCHECK(ncclSocketClose(state->peerSocks + i)); } } free(state->peerSocks); diff --git a/src/transport.cc b/src/transport.cc index b3ca90d..66d8b51 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -67,12 +67,11 @@ void dumpData(struct ncclConnect* data, int ndata) { ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType/*=NULL*/) { // Stream used during transport setup; need for P2P pre-connect + CUDA Graph + ncclResult_t ret = ncclSuccess; int highestType = TRANSPORT_P2P; // track highest transport type - - cudaStream_t transportSetupStream; - CUDACHECK(cudaStreamCreateWithFlags(&transportSetupStream, cudaStreamNonBlocking)); - struct ncclConnect data[2*MAXCHANNELS]; + + NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->hostStream), ret, fail); for (int i=1; inRanks; i++) { int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0); int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; @@ -86,7 +85,7 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* TIME_START(0); for (int c=0; c(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex, &type)); + NCCLCHECKGOTO(selectTransport<0>(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex, &type), ret, fail); if (type > highestType) highestType = type; } } @@ -95,7 +94,7 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* struct ncclConnect* sendData = recvData+recvChannels; for (int c=0; c(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex, &type)); + NCCLCHECKGOTO(selectTransport<1>(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex, &type), ret, fail); if (type > highestType) highestType = type; } } @@ -104,16 +103,16 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* TIME_START(2); if (sendPeer == recvPeer) { if (recvChannels+sendChannels) { - NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels))); - NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels))); + NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); + NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); sendData = data; recvData = data+sendChannels; } } else { - if (recvChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels)); - if (sendChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels)); - if (sendChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels)); - if (recvChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels)); + if (recvChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels), ret, fail); + if (sendChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels), ret, fail); + if (sendChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels), ret, fail); + if (recvChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels), ret, fail); } TIME_STOP(2); @@ -121,9 +120,9 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* for (int c=0; cchannels[c].peers[sendPeer].send + connIndex; - NCCLCHECK(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn)); + NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn), ret, fail); conn->connected = 1; - CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, transportSetupStream)); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); } } TIME_STOP(3); @@ -131,19 +130,23 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* for (int c=0; cchannels[c].peers[recvPeer].recv + connIndex; - NCCLCHECK(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn)); + NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn), ret, fail); conn->connected = 1; - CUDACHECK(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, transportSetupStream)); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); } } TIME_STOP(4); comm->connectRecv[recvPeer] = comm->connectSend[sendPeer] = 0UL; } - CUDACHECK(cudaStreamSynchronize(transportSetupStream)); - CUDACHECK(cudaStreamDestroy(transportSetupStream)); + if (highestTransportType != NULL) *highestTransportType = highestType; TIME_PRINT("P2P Setup/Connect"); - return ncclSuccess; +exit: + NCCLCHECK(ncclStrongStreamWaitStream(ncclCudaGraphNone(), &comm->deviceStream, &comm->hostStream)); + NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->hostStream)); + return ret; +fail: + goto exit; } extern struct ncclTransport collNetTransport; diff --git a/src/transport/net.cc b/src/transport/net.cc index a3a1579..bdb2e2d 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -63,6 +63,7 @@ struct connectMapMem{ char shmPath[PATH_MAX]; cudaIpcMemHandle_t ipc; }; + ncclShmHandle_t handle; }; struct connectMap { @@ -224,13 +225,12 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph } static ncclResult_t netMapShm(struct connectMapMem* mem) { - NCCLCHECK(ncclShmOpen(mem->shmPath, mem->size, (void**)&mem->cpuPtr, (void**)&mem->gpuPtr, 0)); - NCCLCHECK(ncclShmUnlink(mem->shmPath)); + NCCLCHECK(ncclShmOpen(mem->shmPath, mem->size, (void**)&mem->cpuPtr, (void**)&mem->gpuPtr, -1, &mem->handle)); return ncclSuccess; } static ncclResult_t netCreateShm(struct connectMapMem* mem) { mem->shmPath[0] = '\0'; // Let ncclShmOpen create a tmp file - NCCLCHECK(ncclShmOpen(mem->shmPath, mem->size, (void**)&mem->cpuPtr, NULL, 1)); + NCCLCHECK(ncclShmOpen(mem->shmPath, mem->size, (void**)&mem->cpuPtr, NULL, 1, &mem->handle)); return ncclSuccess; } @@ -339,7 +339,7 @@ static ncclResult_t sendFree(struct ncclConnector* send) { struct connectMap* map = (struct connectMap*)(send->transportResources); if (map) { if (map->sameProcess == 0) { - NCCLCHECK(ncclShmClose(map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].gpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].size)); + NCCLCHECK(ncclShmClose(map->mems[NCCL_NET_MAP_HOSTMEM].handle)); if (map->mems[NCCL_NET_MAP_DEVMEM].size) { CUDACHECK(cudaIpcCloseMemHandle(map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr)); } @@ -372,7 +372,7 @@ static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, int local struct ncclProxySharedP2p* state = type == 0 ? &peer->send : &peer->recv; state->refcount++; if (state->size == 0) { - state->size = nChannels*NCCL_SHARED_STEPS*comm->p2pNetChunkSize; + state->size = nChannels*NCCL_SHARED_STEPS*comm->p2pChunkSize; } if (size) *size = state->size; @@ -399,7 +399,7 @@ static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, int local static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int channel, int slot, int* offset) { // Use different pools for different channels and also separate send/recv. int globalSlot = (channel*NCCL_SHARED_STEPS)+slot; - *offset = comm->p2pNetChunkSize * globalSlot; + *offset = comm->p2pChunkSize * globalSlot; return ncclSuccess; } @@ -523,6 +523,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm)); connection->proxyAppendPtr = &connection->proxyAppend; } + if (resources->netSendComm == NULL) { *done = 0; return ncclSuccess; @@ -657,6 +658,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm)); connection->proxyAppendPtr = &connection->proxyAppend; } + if (resources->netRecvComm == NULL) { *done = 0; return ncclSuccess; @@ -746,67 +748,75 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { struct sendResources* resources = (struct sendResources*)(connection->transportResources); - if (resources == NULL) { // NVB Preconnect + if (connection->state == connSharedInitialized) { // NVB Preconnect NCCLCHECK(sharedBuffersDestroy(comm, connection->localRank, 0)); return ncclSuccess; } - for (int p=0; pbuffers[p]) { - NCCLCHECK(ncclNetDeregMr(comm, resources->netSendComm, resources->mhandles[p])); + + if (connection->state == connConnected) { + for (int p=0; pbuffers[p]) { + NCCLCHECK(ncclNetDeregMr(comm, resources->netSendComm, resources->mhandles[p])); + } } - } - struct connectMapMem* mems = resources->map.mems; - if (resources->map.sameProcess) { - NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); - } else { - NCCLCHECK(ncclShmClose(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, NULL, mems[NCCL_NET_MAP_HOSTMEM].size)); - } - CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); - if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); - if (resources->shared) { - NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 0)); - if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { - struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->remoteRank; - comms->sendRefCount[resources->channelId]--; - if (comms->sendRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseSend(comm, comms->sendComm[resources->channelId])); + struct connectMapMem* mems = resources->map.mems; + if (resources->map.sameProcess) { + NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); + } else { + NCCLCHECK(ncclShmClose(mems[NCCL_NET_MAP_HOSTMEM].handle)); + } + CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); + if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); + if (resources->shared) { + NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 0)); + if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { + struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->remoteRank; + comms->sendRefCount[resources->channelId]--; + if (comms->sendRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseSend(comm, comms->sendComm[resources->channelId])); + } else { + NCCLCHECK(ncclNetCloseSend(comm, resources->netSendComm)); + } } else { NCCLCHECK(ncclNetCloseSend(comm, resources->netSendComm)); } - } else { - NCCLCHECK(ncclNetCloseSend(comm, resources->netSendComm)); } - free(resources); + + if (connection->state == connSetupDone) free(resources); return ncclSuccess; } static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct ncclComm* comm) { struct recvResources* resources = (struct recvResources*)(connection->transportResources); - if (resources == NULL) { // NVB Preconnect + if (connection->state == connSharedInitialized) { // NVB Preconnect NCCLCHECK(sharedBuffersDestroy(comm, connection->localRank, 1)); return ncclSuccess; } - for (int p=0; pbuffers[p]) { - NCCLCHECK(ncclNetDeregMr(comm, resources->netRecvComm, resources->mhandles[p])); + + if (connection->state == connConnected) { + for (int p=0; pbuffers[p]) { + NCCLCHECK(ncclNetDeregMr(comm, resources->netRecvComm, resources->mhandles[p])); + } } - } - struct connectMapMem* mems = resources->map.mems; - NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); - CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); - if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); - if (resources->shared) { - NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 1)); - if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { - struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->proxyRank; - comms->recvRefCount[resources->channelId]--; - if (comms->recvRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseRecv(comm, comms->recvComm[resources->channelId])); + struct connectMapMem* mems = resources->map.mems; + NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr)); + CUDACHECK(cudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr)); + if (mems[NCCL_NET_MAP_GDCMEM].cpuPtr) NCCLCHECK(ncclGdrCudaFree(resources->gdrDesc)); + if (resources->shared) { + NCCLCHECK(sharedBuffersDestroy(comm, resources->localRank, 1)); + if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { + struct ncclSharedNetComms* comms = comm->proxyState.progressState.netComms[resources->netDev]+resources->proxyRank; + comms->recvRefCount[resources->channelId]--; + if (comms->recvRefCount[resources->channelId] == 0) NCCLCHECK(ncclNetCloseRecv(comm, comms->recvComm[resources->channelId])); + } else { + NCCLCHECK(ncclNetCloseRecv(comm, resources->netRecvComm)); + } } else { NCCLCHECK(ncclNetCloseRecv(comm, resources->netRecvComm)); } - } else { - NCCLCHECK(ncclNetCloseRecv(comm, resources->netRecvComm)); } - free(resources); + + if (connection->state == connSetupDone) free(resources); return ncclSuccess; } diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index a1f8897..8818554 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -57,6 +57,7 @@ struct alignas(64) ncclIbDev { int realPort; int maxQp; struct ncclIbMrCache mrCache; + int ar; // ADAPTIVE_ROUTING }; #define MAX_IB_PORT 15 @@ -80,6 +81,7 @@ NCCL_PARAM(IbSl, "IB_SL", 0); NCCL_PARAM(IbTc, "IB_TC", 0); NCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192); NCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2); +NCCL_PARAM(IbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2); pthread_t ncclIbAsyncThread; static void* ncclIbAsyncThreadMain(void* args) { @@ -221,6 +223,11 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { ncclIbDevs[ncclNIbDevs].mrCache.population = 0; ncclIbDevs[ncclNIbDevs].mrCache.slots = NULL; + // Enable ADAPTIVE_ROUTING by default on IB networks + // But allow it to be overloaded by an env parameter + ncclIbDevs[ncclNIbDevs].ar = (portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND) ? 1 : 0; + if (ncclParamIbAdaptiveRouting() != -2) ncclIbDevs[ncclNIbDevs].ar = ncclParamIbAdaptiveRouting(); + pthread_create(&ncclIbAsyncThread, NULL, ncclIbAsyncThreadMain, context); ncclSetThreadName(ncclIbAsyncThread, "NCCL IbAsync %2d", ncclNIbDevs); pthread_detach(ncclIbAsyncThread); // will not be pthread_join()'d @@ -298,11 +305,6 @@ failure: return ncclSystemError; } -static ncclResult_t GetSocketAddr(union ncclSocketAddress* addr) { - memcpy(addr, &ncclIbIfAddr, sizeof(*addr)); - return ncclSuccess; -} - #define NCCL_NET_IB_MAX_RECVS 8 ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { @@ -364,6 +366,7 @@ struct ncclIbCommStage { struct ncclIbHandle { union ncclSocketAddress connectAddr; // Filled by the target + uint64_t magic; // random number to help debugging struct ncclIbCommStage stage; // Used by the other side when connecting }; @@ -376,7 +379,7 @@ struct ncclIbRequest { struct ncclIbVerbs* verbs; int type; int events; - union ncclSocketAddress *addr; + struct ncclSocket* sock; int nreqs; union { struct { @@ -427,6 +430,7 @@ struct ncclIbSendComm { struct ibv_qp* qps[NCCL_IB_MAX_QPS]; int nqps; struct ibv_mr* fifoMr; + int ar; }; // The SendFifo needs to be 32-byte aligned and each element needs // to be a 32-byte multiple, so that an entry does not get split and @@ -571,19 +575,19 @@ ncclResult_t ncclIbListen(int dev, void* opaqueHandle, void** listenComm) { static_assert(sizeof(struct ncclIbHandle) < NCCL_NET_HANDLE_MAXSIZE, "ncclIbHandle size too large"); memset(handle, 0, sizeof(struct ncclIbHandle)); comm->dev = dev; - comm->sock.asyncFlag = 1; /* nonblocking socket is required by network communication. */ - NCCLCHECK(GetSocketAddr(&comm->sock.addr)); + handle->magic = NCCL_SOCKET_MAGIC; + NCCLCHECK(ncclSocketInit(&comm->sock, &ncclIbIfAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1)); NCCLCHECK(ncclSocketListen(&comm->sock)); - memcpy(&handle->connectAddr, &comm->sock.addr, sizeof(union ncclSocketAddress)); + NCCLCHECK(ncclSocketGetAddr(&comm->sock, &handle->connectAddr)); *listenComm = comm; return ncclSuccess; } ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) { struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle; - enum ncclSocketState conState; struct ncclIbCommStage* stage = &handle->stage; struct ncclIbSendComm* comm = (struct ncclIbSendComm*)stage->comm; + int ready; *sendComm = NULL; if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; @@ -594,20 +598,15 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) { } NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(struct ncclIbSendComm))); - NCCLCHECK(ncclSocketInit(&comm->sock, &handle->connectAddr, NULL, 1)); + NCCLCHECK(ncclSocketInit(&comm->sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1)); stage->comm = comm; stage->state = ncclIbCommStateConnect; NCCLCHECK(ncclSocketConnect(&comm->sock)); ib_connect_check: /* since ncclSocketConnect is async, we must check if connection is complete */ - NCCLCHECK(ncclGetSocketState(&comm->sock, &conState)); - if (conState == ncclSocketConnecting) { - /* expect user to call again */ - return ncclSuccess; - } else if (conState == ncclSocketError) { - return ncclRemoteError; - } + NCCLCHECK(ncclSocketReady(&comm->sock, &ready)); + if (!ready) return ncclSuccess; // IB Setup struct ibv_context* ctx; @@ -619,6 +618,7 @@ ib_connect_check: for (int q=0; qnqps; q++) { NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, comm->qps+q)); } + comm->ar = ncclIbDevs[dev].ar; // ADAPTIVE_ROUTING // Send my QP Info to receiver through the socket. Hope this won't block. struct ibv_port_attr portAttr; @@ -670,9 +670,10 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { struct ncclIbListenComm* lComm = (struct ncclIbListenComm*)listenComm; struct ncclIbCommStage* stage = &lComm->stage; struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*)stage->comm; + int ready; *recvComm = NULL; - if (stage->state == ncclIbCommStateAccept) goto ib_accept; + if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; if (stage->state == ncclIbCommStateRecv) goto ib_recv; if (stage->state == ncclIbCommStateSend) goto ib_send; if (stage->state != ncclIbCommStateStart) { @@ -683,12 +684,12 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { NCCLCHECK(ncclIbMalloc((void**)&rComm, sizeof(struct ncclIbRecvComm))); stage->comm = rComm; stage->state = ncclIbCommStateAccept; - NCCLCHECK(ncclSocketInit(&rComm->sock, NULL, lComm->sock.abortFlag, 1)); - -ib_accept: + NCCLCHECK(ncclSocketInit(&rComm->sock)); NCCLCHECK(ncclSocketAccept(&rComm->sock, &lComm->sock)); - if (rComm->sock.fd == -1) - return ncclSuccess; + +ib_accept_check: + NCCLCHECK(ncclSocketReady(&rComm->sock, &ready)); + if (!ready) return ncclSuccess; struct ncclIbQpInfo remQpInfo; stage->state = ncclIbCommStateRecv; @@ -791,7 +792,7 @@ ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest** if (r->type == NCCL_NET_IB_REQ_UNUSED) { r->verbs = verbs; r->events = 1; - r->addr = NULL; + r->sock = NULL; *req = r; return ncclSuccess; } @@ -966,8 +967,8 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { } struct ibv_send_wr* lastWr = comm->wrs+nreqs-1; - if (nreqs > 1 || reqs[0]->send.size > ncclParamIbArThreshold()) { - // When using adaptive routing, send the bulk of the data first as an + if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamIbArThreshold())) { + // When using ADAPTIVE_ROUTING, send the bulk of the data first as an // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote // completion. lastWr++; @@ -1033,28 +1034,31 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh // Sanity checks to catch user collective call count/size mismatches if (size > slots[r].size) { - char line[SOCKET_NAME_MAXLEN+1]; + char line[SOCKET_NAME_MAXLEN + 1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(&comm->sock, &addr); WARN("NET/IB : req %d/%d tag %x peer %s collective mismatch error, local size %d remote size %d", - r, nreqs, tag, ncclSocketToString(&comm->sock.addr, line), size, slots[r].size); + r, nreqs, tag, ncclSocketToString(&addr, line), size, slots[r].size); return ncclInvalidUsage; } // plus any potential programming errors else if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkey == 0) { - char line[SOCKET_NAME_MAXLEN+1]; - WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkey %x", - r, nreqs, tag, ncclSocketToString(&comm->sock.addr, line), slots[r].size, slots[r].addr, slots[r].rkey); + char line[SOCKET_NAME_MAXLEN + 1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(&comm->sock, &addr); + WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkey %x", + r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkey); return ncclInternalError; } struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); req->type = NCCL_NET_IB_REQ_SEND; - req->addr = &comm->sock.addr; + req->sock = &comm->sock; req->verbs = &comm->verbs; req->nreqs = nreqs; req->send.size = size; req->send.data = data; req->send.lkey = mr->lkey; req->send.offset = 0; - req->addr = &comm->sock.addr; req->events = comm->nqps; *request = reqs[r] = req; @@ -1147,7 +1151,7 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); req->type = NCCL_NET_IB_REQ_RECV; - req->addr = &comm->sock.addr; + req->sock = &comm->sock; req->nreqs = n; for (int i=0; irecv.sizes[i] = 0; @@ -1186,7 +1190,7 @@ ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); req->type = NCCL_NET_IB_REQ_FLUSH; - req->addr = &comm->sock.addr; + req->sock = &comm->sock; struct ibv_mr* mr = (struct ibv_mr*)mhandles[last]; struct ibv_send_wr wr; @@ -1234,8 +1238,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { struct ibv_wc *wc = wcs+w; if (wc->status != IBV_WC_SUCCESS) { char line[SOCKET_NAME_MAXLEN+1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(r->sock, &addr); WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d", - ncclSocketToString(r->addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err); + ncclSocketToString(&addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err); return ncclRemoteError; } @@ -1267,7 +1273,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { ncclResult_t ncclIbCloseSend(void* sendComm) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; if (comm) { - close(comm->sock.fd); + NCCLCHECK(ncclSocketClose(&comm->sock)); for (int q=0; qnqps; q++) if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); if (comm->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr)); @@ -1281,7 +1287,7 @@ ncclResult_t ncclIbCloseSend(void* sendComm) { ncclResult_t ncclIbCloseRecv(void* recvComm) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; if (comm) { - close(comm->sock.fd); + NCCLCHECK(ncclSocketClose(&comm->sock)); for (int q=0; qnqps; q++) if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); if (comm->gpuFlush.enabled) { @@ -1298,7 +1304,7 @@ ncclResult_t ncclIbCloseRecv(void* recvComm) { ncclResult_t ncclIbCloseListen(void* listenComm) { struct ncclIbListenComm* comm = (struct ncclIbListenComm*)listenComm; if (comm) { - close(comm->sock.fd); + NCCLCHECK(ncclSocketClose(&comm->sock)); free(comm); } return ncclSuccess; diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc index 678aab8..08a8c3a 100644 --- a/src/transport/net_socket.cc +++ b/src/transport/net_socket.cc @@ -18,16 +18,16 @@ /* Init functions */ static int ncclNetIfs = -1; -struct ncclSocketDev { +struct ncclNetSocketDev { union ncclSocketAddress addr; char devName[MAX_IF_NAME_SIZE]; char* pciPath; }; -static struct ncclSocketDev ncclSocketDevs[MAX_IFS]; +static struct ncclNetSocketDev ncclNetSocketDevs[MAX_IFS]; -pthread_mutex_t ncclSocketLock = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t ncclNetSocketLock = PTHREAD_MUTEX_INITIALIZER; -static ncclResult_t ncclSocketGetPciPath(char* devName, char** pciPath) { +static ncclResult_t ncclNetSocketGetPciPath(char* devName, char** pciPath) { char devicePath[PATH_MAX]; snprintf(devicePath, PATH_MAX, "/sys/class/net/%s/device", devName); // May return NULL if the file doesn't exist. @@ -35,9 +35,9 @@ static ncclResult_t ncclSocketGetPciPath(char* devName, char** pciPath) { return ncclSuccess; } -ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) { +ncclResult_t ncclNetSocketInit(ncclDebugLogger_t logFunction) { if (ncclNetIfs == -1) { - pthread_mutex_lock(&ncclSocketLock); + pthread_mutex_lock(&ncclNetSocketLock); if (ncclNetIfs == -1) { char names[MAX_IF_NAME_SIZE*MAX_IFS]; union ncclSocketAddress addrs[MAX_IFS]; @@ -52,9 +52,9 @@ ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) { line[0] = '\0'; addrline[SOCKET_NAME_MAXLEN] = '\0'; for (int i=0; iname = ncclSocketDevs[dev].devName; - props->pciPath = ncclSocketDevs[dev].pciPath; +ncclResult_t ncclNetSocketGetProperties(int dev, ncclNetProperties_t* props) { + props->name = ncclNetSocketDevs[dev].devName; + props->pciPath = ncclNetSocketDevs[dev].pciPath; props->guid = dev; props->ptrSupport = NCCL_PTR_HOST; - NCCLCHECK(ncclSocketGetSpeed(props->name, &props->speed)); + NCCLCHECK(ncclNetSocketGetSpeed(props->name, &props->speed)); props->latency = 0; // Not set props->port = 0; props->maxComms = 65536; @@ -104,12 +104,6 @@ ncclResult_t ncclSocketGetProperties(int dev, ncclNetProperties_t* props) { return ncclSuccess; } -ncclResult_t GetSocketAddr(int dev, union ncclSocketAddress* addr) { - if (dev >= ncclNetIfs) return ncclInternalError; - memcpy(addr, &ncclSocketDevs[dev].addr, sizeof(*addr)); - return ncclSuccess; -} - /* Communication functions */ #define MAX_SOCKETS 64 @@ -120,29 +114,30 @@ ncclResult_t GetSocketAddr(int dev, union ncclSocketAddress* addr) { NCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2); NCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2); -enum ncclSocketCommState { - ncclSocketCommStateStart = 0, - ncclSocketCommStateConnect = 1, - ncclSocketCommStateAccept = 3, - ncclSocketCommStateSend = 4, - ncclSocketCommStateRecv = 5, +enum ncclNetSocketCommState { + ncclNetSocketCommStateStart = 0, + ncclNetSocketCommStateConnect = 1, + ncclNetSocketCommStateAccept = 3, + ncclNetSocketCommStateSend = 4, + ncclNetSocketCommStateRecv = 5, }; -struct ncclSocketCommStage { - enum ncclSocketCommState state; +struct ncclNetSocketCommStage { + enum ncclNetSocketCommState state; uint8_t iteration; struct ncclSocket* sock; - struct ncclSocketComm* comm; + struct ncclNetSocketComm* comm; }; -struct ncclSocketHandle { +struct ncclNetSocketHandle { union ncclSocketAddress connectAddr; + uint64_t magic; // random number to help debugging int nSocks; int nThreads; - struct ncclSocketCommStage stage; + struct ncclNetSocketCommStage stage; }; -struct ncclSocketTask { +struct ncclNetSocketTask { int op; void* data; int size; @@ -152,41 +147,41 @@ struct ncclSocketTask { ncclResult_t result; }; -struct ncclSocketRequest { +struct ncclNetSocketRequest { int op; void* data; int size; struct ncclSocket* ctrlSock; int offset; int used; - struct ncclSocketComm* comm; - struct ncclSocketTask* tasks[MAX_SOCKETS]; + struct ncclNetSocketComm* comm; + struct ncclNetSocketTask* tasks[MAX_SOCKETS]; int nSubs; }; -struct ncclSocketTaskQueue { +struct ncclNetSocketTaskQueue { int next; int len; - struct ncclSocketTask* tasks; + struct ncclNetSocketTask* tasks; }; -struct ncclSocketThreadResources { - struct ncclSocketTaskQueue threadTaskQueue; +struct ncclNetSocketThreadResources { + struct ncclNetSocketTaskQueue threadTaskQueue; int stop; - struct ncclSocketComm* comm; + struct ncclNetSocketComm* comm; pthread_mutex_t threadLock; pthread_cond_t threadCond; }; -struct ncclSocketListenComm { +struct ncclNetSocketListenComm { struct ncclSocket sock; - struct ncclSocketCommStage stage; + struct ncclNetSocketCommStage stage; int nSocks; int nThreads; int dev; }; -struct ncclSocketComm { +struct ncclNetSocketComm { struct ncclSocket ctrlSock; struct ncclSocket socks[MAX_SOCKETS]; int dev; @@ -194,15 +189,15 @@ struct ncclSocketComm { int nSocks; int nThreads; int nextSock; - struct ncclSocketRequest requests[MAX_REQUESTS]; + struct ncclNetSocketRequest requests[MAX_REQUESTS]; pthread_t helperThread[MAX_THREADS]; - struct ncclSocketThreadResources threadResources[MAX_THREADS]; + struct ncclNetSocketThreadResources threadResources[MAX_THREADS]; }; void* persistentSocketThread(void *args_) { - struct ncclSocketThreadResources* resource = (struct ncclSocketThreadResources*)args_; - struct ncclSocketComm* comm = resource->comm; - struct ncclSocketTaskQueue* myQueue = &resource->threadTaskQueue; + struct ncclNetSocketThreadResources* resource = (struct ncclNetSocketThreadResources*)args_; + struct ncclNetSocketComm* comm = resource->comm; + struct ncclNetSocketTaskQueue* myQueue = &resource->threadTaskQueue; int nSocksPerThread = comm->nSocks / comm->nThreads; while (1) { int idle = 1; @@ -212,7 +207,7 @@ void* persistentSocketThread(void *args_) { do { repeat = 0; for (int j=0; jtasks+i+j; + struct ncclNetSocketTask* r = myQueue->tasks+i+j; if (r != NULL && r->used == 1 && r->offset < r->size) { r->result = ncclSocketProgress(r->op, r->sock, r->data, r->size, &r->offset); if (r->result != ncclSuccess) { @@ -236,7 +231,7 @@ void* persistentSocketThread(void *args_) { } } -ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) { +ncclResult_t ncclNetSocketGetNsockNthread(int dev, int* ns, int* nt) { int nSocksPerThread = ncclParamSocketNsocksPerThread(); int nThreads = ncclParamSocketNthreads(); if (nThreads > MAX_THREADS) { @@ -247,7 +242,7 @@ ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) { // Auto-detection int autoNt=0, autoNs=1; // By default, we only use the main thread and do not spawn extra threads char vendorPath[PATH_MAX]; - snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclSocketDevs[dev].devName); + snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclNetSocketDevs[dev].devName); char* rPath = realpath(vendorPath, NULL); int fd = open(rPath, O_RDONLY); free(rPath); @@ -285,36 +280,20 @@ end: return ncclSuccess; } -ncclResult_t ncclSocketNewListenComm(struct ncclSocketListenComm** comm) { - NCCLCHECK(ncclCalloc(comm, 1)); - (*comm)->sock.fd = -1; - return ncclSuccess; -} - -ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) { - NCCLCHECK(ncclCalloc(comm, 1)); - (*comm)->ctrlSock.fd = -1; - for (int i=0; i < MAX_SOCKETS; i++) { - (*comm)->socks[i].fd = -1; - } - (*comm)->nextSock = 0; - return ncclSuccess; -} - -ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) { - if (dev < 0) { // data transfer socket is based on specified dev +ncclResult_t ncclNetSocketListen(int dev, void* opaqueHandle, void** listenComm) { + if (dev < 0 || dev >= ncclNetIfs) { // data transfer socket is based on specified dev return ncclInternalError; } - struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle; - memset(handle, 0, sizeof(struct ncclSocketHandle)); - static_assert(sizeof(struct ncclSocketHandle) <= NCCL_NET_HANDLE_MAXSIZE, "ncclSocketHandle size too large"); - struct ncclSocketListenComm* comm; - NCCLCHECK(ncclSocketNewListenComm(&comm)); - NCCLCHECK(GetSocketAddr(dev, &comm->sock.addr)); - comm->sock.asyncFlag = 1; + struct ncclNetSocketHandle* handle = (struct ncclNetSocketHandle*) opaqueHandle; + memset(handle, 0, sizeof(struct ncclNetSocketHandle)); + static_assert(sizeof(struct ncclNetSocketHandle) <= NCCL_NET_HANDLE_MAXSIZE, "ncclNetSocketHandle size too large"); + struct ncclNetSocketListenComm* comm; + NCCLCHECK(ncclCalloc(&comm, 1)); + handle->magic = NCCL_SOCKET_MAGIC; + NCCLCHECK(ncclSocketInit(&comm->sock, &ncclNetSocketDevs[dev].addr, handle->magic, ncclSocketTypeNetSocket, NULL, 1)); NCCLCHECK(ncclSocketListen(&comm->sock)); - memcpy(&handle->connectAddr, &comm->sock.addr, sizeof(union ncclSocketAddress)); - NCCLCHECK(ncclSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads)); + NCCLCHECK(ncclSocketGetAddr(&comm->sock, &handle->connectAddr)); + NCCLCHECK(ncclNetSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads)); handle->nSocks = comm->nSocks; handle->nThreads = comm->nThreads; comm->dev = dev; @@ -322,46 +301,41 @@ ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) { return ncclSuccess; } -ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) { - if (dev < 0) { // data transfer socket is based on specified dev +ncclResult_t ncclNetSocketConnect(int dev, void* opaqueHandle, void** sendComm) { + if (dev < 0 || dev >= ncclNetIfs) { // data transfer socket is based on specified dev return ncclInternalError; } - enum ncclSocketState conState; - struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle; - struct ncclSocketCommStage* stage = &handle->stage; - struct ncclSocketComm* comm = stage->comm; + int ready; + struct ncclNetSocketHandle* handle = (struct ncclNetSocketHandle*) opaqueHandle; + struct ncclNetSocketCommStage* stage = &handle->stage; + struct ncclNetSocketComm* comm = stage->comm; uint8_t i = stage->iteration; struct ncclSocket* sock = stage->sock; *sendComm = NULL; - if (stage->state == ncclSocketCommStateConnect) goto socket_connect_check; - if (stage->state == ncclSocketCommStateSend) goto socket_send; + if (stage->state == ncclNetSocketCommStateConnect) goto socket_connect_check; + if (stage->state == ncclNetSocketCommStateSend) goto socket_send; - NCCLCHECK(ncclSocketNewComm(&comm)); + NCCLCHECK(ncclCalloc(&comm, 1)); stage->comm = comm; comm->nSocks = handle->nSocks; comm->nThreads = handle->nThreads; comm->dev = dev; CUDACHECK(cudaGetDevice(&comm->cudaDev)); for (; inSocks+1; i++) { - sock = i == comm->nSocks ? &comm->ctrlSock : comm->socks+i; - NCCLCHECK(ncclSocketInit(sock, &handle->connectAddr, NULL, 1)); + sock = (i == comm->nSocks) ? &comm->ctrlSock : comm->socks+i; + NCCLCHECK(ncclSocketInit(sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetSocket, NULL, 1)); stage->sock = sock; - stage->state = ncclSocketCommStateConnect; + stage->state = ncclNetSocketCommStateConnect; stage->iteration = i; NCCLCHECK(ncclSocketConnect(sock)); socket_connect_check: - NCCLCHECK(ncclGetSocketState(sock, &conState)); - if (conState == ncclSocketConnecting) { - /* expect user to call again */ - return ncclSuccess; - } else if (conState == ncclSocketError) { - return ncclRemoteError; - } - stage->state = ncclSocketCommStateSend; + NCCLCHECK(ncclSocketReady(sock, &ready)); + if (! ready) return ncclSuccess; + stage->state = ncclNetSocketCommStateSend; socket_send: int done = 0; @@ -372,59 +346,63 @@ socket_send: return ncclSuccess; } -ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) { - struct ncclSocketListenComm* lComm = (struct ncclSocketListenComm*)listenComm; - struct ncclSocketCommStage* stage = &lComm->stage; - struct ncclSocketComm* rComm = stage->comm; +ncclResult_t ncclNetSocketAccept(void* listenComm, void** recvComm) { + struct ncclNetSocketListenComm* lComm = (struct ncclNetSocketListenComm*)listenComm; + struct ncclNetSocketCommStage* stage = &lComm->stage; + struct ncclNetSocketComm* rComm = stage->comm; uint8_t i = stage->iteration; struct ncclSocket* sock = stage->sock; + int ready; *recvComm = NULL; - if (stage->state == ncclSocketCommStateAccept) goto socket_accept; - if (stage->state == ncclSocketCommStateRecv) goto socket_recv; + if (stage->state == ncclNetSocketCommStateAccept) goto socket_accept_check; + if (stage->state == ncclNetSocketCommStateRecv) goto socket_recv; - NCCLCHECK(ncclSocketNewComm(&rComm)); + NCCLCHECK(ncclCalloc(&rComm, 1)); stage->comm = rComm; rComm->nSocks = lComm->nSocks; rComm->nThreads = lComm->nThreads; rComm->dev = lComm->dev; CUDACHECK(cudaGetDevice(&rComm->cudaDev)); - lComm->sock.asyncFlag = 1; for (; inSocks+1; i++) { uint8_t sendSockIdx; - ncclCalloc(&sock, 1); - NCCLCHECK(ncclSocketInit(sock, NULL, lComm->sock.abortFlag, 1)); - stage->sock = sock; - stage->state = ncclSocketCommStateAccept; - stage->iteration = i; -socket_accept: - NCCLCHECK(ncclSocketAccept(sock, &lComm->sock)); - if (sock->fd == -1) return ncclSuccess; - stage->state = ncclSocketCommStateRecv; + NCCLCHECK(ncclCalloc(&sock, 1)); + NCCLCHECK(ncclSocketInit(sock)); + stage->sock = sock; + stage->state = ncclNetSocketCommStateAccept; + stage->iteration = i; + NCCLCHECK(ncclSocketAccept(sock, &lComm->sock)); + +socket_accept_check: + NCCLCHECK(ncclSocketReady(sock, &ready)); + if (!ready) return ncclSuccess; + + stage->state = ncclNetSocketCommStateRecv; socket_recv: int done = 0; NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &sendSockIdx, sizeof(uint8_t), &done)); if (done == 0) return ncclSuccess; - if (sendSockIdx == rComm->nSocks) memcpy(&rComm->ctrlSock, sock, sizeof(struct ncclSocket)); - else memcpy(rComm->socks+sendSockIdx, sock, sizeof(struct ncclSocket)); - + if (sendSockIdx == rComm->nSocks) + memcpy(&rComm->ctrlSock, sock, sizeof(struct ncclSocket)); + else + memcpy(rComm->socks+sendSockIdx, sock, sizeof(struct ncclSocket)); free(sock); } *recvComm = rComm; /* reset lComm state */ - stage->state = ncclSocketCommStateStart; + stage->state = ncclNetSocketCommStateStart; stage->iteration = 0; stage->sock = NULL; stage->comm = NULL; return ncclSuccess; } -ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketRequest** req) { +ncclResult_t ncclNetSocketGetRequest(struct ncclNetSocketComm* comm, int op, void* data, int size, struct ncclNetSocketRequest** req) { for (int i=0; irequests+i; + struct ncclNetSocketRequest* r = comm->requests+i; if (r->used == 0) { r->op = op; r->data = data; @@ -441,10 +419,10 @@ ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* dat return ncclInternalError; } -ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketTask** req) { +ncclResult_t ncclNetSocketGetTask(struct ncclNetSocketComm* comm, int op, void* data, int size, struct ncclNetSocketTask** req) { int tid = comm->nextSock % comm->nThreads; - struct ncclSocketThreadResources* res = comm->threadResources+tid; - struct ncclSocketTaskQueue* queue = &res->threadTaskQueue; + struct ncclNetSocketThreadResources* res = comm->threadResources+tid; + struct ncclNetSocketTaskQueue* queue = &res->threadTaskQueue; // create helper threads and prepare per-thread task queue if (queue->tasks == NULL) { // each request can be divided up to nSocks tasks, and @@ -459,12 +437,12 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, pthread_create(comm->helperThread+tid, NULL, persistentSocketThread, res); ncclSetThreadName(comm->helperThread[tid], "NCCL Sock%c%1u%2u%2u", op == NCCL_SOCKET_SEND ? 'S' : 'R', comm->dev, tid, comm->cudaDev); } - struct ncclSocketTask* r = queue->tasks+queue->next; + struct ncclNetSocketTask* r = queue->tasks+queue->next; if (r->used == 0) { r->op = op; r->data = data; r->size = size; - r->sock = comm->socks+comm->nextSock; + r->sock = comm->socks + comm->nextSock; r->offset = 0; r->result = ncclSuccess; comm->nextSock = (comm->nextSock + 1) % comm->nSocks; @@ -480,9 +458,9 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, return ncclInternalError; } -ncclResult_t ncclSocketTest(void* request, int* done, int* size) { +ncclResult_t ncclNetSocketTest(void* request, int* done, int* size) { *done = 0; - struct ncclSocketRequest *r = (struct ncclSocketRequest*)request; + struct ncclNetSocketRequest *r = (struct ncclNetSocketRequest*)request; if (r == NULL) { WARN("NET/Socket : test called with NULL request"); return ncclInternalError; @@ -500,9 +478,11 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { // Check size is less or equal to the size provided by the user if (r->op == NCCL_SOCKET_RECV && data > r->size) { char line[SOCKET_NAME_MAXLEN+1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(r->ctrlSock, &addr); WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in healthy state, \ there may be a mismatch in collective sizes or environment settings (e.g. NCCL_PROTO, NCCL_ALGO) between ranks", - ncclSocketToString(&r->ctrlSock->addr, line), data, r->size); + ncclSocketToString(&addr, line), data, r->size); return ncclInvalidUsage; } r->size = data; @@ -515,7 +495,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks)); while (chunkOffset < r->size) { int chunkSize = std::min(taskSize, r->size-chunkOffset); - NCCLCHECK(ncclSocketGetTask(r->comm, r->op, (char*)(r->data)+chunkOffset, chunkSize, r->tasks+i++)); + NCCLCHECK(ncclNetSocketGetTask(r->comm, r->op, (char*)(r->data)+chunkOffset, chunkSize, r->tasks+i++)); chunkOffset += chunkSize; } } @@ -525,7 +505,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { if (r->nSubs > 0) { int nCompleted = 0; for (int i=0; inSubs; i++) { - struct ncclSocketTask* sub = r->tasks[i]; + struct ncclNetSocketTask* sub = r->tasks[i]; if (sub->result != ncclSuccess) return sub->result; if (sub->offset == sub->size) nCompleted++; } @@ -534,7 +514,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { *done = 1; r->used = 0; for (int i=0; inSubs; i++) { - struct ncclSocketTask* sub = r->tasks[i]; + struct ncclNetSocketTask* sub = r->tasks[i]; sub->used = 0; } } @@ -552,43 +532,45 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { return ncclSuccess; } -ncclResult_t ncclSocketRegMr(void* comm, void* data, int size, int type, void** mhandle) { +ncclResult_t ncclNetSocketRegMr(void* comm, void* data, int size, int type, void** mhandle) { return (type != NCCL_PTR_HOST) ? ncclInternalError : ncclSuccess; } -ncclResult_t ncclSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess; } +ncclResult_t ncclNetSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess; } -ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { - struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm; - NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclSocketRequest**)request)); +ncclResult_t ncclNetSocketIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { + struct ncclNetSocketComm* comm = (struct ncclNetSocketComm*)sendComm; + NCCLCHECK(ncclNetSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclNetSocketRequest**)request)); return ncclSuccess; } -ncclResult_t ncclSocketIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { - struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm; +ncclResult_t ncclNetSocketIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { + struct ncclNetSocketComm* comm = (struct ncclNetSocketComm*)recvComm; if (n != 1) return ncclInternalError; - NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_RECV, data[0], sizes[0], (struct ncclSocketRequest**)request)); + NCCLCHECK(ncclNetSocketGetRequest(comm, NCCL_SOCKET_RECV, data[0], sizes[0], (struct ncclNetSocketRequest**)request)); return ncclSuccess; } -ncclResult_t ncclSocketIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { +ncclResult_t ncclNetSocketIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { // We don't support CUDA pointers, so we don't need a flush operation return ncclInternalError; } -ncclResult_t ncclSocketCloseListen(void* opaqueComm) { - struct ncclSocketListenComm* comm = (struct ncclSocketListenComm*)opaqueComm; +ncclResult_t ncclNetSocketCloseListen(void* opaqueComm) { + struct ncclNetSocketListenComm* comm = (struct ncclNetSocketListenComm*)opaqueComm; if (comm) { - if (comm->sock.fd != -1) close(comm->sock.fd); + int ready; + NCCLCHECK(ncclSocketReady(&comm->sock, &ready)); + if (ready) NCCLCHECK(ncclSocketClose(&comm->sock)); free(comm); } return ncclSuccess; } -ncclResult_t ncclSocketClose(void* opaqueComm) { - struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm; +ncclResult_t ncclNetSocketClose(void* opaqueComm) { + struct ncclNetSocketComm* comm = (struct ncclNetSocketComm*)opaqueComm; if (comm) { for (int i=0; inThreads; i++) { - struct ncclSocketThreadResources* res = comm->threadResources+i; + struct ncclNetSocketThreadResources* res = comm->threadResources+i; if (comm->helperThread[i]) { pthread_mutex_lock(&res->threadLock); res->stop = 1; @@ -598,9 +580,12 @@ ncclResult_t ncclSocketClose(void* opaqueComm) { } free(res->threadTaskQueue.tasks); } - if (comm->ctrlSock.fd != -1) close(comm->ctrlSock.fd); + int ready; + NCCLCHECK(ncclSocketReady(&comm->ctrlSock, &ready)); + if (ready) NCCLCHECK(ncclSocketClose(&comm->ctrlSock)); for (int i=0; inSocks; i++) { - if (comm->socks[i].fd != -1) close(comm->socks[i].fd); + NCCLCHECK(ncclSocketReady(&comm->socks[i], &ready)); + if (ready) NCCLCHECK(ncclSocketClose(&comm->socks[i])); } free(comm); } @@ -609,20 +594,20 @@ ncclResult_t ncclSocketClose(void* opaqueComm) { ncclNet_t ncclNetSocket = { "Socket", - ncclSocketInit, - ncclSocketDevices, - ncclSocketGetProperties, - ncclSocketListen, - ncclSocketConnect, - ncclSocketAccept, - ncclSocketRegMr, + ncclNetSocketInit, + ncclNetSocketDevices, + ncclNetSocketGetProperties, + ncclNetSocketListen, + ncclNetSocketConnect, + ncclNetSocketAccept, + ncclNetSocketRegMr, NULL, // No DMA-BUF support - ncclSocketDeregMr, - ncclSocketIsend, - ncclSocketIrecv, - ncclSocketIflush, - ncclSocketTest, - ncclSocketClose, - ncclSocketClose, - ncclSocketCloseListen + ncclNetSocketDeregMr, + ncclNetSocketIsend, + ncclNetSocketIrecv, + ncclNetSocketIflush, + ncclNetSocketTest, + ncclNetSocketClose, + ncclNetSocketClose, + ncclNetSocketCloseListen }; diff --git a/src/transport/p2p.cc b/src/transport/p2p.cc index b0bad4a..e7a4fd0 100644 --- a/src/transport/p2p.cc +++ b/src/transport/p2p.cc @@ -34,6 +34,7 @@ struct p2pProxyInfo { struct p2pShm* devShm; char shmName[7]; int shmSize; + ncclShmHandle_t handle; // Intermediate step for sender struct ncclRecvMem* ceRecvMem; @@ -63,6 +64,7 @@ struct p2pRecvResources { struct p2pShm* shm; struct p2pShm* devShm; int shmSize; + ncclShmHandle_t handle; }; #include @@ -351,9 +353,7 @@ ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectIn sprintf(shmPath, "/dev/shm/nccl-%s", info->shmName); TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmPath, info->shmSize); resources->shmSize = info->shmSize; - NCCLCHECK(ncclShmOpen(shmPath, info->shmSize, (void**)&resources->shm, (void**)&resources->devShm, 0)); - // Remove the file to ensure proper clean-up - NCCLCHECK(ncclShmUnlink(shmPath)); + NCCLCHECK(ncclShmOpen(shmPath, info->shmSize, (void**)&resources->shm, (void**)&resources->devShm, -1, &resources->handle)); recv->conn.tail = &resources->devShm->recvMem.tail; recv->conn.head = &resources->devShm->sendMem.head; @@ -396,7 +396,7 @@ ncclResult_t p2pRecvFree(struct ncclConnector* recv) { if (resources->sendMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->sendMemIpc)); if (resources->recvMemIpc) CUDACHECK(cudaIpcCloseMemHandle(resources->recvMemIpc)); if (useMemcpy) { - NCCLCHECK(ncclShmClose(resources->shm, resources->devShm, resources->shmSize)); + NCCLCHECK(ncclShmClose(resources->handle)); } free(resources); } @@ -414,7 +414,7 @@ static ncclResult_t p2pSendProxySetup(struct ncclProxyConnection* connection, st char shmPath[PATH_MAX]; shmPath[0] = '\0'; proxyInfo->shmSize = sizeof(struct ncclSendMem) + sizeof(struct ncclRecvMem); - NCCLCHECK(ncclShmOpen(shmPath, proxyInfo->shmSize, (void**)&proxyInfo->shm, (void**)&proxyInfo->devShm, 1)); + NCCLCHECK(ncclShmOpen(shmPath, proxyInfo->shmSize, (void**)&proxyInfo->shm, (void**)&proxyInfo->devShm, 1, &proxyInfo->handle)); TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, proxyInfo->shmSize); memcpy(proxyInfo->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(proxyInfo->shmName)); @@ -477,7 +477,7 @@ static ncclResult_t p2pSendProxyFree(struct ncclProxyConnection* connection, str if (useMemcpy) { struct p2pProxyInfo* proxyInfo = (struct p2pProxyInfo*)connection->transportResources; if (proxyInfo) { - NCCLCHECK(ncclShmClose(proxyInfo->shm, proxyInfo->devShm, proxyInfo->shmSize)); + NCCLCHECK(ncclShmClose(proxyInfo->handle)); NCCLCHECK(ncclCudaHostFree(proxyInfo->ceRecvMem)); CUDACHECK(cudaFree(proxyInfo->ceDevBuff)); CUDACHECK(cudaStreamDestroy(proxyInfo->stream)); diff --git a/src/transport/shm.cc b/src/transport/shm.cc index 740bd2a..4bce480 100644 --- a/src/transport/shm.cc +++ b/src/transport/shm.cc @@ -17,18 +17,22 @@ struct shmSendResources { int remShmSize; struct ncclRecvMem* remHostMem; struct ncclRecvMem* devRemHostMem; + ncclShmHandle_t remHandle; int shmSize; struct ncclSendMem* hostMem; struct ncclSendMem* devHostMem; + ncclShmHandle_t hostHandle; }; struct shmRecvResources { int remShmSize; struct ncclSendMem* remHostMem; struct ncclSendMem* devRemHostMem; + ncclShmHandle_t remHandle; int shmSize; struct ncclRecvMem* hostMem; struct ncclRecvMem* devHostMem; + ncclShmHandle_t hostHandle; }; #define SHM_SEND_SIDE 1 @@ -84,7 +88,7 @@ static ncclResult_t shmSendSetup(struct ncclComm* comm, struct ncclTopoGraph* gr for (int p=0; pcomm->buffSizes[p]; } info->shmSize = resources->shmSize = shmSize; - NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1)); + NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1, &resources->hostHandle)); TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, info->shmSize); memcpy(info->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(info->shmName)); @@ -107,7 +111,7 @@ static ncclResult_t shmRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* gr for (int p=0; pcomm->buffSizes[p]; } info->shmSize = resources->shmSize = shmSize; - NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1)); + NCCLCHECK(ncclShmOpen(shmPath, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1, &resources->hostHandle)); TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, info->shmSize); memcpy(info->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(info->shmName)); @@ -137,9 +141,7 @@ static ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* co sprintf(shmPath, "/dev/shm/nccl-%s", info->shmName); resources->remShmSize = info->shmSize; TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmPath, info->shmSize); - NCCLCHECK(ncclShmOpen(shmPath, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, 0)); - // Remove the file to ensure proper clean-up - NCCLCHECK(ncclShmUnlink(shmPath)); + NCCLCHECK(ncclShmOpen(shmPath, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, -1, &resources->remHandle)); char* buff = shmLocality == SHM_SEND_SIDE ? (char*)(resources->devHostMem+1) : (char*)(resources->devRemHostMem+1); for (int p=0; pshmName); resources->remShmSize = info->shmSize; TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmPath, info->shmSize); - NCCLCHECK(ncclShmOpen(shmPath, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, 0)); - NCCLCHECK(ncclShmUnlink(shmPath)); + NCCLCHECK(ncclShmOpen(shmPath, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, -1, &resources->remHandle)); char* buff = shmLocality == SHM_RECV_SIDE ? (char*)(resources->devHostMem+1) : (char*)(resources->devRemHostMem+1); for (int p=0; ptransportResources; if (resources) { - NCCLCHECK(ncclShmClose(resources->hostMem, resources->devHostMem, resources->shmSize)); - NCCLCHECK(ncclShmClose(resources->remHostMem, resources->devRemHostMem, resources->remShmSize)); + NCCLCHECK(ncclShmClose(resources->hostHandle)); + NCCLCHECK(ncclShmClose(resources->remHandle)); free(resources); } return ncclSuccess; @@ -206,8 +207,8 @@ static ncclResult_t shmSendFree(struct ncclConnector* send) { static ncclResult_t shmRecvFree(struct ncclConnector* recv) { struct shmRecvResources* resources = (struct shmRecvResources*)recv->transportResources; if (resources) { - NCCLCHECK(ncclShmClose(resources->hostMem, resources->devHostMem, resources->shmSize)); - NCCLCHECK(ncclShmClose(resources->remHostMem, resources->devRemHostMem, resources->remShmSize)); + NCCLCHECK(ncclShmClose(resources->hostHandle)); + NCCLCHECK(ncclShmClose(resources->remHandle)); free(resources); } return ncclSuccess;