diff --git a/src/enqueue.cc b/src/enqueue.cc index 0cb4fc6..413b337 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -1268,6 +1268,11 @@ NCCL_PARAM(PipelineChunkSteps, "PIPELINE_CHUNKSTEPS", (NCCL_STEPS/8)); NCCL_PARAM(RingChunkSteps, "RING_CHUNKSTEPS", (NCCL_STEPS/2)); NCCL_PARAM(RingSliceSteps, "RING_SLICESTEPS", (NCCL_STEPS/4)); +NCCL_PARAM(TreeLongNetLatChunkSteps, "TREE_LONG_NET_LAT_CHUNKSTEPS", (4)); +NCCL_PARAM(TreeLongNetLatSliceSteps, "TREE_LONG_NET_LAT_SLICESTEPS", (2)); +NCCL_PARAM(RingLongNetLatChunkSteps, "RING_LONG_NET_LAT_CHUNKSTEPS", (4)); +NCCL_PARAM(RingLongNetLatSliceSteps, "RING_LONG_NET_LAT_SLICESTEPS", (2)); + static ncclResult_t getStepInfo(struct ncclInfo* info) { if (info->protocol == NCCL_PROTO_LL) { info->chunkSteps = info->sliceSteps = ncclParamLLChunkSteps(); @@ -1282,12 +1287,19 @@ static ncclResult_t getStepInfo(struct ncclInfo* info) { info->sliceSteps = ncclParamRingSliceSteps(); } } - // Make buffer deeper for longer latency network segment + + // Make buffer deeper for long latency network segment if (info->comm->nNodes > 1 && info->comm->netLatency > 100 && (info->coll == ncclFuncReduceScatter || info->coll == ncclFuncAllGather || info->coll == ncclFuncAllReduce)) { - info->sliceSteps = 1; - info->chunkSteps = 2; + if (info->algorithm == NCCL_ALGO_TREE) { + info->chunkSteps = ncclParamTreeLongNetLatChunkSteps(); + info->sliceSteps = ncclParamTreeLongNetLatSliceSteps(); + } else { + info->chunkSteps = ncclParamRingLongNetLatChunkSteps(); + info->sliceSteps = ncclParamRingLongNetLatSliceSteps(); + } } + if (info->chunkSteps > NCCL_STEPS/2 || info->sliceSteps > NCCL_STEPS/2) { WARN("Invalid chunkSteps=%d/sliceSteps=%d, must be at most NCCL_STEPS/2=%d\n", info->chunkSteps, info->sliceSteps, NCCL_STEPS/2); return ncclInvalidUsage; diff --git a/src/init.cc b/src/init.cc index 2b39afa..8521008 100644 --- a/src/init.cc +++ b/src/init.cc @@ -470,6 +470,7 @@ static ncclResult_t setupChannel(struct ncclComm* comm, int channelId, int rank, NCCL_PARAM(BuffSize, "BUFFSIZE", -2); NCCL_PARAM(LlBuffSize, "LL_BUFFSIZE", -2); NCCL_PARAM(Ll128BuffSize, "LL128_BUFFSIZE", -2); +NCCL_PARAM(LongNetLatBuffSizeScaling, "LONG_NET_LAT_BUFFSIZE_SCALING", 4); NCCL_PARAM(P2pNetChunkSize, "P2P_NET_CHUNKSIZE", (1 << 17)); /* 128 kB */ NCCL_PARAM(P2pPciChunkSize, "P2P_PCI_CHUNKSIZE", (1 << 17)); /* 128 kB */ @@ -484,8 +485,14 @@ static ncclResult_t computeBuffSizes(struct ncclComm* comm) { if (cpuArch == NCCL_TOPO_CPU_ARCH_ARM) defaults[NCCL_PROTO_SIMPLE] = DEFAULT_BUFFSIZE_ARM; + // Make buffer deeper for longer network latency segment + int scaling = 1; + if (comm->nNodes > 1 && comm->netLatency > 100) { + scaling = ncclParamLongNetLatBuffSizeScaling(); + } + for (int p=0; pbuffSizes[p] = envs[p] != -2 ? envs[p] : defaults[p]; + comm->buffSizes[p] = envs[p] != -2 ? envs[p] : scaling * defaults[p]; } if (comm->nNodes > 1) comm->p2pChunkSize = ncclParamP2pNetChunkSize();