Make NCCL collectives work on communicators with only one rank
This commit is contained in:
parent
bd3cf73e6e
commit
7edfc57228
@ -442,6 +442,10 @@ ncclResult_t ncclAllGatherWithType(const void* sendbuff, void* recvbuff,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
|
if (comm->nDev == 1) {
|
||||||
|
if (sendbuff != recvbuff)
|
||||||
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
|
||||||
|
} else {
|
||||||
if( comm->useRemoteRecv ) {
|
if( comm->useRemoteRecv ) {
|
||||||
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, true, T>
|
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, true, T>
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
@ -449,6 +453,7 @@ ncclResult_t ncclAllGatherWithType(const void* sendbuff, void* recvbuff,
|
|||||||
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, false, T>
|
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, false, T>
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -432,6 +432,10 @@ ncclResult_t ncclAllReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
|
if (comm->nDev == 1) {
|
||||||
|
if (sendbuff != recvbuff)
|
||||||
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
|
||||||
|
} else {
|
||||||
if( comm->useRemoteRecv ) {
|
if( comm->useRemoteRecv ) {
|
||||||
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, true, T>
|
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, true, T>
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
@ -439,6 +443,7 @@ ncclResult_t ncclAllReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
|
|||||||
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, false, T>
|
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, false, T>
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,6 +348,7 @@ ncclResult_t ncclBcastWithType(void* buff, const int count, const int root,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
|
if (comm->nDev != 1) {
|
||||||
if (comm->useRemoteRecv) {
|
if (comm->useRemoteRecv) {
|
||||||
if (index == (rootId + comm->nDev - 1) % comm->nDev) {
|
if (index == (rootId + comm->nDev - 1) % comm->nDev) {
|
||||||
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T>
|
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T>
|
||||||
@ -371,6 +372,7 @@ ncclResult_t ncclBcastWithType(void* buff, const int count, const int root,
|
|||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,6 +336,10 @@ ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
|
if (comm->nDev == 1) {
|
||||||
|
if (sendbuff != recvbuff)
|
||||||
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
|
||||||
|
} else {
|
||||||
if (index == (rootId + 1) % comm->nDev) {
|
if (index == (rootId + 1) % comm->nDev) {
|
||||||
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, BEGIN, T>
|
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, BEGIN, T>
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
@ -346,6 +350,7 @@ ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
|
|||||||
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, MIDDLE, T>
|
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, MIDDLE, T>
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -444,8 +444,13 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
|
if (comm->nDev == 1) {
|
||||||
|
if (sendbuff != recvbuff)
|
||||||
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, recvcount*sizeof(T), cudaMemcpyDeviceToDevice, stream));
|
||||||
|
} else {
|
||||||
ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
|
ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user