Fix #209: improve socket transport performance
  Split transfers over multiple sockets
  Launch multiple threads to drive sockets
  Detect AWS NICs and set nsockets/nthreads accordingly
This commit is contained in:
Ke Wen 2019-06-25 13:22:47 -07:00
parent 0ceaec9cee
commit 7c72dee660
7 changed files with 425 additions and 96 deletions

View File

@ -1,6 +1,6 @@
##### version
NCCL_MAJOR := 2
NCCL_MINOR := 4
NCCL_PATCH := 7
NCCL_PATCH := 8
NCCL_SUFFIX :=
PKG_REVISION := 1

View File

@ -9,37 +9,145 @@
#include "utils.h"
#include "bootstrap.h"
#include "net.h"
#include "socket.h"
#include <unistd.h>
#include <sys/types.h>
// Always use sockets for bootstrap
ncclNet_t* ncclBootstrapNet = &ncclNetSocket;
struct bootstrapNetHandle {
union socketAddress connectAddr;
};
static ncclResult_t bootstrapNetListen(int dev, void* handle, void** listenComm) { NCCLCHECK(ncclBootstrapNet->listen(dev, handle, listenComm)); return ncclSuccess; }
static ncclResult_t bootstrapNetConnect(int dev, void* handle, void** sendComm) { NCCLCHECK(ncclBootstrapNet->connect(dev, handle, sendComm)); return ncclSuccess; }
static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) { NCCLCHECK(ncclBootstrapNet->accept(listenComm, recvComm)); return ncclSuccess; }
static ncclResult_t bootstrapNetTest(void* request, int* done, int* size) { NCCLCHECK(ncclBootstrapNet->test(request, done, size)); return ncclSuccess; }
static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(ncclBootstrapNet->closeSend(sendComm)); return ncclSuccess; }
static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(ncclBootstrapNet->closeRecv(recvComm)); return ncclSuccess; }
static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(ncclBootstrapNet->closeListen(listenComm)); return ncclSuccess; }
struct bootstrapNetComm {
int fd;
};
// Additional sync functions based on async + test for bootstrap, using host ptrs.
/* Init functions */
static char bootstrapNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
static union socketAddress bootstrapNetIfAddrs[MAX_IFS];
static int bootstrapNetIfs = -1;
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
ncclResult_t bootstrapNetInit() {
if (bootstrapNetIfs == -1) {
pthread_mutex_lock(&bootstrapNetLock);
if (bootstrapNetIfs == -1) {
bootstrapNetIfs = findInterfaces(bootstrapNetIfNames, bootstrapNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS);
if (bootstrapNetIfs <= 0) {
WARN("Bootstrap : no socket interface found");
return ncclInternalError;
} else {
char line[1024];
char addrline[1024];
line[0] = '\0';
for (int i=0; i<bootstrapNetIfs; i++) {
snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, bootstrapNetIfNames+i*MAX_IF_NAME_SIZE,
socketToString(&bootstrapNetIfAddrs[i].sa, addrline));
}
line[1023] = '\0';
INFO(NCCL_INIT, "Bootstrap : Using%s", line);
}
}
pthread_mutex_unlock(&bootstrapNetLock);
}
return ncclSuccess;
}
static ncclResult_t bootstrapNetNewComm(struct bootstrapNetComm** comm) {
NCCLCHECK(ncclCalloc(comm, 1));
(*comm)->fd = -1;
return ncclSuccess;
}
static ncclResult_t bootstrapNetGetSocketAddr(int dev, union socketAddress* addr) {
if (dev >= bootstrapNetIfs) return ncclInternalError;
memcpy(addr, bootstrapNetIfAddrs+dev, sizeof(*addr));
return ncclSuccess;
}
/* Socket Interface Selection type */
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
static ncclResult_t bootstrapNetListen(int dev, void* opaqueHandle, void** listenComm) {
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
static_assert(sizeof(struct bootstrapNetHandle) < NCCL_NET_HANDLE_MAXSIZE, "bootstrapNetHandle size too large");
// if dev >= 0, listen based on dev
if (dev >= 0) {
NCCLCHECK(bootstrapNetGetSocketAddr(dev, &(handle->connectAddr)));
} else if (dev == findSubnetIf) {
// handle stores a remote address
// need to find a local addr that is in the same network as the remote addr
union socketAddress localAddr;
char ifName[MAX_IF_NAME_SIZE];
if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
WARN("NET/Socket : No usable listening interface found");
return ncclSystemError;
}
// pass the local address back
memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr));
} // Otherwise, handle stores a local address
struct bootstrapNetComm* comm;
NCCLCHECK(bootstrapNetNewComm(&comm));
NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr));
*listenComm = comm;
return ncclSuccess;
}
static ncclResult_t bootstrapNetConnect(int dev, void* opaqueHandle, void** sendComm) {
struct bootstrapNetComm* comm;
NCCLCHECK(bootstrapNetNewComm(&comm));
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
*sendComm = comm;
return ncclSuccess;
}
static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) {
struct bootstrapNetComm* lComm = (struct bootstrapNetComm*)listenComm;
struct bootstrapNetComm* rComm;
NCCLCHECK(bootstrapNetNewComm(&rComm));
struct sockaddr_in sockaddr;
socklen_t socklen = sizeof(struct sockaddr_in);
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
*recvComm = rComm;
return ncclSuccess;
}
static ncclResult_t bootstrapNetClose(void* opaqueComm) {
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)opaqueComm;
if (comm) {
close(comm->fd);
free(comm);
}
return ncclSuccess;
}
static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(bootstrapNetClose(sendComm)); return ncclSuccess; }
static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(bootstrapNetClose(recvComm)); return ncclSuccess; }
static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(bootstrapNetClose(listenComm)); return ncclSuccess; }
// Additional sync functions
static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {
void* request, *mhandle;
NCCLCHECK(ncclBootstrapNet->regMr(sendComm, data, size, NCCL_PTR_HOST, &mhandle));
NCCLCHECK(ncclBootstrapNet->isend(sendComm, data, size, mhandle, &request));
NCCLCHECK(ncclBootstrapNet->deregMr(sendComm, mhandle));
int done = 0;
while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL));
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm;
NCCLCHECK(socketSend(comm->fd, &size, sizeof(int)));
NCCLCHECK(socketSend(comm->fd, data, size));
return ncclSuccess;
}
static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) {
void* request, *mhandle;
NCCLCHECK(ncclBootstrapNet->regMr(recvComm, data, size, NCCL_PTR_HOST, &mhandle));
NCCLCHECK(ncclBootstrapNet->irecv(recvComm, data, size, mhandle, &request));
NCCLCHECK(ncclBootstrapNet->deregMr(recvComm, mhandle));
int done = 0;
while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL));
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)recvComm;
int recvSize;
NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int)));
if (recvSize > size) {
WARN("Message truncated : received %d bytes instead of %d\n", recvSize, size);
return ncclInternalError;
}
NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size)));
return ncclSuccess;
}
ncclResult_t bootstrapNetCreateHandle(void* opaqueHandle, const char* str) {
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
NCCLCHECK(GetSocketAddrFromString(&handle->connectAddr, str));
return ncclSuccess;
}
@ -148,7 +256,7 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out) {
char* env = getenv("NCCL_COMM_ID");
if (env) {
if (ncclSocketCreateHandle(&id->extHandleRoot, env) != 0) {
if (bootstrapNetCreateHandle(&id->extHandleRoot, env) != 0) {
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
return ncclInvalidArgument;
}

View File

@ -9,6 +9,7 @@
#include "nccl.h"
ncclResult_t bootstrapNetInit();
ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv);
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out);
ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commState);

