diff --git a/src/enqueue.cc b/src/enqueue.cc index 25d520e..0f68614 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -1281,6 +1281,12 @@ static ncclResult_t getStepInfo(struct ncclInfo* info) { info->sliceSteps = ncclParamRingSliceSteps(); } } + // Make buffer deeper for longer latency network segment + if (info->comm->nNodes > 1 && info->comm->netLatency > 100 && + (info->coll == ncclFuncReduceScatter || info->coll == ncclFuncAllGather || info->coll == ncclFuncAllReduce)) { + info->sliceSteps = 1; + info->chunkSteps = 2; + } if (info->chunkSteps > NCCL_STEPS/2 || info->sliceSteps > NCCL_STEPS/2) { WARN("Invalid chunkSteps=%d/sliceSteps=%d, must be at most NCCL_STEPS/2=%d\n", info->chunkSteps, info->sliceSteps, NCCL_STEPS/2); return ncclInvalidUsage; diff --git a/src/include/comm.h b/src/include/comm.h index 655292a..0ec1b61 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -218,6 +218,7 @@ struct ncclComm { float latencies[NCCL_NUM_FUNCTIONS][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; float bandwidths[NCCL_NUM_FUNCTIONS][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; int maxThreads[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; + float netLatency; /* This attribute can indicate the states of communicators and return code of * asynchronous NCCL operations. */ diff --git a/src/net.cc b/src/net.cc index 1480c76..11ad52b 100644 --- a/src/net.cc +++ b/src/net.cc @@ -290,6 +290,10 @@ ncclResult_t ncclNetInit(struct ncclComm* comm) { WARN("Error: network %s not found.", netName ? netName : ""); return ncclInvalidUsage; } + + ncclNetProperties_t props; + NCCLCHECK(ncclNetGetProperties(comm, 0, &props)); + comm->netLatency = props.latency; return ncclSuccess; }