Fix Collnet when sliceSteps>1.
Fix chunkSteps/sliceSteps setting for collnet. Fix shared buffer organization to account for the right buffer width while still providing contiguous buffers for a set of consecutive channels.
This commit is contained in:
parent
a6c8f5e0c2
commit
d42cdb72c7
@ -1274,7 +1274,8 @@ static ncclResult_t getStepInfo(struct ncclInfo* info) {
|
|||||||
} else if (info->protocol == NCCL_PROTO_LL128) {
|
} else if (info->protocol == NCCL_PROTO_LL128) {
|
||||||
info->chunkSteps = info->sliceSteps = ncclParamLL128ChunkSteps();
|
info->chunkSteps = info->sliceSteps = ncclParamLL128ChunkSteps();
|
||||||
} else { /* SIMPLE */
|
} else { /* SIMPLE */
|
||||||
if (info->algorithm == NCCL_ALGO_TREE || info->coll == ncclFuncBroadcast || info->coll == ncclFuncReduce) {
|
if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN || info->algorithm == NCCL_ALGO_COLLNET_DIRECT ||
|
||||||
|
info->algorithm == NCCL_ALGO_TREE || info->coll == ncclFuncBroadcast || info->coll == ncclFuncReduce) {
|
||||||
info->chunkSteps = info->sliceSteps = ncclParamPipelineChunkSteps();
|
info->chunkSteps = info->sliceSteps = ncclParamPipelineChunkSteps();
|
||||||
} else {
|
} else {
|
||||||
info->chunkSteps = ncclParamRingChunkSteps();
|
info->chunkSteps = ncclParamRingChunkSteps();
|
||||||
@ -1360,8 +1361,7 @@ comp_next:
|
|||||||
// Set direct direction for broadcast-gather (read or write)
|
// Set direct direction for broadcast-gather (read or write)
|
||||||
work->direct = (info->nBytes / info->nChannels <= 1024*1024) ? NCCL_DIRECT_WRITE : NCCL_DIRECT_READ;
|
work->direct = (info->nBytes / info->nChannels <= 1024*1024) ? NCCL_DIRECT_WRITE : NCCL_DIRECT_READ;
|
||||||
} else if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN) {
|
} else if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN) {
|
||||||
stepSize = info->comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS;
|
chunkSize = std::min(256*1024, chunkSize);
|
||||||
chunkSize = std::min(256*1024, stepSize*chunkSteps);
|
|
||||||
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*64 && chunkSize > 131072) chunkSize /= 2;
|
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*64 && chunkSize > 131072) chunkSize /= 2;
|
||||||
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*8 && chunkSize > 65536) chunkSize /= 2;
|
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*8 && chunkSize > 65536) chunkSize /= 2;
|
||||||
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2;
|
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2;
|
||||||
|
@ -367,10 +367,14 @@ static ncclResult_t sharedBuffersInit(struct ncclComm* comm, int cuda, char** gp
|
|||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int type, int slot, int channel, int* offset) {
|
// Allocate buffers between channels, so that consecutive channels have contiguous buffers.
|
||||||
|
// slot is going to be a multiple of sliceSteps and the buffer per channel needs to be
|
||||||
|
// large enough for sliceSteps.
|
||||||
|
static ncclResult_t sharedBuffersGet(struct ncclComm* comm, int type, int slot, int channel, int* offset, int sliceSteps) {
|
||||||
// Use different pools for different channels and also separate send/recv.
|
// Use different pools for different channels and also separate send/recv.
|
||||||
int slotSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS;
|
int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS;
|
||||||
int globalSlot = (type*NCCL_STEPS+slot)*comm->nChannels+channel;
|
int slotSize = stepSize*sliceSteps;
|
||||||
|
int globalSlot = ((type*NCCL_STEPS+slot)/sliceSteps)*comm->nChannels+channel;
|
||||||
*offset = slotSize * globalSlot;
|
*offset = slotSize * globalSlot;
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
@ -629,7 +633,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
|||||||
int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;
|
int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;
|
||||||
int sharedBuffSlot = sub->posted%NCCL_STEPS;
|
int sharedBuffSlot = sub->posted%NCCL_STEPS;
|
||||||
int offset;
|
int offset;
|
||||||
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset));
|
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset, args->sliceSteps));
|
||||||
resources->recvMem->offsFifo[buffSlot] = offset + s*args->chunkSize;
|
resources->recvMem->offsFifo[buffSlot] = offset + s*args->chunkSize;
|
||||||
__sync_synchronize();
|
__sync_synchronize();
|
||||||
volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head;
|
volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head;
|
||||||
@ -650,7 +654,7 @@ static ncclResult_t sendProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
|||||||
int ready = 1;
|
int ready = 1;
|
||||||
if (s == 0) {
|
if (s == 0) {
|
||||||
int offset;
|
int offset;
|
||||||
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset));
|
NCCLCHECK(sharedBuffersGet(comm, 0, sharedBuffSlot, 0, &offset, args->sliceSteps));
|
||||||
args->sharedBuff[sharedBuffSlot] = localBuff + offset;
|
args->sharedBuff[sharedBuffSlot] = localBuff + offset;
|
||||||
args->sharedSize[sharedBuffSlot] = args->chunkSize;
|
args->sharedSize[sharedBuffSlot] = args->chunkSize;
|
||||||
}
|
}
|
||||||
@ -742,7 +746,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
|||||||
int sharedBuffSlot = sub->posted%NCCL_STEPS;
|
int sharedBuffSlot = sub->posted%NCCL_STEPS;
|
||||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||||
int offset;
|
int offset;
|
||||||
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset));
|
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps));
|
||||||
reqFifo[group][buffSlot].recvBuff = localBuff + offset;
|
reqFifo[group][buffSlot].recvBuff = localBuff + offset;
|
||||||
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
|
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
|
||||||
sub->posted += args->sliceSteps;
|
sub->posted += args->sliceSteps;
|
||||||
@ -773,7 +777,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
|||||||
} else {
|
} else {
|
||||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||||
int offset;
|
int offset;
|
||||||
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset));
|
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps));
|
||||||
NCCLCHECK(collNetIflush(comm, resources->collNetComm, localBuff + offset, totalSize, mhandle, sub->requests+buffSlot));
|
NCCLCHECK(collNetIflush(comm, resources->collNetComm, localBuff + offset, totalSize, mhandle, sub->requests+buffSlot));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -802,7 +806,7 @@ static ncclResult_t recvProxyProgress(struct ncclComm* comm, struct ncclProxyArg
|
|||||||
int sharedBuffSlot = sub->transmitted%NCCL_STEPS;
|
int sharedBuffSlot = sub->transmitted%NCCL_STEPS;
|
||||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||||
int offset;
|
int offset;
|
||||||
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset));
|
NCCLCHECK(sharedBuffersGet(comm, 1, sharedBuffSlot, startChannel, &offset, args->sliceSteps));
|
||||||
volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo;
|
volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo;
|
||||||
offsFifo[buffSlot] = offset + (s%COLLNET_GROUP_NSUBS)*args->chunkSize;
|
offsFifo[buffSlot] = offset + (s%COLLNET_GROUP_NSUBS)*args->chunkSize;
|
||||||
__sync_synchronize();
|
__sync_synchronize();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user