View File

@ -13,11 +13,6 @@
extern ncclNet_t* ncclNet;
typedef char ncclNetHandle_t[NCCL_NET_HANDLE_MAXSIZE];
/* Socket Interface Selection type */
typedef enum { findSubnetIf = -1,
dontCareIf = -2
} ncclSocketIfSl_t;
// Translation to external API
static const char* ncclNetName() { return ncclNet->name; }
static ncclResult_t ncclNetDevices(int* ndev) { NCCLCHECK(ncclNet->devices(ndev)); return ncclSuccess; }
@ -36,7 +31,6 @@ static ncclResult_t ncclNetCloseSend(void* sendComm) { NCCLCHECK(ncclNet->closeS
static ncclResult_t ncclNetCloseRecv(void* recvComm) { NCCLCHECK(ncclNet->closeRecv(recvComm)); return ncclSuccess; }
static ncclResult_t ncclNetCloseListen(void* listenComm) { NCCLCHECK(ncclNet->closeListen(listenComm)); return ncclSuccess; }
extern ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str);
extern ncclNet_t ncclNetIb;
extern ncclNet_t ncclNetSocket;

View File

@ -42,7 +42,7 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
return buf;
}
static inline short socketToPort(struct sockaddr *saddr) {
static inline uint16_t socketToPort(struct sockaddr *saddr) {
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
}
@ -161,7 +161,10 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) {
}
static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) {
char line[1024], line_a[1024];
#ifdef ENABLE_TRACE
char line[1024];
#endif
char line_a[1024];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
@ -185,7 +188,7 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd
// Store the interface name
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
INFO(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
found++;
if (found == maxIfs) break;
}
@ -390,12 +393,12 @@ retry:
#define NCCL_SOCKET_SEND 0
#define NCCL_SOCKET_RECV 1
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) {
int bytes = 0;
char* data = (char*)ptr;
do {
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
if (op == NCCL_SOCKET_RECV && bytes == 0) {
WARN("Net : Connection closed by remote peer");
return ncclSystemError;
@ -413,9 +416,13 @@ static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* off
return ncclSuccess;
}
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
return socketProgressOpt(op, fd, ptr, size, offset, 0);
}
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
while (*offset < size)
NCCLCHECK(socketProgress(op, fd, ptr, size, offset));
NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1));
return ncclSuccess;
}

