2.10.3-1
Add support for bfloat16. Add ncclAvg reduction operation. Improve performance for aggregated operations. Improve performance for tree. Improve network error reporting. Add NCCL_NET parameter to force a specific network. Add NCCL_IB_QPS_PER_CONNECTION parameter to split IB traffic onto multiple queue pairs. Fix topology detection error in WSL2. Fix proxy memory elements affinity (improve alltoall performance). Fix graph search on cubemesh topologies. Fix hang in cubemesh during NVB connections.
This commit is contained in:
parent
3fec2fa5ee
commit
7e51592129
@ -16,9 +16,11 @@ __hidden ncclResult_t pluginPtrSupport(int dev, int* supportedTypes) { return nc
|
||||
__hidden ncclResult_t pluginListen(int dev, void* handle, void** listenComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginConnect(int dev, void* handle, void** sendComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginAccept(void* listenComm, void** recvComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginIsend(void* sendComm, void* data, int size, int type, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginIrecv(void* recvComm, void* data, int size, int type, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginFlush(void* recvComm, void* data, int size) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginRegMr(void* collComm, void* data, int size, int type, void** mhandle) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginDeregMr(void* collComm, void* mhandle) { return ncclInternalError;}
|
||||
__hidden ncclResult_t pluginIsend(void* sendComm, void* data, int size, void* mhandle, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginFlush(void* recvComm, void* data, int size, void* mhandle) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginTest(void* request, int* done, int* size) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCloseSend(void* sendComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCloseRecv(void* recvComm) { return ncclInternalError; }
|
||||
@ -33,6 +35,8 @@ ncclNet_t NCCL_PLUGIN_SYMBOL = {
|
||||
pluginListen,
|
||||
pluginConnect,
|
||||
pluginAccept,
|
||||
pluginRegMr,
|
||||
pluginDeregMr,
|
||||
pluginIsend,
|
||||
pluginIrecv,
|
||||
pluginFlush,
|
||||
@ -41,3 +45,36 @@ ncclNet_t NCCL_PLUGIN_SYMBOL = {
|
||||
pluginCloseRecv,
|
||||
pluginCloseListen
|
||||
};
|
||||
|
||||
__hidden ncclResult_t pluginCollNetInit(ncclDebugLogger_t logFunction) { return ncclSuccess; }
|
||||
__hidden ncclResult_t pluginCollNetDevices(int* ndev) { *ndev = 0; return ncclSuccess; }
|
||||
__hidden ncclResult_t pluginCollNetPciPath(int dev, char** path) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetPtrSupport(int dev, int* supportedTypes) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetListen(int dev, void* handle, void** listenComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetConnect(void* handles[], int nranks, int rank, void* listenComm, void** collComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetReduceSupport(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetRegMr(void* collComm, void* data, int size, int type, void** mhandle) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetDeregMr(void* collComm, void* mhandle) { return ncclInternalError;}
|
||||
__hidden ncclResult_t pluginCollNetIallreduce(void* collComm, void* sendData, void* recvData, int count, ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetFlush(void* collComm, void* data, int size, void* mhandle) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetTest(void* request, int* done, int* size) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetCloseColl(void* collComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetCloseListen(void* listenComm) { return ncclInternalError; }
|
||||
|
||||
ncclCollNet_t NCCL_COLLNET_PLUGIN_SYMBOL = {
|
||||
"Dummy",
|
||||
pluginCollNetInit,
|
||||
pluginCollNetDevices,
|
||||
pluginCollNetPciPath,
|
||||
pluginCollNetPtrSupport,
|
||||
pluginCollNetListen,
|
||||
pluginCollNetConnect,
|
||||
pluginCollNetReduceSupport,
|
||||
pluginCollNetRegMr,
|
||||
pluginCollNetDeregMr,
|
||||
pluginCollNetIallreduce,
|
||||
pluginCollNetFlush,
|
||||
pluginCollNetTest,
|
||||
pluginCollNetCloseColl,
|
||||
pluginCollNetCloseListen
|
||||
};
|
||||
|
@ -55,7 +55,7 @@ CXXFLAGS := -DCUDA_MAJOR=$(CUDA_MAJOR) -DCUDA_MINOR=$(CUDA_MINOR) -fPIC -fvisi
|
||||
# Maxrregcount needs to be set accordingly to NCCL_MAX_NTHREADS (otherwise it will cause kernel launch errors)
|
||||
# 512 : 120, 640 : 96, 768 : 80, 1024 : 60
|
||||
# We would not have to set this if we used __launch_bounds__, but this only works on kernels, not on functions.
|
||||
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 -Xptxas -maxrregcount=96 -Xfatbin -compress-all
|
||||
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xptxas -maxrregcount=96 -Xfatbin -compress-all
|
||||
# Use addprefix so that we can specify more than one path
|
||||
NVLDFLAGS := -L${CUDA_LIB} -lcudart -lrt
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
##### version
|
||||
NCCL_MAJOR := 2
|
||||
NCCL_MINOR := 9
|
||||
NCCL_PATCH := 9
|
||||
NCCL_MINOR := 10
|
||||
NCCL_PATCH := 3
|
||||
NCCL_SUFFIX :=
|
||||
PKG_REVISION := 1
|
||||
|
@ -1,3 +1,4 @@
|
||||
include/nccl.h /usr/include
|
||||
include/nccl_net.h /usr/include
|
||||
lib/libnccl.so /usr/lib/${pkg:MultiArch}
|
||||
lib/libnccl_static.a /usr/lib/${pkg:MultiArch}
|
||||
|
@ -7,7 +7,7 @@ Group: Development/Libraries
|
||||
License: BSD
|
||||
URL: http://developer.nvidia.com/nccl
|
||||
Source0: nccl_${nccl:Major}.${nccl:Minor}.${nccl:Patch}${nccl:Suffix}-${pkg:Revision}+cuda${cuda:Major}.${cuda:Minor}_${pkg:Arch}.txz
|
||||
Prereq: /sbin/ldconfig
|
||||
Requires(pre,preun): /sbin/ldconfig
|
||||
|
||||
%description
|
||||
NCCL (pronounced "Nickel") is a stand-alone library of standard collective
|
||||
@ -46,6 +46,7 @@ ln -s libnccl.so.${nccl:Major}.${nccl:Minor}.${nccl:Patch} $RPM_BUILD_ROOT/%{_li
|
||||
# devel
|
||||
install -m 755 -d $RPM_BUILD_ROOT/%{_includedir}
|
||||
install -m 644 include/nccl.h $RPM_BUILD_ROOT/%{_includedir}
|
||||
install -m 644 include/nccl_net.h $RPM_BUILD_ROOT/%{_includedir}
|
||||
ln -s libnccl.so.${nccl:Major} $RPM_BUILD_ROOT/%{_libdir}/libnccl.so
|
||||
|
||||
# static
|
||||
@ -64,6 +65,7 @@ rm -rf $RPM_BUILD_ROOT
|
||||
%doc LICENSE.txt
|
||||
%defattr(-,root,root,-)
|
||||
%{_includedir}/nccl.h
|
||||
%{_includedir}/nccl_net.h
|
||||
%{_libdir}/libnccl.so
|
||||
|
||||
%files static
|
||||
|
121
src/bootstrap.cc
121
src/bootstrap.cc
@ -43,7 +43,7 @@ ncclResult_t bootstrapNetInit() {
|
||||
}
|
||||
char line[SOCKET_NAME_MAXLEN+MAX_IF_NAME_SIZE+2];
|
||||
sprintf(line, " %s:", bootstrapNetIfName);
|
||||
socketToString(&bootstrapNetIfAddr.sa, line+strlen(line));
|
||||
socketToString(&bootstrapNetIfAddr, line+strlen(line));
|
||||
INFO(NCCL_INIT, "Bootstrap : Using%s", line);
|
||||
bootstrapNetInitDone = 1;
|
||||
}
|
||||
@ -55,27 +55,27 @@ ncclResult_t bootstrapNetInit() {
|
||||
/* Socket Interface Selection type */
|
||||
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
|
||||
|
||||
static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd) {
|
||||
struct sockaddr_in sockaddr;
|
||||
socklen_t socklen = sizeof(struct sockaddr_in);
|
||||
SYSCHECKVAL(accept(listenFd, (struct sockaddr*)&sockaddr, &socklen), "accept", *recvFd);
|
||||
static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd, union socketAddress *addr) {
|
||||
struct sockaddr *saddr = &addr->sa;
|
||||
socklen_t socklen = sizeof(union socketAddress);
|
||||
SYSCHECKVAL(accept(listenFd, saddr, &socklen), "accept", *recvFd);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Additional sync functions
|
||||
static ncclResult_t bootstrapNetSend(int fd, void* data, int size) {
|
||||
NCCLCHECK(socketSend(fd, &size, sizeof(int)));
|
||||
NCCLCHECK(socketSend(fd, data, size));
|
||||
static ncclResult_t bootstrapNetSend(int fd, union socketAddress *addr, void* data, int size) {
|
||||
NCCLCHECK(socketSend(fd, addr, &size, sizeof(int)));
|
||||
NCCLCHECK(socketSend(fd, addr, data, size));
|
||||
return ncclSuccess;
|
||||
}
|
||||
static ncclResult_t bootstrapNetRecv(int fd, void* data, int size) {
|
||||
static ncclResult_t bootstrapNetRecv(int fd, union socketAddress *addr, void* data, int size) {
|
||||
int recvSize;
|
||||
NCCLCHECK(socketRecv(fd, &recvSize, sizeof(int)));
|
||||
NCCLCHECK(socketRecv(fd, addr, &recvSize, sizeof(int)));
|
||||
if (recvSize > size) {
|
||||
WARN("Message truncated : received %d bytes instead of %d", recvSize, size);
|
||||
return ncclInternalError;
|
||||
}
|
||||
NCCLCHECK(socketRecv(fd, data, std::min(recvSize, size)));
|
||||
NCCLCHECK(socketRecv(fd, addr, data, std::min(recvSize, size)));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -111,8 +111,9 @@ static void *bootstrapRoot(void* args) {
|
||||
/* Receive addresses from all ranks */
|
||||
do {
|
||||
int tmpFd;
|
||||
NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &info, sizeof(info)), res, out);
|
||||
union socketAddress addr;
|
||||
NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd, &addr), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &addr, &info, sizeof(info)), res, out);
|
||||
close(tmpFd);
|
||||
|
||||
if (c == 0) {
|
||||
@ -145,7 +146,7 @@ static void *bootstrapRoot(void* args) {
|
||||
int next = (r+1) % nranks;
|
||||
int tmpSendFd;
|
||||
NCCLCHECKGOTO(connectAddress(&tmpSendFd, rankAddressesRoot+r), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddresses+next, sizeof(union socketAddress)), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddressesRoot+r, rankAddresses+next, sizeof(union socketAddress)), res, out);
|
||||
close(tmpSendFd);
|
||||
}
|
||||
TRACE(NCCL_INIT, "SENT OUT ALL %d HANDLES", nranks);
|
||||
@ -193,6 +194,7 @@ struct unexConn {
|
||||
int peer;
|
||||
int tag;
|
||||
int fd;
|
||||
union socketAddress addr;
|
||||
struct unexConn* next;
|
||||
};
|
||||
|
||||
@ -207,6 +209,7 @@ struct extState {
|
||||
int extListenFd;
|
||||
int extRingRecvFd;
|
||||
int extRingSendFd;
|
||||
union socketAddress extRingRecvAddr, extRingSendAddr;
|
||||
union socketAddress* peerCommAddresses;
|
||||
union socketAddress* peerAllocAddresses;
|
||||
struct unexConn* unexpectedConnections;
|
||||
@ -221,9 +224,9 @@ struct extState {
|
||||
|
||||
#define MAX_SEGMENTS 128
|
||||
|
||||
static ncclResult_t remoteAlloc(void** ptr, int fd) {
|
||||
static ncclResult_t remoteAlloc(void** ptr, int fd, union socketAddress *addr) {
|
||||
size_t size;
|
||||
NCCLCHECK(socketRecv(fd, &size, sizeof(size_t)));
|
||||
NCCLCHECK(socketRecv(fd, addr, &size, sizeof(size_t)));
|
||||
cudaIpcMemHandle_t devIpc;
|
||||
NCCLCHECK(ncclCudaCalloc((char**)ptr, size));
|
||||
cudaError_t res = cudaIpcGetMemHandle(&devIpc, *ptr);
|
||||
@ -233,9 +236,9 @@ static ncclResult_t remoteAlloc(void** ptr, int fd) {
|
||||
CUDACHECK(res);
|
||||
}
|
||||
// The CUDA IPC
|
||||
NCCLCHECK(socketSend(fd, &devIpc, sizeof(cudaIpcMemHandle_t)));
|
||||
NCCLCHECK(socketSend(fd, addr, &devIpc, sizeof(cudaIpcMemHandle_t)));
|
||||
// And the direct pointer
|
||||
NCCLCHECK(socketSend(fd, ptr, sizeof(void*)));
|
||||
NCCLCHECK(socketSend(fd, addr, ptr, sizeof(void*)));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -267,11 +270,12 @@ void* ncclRemoteMemAllocationService(void* args) {
|
||||
}
|
||||
if (pollfds[MAX_SEGMENTS].revents) {
|
||||
int s = 0;
|
||||
union socketAddress addr;
|
||||
while (segments[s] != NULL && s < MAX_SEGMENTS) s++;
|
||||
if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd) != ncclSuccess) {
|
||||
if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd, &addr) != ncclSuccess) {
|
||||
pollfds[s].fd = -1;
|
||||
} else {
|
||||
if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd) != ncclSuccess)) {
|
||||
if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd, &addr) != ncclSuccess)) {
|
||||
WARN("[Rem Allocator] Allocation failed (segment %d, fd %d)", s, pollfds[s].fd);
|
||||
close(pollfds[s].fd);
|
||||
pollfds[s].fd = -1;
|
||||
@ -306,10 +310,11 @@ ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id,
|
||||
int fd;
|
||||
ncclResult_t res;
|
||||
*id = -1;
|
||||
NCCLCHECK(connectAddress(&fd, state->peerAllocAddresses+rank));
|
||||
NCCLCHECKGOTO(socketSend(fd, &size, sizeof(size_t)), res, end);
|
||||
NCCLCHECKGOTO(socketRecv(fd, ipc, sizeof(cudaIpcMemHandle_t)), res, end);
|
||||
NCCLCHECKGOTO(socketRecv(fd, ptr, sizeof(void*)), res, end);
|
||||
union socketAddress *addr = state->peerAllocAddresses+rank;
|
||||
NCCLCHECK(connectAddress(&fd, addr));
|
||||
NCCLCHECKGOTO(socketSend(fd, addr, &size, sizeof(size_t)), res, end);
|
||||
NCCLCHECKGOTO(socketRecv(fd, addr, ipc, sizeof(cudaIpcMemHandle_t)), res, end);
|
||||
NCCLCHECKGOTO(socketRecv(fd, addr, ptr, sizeof(void*)), res, end);
|
||||
*id = fd;
|
||||
end:
|
||||
return res;
|
||||
@ -353,19 +358,19 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS
|
||||
// send info on my listening socket to root
|
||||
union socketAddress* rootAddr = (union socketAddress*)id;
|
||||
NCCLCHECK(connectAddress(&tmpSendFd, rootAddr));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, &info, sizeof(info)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, rootAddr, &info, sizeof(info)));
|
||||
close(tmpSendFd);
|
||||
|
||||
// get info on my "next" rank in the bootstrap ring from root
|
||||
union socketAddress extAddressNext;
|
||||
NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &extAddressNext, sizeof(extAddressNext)));
|
||||
union socketAddress addr;
|
||||
NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd, &addr));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &state->extRingSendAddr, sizeof(state->extRingSendAddr)));
|
||||
close(tmpRecvFd);
|
||||
close(extListenFdRoot);
|
||||
|
||||
NCCLCHECK(connectAddress(&state->extRingSendFd, &extAddressNext));
|
||||
NCCLCHECK(connectAddress(&state->extRingSendFd, &state->extRingSendAddr));
|
||||
// Accept the connect request from the previous rank in the AllGather ring
|
||||
NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd));
|
||||
NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd, &state->extRingRecvAddr));
|
||||
|
||||
// AllGather all listen handlers
|
||||
NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks));
|
||||
@ -403,9 +408,9 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
|
||||
size_t sslice = (rank - i + nranks) % nranks;
|
||||
|
||||
// Send slice to the right
|
||||
NCCLCHECK(bootstrapNetSend(state->extRingSendFd, data+sslice*size, size));
|
||||
NCCLCHECK(bootstrapNetSend(state->extRingSendFd, &state->extRingSendAddr, data+sslice*size, size));
|
||||
// Recv slice from the left
|
||||
NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, data+rslice*size, size));
|
||||
NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, &state->extRingRecvAddr, data+rslice*size, size));
|
||||
}
|
||||
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
|
||||
@ -415,21 +420,44 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
|
||||
ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) {
|
||||
struct extState* state = (struct extState*)commState;
|
||||
int tmpSendFd;
|
||||
NCCLCHECK(connectAddress(&tmpSendFd, state->peerCommAddresses+peer));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, &state->rank, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, &tag, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, data, size));
|
||||
union socketAddress *addr = state->peerCommAddresses+peer;
|
||||
NCCLCHECK(connectAddress(&tmpSendFd, addr));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, &state->rank, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, &tag, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, data, size));
|
||||
close(tmpSendFd);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd) {
|
||||
ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks) {
|
||||
if (nranks == 1) return ncclSuccess;
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag);
|
||||
|
||||
/* Simple intra process barrier
|
||||
*
|
||||
* Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet,
|
||||
* "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988"
|
||||
*/
|
||||
int data[1];
|
||||
for (int mask=1; mask<nranks; mask<<=1) {
|
||||
int src = (rank - mask + nranks) % nranks;
|
||||
int dst = (rank + mask) % nranks;
|
||||
NCCLCHECK(bootstrapSend(commState, ranks[dst], tag, data, sizeof(data)));
|
||||
NCCLCHECK(bootstrapRecv(commState, ranks[src], tag, data, sizeof(data)));
|
||||
}
|
||||
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d tag %x - DONE", rank, nranks, tag);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd, union socketAddress *addr) {
|
||||
// New unex
|
||||
struct unexConn* unex;
|
||||
NCCLCHECK(ncclCalloc(&unex, 1));
|
||||
unex->peer = peer;
|
||||
unex->tag = tag;
|
||||
unex->fd = fd;
|
||||
unex->addr = *addr;
|
||||
|
||||
// Enqueue
|
||||
struct unexConn* list = state->unexpectedConnections;
|
||||
@ -442,7 +470,7 @@ ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
int unexpectedDequeue(struct extState* state, int peer, int tag) {
|
||||
int unexpectedDequeue(struct extState* state, int peer, int tag, union socketAddress *addr) {
|
||||
struct unexConn* elem = state->unexpectedConnections;
|
||||
struct unexConn* prev = NULL;
|
||||
while (elem) {
|
||||
@ -453,6 +481,7 @@ int unexpectedDequeue(struct extState* state, int peer, int tag) {
|
||||
prev->next = elem->next;
|
||||
}
|
||||
int fd = elem->fd;
|
||||
*addr = elem->addr;
|
||||
free(elem);
|
||||
return fd;
|
||||
}
|
||||
@ -467,27 +496,29 @@ ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int s
|
||||
struct extState* state = (struct extState*)commState;
|
||||
|
||||
int tmpRecvFd;
|
||||
union socketAddress addr;
|
||||
|
||||
// Search unexpected connections first
|
||||
if ((tmpRecvFd = unexpectedDequeue(state, peer, tag)) != -1) {
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size));
|
||||
if ((tmpRecvFd = unexpectedDequeue(state, peer, tag, &addr)) != -1) {
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, ((char*)data), size));
|
||||
close(tmpRecvFd);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Then look for new connections
|
||||
while (1) {
|
||||
NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd));
|
||||
union socketAddress addr;
|
||||
NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd, &addr));
|
||||
int newPeer, newTag;
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newPeer, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newTag, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &newPeer, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &newTag, sizeof(int)));
|
||||
if (newPeer == peer && newTag == tag) {
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, ((char*)data), size));
|
||||
close(tmpRecvFd);
|
||||
return ncclSuccess;
|
||||
}
|
||||
// Unexpected connection. Save for later.
|
||||
NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, tmpRecvFd));
|
||||
NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, tmpRecvFd, &addr));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# See LICENSE.txt for license information
|
||||
#
|
||||
@ -32,7 +32,7 @@ all_deps: $(DEPENDFILES)
|
||||
$(RULESFILE) :
|
||||
@printf "Generating %-35s > %s\n" rules $@
|
||||
@mkdir -p $(OBJDIR)
|
||||
@./gen_rules.sh $(OBJDIR) > $@
|
||||
@CUDA_MAJOR=${CUDA_MAJOR} CUDA_MINOR=${CUDA_MINOR} ./gen_rules.sh $(OBJDIR) > $@
|
||||
|
||||
-include $(RULESFILE)
|
||||
|
||||
|
@ -5,204 +5,95 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads-WARP_SIZE;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
const int *ringRanks = ring->devUserRanks;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLGATHER_CHUNKSTEPS : 1));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2);
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*int(chunkSize);
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
T *inputBuf = (T*)args->sendbuff;
|
||||
T *outputBuf = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 1, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf);
|
||||
|
||||
ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t chunkOffset = gridOffset + bid*realChunkSize;
|
||||
|
||||
/////////////// begin AllGather steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
if (thisInput + chunkOffset == thisOutput + offset) { // In place
|
||||
prims.directSend(thisInput+chunkOffset, offset, nelem);
|
||||
} else {
|
||||
prims.directCopySend(thisInput+chunkOffset, thisOutput+offset, offset, nelem);
|
||||
}
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
prims.directRecvCopySend(thisOutput+offset, offset, nelem);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
rankDest = ring->devUserRanks[1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
// Final wait/copy.
|
||||
prims.directRecv(thisOutput+offset, offset, nelem);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset,nChannels));
|
||||
realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
}
|
||||
};
|
||||
else if (Proto::Id == NCCL_PROTO_LL)
|
||||
realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize;
|
||||
else if (Proto::Id == NCCL_PROTO_LL128)
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128);
|
||||
realChunkSize = int(realChunkSize);
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
ssize_t chunkOffset = gridOffset + int(bid*realChunkSize);
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
/////////////// begin AllGather steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ringRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
if (size-gridOffset < loopSize) {
|
||||
chunkSize = args->coll.lastChunkSize;
|
||||
}
|
||||
ssize_t chunkOffset = gridOffset + bid*chunkSize;
|
||||
|
||||
/////////////// begin AllGather steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(chunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
if (thisInput + chunkOffset == thisOutput + offset) { // In place
|
||||
LLprims.send(thisInput+chunkOffset, nelem);
|
||||
} else {
|
||||
LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: final store
|
||||
rankDest = ring->devUserRanks[1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
if (inputBuf + chunkOffset == outputBuf + offset) { // In place
|
||||
prims.directSend(chunkOffset, offset, nelem);
|
||||
} else {
|
||||
prims.directCopySend(chunkOffset, offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
|
||||
|
||||
ssize_t chunkOffset = gridOffset + bid*chunkSize;
|
||||
|
||||
/////////////// begin AllGather steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(chunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[0];
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
rankDest = ringRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
if (thisInput + chunkOffset == thisOutput + offset) { // In place
|
||||
LLprims.send(thisInput+chunkOffset, nelem);
|
||||
} else {
|
||||
LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: final store
|
||||
rankDest = ring->devUserRanks[1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
prims.directRecvCopySend(offset, offset, nelem);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
rankDest = ringRanks[1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
// Final wait/copy.
|
||||
prims.directRecv(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_TREE, PROTO, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_COLLNET, PROTO, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -5,566 +5,384 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads-WARP_SIZE;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
int ringIx = ring->index;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLREDUCE_CHUNKSTEPS : 1));
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*nranks*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
int minChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_LL)
|
||||
minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T));
|
||||
if (Proto::Id == NCCL_PROTO_LL128) {
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2;
|
||||
}
|
||||
|
||||
ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 1, Proto> prims
|
||||
(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
|
||||
ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t chunkOffset = gridOffset + bid*nranks*realChunkSize;
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*nranks));
|
||||
realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
else
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize);
|
||||
realChunkSize = int(realChunkSize);
|
||||
|
||||
auto calcOffset = [&]__device__(int chunk)->ssize_t {
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE)
|
||||
return gridOffset + bid*nranks*realChunkSize + chunk*realChunkSize;
|
||||
else
|
||||
return gridOffset + (chunk*nChannels + bid)*realChunkSize;
|
||||
};
|
||||
auto modRanks = [&]__device__(int r)->int {
|
||||
return r - (r >= nranks ? nranks : 0);
|
||||
};
|
||||
|
||||
/////////////// begin AllReduce steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem;
|
||||
int chunk;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
chunk = ring->devUserRanks[nranks-1];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = modRanks(ringIx + nranks-1);
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
prims.send(thisInput+offset, nelem);
|
||||
prims.send(offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = modRanks(ringIx + nranks-j);
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data and push to the next GPU
|
||||
chunk = ring->devUserRanks[0];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = ringIx + 0;
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem);
|
||||
prims.directRecvReduceCopySend(offset, offset, offset, nelem, /*postOp=*/true);
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = modRanks(ringIx + nranks-j);
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
prims.directRecvCopySend(thisOutput+offset, offset, nelem);
|
||||
prims.directRecvCopySend(offset, offset, nelem);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
chunk = ring->devUserRanks[1];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = modRanks(ringIx + 1);
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
// Final wait/copy.
|
||||
prims.directRecv(thisOutput+offset, offset, nelem);
|
||||
prims.directRecv(offset, nelem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runTreeUpDown(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads-2*WARP_SIZE;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclTree* tree = &channel->tree;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
int chunkSize = args->coll.lastChunkSize;
|
||||
const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
ncclTree *tree = &ncclShmem.channel.tree;
|
||||
ssize_t chunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? args->coll.lastChunkSize
|
||||
/* LL & LL128 */ : Proto::calcBytePerStep()/sizeof(T));
|
||||
const ssize_t minChunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads-2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T))
|
||||
/* LL & LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
const ssize_t loopSize = int(nChannels*chunkSize);
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
if (loopSize > size) {
|
||||
chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
|
||||
}
|
||||
if (loopSize > size)
|
||||
chunkSize = divUp((int)size, int(nChannels*minChunkSize))*int(minChunkSize);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
#if 1
|
||||
if (tid < nthreads+WARP_SIZE) {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, 1, 0, FUNC>
|
||||
prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->up == -1) {
|
||||
prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else if (tree->down[0] == -1) {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
{ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto> prims
|
||||
(tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
}
|
||||
else if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.send(offset, nelem);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tid < nthreads+WARP_SIZE) {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_DEV_ARITY, 1, FUNC>
|
||||
prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->up == -1) {
|
||||
prims.directSend(thisOutput+offset, offset, nelem);
|
||||
} else if (tree->down[0] == -1) {
|
||||
prims.directRecv(thisOutput+offset, offset, nelem);
|
||||
} else {
|
||||
prims.directRecvCopySend(thisOutput+offset, offset, nelem);
|
||||
{ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto> prims
|
||||
(tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directSendFromOutput(offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
else if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecv(offset, nelem);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecvCopySend(offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
int nthreadsSplit = nthreads/2;
|
||||
if (nthreadsSplit >= 256) nthreadsSplit += 64;
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runTreeSplit(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclTree *tree = &ncclShmem.channel.tree;
|
||||
ssize_t chunkSize = int(
|
||||
Proto::Id != NCCL_PROTO_LL ? args->coll.lastChunkSize
|
||||
: Proto::calcBytePerStep()/sizeof(T));
|
||||
const ssize_t minChunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads - 2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T)) :
|
||||
Proto::Id == NCCL_PROTO_LL ? nthreads*(Proto::calcBytePerGrain()/sizeof(T))
|
||||
/* LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T))/8);
|
||||
const ssize_t loopSize = int(nChannels*chunkSize);
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
int nthreadsSplit;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
nthreadsSplit = nthreads/2;
|
||||
if (nthreadsSplit >= 256) nthreadsSplit += 64;
|
||||
} else { // LL & LL128
|
||||
// Receiving from up to 3 sources is more compute intensive than sending
|
||||
// to 3 dests. Use 70% for reduce and 30% for bcast.
|
||||
nthreadsSplit = (nthreads*7/(10*WARP_SIZE))*WARP_SIZE;
|
||||
}
|
||||
|
||||
if (loopSize > size)
|
||||
chunkSize = divUp((int)size, nChannels*int(minChunkSize))*int(minChunkSize);
|
||||
|
||||
if (tree->up == -1) {
|
||||
if (tid < nthreads+WARP_SIZE) {
|
||||
// ReduceAndBroadcast : max number of recv is 3, max number of send is 3
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, NCCL_MAX_DEV_ARITY, 1, FUNC>
|
||||
prims(tid, nthreads, tree->down, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
// Reduce and broadcast. Max number of recv is 3, max number of send is 3
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto>
|
||||
prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecvReduceCopySend(offset, offset, offset, nelem, /*doPost=*/true);
|
||||
}
|
||||
}
|
||||
else if (tid < nthreadsSplit) {
|
||||
/* Reduce up. Max number of recv is 3, max number of send is 1 (binary tree + local).
|
||||
* Why Direct=1????
|
||||
* Answer: Because despite not performing any direct operations, the ctor
|
||||
* must assume Direct so that it can exchange direct pointers with remote ctors
|
||||
* that are Direct, otherwise it hangs. A cleaner solution would be to seperate
|
||||
* into DirectRecv and DirectSend capabilities, this ctor would have both=0,
|
||||
* but the ctor above for tree roots would be DirectRecv=0 DirectSend=1.
|
||||
*/
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/1, Proto>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem);
|
||||
prims.send(offset, nelem);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (tid < nthreadsSplit + WARP_SIZE) {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, 1, 0, FUNC>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->down[0] == -1) {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_DEV_ARITY, 1, FUNC>
|
||||
prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &tree->up, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 2);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->down[0] == -1) {
|
||||
prims.directRecv(thisOutput+offset, offset, nelem);
|
||||
} else {
|
||||
prims.directRecvCopySend(thisOutput+offset, offset, nelem);
|
||||
}
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
// Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/1, Proto>
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecv(offset, nelem);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecvCopySend(offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
#define COLLNET_COPY_THREADS 96
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
#if CUDART_VERSION >= 11020 && CUDART_VERSION < 11040 && __CUDA_ARCH__ >= 800
|
||||
runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(args);
|
||||
#else
|
||||
runTreeSplit<T, RedOp, ProtoSimple<1, 1>>(args);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
static constexpr int COLLNET_COPY_THREADS = 96;
|
||||
const int tid = threadIdx.x;
|
||||
//const int nthreads = args->nThreads-3*WARP_SIZE;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclDirect* tree = &channel->collTree;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
int chunkSize = args->coll.lastChunkSize;
|
||||
struct ncclDirect* tree = &ncclShmem.channel.collTree;
|
||||
const ssize_t chunkSize = int(args->coll.lastChunkSize);
|
||||
const ssize_t size = args->coll.count;
|
||||
const ssize_t loopSize = nChannels*tree->nHeads*chunkSize;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
const int hasUp = (tree->up[0] >= 0) ? 1 : 0;
|
||||
const int hasDn = (tree->down[0] >= 0) ? 1 : 0;
|
||||
const int nThreadsScatter = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 3*COLLNET_COPY_THREADS : 0;
|
||||
const int nThreadsGather = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 2*COLLNET_COPY_THREADS : 0;
|
||||
const int nThreadsBcast = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 0 : 2*COLLNET_COPY_THREADS;
|
||||
// Gather does not need sync threads, sparing one more warp for reduce
|
||||
const int nThreadsReduce = NCCL_SIMPLE_MAX_NTHREADS + WARP_SIZE - nThreadsScatter - nThreadsGather - nThreadsBcast;
|
||||
const int nThreadsScatter = WARP_SIZE + ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 3*COLLNET_COPY_THREADS : 0);
|
||||
const int nThreadsGather = ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 2*COLLNET_COPY_THREADS : 0);
|
||||
const int nThreadsBcast = WARP_SIZE + ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 0 : 2*COLLNET_COPY_THREADS);
|
||||
const int nThreadsReduce = args->nThreads - nThreadsScatter - nThreadsGather - nThreadsBcast;
|
||||
const int tidStartBcast = nThreadsGather;
|
||||
const int tidStartScatter = tidStartBcast + nThreadsBcast + WARP_SIZE;
|
||||
const int tidStartReduce = tidStartScatter + nThreadsScatter + WARP_SIZE;
|
||||
const int tidStartScatter = tidStartBcast + nThreadsBcast;
|
||||
const int tidStartReduce = tidStartScatter + nThreadsScatter;
|
||||
|
||||
using Proto = ProtoSimple<1, 1>;
|
||||
|
||||
if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) {
|
||||
// Scatter
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 0, NCCL_MAX_DIRECT_ARITY, 0, FUNC>
|
||||
prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 4);
|
||||
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, args->sendbuff, args->recvbuff, 2*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize;
|
||||
int nelem = min(tree->nHeads*chunkSize, size-offset);
|
||||
prims.scatter(thisInput+offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
prims.scatter(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
}
|
||||
} else if (tid >= tidStartReduce && tree->out != -1) {
|
||||
// Reduce, send to network
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DIRECT_ARITY, 1, 0, FUNC>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, NULL, stepSize, channel, comm, ncclShmem->ptrs, 6);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (hasDn) {
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
} else {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
if (hasDn) {
|
||||
// Reduce, send to network
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 1>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
} else {
|
||||
// Directly send to network
|
||||
Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, nullptr, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.send(offset, nelem);
|
||||
}
|
||||
}
|
||||
} else if (tid < tidStartBcast && hasUp) {
|
||||
// Gather
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DIRECT_ARITY, 0, 0, FUNC>
|
||||
prims(tid, nThreadsGather, tree->up, NULL, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 0>, /*Direct=*/0, Proto>
|
||||
prims(tid, nThreadsGather, tree->up, NULL, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize;
|
||||
int nelem = min(tree->nHeads*chunkSize, size-offset);
|
||||
prims.gather(thisOutput+offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
prims.gather(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
}
|
||||
} else if (tid >= tidStartBcast && tid < tidStartScatter && tree->out != -1) {
|
||||
// Recv from network, broadcast
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_DIRECT_ARITY, 0, FUNC>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 2);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (hasDn) {
|
||||
prims.recvCopySend(thisOutput+offset, nelem);
|
||||
} else {
|
||||
prims.recv(thisOutput+offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T);
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*nranks*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
|
||||
|
||||
/////////////// begin AllReduce steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem;
|
||||
int chunk;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
chunk = ring->devUserRanks[nranks-1];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data and push to the next GPU
|
||||
chunk = ring->devUserRanks[0];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
chunk = ring->devUserRanks[1];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
// Here we need to copy from buffer to this output.
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclTree* tree = &channel->tree;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
if (loopSize > size) {
|
||||
chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
|
||||
}
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
do {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclLLPrimitives<T, FUNC, NCCL_MAX_DEV_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->up == -1) {
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else if (tree->down[0] == -1) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
}
|
||||
} while(0);
|
||||
|
||||
do {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_DEV_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->up == -1) {
|
||||
LLprims.send(thisOutput+offset, nelem);
|
||||
} else if (tree->down[0] == -1) {
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
}
|
||||
} while(0);
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) { }
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*nranks*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
|
||||
|
||||
/////////////// begin AllReduce steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem;
|
||||
int chunk;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
chunk = ring->devUserRanks[nranks-1];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data and push to the next GPU
|
||||
chunk = ring->devUserRanks[0];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
chunk = ring->devUserRanks[1];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
// Here we need to copy from buffer to this output.
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclTree* tree = &channel->tree;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = args->coll.lastChunkSize;
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
int nthreadsSplit = NCCL_LL128_SPLIT(nthreads);
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
if (loopSize > size) {
|
||||
chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
|
||||
}
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
if (tree->up == -1) {
|
||||
// ReduceAndBroadcast : max number of recv is 3, max number of send is 3
|
||||
ncclLL128Primitives<T, FUNC, NCCL_MAX_DEV_ARITY, NCCL_MAX_DEV_ARITY> LLprims(tid, nthreads, tree->down, tree->down, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
|
||||
}
|
||||
} else {
|
||||
if (tid < nthreadsSplit) {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclLL128Primitives<T, FUNC, NCCL_MAX_DEV_ARITY, 1> LLprims(tid, nthreadsSplit, tree->down, &tree->up, stepSize, channel, comm);
|
||||
if (hasDn) {
|
||||
// Recv from network, broadcast
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->down[0] == -1) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
prims.recvCopySend(offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
} else {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_DEV_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, stepSize, channel, comm);
|
||||
// Recv from network (no post thread needed)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, nullptr, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->down[0] == -1) {
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
prims.recv(offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) { }
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runTreeSplit<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runTreeSplit<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
@ -5,158 +5,78 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads-WARP_SIZE;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * BROADCAST_CHUNKSTEPS;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? BROADCAST_CHUNKSTEPS : 1));
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
T *inputBuf = (T*)args->sendbuff;
|
||||
T *outputBuf = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf);
|
||||
|
||||
ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels));
|
||||
realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
else if (Proto::Id == NCCL_PROTO_LL)
|
||||
realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize;
|
||||
else if (Proto::Id == NCCL_PROTO_LL128)
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128);
|
||||
realChunkSize = int(realChunkSize);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t offset = gridOffset + bid*realChunkSize;
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
ssize_t offset = gridOffset + int(bid*realChunkSize);
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
|
||||
if (rank == root) {
|
||||
if (thisInput == thisOutput) {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
prims.copySend(thisInput+offset, thisOutput+offset, nelem);
|
||||
}
|
||||
} else if (nextRank == root) {
|
||||
prims.recv(thisOutput+offset, nelem);
|
||||
if (rank == root) {
|
||||
if (inputBuf == outputBuf) {
|
||||
prims.send(offset, nelem);
|
||||
} else {
|
||||
prims.recvCopySend(thisOutput+offset, nelem);
|
||||
prims.copySend(offset, offset, nelem);
|
||||
}
|
||||
} else if (nextRank == root) {
|
||||
prims.recv(offset, nelem);
|
||||
} else {
|
||||
prims.recvCopySend(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
if (size-gridOffset < loopSize) {
|
||||
chunkSize = args->coll.lastChunkSize;
|
||||
}
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (rank == root) {
|
||||
if (thisInput == thisOutput) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
|
||||
}
|
||||
} else if (nextRank == root) {
|
||||
LLprims.recv(thisOutput + offset, nelem);
|
||||
} else {
|
||||
LLprims.recvCopySend(thisOutput + offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (rank == root) {
|
||||
if (thisInput == thisOutput) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
|
||||
}
|
||||
} else if (nextRank == root) {
|
||||
LLprims.recv(thisOutput + offset, nelem);
|
||||
} else {
|
||||
LLprims.recvCopySend(thisOutput + offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
@ -10,108 +10,152 @@
|
||||
#include "collectives.h"
|
||||
#include "devcomm.h"
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#define COLL_UNROLL 8
|
||||
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
|
||||
#else
|
||||
#define COLL_UNROLL 4
|
||||
#define NCCL_MAX_DEV_ARITY NCCL_MAX_TREE_ARITY
|
||||
#endif
|
||||
|
||||
// Exit If Abort Barrier across CTA: make sure all threads exit consistently
|
||||
// Each thread sets a predicate to true if abort == 1
|
||||
// all CTA's threads enter the barrier and do a popc on their predicates being True
|
||||
// If any of the thread's predicate was True, all the threads call exit()
|
||||
static inline __device__ void exitIfAbortBarrier(int abort) {
|
||||
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
|
||||
|
||||
__device__ inline bool barrierReduceAny(int bit) {
|
||||
uint32_t popc;
|
||||
asm ("{");
|
||||
asm volatile (" .reg .pred barr_pred;");
|
||||
asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abort));
|
||||
asm volatile (" bar.red.popc.u32 %0, 0, barr_pred;" : "=r"(popc));
|
||||
asm ("}");
|
||||
if (popc) { asm volatile ("exit;"); }
|
||||
asm ("{"
|
||||
".reg .pred barr_pred;"
|
||||
"setp.eq.u32 barr_pred, %1, 1;"
|
||||
"bar.red.popc.u32 %0, 0, barr_pred;"
|
||||
"}" : "=r"(popc) : "r"(bit));
|
||||
return popc != 0;
|
||||
}
|
||||
|
||||
typedef void(*ncclKern_t)(struct ncclWorkElem* args);
|
||||
extern __device__ ncclKern_t ncclFuncs[];
|
||||
template<typename T>
|
||||
__device__ int copyToShmem(T *dst, T const *src, int turn=0) {
|
||||
static_assert(sizeof(uint64_t) <= alignof(T), "Uhoh");
|
||||
uint64_t *d = reinterpret_cast<uint64_t*>(dst);
|
||||
uint64_t const *s = reinterpret_cast<uint64_t const*>(src);
|
||||
int t = threadIdx.x - turn;
|
||||
if (t < 0) t += blockDim.x;
|
||||
int n = sizeof(T)/sizeof(uint64_t);
|
||||
|
||||
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) {
|
||||
int* d = (int*)dst;
|
||||
int* s = (int*)src;
|
||||
for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o];
|
||||
int delta = (n + WARP_SIZE-1) & -WARP_SIZE; // round up to warp lane 0
|
||||
if (delta < blockDim.x) {
|
||||
turn += delta;
|
||||
if (turn >= blockDim.x) turn -= blockDim.x;
|
||||
}
|
||||
else
|
||||
turn = 0;
|
||||
|
||||
n -= t;
|
||||
d += t;
|
||||
s += t;
|
||||
#pragma unroll
|
||||
for (int i=0; i < divUp(sizeof(T), WARP_SIZE*sizeof(uint64_t)); i++) {
|
||||
if (n > 0) {
|
||||
*d = *s;
|
||||
d += blockDim.x;
|
||||
s += blockDim.x;
|
||||
n -= blockDim.x;
|
||||
}
|
||||
}
|
||||
return turn;
|
||||
}
|
||||
|
||||
static __device__ void load_coll(struct ncclWork* localWork, struct ncclWork *hostWork, struct ncclWork* workFifo, int tid, struct ncclDevComm* comm) {
|
||||
load_parallel(localWork, workFifo, sizeof(struct ncclWork), tid);
|
||||
// Check whether the last operation was aborted and make sure all threads exit
|
||||
int abort = tid == 0 ? *(comm->abortFlag) : 0;
|
||||
exitIfAbortBarrier(abort);
|
||||
if (tid == 0) hostWork->elems[0].active = 0;
|
||||
}
|
||||
|
||||
template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
|
||||
struct RunWorkElement {
|
||||
__device__ void run(ncclWorkElem*) {
|
||||
// Put NOT IMPLEMENTED behavior here.
|
||||
}
|
||||
};
|
||||
|
||||
struct ncclShmemPtrs {
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
|
||||
struct RunWork {
|
||||
__device__ void run(ncclWork *w) {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll 1
|
||||
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) {
|
||||
if (tid < w->elems[e].nThreads)
|
||||
RunWorkElement<Fn, T, RedOp, Algo, Proto>().run(&w->elems[e]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
typedef void(*ncclKern_t)();
|
||||
extern __device__ ncclKern_t ncclFuncs[];
|
||||
|
||||
struct ncclShmemGroup {
|
||||
ncclConnInfo *recvConns[NCCL_MAX_DIRECT_ARITY];
|
||||
ncclConnInfo *sendConns[NCCL_MAX_DIRECT_ARITY];
|
||||
void* srcs[NCCL_MAX_DIRECT_ARITY+1];
|
||||
void* dsts[NCCL_MAX_DIRECT_ARITY+1];
|
||||
};
|
||||
|
||||
struct ncclShmemData {
|
||||
union {
|
||||
volatile uint64_t data[NCCL_LL128_SHMEM_SIZE];
|
||||
struct ncclShmemPtrs ptrs[NCCL_MAX_GROUPS];
|
||||
uint64_t ll128warp[NCCL_LL128_MAX_NTHREADS/WARP_SIZE][NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE];
|
||||
struct ncclShmemGroup groups[NCCL_MAX_GROUPS];
|
||||
};
|
||||
struct ncclWork localWork;
|
||||
ncclDevComm comm;
|
||||
ncclChannel channel;
|
||||
ncclWork work;
|
||||
};
|
||||
|
||||
extern __device__ struct ncclShmemData *ncclShmem;
|
||||
template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL, int FINDEX>
|
||||
__device__ void ncclKernel(struct ncclWorkElem first) {
|
||||
extern __shared__ ncclShmemData ncclShmem;
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int FnIndex>
|
||||
__device__ void ncclKernel(ncclWorkElem first) {
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
__shared__ struct ncclShmemData shmem;
|
||||
ncclShmem = &shmem;
|
||||
|
||||
auto f = ncclFunction<FUNCTION, ALGO, PROTO, REDOP, T, UNROLL>();
|
||||
int turn = copyToShmem(&ncclShmem.comm, first.comm);
|
||||
// get address of channel without incurring indirect load from ncclDevCom::channels
|
||||
ncclChannel *channel = &((ncclDevCommAndChannels*)first.comm)->channels[bid];
|
||||
turn = copyToShmem(&ncclShmem.channel, channel, turn);
|
||||
|
||||
struct ncclDevComm* comm = first.comm;
|
||||
struct ncclChannel* channel = comm->channels+bid;
|
||||
struct ncclWorkElem* w = NULL;
|
||||
// To optimize for latency, (only) the first operation is passed as argument.
|
||||
if (bid == 0 && first.active != 0) {
|
||||
turn = copyToShmem(&ncclShmem.work.elems[0], &first, turn);
|
||||
if (tid == 0) ncclShmem.work.elems[1].active = 0;
|
||||
}
|
||||
__syncthreads(); // publish ncclShmem
|
||||
|
||||
/* To optimize for latency, (only) the first operation is passed as argument.*/
|
||||
if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) w = &first;
|
||||
ncclWork *workFifoHost = ncclShmem.channel.workFifo;
|
||||
ncclWork *workFifoDev = ncclShmem.channel.workFifoDev;
|
||||
int workFifoIx = ncclShmem.channel.index;
|
||||
|
||||
while (1) {
|
||||
if (w == NULL) {
|
||||
w = shmem.localWork.elems;
|
||||
__syncthreads();
|
||||
load_coll(&shmem.localWork, channel->workFifo+channel->index, channel->workFifoDev+channel->index, tid, comm);
|
||||
if (bid == 0 && first.active != 0)
|
||||
goto SkipLoadWork;
|
||||
|
||||
while (true) {
|
||||
copyToShmem(&ncclShmem.work, &workFifoDev[workFifoIx]); // turn no longer helps
|
||||
{ // Check whether the last operation was aborted and make sure all threads exit
|
||||
int aborted = tid == 0 ? *ncclShmem.comm.abortFlag : 0;
|
||||
if (barrierReduceAny(aborted)) // publish ncclShmem.work
|
||||
break;
|
||||
if (tid == 0)
|
||||
workFifoHost[workFifoIx].elems[0].active = 0;
|
||||
}
|
||||
if (tid < w->nThreads) {
|
||||
if (w->funcIndex == FINDEX) {
|
||||
f.run(w);
|
||||
} else {
|
||||
ncclFuncs[w->funcIndex](w);
|
||||
}
|
||||
}
|
||||
if (tid == 0) channel->index = (channel->index+1) % NCCL_MAX_OPS;
|
||||
if (w->active == 2) {
|
||||
return;
|
||||
}
|
||||
w = NULL;
|
||||
|
||||
SkipLoadWork:
|
||||
workFifoIx = (workFifoIx + 1)%NCCL_MAX_OPS;
|
||||
if (tid == 0)
|
||||
channel->index = workFifoIx; // write back to real channel, not shmem shadow
|
||||
|
||||
if (ncclShmem.work.elems[0].funcIndex == FnIndex)
|
||||
RunWork<Fn, T, RedOp, Algo, Proto>().run(&ncclShmem.work);
|
||||
else
|
||||
ncclFuncs[ncclShmem.work.elems[0].funcIndex]();
|
||||
|
||||
if (ncclShmem.work.elems[0].active == 2)
|
||||
break;
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// Only generate kernels for SUM
|
||||
#if NCCL_OP == 0
|
||||
#define IMPL_COLL_KERN(func, algo, proto, redop, type, fIndex) \
|
||||
__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first) { \
|
||||
ncclKernel<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL, fIndex>(first); \
|
||||
__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(ncclWorkElem first) { \
|
||||
ncclKernel<ncclFunc##func, type, Func##redop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex>(first); \
|
||||
}
|
||||
#else
|
||||
#define IMPL_COLL_KERN(func, algo, proto, redop, type, fInded)
|
||||
@ -119,9 +163,8 @@ __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkEl
|
||||
|
||||
// Examples : AllReduce, RING, LL, Sum, uint8
|
||||
#define IMPL_COLL_FUNC(func, algo, proto, redop, type) \
|
||||
__device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args) { \
|
||||
auto f = ncclFunction<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL>(); \
|
||||
f.run(args); \
|
||||
__device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)() { \
|
||||
RunWork<ncclFunc##func, type, Func##redop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem.work); \
|
||||
}
|
||||
|
||||
// Only generate inline kernels for LL
|
||||
@ -154,6 +197,8 @@ __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkEl
|
||||
#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, float, ncclFloat32)
|
||||
#elif NCCL_TYPE == 8
|
||||
#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, double, ncclFloat64)
|
||||
#elif NCCL_TYPE == 9 && defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, __nv_bfloat16, ncclBfloat16)
|
||||
#endif
|
||||
|
||||
// Reduction define all functions
|
||||
@ -165,6 +210,8 @@ __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkEl
|
||||
#define IMPL_COLL_R(func) IMPL_COLL2(func, Min);
|
||||
#elif NCCL_OP == 3
|
||||
#define IMPL_COLL_R(func) IMPL_COLL2(func, Max);
|
||||
#elif NCCL_OP == 4
|
||||
#define IMPL_COLL_R(func) IMPL_COLL2(func, Avg);
|
||||
#endif
|
||||
|
||||
#if NCCL_OP == 0 && NCCL_TYPE == 0
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -24,10 +24,19 @@ inline __device__ void loadPtr(void** ptr, T* &v) {
|
||||
|
||||
typedef uint64_t PackType;
|
||||
|
||||
template<typename Fn>
|
||||
struct FuncTraits /*{
|
||||
__device__ static Fn make();
|
||||
__device__ static T preOp(Fn, T);
|
||||
__device__ static T postOp(Fn, T);
|
||||
}*/;
|
||||
|
||||
// unpack x and y to elements of type T and apply FUNC to each element
|
||||
template<class FUNC, typename T>
|
||||
struct MULTI {
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const;
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const;
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const;
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const;
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@ -41,17 +50,39 @@ struct MULTI<FUNC, int8_t> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
// for char, we do these as vector ops
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int8_t elt[8];
|
||||
} u;
|
||||
u.pack = x;
|
||||
#pragma unroll
|
||||
for (int i=0; i < 8; i++)
|
||||
u.elt[i] = FuncTraits<FUNC>().preOp(fn, u.elt[i]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int8_t elt[8];
|
||||
} u;
|
||||
u.pack = x;
|
||||
#pragma unroll
|
||||
for (int i=0; i < 8; i++)
|
||||
u.elt[i] = FuncTraits<FUNC>().postOp(fn, u.elt[i]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@ -65,17 +96,39 @@ struct MULTI<FUNC, uint8_t> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
// for char, we do these as vector ops
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint8_t elt[8];
|
||||
} u;
|
||||
u.pack = x;
|
||||
#pragma unroll
|
||||
for (int i=0; i < 8; i++)
|
||||
u.elt[i] = FuncTraits<FUNC>().preOp(fn, u.elt[i]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint8_t elt[8];
|
||||
} u;
|
||||
u.pack = x;
|
||||
#pragma unroll
|
||||
for (int i=0; i < 8; i++)
|
||||
u.elt[i] = FuncTraits<FUNC>().postOp(fn, u.elt[i]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@ -89,16 +142,36 @@ struct MULTI<FUNC, int32_t> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int32_t elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().preOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().preOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int32_t elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().postOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().postOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@ -112,16 +185,36 @@ struct MULTI<FUNC, uint32_t> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint32_t elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().preOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().preOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint32_t elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().postOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().postOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@ -129,22 +222,69 @@ struct MULTI<FUNC, half> {
|
||||
static_assert(sizeof(PackType) == 4 * sizeof(half),
|
||||
"PackType must be four times the size of half.");
|
||||
|
||||
struct PackHalf2 {
|
||||
half2 a, b;
|
||||
union Converter {
|
||||
PackType pack;
|
||||
half2 h2[2];
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
struct PackHalf2 cx, cy, cr;
|
||||
cx = *(reinterpret_cast<const struct PackHalf2*>(&x));
|
||||
cy = *(reinterpret_cast<const struct PackHalf2*>(&y));
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
|
||||
return *(reinterpret_cast<PackType*>(&cr));
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
Converter cx, cy, cr;
|
||||
cx.pack = x;
|
||||
cy.pack = y;
|
||||
cr.h2[0] = fn(cx.h2[0], cy.h2[0]);
|
||||
cr.h2[1] = fn(cx.h2[1], cy.h2[1]);
|
||||
return cr.pack;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
Converter c;
|
||||
c.pack = x;
|
||||
c.h2[0] = FuncTraits<FUNC>().preOp(fn, c.h2[0]);
|
||||
c.h2[1] = FuncTraits<FUNC>().preOp(fn, c.h2[1]);
|
||||
return c.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
Converter c;
|
||||
c.pack = x;
|
||||
c.h2[0] = FuncTraits<FUNC>().postOp(fn, c.h2[0]);
|
||||
c.h2[1] = FuncTraits<FUNC>().postOp(fn, c.h2[1]);
|
||||
return c.pack;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, __nv_bfloat16> {
|
||||
static_assert(sizeof(PackType) == 4 * sizeof(__nv_bfloat16),
|
||||
"PackType must be four times the size of __nv_bfloat16.");
|
||||
|
||||
union Converter {
|
||||
PackType pack;
|
||||
__nv_bfloat162 h2[2];
|
||||
};
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
Converter cx, cy, cr;
|
||||
cx.pack = x;
|
||||
cy.pack = y;
|
||||
cr.h2[0] = fn(cx.h2[0], cy.h2[0]);
|
||||
cr.h2[1] = fn(cx.h2[1], cy.h2[1]);
|
||||
return cr.pack;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
Converter c;
|
||||
c.pack = x;
|
||||
c.h2[0] = FuncTraits<FUNC>().preOp(fn, c.h2[0]);
|
||||
c.h2[1] = FuncTraits<FUNC>().preOp(fn, c.h2[1]);
|
||||
return c.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
Converter c;
|
||||
c.pack = x;
|
||||
c.h2[0] = FuncTraits<FUNC>().postOp(fn, c.h2[0]);
|
||||
c.h2[1] = FuncTraits<FUNC>().postOp(fn, c.h2[1]);
|
||||
return c.pack;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, float> {
|
||||
static_assert(sizeof(PackType) == 2 * sizeof(float),
|
||||
@ -156,46 +296,120 @@ struct MULTI<FUNC, float> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
float elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().preOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().preOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
float elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().postOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().postOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, double> {
|
||||
static_assert(sizeof(PackType) == sizeof(double),
|
||||
"PackType must be the same size as double.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
double rv = FUNC()(__longlong_as_double(x), __longlong_as_double(y));
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
double rv = fn(__longlong_as_double(x), __longlong_as_double(y));
|
||||
return __double_as_longlong(rv);
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
double elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().preOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
double elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().postOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, uint64_t> {
|
||||
static_assert(sizeof(PackType) == sizeof(uint64_t),
|
||||
"PackType must be the same size as uint64_t.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
uint64_t rv = FUNC()(x, y);
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
uint64_t rv = fn(x, y);
|
||||
return rv;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint64_t elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().preOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint64_t elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().postOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, int64_t> {
|
||||
static_assert(sizeof(PackType) == sizeof(int64_t),
|
||||
"PackType must be the same size as int64_t.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
int64_t rv = FUNC()((int64_t)x, (int64_t)y);
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
int64_t rv = fn((int64_t)x, (int64_t)y);
|
||||
return rv;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int64_t elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().preOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int64_t elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().postOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T> inline __device__
|
||||
@ -234,13 +448,35 @@ void vStore<half>(volatile half* ptr, const half val) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<> inline __device__
|
||||
__nv_bfloat16 vFetch<__nv_bfloat16>(const volatile __nv_bfloat16* ptr) {
|
||||
__nv_bfloat16 r;
|
||||
r = ((__nv_bfloat16*)ptr)[0];
|
||||
return r;
|
||||
}
|
||||
|
||||
template<> inline __device__
|
||||
void vStore<__nv_bfloat16>(volatile __nv_bfloat16* ptr, const __nv_bfloat16 val) {
|
||||
((__nv_bfloat16*)ptr)[0] = val;
|
||||
}
|
||||
#endif
|
||||
|
||||
typedef ulong2 Pack128;
|
||||
|
||||
template<class FUNC, typename T>
|
||||
struct MULTI128 {
|
||||
__device__ void operator()(Pack128& x, Pack128& y) {
|
||||
x.x = MULTI<FUNC, T>()(x.x, y.x);
|
||||
x.y = MULTI<FUNC, T>()(x.y, y.y);
|
||||
__device__ void operator()(FUNC fn, Pack128& x, Pack128 const& y) const {
|
||||
x.x = MULTI<FUNC, T>()(fn, x.x, y.x);
|
||||
x.y = MULTI<FUNC, T>()(fn, x.y, y.y);
|
||||
}
|
||||
__device__ void preOp(FUNC fn, Pack128 &x) const {
|
||||
x.x = MULTI<FUNC, T>().preOp(fn, x.x);
|
||||
x.y = MULTI<FUNC, T>().preOp(fn, x.y);
|
||||
}
|
||||
__device__ void postOp(FUNC fn, Pack128 &x) const {
|
||||
x.x = MULTI<FUNC, T>().postOp(fn, x.x);
|
||||
x.y = MULTI<FUNC, T>().postOp(fn, x.y);
|
||||
}
|
||||
};
|
||||
|
||||
@ -253,7 +489,8 @@ inline __device__ void Store128(Pack128* p, Pack128& v) {
|
||||
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t,
|
||||
int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) {
|
||||
FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem
|
||||
) {
|
||||
const int inc = nw * UNROLL * WARP_SIZE;
|
||||
int offset = w * UNROLL * WARP_SIZE + t;
|
||||
|
||||
@ -266,22 +503,30 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const
|
||||
T vals[UNROLL];
|
||||
// Load and reduce
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = vFetch(srcs[0]+u*WARP_SIZE);
|
||||
if (preOpSrc0) {
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits<FUNC>().preOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i=1; i<MINSRCS; i++) {
|
||||
T vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i=MINSRCS; i<MAXSRCS; i++) {
|
||||
if (i<nsrcs) {
|
||||
T vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) {
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits<FUNC>().postOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
// Store
|
||||
#pragma unroll
|
||||
for (int i = 0; i < MINDSTS; i++) {
|
||||
@ -301,7 +546,8 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const
|
||||
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
|
||||
int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) {
|
||||
FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack
|
||||
) {
|
||||
const int inc = nw * UNROLL * WARP_SIZE;
|
||||
int offset = w * UNROLL * WARP_SIZE + t;
|
||||
|
||||
@ -314,22 +560,30 @@ __device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, c
|
||||
Pack128 vals[UNROLL];
|
||||
// Load and reduce
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE);
|
||||
if (preOpSrc0) {
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>().preOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i=1; i<MINSRCS; i++) {
|
||||
Pack128 vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(fn, vals[u], vals2[u]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i=MINSRCS; i<MAXSRCS; i++) {
|
||||
if (i<nsrcs) {
|
||||
Pack128 vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(fn, vals[u], vals2[u]);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) {
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>().postOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
// Store
|
||||
#pragma unroll
|
||||
for (int i = 0; i < MINDSTS; i++) {
|
||||
@ -353,9 +607,9 @@ __device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(Pack128); }
|
||||
#define PACKELEMS (sizeof(Pack128) / sizeof(T))
|
||||
|
||||
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads,
|
||||
int nsrcs, const T** srcs, int ndsts, T** dsts,
|
||||
int N) {
|
||||
__device__ __forceinline__ void ReduceOrCopyMulti(
|
||||
const int tid, const int nthreads, FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, int N
|
||||
) {
|
||||
int Nrem = N;
|
||||
if (Nrem <= 0) return;
|
||||
|
||||
@ -381,7 +635,8 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre
|
||||
int Npack = (Nrem / (PACKELEMS*UNROLL*WARP_SIZE)) * (UNROLL*WARP_SIZE); // round down
|
||||
int Nelem = Npack * PACKELEMS;
|
||||
|
||||
ReduceCopy128bMulti<FUNC, T, UNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
ReduceCopy128bMulti<FUNC, T, UNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
@ -391,7 +646,8 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre
|
||||
Npack = Nrem / PACKELEMS;
|
||||
Nelem = Npack * PACKELEMS;
|
||||
|
||||
ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
@ -401,14 +657,16 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre
|
||||
// unrolled, by-type (mostly for unaligned buffers)
|
||||
int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down
|
||||
|
||||
ReduceCopyMulti<FUNC, T, UNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem);
|
||||
ReduceCopyMulti<FUNC, T, UNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
offset += Nelem;
|
||||
|
||||
// no unroll, by type. Should finish what's remaining.
|
||||
ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem);
|
||||
ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem);
|
||||
}
|
||||
|
||||
#endif // COMMON_KERNEL_H_
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -8,7 +8,7 @@
|
||||
#include "collectives.h"
|
||||
#include "common.h"
|
||||
|
||||
__device__ struct ncclShmemData* ncclShmem;
|
||||
__shared__ ncclShmemData ncclShmem;
|
||||
|
||||
#define NCCL_FUNC5(func, algo, redop, type) \
|
||||
NCCL_FUNC_NAME(func, algo, LL, redop, type), \
|
||||
@ -20,6 +20,31 @@ __device__ struct ncclShmemData* ncclShmem;
|
||||
NCCL_FUNC5(func, RING, redop, type), \
|
||||
NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, uint8_t), \
|
||||
NCCL_FUNC4(func, redop, int32_t), \
|
||||
NCCL_FUNC4(func, redop, uint32_t), \
|
||||
NCCL_FUNC4(func, redop, int64_t), \
|
||||
NCCL_FUNC4(func, redop, uint64_t), \
|
||||
NCCL_FUNC4(func, redop, half), \
|
||||
NCCL_FUNC4(func, redop, float), \
|
||||
NCCL_FUNC4(func, redop, double), \
|
||||
NCCL_FUNC4(func, redop, __nv_bfloat16)
|
||||
#define NCCL_FUNCS3B(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t)
|
||||
#else
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
@ -41,17 +66,21 @@ __device__ struct ncclShmemData* ncclShmem;
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t)
|
||||
#endif
|
||||
|
||||
// Must be consistent with ncclRedOp_t
|
||||
#define NCCL_FUNCS2A(func) \
|
||||
NCCL_FUNCS3A(func, Sum ), \
|
||||
NCCL_FUNCS3A(func, Prod), \
|
||||
NCCL_FUNCS3A(func, Max ), \
|
||||
NCCL_FUNCS3A(func, Min )
|
||||
NCCL_FUNCS3A(func, Min ), \
|
||||
NCCL_FUNCS3A(func, Avg)
|
||||
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
|
||||
// Must be consistent with ncclFunc_t
|
||||
|
@ -1,19 +1,26 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# See LICENSE.txt for license information
|
||||
#
|
||||
|
||||
dir=$1
|
||||
|
||||
datatypes="i8 u8 i32 u32 i64 u64 f16 f32 f64"
|
||||
if [ "$CUDA_MAJOR" -ge 11 ]
|
||||
then
|
||||
datatypes+=" bf16"
|
||||
fi
|
||||
|
||||
targets="GENOBJS := \\\\\n"
|
||||
|
||||
for base in sendrecv all_reduce all_gather broadcast reduce reduce_scatter; do
|
||||
opn=0
|
||||
for op in sum prod min max; do
|
||||
for op in sum prod min max avg; do
|
||||
dtn=0
|
||||
for dt in i8 u8 i32 u32 i64 u64 f16 f32 f64; do
|
||||
# Order must match that of the ncclDataType_t enum
|
||||
for dt in ${datatypes}; do
|
||||
echo "${dir}/${base}_${op}_${dt}.o : ${base}.cu ${dir}/${base}.dep"
|
||||
echo " @printf \"Compiling %-35s > %s\\\\n\" ${base}.cu ${dir}/${base}_${op}_${dt}.o"
|
||||
echo " mkdir -p ${dir}"
|
||||
|
@ -33,4 +33,36 @@ inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_
|
||||
:: "l"(v0), "l"(v1), "l"(shmemAsmPtr));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline __device__ void loadShmemMisaligned128(T *ptr, uint64_t &v0, uint64_t &v1) {
|
||||
union {
|
||||
uint32_t tmp4[4];
|
||||
uint64_t tmp8[2];
|
||||
};
|
||||
if(sizeof(T) < 4) {
|
||||
uint32_t *ptr4 = reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(ptr) & -uintptr_t(4));
|
||||
#pragma unroll
|
||||
for(int e=0; e < 4; e++) {
|
||||
// Produce 4 bytes of sub-register type by reading 2 4-byte
|
||||
// aligned values and shifting.
|
||||
uint32_t lo, hi;
|
||||
asm("ld.shared.b32 %0,[%1];" : "=r"(lo) : "l"(ptr4+e+0));
|
||||
asm("ld.shared.b32 %0,[%1];" : "=r"(hi) : "l"(ptr4+e+1));
|
||||
tmp4[e] = __funnelshift_r(lo, hi, 8*(int(reinterpret_cast<uintptr_t>(ptr))%4));
|
||||
}
|
||||
}
|
||||
else if(sizeof(T) == 4) {
|
||||
#pragma unroll
|
||||
for(int e=0; e < 4; e++)
|
||||
asm("ld.shared.b32 %0,[%1];" : "=r"(tmp4[e]) : "l"(ptr+e));
|
||||
}
|
||||
else /*sizeof(T)==8*/ {
|
||||
#pragma unroll
|
||||
for(int e=0; e < 2; e++)
|
||||
asm("ld.shared.b64 %0,[%1];" : "=l"(tmp8[e]) : "l"(ptr+e));
|
||||
}
|
||||
v0 = tmp8[0];
|
||||
v1 = tmp8[1];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -11,381 +11,132 @@
|
||||
#include "reduce_kernel.h" // for reduction funcs
|
||||
#include "common.h"
|
||||
|
||||
#define SPINS_BEFORE_CHECK_ABORT 1000000
|
||||
#define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000
|
||||
|
||||
// Unroll unconditionally the first send/recv since nsend/nrecv should be at
|
||||
// least 1 if SEND/RECV is set.
|
||||
#define FOR_SEND(func, ...) do { \
|
||||
if (SEND) { \
|
||||
/* Send to far first, then close */ \
|
||||
for (int i=1; i<NSEND && i<nsend; i++) func(i, ##__VA_ARGS__); \
|
||||
func(0, ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
/* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128
|
||||
* We use these as template args to the Primtiives class instead of integral
|
||||
* enums (e.g. NCCL_PROTO_LL) because for SIMPLE we need to carry a few extra
|
||||
* numbers. Also these types hold methods which let us compute numbers important
|
||||
* to how that protocol operates with a consistent interface so that our
|
||||
* algorithm code can operate protocol parametrically.
|
||||
*/
|
||||
template<int SlicePerChunk_1, int StepPerSlice_1, int Unroll_1 = COLL_UNROLL>
|
||||
struct ProtoSimple {
|
||||
static constexpr int Id = NCCL_PROTO_SIMPLE;
|
||||
static constexpr int SlicePerChunk = SlicePerChunk_1;
|
||||
static constexpr int StepPerSlice = StepPerSlice_1;
|
||||
static constexpr int Unroll = Unroll_1;
|
||||
|
||||
#define FOR_RECV(func, ...) do { \
|
||||
if (RECV) { \
|
||||
/* Recv from close first, then far */ \
|
||||
func(0, ##__VA_ARGS__); \
|
||||
for (int i=1; i<NRECV && i<nrecv; i++) func(i, ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ROLE_SRC 0x01
|
||||
#define ROLE_DST 0x02
|
||||
#define ROLE_WAIT_RECV 0x04
|
||||
#define ROLE_WAIT_SEND 0x08
|
||||
#define ROLE_POST_SEND 0x10
|
||||
#define ROLE_POST_RECV 0x20
|
||||
|
||||
// Implementation of primitive types
|
||||
template <int UNROLL, int SLICESPERCHUNK, int SLICESTEPS, typename T, int NRECV, int NSEND, int DIRECT, class FUNC>
|
||||
class ncclPrimitives {
|
||||
private:
|
||||
const int tid;
|
||||
int nthreads;
|
||||
int nworkers;
|
||||
const int stepSize;
|
||||
int nrecv = 0;
|
||||
int nsend = 0;
|
||||
struct ncclConnInfo* conn = NULL;
|
||||
volatile int* connSizesFifoPtr = NULL;
|
||||
void** connPtrsFifoPtr = NULL;
|
||||
volatile uint64_t* connHeadPtr = NULL;
|
||||
volatile uint64_t* connTailPtr = NULL;
|
||||
uint64_t connTailCache; // Cache last seen value
|
||||
uint64_t connHeadCache; // Cache last seen value
|
||||
|
||||
int index; // Peer index I'm responsible for
|
||||
int peer = -1;
|
||||
int role = 0;
|
||||
int group;
|
||||
uint64_t step;
|
||||
T* direct = NULL;
|
||||
T* buff;
|
||||
struct ncclDevComm* comm;
|
||||
|
||||
const T** srcs;
|
||||
T** dsts;
|
||||
|
||||
// Don't use barrier 0 as it's used by the final sync
|
||||
inline __device__ void barrier() {
|
||||
if (nthreads == WARP_SIZE) __syncwarp();
|
||||
else asm volatile ("bar.sync %0, %1;" :: "r"(group+1), "r"(nthreads));
|
||||
// Data bytes (no flags etc) in one step of the fifo queue.
|
||||
__device__ static int calcBytePerStep() {
|
||||
return ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS;
|
||||
}
|
||||
inline __device__ void subBarrier() {
|
||||
if (nworkers == nthreads) barrier();
|
||||
else asm volatile ("bar.sync %0, %1;" :: "r"(group+2), "r"(nworkers));
|
||||
// Granularity of data bytes transferred per thread.
|
||||
__device__ static int calcBytePerGrain() {
|
||||
return sizeof(uint64_t); // Bogus value? Nobody queries this metric for simple.
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
uint32_t abort = 0;
|
||||
|
||||
inline __device__ int checkAbort() {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *(comm->abortFlag);
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
}
|
||||
|
||||
template <int DIRECTPTR>
|
||||
inline __device__ T* directPtr(ssize_t directOffset) {
|
||||
return DIRECTPTR && direct ? direct+directOffset : buff+(step%NCCL_STEPS)*stepSize;
|
||||
}
|
||||
|
||||
template <int DST, int DIRECTSEND>
|
||||
inline __device__ void waitSend(ssize_t directOffset, int nbytes) {
|
||||
spins = 0;
|
||||
while (connHeadCache + NCCL_STEPS < step + SLICESTEPS) {
|
||||
connHeadCache = *connHeadPtr;
|
||||
if (checkAbort()) break;
|
||||
}
|
||||
if (connSizesFifoPtr) {
|
||||
connSizesFifoPtr[step%NCCL_STEPS] = nbytes;
|
||||
}
|
||||
|
||||
if (connPtrsFifoPtr) loadPtr(connPtrsFifoPtr+step%NCCL_STEPS, dsts[DST+index]);
|
||||
else dsts[DST+index] = directPtr<DIRECTSEND>(directOffset);
|
||||
step += SLICESTEPS;
|
||||
}
|
||||
|
||||
template <int SRC, int DIRECTRECV>
|
||||
inline __device__ void waitRecv(ssize_t directOffset) {
|
||||
spins = 0;
|
||||
while (connTailCache < step + SLICESTEPS) {
|
||||
connTailCache = *connTailPtr;
|
||||
if (checkAbort()) break;
|
||||
}
|
||||
if (connPtrsFifoPtr) loadPtr(connPtrsFifoPtr+step%NCCL_STEPS, srcs[SRC+index]);
|
||||
else srcs[SRC+index] = directPtr<DIRECTRECV>(directOffset);
|
||||
step += SLICESTEPS;
|
||||
}
|
||||
|
||||
inline __device__ void postRecv() {
|
||||
*connHeadPtr = step += SLICESTEPS;
|
||||
}
|
||||
|
||||
inline __device__ void postSend() {
|
||||
*connTailPtr = step += SLICESTEPS;
|
||||
}
|
||||
|
||||
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
|
||||
inline __device__ void
|
||||
GenericOp(const T* srcPtr, T* dstPtr, int nelem, ssize_t directOffset) {
|
||||
int offset = 0;
|
||||
int sliceSize = stepSize*SLICESTEPS;
|
||||
int dataSize = max(DIVUP(nelem, 16*SLICESPERCHUNK)*16, sliceSize/32);
|
||||
|
||||
#pragma unroll
|
||||
for (int slice=0; slice<SLICESPERCHUNK; ++slice) {
|
||||
int realSize = max(0, min(dataSize, nelem-offset));
|
||||
if (tid < nworkers) {
|
||||
if (SRC && (role & ROLE_SRC)) srcs[0] = srcPtr+offset;
|
||||
if (RECV && (role & ROLE_WAIT_RECV)) waitRecv<SRC, DIRECTRECV>(directOffset+offset);
|
||||
if (DST && (role & ROLE_DST)) dsts[0] = dstPtr+offset;
|
||||
if (SEND && (role & ROLE_WAIT_SEND)) waitSend<DST, DIRECTSEND>(directOffset+offset, realSize*sizeof(T));
|
||||
if (realSize > 0) {
|
||||
subBarrier();
|
||||
if (DIRECTRECV && srcs[0] == dsts[0]) {
|
||||
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
|
||||
if (SEND) {
|
||||
// (1-SEND) is only there to avoid compilation errors in case NSEND=0 (and SEND=0).
|
||||
ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, (1-SEND)+NSEND>(tid, nworkers, 1, srcs, nsend, dsts+1, realSize);
|
||||
}
|
||||
} else {
|
||||
ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nworkers, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (SEND && (role & ROLE_POST_SEND) && realSize > 0 && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
if (SEND && (role & ROLE_POST_SEND)) postSend();
|
||||
if (RECV && (role & ROLE_POST_RECV)) postRecv();
|
||||
offset += realSize;
|
||||
}
|
||||
}
|
||||
|
||||
// Scatter and gather do not support DIRECT
|
||||
template <int RECV, int SEND>
|
||||
inline __device__ void
|
||||
ScatterGatherOp(const T* srcPtr, T* dstPtr, int totalElem, int peerElem, int skip, int shift) {
|
||||
int offset = 0; // slice offset
|
||||
int sliceSize = stepSize*SLICESTEPS;
|
||||
int dataSize = max(DIVUP(peerElem, 16*SLICESPERCHUNK)*16, sliceSize/32); // per-peer slice size
|
||||
|
||||
#pragma unroll
|
||||
for (int slice=0; slice<SLICESPERCHUNK; ++slice) {
|
||||
int realSize = max(0, min(dataSize, peerElem-offset));
|
||||
if (tid < nworkers) {
|
||||
if (RECV && (role & ROLE_WAIT_RECV)) waitRecv<0, 0>(0);
|
||||
// realSize is not accurate here; but intra-node does not rely on sizes FIFO
|
||||
if (SEND && (role & ROLE_WAIT_SEND)) waitSend<0, 0>(0, realSize*sizeof(T));
|
||||
subBarrier();
|
||||
if (SEND) {
|
||||
#pragma unroll
|
||||
for (int j=0; j<nsend; j++) {
|
||||
int i = (j+shift)%nsend;
|
||||
int peerOffset = i*peerElem + offset;
|
||||
if (skip >=0 && i >= skip) peerOffset += peerElem;
|
||||
const T* src0 = srcPtr + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nworkers, 1, &src0, 1, dsts+i, realPeerSize);
|
||||
}
|
||||
} else if (RECV) {
|
||||
#pragma unroll
|
||||
for (int j=0; j<nrecv; j++) {
|
||||
int i = (j+shift)%nrecv;
|
||||
int peerOffset = i*peerElem + offset;
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
T* dst0 = dstPtr + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nworkers, 1, srcs+i, 1, &dst0, realPeerSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (SEND && (role & ROLE_POST_SEND) && realSize > 0 && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
if (SEND && (role & ROLE_POST_SEND)) postSend();
|
||||
if (RECV && (role & ROLE_POST_RECV)) postRecv();
|
||||
offset += realSize;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclChannel* channel, T* directBuff) {
|
||||
if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) {
|
||||
// For oneshot: groups 0,2 use conn 0, groups 4,6 use conn 1
|
||||
const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/4 : 0;
|
||||
conn = &channel->devPeers[peer].recv[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS);
|
||||
if (role & ROLE_POST_RECV) {
|
||||
connHeadPtr = conn->head;
|
||||
// Return credits in case we rounded up.
|
||||
*connHeadPtr = step;
|
||||
}
|
||||
if (role & ROLE_WAIT_RECV) {
|
||||
buff = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
|
||||
if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
|
||||
direct = directBuff;
|
||||
*conn->ptrExchange = directBuff;
|
||||
}
|
||||
connTailPtr = conn->tail;
|
||||
connTailCache = *connTailPtr;
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadSendConn(struct ncclChannel* channel) {
|
||||
if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) {
|
||||
// For oneshot: groups 0,2 use conn 0, groups 4,6 use conn 1
|
||||
const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/4 : 0;
|
||||
conn = &channel->devPeers[peer].send[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS);
|
||||
if (role & ROLE_POST_SEND) {
|
||||
connTailPtr = conn->tail;
|
||||
}
|
||||
if (role & ROLE_WAIT_SEND) {
|
||||
buff = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
|
||||
if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
|
||||
void* volatile* ptr = conn->ptrExchange;
|
||||
while ((direct = (T*)(*ptr)) == NULL) { if (checkAbort()) break; }
|
||||
*ptr = NULL;
|
||||
}
|
||||
connHeadPtr = conn->head;
|
||||
connHeadCache = *connHeadPtr;
|
||||
connSizesFifoPtr = conn->sizesFifo;
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveSync() {
|
||||
if (role & (ROLE_POST_SEND|ROLE_POST_RECV)) {
|
||||
conn->step = step;
|
||||
__threadfence_system();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclPrimitives(const int tid, const int nworkers, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, struct ncclShmemPtrs* ptrs, int group)
|
||||
: comm(comm), tid(tid), nworkers(nworkers), stepSize(stepSize), srcs((const T**)ptrs[group].srcs), dsts((T**)ptrs[group].dsts), group(group) {
|
||||
nthreads = nworkers;
|
||||
// For send operations, we need an extra warp to overlap the threadfence and the copy
|
||||
int postThreads = NSEND && nworkers >= 64 ? WARP_SIZE : 0;
|
||||
nthreads += postThreads;
|
||||
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
|
||||
for (int i=0; i<NRECV; i++) if (recvPeers[i] != -1) nrecv++;
|
||||
for (int i=0; i<NSEND; i++) if (sendPeers[i] != -1) nsend++;
|
||||
|
||||
#define SYNC_GROUP 8
|
||||
static_assert(NSEND < SYNC_GROUP && NRECV < SYNC_GROUP, "Not enough threads to cover all peers");
|
||||
|
||||
int g = tid / SYNC_GROUP;
|
||||
int ng = nthreads / SYNC_GROUP;
|
||||
index = tid % SYNC_GROUP;
|
||||
|
||||
if (g == 0) {
|
||||
if (index < nrecv) role |= ROLE_WAIT_RECV;
|
||||
if (index == nrecv) role |= ROLE_SRC;
|
||||
} else if (g == 1) {
|
||||
if (index < nsend) role |= ROLE_WAIT_SEND;
|
||||
if (index == nsend) role |= ROLE_DST;
|
||||
} else if (g == ng - 2) {
|
||||
if (index < nrecv) role |= ROLE_POST_RECV;
|
||||
} else if (g == ng - 1) {
|
||||
if (index < nsend) role |= ROLE_POST_SEND;
|
||||
}
|
||||
|
||||
if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) peer = recvPeers[index];
|
||||
if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) peer = sendPeers[index];
|
||||
|
||||
loadRecvConn(channel, directBuff);
|
||||
loadSendConn(channel);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
send(const T* src, int nelem) {
|
||||
GenericOp<0, 0, 0, 1, 1, 0>(src, NULL, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directSend(const T* src, ssize_t directOffset, int nelem) {
|
||||
GenericOp<0, 1, 0, 1, 1, 0>(src, NULL, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recv(T* dst, int nelem) {
|
||||
GenericOp<0, 0, 1, 0, 0, 1>(NULL, dst, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directRecv(T* dst, ssize_t directOffset, int nelem) {
|
||||
GenericOp<1, 0, 1, 0, 0, 1>(NULL, dst, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
copySend(const T* src, T* dst, int nelem) {
|
||||
GenericOp<0, 0, 0, 1, 1, 1>(src, dst, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directCopySend(const T* src, T* dst, ssize_t directOffset, int nelem) {
|
||||
GenericOp<0, 1, 0, 1, 1, 1>(src, dst, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recvCopySend(T* dst, int nelem) {
|
||||
GenericOp<0, 0, 1, 1, 0, 1>(NULL, dst, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directRecvCopySend(T* dst, ssize_t directOffset, int nelem) {
|
||||
GenericOp<1, 1, 1, 1, 0, 1>(NULL, dst, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recvReduceCopy(const T* src, T* dst, int nelem) {
|
||||
GenericOp<0, 0, 1, 0, 1, 1>(src, dst, nelem, 0);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recvReduceSend(const T* src, int nelem) {
|
||||
GenericOp<0, 0, 1, 1, 1, 0>(src, NULL, nelem, 0);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recvReduceCopySend(const T* src, T* dst, int nelem) {
|
||||
GenericOp<0, 0, 1, 1, 1, 1>(src, dst, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directRecvReduceCopySend(const T* src, T* dst, ssize_t directOffset, int nelem) {
|
||||
// Direct is only for the send part
|
||||
GenericOp<0, 1, 1, 1, 1, 1>(src, dst, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
scatter(const T* src, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<0, 1>(src, NULL, totalElem, peerElem, skip, shift);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
gather(T* dst, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<1, 0>(NULL, dst, totalElem, peerElem, skip, shift);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ ~ncclPrimitives() {
|
||||
// Save steps for the next operation
|
||||
saveSync();
|
||||
// Group width is how many consecutive group values a subchannel occupies.
|
||||
static constexpr int MaxGroupWidth = 2;
|
||||
__device__ static int calcGroupWidth(bool send, int nthreads) {
|
||||
return send && nthreads-WARP_SIZE >= 64 ? 2 : 1;
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll.h"
|
||||
//#include "prims_ll128.h"
|
||||
struct ProtoLL {
|
||||
static constexpr int Id = NCCL_PROTO_LL;
|
||||
|
||||
// Data bytes (no flags etc) in one step of the fifo queue.
|
||||
__device__ static int calcBytePerStep() {
|
||||
return ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/2; // Half is data
|
||||
}
|
||||
// Granularity of data bytes transferred per thread.
|
||||
__device__ static int calcBytePerGrain() {
|
||||
return sizeof(uint64_t); // One 16-byte line has 8-bytes of data
|
||||
}
|
||||
// Group width is how many consecutive group values a subchannel occupies.
|
||||
static constexpr int MaxGroupWidth = 1;
|
||||
__device__ static int calcGroupWidth(bool send, int nthreads) {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProtoLL128 {
|
||||
static constexpr int Id = NCCL_PROTO_LL128;
|
||||
|
||||
// Data bytes (no flags etc) in one step of the fifo queue.
|
||||
__device__ static int calcBytePerStep() {
|
||||
return (ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS)*NCCL_LL128_DATAELEMS/NCCL_LL128_LINEELEMS;
|
||||
}
|
||||
// Granularity of data bytes transferred per thread.
|
||||
__device__ static int calcBytePerGrain() {
|
||||
return NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_DATAELEMS*sizeof(uint64_t)/NCCL_LL128_LINEELEMS;
|
||||
}
|
||||
// Group width is how many consecutive group values a subchannel occupies.
|
||||
static constexpr int MaxGroupWidth = 1;
|
||||
__device__ static int calcGroupWidth(bool send, int nthreads) {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
/* Fan (as in fan-in & fan-out) classes hold recv and send counts. The template
|
||||
* arguments are static bounds on the maximum values. Asymmetric counts are
|
||||
* independent. Symmetric is a static guarantee that nrecv==nsend, so it only
|
||||
* stores one value at runtime. This optimization save 32-bit register, but more
|
||||
* importantly uses fewer predicate registers when unrolling loops.
|
||||
*/
|
||||
template<int MaxRecv_, int MaxSend_>
|
||||
struct FanAsymmetric {
|
||||
static constexpr int MaxRecv = MaxRecv_, MaxSend = MaxSend_;
|
||||
int nr, ns;
|
||||
FanAsymmetric() = default;
|
||||
__device__ FanAsymmetric(int nrecv, int nsend): nr(nrecv), ns(nsend) {
|
||||
// assert(nrecv <= MaxRecv && nsend <= MaxSend);
|
||||
}
|
||||
__device__ int nrecv() const { return MaxRecv ? nr : 0; }
|
||||
__device__ int nsend() const { return MaxSend ? ns : 0; }
|
||||
};
|
||||
|
||||
template<int MaxArity>
|
||||
struct FanSymmetric {
|
||||
static constexpr int MaxRecv = MaxArity, MaxSend = MaxArity;
|
||||
int n;
|
||||
FanSymmetric() = default;
|
||||
__device__ FanSymmetric(int nrecv, int nsend): n(nrecv) {
|
||||
// assert(nrecv == nsend && nrecv <= MaxArity);
|
||||
}
|
||||
__device__ int nrecv() const { return n; }
|
||||
__device__ int nsend() const { return n; }
|
||||
};
|
||||
|
||||
// The primitives class. Specialized per protocol in the other headers.
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, typename Proto>
|
||||
class Primitives;
|
||||
|
||||
// Used by LL & LL128 to implement direct members in the naive way.
|
||||
template<typename RealPrimitives>
|
||||
struct PrimitivesWithoutDirect {
|
||||
__device__ void directSend(intptr_t inpIx, intptr_t remoteOutIx, int eltN) {
|
||||
static_cast<RealPrimitives*>(this)->send(inpIx, eltN);
|
||||
}
|
||||
__device__ void directSendFromOutput(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
static_cast<RealPrimitives*>(this)->sendFromOutput(outIx, eltN);
|
||||
}
|
||||
__device__ void directRecv(intptr_t outIx, int eltN) {
|
||||
static_cast<RealPrimitives*>(this)->recv(outIx, eltN, /*postOp=*/false);
|
||||
}
|
||||
__device__ void directCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
static_cast<RealPrimitives*>(this)->copySend(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
static_cast<RealPrimitives*>(this)->recvCopySend(outIx, eltN, /*postOp=*/false);
|
||||
}
|
||||
__device__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
// Direct is only for the send part
|
||||
static_cast<RealPrimitives*>(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_simple.h"
|
||||
#include "prims_ll.h"
|
||||
#include "prims_ll128.h"
|
||||
#endif
|
||||
|
@ -4,15 +4,20 @@
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
template <typename T, class FUNC, int NRECV, int NSEND>
|
||||
class ncclLLPrimitives {
|
||||
private:
|
||||
template<typename T, typename RedOp, typename Fan, int Direct>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL>> {
|
||||
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
RedOp redOp;
|
||||
const int tid;
|
||||
const int nthreads;
|
||||
const int wid;
|
||||
const int group;
|
||||
const int stepLines;
|
||||
int nrecv = 0;
|
||||
int nsend = 0;
|
||||
Fan fan;
|
||||
T *userBufs[2];
|
||||
struct ncclConnInfo* recvConn = NULL;
|
||||
volatile uint64_t* recvConnHeadPtr = NULL;
|
||||
uint64_t recvConnHead;
|
||||
@ -23,11 +28,10 @@ class ncclLLPrimitives {
|
||||
uint64_t sendConnHead;
|
||||
uint64_t sendConnHeadCache; // Cache last seen value
|
||||
|
||||
uint64_t recvStep[NRECV];
|
||||
uint64_t sendStep[NSEND];
|
||||
union ncclLLFifoLine* recvBuff[NRECV];
|
||||
union ncclLLFifoLine* sendBuff[NSEND];
|
||||
struct ncclDevComm* comm;
|
||||
uint64_t recvStep[MaxRecv];
|
||||
uint64_t sendStep[MaxSend];
|
||||
union ncclLLFifoLine* recvBuff[MaxRecv];
|
||||
union ncclLLFifoLine* sendBuff[MaxSend];
|
||||
|
||||
inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; }
|
||||
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; }
|
||||
@ -37,27 +41,26 @@ class ncclLLPrimitives {
|
||||
inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); }
|
||||
|
||||
inline __device__ void barrier() {
|
||||
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
|
||||
asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(1+group));
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
uint32_t abort = 0;
|
||||
|
||||
inline __device__ int checkAbort(int i, int send) {
|
||||
inline __device__ int checkAbort(int &spins, int send) {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *(comm->abortFlag);
|
||||
if (abort == 0 && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *ncclShmem.comm.abortFlag;
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
}
|
||||
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
spins = 0;
|
||||
if (sendConnHeadPtr) {
|
||||
int spins = 0;
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
if (checkAbort(wid, 1)) break;
|
||||
if (checkAbort(spins, 1)) break;
|
||||
}
|
||||
if (sendConnFifoPtr) {
|
||||
int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes;
|
||||
@ -85,83 +88,212 @@ class ncclLLPrimitives {
|
||||
sendStep[i]++;
|
||||
}
|
||||
|
||||
__device__ uint64_t readLL(int i, int offset) {
|
||||
__device__ uint64_t readLL(int offset, int i) {
|
||||
union ncclLLFifoLine* src = recvPtr(i) + offset;
|
||||
uint32_t flag = recvFlag(i);
|
||||
uint32_t data1, flag1, data2, flag2;
|
||||
spins = 0;
|
||||
int spins = 0;
|
||||
do {
|
||||
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
|
||||
if (checkAbort(i, 0)) break;
|
||||
asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
|
||||
if (checkAbort(spins, 0)) break;
|
||||
} while ((flag1 != flag) || (flag2 != flag));
|
||||
uint64_t val64 = data1 + (((uint64_t)data2) << 32);
|
||||
return val64;
|
||||
}
|
||||
|
||||
template<int BeginIx>
|
||||
__device__ void readLLBeginAll(int offset, ncclLLFifoLine(&line)[MaxRecv]) {
|
||||
#pragma unroll
|
||||
for (int i=BeginIx; i < MaxRecv; i++) {
|
||||
if (i < fan.nrecv()) {
|
||||
union ncclLLFifoLine* src = recvPtr(i) + offset;
|
||||
asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4));
|
||||
}
|
||||
}
|
||||
}
|
||||
__device__ uint64_t readLLFinish(int offset, ncclLLFifoLine(&line)[MaxRecv], int i) {
|
||||
union ncclLLFifoLine* src = recvPtr(i) + offset;
|
||||
uint32_t flag = recvFlag(i);
|
||||
int spins = 0;
|
||||
while (line[i].flag1 != flag || line[i].flag2 != flag) {
|
||||
asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4));
|
||||
if (checkAbort(spins, 0)) break;
|
||||
}
|
||||
uint64_t val64 = line[i].data1 + (((uint64_t)line[i].data2) << 32);
|
||||
return val64;
|
||||
}
|
||||
|
||||
__device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
|
||||
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag));
|
||||
}
|
||||
|
||||
// Using memcpy handles misaligned pointers.
|
||||
__device__ uint64_t readAL(uint64_t* src) {
|
||||
uint64_t val;
|
||||
memcpy((char*)&val, (char*)src, sizeof(uint64_t));
|
||||
return val;
|
||||
static constexpr int EltPerLine = sizeof(uint64_t)/sizeof(T);
|
||||
|
||||
template<typename U>
|
||||
__device__ static U load(U *src) {
|
||||
union {
|
||||
U elt;
|
||||
uint16_t u2;
|
||||
uint32_t u4;
|
||||
uint64_t u8;
|
||||
};
|
||||
if(sizeof(U) == 1)
|
||||
asm("ld.volatile.global.b8 %0,[%1];" : "=r"(u4) : "l"(src));
|
||||
else if(sizeof(U) == 2)
|
||||
asm("ld.volatile.global.b16 %0,[%1];" : "=h"(u2) : "l"(src));
|
||||
else if(sizeof(U) == 4)
|
||||
asm("ld.volatile.global.b32 %0,[%1];" : "=r"(u4) : "l"(src));
|
||||
else
|
||||
asm("ld.volatile.global.b64 %0,[%1];" : "=l"(u8) : "l"(src));
|
||||
return elt;
|
||||
}
|
||||
|
||||
__device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) {
|
||||
memcpy((char*)dst, (char*)&val, nbytes);
|
||||
template<typename U>
|
||||
__device__ static void store(U *dst, U val) {
|
||||
union {
|
||||
U elt;
|
||||
uint16_t u2;
|
||||
uint32_t u4;
|
||||
uint64_t u8;
|
||||
};
|
||||
elt = val;
|
||||
if(sizeof(U) == 1)
|
||||
asm("st.volatile.global.b8 [%0],%1;" :: "l"(dst), "r"(u4));
|
||||
else if(sizeof(U) == 2)
|
||||
asm("st.volatile.global.b16 [%0],%1;" :: "l"(dst), "h"(u2));
|
||||
else if(sizeof(U) == 4)
|
||||
asm("st.volatile.global.b32 [%0],%1;" :: "l"(dst), "r"(u4));
|
||||
else
|
||||
asm("st.volatile.global.b64 [%0],%1;" :: "l"(dst), "l"(u8));
|
||||
}
|
||||
|
||||
template <int RECV, int SEND, int SRC, int DST>
|
||||
__device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) {
|
||||
uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T);
|
||||
uint32_t npack = DIVUP(nbytes, sizeof(uint64_t));
|
||||
uint64_t* srcPack = (uint64_t*)srcPtr;
|
||||
uint64_t* dstPack = (uint64_t*)dstPtr;
|
||||
int offset = tid;
|
||||
struct DataLoader {
|
||||
int misalign;
|
||||
union {
|
||||
uint32_t u4[sizeof(T) <= 2 ? 3 : 2];
|
||||
uint64_t u8;
|
||||
T elt[EltPerLine];
|
||||
};
|
||||
|
||||
// Always waitSend in case of cleanup
|
||||
if (SEND) waitSend(npack*sizeof(union ncclLLFifoLine));
|
||||
|
||||
// Do multiples of 64 bits
|
||||
#pragma unroll 2
|
||||
for (; offset<npack; offset+=nthreads) {
|
||||
// Recv : local, then intra-node, then inter-node
|
||||
uint64_t val = SRC ? readAL(srcPack+offset) : readLL(0, offset);
|
||||
if (RECV) {
|
||||
if (SRC) val = MULTI<FUNC, T>()(readLL(0, offset), val);
|
||||
for (int i=1; i<NRECV && i<nrecv; i++) {
|
||||
val = MULTI<FUNC, T>()(readLL(i, offset), val);
|
||||
}
|
||||
__device__ void loadBegin(T *src, int eltN) {
|
||||
if (sizeof(T) <= 2) {
|
||||
misalign = reinterpret_cast<uintptr_t>(src)%4;
|
||||
uint32_t *p = reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(src) & -uintptr_t(4));
|
||||
u4[0] = load(p+0);
|
||||
u4[1] = misalign + eltN*sizeof(T) > 4 ? load(p+1) : 0;
|
||||
// u4[2] would be simpler, but that throws warnings on some compilers
|
||||
u4[sizeof(T) <= 2 ? 2 : 0] = misalign + eltN*sizeof(T) > 8 ? load(p+2) : 0;
|
||||
}
|
||||
|
||||
// Send : inter-node, then intra-node, then local
|
||||
if (SEND) {
|
||||
for (int i=1; i<NSEND && i<nsend; i++) storeLL(sendPtr(i)+offset, val, sendFlag(i));
|
||||
storeLL(sendPtr(0)+offset, val, sendFlag(0));
|
||||
}
|
||||
if (DST) {
|
||||
if (((offset*sizeof(uint64_t)) ^ nbytes) < sizeof(uint64_t)) {
|
||||
// Last incomplete word
|
||||
storeAL(dstPack+offset, val, nbytes & 0x7);
|
||||
} else {
|
||||
storeAL(dstPack+offset, val, sizeof(uint64_t));
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int i=0; i < EltPerLine; i++) {
|
||||
if(i==0 || i < eltN)
|
||||
elt[i] = load(src + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
FOR_RECV(incRecv); if (RECV) postRecv();
|
||||
FOR_SEND(incSend, offset);
|
||||
|
||||
__device__ uint64_t loadFinish() {
|
||||
if (sizeof(T) <= 2) {
|
||||
u4[0] = __funnelshift_r(u4[0], u4[1], 8*misalign);
|
||||
// u4[2] would be simpler, but that throws warnings on some compilers
|
||||
u4[1] = __funnelshift_r(u4[1], u4[sizeof(T) <= 2 ? 2 : 0], 8*misalign);
|
||||
}
|
||||
return u8;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void storeData(T *dst, uint64_t val, int eltN) {
|
||||
union {
|
||||
uint64_t u8;
|
||||
T elt[EltPerLine];
|
||||
};
|
||||
u8 = val;
|
||||
#pragma unroll
|
||||
for(int i=0; i < EltPerLine; i++) {
|
||||
if (i==0 || i < eltN)
|
||||
//store(dst+i, elt[i]);
|
||||
dst[i] = elt[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
|
||||
|
||||
// Always waitSend in case of cleanup
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
if (SEND) waitSend(divUp(nelem, EltPerLine)*sizeof(ncclLLFifoLine));
|
||||
|
||||
nelem -= tid*EltPerLine;
|
||||
srcElts += tid*EltPerLine;
|
||||
dstElts += tid*EltPerLine;
|
||||
int offset = tid;
|
||||
int eltPerTrip = nthreads*EltPerLine;
|
||||
while (nelem > 0) {
|
||||
int eltInLine = EltPerLine < nelem ? EltPerLine : nelem;
|
||||
|
||||
DataLoader dl;
|
||||
ncclLLFifoLine line[MaxRecv];
|
||||
uint64_t data, peerData;
|
||||
if (SRC) {
|
||||
dl.loadBegin(srcElts, eltInLine);
|
||||
srcElts += eltPerTrip;
|
||||
}
|
||||
if (RECV) {
|
||||
readLLBeginAll<1>(offset, line);
|
||||
peerData = readLL(offset, 0);
|
||||
}
|
||||
if (SRC) {
|
||||
data = dl.loadFinish();
|
||||
if (SrcBuf == Input) data = MULTI<RedOp, T>().preOp(redOp, data);
|
||||
}
|
||||
if (RECV) {
|
||||
data = !SRC ? peerData : MULTI<RedOp,T>()(redOp, peerData, data);
|
||||
#pragma unroll MaxRecv
|
||||
for (int i=1; i < MaxRecv && i < fan.nrecv(); i++) {
|
||||
peerData = readLLFinish(offset, line, i);
|
||||
data = MULTI<RedOp,T>()(redOp, peerData, data);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) data = MULTI<RedOp, T>().postOp(redOp, data);
|
||||
|
||||
// Send : inter-node, then intra-node, then local
|
||||
if (SEND) {
|
||||
for (int i=1; i < MaxSend && i < fan.nsend(); i++)
|
||||
storeLL(sendPtr(i)+offset, data, sendFlag(i));
|
||||
storeLL(sendPtr(0)+offset, data, sendFlag(0));
|
||||
}
|
||||
if (DST) {
|
||||
storeData(dstElts, data, eltInLine);
|
||||
dstElts += eltPerTrip;
|
||||
}
|
||||
nelem -= eltPerTrip;
|
||||
offset += nthreads;
|
||||
}
|
||||
|
||||
if (RECV) {
|
||||
for (int i=0; i < MaxRecv; i++) incRecv(i);
|
||||
postRecv();
|
||||
}
|
||||
if (SEND) {
|
||||
for (int i=1; i < MaxSend && i < fan.nsend(); i++)
|
||||
incSend(i, offset);
|
||||
incSend(0, offset);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
||||
recvBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
|
||||
recvStep[i] = conn->step;
|
||||
if (wid == i) recvConn = conn;
|
||||
nrecv++;
|
||||
}
|
||||
__device__ __forceinline__ void loadRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv()) {
|
||||
recvConnHeadPtr = recvConn->head;
|
||||
recvConnHead = recvConn->step;
|
||||
}
|
||||
@ -171,10 +303,9 @@ class ncclLLPrimitives {
|
||||
sendBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
|
||||
sendStep[i] = conn->step;
|
||||
if (wid == i) sendConn = conn;
|
||||
nsend++;
|
||||
}
|
||||
__device__ __forceinline__ void loadSendSync() {
|
||||
if (tid < nsend) {
|
||||
if (tid < fan.nsend()) {
|
||||
sendConnHeadPtr = sendConn->head;
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
sendConnHead = sendConn->step;
|
||||
@ -182,65 +313,74 @@ class ncclLLPrimitives {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
recvConn->step = recvConnHead;
|
||||
__threadfence_block();
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveSendSync() {
|
||||
if (tid < nsend) {
|
||||
sendConn->step = sendConnHead;
|
||||
__threadfence_block();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines) {
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
__device__ Primitives(
|
||||
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0
|
||||
):
|
||||
redOp(FuncTraits<RedOp>().make(ncclShmem.comm.nRanks)),
|
||||
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group),
|
||||
stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) {
|
||||
|
||||
auto *channel = &ncclShmem.channel;
|
||||
// If we are going to support oneshot collNet + LL, then we would need to add connector index here
|
||||
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv->conn, i);
|
||||
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send->conn, i);
|
||||
int nrecv=0, nsend=0;
|
||||
while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) {
|
||||
loadRecvConn(&channel->devPeers[recvPeers[nrecv]].recv->conn, nrecv);
|
||||
nrecv++;
|
||||
}
|
||||
while (nsend < MaxSend && sendPeers[nsend] >= 0) {
|
||||
loadSendConn(&channel->devPeers[sendPeers[nsend]].send->conn, nsend);
|
||||
nsend++;
|
||||
}
|
||||
this->fan = Fan(nrecv, nsend);
|
||||
loadRecvSync();
|
||||
loadSendSync();
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
}
|
||||
|
||||
__device__ void send(const T* src, int nelem) {
|
||||
return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem);
|
||||
}
|
||||
|
||||
__device__ void recv(T* dst, int nelem) {
|
||||
return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceSend(const T* src, int nelem) {
|
||||
return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
|
||||
return LLGenericOp<1, 0, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void copySend(const T* src, T* dst, int nelem) {
|
||||
return LLGenericOp<0, 1, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvCopySend(T* dst, int nelem) {
|
||||
return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
|
||||
return LLGenericOp<1, 1, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ ~ncclLLPrimitives() {
|
||||
__device__ ~Primitives() {
|
||||
// Save steps for the next operation
|
||||
saveRecvSync();
|
||||
saveSendSync();
|
||||
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv())
|
||||
recvConn->step = recvConnHead;
|
||||
if (tid < fan.nsend())
|
||||
sendConn->step = sendConnHead;
|
||||
// Ensure all steps written back
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
userBufs[Input] = (T*)inputBuf;
|
||||
userBufs[Output] = (T*)outputBuf;
|
||||
}
|
||||
|
||||
__device__ void moveDataPtrs(intptr_t delta) {
|
||||
userBufs[Input] += delta;
|
||||
userBufs[Output] += delta;
|
||||
}
|
||||
|
||||
__device__ void send(intptr_t inpIx, int eltN) {
|
||||
return LLGenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void sendFromOutput(intptr_t outIx, int eltN) {
|
||||
return LLGenericOp<0, 1, Output, -1>(outIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void recv(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<1, 0, -1, Output>(-1, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvReduceSend(intptr_t inpIx, int eltN) {
|
||||
return LLGenericOp<1, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<1, 0, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<0, 1, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<1, 1, -1, Output>(-1, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
};
|
||||
|
@ -8,17 +8,22 @@
|
||||
|
||||
#define NCCL_LL128_FLAGTHREAD (NCCL_LL128_LINEELEMS-1)
|
||||
|
||||
template <typename T, class FUNC, int NRECV, int NSEND>
|
||||
class ncclLL128Primitives {
|
||||
private:
|
||||
template<typename T, typename RedOp, typename Fan, int Direct>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL128>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128>> {
|
||||
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
RedOp redOp;
|
||||
const int tid;
|
||||
const int nthreads;
|
||||
const int wid;
|
||||
const int stepSize;
|
||||
const int warp;
|
||||
const bool flagThread;
|
||||
int nrecv = 0;
|
||||
int nsend = 0;
|
||||
const int group;
|
||||
Fan fan;
|
||||
T *userBufs[2];
|
||||
struct ncclConnInfo* recvConn = NULL;
|
||||
volatile uint64_t* recvConnHeadPtr = NULL;
|
||||
uint64_t recvConnHead;
|
||||
@ -31,13 +36,10 @@ class ncclLL128Primitives {
|
||||
uint64_t sendConnHead;
|
||||
uint64_t sendConnHeadCache; // Cache last seen value
|
||||
|
||||
uint64_t recvStep[NRECV];
|
||||
uint64_t sendStep[NSEND];
|
||||
uint64_t* recvBuff[NRECV];
|
||||
uint64_t* sendBuff[NSEND];
|
||||
struct ncclDevComm* comm;
|
||||
|
||||
volatile uint64_t* shmem;
|
||||
uint64_t recvStep[MaxRecv];
|
||||
uint64_t sendStep[MaxSend];
|
||||
uint64_t* recvBuff[MaxRecv];
|
||||
uint64_t* sendBuff[MaxSend];
|
||||
|
||||
inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepSize; }
|
||||
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; }
|
||||
@ -47,31 +49,26 @@ class ncclLL128Primitives {
|
||||
inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; }
|
||||
|
||||
inline __device__ void barrier() {
|
||||
if (NSEND>NRECV) {
|
||||
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
|
||||
} else {
|
||||
asm volatile ("bar.sync 2, %0;" :: "r"(nthreads));
|
||||
}
|
||||
asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(1+group));
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
uint32_t abort = 0;
|
||||
|
||||
inline __device__ int checkAbort(int i, int send) {
|
||||
inline __device__ int checkAbort(int &spins, int i, int send) {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *(comm->abortFlag);
|
||||
if (abort == 0 && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *ncclShmem.comm.abortFlag;
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
}
|
||||
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
spins = 0;
|
||||
if (sendConnHeadPtr) {
|
||||
int spins = 0;
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
if (checkAbort(wid, 1)) break;
|
||||
if (checkAbort(spins, wid, 1)) break;
|
||||
}
|
||||
if (sendConnFifoPtr) {
|
||||
sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes;
|
||||
@ -80,137 +77,185 @@ class ncclLL128Primitives {
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void incRecv(int i) {
|
||||
recvStep[i] += 1;
|
||||
}
|
||||
inline __device__ void postRecv() {
|
||||
if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1;
|
||||
}
|
||||
|
||||
inline __device__ void incSend(int i) {
|
||||
sendStep[i] += 1;
|
||||
}
|
||||
inline __device__ void postSend() {
|
||||
if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; }
|
||||
}
|
||||
|
||||
template <int ELEMS_PER_THREAD>
|
||||
inline __device__ void loadSrcToShmem128(int maxOffset, const uint64_t* src64Ptr) {
|
||||
#if 0
|
||||
uint64_t v[ELEMS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
if (u*WARP_SIZE < maxOffset) load128(src64Ptr+u*WARP_SIZE, v[u], v[u+1]);
|
||||
}
|
||||
uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
storeShmem128(shmemAsmPtr+u*WARP_SIZE, v[u], v[u+1]);
|
||||
}
|
||||
#else
|
||||
uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
if (u*WARP_SIZE < maxOffset) {
|
||||
uint64_t v0, v1;
|
||||
load128(src64Ptr+u*WARP_SIZE, v0, v1);
|
||||
storeShmem128(shmemAsmPtr+u*WARP_SIZE, v0, v1);
|
||||
template<int WordPerThread>
|
||||
__device__ __forceinline__ void loadRegsBegin(uint64_t(®s)[WordPerThread], T const *src, int eltN) {
|
||||
constexpr int EltPer16B = 16/sizeof(T);
|
||||
if(reinterpret_cast<uintptr_t>(src)%16 == 0) {
|
||||
/* We are aligned to 16 bytes, so load directly to registers no shmem.
|
||||
* Flag threads load half as much data which gets shuffled to the even
|
||||
* registers during Finish. The point of splitting into two phases is to
|
||||
* defer that shuffle, which incurs a dependency stall, until after other
|
||||
* memops are launched by the caller.
|
||||
*/
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++) {
|
||||
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
|
||||
if(!flagThread || g%2==0) {
|
||||
if(ix*EltPer16B < eltN)
|
||||
load128((uint64_t*)(src + ix*EltPer16B), regs[2*g+0], regs[2*g+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
// Not aligned. Stage the smallest 16 byte aligned region subsuming the
|
||||
// buffer into shmem.
|
||||
int misalignment = reinterpret_cast<uintptr_t>(src) % 16;
|
||||
uint64_t *src8 = reinterpret_cast<uint64_t*>(reinterpret_cast<uintptr_t>(src) & -uintptr_t(16));
|
||||
uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]);
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++)
|
||||
if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T))
|
||||
load128(src8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]);
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++)
|
||||
storeShmem128(shm8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]);
|
||||
|
||||
inline __device__ void loadSrcToShmem(int start, int end, const T* srcPtr) {
|
||||
T* shmemPtr = (T*)(shmem-2*wid);
|
||||
for (int offset = start+wid; offset < end; offset += WARP_SIZE) {
|
||||
shmemPtr[offset] = srcPtr[offset];
|
||||
__syncwarp();
|
||||
|
||||
// Now load from shmem stage to regs. Preserve the same pre-shuffled layout
|
||||
// as the aligned case since Finish() will be applied regardless.
|
||||
T *shm = (T*)shm8 + misalignment/sizeof(T);
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++) {
|
||||
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
|
||||
if(!flagThread || g%2==0) {
|
||||
if(ix*EltPer16B < eltN)
|
||||
loadShmemMisaligned128(shm + ix*EltPer16B, regs[2*g+0], regs[2*g+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int ELEMS_PER_THREAD>
|
||||
inline __device__ void storeShmemToDst128(int maxOffset, uint64_t* dst64Ptr) {
|
||||
uint64_t v[ELEMS_PER_THREAD];
|
||||
uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
|
||||
template<int WordPerThread>
|
||||
__device__ __forceinline__ void loadRegsFinish(uint64_t(®s)[WordPerThread]) {
|
||||
// Move data out of flag registers into the vacant registers.
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
loadShmem128(shmemAsmPtr+u*WARP_SIZE, v[u], v[u+1]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
if (u*WARP_SIZE < maxOffset) store128(dst64Ptr+u*WARP_SIZE, v[u], v[u+1]);
|
||||
for (int g=1; g < WordPerThread/2; g+=2) {
|
||||
if (flagThread) regs[2*g] = regs[2*g-1];
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void storeShmemToDst(int start, int end, T* dstPtr) {
|
||||
T* shmemPtr = (T*)(shmem-2*wid);
|
||||
for (int offset = start+wid; offset < end; offset += WARP_SIZE) {
|
||||
dstPtr[offset] = shmemPtr[offset];
|
||||
template<int WordPerThread>
|
||||
__device__ __forceinline__ void storeRegs(T *dst, uint64_t(®s)[WordPerThread], int eltN) {
|
||||
constexpr int EltPer16B = 16/sizeof(T);
|
||||
// Reverse Finish() register permuatation.
|
||||
#pragma unroll
|
||||
for (int g=1; g < WordPerThread/2; g+=2) {
|
||||
if (flagThread) regs[2*g-1] = regs[2*g];
|
||||
}
|
||||
// Write to dst if 16-byte aligned, shmem otherwise.
|
||||
int misalignment = reinterpret_cast<uintptr_t>(dst)%16;
|
||||
uint64_t *shm8 = shmemCvtPtr(ncclShmem.ll128warp[warp]);
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++) {
|
||||
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
|
||||
if (!flagThread || g%2==0) {
|
||||
if(misalignment == 0 && (ix+1)*EltPer16B <= eltN)
|
||||
store128((uint64_t*)(dst + ix*EltPer16B), regs[2*g+0], regs[2*g+1]);
|
||||
else
|
||||
storeShmem128(shm8+2*ix, regs[2*g+0], regs[2*g+1]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
// Write rest from shmem to dst. No need to coalesce stores to 16-bytes,
|
||||
// the hardware keeps up fine.
|
||||
T *shm = (T*)ncclShmem.ll128warp[warp];
|
||||
int skip = misalignment == 0 ? eltN & -EltPer16B : 0;
|
||||
for(int i=skip+wid; i < eltN; i += WARP_SIZE)
|
||||
dst[i] = shm[i];
|
||||
}
|
||||
|
||||
#define WARP_MASK 0xffffffff
|
||||
|
||||
template <int ELEMS_PER_THREAD, int RECV, int SEND, int SRC, int DST>
|
||||
__device__ __forceinline__ void recvReduceSendCopy(int ll128Offset) {
|
||||
uint64_t v[ELEMS_PER_THREAD];
|
||||
template <int ELEMS_PER_THREAD, int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ __forceinline__ void recvReduceSendCopy(uint64_t(&v)[ELEMS_PER_THREAD], int ll128Offset, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
uint64_t vr[ELEMS_PER_THREAD];
|
||||
|
||||
/************* Data Loading : SHMEM -> REG **************/
|
||||
if (SRC) {
|
||||
volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = shmem64Ptr[u*(WARP_SIZE-2)];
|
||||
if (!flagThread) v[u+1] = shmem64Ptr[u*(WARP_SIZE-2)+1];
|
||||
}
|
||||
}
|
||||
/*********** End Data Loading : SHMEM -> REG ************/
|
||||
|
||||
/************************ Recv **************************/
|
||||
__syncwarp();
|
||||
/************************ Wait first recv ********************/
|
||||
if (RECV) {
|
||||
uint64_t flag = recvFlag(0);
|
||||
uint64_t* ptr = recvPtr(0)+ll128Offset;
|
||||
uint64_t flag = recvFlag(0);
|
||||
bool needReload;
|
||||
uint64_t v0, v1;
|
||||
int spins = 0;
|
||||
do {
|
||||
needReload = false;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
load128(ptr+u*WARP_SIZE, v0, v1);
|
||||
needReload |= flagThread && (v1 != flag);
|
||||
load128(ptr+u*WARP_SIZE, vr[u], vr[u+1]);
|
||||
needReload |= flagThread && (vr[u+1] != flag);
|
||||
}
|
||||
} while (__any_sync(WARP_MASK, needReload) && checkAbort(spins, 0, 0) == 0);
|
||||
}
|
||||
|
||||
/************* Finish register load **************/
|
||||
if (SRC) {
|
||||
// By deferring register shuffle here we've overlapped spinning on first
|
||||
// peer's data with memory loads of src data.
|
||||
loadRegsFinish(v);
|
||||
if (SrcBuf == Input) {
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = MULTI<RedOp, T>().preOp(redOp, v[u]);
|
||||
if (!flagThread)
|
||||
v[u+1] = MULTI<RedOp, T>().preOp(redOp, v[u+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/************************ Recv rest *********************/
|
||||
if (RECV) {
|
||||
{ // Consume data from first recv
|
||||
uint64_t* ptr = recvPtr(0)+ll128Offset;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = SRC ? MULTI<RedOp, T>()(redOp, vr[u], v[u]) : vr[u];
|
||||
v[u+1] = SRC ? MULTI<RedOp, T>()(redOp, vr[u+1], v[u+1]) : vr[u+1];
|
||||
}
|
||||
} while (__any_sync(WARP_MASK, needReload) && checkAbort(0, 0) == 0);
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
load128(ptr+u*WARP_SIZE, v0, v1);
|
||||
v[u] = SRC ? MULTI<FUNC, T>()(v0, v[u]) : v0;
|
||||
v[u+1] = SRC ? MULTI<FUNC, T>()(v1, v[u+1]) : v1;
|
||||
}
|
||||
|
||||
for (int i=1; i<NRECV && i<nrecv; i++) {
|
||||
for (int i=1; i<MaxRecv && i<fan.nrecv(); i++) {
|
||||
uint64_t flag = recvFlag(i);
|
||||
uint64_t* ptr = recvPtr(i)+ll128Offset;
|
||||
uint64_t v0, v1;
|
||||
bool needReload;
|
||||
int spins = 0;
|
||||
do {
|
||||
needReload = false;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
load128(ptr+u*WARP_SIZE, v0, v1);
|
||||
needReload |= flagThread && (v1 != flag);
|
||||
load128(ptr+u*WARP_SIZE, vr[u], vr[u+1]);
|
||||
needReload |= flagThread && (vr[u+1] != flag);
|
||||
}
|
||||
} while (__any_sync(WARP_MASK, needReload) && checkAbort(i, 0) == 0);
|
||||
} while (__any_sync(WARP_MASK, needReload) && checkAbort(spins, i, 0) == 0);
|
||||
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
load128(ptr+u*WARP_SIZE, v0, v1);
|
||||
v[u] = MULTI<FUNC, T>()(v0, v[u]);
|
||||
v[u+1] = MULTI<FUNC, T>()(v1, v[u+1]);
|
||||
v[u] = MULTI<RedOp, T>()(redOp, vr[u], v[u]);
|
||||
v[u+1] = MULTI<RedOp, T>()(redOp, vr[u+1], v[u+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
/********************** End Recv ************************/
|
||||
|
||||
if (postOp && !FuncTraits<RedOp>::IsPostOpIdentity) {
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = MULTI<RedOp, T>().postOp(redOp, v[u]);
|
||||
v[u+1] = MULTI<RedOp, T>().postOp(redOp, v[u+1]);
|
||||
}
|
||||
}
|
||||
|
||||
/************************ Send **************************/
|
||||
if (SEND) {
|
||||
for (int i=1; i<NSEND && i<nsend; i++) {
|
||||
for (int i=1; i<MaxSend && i<fan.nsend(); i++) {
|
||||
uint64_t flag = sendFlag(i);
|
||||
uint64_t* ptr = sendPtr(i)+ll128Offset;
|
||||
#pragma unroll
|
||||
@ -226,82 +271,61 @@ class ncclLL128Primitives {
|
||||
}
|
||||
}
|
||||
/********************** End Send ************************/
|
||||
|
||||
/************* Data Storing : REG -> SHMEM **************/
|
||||
if (DST) {
|
||||
volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
shmem64Ptr[u*(WARP_SIZE-2)] = v[u];
|
||||
if (!flagThread) shmem64Ptr[u*(WARP_SIZE-2)+1] = v[u+1];
|
||||
}
|
||||
}
|
||||
/*********** End data Storing : REG -> SHMEM ************/
|
||||
}
|
||||
|
||||
#define LL128INC (WARP_SIZE*NCCL_LL128_SHMEM_ELEMS_PER_THREAD)
|
||||
#define ELEMINC (LL128INC-(LL128INC/NCCL_LL128_LINEELEMS))
|
||||
static constexpr int WireWordPerSlice = WARP_SIZE*NCCL_LL128_SHMEM_ELEMS_PER_THREAD;
|
||||
static constexpr int DataEltPerSlice = (WireWordPerSlice - WireWordPerSlice/NCCL_LL128_LINEELEMS)*(sizeof(uint64_t)/sizeof(T));
|
||||
|
||||
template <int RECV, int SEND, int SRC, int DST>
|
||||
__device__ void GenericOp(const T* srcPtr, T* dstPtr, int nelem) {
|
||||
if (nelem <= 0) {
|
||||
// Don't move any data but still increase steps and sync with prev/next
|
||||
if (SEND) waitSend(0);
|
||||
FOR_SEND(incSend); if (SEND) postSend();
|
||||
FOR_RECV(incRecv); if (RECV) postRecv();
|
||||
return;
|
||||
}
|
||||
const int nelem64 = ((nelem*sizeof(T))/(2*sizeof(uint64_t)))*2;
|
||||
const uint64_t* src64Ptr = ((uint64_t*)srcPtr);
|
||||
uint64_t* dst64Ptr = ((uint64_t*)dstPtr);
|
||||
template <int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
static_assert(-1<=SrcBuf && SrcBuf < 2, "Uhoh");
|
||||
static_assert(-1<=DstBuf && DstBuf < 2, "Uhoh");
|
||||
static_assert(DstBuf!=Input, "Mistake?");
|
||||
#if 0
|
||||
assert((SrcBuf==-1) == (srcIx==-1));
|
||||
assert((DstBuf==-1) == (dstIx==-1));
|
||||
#endif
|
||||
|
||||
int ll128Offset = LL128INC*warp+2*wid;
|
||||
int elemOffset = ELEMINC*warp;
|
||||
T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
|
||||
int wireOffset = WireWordPerSlice*warp + 2*wid;
|
||||
const int nwarps = nthreads/WARP_SIZE;
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
|
||||
if (SEND) waitSend(DIVUP(nelem*sizeof(T), ELEMINC*sizeof(uint64_t))*LL128INC*sizeof(uint64_t));
|
||||
if (SEND) waitSend(divUp(nelem, DataEltPerSlice)*WireWordPerSlice*sizeof(uint64_t));
|
||||
barrier();
|
||||
nelem -= DataEltPerSlice*warp;
|
||||
srcPtr += DataEltPerSlice*warp;
|
||||
dstPtr += DataEltPerSlice*warp;
|
||||
while (nelem > 0) {
|
||||
const int eltInSlice = min(nelem, DataEltPerSlice);
|
||||
uint64_t regs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD];
|
||||
if (SRC) loadRegsBegin(regs, srcPtr, eltInSlice);
|
||||
recvReduceSendCopy<NCCL_LL128_SHMEM_ELEMS_PER_THREAD, RECV, SEND, SrcBuf, DstBuf>(regs, wireOffset, postOp);
|
||||
if (DST) storeRegs(dstPtr, regs, eltInSlice);
|
||||
|
||||
while (elemOffset*(sizeof(uint64_t)/sizeof(T)) < nelem) {
|
||||
const int maxOffset128 = min(nelem64-elemOffset, (int)ELEMINC);
|
||||
const int maxOffset = min(nelem-(elemOffset*((int)(sizeof(uint64_t)/sizeof(T)))), (int)(ELEMINC*(sizeof(uint64_t)/sizeof(T))));
|
||||
if (SRC) {
|
||||
int done = 0;
|
||||
if ((((uint64_t)srcPtr)&0xf) == 0) {
|
||||
loadSrcToShmem128<NCCL_LL128_SHMEM_ELEMS_PER_THREAD>(maxOffset128-2*wid, src64Ptr+elemOffset+2*wid);
|
||||
done = maxOffset128*(sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
loadSrcToShmem(done, maxOffset, (T*)(src64Ptr+elemOffset));
|
||||
}
|
||||
__syncwarp();
|
||||
recvReduceSendCopy<NCCL_LL128_SHMEM_ELEMS_PER_THREAD, RECV, SEND, SRC, DST>(ll128Offset);
|
||||
__syncwarp();
|
||||
if (DST) {
|
||||
int done = 0;
|
||||
if ((((uint64_t)dstPtr)&0xf) == 0) {
|
||||
storeShmemToDst128<NCCL_LL128_SHMEM_ELEMS_PER_THREAD>(maxOffset128-2*wid, dst64Ptr+elemOffset+2*wid);
|
||||
done = maxOffset128*(sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
storeShmemToDst(done, maxOffset, (T*)(dst64Ptr+elemOffset));
|
||||
}
|
||||
__syncwarp();
|
||||
ll128Offset += LL128INC*nwarps;
|
||||
elemOffset += ELEMINC*nwarps;
|
||||
wireOffset += WireWordPerSlice*nwarps;
|
||||
srcPtr += DataEltPerSlice*nwarps;
|
||||
dstPtr += DataEltPerSlice*nwarps;
|
||||
nelem -= DataEltPerSlice*nwarps;
|
||||
}
|
||||
|
||||
barrier();
|
||||
FOR_SEND(incSend); if (SEND) postSend();
|
||||
FOR_RECV(incRecv); if (RECV) postRecv();
|
||||
if (SEND) for (int i=0; i < MaxSend; i++) sendStep[i] += 1;
|
||||
if (SEND) postSend();
|
||||
if (RECV) for (int i=0; i < MaxRecv; i++) recvStep[i] += 1;
|
||||
if (RECV) postRecv();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
||||
recvBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
|
||||
recvStep[i] = conn->step;
|
||||
if (wid == i) recvConn = conn;
|
||||
nrecv++;
|
||||
}
|
||||
__device__ __forceinline__ void loadRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv()) {
|
||||
recvConnHeadPtr = recvConn->head;
|
||||
recvConnHead = recvConn->step;
|
||||
}
|
||||
@ -311,16 +335,15 @@ class ncclLL128Primitives {
|
||||
sendBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
|
||||
sendStep[i] = conn->step;
|
||||
if (wid == i) sendConn = conn;
|
||||
nsend++;
|
||||
}
|
||||
__device__ __forceinline__ void loadSendSync() {
|
||||
if (tid < nsend) {
|
||||
if (tid < fan.nsend()) {
|
||||
sendConnHeadPtr = sendConn->head;
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
sendConnHead = sendConn->step;
|
||||
sendConnFifoPtr = sendConn->sizesFifo;
|
||||
}
|
||||
if (tid >= nthreads-WARP_SIZE && wid<nsend) {
|
||||
if (tid >= nthreads-WARP_SIZE && wid<fan.nsend()) {
|
||||
if (sendConn->sizesFifo) {
|
||||
sendConnTailPtr = sendConn->tail;
|
||||
sendConnTail = sendConn->step;
|
||||
@ -328,64 +351,74 @@ class ncclLL128Primitives {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
recvConn->step = recvConnHead;
|
||||
__threadfence_block();
|
||||
public:
|
||||
__device__ Primitives(
|
||||
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0
|
||||
):
|
||||
redOp(FuncTraits<RedOp>().make(ncclShmem.comm.nRanks)),
|
||||
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE),
|
||||
flagThread((tid%8)==7), group(group),
|
||||
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) {
|
||||
|
||||
auto *channel = &ncclShmem.channel;
|
||||
int nrecv=0, nsend=0;
|
||||
while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) {
|
||||
loadRecvConn(&channel->devPeers[recvPeers[nrecv]].recv->conn, nrecv);
|
||||
nrecv++;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveSendSync() {
|
||||
if (tid < nsend) {
|
||||
sendConn->step = sendConnHead;
|
||||
__threadfence_block();
|
||||
while (nsend < MaxSend && sendPeers[nsend] >= 0) {
|
||||
loadSendConn(&channel->devPeers[sendPeers[nsend]].send->conn, nsend);
|
||||
nsend++;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem->data+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
|
||||
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv->conn, i);
|
||||
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send->conn, i);
|
||||
this->fan = Fan(nrecv, nsend);
|
||||
loadRecvSync();
|
||||
loadSendSync();
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
}
|
||||
|
||||
__device__ void send(const T* src, int nelem) {
|
||||
return GenericOp<0, 1, 1, 0>(src, NULL, nelem);
|
||||
}
|
||||
|
||||
__device__ void recv(T* dst, int nelem) {
|
||||
return GenericOp<1, 0, 0, 1>(NULL, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceSend(const T* src, int nelem) {
|
||||
return GenericOp<1, 1, 1, 0>(src, NULL, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
|
||||
return GenericOp<1, 0, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void copySend(const T* src, T* dst, int nelem) {
|
||||
return GenericOp<0, 1, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvCopySend(T* dst, int nelem) {
|
||||
return GenericOp<1, 1, 0, 1>(NULL, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
|
||||
return GenericOp<1, 1, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ ~ncclLL128Primitives() {
|
||||
__device__ ~Primitives() {
|
||||
// Save steps for the next operation
|
||||
saveRecvSync();
|
||||
saveSendSync();
|
||||
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv())
|
||||
recvConn->step = recvConnHead;
|
||||
if (tid < fan.nsend())
|
||||
sendConn->step = sendConnHead;
|
||||
// Ensure all steps written back
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
userBufs[Input] = (T*)inputBuf;
|
||||
userBufs[Output] = (T*)outputBuf;
|
||||
}
|
||||
|
||||
__device__ void moveDataPtrs(intptr_t delta) {
|
||||
userBufs[Input] += delta;
|
||||
userBufs[Output] += delta;
|
||||
}
|
||||
|
||||
__device__ void send(intptr_t inpIx, int eltN) {
|
||||
return GenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void sendFromOutput(intptr_t outIx, int eltN) {
|
||||
return GenericOp<0, 1, Output, -1>(outIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void recv(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<1, 0, -1, Output>(-1, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvReduceSend(intptr_t inpIx, int eltN) {
|
||||
return GenericOp<1, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<1, 0, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<0, 1, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<1, 1, -1, Output>(-1, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
};
|
||||
|
463
src/collectives/device/prims_simple.h
Normal file
463
src/collectives/device/prims_simple.h
Normal file
@ -0,0 +1,463 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct,
|
||||
int SlicePerChunk, int StepPerSlice, int Unroll>
|
||||
class Primitives<
|
||||
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, Unroll>
|
||||
> {
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
static constexpr int RoleInput = 0x01,
|
||||
RoleOutput = 0x02,
|
||||
RoleWaitRecv = 0x04,
|
||||
RoleWaitSend = 0x08,
|
||||
RolePostSend = 0x10,
|
||||
RolePostRecv = 0x20,
|
||||
Aborted = 0x40,
|
||||
PtrsFifoEnabled = 0x80,
|
||||
SizesFifoEnabled = 0x100,
|
||||
DirectEnabled = 0x200,
|
||||
ThreadsSynced = 0x400;
|
||||
const int tid;
|
||||
int nthreads;
|
||||
int nworkers;
|
||||
const int stepSize;
|
||||
Fan fan;
|
||||
RedOp const redOp;
|
||||
int index; // Peer index I'm responsible for
|
||||
int flags;
|
||||
int group;
|
||||
uint64_t step;
|
||||
union {
|
||||
void **connPtrsFifoPtr; // (flags & PtrsFifoEnabled)
|
||||
T *userBuff; // (flags & (RoleInput|RoleOutput))
|
||||
T *connEltsFifo; // !(flags & (PtrsFifoEnabled|RoleInput|RoleOutput))
|
||||
};
|
||||
union {
|
||||
int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled)
|
||||
T *directBuff; // !(flags & SizesFifoEnabled)
|
||||
};
|
||||
uint64_t volatile *connStepPtr;
|
||||
uint64_t connStepCache; // Cache last seen value of (*connStepPtr)
|
||||
|
||||
// Don't use barrier 0 as it's used by the final sync
|
||||
inline __device__ void barrier() {
|
||||
if (nthreads == WARP_SIZE)
|
||||
__syncwarp();
|
||||
else
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(group+1), "r"(nthreads));
|
||||
flags |= ThreadsSynced;
|
||||
}
|
||||
inline __device__ void subBarrier() {
|
||||
if (nworkers == nthreads)
|
||||
barrier();
|
||||
else
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(group+2), "r"(nworkers));
|
||||
}
|
||||
|
||||
inline __device__ bool checkAbort(int &spins) {
|
||||
spins++;
|
||||
if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
|
||||
flags |= *ncclShmem.comm.abortFlag ? Aborted : 0;
|
||||
spins = 0;
|
||||
}
|
||||
return flags & Aborted;
|
||||
}
|
||||
|
||||
template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst>
|
||||
inline __device__ void waitPeer(intptr_t dstIx, intptr_t remoteOutIx, int offset, int nelts) {
|
||||
if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) {
|
||||
bool const isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
|
||||
int spins = 0;
|
||||
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
|
||||
connStepCache = *connStepPtr;
|
||||
if (checkAbort(spins)) break;
|
||||
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
|
||||
}
|
||||
|
||||
if (isSendNotRecv && (flags & SizesFifoEnabled))
|
||||
connSizesFifoPtr[step%NCCL_STEPS] = nelts*sizeof(T);
|
||||
|
||||
void **ptrs = isSendNotRecv ? (ncclShmem.groups[group].dsts + Dst)
|
||||
: (ncclShmem.groups[group].srcs + Src);
|
||||
if (flags & PtrsFifoEnabled)
|
||||
loadPtr(connPtrsFifoPtr + step%NCCL_STEPS, ptrs[index]);
|
||||
else if ((isSendNotRecv ? DirectSend : DirectRecv) && (flags & DirectEnabled))
|
||||
ptrs[index] = directBuff + (isSendNotRecv ? remoteOutIx : dstIx) + offset;
|
||||
else
|
||||
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
|
||||
step += StepPerSlice;
|
||||
}
|
||||
}
|
||||
|
||||
template<int Recv, int Send>
|
||||
inline __device__ void postPeer() {
|
||||
if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
|
||||
step += StepPerSlice;
|
||||
*connStepPtr = step;
|
||||
}
|
||||
}
|
||||
|
||||
template <int DirectRecv1, int DirectSend1, int Recv, int Send, int SrcBuf, int DstBuf>
|
||||
inline __device__ void genericOp(
|
||||
intptr_t srcIx, intptr_t dstIx, intptr_t remoteOutIx, int nelem, bool postOp
|
||||
) {
|
||||
constexpr int DirectRecv = 1 && Direct && DirectRecv1;
|
||||
constexpr int DirectSend = 1 && Direct && DirectSend1;
|
||||
constexpr int Src = SrcBuf != -1;
|
||||
constexpr int Dst = DstBuf != -1;
|
||||
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
int sliceSize = stepSize*StepPerSlice;
|
||||
sliceSize = max(divUp(nelem, 16*SlicePerChunk)*16, sliceSize/32);
|
||||
int slice = 0;
|
||||
int offset = 0;
|
||||
|
||||
if (tid < nworkers && offset < nelem) {
|
||||
// Worker-only loop for non-empty slices. Non-workers and empty slices are
|
||||
// processed in the loop following this if block. The benefit of splitting
|
||||
// the loop like this is we pull two branches out of the critical path.
|
||||
// Using "number of branch insns (taken or not) encountered dynamically"
|
||||
// as the performance metric, then:
|
||||
// perf_orig = 2*numslices
|
||||
// perf_new = 2+numslices
|
||||
// So the new code and old code behave the same for numslices=2, and for
|
||||
// numslices>2 the new code is superior. And note that in the case
|
||||
// numslices=1, the loop is trivially unrollable (single iteration) so we
|
||||
// don't incur that that tail branch and we still have perf_new=2.
|
||||
//
|
||||
// ORIGINAL CODE:
|
||||
// unrolled for(slices) {
|
||||
// if(worker) { // This branch removed
|
||||
// wait();
|
||||
// subBarrier();
|
||||
// if(slice not empty) // This branch removed
|
||||
// ReduceCopyMulti();
|
||||
// }
|
||||
// barrier();
|
||||
// post();
|
||||
// } // Since we no longer unroll, new branch added here
|
||||
#if __CUDA_ARCH__ < 700
|
||||
// Yeah, so all that above don't matter a lick on older hardware.
|
||||
#pragma unroll SlicePerChunk
|
||||
#else
|
||||
#pragma unroll 1
|
||||
#endif
|
||||
do {
|
||||
sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset;
|
||||
if (Src && (flags & (SrcBuf==Input ? RoleInput : RoleOutput)))
|
||||
ncclShmem.groups[group].srcs[0] = userBuff + srcIx + offset;
|
||||
if (Dst && (flags & (DstBuf==Input ? RoleInput : RoleOutput)))
|
||||
ncclShmem.groups[group].dsts[0] = userBuff + dstIx + offset;
|
||||
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(dstIx, remoteOutIx, offset, sliceSize);
|
||||
subBarrier();
|
||||
if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) {
|
||||
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
|
||||
if (Send) {
|
||||
// (1-Send) is only there to avoid compilation errors in case MaxSend=0 (and Send=0).
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, (1-Send)+MaxSend>
|
||||
(tid, nworkers, redOp, false, false,
|
||||
1, (T const**)ncclShmem.groups[group].srcs,
|
||||
fan.nsend(), (T**)ncclShmem.groups[group].dsts+1,
|
||||
sliceSize);
|
||||
}
|
||||
} else {
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, Recv+Src, Recv*MaxRecv+Src, Send+Dst, Send*MaxSend+Dst>
|
||||
(tid, nworkers, redOp, SrcBuf==Input, postOp,
|
||||
Recv*fan.nrecv()+Src, (T const**)ncclShmem.groups[group].srcs,
|
||||
Send*fan.nsend()+Dst, (T**)ncclShmem.groups[group].dsts,
|
||||
sliceSize);
|
||||
}
|
||||
barrier(); // This barrier has a counterpart in following loop
|
||||
if (Send && (flags & RolePostSend) && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
postPeer<Recv, Send>();
|
||||
offset += sliceSize;
|
||||
slice += 1;
|
||||
} while (slice < SlicePerChunk && offset < nelem);
|
||||
}
|
||||
|
||||
// Non-workers come straight here. Workers too but only once the remaining
|
||||
// slices are all empty. Since empty slices are the uncommon case, and
|
||||
// worker perf is the limiter, perf-wise this loop is effectively unentered,
|
||||
// hence just a single branch insn.
|
||||
#pragma unroll 1
|
||||
while (slice < SlicePerChunk) {
|
||||
sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset;
|
||||
{ // Only workers could have Wait roles so we know the slice must be empty
|
||||
// since we've exited the loop above.
|
||||
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(0, 0, 0, 0);
|
||||
}
|
||||
barrier(); // Has couterpart in preceding worker-only loop.
|
||||
if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
postPeer<Recv, Send>();
|
||||
offset += sliceSize;
|
||||
slice += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Scatter and gather do not support Direct
|
||||
template <int Recv, int Send>
|
||||
inline __device__ void
|
||||
ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp) {
|
||||
int offset = 0; // slice offset
|
||||
int sliceSize = stepSize*StepPerSlice;
|
||||
int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size
|
||||
|
||||
#pragma unroll
|
||||
for (int slice=0; slice<SlicePerChunk; ++slice) {
|
||||
int realSize = max(0, min(dataSize, peerElem-offset));
|
||||
if (tid < nworkers) {
|
||||
if (Send && (flags & RoleInput)) ncclShmem.groups[group].srcs[0] = userBuff + inpIx + offset;
|
||||
if (Recv && (flags & RoleOutput)) ncclShmem.groups[group].dsts[0] = userBuff + outIx + offset;
|
||||
// realSize is not accurate here; but intra-node does not rely on sizes FIFO
|
||||
waitPeer<0, 0, Recv, Send, 0, 0>(0, 0, 0, realSize);
|
||||
subBarrier();
|
||||
if (Send) {
|
||||
#pragma unroll
|
||||
for (int j=0; j<fan.nsend(); j++) {
|
||||
int i = (j+shift)%fan.nsend();
|
||||
int peerOffset = i*peerElem;
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
const T* src0 = (T*)ncclShmem.groups[group].srcs[0] + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1>(tid, nworkers, redOp, true, false, 1, &src0, 1, (T**)ncclShmem.groups[group].dsts+i, realPeerSize);
|
||||
}
|
||||
} else if (Recv) {
|
||||
#pragma unroll
|
||||
for (int j=0; j<fan.nrecv(); j++) {
|
||||
int i = (j+shift)%fan.nrecv();
|
||||
int peerOffset = i*peerElem;
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
T* dst0 = (T*)ncclShmem.groups[group].dsts[0] + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1>(tid, nworkers, redOp, false, postOp, 1, (T const**)ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (Send && (flags & RolePostSend) && realSize > 0 && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
postPeer<Recv, Send>();
|
||||
offset += realSize;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(ncclPeer *peer) {
|
||||
if (flags & (RoleWaitRecv|RolePostRecv)) {
|
||||
// For other colls: group <= 2, hence always use conn 0
|
||||
// For P2P: Direct is set to 1, hence always use conn 0
|
||||
// Ideally we should be accepting connIndex from the constructor!
|
||||
const int connIndex = Direct ? 0 : group/4;
|
||||
auto *conn = &peer->recv[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = roundUp(step, SlicePerChunk*StepPerSlice);
|
||||
if (flags & RolePostRecv) {
|
||||
connStepPtr = conn->head;
|
||||
*connStepPtr = step; // Return credits in case we rounded up.
|
||||
}
|
||||
if (flags & RoleWaitRecv) {
|
||||
ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs()
|
||||
connStepPtr = conn->tail;
|
||||
connStepCache = *connStepPtr;
|
||||
flags |= (conn->ptrsFifo != nullptr) ? PtrsFifoEnabled : 0;
|
||||
flags |= (Direct && (conn->direct & NCCL_DIRECT_GPU)) ? DirectEnabled : 0;
|
||||
if (flags & PtrsFifoEnabled)
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
else
|
||||
connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadSendConn(ncclPeer *peer) {
|
||||
if (flags & (RoleWaitSend|RolePostSend)) {
|
||||
// For other colls: group <= 2, hence always use conn 0
|
||||
// For P2P: Direct is set to 1, hence always use conn 0
|
||||
// Ideally we should be accepting connIndex from the constructor!
|
||||
const int connIndex = Direct ? 0 : group/4;
|
||||
auto *conn = &peer->send[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = roundUp(step, SlicePerChunk*StepPerSlice);
|
||||
if (flags & RolePostSend) {
|
||||
connStepPtr = conn->tail;
|
||||
}
|
||||
if (flags & RoleWaitSend) {
|
||||
ncclShmem.groups[group].sendConns[index] = conn; // WaitSend role saves since that's who needs it in setDataPtrs()
|
||||
connStepPtr = conn->head;
|
||||
connStepCache = *connStepPtr;
|
||||
flags |= (conn->ptrsFifo != nullptr) ? PtrsFifoEnabled : 0;
|
||||
if (flags & PtrsFifoEnabled)
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
else
|
||||
connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
|
||||
|
||||
if (conn->sizesFifo != nullptr) {
|
||||
flags |= SizesFifoEnabled;
|
||||
connSizesFifoPtr = conn->sizesFifo;
|
||||
}
|
||||
else if (Direct && (conn->direct & NCCL_DIRECT_GPU))
|
||||
flags |= DirectEnabled;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ Primitives(
|
||||
int tid, int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0
|
||||
):
|
||||
tid(tid),
|
||||
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)),
|
||||
redOp(FuncTraits<RedOp>::make(ncclShmem.comm.nRanks)) {
|
||||
|
||||
// For send operations, we need an extra warp to overlap the threadfence and the copy
|
||||
this->nthreads = nthreads;
|
||||
this->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0);
|
||||
this->group = group;
|
||||
|
||||
int nrecv=0, nsend=0;
|
||||
while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++;
|
||||
while (nsend < MaxSend && sendPeers[nsend] != -1) nsend++;
|
||||
this->fan = Fan(nrecv, nsend);
|
||||
|
||||
constexpr int ThreadPerSync = 8;
|
||||
static_assert(MaxSend < ThreadPerSync && MaxRecv < ThreadPerSync, "Not enough threads to cover all peers");
|
||||
|
||||
int g = tid / ThreadPerSync;
|
||||
int ng = nthreads / ThreadPerSync;
|
||||
index = tid % ThreadPerSync;
|
||||
flags = 0;
|
||||
if (g == 0) {
|
||||
if (index < nrecv) flags |= RoleWaitRecv;
|
||||
if (index == nrecv) flags |= RoleInput;
|
||||
} else if (g == 1) {
|
||||
if (index < nsend) flags |= RoleWaitSend;
|
||||
if (index == nsend) flags |= RoleOutput;
|
||||
} else if (g == ng - 2) {
|
||||
if (index < nrecv) flags |= RolePostRecv;
|
||||
} else if (g == ng - 1) {
|
||||
if (index < nsend) flags |= RolePostSend;
|
||||
}
|
||||
|
||||
int peer = 0;
|
||||
if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index];
|
||||
if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index];
|
||||
|
||||
loadRecvConn(&ncclShmem.channel.devPeers[peer]);
|
||||
loadSendConn(&ncclShmem.channel.devPeers[peer]);
|
||||
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
}
|
||||
|
||||
__device__ ~Primitives() {
|
||||
// Ensure ncclShmem.groups[].send/recvConns are available
|
||||
if (!(flags & ThreadsSynced))
|
||||
barrier();
|
||||
// Save steps for the next operation
|
||||
if (flags & (RolePostSend|RolePostRecv)) {
|
||||
auto *conns = (flags & RolePostSend) ? ncclShmem.groups[group].sendConns : ncclShmem.groups[group].recvConns;
|
||||
conns[index]->step = step;
|
||||
}
|
||||
// Make sure all threads are done writing back conn->step and done using
|
||||
// ncclShmem.groups[group]
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
if (flags & RoleInput) userBuff = (T*)inputBuf;
|
||||
if (flags & RoleOutput) userBuff = (T*)outputBuf;
|
||||
if (Direct && flags == (flags|RoleWaitRecv|DirectEnabled)) {
|
||||
int spins = 0;
|
||||
void *volatile *slot = ncclShmem.groups[group].recvConns[index]->ptrExchange;
|
||||
// Wait for consumer to consume previous value before trampling it.
|
||||
while (*slot != nullptr && !checkAbort(spins));
|
||||
directBuff = (T*)outputBuf;
|
||||
// Encode pointer by XOR'ing against some address they definitely wouldn't send
|
||||
// since we want to allow them sending us nullptr while not colliding with
|
||||
// the empty slot value.
|
||||
*slot = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(outputBuf) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
}
|
||||
if (Direct && flags == (flags|RoleWaitSend|DirectEnabled)) {
|
||||
int spins = 0;
|
||||
void *volatile *slot = ncclShmem.groups[group].sendConns[index]->ptrExchange;
|
||||
void *ptr;
|
||||
while (true) {
|
||||
ptr = *slot;
|
||||
if (ptr != nullptr || checkAbort(spins)) break;
|
||||
}
|
||||
directBuff = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(ptr) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
*slot = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void moveDataPtrs(intptr_t delta) {
|
||||
if (flags & (RoleInput|RoleOutput))
|
||||
userBuff += delta;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void send(intptr_t inpIx, int eltN) {
|
||||
genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, -1, eltN, false);
|
||||
}
|
||||
__device__ __forceinline__ void sendFromOutput(intptr_t outIx, int eltN) {
|
||||
genericOp<0, 0, 0, 1, Output, -1>(outIx, -1, -1, eltN, false);
|
||||
}
|
||||
__device__ __forceinline__ void directSend(intptr_t inpIx, intptr_t remoteOutIx, int eltN) {
|
||||
genericOp<0, 1, 0, 1, Input, -1>(inpIx, -1, remoteOutIx, eltN, false);
|
||||
}
|
||||
__device__ __forceinline__ void directSendFromOutput(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
genericOp<0, 1, 0, 1, Output, -1>(outIx, -1, remoteOutIx, eltN, false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recv(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 0, -1, Output>(-1, outIx, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directRecv(intptr_t outIx, int eltN) {
|
||||
genericOp<1, 0, 1, 0, -1, Output>(-1, outIx, -1, eltN, /*postOp=*/false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 0, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, -1, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, -1, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
// Direct is only for the send part
|
||||
genericOp<0, 1, 1, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
scatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
gather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp=false) {
|
||||
ScatterGatherOp<1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp);
|
||||
}
|
||||
};
|
@ -5,148 +5,87 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads-WARP_SIZE;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * REDUCE_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCE_CHUNKSTEPS : 1));
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ncclShmem.comm.rank;
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff);
|
||||
|
||||
ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
auto calcChunkSize = [&]__device__(ssize_t gridOffset)->int {
|
||||
int realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels));
|
||||
realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
else if (Proto::Id == NCCL_PROTO_LL)
|
||||
realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize;
|
||||
else if (Proto::Id == NCCL_PROTO_LL128)
|
||||
realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize);
|
||||
return realChunkSize;
|
||||
};
|
||||
|
||||
if (prevRank == root) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
int realChunkSize = calcChunkSize(gridOffset);
|
||||
ssize_t offset = gridOffset + bid*realChunkSize;
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
if (prevRank == root) {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
} else if (rank == root) {
|
||||
prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else {
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
prims.send(offset, nelem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = comm->rank;
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
else if (rank == root) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
if (size-gridOffset < loopSize) {
|
||||
chunkSize = args->coll.lastChunkSize;
|
||||
}
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (prevRank == root) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else if (rank == root) {
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
int realChunkSize = calcChunkSize(gridOffset);
|
||||
ssize_t offset = gridOffset + bid*realChunkSize;
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = comm->rank;
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (prevRank == root) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else if (rank == root) {
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
int realChunkSize = calcChunkSize(gridOffset);
|
||||
ssize_t offset = gridOffset + bid*realChunkSize;
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -10,6 +10,7 @@
|
||||
|
||||
#include "common_kernel.h"
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
template<typename T>
|
||||
struct FuncNull {
|
||||
@ -46,6 +47,18 @@ struct FuncMin {
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Fn>
|
||||
struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max
|
||||
static constexpr bool IsPreOpIdentity = true;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
|
||||
__device__ static Fn make(int rankN) { return Fn(); }
|
||||
template<typename T>
|
||||
__device__ static T preOp(Fn, T x) { return x; }
|
||||
template<typename T>
|
||||
__device__ static T postOp(Fn, T x) { return x; }
|
||||
};
|
||||
|
||||
#define MASK0 0x00ff00ff
|
||||
#define MASK1 0xff00ff00
|
||||
static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) {
|
||||
@ -239,6 +252,31 @@ struct FuncSum<half> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<>
|
||||
struct FuncSum<__nv_bfloat16> {
|
||||
__device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hadd2(x, y);
|
||||
#else
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
|
||||
#endif
|
||||
}
|
||||
__device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hadd(x, y);
|
||||
#else
|
||||
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncProd<half> {
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
@ -262,6 +300,31 @@ struct FuncProd<half> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<>
|
||||
struct FuncProd<__nv_bfloat16> {
|
||||
__device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmul2(x, y);
|
||||
#else
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
|
||||
#endif
|
||||
}
|
||||
__device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmul(x, y);
|
||||
#else
|
||||
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMax<half> {
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
@ -281,6 +344,34 @@ struct FuncMax<half> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<>
|
||||
struct FuncMax<__nv_bfloat16> {
|
||||
__device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmax2(x, y);
|
||||
#else
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fmaxf(fxl, fyl), fmaxf(fxh, fyh));
|
||||
#endif
|
||||
}
|
||||
__device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmax(x, y);
|
||||
#else
|
||||
float fx, fy;
|
||||
fx = __bfloat162float(x);
|
||||
fy = __bfloat162float(y);
|
||||
return __float2bfloat16(fmaxf(fx, fy));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMin<half> {
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
@ -299,4 +390,269 @@ struct FuncMin<half> {
|
||||
return __float2half(fm);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<>
|
||||
struct FuncMin<__nv_bfloat16> {
|
||||
__device__ __nv_bfloat162 operator()(const __nv_bfloat162 x, const __nv_bfloat162 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmin2(x, y);
|
||||
#else
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fminf(fxl, fyl), fminf(fxh, fyh));
|
||||
#endif
|
||||
}
|
||||
__device__ __nv_bfloat16 operator()(const __nv_bfloat16 x, const __nv_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmin(x, y);
|
||||
#else
|
||||
float fx, fy;
|
||||
fx = __bfloat162float(x);
|
||||
fy = __bfloat162float(y);
|
||||
return __float2bfloat16(fminf(fx, fy));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMax<float> {
|
||||
__device__ float operator()(float x, float y) const {
|
||||
return fmaxf(x, y);
|
||||
}
|
||||
};
|
||||
template<>
|
||||
struct FuncMin<float> {
|
||||
__device__ float operator()(float x, float y) const {
|
||||
return fminf(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncMax<double> {
|
||||
__device__ double operator()(double x, double y) const {
|
||||
return fmax(x, y);
|
||||
}
|
||||
};
|
||||
template<>
|
||||
struct FuncMin<double> {
|
||||
__device__ double operator()(double x, double y) const {
|
||||
return fmin(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct FuncAvg: FuncSum<T> {
|
||||
static_assert(!std::is_floating_point<T>::value, "Uhoh");
|
||||
static constexpr bool IsPreOpIdentity = true;
|
||||
static constexpr bool IsPostOpIdentity = false;
|
||||
int n;
|
||||
|
||||
template<typename ...Arg>
|
||||
__device__ FuncAvg(int n): n(n) {}
|
||||
|
||||
__device__ T preOp(T x) const {
|
||||
return x;
|
||||
}
|
||||
__device__ T postOp(T x) const {
|
||||
return T(x/n);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<double>: FuncSum<double> {
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
double rcp;
|
||||
__device__ FuncAvg(int n) {
|
||||
rcp = __drcp_rn(double(n));
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ double preOp(double x) const {
|
||||
return IsPreOpIdentity ? x : x*rcp;
|
||||
}
|
||||
__device__ double postOp(double x) const {
|
||||
return IsPostOpIdentity ? x : x*rcp;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<float>: FuncSum<float> {
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
float rcp;
|
||||
__device__ FuncAvg(int n) {
|
||||
rcp = __frcp_rn(float(n));
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ float preOp(float x) const {
|
||||
return IsPreOpIdentity ? x : x*rcp;
|
||||
}
|
||||
__device__ float postOp(float x) const {
|
||||
return IsPostOpIdentity ? x : x*rcp;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<half>: FuncSum<half> {
|
||||
// Change these to switch between all prescale, all postscale, or both by sqrt(N).
|
||||
// Obviously, the only invalid combination is both true. An improvement would be
|
||||
// make this parameterized as a build time setting and passed here through
|
||||
// preprocessor definitions.
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
|
||||
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
|
||||
half2 scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale.x = __float2half(__frsqrt_rn(float(n)));
|
||||
else
|
||||
scale.x = __float2half(__frcp_rn(float(n)));
|
||||
scale.y = scale.x;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ half preOp(half x) const {
|
||||
return IsPreOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
__device__ half2 preOp(half2 x) const {
|
||||
return IsPreOpIdentity ? x : __hmul2(x, scale);
|
||||
}
|
||||
__device__ half postOp(half x) const {
|
||||
return IsPostOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
__device__ half2 postOp(half2 x) const {
|
||||
return IsPostOpIdentity ? x : __hmul2(x, scale);
|
||||
}
|
||||
#else
|
||||
float scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale = __frsqrt_rn(float(n));
|
||||
else
|
||||
scale = __frcp_rn(float(n));
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ half preOp(half x) const {
|
||||
return IsPreOpIdentity ? x : __float2half(__half2float(x)*scale);
|
||||
}
|
||||
__device__ half2 preOp(half2 x) const {
|
||||
if (IsPreOpIdentity)
|
||||
return x;
|
||||
else {
|
||||
float2 a = __half22float2(x);
|
||||
a.x *= scale;
|
||||
a.y *= scale;
|
||||
return __float22half2_rn(a);
|
||||
}
|
||||
}
|
||||
__device__ half postOp(half x) const {
|
||||
return IsPostOpIdentity ? x : __float2half(__half2float(x)*scale);
|
||||
}
|
||||
__device__ half2 postOp(half2 x) const {
|
||||
if (IsPostOpIdentity)
|
||||
return x;
|
||||
else {
|
||||
float2 a = __half22float2(x);
|
||||
a.x *= scale;
|
||||
a.y *= scale;
|
||||
return __float22half2_rn(a);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<>
|
||||
struct FuncAvg<__nv_bfloat16>: FuncSum<__nv_bfloat16> {
|
||||
// Change these to switch between all prescale, all postscale, or both by sqrt(N).
|
||||
// Obviously, the only invalid combination is both true. An improvement would be
|
||||
// make this parameterized as a build time setting and passed here through
|
||||
// preprocessor definitions.
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
__nv_bfloat162 scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale.x = __float2bfloat16(__frsqrt_rn(float(n)));
|
||||
else
|
||||
scale.x = __float2bfloat16(__frcp_rn(float(n)));
|
||||
scale.y = scale.x;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ __nv_bfloat16 preOp(__nv_bfloat16 x) const {
|
||||
return IsPreOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
__device__ __nv_bfloat162 preOp(__nv_bfloat162 x) const {
|
||||
return IsPreOpIdentity ? x : __hmul2(x, scale);
|
||||
}
|
||||
__device__ __nv_bfloat16 postOp(__nv_bfloat16 x) const {
|
||||
return IsPostOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
__device__ __nv_bfloat162 postOp(__nv_bfloat162 x) const {
|
||||
return IsPostOpIdentity ? x : __hmul2(x, scale);
|
||||
}
|
||||
#else
|
||||
float scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale = __frsqrt_rn(float(n));
|
||||
else
|
||||
scale = __frcp_rn(float(n));
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ __nv_bfloat16 preOp(__nv_bfloat16 x) const {
|
||||
return IsPreOpIdentity ? x : __float2bfloat16(__bfloat162float(x)*scale);
|
||||
}
|
||||
__device__ __nv_bfloat162 preOp(__nv_bfloat162 x) const {
|
||||
if (IsPreOpIdentity)
|
||||
return x;
|
||||
else {
|
||||
float fxl, fxh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
return __floats2bfloat162_rn(fxl * scale, fxh * scale);
|
||||
}
|
||||
}
|
||||
__device__ __nv_bfloat16 postOp(__nv_bfloat16 x) const {
|
||||
return IsPostOpIdentity ? x : __float2bfloat16(__bfloat162float(x)*scale);
|
||||
}
|
||||
__device__ __nv_bfloat162 postOp(__nv_bfloat162 x) const {
|
||||
if (IsPostOpIdentity)
|
||||
return x;
|
||||
else {
|
||||
float fxl, fxh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
return __floats2bfloat162_rn(fxl * scale, fxh * scale);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
struct FuncTraits<FuncAvg<T>> {
|
||||
static constexpr bool IsPreOpIdentity = FuncAvg<T>::IsPreOpIdentity;
|
||||
static constexpr bool IsPostOpIdentity = FuncAvg<T>::IsPostOpIdentity;
|
||||
|
||||
__device__ static FuncAvg<T> make(int rankN) {
|
||||
return FuncAvg<T>(rankN);
|
||||
}
|
||||
template<typename U>
|
||||
__device__ static U preOp(FuncAvg<T> fn, U x) {
|
||||
return fn.preOp(x);
|
||||
}
|
||||
template<typename U>
|
||||
__device__ static U postOp(FuncAvg<T> fn, U x) {
|
||||
return fn.postOp(x);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // REDUCE_KERNEL_H_
|
||||
|
@ -5,192 +5,85 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads-WARP_SIZE;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * REDUCESCATTER_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclRing *ring = &ncclShmem.channel.ring;
|
||||
int const *ringRanks = ring->devUserRanks;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCESCATTER_CHUNKSTEPS : 1));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2);
|
||||
const int nranks = ncclShmem.comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff);
|
||||
|
||||
ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t chunkOffset = gridOffset + bid*realChunkSize;
|
||||
|
||||
/////////////// begin ReduceScatter steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[nranks-1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
prims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final result
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
prims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels));
|
||||
realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
}
|
||||
};
|
||||
else if (Proto::Id == NCCL_PROTO_LL)
|
||||
realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize;
|
||||
else if (Proto::Id == NCCL_PROTO_LL128)
|
||||
realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize);
|
||||
realChunkSize = int(realChunkSize);
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
ssize_t chunkOffset = gridOffset + bid*int(realChunkSize);
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
/////////////// begin ReduceScatter steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ringRanks[nranks-1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
prims.send(offset, nelem);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
if (size-gridOffset < loopSize) {
|
||||
chunkSize = args->coll.lastChunkSize;
|
||||
}
|
||||
ssize_t chunkOffset = gridOffset + bid*chunkSize;
|
||||
|
||||
/////////////// begin ReduceScatter steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(chunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[nranks-1];
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
rankDest = ringRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final result
|
||||
rankDest = ringRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
prims.recvReduceCopy(offset, chunkOffset, nelem, /*postOp=*/true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
|
||||
|
||||
ssize_t chunkOffset = gridOffset + bid*chunkSize;
|
||||
|
||||
/////////////// begin ReduceScatter steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(chunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[nranks-1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
|
||||
}
|
||||
}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
@ -5,89 +5,87 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncSendRecv, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* firstArgs) {
|
||||
struct ncclWorkElem* args = firstArgs;
|
||||
int tid = threadIdx.x;
|
||||
int group = 0;
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s++) {
|
||||
int nThreadsSegment = args->p2p.nThreads;
|
||||
if (nThreadsSegment == 0) return; // Nothing else to do
|
||||
int groupRecv = group;
|
||||
group += 1;
|
||||
int groupSend = group;
|
||||
group += nThreadsSegment > 128 ? 2 : 1;
|
||||
if (tid < nThreadsSegment) {
|
||||
const int nThreads = nThreadsSegment > 128 ? nThreadsSegment-WARP_SIZE : nThreadsSegment;
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWork *work) {
|
||||
int tid = threadIdx.x;
|
||||
int group = 0;
|
||||
const int rank = ncclShmem.comm.rank;
|
||||
const int nRanks = ncclShmem.comm.nRanks;
|
||||
using Proto = ProtoSimple<1, 1>;
|
||||
|
||||
// Compute pointers
|
||||
const T* sendbuff = (const T*)args->sendbuff;
|
||||
T* recvbuff = (T*)args->recvbuff;
|
||||
const ssize_t sendCount = args->p2p.sendCount;
|
||||
const ssize_t recvCount = args->p2p.recvCount;
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s++) {
|
||||
ncclWorkElem *args = &work->elems[s];
|
||||
int nThreadsSegment = args->p2p.nThreads;
|
||||
if (args->active == 0 || nThreadsSegment == 0) break;
|
||||
|
||||
const int delta = args->p2p.delta;
|
||||
if (delta == 0) {
|
||||
if (tid < nThreads && sendbuff != recvbuff) {
|
||||
// local copy : ReduceOrCopyMulti takes an int as number of elements,
|
||||
// so we split it in blocks of 1G elements.
|
||||
int blockSize = 1<<30;
|
||||
for (size_t offset=0; offset<sendCount; offset += blockSize) {
|
||||
size_t remaining = sendCount - offset;
|
||||
if (remaining < blockSize) blockSize = remaining;
|
||||
ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nThreads, 1, &sendbuff, 1, &recvbuff, blockSize);
|
||||
sendbuff += blockSize; recvbuff += blockSize;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
int nThreadsSplit = (nThreadsSegment - (nThreadsSegment > 128 ? WARP_SIZE : 0))/2;
|
||||
int groupRecv = group;
|
||||
group += Proto::calcGroupWidth(/*send=*/false, nThreadsSplit);
|
||||
int groupSend = group;
|
||||
group += Proto::calcGroupWidth(/*send=*/true, nThreadsSegment - nThreadsSplit);
|
||||
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/(sizeof(T)*NCCL_STEPS);
|
||||
if (tid < nThreadsSegment) {
|
||||
// Compute pointers
|
||||
T const* sendbuff = (const T*)args->sendbuff;
|
||||
T* recvbuff = (T*)args->recvbuff;
|
||||
ssize_t const sendCount = args->p2p.sendCount;
|
||||
ssize_t const recvCount = args->p2p.recvCount;
|
||||
int const delta = args->p2p.delta;
|
||||
|
||||
int nThreadsSplit = nThreads/2;
|
||||
if ((tid < nThreadsSplit) && recvCount >= 0) {
|
||||
const int chunkSize = args->p2p.recvChunkSize/sizeof(T);
|
||||
int peer = (comm->rank-delta+comm->nRanks)%comm->nRanks;
|
||||
int nt = nThreadsSplit;
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, 0, 1, FUNC>
|
||||
prims(tid, nt, &peer, NULL, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, groupRecv);
|
||||
|
||||
if (recvCount == 0) {
|
||||
prims.recv(recvbuff, 0);
|
||||
} else for (ssize_t offset = 0; offset < recvCount; offset += chunkSize) {
|
||||
int realChunkSize = min(chunkSize, recvCount-offset);
|
||||
ALIGN_SIZE(realChunkSize, nt*sizeof(uint64_t)/sizeof(T));
|
||||
int nelem = min(realChunkSize, recvCount-offset);
|
||||
prims.directRecv(recvbuff+offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
if ((tid >= nThreadsSplit) && sendCount >= 0) {
|
||||
const int chunkSize = args->p2p.sendChunkSize/sizeof(T);
|
||||
int peer = (comm->rank+delta)%comm->nRanks;
|
||||
int nt = nThreads-nThreadsSplit;
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 0, 1, 1, FUNC>
|
||||
prims(tid-nThreadsSplit, nt, NULL, &peer, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, groupSend);
|
||||
|
||||
if (sendCount == 0) {
|
||||
prims.send(sendbuff, 0);
|
||||
} else for (ssize_t offset = 0; offset < sendCount; offset += chunkSize) {
|
||||
int realChunkSize = min(chunkSize, sendCount-offset);
|
||||
ALIGN_SIZE(realChunkSize, nt*sizeof(uint64_t)/sizeof(T));
|
||||
int nelem = min(realChunkSize, sendCount-offset);
|
||||
prims.directSend(sendbuff+offset, offset, nelem);
|
||||
}
|
||||
if (delta == 0) {
|
||||
if (sendbuff != recvbuff) {
|
||||
// local copy : ReduceOrCopyMulti takes an int as number of elements,
|
||||
// so we split it in blocks of 1G elements.
|
||||
int blockSize = 1<<30;
|
||||
for (size_t offset=0; offset<sendCount; offset += blockSize) {
|
||||
size_t remaining = sendCount - offset;
|
||||
if (remaining < blockSize) blockSize = remaining;
|
||||
ReduceOrCopyMulti<COLL_UNROLL, RedOp, T, 1, 1, 1, 1>(tid, nThreadsSegment, RedOp(), false, false, 1, &sendbuff, 1, &recvbuff, blockSize);
|
||||
sendbuff += blockSize;
|
||||
recvbuff += blockSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
tid -= nThreadsSegment;
|
||||
if (tid < 0) return;
|
||||
args++;
|
||||
else {
|
||||
if ((tid < nThreadsSplit) && recvCount >= 0) {
|
||||
int const peer = (rank - delta + nRanks)%nRanks;
|
||||
int const t0 = 0;
|
||||
int const nt = nThreadsSplit;
|
||||
int const chunkSize = args->p2p.recvChunkSize/sizeof(T);
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, 1, Proto> prims
|
||||
(tid-t0, nt, &peer, nullptr, nullptr, recvbuff, groupRecv);
|
||||
ssize_t offset = 0;
|
||||
do {
|
||||
int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T)));
|
||||
nelem = min(chunkSize, recvCount-offset);
|
||||
prims.directRecv(offset, nelem);
|
||||
offset += nelem;
|
||||
} while(offset < recvCount);
|
||||
}
|
||||
|
||||
if ((tid >= nThreadsSplit) && sendCount >= 0) {
|
||||
int const peer = (rank + delta)%nRanks;
|
||||
int const t0 = nThreadsSplit;
|
||||
int const nt = nThreadsSegment - nThreadsSplit;
|
||||
int const chunkSize = args->p2p.sendChunkSize/sizeof(T);
|
||||
Primitives<T, RedOp, FanAsymmetric<0, 1>, 1, Proto> prims
|
||||
(tid-t0, nt, nullptr, &peer, sendbuff, nullptr, groupSend);
|
||||
ssize_t offset = 0;
|
||||
do {
|
||||
int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T)));
|
||||
nelem = min(chunkSize, sendCount-offset);
|
||||
prims.directSend(offset, offset, nelem);
|
||||
offset += nelem;
|
||||
} while(offset < sendCount);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
tid -= nThreadsSegment;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
238
src/enqueue.cc
238
src/enqueue.cc
@ -20,6 +20,31 @@
|
||||
(void*)NCCL_FUNC5(func, RING, redop, type), \
|
||||
(void*)NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int32_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint32_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int64_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint64_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, half), \
|
||||
(void*)NCCL_FUNC4(func, redop, float), \
|
||||
(void*)NCCL_FUNC4(func, redop, double), \
|
||||
(void*)NCCL_FUNC4(func, redop, __nv_bfloat16)
|
||||
#define NCCL_FUNCS3B(func, redop) \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t)
|
||||
#else
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
@ -41,17 +66,20 @@
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t)
|
||||
#endif
|
||||
|
||||
// Must be consistent with ncclRedOp_t -- but we only generate kernel for sums.
|
||||
#define NCCL_FUNCS2A(func) \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum)
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
|
||||
// Must be consistent with the ncclFuncSet enum
|
||||
@ -154,16 +182,11 @@ static ncclResult_t setupLaunch(struct ncclQueueInfo* eqInfo, int usingCudaGraph
|
||||
channel->workFifo[(channel->workFifoTail-1)%NCCL_MAX_OPS].elems[0].active = 2;
|
||||
|
||||
if (c == 0) {
|
||||
// Find the first operation, choose the kernel accordingly and pass it as the first argument.
|
||||
// Note that changing cuda launch argument after capture is not supported by cudaGraph
|
||||
// As we inline the first coll directly, we can free it immediately.
|
||||
// Except P2P or aggregation cases
|
||||
struct ncclWork* work = channel->workFifo+((channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS);
|
||||
struct ncclWorkElem* elem = work->elems;
|
||||
if (!usingCudaGraph) {
|
||||
params->func = ncclKerns[elem->funcIndex];
|
||||
memcpy(&comm->args, elem, sizeof(struct ncclWorkElem));
|
||||
}
|
||||
// As we inline the first coll directly, we can free it immediately.
|
||||
if (elem->funcIndex != FUNC_INDEX_P2P) elem->active = 0;
|
||||
if (elem->funcIndex != FUNC_INDEX_P2P && eqInfo->elemList->count() == 1) elem->active = 0;
|
||||
}
|
||||
|
||||
if (channel->gdrMemDesc) {
|
||||
@ -292,6 +315,7 @@ static ncclResult_t ncclLaunchProxy(struct ncclQueueInfo* eqInfo) {
|
||||
for (int r=0; r<eqInfo->maxChannels; r++) {
|
||||
struct ncclChannel* channel = comm->channels+r;
|
||||
channel->workCount = 0;
|
||||
channel->totalSize = 0;
|
||||
}
|
||||
comm->lastChannel = 0;
|
||||
NCCLCHECK(ncclProxyStart(comm));
|
||||
@ -323,8 +347,7 @@ ncclResult_t ncclLaunchReset(ncclComm_t comm) {
|
||||
// But we need to keep the current enqueue info for CUDA graph
|
||||
// Thus we need to creating a new enqueue info for the next run
|
||||
if (comm->usingCudaGraph) {
|
||||
NCCLCHECK(ncclCalloc(&comm->enqueueInfo, 1));
|
||||
comm->enqueueInfo->comm = comm;
|
||||
NCCLCHECK(ncclCreateQueueInfo(&comm->enqueueInfo, comm));
|
||||
} else {
|
||||
// If not in CUDA graph mode, we reuse the same info space
|
||||
NCCLCHECK(ncclResetQueueInfo(comm->enqueueInfo));
|
||||
@ -345,22 +368,29 @@ ncclResult_t ncclLaunchReset(ncclComm_t comm) {
|
||||
/* Enqueueing system : computation of kernel and proxy operations parameters */
|
||||
/*****************************************************************************/
|
||||
|
||||
static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
|
||||
static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetTypeSupport) {
|
||||
if (info->comm->collNetSupport > 0) {
|
||||
ncclRedOp_t netOp = info->op == ncclAvg ? ncclSum : info->op;
|
||||
NCCLCHECK(collNetReduceSupport(info->datatype, netOp, collNetTypeSupport));
|
||||
} else {
|
||||
*collNetTypeSupport = 0;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, int numPipeOps) {
|
||||
struct ncclComm* comm = info->comm;
|
||||
float minTime = 3600000000.0; // Hopefully no operation will take an hour to complete.
|
||||
// Find algorithm / protocol.
|
||||
info->algorithm = -1;
|
||||
info->protocol = -1;
|
||||
if (comm->nRanks == 1) return ncclSuccess;
|
||||
int nAlgos = NCCL_NUM_ALGORITHMS;
|
||||
// Check collNet support
|
||||
int collNetTypeSupport = 0;
|
||||
if (info->comm->collNetSupport > 0)
|
||||
NCCLCHECK(collNetReduceSupport(info->datatype, info->op, &collNetTypeSupport));
|
||||
for (int a=0; a<nAlgos; a++) {
|
||||
if (a == NCCL_ALGO_COLLNET && collNetTypeSupport != 1) continue;
|
||||
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
|
||||
float time;
|
||||
NCCLCHECK(ncclTopoGetAlgoTime(info, a, p, &time));
|
||||
NCCLCHECK(ncclTopoGetAlgoTime(info, a, p, numPipeOps, &time));
|
||||
if (time >= 0 && time < minTime) {
|
||||
info->algorithm = a;
|
||||
info->protocol = p;
|
||||
@ -397,7 +427,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
|
||||
}
|
||||
if (info->protocol == NCCL_PROTO_SIMPLE) {
|
||||
nt += WARP_SIZE; // Extra warp for sync
|
||||
if (info->algorithm == NCCL_ALGO_TREE) nt += WARP_SIZE;
|
||||
if (info->algorithm == NCCL_ALGO_TREE) nt += 3*WARP_SIZE;
|
||||
if (info->algorithm == NCCL_ALGO_COLLNET) nt += 3*WARP_SIZE;
|
||||
}
|
||||
info->nChannels = nc;
|
||||
@ -447,8 +477,14 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) {
|
||||
static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWorkElem* work, struct ncclProxyArgs* proxyArgs /* output */) {
|
||||
work->comm = info->comm->devComm;
|
||||
|
||||
int collNetTypeSupport = 0;
|
||||
// Check whether algo and proto have been preset
|
||||
if (info->nChannels > 0 && info->nThreads > 0) goto comp_next;
|
||||
NCCLCHECK(getCollNetSupport(info, &collNetTypeSupport));
|
||||
NCCLCHECK(getAlgoInfo(info, collNetTypeSupport, 1));
|
||||
|
||||
comp_next:
|
||||
// Set nstepsPerLoop and nchunksPerLoop
|
||||
NCCLCHECK(getAlgoInfo(info));
|
||||
NCCLCHECK(getPatternInfo(info));
|
||||
NCCLCHECK(getLoopInfo(info));
|
||||
|
||||
@ -478,10 +514,9 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo
|
||||
work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
|
||||
} else if (info->algorithm == NCCL_ALGO_COLLNET && info->protocol == NCCL_PROTO_SIMPLE) {
|
||||
// Optimize chunkSize / nSteps
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*32 && chunkSize > 262144) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*16 && chunkSize > 131072) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*64 && chunkSize > 131072) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*8 && chunkSize > 65536) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*8 && chunkSize > 32768) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth/2 && chunkSize > 16384) chunkSize /= 2;
|
||||
// Use lastChunkSize as chunkSize
|
||||
work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
|
||||
} else if (info->protocol == NCCL_PROTO_LL) {
|
||||
@ -512,7 +547,9 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo
|
||||
proxyArgs->chunkSize = chunkSize;
|
||||
proxyArgs->protocol = info->protocol;
|
||||
proxyArgs->dtype = info->datatype;
|
||||
proxyArgs->redOp = (info->algorithm == NCCL_ALGO_COLLNET) ? info->op : ncclNumOps; // Only set redOp when using CollNet
|
||||
proxyArgs->redOp = info->algorithm != NCCL_ALGO_COLLNET ? ncclNumOps : // Only set redOp when using CollNet
|
||||
info->op == ncclAvg ? ncclSum : // Network sees avg as sum
|
||||
info->op;
|
||||
proxyArgs->pattern = info->pattern;
|
||||
proxyArgs->root = info->root;
|
||||
// This is used by P2P to reduce the receive buffer size. We don't use it in collectives
|
||||
@ -550,7 +587,7 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) {
|
||||
|
||||
// Compute cuda kernel arg and proxy arg templates
|
||||
struct ncclQueueElem* eqElem;
|
||||
NCCLCHECK(ncclAddQueueElem(comm->enqueueInfo, &eqElem));
|
||||
NCCLCHECK(comm->enqueueInfo->elemList->getNewElem(&eqElem));
|
||||
struct ncclWorkElem* work = &eqElem->work;
|
||||
eqElem->proxyArgs.nsubs = 1;
|
||||
NCCLCHECK(computeColl(info, work, &eqElem->proxyArgs));
|
||||
@ -573,6 +610,29 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static inline int findShortestChannel(ncclComm_t comm) {
|
||||
size_t minSize = SIZE_MAX;
|
||||
int minC = 0;
|
||||
for (int c=0; c<comm->nChannels; c++) {
|
||||
struct ncclChannel* channel = comm->channels+c;
|
||||
if (channel->totalSize < minSize) {
|
||||
minSize = channel->totalSize;
|
||||
minC = c;
|
||||
}
|
||||
}
|
||||
return minC;
|
||||
}
|
||||
|
||||
static inline ncclResult_t getNextChannel(ncclComm_t comm, int* nextChannel) {
|
||||
if (comm->asyncAllocMode == ncclComm::SHORTEST_QUEUE) {
|
||||
*nextChannel = findShortestChannel(comm);
|
||||
} else {
|
||||
*nextChannel = comm->lastChannel % comm->nChannels;
|
||||
comm->lastChannel++;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Dynamic enqueue code
|
||||
static ncclResult_t ncclEnqueueCollKernel(ncclComm_t comm, struct ncclQueueElem* eqElem) {
|
||||
struct ncclWorkElem* work = &eqElem->work;
|
||||
@ -600,9 +660,6 @@ static ncclResult_t ncclEnqueueCollKernel(ncclComm_t comm, struct ncclQueueElem*
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64)
|
||||
#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */
|
||||
|
||||
ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
|
||||
if (comm->asyncOpCount == 0) {
|
||||
return ncclSuccess;
|
||||
@ -613,19 +670,47 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
|
||||
NCCLCHECK(ncclSetupCollKernel(info));
|
||||
} else {
|
||||
// Aggregation
|
||||
size_t channelSize = NCCL_AGG_CHANNEL_SIZE * comm->nRanks; // scale channel size based on nranks as latency increases
|
||||
size_t channelSize;
|
||||
if (comm->channelSize > 0) {
|
||||
channelSize = comm->channelSize;
|
||||
} else if (comm->collNetSupport && comm->asyncOps[0].coll == ncclFuncAllReduce) {
|
||||
channelSize = 256 * 1024;
|
||||
} else {
|
||||
channelSize = NCCL_AGG_CHANNEL_SIZE * std::min(16, comm->nRanks); // scale channel size based on nranks as latency increases
|
||||
}
|
||||
// Reduce the per-channel size if we cannot fully utilize the channels
|
||||
while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2;
|
||||
int channelUsed = 0;
|
||||
ncclFunc_t commonColl = ncclNumFuncs;
|
||||
int fastPath = 1;
|
||||
int allCollNetSupport = comm->collNetSupport;
|
||||
for (int c = 0; c < comm->asyncOpCount; c++) {
|
||||
struct ncclInfo* info = comm->asyncOps+c;
|
||||
info->nChannels = std::min((int)DIVUP(info->nBytes, channelSize), comm->nChannels); // assign number of channels
|
||||
info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels
|
||||
channelUsed += info->nChannels;
|
||||
// We can use fast path if all collectives are the same
|
||||
if (commonColl == ncclNumFuncs) commonColl = info->coll;
|
||||
else if (commonColl != info->coll) fastPath = 0;
|
||||
else if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
|
||||
}
|
||||
// Compute algo, proto, nthreads for the entire kernel
|
||||
struct ncclInfo total;
|
||||
total.comm = comm;
|
||||
total.coll = commonColl;
|
||||
total.nBytes = comm->asyncTotalSize;
|
||||
total.nChannels = std::min(channelUsed, comm->nChannels);
|
||||
int perChannelOps = DIVUP(channelUsed, total.nChannels);
|
||||
if (fastPath) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps));
|
||||
for (int c = 0; c < comm->asyncOpCount; c++) {
|
||||
struct ncclInfo* info = comm->asyncOps+c;
|
||||
if (fastPath) {
|
||||
info->algorithm = total.algorithm;
|
||||
info->protocol = total.protocol;
|
||||
info->nThreads = total.nThreads;
|
||||
}
|
||||
NCCLCHECK(ncclSetupCollKernel(info));
|
||||
}
|
||||
// If we wrap around on channels, then the inlined op on channel 0 is not the last one on this channel
|
||||
// Then we need to change active from 2 to 1
|
||||
if (channelUsed > comm->nChannels) comm->args.active = 1;
|
||||
comm->args.active = 0; // disable inline argument
|
||||
}
|
||||
// Reset counters
|
||||
comm->asyncOpCount = 0;
|
||||
@ -662,7 +747,7 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) {
|
||||
}
|
||||
}
|
||||
}
|
||||
NCCLCHECK(enqueueP2pInfo(comm->p2pSends+info->root, (void*)info->sendbuff, nBytes));
|
||||
NCCLCHECK(ncclSaveP2pInfo(comm->p2pSends[info->root], (void*)info->sendbuff, nBytes));
|
||||
comm->p2pSendCount++;
|
||||
} else {
|
||||
if (peer != comm->rank) {
|
||||
@ -675,15 +760,22 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) {
|
||||
}
|
||||
}
|
||||
}
|
||||
NCCLCHECK(enqueueP2pInfo(comm->p2pRecvs+info->root, info->recvbuff, nBytes));
|
||||
NCCLCHECK(ncclSaveP2pInfo(comm->p2pRecvs[info->root], info->recvbuff, nBytes));
|
||||
comm->p2pRecvCount++;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static int getSegment(int delta, struct ncclWork* work) {
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS && work->elems[s].p2p.delta != delta; s++) {
|
||||
if (work->elems[s].p2p.nThreads == 0) return s;
|
||||
enum { COLL_SEGMENT=0, P2P_SEGMENT=1 };
|
||||
static int getSegment(int type, int delta, struct ncclWork* work) {
|
||||
if (type == P2P_SEGMENT) { // P2P
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS && work->elems[s].p2p.delta != delta; s++) {
|
||||
if (work->elems[s].active == 0) return s;
|
||||
}
|
||||
} else { // aggregation
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s++) {
|
||||
if (work->elems[s].active == 0) return s;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@ -702,16 +794,19 @@ static ncclResult_t computeP2pWorkElem(struct ncclInfo* info /* input */, struct
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t enqueueP2pOp(struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s) {
|
||||
static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s) {
|
||||
// Copy element into corresponding segment of ncclWork
|
||||
memcpy(work->elems+s, elem, sizeof(struct ncclWorkElem));
|
||||
work->elems[s].active = 1;
|
||||
|
||||
// Determine nThreads at dynamic time
|
||||
const int nsegments = s+1;
|
||||
int nThreads = 512;
|
||||
while (nsegments*nThreads > 512) nThreads /= 2;
|
||||
if (nThreads >= 128) nThreads += WARP_SIZE;
|
||||
for (int i=0; i<nsegments; i++) work->elems[i].p2p.nThreads = nThreads;
|
||||
if (type == P2P_SEGMENT) {
|
||||
const int nsegments = s+1;
|
||||
int nThreads = 512;
|
||||
while (nsegments*nThreads > 512) nThreads /= 2;
|
||||
if (nThreads >= 128) nThreads += WARP_SIZE;
|
||||
for (int i=0; i<nsegments; i++) work->elems[i].p2p.nThreads = nThreads;
|
||||
}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
@ -725,9 +820,9 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e
|
||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
int segment = -1;
|
||||
if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].p2p.nThreads == 0) {
|
||||
if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) {
|
||||
// Try to pack more segments into a single operation
|
||||
segment = getSegment(workElem->p2p.delta, w);
|
||||
segment = getSegment(P2P_SEGMENT, workElem->p2p.delta, w);
|
||||
}
|
||||
if (segment == -1) {
|
||||
NCCLCHECK(getNextOp(channel, &w, NULL));
|
||||
@ -736,7 +831,7 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e
|
||||
|
||||
// store work element into FIFO
|
||||
NCCLCHECK(ncclProxySaveP2p(comm, proxyArgs));
|
||||
NCCLCHECK(enqueueP2pOp(workElem, w, segment));
|
||||
NCCLCHECK(enqueueSegOp(P2P_SEGMENT, workElem, w, segment));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -744,7 +839,7 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) {
|
||||
ncclComm* comm = info->comm;
|
||||
// Compute cuda kernel arg and proxy arg templates
|
||||
struct ncclQueueElem* eqElem;
|
||||
NCCLCHECK(ncclAddQueueElem(comm->enqueueInfo, &eqElem));
|
||||
NCCLCHECK(comm->enqueueInfo->elemList->getNewElem(&eqElem));
|
||||
// The proxy code will set and tune the send/recv chunk size, make sure to run it first.
|
||||
NCCLCHECK(ncclProxyComputeP2p(info, &eqElem->proxyArgs));
|
||||
NCCLCHECK(computeP2pWorkElem(info, &eqElem->work));
|
||||
@ -760,11 +855,51 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) {
|
||||
// The CUDA kernel does not use the inlined first work element as fastpath argument
|
||||
if (params->func == NULL) {
|
||||
params->func = ncclKerns[eqElem->work.funcIndex];
|
||||
memcpy(&comm->args, &eqElem->work, sizeof(struct ncclWorkElem));
|
||||
comm->args.comm = eqElem->work.comm;
|
||||
comm->args.active = 0;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem* eqElem) {
|
||||
struct ncclWorkElem* work = &eqElem->work;
|
||||
struct ncclProxyArgs* proxyArgs = &eqElem->proxyArgs;
|
||||
|
||||
int nChannels = work->coll.nChannels;
|
||||
size_t channelSize = work->coll.count*ncclTypeSize(proxyArgs->dtype)/work->coll.nChannels;
|
||||
for (int bid=0; bid<nChannels; bid++) {
|
||||
int channelId;
|
||||
NCCLCHECK(getNextChannel(comm, &channelId));
|
||||
struct ncclChannel* channel = comm->channels+channelId;
|
||||
|
||||
// Proxy
|
||||
proxyArgs->subs[0].channel = channel;
|
||||
proxyArgs->opCount = comm->collOpCount;
|
||||
proxyArgs->commOpCount = comm->opCount;
|
||||
if (proxyArgs->subs[0].nsteps) NCCLCHECK(ncclProxySaveColl(proxyArgs, comm->nRanks));
|
||||
|
||||
// Try to reuse last work if not full yet
|
||||
work->coll.bid = bid % nChannels;
|
||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
int segment = -1;
|
||||
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) {
|
||||
// Try to pack more segments into a single operation
|
||||
segment = getSegment(COLL_SEGMENT, 0, w);
|
||||
}
|
||||
if (segment == -1) {
|
||||
NCCLCHECK(getNextOp(channel, &w, NULL));
|
||||
segment = 0;
|
||||
}
|
||||
|
||||
// store work element into FIFO
|
||||
NCCLCHECK(enqueueSegOp(COLL_SEGMENT, work, w, segment));
|
||||
channel->totalSize += channelSize;
|
||||
}
|
||||
comm->collOpCount++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
template<int USING_CUDA_GRAPH>
|
||||
void CUDART_CB ncclEnqueueHostSetup(void* arg) {
|
||||
ncclResult_t ret;
|
||||
@ -772,14 +907,17 @@ void CUDART_CB ncclEnqueueHostSetup(void* arg) {
|
||||
ncclComm_t comm = eqInfo->comm;
|
||||
|
||||
// Iterate through the element list
|
||||
struct ncclQueueElem* eqElem = eqInfo->elemList.head;
|
||||
while (eqElem != eqInfo->elemList.tail) { // The queue always has one extra element
|
||||
struct ncclQueueElem* eqElem = eqInfo->elemList->begin();
|
||||
while (eqElem != NULL) {
|
||||
if (eqElem->work.funcIndex == FUNC_INDEX_P2P) {
|
||||
NCCLCHECKGOTO(ncclEnqueueP2pKernel(comm, eqElem), ret, cb_end);
|
||||
} else if (eqInfo->elemList->count() > 1) {
|
||||
// We have more than one operation, hence aggregating
|
||||
NCCLCHECKGOTO(ncclEnqueueAsyncKernel(comm, eqElem), ret, cb_end);
|
||||
} else {
|
||||
NCCLCHECKGOTO(ncclEnqueueCollKernel(comm, eqElem), ret, cb_end);
|
||||
}
|
||||
eqElem = eqElem->next;
|
||||
eqElem = eqInfo->elemList->getNext();
|
||||
}
|
||||
|
||||
NCCLCHECKGOTO(setupLaunch(eqInfo, USING_CUDA_GRAPH), ret, cb_end);
|
||||
|
@ -388,7 +388,9 @@ ncclResult_t ncclTopoComputePaths(struct ncclTopoSystem* system, struct ncclPeer
|
||||
struct ncclPeerInfo* srcInfo = peerInfos+system->nodes[GPU].nodes[p].gpu.rank;
|
||||
int shm;
|
||||
NCCLCHECK(ncclTransports[TRANSPORT_SHM].canConnect(&shm, system, NULL, srcInfo, dstInfo));
|
||||
if (shm == 0) {
|
||||
int p2p;
|
||||
NCCLCHECK(ncclTransports[TRANSPORT_P2P].canConnect(&p2p, system, NULL, srcInfo, dstInfo));
|
||||
if (shm == 0 && p2p == 0) {
|
||||
// Mark this peer as inaccessible. We'll trim it later.
|
||||
system->nodes[GPU].nodes[p].paths[GPU][g].count = 0;
|
||||
}
|
||||
|
@ -707,8 +707,10 @@ ncclResult_t ncclTopoGetXmlFromGraphs(int ngraphs, struct ncclTopoGraph** graphs
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
float speedArray[] = { 42.0, 30.0, 24.0, 21.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
#define NSPEEDS (sizeof(speedArray)/sizeof(float))
|
||||
float speedArrayIntra[] = { 44.0, 30.0, 22.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0 };
|
||||
float speedArrayInter[] = { 48.0, 30.0, 24.0, 22.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
#define NSPEEDSINTRA (sizeof(speedArrayIntra)/sizeof(float))
|
||||
#define NSPEEDSINTER (sizeof(speedArrayInter)/sizeof(float))
|
||||
|
||||
ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph) {
|
||||
int ngpus = system->nodes[GPU].count;
|
||||
@ -738,15 +740,23 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph
|
||||
// SPLIT_TREE works better on older archs.
|
||||
int ccMin;
|
||||
NCCLCHECK(ncclTopoGetCompCap(system, &ccMin, NULL));
|
||||
if (ccMin < 80 && graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE) graph->pattern = NCCL_TOPO_PATTERN_SPLIT_TREE;
|
||||
|
||||
struct ncclTopoGraph tmpGraph;
|
||||
memcpy(&tmpGraph, graph, sizeof(struct ncclTopoGraph));
|
||||
|
||||
// First try crossnic, then decrease speed and finally increase speedIntra.
|
||||
int nspeeds = 0;
|
||||
float* speedArray = NULL;
|
||||
if (system->nodes[NET].count == 0) {
|
||||
nspeeds = NSPEEDSINTRA;
|
||||
speedArray = speedArrayIntra;
|
||||
} else {
|
||||
nspeeds = NSPEEDSINTER;
|
||||
speedArray = speedArrayInter;
|
||||
}
|
||||
int pass = 1;
|
||||
int speedIndex = 0;
|
||||
while (speedArray[speedIndex] > system->maxWidth && speedIndex < NSPEEDS-1) speedIndex++;
|
||||
while (speedArray[speedIndex] > system->maxWidth && speedIndex < nspeeds-1) speedIndex++;
|
||||
tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex];
|
||||
int64_t globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT;
|
||||
|
||||
@ -813,12 +823,12 @@ search:
|
||||
tmpGraph.crossNic = 0;
|
||||
|
||||
// Decrease speed until we find a solution
|
||||
if ((speedIndex < NSPEEDS-1) && (graph->nChannels == 0 || (speedArray[speedIndex+1]/graph->speedInter > .49))) {
|
||||
if ((speedIndex < nspeeds-1) && (graph->nChannels == 0 || (speedArray[speedIndex+1]/graph->speedInter > .49))) {
|
||||
tmpGraph.speedInter = tmpGraph.speedIntra = speedArray[++speedIndex];
|
||||
goto search;
|
||||
}
|
||||
speedIndex = 0;
|
||||
while (speedArray[speedIndex] > system->maxWidth && speedIndex < NSPEEDS-1) speedIndex++;
|
||||
while (speedArray[speedIndex] > system->maxWidth && speedIndex < nspeeds-1) speedIndex++;
|
||||
tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex];
|
||||
|
||||
}
|
||||
@ -829,7 +839,7 @@ done:
|
||||
time = -1;
|
||||
memcpy(&tmpGraph, graph, sizeof(tmpGraph));
|
||||
speedIndex = 0;
|
||||
while (speedArray[speedIndex] > graph->speedInter && speedIndex < NSPEEDS-1) speedIndex++;
|
||||
while (speedArray[speedIndex] > graph->speedInter && speedIndex < nspeeds-1) speedIndex++;
|
||||
tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex];
|
||||
tmpGraph.minChannels = graph->nChannels;
|
||||
pass = 2;
|
||||
|
@ -583,7 +583,10 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy
|
||||
char* xmlTopoFile = getenv("NCCL_TOPO_FILE");
|
||||
if (xmlTopoFile) {
|
||||
INFO(NCCL_ENV, "NCCL_TOPO_FILE set by environment to %s", xmlTopoFile);
|
||||
NCCLCHECK(ncclTopoGetXmlFromFile(xmlTopoFile, xml));
|
||||
NCCLCHECK(ncclTopoGetXmlFromFile(xmlTopoFile, xml, 1));
|
||||
} else {
|
||||
// Try default XML topology location
|
||||
NCCLCHECK(ncclTopoGetXmlFromFile("/var/run/nvidia-topologyd/virtualTopology.xml", xml, 0));
|
||||
}
|
||||
if (xml->maxIndex == 0) {
|
||||
// Create top tag
|
||||
@ -691,7 +694,7 @@ ncclResult_t ncclTopoCpuType(struct ncclTopoSystem* system, int* arch, int* vend
|
||||
|
||||
NCCL_PARAM(IgnoreCpuAffinity, "IGNORE_CPU_AFFINITY", 0);
|
||||
|
||||
ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank) {
|
||||
ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu_set_t* affinity) {
|
||||
struct ncclTopoNode* cpu = NULL, *gpu = NULL;
|
||||
for (int g=0; g<system->nodes[GPU].count; g++) {
|
||||
if (system->nodes[GPU].nodes[g].gpu.rank == rank) {
|
||||
@ -744,12 +747,13 @@ ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank) {
|
||||
// Use a subset of the GPU affinity set
|
||||
CPU_AND(&finalMask, &mask, &cpuMask);
|
||||
|
||||
memcpy(affinity, &finalMask, sizeof(cpu_set_t));
|
||||
|
||||
// If there is a non empty set, use it to set affinity
|
||||
if (CPU_COUNT(&finalMask)) {
|
||||
char affinityStr[sizeof(cpu_set_t)*2];
|
||||
NCCLCHECK(ncclCpusetToStr(&finalMask, affinityStr));
|
||||
INFO(NCCL_INIT, "Setting affinity for GPU %d to %s", gpu->gpu.dev, affinityStr);
|
||||
SYSCHECK(sched_setaffinity(0, sizeof(cpu_set_t), &finalMask), "sched_setaffinity");
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
@ -9,12 +9,11 @@
|
||||
|
||||
#include "graph.h"
|
||||
#include "core.h"
|
||||
#include <sched.h>
|
||||
|
||||
#define LOC_WIDTH 5000.0
|
||||
#define SM60_NVLINK_WIDTH 18.0
|
||||
#define SM70_NVLINK_WIDTH 21.0
|
||||
#define SM80_NVLINK_WIDTH 21.0
|
||||
#define SM70_NVLINK_WIDTH 22.0
|
||||
#define SM80_NVLINK_WIDTH 22.0
|
||||
#define SM86_NVLINK_WIDTH 12.0
|
||||
#define PCI_WIDTH 12.0 // PCI Gen3 x16
|
||||
#define QPI_WIDTH 6.0
|
||||
|
@ -60,20 +60,19 @@ static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 4.4,
|
||||
#define NCCL_HW_PCI 1
|
||||
#define NCCL_HW_NET 2
|
||||
// Tree/Simple is the latency a 256kB chunk, which is ~ base lat + 256k/12GB/s (+ 256k/12GB/s for the network).
|
||||
static const float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
|
||||
static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
|
||||
{ /* NVLINK */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { .52, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .47, 1.9, 3.4 }, /* CollNet (LL/LL128/Simple)*/ { .5, 1.2, 4.0 } },
|
||||
{ /* Tree (LL/LL128/Simple)*/ { .52, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .47, 1.9, 3.4 }, /* CollNet (LL/LL128/Simple)*/ { .5, 1.2, 8.0 } },
|
||||
/* PCI */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 }, /* CollNet (LL/LL128/Simple)*/ { 1.0, 1.9, 5.5 } },
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 }, /* CollNet (LL/LL128/Simple)*/ { 1.0, 1.9, 8.0 } },
|
||||
/* NET */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 28 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 9.6 }, /* CollNet (LL/LL128/Simple)*/ { 5.0, 5.0, 10.7 } }
|
||||
};
|
||||
|
||||
// LL128 max BW (per channel) for the different collectives
|
||||
// ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce
|
||||
static const double ll128MaxBwPerCh[NCCL_NUM_FUNCTIONS] = { 18.8, 12.0, 18.3, 15.2, 16.9 };
|
||||
// LL128 max BW per channel
|
||||
static const double ll128MaxBwPerCh = 20.0;
|
||||
static const double llMaxBws[2][3] = { /* Volta-N1/Intel-N2/Intel-N4) */ {39.0, 39.0, 20.4}, /* Ampere-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0} };
|
||||
static const double perChMaxTreeBws[2][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 22.5, 16.0} };
|
||||
static const double perChMaxTreeBws[2][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 23.6, 17.8} };
|
||||
|
||||
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph) {
|
||||
int simpleDefaultThreads = (ringGraph->speedIntra*ringGraph->nChannels <= PCI_WIDTH) ? 256 : NCCL_SIMPLE_MAX_NTHREADS;
|
||||
@ -100,6 +99,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
int index1 = nNodes == 1 ? compCap80 : cpuVendor == NCCL_TOPO_CPU_VENDOR_AMD ? 1 : 0;
|
||||
double llMaxBw = llMaxBws[index1][index2];
|
||||
double perChMaxTreeBw = perChMaxTreeBws[compCap80][index2];
|
||||
// De-penalize Tree/Simple latency on Power systems to favor Tree than Ring
|
||||
if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE];
|
||||
float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount
|
||||
|
||||
struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph };
|
||||
@ -125,11 +126,10 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
// Various model refinements
|
||||
if (compCap80) busBw = std::min(busBw, 235.0f);
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); }
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw = std::min(busBw*1.0/3.8, llMaxBw);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 0.915 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_COLLNET) busBw *= .9;
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), ll128MaxBwPerCh*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_COLLNET && p != NCCL_PROTO_SIMPLE) busBw = 0; // Oneshot CollNet only supports Simple
|
||||
|
||||
// Convert bus BW to algorithm BW
|
||||
@ -157,7 +157,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
2 * ((nRanks/nNodes-1) * intraLat + log2i(nNodes) * interLat);
|
||||
} else {
|
||||
comm->latencies[coll][a][p] +=
|
||||
2 * (nRanks/nNodes-1) * intraLat + interLat;
|
||||
2 * (std::min(1, (nRanks/nNodes-1)) * intraLat + (nRanks/nNodes-1) * 0.5) + interLat; // Add 0.5 arity serialization latency
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -266,11 +266,11 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
// factor is not ideal but works quite well. Powers of two, 64 B to 256MB.
|
||||
static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][23] = {
|
||||
{ 1.0, 1.0, 1.0, 1.0, .9, .8, .7, .7, .7, .7, .6, .5, .4, .4, .5, .6, .7, .8, .9, 1.0, 1.0, 1.0, 1.0 },
|
||||
{ 1.0, 1.0, 1.0, 1.0, 1.0, .9, .8, .8, .8, .7, .6, .6, .6, .5, .6, .6, .7, .7, .8, .9, .9, .92, .92 },
|
||||
{ 1.0, 1.0, 1.0, 1.0, 1.0, .9, .8, .8, .8, .7, .6, .6, .6, .6, .6, .6, .8, .9, .9, .9, .9, 1.0, 1.0 },
|
||||
{ .9, .9, .9, .9, .9, .9, .9, .8, .7, .6, .6, .5, .5, .5, .5, .6, .7, .8, .7, .7, .8, .9, .9 }
|
||||
};
|
||||
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, float* time) {
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time) {
|
||||
float bw = info->comm->bandwidths[info->coll][algorithm][protocol];
|
||||
float lat = info->comm->latencies[info->coll][algorithm][protocol];
|
||||
if (bw == 0) {
|
||||
@ -281,6 +281,8 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int proto
|
||||
if (info->nChannels != 0) bw = bw / info->comm->nChannels * info->nChannels;
|
||||
if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && info->comm->nNodes > 1
|
||||
&& info->coll == ncclFuncAllReduce && info->nBytes >= info->comm->nRanks/16.0*65536) lat *= 1.9; // Plateau effect of ring
|
||||
*time = lat + (info->nBytes) / (1000 * bw);
|
||||
// Tree pipelining saves latency in aggregation cases
|
||||
int latCount = algorithm == NCCL_ALGO_RING ? numPipeOps : DIVUP(numPipeOps, NCCL_MAX_WORK_ELEMENTS);
|
||||
*time = lat * latCount + (info->nBytes) / (1000 * bw);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
153
src/graph/xml.cc
153
src/graph/xml.cc
@ -300,12 +300,15 @@ ncclResult_t ncclTopoXmlLoadSystem(FILE* file, struct ncclXml* xml, struct ncclX
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml) {
|
||||
ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml, int warn) {
|
||||
FILE* file = fopen(xmlTopoFile, "r");
|
||||
if (file == NULL) {
|
||||
WARN("Could not open XML topology file %s : %s", xmlTopoFile, strerror(errno));
|
||||
if (warn) {
|
||||
WARN("Could not open XML topology file %s : %s", xmlTopoFile, strerror(errno));
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
INFO(NCCL_GRAPH, "Loading topology file %s", xmlTopoFile);
|
||||
struct xmlHandler handlers[] = { { "system", ncclTopoXmlLoadSystem } };
|
||||
xml->maxIndex = 0;
|
||||
NCCLCHECK(xmlLoadSub(file, xml, NULL, handlers, 1));
|
||||
@ -441,8 +444,8 @@ ncclResult_t ncclTopoGetPciNode(struct ncclXml* xml, const char* busId, struct n
|
||||
NCCLCHECK(xmlFindTagKv(xml, "pci", pciNode, "busid", busId));
|
||||
if (*pciNode == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, NULL, "pci", pciNode));
|
||||
NCCLCHECK(xmlSetAttr(*pciNode, "busid", busId));
|
||||
}
|
||||
NCCLCHECK(xmlSetAttr(*pciNode, "busid", busId));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -463,100 +466,114 @@ ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml*
|
||||
const char* busId;
|
||||
NCCLCHECK(xmlGetAttr(pciNode, "busid", &busId));
|
||||
char* path = NULL;
|
||||
int index;
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "class", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
|
||||
ncclDebugNoWarn = NCCL_GRAPH;
|
||||
getPciPath(busId, &path);
|
||||
ncclDebugNoWarn = 0;
|
||||
|
||||
if (path) {
|
||||
NCCLCHECK(ncclTopoSetAttrFromSys(pciNode, path, "class", "class"));
|
||||
}
|
||||
int index;
|
||||
ncclDebugNoWarn = NCCL_GRAPH;
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "vendor", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) getPciPath(busId, &path);
|
||||
if (path) ncclTopoSetAttrFromSys(pciNode, path, "vendor", "vendor");
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "device", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) getPciPath(busId, &path);
|
||||
if (path) ncclTopoSetAttrFromSys(pciNode, path, "device", "device");
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "subsystem_vendor", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) getPciPath(busId, &path);
|
||||
if (path) ncclTopoSetAttrFromSys(pciNode, path, "subsystem_vendor", "subsystem_vendor");
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "subsystem_device", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) getPciPath(busId, &path);
|
||||
if (path) ncclTopoSetAttrFromSys(pciNode, path, "subsystem_device", "subsystem_device");
|
||||
}
|
||||
ncclDebugNoWarn = 0;
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "link_speed", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
|
||||
char deviceSpeedStr[MAX_STR_LEN];
|
||||
float deviceSpeed;
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_speed", deviceSpeedStr));
|
||||
sscanf(deviceSpeedStr, "%f GT/s", &deviceSpeed);
|
||||
char portSpeedStr[MAX_STR_LEN];
|
||||
float portSpeed;
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_speed", portSpeedStr));
|
||||
sscanf(portSpeedStr, "%f GT/s", &portSpeed);
|
||||
NCCLCHECK(xmlSetAttr(pciNode, "link_speed", portSpeed < deviceSpeed ? portSpeedStr : deviceSpeedStr));
|
||||
if (path) {
|
||||
char deviceSpeedStr[MAX_STR_LEN];
|
||||
float deviceSpeed;
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_speed", deviceSpeedStr));
|
||||
sscanf(deviceSpeedStr, "%f GT/s", &deviceSpeed);
|
||||
char portSpeedStr[MAX_STR_LEN];
|
||||
float portSpeed;
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_speed", portSpeedStr));
|
||||
sscanf(portSpeedStr, "%f GT/s", &portSpeed);
|
||||
NCCLCHECK(xmlSetAttr(pciNode, "link_speed", portSpeed < deviceSpeed ? portSpeedStr : deviceSpeedStr));
|
||||
} else {
|
||||
NCCLCHECK(xmlSetAttr(pciNode, "link_speed", ""));
|
||||
}
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "link_width", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
|
||||
char strValue[MAX_STR_LEN];
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_width", strValue));
|
||||
int deviceWidth = strtol(strValue, NULL, 0);
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_width", strValue));
|
||||
int portWidth = strtol(strValue, NULL, 0);
|
||||
NCCLCHECK(xmlSetAttrInt(pciNode, "link_width", std::min(deviceWidth,portWidth)));
|
||||
if (path) {
|
||||
char strValue[MAX_STR_LEN];
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_width", strValue));
|
||||
int deviceWidth = strtol(strValue, NULL, 0);
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_width", strValue));
|
||||
int portWidth = strtol(strValue, NULL, 0);
|
||||
NCCLCHECK(xmlSetAttrInt(pciNode, "link_width", std::min(deviceWidth,portWidth)));
|
||||
} else {
|
||||
NCCLCHECK(xmlSetAttr(pciNode, "link_width", ""));
|
||||
}
|
||||
}
|
||||
struct ncclXmlNode* parent = pciNode->parent;
|
||||
if (parent == NULL) {
|
||||
if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
|
||||
if (path) {
|
||||
// Save that for later in case next step is a CPU
|
||||
char numaIdStr[MAX_STR_LEN];
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "numa_node", numaIdStr));
|
||||
|
||||
// Save that for later in case next step is a CPU
|
||||
char numaIdStr[MAX_STR_LEN];
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "numa_node", numaIdStr));
|
||||
|
||||
// Go up one level in the PCI tree. Rewind two "/" and follow the upper PCI
|
||||
// switch, or stop if we reach a CPU root complex.
|
||||
int slashCount = 0;
|
||||
int parentOffset;
|
||||
for (parentOffset = strlen(path)-1; parentOffset>0; parentOffset--) {
|
||||
if (path[parentOffset] == '/') {
|
||||
slashCount++;
|
||||
path[parentOffset] = '\0';
|
||||
int start = parentOffset - 1;
|
||||
while (start>0 && path[start] != '/') start--;
|
||||
// Check whether the parent path looks like "BBBB:BB:DD.F" or not.
|
||||
if (checkBDFFormat(path+start+1) == 0) {
|
||||
// This a CPU root complex. Create a CPU tag and stop there.
|
||||
struct ncclXmlNode* topNode;
|
||||
NCCLCHECK(xmlFindTag(xml, "system", &topNode));
|
||||
NCCLCHECK(xmlGetSubKv(topNode, "cpu", &parent, "numaid", numaIdStr));
|
||||
if (parent == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "numaid", numaIdStr));
|
||||
}
|
||||
} else if (slashCount == 2) {
|
||||
// Continue on the upper PCI switch
|
||||
for (int i = strlen(path)-1; i>0; i--) {
|
||||
if (path[i] == '/') {
|
||||
NCCLCHECK(xmlFindTagKv(xml, "pci", &parent, "busid", path+i+1));
|
||||
if (parent == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, NULL, "pci", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "busid", path+i+1));
|
||||
// Go up one level in the PCI tree. Rewind two "/" and follow the upper PCI
|
||||
// switch, or stop if we reach a CPU root complex.
|
||||
int slashCount = 0;
|
||||
int parentOffset;
|
||||
for (parentOffset = strlen(path)-1; parentOffset>0; parentOffset--) {
|
||||
if (path[parentOffset] == '/') {
|
||||
slashCount++;
|
||||
path[parentOffset] = '\0';
|
||||
int start = parentOffset - 1;
|
||||
while (start>0 && path[start] != '/') start--;
|
||||
// Check whether the parent path looks like "BBBB:BB:DD.F" or not.
|
||||
if (checkBDFFormat(path+start+1) == 0) {
|
||||
// This a CPU root complex. Create a CPU tag and stop there.
|
||||
struct ncclXmlNode* topNode;
|
||||
NCCLCHECK(xmlFindTag(xml, "system", &topNode));
|
||||
NCCLCHECK(xmlGetSubKv(topNode, "cpu", &parent, "numaid", numaIdStr));
|
||||
if (parent == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "numaid", numaIdStr));
|
||||
}
|
||||
} else if (slashCount == 2) {
|
||||
// Continue on the upper PCI switch
|
||||
for (int i = strlen(path)-1; i>0; i--) {
|
||||
if (path[i] == '/') {
|
||||
NCCLCHECK(xmlFindTagKv(xml, "pci", &parent, "busid", path+i+1));
|
||||
if (parent == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, NULL, "pci", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "busid", path+i+1));
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parent) break;
|
||||
}
|
||||
} else {
|
||||
// No information on /sys, attach GPU to unknown CPU
|
||||
NCCLCHECK(xmlFindTagKv(xml, "cpu", &parent, "numaid", "-1"));
|
||||
if (parent == NULL) {
|
||||
struct ncclXmlNode* topNode;
|
||||
NCCLCHECK(xmlFindTag(xml, "system", &topNode));
|
||||
NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "numaid", "-1"));
|
||||
NCCLCHECK(ncclTopoGetXmlFromCpu(parent, xml));
|
||||
}
|
||||
if (parent) break;
|
||||
}
|
||||
pciNode->parent = parent;
|
||||
parent->subs[parent->nSubs++] = pciNode;
|
||||
@ -661,12 +678,14 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
|
||||
if (index == -1) {
|
||||
const char* busId;
|
||||
NCCLCHECK(xmlGetAttr(sub, "target", &busId));
|
||||
if (strcmp(busId, "fffffff:ffff:ff") == 0) {
|
||||
char* path;
|
||||
ncclDebugNoWarn = NCCL_GRAPH;
|
||||
getPciPath(busId, &path);
|
||||
ncclDebugNoWarn = 0;
|
||||
if (path == NULL || strcmp(busId, "fffffff:ffff:ff") == 0) {
|
||||
// Remote NVLink device is not visible inside this VM. Assume NVSwitch.
|
||||
NCCLCHECK(xmlSetAttr(sub, "tclass", "0x068000"));
|
||||
} else {
|
||||
char* path;
|
||||
NCCLCHECK(getPciPath(busId, &path));
|
||||
NCCLCHECK(ncclTopoSetAttrFromSys(sub, path, "class", "tclass"));
|
||||
free(path);
|
||||
}
|
||||
@ -679,6 +698,7 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
|
||||
ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct ncclXmlNode** gpuNode) {
|
||||
struct ncclXmlNode* node;
|
||||
NCCLCHECK(ncclTopoGetPciNode(xml, busId, &node));
|
||||
NCCLCHECK(xmlSetAttrIfUnset(node, "class", "0x03"));
|
||||
NCCLCHECK(ncclTopoGetXmlFromSys(node, xml));
|
||||
nvmlDevice_t nvmlDev = NULL;
|
||||
static int nvmlInit = 0;
|
||||
@ -731,6 +751,7 @@ ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const cha
|
||||
char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE];
|
||||
strcpy(busId, pciSysPath+offset+1);
|
||||
NCCLCHECK(ncclTopoGetPciNode(xml, busId, &parent));
|
||||
NCCLCHECK(xmlSetAttrIfUnset(parent, "class", "0x02"));
|
||||
NCCLCHECK(ncclTopoGetXmlFromSys(parent, xml));
|
||||
} else {
|
||||
// Virtual NIC, no PCI device, attach to first CPU
|
||||
|
@ -38,7 +38,7 @@ struct ncclXml {
|
||||
|
||||
/* File functions */
|
||||
#define NCCL_TOPO_XML_VERSION 1
|
||||
ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml);
|
||||
ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml, int warn);
|
||||
ncclResult_t ncclTopoDumpXmlToFile(const char* xmlTopoFile, struct ncclXml* xml);
|
||||
#define NCCL_GRAPH_XML_VERSION 1
|
||||
ncclResult_t ncclTopoGetXmlGraphFromFile(const char* xmlGraphFile, struct ncclXml* xml);
|
||||
@ -137,6 +137,18 @@ static ncclResult_t xmlSetAttr(struct ncclXmlNode* node, const char* attrName, c
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t xmlSetAttrIfUnset(struct ncclXmlNode* node, const char* attrName, const char* value) {
|
||||
int index;
|
||||
NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
|
||||
if (index != -1) return ncclSuccess;
|
||||
index = node->nAttrs++;
|
||||
strncpy(node->attrs[index].key, attrName, MAX_STR_LEN);
|
||||
node->attrs[index].key[MAX_STR_LEN] = '\0';
|
||||
strncpy(node->attrs[index].value, value, MAX_STR_LEN);
|
||||
node->attrs[index].value[MAX_STR_LEN] = '\0';
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t xmlSetAttrInt(struct ncclXmlNode* node, const char* attrName, const int value) {
|
||||
int index;
|
||||
NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
|
||||
|
25
src/group.cc
25
src/group.cc
@ -133,6 +133,7 @@ void* ncclAsyncThreadPreconnect(void* args_) {
|
||||
struct ncclAsyncArgs* args = (struct ncclAsyncArgs*)args_;
|
||||
struct ncclComm* comm = args->coll.comm;
|
||||
CUDACHECKTHREAD(cudaSetDevice(comm->cudaDev));
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity);
|
||||
NCCLCHECKTHREAD(ncclTransportP2pSetup(comm, NULL, 0));
|
||||
return args;
|
||||
}
|
||||
@ -217,8 +218,6 @@ ncclResult_t ncclGroupEnd() {
|
||||
struct ncclComm* comm = args->coll.comm;
|
||||
int rank = comm->rank;
|
||||
int nRanks = comm->nRanks;
|
||||
struct ncclP2Plist* p2pSends = comm->p2pSends;
|
||||
struct ncclP2Plist* p2pRecvs = comm->p2pRecvs;
|
||||
|
||||
// Compute how much to split operations
|
||||
// Natural step size matching buffer steps.
|
||||
@ -241,8 +240,8 @@ ncclResult_t ncclGroupEnd() {
|
||||
sched_delta:
|
||||
uint32_t from = (rank+nRanks-delta)%nRanks;
|
||||
uint32_t to = (rank+delta)%nRanks;
|
||||
struct ncclP2Pinfo* recv = p2pRecvs[from].head;
|
||||
struct ncclP2Pinfo* send = p2pSends[to].head;
|
||||
struct ncclP2Pinfo* recv = comm->p2pRecvs[from] ? comm->p2pRecvs[from]->getNext() : NULL;
|
||||
struct ncclP2Pinfo* send = comm->p2pSends[to] ? comm->p2pSends[to]->getNext() : NULL;
|
||||
if (recv != NULL || send != NULL) {
|
||||
ssize_t totRecvBytes = -1, totSendBytes = -1;
|
||||
if (recv != NULL) totRecvBytes = recv->nbytes;
|
||||
@ -273,15 +272,11 @@ sched_delta:
|
||||
sendOffset += sendChunkSize;
|
||||
chunk++;
|
||||
} while (sendRemaining || recvRemaining);
|
||||
if (recv) {
|
||||
NCCLCHECKGOTO(dequeueP2pInfo(p2pRecvs+from), ret, group_cleanup);
|
||||
comm->p2pRecvCount--;
|
||||
}
|
||||
if (send) {
|
||||
NCCLCHECKGOTO(dequeueP2pInfo(p2pSends+to), ret, group_cleanup);
|
||||
comm->p2pSendCount--;
|
||||
}
|
||||
if (recv) comm->p2pRecvCount--;
|
||||
if (send) comm->p2pSendCount--;
|
||||
}
|
||||
if (recv == NULL && comm->p2pRecvs[from]) comm->p2pRecvs[from]->recycle();
|
||||
if (send == NULL && comm->p2pSends[to]) comm->p2pSends[to]->recycle();
|
||||
index++;
|
||||
if (index == 1 && deltas[1] == deltas[0]) index++;
|
||||
if (index == 2 && deltas[2] == deltas[0]) index++;
|
||||
@ -381,11 +376,9 @@ group_cleanup:
|
||||
comm->asyncTotalSize = 0;
|
||||
// Dequeue p2p lists
|
||||
if (comm->p2pSendCount > 0 || comm->p2pRecvCount > 0) {
|
||||
struct ncclP2Plist* p2pSends = comm->p2pSends;
|
||||
struct ncclP2Plist* p2pRecvs = comm->p2pRecvs;
|
||||
for (int peer=0; peer<comm->nRanks; peer++) {
|
||||
while (p2pSends[peer].head != NULL) dequeueP2pInfo(p2pSends+peer);
|
||||
while (p2pRecvs[peer].head != NULL) dequeueP2pInfo(p2pRecvs+peer);
|
||||
if (comm->p2pSends[peer]) comm->p2pSends[peer]->recycle();
|
||||
if (comm->p2pRecvs[peer]) comm->p2pRecvs[peer]->recycle();
|
||||
}
|
||||
comm->p2pSendCount = comm->p2pRecvCount = 0;
|
||||
}
|
||||
|
@ -16,4 +16,29 @@
|
||||
#define ALIGN_SIZE(size, align) \
|
||||
size = ((size + (align) - 1) / (align)) * (align);
|
||||
|
||||
#if !__CUDA_ARCH__
|
||||
#ifndef __host__
|
||||
#define __host__
|
||||
#endif
|
||||
#ifndef __device__
|
||||
#define __device__
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template<typename X, typename Y, typename Z = decltype(X()+Y())>
|
||||
__host__ __device__ constexpr Z divUp(X x, Y y) {
|
||||
return (x+y-1)/y;
|
||||
}
|
||||
|
||||
template<typename X, typename Y, typename Z = decltype(X()+Y())>
|
||||
__host__ __device__ constexpr Z roundUp(X x, Y y) {
|
||||
return (x+y-1) - (x+y-1)%y;
|
||||
}
|
||||
|
||||
// assumes second argument is a power of 2
|
||||
template<typename X, typename Z = decltype(X()+int())>
|
||||
__host__ __device__ constexpr Z alignUp(X x, int a) {
|
||||
return (x+a-1) & Z(-a);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -13,12 +13,13 @@
|
||||
#include <sys/mman.h>
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCudaHostCalloc(T** ptr, size_t nelem) {
|
||||
static ncclResult_t ncclCudaHostCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
|
||||
CUDACHECK(cudaHostAlloc(ptr, nelem*sizeof(T), cudaHostAllocMapped));
|
||||
memset(*ptr, 0, nelem*sizeof(T));
|
||||
INFO(NCCL_ALLOC, "Cuda Host Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr);
|
||||
INFO(NCCL_ALLOC, "%s:%d Cuda Host Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclCudaHostCalloc(...) ncclCudaHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
static inline ncclResult_t ncclCudaHostFree(void* ptr) {
|
||||
CUDACHECK(cudaFreeHost(ptr));
|
||||
@ -26,7 +27,7 @@ static inline ncclResult_t ncclCudaHostFree(void* ptr) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCalloc(T** ptr, size_t nelem) {
|
||||
static ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
|
||||
void* p = malloc(nelem*sizeof(T));
|
||||
if (p == NULL) {
|
||||
WARN("Failed to malloc %ld bytes", nelem*sizeof(T));
|
||||
@ -34,12 +35,13 @@ static ncclResult_t ncclCalloc(T** ptr, size_t nelem) {
|
||||
}
|
||||
memset(p, 0, nelem*sizeof(T));
|
||||
*ptr = (T*)p;
|
||||
INFO(NCCL_ALLOC, "Mem Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr);
|
||||
INFO(NCCL_ALLOC, "%s:%d Mem Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclCalloc(...) ncclCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCudaCalloc(T** ptr, size_t nelem) {
|
||||
static ncclResult_t ncclCudaCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
|
||||
// Need async stream for P2P pre-connect + CUDA Graph
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
@ -47,9 +49,10 @@ static ncclResult_t ncclCudaCalloc(T** ptr, size_t nelem) {
|
||||
CUDACHECK(cudaMemsetAsync(*ptr, 0, nelem*sizeof(T), stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
CUDACHECK(cudaStreamDestroy(stream));
|
||||
INFO(NCCL_ALLOC, "Cuda Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr);
|
||||
INFO(NCCL_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclCudaCalloc(...) ncclCudaCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCudaMemcpy(T* dst, T* src, size_t nelem) {
|
||||
@ -60,7 +63,7 @@ static ncclResult_t ncclCudaMemcpy(T* dst, T* src, size_t nelem) {
|
||||
// Allocate memory to be potentially ibv_reg_mr'd. This needs to be
|
||||
// allocated on separate pages as those pages will be marked DONTFORK
|
||||
// and if they are shared, that could cause a crash in a child process
|
||||
static ncclResult_t ncclIbMalloc(void** ptr, size_t size) {
|
||||
static ncclResult_t ncclIbMallocDebug(void** ptr, size_t size, const char *filefunc, int line) {
|
||||
size_t page_size = sysconf(_SC_PAGESIZE);
|
||||
void* p;
|
||||
int size_aligned = ROUNDUP(size, page_size);
|
||||
@ -68,8 +71,9 @@ static ncclResult_t ncclIbMalloc(void** ptr, size_t size) {
|
||||
if (ret != 0) return ncclSystemError;
|
||||
memset(p, 0, size);
|
||||
*ptr = p;
|
||||
INFO(NCCL_ALLOC, "Ib Alloc Size %ld pointer %p", size, *ptr);
|
||||
INFO(NCCL_ALLOC, "%s:%d Ib Alloc Size %ld pointer %p", filefunc, line, size, *ptr);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclIbMalloc(...) ncclIbMallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
#endif
|
||||
|
@ -16,6 +16,7 @@ ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commSt
|
||||
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size);
|
||||
ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size);
|
||||
ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size);
|
||||
ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks);
|
||||
ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, cudaIpcMemHandle_t* ipc, void** ptr);
|
||||
ncclResult_t bootstrapRemFree(int id, int rank, void* commState);
|
||||
ncclResult_t bootstrapClose(void* commState);
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -21,8 +21,8 @@
|
||||
|
||||
/* Declare all collective operations */
|
||||
#define DECL5(func, algo, proto, redop, type) \
|
||||
extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args); \
|
||||
extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem c); \
|
||||
extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(); \
|
||||
extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(ncclWorkElem c); \
|
||||
|
||||
#define DECL4(func, algo, redop, type) \
|
||||
DECL5(func, algo, SIMPLE, redop, type) \
|
||||
@ -34,6 +34,19 @@
|
||||
DECL4(func, TREE, redop, type) \
|
||||
DECL4(func, COLLNET, redop, type)
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#define DECL2(func, redop) \
|
||||
DECL3(func, redop, int8_t) \
|
||||
DECL3(func, redop, uint8_t) \
|
||||
DECL3(func, redop, int32_t) \
|
||||
DECL3(func, redop, uint32_t) \
|
||||
DECL3(func, redop, int64_t) \
|
||||
DECL3(func, redop, uint64_t) \
|
||||
DECL3(func, redop, half) \
|
||||
DECL3(func, redop, float) \
|
||||
DECL3(func, redop, double) \
|
||||
DECL3(func, redop, __nv_bfloat16)
|
||||
#else
|
||||
#define DECL2(func, redop) \
|
||||
DECL3(func, redop, int8_t) \
|
||||
DECL3(func, redop, uint8_t) \
|
||||
@ -44,12 +57,14 @@
|
||||
DECL3(func, redop, half) \
|
||||
DECL3(func, redop, float) \
|
||||
DECL3(func, redop, double)
|
||||
#endif
|
||||
|
||||
#define DECL(func) \
|
||||
DECL2(func, Sum) \
|
||||
DECL2(func, Prod) \
|
||||
DECL2(func, Min) \
|
||||
DECL2(func, Max)
|
||||
DECL2(func, Max) \
|
||||
DECL2(func, Avg)
|
||||
|
||||
#define DECL_ALL \
|
||||
DECL2(Broadcast, Sum) \
|
||||
|
@ -72,6 +72,7 @@ struct ncclComm {
|
||||
int nRanks; // number of GPUs in communicator
|
||||
int cudaDev; // my cuda device index
|
||||
int64_t busId; // my PCI bus ID in int format
|
||||
cpu_set_t cpuAffinity; // CPU affinity of the GPU
|
||||
|
||||
int node;
|
||||
int nNodes;
|
||||
@ -146,11 +147,13 @@ struct ncclComm {
|
||||
struct ncclInfo* asyncOps;
|
||||
int asyncOpCount;
|
||||
size_t asyncTotalSize;
|
||||
ssize_t channelSize;
|
||||
int lastChannel;
|
||||
enum { ROUND_ROBIN, SHORTEST_QUEUE } asyncAllocMode;
|
||||
|
||||
//list of async p2p operation queued in a group semantics
|
||||
struct ncclP2Plist* p2pSends;
|
||||
struct ncclP2Plist* p2pRecvs;
|
||||
ncclP2Plist** p2pSends;
|
||||
ncclP2Plist** p2pRecvs;
|
||||
int p2pSendCount;
|
||||
int p2pRecvCount;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -36,6 +36,9 @@ static __inline__ int ncclTypeSize(ncclDataType_t type) {
|
||||
case ncclUint8:
|
||||
return 1;
|
||||
case ncclFloat16:
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
case ncclBfloat16:
|
||||
#endif
|
||||
return 2;
|
||||
case ncclInt32:
|
||||
case ncclUint32:
|
||||
|
@ -12,7 +12,7 @@
|
||||
#include <stdint.h>
|
||||
|
||||
#define NCCL_NUM_FUNCTIONS 5 // SendRecv not included for now
|
||||
typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv} ncclFunc_t;
|
||||
typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclNumFuncs} ncclFunc_t;
|
||||
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS];
|
||||
|
||||
#define NCCL_NUM_ALGORITHMS 3 // Tree/Ring/CollNet
|
||||
@ -69,10 +69,6 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK
|
||||
#define NCCL_LL128_MAX_NTHREADS 640
|
||||
#define NCCL_LL128_ELEMS_PER_THREAD 120
|
||||
|
||||
// Receiving from up to 3 sources is more compute intensive than sending
|
||||
// to 3 dests. Use 70% for reduce and 30% for bcast.
|
||||
#define NCCL_LL128_SPLIT(nt) ((nt*7/(10*32))*32)
|
||||
|
||||
#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 8
|
||||
#define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS)
|
||||
|
||||
@ -116,6 +112,8 @@ struct ncclRing {
|
||||
// devices. Ordered from current device.
|
||||
int* userRanks;
|
||||
int* devUserRanks;
|
||||
|
||||
int index; // This rank's index in the ring
|
||||
};
|
||||
|
||||
|
||||
@ -203,6 +201,7 @@ struct ncclChannel {
|
||||
// Operation list for aggregation
|
||||
struct ncclWork* workFifo;
|
||||
int workCount;
|
||||
size_t totalSize;
|
||||
uint64_t workFifoTail; // Only used by CPU
|
||||
uint16_t index; // Only used by GPU
|
||||
|
||||
@ -228,4 +227,9 @@ struct ncclDevComm {
|
||||
struct ncclChannel* channels;
|
||||
};
|
||||
|
||||
struct ncclDevCommAndChannels {
|
||||
ncclDevComm comm;
|
||||
ncclChannel channels[MAXCHANNELS];
|
||||
};
|
||||
|
||||
#endif
|
||||
|
@ -11,6 +11,9 @@
|
||||
#include "group.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64)
|
||||
#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */
|
||||
|
||||
size_t ncclKernMaxLocalSize();
|
||||
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info);
|
||||
ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast);
|
||||
@ -31,39 +34,22 @@ ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph);
|
||||
struct ncclQueueElem {
|
||||
struct ncclWorkElem work;
|
||||
struct ncclProxyArgs proxyArgs;
|
||||
struct ncclQueueElem* next;
|
||||
};
|
||||
|
||||
// Store enqueue elements in a list
|
||||
struct ncclQueueElemList {
|
||||
struct ncclQueueElem* head;
|
||||
struct ncclQueueElem* tail;
|
||||
};
|
||||
typedef ncclRecyclableList<struct ncclQueueElem> ncclQueueElemList;
|
||||
|
||||
// Structure passed to CUDA graph
|
||||
struct ncclQueueInfo {
|
||||
ncclComm_t comm;
|
||||
int maxChannels; // Dynamic version of gridDim
|
||||
ncclResult_t ret; // Return value of host setup call
|
||||
struct ncclQueueElemList elemList;
|
||||
ncclQueueElemList* elemList;
|
||||
};
|
||||
|
||||
// Get next element from enqueue list
|
||||
static ncclResult_t ncclAddQueueElem(struct ncclQueueInfo* eqInfo, struct ncclQueueElem** elemOut) {
|
||||
if (eqInfo == NULL) return ncclInternalError;
|
||||
struct ncclQueueElemList* list = &eqInfo->elemList;
|
||||
if (list->tail != NULL) {
|
||||
*elemOut = list->tail;
|
||||
memset(*elemOut, 0, sizeof(struct ncclWorkElem) + sizeof(struct ncclProxyArgs));
|
||||
} else {
|
||||
NCCLCHECK(ncclCalloc(&list->tail, 1));
|
||||
*elemOut = list->tail;
|
||||
list->head = list->tail;
|
||||
}
|
||||
if (list->tail->next == NULL) {
|
||||
NCCLCHECK(ncclCalloc(&list->tail->next, 1));
|
||||
}
|
||||
list->tail = list->tail->next;
|
||||
static ncclResult_t ncclCreateQueueInfo(struct ncclQueueInfo** eqInfo, ncclComm_t comm) {
|
||||
NCCLCHECK(ncclCalloc(eqInfo, 1));
|
||||
(*eqInfo)->comm = comm;
|
||||
(*eqInfo)->elemList = new ncclQueueElemList();
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -72,7 +58,7 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) {
|
||||
if (eqInfo == NULL) return ncclInternalError;
|
||||
eqInfo->maxChannels = 0;
|
||||
eqInfo->ret = ncclSuccess;
|
||||
eqInfo->elemList.tail = eqInfo->elemList.head;
|
||||
eqInfo->elemList->recycle();
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -81,12 +67,7 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) {
|
||||
static void ncclDestroyQueueInfo(void* ptr) {
|
||||
if (ptr == NULL) return;
|
||||
struct ncclQueueInfo* eqInfo = (struct ncclQueueInfo*)ptr;
|
||||
struct ncclQueueElem* head = eqInfo->elemList.head;
|
||||
while (head != NULL) {
|
||||
struct ncclQueueElem* temp = head;
|
||||
head = head->next;
|
||||
free(temp);
|
||||
}
|
||||
delete eqInfo->elemList;
|
||||
free(eqInfo);
|
||||
}
|
||||
#endif // End include guard
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <ctype.h>
|
||||
#include <stdio.h>
|
||||
#include <sched.h>
|
||||
|
||||
ncclResult_t ncclTopoCudaPath(int cudaDev, char** path);
|
||||
|
||||
@ -33,8 +34,8 @@ ncclResult_t ncclTopoGetNetDev(struct ncclTopoSystem* system, int rank, struct n
|
||||
ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int64_t id1, int64_t id2, int* p2p, int *read, int* intermediateRank);
|
||||
ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* topo, int64_t busId, int netDev, int read, int* useGdr);
|
||||
|
||||
// Set CPU affinity
|
||||
ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank);
|
||||
// Find CPU affinity
|
||||
ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu_set_t* affinity);
|
||||
|
||||
#define NCCL_TOPO_CPU_ARCH_X86 1
|
||||
#define NCCL_TOPO_CPU_ARCH_POWER 2
|
||||
@ -100,6 +101,6 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa
|
||||
|
||||
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph);
|
||||
#include "info.h"
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, float* time);
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time);
|
||||
|
||||
#endif
|
||||
|
@ -12,32 +12,16 @@
|
||||
struct ncclP2Pinfo {
|
||||
void* buff;
|
||||
ssize_t nbytes;
|
||||
struct ncclP2Pinfo* next;
|
||||
};
|
||||
|
||||
struct ncclP2Plist {
|
||||
struct ncclP2Pinfo *head;
|
||||
struct ncclP2Pinfo *tail;
|
||||
};
|
||||
typedef ncclRecyclableList<struct ncclP2Pinfo> ncclP2Plist;
|
||||
|
||||
static ncclResult_t enqueueP2pInfo(ncclP2Plist* p2p, void* buff, ssize_t nBytes) {
|
||||
if (p2p == NULL) return ncclInternalError;
|
||||
static ncclResult_t ncclSaveP2pInfo(ncclP2Plist* &p2p, void* buff, ssize_t nBytes) {
|
||||
if (p2p == NULL) p2p = new ncclP2Plist();
|
||||
struct ncclP2Pinfo* next;
|
||||
NCCLCHECK(ncclCalloc(&next, 1));
|
||||
NCCLCHECK(p2p->getNewElem(&next));
|
||||
next->buff = buff;
|
||||
next->nbytes = nBytes;
|
||||
if (p2p->tail != NULL) p2p->tail->next = next;
|
||||
p2p->tail = next;
|
||||
if (p2p->head == NULL) p2p->head = next;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t dequeueP2pInfo(ncclP2Plist* p2p) {
|
||||
if (p2p == NULL) return ncclInternalError;
|
||||
struct ncclP2Pinfo* temp = p2p->head;
|
||||
p2p->head = p2p->head->next;
|
||||
if (p2p->tail == temp) p2p->tail = NULL;
|
||||
free(temp);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#endif
|
||||
|
@ -30,12 +30,13 @@ union socketAddress {
|
||||
struct sockaddr_in6 sin6;
|
||||
};
|
||||
|
||||
/* Format a string representation of a (struct sockaddr *) socket address using getnameinfo()
|
||||
/* Format a string representation of a (union socketAddress *) socket address using getnameinfo()
|
||||
*
|
||||
* Output: "IPv4/IPv6 address<port>"
|
||||
*/
|
||||
static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
|
||||
if (buf == NULL || saddr == NULL) return NULL;
|
||||
static inline const char *socketToString(union socketAddress *addr, char *buf) {
|
||||
if (buf == NULL || addr == NULL) return NULL;
|
||||
struct sockaddr *saddr = &addr->sa;
|
||||
if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0]='\0'; return buf; }
|
||||
char host[NI_MAXHOST], service[NI_MAXSERV];
|
||||
(void) getnameinfo(saddr, sizeof(union socketAddress), host, NI_MAXHOST, service, NI_MAXSERV, NI_NUMERICHOST|NI_NUMERICSERV);
|
||||
@ -43,8 +44,9 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
|
||||
return buf;
|
||||
}
|
||||
|
||||
static inline uint16_t socketToPort(struct sockaddr *saddr) {
|
||||
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
|
||||
static inline uint16_t socketToPort(union socketAddress *addr) {
|
||||
struct sockaddr *saddr = &addr->sa;
|
||||
return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
|
||||
}
|
||||
|
||||
/* Allow the user to force the IPv4/IPv6 interface selection */
|
||||
@ -85,7 +87,7 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre
|
||||
if (family != AF_INET && family != AF_INET6)
|
||||
continue;
|
||||
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Found interface %s:%s", interface->ifa_name, socketToString(interface->ifa_addr, line));
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Found interface %s:%s", interface->ifa_name, socketToString((union socketAddress *)interface->ifa_addr, line));
|
||||
|
||||
/* Allow the caller to force the socket family type */
|
||||
if (sock_family != -1 && family != sock_family)
|
||||
@ -194,13 +196,13 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd
|
||||
// Store the interface name
|
||||
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
|
||||
|
||||
TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr->sa), line_a));
|
||||
TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(localAddrs+found, line), socketToString(remoteAddr, line_a));
|
||||
found++;
|
||||
if (found == maxIfs) break;
|
||||
}
|
||||
|
||||
if (found == 0) {
|
||||
WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr->sa), line_a));
|
||||
WARN("Net : No interface found in the same subnet as remote address %s", socketToString(remoteAddr, line_a));
|
||||
}
|
||||
freeifaddrs(interfaces);
|
||||
return found;
|
||||
@ -333,7 +335,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
if (socketToPort(&localAddr->sa)) {
|
||||
if (socketToPort(localAddr)) {
|
||||
// Port is forced by env. Make sure we get the port.
|
||||
int opt = 1;
|
||||
#if defined(SO_REUSEPORT)
|
||||
@ -352,7 +354,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
|
||||
|
||||
#ifdef ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(&localAddr->sa, line));
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(localAddr, line));
|
||||
#endif
|
||||
|
||||
/* Put the socket in listen mode
|
||||
@ -364,10 +366,12 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
|
||||
}
|
||||
|
||||
static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
/* IPv4/IPv6 support */
|
||||
int family = remoteAddr->sa.sa_family;
|
||||
if (family != AF_INET && family != AF_INET6) {
|
||||
WARN("Error : connecting to address with family %d is neither AF_INET(%d) nor AF_INET6(%d)", family, AF_INET, AF_INET6);
|
||||
WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
|
||||
socketToString(remoteAddr, line), family, AF_INET, AF_INET6);
|
||||
return ncclInternalError;
|
||||
}
|
||||
int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
|
||||
@ -386,8 +390,7 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
|
||||
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt");
|
||||
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/
|
||||
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(&remoteAddr->sa, line));
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(remoteAddr, line));
|
||||
|
||||
int ret;
|
||||
int timedout_retries = 0;
|
||||
@ -403,25 +406,26 @@ retry:
|
||||
goto retry;
|
||||
}
|
||||
}
|
||||
WARN("Connect to %s failed : %s", socketToString(&remoteAddr->sa, line), strerror(errno));
|
||||
WARN("Net : Connect to %s failed : %s", socketToString(remoteAddr, line), strerror(errno));
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
#define NCCL_SOCKET_SEND 0
|
||||
#define NCCL_SOCKET_RECV 1
|
||||
static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) {
|
||||
static ncclResult_t socketProgressOpt(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset, int block) {
|
||||
int bytes = 0;
|
||||
char* data = (char*)ptr;
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
do {
|
||||
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
|
||||
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
|
||||
if (op == NCCL_SOCKET_RECV && bytes == 0) {
|
||||
WARN("Net : Connection closed by remote peer");
|
||||
WARN("Net : Connection closed by remote peer %s", socketToString(addr, line));
|
||||
return ncclSystemError;
|
||||
}
|
||||
if (bytes == -1) {
|
||||
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
|
||||
WARN("Call to recv failed : %s", strerror(errno));
|
||||
WARN("Net : Call to recv from %s failed : %s", socketToString(addr, line), strerror(errno));
|
||||
return ncclSystemError;
|
||||
} else {
|
||||
bytes = 0;
|
||||
@ -432,25 +436,25 @@ static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int*
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
|
||||
return socketProgressOpt(op, fd, ptr, size, offset, 0);
|
||||
static ncclResult_t socketProgress(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset) {
|
||||
return socketProgressOpt(op, fd, addr, ptr, size, offset, 0);
|
||||
}
|
||||
|
||||
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
|
||||
static ncclResult_t socketWait(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset) {
|
||||
while (*offset < size)
|
||||
NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1));
|
||||
NCCLCHECK(socketProgressOpt(op, fd, addr, ptr, size, offset, 1));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t socketSend(int fd, void* ptr, int size) {
|
||||
static ncclResult_t socketSend(int fd, union socketAddress *addr, void* ptr, int size) {
|
||||
int offset = 0;
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, ptr, size, &offset));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, addr, ptr, size, &offset));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t socketRecv(int fd, void* ptr, int size) {
|
||||
static ncclResult_t socketRecv(int fd, union socketAddress *addr, void* ptr, int size) {
|
||||
int offset = 0;
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, addr, ptr, size, &offset));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
@ -37,4 +37,76 @@ static long log2i(long n) {
|
||||
return l;
|
||||
}
|
||||
|
||||
// Recyclable list that avoids frequent malloc/free
|
||||
template<typename T>
|
||||
struct ncclListElem {
|
||||
T data;
|
||||
struct ncclListElem* next;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class ncclRecyclableList {
|
||||
private:
|
||||
struct ncclListElem<T>* head;
|
||||
struct ncclListElem<T>* tail;
|
||||
struct ncclListElem<T>* cursor;
|
||||
int n;
|
||||
|
||||
public:
|
||||
ncclRecyclableList() {
|
||||
tail = cursor = head = NULL;
|
||||
n = 0;
|
||||
}
|
||||
|
||||
int count() const { return n; }
|
||||
|
||||
// Get a new element from the list and return pointer
|
||||
ncclResult_t getNewElem(T** dataOut) {
|
||||
if (tail != NULL) {
|
||||
*dataOut = &tail->data;
|
||||
memset(*dataOut, 0, sizeof(T));
|
||||
} else {
|
||||
NCCLCHECK(ncclCalloc(&tail, 1));
|
||||
*dataOut = &tail->data;
|
||||
cursor = head = tail;
|
||||
}
|
||||
if (tail->next == NULL) {
|
||||
NCCLCHECK(ncclCalloc(&tail->next, 1));
|
||||
}
|
||||
tail = tail->next;
|
||||
n += 1;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
T* begin() {
|
||||
if (head == NULL || head == tail) return NULL;
|
||||
cursor = head->next;
|
||||
return &head->data;
|
||||
}
|
||||
|
||||
// Get next element from the list during an iteration
|
||||
T* getNext() {
|
||||
// tail always points to the next element to be enqueued
|
||||
// hence does not contain valid data
|
||||
if (cursor == NULL || cursor == tail) return NULL;
|
||||
T* rv = &cursor->data;
|
||||
cursor = cursor->next;
|
||||
return rv;
|
||||
}
|
||||
|
||||
// Recycle the list without freeing the space
|
||||
void recycle() {
|
||||
tail = cursor = head;
|
||||
n = 0;
|
||||
}
|
||||
|
||||
~ncclRecyclableList() {
|
||||
while (head != NULL) {
|
||||
struct ncclListElem<T>* temp = head;
|
||||
head = head->next;
|
||||
free(temp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
163
src/init.cc
163
src/init.cc
@ -79,21 +79,17 @@ ncclResult_t initNetPlugin(ncclNet_t** net, ncclCollNet_t** collnet) {
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
ncclNet_t* extNet = (ncclNet_t*) dlsym(netPluginLib, STR(NCCL_PLUGIN_SYMBOL));
|
||||
if (extNet == NULL) {
|
||||
*net = (ncclNet_t*) dlsym(netPluginLib, STR(NCCL_PLUGIN_SYMBOL));
|
||||
if (*net == NULL) {
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_PLUGIN_SYMBOL) " symbol.");
|
||||
} else if (initNet(extNet) == ncclSuccess) {
|
||||
*net = extNet;
|
||||
// Check for CollNet
|
||||
ncclCollNet_t* extCollNet = (ncclCollNet_t*) dlsym(netPluginLib, STR(NCCL_COLLNET_PLUGIN_SYMBOL));
|
||||
if (extCollNet == NULL) {
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol.");
|
||||
} else if (initCollNet(extCollNet) == ncclSuccess) {
|
||||
*collnet = extCollNet;
|
||||
}
|
||||
if (netPluginLib != NULL) dlclose(netPluginLib);
|
||||
return ncclSuccess;
|
||||
}
|
||||
if (netPluginLib != NULL) dlclose(netPluginLib);
|
||||
// Check for CollNet
|
||||
*collnet = (ncclCollNet_t*) dlsym(netPluginLib, STR(NCCL_COLLNET_PLUGIN_SYMBOL));
|
||||
if (*collnet == NULL) {
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol.");
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -101,13 +97,27 @@ ncclResult_t initNet() {
|
||||
// Always initialize bootstrap network
|
||||
NCCLCHECK(bootstrapNetInit());
|
||||
|
||||
NCCLCHECK(initNetPlugin(&ncclNet, &ncclCollNet));
|
||||
if (ncclNet != NULL) return ncclSuccess;
|
||||
if (initNet(&ncclNetIb) == ncclSuccess) {
|
||||
ncclNet = &ncclNetIb;
|
||||
} else {
|
||||
NCCLCHECK(initNet(&ncclNetSocket));
|
||||
ncclNet = &ncclNetSocket;
|
||||
// Initialize main communication network
|
||||
ncclNet_t* nets[3] = { NULL, &ncclNetIb, &ncclNetSocket };
|
||||
ncclCollNet_t* collNets[3] = { NULL, NULL, NULL };
|
||||
NCCLCHECK(initNetPlugin(nets+0, collNets+0));
|
||||
char* netName = getenv("NCCL_NET");
|
||||
|
||||
for (int i=0; i<3; i++) {
|
||||
if (nets[i] == NULL) continue;
|
||||
if (netName && strcmp(netName, nets[i]->name) != 0) continue;
|
||||
// net plugin is already initialized
|
||||
if (initNet(nets[i]) != ncclSuccess) continue;
|
||||
ncclNet = nets[i];
|
||||
if (collNets[i] && initCollNet(collNets[i]) == ncclSuccess) {
|
||||
ncclCollNet = collNets[i];
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (ncclNet == NULL) {
|
||||
WARN("Error: network %s not found.", netName ? netName : "");
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
@ -177,6 +187,10 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
return ncclSuccess;
|
||||
free(comm->connectSend);
|
||||
free(comm->connectRecv);
|
||||
for (int peer=0; peer<comm->nRanks; peer++) {
|
||||
delete comm->p2pSends[peer];
|
||||
delete comm->p2pRecvs[peer];
|
||||
}
|
||||
free(comm->p2pSends);
|
||||
free(comm->p2pRecvs);
|
||||
free(comm->asyncOps);
|
||||
@ -187,8 +201,7 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
if (comm->bootstrap)
|
||||
NCCLCHECK(bootstrapClose(comm->bootstrap));
|
||||
|
||||
CUDACHECK(cudaFree(comm->hostDevComm.channels));
|
||||
CUDACHECK(cudaFree(comm->devComm));
|
||||
CUDACHECK(cudaFree((ncclDevCommAndChannels*)comm->devComm));
|
||||
|
||||
for (int channel=0; channel<MAXCHANNELS; channel++)
|
||||
NCCLCHECK(freeChannel(comm->channels+channel, comm->nRanks));
|
||||
@ -224,6 +237,8 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2);
|
||||
|
||||
static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
if (ndev < 1) {
|
||||
WARN("invalid device count (%d) requested", ndev);
|
||||
@ -271,9 +286,15 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
NCCLCHECK(ncclCalloc(&comm->asyncOps, NCCL_MAX_OPS));
|
||||
comm->asyncOpCount = 0;
|
||||
comm->asyncTotalSize = 0;
|
||||
comm->channelSize = ncclParamAggChannelSize();
|
||||
comm->asyncAllocMode = ncclComm::SHORTEST_QUEUE;
|
||||
char* str = getenv("NCCL_AGG_ALLOC_MODE");
|
||||
if (str) INFO(NCCL_ENV, "NCCL_AGG_ALLOC_MODE set by environment to %s", str);
|
||||
if (str && strcmp(str, "ROUND_ROBIN") == 0) {
|
||||
comm->asyncAllocMode = ncclComm::ROUND_ROBIN;
|
||||
}
|
||||
|
||||
NCCLCHECK(ncclCalloc(&comm->enqueueInfo, 1));
|
||||
comm->enqueueInfo->comm = comm;
|
||||
NCCLCHECK(ncclCreateQueueInfo(&comm->enqueueInfo, comm));
|
||||
comm->lastSetupNode = NULL;
|
||||
comm->lastCudaGraphId = -1;
|
||||
|
||||
@ -296,9 +317,13 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
}
|
||||
|
||||
static ncclResult_t devCommSetup(ncclComm_t comm) {
|
||||
ncclDevCommAndChannels *devCommAndChans;
|
||||
NCCLCHECK(ncclCudaCalloc(&devCommAndChans, 1));
|
||||
comm->devComm = &devCommAndChans->comm;
|
||||
comm->hostDevComm.channels = devCommAndChans->channels;
|
||||
|
||||
// Duplicate the channels on the device
|
||||
int nChannels = std::max(comm->nChannels, comm->p2pnChannels);
|
||||
NCCLCHECK(ncclCudaCalloc(&comm->hostDevComm.channels, nChannels));
|
||||
NCCLCHECK(ncclCudaMemcpy(comm->hostDevComm.channels, comm->channels, nChannels));
|
||||
|
||||
// Copy userRanks and peers
|
||||
@ -307,7 +332,6 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
|
||||
}
|
||||
|
||||
// Duplicate the dev comm on the device
|
||||
NCCLCHECK(ncclCudaCalloc(&comm->devComm, 1));
|
||||
NCCLCHECK(ncclCudaMemcpy(comm->devComm, &comm->hostDevComm, 1));
|
||||
return ncclSuccess;
|
||||
}
|
||||
@ -349,15 +373,15 @@ static ncclResult_t setupChannel(struct ncclComm* comm, int channelId, int rank,
|
||||
NCCLCHECK(initChannel(comm, channelId));
|
||||
|
||||
struct ncclRing* ring = &comm->channels[channelId].ring;
|
||||
// Reorganize ranks to start with rank.
|
||||
int shift;
|
||||
for (shift = 0; shift<nranks; shift++) {
|
||||
if (ringRanks[shift] == rank) {
|
||||
break;
|
||||
}
|
||||
// Find our ring-distance from rank zero and reorganize ranks to start with rank.
|
||||
int ixZero=0, ixRank=0;
|
||||
for (int i=0; i < nranks; i++) {
|
||||
if (ringRanks[i] == 0) ixZero = i;
|
||||
if (ringRanks[i] == rank) ixRank = i;
|
||||
}
|
||||
ring->index = (ixRank-ixZero + nranks)%nranks;
|
||||
for (int i=0; i<nranks; i++) {
|
||||
ring->userRanks[i] = ringRanks[(i+shift)%nranks];
|
||||
ring->userRanks[i] = ringRanks[(i+ixRank)%nranks];
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
@ -379,7 +403,7 @@ ncclResult_t initParams(struct ncclComm* comm) {
|
||||
}
|
||||
|
||||
// Allocate/Set Intra Process Structures and set CG options
|
||||
ncclResult_t ncclCommSetIntra(struct ncclComm* comm, int rank, int ranks, struct ncclComm* comm0) {
|
||||
ncclResult_t ncclCommSetIntraProc(struct ncclComm* comm, int rank, int ranks, struct ncclComm* comm0) {
|
||||
comm->intraRank = rank;
|
||||
comm->intraRanks = ranks;
|
||||
comm->intraPhase = 0;
|
||||
@ -500,37 +524,45 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
}
|
||||
|
||||
// Compute intra ranks and minimum CUDA Compute capabilities of intra-node GPUs and all GPUs
|
||||
int intraRank0 = -1, intraRank = -1, intraRanks = 0;
|
||||
int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0;
|
||||
int intraNodeRank0 = -1, intraNodeRank = -1, intraNodeRanks = 0;
|
||||
int myCompCap = allGather1Data[rank].cudaCompCap;
|
||||
int minCompCap = myCompCap, maxCompCap = myCompCap;
|
||||
uint64_t otherHostHash;
|
||||
int tmpNnodes = 1;
|
||||
int intraNodeGlobalRanks[256];
|
||||
for (int i = 0; i < nranks; i++) {
|
||||
if (allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) {
|
||||
// Rank is on same node
|
||||
if (intraNodeRanks == 0) intraNodeRank0 = i;
|
||||
if (i == rank) intraNodeRank = intraNodeRanks;
|
||||
intraNodeGlobalRanks[intraNodeRanks++] = i;
|
||||
if (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash) {
|
||||
if (intraRanks == 0) intraRank0 = i;
|
||||
if (i == rank) intraRank = intraRanks;
|
||||
intraRanks++;
|
||||
}
|
||||
} else { // Determine whether number of nodes is 2 (for use in tree pattern determination)
|
||||
if (tmpNnodes == 1) {
|
||||
otherHostHash = allGather1Data[i].peerInfo.hostHash;
|
||||
tmpNnodes = 2;
|
||||
} else if (tmpNnodes == 2 && otherHostHash != allGather1Data[i].peerInfo.hostHash) {
|
||||
tmpNnodes = 3;
|
||||
// Rank is in same process
|
||||
if (intraProcRanks == 0) intraProcRank0 = i;
|
||||
if (i == rank) intraProcRank = intraProcRanks;
|
||||
intraProcRanks++;
|
||||
}
|
||||
}
|
||||
minCompCap = std::min(allGather1Data[i].cudaCompCap, minCompCap);
|
||||
maxCompCap = std::max(allGather1Data[i].cudaCompCap, maxCompCap);
|
||||
}
|
||||
TRACE(NCCL_INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
|
||||
if (intraRank == -1 || intraRank0 == -1 || allGather1Data[intraRank0].comm == NULL) {
|
||||
WARN("Failed to determine intra ranks hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
|
||||
TRACE(NCCL_INIT,"hostHash[%d] %lx intraNodeRank %d intraNodeRanks %d intraNodeRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraNodeRank, intraNodeRanks, intraNodeRank0);
|
||||
TRACE(NCCL_INIT,"pidHash[%d] %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.pidHash, intraProcRank, intraProcRanks, intraProcRank0);
|
||||
if (intraProcRank == -1 || intraProcRank0 == -1 || allGather1Data[intraProcRank0].comm == NULL) {
|
||||
WARN("Failed to determine intra proc ranks rank %d hostHash %lx pidHash %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, allGather1Data[rank].peerInfo.pidHash,
|
||||
intraProcRank, intraProcRanks, intraProcRank0);
|
||||
return ncclInternalError;
|
||||
}
|
||||
struct ncclComm* intraRank0Comm = allGather1Data[intraRank0].comm;
|
||||
if (intraNodeRank == -1 || intraNodeRank0 == -1 || intraNodeRanks == 0) {
|
||||
WARN("Failed to determine intra node ranks rank %d hostHash %lx pidHash %lx intraNodeRank %d intraNodeRanks %d intraNodeRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, allGather1Data[rank].peerInfo.pidHash,
|
||||
intraNodeRank, intraNodeRanks, intraNodeRank0);
|
||||
return ncclInternalError;
|
||||
}
|
||||
struct ncclComm* intraProcRank0Comm = allGather1Data[intraProcRank0].comm;
|
||||
uint64_t intraNodeRank0pidHash = allGather1Data[intraNodeRank0].peerInfo.pidHash;
|
||||
|
||||
free(allGather1Data);
|
||||
|
||||
@ -562,7 +594,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
|
||||
struct ncclTopoGraph treeGraph;
|
||||
treeGraph.id = 1;
|
||||
treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE;
|
||||
treeGraph.pattern = NCCL_TOPO_PATTERN_BALANCED_TREE;
|
||||
treeGraph.crossNic = ncclParamCrossNic();
|
||||
treeGraph.collNet = 0;
|
||||
treeGraph.minChannels = 1;
|
||||
@ -585,8 +617,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
}
|
||||
|
||||
// Determine local CollNet support before all-gather
|
||||
if (tmpNnodes > 1 && ncclParamCollNetEnable() == 1 && collNetSupport() == 1 && collNetGraph.nChannels > 0) comm->collNetSupport = 1;
|
||||
if (intraRanks > 8) {
|
||||
if (ncclParamCollNetEnable() == 1 && collNetSupport() == 1 && collNetGraph.nChannels > 0) comm->collNetSupport = 1;
|
||||
if (intraNodeRanks > 8) {
|
||||
if (comm->collNetSupport == 1) WARN("CollNet currently only supports up to 8 GPUs per node");
|
||||
comm->collNetSupport = 0;
|
||||
}
|
||||
@ -719,15 +751,19 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
struct ncclTree* tree = &comm->channels[c].tree;
|
||||
snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d",
|
||||
c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up);
|
||||
INFO(NCCL_GRAPH, "Ring %02d : %d -> %d -> %d", c, comm->channels[c].ring.prev, comm->rank, comm->channels[c].ring.next);
|
||||
}
|
||||
line[1023] = '\0';
|
||||
INFO(NCCL_INIT, "Trees%s", line);
|
||||
|
||||
// Set Affinity to a CPU local the our GPU, so that all memory we allocate
|
||||
// on the host is local.
|
||||
NCCLCHECK(ncclTopoGetCpuAffinity(comm->topo, comm->rank, &comm->cpuAffinity));
|
||||
cpu_set_t affinitySave;
|
||||
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
NCCLCHECK(ncclTopoSetAffinity(comm->topo, comm->rank));
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) {
|
||||
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity);
|
||||
}
|
||||
ncclResult_t ret;
|
||||
|
||||
NCCLCHECK(computeBuffSizes(comm));
|
||||
@ -768,10 +804,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
struct ncclChannel* channel = comm->channels+c;
|
||||
for (int h=0; h<nHeads; h++) {
|
||||
const int head = heads[h];
|
||||
if (ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetRecv) != 1)
|
||||
collNetSetupFail = 1;
|
||||
else if (ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetSend) != 1)
|
||||
collNetSetupFail = 1;
|
||||
collNetSetupFail = ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetRecv);
|
||||
if (!collNetSetupFail) collNetSetupFail = ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetSend);
|
||||
}
|
||||
// Verify CollNet setup across ranks after trying the first channel
|
||||
if (c == 0) {
|
||||
@ -837,14 +871,17 @@ collnet_cleanup:
|
||||
free(nvbPeers);
|
||||
}
|
||||
|
||||
NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, intraRank0Comm));
|
||||
NCCLCHECK(ncclCommSetIntraProc(comm, intraProcRank, intraProcRanks, intraProcRank0Comm));
|
||||
|
||||
/* Local intra-node barrier */
|
||||
NCCLCHECK(bootstrapBarrier(comm->bootstrap, intraNodeGlobalRanks, (int)intraNodeRank0pidHash, intraNodeRank, intraNodeRanks));
|
||||
|
||||
if (comm->nNodes) NCCLCHECK(ncclProxyCreate(comm));
|
||||
|
||||
// We should have allocated all buffers, collective fifos, ... we can
|
||||
// restore the affinity.
|
||||
affinity_restore:
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
if (ret != ncclSuccess) return ret;
|
||||
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks);
|
||||
|
@ -9,6 +9,9 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if CUDART_VERSION >= 11000
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#define NCCL_MAJOR ${nccl:Major}
|
||||
#define NCCL_MINOR ${nccl:Minor}
|
||||
@ -103,7 +106,8 @@ typedef enum { ncclSum = 0,
|
||||
ncclProd = 1,
|
||||
ncclMax = 2,
|
||||
ncclMin = 3,
|
||||
ncclNumOps = 4 } ncclRedOp_t;
|
||||
ncclAvg = 4,
|
||||
ncclNumOps = 5 } ncclRedOp_t;
|
||||
|
||||
/* Data types */
|
||||
typedef enum { ncclInt8 = 0, ncclChar = 0,
|
||||
@ -115,7 +119,13 @@ typedef enum { ncclInt8 = 0, ncclChar = 0,
|
||||
ncclFloat16 = 6, ncclHalf = 6,
|
||||
ncclFloat32 = 7, ncclFloat = 7,
|
||||
ncclFloat64 = 8, ncclDouble = 8,
|
||||
ncclNumTypes = 9 } ncclDataType_t;
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
ncclBfloat16 = 9,
|
||||
ncclNumTypes = 10
|
||||
#else
|
||||
ncclNumTypes = 9
|
||||
#endif
|
||||
} ncclDataType_t;
|
||||
|
||||
/*
|
||||
* Collective communication operations
|
||||
|
12
src/proxy.cc
12
src/proxy.cc
@ -41,9 +41,19 @@ static ncclResult_t allocateArgs(struct ncclComm* comm, struct ncclProxyArgs** a
|
||||
state->poolReturned = NULL;
|
||||
pthread_mutex_unlock(&state->poolMutex);
|
||||
} else {
|
||||
// Allocate a new pool of elements
|
||||
// Allocate a new pool of elements. Make sure we allocate the memory close
|
||||
// to the network thread
|
||||
struct ncclProxyPool* newPool;
|
||||
cpu_set_t affinitySave;
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) {
|
||||
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity);
|
||||
}
|
||||
NCCLCHECK(ncclCalloc(&newPool, 1));
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) {
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
}
|
||||
|
||||
struct ncclProxyArgs* newElems = newPool->elems;
|
||||
// Chain newly allocated elements
|
||||
for (int i=0; i<PROXYARGS_ALLOCATE_SIZE; i++) {
|
||||
|
@ -131,14 +131,13 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
|
||||
extern struct ncclTransport collNetTransport;
|
||||
|
||||
// All ranks must participate in collNetSetup call
|
||||
// return: 0 - unsupported, 1 - supported
|
||||
// We do not NCCLCHECK this call because we would fall back to P2P network in case CollNet setup fails
|
||||
int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, struct ncclChannel* channel, int masterRank, int masterPeer, int collNetGraphChannelId, int type) {
|
||||
int fail = 1;
|
||||
int rank = comm->rank;
|
||||
int nranks = comm->nRanks;
|
||||
int nMasters = comm->nNodes;
|
||||
int rankInCollNet = -1;
|
||||
int supported = 0;
|
||||
int isMaster = (rank == masterRank) ? 1 : 0;
|
||||
struct {
|
||||
int collNetRank;
|
||||
@ -148,9 +147,9 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
// check if we can connect to collnet, whose root is the nranks-th rank
|
||||
struct ncclPeerInfo *myInfo = comm->peerInfo+rank, *peerInfo = comm->peerInfo+nranks;
|
||||
peerInfo->rank = nranks;
|
||||
int ret = 1;
|
||||
int support = 1;
|
||||
if (isMaster) {
|
||||
NCCLCHECK(collNetTransport.canConnect(&ret, comm->topo, collNetGraph, myInfo, peerInfo));
|
||||
NCCLCHECK(collNetTransport.canConnect(&support, comm->topo, collNetGraph, myInfo, peerInfo));
|
||||
}
|
||||
|
||||
// send master receives connect info from peer recv master
|
||||
@ -168,7 +167,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
conn->transportComm = transportComm;
|
||||
// setup
|
||||
struct ncclConnect myConnect;
|
||||
if (isMaster && ret > 0) {
|
||||
if (isMaster && support) {
|
||||
NCCLCHECK(transportComm->setup(comm, collNetGraph, myInfo, peerInfo, &myConnect, conn, collNetGraphChannelId, type));
|
||||
}
|
||||
// prepare connect handles
|
||||
@ -198,7 +197,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
if (isMaster) memcpy(masterConnects+rankInCollNet, &(sendrecvExchange.connect), sizeof(struct ncclConnect));
|
||||
}
|
||||
// connect
|
||||
if (isMaster && ret > 0) {
|
||||
if (isMaster && support) {
|
||||
NCCLCHECKGOTO(transportComm->connect(comm, masterConnects, nMasters, rankInCollNet, conn), res, cleanup);
|
||||
struct ncclPeer* devRoot = channel->devPeers+nranks;
|
||||
struct ncclConnector* devConn = (type == collNetRecv) ? devRoot->recv+type : devRoot->send+type;
|
||||
@ -211,13 +210,11 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, masterPeer, collNetGraph->id, &sendrecvExchange, sizeof(sendrecvExchange)), res, cleanup);
|
||||
TRACE(NCCL_INIT, "CollNet [recv] : rank %d collNetRank %d collNetNranks %d sent connect to rank %d", rank, rankInCollNet, nMasters, masterPeer);
|
||||
}
|
||||
if (ret > 0) {
|
||||
supported = 1;
|
||||
}
|
||||
if (support) fail = 0;
|
||||
cleanup:
|
||||
if (allConnects != NULL) free(allConnects);
|
||||
if (masterConnects != NULL) free(masterConnects);
|
||||
return supported;
|
||||
return fail;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportCollNetCheck(struct ncclComm* comm, int collNetSetupFail) {
|
||||
|
@ -459,10 +459,9 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;
|
||||
char* ptr;
|
||||
int sharedBuffSlot = sub->posted%NCCL_STEPS;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, 0, &ptr));
|
||||
args->sharedBuff[sharedBuffSlot] = ptr;
|
||||
int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
|
||||
reqFifo[group][buffSlot].recvBuff = args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize;
|
||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, startChannel, &ptr));
|
||||
reqFifo[group][buffSlot].recvBuff = ptr;
|
||||
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
|
||||
sub->posted += args->sliceSteps;
|
||||
args->idle = 0;
|
||||
@ -478,9 +477,10 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] received, size %d", sub->received, group, buffSlot, totalSize);
|
||||
sub->received += args->sliceSteps;
|
||||
if (reqFifo[group][buffSlot].size > 0 && p == NCCL_PROTO_SIMPLE && resources->useGdr) {
|
||||
int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
|
||||
char* recvAddress = (char*)args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize;
|
||||
NCCLCHECK(collNetIflush(resources->collNetComm, recvAddress, totalSize, mhandle, sub->requests+buffSlot));
|
||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||
char* groupRecvAddress;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, 1, 1, sharedBuffSlot, startChannel, &groupRecvAddress));
|
||||
NCCLCHECK(collNetIflush(resources->collNetComm, groupRecvAddress, totalSize, mhandle, sub->requests+buffSlot));
|
||||
} else {
|
||||
for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps;
|
||||
}
|
||||
@ -505,8 +505,10 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
int group = s / COLLNET_GROUP_NSUBS;
|
||||
int buffSlot = (sub->base + sub->transmitted)%NCCL_STEPS;
|
||||
int sharedBuffSlot = sub->transmitted%NCCL_STEPS;
|
||||
int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
|
||||
char* ptr = args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize + (s%COLLNET_GROUP_NSUBS)*args->sharedSize[sharedBuffSlot];
|
||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||
char* groupRecvAddress;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, 1, 1, sharedBuffSlot, startChannel, &groupRecvAddress));
|
||||
char* ptr = groupRecvAddress + (s%COLLNET_GROUP_NSUBS)*args->sharedSize[sharedBuffSlot];
|
||||
if (p == NCCL_PROTO_SIMPLE) {
|
||||
volatile void** ptrsFifo = (volatile void**)resources->recvMem->ptrsFifo;
|
||||
ptrsFifo[buffSlot] = ptr;
|
||||
|
@ -201,7 +201,7 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) {
|
||||
}
|
||||
line[1023] = '\0';
|
||||
char addrline[SOCKET_NAME_MAXLEN+1];
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s ; OOB %s:%s", line, ncclIbIfName, socketToString(&ncclIbIfAddr.sa, addrline));
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s ; OOB %s:%s", line, ncclIbIfName, socketToString(&ncclIbIfAddr, addrline));
|
||||
}
|
||||
pthread_mutex_unlock(&ncclIbLock);
|
||||
}
|
||||
@ -252,10 +252,12 @@ ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) {
|
||||
|
||||
#define MAX_REQUESTS NCCL_NET_MAX_REQUESTS
|
||||
|
||||
#define NCCL_IB_MAX_QPS 128
|
||||
|
||||
struct ncclIbQpInfo {
|
||||
uint32_t lid;
|
||||
uint8_t ib_port;
|
||||
uint32_t qpn;
|
||||
uint32_t qpn[NCCL_IB_MAX_QPS];
|
||||
|
||||
// For RoCE
|
||||
uint64_t spn;
|
||||
@ -277,6 +279,7 @@ struct ncclIbRequest {
|
||||
struct ncclIbVerbs* verbs;
|
||||
int events;
|
||||
int size;
|
||||
union socketAddress *addr;
|
||||
};
|
||||
|
||||
struct ncclIbVerbs {
|
||||
@ -305,8 +308,10 @@ struct ncclIbSendComm {
|
||||
struct ncclIbSendFifo fifo[MAX_REQUESTS];
|
||||
uint32_t fifoHead;
|
||||
int fd;
|
||||
union socketAddress addr;
|
||||
int ready;
|
||||
struct ibv_qp* qp;
|
||||
struct ibv_qp* qps[NCCL_IB_MAX_QPS];
|
||||
int nqps;
|
||||
struct ibv_mr* fifoMr;
|
||||
};
|
||||
// The SendFifo needs to be 32-byte aligned and each element needs
|
||||
@ -337,16 +342,20 @@ struct ncclIbRecvComm {
|
||||
struct ncclIbVerbs verbs;
|
||||
struct ncclIbRemFifo remFifo;
|
||||
int fd;
|
||||
union socketAddress addr;
|
||||
int ready;
|
||||
struct ibv_qp* qp;
|
||||
struct ibv_qp* qps[NCCL_IB_MAX_QPS];
|
||||
int nqps;
|
||||
struct ncclIbGpuFlush gpuFlush;
|
||||
};
|
||||
static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned");
|
||||
|
||||
NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1);
|
||||
|
||||
ncclResult_t ncclIbInitVerbs(ibv_context* ctx, struct ncclIbVerbs* verbs) {
|
||||
NCCLCHECK(wrap_ibv_alloc_pd(&verbs->pd, ctx));
|
||||
// Recv requests can generate 2 completions (one for the post FIFO, one for the Recv).
|
||||
NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS, NULL, NULL, 0));
|
||||
NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS*ncclParamIbQpsPerConn(), NULL, NULL, 0));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -379,12 +388,12 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbVerbs* verbs, int acce
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclIbRtrQp(ibv_qp* qp, struct ncclIbQpInfo* info) {
|
||||
ncclResult_t ncclIbRtrQp(ibv_qp* qp, uint32_t qpn, struct ncclIbQpInfo* info) {
|
||||
struct ibv_qp_attr qpAttr;
|
||||
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
|
||||
qpAttr.qp_state = IBV_QPS_RTR;
|
||||
qpAttr.path_mtu = info->mtu;
|
||||
qpAttr.dest_qp_num = info->qpn;
|
||||
qpAttr.dest_qp_num = qpn;
|
||||
qpAttr.rq_psn = 0;
|
||||
qpAttr.max_dest_rd_atomic = 1;
|
||||
qpAttr.min_rnr_timer = 12;
|
||||
@ -441,18 +450,23 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) {
|
||||
NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
|
||||
*sendComm = comm;
|
||||
|
||||
comm->addr = handle->connectAddr;
|
||||
|
||||
// IB Setup
|
||||
ibv_context* ctx = ncclIbDevs[dev].context;
|
||||
NCCLCHECK(ncclIbInitVerbs(ctx, &comm->verbs));
|
||||
uint8_t ib_port = ncclIbDevs[dev].port;
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, &comm->qp));
|
||||
comm->nqps = ncclParamIbQpsPerConn();
|
||||
for (int q=0; q<comm->nqps; q++) {
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, comm->qps+q));
|
||||
}
|
||||
|
||||
// Send my QP Info to receiver through the socket. Hope this won't block.
|
||||
struct ibv_port_attr portAttr;
|
||||
NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr));
|
||||
struct ncclIbQpInfo qpInfo;
|
||||
qpInfo.ib_port = ib_port;
|
||||
qpInfo.qpn = comm->qp->qp_num;
|
||||
for (int q=0; q<comm->nqps; q++) qpInfo.qpn[q] = comm->qps[q]->qp_num;
|
||||
qpInfo.mtu = portAttr.active_mtu;
|
||||
|
||||
// Prepare my fifo
|
||||
@ -463,16 +477,18 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) {
|
||||
// RoCE support
|
||||
qpInfo.lid = portAttr.lid;
|
||||
if (qpInfo.lid) { // IB
|
||||
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn, qpInfo.mtu, qpInfo.lid);
|
||||
for (int q=0; q<comm->nqps; q++)
|
||||
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.lid);
|
||||
} else { // RoCE
|
||||
union ibv_gid gid;
|
||||
NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &gid));
|
||||
qpInfo.spn = gid.global.subnet_prefix;
|
||||
qpInfo.iid = gid.global.interface_id;
|
||||
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn, qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid);
|
||||
for (int q=0; q<comm->nqps; q++)
|
||||
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid);
|
||||
}
|
||||
|
||||
NCCLCHECK(socketSend(comm->fd, &qpInfo, sizeof(qpInfo)));
|
||||
NCCLCHECK(socketSend(comm->fd, &comm->addr, &qpInfo, sizeof(qpInfo)));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -483,11 +499,10 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
|
||||
struct ncclIbRecvComm* rComm;
|
||||
NCCLCHECK(ncclIbMalloc((void**)&rComm, sizeof(struct ncclIbRecvComm)));
|
||||
|
||||
struct sockaddr_in sockaddr;
|
||||
socklen_t socklen = sizeof(struct sockaddr_in);
|
||||
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
|
||||
socklen_t socklen = sizeof(union socketAddress);
|
||||
SYSCHECKVAL(accept(lComm->fd, &rComm->addr.sa, &socklen), "accept", rComm->fd);
|
||||
struct ncclIbQpInfo remQpInfo;
|
||||
NCCLCHECK(socketRecv(rComm->fd, &remQpInfo, sizeof(remQpInfo)));
|
||||
NCCLCHECK(socketRecv(rComm->fd, &rComm->addr, &remQpInfo, sizeof(remQpInfo)));
|
||||
|
||||
// IB setup
|
||||
ibv_context* ctx = ncclIbDevs[lComm->dev].context;
|
||||
@ -499,15 +514,20 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
|
||||
|
||||
// QP Creation
|
||||
NCCLCHECK(ncclIbInitVerbs(ctx, &rComm->verbs));
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, &rComm->qp));
|
||||
rComm->nqps = ncclParamIbQpsPerConn();
|
||||
for (int q=0; q<rComm->nqps; q++) {
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, rComm->qps+q));
|
||||
}
|
||||
|
||||
// Adjust the MTU
|
||||
remQpInfo.mtu = (enum ibv_mtu)std::min(remQpInfo.mtu, portAttr.active_mtu);
|
||||
|
||||
// Setup QP
|
||||
struct ibv_qp* qp = rComm->qp;
|
||||
NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(qp));
|
||||
for (int q=0; q<rComm->nqps; q++) {
|
||||
struct ibv_qp* qp = rComm->qps[q];
|
||||
NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(qp));
|
||||
}
|
||||
|
||||
// Retain remote fifo info and prepare my RDMA ops
|
||||
rComm->remFifo.rkey = remQpInfo.fifoRkey;
|
||||
@ -525,29 +545,26 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
|
||||
rComm->gpuFlush.sge.length = 1;
|
||||
rComm->gpuFlush.sge.lkey = rComm->gpuFlush.hostMr->lkey;
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ, &rComm->gpuFlush.qp));
|
||||
struct ncclIbQpInfo localQpInfo = {
|
||||
.lid=portAttr.lid,
|
||||
.ib_port=ib_port,
|
||||
.qpn=rComm->gpuFlush.qp->qp_num,
|
||||
.spn=gid.global.subnet_prefix,
|
||||
.iid=gid.global.interface_id,
|
||||
.mtu=portAttr.active_mtu
|
||||
};
|
||||
NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, &localQpInfo));
|
||||
struct ncclIbQpInfo localQpInfo;
|
||||
localQpInfo.lid=portAttr.lid;
|
||||
localQpInfo.ib_port=ib_port;
|
||||
localQpInfo.spn=gid.global.subnet_prefix;
|
||||
localQpInfo.iid=gid.global.interface_id;
|
||||
localQpInfo.mtu=portAttr.active_mtu;
|
||||
NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, rComm->gpuFlush.qp->qp_num, &localQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(rComm->gpuFlush.qp));
|
||||
}
|
||||
|
||||
// Fill Handle
|
||||
struct ncclIbQpInfo qpInfo = {
|
||||
.lid=portAttr.lid,
|
||||
.ib_port=ib_port,
|
||||
.qpn=qp->qp_num,
|
||||
.spn=gid.global.subnet_prefix,
|
||||
.iid=gid.global.interface_id,
|
||||
.mtu=remQpInfo.mtu
|
||||
};
|
||||
struct ncclIbQpInfo qpInfo;
|
||||
qpInfo.lid=portAttr.lid;
|
||||
qpInfo.ib_port=ib_port;
|
||||
for (int q=0; q<rComm->nqps; q++) qpInfo.qpn[q]=rComm->qps[q]->qp_num;
|
||||
qpInfo.spn=gid.global.subnet_prefix;
|
||||
qpInfo.iid=gid.global.interface_id;
|
||||
qpInfo.mtu=remQpInfo.mtu;
|
||||
|
||||
NCCLCHECK(socketSend(rComm->fd, &qpInfo, sizeof(qpInfo)));
|
||||
NCCLCHECK(socketSend(rComm->fd, &rComm->addr, &qpInfo, sizeof(qpInfo)));
|
||||
*recvComm = rComm;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@ -561,6 +578,7 @@ ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest**
|
||||
r->verbs = verbs;
|
||||
r->events = 1;
|
||||
r->size = -1;
|
||||
r->addr = NULL;
|
||||
*req = r;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@ -576,19 +594,21 @@ ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) {
|
||||
|
||||
ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) {
|
||||
struct ncclIbQpInfo remQpInfo;
|
||||
struct ibv_qp* qp = comm->qp;
|
||||
|
||||
// Do not block on this receive, return if not ready.
|
||||
int bytes = 0;
|
||||
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes));
|
||||
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &remQpInfo, sizeof(remQpInfo), &bytes));
|
||||
if (bytes == 0) return ncclSuccess; // Try again later
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &remQpInfo, sizeof(remQpInfo), &bytes));
|
||||
|
||||
NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(qp));
|
||||
for (int q=0; q<comm->nqps; q++) {
|
||||
struct ibv_qp* qp = comm->qps[q];
|
||||
NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(qp));
|
||||
}
|
||||
comm->ready = 1;
|
||||
// Block until this is done. It *should* not block indefinitely.
|
||||
NCCLCHECK(socketSend(comm->fd, &comm->ready, sizeof(int)));
|
||||
NCCLCHECK(socketSend(comm->fd, &comm->addr, &comm->ready, sizeof(int)));
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
@ -596,9 +616,9 @@ ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) {
|
||||
ncclResult_t ncclRecvCheck(struct ncclIbRecvComm* comm) {
|
||||
// Do not block on this receive, return if not ready.
|
||||
int bytes = 0;
|
||||
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes));
|
||||
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &comm->ready, sizeof(int), &bytes));
|
||||
if (bytes == 0) return ncclSuccess; // Try again later
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &comm->ready, sizeof(int), &bytes));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -643,20 +663,15 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
|
||||
struct ncclIbRequest* req;
|
||||
NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
|
||||
req->size = size;
|
||||
req->addr = &comm->addr;
|
||||
|
||||
struct ibv_send_wr wr[2];
|
||||
memset(&wr[0], 0, sizeof(wr[0]));
|
||||
wr[0].wr_id = (uint64_t)req;
|
||||
|
||||
struct ibv_sge sge;
|
||||
if (size == 0) {
|
||||
wr[0].sg_list = NULL;
|
||||
wr[0].num_sge = 0;
|
||||
} else {
|
||||
sge.addr=(uintptr_t)data; sge.length=(unsigned int)size; sge.lkey=mr->lkey;
|
||||
wr[0].sg_list = &sge;
|
||||
wr[0].num_sge = 1;
|
||||
}
|
||||
sge.addr=(uintptr_t)data; sge.lkey=mr->lkey;
|
||||
|
||||
#if USE_RDMA_WRITE == 0
|
||||
wr[0].opcode = IBV_WR_SEND;
|
||||
wr[0].send_flags = IBV_SEND_SIGNALED;
|
||||
@ -665,8 +680,9 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
|
||||
// Sanity checks to catch user collective call count/size mismatches
|
||||
// plus any potential programming errors
|
||||
if (size > slot->size || slot->size < 0 || slot->addr == 0 || slot->rkey == 0 || slot->seq != comm->fifoHead) {
|
||||
WARN("NET/IB : collective mismatch error local size %d remote %d addr %lx rkey %x seq %x/%x",
|
||||
size, slot->size, slot->addr, slot->rkey, slot->seq, comm->fifoHead);
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
WARN("NET/IB : peer %s collective mismatch error local size %d remote %d addr %lx rkey %x seq %x/%x",
|
||||
socketToString(req->addr, line), size, slot->size, slot->addr, slot->rkey, slot->seq, comm->fifoHead);
|
||||
return ncclInternalError;
|
||||
}
|
||||
wr[0].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
@ -703,8 +719,26 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
|
||||
}
|
||||
#endif
|
||||
|
||||
struct ibv_send_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_send(comm->qp, wr, &bad_wr));
|
||||
int chunkSize = std::max(8, DIVUP(size, comm->nqps));
|
||||
|
||||
int offset = 0;
|
||||
for (int q=0; q<comm->nqps; q++) {
|
||||
int length = std::min(size-offset, chunkSize);
|
||||
if (length <= 0) {
|
||||
wr[0].sg_list = NULL;
|
||||
wr[0].num_sge = 0;
|
||||
} else {
|
||||
sge.length = length;
|
||||
wr[0].sg_list = &sge;
|
||||
wr[0].num_sge = 1;
|
||||
}
|
||||
struct ibv_send_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_send(comm->qps[q], wr, &bad_wr));
|
||||
offset += chunkSize;
|
||||
sge.addr += chunkSize;
|
||||
wr[0].wr.rdma.remote_addr += chunkSize;
|
||||
}
|
||||
req->events = comm->nqps;
|
||||
|
||||
*request = req;
|
||||
return ncclSuccess;
|
||||
@ -757,7 +791,7 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, uint32_t rkey, uint64_t
|
||||
}
|
||||
|
||||
struct ibv_send_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_send(comm->qp, &wr, &bad_wr));
|
||||
NCCLCHECK(wrap_ibv_post_send(comm->qps[0], &wr, &bad_wr));
|
||||
comm->remFifo.tail++;
|
||||
|
||||
return ncclSuccess;
|
||||
@ -773,23 +807,22 @@ ncclResult_t ncclIbIrecv(void* recvComm, void* data, int size, void* mhandle, vo
|
||||
struct ncclIbRequest* req;
|
||||
NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
|
||||
req->size = size;
|
||||
req->addr = &comm->addr;
|
||||
|
||||
struct ibv_recv_wr wr;
|
||||
memset(&wr, 0, sizeof(wr));
|
||||
wr.wr_id = (uint64_t)req;
|
||||
|
||||
struct ibv_sge sge;
|
||||
if (size == 0) {
|
||||
wr.sg_list = NULL;
|
||||
wr.num_sge = 0;
|
||||
} else {
|
||||
sge.addr=(uintptr_t)data; sge.length=(unsigned int)size; sge.lkey=mr->lkey;
|
||||
wr.sg_list = &sge;
|
||||
wr.num_sge = 1;
|
||||
}
|
||||
wr.sg_list = NULL;
|
||||
wr.num_sge = 0;
|
||||
|
||||
for (int q=0; q<comm->nqps; q++) {
|
||||
struct ibv_qp* qp = comm->qps[q];
|
||||
struct ibv_recv_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_recv(qp, &wr, &bad_wr));
|
||||
}
|
||||
req->events = comm->nqps;
|
||||
|
||||
struct ibv_recv_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_recv(comm->qp, &wr, &bad_wr));
|
||||
*request = req;
|
||||
|
||||
// Post to FIFO to notify sender
|
||||
@ -803,6 +836,7 @@ ncclResult_t ncclIbIflush(void* recvComm, void* data, int size, void* mhandle, v
|
||||
|
||||
struct ncclIbRequest* req;
|
||||
NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
|
||||
req->addr = &comm->addr;
|
||||
struct ibv_mr* mr = (struct ibv_mr*)mhandle;
|
||||
|
||||
struct ibv_send_wr wr;
|
||||
@ -843,7 +877,9 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size) {
|
||||
for (int w=0; w<wrDone; w++) {
|
||||
struct ibv_wc *wc = wcs+w;
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("NET/IB : Got completion with error %d, opcode %d, len %d, vendor err %d", wc->status, wc->opcode, wc->byte_len, wc->vendor_err);
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d",
|
||||
socketToString(r->addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err);
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
@ -853,7 +889,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size) {
|
||||
doneReq->size = wc->byte_len;
|
||||
#if USE_RDMA_WRITE
|
||||
} else if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
|
||||
doneReq->size = wc->imm_data;
|
||||
if (doneReq->size == -1)
|
||||
doneReq->size = wc->imm_data;
|
||||
else
|
||||
doneReq->size += wc->imm_data;
|
||||
#endif
|
||||
}
|
||||
doneReq->events--;
|
||||
@ -866,7 +905,8 @@ ncclResult_t ncclIbCloseSend(void* sendComm) {
|
||||
struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm;
|
||||
if (comm) {
|
||||
close(comm->fd);
|
||||
if (comm->qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qp));
|
||||
for (int q=0; q<comm->nqps; q++)
|
||||
if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q]));
|
||||
if (comm->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr));
|
||||
NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs));
|
||||
free(comm);
|
||||
@ -878,7 +918,8 @@ ncclResult_t ncclIbCloseRecv(void* recvComm) {
|
||||
struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm;
|
||||
if (comm) {
|
||||
close(comm->fd);
|
||||
if (comm->qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qp));
|
||||
for (int q=0; q<comm->nqps; q++)
|
||||
if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q]));
|
||||
if (comm->gpuFlush.enabled) {
|
||||
if (comm->gpuFlush.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->gpuFlush.qp));
|
||||
if (comm->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->gpuFlush.hostMr));
|
||||
|
@ -56,7 +56,7 @@ ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) {
|
||||
memcpy(&ncclSocketDevs[i].addr, addrs+i, sizeof(union socketAddress));
|
||||
NCCLCHECK(ncclSocketGetPciPath(ncclSocketDevs[i].devName, &ncclSocketDevs[i].pciPath));
|
||||
snprintf(line+strlen(line), MAX_LINE_LEN-strlen(line), " [%d]%s:%s", i, names+i*MAX_IF_NAME_SIZE,
|
||||
socketToString(&addrs[i].sa, addrline));
|
||||
socketToString(&addrs[i], addrline));
|
||||
}
|
||||
line[MAX_LINE_LEN] = '\0';
|
||||
INFO(NCCL_INIT|NCCL_NET,"NET/Socket : Using%s", line);
|
||||
@ -129,6 +129,7 @@ struct ncclSocketTask {
|
||||
void* data;
|
||||
int size;
|
||||
int fd;
|
||||
union socketAddress *addr;
|
||||
int offset;
|
||||
int used;
|
||||
ncclResult_t result;
|
||||
@ -139,6 +140,7 @@ struct ncclSocketRequest {
|
||||
void* data;
|
||||
int size;
|
||||
int ctrlFd;
|
||||
union socketAddress *addr;
|
||||
int offset;
|
||||
int used;
|
||||
struct ncclSocketComm* comm;
|
||||
@ -170,6 +172,7 @@ struct ncclSocketListenComm {
|
||||
|
||||
struct ncclSocketComm {
|
||||
int ctrlFd;
|
||||
union socketAddress addr;
|
||||
int fds[MAX_SOCKETS];
|
||||
int nSocks;
|
||||
int nThreads;
|
||||
@ -195,7 +198,7 @@ void* persistentSocketThread(void *args_) {
|
||||
for (int j=0; j<nSocksPerThread; j++) {
|
||||
struct ncclSocketTask* r = myQueue->tasks+i+j;
|
||||
if (r != NULL && r->used == 1 && r->offset < r->size) {
|
||||
r->result = socketProgress(r->op, r->fd, r->data, r->size, &r->offset);
|
||||
r->result = socketProgress(r->op, r->fd, r->addr, r->data, r->size, &r->offset);
|
||||
if (r->result != ncclSuccess) {
|
||||
WARN("NET/Socket : socket progress error");
|
||||
return NULL;
|
||||
@ -311,11 +314,12 @@ ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) {
|
||||
for (int i=0; i<comm->nSocks+1; i++) {
|
||||
int tmpFd, offset=0;
|
||||
NCCLCHECK(connectAddress(&tmpFd, &handle->connectAddr));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &handle->connectAddr, &i, sizeof(int), &offset));
|
||||
if (i == comm->nSocks) comm->ctrlFd = tmpFd;
|
||||
else comm->fds[i] = tmpFd;
|
||||
}
|
||||
*sendComm = comm;
|
||||
comm->addr = handle->connectAddr;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@ -327,10 +331,9 @@ ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) {
|
||||
rComm->nThreads = lComm->nThreads;
|
||||
for (int i=0; i<rComm->nSocks+1; i++) {
|
||||
int tmpFd, sendSockIdx, offset=0;
|
||||
struct sockaddr_in sockaddr;
|
||||
socklen_t socklen = sizeof(struct sockaddr_in);
|
||||
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", tmpFd);
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &sendSockIdx, sizeof(int), &offset));
|
||||
socklen_t socklen = sizeof(union socketAddress);
|
||||
SYSCHECKVAL(accept(lComm->fd, &rComm->addr.sa, &socklen), "accept", tmpFd);
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &rComm->addr, &sendSockIdx, sizeof(int), &offset));
|
||||
if (sendSockIdx == rComm->nSocks) rComm->ctrlFd = tmpFd;
|
||||
else rComm->fds[sendSockIdx] = tmpFd;
|
||||
}
|
||||
@ -346,6 +349,7 @@ ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* dat
|
||||
r->data = data;
|
||||
r->size = size;
|
||||
r->ctrlFd = comm->ctrlFd;
|
||||
r->addr = &comm->addr;
|
||||
r->used = 1;
|
||||
r->comm = comm;
|
||||
r->nSubs = 0;
|
||||
@ -380,6 +384,7 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data,
|
||||
r->data = data;
|
||||
r->size = size;
|
||||
r->fd = comm->fds[comm->nextFd];
|
||||
r->addr = &comm->addr;
|
||||
r->offset = 0;
|
||||
r->result = ncclSuccess;
|
||||
comm->nextFd = (comm->nextFd + 1) % comm->nSocks;
|
||||
@ -406,16 +411,17 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
|
||||
if (r->used == 1) { /* try to send/recv size */
|
||||
int data = r->size;
|
||||
int offset = 0;
|
||||
NCCLCHECK(socketProgress(r->op, r->ctrlFd, &data, sizeof(int), &offset));
|
||||
NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->addr, &data, sizeof(int), &offset));
|
||||
|
||||
if (offset == 0) return ncclSuccess; /* Not ready -- retry later */
|
||||
|
||||
// Not sure we could ever receive less than 4 bytes, but just in case ...
|
||||
if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, &data, sizeof(int), &offset));
|
||||
if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, r->addr, &data, sizeof(int), &offset));
|
||||
|
||||
// Check size is less or equal to the size provided by the user
|
||||
if (r->op == NCCL_SOCKET_RECV && data > r->size) {
|
||||
WARN("NET/Socket : message truncated : receiving %d bytes instead of %d", data, r->size);
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d", socketToString(r->addr, line), data, r->size);
|
||||
return ncclInternalError;
|
||||
}
|
||||
r->size = data;
|
||||
@ -453,7 +459,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
|
||||
}
|
||||
} else { // progress request using main thread
|
||||
if (r->offset < r->size) {
|
||||
NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->data, r->size, &r->offset));
|
||||
NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->addr, r->data, r->size, &r->offset));
|
||||
}
|
||||
if (r->offset == r->size) {
|
||||
if (size) *size = r->size;
|
||||
|
@ -53,8 +53,8 @@ static int busIdToCudaDev(int64_t busId) {
|
||||
|
||||
/* Determine if two peers can communicate through p2p */
|
||||
ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) {
|
||||
// Rule out different nodes
|
||||
if (info1->hostHash != info2->hostHash) {
|
||||
// Rule out different nodes / isolated containers
|
||||
if (info1->hostHash != info2->hostHash || info1->shmDev != info2->shmDev) {
|
||||
*ret = 0;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user