2.4.8-1
Fix #209: improve socket transport performance Split transfers over multiple sockets Launch multiple threads to drive sockets Detect AWS NICs and set nsockets/nthreads accordingly
This commit is contained in:
parent
0ceaec9cee
commit
7c72dee660
@ -1,6 +1,6 @@
|
|||||||
##### version
|
##### version
|
||||||
NCCL_MAJOR := 2
|
NCCL_MAJOR := 2
|
||||||
NCCL_MINOR := 4
|
NCCL_MINOR := 4
|
||||||
NCCL_PATCH := 7
|
NCCL_PATCH := 8
|
||||||
NCCL_SUFFIX :=
|
NCCL_SUFFIX :=
|
||||||
PKG_REVISION := 1
|
PKG_REVISION := 1
|
||||||
|
152
src/bootstrap.cc
152
src/bootstrap.cc
@ -9,37 +9,145 @@
|
|||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
#include "bootstrap.h"
|
#include "bootstrap.h"
|
||||||
#include "net.h"
|
#include "net.h"
|
||||||
|
#include "socket.h"
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
|
|
||||||
// Always use sockets for bootstrap
|
// Always use sockets for bootstrap
|
||||||
ncclNet_t* ncclBootstrapNet = &ncclNetSocket;
|
struct bootstrapNetHandle {
|
||||||
|
union socketAddress connectAddr;
|
||||||
|
};
|
||||||
|
|
||||||
static ncclResult_t bootstrapNetListen(int dev, void* handle, void** listenComm) { NCCLCHECK(ncclBootstrapNet->listen(dev, handle, listenComm)); return ncclSuccess; }
|
struct bootstrapNetComm {
|
||||||
static ncclResult_t bootstrapNetConnect(int dev, void* handle, void** sendComm) { NCCLCHECK(ncclBootstrapNet->connect(dev, handle, sendComm)); return ncclSuccess; }
|
int fd;
|
||||||
static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) { NCCLCHECK(ncclBootstrapNet->accept(listenComm, recvComm)); return ncclSuccess; }
|
};
|
||||||
static ncclResult_t bootstrapNetTest(void* request, int* done, int* size) { NCCLCHECK(ncclBootstrapNet->test(request, done, size)); return ncclSuccess; }
|
|
||||||
static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(ncclBootstrapNet->closeSend(sendComm)); return ncclSuccess; }
|
|
||||||
static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(ncclBootstrapNet->closeRecv(recvComm)); return ncclSuccess; }
|
|
||||||
static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(ncclBootstrapNet->closeListen(listenComm)); return ncclSuccess; }
|
|
||||||
|
|
||||||
// Additional sync functions based on async + test for bootstrap, using host ptrs.
|
/* Init functions */
|
||||||
|
static char bootstrapNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
|
||||||
|
static union socketAddress bootstrapNetIfAddrs[MAX_IFS];
|
||||||
|
static int bootstrapNetIfs = -1;
|
||||||
|
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
|
||||||
|
|
||||||
|
ncclResult_t bootstrapNetInit() {
|
||||||
|
if (bootstrapNetIfs == -1) {
|
||||||
|
pthread_mutex_lock(&bootstrapNetLock);
|
||||||
|
if (bootstrapNetIfs == -1) {
|
||||||
|
bootstrapNetIfs = findInterfaces(bootstrapNetIfNames, bootstrapNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS);
|
||||||
|
if (bootstrapNetIfs <= 0) {
|
||||||
|
WARN("Bootstrap : no socket interface found");
|
||||||
|
return ncclInternalError;
|
||||||
|
} else {
|
||||||
|
char line[1024];
|
||||||
|
char addrline[1024];
|
||||||
|
line[0] = '\0';
|
||||||
|
for (int i=0; i<bootstrapNetIfs; i++) {
|
||||||
|
snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, bootstrapNetIfNames+i*MAX_IF_NAME_SIZE,
|
||||||
|
socketToString(&bootstrapNetIfAddrs[i].sa, addrline));
|
||||||
|
}
|
||||||
|
line[1023] = '\0';
|
||||||
|
INFO(NCCL_INIT, "Bootstrap : Using%s", line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pthread_mutex_unlock(&bootstrapNetLock);
|
||||||
|
}
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ncclResult_t bootstrapNetNewComm(struct bootstrapNetComm** comm) {
|
||||||
|
NCCLCHECK(ncclCalloc(comm, 1));
|
||||||
|
(*comm)->fd = -1;
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ncclResult_t bootstrapNetGetSocketAddr(int dev, union socketAddress* addr) {
|
||||||
|
if (dev >= bootstrapNetIfs) return ncclInternalError;
|
||||||
|
memcpy(addr, bootstrapNetIfAddrs+dev, sizeof(*addr));
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Socket Interface Selection type */
|
||||||
|
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
|
||||||
|
|
||||||
|
static ncclResult_t bootstrapNetListen(int dev, void* opaqueHandle, void** listenComm) {
|
||||||
|
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
|
||||||
|
static_assert(sizeof(struct bootstrapNetHandle) < NCCL_NET_HANDLE_MAXSIZE, "bootstrapNetHandle size too large");
|
||||||
|
// if dev >= 0, listen based on dev
|
||||||
|
if (dev >= 0) {
|
||||||
|
NCCLCHECK(bootstrapNetGetSocketAddr(dev, &(handle->connectAddr)));
|
||||||
|
} else if (dev == findSubnetIf) {
|
||||||
|
// handle stores a remote address
|
||||||
|
// need to find a local addr that is in the same network as the remote addr
|
||||||
|
union socketAddress localAddr;
|
||||||
|
char ifName[MAX_IF_NAME_SIZE];
|
||||||
|
if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
|
||||||
|
WARN("NET/Socket : No usable listening interface found");
|
||||||
|
return ncclSystemError;
|
||||||
|
}
|
||||||
|
// pass the local address back
|
||||||
|
memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr));
|
||||||
|
} // Otherwise, handle stores a local address
|
||||||
|
struct bootstrapNetComm* comm;
|
||||||
|
NCCLCHECK(bootstrapNetNewComm(&comm));
|
||||||
|
NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr));
|
||||||
|
*listenComm = comm;
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ncclResult_t bootstrapNetConnect(int dev, void* opaqueHandle, void** sendComm) {
|
||||||
|
struct bootstrapNetComm* comm;
|
||||||
|
NCCLCHECK(bootstrapNetNewComm(&comm));
|
||||||
|
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
|
||||||
|
NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
|
||||||
|
*sendComm = comm;
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) {
|
||||||
|
struct bootstrapNetComm* lComm = (struct bootstrapNetComm*)listenComm;
|
||||||
|
struct bootstrapNetComm* rComm;
|
||||||
|
NCCLCHECK(bootstrapNetNewComm(&rComm));
|
||||||
|
struct sockaddr_in sockaddr;
|
||||||
|
socklen_t socklen = sizeof(struct sockaddr_in);
|
||||||
|
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
|
||||||
|
*recvComm = rComm;
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ncclResult_t bootstrapNetClose(void* opaqueComm) {
|
||||||
|
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)opaqueComm;
|
||||||
|
if (comm) {
|
||||||
|
close(comm->fd);
|
||||||
|
free(comm);
|
||||||
|
}
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(bootstrapNetClose(sendComm)); return ncclSuccess; }
|
||||||
|
static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(bootstrapNetClose(recvComm)); return ncclSuccess; }
|
||||||
|
static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(bootstrapNetClose(listenComm)); return ncclSuccess; }
|
||||||
|
|
||||||
|
// Additional sync functions
|
||||||
static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {
|
static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {
|
||||||
void* request, *mhandle;
|
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm;
|
||||||
NCCLCHECK(ncclBootstrapNet->regMr(sendComm, data, size, NCCL_PTR_HOST, &mhandle));
|
NCCLCHECK(socketSend(comm->fd, &size, sizeof(int)));
|
||||||
NCCLCHECK(ncclBootstrapNet->isend(sendComm, data, size, mhandle, &request));
|
NCCLCHECK(socketSend(comm->fd, data, size));
|
||||||
NCCLCHECK(ncclBootstrapNet->deregMr(sendComm, mhandle));
|
|
||||||
int done = 0;
|
|
||||||
while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL));
|
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) {
|
static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) {
|
||||||
void* request, *mhandle;
|
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)recvComm;
|
||||||
NCCLCHECK(ncclBootstrapNet->regMr(recvComm, data, size, NCCL_PTR_HOST, &mhandle));
|
int recvSize;
|
||||||
NCCLCHECK(ncclBootstrapNet->irecv(recvComm, data, size, mhandle, &request));
|
NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int)));
|
||||||
NCCLCHECK(ncclBootstrapNet->deregMr(recvComm, mhandle));
|
if (recvSize > size) {
|
||||||
int done = 0;
|
WARN("Message truncated : received %d bytes instead of %d\n", recvSize, size);
|
||||||
while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL));
|
return ncclInternalError;
|
||||||
|
}
|
||||||
|
NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size)));
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t bootstrapNetCreateHandle(void* opaqueHandle, const char* str) {
|
||||||
|
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
|
||||||
|
NCCLCHECK(GetSocketAddrFromString(&handle->connectAddr, str));
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,7 +256,7 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out) {
|
|||||||
|
|
||||||
char* env = getenv("NCCL_COMM_ID");
|
char* env = getenv("NCCL_COMM_ID");
|
||||||
if (env) {
|
if (env) {
|
||||||
if (ncclSocketCreateHandle(&id->extHandleRoot, env) != 0) {
|
if (bootstrapNetCreateHandle(&id->extHandleRoot, env) != 0) {
|
||||||
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
|
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
|
||||||
return ncclInvalidArgument;
|
return ncclInvalidArgument;
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include "nccl.h"
|
#include "nccl.h"
|
||||||
|
|
||||||
|
ncclResult_t bootstrapNetInit();
|
||||||
ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv);
|
ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv);
|
||||||
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out);
|
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out);
|
||||||
ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commState);
|
ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commState);
|
||||||
|
@ -13,11 +13,6 @@
|
|||||||
extern ncclNet_t* ncclNet;
|
extern ncclNet_t* ncclNet;
|
||||||
typedef char ncclNetHandle_t[NCCL_NET_HANDLE_MAXSIZE];
|
typedef char ncclNetHandle_t[NCCL_NET_HANDLE_MAXSIZE];
|
||||||
|
|
||||||
/* Socket Interface Selection type */
|
|
||||||
typedef enum { findSubnetIf = -1,
|
|
||||||
dontCareIf = -2
|
|
||||||
} ncclSocketIfSl_t;
|
|
||||||
|
|
||||||
// Translation to external API
|
// Translation to external API
|
||||||
static const char* ncclNetName() { return ncclNet->name; }
|
static const char* ncclNetName() { return ncclNet->name; }
|
||||||
static ncclResult_t ncclNetDevices(int* ndev) { NCCLCHECK(ncclNet->devices(ndev)); return ncclSuccess; }
|
static ncclResult_t ncclNetDevices(int* ndev) { NCCLCHECK(ncclNet->devices(ndev)); return ncclSuccess; }
|
||||||
@ -36,7 +31,6 @@ static ncclResult_t ncclNetCloseSend(void* sendComm) { NCCLCHECK(ncclNet->closeS
|
|||||||
static ncclResult_t ncclNetCloseRecv(void* recvComm) { NCCLCHECK(ncclNet->closeRecv(recvComm)); return ncclSuccess; }
|
static ncclResult_t ncclNetCloseRecv(void* recvComm) { NCCLCHECK(ncclNet->closeRecv(recvComm)); return ncclSuccess; }
|
||||||
static ncclResult_t ncclNetCloseListen(void* listenComm) { NCCLCHECK(ncclNet->closeListen(listenComm)); return ncclSuccess; }
|
static ncclResult_t ncclNetCloseListen(void* listenComm) { NCCLCHECK(ncclNet->closeListen(listenComm)); return ncclSuccess; }
|
||||||
|
|
||||||
extern ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str);
|
|
||||||
extern ncclNet_t ncclNetIb;
|
extern ncclNet_t ncclNetIb;
|
||||||
extern ncclNet_t ncclNetSocket;
|
extern ncclNet_t ncclNetSocket;
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
|
|||||||
return buf;
|
return buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline short socketToPort(struct sockaddr *saddr) {
|
static inline uint16_t socketToPort(struct sockaddr *saddr) {
|
||||||
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
|
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,7 +161,10 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) {
|
static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) {
|
||||||
char line[1024], line_a[1024];
|
#ifdef ENABLE_TRACE
|
||||||
|
char line[1024];
|
||||||
|
#endif
|
||||||
|
char line_a[1024];
|
||||||
int found = 0;
|
int found = 0;
|
||||||
struct ifaddrs *interfaces, *interface;
|
struct ifaddrs *interfaces, *interface;
|
||||||
getifaddrs(&interfaces);
|
getifaddrs(&interfaces);
|
||||||
@ -185,7 +188,7 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd
|
|||||||
// Store the interface name
|
// Store the interface name
|
||||||
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
|
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
|
||||||
|
|
||||||
INFO(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
|
TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
|
||||||
found++;
|
found++;
|
||||||
if (found == maxIfs) break;
|
if (found == maxIfs) break;
|
||||||
}
|
}
|
||||||
@ -390,12 +393,12 @@ retry:
|
|||||||
|
|
||||||
#define NCCL_SOCKET_SEND 0
|
#define NCCL_SOCKET_SEND 0
|
||||||
#define NCCL_SOCKET_RECV 1
|
#define NCCL_SOCKET_RECV 1
|
||||||
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
|
static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) {
|
||||||
int bytes = 0;
|
int bytes = 0;
|
||||||
char* data = (char*)ptr;
|
char* data = (char*)ptr;
|
||||||
do {
|
do {
|
||||||
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
|
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
|
||||||
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
|
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
|
||||||
if (op == NCCL_SOCKET_RECV && bytes == 0) {
|
if (op == NCCL_SOCKET_RECV && bytes == 0) {
|
||||||
WARN("Net : Connection closed by remote peer");
|
WARN("Net : Connection closed by remote peer");
|
||||||
return ncclSystemError;
|
return ncclSystemError;
|
||||||
@ -413,9 +416,13 @@ static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* off
|
|||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
|
||||||
|
return socketProgressOpt(op, fd, ptr, size, offset, 0);
|
||||||
|
}
|
||||||
|
|
||||||
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
|
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
|
||||||
while (*offset < size)
|
while (*offset < size)
|
||||||
NCCLCHECK(socketProgress(op, fd, ptr, size, offset));
|
NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1));
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,14 +124,15 @@ cleanup:
|
|||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t initNet() {
|
ncclResult_t initNet() {
|
||||||
// Always initialize sockets as we use it for bootstrap
|
// Always initialize bootstrap network
|
||||||
NCCLCHECK(initNet(&ncclNetSocket));
|
NCCLCHECK(bootstrapNetInit());
|
||||||
|
|
||||||
NCCLCHECK(initNetPlugin(&ncclNet));
|
NCCLCHECK(initNetPlugin(&ncclNet));
|
||||||
if (ncclNet != NULL) return ncclSuccess;
|
if (ncclNet != NULL) return ncclSuccess;
|
||||||
if (initNet(&ncclNetIb) == ncclSuccess) {
|
if (initNet(&ncclNetIb) == ncclSuccess) {
|
||||||
ncclNet = &ncclNetIb;
|
ncclNet = &ncclNetIb;
|
||||||
} else {
|
} else {
|
||||||
|
NCCLCHECK(initNet(&ncclNetSocket));
|
||||||
ncclNet = &ncclNetSocket;
|
ncclNet = &ncclNetSocket;
|
||||||
}
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include "core.h"
|
#include "core.h"
|
||||||
#include "socket.h"
|
#include "socket.h"
|
||||||
#include "net.h"
|
#include "net.h"
|
||||||
|
#include "param.h"
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <pthread.h>
|
#include <pthread.h>
|
||||||
@ -15,6 +16,7 @@
|
|||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <poll.h>
|
#include <poll.h>
|
||||||
#include <limits.h>
|
#include <limits.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
|
||||||
/* Init functions */
|
/* Init functions */
|
||||||
static char ncclNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
|
static char ncclNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
|
||||||
@ -68,7 +70,7 @@ ncclResult_t ncclSocketPciPath(int dev, char** path) {
|
|||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
|
ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
|
||||||
if (dev >= ncclNetIfs) return ncclInternalError;
|
if (dev >= ncclNetIfs) return ncclInternalError;
|
||||||
memcpy(addr, ncclNetIfAddrs+dev, sizeof(*addr));
|
memcpy(addr, ncclNetIfAddrs+dev, sizeof(*addr));
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
@ -76,105 +78,281 @@ static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
|
|||||||
|
|
||||||
/* Communication functions */
|
/* Communication functions */
|
||||||
|
|
||||||
|
#define MAX_SOCKETS 64
|
||||||
|
#define MAX_THREADS 16
|
||||||
|
#define MAX_REQUESTS 128
|
||||||
|
#define MAX_QUEUE_LEN MAX_REQUESTS
|
||||||
|
#define MIN_CHUNKSIZE (64*1024)
|
||||||
|
|
||||||
|
NCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2);
|
||||||
|
NCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2);
|
||||||
|
|
||||||
struct ncclSocketHandle {
|
struct ncclSocketHandle {
|
||||||
union socketAddress connectAddr;
|
union socketAddress connectAddr;
|
||||||
|
int nSocks;
|
||||||
|
int nThreads;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ncclSocketRequest {
|
struct ncclSocketTask {
|
||||||
int op;
|
int op;
|
||||||
void* data;
|
void* data;
|
||||||
int size;
|
int size;
|
||||||
int fd;
|
int fd;
|
||||||
int offset;
|
int offset;
|
||||||
int used;
|
int used;
|
||||||
|
ncclResult_t result;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ncclSocketReqs {
|
struct ncclSocketRequest {
|
||||||
struct ncclSocketRequest* requests;
|
int op;
|
||||||
|
void* data;
|
||||||
|
int size;
|
||||||
|
int ctrlFd;
|
||||||
|
int used;
|
||||||
|
struct ncclSocketComm* comm;
|
||||||
|
struct ncclSocketTask* tasks[MAX_SOCKETS];
|
||||||
|
int nSubs;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ncclSocketTaskQueue {
|
||||||
|
int next;
|
||||||
|
struct ncclSocketTask* tasks;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum threadState {start, stop};
|
||||||
|
|
||||||
|
struct ncclSocketThreadResources {
|
||||||
|
struct ncclSocketTaskQueue threadTaskQueue;
|
||||||
|
enum threadState state;
|
||||||
|
struct ncclSocketComm* comm;
|
||||||
|
pthread_mutex_t threadLock;
|
||||||
|
pthread_cond_t threadCond;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ncclSocketListenComm {
|
||||||
|
int fd;
|
||||||
|
int nSocks;
|
||||||
|
int nThreads;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ncclSocketComm {
|
struct ncclSocketComm {
|
||||||
int fd;
|
int ctrlFd;
|
||||||
struct ncclSocketReqs reqs;
|
int fds[MAX_SOCKETS];
|
||||||
|
int nSocks;
|
||||||
|
int nThreads;
|
||||||
|
int nextFd;
|
||||||
|
struct ncclSocketRequest requests[MAX_REQUESTS];
|
||||||
|
pthread_t helperThread[MAX_THREADS];
|
||||||
|
struct ncclSocketThreadResources threadResources[MAX_THREADS];
|
||||||
};
|
};
|
||||||
|
|
||||||
ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) {
|
void* persistentSocketThread(void *args_) {
|
||||||
|
struct ncclSocketThreadResources* resource = (struct ncclSocketThreadResources*)args_;
|
||||||
|
struct ncclSocketComm* comm = resource->comm;
|
||||||
|
volatile enum threadState* state = &resource->state;
|
||||||
|
struct ncclSocketTaskQueue* myQueue = &resource->threadTaskQueue;
|
||||||
|
int nSocksPerThread = comm->nSocks / comm->nThreads;
|
||||||
|
while (1) {
|
||||||
|
int idle = 1;
|
||||||
|
int mark = myQueue->next; // mark newest task seen
|
||||||
|
for (int i=0; i<MAX_QUEUE_LEN; i+=nSocksPerThread) {
|
||||||
|
int repeat;
|
||||||
|
do {
|
||||||
|
repeat = 0;
|
||||||
|
for (int j=0; j<nSocksPerThread; j++) {
|
||||||
|
struct ncclSocketTask* r = myQueue->tasks+i+j;
|
||||||
|
if (r != NULL && r->used == 1 && r->offset < r->size) {
|
||||||
|
r->result = socketProgress(r->op, r->fd, r->data, r->size, &r->offset);
|
||||||
|
if (r->result != ncclSuccess) {
|
||||||
|
WARN("NET/Socket : socket progress error");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
idle = 0;
|
||||||
|
if (r->offset < r->size) repeat = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (repeat);
|
||||||
|
}
|
||||||
|
if (idle) {
|
||||||
|
pthread_mutex_lock(&resource->threadLock);
|
||||||
|
while (mark == myQueue->next && *state != stop) { // no new tasks, wait
|
||||||
|
pthread_cond_wait(&resource->threadCond, &resource->threadLock);
|
||||||
|
}
|
||||||
|
pthread_mutex_unlock(&resource->threadLock);
|
||||||
|
}
|
||||||
|
if (*state == stop) return NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) {
|
||||||
|
int nSocksPerThread = ncclParamSocketNsocksPerThread();
|
||||||
|
int nThreads = ncclParamSocketNthreads();
|
||||||
|
if (nThreads > MAX_THREADS) {
|
||||||
|
WARN("NET/Socket : NCCL_SOCKET_NTHREADS is greater than the maximum allowed, setting to %d", MAX_THREADS);
|
||||||
|
nThreads = MAX_THREADS;
|
||||||
|
}
|
||||||
|
if (nThreads == -2 || nSocksPerThread == -2) {
|
||||||
|
// Auto-detection
|
||||||
|
int autoNt=1, autoNs=1;
|
||||||
|
char vendorPath[PATH_MAX];
|
||||||
|
snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclNetIfNames+dev*MAX_IF_NAME_SIZE);
|
||||||
|
char* rPath = realpath(vendorPath, NULL);
|
||||||
|
int fd = open(rPath, O_RDONLY);
|
||||||
|
free(rPath);
|
||||||
|
if (fd == -1) {
|
||||||
|
// Could not find device vendor. This is handled silently so
|
||||||
|
// we don't want to print an INFO error.
|
||||||
|
TRACE(NCCL_NET, "Open of %s failed : %s\n", vendorPath, strerror(errno));
|
||||||
|
goto end;
|
||||||
|
}
|
||||||
|
char vendor[7];
|
||||||
|
strncpy(vendor, "0x0000", 7);
|
||||||
|
int len;
|
||||||
|
SYSCHECKVAL(read(fd, vendor, 6), "read", len);
|
||||||
|
SYSCHECK(close(fd), "close");
|
||||||
|
if (strcmp(vendor, "0x1d0f") == 0) { // AWS
|
||||||
|
autoNt = 2;
|
||||||
|
autoNs = 8;
|
||||||
|
}
|
||||||
|
end:
|
||||||
|
if (nThreads == -2) nThreads = autoNt;
|
||||||
|
if (nSocksPerThread == -2) nSocksPerThread = autoNs;
|
||||||
|
}
|
||||||
|
int nSocks = nSocksPerThread * nThreads;
|
||||||
|
if (nSocks > MAX_SOCKETS) {
|
||||||
|
nSocksPerThread = MAX_SOCKETS/nThreads;
|
||||||
|
WARN("NET/Socket : the total number of sockets is greater than the maximum allowed, setting NCCL_NSOCKS_PERTHREAD to %d", nSocksPerThread);
|
||||||
|
nSocks = nSocksPerThread * nThreads;
|
||||||
|
}
|
||||||
|
*ns = nSocks;
|
||||||
|
*nt = nThreads;
|
||||||
|
INFO(NCCL_INIT, "NET/Socket: Using %d threads and %d sockets per thread", nThreads, nSocksPerThread);
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t ncclSocketNewListenComm(struct ncclSocketListenComm** comm) {
|
||||||
NCCLCHECK(ncclCalloc(comm, 1));
|
NCCLCHECK(ncclCalloc(comm, 1));
|
||||||
(*comm)->fd = -1;
|
(*comm)->fd = -1;
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str) {
|
ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) {
|
||||||
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
|
NCCLCHECK(ncclCalloc(comm, 1));
|
||||||
NCCLCHECK(GetSocketAddrFromString(&(handle->connectAddr), str));
|
(*comm)->ctrlFd = -1;
|
||||||
|
for (int i=0; i < MAX_SOCKETS; i++) {
|
||||||
|
(*comm)->fds[i] = -1;
|
||||||
|
}
|
||||||
|
(*comm)->nextFd = 0;
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) {
|
ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) {
|
||||||
|
if (dev < 0) { // data transfer socket is based on specified dev
|
||||||
|
return ncclInternalError;
|
||||||
|
}
|
||||||
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
|
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
|
||||||
static_assert(sizeof(struct ncclSocketHandle) < NCCL_NET_HANDLE_MAXSIZE, "ncclSocketHandle size too large");
|
static_assert(sizeof(struct ncclSocketHandle) < NCCL_NET_HANDLE_MAXSIZE, "ncclSocketHandle size too large");
|
||||||
// if dev >= 0, listen based on dev
|
struct ncclSocketListenComm* comm;
|
||||||
if (dev >= 0) {
|
NCCLCHECK(ncclSocketNewListenComm(&comm));
|
||||||
NCCLCHECK(GetSocketAddr(dev, &(handle->connectAddr)));
|
NCCLCHECK(GetSocketAddr(dev, &handle->connectAddr));
|
||||||
} else if (dev == findSubnetIf) {
|
|
||||||
// handle stores a remote address
|
|
||||||
// need to find a local addr that is in the same network as the remote addr
|
|
||||||
union socketAddress localAddr;
|
|
||||||
char ifName[MAX_IF_NAME_SIZE];
|
|
||||||
if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
|
|
||||||
WARN("NET/Socket : No usable listening interface found");
|
|
||||||
return ncclSystemError;
|
|
||||||
}
|
|
||||||
// pass the local address back
|
|
||||||
memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr));
|
|
||||||
} // Otherwise, handle stores a local address
|
|
||||||
struct ncclSocketComm* comm;
|
|
||||||
NCCLCHECK(ncclSocketNewComm(&comm));
|
|
||||||
NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr));
|
NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr));
|
||||||
|
NCCLCHECK(ncclSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads));
|
||||||
|
handle->nSocks = comm->nSocks;
|
||||||
|
handle->nThreads = comm->nThreads;
|
||||||
*listenComm = comm;
|
*listenComm = comm;
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) {
|
ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) {
|
||||||
|
if (dev < 0) { // data transfer socket is based on specified dev
|
||||||
|
return ncclInternalError;
|
||||||
|
}
|
||||||
struct ncclSocketComm* comm;
|
struct ncclSocketComm* comm;
|
||||||
NCCLCHECK(ncclSocketNewComm(&comm));
|
NCCLCHECK(ncclSocketNewComm(&comm));
|
||||||
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
|
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
|
||||||
NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
|
comm->nSocks = handle->nSocks;
|
||||||
|
comm->nThreads = handle->nThreads;
|
||||||
|
for (int i=0; i<comm->nSocks+1; i++) {
|
||||||
|
int tmpFd, offset=0;
|
||||||
|
NCCLCHECK(connectAddress(&tmpFd, &handle->connectAddr));
|
||||||
|
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset));
|
||||||
|
if (i == comm->nSocks) comm->ctrlFd = tmpFd;
|
||||||
|
else comm->fds[i] = tmpFd;
|
||||||
|
}
|
||||||
*sendComm = comm;
|
*sendComm = comm;
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) {
|
ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) {
|
||||||
struct ncclSocketComm* lComm = (struct ncclSocketComm*)listenComm;
|
struct ncclSocketListenComm* lComm = (struct ncclSocketListenComm*)listenComm;
|
||||||
struct ncclSocketComm* rComm;
|
struct ncclSocketComm* rComm;
|
||||||
NCCLCHECK(ncclSocketNewComm(&rComm));
|
NCCLCHECK(ncclSocketNewComm(&rComm));
|
||||||
struct sockaddr_in sockaddr;
|
rComm->nSocks = lComm->nSocks;
|
||||||
socklen_t socklen = sizeof(struct sockaddr_in);
|
rComm->nThreads = lComm->nThreads;
|
||||||
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
|
for (int i=0; i<rComm->nSocks+1; i++) {
|
||||||
|
int tmpFd, sendSockIdx, offset=0;
|
||||||
|
struct sockaddr_in sockaddr;
|
||||||
|
socklen_t socklen = sizeof(struct sockaddr_in);
|
||||||
|
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", tmpFd);
|
||||||
|
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &sendSockIdx, sizeof(int), &offset));
|
||||||
|
if (sendSockIdx == rComm->nSocks) rComm->ctrlFd = tmpFd;
|
||||||
|
else rComm->fds[sendSockIdx] = tmpFd;
|
||||||
|
}
|
||||||
*recvComm = rComm;
|
*recvComm = rComm;
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MAX_REQUESTS 128
|
ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketRequest** req) {
|
||||||
|
|
||||||
ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, int op, void* data, int size, int fd, struct ncclSocketRequest** req) {
|
|
||||||
if (reqs->requests == NULL) {
|
|
||||||
NCCLCHECK(ncclCalloc(&reqs->requests, MAX_REQUESTS));
|
|
||||||
}
|
|
||||||
for (int i=0; i<MAX_REQUESTS; i++) {
|
for (int i=0; i<MAX_REQUESTS; i++) {
|
||||||
struct ncclSocketRequest* r = reqs->requests+i;
|
struct ncclSocketRequest* r = comm->requests+i;
|
||||||
if (r->used == 0) {
|
if (r->used == 0) {
|
||||||
r->op = op;
|
r->op = op;
|
||||||
r->data = data;
|
r->data = data;
|
||||||
r->size = size;
|
r->size = size;
|
||||||
r->fd = fd;
|
r->ctrlFd = comm->ctrlFd;
|
||||||
r->offset = -1;
|
|
||||||
r->used = 1;
|
r->used = 1;
|
||||||
|
r->comm = comm;
|
||||||
|
r->nSubs = 0;
|
||||||
*req = r;
|
*req = r;
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
WARN("Socket : unable to allocate requests");
|
WARN("NET/Socket : unable to allocate requests");
|
||||||
|
return ncclInternalError;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketTask** req) {
|
||||||
|
int tid = comm->nextFd % comm->nThreads;
|
||||||
|
struct ncclSocketThreadResources* res = comm->threadResources+tid;
|
||||||
|
struct ncclSocketTaskQueue* queue = &res->threadTaskQueue;
|
||||||
|
// create helper threads and prepare per-thread task queue
|
||||||
|
if (queue->tasks == NULL) {
|
||||||
|
NCCLCHECK(ncclCalloc(&queue->tasks, MAX_QUEUE_LEN));
|
||||||
|
queue->next = 0;
|
||||||
|
res->comm = comm;
|
||||||
|
pthread_mutex_init(&res->threadLock, NULL);
|
||||||
|
pthread_cond_init(&res->threadCond, NULL);
|
||||||
|
pthread_create(comm->helperThread+tid, NULL, persistentSocketThread, res);
|
||||||
|
}
|
||||||
|
struct ncclSocketTask* r = queue->tasks+queue->next;
|
||||||
|
if (r->used == 0) {
|
||||||
|
r->op = op;
|
||||||
|
r->data = data;
|
||||||
|
r->size = size;
|
||||||
|
r->fd = comm->fds[comm->nextFd];
|
||||||
|
r->offset = 0;
|
||||||
|
r->result = ncclSuccess;
|
||||||
|
comm->nextFd = (comm->nextFd + 1) % comm->nSocks;
|
||||||
|
r->used = 1;
|
||||||
|
*req = r;
|
||||||
|
pthread_mutex_lock(&res->threadLock);
|
||||||
|
queue->next = (queue->next+1)%MAX_QUEUE_LEN;
|
||||||
|
res->state = start;
|
||||||
|
pthread_cond_signal(&res->threadCond);
|
||||||
|
pthread_mutex_unlock(&res->threadLock);
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
WARN("NET/Socket : unable to allocate subtasks");
|
||||||
return ncclInternalError;
|
return ncclInternalError;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,15 +363,15 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
|
|||||||
WARN("NET/Socket : test called with NULL request");
|
WARN("NET/Socket : test called with NULL request");
|
||||||
return ncclInternalError;
|
return ncclInternalError;
|
||||||
}
|
}
|
||||||
if (r->offset == -1) { /* try to send/recv size */
|
if (r->used == 1) { /* try to send/recv size */
|
||||||
int data = r->size;
|
int data = r->size;
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
NCCLCHECK(socketProgress(r->op, r->fd, &data, sizeof(int), &offset));
|
NCCLCHECK(socketProgress(r->op, r->ctrlFd, &data, sizeof(int), &offset));
|
||||||
|
|
||||||
if (offset == 0) return ncclSuccess; /* Not ready -- retry later */
|
if (offset == 0) return ncclSuccess; /* Not ready -- retry later */
|
||||||
|
|
||||||
// Not sure we could ever receive less than 4 bytes, but just in case ...
|
// Not sure we could ever receive less than 4 bytes, but just in case ...
|
||||||
if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->fd, &data, sizeof(int), &offset));
|
if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, &data, sizeof(int), &offset));
|
||||||
|
|
||||||
// Check size is less or equal to the size provided by the user
|
// Check size is less or equal to the size provided by the user
|
||||||
if (r->op == NCCL_SOCKET_RECV && data > r->size) {
|
if (r->op == NCCL_SOCKET_RECV && data > r->size) {
|
||||||
@ -201,15 +379,33 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
|
|||||||
return ncclInternalError;
|
return ncclInternalError;
|
||||||
}
|
}
|
||||||
r->size = data;
|
r->size = data;
|
||||||
r->offset = 0;
|
r->used = 2; // done exchanging size
|
||||||
|
// divide into subtasks
|
||||||
|
int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
|
||||||
|
int chunkOffset = 0, i = 0;
|
||||||
|
while (chunkOffset < r->size) {
|
||||||
|
int chunkSize = std::min(taskSize, r->size-chunkOffset);
|
||||||
|
NCCLCHECK(ncclSocketGetTask(r->comm, r->op, (char*)(r->data)+chunkOffset, chunkSize, r->tasks+i++));
|
||||||
|
chunkOffset += chunkSize;
|
||||||
|
}
|
||||||
|
r->nSubs = i;
|
||||||
}
|
}
|
||||||
if (r->offset < r->size) {
|
if (r->used == 2) { // already exchanged size
|
||||||
NCCLCHECK(socketProgress(r->op, r->fd, r->data, r->size, &r->offset));
|
int nCompleted = 0;
|
||||||
}
|
for (int i=0; i<r->nSubs; i++) {
|
||||||
if (r->offset == r->size) {
|
struct ncclSocketTask* sub = r->tasks[i];
|
||||||
if (size) *size = r->size;
|
if (sub->result != ncclSuccess) return sub->result;
|
||||||
*done = 1;
|
if (sub->offset == sub->size) nCompleted++;
|
||||||
r->used = 0;
|
}
|
||||||
|
if (nCompleted == r->nSubs) {
|
||||||
|
if (size) *size = r->size;
|
||||||
|
*done = 1;
|
||||||
|
r->used = 0;
|
||||||
|
for (int i=0; i<r->nSubs; i++) {
|
||||||
|
struct ncclSocketTask* sub = r->tasks[i];
|
||||||
|
sub->used = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
@ -221,13 +417,13 @@ ncclResult_t ncclSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess;
|
|||||||
|
|
||||||
ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, void* mhandle, void** request) {
|
ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, void* mhandle, void** request) {
|
||||||
struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm;
|
struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm;
|
||||||
NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_SEND, data, size, comm->fd, (struct ncclSocketRequest**)request));
|
NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclSocketRequest**)request));
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) {
|
ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) {
|
||||||
struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm;
|
struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm;
|
||||||
NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_RECV, data, size, comm->fd, (struct ncclSocketRequest**)request));
|
NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_RECV, data, size, (struct ncclSocketRequest**)request));
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,11 +432,33 @@ ncclResult_t ncclSocketFlush(void* recvComm, void* data, int size, void* mhandle
|
|||||||
return ncclInternalError;
|
return ncclInternalError;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ncclResult_t ncclSocketCloseListen(void* opaqueComm) {
|
||||||
|
struct ncclSocketListenComm* comm = (struct ncclSocketListenComm*)opaqueComm;
|
||||||
|
if (comm) {
|
||||||
|
if (comm->fd != -1) close(comm->fd);
|
||||||
|
free(comm);
|
||||||
|
}
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
ncclResult_t ncclSocketClose(void* opaqueComm) {
|
ncclResult_t ncclSocketClose(void* opaqueComm) {
|
||||||
struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm;
|
struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm;
|
||||||
if (comm) {
|
if (comm) {
|
||||||
free(comm->reqs.requests);
|
for (int i=0; i<comm->nThreads; i++) {
|
||||||
close(comm->fd);
|
struct ncclSocketThreadResources* res = comm->threadResources+i;
|
||||||
|
if (comm->helperThread[i]) {
|
||||||
|
pthread_mutex_lock(&res->threadLock);
|
||||||
|
res->state = stop;
|
||||||
|
pthread_cond_signal(&res->threadCond);
|
||||||
|
pthread_mutex_unlock(&res->threadLock);
|
||||||
|
pthread_join(comm->helperThread[i], NULL);
|
||||||
|
}
|
||||||
|
free(res->threadTaskQueue.tasks);
|
||||||
|
}
|
||||||
|
if (comm->ctrlFd != -1) close(comm->ctrlFd);
|
||||||
|
for (int i=0; i<comm->nSocks; i++) {
|
||||||
|
if (comm->fds[i] != -1) close(comm->fds[i]);
|
||||||
|
}
|
||||||
free(comm);
|
free(comm);
|
||||||
}
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
@ -263,5 +481,5 @@ ncclNet_t ncclNetSocket = {
|
|||||||
ncclSocketTest,
|
ncclSocketTest,
|
||||||
ncclSocketClose,
|
ncclSocketClose,
|
||||||
ncclSocketClose,
|
ncclSocketClose,
|
||||||
ncclSocketClose
|
ncclSocketCloseListen
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user