View File

@ -124,14 +124,15 @@ cleanup:
}
ncclResult_t initNet() {
// Always initialize sockets as we use it for bootstrap
NCCLCHECK(initNet(&ncclNetSocket));
// Always initialize bootstrap network
NCCLCHECK(bootstrapNetInit());
NCCLCHECK(initNetPlugin(&ncclNet));
if (ncclNet != NULL) return ncclSuccess;
if (initNet(&ncclNetIb) == ncclSuccess) {
ncclNet = &ncclNetIb;
} else {
NCCLCHECK(initNet(&ncclNetSocket));
ncclNet = &ncclNetSocket;
}
return ncclSuccess;

View File

@ -8,6 +8,7 @@
#include "core.h"
#include "socket.h"
#include "net.h"
#include "param.h"
#include <assert.h>
#include <pthread.h>
@ -15,6 +16,7 @@
#include <stdlib.h>
#include <poll.h>
#include <limits.h>
#include <fcntl.h>
/* Init functions */
static char ncclNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
@ -68,7 +70,7 @@ ncclResult_t ncclSocketPciPath(int dev, char** path) {
return ncclSuccess;
}
static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
if (dev >= ncclNetIfs) return ncclInternalError;
memcpy(addr, ncclNetIfAddrs+dev, sizeof(*addr));
return ncclSuccess;
@ -76,105 +78,281 @@ static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
/* Communication functions */
#define MAX_SOCKETS 64
#define MAX_THREADS 16
#define MAX_REQUESTS 128
#define MAX_QUEUE_LEN MAX_REQUESTS
#define MIN_CHUNKSIZE (64*1024)
NCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2);
NCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2);
struct ncclSocketHandle {
union socketAddress connectAddr;
int nSocks;
int nThreads;
};
struct ncclSocketRequest {
struct ncclSocketTask {
int op;
void* data;
int size;
int fd;
int offset;
int used;
ncclResult_t result;
};
struct ncclSocketReqs {
struct ncclSocketRequest* requests;
struct ncclSocketRequest {
int op;
void* data;
int size;
int ctrlFd;
int used;
struct ncclSocketComm* comm;
struct ncclSocketTask* tasks[MAX_SOCKETS];
int nSubs;
};
struct ncclSocketTaskQueue {
int next;
struct ncclSocketTask* tasks;
};
enum threadState {start, stop};
struct ncclSocketThreadResources {
struct ncclSocketTaskQueue threadTaskQueue;
enum threadState state;
struct ncclSocketComm* comm;
pthread_mutex_t threadLock;
pthread_cond_t threadCond;
};
struct ncclSocketListenComm {
int fd;
int nSocks;
int nThreads;
};
struct ncclSocketComm {
int fd;
struct ncclSocketReqs reqs;
int ctrlFd;
int fds[MAX_SOCKETS];
int nSocks;
int nThreads;
int nextFd;
struct ncclSocketRequest requests[MAX_REQUESTS];
pthread_t helperThread[MAX_THREADS];
struct ncclSocketThreadResources threadResources[MAX_THREADS];
};
ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) {
void* persistentSocketThread(void *args_) {
struct ncclSocketThreadResources* resource = (struct ncclSocketThreadResources*)args_;
struct ncclSocketComm* comm = resource->comm;
volatile enum threadState* state = &resource->state;
struct ncclSocketTaskQueue* myQueue = &resource->threadTaskQueue;
int nSocksPerThread = comm->nSocks / comm->nThreads;
while (1) {
int idle = 1;
int mark = myQueue->next; // mark newest task seen
for (int i=0; i<MAX_QUEUE_LEN; i+=nSocksPerThread) {
int repeat;
do {
repeat = 0;
for (int j=0; j<nSocksPerThread; j++) {
struct ncclSocketTask* r = myQueue->tasks+i+j;
if (r != NULL && r->used == 1 && r->offset < r->size) {
r->result = socketProgress(r->op, r->fd, r->data, r->size, &r->offset);
if (r->result != ncclSuccess) {
WARN("NET/Socket : socket progress error");
return NULL;
}
idle = 0;
if (r->offset < r->size) repeat = 1;
}
}
} while (repeat);
}
if (idle) {
pthread_mutex_lock(&resource->threadLock);
while (mark == myQueue->next && *state != stop) { // no new tasks, wait
pthread_cond_wait(&resource->threadCond, &resource->threadLock);
}
pthread_mutex_unlock(&resource->threadLock);
}
if (*state == stop) return NULL;
}
}
ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) {
int nSocksPerThread = ncclParamSocketNsocksPerThread();
int nThreads = ncclParamSocketNthreads();
if (nThreads > MAX_THREADS) {
WARN("NET/Socket : NCCL_SOCKET_NTHREADS is greater than the maximum allowed, setting to %d", MAX_THREADS);
nThreads = MAX_THREADS;
}
if (nThreads == -2 || nSocksPerThread == -2) {
// Auto-detection
int autoNt=1, autoNs=1;
char vendorPath[PATH_MAX];
snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclNetIfNames+dev*MAX_IF_NAME_SIZE);
char* rPath = realpath(vendorPath, NULL);
int fd = open(rPath, O_RDONLY);
free(rPath);
if (fd == -1) {
// Could not find device vendor. This is handled silently so
// we don't want to print an INFO error.
TRACE(NCCL_NET, "Open of %s failed : %s\n", vendorPath, strerror(errno));
goto end;
}
char vendor[7];
strncpy(vendor, "0x0000", 7);
int len;
SYSCHECKVAL(read(fd, vendor, 6), "read", len);
SYSCHECK(close(fd), "close");
if (strcmp(vendor, "0x1d0f") == 0) { // AWS
autoNt = 2;
autoNs = 8;
}
end:
if (nThreads == -2) nThreads = autoNt;
if (nSocksPerThread == -2) nSocksPerThread = autoNs;
}
int nSocks = nSocksPerThread * nThreads;
if (nSocks > MAX_SOCKETS) {
nSocksPerThread = MAX_SOCKETS/nThreads;
WARN("NET/Socket : the total number of sockets is greater than the maximum allowed, setting NCCL_NSOCKS_PERTHREAD to %d", nSocksPerThread);
nSocks = nSocksPerThread * nThreads;
}
*ns = nSocks;
*nt = nThreads;
INFO(NCCL_INIT, "NET/Socket: Using %d threads and %d sockets per thread", nThreads, nSocksPerThread);
return ncclSuccess;
}
ncclResult_t ncclSocketNewListenComm(struct ncclSocketListenComm** comm) {
NCCLCHECK(ncclCalloc(comm, 1));
(*comm)->fd = -1;
return ncclSuccess;
}
ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str) {
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
NCCLCHECK(GetSocketAddrFromString(&(handle->connectAddr), str));
ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) {
NCCLCHECK(ncclCalloc(comm, 1));
(*comm)->ctrlFd = -1;
for (int i=0; i < MAX_SOCKETS; i++) {
(*comm)->fds[i] = -1;
}
(*comm)->nextFd = 0;
return ncclSuccess;
}
ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) {
if (dev < 0) { // data transfer socket is based on specified dev
return ncclInternalError;
}
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
static_assert(sizeof(struct ncclSocketHandle) < NCCL_NET_HANDLE_MAXSIZE, "ncclSocketHandle size too large");
// if dev >= 0, listen based on dev
if (dev >= 0) {
NCCLCHECK(GetSocketAddr(dev, &(handle->connectAddr)));
} else if (dev == findSubnetIf) {
// handle stores a remote address
// need to find a local addr that is in the same network as the remote addr
union socketAddress localAddr;
char ifName[MAX_IF_NAME_SIZE];
if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
WARN("NET/Socket : No usable listening interface found");
return ncclSystemError;
}
// pass the local address back
memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr));
} // Otherwise, handle stores a local address
struct ncclSocketComm* comm;
NCCLCHECK(ncclSocketNewComm(&comm));
struct ncclSocketListenComm* comm;
NCCLCHECK(ncclSocketNewListenComm(&comm));
NCCLCHECK(GetSocketAddr(dev, &handle->connectAddr));
NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr));
NCCLCHECK(ncclSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads));
handle->nSocks = comm->nSocks;
handle->nThreads = comm->nThreads;
*listenComm = comm;
return ncclSuccess;
}
ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) {
if (dev < 0) { // data transfer socket is based on specified dev
return ncclInternalError;
}
struct ncclSocketComm* comm;
NCCLCHECK(ncclSocketNewComm(&comm));
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
comm->nSocks = handle->nSocks;
comm->nThreads = handle->nThreads;
for (int i=0; i<comm->nSocks+1; i++) {
int tmpFd, offset=0;
NCCLCHECK(connectAddress(&tmpFd, &handle->connectAddr));
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset));
if (i == comm->nSocks) comm->ctrlFd = tmpFd;
else comm->fds[i] = tmpFd;
}
*sendComm = comm;
return ncclSuccess;
}
ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) {
struct ncclSocketComm* lComm = (struct ncclSocketComm*)listenComm;
struct ncclSocketListenComm* lComm = (struct ncclSocketListenComm*)listenComm;
struct ncclSocketComm* rComm;
NCCLCHECK(ncclSocketNewComm(&rComm));
struct sockaddr_in sockaddr;
socklen_t socklen = sizeof(struct sockaddr_in);
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
rComm->nSocks = lComm->nSocks;
rComm->nThreads = lComm->nThreads;
for (int i=0; i<rComm->nSocks+1; i++) {
int tmpFd, sendSockIdx, offset=0;
struct sockaddr_in sockaddr;
socklen_t socklen = sizeof(struct sockaddr_in);
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", tmpFd);
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &sendSockIdx, sizeof(int), &offset));
if (sendSockIdx == rComm->nSocks) rComm->ctrlFd = tmpFd;
else rComm->fds[sendSockIdx] = tmpFd;
}
*recvComm = rComm;
return ncclSuccess;
}
#define MAX_REQUESTS 128
ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, int op, void* data, int size, int fd, struct ncclSocketRequest** req) {
if (reqs->requests == NULL) {
NCCLCHECK(ncclCalloc(&reqs->requests, MAX_REQUESTS));
}
ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketRequest** req) {
for (int i=0; i<MAX_REQUESTS; i++) {
struct ncclSocketRequest* r = reqs->requests+i;
struct ncclSocketRequest* r = comm->requests+i;
if (r->used == 0) {
r->op = op;
r->data = data;
r->size = size;
r->fd = fd;
r->offset = -1;
r->ctrlFd = comm->ctrlFd;
r->used = 1;
r->comm = comm;
r->nSubs = 0;
*req = r;
return ncclSuccess;
}
}
WARN("Socket : unable to allocate requests");
WARN("NET/Socket : unable to allocate requests");
return ncclInternalError;
}
ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketTask** req) {
int tid = comm->nextFd % comm->nThreads;
struct ncclSocketThreadResources* res = comm->threadResources+tid;
struct ncclSocketTaskQueue* queue = &res->threadTaskQueue;
// create helper threads and prepare per-thread task queue
if (queue->tasks == NULL) {
NCCLCHECK(ncclCalloc(&queue->tasks, MAX_QUEUE_LEN));
queue->next = 0;
res->comm = comm;
pthread_mutex_init(&res->threadLock, NULL);
pthread_cond_init(&res->threadCond, NULL);
pthread_create(comm->helperThread+tid, NULL, persistentSocketThread, res);
}
struct ncclSocketTask* r = queue->tasks+queue->next;
if (r->used == 0) {
r->op = op;
r->data = data;
r->size = size;
r->fd = comm->fds[comm->nextFd];
r->offset = 0;
r->result = ncclSuccess;
comm->nextFd = (comm->nextFd + 1) % comm->nSocks;
r->used = 1;
*req = r;
pthread_mutex_lock(&res->threadLock);
queue->next = (queue->next+1)%MAX_QUEUE_LEN;
res->state = start;
pthread_cond_signal(&res->threadCond);
pthread_mutex_unlock(&res->threadLock);
return ncclSuccess;
}
WARN("NET/Socket : unable to allocate subtasks");
return ncclInternalError;
}
@ -185,15 +363,15 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
WARN("NET/Socket : test called with NULL request");
return ncclInternalError;
}
if (r->offset == -1) { /* try to send/recv size */
if (r->used == 1) { /* try to send/recv size */
int data = r->size;
int offset = 0;
NCCLCHECK(socketProgress(r->op, r->fd, &data, sizeof(int), &offset));
NCCLCHECK(socketProgress(r->op, r->ctrlFd, &data, sizeof(int), &offset));
if (offset == 0) return ncclSuccess; /* Not ready -- retry later */
// Not sure we could ever receive less than 4 bytes, but just in case ...
if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->fd, &data, sizeof(int), &offset));
if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, &data, sizeof(int), &offset));
// Check size is less or equal to the size provided by the user
if (r->op == NCCL_SOCKET_RECV && data > r->size) {
@ -201,15 +379,33 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
return ncclInternalError;
}
r->size = data;
r->offset = 0;
r->used = 2; // done exchanging size
// divide into subtasks
int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
int chunkOffset = 0, i = 0;
while (chunkOffset < r->size) {
int chunkSize = std::min(taskSize, r->size-chunkOffset);
NCCLCHECK(ncclSocketGetTask(r->comm, r->op, (char*)(r->data)+chunkOffset, chunkSize, r->tasks+i++));
chunkOffset += chunkSize;
}
r->nSubs = i;
}
if (r->offset < r->size) {
NCCLCHECK(socketProgress(r->op, r->fd, r->data, r->size, &r->offset));
}
if (r->offset == r->size) {
if (size) *size = r->size;
*done = 1;
r->used = 0;
if (r->used == 2) { // already exchanged size
int nCompleted = 0;
for (int i=0; i<r->nSubs; i++) {
struct ncclSocketTask* sub = r->tasks[i];
if (sub->result != ncclSuccess) return sub->result;
if (sub->offset == sub->size) nCompleted++;
}
if (nCompleted == r->nSubs) {
if (size) *size = r->size;
*done = 1;
r->used = 0;
for (int i=0; i<r->nSubs; i++) {
struct ncclSocketTask* sub = r->tasks[i];
sub->used = 0;
}
}
}
return ncclSuccess;
}
@ -221,13 +417,13 @@ ncclResult_t ncclSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess;
ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, void* mhandle, void** request) {
struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm;
NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_SEND, data, size, comm->fd, (struct ncclSocketRequest**)request));
NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclSocketRequest**)request));
return ncclSuccess;
}
ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) {
struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm;
NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_RECV, data, size, comm->fd, (struct ncclSocketRequest**)request));
NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_RECV, data, size, (struct ncclSocketRequest**)request));
return ncclSuccess;
}
@ -236,11 +432,33 @@ ncclResult_t ncclSocketFlush(void* recvComm, void* data, int size, void* mhandle
return ncclInternalError;
}
ncclResult_t ncclSocketCloseListen(void* opaqueComm) {
struct ncclSocketListenComm* comm = (struct ncclSocketListenComm*)opaqueComm;
if (comm) {
if (comm->fd != -1) close(comm->fd);
free(comm);
}
return ncclSuccess;
}
ncclResult_t ncclSocketClose(void* opaqueComm) {
struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm;
if (comm) {
free(comm->reqs.requests);
close(comm->fd);
for (int i=0; i<comm->nThreads; i++) {
struct ncclSocketThreadResources* res = comm->threadResources+i;
if (comm->helperThread[i]) {
pthread_mutex_lock(&res->threadLock);
res->state = stop;
pthread_cond_signal(&res->threadCond);
pthread_mutex_unlock(&res->threadLock);
pthread_join(comm->helperThread[i], NULL);
}
free(res->threadTaskQueue.tasks);
}
if (comm->ctrlFd != -1) close(comm->ctrlFd);
for (int i=0; i<comm->nSocks; i++) {
if (comm->fds[i] != -1) close(comm->fds[i]);
}
free(comm);
}
return ncclSuccess;
@ -263,5 +481,5 @@ ncclNet_t ncclNetSocket = {
ncclSocketTest,
ncclSocketClose,
ncclSocketClose,
ncclSocketClose
ncclSocketCloseListen
};