Add support for A100 GPU and related platforms. Add support for CUDA 11. Add support for send/receive operations (beta).
267 lines
8.9 KiB
C++
267 lines
8.9 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
|
|
template <typename T, class FUNC, int NRECV, int NSEND>
|
|
class ncclLLPrimitives {
|
|
private:
|
|
const int tid;
|
|
const int nthreads;
|
|
const int wid;
|
|
const int stepLines;
|
|
int nrecv = 0;
|
|
int nsend = 0;
|
|
struct ncclConnInfo* recvConn = NULL;
|
|
volatile uint64_t* recvConnHeadPtr = NULL;
|
|
uint64_t recvConnHead;
|
|
|
|
struct ncclConnInfo* sendConn = NULL;
|
|
volatile int* sendConnFifoPtr = NULL;
|
|
volatile uint64_t* sendConnHeadPtr = NULL;
|
|
uint64_t sendConnHead;
|
|
uint64_t sendConnHeadCache; // Cache last seen value
|
|
|
|
uint64_t recvStep[NRECV];
|
|
uint64_t sendStep[NSEND];
|
|
union ncclLLFifoLine* recvBuff[NRECV];
|
|
union ncclLLFifoLine* sendBuff[NSEND];
|
|
struct ncclDevComm* comm;
|
|
|
|
inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; }
|
|
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; }
|
|
inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
|
|
inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
|
|
inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); }
|
|
inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); }
|
|
|
|
inline __device__ void barrier() {
|
|
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
|
|
}
|
|
|
|
uint32_t mismatch = 0;
|
|
const uint64_t opCount;
|
|
|
|
inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
|
|
if (mismatch > 20) {
|
|
// We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
|
|
// Note that we are not using _threadfence_system in LL so the error cannot be asserted
|
|
*(comm->fatalDevError) = ncclDevSuspectedMismatch;
|
|
} else if (conn && *conn->opCountRem > opCount) {
|
|
mismatch += 1;
|
|
}
|
|
}
|
|
|
|
uint32_t spins = 0;
|
|
uint32_t abort = 0;
|
|
|
|
inline __device__ int checkAbort(int i, int send) {
|
|
spins++;
|
|
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
|
abort = *(comm->abortFlag);
|
|
if (wid == i) checkMismatch(send ? sendConn : recvConn);
|
|
spins = 0;
|
|
}
|
|
return abort;
|
|
}
|
|
|
|
inline __device__ void waitSend(int nbytes) {
|
|
spins = 0;
|
|
mismatch = 0;
|
|
if (sendConnHeadPtr) {
|
|
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
|
sendConnHeadCache = *sendConnHeadPtr;
|
|
if (checkAbort(wid, 1)) break;
|
|
}
|
|
if (sendConnFifoPtr) {
|
|
int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes;
|
|
sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size;
|
|
}
|
|
sendConnHead += 1;
|
|
}
|
|
barrier();
|
|
}
|
|
|
|
inline __device__ void incRecv(int i) {
|
|
recvStep[i] += 1;
|
|
}
|
|
inline __device__ void postRecv() {
|
|
barrier();
|
|
if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1;
|
|
}
|
|
|
|
inline __device__ void incSend(int i, int offset) {
|
|
// LL Cleanup : write all flags in the slice to make sure we don't have
|
|
// data corruption when flag loops over.
|
|
if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) {
|
|
for (int o = offset; o<stepLines; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
|
|
}
|
|
sendStep[i]++;
|
|
}
|
|
|
|
__device__ uint64_t readLL(int i, int offset) {
|
|
union ncclLLFifoLine* src = recvPtr(i) + offset;
|
|
uint32_t flag = recvFlag(i);
|
|
uint32_t data1, flag1, data2, flag2;
|
|
spins = 0;
|
|
mismatch = 0;
|
|
do {
|
|
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
|
|
if (checkAbort(i, 0)) break;
|
|
} while ((flag1 != flag) || (flag2 != flag));
|
|
uint64_t val64 = data1 + (((uint64_t)data2) << 32);
|
|
return val64;
|
|
}
|
|
|
|
__device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
|
|
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag));
|
|
}
|
|
|
|
// Using memcpy handles misaligned pointers.
|
|
__device__ uint64_t readAL(uint64_t* src) {
|
|
uint64_t val;
|
|
memcpy((char*)&val, (char*)src, sizeof(uint64_t));
|
|
return val;
|
|
}
|
|
|
|
__device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) {
|
|
memcpy((char*)dst, (char*)&val, nbytes);
|
|
}
|
|
|
|
template <int RECV, int SEND, int SRC, int DST>
|
|
__device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) {
|
|
uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T);
|
|
uint32_t npack = DIVUP(nbytes, sizeof(uint64_t));
|
|
uint64_t* srcPack = (uint64_t*)srcPtr;
|
|
uint64_t* dstPack = (uint64_t*)dstPtr;
|
|
int offset = tid;
|
|
|
|
// Always waitSend in case of cleanup
|
|
if (SEND) waitSend(npack*sizeof(union ncclLLFifoLine));
|
|
|
|
// Do multiples of 64 bits
|
|
#pragma unroll 2
|
|
for (; offset<npack; offset+=nthreads) {
|
|
// Recv : local, then intra-node, then inter-node
|
|
uint64_t val = SRC ? readAL(srcPack+offset) : readLL(0, offset);
|
|
if (RECV) {
|
|
if (SRC) val = MULTI<FUNC, T>()(readLL(0, offset), val);
|
|
for (int i=1; i<NRECV && i<nrecv; i++) {
|
|
val = MULTI<FUNC, T>()(readLL(i, offset), val);
|
|
}
|
|
}
|
|
|
|
// Send : inter-node, then intra-node, then local
|
|
if (SEND) {
|
|
for (int i=1; i<NSEND && i<nsend; i++) storeLL(sendPtr(i)+offset, val, sendFlag(i));
|
|
storeLL(sendPtr(0)+offset, val, sendFlag(0));
|
|
}
|
|
if (DST) {
|
|
if (((offset*sizeof(uint64_t)) ^ nbytes) < sizeof(uint64_t)) {
|
|
// Last incomplete word
|
|
storeAL(dstPack+offset, val, nbytes & 0x7);
|
|
} else {
|
|
storeAL(dstPack+offset, val, sizeof(uint64_t));
|
|
}
|
|
}
|
|
}
|
|
FOR_RECV(incRecv); if (RECV) postRecv();
|
|
FOR_SEND(incSend, offset);
|
|
}
|
|
|
|
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
|
recvBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
|
|
recvStep[i] = conn->step;
|
|
if (wid == i) recvConn = conn;
|
|
nrecv++;
|
|
}
|
|
__device__ __forceinline__ void loadRecvSync() {
|
|
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
|
recvConnHeadPtr = recvConn->head;
|
|
recvConnHead = recvConn->step;
|
|
// Update opCount in case we skipped some operations
|
|
*(recvConn->opCountLoc) = opCount;
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
|
|
sendBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
|
|
sendStep[i] = conn->step;
|
|
if (wid == i) sendConn = conn;
|
|
nsend++;
|
|
}
|
|
__device__ __forceinline__ void loadSendSync() {
|
|
if (tid < nsend) {
|
|
sendConnHeadPtr = sendConn->head;
|
|
sendConnHeadCache = *sendConnHeadPtr;
|
|
sendConnHead = sendConn->step;
|
|
sendConnFifoPtr = sendConn->fifo;
|
|
*(sendConn->opCountLoc) = opCount;
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void saveRecvSync() {
|
|
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
|
recvConn->step = recvConnHead;
|
|
*(recvConn->opCountLoc) = opCount+1;
|
|
__threadfence_block();
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void saveSendSync() {
|
|
if (tid < nsend) {
|
|
sendConn->step = sendConnHead;
|
|
*(sendConn->opCountLoc) = opCount+1;
|
|
__threadfence_block();
|
|
}
|
|
}
|
|
|
|
public:
|
|
__device__ __forceinline__
|
|
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
|
|
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount) {
|
|
// Make sure step is updated before we read it.
|
|
barrier();
|
|
|
|
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i);
|
|
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i);
|
|
loadRecvSync();
|
|
loadSendSync();
|
|
}
|
|
|
|
__device__ void send(const T* src, int nelem) {
|
|
return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem);
|
|
}
|
|
|
|
__device__ void recv(T* dst, int nelem) {
|
|
return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem);
|
|
}
|
|
|
|
__device__ void recvReduceSend(const T* src, int nelem) {
|
|
return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem);
|
|
}
|
|
|
|
__device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
|
|
return LLGenericOp<1, 0, 1, 1>(src, dst, nelem);
|
|
}
|
|
|
|
__device__ void copySend(const T* src, T* dst, int nelem) {
|
|
return LLGenericOp<0, 1, 1, 1>(src, dst, nelem);
|
|
}
|
|
|
|
__device__ void recvCopySend(T* dst, int nelem) {
|
|
return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem);
|
|
}
|
|
|
|
__device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
|
|
return LLGenericOp<1, 1, 1, 1>(src, dst, nelem);
|
|
}
|
|
|
|
__device__ __forceinline__ ~ncclLLPrimitives() {
|
|
// Save steps for the next operation
|
|
saveRecvSync();
|
|
saveSendSync();
|
|
}
|
|
};
|