Fix collective mismatch error when using ncclSend/ncclRecv
This commit is contained in:
David Addison 2020-07-23 12:08:08 -07:00
parent 2d8601701d
commit 033d799524
18 changed files with 35 additions and 170 deletions

View File

@ -1,6 +1,6 @@
##### version
NCCL_MAJOR := 2
NCCL_MINOR := 7
NCCL_PATCH := 6
NCCL_PATCH := 8
NCCL_SUFFIX :=
PKG_REVISION := 1

View File

@ -28,7 +28,7 @@ __device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
@ -88,7 +88,7 @@ __device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) {
const ssize_t loopSize = nChannels*chunkSize;
const ssize_t size = args->coll.count;
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;
@ -155,7 +155,7 @@ __device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) {
const ssize_t loopSize = nChannels*chunkSize;
const ssize_t size = args->coll.count;
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;

View File

@ -28,7 +28,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels));
@ -108,7 +108,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@ -126,7 +126,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, args->opCount);
ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@ -166,7 +166,7 @@ __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
if (blockIdx.x < nChannels) { // first half of the channels do reduce
struct ncclTree* tree = &channel->collTreeUp;
ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@ -183,7 +183,7 @@ __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
struct ncclTree* tree = &channel->collTreeDn;
ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@ -215,7 +215,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
const ssize_t loopSize = nChannels*nranks*chunkSize;
const ssize_t size = args->coll.count;
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;
@ -297,7 +297,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@ -315,7 +315,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@ -355,7 +355,7 @@ __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
if (blockIdx.x < nChannels) { // first half of the channels do reduce
struct ncclTree* tree = &channel->collTreeUp;
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@ -372,7 +372,7 @@ __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
struct ncclTree* tree = &channel->collTreeDn;
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@ -406,7 +406,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
const ssize_t loopSize = nChannels*nranks*chunkSize;
const ssize_t size = args->coll.count;
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;
@ -490,7 +490,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
if (treeUp->up == -1) {
// ReduceAndBroadcast : max number of recv is 3, max number of send is 3
ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, stepSize, channel, comm, args->opCount);
ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset);
@ -499,7 +499,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
} else {
if (tid < nthreadsSplit) {
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, stepSize, channel, comm, args->opCount);
ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@ -512,7 +512,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
}
} else {
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, stepSize, channel, comm, args->opCount);
ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;

View File

