Make NCCL collectives work on communicators with only one rank
This commit is contained in:
parent
bd3cf73e6e
commit
7edfc57228
@ -442,12 +442,17 @@ ncclResult_t ncclAllGatherWithType(const void* sendbuff, void* recvbuff,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
if( comm->useRemoteRecv ) {
|
if (comm->nDev == 1) {
|
||||||
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, true, T>
|
if (sendbuff != recvbuff)
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
|
||||||
} else {
|
} else {
|
||||||
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, false, T>
|
if( comm->useRemoteRecv ) {
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, true, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
} else {
|
||||||
|
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, false, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
@ -432,12 +432,17 @@ ncclResult_t ncclAllReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
if( comm->useRemoteRecv ) {
|
if (comm->nDev == 1) {
|
||||||
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, true, T>
|
if (sendbuff != recvbuff)
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
|
||||||
} else {
|
} else {
|
||||||
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, false, T>
|
if( comm->useRemoteRecv ) {
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, true, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
} else {
|
||||||
|
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, false, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
@ -348,27 +348,29 @@ ncclResult_t ncclBcastWithType(void* buff, const int count, const int root,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
if (comm->useRemoteRecv) {
|
if (comm->nDev != 1) {
|
||||||
if (index == (rootId + comm->nDev - 1) % comm->nDev) {
|
if (comm->useRemoteRecv) {
|
||||||
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T>
|
if (index == (rootId + comm->nDev - 1) % comm->nDev) {
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T>
|
||||||
} else if (index == rootId) {
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, ROOT, T>
|
} else if (index == rootId) {
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, ROOT, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
} else {
|
||||||
|
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, MIDDLE, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, MIDDLE, T>
|
if (index == (rootId + comm->nDev - 1) % comm->nDev) {
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, END, T>
|
||||||
}
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
} else {
|
} else if (index == rootId) {
|
||||||
if (index == (rootId + comm->nDev - 1) % comm->nDev) {
|
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, ROOT, T>
|
||||||
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, END, T>
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
} else {
|
||||||
} else if (index == rootId) {
|
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, MIDDLE, T>
|
||||||
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, ROOT, T>
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
}
|
||||||
} else {
|
|
||||||
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, MIDDLE, T>
|
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
|
@ -336,15 +336,20 @@ ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
if (index == (rootId + 1) % comm->nDev) {
|
if (comm->nDev == 1) {
|
||||||
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, BEGIN, T>
|
if (sendbuff != recvbuff)
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
|
||||||
} else if (index == rootId) {
|
|
||||||
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, END, T>
|
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
|
||||||
} else {
|
} else {
|
||||||
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, MIDDLE, T>
|
if (index == (rootId + 1) % comm->nDev) {
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, BEGIN, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
} else if (index == rootId) {
|
||||||
|
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, END, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
} else {
|
||||||
|
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, MIDDLE, T>
|
||||||
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
@ -444,8 +444,13 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff,
|
|||||||
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
|
||||||
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
|
||||||
|
|
||||||
ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
|
if (comm->nDev == 1) {
|
||||||
|
if (sendbuff != recvbuff)
|
||||||
|
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, recvcount*sizeof(T), cudaMemcpyDeviceToDevice, stream));
|
||||||
|
} else {
|
||||||
|
ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
|
||||||
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
|
||||||
|
}
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user