diff --git a/src/all_gather.cu b/src/all_gather.cu index f0948eb..515059e 100644 --- a/src/all_gather.cu +++ b/src/all_gather.cu @@ -442,12 +442,17 @@ ncclResult_t ncclAllGatherWithType(const void* sendbuff, void* recvbuff, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - if( comm->useRemoteRecv ) { - AllGatherKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (comm->nDev == 1) { + if (sendbuff != recvbuff) + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); } else { - AllGatherKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if( comm->useRemoteRecv ) { + AllGatherKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + AllGatherKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } return ncclSuccess; } diff --git a/src/all_reduce.cu b/src/all_reduce.cu index 54b046c..eb536c6 100644 --- a/src/all_reduce.cu +++ b/src/all_reduce.cu @@ -432,12 +432,17 @@ ncclResult_t ncclAllReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - if( comm->useRemoteRecv ) { - AllReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (comm->nDev == 1) { + if (sendbuff != recvbuff) + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); } else { - AllReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if( comm->useRemoteRecv ) { + AllReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + AllReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } return ncclSuccess; } diff --git a/src/broadcast.cu b/src/broadcast.cu index 5053cc3..0b4d152 100644 --- a/src/broadcast.cu +++ b/src/broadcast.cu @@ -348,27 +348,29 @@ ncclResult_t ncclBcastWithType(void* buff, const int count, const int root, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - if (comm->useRemoteRecv) { - if (index == (rootId + comm->nDev - 1) % comm->nDev) { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } else if (index == rootId) { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (comm->nDev != 1) { + if (comm->useRemoteRecv) { + if (index == (rootId + comm->nDev - 1) % comm->nDev) { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else if (index == rootId) { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } else { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } - } else { - if (index == (rootId + comm->nDev - 1) % comm->nDev) { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } else if (index == rootId) { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } else { - BroadcastKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (index == (rootId + comm->nDev - 1) % comm->nDev) { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else if (index == rootId) { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + BroadcastKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } } return ncclSuccess; diff --git a/src/reduce.cu b/src/reduce.cu index 6ef38b9..486ba78 100644 --- a/src/reduce.cu +++ b/src/reduce.cu @@ -336,15 +336,20 @@ ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - if (index == (rootId + 1) % comm->nDev) { - ReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); - } else if (index == rootId) { - ReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (comm->nDev == 1) { + if (sendbuff != recvbuff) + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); } else { - ReduceKernel - <<<1, NUM_THREADS + 1, 0, stream>>>(args); + if (index == (rootId + 1) % comm->nDev) { + ReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else if (index == rootId) { + ReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } else { + ReduceKernel + <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } } return ncclSuccess; } diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu index 797cfd8..5b67b8e 100644 --- a/src/reduce_scatter.cu +++ b/src/reduce_scatter.cu @@ -444,8 +444,13 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff, args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1; args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1; - ReduceScatterKernel + if (comm->nDev == 1) { + if (sendbuff != recvbuff) + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, recvcount*sizeof(T), cudaMemcpyDeviceToDevice, stream)); + } else { + ReduceScatterKernel <<<1, NUM_THREADS + 1, 0, stream>>>(args); + } return ncclSuccess; }