2.18.3-1
Fix data corruption with Tree/LL128 on systems with 1GPU:1NIC. Fix hang with Collnet on bfloat16 on systems with less than one NIC per GPU. Fix long initialization time. Fix data corruption with Collnet when mixing multi-process and multi-GPU per process. Fix crash when shared memory creation fails. Fix Avg operation with Collnet/Chain. Fix performance of alltoall at scale with more than one NIC per GPU. Fix performance for DGX H800. Fix race condition in connection progress causing a crash. Fix network flush with Collnet. Fix performance of aggregated allGather/reduceScatter operations. Fix PXN operation when CUDA_VISIBLE_DEVICES is set. Fix NVTX3 compilation issues on Debian 10.
This commit is contained in:
parent
d97a32fac8
commit
ea38312273
@ -1,6 +1,6 @@
|
||||
##### version
|
||||
NCCL_MAJOR := 2
|
||||
NCCL_MINOR := 18
|
||||
NCCL_PATCH := 1
|
||||
NCCL_PATCH := 3
|
||||
NCCL_SUFFIX :=
|
||||
PKG_REVISION := 1
|
||||
|
@ -114,7 +114,7 @@ namespace {
|
||||
chunkSize = divUp((int)size, int(nChannels*minChunkSize))*int(minChunkSize);
|
||||
|
||||
{ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0> prims
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_TREE_ARITY, 1>, /*Direct=*/0, Proto, 0> prims
|
||||
(tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
@ -140,7 +140,7 @@ namespace {
|
||||
}
|
||||
|
||||
{ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto, 0> prims
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_TREE_ARITY>, /*Direct=*/1, Proto, 0> prims
|
||||
(tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
@ -197,8 +197,8 @@ namespace {
|
||||
chunkSize = divUp((int)size, nChannels*int(minChunkSize))*int(minChunkSize);
|
||||
|
||||
if (tree->up == -1) {
|
||||
// Reduce and broadcast. Max number of recv is 3, max number of send is 3
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto, 0>
|
||||
// Reduce and broadcast. Max number of recv is 2, max number of send is 2
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_TREE_ARITY_TOP>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff, args->redOpArg);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@ -215,7 +215,7 @@ namespace {
|
||||
* into DirectRecv and DirectSend capabilities, this ctor would have both=0,
|
||||
* but the ctor above for tree roots would be DirectRecv=0 DirectSend=1.
|
||||
*/
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/1, Proto, 0>
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_TREE_ARITY, 1>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg, 0*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
@ -234,7 +234,7 @@ namespace {
|
||||
}
|
||||
else {
|
||||
// Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto, 0>
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_TREE_ARITY>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff,
|
||||
args->redOpArg, 1*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
@ -564,6 +564,7 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_CHAIN, NCCL
|
||||
ncclTree *tree = &ncclShmem.channel.collnetChain;
|
||||
ssize_t chunkSize = int(args->lastChunkSize);
|
||||
const ssize_t loopSize = int(nChannels*chunkSize);
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
const ssize_t size = args->count;
|
||||
|
||||
int nthreadsSplit = nthreads/2;
|
||||
@ -609,6 +610,22 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_CHAIN, NCCL
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (recv == nranks) {
|
||||
// I'm the first in the broadcast chain, I need to perform the division (postOp)
|
||||
if (send == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recv(offset, nelem, /*postOp*/true);
|
||||
}
|
||||
} else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvCopyDirectSend(offset, nelem, /*postOp*/true);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (send == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@ -624,6 +641,7 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET_CHAIN, NCCL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
|
@ -12,7 +12,6 @@
|
||||
#include "op128.h"
|
||||
|
||||
#define COLL_UNROLL (ncclCollUnroll())
|
||||
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
|
||||
|
||||
typedef void(*ncclKern_t)();
|
||||
extern __device__ ncclKern_t ncclFuncs[];
|
||||
|
@ -327,10 +327,6 @@ class Primitives<
|
||||
// Adjust remote index with peer offset in case we are directly pulling from peer's output buffer
|
||||
waitPeer<DirectRecv, 0, 1, 0, 0, 1>(outIx, outIx+pOffset, offset, realSize);
|
||||
subBarrier();
|
||||
if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) {
|
||||
// Since waitPeer sets srcs[0] to output buffer + offset, we are doing a direct-write based recv
|
||||
// Do nothing
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j=0; j<fan.nrecv(); j++) {
|
||||
int i = (j+shift)%fan.nrecv();
|
||||
@ -338,11 +334,11 @@ class Primitives<
|
||||
if (skip >= 0 && i >= skip) pOffset += peerElem;
|
||||
void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset;
|
||||
int realPeerSize = min(realSize, totalElem-pOffset);
|
||||
if (DirectRecv && ncclShmem.groups[group].srcs[i] == dst0) realPeerSize = 0;
|
||||
if (realPeerSize > 0) reduceCopy<Unroll, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fenceNeeded = barrierAny(fenceNeeded);
|
||||
postPeer<Recv, Send>(fenceNeeded);
|
||||
offset += realSize;
|
||||
|
@ -629,10 +629,12 @@ static ncclResult_t scheduleP2pTasksToPlan(
|
||||
// Try to use all channels
|
||||
int nChannelsMax = comm->p2pnChannelsPerPeer;
|
||||
int nChannelsMin = nChannelsMax;
|
||||
if (comm->nNodes == 1) {
|
||||
// Try to use all channels, but one channel per operation.
|
||||
while (nChannelsMin*nRanks > comm->p2pnChannels && nChannelsMin > 1) nChannelsMin /= 2;
|
||||
// Avoid overloading channels with 8+ operations as we loose the sync warp, hence a bit of bandwidth.
|
||||
while (nChannelsMax*nRanks > comm->p2pnChannels*4 && nChannelsMax > 1) nChannelsMax /= 2;
|
||||
}
|
||||
|
||||
bool fuseOk;
|
||||
// We can perform 8 send/recv per round per CTA. Make sure we jump between fused blocks at node boundaries.
|
||||
@ -1141,13 +1143,9 @@ ncclResult_t ncclLaunchFinish(struct ncclComm* comm) {
|
||||
/*****************************************************************************/
|
||||
|
||||
static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetTypeSupport) {
|
||||
if (info->comm->collNetSupport > 0) {
|
||||
// Translate ncclAvg and PreMulSum
|
||||
ncclRedOp_t netOp = info->op == ncclAvg || info->op >= ncclNumOps ? ncclSum : info->op;
|
||||
NCCLCHECK(collNetReduceSupport(info->comm, info->datatype, netOp, collNetTypeSupport));
|
||||
} else {
|
||||
*collNetTypeSupport = 0;
|
||||
}
|
||||
*collNetTypeSupport = info->comm->collNetSupportMatrix[netOp][info->datatype];
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -1536,7 +1534,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo const* inf
|
||||
t->chunkSteps = info->chunkSteps;
|
||||
t->sliceSteps = info->sliceSteps;
|
||||
ncclIntruQueueEnqueue(&tasks->collQueue, t);
|
||||
tasks->collBytesTotal += t->count*ncclTypeSize(t->datatype);
|
||||
tasks->collBytesTotal += info->nBytes;
|
||||
tasks->nTasksColl += 1;
|
||||
}
|
||||
}
|
||||
|
@ -169,14 +169,16 @@ static ncclResult_t connectTrees(struct ncclComm* comm, int* treeToParent, int*
|
||||
static ncclResult_t connectCollNet(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph) {
|
||||
int rank = comm->rank;
|
||||
int localRanks = comm->localRanks;
|
||||
int nHeads = collNetGraph->nChannels;
|
||||
int nHeads = 0;
|
||||
int *heads;
|
||||
NCCLCHECK(ncclCalloc(&heads, nHeads));
|
||||
NCCLCHECK(ncclCalloc(&heads, localRanks));
|
||||
// Find all head ranks
|
||||
// Head index is always 0
|
||||
for (int c=0; c<nHeads; c++) {
|
||||
for (int c=0; c<collNetGraph->nChannels; c++) {
|
||||
int* collNetIntra = collNetGraph->intra+c*localRanks;
|
||||
heads[c] = collNetIntra[0];
|
||||
int head = collNetIntra[0];
|
||||
for (int h=0; h<nHeads; h++) if (heads[h] == head) head = -1;
|
||||
if (head != -1) heads[nHeads++] = collNetIntra[0];
|
||||
}
|
||||
// For all channels
|
||||
for (int c=0; c<comm->nChannels; c++) {
|
||||
|
@ -108,6 +108,9 @@ static ncclResult_t ncclTopoFollowPath(struct ncclTopoSystem* system, struct ncc
|
||||
if (type1 == -1) return ncclSuccess;
|
||||
struct ncclTopoNode* node1 = system->nodes[type1].nodes+index1;
|
||||
struct ncclTopoLinkList* path = node1->paths[type2]+index2;
|
||||
struct ncclTopoNode* node2 = system->nodes[type2].nodes+index2;
|
||||
struct ncclTopoLinkList* revPath = node2->paths[type1]+index1;
|
||||
|
||||
if (path == NULL) {
|
||||
WARN("No path computed to go from %s/%d to %s/%d", topoNodeTypeStr[type1], index1, topoNodeTypeStr[type2], index2);
|
||||
return ncclInternalError;
|
||||
@ -121,6 +124,10 @@ static ncclResult_t ncclTopoFollowPath(struct ncclTopoSystem* system, struct ncc
|
||||
int type = intra ? graph->typeIntra : graph->typeInter;
|
||||
|
||||
if (mult == 1 && (path->type > type)) return ncclSuccess;
|
||||
if (mult == 1 && (graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE ||
|
||||
graph->pattern == NCCL_TOPO_PATTERN_TREE ||
|
||||
graph->pattern == NCCL_TOPO_PATTERN_SPLIT_TREE) &&
|
||||
(revPath->type > type)) return ncclSuccess;
|
||||
|
||||
bw *= mult;
|
||||
|
||||
@ -260,7 +267,7 @@ ncclResult_t ncclTopoSearchNextGpuSort(struct ncclTopoSystem* system, struct ncc
|
||||
ncclResult_t ncclTopoSearchRec(struct ncclTopoSystem* system, struct ncclTopoGraph* graph, struct ncclTopoGraph* saveGraph, int* time);
|
||||
|
||||
// Try to keep all searchs within one second
|
||||
#define NCCL_SEARCH_GLOBAL_TIMEOUT (1ULL<<18)
|
||||
#define NCCL_SEARCH_GLOBAL_TIMEOUT (5ULL<<16)
|
||||
#define NCCL_SEARCH_TIMEOUT (1<<14)
|
||||
#define NCCL_SEARCH_TIMEOUT_TREE (1<<14)
|
||||
#define NCCL_SEARCH_TIMEOUT_SAMECHANNELS (1<<8)
|
||||
@ -333,6 +340,10 @@ ncclResult_t ncclTopoCompareGraphs(struct ncclTopoSystem* system, struct ncclTop
|
||||
// 1. Try to get the same nChannels between Rings and Trees
|
||||
if (graph->nChannels < graph->minChannels) return ncclSuccess;
|
||||
|
||||
if (graph->pattern == NCCL_TOPO_PATTERN_NVLS) { // NVLS channels correspond to GPUs pulling from NVLS. So the more the better.
|
||||
if (graph->nChannels > refGraph->nChannels && graph->nChannels <= system->nodes[GPU].count) *copy = 1;
|
||||
return ncclSuccess;
|
||||
}
|
||||
// 2. Try to get better bandwidth
|
||||
// Give a 15% perf bonus to paths not crossing nics
|
||||
float target = 1.0 - (refGraph->crossNic - graph->crossNic) * .15;
|
||||
@ -506,7 +517,6 @@ ncclResult_t ncclTopoSearchRecNet(struct ncclTopoSystem* system, struct ncclTopo
|
||||
struct ncclTopoNode* gpu;
|
||||
if (graph->collNet && net->net.collSupport == 0) continue;
|
||||
if (net->net.bw < bw) continue;
|
||||
if (net->net.maxChannels == 0) continue;
|
||||
|
||||
graph->inter[graph->nChannels*2] = net->id;
|
||||
graph->latencyInter = net->net.latency;
|
||||
@ -517,10 +527,13 @@ ncclResult_t ncclTopoSearchRecNet(struct ncclTopoSystem* system, struct ncclTopo
|
||||
system->nodes[NET].nodes[i].net.bw -= bw;
|
||||
}
|
||||
}
|
||||
net->net.maxChannels--;
|
||||
|
||||
// First try to replay the last channel
|
||||
// NVLS needs to balance on all NICs
|
||||
if (graph->pattern == NCCL_TOPO_PATTERN_NVLS) {
|
||||
NCCLCHECK(ncclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, 0, time, -1, -1, nets[graph->nChannels]));
|
||||
} else {
|
||||
if (graph->nChannels > 0) {
|
||||
// Try to replay the last channel
|
||||
int g;
|
||||
NCCLCHECK(ncclTopoReplayGetGpu(system, graph, -1, &g));
|
||||
NCCLCHECK(ncclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, FORCED_ORDER_REPLAY, time, NET, n, g));
|
||||
@ -562,8 +575,8 @@ ncclResult_t ncclTopoSearchRecNet(struct ncclTopoSystem* system, struct ncclTopo
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
net->net.maxChannels++;
|
||||
for (int i=0; i<system->nodes[NET].count; i++) {
|
||||
if ((system->nodes[NET].nodes[i].net.asic == net->net.asic) &&
|
||||
(system->nodes[NET].nodes[i].net.port == net->net.port)) {
|
||||
@ -779,7 +792,7 @@ float speedArrayInter[] = { 48.0, 30.0, 28.0, 24.0, 20.0, 18.0, 15.0, 12.0, 10.0
|
||||
#define NSPEEDSINTER (sizeof(speedArrayInter)/sizeof(float))
|
||||
|
||||
float sm90SpeedArrayIntra[] = { 60.0, 40.0, 30.0, 24.0, 20.0, 15.0, 12.0, 6.0, 3.0 };
|
||||
float sm90SpeedArrayInter[] = { 48.0, 45.0, 42.0, 40.0, 30.0, 24.0, 15.0, 12.0, 6.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
float sm90SpeedArrayInter[] = { 48.0, 45.0, 42.0, 40.0, 30.0, 24.0, 20.0, 17.5, 15.0, 12.0, 6.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
#define NSPEEDSINTRA_SM90 (sizeof(sm90SpeedArrayIntra)/sizeof(float))
|
||||
#define NSPEEDSINTER_SM90 (sizeof(sm90SpeedArrayInter)/sizeof(float))
|
||||
|
||||
@ -839,8 +852,9 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph
|
||||
int pass = 1;
|
||||
int speedIndex = 0;
|
||||
float maxBw = system->maxBw;
|
||||
if (system->nodes[NET].count == 0 && graph->pattern == NCCL_TOPO_PATTERN_NVLS) maxBw /= ngpus; // We want all GPUs to pull the same BW
|
||||
while (speedArray[speedIndex] > maxBw && speedIndex < nspeeds-1) speedIndex++;
|
||||
float totalBw = system->totalBw;
|
||||
if (ngpus == 1 || graph->pattern != NCCL_TOPO_PATTERN_RING) totalBw *= ngpus*1.0/(ngpus-1);
|
||||
while ((speedArray[speedIndex] > maxBw || speedArray[speedIndex]*graph->minChannels > totalBw) && speedIndex < nspeeds-1) speedIndex++;
|
||||
tmpGraph.bwIntra = tmpGraph.bwInter = speedArray[speedIndex];
|
||||
int64_t globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT;
|
||||
|
||||
@ -880,6 +894,13 @@ search:
|
||||
else globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT;
|
||||
if (globalTimeout < 0 && graph->nChannels) goto done;
|
||||
|
||||
// Try a simpler tree
|
||||
if (ccMin >= 90 && tmpGraph.pattern == NCCL_TOPO_PATTERN_BALANCED_TREE) {
|
||||
tmpGraph.pattern = NCCL_TOPO_PATTERN_TREE;
|
||||
goto search;
|
||||
}
|
||||
tmpGraph.pattern = graph->pattern;
|
||||
|
||||
int maxTypeIntra = system->nodes[NET].count > 0 ? tmpGraph.typeInter : PATH_SYS;
|
||||
if (tmpGraph.typeIntra < maxTypeIntra && (graph->nChannels == 0 || tmpGraph.typeIntra < graph->typeIntra)) {
|
||||
tmpGraph.typeIntra += 1;
|
||||
@ -900,13 +921,6 @@ search:
|
||||
}
|
||||
tmpGraph.crossNic = 0;
|
||||
|
||||
// Try a simpler tree
|
||||
if (tmpGraph.pattern == NCCL_TOPO_PATTERN_SPLIT_TREE) {
|
||||
tmpGraph.pattern = NCCL_TOPO_PATTERN_TREE;
|
||||
goto search;
|
||||
}
|
||||
tmpGraph.pattern = graph->pattern;
|
||||
|
||||
// Decrease bw until we find a solution
|
||||
if ((speedIndex < nspeeds-1) && (graph->nChannels == 0 || (speedArray[speedIndex+1]/graph->bwInter > .49))) {
|
||||
tmpGraph.bwInter = tmpGraph.bwIntra = speedArray[++speedIndex];
|
||||
@ -951,14 +965,17 @@ done:
|
||||
graph->nChannels = 1;
|
||||
}
|
||||
|
||||
if (graph->pattern != NCCL_TOPO_PATTERN_NVLS && ((ccMin <= 80 && graph->bwIntra >= 25.0) || (ccMin <= 90 && graph->bwIntra >= 50.0))) {
|
||||
if (graph->nChannels == 0) return ncclSuccess;
|
||||
if (graph->pattern == NCCL_TOPO_PATTERN_NVLS) return ncclSuccess;
|
||||
if (graph->bwIntra < 25.0) return ncclSuccess;
|
||||
if (ccMin > 80 && graph->bwIntra < 50.0 && graph->nChannels > 4) return ncclSuccess;
|
||||
|
||||
int dupChannels = std::min(graph->nChannels*2, graph->maxChannels);
|
||||
memcpy(graph->intra+graph->nChannels*ngpus, graph->intra, (dupChannels-graph->nChannels)*ngpus*sizeof(int));
|
||||
memcpy(graph->inter+graph->nChannels*2,graph->inter, (dupChannels-graph->nChannels)*2*sizeof(int));
|
||||
graph->bwIntra /= DIVUP(dupChannels, graph->nChannels);
|
||||
graph->bwInter /= DIVUP(dupChannels, graph->nChannels);
|
||||
graph->nChannels = dupChannels;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -1039,10 +1056,10 @@ ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoG
|
||||
int pxnLevel = ncclPxnDisable(comm) == 1 ? 0 : ncclParamP2pPxnLevel();
|
||||
// See whether we can use the remote rank preferred device.
|
||||
if (ncclParamCrossNic() == 0 || (pxnLevel != 0)) {
|
||||
// Find local NIC number close to local cudaDev
|
||||
int cudaDev = comm->peerInfo[peerRank].cudaDev;
|
||||
// Find local NIC number close to local nvmlDev
|
||||
int nvmlDev = comm->peerInfo[peerRank].nvmlDev;
|
||||
int localRank;
|
||||
if (ncclTopoDevToRank(comm->topo, cudaDev, &localRank) != ncclSuccess) return ncclSuccess;
|
||||
if (ncclTopoDevToRank(comm->topo, nvmlDev, &localRank) != ncclSuccess) return ncclSuccess;
|
||||
int netDev;
|
||||
NCCLCHECK(ncclTopoGetLocalNet(comm->topo, localRank, channelId, &netDev));
|
||||
|
||||
|
@ -62,19 +62,17 @@ static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = {
|
||||
#define NCCL_HW_NVLINK 0
|
||||
#define NCCL_HW_PCI 1
|
||||
#define NCCL_HW_NET 2
|
||||
// Tree/Simple is the latency a 256kB chunk, which is ~ base lat + 256k/12GB/s (+ 256k/12GB/s for the network).
|
||||
// Ring/LL128 reflects the latency for the second plateau, not the base latency.
|
||||
static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
|
||||
{ /* NVLINK */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { .6, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .6, 1.9, 3.4 },
|
||||
{ /* Tree (LL/LL128/Simple)*/ { .6, 1.25, 4 }, /* Ring (LL/LL128/Simple)*/ { .6, 1.9, 3.4 },
|
||||
/* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 4.75 },
|
||||
/* NVLS */ { 0, 0, 0 }, /* NVLSTree */ { 0, 0, 0 } },
|
||||
/* PCI */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 },
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 6 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 },
|
||||
/* CollNetDirect (Simple)*/ { 0, 0, 8.0 }, /* CollNetChain (Simple)*/ { 0, 0, 8.0 },
|
||||
/* NVLS */ { 0, 0, 0 }, /* NVLSTree */ { 0, 0, 0 } },
|
||||
/* NET */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 28 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 14.0 },
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 14 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 14.0 },
|
||||
/* CollNetDirect (Simple)*/ { 0, 0, 10.7 }, /* CollNetChain (Simple)*/ { 0, 0, 14 },
|
||||
/* NVLS */ { 0, 0, 18 }, /* NVLSTree */ { 0, 0, 19 } }
|
||||
};
|
||||
@ -85,17 +83,26 @@ static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
|
||||
#define HOPPER_COMPCAP_IDX 2
|
||||
|
||||
// LL128 max BW per channel
|
||||
static const double ll128MaxBwPerCh[3] = { 20.0, 20.0, 36.7 };
|
||||
static const double llMaxBws[3][3] = {
|
||||
/* Volta-N1/Intel-N2/Intel-N4) */ {39.0, 39.0, 20.4},
|
||||
/* Ampere-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0},
|
||||
/* Hopper-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0}
|
||||
};
|
||||
|
||||
static const double perChMaxRingLL128Bws[3][3] = {
|
||||
/* Volta (N1/N2/N4) */ {20.0, 20.0, 20.0},
|
||||
/* Ampere (N1/N2/N4) */ {20.0, 20.0, 20.0},
|
||||
/* Hopper (N1/N2/N4) */ {36.7, 36.7, 36.7},
|
||||
};
|
||||
static const double perChMaxTreeLL128Bws[3][3] = {
|
||||
/* Volta (N1/N2/N4) */ {20.0, 20.0, 20.0},
|
||||
/* Ampere (N1/N2/N4) */ {20.0, 20.0, 20.0},
|
||||
/* Hopper (N1/N2/N4) */ {36.7, 36.7, 29.0},
|
||||
};
|
||||
static const double perChMaxTreeBws[3][3] = {
|
||||
/* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0},
|
||||
/* Ampere (N1/N2/N4) */ {24.0, 23.6, 17.8},
|
||||
/* Hopper (N1/N2/N4) */ {38.7, 41.4, 33.0},
|
||||
/* Hopper (N1/N2/N4) */ {38.7, 41.4, 36.0},
|
||||
};
|
||||
|
||||
// Network post overhead in ns (1000 = 1 us)
|
||||
@ -137,6 +144,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
int index1 = nNodes == 1 ? compCapIndex : cpuVendor == NCCL_TOPO_CPU_VENDOR_AMD ? 1 : 0;
|
||||
double llMaxBw = llMaxBws[index1][index2];
|
||||
double perChMaxTreeBw = perChMaxTreeBws[compCapIndex][index2];
|
||||
double perChMaxRingLL128Bw = perChMaxRingLL128Bws[compCapIndex][index2];
|
||||
double perChMaxTreeLL128Bw = perChMaxTreeLL128Bws[compCapIndex][index2];
|
||||
// De-penalize Tree/Simple latency on Power systems to favor Tree than Ring
|
||||
if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE];
|
||||
float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount
|
||||
@ -167,10 +176,11 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
|
||||
// Various model refinements
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); }
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[compCapIndex]*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), graphs[a]->nChannels*perChMaxRingLL128Bw);
|
||||
if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw = std::min(busBw*1.0/3.8, llMaxBw);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), ll128MaxBwPerCh[compCapIndex]*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), graphs[a]->nChannels*perChMaxTreeLL128Bw);
|
||||
if (a == NCCL_ALGO_TREE && graphs[a]->pattern == NCCL_TOPO_PATTERN_TREE) busBw *= .85;
|
||||
if (a == NCCL_ALGO_COLLNET_DIRECT && p != NCCL_PROTO_SIMPLE) busBw = 0; // Not used
|
||||
if (a == NCCL_ALGO_COLLNET_CHAIN && p != NCCL_PROTO_SIMPLE) busBw = 0; // Not used
|
||||
if (a == NCCL_ALGO_COLLNET_DIRECT && p == NCCL_PROTO_SIMPLE) {
|
||||
@ -184,7 +194,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
// Convert bus BW to algorithm BW
|
||||
float ratio;
|
||||
if (a == NCCL_ALGO_RING) ratio = (1.0 * nRanks) / nsteps;
|
||||
else if (a == NCCL_ALGO_NVLS) ratio = .75;
|
||||
else if (a == NCCL_ALGO_NVLS) ratio = 5.0/6.0;
|
||||
else if (a == NCCL_ALGO_NVLS_TREE) ratio = .70 * nNodes / (2*(nNodes-1));
|
||||
else ratio = .5;
|
||||
comm->bandwidths[coll][a][p] = busBw * ratio;
|
||||
@ -273,7 +283,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
// Enable LL128 by default only on Volta/Ampere/Hopper+NVLink. Other cases are not tested and may cause silent data corruption.
|
||||
pEnable = 1;
|
||||
pEnable &= (graphs[a]->typeInter <= PATH_PXB || (minCompCap >= 90 && graphs[a]->typeInter <= PATH_PXN));
|
||||
pEnable &= (graphs[a]->typeIntra <= PATH_NVL);
|
||||
pEnable &= (graphs[a]->typeIntra <= PATH_NVB);
|
||||
pEnable &= (minCompCap == maxCompCap);
|
||||
switch (minCompCap) {
|
||||
case 70: pEnable &= 1; break;
|
||||
|
@ -597,13 +597,7 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
|
||||
int dev = -1;
|
||||
NCCLCHECK(xmlGetAttrIndex(gpuNode, "dev", &index));
|
||||
if (index == -1) {
|
||||
if (nvmlDev == NULL) {
|
||||
const char* busId;
|
||||
NCCLCHECK(xmlGetAttr(pciNode, "busid", &busId));
|
||||
if (busId == NULL || cudaDeviceGetByPCIBusId(&dev, busId) != cudaSuccess) dev = -1;
|
||||
} else {
|
||||
NCCLCHECK(ncclNvmlDeviceGetIndex(nvmlDev, (unsigned int*)&dev));
|
||||
}
|
||||
NCCLCHECK(xmlSetAttrInt(gpuNode, "dev", dev));
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrInt(gpuNode, "dev", &dev));
|
||||
@ -713,8 +707,8 @@ ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct nccl
|
||||
NCCLCHECK(ncclTopoGetPciNode(xml, busId, &node));
|
||||
NCCLCHECK(xmlSetAttrIfUnset(node, "class", "0x03"));
|
||||
NCCLCHECK(ncclTopoGetXmlFromSys(node, xml));
|
||||
nvmlDevice_t nvmlDev = NULL;
|
||||
if (ncclNvmlDeviceGetHandleByPciBusId(busId, &nvmlDev) != ncclSuccess) nvmlDev = NULL;
|
||||
nvmlDevice_t nvmlDev;
|
||||
NCCLCHECK(ncclNvmlDeviceGetHandleByPciBusId(busId, &nvmlDev));
|
||||
NCCLCHECK(ncclTopoGetXmlFromGpu(node, nvmlDev, xml, gpuNode));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
29
src/group.cc
29
src/group.cc
@ -41,6 +41,7 @@ ncclResult_t ncclAsyncLaunch(
|
||||
job->undo = undo;
|
||||
job->destructor = destructor;
|
||||
job->abortFlag = comm->abortFlag;
|
||||
job->childAbortFlag = comm->childAbortFlag;
|
||||
job->state = ncclGroupJobRunning;
|
||||
job->comm = comm;
|
||||
/* check if there are blocking and nonblocking comms at the same time in group. */
|
||||
@ -83,19 +84,8 @@ ncclResult_t ncclGroupStart() {
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
NVTX3_FUNC_RANGE_IN(nccl_domain);
|
||||
|
||||
/* if previous group launch does not complete, don't launch this one. */
|
||||
if (ncclGroupJobMainPtr != NULL) {
|
||||
if (__atomic_load_n(&ncclGroupJobMainPtr->doneFlag, __ATOMIC_ACQUIRE) == false) {
|
||||
ret = ncclInvalidUsage;
|
||||
goto exit;
|
||||
} else {
|
||||
NCCLCHECKGOTO(groupJobComplete(ncclGroupJobMainPtr), ret, exit);
|
||||
}
|
||||
}
|
||||
NCCLCHECK(ncclGroupStartInternal());
|
||||
TRACE_CALL("ncclGroupStart()");
|
||||
|
||||
exit:
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -191,13 +181,6 @@ failure:
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline void groupResetJobState() {
|
||||
ncclGroupBlocking = -1;
|
||||
ncclGroupJobMainPtr = NULL;
|
||||
memset(&ncclGroupJobMain, 0, sizeof(struct ncclGroupJob));
|
||||
return;
|
||||
}
|
||||
|
||||
static void groupCleanup(struct ncclComm** groupCommHeadPtr, struct ncclComm** groupCommPreconnectHeadPtr, struct ncclIntruQueue<struct ncclAsyncJob, &ncclAsyncJob::next>* asyncJobsPtr, ncclResult_t* groupErrorPtr, ncclResult_t error) {
|
||||
struct ncclComm* comm = *groupCommHeadPtr;
|
||||
|
||||
@ -326,6 +309,7 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) {
|
||||
|
||||
if (*groupAbortFlag == true || errorJobAbortFlag == true) {
|
||||
*job->abortFlag = 1;
|
||||
if (job->childAbortFlag) *job->childAbortFlag = 1;
|
||||
}
|
||||
|
||||
job = job->next;
|
||||
@ -432,15 +416,6 @@ fail:
|
||||
goto exit;
|
||||
}
|
||||
|
||||
static ncclResult_t groupJobComplete(struct ncclGroupJob* job) {
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
if (job) {
|
||||
ret = ncclAsyncJobComplete(&job->base);
|
||||
groupResetJobState();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void ncclGroupJobAbort() {
|
||||
ncclGroupJobAbortFlag = true;
|
||||
(void) groupJobComplete(ncclGroupJobMainPtr);
|
||||
|
@ -18,11 +18,11 @@
|
||||
} \
|
||||
} while(false)
|
||||
|
||||
#define CUDACHECKGOTO(cmd, res, label) do { \
|
||||
#define CUDACHECKGOTO(cmd, RES, label) do { \
|
||||
cudaError_t err = cmd; \
|
||||
if( err != cudaSuccess ) { \
|
||||
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
|
||||
res = ncclUnhandledCudaError; \
|
||||
RES = ncclUnhandledCudaError; \
|
||||
goto label; \
|
||||
} \
|
||||
} while(false)
|
||||
@ -60,11 +60,11 @@
|
||||
} \
|
||||
} while(true)
|
||||
|
||||
#define SYSCHECKGOTO(statement, res, label) do { \
|
||||
#define SYSCHECKGOTO(statement, RES, label) do { \
|
||||
if ((statement) == -1) { \
|
||||
/* Print the back trace*/ \
|
||||
res = ncclSystemError; \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
RES = ncclSystemError; \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
|
||||
goto label; \
|
||||
} \
|
||||
} while (0);
|
||||
@ -72,16 +72,16 @@
|
||||
#define NEQCHECK(statement, value) do { \
|
||||
if ((statement) != value) { \
|
||||
/* Print the back trace*/ \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, ncclSystemError); \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, ncclSystemError, strerror(errno)); \
|
||||
return ncclSystemError; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define NEQCHECKGOTO(statement, value, res, label) do { \
|
||||
#define NEQCHECKGOTO(statement, value, RES, label) do { \
|
||||
if ((statement) != value) { \
|
||||
/* Print the back trace*/ \
|
||||
res = ncclSystemError; \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
RES = ncclSystemError; \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
|
||||
goto label; \
|
||||
} \
|
||||
} while (0);
|
||||
@ -89,57 +89,57 @@
|
||||
#define EQCHECK(statement, value) do { \
|
||||
if ((statement) == value) { \
|
||||
/* Print the back trace*/ \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, ncclSystemError); \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, ncclSystemError, strerror(errno)); \
|
||||
return ncclSystemError; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define EQCHECKGOTO(statement, value, res, label) do { \
|
||||
#define EQCHECKGOTO(statement, value, RES, label) do { \
|
||||
if ((statement) == value) { \
|
||||
/* Print the back trace*/ \
|
||||
res = ncclSystemError; \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
RES = ncclSystemError; \
|
||||
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
|
||||
goto label; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
// Propagate errors up
|
||||
#define NCCLCHECK(call) do { \
|
||||
ncclResult_t res = call; \
|
||||
if (res != ncclSuccess && res != ncclInProgress) { \
|
||||
ncclResult_t RES = call; \
|
||||
if (RES != ncclSuccess && RES != ncclInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
return res; \
|
||||
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, RES); \
|
||||
return RES; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define NCCLCHECKGOTO(call, res, label) do { \
|
||||
res = call; \
|
||||
if (res != ncclSuccess && res != ncclInProgress) { \
|
||||
#define NCCLCHECKGOTO(call, RES, label) do { \
|
||||
RES = call; \
|
||||
if (RES != ncclSuccess && RES != ncclInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, RES); \
|
||||
goto label; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define NCCLWAIT(call, cond, abortFlagPtr) do { \
|
||||
volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \
|
||||
ncclResult_t res = call; \
|
||||
if (res != ncclSuccess && res != ncclInProgress) { \
|
||||
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
ncclResult_t RES = call; \
|
||||
if (RES != ncclSuccess && RES != ncclInProgress) { \
|
||||
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, RES); \
|
||||
return ncclInternalError; \
|
||||
} \
|
||||
if (tmpAbortFlag) NEQCHECK(*tmpAbortFlag, 0); \
|
||||
} while (!(cond));
|
||||
|
||||
#define NCCLWAITGOTO(call, cond, abortFlagPtr, res, label) do { \
|
||||
#define NCCLWAITGOTO(call, cond, abortFlagPtr, RES, label) do { \
|
||||
volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \
|
||||
res = call; \
|
||||
if (res != ncclSuccess && res != ncclInProgress) { \
|
||||
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
RES = call; \
|
||||
if (RES != ncclSuccess && RES != ncclInProgress) { \
|
||||
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, RES); \
|
||||
goto label; \
|
||||
} \
|
||||
if (tmpAbortFlag) NEQCHECKGOTO(*tmpAbortFlag, 0, res, label); \
|
||||
if (tmpAbortFlag) NEQCHECKGOTO(*tmpAbortFlag, 0, RES, label); \
|
||||
} while (!(cond));
|
||||
|
||||
#define NCCLCHECKTHREAD(a, args) do { \
|
||||
|
@ -215,6 +215,7 @@ struct ncclComm {
|
||||
int rank; // my rank in the communicator
|
||||
int nRanks; // number of GPUs in communicator
|
||||
int cudaDev; // my cuda device index
|
||||
int nvmlDev; // my nvml device index
|
||||
int compCap; // compute capability of the GPU
|
||||
int minCompCap, maxCompCap; // min/max compute capability in the communicator
|
||||
int64_t busId; // my PCI bus ID in int format
|
||||
@ -298,6 +299,7 @@ struct ncclComm {
|
||||
int proxyRefCountOld; /* store proxy post-atomic-sub refcount */
|
||||
// Whether this communicator uses collNet
|
||||
int collNetSupport;
|
||||
uint8_t collNetSupportMatrix[4/*sum,prod,min,max*/][ncclNumTypes];
|
||||
int intraHighestTransportType;
|
||||
int* collNetHeads;
|
||||
int collNetHeadsNum;
|
||||
|
@ -129,6 +129,9 @@ struct ncclRing {
|
||||
};
|
||||
|
||||
|
||||
// The root of each tree only has one node down (+1 intra-node).
|
||||
#define NCCL_MAX_TREE_ARITY_TOP 2
|
||||
// Nodes inside the binary tree can have to two nodes down (+1 intra-node).
|
||||
#define NCCL_MAX_TREE_ARITY 3
|
||||
struct ncclTree {
|
||||
int depth;
|
||||
|
@ -35,6 +35,7 @@ struct ncclAsyncJob {
|
||||
void(*destructor)(void*);
|
||||
ncclGroupJobState_t state;
|
||||
volatile uint32_t *abortFlag; /* point to comm abortFlag */
|
||||
volatile uint32_t *childAbortFlag; /* point to child abortFlag */
|
||||
ncclComm_t comm;
|
||||
};
|
||||
|
||||
@ -66,8 +67,34 @@ extern __thread ncclResult_t ncclGroupError;
|
||||
extern __thread struct ncclComm* ncclGroupCommHead;
|
||||
extern __thread struct ncclComm* ncclGroupCommPreconnectHead;
|
||||
extern __thread int ncclGroupBlocking;
|
||||
extern __thread struct ncclGroupJob *ncclGroupJobMainPtr;
|
||||
extern __thread struct ncclGroupJob ncclGroupJobMain;
|
||||
|
||||
static inline void groupResetJobState() {
|
||||
ncclGroupBlocking = -1;
|
||||
ncclGroupJobMainPtr = NULL;
|
||||
memset(&ncclGroupJobMain, 0, sizeof(struct ncclGroupJob));
|
||||
return;
|
||||
}
|
||||
|
||||
static inline ncclResult_t groupJobComplete(struct ncclGroupJob* job) {
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
if (job) {
|
||||
ret = ncclAsyncJobComplete(&job->base);
|
||||
groupResetJobState();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline ncclResult_t ncclGroupStartInternal() {
|
||||
/* if previous group launch does not complete, don't launch this one. */
|
||||
if (ncclGroupJobMainPtr != NULL) {
|
||||
if (__atomic_load_n(&ncclGroupJobMainPtr->doneFlag, __ATOMIC_ACQUIRE) == false) {
|
||||
return ncclInvalidUsage;
|
||||
} else {
|
||||
NCCLCHECK(groupJobComplete(ncclGroupJobMainPtr));
|
||||
}
|
||||
}
|
||||
ncclGroupDepth++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
@ -126,7 +126,7 @@
|
||||
* Systems:
|
||||
*
|
||||
* \image html
|
||||
* https://raw.githubusercontent.com/jrhemstad/nvtx_wrappers/master/docs/example_range.png
|
||||
* https://raw.githubusercontent.com/NVIDIA/NVTX/release-v3/docs/images/example_range.png
|
||||
*
|
||||
* Alternatively, use the \ref MACROS like `NVTX3_FUNC_RANGE()` to add
|
||||
* ranges to your code that automatically use the name of the enclosing function
|
||||
@ -561,18 +561,27 @@
|
||||
|
||||
/* Temporary helper #defines, removed with #undef at end of header */
|
||||
|
||||
#if !defined(NVTX3_USE_CHECKED_OVERLOADS_FOR_GET)
|
||||
#if defined(_MSC_VER) && _MSC_VER < 1914
|
||||
/* Microsoft's compiler prior to VS2017 Update 7 (15.7) uses an older parser
|
||||
* that does not work with domain::get's specialization for domain::global,
|
||||
* and would require extra conditions to make SFINAE work for the overloaded
|
||||
* get() functions. This macro disables use of overloaded get() in order to
|
||||
* work with VS2015 and versions of VS2017 below 15.7, without penalizing
|
||||
* users of newer compilers. Building with this flag set to 0 means errors
|
||||
* when defining tag structs (see documentation for domain, named_category,
|
||||
* and registered_string) will have more complex compiler error messages
|
||||
* instead of the clear static_assert messages from the get() overloads.
|
||||
/* Some compilers do not correctly support SFINAE, which is used in this API
|
||||
* to detect common usage errors and provide clearer error messages (by using
|
||||
* static_assert) than the compiler would produce otherwise. These compilers
|
||||
* will generate errors while compiling this file such as:
|
||||
*
|
||||
* error: ‘name’ is not a member of ‘nvtx3::v1::domain::global’
|
||||
*
|
||||
* The following compiler versions are known to have this problem, and so are
|
||||
* set by default to disable the SFINAE-based checks:
|
||||
*
|
||||
* - All MSVC versions prior to VS2017 Update 7 (15.7)
|
||||
* - GCC 8.1-8.3 (the problem was fixed in GCC 8.4)
|
||||
*
|
||||
* If you find your compiler hits this problem, you can work around it by
|
||||
* defining NVTX3_USE_CHECKED_OVERLOADS_FOR_GET to 0 before including this
|
||||
* header, or you can add a check for your compiler version to this #if.
|
||||
* Also, please report the issue on the NVTX github page.
|
||||
*/
|
||||
#if !defined(NVTX3_USE_CHECKED_OVERLOADS_FOR_GET)
|
||||
#if defined(_MSC_VER) && _MSC_VER < 1914 \
|
||||
|| defined(__GNUC__) && __GNUC__ == 8 && __GNUC_MINOR__ < 4
|
||||
#define NVTX3_USE_CHECKED_OVERLOADS_FOR_GET 0
|
||||
#else
|
||||
#define NVTX3_USE_CHECKED_OVERLOADS_FOR_GET 1
|
||||
|
@ -35,6 +35,7 @@ struct ncclComm;
|
||||
struct ncclPeerInfo {
|
||||
int rank;
|
||||
int cudaDev;
|
||||
int nvmlDev;
|
||||
int gdrSupport;
|
||||
uint64_t hostHash;
|
||||
uint64_t pidHash;
|
||||
|
54
src/init.cc
54
src/init.cc
@ -320,6 +320,12 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
|
||||
CUDACHECK(cudaGetDevice(&comm->cudaDev));
|
||||
|
||||
NCCLCHECK(getBusId(comm->cudaDev, &comm->busId));
|
||||
nvmlDevice_t nvmlDev;
|
||||
char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE];
|
||||
NCCLCHECK(int64ToBusId(comm->busId, busId));
|
||||
NCCLCHECK(ncclNvmlDeviceGetHandleByPciBusId(busId, &nvmlDev));
|
||||
NCCLCHECK(ncclNvmlDeviceGetIndex(nvmlDev, (unsigned int*)&comm->nvmlDev));
|
||||
|
||||
comm->compCap = ncclCudaCompCap();
|
||||
TRACE(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx compCap %d", comm, rank, ndev, comm->cudaDev, comm->busId, comm->compCap);
|
||||
|
||||
@ -327,6 +333,7 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
|
||||
comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;
|
||||
|
||||
comm->collNetSupport = 0;
|
||||
memset(comm->collNetSupportMatrix, 0, sizeof(comm->collNetSupportMatrix));
|
||||
|
||||
ncclMemoryPoolConstruct(&comm->memPool_ncclKernelPlan);
|
||||
ncclMemoryPoolConstruct(&comm->memPool_ncclProxyOp);
|
||||
@ -452,7 +459,8 @@ static void showVersion() {
|
||||
|
||||
static ncclResult_t fillInfo(struct ncclComm* comm, struct ncclPeerInfo* info, uint64_t commHash) {
|
||||
info->rank = comm->rank;
|
||||
CUDACHECK(cudaGetDevice(&info->cudaDev));
|
||||
info->cudaDev = comm->cudaDev;
|
||||
info->nvmlDev = comm->nvmlDev;
|
||||
info->hostHash=getHostHash()+commHash;
|
||||
info->pidHash=getPidHash()+commHash;
|
||||
|
||||
@ -636,6 +644,45 @@ static ncclResult_t collNetTrySetup(ncclComm_t comm, ncclComm_t parent, struct n
|
||||
share = false;
|
||||
}
|
||||
|
||||
if (share) {
|
||||
memcpy(comm->collNetSupportMatrix, parent->collNetSupportMatrix, sizeof(comm->collNetSupportMatrix));
|
||||
} else {
|
||||
do {
|
||||
/* Initialize all entries in collNetSupportMatrix[redop][type]. Since some
|
||||
ranks don't connect to sharp we enable a (redop,type) if any rank claims
|
||||
support. */
|
||||
const ncclRedOp_t redops[] = {ncclSum, ncclProd, ncclMin, ncclMax};
|
||||
uint8_t(*matrix)[4][ncclNumTypes];
|
||||
bool isHead = false;
|
||||
matrix = nullptr;
|
||||
NCCLCHECKGOTO(ncclCalloc(&matrix, comm->nRanks), ret, matrix_end);
|
||||
for (int h = 0; h < nHeads; h++) isHead |= (heads[h] == comm->rank);
|
||||
if (isHead) {
|
||||
for (int ty=0; ty < ncclNumTypes; ty++) {
|
||||
for (int i=0; i < 4; i++) {
|
||||
int support = 0;
|
||||
NCCLCHECKGOTO(collNetReduceSupport(comm, (ncclDataType_t)ty, redops[i], &support), ret, matrix_end);
|
||||
// bit 0 = not supported, bit 1 = supported
|
||||
matrix[rank][redops[i]][ty] = 1<<(support ? 1 : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
NCCLCHECKGOTO(bootstrapAllGather(comm->bootstrap, matrix, sizeof(*matrix)), ret, matrix_end);
|
||||
for (int ty=0; ty < ncclNumTypes; ty++) {
|
||||
for (int i=0; i < 4; i++) {
|
||||
int op = redops[i];
|
||||
uint8_t accum = 0;
|
||||
for (int r=0; r < comm->nRanks; r++) accum |= matrix[r][op][ty];
|
||||
// We support (redop, type) if some rank supports it and no rank doesn't support it
|
||||
comm->collNetSupportMatrix[op][ty] = (accum == (1<<1));
|
||||
}
|
||||
}
|
||||
matrix_end:
|
||||
free(matrix);
|
||||
if (ret != ncclSuccess) goto fail;
|
||||
} while (0);
|
||||
}
|
||||
|
||||
// Verify CollNet setup across ranks after trying all channels
|
||||
NCCLCHECKGOTO(ncclTransportCollNetCheck(comm, collNetSetupFail), ret, fail);
|
||||
TRACE(NCCL_INIT, "rank %d Connected inter-node CollNet", rank);
|
||||
@ -1306,6 +1353,8 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
|
||||
comm->cudaArch = cudaArch;
|
||||
comm->commHash = getHash(job->commId.internal, NCCL_UNIQUE_ID_BYTES);
|
||||
|
||||
INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d nvmlDev %d busId %lx commId 0x%llx - Init START", comm, comm->rank, comm->nRanks, comm->cudaDev, comm->nvmlDev, comm->busId, (unsigned long long)hashUniqueId(job->commId));
|
||||
|
||||
NCCLCHECKGOTO(initTransportsRank(comm, job->parent), res, fail);
|
||||
|
||||
// update communicator state
|
||||
@ -1323,7 +1372,7 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
|
||||
}
|
||||
|
||||
|
||||
INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx commId 0x%llx - Init COMPLETE", comm, comm->rank, comm->nRanks, comm->cudaDev, comm->busId, (unsigned long long)hashUniqueId(job->commId));
|
||||
INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d nvmlDev %d busId %lx commId 0x%llx - Init COMPLETE", comm, comm->rank, comm->nRanks, comm->cudaDev, comm->nvmlDev, comm->busId, (unsigned long long)hashUniqueId(job->commId));
|
||||
exit:
|
||||
if (job->newcomm) {
|
||||
/* assign it to user pointer. */
|
||||
@ -1952,6 +2001,7 @@ ncclResult_t ncclCommSplit(ncclComm_t comm, int color, int key, ncclComm_t *newc
|
||||
if (comm->config.splitShare) {
|
||||
childComm->abortFlag = comm->abortFlag;
|
||||
childComm->abortFlagRefCount = comm->abortFlagRefCount;
|
||||
comm->childAbortFlag = NULL;
|
||||
ncclAtomicRefCountIncrement(comm->abortFlagRefCount);
|
||||
} else {
|
||||
NCCLCHECKGOTO(ncclCudaHostCalloc((uint32_t**)&childComm->abortFlag, 1), res, fail);
|
||||
|
@ -32,7 +32,7 @@ static void shmHandleInit(int fd, char* shmPath, size_t shmSize, size_t realShmS
|
||||
handle->devShmPtr = dptr;
|
||||
handle->shmSize = shmSize;
|
||||
handle->realShmSize = realShmSize;
|
||||
handle->refcount = (int*)(hptr + shmSize);
|
||||
handle->refcount = (hptr != NULL) ? (int*)(hptr + shmSize) : NULL;
|
||||
if (create) {
|
||||
int slen = strlen(shmPath);
|
||||
handle->shmPath = (char*)malloc(slen + 1);
|
||||
@ -81,6 +81,7 @@ ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** de
|
||||
if (hptr == MAP_FAILED) {
|
||||
WARN("Could not map %s size %zi, error: %s", shmPath, realShmSize, strerror(errno));
|
||||
ret = ncclSystemError;
|
||||
hptr = NULL;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
@ -125,7 +126,7 @@ ncclResult_t ncclShmClose(ncclShmHandle_t handle) {
|
||||
if (tmphandle) {
|
||||
if (tmphandle->fd >= 0) {
|
||||
close(tmphandle->fd);
|
||||
if (tmphandle->shmPath != NULL && *tmphandle->refcount > 0) {
|
||||
if (tmphandle->shmPath != NULL && tmphandle->refcount != NULL && *tmphandle->refcount > 0) {
|
||||
if (unlink(tmphandle->shmPath) != 0) {
|
||||
WARN("unlink shared memory %s failed, error: %s", tmphandle->shmPath, strerror(errno));
|
||||
ret = ncclSystemError;
|
||||
|
@ -421,6 +421,9 @@ static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) {
|
||||
uint64_t magic;
|
||||
enum ncclSocketType type;
|
||||
int received = 0;
|
||||
const int one = 1;
|
||||
SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
|
||||
|
||||
NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
|
||||
if (received == 0) return ncclSuccess;
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
|
||||
|
@ -1474,10 +1474,11 @@ void* ncclProxyService(void* _args) {
|
||||
// Progress all ops for this ncclProxyLocalPeer
|
||||
ncclProxyAsyncOp* op = peer->asyncOps;
|
||||
while (op != nullptr) {
|
||||
ncclProxyAsyncOp* opnext = op->next; /* in case op is freed in proxyProgressAsync */
|
||||
type = op->type;
|
||||
res = proxyProgressAsync(op, proxyState, &asyncOpCount, peer, &connectionPool);
|
||||
if (res == ncclSuccess || res == ncclInProgress) {
|
||||
op = op->next;
|
||||
op = opnext;
|
||||
} else {
|
||||
// Res is a bad result
|
||||
closeConn = 1;
|
||||
|
@ -148,14 +148,12 @@ struct setupReq {
|
||||
/* Setup send connector, and return connect information for others in the coll
|
||||
* communicator to connect to me */
|
||||
static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) {
|
||||
struct setupReq req;
|
||||
struct setupReq req = { 0 };
|
||||
|
||||
int proxyRank, tpProxyRank;
|
||||
NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank));
|
||||
NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr));
|
||||
send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0;
|
||||
// Determine whether we need to flush the GDR buffer on recv or not
|
||||
if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush));
|
||||
|
||||
NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &send->proxyConn.tpLocalRank));
|
||||
tpProxyRank = comm->topParentRanks[myInfo->rank];
|
||||
@ -170,12 +168,14 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph
|
||||
}
|
||||
|
||||
static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId, int connIndex) {
|
||||
struct setupReq req;
|
||||
struct setupReq req = { 0 };
|
||||
|
||||
int proxyRank, tpProxyRank;
|
||||
NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank));
|
||||
NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 0, &req.useGdr));
|
||||
recv->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0;
|
||||
// Determine whether we need to flush the GDR buffer on recv or not
|
||||
if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush));
|
||||
|
||||
NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &recv->proxyConn.tpLocalRank));
|
||||
tpProxyRank = comm->topParentRanks[myInfo->rank];
|
||||
|
@ -162,7 +162,7 @@ struct setupReq {
|
||||
/* Determine if we will use this transport for this peer and return connect
|
||||
* information for this peer */
|
||||
static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) {
|
||||
struct setupReq req;
|
||||
struct setupReq req = { 0 };
|
||||
int localRank, tpProxyRank;
|
||||
|
||||
send->conn.shared = req.shared = graph ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1;
|
||||
@ -183,10 +183,10 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph
|
||||
NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0));
|
||||
|
||||
if (proxyRank == myInfo->rank) {
|
||||
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d%s%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, comm->ncclNet->name, req.netDev,
|
||||
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%d] -> %d[%d] [send] via NET/%s/%d%s%s", channelId, connIndex, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, comm->ncclNet->name, req.netDev,
|
||||
req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "");
|
||||
} else {
|
||||
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d(%d)%s%s", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, comm->ncclNet->name, req.netDev,
|
||||
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%d] -> %d[%d] [send] via NET/%s/%d(%d)%s%s", channelId, connIndex, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, comm->ncclNet->name, req.netDev,
|
||||
proxyRank, req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "");
|
||||
}
|
||||
*((int*)connectInfo) = tpProxyRank;
|
||||
@ -200,7 +200,7 @@ NCCL_PARAM(GdrCopyFlushEnable, "GDRCOPY_FLUSH_ENABLE", 0);
|
||||
|
||||
/* Setup recv connector */
|
||||
static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId, int connIndex) {
|
||||
struct setupReq req;
|
||||
struct setupReq req = { 0 };
|
||||
int localRank;
|
||||
|
||||
recv->conn.shared = req.shared = graph ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1;
|
||||
@ -224,7 +224,7 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph
|
||||
req.tpRank = comm->topParentRanks[myInfo->rank];
|
||||
req.tpRemoteRank = comm->topParentRanks[peerInfo->rank];
|
||||
NCCLCHECK(ncclProxyCallBlocking(comm, &recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), connectInfo, sizeof(ncclNetHandle_t)));
|
||||
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [receive] via NET/%s/%d%s%s", channelId, connIndex, peerInfo->rank, peerInfo->busId, myInfo->rank, myInfo->busId, comm->ncclNet->name, req.netDev,
|
||||
INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%d] -> %d[%d] [receive] via NET/%s/%d%s%s", channelId, connIndex, peerInfo->rank, peerInfo->nvmlDev, myInfo->rank, myInfo->nvmlDev, comm->ncclNet->name, req.netDev,
|
||||
req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "");
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort)
|
||||
// Merge multi-port NICs into the same PCI device
|
||||
p[strlen(p)-1] = '0';
|
||||
// Also merge virtual functions (VF) into the same device
|
||||
if (ncclParamIbMergeVfs()) p[strlen(p)-3] = '0';
|
||||
if (ncclParamIbMergeVfs()) p[strlen(p)-3] = p[strlen(p)-4] = '0';
|
||||
// And keep the real port aside (the ibv port is always 1 on recent cards)
|
||||
*realPort = 0;
|
||||
for (int d=0; d<ncclNIbDevs; d++) {
|
||||
@ -795,7 +795,8 @@ ib_recv:
|
||||
if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE;
|
||||
|
||||
// Allocate Flush dummy buffer for GPU Direct RDMA
|
||||
rComm->gpuFlush.enabled = (ncclIbGdrSupport(lComm->dev) == 0) && (ncclParamIbGdrFlushDisable() == 0) ? 1 : 0;
|
||||
rComm->gpuFlush.enabled = ((ncclIbGdrSupport(lComm->dev) == ncclSuccess || ncclIbDmaBufSupport(lComm->dev) == ncclSuccess)
|
||||
&& (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0;
|
||||
if (rComm->gpuFlush.enabled) {
|
||||
NCCLCHECK(wrap_ibv_reg_mr(&rComm->gpuFlush.hostMr, rComm->verbs.pd, &rComm->gpuFlush.hostMem, sizeof(int), IBV_ACCESS_LOCAL_WRITE));
|
||||
rComm->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlush.hostMem;
|
||||
|
@ -285,7 +285,6 @@ ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent) {
|
||||
ncclAtomicRefCountIncrement(&parent->nvlsResources->refCount);
|
||||
} else {
|
||||
int nChannels;
|
||||
ncclResult_t res = ncclSuccess;
|
||||
struct ncclNvlsSharedRes* resources;
|
||||
|
||||
NCCLCHECK(ncclCalloc(&resources, 1));
|
||||
|
@ -352,28 +352,28 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
|
||||
if (myInfo->pidHash == peerInfo->pidHash && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0 && !ncclCuMemEnable()) {
|
||||
resources->type = P2P_DIRECT;
|
||||
send->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE;
|
||||
INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s",
|
||||
channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr);
|
||||
INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%d] -> %d[%d] via P2P/direct pointer%s",
|
||||
channelId, connIndex, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, useReadStr);
|
||||
} else {
|
||||
// cuMem API support
|
||||
if (ncclCuMemEnable()) {
|
||||
resources->type = P2P_CUMEM;
|
||||
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%x] -> %d[%x] via P2P/CUMEM%s%s",
|
||||
channelId, connIndex, myInfo->rank, myInfo->cudaDev, peerInfo->rank, peerInfo->cudaDev, useReadStr, useMemcpy ? "/CE" : "");;
|
||||
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%d] -> %d[%d] via P2P/CUMEM%s%s",
|
||||
channelId, connIndex, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, useReadStr, useMemcpy ? "/CE" : "");;
|
||||
} else {
|
||||
// Legacy CUDA IPC
|
||||
resources->type = P2P_IPC;
|
||||
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/IPC%s%s",
|
||||
channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, useMemcpy ? "/CE" : "");
|
||||
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%d] -> %d[%d] via P2P/IPC%s%s",
|
||||
channelId, connIndex, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, useReadStr, useMemcpy ? "/CE" : "");
|
||||
}
|
||||
send->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE;
|
||||
}
|
||||
} else {
|
||||
resources->type = P2P_INTERMEDIATE;
|
||||
info->rank = intermediateRank;
|
||||
INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/indirect/%d[%lx]%s",
|
||||
channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, intermediateRank,
|
||||
comm->peerInfo[intermediateRank].busId, useReadStr);
|
||||
INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%d] -> %d[%d] via P2P/indirect/%d[%d]%s",
|
||||
channelId, connIndex, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, intermediateRank,
|
||||
comm->peerInfo[intermediateRank].nvmlDev, useReadStr);
|
||||
}
|
||||
|
||||
tpProxyRank = comm->topParentRanks[info->rank];
|
||||
@ -421,7 +421,7 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
|
||||
// cuMem API support
|
||||
resources->type = P2P_CUMEM;
|
||||
TRACE(NCCL_INIT|NCCL_P2P,"Ring %02d : %d[%d] <- %d[%d] via P2P/CUMEM",
|
||||
channelId, myInfo->rank, myInfo->cudaDev, peerInfo->rank, peerInfo->cudaDev);
|
||||
channelId, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev);
|
||||
} else {
|
||||
// Legacy CUDA IPC
|
||||
resources->type = P2P_IPC;
|
||||
|
@ -92,7 +92,7 @@ static ncclResult_t shmSendSetup(struct ncclComm* comm, struct ncclTopoGraph* gr
|
||||
TRACE(NCCL_SHM,"Opened shmName %s shmSize %d", shmPath, info->shmSize);
|
||||
memcpy(info->shmName, shmPath+sizeof("/dev/shm/nccl-")-1, sizeof(info->shmName));
|
||||
|
||||
INFO(NCCL_INIT|NCCL_SHM,"Channel %02d : %d[%lx] -> %d[%lx] via SHM/%s/%s", channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useMemcpySend?"CE":"direct", useMemcpyRecv?"CE":"direct");
|
||||
INFO(NCCL_INIT|NCCL_SHM,"Channel %02d : %d[%d] -> %d[%d] via SHM/%s/%s", channelId, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, useMemcpySend?"CE":"direct", useMemcpyRecv?"CE":"direct");
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user