Make NCCL collectives work on communicators with only one rank

This commit is contained in:
Sylvain Jeaugey 2016-06-06 14:35:00 -07:00
parent bd3cf73e6e
commit 7edfc57228
5 changed files with 61 additions and 39 deletions

View File

@ -442,12 +442,17 @@ ncclResult_t ncclAllGatherWithType(const void* sendbuff, void* recvbuff,
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if( comm->useRemoteRecv ) { if (comm->nDev == 1) {
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, true, T> if (sendbuff != recvbuff)
<<<1, NUM_THREADS + 1, 0, stream>>>(args); CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else { } else {
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, false, T> if( comm->useRemoteRecv ) {
<<<1, NUM_THREADS + 1, 0, stream>>>(args); AllGatherKernel<NUM_THREADS, UNROLL_COUNT, true, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, false, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
} }
return ncclSuccess; return ncclSuccess;
} }

View File

@ -432,12 +432,17 @@ ncclResult_t ncclAllReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if( comm->useRemoteRecv ) { if (comm->nDev == 1) {
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, true, T> if (sendbuff != recvbuff)
<<<1, NUM_THREADS + 1, 0, stream>>>(args); CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else { } else {
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, false, T> if( comm->useRemoteRecv ) {
<<<1, NUM_THREADS + 1, 0, stream>>>(args); AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, true, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, false, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
} }
return ncclSuccess; return ncclSuccess;
} }

View File

@ -348,27 +348,29 @@ ncclResult_t ncclBcastWithType(void* buff, const int count, const int root,
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if (comm->useRemoteRecv) { if (comm->nDev != 1) {
if (index == (rootId + comm->nDev - 1) % comm->nDev) { if (comm->useRemoteRecv) {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T> if (index == (rootId + comm->nDev - 1) % comm->nDev) {
<<<1, NUM_THREADS + 1, 0, stream>>>(args); BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T>
} else if (index == rootId) { <<<1, NUM_THREADS + 1, 0, stream>>>(args);
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, ROOT, T> } else if (index == rootId) {
<<<1, NUM_THREADS + 1, 0, stream>>>(args); BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, ROOT, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, MIDDLE, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
} else { } else {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, MIDDLE, T> if (index == (rootId + comm->nDev - 1) % comm->nDev) {
<<<1, NUM_THREADS + 1, 0, stream>>>(args); BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, END, T>
} <<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else { } else if (index == rootId) {
if (index == (rootId + comm->nDev - 1) % comm->nDev) { BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, ROOT, T>
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, END, T> <<<1, NUM_THREADS + 1, 0, stream>>>(args);
<<<1, NUM_THREADS + 1, 0, stream>>>(args); } else {
} else if (index == rootId) { BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, MIDDLE, T>
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, ROOT, T> <<<1, NUM_THREADS + 1, 0, stream>>>(args);
<<<1, NUM_THREADS + 1, 0, stream>>>(args); }
} else {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, MIDDLE, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} }
} }
return ncclSuccess; return ncclSuccess;

View File

@ -336,15 +336,20 @@ ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if (index == (rootId + 1) % comm->nDev) { if (comm->nDev == 1) {
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, BEGIN, T> if (sendbuff != recvbuff)
<<<1, NUM_THREADS + 1, 0, stream>>>(args); CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else if (index == rootId) {
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, END, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else { } else {
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, MIDDLE, T> if (index == (rootId + 1) % comm->nDev) {
<<<1, NUM_THREADS + 1, 0, stream>>>(args); ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, BEGIN, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else if (index == rootId) {
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, END, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, MIDDLE, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
} }
return ncclSuccess; return ncclSuccess;
} }

View File

@ -444,8 +444,13 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff,
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T> if (comm->nDev == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, recvcount*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args); <<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
return ncclSuccess; return ncclSuccess;
} }