2.3.7-1
Improved LL tuning for multi-node jobs. Improved bootstrap for large job scaling. Fixed a hang during bootstrap due to socket reuse. Added operation name to the COLL INFO logging.
This commit is contained in:
parent
3202d6b393
commit
b56650c7f5
@ -1,6 +1,6 @@
|
|||||||
##### version
|
##### version
|
||||||
NCCL_MAJOR := 2
|
NCCL_MAJOR := 2
|
||||||
NCCL_MINOR := 3
|
NCCL_MINOR := 3
|
||||||
NCCL_PATCH := 5
|
NCCL_PATCH := 7
|
||||||
NCCL_SUFFIX :=
|
NCCL_SUFFIX :=
|
||||||
PKG_REVISION := 5
|
PKG_REVISION := 1
|
||||||
|
213
src/bootstrap.cu
213
src/bootstrap.cu
@ -40,7 +40,7 @@ static ncclResult_t bootstrapRecv(void* recvComm, void* data, int size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct extId {
|
struct extId {
|
||||||
ncclNetHandle_t extHandle;
|
ncclNetHandle_t extHandleRoot;
|
||||||
void* extListenComm;
|
void* extListenComm;
|
||||||
uint64_t hostHash;
|
uint64_t hostHash;
|
||||||
pid_t pid;
|
pid_t pid;
|
||||||
@ -48,20 +48,11 @@ struct extId {
|
|||||||
pthread_t boostrapThread;
|
pthread_t boostrapThread;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct bootstrapOp {
|
|
||||||
int op;
|
|
||||||
int size;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct extInfo {
|
struct extInfo {
|
||||||
int rank;
|
int rank;
|
||||||
int nranks;
|
int nranks;
|
||||||
ncclNetHandle_t extHandle;
|
ncclNetHandle_t extHandleListenFromRoot;
|
||||||
};
|
ncclNetHandle_t extHandleRing;
|
||||||
|
|
||||||
enum {
|
|
||||||
BOOTSTRAP_ALLGATHER = 1,
|
|
||||||
BOOTSTRAP_RINGEXCHANGE,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <sys/resource.h>
|
#include <sys/resource.h>
|
||||||
@ -77,10 +68,10 @@ static ncclResult_t setFilesLimit() {
|
|||||||
static void *bootstrapRoot(void* commId) {
|
static void *bootstrapRoot(void* commId) {
|
||||||
struct extInfo info;
|
struct extInfo info;
|
||||||
struct extId* id = (struct extId*)commId;
|
struct extId* id = (struct extId*)commId;
|
||||||
struct bootstrapOp bop;
|
ncclNetHandle_t *extHandleBstrap = NULL; // for initial rank <-> root information exchange
|
||||||
void **extSendComm = NULL;
|
ncclNetHandle_t *extHandleRing = NULL; // for bootstrap ring creation
|
||||||
void **extRecvComm = NULL;
|
ncclNetHandle_t zero = { 0 }; // for sanity checking
|
||||||
int size, alloc_size = 0;
|
void* tmpComm;
|
||||||
char* data = NULL;
|
char* data = NULL;
|
||||||
ncclResult_t res;
|
ncclResult_t res;
|
||||||
setFilesLimit();
|
setFilesLimit();
|
||||||
@ -88,13 +79,14 @@ static void *bootstrapRoot(void* commId) {
|
|||||||
/* Receive addresses from all ranks */
|
/* Receive addresses from all ranks */
|
||||||
int nranks = 0, c = 0;
|
int nranks = 0, c = 0;
|
||||||
do {
|
do {
|
||||||
void* tmpRecvComm;
|
NCCLCHECKGOTO(bootstrapAccept(id->extListenComm, &tmpComm), res, out);
|
||||||
NCCLCHECKGOTO(bootstrapAccept(id->extListenComm, &tmpRecvComm), res, out);
|
NCCLCHECKGOTO(bootstrapRecv(tmpComm, &info, sizeof(info)), res, out);
|
||||||
NCCLCHECKGOTO(bootstrapRecv(tmpRecvComm, &info, sizeof(info)), res, out);
|
NCCLCHECKGOTO(bootstrapCloseRecv(tmpComm), res, out);
|
||||||
if (!c) {
|
|
||||||
extSendComm = (void**)calloc(info.nranks, sizeof(void*));
|
if (c == 0) {
|
||||||
extRecvComm = (void**)calloc(info.nranks, sizeof(void*));
|
extHandleBstrap = (ncclNetHandle_t *)calloc(info.nranks, sizeof(ncclNetHandle_t));
|
||||||
if (extSendComm == NULL || extRecvComm == NULL) {
|
extHandleRing = (ncclNetHandle_t *)calloc(info.nranks, sizeof(ncclNetHandle_t));
|
||||||
|
if (extHandleBstrap == NULL || extHandleRing == NULL) {
|
||||||
WARN("Bootstrap thread : failed to allocate memory");
|
WARN("Bootstrap thread : failed to allocate memory");
|
||||||
goto out;
|
goto out;
|
||||||
}
|
}
|
||||||
@ -106,69 +98,39 @@ static void *bootstrapRoot(void* commId) {
|
|||||||
goto out;
|
goto out;
|
||||||
}
|
}
|
||||||
|
|
||||||
extRecvComm[info.rank] = tmpRecvComm;
|
if (memcmp(&zero, &extHandleBstrap[info.rank], sizeof(ncclNetHandle_t)) != 0) {
|
||||||
NCCLCHECKGOTO(bootstrapConnect(0, info.extHandle, extSendComm+info.rank), res, out);
|
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks);
|
||||||
c++;
|
goto out;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the connection handle for connecting back to the ranks
|
||||||
|
memcpy(&extHandleBstrap[info.rank], info.extHandleListenFromRoot, sizeof(ncclNetHandle_t));
|
||||||
|
// Save the connection handle for the AllGather ring
|
||||||
|
memcpy(&extHandleRing[info.rank], info.extHandleRing, sizeof(ncclNetHandle_t));
|
||||||
|
|
||||||
|
++c;
|
||||||
} while (c < nranks);
|
} while (c < nranks);
|
||||||
|
|
||||||
do {
|
// Send the connect handle for the next rank in the AllGather ring
|
||||||
NCCLCHECKGOTO(bootstrapRecv(extRecvComm[0], &bop, sizeof(struct bootstrapOp)), res, out);
|
for (int r=0; r<nranks; ++r) {
|
||||||
if (bop.size == -1) {
|
int next = (r+1) % nranks;
|
||||||
break;
|
void *tmpSendComm;
|
||||||
} else {
|
NCCLCHECKGOTO(bootstrapConnect(0, extHandleBstrap[r], &tmpSendComm), res, out);
|
||||||
size = bop.size;
|
NCCLCHECKGOTO(bootstrapSend(tmpSendComm, &extHandleRing[next], sizeof(ncclNetHandle_t)), res, out);
|
||||||
if (size*nranks*2 > alloc_size) {
|
NCCLCHECKGOTO(bootstrapCloseSend(tmpSendComm), res, out);
|
||||||
if (data) free(data); data = NULL;
|
}
|
||||||
NCCLCHECKGOTO(ncclCalloc(&data, size*nranks*2), res, out);
|
|
||||||
alloc_size = size*nranks*2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bop.op == BOOTSTRAP_ALLGATHER) {
|
|
||||||
for (int r=0; r<nranks; r++) {
|
|
||||||
NCCLCHECKGOTO(bootstrapRecv(extRecvComm[r], data+size*r, size), res, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int r=0; r<nranks; r++) {
|
|
||||||
NCCLCHECKGOTO(bootstrapSend(extSendComm[r], data, size*nranks), res, out);
|
|
||||||
}
|
|
||||||
} else if (bop.op == BOOTSTRAP_RINGEXCHANGE) {
|
|
||||||
// Receive from all and build total table
|
|
||||||
for (int r=0; r<nranks; r++) {
|
|
||||||
NCCLCHECKGOTO(bootstrapRecv(extRecvComm[r], data+r*2*size, 2*size), res, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get prev/next request from everyone and answer.
|
|
||||||
for (int r=0; r<nranks; r++) {
|
|
||||||
int offset;
|
|
||||||
NCCLCHECKGOTO(bootstrapRecv(extRecvComm[r], &offset, sizeof(int)), res, out);
|
|
||||||
NCCLCHECKGOTO(bootstrapSend(extSendComm[r], data+offset, size), res, out);
|
|
||||||
NCCLCHECKGOTO(bootstrapRecv(extRecvComm[r], &offset, sizeof(int)), res, out);
|
|
||||||
NCCLCHECKGOTO(bootstrapSend(extSendComm[r], data+offset, size), res, out);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
WARN("Bootstrap Root : invalid op type received %d", bop.op);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} while (1);
|
|
||||||
|
|
||||||
out:
|
out:
|
||||||
bootstrapCloseListen(id->extListenComm);
|
bootstrapCloseListen(id->extListenComm);
|
||||||
for (int r=0; r<nranks; r++) {
|
|
||||||
if (extSendComm[r]) bootstrapCloseSend(extSendComm[r]);
|
|
||||||
if (extRecvComm[r]) bootstrapCloseRecv(extRecvComm[r]);
|
|
||||||
}
|
|
||||||
free(commId);
|
free(commId);
|
||||||
if (data) free(data);
|
if (data) free(data);
|
||||||
if (extSendComm) free(extSendComm);
|
|
||||||
if (extRecvComm) free(extRecvComm);
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv) {
|
ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv) {
|
||||||
struct extId* id = (struct extId*)commId;
|
struct extId* id = (struct extId*)commId;
|
||||||
id->hostHash = getHostHash();
|
id->hostHash = getHostHash();
|
||||||
NCCLCHECK(bootstrapListen(idFromEnv ? dontCareIf : 0, &id->extHandle, &id->extListenComm));
|
NCCLCHECK(bootstrapListen(idFromEnv ? dontCareIf : 0, &id->extHandleRoot, &id->extListenComm));
|
||||||
ncclUniqueId* threadIdCopy;
|
ncclUniqueId* threadIdCopy;
|
||||||
NCCLCHECK(ncclCalloc(&threadIdCopy, 1));
|
NCCLCHECK(ncclCalloc(&threadIdCopy, 1));
|
||||||
memcpy(threadIdCopy, id, sizeof(ncclUniqueId));
|
memcpy(threadIdCopy, id, sizeof(ncclUniqueId));
|
||||||
@ -182,7 +144,7 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out) {
|
|||||||
|
|
||||||
char* env = getenv("NCCL_COMM_ID");
|
char* env = getenv("NCCL_COMM_ID");
|
||||||
if (env) {
|
if (env) {
|
||||||
if (ncclSocketCreateHandle(&id->extHandle, env) != 0) {
|
if (ncclSocketCreateHandle(&id->extHandleRoot, env) != 0) {
|
||||||
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
|
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
|
||||||
return ncclInvalidArgument;
|
return ncclInvalidArgument;
|
||||||
}
|
}
|
||||||
@ -196,10 +158,12 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct extState {
|
struct extState {
|
||||||
void* extRecvComm;
|
void* extBstrapRingRecvComm;
|
||||||
void* extSendComm;
|
void* extBstrapRingSendComm;
|
||||||
|
ncclNetHandle_t extBstrapRootHandle;
|
||||||
int rank;
|
int rank;
|
||||||
int nranks;
|
int nranks;
|
||||||
|
int dev;
|
||||||
};
|
};
|
||||||
|
|
||||||
ncclResult_t bootstrapInit(ncclUniqueId* commId, int rank, int nranks, void** commState) {
|
ncclResult_t bootstrapInit(ncclUniqueId* commId, int rank, int nranks, void** commState) {
|
||||||
@ -210,22 +174,39 @@ ncclResult_t bootstrapInit(ncclUniqueId* commId, int rank, int nranks, void** co
|
|||||||
state->rank = rank;
|
state->rank = rank;
|
||||||
state->nranks = nranks;
|
state->nranks = nranks;
|
||||||
*commState = state;
|
*commState = state;
|
||||||
|
void* extBstrapRootListenComm; // comm on which we accept root's connections
|
||||||
|
|
||||||
struct extInfo info;
|
struct extInfo info = { 0 };
|
||||||
info.rank = rank;
|
info.rank = rank;
|
||||||
info.nranks = nranks;
|
info.nranks = nranks;
|
||||||
void* tmpListenComm;
|
void *tmpSendComm, *extBstrapRingListenComm, *tmpRecvComm;
|
||||||
// Pass the remote address to listen via info
|
// Pass the remote address to listen via info
|
||||||
if (idFromEnv) {
|
if (idFromEnv) {
|
||||||
memcpy(&info.extHandle, &id->extHandle, sizeof(ncclNetHandle_t));
|
memcpy(&info.extHandleListenFromRoot, &id->extHandleRoot, sizeof(ncclNetHandle_t));
|
||||||
|
memcpy(&info.extHandleRing, &id->extHandleRoot, sizeof(ncclNetHandle_t));
|
||||||
}
|
}
|
||||||
// listen will return the local address via info ('findSubnetIf' indicates that the net device is unknown)
|
// listen will return the local address via info (specify interface type 'findSubnetIf')
|
||||||
int dev = idFromEnv ? findSubnetIf : 0;
|
state->dev = idFromEnv ? findSubnetIf : 0;
|
||||||
NCCLCHECK(bootstrapListen(dev, &info.extHandle, &tmpListenComm));
|
NCCLCHECK(bootstrapListen(state->dev, &info.extHandleListenFromRoot, &extBstrapRootListenComm));
|
||||||
NCCLCHECK(bootstrapConnect(dev, id->extHandle, &state->extSendComm));
|
NCCLCHECK(bootstrapListen(state->dev, &info.extHandleRing, &extBstrapRingListenComm)); // AllGather Ring
|
||||||
NCCLCHECK(bootstrapSend(state->extSendComm, &info, sizeof(info)));
|
|
||||||
NCCLCHECK(bootstrapAccept(tmpListenComm, &state->extRecvComm));
|
memcpy(&state->extBstrapRootHandle, &id->extHandleRoot, sizeof(ncclNetHandle_t));
|
||||||
NCCLCHECK(bootstrapCloseListen(tmpListenComm));
|
// send info on my listening sockets to root
|
||||||
|
NCCLCHECK(bootstrapConnect(state->dev, id->extHandleRoot, &tmpSendComm));
|
||||||
|
NCCLCHECK(bootstrapSend(tmpSendComm, &info, sizeof(info)));
|
||||||
|
NCCLCHECK(bootstrapCloseSend(tmpSendComm));
|
||||||
|
|
||||||
|
// get info on my "next" rank in the bootstrap ring from root
|
||||||
|
ncclNetHandle_t extHandleNext;
|
||||||
|
NCCLCHECK(bootstrapAccept(extBstrapRootListenComm, &tmpRecvComm));
|
||||||
|
NCCLCHECK(bootstrapRecv(tmpRecvComm, &extHandleNext, sizeof(extHandleNext)));
|
||||||
|
NCCLCHECK(bootstrapCloseRecv(tmpRecvComm));
|
||||||
|
|
||||||
|
NCCLCHECK(bootstrapConnect(state->dev, extHandleNext, &state->extBstrapRingSendComm));
|
||||||
|
// Accept the connect request from the previous rank in the AllGather ring
|
||||||
|
NCCLCHECK(bootstrapAccept(extBstrapRingListenComm, &state->extBstrapRingRecvComm));
|
||||||
|
NCCLCHECK(bootstrapCloseListen(extBstrapRingListenComm));
|
||||||
|
NCCLCHECK(bootstrapCloseListen(extBstrapRootListenComm));
|
||||||
|
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
@ -233,58 +214,34 @@ ncclResult_t bootstrapInit(ncclUniqueId* commId, int rank, int nranks, void** co
|
|||||||
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
|
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
|
||||||
struct extState* state = (struct extState*)commState;
|
struct extState* state = (struct extState*)commState;
|
||||||
char* data = (char*)allData;
|
char* data = (char*)allData;
|
||||||
struct bootstrapOp bop;
|
int rank = state->rank;
|
||||||
|
int nranks = state->nranks;
|
||||||
|
|
||||||
bop.op = BOOTSTRAP_ALLGATHER;
|
TRACE(INIT, "rank %d nranks %d size %d", rank, nranks, size);
|
||||||
bop.size = size;
|
|
||||||
|
|
||||||
if (!state->rank) {
|
/* Simple ring based AllGather
|
||||||
NCCLCHECK(bootstrapSend(state->extSendComm, &bop, sizeof(struct bootstrapOp)));
|
* At each step i receive data from (rank-i-1) from left
|
||||||
|
* and send previous step's data from (rank-i) to right
|
||||||
|
*/
|
||||||
|
for (int i=0; i<nranks-1; i++) {
|
||||||
|
int rslice = (rank - i - 1 + nranks) % nranks;
|
||||||
|
int sslice = (rank - i + nranks) % nranks;
|
||||||
|
|
||||||
|
// Send slice to the right
|
||||||
|
NCCLCHECK(bootstrapSend(state->extBstrapRingSendComm, data+sslice*size, size));
|
||||||
|
// Recv slice from the left
|
||||||
|
NCCLCHECK(bootstrapRecv(state->extBstrapRingRecvComm, data+rslice*size, size));
|
||||||
}
|
}
|
||||||
|
|
||||||
NCCLCHECK(bootstrapSend(state->extSendComm, data+state->rank*size, size));
|
TRACE(INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
|
||||||
NCCLCHECK(bootstrapRecv(state->extRecvComm, data, size*state->nranks));
|
|
||||||
|
|
||||||
return ncclSuccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
ncclResult_t bootstrapRingExchange(void* commState, void* prevNextData, int prev, int next, int size) {
|
|
||||||
struct extState* state = (struct extState*)commState;
|
|
||||||
char* mydata = (char*)prevNextData;
|
|
||||||
int prev_offset = prev*2*size+size, next_offset = next*2*size;
|
|
||||||
|
|
||||||
struct bootstrapOp bop;
|
|
||||||
bop.op = BOOTSTRAP_RINGEXCHANGE;
|
|
||||||
bop.size = size;
|
|
||||||
|
|
||||||
if (!state->rank) {
|
|
||||||
NCCLCHECK(bootstrapSend(state->extSendComm, &bop, sizeof(struct bootstrapOp)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send data to root
|
|
||||||
NCCLCHECK(bootstrapSend(state->extSendComm, mydata, 2*size));
|
|
||||||
|
|
||||||
// Receive prev and next data
|
|
||||||
NCCLCHECK(bootstrapSend(state->extSendComm, &prev_offset, sizeof(int)));
|
|
||||||
NCCLCHECK(bootstrapRecv(state->extRecvComm, mydata, size));
|
|
||||||
NCCLCHECK(bootstrapSend(state->extSendComm, &next_offset, sizeof(int)));
|
|
||||||
NCCLCHECK(bootstrapRecv(state->extRecvComm, mydata+size, size));
|
|
||||||
|
|
||||||
|
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t bootstrapClose(void* commState) {
|
ncclResult_t bootstrapClose(void* commState) {
|
||||||
struct extState* state = (struct extState*)commState;
|
struct extState* state = (struct extState*)commState;
|
||||||
struct bootstrapOp bop;
|
|
||||||
bop.size = -1;
|
|
||||||
|
|
||||||
if (!state->rank) {
|
NCCLCHECK(bootstrapCloseSend(state->extBstrapRingSendComm));
|
||||||
NCCLCHECK(bootstrapSend(state->extSendComm, &bop, sizeof(struct bootstrapOp)));
|
NCCLCHECK(bootstrapCloseRecv(state->extBstrapRingRecvComm));
|
||||||
}
|
|
||||||
|
|
||||||
NCCLCHECK(bootstrapCloseSend(state->extSendComm));
|
|
||||||
NCCLCHECK(bootstrapCloseRecv(state->extRecvComm));
|
|
||||||
|
|
||||||
free(state);
|
free(state);
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
ncclResult_t ncclAllGatherFunc(const void* sendbuff, void* recvbuff, size_t count,
|
ncclResult_t ncclAllGatherFunc(const void* sendbuff, void* recvbuff, size_t count,
|
||||||
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
||||||
size_t nbytes = count*ncclTypeSize(datatype);
|
size_t nbytes = count*ncclTypeSize(datatype);
|
||||||
INFO(COLL,"opCount %lx sendbuff %p recvbuff %p count %zi size %zi datatype %d op %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, nbytes, datatype, op, comm, comm->nRanks, stream);
|
INFO(COLL,"AllGather: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
|
||||||
if (comm->nRanks == 1) {
|
if (comm->nRanks == 1) {
|
||||||
if (sendbuff != recvbuff)
|
if (sendbuff != recvbuff)
|
||||||
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
ncclResult_t ncclAllReduceFunc(const void* sendbuff, void* recvbuff, size_t count,
|
ncclResult_t ncclAllReduceFunc(const void* sendbuff, void* recvbuff, size_t count,
|
||||||
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
||||||
size_t nbytes = count*ncclTypeSize(datatype);
|
size_t nbytes = count*ncclTypeSize(datatype);
|
||||||
INFO(COLL,"opCount %lx sendbuff %p recvbuff %p count %zi size %zi datatype %d op %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, nbytes, datatype, op, comm, comm->nRanks, stream);
|
INFO(COLL,"AllReduce: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
|
||||||
if (comm->nRanks == 1) {
|
if (comm->nRanks == 1) {
|
||||||
if (sendbuff != recvbuff)
|
if (sendbuff != recvbuff)
|
||||||
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
ncclResult_t ncclBroadcastFunc(const void* sendbuff, void* recvbuff, const size_t count,
|
ncclResult_t ncclBroadcastFunc(const void* sendbuff, void* recvbuff, const size_t count,
|
||||||
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
||||||
size_t nbytes = count*ncclTypeSize(datatype);
|
size_t nbytes = count*ncclTypeSize(datatype);
|
||||||
INFO(COLL,"opCount %lx sendbuff %p recvbuff %p count %zi size %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, nbytes, datatype, op, root, comm, comm->nRanks, stream);
|
INFO(COLL,"Broadcast: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
|
||||||
if (comm->nRanks == 1) {
|
if (comm->nRanks == 1) {
|
||||||
if (sendbuff != recvbuff)
|
if (sendbuff != recvbuff)
|
||||||
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
ncclResult_t ncclReduceFunc(const void* sendbuff, void* recvbuff, const size_t count,
|
ncclResult_t ncclReduceFunc(const void* sendbuff, void* recvbuff, const size_t count,
|
||||||
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
||||||
size_t nbytes = count*ncclTypeSize(datatype);
|
size_t nbytes = count*ncclTypeSize(datatype);
|
||||||
INFO(COLL,"opCount %lx sendbuff %p recvbuff %p count %zi size %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, nbytes, datatype, op, root, comm, comm->nRanks, stream);
|
INFO(COLL,"Reduce: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
|
||||||
if (comm->nRanks == 1) {
|
if (comm->nRanks == 1) {
|
||||||
if (sendbuff != recvbuff)
|
if (sendbuff != recvbuff)
|
||||||
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
ncclResult_t ncclReduceScatterFunc(const void* sendbuff, void* recvbuff, size_t count,
|
ncclResult_t ncclReduceScatterFunc(const void* sendbuff, void* recvbuff, size_t count,
|
||||||
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
||||||
size_t nbytes = count*ncclTypeSize(datatype);
|
size_t nbytes = count*ncclTypeSize(datatype);
|
||||||
INFO(COLL,"opCount %lx sendbuff %p recvbuff %p count %zi size %zi datatype %d op %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, nbytes, datatype, op, comm, comm->nRanks, stream);
|
INFO(COLL,"ReduceScatter: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
|
||||||
if (comm->nRanks == 1) {
|
if (comm->nRanks == 1) {
|
||||||
if (sendbuff != recvbuff)
|
if (sendbuff != recvbuff)
|
||||||
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
@ -13,6 +13,5 @@ ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv);
|
|||||||
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out);
|
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out);
|
||||||
ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commState);
|
ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commState);
|
||||||
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size);
|
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size);
|
||||||
ncclResult_t bootstrapRingExchange(void* commState, void* prevNextData, int prev, int next, int size);
|
|
||||||
ncclResult_t bootstrapClose(void* commState);
|
ncclResult_t bootstrapClose(void* commState);
|
||||||
#endif
|
#endif
|
||||||
|
@ -35,7 +35,8 @@ struct cudaLaunchParams {
|
|||||||
|
|
||||||
// Rings / LL tuning
|
// Rings / LL tuning
|
||||||
#define NCCL_LL_RING_THRESHOLD 8 // Per thread size before we start increasing nrings
|
#define NCCL_LL_RING_THRESHOLD 8 // Per thread size before we start increasing nrings
|
||||||
#define NCCL_THREAD_THRESHOLD 32 // Per thread size before we switch to non-LL
|
#define NCCL_THREAD_THRESHOLD 64 // Per thread size before we switch to non-LL for Volta and above
|
||||||
|
#define NCCL_THREAD_THRESHOLD_PREVOLTA 32 // Per thread size before we switch to non-LL for pre-Volta archs
|
||||||
#define NCCL_LL_MAX_NTHREADS 256
|
#define NCCL_LL_MAX_NTHREADS 256
|
||||||
#define NCCL_LL_MIN_NTHREADS 64
|
#define NCCL_LL_MIN_NTHREADS 64
|
||||||
|
|
||||||
@ -95,8 +96,8 @@ struct ncclConnector {
|
|||||||
#define CUDA_IPC_MIN 2097152UL /* 2MiB - not currently used */
|
#define CUDA_IPC_MIN 2097152UL /* 2MiB - not currently used */
|
||||||
|
|
||||||
#define NCCL_LL_CHUNKS 8
|
#define NCCL_LL_CHUNKS 8
|
||||||
#define NUM_LINES_PER_THREAD 2
|
#define NUM_LINES_PER_THREAD 8
|
||||||
#define NCCL_LL_BUFF_SIZE (NUM_LINES_PER_THREAD*NCCL_LL_MAX_NTHREADS*NCCL_LL_CHUNKS*sizeof(union ncclLLFifoLine)) // 64K
|
#define NCCL_LL_BUFF_SIZE (NUM_LINES_PER_THREAD*NCCL_LL_MAX_NTHREADS*NCCL_LL_CHUNKS*sizeof(union ncclLLFifoLine)) // 256K
|
||||||
#define NCCL_LL_BUFF_LINES (NCCL_LL_BUFF_SIZE / (2*sizeof(uint64_t)))
|
#define NCCL_LL_BUFF_LINES (NCCL_LL_BUFF_SIZE / (2*sizeof(uint64_t)))
|
||||||
#define NCCL_LL_SLICE_LINES (NCCL_LL_BUFF_LINES / NCCL_LL_CHUNKS)
|
#define NCCL_LL_SLICE_LINES (NCCL_LL_BUFF_LINES / NCCL_LL_CHUNKS)
|
||||||
#define NCCL_LL_CLEAN_FREQ 0x10000000
|
#define NCCL_LL_CLEAN_FREQ 0x10000000
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#include <net/if.h>
|
#include <net/if.h>
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
|
#define MAX_IFS 16
|
||||||
#define MAX_IF_NAME_SIZE 16
|
#define MAX_IF_NAME_SIZE 16
|
||||||
#define SLEEP_INT 1000 // sleep interval in usec
|
#define SLEEP_INT 1000 // sleep interval in usec
|
||||||
#define RETRY_TIMES 2e4 // retry times before reporting a timeout (20 sec)
|
#define RETRY_TIMES 2e4 // retry times before reporting a timeout (20 sec)
|
||||||
@ -40,6 +41,10 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
|
|||||||
return buf;
|
return buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline short socketToPort(struct sockaddr *saddr) {
|
||||||
|
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
|
||||||
|
}
|
||||||
|
|
||||||
/* Allow the user to force the IPv4/IPv6 interface selection */
|
/* Allow the user to force the IPv4/IPv6 interface selection */
|
||||||
static inline int envSocketFamily(void) {
|
static inline int envSocketFamily(void) {
|
||||||
int family = -1; // Family selection is not forced, will use first one found
|
int family = -1; // Family selection is not forced, will use first one found
|
||||||
@ -56,9 +61,9 @@ static inline int envSocketFamily(void) {
|
|||||||
|
|
||||||
static int findInterfaces(const char* prefixList, char* names, union socketAddress *addrs, int sock_family, int maxIfNameSize, int maxIfs) {
|
static int findInterfaces(const char* prefixList, char* names, union socketAddress *addrs, int sock_family, int maxIfNameSize, int maxIfs) {
|
||||||
char line[1024];
|
char line[1024];
|
||||||
struct netIf userIfs[maxIfs];
|
struct netIf userIfs[MAX_IFS];
|
||||||
bool searchNot = prefixList && prefixList[0] == '^';
|
bool searchNot = prefixList && prefixList[0] == '^';
|
||||||
int nUserIfs = parseStringList(prefixList, userIfs, maxIfs);
|
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
|
||||||
|
|
||||||
int found = 0;
|
int found = 0;
|
||||||
struct ifaddrs *interfaces, *interface;
|
struct ifaddrs *interfaces, *interface;
|
||||||
@ -313,8 +318,11 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
|
|||||||
return ncclSystemError;
|
return ncclSystemError;
|
||||||
}
|
}
|
||||||
|
|
||||||
int opt = 1;
|
if (socketToPort(&localAddr->sa)) {
|
||||||
SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
|
// Port is forced by env. Make sure we get the port.
|
||||||
|
int opt = 1;
|
||||||
|
SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
|
||||||
|
}
|
||||||
|
|
||||||
// localAddr port should be 0 (Any port)
|
// localAddr port should be 0 (Any port)
|
||||||
SYSCHECK(bind(sockfd, &localAddr->sa, salen), "bind");
|
SYSCHECK(bind(sockfd, &localAddr->sa, salen), "bind");
|
||||||
|
67
src/init.cu
67
src/init.cu
@ -79,7 +79,19 @@ void initNet() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NCCL_PARAM(LlThreshold, "LL_THRESHOLD", -2);
|
NCCL_PARAM(LlThreshold, "LL_THRESHOLD", -2);
|
||||||
NCCL_PARAM(ThreadThreshold, "THREAD_THRESHOLD", NCCL_THREAD_THRESHOLD);
|
NCCL_PARAM(ThreadThreshold, "THREAD_THRESHOLD", -2);
|
||||||
|
|
||||||
|
int ncclThreadThreshold(int minCompCap, int multiNode) {
|
||||||
|
int threshold = ncclParamThreadThreshold();
|
||||||
|
if (threshold == -2) { // user has not set this env variable
|
||||||
|
threshold = (minCompCap <= 6) ? NCCL_THREAD_THRESHOLD_PREVOLTA : NCCL_THREAD_THRESHOLD;
|
||||||
|
// multiply by 2 if running on multiple nodes
|
||||||
|
if (multiNode) {
|
||||||
|
threshold *= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return threshold;
|
||||||
|
}
|
||||||
|
|
||||||
pthread_mutex_t initLock = PTHREAD_MUTEX_INITIALIZER;
|
pthread_mutex_t initLock = PTHREAD_MUTEX_INITIALIZER;
|
||||||
static bool initialized = false;
|
static bool initialized = false;
|
||||||
@ -165,7 +177,6 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
|||||||
cudaGetDevice(&comm->cudaDev);
|
cudaGetDevice(&comm->cudaDev);
|
||||||
comm->doneEvent = doneEvent;
|
comm->doneEvent = doneEvent;
|
||||||
comm->llThreshold = ncclParamLlThreshold();
|
comm->llThreshold = ncclParamLlThreshold();
|
||||||
comm->threadThreshold = ncclParamThreadThreshold();
|
|
||||||
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
|
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
|
||||||
#if __CUDACC_VER_MAJOR__ >= 10 || (__CUDACC_VER_MAJOR__ >= 9 && __CUDACC_VER_MINOR__ >= 2)
|
#if __CUDACC_VER_MAJOR__ >= 10 || (__CUDACC_VER_MAJOR__ >= 9 && __CUDACC_VER_MINOR__ >= 2)
|
||||||
comm->groupCudaStream = ncclParamGroupCudaStream();
|
comm->groupCudaStream = ncclParamGroupCudaStream();
|
||||||
@ -277,7 +288,7 @@ static void swap(void* mem1, void* mem2, int size) {
|
|||||||
|
|
||||||
#define MAXWIDTH 20
|
#define MAXWIDTH 20
|
||||||
#define PREFIXLEN 15
|
#define PREFIXLEN 15
|
||||||
#define STRLENGTH (PREFIXLEN+4*MAXWIDTH)
|
#define STRLENGTH (PREFIXLEN+5*MAXWIDTH)
|
||||||
void dumpMatrix(int* connectMatrix, int nranks) {
|
void dumpMatrix(int* connectMatrix, int nranks) {
|
||||||
char line[STRLENGTH+1];
|
char line[STRLENGTH+1];
|
||||||
line[STRLENGTH] = '\0';
|
line[STRLENGTH] = '\0';
|
||||||
@ -292,6 +303,21 @@ void dumpMatrix(int* connectMatrix, int nranks) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dumpMatrixTvalue(ncclTvalue_t* connectMatrix, int nranks) {
|
||||||
|
char line[STRLENGTH+1];
|
||||||
|
line[STRLENGTH] = '\0';
|
||||||
|
memset(line, ' ', STRLENGTH);
|
||||||
|
for (int j=0; j<nranks && j<MAXWIDTH; j++) sprintf(4+line+5*j, " %4d", j);
|
||||||
|
INFO(INIT,"%s", line);
|
||||||
|
for (int i=0; i<nranks; i++) {
|
||||||
|
memset(line, ' ', STRLENGTH);
|
||||||
|
sprintf(line, "%3d ", i);
|
||||||
|
for (int j=0; j<nranks && j<MAXWIDTH; j++) sprintf(4+line+5*j, " %4o", (int)connectMatrix[i*nranks+j]);
|
||||||
|
INFO(INIT,"%s", line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void dumpLine(int* values, int nranks, const char* prefix) {
|
void dumpLine(int* values, int nranks, const char* prefix) {
|
||||||
int prefixlen = strlen(prefix);
|
int prefixlen = strlen(prefix);
|
||||||
char line[STRLENGTH+1];
|
char line[STRLENGTH+1];
|
||||||
@ -433,7 +459,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
|||||||
NCCLCHECK(bootstrapAllGather(commState, connectTransport, nranks*(sizeof(int))));
|
NCCLCHECK(bootstrapAllGather(commState, connectTransport, nranks*(sizeof(int))));
|
||||||
NCCLCHECK(bootstrapAllGather(commState, connectValue, nranks*(sizeof(ncclTvalue_t))));
|
NCCLCHECK(bootstrapAllGather(commState, connectValue, nranks*(sizeof(ncclTvalue_t))));
|
||||||
//if (rank == 0) dumpMatrix(connectTransport, nranks);
|
//if (rank == 0) dumpMatrix(connectTransport, nranks);
|
||||||
//if (rank == 0) dumpMatrix(connectValue, nranks);
|
//if (rank == 0) dumpMatrixTvalue(connectValue, nranks);
|
||||||
|
|
||||||
// Get my rings
|
// Get my rings
|
||||||
int nrings;
|
int nrings;
|
||||||
@ -481,15 +507,19 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
|||||||
free(next);
|
free(next);
|
||||||
|
|
||||||
// Connect with prev/next for each ring
|
// Connect with prev/next for each ring
|
||||||
|
struct ncclConnect *connectData;
|
||||||
|
NCCLCHECK(ncclCalloc(&connectData, 2*nranks));
|
||||||
for (int r=0; r<nrings; r++) {
|
for (int r=0; r<nrings; r++) {
|
||||||
int* ringRanks = rings+r*nranks;
|
int* ringRanks = rings+r*nranks;
|
||||||
struct ncclRing *ring = comm->rings+r;
|
struct ncclRing *ring = comm->rings+r;
|
||||||
struct ncclConnect connect[2];
|
NCCLCHECK(setupRing(comm, r, rank, nranks, ringRanks, allInfo, connectData+2*rank));
|
||||||
NCCLCHECK(setupRing(comm, r, rank, nranks, ringRanks, allInfo, connect));
|
int prev_offset = ring->userRanks[nranks-1]*2+1;
|
||||||
NCCLCHECK(bootstrapRingExchange(commState, connect, ring->userRanks[nranks-1], ring->userRanks[1], sizeof(struct ncclConnect)));
|
int next_offset = ring->userRanks[1]*2;
|
||||||
NCCLCHECK(ring->send.transport->send.connect(connect+1, &ring->send));
|
NCCLCHECK(bootstrapAllGather(commState, connectData, sizeof(struct ncclConnect)*2));
|
||||||
NCCLCHECK(ring->recv.transport->recv.connect(connect+0, &ring->recv));
|
NCCLCHECK(ring->send.transport->send.connect(connectData+next_offset, &ring->send));
|
||||||
|
NCCLCHECK(ring->recv.transport->recv.connect(connectData+prev_offset, &ring->recv));
|
||||||
}
|
}
|
||||||
|
free(connectData);
|
||||||
free(rings);
|
free(rings);
|
||||||
free(allInfo);
|
free(allInfo);
|
||||||
|
|
||||||
@ -506,12 +536,15 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
|||||||
|
|
||||||
// Compute intra ranks
|
// Compute intra ranks
|
||||||
int intraRank0 = -1, intraRank = -1, intraRanks = 0;
|
int intraRank0 = -1, intraRank = -1, intraRanks = 0;
|
||||||
|
int multiNode = 0;
|
||||||
for (int r=0; r<nranks; r++) {
|
for (int r=0; r<nranks; r++) {
|
||||||
if ((rankInfos[r].hostHash == rankInfos[rank].hostHash) &&
|
if ((rankInfos[r].hostHash == rankInfos[rank].hostHash) &&
|
||||||
(rankInfos[r].pidHash == rankInfos[rank].pidHash)) {
|
(rankInfos[r].pidHash == rankInfos[rank].pidHash)) {
|
||||||
if (intraRanks == 0) intraRank0 = r;
|
if (intraRanks == 0) intraRank0 = r;
|
||||||
if (r == rank) intraRank = intraRanks;
|
if (r == rank) intraRank = intraRanks;
|
||||||
intraRanks++;
|
intraRanks++;
|
||||||
|
} else if (rankInfos[r].hostHash != rankInfos[rank].hostHash) {
|
||||||
|
multiNode = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TRACE(INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
|
TRACE(INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
|
||||||
@ -523,6 +556,9 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
|||||||
}
|
}
|
||||||
NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, rankInfos[intraRank0].comm));
|
NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, rankInfos[intraRank0].comm));
|
||||||
|
|
||||||
|
// Determine thread threshold across all GPUs
|
||||||
|
comm->threadThreshold = ncclThreadThreshold(minCompCap, multiNode);
|
||||||
|
|
||||||
// Barrier
|
// Barrier
|
||||||
bootstrapClose(commState);
|
bootstrapClose(commState);
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
@ -539,7 +575,7 @@ bool SetCpuAffinity(int cudaDev, nvmlDevice_t* nvmlDevice) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t ncclCommInitRankSync(ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank) {
|
ncclResult_t ncclCommInitRankSync(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank) {
|
||||||
cpu_set_t affinitySave;
|
cpu_set_t affinitySave;
|
||||||
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||||
|
|
||||||
@ -553,12 +589,15 @@ ncclResult_t ncclCommInitRankSync(ncclComm_t* newcomm, int ndev, ncclUniqueId co
|
|||||||
SetCpuAffinity(cudaDev, &nvmlDevice);
|
SetCpuAffinity(cudaDev, &nvmlDevice);
|
||||||
ncclResult_t res;
|
ncclResult_t res;
|
||||||
|
|
||||||
NCCLCHECKGOTO(commAlloc(newcomm, ndev, myrank), res, cleanup);
|
NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup);
|
||||||
NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup);
|
NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup);
|
||||||
NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup);
|
NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup);
|
||||||
|
|
||||||
sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||||
NCCLCHECKGOTO(wrapNvmlShutdown(), res, cleanup);
|
NCCLCHECKGOTO(wrapNvmlShutdown(), res, cleanup);
|
||||||
|
|
||||||
|
INFO(INIT,"comm %p rank %d nranks %d - COMPLETE", *newcomm, myrank, nranks);
|
||||||
|
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
cleanup:
|
cleanup:
|
||||||
*newcomm = NULL;
|
*newcomm = NULL;
|
||||||
@ -566,7 +605,7 @@ cleanup:
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank);
|
NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank);
|
||||||
ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank) {
|
ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank) {
|
||||||
char* env = getenv("NCCL_COMM_ID");
|
char* env = getenv("NCCL_COMM_ID");
|
||||||
if (env && myrank == 0) {
|
if (env && myrank == 0) {
|
||||||
@ -649,9 +688,13 @@ static ncclResult_t initTransportsAll(struct ncclComm** comms, const int* devs,
|
|||||||
free(prevFinal);
|
free(prevFinal);
|
||||||
free(nextFinal);
|
free(nextFinal);
|
||||||
|
|
||||||
|
// Determine thread threshold across all GPUs
|
||||||
|
int threadThreshold = ncclThreadThreshold(minCompCap, 0);
|
||||||
|
|
||||||
for (int rank=0; rank<nranks; rank++) {
|
for (int rank=0; rank<nranks; rank++) {
|
||||||
comms[rank]->nRings = nrings;
|
comms[rank]->nRings = nrings;
|
||||||
comms[rank]->nThreads = nthreads;
|
comms[rank]->nThreads = nthreads;
|
||||||
|
comms[rank]->threadThreshold = threadThreshold;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int r=0; r<nrings; r++) {
|
for (int r=0; r<nrings; r++) {
|
||||||
|
@ -99,7 +99,7 @@ int parseStringList(const char* string, struct netIf* ifList, int maxList) {
|
|||||||
ifC++;
|
ifC++;
|
||||||
}
|
}
|
||||||
ptr++;
|
ptr++;
|
||||||
} while (c);
|
} while (ifNum < maxList && c);
|
||||||
return ifNum;
|
return ifNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ ncclResult_t ncclSocketPtrSupport(int dev, int* supportedTypes) {
|
|||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MAX_IFS 16
|
|
||||||
static char ncclNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
|
static char ncclNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
|
||||||
static union socketAddress ncclNetIfAddrs[MAX_IFS];
|
static union socketAddress ncclNetIfAddrs[MAX_IFS];
|
||||||
static int ncclNetIfs = -1;
|
static int ncclNetIfs = -1;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user