@ -30,7 +30,7 @@ __device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, 0, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
@ -75,7 +75,7 @@ __device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) {
const int nextRank = ring->devUserRanks[1];
const int root = args->coll.root;
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;
@ -127,7 +127,7 @@ __device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) {
const int nextRank = ring->devUserRanks[1];
const int root = args->coll.root;
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;

View File

@ -84,18 +84,6 @@ class ncclPrimitives {
}
}
uint32_t mismatch = 0;
const uint64_t opCount;
inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
if (mismatch) {
// In non-LL, we use _threadfence_system before incrementing opCount, yet we are still waiting for credits here, so there must be a size mismatch
*(comm->fatalDevError) = ncclDevAssertedMismatch;
} else if (conn && *conn->opCountRem > opCount) {
mismatch += 1;
}
}
uint32_t spins = 0;
uint32_t abort = 0;
@ -103,7 +91,6 @@ class ncclPrimitives {
spins++;
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = *(comm->abortFlag);
if (wid == i) checkMismatch(send ? sendConn : recvConn);
spins = 0;
}
return abort;
@ -111,7 +98,6 @@ class ncclPrimitives {
inline __device__ void waitSend(int nbytes) {
spins = 0;
mismatch = 0;
if (sendConnHeadPtr) {
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + SLICESTEPS) {
sendConnHeadCache = *sendConnHeadPtr;
@ -126,7 +112,6 @@ class ncclPrimitives {
inline __device__ void waitRecv() {
spins = 0;
mismatch = 0;
if (recvConnTailPtr) {
while (recvConnTailCache < recvConnTail + SLICESTEPS) {
recvConnTailCache = *recvConnTailPtr;
@ -252,8 +237,6 @@ class ncclPrimitives {
recvConnHeadPtr = recvConn->head;
// Return credits in case we rounded up.
*recvConnHeadPtr = recvConnHead;
// Update opCount in case we skipped some operations
*(recvConn->opCountLoc) = opCount;
}
}
@ -277,7 +260,6 @@ class ncclPrimitives {
sendConnHeadPtr = sendConn->head;
sendConnHeadCache = *sendConnHeadPtr;
sendConnFifoPtr = sendConn->fifo;
*(sendConn->opCountLoc) = opCount;
}
if (tid >= nthreads && wid<nsend) {
sendConnTailPtr = sendConn->tail;
@ -287,7 +269,6 @@ class ncclPrimitives {
__device__ __forceinline__ void saveRecvSync() {
if (tid >= nthreads && wid < nrecv) {
recvConn->step = recvConnHead;
*(recvConn->opCountLoc) = opCount+1;
__threadfence_system();
}
}
@ -295,15 +276,14 @@ class ncclPrimitives {
__device__ __forceinline__ void saveSendSync() {
if (tid < nsend) {
sendConn->step = sendConnHead;
*(sendConn->opCountLoc) = opCount+1;
__threadfence_system();
}
}
public:
__device__ __forceinline__
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount) {
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize) {
// Make sure step is updated before we read it.
barrier();

View File

@ -40,19 +40,6 @@ class ncclLLPrimitives {
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
}
uint32_t mismatch = 0;
const uint64_t opCount;
inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
if (mismatch > 20) {
// We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
// Note that we are not using _threadfence_system in LL so the error cannot be asserted
*(comm->fatalDevError) = ncclDevSuspectedMismatch;
} else if (conn && *conn->opCountRem > opCount) {
mismatch += 1;
}
}
uint32_t spins = 0;
uint32_t abort = 0;
@ -60,7 +47,6 @@ class ncclLLPrimitives {
spins++;
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = *(comm->abortFlag);
if (wid == i) checkMismatch(send ? sendConn : recvConn);
spins = 0;
}
return abort;
@ -68,7 +54,6 @@ class ncclLLPrimitives {
inline __device__ void waitSend(int nbytes) {
spins = 0;
mismatch = 0;
if (sendConnHeadPtr) {
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
sendConnHeadCache = *sendConnHeadPtr;
@ -105,7 +90,6 @@ class ncclLLPrimitives {
uint32_t flag = recvFlag(i);
uint32_t data1, flag1, data2, flag2;
spins = 0;
mismatch = 0;
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
if (checkAbort(i, 0)) break;
@ -180,8 +164,6 @@ class ncclLLPrimitives {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
recvConnHeadPtr = recvConn->head;
recvConnHead = recvConn->step;
// Update opCount in case we skipped some operations
*(recvConn->opCountLoc) = opCount;
}
}
@ -197,14 +179,12 @@ class ncclLLPrimitives {
sendConnHeadCache = *sendConnHeadPtr;
sendConnHead = sendConn->step;
sendConnFifoPtr = sendConn->fifo;
*(sendConn->opCountLoc) = opCount;
}
}
__device__ __forceinline__ void saveRecvSync() {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
recvConn->step = recvConnHead;
*(recvConn->opCountLoc) = opCount+1;
__threadfence_block();
}
}
@ -212,15 +192,14 @@ class ncclLLPrimitives {
__device__ __forceinline__ void saveSendSync() {
if (tid < nsend) {
sendConn->step = sendConnHead;
*(sendConn->opCountLoc) = opCount+1;
__threadfence_block();
}
}
public:
__device__ __forceinline__
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount) {
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines) {
// Make sure step is updated before we read it.
barrier();

View File

@ -54,19 +54,6 @@ class ncclLL128Primitives {
}
}
uint32_t mismatch = 0;
const uint64_t opCount;
inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
if (mismatch > 20) {
// We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
// Note that we are not using _threadfence_system in LL so the error cannot be asserted
*(comm->fatalDevError) = ncclDevSuspectedMismatch;
} else if (conn && *conn->opCountRem > opCount) {
mismatch += 1;
}
}
uint32_t spins = 0;
uint32_t abort = 0;
@ -74,7 +61,6 @@ class ncclLL128Primitives {
spins++;
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = *(comm->abortFlag);
if (wid == i) checkMismatch(send ? sendConn : recvConn);
spins = 0;
}
return abort;
@ -82,7 +68,6 @@ class ncclLL128Primitives {
inline __device__ void waitSend(int nbytes) {
spins = 0;
mismatch = 0;
if (sendConnHeadPtr) {
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
sendConnHeadCache = *sendConnHeadPtr;
@ -319,8 +304,6 @@ class ncclLL128Primitives {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
recvConnHeadPtr = recvConn->head;
recvConnHead = recvConn->step;
// Update opCount in case we skipped some operations
*(recvConn->opCountLoc) = opCount;
}
}
@ -336,7 +319,6 @@ class ncclLL128Primitives {
sendConnHeadCache = *sendConnHeadPtr;
sendConnHead = sendConn->step;
sendConnFifoPtr = sendConn->fifo;
*(sendConn->opCountLoc) = opCount;
}
if (tid >= nthreads-WARP_SIZE && wid<nsend) {
if (sendConn->fifo) {
@ -349,7 +331,6 @@ class ncclLL128Primitives {
__device__ __forceinline__ void saveRecvSync() {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
recvConn->step = recvConnHead;
*(recvConn->opCountLoc) = opCount+1;
__threadfence_block();
}
}
@ -357,15 +338,14 @@ class ncclLL128Primitives {
__device__ __forceinline__ void saveSendSync() {
if (tid < nsend) {
sendConn->step = sendConnHead;
*(sendConn->opCountLoc) = opCount+1;
__threadfence_block();
}
}
public:
__device__ __forceinline__
ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), opCount(opCount), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
// Make sure step is updated before we read it.
barrier();

View File

@ -31,7 +31,7 @@ __device__ void ncclReduceRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, 0, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
@ -72,7 +72,7 @@ __device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) {
const int prevRank = ring->devUserRanks[nranks-1];
const int root = args->coll.root;
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;
@ -121,7 +121,7 @@ __device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) {
const int prevRank = ring->devUserRanks[nranks-1];
const int root = args->coll.root;
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;

View File

@ -28,7 +28,7 @@ __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, 0, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
@ -83,7 +83,7 @@ __device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) {
const ssize_t loopSize = nChannels*chunkSize;
const ssize_t size = args->coll.count;
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;
@ -147,7 +147,7 @@ __device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) {
const ssize_t loopSize = nChannels*chunkSize;
const ssize_t size = args->coll.count;
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->sendbuff;

View File

@ -51,7 +51,7 @@ __device__ void ncclSendRecvKernel(struct CollectiveArgs* args) {
int peer = (comm->rank+(int)args->p2p.delta)%comm->nRanks;
ncclPrimitives<UNROLL, 1, 1, T, 2, 1, 1, FUNC>
prims(tid, nthreadsSplit, peerNone, &peer, recvbuff, stepSize*4, channel, comm, args->opCount);
prims(tid, nthreadsSplit, peerNone, &peer, recvbuff, stepSize*4, channel, comm);
if (sendSize == 0) {
prims.send(sendbuff, 0);
@ -67,7 +67,7 @@ __device__ void ncclSendRecvKernel(struct CollectiveArgs* args) {
int peer = (comm->rank-(int)args->p2p.delta+comm->nRanks)%comm->nRanks;
ncclPrimitives<UNROLL, 1, 1, T, 1, 2, 1, FUNC>
prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &peer, peerNone, recvbuff, stepSize*4, channel, comm, args->opCount);
prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &peer, peerNone, recvbuff, stepSize*4, channel, comm);
if (recvSize == 0) {
prims.recv(recvbuff, 0);

View File

@ -337,7 +337,6 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclCo
coll->args.sendbuff = info->sendbuff;
coll->args.recvbuff = info->recvbuff;
coll->args.comm = info->comm->devComm;
coll->args.opCount = info->comm->opCount;
if (info->coll == ncclCollSendRecv) {
coll->args.p2p.sendCount = info->sendbytes;

View File

@ -37,7 +37,6 @@ struct ncclSendMem {
char pad1[CACHE_LINE_SIZE-sizeof(uint64_t)];
void* ptrExchange;
char pad2[CACHE_LINE_SIZE-sizeof(void*)];
uint64_t opCount;
};
char pad3[MEM_ALIGN];
};
@ -49,7 +48,6 @@ struct ncclRecvMem {
struct {
uint64_t tail;
char pad1[CACHE_LINE_SIZE-sizeof(uint64_t)];
uint64_t opCount;
char pad2[CACHE_LINE_SIZE-sizeof(uint64_t)];
int sizesFifo[NCCL_STEPS];
};
@ -109,9 +107,6 @@ struct ncclComm {
// Whether there has been a fatal error in this communicator.
ncclResult_t fatalError;
// Error reported by GPU
volatile ncclDevError_t* fatalDevError;
// Flag to ask NCCL kernels to abort
volatile uint32_t *abortFlag;

View File

@ -83,8 +83,6 @@ struct ncclConnInfo {
char *buffs[NCCL_NUM_PROTOCOLS]; // Local for recv, remote for send
uint64_t *tail; // Local for recv, remote for send
uint64_t *head; // Local for send, remote for recv
uint64_t *opCountLoc; // opCount of local rank
uint64_t *opCountRem; // opCount of remote rank
int direct; // Direct communication
void **ptrExchange; // Pointer exchange for direct communication
@ -136,7 +134,6 @@ struct ncclDevComm;
/* Make sure to adjust padding at the end of ncclColl. */
struct CollectiveArgs {
struct ncclDevComm* comm;
uint64_t opCount;
// local and remote input, output, and buffer
const void * sendbuff;
@ -205,12 +202,6 @@ struct ncclChannel {
};
static_assert(sizeof(struct ncclChannel) == 0x80*sizeof(int), "ncclChannel must have a pow2 size");
typedef enum {
ncclDevSuccess,
ncclDevAssertedMismatch,
ncclDevSuspectedMismatch
} ncclDevError_t;
struct ncclDevComm {
int rank;
int nRanks;
@ -218,7 +209,6 @@ struct ncclDevComm {
// Flag to ask NCCL kernels to abort
volatile uint32_t *abortFlag;
volatile ncclDevError_t *fatalDevError;
// Channels, device side
struct ncclChannel* channels;

View File

@ -192,7 +192,6 @@ static ncclResult_t commFree(ncclComm_t comm) {
free(comm->intraCC);
}
CUDACHECK(cudaFreeHost((void *)comm->abortFlag));
CUDACHECK(cudaFreeHost((void *)comm->fatalDevError));
// Poison comm to try and catch a double free
commPoison(comm);
@ -235,10 +234,6 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
#endif
comm->fatalError = ncclSuccess;
NCCLCHECK(ncclCudaHostCalloc((ncclDevError_t**)&comm->fatalDevError, 1));
comm->hostDevComm.fatalDevError = comm->fatalDevError;
*comm->fatalDevError = ncclDevSuccess;
NCCLCHECK(ncclCudaHostCalloc((uint32_t**)&comm->abortFlag, 1));
comm->hostDevComm.abortFlag = comm->abortFlag;
*comm->abortFlag = 0;
@ -982,31 +977,6 @@ NCCL_API(ncclResult_t, ncclCommGetAsyncError, ncclComm_t comm, ncclResult_t *asy
ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError) {
NCCLCHECK(PtrCheck(comm, "ncclGetAsyncError", "comm"));
NCCLCHECK(PtrCheck(asyncError, "ncclGetAsyncError", "asyncError"));
// Check device reported error
static ncclDevError_t printedDevErr = ncclDevSuccess;
switch(*comm->fatalDevError) {
case ncclDevSuccess :
break;
case ncclDevAssertedMismatch :
if (printedDevErr != ncclDevAssertedMismatch) {
WARN("Mismatched collective detected, please check your collective calls at and around rank %d. You can use NCCL_DEBUG=INFO and NCCL_DEBUG_SUBSYS=COLL to see the collective logs", comm->rank);
printedDevErr = ncclDevAssertedMismatch;
}
if (comm->fatalError == ncclSuccess) {
comm->fatalError = ncclInvalidUsage;
}
break;
case ncclDevSuspectedMismatch :
if (printedDevErr != ncclDevSuspectedMismatch) {
WARN("Your program may be hanging, this may be caused by a collective mismatch around rank %d. Please check your collective calls at and around this rank. You can use NCCL_DEBUG=INFO and NCCL_DEBUG_SUBSYS=COLL to see the collective logs", comm->rank);
printedDevErr = ncclDevSuspectedMismatch;
}
break;
default:
WARN("Unknown device error %d", *comm->fatalDevError);
return ncclInternalError;
}
*asyncError = comm->fatalError;
return ncclSuccess;
}

View File

@ -139,10 +139,8 @@ ncclResult_t collNetSendConnect(struct ncclConnect* connectInfos, int nranks, in
// Head/Tail/Opcount/Fifos are always on host
send->conn.tail = &resources->devHostRecvMem->tail;
send->conn.opCountRem = &resources->devHostRecvMem->opCount;
send->conn.fifo = resources->devHostRecvMem->sizesFifo;
send->conn.head = &resources->devHostSendMem->head;
send->conn.opCountLoc = &resources->devHostSendMem->opCount;
for (int i=0; i<NCCL_STEPS; i++) send->conn.fifo[i] = -1;
// Get info from recv side
@ -178,9 +176,7 @@ ncclResult_t collNetRecvConnect(struct ncclConnect* connectInfos, int nranks, in
// Head/Tail/Opcount are always on host
recv->conn.tail = &resources->devHostRecvMem->tail;
recv->conn.opCountLoc = &resources->devHostRecvMem->opCount;
recv->conn.head = &resources->devHostSendMem->head;
recv->conn.opCountRem = &resources->devHostSendMem->opCount;
// Connect to coll comm
collNetHandle_t** handlePtrs = NULL;
@ -258,9 +254,6 @@ ncclResult_t collNetSendProxy(struct ncclProxyArgs* args) {
}
struct collNetSendResources* resources = (struct collNetSendResources*) (args->connector->transportResources);
if (args->state == ncclProxyOpReady) {
// Update opCount
resources->hostRecvMem->opCount = args->opCount;
// Round to next multiple of sliceSteps
resources->step = ROUNDUP(resources->step, args->chunkSteps);
args->head = resources->step;
@ -365,9 +358,6 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
}
struct collNetRecvResources* resources = (struct collNetRecvResources*) (args->connector->transportResources);
if (args->state == ncclProxyOpReady) {
// Update opCount
resources->hostSendMem->opCount = args->opCount;
// Round to next multiple of sliceSteps
resources->step = ROUNDUP(resources->step, args->chunkSteps);
args->head = resources->step;

View File

@ -66,10 +66,8 @@ ncclResult_t netSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* gra
send->conn.direct |= resources->useGdr ? NCCL_DIRECT_NIC : 0;
send->conn.tail = &resources->recvMem->tail;
send->conn.opCountRem = &resources->recvMem->opCount;
send->conn.fifo = resources->recvMem->sizesFifo;
send->conn.head = &resources->sendMem->head;
send->conn.opCountLoc = &resources->sendMem->opCount;
for (int i=0; i<NCCL_STEPS; i++) send->conn.fifo[i] = -1;
int protoLoc[NCCL_NUM_PROTOCOLS];
@ -117,9 +115,7 @@ ncclResult_t netRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* gra
recv->conn.direct |= resources->useGdr ? NCCL_DIRECT_NIC : 0;
recv->conn.tail = &resources->recvMem->tail;
recv->conn.opCountLoc = &resources->recvMem->opCount;
recv->conn.head = &resources->sendMem->head;
recv->conn.opCountRem = &resources->sendMem->opCount;
int protoLoc[NCCL_NUM_PROTOCOLS];
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
@ -224,9 +220,6 @@ ncclResult_t netRecvFree(void* transportResources) {
ncclResult_t netSendProxy(struct ncclProxyArgs* args) {
struct netSendResources* resources = (struct netSendResources*) (args->connector->transportResources);
if (args->state == ncclProxyOpReady) {
// Update opCount
resources->recvMem->opCount = args->opCount;
// Round to next multiple of sliceSteps
resources->step = ROUNDUP(resources->step, args->chunkSteps);
args->head = resources->step;
@ -334,9 +327,6 @@ ncclResult_t netSendProxy(struct ncclProxyArgs* args) {
ncclResult_t netRecvProxy(struct ncclProxyArgs* args) {
struct netRecvResources* resources = (struct netRecvResources*) (args->connector->transportResources);
if (args->state == ncclProxyOpReady) {
// Update opCount
resources->sendMem->opCount = args->opCount;
// Round to next multiple of sliceSteps
resources->step = ROUNDUP(resources->step, args->chunkSteps);
args->head = resources->step;

View File

@ -251,10 +251,8 @@ static ncclResult_t p2pSendConnect(struct ncclConnect* connectInfo, int nranks,
}
}
send->conn.tail = &remDevMem->tail;
send->conn.opCountRem = &remDevMem->opCount;
send->conn.head = &resources->devMem->head;
send->conn.ptrExchange = &resources->devMem->ptrExchange;
send->conn.opCountLoc = &resources->devMem->opCount;
return ncclSuccess;
}
@ -291,9 +289,7 @@ ncclResult_t p2pRecvConnect(struct ncclConnect* connectInfo, int nranks, int ran
}
}
recv->conn.tail = &resources->devMem->tail;
recv->conn.opCountLoc = &resources->devMem->opCount;
recv->conn.head = &remDevMem->head;
recv->conn.opCountRem = &remDevMem->opCount;
return ncclSuccess;
}

View File

@ -126,10 +126,8 @@ ncclResult_t shmSendConnect(struct ncclConnect* connectInfo, int nranks, int ran
offset += send->comm->buffSizes[p];
}
send->conn.tail = &resources->devRemHostMem->tail;
send->conn.opCountRem = &resources->devRemHostMem->opCount;
send->conn.head = &resources->devHostMem->head;
send->conn.opCountLoc = &resources->devHostMem->opCount;
return ncclSuccess;
}
@ -145,7 +143,6 @@ ncclResult_t shmRecvConnect(struct ncclConnect* connectInfo, int nranks, int ran
NCCLCHECK(shmOpen(shmName, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, 0));
NCCLCHECK(shmUnlink(shmName));
recv->conn.head = &resources->devRemHostMem->head;
recv->conn.opCountRem = &resources->devRemHostMem->opCount;
int offset = 0;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
@ -153,7 +150,6 @@ ncclResult_t shmRecvConnect(struct ncclConnect* connectInfo, int nranks, int ran
offset += recv->comm->buffSizes[p];
}
recv->conn.tail = &resources->devHostMem->tail;
recv->conn.opCountLoc = &resources->devHostMem->opCount;
return ncclSuccess;
}