Adding missing file
This commit is contained in:
parent
34d27771c6
commit
648e9fbb58
118
src/common_coll.h
Normal file
118
src/common_coll.h
Normal file
@ -0,0 +1,118 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#ifndef COMMON_COLL_H_
|
||||
#define COMMON_COLL_H_
|
||||
|
||||
#include "core.h"
|
||||
|
||||
static ncclResult_t PointerCheck(const void* pointer, struct ncclComm* comm, const char* ptrname, const char* opname) {
|
||||
cudaPointerAttributes attr;
|
||||
cudaError_t err = cudaPointerGetAttributes(&attr, pointer);
|
||||
if (err != cudaSuccess || attr.devicePointer == NULL) {
|
||||
WARN("%s : %s is not a valid pointer\n", opname, ptrname);
|
||||
return ncclInvalidDevicePointer;
|
||||
}
|
||||
if (attr.memoryType == cudaMemoryTypeDevice && attr.device != comm->cudaDev) {
|
||||
WARN("%s : %s allocated on device %d mismatchs with NCCL device %d \n", opname, ptrname, attr.device, comm->cudaDev);
|
||||
return ncclInvalidDevicePointer;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t PtrCheck(void* ptr, const char* opname, const char* ptrname) {
|
||||
if (ptr == NULL) {
|
||||
WARN("%s : %s argument is NULL", opname, ptrname);
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t ArgsCheck(const void* sendbuff, const void* recvbuff, int count, ncclDataType_t type, ncclRedOp_t op, int root, struct ncclComm* comm, const char* opname) {
|
||||
NCCLCHECK(PtrCheck(comm, opname, "comm"));
|
||||
// First, the easy ones
|
||||
if (root < 0 || root >= comm->nRanks) {
|
||||
WARN("%s : invalid root %d (root should be in the 0..%d range)\n", opname, root, comm->nRanks);
|
||||
return ncclInvalidRank;
|
||||
}
|
||||
if (type < 0 || type >= nccl_NUM_TYPES) {
|
||||
WARN("%s : invalid type %d\n", opname, type);
|
||||
return ncclInvalidType;
|
||||
}
|
||||
if (op < 0 || op >= nccl_NUM_OPS) {
|
||||
WARN("%s : invalid reduction operation %d\n", opname, op);
|
||||
return ncclInvalidOperation;
|
||||
}
|
||||
if (count < 0) {
|
||||
WARN("%s : invalid count %d\n", opname, count);
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
|
||||
// Check pointers
|
||||
NCCLCHECK(PointerCheck(sendbuff, comm, "sendbuff", opname))
|
||||
if (strcmp(opname, "Reduce") == 0 && comm->rank != root) {
|
||||
// No need to check recvbuff pointer for non-root reduce
|
||||
return ncclSuccess;
|
||||
}
|
||||
NCCLCHECK(PointerCheck(recvbuff, comm, "recvbuff", opname))
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Kernel launch
|
||||
template<typename T>
|
||||
struct KernelArgs {
|
||||
// general parameters
|
||||
int nRanks;
|
||||
int root;
|
||||
int buffSize;
|
||||
int N;
|
||||
int opIndex;
|
||||
volatile int * __restrict__ opCounter;
|
||||
int * __restrict__ doneCount;
|
||||
bool pushrecv;
|
||||
|
||||
// some pre-computed sizes
|
||||
int SliceSize;
|
||||
int SliceOffset;
|
||||
int ChunkSize;
|
||||
int NumChunks;
|
||||
|
||||
// local and remote input, output, and buffer
|
||||
const T * __restrict__ ThisInput;
|
||||
T * __restrict__ ThisOutput;
|
||||
|
||||
DevRing<char>* ring;
|
||||
int nRings;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
void ArgsSetup(KernelArgs<T> *args, const void* sendbuff, void* recvbuff,
|
||||
const int root, const int count, ncclComm *comm) {
|
||||
args->nRanks = comm->nRanks;
|
||||
args->root = root;
|
||||
args->buffSize = comm->buffSizePerRing;
|
||||
args->N = count;
|
||||
args->opIndex = comm->opSched;
|
||||
args->opCounter = comm->opCounter;
|
||||
args->doneCount = &comm->devMem->doneCount;
|
||||
args->ThisInput = (const T*)sendbuff;
|
||||
args->ThisOutput = (T*)recvbuff;
|
||||
args->ring = comm->devRing;
|
||||
args->pushrecv = comm->globalMemSpace;
|
||||
args->nRings = comm->nRings;
|
||||
}
|
||||
|
||||
#define LAUNCH_KERNEL(K, THREADS, UNROLL, FUNC, T, \
|
||||
args, stream) do { \
|
||||
dim3 grid(args.nRings, 1, 1); \
|
||||
dim3 block(THREADS+1, 1, 1); \
|
||||
void* argptrs[] = {&args}; \
|
||||
CUDACHECK(cudaLaunchKernel( \
|
||||
(void*)K<THREADS, UNROLL, FUNC, T>, \
|
||||
grid, block, argptrs, 0, stream), ncclUnhandledCudaError); \
|
||||
} while (0)
|
||||
|
||||
#endif
|
Loading…
x
Reference in New Issue
Block a user