Add int64 and uint64 types for all algorithms and tests
This commit is contained in:
parent
27d32ac5d9
commit
41ce4ca9fc
@ -477,6 +477,12 @@ public:
|
||||
case ncclDouble:
|
||||
return ncclAllGatherWithType<double>(sendbuff, recvbuff, count, comm,
|
||||
numUnroll, stream);
|
||||
case ncclInt64:
|
||||
return ncclAllGatherWithType<long long>(sendbuff, recvbuff, count, comm,
|
||||
numUnroll, stream);
|
||||
case ncclUint64:
|
||||
return ncclAllGatherWithType<unsigned long long>(sendbuff, recvbuff, count, comm,
|
||||
numUnroll, stream);
|
||||
}
|
||||
return ncclInvalidType;
|
||||
}
|
||||
|
@ -224,6 +224,8 @@ int main(int argc, char* argv[]) {
|
||||
#endif
|
||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||
|
||||
printf("\n");
|
||||
|
||||
|
@ -489,6 +489,12 @@ public:
|
||||
case ncclDouble:
|
||||
return ncclAllReduceWithType<double>(sendbuff, recvbuff, count, op,
|
||||
comm, stream);
|
||||
case ncclInt64:
|
||||
return ncclAllReduceWithType<long long>(sendbuff, recvbuff, count, op,
|
||||
comm, stream);
|
||||
case ncclUint64:
|
||||
return ncclAllReduceWithType<unsigned long long int>(sendbuff, recvbuff, count, op,
|
||||
comm, stream);
|
||||
}
|
||||
|
||||
return ncclInvalidType;
|
||||
|
@ -287,6 +287,8 @@ int main(int argc, char* argv[]) {
|
||||
#endif
|
||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||
|
||||
printf("\n");
|
||||
|
||||
|
@ -396,6 +396,10 @@ public:
|
||||
return ncclBcastWithType<float>(buff, count, root, comm, numUnroll, stream);
|
||||
case ncclDouble:
|
||||
return ncclBcastWithType<double>(buff, count, root, comm, numUnroll, stream);
|
||||
case ncclInt64:
|
||||
return ncclBcastWithType<long long>(buff, count, root, comm, numUnroll, stream);
|
||||
case ncclUint64:
|
||||
return ncclBcastWithType<unsigned long long>(buff, count, root, comm, numUnroll, stream);
|
||||
}
|
||||
return ncclInvalidType;
|
||||
}
|
||||
|
@ -224,6 +224,8 @@ int main(int argc, char* argv[]) {
|
||||
#endif
|
||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||
|
||||
printf("\n");
|
||||
|
||||
|
@ -174,6 +174,26 @@ struct MULTI<FUNC, double> {
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, unsigned long long> {
|
||||
static_assert(sizeof(PackType) == sizeof(unsigned long long),
|
||||
"PackType must be the same size as unsigned long long.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
unsigned long long rv = FUNC()(x, y);
|
||||
return rv;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, long long> {
|
||||
static_assert(sizeof(PackType) == sizeof(long long),
|
||||
"PackType must be the same size as long long.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
long long rv = FUNC()((long long)x, (long long)y);
|
||||
return rv;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, bool FETCHTWO>
|
||||
__device__ inline void FetchOneOrTwo64b(PackType& s0,
|
||||
const volatile T * __restrict__ const src0, PackType& s1,
|
||||
|
@ -117,7 +117,9 @@ typedef enum { ncclChar = 0,
|
||||
#endif
|
||||
ncclFloat = 3,
|
||||
ncclDouble = 4,
|
||||
nccl_NUM_TYPES = 5 } ncclDataType_t;
|
||||
ncclInt64 = 5,
|
||||
ncclUint64 = 6,
|
||||
nccl_NUM_TYPES = 7 } ncclDataType_t;
|
||||
|
||||
/* Reduces data arrays of length count in sendbuff into recvbuf using op operation.
|
||||
* recvbuf may be NULL on all calls except for root device.
|
||||
|
@ -393,6 +393,10 @@ public:
|
||||
return ncclReduceWithType<float>(sendbuff, recvbuff, count, op, root, comm, stream);
|
||||
case ncclDouble:
|
||||
return ncclReduceWithType<double>(sendbuff, recvbuff, count, op, root, comm, stream);
|
||||
case ncclInt64:
|
||||
return ncclReduceWithType<long long>(sendbuff, recvbuff, count, op, root, comm, stream);
|
||||
case ncclUint64:
|
||||
return ncclReduceWithType<unsigned long long>(sendbuff, recvbuff, count, op, root, comm, stream);
|
||||
}
|
||||
return ncclInvalidType;
|
||||
}
|
||||
|
@ -474,6 +474,12 @@ public:
|
||||
case ncclDouble:
|
||||
return ncclReduceScatterWithType<double>(sendbuff, recvbuff, recvcount,
|
||||
op, comm, stream);
|
||||
case ncclInt64:
|
||||
return ncclReduceScatterWithType<long long>(sendbuff, recvbuff, recvcount,
|
||||
op, comm, stream);
|
||||
case ncclUint64:
|
||||
return ncclReduceScatterWithType<unsigned long long>(sendbuff, recvbuff, recvcount,
|
||||
op, comm, stream);
|
||||
}
|
||||
return ncclInvalidType;
|
||||
}
|
||||
|
@ -271,6 +271,8 @@ int main(int argc, char* argv[]) {
|
||||
#endif
|
||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||
|
||||
printf("\n");
|
||||
|
||||
|
@ -285,6 +285,8 @@ int main(int argc, char* argv[]) {
|
||||
#endif
|
||||
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
|
||||
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
|
||||
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
|
||||
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);
|
||||
|
||||
printf("\n");
|
||||
|
||||
|
@ -89,6 +89,12 @@ void GenerateRandom<double>(curandGenerator_t generator, double * const dest,
|
||||
CURAND_CHK(curandGenerateUniformDouble(generator, dest, N));
|
||||
}
|
||||
|
||||
template<>
|
||||
void GenerateRandom<unsigned long long>(curandGenerator_t generator, unsigned long long * const dest,
|
||||
const int N) {
|
||||
CURAND_CHK(curandGenerateLongLong(generator, dest, N));
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Randomize(T* const dest, const int N, const int randomSeed) {
|
||||
@ -100,6 +106,24 @@ void Randomize(T* const dest, const int N, const int randomSeed) {
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template<>
|
||||
void Randomize(unsigned long long* const dest, const int N, const int randomSeed) {
|
||||
curandGenerator_t gen;
|
||||
CURAND_CHK(curandCreateGenerator(&gen, CURAND_RNG_QUASI_SOBOL64));
|
||||
GenerateRandom<unsigned long long>(gen, dest, N);
|
||||
CURAND_CHK(curandDestroyGenerator(gen));
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template<>
|
||||
void Randomize(long long* const dest, const int N, const int randomSeed) {
|
||||
curandGenerator_t gen;
|
||||
CURAND_CHK(curandCreateGenerator(&gen, CURAND_RNG_QUASI_SOBOL64));
|
||||
GenerateRandom<unsigned long long>(gen, (unsigned long long *)dest, N);
|
||||
CURAND_CHK(curandDestroyGenerator(gen));
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
#ifdef CUDA_HAS_HALF
|
||||
__global__ void halve(const float * src, half* dest, int N) {
|
||||
for(int tid = threadIdx.x + blockIdx.x*blockDim.x;
|
||||
@ -268,6 +292,8 @@ std::string TypeName(const ncclDataType_t type) {
|
||||
#endif
|
||||
case ncclFloat: return "float";
|
||||
case ncclDouble: return "double";
|
||||
case ncclInt64: return "int64";
|
||||
case ncclUint64: return "uint64";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user