Fix Collnet when sliceSteps>1.

Fix chunkSteps/sliceSteps setting for collnet.
Fix shared buffer organization to account for the right buffer
width while still providing contiguous buffers for a set of
consecutive channels.
This commit is contained in:
Sylvain Jeaugey 2022-12-05 06:46:08 -08:00
parent a6c8f5e0c2
commit d42cdb72c7
2 changed files with 15 additions and 11 deletions

View File

@ -1274,7 +1274,8 @@ static ncclResult_t getStepInfo(struct ncclInfo* info) {
} else if (info->protocol == NCCL_PROTO_LL128) {
info->chunkSteps = info->sliceSteps = ncclParamLL128ChunkSteps();
} else { /* SIMPLE */
if (info->algorithm == NCCL_ALGO_TREE || info->coll == ncclFuncBroadcast || info->coll == ncclFuncReduce) {
if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN || info->algorithm == NCCL_ALGO_COLLNET_DIRECT ||
info->algorithm == NCCL_ALGO_TREE || info->coll == ncclFuncBroadcast || info->coll == ncclFuncReduce) {
info->chunkSteps = info->sliceSteps = ncclParamPipelineChunkSteps();
} else {
info->chunkSteps = ncclParamRingChunkSteps();
@ -1360,8 +1361,7 @@ comp_next:
// Set direct direction for broadcast-gather (read or write)
work->direct = (info->nBytes / info->nChannels <= 1024*1024) ? NCCL_DIRECT_WRITE : NCCL_DIRECT_READ;
} else if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN) {
stepSize = info->comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS;
chunkSize = std::min(256*1024, stepSize*chunkSteps);
chunkSize = std::min(256*1024, chunkSize);
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*64 && chunkSize > 131072) chunkSize /= 2;
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*8 && chunkSize > 65536) chunkSize /= 2;
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2;

View File

@ -367,10 +367,14 @@ static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, char** gp
return ncclSuccess;
}
static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int type, int slot, int channel, int* offset) {
// Allocate buffers between channels, so that consecutive channels have contiguous buffers.
// slot is going to be a multiple of sliceSteps and the buffer per channel needs to be
// large enough for sliceSteps.
static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int type, int slot, int channel, int* offset, int sliceSteps) {
// Use different pools for different channels and also separate send/recv.
int slotSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS;
int globalSlot = (type*NCCL_STEPS+slot)*comm->nChannels+channel;
int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS;
int slotSize = stepSize*sliceSteps;
int globalSlot = ((type*NCCL_STEPS+slot)/sliceSteps)*comm->nChannels+channel;
*offset = slotSize * globalSlot;
return ncclSuccess;
}
@ -629,7 +633,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;
int sharedBuffSlot = sub->posted%NCCL_STEPS;
int offset;
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset));
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset, args->sliceSteps));
resources->recvMem->offsFifo[buffSlot] = offset + s*args->chunkSize;
__sync_synchronize();
volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head;
@ -650,7 +654,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int ready = 1;
if (s == 0) {
int offset;
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset));
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset, args->sliceSteps));
args->sharedBuff[sharedBuffSlot] = localBuff + offset;
args->sharedSize[sharedBuffSlot] = args->chunkSize;
}
@ -742,7 +746,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int sharedBuffSlot = sub->posted%NCCL_STEPS;
int startChannel = group*COLLNET_GROUP_NSUBS;
int offset;
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset));
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps));
reqFifo[group][buffSlot].recvBuff = localBuff + offset;
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
sub->posted += args->sliceSteps;
@ -773,7 +777,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
} else {
int startChannel = group*COLLNET_GROUP_NSUBS;
int offset;
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset));
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps));
NCCLCHECK(collNetIflush(comm, resources->collNetComm, localBuff + offset, totalSize, mhandle, sub->requests+buffSlot));
}
} else {
@ -802,7 +806,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
int sharedBuffSlot = sub->transmitted%NCCL_STEPS;
int startChannel = group*COLLNET_GROUP_NSUBS;
int offset;
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset));
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps));
volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo;
offsFifo[buffSlot] = offset + (s%COLLNET_GROUP_NSUBS)*args->chunkSize;
__sync_synchronize();