nccl/src/collectives/device/reduce_kernel.h
Ke Wen e11238b302 2.11.4-1
Add new API for creating a reduction operation which multiplies the input by a rank-specific scalar before doing an inter-rank summation (see: ncclRedOpCreatePreMulSum).
Improve CollNet (SHARP) performance of ncclAllReduce when captured in a CUDA Graph via user buffer registration.
Add environment variable NCCL_NET_PLUGIN="<suffix>" to allow user to choose among multiple NCCL net plugins by substituting into "libnccl-net-<suffix>.so".
Fix memory leak of NVB connections.
Fix topology detection of IB Virtual Functions (SR-IOV).
2021-09-08 16:06:23 -07:00

706 lines
20 KiB
C++

/*************************************************************************
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef NCCL_REDUCE_KERNEL_H_
#define NCCL_REDUCE_KERNEL_H_
#include "common_kernel.h"
#include <limits>
#include <type_traits>
template<typename T>
struct FuncNull {
__device__ FuncNull(uint64_t opArg=0) {}
__device__ T operator()(const T x, const T y) const {
return 0;
}
};
template<typename T>
struct FuncSum {
__device__ FuncSum(uint64_t opArg=0) {}
__device__ T operator()(const T x, const T y) const {
return x + y;
}
};
template<typename T>
struct FuncProd {
__device__ FuncProd(uint64_t opArg=0) {}
__device__ T operator()(const T x, const T y) const {
return x * y;
}
};
template<typename T>
struct FuncMax {
__device__ FuncMax(uint64_t opArg=0) {}
__device__ T operator()(const T x, const T y) const {
return (x < y) ? y : x;
}
};
template<typename T>
struct FuncMin {
__device__ FuncMin(uint64_t opArg=0) {}
__device__ T operator()(const T x, const T y) const {
return (x < y) ? x : y;
}
};
template<typename Fn>
struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max
static constexpr bool IsPreOpIdentity = true;
static constexpr bool IsPostOpIdentity = true;
template<typename T>
__device__ static T preOp(Fn, T x) { return x; }
template<typename T>
__device__ static T postOp(Fn, T x) { return x; }
};
#define MASK0 0x00ff00ff
#define MASK1 0xff00ff00
static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) {
/* This can be used both for signed and unsigned 8-bit addition */
const uint32_t x0 = x & MASK0;
const uint32_t x1 = x & MASK1;
const uint32_t y0 = y & MASK0;
const uint32_t y1 = y & MASK1;
const uint32_t r0 = (x0+y0);
const uint32_t r1 = (x1+y1);
return (r0 & MASK0) | (r1 & MASK1);
}
template<>
struct FuncSum<int8_t> {
__device__ FuncSum(uint64_t opArg=0) {}
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
int32_t rv, z=0;
asm("vadd4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
#else
return addChar4(x, y);
#endif
}
__device__ int8_t operator()(const int8_t x, const int8_t y) const {
return x+y;
}
};
template<>
struct FuncSum<uint8_t> {
__device__ FuncSum(uint64_t opArg=0) {}
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
int32_t rv, z=0;
asm("vadd4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
#else
return addChar4(x, y);
#endif
}
__device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
return x+y;
}
};
static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) {
/* This can be used both for signed and unsigned 8-bit multiplication */
union converter { uint32_t storage; char4 a; };
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
cr.a.x = cx.a.x * cy.a.x;
cr.a.y = cx.a.y * cy.a.y;
cr.a.z = cx.a.z * cy.a.z;
cr.a.w = cx.a.w * cy.a.w;
return cr.storage;
}
template<>
struct FuncProd<int8_t> {
__device__ FuncProd(uint64_t opArg=0) {}
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
return mulChar4(x, y);
}
__device__ int8_t operator()(const int8_t x, const int8_t y) const {
return x*y;
}
};
template<>
struct FuncProd<uint8_t> {
__device__ FuncProd(uint64_t opArg=0) {}
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
return mulChar4(x, y);
}
__device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
return x*y;
}
};
template<>
struct FuncMax<int8_t> {
__device__ FuncMax(uint64_t opArg=0) {}
union converter { uint32_t storage; char4 a; };
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
int32_t rv, z=0;
asm("vmax4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
#else
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
cr.a.x = max(cx.a.x, cy.a.x);
cr.a.y = max(cx.a.y, cy.a.y);
cr.a.z = max(cx.a.z, cy.a.z);
cr.a.w = max(cx.a.w, cy.a.w);
return cr.storage;
#endif
}
__device__ int8_t operator()(const int8_t x, const int8_t y) const {
return (x>y) ? x : y;
}
};
template<>
struct FuncMax<uint8_t> {
__device__ FuncMax(uint64_t opArg=0) {}
union converter { uint32_t storage; uchar4 a; };
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
int32_t rv, z=0;
asm("vmax4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
#else
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
cr.a.x = max(cx.a.x, cy.a.x);
cr.a.y = max(cx.a.y, cy.a.y);
cr.a.z = max(cx.a.z, cy.a.z);
cr.a.w = max(cx.a.w, cy.a.w);
return cr.storage;
#endif
}
__device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
return (x>y) ? x : y;
}
};
template<>
struct FuncMin<int8_t> {
__device__ FuncMin(uint64_t opArg=0) {}
union converter { uint32_t storage; char4 a; };
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
int32_t rv, z=0;
asm("vmin4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
#else
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
cr.a.x = min(cx.a.x, cy.a.x);
cr.a.y = min(cx.a.y, cy.a.y);
cr.a.z = min(cx.a.z, cy.a.z);
cr.a.w = min(cx.a.w, cy.a.w);
return cr.storage;
#endif
}
__device__ int8_t operator()(const int8_t x, const int8_t y) const {
return (x<y) ? x : y;
}
};
template<>
struct FuncMin<uint8_t> {
__device__ FuncMin(uint64_t opArg=0) {}
union converter { uint32_t storage; uchar4 a; };
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
int32_t rv, z=0;
asm("vmin4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
#else
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
cr.a.x = min(cx.a.x, cy.a.x);
cr.a.y = min(cx.a.y, cy.a.y);
cr.a.z = min(cx.a.z, cy.a.z);
cr.a.w = min(cx.a.w, cy.a.w);
return cr.storage;
#endif
}
__device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
return (x<y) ? x : y;
}
};
template<>
struct FuncSum<half> {
__device__ FuncSum(uint64_t opArg=0) {}
__device__ half2 operator()(const half2 x, const half2 y) const {
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
return __hadd2(x, y);
#else
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fx.x + fy.x;
fr.y = fx.y + fy.y;
return __float22half2_rn(fr);
#endif
}
__device__ half operator()(const half x, const half y) const {
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
return __hadd(x, y);
#else
return __float2half( __half2float(x) + __half2float(y) );
#endif
}
};
#if defined(__CUDA_BF16_TYPES_EXIST__)
template<>
struct FuncSum<__nv_bfloat16> {
__device__ FuncSum(uint64_t opArg=0) {}
__device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const {
#if __CUDA_ARCH__ >= 800
return __hadd2(x, y);
#else
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#endif
}
__device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const {
#if __CUDA_ARCH__ >= 800
return __hadd(x, y);
#else
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
#endif
}
};
#endif
template<>
struct FuncProd<half> {
__device__ FuncProd(uint64_t opArg=0) {}
__device__ half2 operator()(const half2 x, const half2 y) const {
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
return __hmul2(x, y);
#else
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fx.x * fy.x;
fr.y = fx.y * fy.y;
return __float22half2_rn(fr);
#endif
}
__device__ half operator()(const half x, const half y) const {
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
return __hmul(x, y);
#else
return __float2half( __half2float(x) * __half2float(y) );
#endif
}
};
#if defined(__CUDA_BF16_TYPES_EXIST__)
template<>
struct FuncProd<__nv_bfloat16> {
__device__ FuncProd(uint64_t opArg=0) {}
__device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const {
#if __CUDA_ARCH__ >= 800
return __hmul2(x, y);
#else
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#endif
}
__device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const {
#if __CUDA_ARCH__ >= 800
return __hmul(x, y);
#else
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
#endif
}
};
#endif
template<>
struct FuncMax<half> {
__device__ FuncMax(uint64_t opArg=0) {}
__device__ half2 operator()(const half2 x, const half2 y) const {
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fmaxf(fx.x, fy.x);
fr.y = fmaxf(fx.y, fy.y);
return __float22half2_rn(fr);
}
__device__ half operator()(const half x, const half y) const {
float fx, fy, fm;
fx = __half2float(x);
fy = __half2float(y);
fm = fmaxf(fx, fy);
return __float2half(fm);
}
};
#if defined(__CUDA_BF16_TYPES_EXIST__)
template<>
struct FuncMax<__nv_bfloat16> {
__device__ FuncMax(uint64_t opArg=0) {}
__device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const {
#if __CUDA_ARCH__ >= 800
return __hmax2(x, y);
#else
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fmaxf(fxl, fyl), fmaxf(fxh, fyh));
#endif
}
__device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const {
#if __CUDA_ARCH__ >= 800
return __hmax(x, y);
#else
float fx, fy;
fx = __bfloat162float(x);
fy = __bfloat162float(y);
return __float2bfloat16(fmaxf(fx, fy));
#endif
}
};
#endif
template<>
struct FuncMin<half> {
__device__ FuncMin(uint64_t opArg=0) {}
__device__ half2 operator()(const half2 x, const half2 y) const {
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fminf(fx.x, fy.x);
fr.y = fminf(fx.y, fy.y);
return __float22half2_rn(fr);
}
__device__ half operator()(const half x, const half y) const {
float fx, fy, fm;
fx = __half2float(x);
fy = __half2float(y);
fm = fminf(fx, fy);
return __float2half(fm);
}
};
#if defined(__CUDA_BF16_TYPES_EXIST__)
template<>
struct FuncMin<__nv_bfloat16> {
__device__ FuncMin(uint64_t opArg=0) {}
__device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const {
#if __CUDA_ARCH__ >= 800
return __hmin2(x, y);
#else
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fminf(fxl, fyl), fminf(fxh, fyh));
#endif
}
__device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const {
#if __CUDA_ARCH__ >= 800
return __hmin(x, y);
#else
float fx, fy;
fx = __bfloat162float(x);
fy = __bfloat162float(y);
return __float2bfloat16(fminf(fx, fy));
#endif
}
};
#endif
template<>
struct FuncMax<float> {
__device__ FuncMax(uint64_t opArg=0) {}
__device__ float operator()(float x, float y) const {
return fmaxf(x, y);
}
};
template<>
struct FuncMin<float> {
__device__ FuncMin(uint64_t opArg=0) {}
__device__ float operator()(float x, float y) const {
return fminf(x, y);
}
};
template<>
struct FuncMax<double> {
__device__ FuncMax(uint64_t opArg=0) {}
__device__ double operator()(double x, double y) const {
return fmax(x, y);
}
};
template<>
struct FuncMin<double> {
__device__ FuncMin(uint64_t opArg=0) {}
__device__ double operator()(double x, double y) const {
return fmin(x, y);
}
};
template<typename T>
struct IsFloatingPoint: std::false_type {};
template<>
struct IsFloatingPoint<half>: std::true_type {};
#if defined(__CUDA_BF16_TYPES_EXIST__)
template<>
struct IsFloatingPoint<__nv_bfloat16>: std::true_type {};
#endif
template<>
struct IsFloatingPoint<float>: std::true_type {};
template<>
struct IsFloatingPoint<double>: std::true_type {};
template<typename T, bool IsFloating=IsFloatingPoint<T>::value>
struct FuncSumPostDiv;
template<typename T>
struct FuncSumPostDiv<T, /*IsFloating=*/false>: FuncSum<T> {
static constexpr bool IsPreOpIdentity = true;
static constexpr bool IsPostOpIdentity = false;
int n;
__device__ FuncSumPostDiv(uint64_t opArg): n(opArg) {}
// inherits FuncSum::operator()
__device__ T preOp(T x) const { return x; }
__device__ T postOp(T x) const { return T(x/n); }
};
template<typename T>
struct FuncSumPostDiv<T, /*IsFloating=*/true> {
static_assert(sizeof(T)!=sizeof(T), "FuncSumPostDiv is only for implementing ncclAvg on integral types.");
};
template<typename T>
struct FuncPreMulSum: FuncSum<T> { // integral T since all floats are specialized below
static constexpr bool IsPreOpIdentity = false;
static constexpr bool IsPostOpIdentity = true;
T scale;
__device__ FuncPreMulSum(uint64_t opArg) { scale = *(T*)&opArg; }
// inherits FuncSum::operator()
__device__ T preOp(T x) const { return x*scale; }
__device__ T postOp(T x) const { return x; }
};
template<>
struct FuncPreMulSum<double>: FuncSum<double> {
static constexpr bool IsPreOpIdentity = false;
static constexpr bool IsPostOpIdentity = true;
double scale;
__device__ FuncPreMulSum(uint64_t opArg) {
scale = *(double*)&opArg;
}
// inherits FuncSum::operator()
__device__ double preOp(double x) const {
return IsPreOpIdentity ? x : x*scale;
}
__device__ double postOp(double x) const {
return IsPostOpIdentity ? x : x*scale;
}
};
template<>
struct FuncPreMulSum<float>: FuncSum<float> {
static constexpr bool IsPreOpIdentity = false;
static constexpr bool IsPostOpIdentity = true;
float scale;
__device__ FuncPreMulSum(uint64_t opArg) {
scale = *(float*)&opArg;
}
// inherits FuncSum::operator()
__device__ float preOp(float x) const {
return IsPreOpIdentity ? x : x*scale;
}
__device__ float postOp(float x) const {
return IsPostOpIdentity ? x : x*scale;
}
};
template<>
struct FuncPreMulSum<half>: FuncSum<half> {
// Change these to switch between all prescale, all postscale, or both by sqrt(N).
// Obviously, the only invalid combination is both true. An improvement would be
// make this parameterized as a build time setting and passed here through
// preprocessor definitions.
static constexpr bool IsPreOpIdentity = false;
static constexpr bool IsPostOpIdentity = true;
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
half2 scale;
__device__ FuncPreMulSum(uint64_t opArg) {
scale.x = *(half*)&opArg;
scale.y = scale.x;
}
// inherits FuncSum::operator()
__device__ half preOp(half x) const {
return IsPreOpIdentity ? x : __hmul(x, scale.x);
}
__device__ half2 preOp(half2 x) const {
return IsPreOpIdentity ? x : __hmul2(x, scale);
}
__device__ half postOp(half x) const {
return IsPostOpIdentity ? x : __hmul(x, scale.x);
}
__device__ half2 postOp(half2 x) const {
return IsPostOpIdentity ? x : __hmul2(x, scale);
}
#else
float scale;
__device__ FuncPreMulSum(uint64_t opArg) {
scale = __half2float(*(half*)&opArg);
}
// inherits FuncSum::operator()
__device__ half preOp(half x) const {
return IsPreOpIdentity ? x : __float2half(__half2float(x)*scale);
}
__device__ half2 preOp(half2 x) const {
if (IsPreOpIdentity)
return x;
else {
float2 a = __half22float2(x);
a.x *= scale;
a.y *= scale;
return __float22half2_rn(a);
}
}
__device__ half postOp(half x) const {
return IsPostOpIdentity ? x : __float2half(__half2float(x)*scale);
}
__device__ half2 postOp(half2 x) const {
if (IsPostOpIdentity)
return x;
else {
float2 a = __half22float2(x);
a.x *= scale;
a.y *= scale;
return __float22half2_rn(a);
}
}
#endif
};
#if defined(__CUDA_BF16_TYPES_EXIST__)
template<>
struct FuncPreMulSum<__nv_bfloat16>: FuncSum<__nv_bfloat16> {
// Change these to switch between all prescale, all postscale, or both by sqrt(N).
// Obviously, the only invalid combination is both true. An improvement would be
// make this parameterized as a build time setting and passed here through
// preprocessor definitions.
static constexpr bool IsPreOpIdentity = false;
static constexpr bool IsPostOpIdentity = true;
#if __CUDA_ARCH__ >= 800
__nv_bfloat162 scale;
__device__ FuncPreMulSum(uint64_t opArg) {
scale.x = *(__nv_bfloat16*)&opArg;
scale.y = scale.x;
}
// inherits FuncSum::operator()
__device__ __nv_bfloat16 preOp(__nv_bfloat16 x) const {
return IsPreOpIdentity ? x : __hmul(x, scale.x);
}
__device__ __nv_bfloat162 preOp(__nv_bfloat162 x) const {
return IsPreOpIdentity ? x : __hmul2(x, scale);
}
__device__ __nv_bfloat16 postOp(__nv_bfloat16 x) const {
return IsPostOpIdentity ? x : __hmul(x, scale.x);
}
__device__ __nv_bfloat162 postOp(__nv_bfloat162 x) const {
return IsPostOpIdentity ? x : __hmul2(x, scale);
}
#else
float scale;
__device__ FuncPreMulSum(uint64_t opArg) {
scale = *(__nv_bfloat16*)&opArg;
}
// inherits FuncSum::operator()
__device__ __nv_bfloat16 preOp(__nv_bfloat16 x) const {
return IsPreOpIdentity ? x : __float2bfloat16(__bfloat162float(x)*scale);
}
__device__ __nv_bfloat162 preOp(__nv_bfloat162 x) const {
if (IsPreOpIdentity)
return x;
else {
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);
return __floats2bfloat162_rn(fxl * scale, fxh * scale);
}
}
__device__ __nv_bfloat16 postOp(__nv_bfloat16 x) const {
return IsPostOpIdentity ? x : __float2bfloat16(__bfloat162float(x)*scale);
}
__device__ __nv_bfloat162 postOp(__nv_bfloat162 x) const {
if (IsPostOpIdentity)
return x;
else {
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);
return __floats2bfloat162_rn(fxl * scale, fxh * scale);
}
}
#endif
};
#endif
template<typename T>
struct FuncTraits<FuncPreMulSum<T>> {
static constexpr bool IsPreOpIdentity = FuncPreMulSum<T>::IsPreOpIdentity;
static constexpr bool IsPostOpIdentity = FuncPreMulSum<T>::IsPostOpIdentity;
template<typename U>
__device__ static U preOp(FuncPreMulSum<T> fn, U x) {
return fn.preOp(x);
}
template<typename U>
__device__ static U postOp(FuncPreMulSum<T> fn, U x) {
return fn.postOp(x);
}
};
template<typename T>
struct FuncTraits<FuncSumPostDiv<T>> {
static constexpr bool IsPreOpIdentity = FuncSumPostDiv<T>::IsPreOpIdentity;
static constexpr bool IsPostOpIdentity = FuncSumPostDiv<T>::IsPostOpIdentity;
template<typename U>
__device__ static U preOp(FuncSumPostDiv<T> fn, U x) {
return fn.preOp(x);
}
template<typename U>
__device__ static U postOp(FuncSumPostDiv<T> fn, U x) {
return fn.postOp(x);
}
};
#endif // REDUCE_KERNEL_H_