2.7.8-1
Fix collective mismatch error when using ncclSend/ncclRecv
This commit is contained in:
parent
2d8601701d
commit
033d799524
@ -1,6 +1,6 @@
|
||||
##### version
|
||||
NCCL_MAJOR := 2
|
||||
NCCL_MINOR := 7
|
||||
NCCL_PATCH := 6
|
||||
NCCL_PATCH := 8
|
||||
NCCL_SUFFIX :=
|
||||
PKG_REVISION := 1
|
||||
|
@ -28,7 +28,7 @@ __device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) {
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
@ -88,7 +88,7 @@ __device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) {
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
@ -155,7 +155,7 @@ __device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) {
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
|
@ -28,7 +28,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
|
||||
ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels));
|
||||
@ -108,7 +108,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
|
||||
do {
|
||||
struct ncclTree* tree = &channel->treeUp;
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
|
||||
ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -126,7 +126,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
|
||||
do {
|
||||
struct ncclTree* tree = &channel->treeDn;
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, args->opCount);
|
||||
ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -166,7 +166,7 @@ __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
|
||||
|
||||
if (blockIdx.x < nChannels) { // first half of the channels do reduce
|
||||
struct ncclTree* tree = &channel->collTreeUp;
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -183,7 +183,7 @@ __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
|
||||
|
||||
if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
|
||||
struct ncclTree* tree = &channel->collTreeDn;
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -215,7 +215,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
|
||||
const ssize_t loopSize = nChannels*nranks*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
@ -297,7 +297,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
|
||||
do {
|
||||
struct ncclTree* tree = &channel->treeUp;
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -315,7 +315,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
|
||||
do {
|
||||
struct ncclTree* tree = &channel->treeDn;
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -355,7 +355,7 @@ __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
|
||||
|
||||
if (blockIdx.x < nChannels) { // first half of the channels do reduce
|
||||
struct ncclTree* tree = &channel->collTreeUp;
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -372,7 +372,7 @@ __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
|
||||
|
||||
if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
|
||||
struct ncclTree* tree = &channel->collTreeDn;
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -406,7 +406,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
|
||||
const ssize_t loopSize = nChannels*nranks*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
@ -490,7 +490,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
|
||||
|
||||
if (treeUp->up == -1) {
|
||||
// ReduceAndBroadcast : max number of recv is 3, max number of send is 3
|
||||
ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, stepSize, channel, comm, args->opCount);
|
||||
ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@ -499,7 +499,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
|
||||
} else {
|
||||
if (tid < nthreadsSplit) {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, stepSize, channel, comm, args->opCount);
|
||||
ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
@ -512,7 +512,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
|
||||
}
|
||||
} else {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, stepSize, channel, comm, args->opCount);
|
||||
ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
@ -30,7 +30,7 @@ __device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) {
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
@ -75,7 +75,7 @@ __device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) {
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
@ -127,7 +127,7 @@ __device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) {
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
|
@ -84,18 +84,6 @@ class ncclPrimitives {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t mismatch = 0;
|
||||
const uint64_t opCount;
|
||||
|
||||
inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
|
||||
if (mismatch) {
|
||||
// In non-LL, we use _threadfence_system before incrementing opCount, yet we are still waiting for credits here, so there must be a size mismatch
|
||||
*(comm->fatalDevError) = ncclDevAssertedMismatch;
|
||||
} else if (conn && *conn->opCountRem > opCount) {
|
||||
mismatch += 1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
uint32_t abort = 0;
|
||||
|
||||
@ -103,7 +91,6 @@ class ncclPrimitives {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *(comm->abortFlag);
|
||||
if (wid == i) checkMismatch(send ? sendConn : recvConn);
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
@ -111,7 +98,6 @@ class ncclPrimitives {
|
||||
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
spins = 0;
|
||||
mismatch = 0;
|
||||
if (sendConnHeadPtr) {
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + SLICESTEPS) {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
@ -126,7 +112,6 @@ class ncclPrimitives {
|
||||
|
||||
inline __device__ void waitRecv() {
|
||||
spins = 0;
|
||||
mismatch = 0;
|
||||
if (recvConnTailPtr) {
|
||||
while (recvConnTailCache < recvConnTail + SLICESTEPS) {
|
||||
recvConnTailCache = *recvConnTailPtr;
|
||||
@ -252,8 +237,6 @@ class ncclPrimitives {
|
||||
recvConnHeadPtr = recvConn->head;
|
||||
// Return credits in case we rounded up.
|
||||
*recvConnHeadPtr = recvConnHead;
|
||||
// Update opCount in case we skipped some operations
|
||||
*(recvConn->opCountLoc) = opCount;
|
||||
}
|
||||
}
|
||||
|
||||
@ -277,7 +260,6 @@ class ncclPrimitives {
|
||||
sendConnHeadPtr = sendConn->head;
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
sendConnFifoPtr = sendConn->fifo;
|
||||
*(sendConn->opCountLoc) = opCount;
|
||||
}
|
||||
if (tid >= nthreads && wid<nsend) {
|
||||
sendConnTailPtr = sendConn->tail;
|
||||
@ -287,7 +269,6 @@ class ncclPrimitives {
|
||||
__device__ __forceinline__ void saveRecvSync() {
|
||||
if (tid >= nthreads && wid < nrecv) {
|
||||
recvConn->step = recvConnHead;
|
||||
*(recvConn->opCountLoc) = opCount+1;
|
||||
__threadfence_system();
|
||||
}
|
||||
}
|
||||
@ -295,15 +276,14 @@ class ncclPrimitives {
|
||||
__device__ __forceinline__ void saveSendSync() {
|
||||
if (tid < nsend) {
|
||||
sendConn->step = sendConnHead;
|
||||
*(sendConn->opCountLoc) = opCount+1;
|
||||
__threadfence_system();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount) {
|
||||
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize) {
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
|
||||
|
@ -40,19 +40,6 @@ class ncclLLPrimitives {
|
||||
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
|
||||
}
|
||||
|
||||
uint32_t mismatch = 0;
|
||||
const uint64_t opCount;
|
||||
|
||||
inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
|
||||
if (mismatch > 20) {
|
||||
// We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
|
||||
// Note that we are not using _threadfence_system in LL so the error cannot be asserted
|
||||
*(comm->fatalDevError) = ncclDevSuspectedMismatch;
|
||||
} else if (conn && *conn->opCountRem > opCount) {
|
||||
mismatch += 1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
uint32_t abort = 0;
|
||||
|
||||
@ -60,7 +47,6 @@ class ncclLLPrimitives {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *(comm->abortFlag);
|
||||
if (wid == i) checkMismatch(send ? sendConn : recvConn);
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
@ -68,7 +54,6 @@ class ncclLLPrimitives {
|
||||
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
spins = 0;
|
||||
mismatch = 0;
|
||||
if (sendConnHeadPtr) {
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
@ -105,7 +90,6 @@ class ncclLLPrimitives {
|
||||
uint32_t flag = recvFlag(i);
|
||||
uint32_t data1, flag1, data2, flag2;
|
||||
spins = 0;
|
||||
mismatch = 0;
|
||||
do {
|
||||
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
|
||||
if (checkAbort(i, 0)) break;
|
||||
@ -180,8 +164,6 @@ class ncclLLPrimitives {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
recvConnHeadPtr = recvConn->head;
|
||||
recvConnHead = recvConn->step;
|
||||
// Update opCount in case we skipped some operations
|
||||
*(recvConn->opCountLoc) = opCount;
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,14 +179,12 @@ class ncclLLPrimitives {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
sendConnHead = sendConn->step;
|
||||
sendConnFifoPtr = sendConn->fifo;
|
||||
*(sendConn->opCountLoc) = opCount;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
recvConn->step = recvConnHead;
|
||||
*(recvConn->opCountLoc) = opCount+1;
|
||||
__threadfence_block();
|
||||
}
|
||||
}
|
||||
@ -212,15 +192,14 @@ class ncclLLPrimitives {
|
||||
__device__ __forceinline__ void saveSendSync() {
|
||||
if (tid < nsend) {
|
||||
sendConn->step = sendConnHead;
|
||||
*(sendConn->opCountLoc) = opCount+1;
|
||||
__threadfence_block();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount) {
|
||||
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines) {
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
|
||||
|
@ -54,19 +54,6 @@ class ncclLL128Primitives {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t mismatch = 0;
|
||||
const uint64_t opCount;
|
||||
|
||||
inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
|
||||
if (mismatch > 20) {
|
||||
// We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
|
||||
// Note that we are not using _threadfence_system in LL so the error cannot be asserted
|
||||
*(comm->fatalDevError) = ncclDevSuspectedMismatch;
|
||||
} else if (conn && *conn->opCountRem > opCount) {
|
||||
mismatch += 1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
uint32_t abort = 0;
|
||||
|
||||
@ -74,7 +61,6 @@ class ncclLL128Primitives {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *(comm->abortFlag);
|
||||
if (wid == i) checkMismatch(send ? sendConn : recvConn);
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
@ -82,7 +68,6 @@ class ncclLL128Primitives {
|
||||
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
spins = 0;
|
||||
mismatch = 0;
|
||||
if (sendConnHeadPtr) {
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
@ -319,8 +304,6 @@ class ncclLL128Primitives {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
recvConnHeadPtr = recvConn->head;
|
||||
recvConnHead = recvConn->step;
|
||||
// Update opCount in case we skipped some operations
|
||||
*(recvConn->opCountLoc) = opCount;
|
||||
}
|
||||
}
|
||||
|
||||
@ -336,7 +319,6 @@ class ncclLL128Primitives {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
sendConnHead = sendConn->step;
|
||||
sendConnFifoPtr = sendConn->fifo;
|
||||
*(sendConn->opCountLoc) = opCount;
|
||||
}
|
||||
if (tid >= nthreads-WARP_SIZE && wid<nsend) {
|
||||
if (sendConn->fifo) {
|
||||
@ -349,7 +331,6 @@ class ncclLL128Primitives {
|
||||
__device__ __forceinline__ void saveRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
recvConn->step = recvConnHead;
|
||||
*(recvConn->opCountLoc) = opCount+1;
|
||||
__threadfence_block();
|
||||
}
|
||||
}
|
||||
@ -357,15 +338,14 @@ class ncclLL128Primitives {
|
||||
__device__ __forceinline__ void saveSendSync() {
|
||||
if (tid < nsend) {
|
||||
sendConn->step = sendConnHead;
|
||||
*(sendConn->opCountLoc) = opCount+1;
|
||||
__threadfence_block();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), opCount(opCount), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
|
||||
ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
|
||||
|
@ -31,7 +31,7 @@ __device__ void ncclReduceRingKernel(struct CollectiveArgs* args) {
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
@ -72,7 +72,7 @@ __device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) {
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
@ -121,7 +121,7 @@ __device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) {
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
|
@ -28,7 +28,7 @@ __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) {
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
@ -83,7 +83,7 @@ __device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) {
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
@ -147,7 +147,7 @@ __device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) {
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
|
@ -51,7 +51,7 @@ __device__ void ncclSendRecvKernel(struct CollectiveArgs* args) {
|
||||
|
||||
int peer = (comm->rank+(int)args->p2p.delta)%comm->nRanks;
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 2, 1, 1, FUNC>
|
||||
prims(tid, nthreadsSplit, peerNone, &peer, recvbuff, stepSize*4, channel, comm, args->opCount);
|
||||
prims(tid, nthreadsSplit, peerNone, &peer, recvbuff, stepSize*4, channel, comm);
|
||||
|
||||
if (sendSize == 0) {
|
||||
prims.send(sendbuff, 0);
|
||||
@ -67,7 +67,7 @@ __device__ void ncclSendRecvKernel(struct CollectiveArgs* args) {
|
||||
|
||||
int peer = (comm->rank-(int)args->p2p.delta+comm->nRanks)%comm->nRanks;
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, 2, 1, FUNC>
|
||||
prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &peer, peerNone, recvbuff, stepSize*4, channel, comm, args->opCount);
|
||||
prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &peer, peerNone, recvbuff, stepSize*4, channel, comm);
|
||||
|
||||
if (recvSize == 0) {
|
||||
prims.recv(recvbuff, 0);
|
||||
|
@ -337,7 +337,6 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclCo
|
||||
coll->args.sendbuff = info->sendbuff;
|
||||
coll->args.recvbuff = info->recvbuff;
|
||||
coll->args.comm = info->comm->devComm;
|
||||
coll->args.opCount = info->comm->opCount;
|
||||
|
||||
if (info->coll == ncclCollSendRecv) {
|
||||
coll->args.p2p.sendCount = info->sendbytes;
|
||||
|
@ -37,7 +37,6 @@ struct ncclSendMem {
|
||||
char pad1[CACHE_LINE_SIZE-sizeof(uint64_t)];
|
||||
void* ptrExchange;
|
||||
char pad2[CACHE_LINE_SIZE-sizeof(void*)];
|
||||
uint64_t opCount;
|
||||
};
|
||||
char pad3[MEM_ALIGN];
|
||||
};
|
||||
@ -49,7 +48,6 @@ struct ncclRecvMem {
|
||||
struct {
|
||||
uint64_t tail;
|
||||
char pad1[CACHE_LINE_SIZE-sizeof(uint64_t)];
|
||||
uint64_t opCount;
|
||||
char pad2[CACHE_LINE_SIZE-sizeof(uint64_t)];
|
||||
int sizesFifo[NCCL_STEPS];
|
||||
};
|
||||
@ -109,9 +107,6 @@ struct ncclComm {
|
||||
// Whether there has been a fatal error in this communicator.
|
||||
ncclResult_t fatalError;
|
||||
|
||||
// Error reported by GPU
|
||||
volatile ncclDevError_t* fatalDevError;
|
||||
|
||||
// Flag to ask NCCL kernels to abort
|
||||
volatile uint32_t *abortFlag;
|
||||
|
||||
|
@ -83,8 +83,6 @@ struct ncclConnInfo {
|
||||
char *buffs[NCCL_NUM_PROTOCOLS]; // Local for recv, remote for send
|
||||
uint64_t *tail; // Local for recv, remote for send
|
||||
uint64_t *head; // Local for send, remote for recv
|
||||
uint64_t *opCountLoc; // opCount of local rank
|
||||
uint64_t *opCountRem; // opCount of remote rank
|
||||
|
||||
int direct; // Direct communication
|
||||
void **ptrExchange; // Pointer exchange for direct communication
|
||||
@ -136,7 +134,6 @@ struct ncclDevComm;
|
||||
/* Make sure to adjust padding at the end of ncclColl. */
|
||||
struct CollectiveArgs {
|
||||
struct ncclDevComm* comm;
|
||||
uint64_t opCount;
|
||||
|
||||
// local and remote input, output, and buffer
|
||||
const void * sendbuff;
|
||||
@ -205,12 +202,6 @@ struct ncclChannel {
|
||||
};
|
||||
static_assert(sizeof(struct ncclChannel) == 0x80*sizeof(int), "ncclChannel must have a pow2 size");
|
||||
|
||||
typedef enum {
|
||||
ncclDevSuccess,
|
||||
ncclDevAssertedMismatch,
|
||||
ncclDevSuspectedMismatch
|
||||
} ncclDevError_t;
|
||||
|
||||
struct ncclDevComm {
|
||||
int rank;
|
||||
int nRanks;
|
||||
@ -218,7 +209,6 @@ struct ncclDevComm {
|
||||
|
||||
// Flag to ask NCCL kernels to abort
|
||||
volatile uint32_t *abortFlag;
|
||||
volatile ncclDevError_t *fatalDevError;
|
||||
|
||||
// Channels, device side
|
||||
struct ncclChannel* channels;
|
||||
|
30
src/init.cc
30
src/init.cc
@ -192,7 +192,6 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
free(comm->intraCC);
|
||||
}
|
||||
CUDACHECK(cudaFreeHost((void *)comm->abortFlag));
|
||||
CUDACHECK(cudaFreeHost((void *)comm->fatalDevError));
|
||||
|
||||
// Poison comm to try and catch a double free
|
||||
commPoison(comm);
|
||||
@ -235,10 +234,6 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
#endif
|
||||
comm->fatalError = ncclSuccess;
|
||||
|
||||
NCCLCHECK(ncclCudaHostCalloc((ncclDevError_t**)&comm->fatalDevError, 1));
|
||||
comm->hostDevComm.fatalDevError = comm->fatalDevError;
|
||||
*comm->fatalDevError = ncclDevSuccess;
|
||||
|
||||
NCCLCHECK(ncclCudaHostCalloc((uint32_t**)&comm->abortFlag, 1));
|
||||
comm->hostDevComm.abortFlag = comm->abortFlag;
|
||||
*comm->abortFlag = 0;
|
||||
@ -982,31 +977,6 @@ NCCL_API(ncclResult_t, ncclCommGetAsyncError, ncclComm_t comm, ncclResult_t *asy
|
||||
ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError) {
|
||||
NCCLCHECK(PtrCheck(comm, "ncclGetAsyncError", "comm"));
|
||||
NCCLCHECK(PtrCheck(asyncError, "ncclGetAsyncError", "asyncError"));
|
||||
|
||||
// Check device reported error
|
||||
static ncclDevError_t printedDevErr = ncclDevSuccess;
|
||||
switch(*comm->fatalDevError) {
|
||||
case ncclDevSuccess :
|
||||
break;
|
||||
case ncclDevAssertedMismatch :
|
||||
if (printedDevErr != ncclDevAssertedMismatch) {
|
||||
WARN("Mismatched collective detected, please check your collective calls at and around rank %d. You can use NCCL_DEBUG=INFO and NCCL_DEBUG_SUBSYS=COLL to see the collective logs", comm->rank);
|
||||
printedDevErr = ncclDevAssertedMismatch;
|
||||
}
|
||||
if (comm->fatalError == ncclSuccess) {
|
||||
comm->fatalError = ncclInvalidUsage;
|
||||
}
|
||||
break;
|
||||
case ncclDevSuspectedMismatch :
|
||||
if (printedDevErr != ncclDevSuspectedMismatch) {
|
||||
WARN("Your program may be hanging, this may be caused by a collective mismatch around rank %d. Please check your collective calls at and around this rank. You can use NCCL_DEBUG=INFO and NCCL_DEBUG_SUBSYS=COLL to see the collective logs", comm->rank);
|
||||
printedDevErr = ncclDevSuspectedMismatch;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
WARN("Unknown device error %d", *comm->fatalDevError);
|
||||
return ncclInternalError;
|
||||
}
|
||||
*asyncError = comm->fatalError;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
@ -139,10 +139,8 @@ ncclResult_t collNetSendConnect(struct ncclConnect* connectInfos, int nranks, in
|
||||
|
||||
// Head/Tail/Opcount/Fifos are always on host
|
||||
send->conn.tail = &resources->devHostRecvMem->tail;
|
||||
send->conn.opCountRem = &resources->devHostRecvMem->opCount;
|
||||
send->conn.fifo = resources->devHostRecvMem->sizesFifo;
|
||||
send->conn.head = &resources->devHostSendMem->head;
|
||||
send->conn.opCountLoc = &resources->devHostSendMem->opCount;
|
||||
for (int i=0; i<NCCL_STEPS; i++) send->conn.fifo[i] = -1;
|
||||
|
||||
// Get info from recv side
|
||||
@ -178,9 +176,7 @@ ncclResult_t collNetRecvConnect(struct ncclConnect* connectInfos, int nranks, in
|
||||
|
||||
// Head/Tail/Opcount are always on host
|
||||
recv->conn.tail = &resources->devHostRecvMem->tail;
|
||||
recv->conn.opCountLoc = &resources->devHostRecvMem->opCount;
|
||||
recv->conn.head = &resources->devHostSendMem->head;
|
||||
recv->conn.opCountRem = &resources->devHostSendMem->opCount;
|
||||
|
||||
// Connect to coll comm
|
||||
collNetHandle_t** handlePtrs = NULL;
|
||||
@ -258,9 +254,6 @@ ncclResult_t collNetSendProxy(struct ncclProxyArgs* args) {
|
||||
}
|
||||
struct collNetSendResources* resources = (struct collNetSendResources*) (args->connector->transportResources);
|
||||
if (args->state == ncclProxyOpReady) {
|
||||
// Update opCount
|
||||
resources->hostRecvMem->opCount = args->opCount;
|
||||
|
||||
// Round to next multiple of sliceSteps
|
||||
resources->step = ROUNDUP(resources->step, args->chunkSteps);
|
||||
args->head = resources->step;
|
||||
@ -365,9 +358,6 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
}
|
||||
struct collNetRecvResources* resources = (struct collNetRecvResources*) (args->connector->transportResources);
|
||||
if (args->state == ncclProxyOpReady) {
|
||||
// Update opCount
|
||||
resources->hostSendMem->opCount = args->opCount;
|
||||
|
||||
// Round to next multiple of sliceSteps
|
||||
resources->step = ROUNDUP(resources->step, args->chunkSteps);
|
||||
args->head = resources->step;
|
||||
|
@ -66,10 +66,8 @@ ncclResult_t netSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* gra
|
||||
|
||||
send->conn.direct |= resources->useGdr ? NCCL_DIRECT_NIC : 0;
|
||||
send->conn.tail = &resources->recvMem->tail;
|
||||
send->conn.opCountRem = &resources->recvMem->opCount;
|
||||
send->conn.fifo = resources->recvMem->sizesFifo;
|
||||
send->conn.head = &resources->sendMem->head;
|
||||
send->conn.opCountLoc = &resources->sendMem->opCount;
|
||||
for (int i=0; i<NCCL_STEPS; i++) send->conn.fifo[i] = -1;
|
||||
|
||||
int protoLoc[NCCL_NUM_PROTOCOLS];
|
||||
@ -117,9 +115,7 @@ ncclResult_t netRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* gra
|
||||
|
||||
recv->conn.direct |= resources->useGdr ? NCCL_DIRECT_NIC : 0;
|
||||
recv->conn.tail = &resources->recvMem->tail;
|
||||
recv->conn.opCountLoc = &resources->recvMem->opCount;
|
||||
recv->conn.head = &resources->sendMem->head;
|
||||
recv->conn.opCountRem = &resources->sendMem->opCount;
|
||||
|
||||
int protoLoc[NCCL_NUM_PROTOCOLS];
|
||||
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
|
||||
@ -224,9 +220,6 @@ ncclResult_t netRecvFree(void* transportResources) {
|
||||
ncclResult_t netSendProxy(struct ncclProxyArgs* args) {
|
||||
struct netSendResources* resources = (struct netSendResources*) (args->connector->transportResources);
|
||||
if (args->state == ncclProxyOpReady) {
|
||||
// Update opCount
|
||||
resources->recvMem->opCount = args->opCount;
|
||||
|
||||
// Round to next multiple of sliceSteps
|
||||
resources->step = ROUNDUP(resources->step, args->chunkSteps);
|
||||
args->head = resources->step;
|
||||
@ -334,9 +327,6 @@ ncclResult_t netSendProxy(struct ncclProxyArgs* args) {
|
||||
ncclResult_t netRecvProxy(struct ncclProxyArgs* args) {
|
||||
struct netRecvResources* resources = (struct netRecvResources*) (args->connector->transportResources);
|
||||
if (args->state == ncclProxyOpReady) {
|
||||
// Update opCount
|
||||
resources->sendMem->opCount = args->opCount;
|
||||
|
||||
// Round to next multiple of sliceSteps
|
||||
resources->step = ROUNDUP(resources->step, args->chunkSteps);
|
||||
args->head = resources->step;
|
||||
|
@ -251,10 +251,8 @@ static ncclResult_t p2pSendConnect(struct ncclConnect* connectInfo, int nranks,
|
||||
}
|
||||
}
|
||||
send->conn.tail = &remDevMem->tail;
|
||||
send->conn.opCountRem = &remDevMem->opCount;
|
||||
send->conn.head = &resources->devMem->head;
|
||||
send->conn.ptrExchange = &resources->devMem->ptrExchange;
|
||||
send->conn.opCountLoc = &resources->devMem->opCount;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -291,9 +289,7 @@ ncclResult_t p2pRecvConnect(struct ncclConnect* connectInfo, int nranks, int ran
|
||||
}
|
||||
}
|
||||
recv->conn.tail = &resources->devMem->tail;
|
||||
recv->conn.opCountLoc = &resources->devMem->opCount;
|
||||
recv->conn.head = &remDevMem->head;
|
||||
recv->conn.opCountRem = &remDevMem->opCount;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
@ -126,10 +126,8 @@ ncclResult_t shmSendConnect(struct ncclConnect* connectInfo, int nranks, int ran
|
||||
offset += send->comm->buffSizes[p];
|
||||
}
|
||||
send->conn.tail = &resources->devRemHostMem->tail;
|
||||
send->conn.opCountRem = &resources->devRemHostMem->opCount;
|
||||
|
||||
send->conn.head = &resources->devHostMem->head;
|
||||
send->conn.opCountLoc = &resources->devHostMem->opCount;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -145,7 +143,6 @@ ncclResult_t shmRecvConnect(struct ncclConnect* connectInfo, int nranks, int ran
|
||||
NCCLCHECK(shmOpen(shmName, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, 0));
|
||||
NCCLCHECK(shmUnlink(shmName));
|
||||
recv->conn.head = &resources->devRemHostMem->head;
|
||||
recv->conn.opCountRem = &resources->devRemHostMem->opCount;
|
||||
|
||||
int offset = 0;
|
||||
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
|
||||
@ -153,7 +150,6 @@ ncclResult_t shmRecvConnect(struct ncclConnect* connectInfo, int nranks, int ran
|
||||
offset += recv->comm->buffSizes[p];
|
||||
}
|
||||
recv->conn.tail = &resources->devHostMem->tail;
|
||||
recv->conn.opCountLoc = &resources->devHostMem->opCount;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user