diff --git a/src/all_gather.cu b/src/all_gather.cu index a83385f..0f90efd 100644 --- a/src/all_gather.cu +++ b/src/all_gather.cu @@ -477,6 +477,12 @@ public: case ncclDouble: return ncclAllGatherWithType(sendbuff, recvbuff, count, comm, numUnroll, stream); + case ncclInt64: + return ncclAllGatherWithType(sendbuff, recvbuff, count, comm, + numUnroll, stream); + case ncclUint64: + return ncclAllGatherWithType(sendbuff, recvbuff, count, comm, + numUnroll, stream); } return ncclInvalidType; } diff --git a/src/all_gather_test.cu b/src/all_gather_test.cu index a928806..a9e1c1e 100644 --- a/src/all_gather_test.cu +++ b/src/all_gather_test.cu @@ -224,6 +224,8 @@ int main(int argc, char* argv[]) { #endif RunTests(N / sizeof(float), ncclFloat, comms, dList); RunTests(N / sizeof(double), ncclDouble, comms, dList); + RunTests(N / sizeof(long long), ncclInt64, comms, dList); + RunTests(N / sizeof(unsigned long long), ncclUint64, comms, dList); printf("\n"); diff --git a/src/all_reduce.cu b/src/all_reduce.cu index cf84de0..670d45c 100644 --- a/src/all_reduce.cu +++ b/src/all_reduce.cu @@ -489,6 +489,12 @@ public: case ncclDouble: return ncclAllReduceWithType(sendbuff, recvbuff, count, op, comm, stream); + case ncclInt64: + return ncclAllReduceWithType(sendbuff, recvbuff, count, op, + comm, stream); + case ncclUint64: + return ncclAllReduceWithType(sendbuff, recvbuff, count, op, + comm, stream); } return ncclInvalidType; diff --git a/src/all_reduce_test.cu b/src/all_reduce_test.cu index aa18697..f46bd48 100644 --- a/src/all_reduce_test.cu +++ b/src/all_reduce_test.cu @@ -287,6 +287,8 @@ int main(int argc, char* argv[]) { #endif RunTests(N / sizeof(float), ncclFloat, comms, dList); RunTests(N / sizeof(double), ncclDouble, comms, dList); + RunTests(N / sizeof(long long), ncclInt64, comms, dList); + RunTests(N / sizeof(unsigned long long), ncclUint64, comms, dList); printf("\n"); diff --git a/src/broadcast.cu b/src/broadcast.cu index cde9c9e..c3e4c20 100644 --- a/src/broadcast.cu +++ b/src/broadcast.cu @@ -396,6 +396,10 @@ public: return ncclBcastWithType(buff, count, root, comm, numUnroll, stream); case ncclDouble: return ncclBcastWithType(buff, count, root, comm, numUnroll, stream); + case ncclInt64: + return ncclBcastWithType(buff, count, root, comm, numUnroll, stream); + case ncclUint64: + return ncclBcastWithType(buff, count, root, comm, numUnroll, stream); } return ncclInvalidType; } diff --git a/src/broadcast_test.cu b/src/broadcast_test.cu index 344ca7f..9c85a1f 100644 --- a/src/broadcast_test.cu +++ b/src/broadcast_test.cu @@ -224,6 +224,8 @@ int main(int argc, char* argv[]) { #endif RunTests(N / sizeof(float), ncclFloat, comms, dList); RunTests(N / sizeof(double), ncclDouble, comms, dList); + RunTests(N / sizeof(long long), ncclInt64, comms, dList); + RunTests(N / sizeof(unsigned long long), ncclUint64, comms, dList); printf("\n"); diff --git a/src/common_kernel.h b/src/common_kernel.h index 5b6770a..e30bf5c 100644 --- a/src/common_kernel.h +++ b/src/common_kernel.h @@ -174,6 +174,26 @@ struct MULTI { } }; +template +struct MULTI { + static_assert(sizeof(PackType) == sizeof(unsigned long long), + "PackType must be the same size as unsigned long long."); + __device__ PackType operator()(const PackType x, const PackType y) const { + unsigned long long rv = FUNC()(x, y); + return rv; + } +}; + +template +struct MULTI { + static_assert(sizeof(PackType) == sizeof(long long), + "PackType must be the same size as long long."); + __device__ PackType operator()(const PackType x, const PackType y) const { + long long rv = FUNC()((long long)x, (long long)y); + return rv; + } +}; + template __device__ inline void FetchOneOrTwo64b(PackType& s0, const volatile T * __restrict__ const src0, PackType& s1, diff --git a/src/nccl.h b/src/nccl.h index 94bb556..5173b13 100644 --- a/src/nccl.h +++ b/src/nccl.h @@ -117,7 +117,9 @@ typedef enum { ncclChar = 0, #endif ncclFloat = 3, ncclDouble = 4, - nccl_NUM_TYPES = 5 } ncclDataType_t; + ncclInt64 = 5, + ncclUint64 = 6, + nccl_NUM_TYPES = 7 } ncclDataType_t; /* Reduces data arrays of length count in sendbuff into recvbuf using op operation. * recvbuf may be NULL on all calls except for root device. diff --git a/src/reduce.cu b/src/reduce.cu index 2863e2a..6752d24 100644 --- a/src/reduce.cu +++ b/src/reduce.cu @@ -393,6 +393,10 @@ public: return ncclReduceWithType(sendbuff, recvbuff, count, op, root, comm, stream); case ncclDouble: return ncclReduceWithType(sendbuff, recvbuff, count, op, root, comm, stream); + case ncclInt64: + return ncclReduceWithType(sendbuff, recvbuff, count, op, root, comm, stream); + case ncclUint64: + return ncclReduceWithType(sendbuff, recvbuff, count, op, root, comm, stream); } return ncclInvalidType; } diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu index 3419caa..e1860c5 100644 --- a/src/reduce_scatter.cu +++ b/src/reduce_scatter.cu @@ -474,6 +474,12 @@ public: case ncclDouble: return ncclReduceScatterWithType(sendbuff, recvbuff, recvcount, op, comm, stream); + case ncclInt64: + return ncclReduceScatterWithType(sendbuff, recvbuff, recvcount, + op, comm, stream); + case ncclUint64: + return ncclReduceScatterWithType(sendbuff, recvbuff, recvcount, + op, comm, stream); } return ncclInvalidType; } diff --git a/src/reduce_scatter_test.cu b/src/reduce_scatter_test.cu index c1c87be..da205d5 100644 --- a/src/reduce_scatter_test.cu +++ b/src/reduce_scatter_test.cu @@ -271,6 +271,8 @@ int main(int argc, char* argv[]) { #endif RunTests(N / sizeof(float), ncclFloat, comms, dList); RunTests(N / sizeof(double), ncclDouble, comms, dList); + RunTests(N / sizeof(long long), ncclInt64, comms, dList); + RunTests(N / sizeof(unsigned long long), ncclUint64, comms, dList); printf("\n"); diff --git a/src/reduce_test.cu b/src/reduce_test.cu index fc06225..ce17e32 100644 --- a/src/reduce_test.cu +++ b/src/reduce_test.cu @@ -285,6 +285,8 @@ int main(int argc, char* argv[]) { #endif RunTests(N / sizeof(float), ncclFloat, comms, dList); RunTests(N / sizeof(double), ncclDouble, comms, dList); + RunTests(N / sizeof(long long), ncclInt64, comms, dList); + RunTests(N / sizeof(unsigned long long), ncclUint64, comms, dList); printf("\n"); diff --git a/src/test_utilities.h b/src/test_utilities.h index ecf760c..a5d3661 100644 --- a/src/test_utilities.h +++ b/src/test_utilities.h @@ -89,6 +89,12 @@ void GenerateRandom(curandGenerator_t generator, double * const dest, CURAND_CHK(curandGenerateUniformDouble(generator, dest, N)); } +template<> +void GenerateRandom(curandGenerator_t generator, unsigned long long * const dest, + const int N) { + CURAND_CHK(curandGenerateLongLong(generator, dest, N)); +} + template void Randomize(T* const dest, const int N, const int randomSeed) { @@ -100,6 +106,24 @@ void Randomize(T* const dest, const int N, const int randomSeed) { CUDACHECK(cudaDeviceSynchronize()); } +template<> +void Randomize(unsigned long long* const dest, const int N, const int randomSeed) { + curandGenerator_t gen; + CURAND_CHK(curandCreateGenerator(&gen, CURAND_RNG_QUASI_SOBOL64)); + GenerateRandom(gen, dest, N); + CURAND_CHK(curandDestroyGenerator(gen)); + CUDACHECK(cudaDeviceSynchronize()); +} + +template<> +void Randomize(long long* const dest, const int N, const int randomSeed) { + curandGenerator_t gen; + CURAND_CHK(curandCreateGenerator(&gen, CURAND_RNG_QUASI_SOBOL64)); + GenerateRandom(gen, (unsigned long long *)dest, N); + CURAND_CHK(curandDestroyGenerator(gen)); + CUDACHECK(cudaDeviceSynchronize()); +} + #ifdef CUDA_HAS_HALF __global__ void halve(const float * src, half* dest, int N) { for(int tid = threadIdx.x + blockIdx.x*blockDim.x; @@ -268,6 +292,8 @@ std::string TypeName(const ncclDataType_t type) { #endif case ncclFloat: return "float"; case ncclDouble: return "double"; + case ncclInt64: return "int64"; + case ncclUint64: return "uint64"; default: return "unknown"; } }