Add support for CUDA9 half semantics

This commit is contained in:
Sylvain Jeaugey 2017-06-07 09:57:12 -07:00
parent ccfc4567dc
commit 29a1a916dc
3 changed files with 28 additions and 24 deletions

View File

@ -54,7 +54,7 @@ endif
NCCL_MAJOR := 1 NCCL_MAJOR := 1
NCCL_MINOR := 3 NCCL_MINOR := 3
NCCL_PATCH := 4 NCCL_PATCH := 5
CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH) CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev) CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)

View File

@ -35,25 +35,33 @@ T vFetch(const volatile T* ptr) {
return *ptr; return *ptr;
} }
#ifdef CUDA_HAS_HALF
template<> inline __device__
half vFetch<half>(const volatile half* ptr) {
half r;
r.x = ptr->x;
return r;
}
#endif
template<typename T> inline __device__ template<typename T> inline __device__
void vStore(volatile T* ptr, const T val) { void vStore(volatile T* ptr, const T val) {
*ptr = val; *ptr = val;
} }
#ifdef CUDA_HAS_HALF #ifdef CUDA_HAS_HALF
#if CUDART_VERSION < 9000
template<> inline __device__
half vFetch<half>(const volatile half* ptr) {
half r;
r.x = ptr->x;
return r;
}
template<> inline __device__ template<> inline __device__
void vStore<half>(volatile half* ptr, const half val) { void vStore<half>(volatile half* ptr, const half val) {
ptr->x = val.x; ptr->x = val.x;
} }
#else
template<> inline __device__
half vFetch<half>(const volatile half* ptr) {
return *((half*)ptr);
}
template<> inline __device__
void vStore<half>(volatile half* ptr, const half val) {
*((half*)ptr) = val;
}
#endif
#endif #endif
__device__ unsigned int spinct; __device__ unsigned int spinct;
@ -125,24 +133,22 @@ struct MULTI<FUNC, int> {
#ifdef CUDA_HAS_HALF #ifdef CUDA_HAS_HALF
template<class FUNC> template<class FUNC>
struct MULTI<FUNC, half> { struct MULTI<FUNC, half> {
static_assert(sizeof(PackType) == 2 * sizeof(float), static_assert(sizeof(PackType) == 4 * sizeof(half),
"PackType must be twice the size of float."); "PackType must be four times the size of half.");
union converter {
PackType storage; struct PackHalf2 {
struct {
half2 a, b; half2 a, b;
}; };
};
__device__ PackType operator()(const PackType x, const PackType y) const { __device__ PackType operator()(const PackType x, const PackType y) const {
converter cx, cy, cr; struct PackHalf2 cx, cy, cr;
cx.storage = x; cx = *(reinterpret_cast<const struct PackHalf2*>(&x));
cy.storage = y; cy = *(reinterpret_cast<const struct PackHalf2*>(&y));
cr.a = FUNC()(cx.a, cy.a); cr.a = FUNC()(cx.a, cy.a);
cr.b = FUNC()(cx.b, cy.b); cr.b = FUNC()(cx.b, cy.b);
return cr.storage; return *(reinterpret_cast<PackType*>(&cr));
} }
}; };
#endif #endif

View File

@ -24,9 +24,7 @@ struct FuncPassA<half> {
return x; return x;
} }
__device__ half operator()(const half x, const half y) const { __device__ half operator()(const half x, const half y) const {
half r; return x;
r.x = x.x;
return r;
} }
}; };
#endif #endif