Heavy code refactoring to remove a lot of code in collectives (~1000 lines).

Have all collectives use the same args, the same ring, and the same primitives for synchronization between threads with the same pattern.
This commit is contained in:
Sylvain Jeaugey 2016-09-22 11:57:56 -07:00
parent e3dbc6110e
commit cabd6848e4
15 changed files with 1441 additions and 2273 deletions

View File

@ -7,7 +7,9 @@
CUDA_HOME ?= /usr/local/cuda
PREFIX ?= /usr/local
VERBOSE ?= 0
KEEP ?= 0
DEBUG ?= 0
PROFAPI ?= 0
BUILDDIR ?= build
CUDA_LIB ?= $(CUDA_HOME)/lib64
@ -39,10 +41,17 @@ else
.SILENT:
endif
ifneq ($(KEEP), 0)
NVCUFLAGS += -keep
endif
ifneq ($(PROFAPI), 0)
CXXFLAGS += -DPROFAPI
endif
NCCL_MAJOR := 1
NCCL_MINOR := 2
NCCL_PATCH := 3
NCCL_MINOR := 3
NCCL_PATCH := 0
CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)
@ -50,7 +59,7 @@ CUDA_MAJOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 1)
CUDA_MINOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 2)
CXXFLAGS += -DCUDA_MAJOR=$(CUDA_MAJOR) -DCUDA_MINOR=$(CUDA_MINOR)
.PHONY : lib clean debclean test mpitest install
.PHONY : lib clean test mpitest install deb debian debclean
.DEFAULT : lib
INCEXPORTS := nccl.h
@ -103,6 +112,7 @@ install : lib
cp -P -v $(BUILDDIR)/lib/* $(PREFIX)/lib/
cp -v $(BUILDDIR)/include/* $(PREFIX)/include/
#### TESTS ####
TEST_ONLY ?= 0
@ -132,7 +142,7 @@ MPITESTBINS:= $(patsubst %, $(MPITSTDIR)/%, $(MPITESTS))
test : $(TESTBINS)
$(TSTDIR)/% : test/single/%.cu $(TSTDEP)
$(TSTDIR)/% : test/single/%.cu test/include/*.h $(TSTDEP)
@printf "Building %-25s > %-24s\n" $< $@
mkdir -p $(TSTDIR)
$(NVCC) $(TSTINC) $(NVCUFLAGS) --compiler-options "$(CXXFLAGS)" -o $@ $< $(TSTLIB) -lcuda -lcurand -lnvToolsExt

View File

@ -1,479 +1,203 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#include <algorithm>
#include <cassert>
#include "core.h"
#include "common_kernel.h"
#include "copy_kernel.h"
#include "enqueue.h"
#include "primitives.h"
/* HIERARCHY
*
* The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
* SUBCHUNKS, where each SUBCHUNK is processed independently. A SUBCHUNK is
* split into numUnroll UNROLLS and each thread performs UNROLL_COUNT
* single-data-element operations inside an UNROLL. As the name suggests, the
* UNROLL_COUNT operations within an UNROLL are unrolled.
*/
#define NUM_SUBSTEPS 2
#define NUM_BUFCHUNKS 2
// Increase Step and poffset/noffset for buffer sync
#define NEXT_STEP \
step++; \
poffset = noffset; \
noffset += sliceSize; \
if (noffset == buffSize) noffset = 0;
// Number of threads used to perform copies, etc. Must be multiple of 32.
// An additional thread is used to handle threadfences, so the CUDA blocks
// have dimension NUM_THREADS+1.
#define NUM_THREADS 256
#define ALIGN_SIZE(size, align) \
size = ((size + (align) - 1) / (align)) * (align);
// Each thread unrolls the innermost loop of the copy or reduction operations
// to this many single-data-element instructions
#define UNROLL_COUNT 8
template<int THREADS, int UNROLL, class FUNC, typename T>
__launch_bounds__(THREADS+WARP_SIZE, 1)
__global__ void AllGatherKernel(const KernelArgs<T> args) {
const int tid = threadIdx.x;
__shared__ T* sharedNextOutput;
__shared__ DevRing<T> ring;
bool pushrecv = args.pushrecv;
#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
LoadRing<THREADS>(args.ring, &ring);
__syncthreads();
// To hide the latency associated with the synchronization between different
// subchunks, we interleave the independent subchunks so that more data can be
// transferred while the sync is in progress. This is the number of subchunks
// that are active at the same time
#define NUM_SUBCHUNKS 2
// If this is called with STEP, it means that we just finished processing the
// data for step STEP on this GPU, which is the data required on the next GPU
// for step STEP + 1, so we signal the next GPU that its data for step STEP + 1
// is available. This is called by one particular consumer warp and so we select
// the first thread in the warp to set the flag.
#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk, step) \
do { \
__threadfence_system(); \
args.NextNewDataAvailableFlag[0] = \
NUM_SUBCHUNKS*((chunk) * (args.NumGPUs - 1) + (step)) + subchunk+1; \
} while (0)
// This is called by all producer threads, but only thread 0 spins on the flag,
#define WAIT_FOR_NEW_DATA(chunk, subchunk, step) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
NUM_SUBCHUNKS*((chunk) * (args.NumGPUs - 1) + (step)) \
+ subchunk + 1 - NUM_SUBCHUNKS; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
do { \
__threadfence_system(); \
args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + (subchunk) + 1; \
} while (0)
#define WAIT_FOR_PREV_CHUNK(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int*)args.ThisChunkDoneFlag)[0] >= \
NUM_SUBCHUNKS*(chunk) + subchunk + 1-NUM_SUBCHUNKS; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
__device__ inline void getSliceSizeAndChunkSize(int *sliceSize, int slice,
int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
int smallSliceN, int lastSliceN) {
if (slice < numBigSlices) {
*sliceSize = bigSliceN;
} else {
*sliceSize = (slice < numBigSlices + numSmallSlices) ? smallSliceN
: ((slice == numSlices - 1) ? lastSliceN : 0);
}
}
template<typename T>
struct AllGatherKernelArgs {
// general parameters
int ThisId;
int NumGPUs;
int N;
int * UserFromRing;
// some pre-computed sizes
int SliceSize;
int ChunkSize;
int NumChunks;
int BufferSliceStride;
int BufferMisalignedN;
T ** ThisPtrToNextOutput;
T ** PrevPtrToThisOutput;
// local and remote input, output, and buffer
const T * __restrict__ ThisInput;
volatile T * __restrict__ ThisOutput;
volatile T * __restrict__ ThisBuffer;
volatile T * __restrict__ NextBuffer;
// local and remote flags
volatile int * __restrict__ ThisNewDataAvailableFlag;
volatile int * __restrict__ NextNewDataAvailableFlag;
volatile int * __restrict__ ThisChunkDoneFlag;
volatile int * __restrict__ PrevChunkDoneFlag;
};
__device__ inline int GetBlock(const int index, const int step,
const int * const userFromRing, const int numGPUs) {
return userFromRing[(numGPUs + index - step) % numGPUs];
}
__shared__ volatile void * nextOutput;
template<int THREADS, int UNROLL, bool PUSHRECV, typename T>
__global__ void AllGatherKernel(const AllGatherKernelArgs<T> args) {
if (args.N == 0) return;
int tid = threadIdx.x;
// First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
// the previous GPU is done with a previous collective operation.
if (tid == 0) {
WaitFlag prevCommOp(ring.prevOpCounter, 0);
WaitFlag nextCommOp(ring.nextOpCounter, 0);
prevCommOp.wait(args.opIndex);
nextCommOp.wait(args.opIndex);
if (pushrecv) {
*ring.sendPtrToPrev = (T*)args.ThisOutput;
Wait([=] {
return *((T * volatile *)args.PrevPtrToThisOutput) == nullptr;
return *ring.recvPtrFromNext != nullptr;
});
*((T * volatile *)args.PrevPtrToThisOutput) = (T*)args.ThisOutput;
Wait([=] {
return *((T * volatile *)args.ThisPtrToNextOutput) != nullptr;
});
if(PUSHRECV)
nextOutput = *((volatile void * volatile *)args.ThisPtrToNextOutput);
sharedNextOutput = *ring.recvPtrFromNext;
*ring.recvPtrFromNext = nullptr;
}
}
__syncthreads();
for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
// calculate slice size. for all chunks except (possibly) the last one,
// this will just be args.SliceSize. For the last one, it may be smaller
int bigSliceN = args.SliceSize;
int smallSliceN = 0;
int lastSliceN = 0;
int numSlices = NUM_SUBCHUNKS;
int numBigSlices = numSlices;
int numSmallSlices = 0;
WaitFlag waitDoneFromNext(ring.recvFlagFromNext, -NUM_BUFCHUNKS*NUM_SUBSTEPS);
WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, -1*NUM_SUBSTEPS);
PostFlag postDoneToPrev(ring.sendFlagToPrev, -1*NUM_SUBSTEPS);
PostFlag postReadyToNext(ring.sendFlagToNext, 0);
// last chunk
if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
&numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
args.ChunkSize);
typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T> Prims;
// this offset is only applied to Data pointers, not to Buffer pointers,
// since we only have one buffer per chunk
int chunkOffset = chunk * args.ChunkSize;
const int size = args.N;
const int nranks = args.nRanks;
const int buffSize = args.buffSize / sizeof(T);
const int sliceSize = buffSize / NUM_BUFCHUNKS;
// step 0: copy the resident block from the ThisInput to ThisOutput and also
// to NextOutput
int step = 0;
int block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
int outputOffset = chunkOffset + block * args.N;
int inputOffset = chunkOffset;
int bufferOffset;
int sliceSize;
int poffset, noffset = 0;
if (!PUSHRECV) {
bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
block * args.BufferMisalignedN;
}
// Compute pointers
const T * __restrict__ thisInput = args.ThisInput;
T * __restrict__ thisOutput = args.ThisOutput;
T * __restrict__ prevInput = ring.recvBuffer;
T * __restrict__ nextOutput = ring.sendBuffer;
// Copy from ThisInput
if (tid < THREADS) {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
for (int chunkOffset = 0; chunkOffset < size; chunkOffset += sliceSize) {
/////////////// begin AllGather steps ///////////////
int offset;
int maxOffset = size-chunkOffset;
int rankDest;
if (!PUSHRECV)
WAIT_FOR_PREV_CHUNK(chunk, s);
// step 0: push data to next GPU
rankDest = ring.userRank[0];
offset = chunkOffset + rankDest * size;
if (PUSHRECV) {
DoubleCopy<UNROLL, THREADS>(
args.ThisOutput + outputOffset,
(volatile T *)nextOutput + outputOffset,
args.ThisInput + inputOffset,
sliceSize);
if (thisInput == thisOutput) {
Prims::Copy(
thisInput + offset,
pushrecv ? sharedNextOutput + offset : nextOutput + noffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
} else {
DoubleCopy<UNROLL, THREADS>(
args.ThisOutput + outputOffset,
args.NextBuffer + bufferOffset,
args.ThisInput + inputOffset,
sliceSize);
Prims::DoubleCopy(
thisInput + chunkOffset,
thisOutput + offset,
pushrecv ? sharedNextOutput + offset : nextOutput + noffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
}
__syncthreads();
outputOffset += sliceSize;
inputOffset += sliceSize;
if (!PUSHRECV)
bufferOffset += sliceSize;
NEXT_STEP; // Increases step, poffset, noffset
// k-2 steps: copy to next GPU
if (pushrecv) {
for (int j=1; j<nranks-1; ++j) {
rankDest = ring.userRank[nranks-j];
offset = chunkOffset + rankDest * size;
Prims::Copy(
thisOutput + offset,
sharedNextOutput + offset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
NEXT_STEP;
}
} else {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
}
}
for (int j=1; j<nranks-1; ++j) {
rankDest = ring.userRank[nranks-j];
offset = chunkOffset + rankDest * size;
// steps j with 0 < j < k - 1:
// copy a block that was pushed to this GPU to the next GPU
for (step = 1; step < args.NumGPUs - 1; ++step) {
block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
outputOffset = chunkOffset + block * args.N;
if (!PUSHRECV) {
bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
block * args.BufferMisalignedN;
}
Prims::DoubleCopy(
prevInput + poffset,
thisOutput + offset,
nextOutput + noffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
if (tid < THREADS) {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
WAIT_FOR_NEW_DATA(chunk, s, step);
if (PUSHRECV) {
Copy<UNROLL, THREADS>(
(volatile T *)nextOutput + outputOffset,
args.ThisOutput + outputOffset,
sliceSize);
} else {
DoubleCopy<UNROLL, THREADS>(
args.NextBuffer + bufferOffset,
args.ThisOutput + outputOffset,
args.ThisBuffer + bufferOffset,
sliceSize);
NEXT_STEP;
}
__syncthreads();
outputOffset += sliceSize;
if (!PUSHRECV)
bufferOffset += sliceSize;
}
} else {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
}
}
}
if (!PUSHRECV) {
step = args.NumGPUs - 1;
block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
outputOffset = chunkOffset + block * args.N;
bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
block * args.BufferMisalignedN;
// Make final copy from buffer to dest.
if (tid < THREADS) {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
WAIT_FOR_NEW_DATA(chunk, s, step);
rankDest = ring.userRank[1];
offset = chunkOffset + rankDest * size;
Copy<UNROLL, THREADS>(
args.ThisOutput + outputOffset,
args.ThisBuffer + bufferOffset,
sliceSize);
// Here we need to copy from buffer to this output.
Prims::Copy(
prevInput + poffset,
thisOutput + offset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
__syncthreads();
outputOffset += sliceSize;
bufferOffset += sliceSize;
}
} else {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_CHUNK_DONE(chunk, s);
}
}
NEXT_STEP;
}
}
// wait for the last data to be pushed to us
if (tid < THREADS) {
if (PUSHRECV)
WAIT_FOR_NEW_DATA(args.NumChunks, NUM_SUBCHUNKS-1, 0);
else
WAIT_FOR_PREV_CHUNK(args.NumChunks, NUM_SUBCHUNKS-1);
if (tid == 0) {
args.ThisNewDataAvailableFlag[0] = 0;
args.ThisChunkDoneFlag[0] = 0;
*args.ThisPtrToNextOutput = nullptr;
}
// Wait for last update from next then reset the flag
waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
*ring.recvFlagFromNext = 0;
// Wait for last update from prev then reset the flag
waitReadyFromPrev.wait(NUM_SUBSTEPS*(step+1));
*ring.recvFlagFromPrev = 0;
incrementOpCounter(&args);
}
}
template<typename T>
ncclResult_t ncclAllGatherWithType(const void* sendbuff, void* recvbuff,
int count, ncclComm* comm, int numUnroll, cudaStream_t stream) {
#define THREADS 384
#define UNROLL 8
template<class FUNC, typename T>
ncclResult_t RingAllGather(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
int index = comm->ncclId;
int blockSizeInBytes = count * sizeof(T);
int misalignedBytes = blockSizeInBytes % alignof(uint64_t);
assert((int)((misalignedBytes / sizeof(T)) * sizeof(T)) == misalignedBytes);
int misalignedN = misalignedBytes / sizeof(T);
assert(misalignedN < (int)(sizeof(uint64_t) / sizeof(T)));
int paddingN = (misalignedN > 0) ? sizeof(uint64_t) / sizeof(T) : 0;
// There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
// where bufferN is the number of elements of type T that fit into the buffer.
int bufferN = comm->buffSize / sizeof(T);
// we only need buffer for k slices and k paddings
int bufferNPerSlice = (bufferN - comm->nDev * NUM_SUBCHUNKS * paddingN)
/ (comm->nDev * NUM_SUBCHUNKS);
// For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
int maxSliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
int nextId = (index + 1) % comm->nDev;
int prevId = (index + comm->nDev - 1) % comm->nDev;
AllGatherKernelArgs<T> args;
args.ThisId = index;
args.NumGPUs = comm->nDev;
args.N = count;
/* Block j is coming from sendbuff[j], which lives on device with logical
* index comm->ringFromUser[j]. But the block ordering does not necessarily
* follow the ring ordering. Hence the order in which a particular GPU
* processes the different blocks (the correspondence between the step in
* the reduction algorithm and the block on which a GPU operates in that
* particular step) is not the same as the ring order.
*
* Say we have 4 GPUs and comm->userFromRing = { 1, 2, 0, 3 }. Then there are 3
* step in the all-gather algorithm and block 0 comes from device 2, block 1
* from 0, block 2 from device 1, and block 3 comes from device 3. In the
* first step of the algorithm, each GPU must copy its own block from its
* sendbuff to the appropriate location in its recvbuff. The blocks that a
* GPU has to process in the next steps is determined by the previous step
* because each GPU only hands off data to the next GPU in the ring.
*
* In the above example, we get the following table of which block is
* processed by each GPU in a given step. The columns correspond to the
* different GPUs while the rows are the steps in the algorithm.
*
* GPU 0 1 2 3
* step
* 0 1 2 0 3
* 1 3 1 2 0
* 2 0 3 1 2
*
* We note the the rows in the above table are just comm->userFromRing in the
* first step and the list is cyclicly permuted to the right for each next
* step. The columns, which are what the individual GPUs need to know, are
* comm->userFromRing traversed backwards and starting at index k for GPU k.
* These columns are what we put into args.BlockVsStep to tell the GPU which
* block it needs to be processing at a particular step. */
args.UserFromRing = comm->devUserFromRing;
args.SliceSize = numUnroll * UNROLL_SIZE * sizeof(PackType) / sizeof(T);
args.SliceSize = std::min(maxSliceSize, args.SliceSize);
args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
// don't reduce this if we cut the slice size in half below, because if that
// happens, the last chunk will be larger than the other chunks, and we will
// need the extra buffer space
args.BufferSliceStride = args.SliceSize + paddingN;
args.BufferMisalignedN = misalignedN;
// avoid a case where we have one or more big chunks and one tiny one
int remainder = args.N % args.ChunkSize;
if ((args.N > args.ChunkSize) && (remainder > 0) &&
(args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
args.SliceSize /= 2;
args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
// round down so we end up with a big last chunk
args.NumChunks = args.N / args.ChunkSize;
} else {
// round up
args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
}
args.ThisPtrToNextOutput = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
args.PrevPtrToThisOutput = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
args.ThisInput = (const T*)sendbuff;
args.ThisOutput = (volatile T*)recvbuff;
args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if (comm->nDev == 1) {
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
if( comm->useRemoteRecv ) {
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, true, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
AllGatherKernel<NUM_THREADS, UNROLL_COUNT, false, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
LAUNCH_KERNEL(AllGatherKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
return ncclSuccess;
}
class AllGatherFunctor {
public:
ncclResult_t operator()(const void* sendbuff, void* recvbuff,
int count, ncclDataType_t datatype, ncclRedOp_t /*dummy operation*/,
int /*dummy root*/, ncclComm* comm, cudaStream_t stream) {
int numUnroll = 16; // this is optimal on dt07 with 4 GPUs
switch (datatype) {
case ncclChar:
return ncclAllGatherWithType<char>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
case ncclInt:
return ncclAllGatherWithType<int>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
#if CUDART_VERSION >= 7050
case ncclHalf:
return ncclAllGatherWithType<half>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
#endif
case ncclFloat:
return ncclAllGatherWithType<float>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
case ncclDouble:
return ncclAllGatherWithType<double>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
case ncclInt64:
return ncclAllGatherWithType<long long>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
case ncclUint64:
return ncclAllGatherWithType<unsigned long long>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
}
return ncclInvalidType;
template<typename T, template<typename> class RedOp>
class AllGather {
public:
static ncclResult_t entry(const void* sendbuff, void* recvbuff,
int count, int /*root*/, ncclComm* comm, cudaStream_t stream) {
return RingAllGather<RedOp<T>, T>(sendbuff, recvbuff, count, comm, stream);
}
};
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, int count, ncclDataType_t datatype,
void* recvbuff, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllGather(const void* sendbuff, int count, ncclDataType_t datatype,
void* recvbuff, ncclComm_t comm, cudaStream_t stream) {
return enqueue(AllGatherFunctor(), sendbuff, recvbuff, count, datatype,
ncclSum, 0, comm, stream);
return enqueue<AllGather, FuncNull>(sendbuff, recvbuff, count, datatype, 0, comm, stream);
}

View File

@ -1,491 +1,233 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#include "core.h"
#include "common_kernel.h"
#include "copy_kernel.h"
#include "enqueue.h"
#include "reduce_kernel.h"
#include "primitives.h"
/* HIERARCHY
*
* The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
* SUBCHUNKS, where each SUBCHUNK is an independent, complete reduction. Each
* GPU has a buffer that can fit an entire CHUNK, so that all SUBCHUNKS can be
* processed without checking that the buffer on the receiving GPU is empty. A
* SUBCHUNK is split into NUM_GPUS SLICES and each GPU works on a different
* SLICE at the same time. Before moving on the the next SLICE in the reduction
* algorithm, the GPU has to check whether it has received the data from the
* previous GPU it needs for this SLICE. To hide the latency of this
* communication, each GPU processes all the SLICES of all the SUBCHUNKS in
* sequence before moving on to the next SLICE. Each SLICE is split into a
* certain number of UNROLLS (determined by the buffer size) and each thread
* performs UNROLL_COUNT single-data-element operations inside an UNROLL. As the
* name suggests, the UNROLL_COUNT operations within an UNROLL are unrolled.
*/
#define NUM_SUBSTEPS 2
#define NUM_BUFCHUNKS 2
// Number of threads used to perform copies, etc. Must be multiple of 32.
// An additional thread is used to handle threadfences, so the CUDA blocks
// have dimension NUM_THREADS+1.
#define NUM_THREADS 256
// Increase Step and poffset/noffset for buffer sync
#define NEXT_STEP \
step++; \
poffset = noffset; \
noffset += sliceSize; \
if (noffset == buffSize) noffset = 0;
// Each thread unrolls the innermost loop of the copy or reduction operations
// to this many single-data-element instructions
#define UNROLL_COUNT 8
#define ALIGN_SIZE(size, align) \
size = ((size + (align) - 1) / (align)) * (align);
#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
// To hide the latency associated with the synchronization between different
// subchunks, we interleave the independent subchunks so that more data can be
// transferred while the sync is in progress. This is the number of subchunks
// that are active at the same time
#define NUM_SUBCHUNKS 2
// If this is called with STEP, it means that we just finished processing the
// data for step STEP on this GPU, which is the data required on the next GPU
// for step STEP + 1, so we signal the next GPU that its data for step STEP + 1
// is available. This is called by one particular consumer warp and so we select
// the first thread in the warp to set the flag.
#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk, step) \
do { \
__threadfence_system(); \
args.NextNewDataAvailableFlag[0] = \
NUM_SUBCHUNKS*((chunk) * (2 * args.NumGPUs - 2) + (step) + 1)+subchunk; \
} while (0)
// This is called by all producer threads, but only thread 0 spins on the flag,
#define WAIT_FOR_NEW_DATA(chunk, subchunk, step) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
2*((chunk) * (2 * args.NumGPUs - 2) + (step))+subchunk; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
do { \
args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
} while (0)
#define WAIT_FOR_CHUNK(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
NUM_SUBCHUNKS*(chunk) + subchunk + 1 - NUM_SUBCHUNKS; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
__device__ inline void getSliceSizeAndOffset(int *size, int *offset, int slice,
int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
int smallSliceN, int lastSliceN) {
if (slice < numBigSlices) {
*size = bigSliceN;
*offset = slice * bigSliceN;
} else {
*size = (slice < numBigSlices + numSmallSlices) ? smallSliceN
: ((slice == numSlices - 1) ? lastSliceN : 0);
*offset = numBigSlices * bigSliceN + (slice - numBigSlices) * smallSliceN;
}
}
template<typename T>
struct AllReduceKernelArgs {
// general parameters
int ThisId;
int NumGPUs;
int N;
// some pre-computed sizes
int SliceSize;
int ChunkSize;
int NumChunks;
T ** ThisPtrToNextOutput;
T ** PrevPtrToThisOutput;
// local and remote input, output, and buffer
const T * __restrict__ ThisInput;
volatile T * __restrict__ ThisOutput;
volatile T * __restrict__ ThisBuffer;
volatile T * __restrict__ NextBuffer;
// local and remote flags
volatile int * __restrict__ ThisNewDataAvailableFlag;
volatile int * __restrict__ NextNewDataAvailableFlag;
volatile int * __restrict__ ThisChunkDoneFlag;
volatile int * __restrict__ PrevChunkDoneFlag;
};
__shared__ volatile void * nextOutput;
template<int THREADS, int UNROLL, class FUNC, bool PUSHRECV, typename T>
template<int THREADS, int UNROLL, class FUNC, typename T>
__launch_bounds__(THREADS+WARP_SIZE, 1)
__global__ void AllReduceKernel(const AllReduceKernelArgs<T> args) {
if (args.N == 0) return;
__global__ void AllReduceKernel(const KernelArgs<T> args) {
const int tid = threadIdx.x;
__shared__ T* sharedNextOutput;
__shared__ DevRing<T> ring;
bool pushrecv = args.pushrecv;
LoadRing<THREADS>(args.ring, &ring);
__syncthreads();
// First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
// the previous GPU is done with a previous collective operation.
if (tid == 0) {
WaitFlag prevCommOp(ring.prevOpCounter, 0);
WaitFlag nextCommOp(ring.nextOpCounter, 0);
prevCommOp.wait(args.opIndex);
nextCommOp.wait(args.opIndex);
if (pushrecv) {
*ring.sendPtrToPrev = (T*)args.ThisOutput;
Wait([=] {
return *((T * volatile *)args.PrevPtrToThisOutput) == nullptr;
return *ring.recvPtrFromNext != nullptr;
});
*((T * volatile *)args.PrevPtrToThisOutput) = (T*)args.ThisOutput;
Wait([=] {
return *((T * volatile *)args.ThisPtrToNextOutput) != nullptr;
});
if (PUSHRECV)
nextOutput =
*((volatile void * volatile *)args.ThisPtrToNextOutput);
sharedNextOutput = *ring.recvPtrFromNext;
*ring.recvPtrFromNext = nullptr;
}
}
__syncthreads();
WaitFlag waitDoneFromNext(ring.recvFlagFromNext, -NUM_BUFCHUNKS*NUM_SUBSTEPS);
WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, -1*NUM_SUBSTEPS);
PostFlag postDoneToPrev(ring.sendFlagToPrev, -1*NUM_SUBSTEPS);
PostFlag postReadyToNext(ring.sendFlagToNext, 0);
for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
// calculate slice size. for all chunks except (possibly) the last one,
// this will just be args.SliceSize. For the last one, it may be smaller
int bigSliceN = args.SliceSize;
int smallSliceN = 0;
int lastSliceN = 0;
int numSlices = args.NumGPUs * NUM_SUBCHUNKS;
int numBigSlices = numSlices;
int numSmallSlices = 0;
typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T, FUNC> Prims;
// last chunk
if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
&numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
args.ChunkSize);
const int size = args.N;
const int nranks = args.nRanks;
const int buffSize = args.buffSize / sizeof(T);
const int sliceSize = buffSize / NUM_BUFCHUNKS;
// this offset is only applied to Data pointers, not to Buffer pointers,
// since we only have one buffer per chunk
int chunkOffset = chunk * args.ChunkSize;
int step = 0;
int poffset, noffset = 0;
// Compute pointers
const T * __restrict__ thisInput = args.ThisInput;
T * __restrict__ thisOutput = args.ThisOutput;
T * __restrict__ prevInput = ring.recvBuffer;
T * __restrict__ nextOutput = ring.sendBuffer;
for (int chunkOffset = 0; chunkOffset < size; chunkOffset += nranks*sliceSize) {
/////////////// begin AllReduce steps ///////////////
int offset;
int maxOffset;
int slice;
// step 0: push data to next GPU
int step = 0;
int slice = args.ThisId;
int offset;
int sliceSize;
slice = ring.userRank[nranks-1];
offset = chunkOffset + slice * sliceSize;
maxOffset = size-offset;
if (tid < THREADS) {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
if (s > 0) { slice += args.NumGPUs; }
getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
Prims::Copy(
thisInput + offset,
nextOutput + noffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
if (!PUSHRECV && chunk > 0) {
WAIT_FOR_CHUNK(chunk, s);
}
NEXT_STEP; // Increases step, poffset, noffset
Copy<UNROLL, THREADS>(
args.NextBuffer + offset,
args.ThisInput + chunkOffset + offset,
sliceSize);
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
slice = ring.userRank[nranks-j];
offset = chunkOffset + slice * sliceSize;
maxOffset = size-offset;
__syncthreads();
}
} else { // is consumer thread
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
}
}
Prims::Reduce(
prevInput + poffset,
thisInput + offset,
nextOutput + noffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
// steps j with 1 <= j < k - 1, where k = number of GPUs:
// reduce and copy to next GPU
for (step = 1; step < args.NumGPUs - 1; ++step) {
if (tid < THREADS) {
slice = (args.NumGPUs + slice - 1) % args.NumGPUs;
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
if (s > 0) { slice += args.NumGPUs; }
getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
WAIT_FOR_NEW_DATA(chunk, s, step);
Reduce<UNROLL, THREADS, FUNC>(
args.NextBuffer + offset,
args.ThisBuffer + offset,
args.ThisInput + chunkOffset + offset,
sliceSize);
__syncthreads();
}
} else {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
}
}
NEXT_STEP;
}
// step k - 1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
step = args.NumGPUs - 1;
slice = ring.userRank[0];
offset = chunkOffset + slice * sliceSize;
maxOffset = size-offset;
if (tid < THREADS) {
slice = (args.NumGPUs + slice - 1) % args.NumGPUs;
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
if (s > 0) { slice += args.NumGPUs; }
getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
Prims::ReduceCopy(
prevInput + poffset,
thisInput + offset,
pushrecv ? (sharedNextOutput + offset) : (nextOutput + noffset),
thisOutput + offset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
WAIT_FOR_NEW_DATA(chunk, s, step);
NEXT_STEP;
if (PUSHRECV) {
ReduceAndCopy<UNROLL, THREADS, FUNC>(
(volatile T *)nextOutput + chunkOffset + offset,
args.ThisOutput + chunkOffset + offset,
args.ThisBuffer + offset,
args.ThisInput + chunkOffset + offset,
sliceSize);
} else {
ReduceAndCopy<UNROLL, THREADS, FUNC>(
args.NextBuffer + offset,
args.ThisOutput + chunkOffset + offset,
args.ThisBuffer + offset,
args.ThisInput + chunkOffset + offset,
sliceSize);
}
if (pushrecv) {
// k-2 steps: copy result to next GPU
for (int j=1; j<nranks-1; ++j) {
slice = ring.userRank[nranks - j];
offset = chunkOffset + slice * sliceSize;
maxOffset = size-offset;
__syncthreads();
Prims::Copy(
thisOutput + offset,
sharedNextOutput + offset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
NEXT_STEP;
}
} else {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
}
// k-2 steps: copy result to next GPU
for (int j=1; j<nranks-1; ++j) {
slice = ring.userRank[nranks - j];
offset = chunkOffset + slice * sliceSize;
maxOffset = size-offset;
Prims::DoubleCopy(
prevInput + poffset,
thisOutput + offset,
nextOutput + noffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
NEXT_STEP;
}
// steps j with k <= j < 2*k-2: copy result to next GPU
for (step = args.NumGPUs; step < 2 * args.NumGPUs - 2; ++step) {
if (tid < THREADS) {
slice = (args.NumGPUs + slice - 1) % args.NumGPUs;
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
if (s > 0) { slice += args.NumGPUs; }
getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
WAIT_FOR_NEW_DATA(chunk, s, step);
if( PUSHRECV ) {
Copy<UNROLL, THREADS>(
(volatile T *)nextOutput + chunkOffset + offset,
args.ThisOutput + chunkOffset + offset,
sliceSize);
} else {
DoubleCopy<UNROLL, THREADS>(
args.NextBuffer + offset,
args.ThisOutput + chunkOffset + offset,
args.ThisBuffer + offset,
sliceSize);
}
__syncthreads();
}
} else {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
}
}
}
if (!PUSHRECV) {
// Make final copy from buffer to dest.
if (tid < THREADS) {
slice = (args.NumGPUs + slice - 1) % args.NumGPUs;
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
if (s > 0) { slice += args.NumGPUs; }
getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
WAIT_FOR_NEW_DATA(chunk, s, step);
slice = ring.userRank[1];
offset = chunkOffset + slice * sliceSize;
maxOffset = size-offset;
// Here we need to copy from buffer to this output.
Copy<UNROLL, THREADS>(
args.ThisOutput + chunkOffset + offset,
args.ThisBuffer + offset,
sliceSize);
Prims::Copy(
prevInput + poffset,
thisOutput + offset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
__syncthreads();
}
} else {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
if(chunk+1 < args.NumChunks) {
SIGNAL_CHUNK_DONE(chunk, s);
}
}
}
NEXT_STEP;
}
}
// wait for the last data to be pushed to us
if (tid < THREADS) {
if(PUSHRECV) {
WAIT_FOR_NEW_DATA(args.NumChunks, NUM_SUBCHUNKS-1, 0);
}
if (tid == 0) {
args.ThisNewDataAvailableFlag[0] = 0;
if(!PUSHRECV) {
args.ThisChunkDoneFlag[0] = 0;
}
*args.ThisPtrToNextOutput = nullptr;
}
// Wait for last update from next then reset the flag
waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
*ring.recvFlagFromNext = 0;
// Wait for last update from prev then reset the flag
waitReadyFromPrev.wait(NUM_SUBSTEPS*(step+1));
*ring.recvFlagFromPrev = 0;
incrementOpCounter(&args);
}
}
#define THREADS 512
#define UNROLL 8
template<class FUNC, typename T>
ncclResult_t ncclAllReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
ncclResult_t RingAllReduce(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
int index = comm->ncclId;
// There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
// where bufferN is the number of elements of type T that fit into the buffer.
// For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
int bufferN = comm->buffSize / sizeof(T);
int bufferNPerSlice = bufferN / (NUM_SUBCHUNKS * comm->nDev);
int sliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
int nextId = (index + 1) % comm->nDev;
int prevId = (index + comm->nDev - 1) % comm->nDev;
AllReduceKernelArgs<T> args;
args.ThisId = index;
args.NumGPUs = comm->nDev;
args.N = count;
args.SliceSize = sliceSize;
int subchunkSize = comm->nDev * args.SliceSize;
args.ChunkSize = NUM_SUBCHUNKS * subchunkSize;
// avoid a case where we have one or more big chunks and one tiny one
int remainder = args.N % args.ChunkSize;
if ((args.N > args.ChunkSize) && (remainder > 0) &&
(args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
args.SliceSize /= 2;
int subchunkSize = comm->nDev * args.SliceSize;
args.ChunkSize = NUM_SUBCHUNKS * subchunkSize;
// round down so we end up with a big last chunk
args.NumChunks = args.N / args.ChunkSize;
} else {
// round up
args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
}
args.ThisPtrToNextOutput = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
args.PrevPtrToThisOutput = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
args.ThisInput = (const T*)sendbuff;
args.ThisOutput = (volatile T*)recvbuff;
args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if (comm->nDev == 1) {
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
if( comm->useRemoteRecv ) {
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, true, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, false, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
LAUNCH_KERNEL(AllReduceKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
return ncclSuccess;
}
template<typename T>
ncclResult_t ncclAllReduceWithType(const void* sendbuff,
void* recvbuff, int count, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
switch (op) {
case ncclSum:
return ncclAllReduceWithTypeAndFunc<FuncSum<T>, T>(
sendbuff, recvbuff, count, comm, stream);
case ncclProd:
return ncclAllReduceWithTypeAndFunc<FuncProd<T>, T>(
sendbuff, recvbuff, count, comm, stream);
case ncclMax:
return ncclAllReduceWithTypeAndFunc<FuncMax<T>, T>(
sendbuff, recvbuff, count, comm, stream);
case ncclMin:
return ncclAllReduceWithTypeAndFunc<FuncMin<T>, T>(
sendbuff, recvbuff, count, comm, stream);
}
return ncclInvalidOperation;
}
class AllReduceFunctor {
public:
ncclResult_t operator()(const void* sendbuff, void* recvbuff,
int count, ncclDataType_t datatype, ncclRedOp_t op, int /*root*/,
ncclComm* comm, cudaStream_t stream) {
switch (datatype) {
case ncclChar:
return ncclAllReduceWithType<char>(sendbuff, recvbuff, count, op,
comm, stream);
case ncclInt:
return ncclAllReduceWithType<int>(sendbuff, recvbuff, count, op,
comm, stream);
#if CUDART_VERSION >= 7050
case ncclHalf:
return ncclAllReduceWithType<half>(sendbuff, recvbuff, count, op,
comm, stream);
#endif
case ncclFloat:
return ncclAllReduceWithType<float>(sendbuff, recvbuff, count, op,
comm, stream);
case ncclDouble:
return ncclAllReduceWithType<double>(sendbuff, recvbuff, count, op,
comm, stream);
case ncclInt64:
return ncclAllReduceWithType<long long>(sendbuff, recvbuff, count, op,
comm, stream);
case ncclUint64:
return ncclAllReduceWithType<unsigned long long int>(sendbuff, recvbuff, count, op,
comm, stream);
}
return ncclInvalidType;
template<typename T, template <typename> class RedOp>
class AllReduce {
public:
static ncclResult_t entry(const void* sendbuff, void* recvbuff,
int count, int /*root*/, ncclComm* comm, cudaStream_t stream) {
return RingAllReduce<RedOp<T>, T>(sendbuff, recvbuff, count, comm, stream);
}
};
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
return enqueue(AllReduceFunctor(), sendbuff, recvbuff, count, datatype, op, 0,
comm, stream);
return enqueue<AllReduce>(sendbuff, recvbuff, count, datatype, op, 0, comm, stream);
}

View File

@ -1,392 +1,165 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#include <algorithm>
#include "core.h"
#include "common_kernel.h"
#include "copy_kernel.h"
#include "enqueue.h"
#include "primitives.h"
/* HIERARCHY
*
* The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
* SUBCHUNKS, where each SUBCHUNK is processed independently. A SUBCHUNK is
* split into numUnroll UNROLLS and each thread performs UNROLL_COUNT
* single-data-element operations inside an UNROLL. As the name suggests, the
* UNROLL_COUNT operations within an UNROLL are unrolled.
*/
#define NUM_SUBSTEPS 2
#define NUM_BUFCHUNKS 2
// Number of threads used to perform copies, etc. Must be multiple of 32.
// An additional thread is used to handle threadfences, so the CUDA blocks
// have dimension NUM_THREADS+1.
#define NUM_THREADS 256
// Increase Step and boffset for buffer sync
#define NEXT_STEP \
step++; \
boffset += sliceSize; \
if (boffset == buffSize) boffset = 0;
// Each thread unrolls the innermost loop of the copy or reduction operations
// to this many single-data-element instructions
#define UNROLL_COUNT 8
#define ALIGN_SIZE(size, align) \
size = ((size + (align) - 1) / (align)) * (align);
#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
template<int THREADS, int UNROLL, class FUNC, typename T>
__launch_bounds__(THREADS+WARP_SIZE, 1)
__global__ void BroadcastKernel(const KernelArgs<T> args) {
const int tid = threadIdx.x;
__shared__ T* sharedNextOutput;
__shared__ DevRing<T> ring;
bool pushrecv = args.pushrecv;
// To hide the latency associated with the synchronization between different
// subchunks, we interleave the independent subchunks so that more data can be
// transferred while the sync is in progress. This is the number of subchunks
// that are active at the same time
#define NUM_SUBCHUNKS 4
LoadRing<THREADS>(args.ring, &ring);
__syncthreads();
// if this is called with CHUNK, it means that we just finished pushing the data
// of chunk CHUNK to the next GPU, so it can proceed with CHUNK
// We add 1 to chunk so that the initial flag of 0 doesn't allow the non-root
// GPUs to proceed before the flag is incremented from the upstream GPU. This
// is called by one particular consumer warp and so we select the first thread
// in the warp to set the flag.
#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk) \
do { \
__threadfence_system(); \
args.NextNewDataAvailableFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
} while (0)
if (tid == 0) {
WaitFlag prevCommOp(ring.prevOpCounter, 0);
WaitFlag nextCommOp(ring.nextOpCounter, 0);
prevCommOp.wait(args.opIndex);
nextCommOp.wait(args.opIndex);
if (pushrecv) {
*ring.sendPtrToPrev = (T*)args.ThisOutput;
Wait([=] {
return *ring.recvPtrFromNext != nullptr;
});
sharedNextOutput = *ring.recvPtrFromNext;
*ring.recvPtrFromNext = nullptr;
}
}
__syncthreads();
// This is called by all producer threads, but only thread 0 spins on the flag,
#define WAIT_FOR_NEW_DATA(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
WaitFlag waitDoneFromNext(ring.recvFlagFromNext, (1-NUM_BUFCHUNKS)*NUM_SUBSTEPS);
WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, 0);
PostFlag postDoneToPrev(ring.sendFlagToPrev, 0);
PostFlag postReadyToNext(ring.sendFlagToNext, 0);
// If this is called with CHUNK, it means that this GPU has just finished
// processing the chunk CHUNK and so the previous GPU can start with CHUNK + 1
#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
do { \
args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
} while (0)
typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T> Prims;
// This is called by all producer threads, but only thread 0 spins on the flag,
// all threads synchronize after thread 0 is done spinning.
#define WAIT_FOR_CHUNK(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
NUM_SUBCHUNKS*(chunk) + subchunk + 1 - NUM_SUBCHUNKS; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
const int size = args.N;
const int rank = ring.userRank[0];
const int nextRank = ring.userRank[1];
const int root = args.root;
const int buffSize = args.buffSize / sizeof(T);
const int sliceSize = buffSize / NUM_BUFCHUNKS;
// This is called by all producer threads, but only thread 0 spins on the flag,
// all threads synchronize after thread 0 is done spinning.
#define WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
bool newDataAvailable = \
((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
bool chunkDone = \
((volatile int *)args.ThisChunkDoneFlag)[0] >= \
NUM_SUBCHUNKS*(chunk)+subchunk + 1 - NUM_SUBCHUNKS; \
return newDataAvailable && chunkDone; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
int step = 0;
int boffset = 0;
__device__ inline void getSliceSizeAndOffset(int *size, int *offset, int slice,
int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
int smallSliceN, int lastSliceN) {
if (slice < numBigSlices) {
*size = bigSliceN;
*offset = slice * bigSliceN;
// Compute pointers
const T * __restrict__ thisInput = args.ThisInput;
T * __restrict__ thisOutput = args.ThisOutput;
T * __restrict__ prevInput = ring.recvBuffer;
T * __restrict__ nextOutput = ring.sendBuffer;
for (int offset = 0; offset < size; offset += sliceSize) {
int maxOffset = size-offset;
if (rank == root) {
Prims::Copy(
thisInput + offset,
pushrecv ? sharedNextOutput + offset : nextOutput + boffset,
sliceSize, maxOffset,
step,
waitDoneFromNext,
postReadyToNext);
} else if (nextRank == root) {
if (pushrecv) maxOffset = 0; // Only wait for signals
Prims::Copy(
prevInput + boffset,
thisOutput + offset,
sliceSize, maxOffset,
step,
waitReadyFromPrev,
postDoneToPrev);
} else {
*size = (slice < numBigSlices + numSmallSlices) ? smallSliceN
: ((slice == numSlices - 1) ? lastSliceN : 0);
*offset = numBigSlices * bigSliceN + (slice - numBigSlices) * smallSliceN;
if (pushrecv) {
Prims::Copy(
thisOutput + offset,
sharedNextOutput + offset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
} else {
Prims::DoubleCopy(
prevInput + boffset,
thisOutput + offset,
nextOutput + boffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
}
}
NEXT_STEP; // Increases step, boffset
}
// if (threadIdx.x == 0)
// printf("[size=%d] [offset=%d] slice=%d numSlices=%d "
// "numBigSlices=%d numSmallSlices=%d bigSliceN=%d smallSliceN=%d "
// "lastSliceN=%d\n", *size, *offset, slice, numSlices, numBigSlices,
// numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
}
template<typename T>
struct BroadcastKernelArgs {
// general parameters
int ThisId;
int N;
// some pre-computed sizes
int SliceSize;
int ChunkSize;
int NumChunks;
int BufferSliceStride;
T ** ThisPtrToNextData;
T ** PrevPtrToThisData;
// local and remote data
T * __restrict__ ThisData;
volatile T * __restrict__ ThisBuffer;
volatile T * __restrict__ NextBuffer;
// local and remote flags
volatile int * __restrict__ ThisNewDataAvailableFlag;
volatile int * __restrict__ NextNewDataAvailableFlag;
volatile int * __restrict__ ThisChunkDoneFlag;
volatile int * __restrict__ PrevChunkDoneFlag;
};
__shared__ volatile void * nextData;
enum BcastRole {ROOT=0, MIDDLE=1, END=2};
template<int THREADS, int UNROLL, bool PUSHRECV, int ROLE, typename T>
__global__ void BroadcastKernel(const BroadcastKernelArgs<T> args) {
if (args.N == 0) return;
int tid = threadIdx.x;
// First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
// the previous GPU is done with a previous collective operation.
// wait for the last data to be pushed to us
if (tid == 0) {
Wait([=] {
return *((T * volatile *)args.PrevPtrToThisData) == nullptr; // Wait for previous processor to be done
});
*((T * volatile *)args.PrevPtrToThisData) = (T*)args.ThisData; // Tell Previous I'm starting
Wait([=] {
return *((T * volatile *)args.ThisPtrToNextData) != nullptr; // Wait till I've been told next started
});
if (PUSHRECV)
nextData = *((volatile void * volatile *)args.ThisPtrToNextData); // Grab next's pointer if needed.
}
__syncthreads();
for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
// calculate slice size. for all chunks except (possibly) the last one,
// this will just be args.SliceSize. For the last one, it may be smaller
int bigSliceN = args.SliceSize;
int smallSliceN = 0;
int lastSliceN = 0;
int numSlices = NUM_SUBCHUNKS;
int numBigSlices = numSlices;
int numSmallSlices = 0;
// last chunk
if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
&numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
args.ChunkSize);
// this offset is only applied to Data pointers, not to Buffer pointers,
// since we only have one buffer per chunk
int chunkOffset = chunk * args.ChunkSize;
int offset;
int sliceSize;
if (tid < THREADS) {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
getSliceSizeAndOffset(&sliceSize, &offset, s, numSlices,
numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
if (PUSHRECV) {
if (ROLE != ROOT)
WAIT_FOR_NEW_DATA(chunk, s);
if (ROLE != END)
Copy<UNROLL, THREADS>(
(volatile T *)nextData + chunkOffset + offset,
args.ThisData + chunkOffset + offset,
sliceSize);
} else { // PUSH2BUFF
if (ROLE == ROOT) {
WAIT_FOR_CHUNK(chunk, s);
Copy<UNROLL, THREADS>(
args.NextBuffer + (s * args.BufferSliceStride),
args.ThisData + chunkOffset + offset,
sliceSize);
} else if (ROLE == MIDDLE) {
WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, s);
DoubleCopy<UNROLL, THREADS>(
args.NextBuffer + (s * args.BufferSliceStride),
args.ThisData + chunkOffset + offset,
args.ThisBuffer + (s * args.BufferSliceStride),
sliceSize);
} else { // ROLE == END
WAIT_FOR_NEW_DATA(chunk, s);
Copy<UNROLL, THREADS>(
args.ThisData + chunkOffset + offset,
args.ThisBuffer + (s * args.BufferSliceStride),
sliceSize);
}
}
__syncthreads();
}
} else { // Consumer thread
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
if (ROLE != END)
SIGNAL_NEW_DATA_AVAILABLE(chunk, s);
// signal chunk done if we don't push into the receive buffer and this
// is no the last chunk and this is not root
if ((!PUSHRECV) && (ROLE != ROOT) && (chunk + 1 < args.NumChunks)) {
SIGNAL_CHUNK_DONE(chunk, s);
}
}
}
if (nextRank != root) {
// Wait for last update from next then reset the flag
waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
*ring.recvFlagFromNext = 0;
}
// reset flags
if (tid == 0) {
args.ThisNewDataAvailableFlag[0] = 0;
args.ThisChunkDoneFlag[0] = 0;
*args.ThisPtrToNextData = nullptr;
if (rank != root) {
// reset the flag
*ring.recvFlagFromPrev = 0;
}
incrementOpCounter(&args);
}
}
template<typename T>
ncclResult_t ncclBcastWithType(void* buff, const int count, const int root,
ncclComm* comm, int numUnroll, cudaStream_t stream) {
#define THREADS 256
#define UNROLL 8
template<class FUNC, typename T>
ncclResult_t RingBroadcast(void* buff, const int count, const int root,
ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
int index = comm->ncclId;
int rootId = comm->ringFromUser[root];
int nextId = (index + 1) % comm->nDev;
int prevId = (index + comm->nDev - 1) % comm->nDev;
// There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
// where bufferN is the number of elements of type T that fit into the buffer.
// For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
int bufferN = comm->buffSize / sizeof(T);
// we only need buffer for k slices and k paddings
int bufferNPerSlice = bufferN / NUM_SUBCHUNKS;
int maxSliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
BroadcastKernelArgs<T> args;
args.ThisId = index;
args.N = count;
args.SliceSize = numUnroll * UNROLL_SIZE * sizeof(PackType) / sizeof(T);
// if we don't directly push into the remote receive buffer, make sure slice
// fits into the temporary buffer
if (!comm->useRemoteRecv) {
// Larger transfers help QPI more than tag updates hurt P2P.
args.SliceSize *= 8;
if (comm->nRanks != 1) {
KernelArgs<T> args;
ArgsSetup(&args, buff, buff, root, count, comm);
LAUNCH_KERNEL(BroadcastKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
args.SliceSize = std::min(maxSliceSize, args.SliceSize);
args.BufferSliceStride = args.SliceSize;
args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
// avoid a case where we have one or more big chunks and one tiny one
int remainder = args.N % args.ChunkSize;
if ((args.N > args.ChunkSize) && (remainder > 0) &&
(args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
args.SliceSize /= 2;
args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
// round down so we end up with a big last chunk
args.NumChunks = args.N / args.ChunkSize;
} else {
// round up
args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
}
// printf("sliceSize = %i, chunkSize = %i, numChunks = %i\n", args.SliceSize, args.ChunkSize, args.NumChunks);
args.ThisPtrToNextData = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
args.PrevPtrToThisData = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
args.ThisData = (T*)buff;
args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
// we need 2 * NUM_SUBCHUNKS flags, so use the first NUM_SUBCHUNKS flags
// to signal the next GPU that new data is available and the following
// NUM_SUBCHUNKS to signal the previous GPU that a chunk is finished
args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if (comm->nDev != 1) {
if (comm->useRemoteRecv) {
if (index == (rootId + comm->nDev - 1) % comm->nDev) {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else if (index == rootId) {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, ROOT, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, MIDDLE, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
} else {
if (index == (rootId + comm->nDev - 1) % comm->nDev) {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, END, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else if (index == rootId) {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, ROOT, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, MIDDLE, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
}
}
return ncclSuccess;
}
class BroadcastFunctor {
public:
ncclResult_t operator()(const void* /*dummy sendbuff*/,
void* buff, int count, ncclDataType_t datatype, ncclRedOp_t /*dummy operation*/,
int root, ncclComm* comm, cudaStream_t stream) {
int numUnroll = 4;
switch (datatype) {
case ncclChar:
return ncclBcastWithType<char>(buff, count, root, comm, numUnroll, stream);
case ncclInt:
return ncclBcastWithType<int>(buff, count, root, comm, numUnroll, stream);
#ifdef CUDA_HAS_HALF
case ncclHalf:
return ncclBcastWithType<half>(buff, count, root, comm, numUnroll, stream);
#endif
case ncclFloat:
return ncclBcastWithType<float>(buff, count, root, comm, numUnroll, stream);
case ncclDouble:
return ncclBcastWithType<double>(buff, count, root, comm, numUnroll, stream);
case ncclInt64:
return ncclBcastWithType<long long>(buff, count, root, comm, numUnroll, stream);
case ncclUint64:
return ncclBcastWithType<unsigned long long>(buff, count, root, comm, numUnroll, stream);
}
return ncclInvalidType;
template<typename T, template<typename> class RedOp>
class Broadcast {
public:
static ncclResult_t entry(const void* sendbuff, void* recvbuff,
int count, int root, ncclComm* comm, cudaStream_t stream) {
return RingBroadcast<RedOp<T>, T>(recvbuff, count, root, comm, stream);
}
};
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclBcast, void* buff, int count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclBcast(void* buff, int count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream) {
return enqueue(BroadcastFunctor(), nullptr, buff, count, datatype, ncclSum,
root, comm, stream);
return enqueue<Broadcast, FuncNull>(nullptr, buff, count, datatype, root, comm, stream);
}

View File

@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
@ -245,7 +245,7 @@ __device__ inline void ReduceOrCopy(const int tid,
volatile T * __restrict__ dest0, volatile T * __restrict__ dest1,
const volatile T * __restrict__ src0, const volatile T * __restrict__ src1,
int N) {
if (N==0) {
if (N<=0) {
return;
}
@ -455,5 +455,76 @@ __device__ inline void CalcLastChunk(int * const bigSliceN,
*numSmallSlices + 1;
}
// Kernel launch
template<typename T>
struct KernelArgs {
// general parameters
int nRanks;
int root;
int buffSize;
int N;
int opIndex;
volatile int * __restrict__ opCounter;
bool pushrecv;
// some pre-computed sizes
int SliceSize;
int SliceOffset;
int ChunkSize;
int NumChunks;
// local and remote input, output, and buffer
const T * __restrict__ ThisInput;
T * __restrict__ ThisOutput;
DevRing<char>* ring;
};
template<typename T>
void ArgsSetup(KernelArgs<T> *args, const void* sendbuff, void* recvbuff,
const int root, const int count, ncclComm *comm) {
args->nRanks = comm->nRanks;
args->root = root;
args->buffSize = comm->buffSize;
args->N = count;
args->opIndex = comm->opSched;
args->opCounter = comm->opCounter;
args->ThisInput = (const T*)sendbuff;
args->ThisOutput = (T*)recvbuff;
args->ring = comm->devRing;
args->pushrecv = comm->globalMemSpace;
}
#define LAUNCH_KERNEL(K, THREADS, UNROLL, FUNC, T, \
args, stream) do { \
dim3 grid(1, 1, 1); \
dim3 block(THREADS+1, 1, 1); \
void* argptrs[] = {&args}; \
CUDACHECK(cudaLaunchKernel( \
(void*)K<THREADS, UNROLL, FUNC, T>, \
grid, block, argptrs, 0, stream)); \
} while (0)
template <typename T>
__device__ inline void incrementOpCounter(const KernelArgs<T> *args) {
// increment comm's operation counts
__threadfence_system(); // Technically need to ensure that cleared flags
// are visible before incrementing op counter.
*args->opCounter = args->opIndex+1;
}
template <int THREADS, typename T> __device__ __forceinline__
void LoadRing(const DevRing<char>* src, DevRing<T>* dst) {
enum { NUM_WORDS = sizeof(DevRing<char>) / sizeof(long long) };
static_assert(sizeof(DevRing<char>) % sizeof(long long) == 0, "Bad alignment");
static_assert(THREADS >= NUM_WORDS, "Not enough threads to load DevRing");
static_assert(sizeof(DevRing<char>) == sizeof(DevRing<T>), "DevRing size mismatch");
long long* lldst = reinterpret_cast<long long*>(dst);
const long long* llsrc = reinterpret_cast<const long long*>(src);
if (threadIdx.x < NUM_WORDS) {
lldst[threadIdx.x] = llsrc[threadIdx.x];
}
}
#endif // COMMON_KERNEL_H_

View File

@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/

View File

@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#include <stdio.h>
@ -20,7 +20,7 @@
DebugLevel ncclDebugLevel;
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclGetUniqueId, ncclUniqueId* out);
ncclResult_t ncclGetUniqueId(ncclUniqueId* out) {
pid_t pid = getpid();
static int count = 0;
@ -83,7 +83,7 @@ typedef struct {
int rank;
int ndev;
int cudaDev;
int ncclId;
int sortId;
pid_t pid;
ncclMem* hostptr;
ncclMem* devptr;
@ -94,15 +94,13 @@ typedef struct {
static int compRanks(const void* a, const void* b) {
const RankEntry* A = (const RankEntry*)a;
const RankEntry* B = (const RankEntry*)b;
if (A->ncclId < B->ncclId) return -1;
if (A->ncclId > B->ncclId) return 1;
if (A->sortId < B->sortId) return -1;
if (A->sortId > B->sortId) return 1;
return 0;
}
static void orderRanks(RankEntry* ranks, int count) {
qsort(ranks, count, sizeof(RankEntry), compRanks);
for(int i=0; i<count; ++i)
ranks[i].ncclId = i;
}
@ -110,7 +108,7 @@ typedef struct {
union {
struct {
volatile int bar;
int ringDirectFail;
int globalMemSpaceBroke;
};
char pad[16];
};
@ -156,7 +154,7 @@ static ncclResult_t initGather(RankGather** gather, ncclUniqueId commId,
return ncclSuccess;
}
static void syncRingDirect(RankGather* gather, int* ringDirectOk) {
static void syncRingDirect(RankGather* gather, int* globalMemSpaceOk) {
int bar_tmp = gather->bar - 1;
int ndev = gather->ranks[0].ndev;
bool swapped;
@ -169,7 +167,7 @@ static void syncRingDirect(RankGather* gather, int* ringDirectOk) {
sched_yield();
__sync_synchronize();
*ringDirectOk = gather->ringDirectFail ? 0 : 1;
*globalMemSpaceOk = gather->globalMemSpaceBroke ? 0 : 1;
}
static ncclResult_t closeGather(RankGather* gather, int ndev) {
@ -264,13 +262,13 @@ static ncclResult_t populateRankInfo(RankEntry* info, int rank, ncclComm_t comm)
return ncclUnhandledCudaError;
}
// Order by nvml index
if (wrapNvmlDeviceGetIndex(nvmlHandle, (unsigned*)&info->ncclId) != ncclSuccess) {
if (wrapNvmlDeviceGetIndex(nvmlHandle, (unsigned*)&info->sortId) != ncclSuccess) {
WARN("rank %d failed to get nvml device index for device %d", rank, comm->cudaDev);
return ncclUnhandledCudaError;
}
info->rank = rank;
info->ndev = comm->nDev;
info->ndev = comm->nRanks;
info->cudaDev = comm->cudaDev;
info->pid = getpid();
info->buffSize = comm->buffSize;
@ -285,109 +283,104 @@ static ncclResult_t populateRankInfo(RankEntry* info, int rank, ncclComm_t comm)
}
static const int CLEANUP_NONE = 0;
static const int CLEANUP_CUIPC = 1;
static const int CLEANUP_UNMAP = 2;
static ncclResult_t commClearMaps(ncclComm_t comm) {
ncclResult_t res, retval = ncclSuccess;
cudaError_t cures;
for(int d=0; d<comm->nDev; ++d) {
switch(comm->ptrs[d].remoteCleanup) {
case CLEANUP_NONE:
break;
case CLEANUP_CUIPC:
cures = cudaIpcCloseMemHandle((void*)comm->ptrs[d].cleanupHandle);
for(int d=0; d<comm->nRanks; ++d) {
if (comm->ptrs[d].hostCleanup != NULL) {
cures = cudaHostUnregister(comm->ptrs[d].hostCleanup);
if (cures != cudaSuccess) {
WARN("rank %d failed to close IPC handle to rank %d",
comm->userFromRing[comm->ncclId], comm->userFromRing[d]);
WARN("rank %d failed to unregister handle to device %d",
comm->rank, d);
retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
}
break;
case CLEANUP_UNMAP:
cures = cudaHostUnregister(comm->ptrs[d].cleanupHandle);
if (cures != cudaSuccess) {
WARN("rank %d failed to unregister handle to rank %d",
comm->userFromRing[comm->ncclId], comm->userFromRing[d]);
retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
}
res = shmUnmap(comm->ptrs[d].cleanupHandle, offsetof(ncclMem, buff) + comm->buffSize);
res = shmUnmap(comm->ptrs[d].hostCleanup, offsetof(ncclMem, buff) + comm->buffSize);
if (res != ncclSuccess) {
WARN("rank %d failed to unmap handle to rank %d",
comm->userFromRing[comm->ncclId], comm->userFromRing[d]);
WARN("rank %d failed to unmap handle to device %d",
comm->rank, d);
retval = (retval == ncclSuccess) ? res : retval;
}
break;
default:
WARN("Unknown cleanup type %d", comm->ptrs[d].remoteCleanup);
comm->ptrs[d].hostCleanup = NULL;
}
if (comm->ptrs[d].devCleanup != NULL) {
cures = cudaIpcCloseMemHandle((void*)comm->ptrs[d].devCleanup);
if (cures != cudaSuccess) {
WARN("rank %d failed to close IPC handle to device %d: %s",
comm->rank, d, cudaGetErrorString(cures));
retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
}
}
comm->ptrs[d].remoteCleanup = CLEANUP_NONE;
comm->ptrs[d].cleanupHandle = NULL;
}
if (comm->userFromRing != NULL)
memset(comm->userFromRing, 0, sizeof(int)*comm->nDev);
if (comm->ringFromUser != NULL)
memset(comm->ringFromUser, 0, sizeof(int)*comm->nDev);
memset(comm->userFromRing, 0, sizeof(int)*comm->nRanks);
if (comm->ncclFromRing != NULL)
memset(comm->ncclFromRing, 0, sizeof(int)*comm->nRanks);
if (comm->devUserFromRing != NULL) {
cudaError_t err = cudaMemset(comm->devUserFromRing, 0, sizeof(int)*comm->nDev);
if (err != cudaSuccess) {
WARN("Faild to clear dev map: %s", cudaGetErrorString(err));
cures = cudaMemset(comm->devUserFromRing, 0, sizeof(int)*comm->nRanks);
if (cures != cudaSuccess) {
WARN("Faild to clear dev map: %s", cudaGetErrorString(cures));
retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
}
}
if (comm->devRing != NULL) {
cures = cudaMemset(comm->devRing, 0, sizeof(DevRing<char>));
if (cures != cudaSuccess) {
WARN("Failed to clear devRing: %s", cudaGetErrorString(cures));
retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
}
}
return retval;
}
static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int rank, RankEntry* ranks, int* ringDirectFailed) {
int ndev = comm->nDev;
for(int i=0; i<ndev; ++i) {
static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int rank, RankEntry* ranks, int* globalMemSpaceBroke) {
int ndev = comm->nRanks;
comm->rank = rank;
if (ndev > MAXRANKS) {
WARN("%d ranks exceeds MAXRANKS of %d", ndev, MAXRANKS);
return ncclUnsupportedDeviceCount;
}
// Check for inconsistencies between ranks
// If two ranks use the same rank, then one slot of
// ranks[] will be left unset with zero ndev/buffSize.
for(int i=0; i<ndev; ++i) {
if (ranks[i].buffSize != comm->buffSize
|| ranks[i].ndev != comm->nDev) {
|| ranks[i].ndev != comm->nRanks) {
commClearMaps(comm);
return ncclRankMismatch;
}
// Create rank<->nccl maps
int iRank = ranks[i].rank;
comm->userFromRing[i] = iRank;
comm->ringFromUser[iRank] = i;
}
if (cudaMemcpy(comm->devUserFromRing, comm->userFromRing, ndev*sizeof(int),
cudaMemcpyHostToDevice) != cudaSuccess) {
WARN("rank %d failed to copy maps to device", rank);
commClearMaps(comm);
return ncclUnhandledCudaError;
}
int myId = -1;
// Find self among ranks of gather
int myNcclId = -1;
for (int i=0; i<ndev; ++i) {
if(ranks[i].rank == rank) {
myId = i;
myNcclId = i;
break;
}
}
if (myId == -1) {
if (myNcclId == -1) {
WARN("rank %d not found in communicator", rank);
return ncclInvalidRank;
}
comm->ncclId = myId;
int myDev = ranks[myId].cudaDev;
pid_t myPid = ranks[myId].pid;
comm->useRemoteRecv = 1; // Assume we directly write to result ptrs.
for(int ringPos=0; ringPos<ndev; ++ringPos) {
int ncclPos = (ringPos+myNcclId) % ndev; // ring order relative to self
int userRank = ranks[ncclPos].rank;
comm->userFromRing[ringPos] = userRank;
comm->ncclFromRing[ringPos] = ncclPos;
}
// The order that we link with peers must ensure that
// P2P slots are used for high-priority links first.
for (int j=0; j<ndev; ++j) {
int i = (myId - 1 + ndev + j) % ndev;
int myDev = ranks[myNcclId].cudaDev;
pid_t myPid = ranks[myNcclId].pid;
for (int i=0; i<ndev; ++i) {
int iRank = ranks[i].rank;
int iDev = ranks[i].cudaDev;
pid_t iPid = ranks[i].pid;
@ -399,84 +392,127 @@ static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int ran
canpeer = 0;
}
if (iPid == myPid) {
if (myDev == iDev) {
INFO("rank access %d -> %d via common device", rank, iRank);
comm->ptrs[i].local = ranks[myId].devptr;
comm->ptrs[i].remote = ranks[i].devptr;
comm->ptrs[i].remoteCleanup = CLEANUP_NONE;
} else {
int peer_enabled = canpeer;
if (canpeer) {
cudaError_t p2pErr = cudaDeviceEnablePeerAccess(iDev, 0);
if (p2pErr == cudaErrorPeerAccessAlreadyEnabled) {
cudaGetLastError();
} else if (p2pErr != cudaSuccess) {
INFO("peer access failed between rank %d (dev %d) and rank %d (dev %d)\n",
rank, myDev, iRank, iDev);
peer_enabled = 0;
}
}
cudaError_t err;
ncclMem* remoteHostBuff;
if (peer_enabled) {
INFO("rank access %d -> %d via P2P device mem", rank, iRank);
comm->ptrs[i].local = ranks[myId].devptr;
comm->ptrs[i].type = NodeRef::HOST; // Assume host buffer
comm->ptrs[i].devCleanup = NULL;
comm->ptrs[i].hostCleanup = NULL;
if (iPid == myPid) {
remoteHostBuff = ranks[i].hostptr;
if (myDev == iDev) { // shared device
INFO("rank access %d -> %d via common device", rank, iRank);
comm->ptrs[i].type = NodeRef::DEVICE;
comm->ptrs[i].local = ranks[myNcclId].devptr;
comm->ptrs[i].remote = ranks[i].devptr;
comm->ptrs[i].remoteCleanup = CLEANUP_NONE;
} else { // go through hostmem
INFO("rank access %d -> %d via zero-copy host mem", rank, iRank);
if (j <= 2)
*ringDirectFailed = 1;
if (cudaHostGetDevicePointer(&comm->ptrs[i].local, ranks[myId].hostptr, 0) != cudaSuccess) {
WARN("rank %d failed to map zero copy buffer to device", rank);
} else if (canpeer) {
INFO("rank access %d -> %d via P2P device mem", rank, iRank);
err = cudaDeviceEnablePeerAccess(iDev, 0);
if (err == cudaErrorPeerAccessAlreadyEnabled) {
cudaGetLastError();
} else if (err != cudaSuccess) {
WARN("rank %d failed to peer with device %d: %s",
rank, iDev, cudaGetErrorString(err));
commClearMaps(comm);
return ncclUnhandledCudaError;
}
if (cudaHostGetDevicePointer(&comm->ptrs[i].remote, ranks[i].hostptr, 0) != cudaSuccess) {
WARN("rank %d failed to map %d's zero copy buffer to device", rank, iRank);
commClearMaps(comm);
return ncclUnhandledCudaError;
}
comm->ptrs[i].remoteCleanup = CLEANUP_NONE;
}
}
} else { // multi-process!
*ringDirectFailed = 1;
if (canpeer || myDev == iDev) {
INFO("rank access %d -> %d via Ipc P2P device mem", rank, iRank);
comm->ptrs[i].local = ranks[myId].devptr;
if (cudaIpcOpenMemHandle((void**)(&comm->ptrs[i].remote),
ranks[i].devipc, cudaIpcMemLazyEnablePeerAccess) != cudaSuccess) {
WARN("rank %d failed to open Ipc handle to rank %d", rank, iRank);
commClearMaps(comm);
return ncclUnhandledCudaError;
}
comm->ptrs[i].remoteCleanup = CLEANUP_CUIPC;
comm->ptrs[i].cleanupHandle = comm->ptrs[i].remote;
} else { // go through hostmem
INFO("rank access %d -> %d via zero copy host shm", rank, iRank);
if (cudaHostGetDevicePointer(&comm->ptrs[i].local, ranks[myId].hostptr, 0) != cudaSuccess) {
WARN("rank %d failed to obtain dev ptr to sysmem buffer", rank);
commClearMaps(comm);
return ncclUnhandledCudaError;
comm->ptrs[i].type = NodeRef::DEVICE;
comm->ptrs[i].local = ranks[myNcclId].devptr;
comm->ptrs[i].remote = ranks[i].devptr;
}
} else { // Separate processes
*globalMemSpaceBroke = 1;
char rankname[1024];
sprintf(rankname, "%s-%d", commId->internal, ranks[i].rank);
if (openHostMemShm(rankname, (ncclMem**)&comm->ptrs[i].cleanupHandle, ranks[i].buffSize)
if (openHostMemShm(rankname, &remoteHostBuff, ranks[i].buffSize)
!= ncclSuccess) {
WARN("rank %d failed to open sysmem buffer of rank %d", rank, iRank);
commClearMaps(comm);
return ncclUnhandledCudaError;
}
if (cudaHostGetDevicePointer(&comm->ptrs[i].remote, comm->ptrs[i].cleanupHandle, 0) != cudaSuccess) {
WARN("rank %d failed to obtain dev ptr for rank %d", rank, iRank);
comm->ptrs[i].hostCleanup = remoteHostBuff;
// TODO: Extend to same device (MPS) case.
// At present that would go through host mem.
if (canpeer) {
INFO("rank access %d -> %d via IPC device mem", rank, iRank);
comm->ptrs[i].type = NodeRef::DEVICE;
comm->ptrs[i].local = ranks[myNcclId].devptr;
err = cudaIpcOpenMemHandle((void**)(&comm->ptrs[i].remote),
ranks[i].devipc, cudaIpcMemLazyEnablePeerAccess);
if (err != cudaSuccess) {
WARN("rank %d failed to open Ipc handle to rank %d: %s",
rank, iRank, cudaGetErrorString(err));
commClearMaps(comm);
return ncclUnhandledCudaError;
}
comm->ptrs[i].remoteCleanup = CLEANUP_UNMAP;
comm->ptrs[i].devCleanup = comm->ptrs[i].remote;
}
}
err = cudaHostGetDevicePointer(&comm->ptrs[i].opCounter,
&(remoteHostBuff->opCounter), 0);
if (err != cudaSuccess) {
WARN("rank %d failed to obtain %d's zero copy pointer: %s",
rank, iRank, cudaGetErrorString(err));
commClearMaps(comm);
return ncclUnhandledCudaError;
}
if (comm->ptrs[i].type == NodeRef::HOST) {
*globalMemSpaceBroke = 1;
INFO("rank access %d -> %d via zero-copy host mem", rank, iRank);
if (cudaHostGetDevicePointer(&comm->ptrs[i].local, ranks[myNcclId].hostptr, 0) != cudaSuccess) {
WARN("rank %d failed to map zero copy buffer to device", rank);
commClearMaps(comm);
return ncclUnhandledCudaError;
}
if (cudaHostGetDevicePointer(&comm->ptrs[i].remote, remoteHostBuff, 0) != cudaSuccess) {
WARN("rank %d failed to map %d's zero copy buffer to device", rank, iRank);
commClearMaps(comm);
return ncclUnhandledCudaError;
}
}
}
// Setup device-side ring view
if (cudaMemcpy(comm->devUserFromRing, comm->userFromRing, ndev*sizeof(int),
cudaMemcpyHostToDevice) != cudaSuccess) {
WARN("rank %d failed to copy maps to device", rank);
commClearMaps(comm);
return ncclUnhandledCudaError;
}
DevRing<char> ringTemp;
memcpy(ringTemp.userRank, comm->userFromRing, ndev*sizeof(int));
int prevIdx = comm->ncclFromRing[comm->nRanks-1];
int nextIdx = comm->ncclFromRing[1 % comm->nRanks];
NodeRef* prevPtrs = comm->ptrs+prevIdx;
NodeRef* nextPtrs = comm->ptrs+nextIdx;
ringTemp.prevOpCounter = prevPtrs->opCounter;
ringTemp.nextOpCounter = nextPtrs->opCounter;
ringTemp.sendFlagToNext = nextPtrs->remote->flags;
ringTemp.recvFlagFromPrev = prevPtrs->local->flags;
ringTemp.sendFlagToPrev = prevPtrs->remote->flags+1;
ringTemp.recvFlagFromNext = nextPtrs->local->flags+1;
ringTemp.recvPtrFromNext = (char**)&nextPtrs->local->recvPtrs;
ringTemp.sendPtrToPrev = (char**)&prevPtrs->remote->recvPtrs;
ringTemp.recvBuffer = prevPtrs->local->buff;
ringTemp.sendBuffer = nextPtrs->remote->buff;
if (cudaMemcpy(comm->devRing, &ringTemp, sizeof(ringTemp),
cudaMemcpyHostToDevice) != cudaSuccess) {
WARN("rank %d failed to copy ring maps to device", rank);
commClearMaps(comm);
return ncclUnhandledCudaError;
}
return ncclSuccess;
}
@ -495,23 +531,24 @@ static void initDebug() {
ncclDebugLevel = ABORT;
INFO("NCCL debug level set to ABORT");
}
}
static void commFree(ncclComm_t comm) {
if (comm == NULL)
return;
for(int i=0; i<MAXQUEUE; ++i) {
if (comm->events.isDone[i] != NULL)
if (cudaEventDestroy(comm->events.isDone[i]) != cudaSuccess)
INFO("failed to destroy cuda event %d", i);
}
if (comm->doneEvent != NULL)
if (cudaEventDestroy(comm->doneEvent) != cudaSuccess)
INFO("ncclComm failed to destroy doneEvent");
ncclResult_t res = commClearMaps(comm);
if (res != ncclSuccess)
INFO("failed to cleanup comm maps");
if (comm->devRing != NULL)
if (cudaFree(comm->devRing) != cudaSuccess)
INFO("commFree failed to free devRing");
if (comm->userFromRing != NULL)
free(comm->userFromRing);
@ -519,8 +556,8 @@ static void commFree(ncclComm_t comm) {
if (cudaFree(comm->devUserFromRing) != cudaSuccess)
INFO("commFree failed to free dev maps");
if (comm->ringFromUser != NULL)
free(comm->ringFromUser);
if (comm->ncclFromRing != NULL)
free(comm->ncclFromRing);
if (comm->devMem != NULL && cudaFree(comm->devMem) != cudaSuccess)
INFO("Failed to free devMap");
@ -550,7 +587,7 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
return ncclInvalidRank;
}
size_t commBytes = offsetof(ncclComm, ptrs) + ndev*sizeof(ncclNodeRef);
size_t commBytes = offsetof(ncclComm, ptrs) + ndev*sizeof(NodeRef);
struct ncclComm* comm = (struct ncclComm*)malloc(commBytes);
if (comm == NULL) {
WARN("comm allocation failed");
@ -558,21 +595,23 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
}
memset(comm, 0, commBytes);
comm->nDev = ndev;
comm->nRanks = ndev;
cudaGetDevice(&comm->cudaDev);
const char* str = getenv("NCCL_BUFFSIZE");
int buffsize;
if (str != NULL) {
errno = 0;
comm->buffSize = strtol(str, NULL, 10);
if (errno == ERANGE || comm->buffSize == 0) {
buffsize = strtol(str, NULL, 10);
if (errno == ERANGE || buffsize == 0) {
INFO("rank %d invalid NCCL_BUFFSIZE: %s, using default %lu",
rank, str, DEFAULT_BUFFER_SIZE_BYTES);
comm->buffSize = DEFAULT_BUFFER_SIZE_BYTES;
buffsize = DEFAULT_BUFFER_SIZE_BYTES;
}
} else {
comm->buffSize = DEFAULT_BUFFER_SIZE_BYTES;
buffsize = DEFAULT_BUFFER_SIZE_BYTES;
}
comm->buffSize = buffsize;
INFO("rank %d using buffSize = %lu", rank, comm->buffSize);
@ -583,7 +622,14 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
commFree(comm);
return res;
}
if (cudaMalloc(&comm->devUserFromRing, ndev*sizeof(int)) != cudaSuccess) {
if (cudaMalloc(&comm->devRing, sizeof(DevRing<char>)) != cudaSuccess) {
WARN("rank %d failed to allocate device-side ring views", rank);
commFree(comm);
return ncclCudaMallocFailed;
}
if (cudaMalloc(&comm->devUserFromRing, ndev*sizeof(int)) != cudaSuccess ) {
WARN("rank %d failed to allocated device maps", rank);
commFree(comm);
return ncclCudaMallocFailed;
@ -596,21 +642,18 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
return ncclSystemError;
}
comm->ringFromUser = (int*)malloc(ndev*sizeof(int));
if (comm->ringFromUser == NULL) {
comm->ncclFromRing = (int*)malloc(ndev*sizeof(int));
if (comm->ncclFromRing == NULL) {
WARN("rank %d failed to allocate host maps", rank);
commFree(comm);
return ncclSystemError;
}
EventQueue* eq = &comm->events;
for(int i=0; i<MAXQUEUE; ++i) {
if (cudaEventCreateWithFlags(eq->isDone+i, cudaEventDisableTiming) != cudaSuccess) {
WARN("rank %d failed to create nccl event %d", rank, i);
if (cudaEventCreateWithFlags(&comm->doneEvent, cudaEventDisableTiming) != cudaSuccess) {
WARN("ncclComm on rank %d failed to create doneEvent", rank);
commFree(comm);
return ncclUnhandledCudaError;
}
}
if(commId == NULL) {
comm->hostMemState = 0;
@ -627,10 +670,46 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
comm->hostMemState = ShmMapped | ShmLinked;
}
if (cudaHostGetDevicePointer(&comm->opCounter, &comm->hostMem->opCounter, 0) != cudaSuccess) {
WARN("ncclComm on rank %d failed to map opCounter to device", rank);
commFree(comm);
return ncclUnhandledCudaError;
}
*comret = comm;
return ncclSuccess;
}
static ncclResult_t devCommUpdate(ncclComm_t comm) {
// Copy the comm on the device
size_t commBytes = offsetof(ncclComm, ptrs) + comm->nRanks*sizeof(NodeRef);
if (cudaMemcpy(comm->devComm, comm, commBytes, cudaMemcpyHostToDevice) != cudaSuccess) {
WARN("failed to copy device comm");
return ncclUnhandledCudaError;
}
// Fix the host pointer to be accessible from the device
void* dptr;
if (cudaHostGetDevicePointer(&dptr, comm->hostMem, 0) != cudaSuccess) {
WARN("failed to get device pointer for host mem");
return ncclUnhandledCudaError;
}
if (cudaMemcpy(&comm->devComm->hostMem, &dptr, sizeof(dptr), cudaMemcpyHostToDevice) != cudaSuccess) {
WARN("failed to update host pointer");
return ncclUnhandledCudaError;
}
return ncclSuccess;
}
static ncclResult_t devCommSetup(ncclComm_t comm) {
// Fully duplicate the comm on the device
size_t commBytes = offsetof(ncclComm, ptrs) + comm->nRanks*sizeof(NodeRef);
if (cudaMalloc(&comm->devComm, commBytes) != cudaSuccess) {
WARN("failed to allocated device comm");
return ncclCudaMallocFailed;
}
return devCommUpdate(comm);
}
static ncclResult_t commUnlinkHostMem(ncclComm_t comm, ncclUniqueId commId, int rank) {
char rankname[1024];
sprintf(rankname, "%s-%d", commId.internal, rank);
@ -643,12 +722,12 @@ static void showVersion() {
static int shown = 0;
if (shown == 0 && ncclDebugLevel >= VERSION) {
printf("NCCL version %d.%d.%d compiled with CUDA %d.%d\n", NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH, CUDA_MAJOR, CUDA_MINOR);
fflush(stdout); \
fflush(stdout);
shown = 1;
}
}
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank);
ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank) {
if (myrank == 0) showVersion();
@ -693,14 +772,14 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId
goto cleanup;
}
res = commBuildMaps(*newcomm, &commId, myrank, gath->ranks, &gath->ringDirectFail);
res = commBuildMaps(*newcomm, &commId, myrank, gath->ranks, &gath->globalMemSpaceBroke);
if (res != ncclSuccess) {
WARN("rank %d failed to build comm maps", myrank);
goto cleanup;
}
syncRingDirect(gath, &((*newcomm)->useRemoteRecv));
INFO("PushToRecv algos are %s\n", (*newcomm)->useRemoteRecv ? "enabled" : "disabled");
syncRingDirect(gath, &((*newcomm)->globalMemSpace));
INFO("Global device memory space is %s", (*newcomm)->globalMemSpace ? "enabled" : "disabled");
res = closeGather(gath, ndev); // includes a barrier
gath = NULL;
@ -709,6 +788,13 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId
goto cleanup;
}
res = devCommSetup(*newcomm);
if (res != ncclSuccess) {
WARN("rank %d failed to copy dcomm", myrank);
goto cleanup;
}
res = ncclSuccess;
goto final;
cleanup:
@ -727,7 +813,7 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId
return res;
}
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclCommInitAll, ncclComm_t* comms, int ndev, const int* devlist);
ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
initDebug();
@ -741,7 +827,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
char busId[13];
nvmlDevice_t nvmlHandle;
int affinity_set = 0;
int ringDirectFail = 0; // Assume direct access to recv ptr OK
int globalMemSpaceBroke = 0; // Assume direct access to recv ptr OK
res = wrapSymbols();
if (res != ncclSuccess) {
@ -812,16 +898,24 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
for(rank=0; rank<ndev; ++rank) {
comm = comms[rank];
cudaSetDevice(comm->cudaDev);
res = commBuildMaps(comm, NULL, rank, ranks, &ringDirectFail);
res = commBuildMaps(comm, NULL, rank, ranks, &globalMemSpaceBroke);
if (res != ncclSuccess) {
WARN("rank %d failed to build comm maps", rank);
goto cleanup;
}
}
INFO("PushToRecv algos are %s\n", (ringDirectFail) ? "disabled" : "enabled");
INFO("Global device memory space is %s", (globalMemSpaceBroke) ? "disabled" : "enabled");
for(rank=0; rank<ndev; ++rank) {
comms[rank]->useRemoteRecv = ringDirectFail ? 0 : 1;
comms[rank]->globalMemSpace = globalMemSpaceBroke ? 0 : 1;
}
for(rank=0; rank<ndev; ++rank) {
res = devCommSetup(comms[rank]);
if (res != ncclSuccess) {
WARN("rank %d failed to copy dcomm", rank);
goto cleanup;
}
}
free(ranks);
@ -845,8 +939,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
return res;
}
extern "C" DSOGLOBAL
NCCL_API(void, ncclCommDestroy, ncclComm_t comm);
void ncclCommDestroy(ncclComm_t comm) {
if (comm == NULL)
return;
@ -865,7 +958,7 @@ void ncclCommDestroy(ncclComm_t comm) {
cudaSetDevice(savedDevice);
}
extern "C" DSOGLOBAL
NCCL_API(const char*, ncclGetErrorString, ncclResult_t code);
const char* ncclGetErrorString(ncclResult_t code) {
switch (code) {
case ncclSuccess : return "no error";
@ -887,21 +980,21 @@ const char* ncclGetErrorString(ncclResult_t code) {
return "unknown result code";
}
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclCommCount, const ncclComm_t comm, int* count);
ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) {
*count = comm->nDev;
*count = comm->nRanks;
return ncclSuccess;
}
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclCommCuDevice, const ncclComm_t comm, int* devid);
ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* devid) {
*devid = comm->cudaDev;
return ncclSuccess;
}
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclCommUserRank, const ncclComm_t comm, int* rank);
ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) {
*rank = comm->userFromRing[comm->ncclId];
*rank = comm->rank;
return ncclSuccess;
}

View File

@ -1,19 +1,17 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#ifndef CORE_H_
#define CORE_H_
#include "nccl.h"
#include <cstdio>
#include <cuda_runtime.h>
#define MAXFLAGS 8
#define MAXQUEUE 4 // Maximum number of queued collectives per communicator.
#define DEFAULT_BUFFER_SIZE_BYTES (1UL << 21)
// DIE on error
#define CUDACHECK(cmd) do { \
@ -25,55 +23,78 @@
} \
} while(false)
#define NCCL_MEM_PAD_ALIGN 4096
typedef struct {
cudaEvent_t isDone[MAXQUEUE];
int back; // Last event used
} EventQueue;
#define MAXRANKS 32
#define DEFAULT_BUFFER_SIZE_BYTES (1UL << 21)
#define NCCL_MEM_PAD_ALIGN 65536
struct ncclMem {
union { // Pad this block so that devBuff is correctly aligned
struct {
int flags[MAXFLAGS];
void* recvPtrs[MAXFLAGS];
int flags[2];
void* recvPtrs;
int opCounter; // Used to determine when remote Communicators are ready.
// Only used in host memory.
};
char pad[NCCL_MEM_PAD_ALIGN];
};
// devBuff will likely be bigger ; we only use its offset/address.
char buff[NCCL_MEM_PAD_ALIGN];
// devBuff will be bigger ; we only use its offset/address.
char buff[1];
};
struct ncclNodeRef {
ncclMem* remote;
ncclMem* local;
int remoteCleanup;
void* cleanupHandle;
template <typename T>
struct alignas(long long) DevRing {
volatile int* __restrict__ prevOpCounter;
volatile int* __restrict__ nextOpCounter;
volatile int* __restrict__ sendFlagToNext;
volatile int* __restrict__ sendFlagToPrev;
volatile int* __restrict__ recvFlagFromNext;
volatile int* __restrict__ recvFlagFromPrev;
T* volatile * __restrict__ recvPtrFromNext;
T* volatile * __restrict__ sendPtrToPrev;
T* __restrict__ recvBuffer;
T* __restrict__ sendBuffer;
int userRank[MAXRANKS];
};
struct NodeRef {
ncclMem* remote; // TODO: Verify if these
ncclMem* local; // are still needed.
enum {DEVICE, HOST} type;
ncclMem* devCleanup; // Used only when remote comm uses same process & GPU
ncclMem* hostCleanup; // Used whenever target is in different process
int* opCounter; // TODO: see if this can be removed too.
};
struct ncclComm {
int nDev; // number of devices in communicator
int cudaDev; // cuda device index
int ncclId; // nccl logical index
int rank; // my rank in the communicator
int nRanks; // number of GPUs in communicator
int cudaDev; // my cuda device index
// Device and Host allocated chunks. Stored here to correctly free() memory.
ncclMem* devMem;
ncclMem* hostMem;
int hostMemState;
int opSched; // Scheduling operation index
int* opCounter; // Counter of completed operations
// Placed between calling and internal device streams.
EventQueue events;
cudaStream_t prevStream; // cache last used stream
cudaEvent_t doneEvent; // orders operations in different streams
// Maps an internal nccl index to user-specified rank order. This is necessary
// since we need to know how the user expects data to be ordered across
// devices.
// devices. Ordered from current device.
int* userFromRing;
// copy of the above stored on each device
int* devUserFromRing;
// Inverse of userFromRing. Maps user specified index to internal nccl index.
int* ringFromUser;
// Ring order
int* ncclFromRing; // TODO: REMOVE IF NOT NEEDED BEYOND CORE.CU
// Size of temp buffer in bytes.
size_t buffSize;
@ -81,13 +102,20 @@ struct ncclComm {
// Whether we have remote access to the recvbuff pointers passed from remote
// GPUs. In single process mode this can be used as long as QPI links are
// not present. In multi-process, we never push to a remote recvbuff.
int useRemoteRecv;
int globalMemSpace;
// Device copy of the communicator
struct ncclComm *devComm; // TODO: Remove this if not useful
// Device-side ring view
DevRing<char>* devRing;
// Device-to-device communication structures to access remote or local device
// memory. Actual allocation larger than 1.
ncclNodeRef ptrs[1];
NodeRef ptrs[1];
};
typedef enum {NONE=0, VERSION=1, WARN=2, INFO=3, ABORT=4} DebugLevel;
extern DebugLevel ncclDebugLevel;
@ -96,6 +124,7 @@ extern DebugLevel ncclDebugLevel;
printf("WARN %s:%d ", __FILE__, __LINE__); \
printf(__VA_ARGS__); \
printf("\n"); \
fflush(stdout); \
if (ncclDebugLevel >= ABORT) abort(); \
} \
} while(0)
@ -103,10 +132,26 @@ extern DebugLevel ncclDebugLevel;
#define INFO(...) do { \
if (ncclDebugLevel >= INFO) { \
printf("INFO "); printf(__VA_ARGS__); printf("\n"); \
fflush(stdout); \
} \
} while(0)
#define DSOGLOBAL __attribute__((visibility("default")))
#ifdef PROFAPI
#define NCCL_API(ret, func, args...) \
__attribute__ ((visibility("default"))) \
__attribute__ ((alias(#func))) \
ret p##func (args); \
extern "C" \
__attribute__ ((visibility("default"))) \
__attribute__ ((weak)) \
ret func(args)
#else
#define NCCL_API(ret, func, args...) \
extern "C" \
__attribute__ ((visibility("default"))) \
ret func(args)
#endif // end PROFAPI
#endif // end include guard

View File

@ -1,31 +1,90 @@
/*************************************************************************
* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#ifndef enqueue_h_
#define enqueue_h_
#include "core.h"
#include "reduce_kernel.h"
int getRingIndex(const ncclComm_t comm, int device);
void lockEventQueue(EventQueue* eq);
void releaseEventQueue(EventQueue* eq);
void CUDART_CB freeEvent(cudaStream_t stream, cudaError_t status, void* eq_void);
/* Syncronize previous collective (if in different stream) and enqueue
* collective. Work is performed asynchronously with the host thread.
* The ColFunc class should be templated on the datatype and reduction
* operator (if applicable) and define a static entry() method as
* follows.
* template <typename T, template <typename> class RedOp>
* class CollectiveFunctor {
* public:
* static ncclResult_t entry(const void* sendbuff, void* recvbuff, int count,
* int root, ncclComm* comm, cudaStream_t stream);
* };
* The entry() method can assume that the appropriate cuda device has been set. */
template< template<typename, template<typename> class> class ColFunc,
typename T,
template<typename> class Op >
ncclResult_t enqueue(const void* sendbuff,
void* recvbuff,
int count,
int root,
ncclComm_t comm,
cudaStream_t stream)
{
if (stream != comm->prevStream) { // sync required for calls in different streams
comm->prevStream = stream;
CUDACHECK( cudaStreamWaitEvent(stream, comm->doneEvent, 0) );
}
/* Syncronize with user stream and launch the collective.
* All work is performed asynchronously with the host thread.
* The actual collective should be a functor with the
* folloaing signature.
* ncclResult_t collective(void* sendbuff, void* recvbuff,
* int count, ncclDataType_t type, ncclRedOp_t op,
* int root, ncclComm_t comm);
* Unneeded arguments should be ignored. The collective may
* assume that the appropriate cuda device has been set. */
template<typename ColFunc>
ncclResult_t enqueue(ColFunc colfunc,
const void* sendbuff,
ncclResult_t ret;
ret = ColFunc<T, Op>::entry(sendbuff, recvbuff, count, root, comm, stream);
// Always have to record done event because we don't know what stream next
// collective will be in.
CUDACHECK( cudaEventRecord(comm->doneEvent, stream) );
comm->opSched += 1;
return ret;
}
// This version decodes type
template< template<typename, template<typename> class> class ColFunc,
template<typename> class Op >
ncclResult_t enqueue(const void* sendbuff,
void* recvbuff,
int count,
ncclDataType_t type,
int root,
ncclComm_t comm,
cudaStream_t stream)
{
switch(type) {
case ncclChar:
return enqueue<ColFunc, char, Op>(sendbuff, recvbuff, count, root, comm, stream);
case ncclInt:
return enqueue<ColFunc, int, Op>(sendbuff, recvbuff, count, root, comm, stream);
#ifdef CUDA_HAS_HALF
case ncclHalf:
return enqueue<ColFunc, half, Op>(sendbuff, recvbuff, count, root, comm, stream);
#endif
case ncclFloat:
return enqueue<ColFunc, float, Op>(sendbuff, recvbuff, count, root, comm, stream);
case ncclDouble:
return enqueue<ColFunc, double, Op>(sendbuff, recvbuff, count, root, comm, stream);
case ncclInt64:
return enqueue<ColFunc, long long, Op>(sendbuff, recvbuff, count, root, comm, stream);
case ncclUint64:
return enqueue<ColFunc, unsigned long long, Op>(sendbuff, recvbuff, count, root, comm, stream);
default:
WARN("Invalid ncclType %d", type);
return ncclInvalidType;
}
}
// This version decodes both type and reduction op
template< template<typename, template<typename> class> class ColFunc>
ncclResult_t enqueue(const void* sendbuff,
void* recvbuff,
int count,
ncclDataType_t type,
@ -34,24 +93,19 @@ ncclResult_t enqueue(ColFunc colfunc,
ncclComm_t comm,
cudaStream_t stream)
{
int curDevice;
CUDACHECK( cudaGetDevice(&curDevice) );
// No need for a mutex here because we assume that all enqueue operations happen in a fixed
// order on all devices. Thus, thread race conditions SHOULD be impossible.
EventQueue* eq = &comm->events;
// Ensure that previous collective is complete
cudaError_t flag = cudaEventQuery(eq->isDone[eq->back]);
if( flag == cudaErrorNotReady )
CUDACHECK( cudaStreamWaitEvent(stream, eq->isDone[eq->back], 0) );
// Launch the collective here
ncclResult_t ret = colfunc(sendbuff, recvbuff, count, type, op, root, comm, stream);
eq->back = (eq->back + 1) % MAXQUEUE;
CUDACHECK( cudaEventRecord(eq->isDone[eq->back], stream) );
return ret;
switch(op) {
case ncclSum:
return enqueue<ColFunc, FuncSum>(sendbuff, recvbuff, count, type, root, comm, stream);
case ncclProd:
return enqueue<ColFunc, FuncProd>(sendbuff, recvbuff, count, type, root, comm, stream);
case ncclMax:
return enqueue<ColFunc, FuncMax>(sendbuff, recvbuff, count, type, root, comm, stream);
case ncclMin:
return enqueue<ColFunc, FuncMin>(sendbuff, recvbuff, count, type, root, comm, stream);
default:
WARN("Invalid ncclRedOp: %d", op);
return ncclInvalidOperation;
}
}
#endif // End include guard

View File

@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#include "libwrap.h"
@ -25,7 +25,6 @@ ncclResult_t wrapSymbols(void) {
return ncclSuccess;
static void* nvmlhandle = NULL;
static void* cuhandle = NULL;
void* tmp;
void** cast;
@ -38,20 +37,11 @@ ncclResult_t wrapSymbols(void) {
}
}
cuhandle = dlopen("libcuda.so", RTLD_NOW);
if (!cuhandle) {
cuhandle = dlopen("libcuda.so.1", RTLD_NOW);
if (!cuhandle) {
WARN("Failed to open libcuda.so[.1]");
goto teardown;
}
}
#define LOAD_SYM(handle, symbol, funcptr) do { \
cast = (void**)&funcptr; \
tmp = dlsym(handle, symbol); \
if (tmp == NULL) { \
WARN("dlsym failed on %s - %s", symbol, dlerror()); \
WARN("dlsym failed on %s - %s", symbol, dlerror());\
goto teardown; \
} \
*cast = tmp; \
@ -76,7 +66,6 @@ ncclResult_t wrapSymbols(void) {
nvmlInternalDeviceSetCpuAffinity = NULL;
nvmlInternalDeviceClearCpuAffinity = NULL;
if (cuhandle != NULL) dlclose(cuhandle);
if (nvmlhandle != NULL) dlclose(nvmlhandle);
return ncclSystemError;
}
@ -84,7 +73,7 @@ ncclResult_t wrapSymbols(void) {
ncclResult_t wrapNvmlInit(void) {
if (nvmlInternalInit == NULL) {
WARN("lib wrapper not initilaized.");
WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalInit();
@ -98,7 +87,7 @@ ncclResult_t wrapNvmlInit(void) {
ncclResult_t wrapNvmlShutdown(void) {
if (nvmlInternalShutdown == NULL) {
WARN("lib wrapper not initilaized.");
WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalShutdown();
@ -112,7 +101,7 @@ ncclResult_t wrapNvmlShutdown(void) {
ncclResult_t wrapNvmlDeviceGetHandleByPciBusId(const char* pciBusId, nvmlDevice_t* device) {
if (nvmlInternalDeviceGetHandleByPciBusId == NULL) {
WARN("lib wrapper not initilaized.");
WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalDeviceGetHandleByPciBusId(pciBusId, device);
@ -126,7 +115,7 @@ ncclResult_t wrapNvmlDeviceGetHandleByPciBusId(const char* pciBusId, nvmlDevice_
ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index) {
if (nvmlInternalDeviceGetIndex == NULL) {
WARN("lib wrapper not initilaized.");
WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalDeviceGetIndex(device, index);
@ -140,7 +129,7 @@ ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index) {
ncclResult_t wrapNvmlDeviceSetCpuAffinity(nvmlDevice_t device) {
if (nvmlInternalDeviceSetCpuAffinity == NULL) {
WARN("lib wrapper not initilaized.");
WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalDeviceSetCpuAffinity(device);
@ -154,7 +143,7 @@ ncclResult_t wrapNvmlDeviceSetCpuAffinity(nvmlDevice_t device) {
ncclResult_t wrapNvmlDeviceClearCpuAffinity(nvmlDevice_t device) {
if (nvmlInternalInit == NULL) {
WARN("lib wrapper not initilaized.");
WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalDeviceClearCpuAffinity(device);
@ -165,3 +154,4 @@ ncclResult_t wrapNvmlDeviceClearCpuAffinity(nvmlDevice_t device) {
}
return ncclSuccess;
}

View File

@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
@ -14,6 +14,15 @@
typedef struct nvmlDevice_st* nvmlDevice_t;
/**
* Generic enable/disable enum.
*/
typedef enum nvmlEnableState_enum
{
NVML_FEATURE_DISABLED = 0, //!< Feature disabled
NVML_FEATURE_ENABLED = 1 //!< Feature enabled
} nvmlEnableState_t;
ncclResult_t wrapSymbols(void);
ncclResult_t wrapNvmlInit(void);
@ -22,6 +31,7 @@ ncclResult_t wrapNvmlDeviceGetHandleByPciBusId(const char* pciBusId, nvmlDevice_
ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index);
ncclResult_t wrapNvmlDeviceSetCpuAffinity(nvmlDevice_t device);
ncclResult_t wrapNvmlDeviceClearCpuAffinity(nvmlDevice_t device);
ncclResult_t wrapNvmlDeviceGetHandleByIndex(unsigned int index, nvmlDevice_t *device);
#endif // End include guard

206
src/primitives.h Normal file
View File

@ -0,0 +1,206 @@
/*************************************************************************
* Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef PRIMITIVES_H_
#define PRIMITIVES_H_
#include <type_traits>
#include "copy_kernel.h" // for FuncPassA
#include "reduce_kernel.h" // for reduction funcs
/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy.
*
* In order to reduce the reptetion of template arguments, the operations
* are bundled as static methods of the Primitives class.
*
* Each primitive operation copies/reduces a contiguous buffer and syncs
* an optional set of flags against a sub-step counter. The sync value is
* based on the step parameter. Sync flags must be of type WaitFlag or
* PostFlag. The primitive routines wait for all WaitFlag args to attain
* at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of
* corresponding substep by previous step) before executing the transfer.
* After each substep is transfered, all PostFlag arguments get updated to
* the value SUBSTEPS*step+substep+1.
*/
class WaitFlag {
volatile int * const flag;
const int shift;
public:
__device__ __forceinline__
WaitFlag(volatile int * const flag, const int shift) : flag(flag), shift(shift) { }
__device__ __forceinline__
void wait(int val) { while (*flag < (val + shift)) /*SPIN*/; }
};
class PostFlag {
volatile int * const flag;
const int shift;
public:
__device__ __forceinline__
PostFlag(volatile int* const flag, const int shift) : flag(flag), shift(shift) { }
__device__ __forceinline__
void post(int val) { *flag = (val + shift); }
};
// Helper to check if any argument is of type T.
// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...)
template<typename T> __device__ __forceinline__
bool AnyAre() { return false; }
template<typename T, typename FIRST_T, typename... TAIL_Ts>
__device__ __forceinline__
bool AnyAre(FIRST_T first, TAIL_Ts... tail) {
return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...);
}
// Wait on all WaitFlags, ignore PostFlags
__device__ __forceinline__
void WaitOnFlags(int val) { }
template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
flag.wait(val);
WaitOnFlags(val, tail...);
}
template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(int val, PostFlag, TAIL_Ts... tail) {
WaitOnFlags(val, tail...);
}
// Post all PostFlags, ingnore WaitFlags
__device__ __forceinline__
void PostToFlags(int val) { }
template <typename... TAIL_Ts> __device__ __forceinline__
void PostToFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
PostToFlags(val, tail...);
}
template <typename... TAIL_Ts> __device__ __forceinline__
void PostToFlags(int val, PostFlag flag, TAIL_Ts... tail) {
flag.post(val);
PostToFlags(val, tail...);
}
// Create pointer arithmetic syntax that doesn't break for nullptr_t
template <typename Tptr> __device__ __forceinline__
Tptr ptradd(Tptr ptr, int i) {
return ptr + i;
}
__device__ __forceinline__
nullptr_t ptradd(nullptr_t ptr, int i) {
return nullptr;
}
// Implementation of primitive types
template <int THREADS, int UNROLL, int SUBSTEPS, typename T, typename REDOP=FuncSum<T> >
class Primitives {
private:
template <typename SRC2_T, // either T* or nullptr_t
typename DST2_T, // either T* or nullptr_t
typename... SYNC_Ts> // either WaitFunc or PostFunc
static __device__ __forceinline__ void
GenericOp(const T* src1,
const SRC2_T src2,
T* dst1,
DST2_T dst2,
int len, int maxoffset, int step, SYNC_Ts... flags) {
enum { noSrc2 = std::is_same<SRC2_T, nullptr_t>::value };
enum { noDst2 = std::is_same<DST2_T, nullptr_t>::value };
static_assert(noSrc2 || std::is_same<SRC2_T, const T*>::value,
"src2 must be of type T* or nullptr_t");
static_assert(noDst2 || std::is_same<DST2_T, T*>::value,
"dst2 must be of type T* or nullptr_t");
using OpType = typename std::conditional<noSrc2, FuncPassA<T>, REDOP>::type;
if (threadIdx.x < THREADS) {
int sliceSize = len / SUBSTEPS;
int sliceOffset = 0;
#pragma unroll 1
for (int sub=0; sub<SUBSTEPS; ++sub) {
if (AnyAre<WaitFlag>(flags...)) {
if (threadIdx.x == 0) {
WaitOnFlags(SUBSTEPS*step + sub + 1, flags...);
}
asm volatile ("bar.sync 1, %0;" :: "r"(THREADS));
}
ReduceOrCopy
<
UNROLL,
THREADS,
OpType,
T,
!std::is_same<DST2_T, nullptr_t>::value, // HAS_DEST1
!std::is_same<SRC2_T, nullptr_t>::value // HAS_SRC1
>
(
threadIdx.x,
ptradd(dst1, sliceOffset),
ptradd(dst2, sliceOffset),
ptradd(src1, sliceOffset),
ptradd(src2, sliceOffset),
min(sliceSize, maxoffset-sliceOffset)
);
if (AnyAre<PostFlag>(flags...)) {
__syncthreads();
}
sliceOffset += sliceSize;
}
} else {
for(int sub=0; sub<SUBSTEPS; ++sub) {
if (AnyAre<PostFlag>(flags...)) {
__syncthreads();
__threadfence_system();
PostToFlags(SUBSTEPS*step + sub + 1, flags...);
}
}
}
}
public:
template <typename... SYNC_Ts>
static __device__ __forceinline__ void
Copy(const T* src, T* dst,
int len, int step, SYNC_Ts... flags) {
GenericOp(src, nullptr, dst, nullptr, len, step, flags...);
}
template <typename... SYNC_Ts>
static __device__ __forceinline__ void
DoubleCopy(const T* src, T* dst1, T* dst2,
int len, int step, SYNC_Ts... flags) {
GenericOp(src, nullptr, dst1, dst2, len, step, flags...);
}
template <typename... SYNC_Ts>
static __device__ __forceinline__ void
Reduce(const T* src1, const T* src2, T* dst,
int len, int step, SYNC_Ts... flags) {
GenericOp(src1, src2, dst, nullptr, len, step, flags...);
}
template <typename... SYNC_Ts>
static __device__ __forceinline__ void
ReduceCopy(const T* src1, const T* src2, T* dst1, T* dst2,
int len, int step, SYNC_Ts... flags) {
GenericOp(src1, src2, dst1, dst2, len, step, flags...);
}
};
#endif // end include guard

View File

@ -1,393 +1,150 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#include <algorithm>
#include "core.h"
#include "common_kernel.h"
#include "copy_kernel.h"
#include "enqueue.h"
#include "reduce_kernel.h"
#include "primitives.h"
/* HIERARCHY
*
* The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
* SUBCHUNKS, where each SUBCHUNK is processed independently. A SUBCHUNK is
* split into numUnroll UNROLLS and each thread performs UNROLL_COUNT
* single-data-element operations inside an UNROLL. As the name suggests, the
* UNROLL_COUNT operations within an UNROLL are unrolled.
*/
#define NUM_SUBSTEPS 2
#define NUM_BUFCHUNKS 2
// Number of threads used to perform copies, etc. Must be multiple of 32.
// An additional thread is used to handle threadfences, so the CUDA blocks
// have dimension NUM_THREADS+1.
#define NUM_THREADS 256
// Increase Step and boffset for buffer sync
#define NEXT_STEP \
step++; \
boffset += sliceSize; \
if (boffset == buffSize) boffset = 0;
// Each thread unrolls the innermost loop of the copy or reduction operations
// to this many single-data-element instructions
#define UNROLL_COUNT 8
#define ALIGN_SIZE(size, align) \
size = ((size + (align) - 1) / (align)) * (align);
#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
template<int THREADS, int UNROLL, class FUNC, typename T>
__launch_bounds__(THREADS+WARP_SIZE, 1)
__global__ void ReduceKernel(const KernelArgs<T> args) {
const int tid = threadIdx.x;
__shared__ DevRing<T> ring;
// To hide the latency associated with the synchronization between different
// subchunks, we interleave the independent subchunks so that more data can be
// transferred while the sync is in progress. This is the number of subchunks
// that are active at the same time
#define NUM_SUBCHUNKS 4
LoadRing<THREADS>(args.ring, &ring);
__syncthreads();
// if this is called with CHUNK, it means that we just finished pushing the data
// of chunk CHUNK to the next GPU, so it can proceed with CHUNK
// We add 1 to chunk so that the initial flag of 0 doesn't allow the non-root
// GPUs to proceed before the flag is incremented from the upstream GPU. This
// is called by one particular consumer warp and so we select the first thread
// in the warp to set the flag.
#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk) \
do { \
__threadfence_system(); \
args.NextNewDataAvailableFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
} while (0)
if (tid == 0) {
WaitFlag prevCommOp(ring.prevOpCounter, 0);
WaitFlag nextCommOp(ring.nextOpCounter, 0);
prevCommOp.wait(args.opIndex);
nextCommOp.wait(args.opIndex);
}
__syncthreads();
// This is called by all producer threads, but only thread 0 spins on the flag,
#define WAIT_FOR_NEW_DATA(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
WaitFlag waitDoneFromNext(ring.recvFlagFromNext, (1-NUM_BUFCHUNKS)*NUM_SUBSTEPS);
WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, 0);
PostFlag postDoneToPrev(ring.sendFlagToPrev, 0);
PostFlag postReadyToNext(ring.sendFlagToNext, 0);
// If this is called with CHUNK, it means that this GPU has just finished
// processing the chunk CHUNK and so the previous GPU can start with CHUNK + 1
#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
do { \
args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
} while (0)
typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T, FUNC> Prims;
// This is called by all producer threads, but only thread 0 spins on the flag,
// all threads synchronize after thread 0 is done spinning.
#define WAIT_FOR_CHUNK(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
NUM_SUBCHUNKS*(chunk) + subchunk + 1 - NUM_SUBCHUNKS; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
const int size = args.N;
const int nranks = args.nRanks;
const int rank = ring.userRank[0];
const int prevRank = ring.userRank[nranks-1];
const int root = args.root;
const int buffSize = args.buffSize / sizeof(T);
const int sliceSize = buffSize / NUM_BUFCHUNKS;
// This is called by all producer threads, but only thread 0 spins on the flag,
// all threads synchronize after thread 0 is done spinning.
#define WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
bool newDataAvailable = \
((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
bool chunkDone = \
((volatile int *)args.ThisChunkDoneFlag)[0] >= \
NUM_SUBCHUNKS*(chunk)+subchunk + 1 - NUM_SUBCHUNKS; \
return newDataAvailable && chunkDone; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
int step = 0;
int boffset = 0;
__device__ inline void getSliceSizeAndOffset(int *size, int *offset, int slice,
int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
int smallSliceN, int lastSliceN) {
if (slice < numBigSlices) {
*size = bigSliceN;
*offset = slice * bigSliceN;
// Compute pointers
const T * __restrict__ thisInput = args.ThisInput;
T * __restrict__ thisOutput = args.ThisOutput;
T * __restrict__ prevInput = ring.recvBuffer;
T * __restrict__ nextOutput = ring.sendBuffer;
for (int offset = 0; offset < size; offset += sliceSize) {
int maxOffset = size-offset;
if (prevRank == root) {
Prims::Copy(
thisInput + offset,
nextOutput + boffset,
sliceSize, maxOffset,
step,
waitDoneFromNext,
postReadyToNext);
} else if (rank == root) {
Prims::Reduce(
prevInput + boffset,
thisInput + offset,
thisOutput + offset,
sliceSize, maxOffset,
step,
waitReadyFromPrev,
postDoneToPrev);
} else {
*size = (slice < numBigSlices + numSmallSlices) ? smallSliceN
: ((slice == numSlices - 1) ? lastSliceN : 0);
*offset = numBigSlices * bigSliceN + (slice - numBigSlices) * smallSliceN;
Prims::ReduceCopy(
thisInput + offset,
prevInput + boffset,
thisOutput + offset,
nextOutput + boffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
}
NEXT_STEP; // Increases step, boffset
}
// if (threadIdx.x == 0)
// printf("[size=%d] [offset=%d] slice=%d numSlices=%d "
// "numBigSlices=%d numSmallSlices=%d bigSliceN=%d smallSliceN=%d "
// "lastSliceN=%d\n", *size, *offset, slice, numSlices, numBigSlices,
// numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
// wait for the last data to be pushed to us
if (tid == 0) {
if (rank != root) {
// Wait for last update from next then reset the flag
waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
*ring.recvFlagFromNext = 0;
}
if (prevRank != root) {
// reset the flag
*ring.recvFlagFromPrev = 0;
}
incrementOpCounter(&args);
}
}
template<typename T>
struct ReduceKernelArgs {
// general parameters
int ThisId;
int N;
// some pre-computed sizes
int SliceSize;
int ChunkSize;
int NumChunks;
int BufferSliceStride;
T ** ThisPtrToNextData;
T ** PrevPtrToThisData;
// local and remote data
T * __restrict__ Output;
const T * __restrict__ ThisData;
volatile T * __restrict__ ThisBuffer;
volatile T * __restrict__ NextBuffer;
// local and remote flags
volatile int * __restrict__ ThisNewDataAvailableFlag;
volatile int * __restrict__ NextNewDataAvailableFlag;
volatile int * __restrict__ ThisChunkDoneFlag;
volatile int * __restrict__ PrevChunkDoneFlag;
};
__shared__ volatile void * nextData;
enum ReduceRole {BEGIN=0, MIDDLE=1, END=2};
template<int THREADS, int UNROLL, class FUNC, int ROLE, typename T>
__global__ void ReduceKernel(const ReduceKernelArgs<T> args) {
if (args.N == 0) return;
int tid = threadIdx.x;
// First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
// the previous GPU is done with a previous collective operation.
if (tid == 0) {
Wait([=] {
return *((T * volatile *)args.PrevPtrToThisData) == nullptr; // Wait for previous processor to be done
});
*((T * volatile *)args.PrevPtrToThisData) = (T*)args.ThisData; // Tell Previous I'm starting
Wait([=] {
return *((T * volatile *)args.ThisPtrToNextData) != nullptr; // Wait till I've been told next started
});
}
__syncthreads();
for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
// calculate slice size. for all chunks except (possibly) the last one,
// this will just be args.SliceSize. For the last one, it may be smaller
int bigSliceN = args.SliceSize;
int smallSliceN = 0;
int lastSliceN = 0;
int numSlices = NUM_SUBCHUNKS;
int numBigSlices = numSlices;
int numSmallSlices = 0;
// last chunk
if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
&numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
args.ChunkSize);
// this offset is only applied to Data pointers, not to Buffer pointers,
// since we only have one buffer per chunk
int chunkOffset = chunk * args.ChunkSize;
int offset;
int sliceSize;
if (tid < THREADS) {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
getSliceSizeAndOffset(&sliceSize, &offset, s, numSlices,
numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
if (ROLE == BEGIN) {
WAIT_FOR_CHUNK(chunk, s);
Copy<UNROLL, THREADS>(
args.NextBuffer + (s * args.BufferSliceStride),
args.ThisData + chunkOffset + offset,
sliceSize);
} else if (ROLE == MIDDLE) {
WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, s);
Reduce<UNROLL, THREADS, FUNC>(
args.NextBuffer + (s * args.BufferSliceStride),
args.ThisData + chunkOffset + offset,
args.ThisBuffer + (s * args.BufferSliceStride),
sliceSize);
} else { // ROLE == END
WAIT_FOR_NEW_DATA(chunk, s);
Reduce<UNROLL, THREADS, FUNC>(
args.Output + chunkOffset + offset,
args.ThisData + chunkOffset + offset,
args.ThisBuffer + (s * args.BufferSliceStride),
sliceSize);
}
__syncthreads();
}
} else { // Consumer thread
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
if (ROLE != END)
SIGNAL_NEW_DATA_AVAILABLE(chunk, s);
// signal chunk done if we don't push into the receive buffer and this
// is no the last chunk and this is not root
if ((ROLE != BEGIN) && (chunk + 1 < args.NumChunks)) {
SIGNAL_CHUNK_DONE(chunk, s);
}
}
}
}
// reset flags
if (tid == 0) {
args.ThisNewDataAvailableFlag[0] = 0;
args.ThisChunkDoneFlag[0] = 0;
*args.ThisPtrToNextData = nullptr;
}
}
#define THREADS 512
#define UNROLL 8
template<class FUNC, typename T>
ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
const int count, const int root, ncclComm* comm, cudaStream_t stream) {
ncclResult_t RingReduce(const void* sendbuff, void* recvbuff, const int count, const int root,
ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
int index = comm->ncclId;
const int numUnroll = 4;
int rootId = comm->ringFromUser[root];
int nextId = (index + 1) % comm->nDev;
int prevId = (index + comm->nDev - 1) % comm->nDev;
// There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
// where bufferN is the number of elements of type T that fit into the buffer.
// For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
int bufferN = comm->buffSize / sizeof(T);
// we only need buffer for k slices and k paddings
int bufferNPerSlice = bufferN / NUM_SUBCHUNKS;
int maxSliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
ReduceKernelArgs<T> args;
args.ThisId = index;
args.N = count;
args.SliceSize = numUnroll * UNROLL_SIZE * sizeof(PackType) / sizeof(T);
if(!comm->useRemoteRecv) {
// Proxy for QPI. Reduce never pushes directly to recv.
// But larger transfers help QPI more than tag updates hurt P2P.
args.SliceSize *= 8;
}
// make sure slice fits into the temporary buffer
args.SliceSize = std::min(maxSliceSize, args.SliceSize);
args.BufferSliceStride = args.SliceSize;
args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
// avoid a case where we have one or more big chunks and one tiny one
int remainder = args.N % args.ChunkSize;
if ((args.N > args.ChunkSize) && (remainder > 0) &&
(args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
args.SliceSize /= 2;
args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
// round down so we end up with a big last chunk
args.NumChunks = args.N / args.ChunkSize;
} else {
// round up
args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
}
args.ThisPtrToNextData = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
args.PrevPtrToThisData = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
args.Output = (T*)recvbuff;
args.ThisData = (const T*) sendbuff;
args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if (comm->nDev == 1) {
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
if (index == (rootId + 1) % comm->nDev) {
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, BEGIN, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else if (index == rootId) {
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, END, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
} else {
ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, MIDDLE, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
}
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, root, count, comm);
LAUNCH_KERNEL(ReduceKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
return ncclSuccess;
}
template <typename T>
ncclResult_t ncclReduceWithType(const void* sendbuff,
void* recvbuff, int count, ncclRedOp_t op, int root,
ncclComm* comm, cudaStream_t stream) {
switch (op) {
case ncclSum:
return ncclReduceWithTypeAndFunc<FuncSum<T>, T>(
sendbuff, recvbuff, count, root, comm, stream);
case ncclProd:
return ncclReduceWithTypeAndFunc<FuncProd<T>, T>(
sendbuff, recvbuff, count, root, comm, stream);
case ncclMax:
return ncclReduceWithTypeAndFunc<FuncMax<T>, T>(
sendbuff, recvbuff, count, root, comm, stream);
case ncclMin:
return ncclReduceWithTypeAndFunc<FuncMin<T>, T>(
sendbuff, recvbuff, count, root, comm, stream);
}
return ncclInvalidOperation;
}
template<typename T, template<typename> class RedOp>
class ReduceFunctor {
public:
ncclResult_t operator()(const void* sendbuff,
void* recvbuff, int count, ncclDataType_t datatype, ncclRedOp_t op,
int root, ncclComm* comm, cudaStream_t stream) {
switch (datatype) {
case ncclChar:
return ncclReduceWithType<char>(sendbuff, recvbuff, count, op, root, comm, stream);
case ncclInt:
return ncclReduceWithType<int>(sendbuff, recvbuff, count, op, root, comm, stream);
#ifdef CUDA_HAS_HALF
case ncclHalf:
return ncclReduceWithType<half>(sendbuff, recvbuff, count, op, root, comm, stream);
#endif
case ncclFloat:
return ncclReduceWithType<float>(sendbuff, recvbuff, count, op, root, comm, stream);
case ncclDouble:
return ncclReduceWithType<double>(sendbuff, recvbuff, count, op, root, comm, stream);
case ncclInt64:
return ncclReduceWithType<long long>(sendbuff, recvbuff, count, op, root, comm, stream);
case ncclUint64:
return ncclReduceWithType<unsigned long long>(sendbuff, recvbuff, count, op, root, comm, stream);
}
return ncclInvalidType;
public:
static ncclResult_t entry(const void* sendbuff, void* recvbuff,
int count, int root, ncclComm* comm, cudaStream_t stream) {
return RingReduce<RedOp<T>, T>(sendbuff, recvbuff, count, root, comm, stream);
}
};
extern "C" DSOGLOBAL
NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm,
cudaStream_t stream) {
return enqueue(ReduceFunctor(), sendbuff, recvbuff, count, datatype, op,
root, comm, stream);
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
return enqueue<ReduceFunctor>(sendbuff, recvbuff, count, datatype, op, root, comm, stream);
}

View File

@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
@ -11,6 +11,13 @@
#include "common_kernel.h"
#include <limits>
template<typename T>
struct FuncNull {
__device__ T operator()(const T x, const T y) const {
return 0;
}
};
template<typename T>
struct FuncSum {
__device__ T operator()(const T x, const T y) const {
@ -192,30 +199,46 @@ struct FuncMin<char> {
template<>
struct FuncSum<half> {
__device__ half2 operator()(const half2 x, const half2 y) const {
#if __CUDA_ARCH__ >= 530
return __hadd2(x, y);
#else
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fx.x + fy.x;
fr.y = fx.y + fy.y;
return __float22half2_rn(fr);
#endif
}
__device__ half operator()(const half x, const half y) const {
#if __CUDA_ARCH__ >= 530
return __hadd(x, y);
#else
return __float2half( __half2float(x) + __half2float(y) );
#endif
}
};
template<>
struct FuncProd<half> {
__device__ half2 operator()(const half2 x, const half2 y) const {
#if __CUDA_ARCH__ >= 530
return __hmul2(x, y);
#else
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fx.x * fy.x;
fr.y = fx.y * fy.y;
return __float22half2_rn(fr);
#endif
}
__device__ half operator()(const half x, const half y) const {
#if __CUDA_ARCH__ >= 530
return __hmul(x, y);
#else
return __float2half( __half2float(x) * __half2float(y) );
#endif
}
};
@ -225,15 +248,15 @@ struct FuncMax<half> {
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fx.x > fy.x ? fx.x : fy.x;
fr.y = fx.y > fy.y ? fx.y : fy.y;
fr.x = fmaxf(fx.x, fy.x);
fr.y = fmaxf(fx.y, fy.y);
return __float22half2_rn(fr);
}
__device__ half operator()(const half x, const half y) const {
float fx, fy, fm;
fx = __half2float(x);
fy = __half2float(y);
fm = fx > fy ? fx : fy;
fm = fmaxf(fx, fy);
return __float2half(fm);
}
};
@ -244,15 +267,15 @@ struct FuncMin<half> {
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fx.x < fy.x ? fx.x : fy.x;
fr.y = fx.y < fy.y ? fx.y : fy.y;
fr.x = fminf(fx.x, fy.x);
fr.y = fminf(fx.y, fy.y);
return __float22half2_rn(fr);
}
__device__ half operator()(const half x, const half y) const {
float fx, fy, fm;
fx = __half2float(x);
fy = __half2float(y);
fm = fx < fy ? fx : fy;
fm = fminf(fx, fy);
return __float2half(fm);
}
};

View File

@ -1,496 +1,166 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
* See LICENSE.txt for license information
************************************************************************/
#include <cassert>
#include "core.h"
#include "common_kernel.h"
#include "copy_kernel.h"
#include "enqueue.h"
#include "reduce_kernel.h"
#include "primitives.h"
/* HIERARCHY
*
* The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
* SUBCHUNKS, where each SUBCHUNK is an independent, complete reduction. Each
* GPU has a buffer that can fit an entire CHUNK, so that all SUBCHUNKS can be
* processed without checking that the buffer on the receiving GPU is empty. A
* SUBCHUNK is split into NUM_GPUS SLICES and each GPU works on a different
* SLICE at the same time. Before moving on the the next SLICE in the reduction
* algorithm, the GPU has to check whether it has received the data from the
* previous GPU it needs for this SLICE. To hide the latency of this
* communication, each GPU processes all the SLICES of all the SUBCHUNKS in
* sequence before moving on to the next SLICE. Each SLICE is split into a
* certain number of UNROLLS (determined by the buffer size) and each thread
* performs UNROLL_COUNT single-data-element operations inside an UNROLL. As the
* name suggests, the UNROLL_COUNT operations within an UNROLL are unrolled.
*/
#define NUM_SUBSTEPS 2
#define NUM_BUFCHUNKS 2
// Number of threads used to perform copies, etc. Must be multiple of 32.
// An additional thread is used to handle threadfences, so the CUDA blocks
// have dimension NUM_THREADS+1.
#define NUM_THREADS 256
// Increase Step and poffset/noffset for buffer sync
#define NEXT_STEP \
step++; \
poffset = noffset; \
noffset += sliceSize; \
if (noffset == buffSize) noffset = 0;
// Each thread unrolls the innermost loop of the copy or reduction operations
// to this many single-data-element instructions
#define UNROLL_COUNT 8
#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
// To hide the latency associated with the synchronization between different
// subchunks, we interleave the independent subchunks so that more data can be
// transferred while the sync is in progress. This is the number of subchunks
// that are active at the same time
#define NUM_SUBCHUNKS 2
/*
* numGPUs BLOCKs consisting of recvcount words each
* BLOCK is split up into NumChunks CHUNKs
* CHUNK is split up into NUM_SUBCHUNKS SUBCHUNKs
* SUBCHUNK consists of exactly one SLICE
* SLICE is most efficiently processed in multiples of UNROLL_SIZE
*
* The algorithm has numGPUs steps and each step processes a SLICE (i.e.
* SUBCHUNK) of a different BLOCK. Only data of the BLOCKs not resident on the
* GPU need to be communicated, hence (numGPUs - 1) BLOCKs. So the buffer needs
* to have room for (numGPUs - 1) SLICEs.
*/
// do not encode the subchunk number into the flag, because there is a separate
// flag for each subchunk
// If this is called with STEP, it means that we just finished processing the
// data for step STEP on this GPU, which is the data required on the next GPU
// for step STEP + 1, so we signal the next GPU that its data for step STEP + 1
// is available. This is called by one particular consumer warp and so we select
// the first thread in the warp to set the flag.
#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk, step) \
do { \
args.NextNewDataAvailableFlag[0] = \
2*((chunk) * args.NumGPUs + (step)) + subchunk + 1; \
} while (0)
// This is called by all producer threads, but only thread 0 spins on the flag,
// all threads synchronize after thread 0 is done spinning.
#define WAIT_FOR_NEW_DATA(chunk, subchunk, step) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
2*((chunk) * args.NumGPUs + (step)) + subchunk - 1; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
// If this is called with CHUNK, it means that this GPU has just finished
// processing the chunk CHUNK and so the previous GPU can start with CHUNK + 1
#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
do { \
args.PrevChunkDoneFlag[0] = 2*(chunk) + subchunk + 1; \
} while (0)
// This is called by all producer threads, but only thread 0 spins on the flag,
// all threads synchronize after thread 0 is done spinning.
#define WAIT_FOR_CHUNK(chunk, subchunk) \
do { \
if (tid == 0) { \
Wait([=] { \
return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
2*(chunk) + subchunk - 1; \
}); \
} \
BAR(sync, 1, NUM_THREADS); \
} while (0)
__device__ inline void getSliceSizeAndChunkSize(int *sliceSize, int slice,
int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
int smallSliceN, int lastSliceN) {
if (slice < numBigSlices) {
*sliceSize = bigSliceN;
} else {
*sliceSize = (slice < numBigSlices + numSmallSlices) ? smallSliceN
: ((slice == numSlices - 1) ? lastSliceN : 0);
}
/* if (threadIdx.x == 0)
printf("[sliceSize=%d] slice=%d numSlices=%d "
"numBigSlices=%d numSmallSlices=%d bigSliceN=%d smallSliceN=%d "
"lastSliceN=%d\n", *sliceSize, slice, numSlices, numBigSlices,
numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
*/
}
template<typename T>
struct ReduceScatterKernelArgs {
// general parameters
int ThisId;
int NumGPUs;
int N;
int * UserFromRing;
// some pre-computed sizes
int SliceSize;
int ChunkSize;
int NumChunks;
int BufferSliceStride;
int BufferMisalignedN;
T ** ThisPtrToNextOutput;
T ** PrevPtrToThisOutput;
// local and remote input, output, and buffer
const T * __restrict__ ThisInput;
volatile T * __restrict__ ThisOutput;
volatile T * __restrict__ ThisBuffer;
volatile T * __restrict__ NextBuffer;
// local and remote flags
volatile int * __restrict__ ThisNewDataAvailableFlag;
volatile int * __restrict__ NextNewDataAvailableFlag;
volatile int * __restrict__ ThisChunkDoneFlag;
volatile int * __restrict__ PrevChunkDoneFlag;
};
__device__ inline int GetBlock(const int index, const int step,
const int * const userFromRing, const int numGPUs) {
return userFromRing[(numGPUs + index - 1 - step) % numGPUs];
}
#define ALIGN_SIZE(size, align) \
size = ((size + (align) - 1) / (align)) * (align);
template<int THREADS, int UNROLL, class FUNC, typename T>
__global__ void ReduceScatterKernel(const ReduceScatterKernelArgs<T> args) {
if (args.N == 0) return;
int tid = threadIdx.x;
__launch_bounds__(THREADS+WARP_SIZE, 1)
__global__ void ReduceScatterKernel(const KernelArgs<T> args) {
const int tid = threadIdx.x;
__shared__ DevRing<T> ring;
LoadRing<THREADS>(args.ring, &ring);
__syncthreads();
// First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
// the previous GPU is done with a previous collective operation.
if (tid == 0) {
Wait([=] {
return *((T * volatile *)args.PrevPtrToThisOutput) == nullptr; // Wait for previous processor to be done
});
*((T * volatile *)args.PrevPtrToThisOutput) = (T*)args.ThisOutput; // Tell Previous I'm starting
Wait([=] {
return *((T * volatile *)args.ThisPtrToNextOutput) != nullptr; // Wait till I've been told next started
});
WaitFlag prevCommOp(ring.prevOpCounter, 0);
WaitFlag nextCommOp(ring.nextOpCounter, 0);
prevCommOp.wait(args.opIndex);
nextCommOp.wait(args.opIndex);
}
__syncthreads();
for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
// calculate slice size. for all chunks except (possibly) the last one,
// this will just be args.SliceSize. For the last one, it may be smaller
int bigSliceN = args.SliceSize;
int smallSliceN = 0;
int lastSliceN = 0;
int numSlices = NUM_SUBCHUNKS;
int numBigSlices = numSlices;
int numSmallSlices = 0;
WaitFlag waitDoneFromNext(ring.recvFlagFromNext, -NUM_BUFCHUNKS*NUM_SUBSTEPS);
WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, -1*NUM_SUBSTEPS);
PostFlag postDoneToPrev(ring.sendFlagToPrev, -1*NUM_SUBSTEPS);
PostFlag postReadyToNext(ring.sendFlagToNext, 0);
// last chunk
if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
&numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
args.ChunkSize);
typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T, FUNC> Prims;
const int size = args.N;
const int nranks = args.nRanks;
const int buffSize = args.buffSize / sizeof(T);
const int sliceSize = buffSize / NUM_BUFCHUNKS;
// this offset is only applied to Data pointers, not to Buffer pointers,
// since we only have one buffer per chunk
int chunkOffset = chunk * args.ChunkSize;
int step = 0;
int poffset, noffset = 0;
// Compute pointers
const T * __restrict__ thisInput = args.ThisInput;
T * __restrict__ thisOutput = args.ThisOutput;
T * __restrict__ prevInput = ring.recvBuffer;
T * __restrict__ nextOutput = ring.sendBuffer;
for (int chunkOffset = 0; chunkOffset < size; chunkOffset += sliceSize) {
/////////////// begin ReduceScatter steps ///////////////
int offset;
int maxOffset = size-chunkOffset;
int rankDest;
// step 0: push data to next GPU
int step = 0;
int block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
int blockOffset = chunkOffset + block * args.N;
int bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
((block * args.BufferMisalignedN) % alignof(PackType));
int sliceSize;
rankDest = ring.userRank[nranks-1];
offset = chunkOffset + rankDest * size;
if (tid < NUM_THREADS) {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
Prims::Copy(
thisInput + offset,
nextOutput + noffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
WAIT_FOR_CHUNK(chunk, s);
Copy<UNROLL, THREADS>(
args.NextBuffer + bufferOffset,
args.ThisInput + blockOffset,
sliceSize);
__syncthreads();
bufferOffset += sliceSize;
blockOffset += sliceSize;
}
} else { // Is consumer
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
}
}
NEXT_STEP; // Increases step, poffset, noffset
// steps j with 0 < j < k - 1, where k = number of GPUs: reduce and copy to
// next GPU
for (step = 1; step < args.NumGPUs - 1; ++step) {
int block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
int blockOffset = chunkOffset + block * args.N;
int bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
((block * args.BufferMisalignedN) % alignof(PackType));
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
rankDest = ring.userRank[nranks-j];
offset = chunkOffset + rankDest * size;
if (tid < NUM_THREADS) {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
WAIT_FOR_NEW_DATA(chunk, s, step);
Reduce<UNROLL, THREADS, FUNC>(
args.NextBuffer + bufferOffset,
args.ThisBuffer + bufferOffset,
args.ThisInput + blockOffset,
sliceSize);
__syncthreads();
bufferOffset += sliceSize;
blockOffset += sliceSize;
}
} else {
for(int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
}
}
Prims::Reduce(
prevInput + poffset,
thisInput + offset,
nextOutput + noffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
NEXT_STEP;
}
// step k - 1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
step = args.NumGPUs - 1;
block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
blockOffset = chunkOffset + block * args.N;
bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
((block * args.BufferMisalignedN) % alignof(PackType));
rankDest = ring.userRank[0];
offset = chunkOffset + rankDest * size;
if (tid < NUM_THREADS) {
int outputOffset = 0;
for (int s=0; s<NUM_SUBCHUNKS; ++s) {
getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
WAIT_FOR_NEW_DATA(chunk, s, step);
Reduce<UNROLL, THREADS, FUNC>(
args.ThisOutput + (chunkOffset + outputOffset),
args.ThisBuffer + bufferOffset,
args.ThisInput + blockOffset,
sliceSize);
__syncthreads();
outputOffset += sliceSize;
bufferOffset += sliceSize;
blockOffset += sliceSize;
}
} else {
for (int s=0; s<NUM_SUBCHUNKS; ++s) {
__syncthreads();
SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
Prims::Reduce(
prevInput + poffset,
thisInput + offset,
thisOutput + chunkOffset,
sliceSize, maxOffset,
step,
waitDoneFromNext, waitReadyFromPrev,
postReadyToNext, postDoneToPrev);
// signal that chunk is done if this is not the last chunk
if (chunk + 1 < args.NumChunks) {
SIGNAL_CHUNK_DONE(chunk, s);
}
}
}
NEXT_STEP;
}
// wait for the last data to be pushed to us
if (tid < NUM_THREADS) {
WAIT_FOR_NEW_DATA(args.NumChunks, NUM_SUBCHUNKS-1, 0);
if (tid == 0) {
args.ThisNewDataAvailableFlag[tid] = 0;
args.ThisChunkDoneFlag[tid] = 0;
*args.ThisPtrToNextOutput = nullptr;
}
// Wait for last update from next then reset the flag
waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
*ring.recvFlagFromNext = 0;
// Wait for last update from prev then reset the flag
waitReadyFromPrev.wait(NUM_SUBSTEPS*(step+1));
*ring.recvFlagFromPrev = 0;
incrementOpCounter(&args);
}
}
#define THREADS 512
#define UNROLL 8
template<class FUNC, typename T>
ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff,
void* recvbuff, const int recvcount, ncclComm* comm, cudaStream_t stream) {
if (recvcount == 0) {
ncclResult_t RingReduceScatter(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
}
int index = comm->ncclId;
int blockSizeInBytes = recvcount * sizeof(T);
int misalignedBytes = blockSizeInBytes % alignof(uint64_t);
assert((int)((misalignedBytes / sizeof(T)) * sizeof(T)) == misalignedBytes);
int misalignedN = misalignedBytes / sizeof(T);
assert(misalignedN < (int)(sizeof(uint64_t) / sizeof(T)));
int paddingN = (misalignedN > 0) ? sizeof(uint64_t) / sizeof(T) : 0;
// There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
// where bufferN is the number of elements of type T that fit into the buffer.
// For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
int bufferN = comm->buffSize / sizeof(T);
// we only need buffer for k slices and k*k paddings (we need k paddings per
// block and we have k blocks)
int bufferNPerSlice = (bufferN - NUM_SUBCHUNKS * comm->nDev * paddingN) /
(NUM_SUBCHUNKS * comm->nDev);
int sliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
int nextId = (index + 1) % comm->nDev;
int prevId = (index + comm->nDev - 1) % comm->nDev;
ReduceScatterKernelArgs<T> args;
args.ThisId = index;
args.NumGPUs = comm->nDev;
args.N = recvcount;
/* Block j must end up in recvbuff[j], which lives on device with logical
* index comm->ringFromUser[j]. But the block ordering does not necessarily
* follow the ring ordering. Hence the order in which a particular GPU
* processes the different blocks (the correspondence between the step in
* the reduction algorithm and the block on which a GPU operates in that
* particular step) is not the same as the ring order.
*
* Say we have 4 GPUs and comm->userFromRing = { 1, 2, 0, 3 }. Then there are 4
* step in the reduction algorithm and block 0 needs to end up device 2,
* block 1 on device 0, block 2 on device 1, and block 3 needs to end up on
* device 3. In the last step of the algorithm, each GPU must be processing
* the block that will end up on that GPU. The blocks that a GPU has to
* process in the previous steps is determined by the next step because each
* GPU only hands off data to the next GPU in the ring.
*
* In the above example, we get the following table of which block is
* processed by each GPU in a given step. The columns correspond to the
* different GPUs while the rows are the steps in the algorithm.
*
* GPU 0 1 2 3
* step
* 0 3 1 2 0
* 1 0 3 1 2
* 2 2 0 3 1
* 3 1 2 0 3
*
* We note the the rows in the above table are just comm->userFromRing in the last
* step and the list is cyclicly permuted to the left for each previous
* step. The columns, which are what the individual GPUs need to know, are
* comm->userFromRing traversed backwards and starting at index k-1 for GPU k.
* These columns are what we put into args.BlockVsStep to tell the GPU which
* block it needs to be processing at a particular step. */
args.UserFromRing = comm->devUserFromRing;
args.SliceSize = sliceSize;
args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
// don't reduce this if we cut the slice size in half below, because if that
// happens, the last chunk will be larger than the other chunks, and we will
// need the extra buffer space
args.BufferSliceStride = args.SliceSize + paddingN;
args.BufferMisalignedN = misalignedN;
// avoid a case where we have one or more big chunks and one tiny one
int remainder = args.N % args.ChunkSize;
if ((args.N > args.ChunkSize) && (remainder > 0) &&
(args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
args.SliceSize /= 2;
args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
// round down so we end up with a big last chunk
args.NumChunks = args.N / args.ChunkSize;
} else {
// round up
args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
}
args.ThisPtrToNextOutput = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
args.PrevPtrToThisOutput = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
args.ThisInput = (const T*)sendbuff;
args.ThisOutput = (volatile T*)recvbuff;
args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
// we need 2 * NUM_SUBCHUNKS flags, so use the first NUM_SUBCHUNKS flags
// to signal the next GPU that new data is available and the following
// NUM_SUBCHUNKS to signal the previous GPU that a chunk is finished
args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
if (comm->nDev == 1) {
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, recvcount*sizeof(T), cudaMemcpyDeviceToDevice, stream));
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
<<<1, NUM_THREADS + 1, 0, stream>>>(args);
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
LAUNCH_KERNEL(ReduceScatterKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
return ncclSuccess;
}
template<typename T>
ncclResult_t ncclReduceScatterWithType(const void* sendbuff, void* recvbuff,
int recvcount, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
switch (op) {
case ncclSum:
return ncclReduceScatterWithTypeAndFunc<FuncSum<T>, T>(
sendbuff, recvbuff, recvcount, comm, stream);
case ncclProd:
return ncclReduceScatterWithTypeAndFunc<FuncProd<T>, T>(
sendbuff, recvbuff, recvcount, comm, stream);
case ncclMax:
return ncclReduceScatterWithTypeAndFunc<FuncMax<T>, T>(
sendbuff, recvbuff, recvcount, comm, stream);
case ncclMin:
return ncclReduceScatterWithTypeAndFunc<FuncMin<T>, T>(
sendbuff, recvbuff, recvcount, comm, stream);
}
return ncclInvalidOperation;
}
class ReduceScatterFunctor {
public:
ncclResult_t operator()(const void* sendbuff, void* recvbuff,
int recvcount, ncclDataType_t datatype, ncclRedOp_t op, int /*root*/,
ncclComm* comm, cudaStream_t stream) {
switch (datatype) {
case ncclChar:
return ncclReduceScatterWithType<char>(sendbuff, recvbuff, recvcount,
op, comm, stream);
case ncclInt:
return ncclReduceScatterWithType<int>(sendbuff, recvbuff, recvcount,
op, comm, stream);
#ifdef CUDA_HAS_HALF
case ncclHalf:
return ncclReduceScatterWithType<half>(sendbuff, recvbuff, recvcount,
op, comm, stream);
#endif
case ncclFloat:
return ncclReduceScatterWithType<float>(sendbuff, recvbuff, recvcount,
op, comm, stream);
case ncclDouble:
return ncclReduceScatterWithType<double>(sendbuff, recvbuff, recvcount,
op, comm, stream);
case ncclInt64:
return ncclReduceScatterWithType<long long>(sendbuff, recvbuff, recvcount,
op, comm, stream);
case ncclUint64:
return ncclReduceScatterWithType<unsigned long long>(sendbuff, recvbuff, recvcount,
op, comm, stream);
}
return ncclInvalidType;
template<typename T, template <typename> class RedOp>
class ReduceScatter {
public:
static ncclResult_t entry(const void* sendbuff, void* recvbuff,
int count, int /*root*/, ncclComm* comm, cudaStream_t stream) {
return RingReduceScatter<RedOp<T>, T>(sendbuff, recvbuff, count, comm, stream);
}
};
extern "C" DSOGLOBAL
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
int recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm,
cudaStream_t stream) {
return enqueue(ReduceScatterFunctor(), sendbuff, recvbuff, recvcount,
datatype, op, 0, comm, stream);
NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, int recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, int recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
return enqueue<ReduceScatter>(sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream);
}