From 7edfc57228efbf160d8ed7c187e78b25fbaf8ea6 Mon Sep 17 00:00:00 2001 From: Sylvain Jeaugey Date: Mon, 6 Jun 2016 14:35:00 -0700 Subject: [PATCH] Make NCCL collectives work on communicators with only one rank --- src/all_gather.cu | 15 ++++++++++----- src/all_reduce.cu | 15 ++++++++++----- src/broadcast.cu | 42 ++++++++++++++++++++++-------------------- src/reduce.cu | 21 +++++++++++++-------- src/reduce_scatter.cu | 7 ++++++- 5 files changed, 61 insertions(+), 39 deletions(-) diff --git a/src/all_gather.cu b/src/all_gather.cu index f0948eb..515059e 100644 --- a/src/all_gather.cu +++ b/src/all_gather.cu @@ -442,12 +442,17 @@ ncclResult_t ncclAllGatherWithType(const void* sendbuff, void* recvbuff, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - if( comm->useRemoteRecv ) { - AllGatherKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (comm->nDev == 1) { + if (sendbuff != recvbuff) + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); } else { - AllGatherKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if( comm->useRemoteRecv ) { + AllGatherKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + AllGatherKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } return ncclSuccess; } diff --git a/src/all_reduce.cu b/src/all_reduce.cu index 54b046c..eb536c6 100644 --- a/src/all_reduce.cu +++ b/src/all_reduce.cu @@ -432,12 +432,17 @@ ncclResult_t ncclAllReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - if( comm->useRemoteRecv ) { - AllReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (comm->nDev == 1) { + if (sendbuff != recvbuff) + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); } else { - AllReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if( comm->useRemoteRecv ) { + AllReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + AllReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } return ncclSuccess; } diff --git a/src/broadcast.cu b/src/broadcast.cu index 5053cc3..0b4d152 100644 --- a/src/broadcast.cu +++ b/src/broadcast.cu @@ -348,27 +348,29 @@ ncclResult_t ncclBcastWithType(void* buff, const int count, const int root, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - if (comm->useRemoteRecv) { - if (index == (rootId + comm->nDev - 1) % comm->nDev) { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } else if (index == rootId) { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (comm->nDev != 1) { + if (comm->useRemoteRecv) { + if (index == (rootId + comm->nDev - 1) % comm->nDev) { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else if (index == rootId) { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } else { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } - } else { - if (index == (rootId + comm->nDev - 1) % comm->nDev) { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } else if (index == rootId) { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } else { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (index == (rootId + comm->nDev - 1) % comm->nDev) { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else if (index == rootId) { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } } return ncclSuccess; diff --git a/src/reduce.cu b/src/reduce.cu index 6ef38b9..486ba78 100644 --- a/src/reduce.cu +++ b/src/reduce.cu @@ -336,15 +336,20 @@ ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - if (index == (rootId + 1) % comm->nDev) { - ReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } else if (index == rootId) { - ReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (comm->nDev == 1) { + if (sendbuff != recvbuff) + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); } else { - ReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (index == (rootId + 1) % comm->nDev) { + ReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else if (index == rootId) { + ReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + ReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } return ncclSuccess; } diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu index 797cfd8..5b67b8e 100644 --- a/src/reduce_scatter.cu +++ b/src/reduce_scatter.cu @@ -444,8 +444,13 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - ReduceScatterKernel + if (comm->nDev == 1) { + if (sendbuff != recvbuff) + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, recvcount*sizeof(T), cudaMemcpyDeviceToDevice, stream)); + } else { + ReduceScatterKernel <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } return ncclSuccess; }