/************************************************************************* * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ template class ncclLLPrimitives { private: const int tid; const int nthreads; const int wid; const int stepLines; int nrecv = 0; int nsend = 0; struct ncclConnInfo* recvConn = NULL; volatile uint64_t* recvConnHeadPtr = NULL; uint64_t recvConnHead; struct ncclConnInfo* sendConn = NULL; volatile int* sendConnFifoPtr = NULL; volatile uint64_t* sendConnHeadPtr = NULL; uint64_t sendConnHead; uint64_t sendConnHeadCache; // Cache last seen value uint64_t recvStep[NRECV]; uint64_t sendStep[NSEND]; union ncclLLFifoLine* recvBuff[NRECV]; union ncclLLFifoLine* sendBuff[NSEND]; struct ncclDevComm* comm; inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; } inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; } inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); } inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); } inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); } inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); } inline __device__ void barrier() { asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); } uint32_t mismatch = 0; const uint64_t opCount; inline __device__ void checkMismatch(struct ncclConnInfo* conn) { if (mismatch > 20) { // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch // Note that we are not using _threadfence_system in LL so the error cannot be asserted *(comm->fatalDevError) = ncclDevSuspectedMismatch; } else if (conn && *conn->opCountRem > opCount) { mismatch += 1; } } uint32_t spins = 0; uint32_t abort = 0; inline __device__ int checkAbort(int i, int send) { spins++; if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { abort = *(comm->abortFlag); if (wid == i) checkMismatch(send ? sendConn : recvConn); spins = 0; } return abort; } inline __device__ void waitSend(int nbytes) { spins = 0; mismatch = 0; if (sendConnHeadPtr) { while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { sendConnHeadCache = *sendConnHeadPtr; if (checkAbort(wid, 1)) break; } if (sendConnFifoPtr) { int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size; } sendConnHead += 1; } barrier(); } inline __device__ void incRecv(int i) { recvStep[i] += 1; } inline __device__ void postRecv() { barrier(); if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1; } inline __device__ void incSend(int i, int offset) { // LL Cleanup : write all flags in the slice to make sure we don't have // data corruption when flag loops over. if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) { for (int o = offset; oi4)); if (checkAbort(i, 0)) break; } while ((flag1 != flag) || (flag2 != flag)); uint64_t val64 = data1 + (((uint64_t)data2) << 32); return val64; } __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) { asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag)); } // Using memcpy handles misaligned pointers. __device__ uint64_t readAL(uint64_t* src) { uint64_t val; memcpy((char*)&val, (char*)src, sizeof(uint64_t)); return val; } __device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) { memcpy((char*)dst, (char*)&val, nbytes); } template __device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) { uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T); uint32_t npack = DIVUP(nbytes, sizeof(uint64_t)); uint64_t* srcPack = (uint64_t*)srcPtr; uint64_t* dstPack = (uint64_t*)dstPtr; int offset = tid; // Always waitSend in case of cleanup if (SEND) waitSend(npack*sizeof(union ncclLLFifoLine)); // Do multiples of 64 bits #pragma unroll 2 for (; offset()(readLL(0, offset), val); for (int i=1; i()(readLL(i, offset), val); } } // Send : inter-node, then intra-node, then local if (SEND) { for (int i=1; ibuffs[NCCL_PROTO_LL]; recvStep[i] = conn->step; if (wid == i) recvConn = conn; nrecv++; } __device__ __forceinline__ void loadRecvSync() { if (tid >= nthreads-WARP_SIZE && wid < nrecv) { recvConnHeadPtr = recvConn->head; recvConnHead = recvConn->step; // Update opCount in case we skipped some operations *(recvConn->opCountLoc) = opCount; } } __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { sendBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL]; sendStep[i] = conn->step; if (wid == i) sendConn = conn; nsend++; } __device__ __forceinline__ void loadSendSync() { if (tid < nsend) { sendConnHeadPtr = sendConn->head; sendConnHeadCache = *sendConnHeadPtr; sendConnHead = sendConn->step; sendConnFifoPtr = sendConn->fifo; *(sendConn->opCountLoc) = opCount; } } __device__ __forceinline__ void saveRecvSync() { if (tid >= nthreads-WARP_SIZE && wid < nrecv) { recvConn->step = recvConnHead; *(recvConn->opCountLoc) = opCount+1; __threadfence_block(); } } __device__ __forceinline__ void saveSendSync() { if (tid < nsend) { sendConn->step = sendConnHead; *(sendConn->opCountLoc) = opCount+1; __threadfence_block(); } } public: __device__ __forceinline__ ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount) { // Make sure step is updated before we read it. barrier(); for (int i=0; i= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i); for (int i=0; i= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i); loadRecvSync(); loadSendSync(); } __device__ void send(const T* src, int nelem) { return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem); } __device__ void recv(T* dst, int nelem) { return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem); } __device__ void recvReduceSend(const T* src, int nelem) { return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem); } __device__ void recvReduceCopy(const T* src, T* dst, int nelem) { return LLGenericOp<1, 0, 1, 1>(src, dst, nelem); } __device__ void copySend(const T* src, T* dst, int nelem) { return LLGenericOp<0, 1, 1, 1>(src, dst, nelem); } __device__ void recvCopySend(T* dst, int nelem) { return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem); } __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) { return LLGenericOp<1, 1, 1, 1>(src, dst, nelem); } __device__ __forceinline__ ~ncclLLPrimitives() { // Save steps for the next operation saveRecvSync(); saveSendSync(); } };