Add support for CUDA9 half semantics
This commit is contained in:
parent
ccfc4567dc
commit
29a1a916dc
2
Makefile
2
Makefile
@ -54,7 +54,7 @@ endif
|
|||||||
|
|
||||||
NCCL_MAJOR := 1
|
NCCL_MAJOR := 1
|
||||||
NCCL_MINOR := 3
|
NCCL_MINOR := 3
|
||||||
NCCL_PATCH := 4
|
NCCL_PATCH := 5
|
||||||
CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
|
CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
|
||||||
|
|
||||||
CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)
|
CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)
|
||||||
|
@ -35,25 +35,33 @@ T vFetch(const volatile T* ptr) {
|
|||||||
return *ptr;
|
return *ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef CUDA_HAS_HALF
|
|
||||||
template<> inline __device__
|
|
||||||
half vFetch<half>(const volatile half* ptr) {
|
|
||||||
half r;
|
|
||||||
r.x = ptr->x;
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template<typename T> inline __device__
|
template<typename T> inline __device__
|
||||||
void vStore(volatile T* ptr, const T val) {
|
void vStore(volatile T* ptr, const T val) {
|
||||||
*ptr = val;
|
*ptr = val;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef CUDA_HAS_HALF
|
#ifdef CUDA_HAS_HALF
|
||||||
|
#if CUDART_VERSION < 9000
|
||||||
|
template<> inline __device__
|
||||||
|
half vFetch<half>(const volatile half* ptr) {
|
||||||
|
half r;
|
||||||
|
r.x = ptr->x;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
template<> inline __device__
|
template<> inline __device__
|
||||||
void vStore<half>(volatile half* ptr, const half val) {
|
void vStore<half>(volatile half* ptr, const half val) {
|
||||||
ptr->x = val.x;
|
ptr->x = val.x;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
template<> inline __device__
|
||||||
|
half vFetch<half>(const volatile half* ptr) {
|
||||||
|
return *((half*)ptr);
|
||||||
|
}
|
||||||
|
template<> inline __device__
|
||||||
|
void vStore<half>(volatile half* ptr, const half val) {
|
||||||
|
*((half*)ptr) = val;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
__device__ unsigned int spinct;
|
__device__ unsigned int spinct;
|
||||||
@ -125,24 +133,22 @@ struct MULTI<FUNC, int> {
|
|||||||
#ifdef CUDA_HAS_HALF
|
#ifdef CUDA_HAS_HALF
|
||||||
template<class FUNC>
|
template<class FUNC>
|
||||||
struct MULTI<FUNC, half> {
|
struct MULTI<FUNC, half> {
|
||||||
static_assert(sizeof(PackType) == 2 * sizeof(float),
|
static_assert(sizeof(PackType) == 4 * sizeof(half),
|
||||||
"PackType must be twice the size of float.");
|
"PackType must be four times the size of half.");
|
||||||
union converter {
|
|
||||||
PackType storage;
|
struct PackHalf2 {
|
||||||
struct {
|
|
||||||
half2 a, b;
|
half2 a, b;
|
||||||
};
|
};
|
||||||
};
|
|
||||||
|
|
||||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||||
converter cx, cy, cr;
|
struct PackHalf2 cx, cy, cr;
|
||||||
cx.storage = x;
|
cx = *(reinterpret_cast<const struct PackHalf2*>(&x));
|
||||||
cy.storage = y;
|
cy = *(reinterpret_cast<const struct PackHalf2*>(&y));
|
||||||
|
|
||||||
cr.a = FUNC()(cx.a, cy.a);
|
cr.a = FUNC()(cx.a, cy.a);
|
||||||
cr.b = FUNC()(cx.b, cy.b);
|
cr.b = FUNC()(cx.b, cy.b);
|
||||||
|
|
||||||
return cr.storage;
|
return *(reinterpret_cast<PackType*>(&cr));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
@ -24,9 +24,7 @@ struct FuncPassA<half> {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
__device__ half operator()(const half x, const half y) const {
|
__device__ half operator()(const half x, const half y) const {
|
||||||
half r;
|
return x;
|
||||||
r.x = x.x;
|
|
||||||
return r;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
x
Reference in New Issue
Block a user