1.3.2 release

Broadcast tuning
Better checking of inputs
Copy/reduce code simplification
This commit is contained in:
Sylvain Jeaugey 2016-12-01 15:17:50 -08:00
parent 1093821c33
commit 34d27771c6
10 changed files with 120 additions and 279 deletions

View File

@ -52,7 +52,7 @@ endif
NCCL_MAJOR := 1
NCCL_MINOR := 3
NCCL_PATCH := 1
NCCL_PATCH := 2
CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)

View File

@ -5,6 +5,7 @@
************************************************************************/
#include "core.h"
#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
@ -164,18 +165,15 @@ __global__ void AllGatherKernel(const KernelArgs<T> args) {
}
}
#define THREADS 384
#define THREADS 512
#define UNROLL 8
template<class FUNC, typename T>
ncclResult_t RingAllGather(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError);
} else {
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
@ -198,6 +196,7 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, int count, ncclDataT
void* recvbuff, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllGather(const void* sendbuff, int count, ncclDataType_t datatype,
void* recvbuff, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, "AllGather"));
return enqueue<AllGather, FuncNull>(sendbuff, recvbuff, count, datatype, 0, comm, stream);
}

View File

@ -5,6 +5,7 @@
************************************************************************/
#include "core.h"
#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
@ -202,12 +203,9 @@ __global__ void AllReduceKernel(const KernelArgs<T> args) {
template<class FUNC, typename T>
ncclResult_t RingAllReduce(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError);
} else {
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
@ -230,6 +228,7 @@ NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, int
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, op, 0, comm, "AllReduce"));
return enqueue<AllReduce>(sendbuff, recvbuff, count, datatype, op, 0, comm, stream);
}

View File

