/************************************************************************* * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "devcomm.h" #include "collectives.h" #include "primitives.h" namespace { template __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->coll.bid; const int nChannels = args->coll.nChannels; ncclRing *ring = &ncclShmem.channel.ring; const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? BROADCAST_CHUNKSTEPS : 1)); const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))); const ssize_t loopSize = nChannels*chunkSize; const ssize_t size = args->coll.count; const int rank = ring->devUserRanks[0]; const int nextRank = ring->devUserRanks[1]; const int root = args->coll.root; T *inputBuf = (T*)args->sendbuff; T *outputBuf = (T*)args->recvbuff; Primitives, 0, Proto> prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->coll.redOpArg); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; if (Proto::Id == NCCL_PROTO_SIMPLE) { realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels)); realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); } else if (Proto::Id == NCCL_PROTO_LL) realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize; else if (Proto::Id == NCCL_PROTO_LL128) realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128); realChunkSize = int(realChunkSize); ssize_t offset = gridOffset + int(bid*realChunkSize); int nelem = min(realChunkSize, size-offset); if (rank == root) { if (inputBuf == outputBuf) { prims.send(offset, nelem); } else { prims.copySend(offset, offset, nelem); } } else if (nextRank == root) { prims.recv(offset, nelem); } else { prims.recvCopySend(offset, nelem); } } } } template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { using Proto = ProtoSimple; runRing(args); } }; template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } };