From d42cdb72c7c8ad4380ccd0f0f5a6312cbfe83c76 Mon Sep 17 00:00:00 2001 From: Sylvain Jeaugey Date: Mon, 5 Dec 2022 06:46:08 -0800 Subject: [PATCH] Fix Collnet when sliceSteps>1. Fix chunkSteps/sliceSteps setting for collnet. Fix shared buffer organization to account for the right buffer width while still providing contiguous buffers for a set of consecutive channels. --- src/enqueue.cc | 6 +++--- src/transport/coll_net.cc | 20 ++++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/enqueue.cc b/src/enqueue.cc index 0f68614..0cb4fc6 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -1274,7 +1274,8 @@ static ncclResult_t getStepInfo(struct ncclInfo* info) { } else if (info->protocol == NCCL_PROTO_LL128) { info->chunkSteps = info->sliceSteps = ncclParamLL128ChunkSteps(); } else { /* SIMPLE */ - if (info->algorithm == NCCL_ALGO_TREE || info->coll == ncclFuncBroadcast || info->coll == ncclFuncReduce) { + if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN || info->algorithm == NCCL_ALGO_COLLNET_DIRECT || + info->algorithm == NCCL_ALGO_TREE || info->coll == ncclFuncBroadcast || info->coll == ncclFuncReduce) { info->chunkSteps = info->sliceSteps = ncclParamPipelineChunkSteps(); } else { info->chunkSteps = ncclParamRingChunkSteps(); @@ -1360,8 +1361,7 @@ comp_next: // Set direct direction for broadcast-gather (read or write) work->direct = (info->nBytes / info->nChannels <= 1024*1024) ? NCCL_DIRECT_WRITE : NCCL_DIRECT_READ; } else if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN) { - stepSize = info->comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; - chunkSize = std::min(256*1024, stepSize*chunkSteps); + chunkSize = std::min(256*1024, chunkSize); while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*64 && chunkSize > 131072) chunkSize /= 2; while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*8 && chunkSize > 65536) chunkSize /= 2; while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2; diff --git a/src/transport/coll_net.cc b/src/transport/coll_net.cc index de10f2f..0595432 100644 --- a/src/transport/coll_net.cc +++ b/src/transport/coll_net.cc @@ -367,10 +367,14 @@ static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, char** gp return ncclSuccess; } -static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int type, int slot, int channel, int* offset) { +// Allocate buffers between channels, so that consecutive channels have contiguous buffers. +// slot is going to be a multiple of sliceSteps and the buffer per channel needs to be +// large enough for sliceSteps. +static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int type, int slot, int channel, int* offset, int sliceSteps) { // Use different pools for different channels and also separate send/recv. - int slotSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; - int globalSlot = (type*NCCL_STEPS+slot)*comm->nChannels+channel; + int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; + int slotSize = stepSize*sliceSteps; + int globalSlot = ((type*NCCL_STEPS+slot)/sliceSteps)*comm->nChannels+channel; *offset = slotSize * globalSlot; return ncclSuccess; } @@ -629,7 +633,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; int sharedBuffSlot = sub->posted%NCCL_STEPS; int offset; - NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset)); + NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset, args->sliceSteps)); resources->recvMem->offsFifo[buffSlot] = offset + s*args->chunkSize; __sync_synchronize(); volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; @@ -650,7 +654,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg int ready = 1; if (s == 0) { int offset; - NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset)); + NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset, args->sliceSteps)); args->sharedBuff[sharedBuffSlot] = localBuff + offset; args->sharedSize[sharedBuffSlot] = args->chunkSize; } @@ -742,7 +746,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg int sharedBuffSlot = sub->posted%NCCL_STEPS; int startChannel = group*COLLNET_GROUP_NSUBS; int offset; - NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset)); + NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps)); reqFifo[group][buffSlot].recvBuff = localBuff + offset; TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff); sub->posted += args->sliceSteps; @@ -773,7 +777,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg } else { int startChannel = group*COLLNET_GROUP_NSUBS; int offset; - NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset)); + NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps)); NCCLCHECK(collNetIflush(comm, resources->collNetComm, localBuff + offset, totalSize, mhandle, sub->requests+buffSlot)); } } else { @@ -802,7 +806,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg int sharedBuffSlot = sub->transmitted%NCCL_STEPS; int startChannel = group*COLLNET_GROUP_NSUBS; int offset; - NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset)); + NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps)); volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo; offsFifo[buffSlot] = offset + (s%COLLNET_GROUP_NSUBS)*args->chunkSize; __sync_synchronize();