@ -5,10 +5,11 @@
************************************************************************/
#include "core.h"
#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
#define NUM_SUBSTEPS 2
#define NUM_SUBSTEPS 4
#define NUM_BUFCHUNKS 2
// Increase Step and boffset for buffer sync
@ -135,9 +136,6 @@ __global__ void BroadcastKernel(const KernelArgs<T> args) {
template<class FUNC, typename T>
ncclResult_t RingBroadcast(void* buff, const int count, const int root,
ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
if (comm->nRanks != 1) {
KernelArgs<T> args;
ArgsSetup(&args, buff, buff, root, count, comm);
@ -160,6 +158,7 @@ NCCL_API(ncclResult_t, ncclBcast, void* buff, int count, ncclDataType_t datatype
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclBcast(void* buff, int count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ArgsCheck(buff, buff, count, datatype, ncclSum, root, comm, "Bcast"));
return enqueue<Broadcast, FuncNull>(nullptr, buff, count, datatype, root, comm, stream);
}

View File

@ -174,33 +174,47 @@ struct MULTI<FUNC, long long> {
}
};
template<typename T, bool FETCHTWO>
__device__ inline void FetchOneOrTwo64b(PackType& s0,
const volatile T * __restrict__ const src0, PackType& s1,
const volatile T * __restrict__ const src1, const int idx) {
s0 = (reinterpret_cast<const volatile PackType *>(src0))[idx];
if (FETCHTWO) {
s1 = (reinterpret_cast<const volatile PackType *>(src1))[idx];
template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS>
__device__ inline void ReduceCopy(
const volatile T * __restrict__ const src0,
const volatile T * __restrict__ const src1,
volatile T * __restrict__ const dest0,
volatile T * __restrict__ const dest1, const int idx) {
T val = vFetch(src0+idx);
if (TWO_INPUTS) {
val = FUNC()(val, vFetch(src1+idx));
}
vStore(dest0+idx, val);
if (TWO_OUTPUTS) {
vStore(dest1+idx, val);
}
}
template<typename T, bool STORETWO>
__device__ inline void StoreOneOrTwo64b(volatile T * __restrict__ const dest0,
volatile T * __restrict__ const dest1, PackType val, const int idx) {
template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS, int UNROLL, int THREADS>
__device__ inline void ReduceCopy64b(
const volatile T * __restrict__ const src0,
const volatile T * __restrict__ const src1,
volatile T * __restrict__ const dest0,
volatile T * __restrict__ const dest1, const int offset) {
PackType t0[UNROLL];
PackType t1[UNROLL];
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int idx = offset + u*THREADS;
t0[u] = (reinterpret_cast<const volatile PackType *>(src0))[idx];
if (TWO_INPUTS) {
t1[u] = (reinterpret_cast<const volatile PackType *>(src1))[idx];
}
}
#pragma unroll
for (int u = 0; u < UNROLL; ++u) {
int idx = offset + u*THREADS;
PackType val = TWO_INPUTS ? MULTI<FUNC, T>()(t0[u], t1[u]) : t0[u];
(reinterpret_cast<volatile PackType *>(dest0))[idx] = val;
if (STORETWO) {
if (TWO_OUTPUTS) {
(reinterpret_cast<volatile PackType *>(dest1))[idx] = val;
}
}
template<class FUNC, typename T, bool ISREDUCE>
__device__ inline PackType ReduceOrCopy64b(const PackType s0,
const PackType s1) {
if (ISREDUCE) {
return MULTI<FUNC, T>()(s0, s1);
} else {
return s0;
}
}
#define ALIGNUP(x, a) ((((x)-1) & ~((a)-1)) + (a))
@ -251,9 +265,6 @@ __device__ inline void ReduceOrCopy(const int tid,
return;
}
const int UNROLL2 = (UNROLL >= 2) ? (UNROLL / 2) : 1;
const bool NOUNROLL2 = ((UNROLL / 2) == 0);
int Npreamble = (N<alignof(PackType)) ? N : AlignUp(dest0, alignof(PackType)) - dest0;
// stage 0: check if we'll be able to use the fast, 64-bit aligned path.
@ -266,247 +277,60 @@ __device__ inline void ReduceOrCopy(const int tid,
Npreamble = N;
}
/*
if (threadIdx.x == 0) {
printf("** alignable: %s", (alignable ? "YES" : " NO"));
printf(", dest0 = 0x%08X", dest0);
printf(", src0 = 0x%08X", src0);
if (HAS_DEST1) printf(", dest1 = 0x%08X", dest1);
if (HAS_SRC1) printf(", src1 = 0x%08X", src1);
printf("\n");
}
*/
// stage 1: preamble: handle any elements up to the point of everything coming
// into alignment
for (int idx = tid; idx < Npreamble; idx += THREADS) {
// ought to be no way this is ever more than one iteration, except when
// alignable is false
T val = vFetch(src0+idx);
if (HAS_SRC1) {
val = FUNC()(val, vFetch(src1+idx));
ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(src0, src1, dest0, dest1, idx);
}
vStore(dest0+idx, val);
if (HAS_DEST1) {
vStore(dest1+idx, val);
}
}
// reduce N by however many elements we've handled already
int Ndone = Npreamble;
int Nrem = N - Ndone;
// stage 2: fast path: use 64b loads/stores to do the bulk of the work,
// assuming the pointers we have are all 64-bit alignable.
if (alignable) {
if (Ndone > 0) {
// align up pointers
dest0 += Ndone; if (HAS_DEST1) { dest1 += Ndone; }
src0 += Ndone; if (HAS_SRC1) { src1 += Ndone; }
}
const int PackFactor = sizeof(PackType) / sizeof(T);
int Nrem = N - Npreamble;
dest0 += Npreamble; if (HAS_DEST1) { dest1 += Npreamble; }
src0 += Npreamble; if (HAS_SRC1) { src1 += Npreamble; }
// stage 2a: main loop
int Nalign = (Nrem / (sizeof(PackType) / sizeof(T)) / (UNROLL * THREADS))
int Nalign2a = (Nrem / (PackFactor * UNROLL * THREADS))
* (UNROLL * THREADS); // round down
#pragma unroll 1 // don't unroll this loop
for (int idx = tid; idx < Nalign; idx += UNROLL * THREADS) {
PackType t0[UNROLL2];
PackType t1[UNROLL2];
PackType t2[UNROLL2];
#pragma unroll
for (int j = 0; j < UNROLL2; ++j)
FetchOneOrTwo64b<T, HAS_SRC1>(t0[j], src0, t1[j], src1,
idx + j * THREADS);
#pragma unroll
for (int j = 0; j < UNROLL2; ++j)
t2[j] = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0[j], t1[j]);
if (!NOUNROLL2) {
#pragma unroll
for (int j = 0; j < UNROLL2; ++j)
FetchOneOrTwo64b<T, HAS_SRC1>(t0[j], src0, t1[j], src1,
idx + (UNROLL2 + j) * THREADS);
for (int idx = tid; idx < Nalign2a; idx += UNROLL * THREADS) {
ReduceCopy64b<FUNC, T, HAS_SRC1, HAS_DEST1, UNROLL, THREADS>(src0, src1, dest0, dest1, idx);
}
#pragma unroll
for (int j = 0; j < UNROLL2; ++j)
StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2[j], idx + j * THREADS);
if (!NOUNROLL2) {
#pragma unroll
for (int j = 0; j < UNROLL2; ++j)
t2[j] = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0[j], t1[j]);
#pragma unroll
for (int j = 0; j < UNROLL2; ++j)
StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2[j],
idx + (UNROLL2 + j) * THREADS);
}
}
int Ndone2a = Nalign2a * PackFactor;
Nrem -= Ndone2a;
// stage 2b: slightly less optimized for section when we don't have full
// UNROLLs
int Ndone2a = Nalign * (sizeof(PackType)/sizeof(T));
Ndone += Ndone2a;
Nrem = N - Ndone;
// TODO: This kind of pointer update arithmetic is expensive. Should
// probably find a better way.
if (Nrem > 0) {
dest0 += Ndone2a; if (HAS_DEST1) { dest1 += Ndone2a; }
src0 += Ndone2a; if (HAS_SRC1) { src1 += Ndone2a; }
}
Nalign = Nrem / (sizeof(PackType)/sizeof(T));
int Nalign2b = Nrem / PackFactor;
#pragma unroll 4
for (int idx = tid; idx < Nalign; idx += THREADS) {
PackType t0, t1, t2;
FetchOneOrTwo64b<T, HAS_SRC1>(t0, src0, t1, src1, idx);
t2 = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0, t1);
StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2, idx);
for (int idx = Nalign2a + tid; idx < Nalign2a + Nalign2b; idx += THREADS) {
ReduceCopy64b<FUNC, T, HAS_SRC1, HAS_DEST1, 1, 0>(src0, src1, dest0, dest1, idx);
}
int Ndone2b = Nalign2b * PackFactor;
Nrem -= Ndone2b;
int Ndone2 = Ndone2a + Ndone2b;
dest0 += Ndone2; if (HAS_DEST1) { dest1 += Ndone2; }
src0 += Ndone2; if (HAS_SRC1) { src1 += Ndone2; }
// stage 2c: tail
int Ndone2b = Nalign * (sizeof(PackType)/sizeof(T));
Ndone += Nalign * (sizeof(PackType)/sizeof(T));
Nrem = N - Ndone;
if (Nrem > 0) {
dest0 += Ndone2b; if (HAS_DEST1) { dest1 += Ndone2b; }
src0 += Ndone2b; if (HAS_SRC1) { src1 += Ndone2b; }
}
for (int idx = tid; idx < Nrem; idx += THREADS) {
// never ought to make it more than one time through this loop. only a
// few threads should even participate
T val = vFetch(src0+idx);
if (HAS_SRC1) {
val = FUNC()(val, vFetch(src1+idx));
}
vStore(dest0+idx, val);
if (HAS_DEST1) {
vStore(dest1+idx, val);
}
ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(src0, src1, dest0, dest1, idx);
}
} // done fast path
}
template<int THREADS, int UNROLL, typename T>
__device__ inline void CalcLastChunk(int * const bigSliceN,
int * const smallSliceN, int * const lastSliceN, int * const numSlices,
int * const numBigSlices, int * const numSmallSlices, const int N,
const int numChunks, const int chunkSize) {
int Nleft = N - ((numChunks - 1) * chunkSize);
// semi-equally split up the remaining work into numslices slices.
// it's "semi"-equal because we want the divisions to land as neatly as we
// can on alignable boundaries
int NperTile = UNROLL * THREADS * (sizeof(PackType)/sizeof(T));
int numTiles = (Nleft + NperTile - 1) / NperTile;
int numTilesPerBigSlice = (numTiles + *numSlices - 1)
/ *numSlices;
int numTilesPerSmallSlice = numTiles / *numSlices;
*bigSliceN = NperTile * numTilesPerBigSlice;
*smallSliceN = NperTile * numTilesPerSmallSlice;
*numBigSlices = numTiles % *numSlices;
*numSmallSlices = (*smallSliceN > 0) ?
*numSlices - *numBigSlices : 0;
// the lastSlice will take the place of one of the small slices unless
// there are no small slices (because this is a very small reduction), in
// which case we replace one of the big slices and leave the small slices
// as 0.
if (*numSmallSlices > 0) {
--*numSmallSlices;
if (*numSmallSlices == 0)
*smallSliceN = 0;
}
else {
--*numBigSlices;
if (*numBigSlices == 0)
*bigSliceN = 0;
}
*lastSliceN = Nleft -
(*numBigSlices * *bigSliceN
+ *numSmallSlices * *smallSliceN);
// in cases where args.N % numSlices is pretty small, we'd rather have one
// slightly big last slice than one big slice, a bunch of small slices,
// and one smaller last slice
if ((*numBigSlices == 1) &&
(*numSmallSlices == *numSlices - 2) &&
(*lastSliceN < *smallSliceN)) {
*numBigSlices += *numSmallSlices;
*numSmallSlices = 0;
*bigSliceN = *smallSliceN;
*smallSliceN = 0;
*lastSliceN = Nleft -
*numBigSlices * *bigSliceN;
}
// done recalculating
*numSlices = *numBigSlices +
*numSmallSlices + 1;
}
// Kernel launch
template<typename T>
struct KernelArgs {
// general parameters
int nRanks;
int root;
int buffSize;
int N;
int opIndex;
volatile int * __restrict__ opCounter;
bool pushrecv;
// some pre-computed sizes
int SliceSize;
int SliceOffset;
int ChunkSize;
int NumChunks;
// local and remote input, output, and buffer
const T * __restrict__ ThisInput;
T * __restrict__ ThisOutput;
DevRing<char>* ring;
};
template<typename T>
void ArgsSetup(KernelArgs<T> *args, const void* sendbuff, void* recvbuff,
const int root, const int count, ncclComm *comm) {
args->nRanks = comm->nRanks;
args->root = root;
args->buffSize = comm->buffSize;
args->N = count;
args->opIndex = comm->opSched;
args->opCounter = comm->opCounter;
args->ThisInput = (const T*)sendbuff;
args->ThisOutput = (T*)recvbuff;
args->ring = comm->devRing;
args->pushrecv = comm->globalMemSpace;
}
#define LAUNCH_KERNEL(K, THREADS, UNROLL, FUNC, T, \
args, stream) do { \
dim3 grid(1, 1, 1); \
dim3 block(THREADS+1, 1, 1); \
void* argptrs[] = {&args}; \
CUDACHECK(cudaLaunchKernel( \
(void*)K<THREADS, UNROLL, FUNC, T>, \
grid, block, argptrs, 0, stream)); \
} while (0)
template <typename T>
__device__ inline void incrementOpCounter(const KernelArgs<T> *args) {
// increment comm's operation counts

View File

@ -8,6 +8,7 @@
#include <stdlib.h>
#include "core.h"
#include "libwrap.h"
#include "common_coll.h"
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
@ -22,6 +23,7 @@ DebugLevel ncclDebugLevel;
NCCL_API(ncclResult_t, ncclGetUniqueId, ncclUniqueId* out);
ncclResult_t ncclGetUniqueId(ncclUniqueId* out) {
NCCLCHECK(PtrCheck(out, "GetUniqueId", "out"));
pid_t pid = getpid();
static int count = 0;
int commId = __sync_fetch_and_add(&count, 1);
@ -578,15 +580,6 @@ static void commFree(ncclComm_t comm) {
}
static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId* commId, int rank) {
if (ndev < 1) {
WARN("invalid device count (%d) requested", ndev);
return ncclUnsupportedDeviceCount;
}
if (rank >= ndev || rank < 0) {
WARN("rank %d exceeds ndev=%d", rank, ndev);
return ncclInvalidRank;
}
size_t commBytes = offsetof(ncclComm, ptrs) + ndev*sizeof(NodeRef);
struct ncclComm* comm = (struct ncclComm*)malloc(commBytes);
if (comm == NULL) {
@ -731,6 +724,17 @@ NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int ndev, ncclUniq
ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank) {
if (myrank == 0) showVersion();
NCCLCHECK(PtrCheck(newcomm, "CommInitRank", "newcomm"));
if (ndev < 1) {
WARN("Invalid device count requested : %d", ndev);
return ncclUnsupportedDeviceCount;
}
if (myrank >= ndev || myrank < 0) {
WARN("Invalid rank %d, should be in the range 0..%d", myrank, ndev-1);
return ncclInvalidRank;
}
if (strlen(commId.internal) < 1 ||
strlen(commId.internal) >= NCCL_UNIQUE_ID_BYTES) {
WARN("rank %d invalid commId", myrank);
@ -819,6 +823,13 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
showVersion();
NCCLCHECK(PtrCheck(comms, "CommInitRank", "comms"));
if (ndev < 1) {
WARN("Invalid device count requested : %d", ndev);
return ncclUnsupportedDeviceCount;
}
ncclResult_t res;
int savedDevice;
RankEntry* ranks = NULL;
@ -949,7 +960,7 @@ void ncclCommDestroy(ncclComm_t comm) {
int commDevice = comm->cudaDev;
if (savedDevice != commDevice) {
CUDACHECK(cudaSetDevice(commDevice));
CUDACHECK(cudaSetDevice(commDevice), void());
}
commFree(comm);
@ -982,18 +993,24 @@ const char* ncclGetErrorString(ncclResult_t code) {
NCCL_API(ncclResult_t, ncclCommCount, const ncclComm_t comm, int* count);
ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) {
NCCLCHECK(PtrCheck(comm, "CommCount", "comm"));
NCCLCHECK(PtrCheck(count, "CommCount", "count"));
*count = comm->nRanks;
return ncclSuccess;
}
NCCL_API(ncclResult_t, ncclCommCuDevice, const ncclComm_t comm, int* devid);
ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* devid) {
NCCLCHECK(PtrCheck(comm, "CommCuDevice", "comm"));
NCCLCHECK(PtrCheck(devid, "CommCuDevice", "devid"));
*devid = comm->cudaDev;
return ncclSuccess;
}
NCCL_API(ncclResult_t, ncclCommUserRank, const ncclComm_t comm, int* rank);
ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) {
NCCLCHECK(PtrCheck(comm, "CommUserRank", "comm"));
NCCLCHECK(PtrCheck(rank, "CommUserRank", "rank"));
*rank = comm->rank;
return ncclSuccess;
}

View File

@ -12,18 +12,6 @@
#include <cstdio>
#include <cuda_runtime.h>
// DIE on error
#define CUDACHECK(cmd) do { \
cudaError_t e = cmd; \
if( e != cudaSuccess ) { \
printf("Cuda failure %s:%d '%s'\n", \
__FILE__,__LINE__,cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while(false)
#define MAXRANKS 32
#define DEFAULT_BUFFER_SIZE_BYTES (1UL << 21)
#define NCCL_MEM_PAD_ALIGN 65536
@ -136,6 +124,23 @@ extern DebugLevel ncclDebugLevel;
} \
} while(0)
// Check CUDA calls
#define CUDACHECK(cmd, retcode) do { \
cudaError_t e = cmd; \
if( e != cudaSuccess ) { \
WARN("Cuda failure '%s'\n", cudaGetErrorString(e)); \
return retcode; \
} \
} while(false)
// Propagate errors up
#define NCCLCHECK(call) do { \
ncclResult_t res = call; \
if (res != ncclSuccess) { \
return res; \
} \
} while (0);
#ifdef PROFAPI
#define NCCL_API(ret, func, args...) \
__attribute__ ((visibility("default"))) \

View File

@ -34,7 +34,7 @@ ncclResult_t enqueue(const void* sendbuff,
{
if (stream != comm->prevStream) { // sync required for calls in different streams
comm->prevStream = stream;
CUDACHECK( cudaStreamWaitEvent(stream, comm->doneEvent, 0) );
CUDACHECK(cudaStreamWaitEvent(stream, comm->doneEvent, 0), ncclUnhandledCudaError);
}
ncclResult_t ret;
@ -42,7 +42,7 @@ ncclResult_t enqueue(const void* sendbuff,
// Always have to record done event because we don't know what stream next
// collective will be in.
CUDACHECK( cudaEventRecord(comm->doneEvent, stream) );
CUDACHECK(cudaEventRecord(comm->doneEvent, stream), ncclUnhandledCudaError);
comm->opSched += 1;
return ret;
}

View File

@ -5,6 +5,7 @@
************************************************************************/
#include "core.h"
#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
@ -117,12 +118,9 @@ __global__ void ReduceKernel(const KernelArgs<T> args) {
template<class FUNC, typename T>
ncclResult_t RingReduce(const void* sendbuff, void* recvbuff, const int count, const int root,
ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError);
} else {
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, root, count, comm);
@ -145,6 +143,7 @@ NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, int cou
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, op, root, comm, "Reduce"));
return enqueue<ReduceFunctor>(sendbuff, recvbuff, count, datatype, op, root, comm, stream);
}

View File

@ -5,6 +5,7 @@
************************************************************************/
#include "core.h"
#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
@ -133,12 +134,9 @@ __global__ void ReduceScatterKernel(const KernelArgs<T> args) {
template<class FUNC, typename T>
ncclResult_t RingReduceScatter(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError);
} else {
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
@ -161,6 +159,7 @@ NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, int recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
NCCLCHECK(ArgsCheck(sendbuff, recvbuff, recvcount, datatype, op, 0, comm, "ReduceScatter"));
return enqueue<ReduceScatter>(sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream);
}