Add int64 and uint64 types for all algorithms and tests
This commit is contained in:
parent
27d32ac5d9
commit
41ce4ca9fc
@ -477,6 +477,12 @@ public:
|
|||||||
case ncclDouble:
|
case ncclDouble:
|
||||||
return ncclAllGatherWithType<double>(sendbuff, recvbuff, count, comm,
|
return ncclAllGatherWithType<double>(sendbuff, recvbuff, count, comm,
|
||||||
numUnroll, stream);
|
numUnroll, stream);
|
||||||
|
case ncclInt64:
|
||||||
|
return ncclAllGatherWithType<long long>(sendbuff, recvbuff, count, comm,
|
||||||
|
numUnroll, stream);
|
||||||
|
case ncclUint64:
|
||||||
|
return ncclAllGatherWithType<unsigned long long>(sendbuff, recvbuff, count, comm,
|
||||||
|
numUnroll, stream);
|
||||||
}
|
}
|
||||||
return ncclInvalidType;
|
return ncclInvalidType;
|
||||||
}
|
}
|
||||||
|
@ -224,6 +224,8 @@ int main(int argc, char* argv[]) {
|
|||||||
#endif
|
#endif
|
||||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||||
|
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||||
|
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
@ -489,6 +489,12 @@ public:
|
|||||||
case ncclDouble:
|
case ncclDouble:
|
||||||
return ncclAllReduceWithType<double>(sendbuff, recvbuff, count, op,
|
return ncclAllReduceWithType<double>(sendbuff, recvbuff, count, op,
|
||||||
comm, stream);
|
comm, stream);
|
||||||
|
case ncclInt64:
|
||||||
|
return ncclAllReduceWithType<long long>(sendbuff, recvbuff, count, op,
|
||||||
|
comm, stream);
|
||||||
|
case ncclUint64:
|
||||||
|
return ncclAllReduceWithType<unsigned long long int>(sendbuff, recvbuff, count, op,
|
||||||
|
comm, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
return ncclInvalidType;
|
return ncclInvalidType;
|
||||||
|
@ -287,6 +287,8 @@ int main(int argc, char* argv[]) {
|
|||||||
#endif
|
#endif
|
||||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||||
|
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||||
|
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
@ -396,6 +396,10 @@ public:
|
|||||||
return ncclBcastWithType<float>(buff, count, root, comm, numUnroll, stream);
|
return ncclBcastWithType<float>(buff, count, root, comm, numUnroll, stream);
|
||||||
case ncclDouble:
|
case ncclDouble:
|
||||||
return ncclBcastWithType<double>(buff, count, root, comm, numUnroll, stream);
|
return ncclBcastWithType<double>(buff, count, root, comm, numUnroll, stream);
|
||||||
|
case ncclInt64:
|
||||||
|
return ncclBcastWithType<long long>(buff, count, root, comm, numUnroll, stream);
|
||||||
|
case ncclUint64:
|
||||||
|
return ncclBcastWithType<unsigned long long>(buff, count, root, comm, numUnroll, stream);
|
||||||
}
|
}
|
||||||
return ncclInvalidType;
|
return ncclInvalidType;
|
||||||
}
|
}
|
||||||
|
@ -224,6 +224,8 @@ int main(int argc, char* argv[]) {
|
|||||||
#endif
|
#endif
|
||||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||||
|
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||||
|
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
@ -174,6 +174,26 @@ struct MULTI<FUNC, double> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<class FUNC>
|
||||||
|
struct MULTI<FUNC, unsigned long long> {
|
||||||
|
static_assert(sizeof(PackType) == sizeof(unsigned long long),
|
||||||
|
"PackType must be the same size as unsigned long long.");
|
||||||
|
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||||
|
unsigned long long rv = FUNC()(x, y);
|
||||||
|
return rv;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class FUNC>
|
||||||
|
struct MULTI<FUNC, long long> {
|
||||||
|
static_assert(sizeof(PackType) == sizeof(long long),
|
||||||
|
"PackType must be the same size as long long.");
|
||||||
|
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||||
|
long long rv = FUNC()((long long)x, (long long)y);
|
||||||
|
return rv;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<typename T, bool FETCHTWO>
|
template<typename T, bool FETCHTWO>
|
||||||
__device__ inline void FetchOneOrTwo64b(PackType& s0,
|
__device__ inline void FetchOneOrTwo64b(PackType& s0,
|
||||||
const volatile T * __restrict__ const src0, PackType& s1,
|
const volatile T * __restrict__ const src0, PackType& s1,
|
||||||
|
@ -117,7 +117,9 @@ typedef enum { ncclChar = 0,
|
|||||||
#endif
|
#endif
|
||||||
ncclFloat = 3,
|
ncclFloat = 3,
|
||||||
ncclDouble = 4,
|
ncclDouble = 4,
|
||||||
nccl_NUM_TYPES = 5 } ncclDataType_t;
|
ncclInt64 = 5,
|
||||||
|
ncclUint64 = 6,
|
||||||
|
nccl_NUM_TYPES = 7 } ncclDataType_t;
|
||||||
|
|
||||||
/* Reduces data arrays of length count in sendbuff into recvbuf using op operation.
|
/* Reduces data arrays of length count in sendbuff into recvbuf using op operation.
|
||||||
* recvbuf may be NULL on all calls except for root device.
|
* recvbuf may be NULL on all calls except for root device.
|
||||||
|
@ -393,6 +393,10 @@ public:
|
|||||||
return ncclReduceWithType<float>(sendbuff, recvbuff, count, op, root, comm, stream);
|
return ncclReduceWithType<float>(sendbuff, recvbuff, count, op, root, comm, stream);
|
||||||
case ncclDouble:
|
case ncclDouble:
|
||||||
return ncclReduceWithType<double>(sendbuff, recvbuff, count, op, root, comm, stream);
|
return ncclReduceWithType<double>(sendbuff, recvbuff, count, op, root, comm, stream);
|
||||||
|
case ncclInt64:
|
||||||
|
return ncclReduceWithType<long long>(sendbuff, recvbuff, count, op, root, comm, stream);
|
||||||
|
case ncclUint64:
|
||||||
|
return ncclReduceWithType<unsigned long long>(sendbuff, recvbuff, count, op, root, comm, stream);
|
||||||
}
|
}
|
||||||
return ncclInvalidType;
|
return ncclInvalidType;
|
||||||
}
|
}
|
||||||
|
@ -474,6 +474,12 @@ public:
|
|||||||
case ncclDouble:
|
case ncclDouble:
|
||||||
return ncclReduceScatterWithType<double>(sendbuff, recvbuff, recvcount,
|
return ncclReduceScatterWithType<double>(sendbuff, recvbuff, recvcount,
|
||||||
op, comm, stream);
|
op, comm, stream);
|
||||||
|
case ncclInt64:
|
||||||
|
return ncclReduceScatterWithType<long long>(sendbuff, recvbuff, recvcount,
|
||||||
|
op, comm, stream);
|
||||||
|
case ncclUint64:
|
||||||
|
return ncclReduceScatterWithType<unsigned long long>(sendbuff, recvbuff, recvcount,
|
||||||
|
op, comm, stream);
|
||||||
}
|
}
|
||||||
return ncclInvalidType;
|
return ncclInvalidType;
|
||||||
}
|
}
|
||||||
|
@ -271,6 +271,8 @@ int main(int argc, char* argv[]) {
|
|||||||
#endif
|
#endif
|
||||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||||
|
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||||
|
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
@ -285,6 +285,8 @@ int main(int argc, char* argv[]) {
|
|||||||
#endif
|
#endif
|
||||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||||
|
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||||
|
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
@ -89,6 +89,12 @@ void GenerateRandom<double>(curandGenerator_t generator, double * const dest,
|
|||||||
CURAND_CHK(curandGenerateUniformDouble(generator, dest, N));
|
CURAND_CHK(curandGenerateUniformDouble(generator, dest, N));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
void GenerateRandom<unsigned long long>(curandGenerator_t generator, unsigned long long * const dest,
|
||||||
|
const int N) {
|
||||||
|
CURAND_CHK(curandGenerateLongLong(generator, dest, N));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void Randomize(T* const dest, const int N, const int randomSeed) {
|
void Randomize(T* const dest, const int N, const int randomSeed) {
|
||||||
@ -100,6 +106,24 @@ void Randomize(T* const dest, const int N, const int randomSeed) {
|
|||||||
CUDACHECK(cudaDeviceSynchronize());
|
CUDACHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
void Randomize(unsigned long long* const dest, const int N, const int randomSeed) {
|
||||||
|
curandGenerator_t gen;
|
||||||
|
CURAND_CHK(curandCreateGenerator(&gen, CURAND_RNG_QUASI_SOBOL64));
|
||||||
|
GenerateRandom<unsigned long long>(gen, dest, N);
|
||||||
|
CURAND_CHK(curandDestroyGenerator(gen));
|
||||||
|
CUDACHECK(cudaDeviceSynchronize());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
void Randomize(long long* const dest, const int N, const int randomSeed) {
|
||||||
|
curandGenerator_t gen;
|
||||||
|
CURAND_CHK(curandCreateGenerator(&gen, CURAND_RNG_QUASI_SOBOL64));
|
||||||
|
GenerateRandom<unsigned long long>(gen, (unsigned long long *)dest, N);
|
||||||
|
CURAND_CHK(curandDestroyGenerator(gen));
|
||||||
|
CUDACHECK(cudaDeviceSynchronize());
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef CUDA_HAS_HALF
|
#ifdef CUDA_HAS_HALF
|
||||||
__global__ void halve(const float * src, half* dest, int N) {
|
__global__ void halve(const float * src, half* dest, int N) {
|
||||||
for(int tid = threadIdx.x + blockIdx.x*blockDim.x;
|
for(int tid = threadIdx.x + blockIdx.x*blockDim.x;
|
||||||
@ -268,6 +292,8 @@ std::string TypeName(const ncclDataType_t type) {
|
|||||||
#endif
|
#endif
|
||||||
case ncclFloat: return "float";
|
case ncclFloat: return "float";
|
||||||
case ncclDouble: return "double";
|
case ncclDouble: return "double";
|
||||||
|
case ncclInt64: return "int64";
|
||||||
|
case ncclUint64: return "uint64";
|
||||||
default: return "unknown";
|
default: return "unknown";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user