diff --git a/Makefile b/Makefile index 8f34fcb..c37b7f7 100644 --- a/Makefile +++ b/Makefile @@ -54,7 +54,7 @@ endif NCCL_MAJOR := 1 NCCL_MINOR := 3 -NCCL_PATCH := 4 +NCCL_PATCH := 5 CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH) CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev) diff --git a/src/common_kernel.h b/src/common_kernel.h index cc71f8a..b96519f 100644 --- a/src/common_kernel.h +++ b/src/common_kernel.h @@ -35,25 +35,33 @@ T vFetch(const volatile T* ptr) { return *ptr; } -#ifdef CUDA_HAS_HALF -template<> inline __device__ -half vFetch(const volatile half* ptr) { - half r; - r.x = ptr->x; - return r; -} -#endif - template inline __device__ void vStore(volatile T* ptr, const T val) { *ptr = val; } #ifdef CUDA_HAS_HALF +#if CUDART_VERSION < 9000 +template<> inline __device__ +half vFetch(const volatile half* ptr) { + half r; + r.x = ptr->x; + return r; +} template<> inline __device__ void vStore(volatile half* ptr, const half val) { ptr->x = val.x; } +#else +template<> inline __device__ +half vFetch(const volatile half* ptr) { + return *((half*)ptr); +} +template<> inline __device__ +void vStore(volatile half* ptr, const half val) { + *((half*)ptr) = val; +} +#endif #endif __device__ unsigned int spinct; @@ -125,24 +133,22 @@ struct MULTI { #ifdef CUDA_HAS_HALF template struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(float), - "PackType must be twice the size of float."); - union converter { - PackType storage; - struct { - half2 a, b; - }; + static_assert(sizeof(PackType) == 4 * sizeof(half), + "PackType must be four times the size of half."); + + struct PackHalf2 { + half2 a, b; }; __device__ PackType operator()(const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; + struct PackHalf2 cx, cy, cr; + cx = *(reinterpret_cast(&x)); + cy = *(reinterpret_cast(&y)); cr.a = FUNC()(cx.a, cy.a); cr.b = FUNC()(cx.b, cy.b); - return cr.storage; + return *(reinterpret_cast(&cr)); } }; #endif diff --git a/src/copy_kernel.h b/src/copy_kernel.h index 8464699..0f69748 100644 --- a/src/copy_kernel.h +++ b/src/copy_kernel.h @@ -24,9 +24,7 @@ struct FuncPassA { return x; } __device__ half operator()(const half x, const half y) const { - half r; - r.x = x.x; - return r; + return x; } }; #endif