nccl/src/misc/socket.cc
Kamil Iskra f44ac759fe NCCL 2.26.2-1
Profiler improvements
 * Add events for CUDA kernel start and end.
 * Allow network plugins to generate profiling events
 * Enable profiling on a per-operation basis, rather than per-communicator.
 * Add support for graph capturing.

Add implicit launch order
 * Allow to prevent deadlocks when using multiple NCCL communicators per
   device by implicitly ordering NCCL operations using the host program
   order. Disabled by default, set NCCL_LAUNCH_ORDER_IMPLICIT=1 to enable.
 * Add a complementary mechanism to detect host threads racing to launch
   to the same device. Enabled by default, set NCCL_LAUNCH_RACE_FATAL=0 to
   disable.

Optimize the PAT algorithm
 * Separate the computation and execution of PAT steps on different warps,
   allowing to run up to 16 PAT steps in parallel to significantly
   accelerate PAT and reduce its linear part.

Add support for setting QoS per communicator
 * Add a new trafficClass field to the communicator configuration, to
   allow the application to select a particular traffic class for a
   given communicator. The meaning of the traffic class is
   network-specific and should be set in accordance with the network
   configuration.
 * For the IB/RoCE plugin, existing config variables such as NCCL_IB_SL
   and NCCL_IB_TC take precedence.

Allow to enable GPU Direct RDMA specifically on C2C platforms
 * Disabled by default, set NCCL_NET_GDR_C2C=1 to enable.

Do not disable user buffer registration unless PXN is really used
 * Only disable UB when a communicator has more than one rank per
   node on any node.

RAS subsystem improvements
 * Report operation counts separately for each collective operation type.
 * Provide details about missing communicator ranks and reliably
   distinguish ranks that are no longer a given communicator's members
   (now reported as NOCOMM) from those that failed to respond.

Add support for timestamps to NCCL diagnostic messages
 * On by default for WARN messages; NCCL_DEBUG_TIMESTAMP_LEVELS can be
   used to enable them for other debug levels as well.
 * The format can be changed using the NCCL_DEBUG_TIMESTAMP_FORMAT config
   variable.

Reduce the memory usage with NVLink SHARP (NVLS)
 * Potentially save hundreds of MBs of device memory, considering the
   multicast buffer size granularity separately from the address alignment.

Update performance tuning for recent Intel CPUs
 * Improve algorithm/protocol selection on recent CPUs such as Emerald
   Rapids and Sapphire Rapids.

Improve channel scheduling when mixing LL and Simple operations.
 * Make LL operations account for 4x more traffic to ensure LL and simple
   operations complete at the same time.

Refactor the plugin code
 * Clean up and harmonize the support code across the network, tuner,
   and profiler plugins.

Add support for comment lines (starting with #) in the nccl.conf file
* Issue #1540.

Make user buffer registration problems print an INFO instead of a WARN.

Drop support for network plugin interface version 5.

Fix a race condition with split-shared communicators
 * NCCL could hang during connection setup if multiple communicators
   were grouped together that share resources.

Fix a performance regression when using NCCL_CROSS_NIC=1
 * NCCL would unnecessarily alternate rings, breaking the GPU-NIC
   associations.

Make GID index detection code more resilient
 * Dynamic GID detection code was giving up too soon if the
   detected index was not available (e.g., wasn't mapped to the
   container's sysfs).
 * Issues #1538, #1573.

Fix a race condition with non-blocking operation
 * Fix issue when creating a non-blocking communicator after a non-
   blocking collective operation on another communicator.

Fix shared memory usage on recent Blackwell GPUs.
 * Issues NVIDIA/nccl-tests#287, NVIDIA/nccl-tests#291, #1637.

Fix an error with NIC fusion and IB SHARP when recreating communicators
 * Disable the unloading of network plugins

Make the auto-merge failures in the NIC fusion non-fatal
 * This could happen when trying to merge IB and RoCE devices.

Fixes to ncclCommAbort
 * Fix hangs due to the progress thread spinning indefinitely on the
   network progress.
 * Reduce the abort time by up to two orders of magnitude.

Fix a crash when libnccl.so was dynamically unloaded
 * The RAS subsystem was missing a clean-up handler.

Fix a hang if the network plugin's test() call returns an error.

Fix a hang on heterogeneous architectures
 * Ensure we harmonize the tuning to avoid different tuning choices,
   causing a hang.

Fix double-free on failed ncclCommInitRank and ncclCommFinalize.

Fix a potential list traversal bug during a group launch of multiple
communicators
 * Issue #1599.

Unify the handling of NCCL configuration variables
 * Under rare circumstances, some variables specified in the config file
   could be ignored.
2025-03-12 13:46:21 -07:00

950 lines
34 KiB
C++

/*************************************************************************
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "socket.h"
#include "utils.h"
#include <stdlib.h>
#include <unistd.h>
#include <ifaddrs.h>
#include <net/if.h>
#include "param.h"
#include <time.h>
NCCL_PARAM(RetryCnt, "SOCKET_RETRY_CNT", 34);
NCCL_PARAM(RetryTimeOut, "SOCKET_RETRY_SLEEP_MSEC", 100);
static void msleep(unsigned int time_msec) {
const long c_1e6 = 1e6;
struct timespec tv = (struct timespec){
.tv_sec = time_msec / 1000,
.tv_nsec = (time_msec % 1000) * c_1e6,
};
nanosleep(&tv, NULL);
}
static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) {
int bytes = 0;
*closed = 0;
char* data = (char*)ptr;
char line[SOCKET_NAME_MAXLEN+1];
do {
if (op == NCCL_SOCKET_RECV) bytes = recv(sock->fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
if (op == NCCL_SOCKET_SEND) bytes = send(sock->fd, data+(*offset), size-(*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL);
if (op == NCCL_SOCKET_RECV && bytes == 0) {
*closed = 1;
return ncclSuccess;
}
if (bytes == -1) {
if ((op == NCCL_SOCKET_SEND && errno == EPIPE) || (op == NCCL_SOCKET_RECV && errno == ECONNRESET)) {
*closed = 1;
return ncclSuccess;
}
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
WARN("socketProgressOpt: Call to %s %s failed : %s", (op == NCCL_SOCKET_RECV ? "recv from" : "send to"),
ncclSocketToString(&sock->addr, line), strerror(errno));
return ncclRemoteError;
} else {
bytes = 0;
}
}
(*offset) += bytes;
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) {
INFO(NCCL_NET, "socketProgressOpt: abort called");
return ncclInternalError;
}
} while (sock->asyncFlag == 0 && bytes > 0 && (*offset) < size);
return ncclSuccess;
}
static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int* pclosed = NULL) {
int closed;
NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed));
if (closed) {
if (pclosed) {
*pclosed = closed;
return ncclSuccess;
} else {
char line[SOCKET_NAME_MAXLEN+1];
WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0));
return ncclRemoteError;
}
}
return ncclSuccess;
}
static ncclResult_t socketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) {
while (*offset < size)
NCCLCHECK(socketProgress(op, sock, ptr, size, offset));
return ncclSuccess;
}
/* Format a string representation of a (union ncclSocketAddress *) socket address using getnameinfo()
*
* Output: "IPv4/IPv6 address<port>"
*/
const char *ncclSocketToString(const union ncclSocketAddress *addr, char *buf, const int numericHostForm /*= 1*/) {
if (buf == NULL || addr == NULL) return NULL;
const struct sockaddr *saddr = &addr->sa;
if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0]='\0'; return buf; }
char host[NI_MAXHOST], service[NI_MAXSERV];
/* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned.
* (When not set, this will still happen in case the node's name cannot be determined.)
*/
int flag = NI_NUMERICSERV | (numericHostForm ? NI_NUMERICHOST : 0);
(void) getnameinfo(saddr, sizeof(union ncclSocketAddress), host, NI_MAXHOST, service, NI_MAXSERV, flag);
sprintf(buf, "%s<%s>", host, service);
return buf;
}
static uint16_t socketToPort(union ncclSocketAddress *addr) {
struct sockaddr *saddr = &addr->sa;
return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
}
/* Allow the user to force the IPv4/IPv6 interface selection */
static int envSocketFamily(void) {
int family = -1; // Family selection is not forced, will use first one found
const char* env = ncclGetEnv("NCCL_SOCKET_FAMILY");
if (env == NULL)
return family;
INFO(NCCL_ENV, "NCCL_SOCKET_FAMILY set by environment to %s", env);
if (strcmp(env, "AF_INET") == 0)
family = AF_INET; // IPv4
else if (strcmp(env, "AF_INET6") == 0)
family = AF_INET6; // IPv6
return family;
}
static int findInterfaces(const char* prefixList, char* names, union ncclSocketAddress *addrs, int sock_family, int maxIfNameSize, int maxIfs) {
#ifdef ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN+1];
#endif
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
if (searchNot) prefixList++;
bool searchExact = prefixList && prefixList[0] == '=';
if (searchExact) prefixList++;
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for (interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
if (interface->ifa_addr == NULL) continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if (family != AF_INET && family != AF_INET6)
continue;
TRACE(NCCL_INIT|NCCL_NET,"Found interface %s:%s", interface->ifa_name, ncclSocketToString((union ncclSocketAddress *) interface->ifa_addr, line));
/* Allow the caller to force the socket family type */
if (sock_family != -1 && family != sock_family)
continue;
/* We also need to skip IPv6 loopback interfaces */
if (family == AF_INET6) {
struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) continue;
}
// check against user specified interfaces
if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
continue;
}
// Check that this interface has not already been saved
// getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link
bool duplicate = false;
for (int i = 0; i < found; i++) {
if (strcmp(interface->ifa_name, names+i*maxIfNameSize) == 0) { duplicate = true; break; }
}
if (!duplicate) {
// Store the interface name
strncpy(names+found*maxIfNameSize, interface->ifa_name, maxIfNameSize);
// Store the IP address
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
memset(addrs+found, '\0', sizeof(*addrs));
memcpy(addrs+found, interface->ifa_addr, salen);
found++;
}
}
freeifaddrs(interfaces);
return found;
}
static bool matchSubnet(struct ifaddrs local_if, union ncclSocketAddress* remote) {
/* Check family first */
int family = local_if.ifa_addr->sa_family;
if (family != remote->sa.sa_family) {
return false;
}
if (family == AF_INET) {
struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr);
struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask);
struct sockaddr_in& remote_addr = remote->sin;
struct in_addr local_subnet, remote_subnet;
local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true;
} else if (family == AF_INET6) {
struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr);
struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask);
struct sockaddr_in6& remote_addr = remote->sin6;
struct in6_addr& local_in6 = local_addr->sin6_addr;
struct in6_addr& mask_in6 = mask->sin6_addr;
struct in6_addr& remote_in6 = remote_addr.sin6_addr;
bool same = true;
int len = 16; //IPv6 address is 16 unsigned char
for (int c = 0; c < len; c++) { //Network byte order is big-endian
char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
if (c1 ^ c2) {
same = false;
break;
}
}
// At last, we need to compare scope id
// Two Link-type addresses can have the same subnet address even though they are not in the same scope
// For Global type, this field is 0, so a comparison wouldn't matter
same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id);
return same;
} else {
WARN("Net : Unsupported address family type");
return false;
}
}
int ncclFindInterfaceMatchSubnet(char* ifNames, union ncclSocketAddress* localAddrs, union ncclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
#ifdef ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN+1];
#endif
char line_a[SOCKET_NAME_MAXLEN+1];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for (interface = interfaces; interface && !found; interface = interface->ifa_next) {
if (interface->ifa_addr == NULL) continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if (family != AF_INET && family != AF_INET6)
continue;
// check against user specified interfaces
if (!matchSubnet(*interface, remoteAddr)) {
continue;
}
// Store the local IP address
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
memcpy(localAddrs+found, interface->ifa_addr, salen);
// Store the interface name
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, ncclSocketToString(localAddrs+found, line), ncclSocketToString(remoteAddr, line_a));
found++;
if (found == maxIfs) break;
}
if (found == 0) {
WARN("Net : No interface found in the same subnet as remote address %s", ncclSocketToString(remoteAddr, line_a));
}
freeifaddrs(interfaces);
return found;
}
ncclResult_t ncclSocketGetAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair) {
if (!(ip_port_pair && strlen(ip_port_pair) > 1)) {
WARN("Net : string is null");
return ncclInvalidArgument;
}
bool ipv6 = ip_port_pair[0] == '[';
/* Construct the sockaddress structure */
if (!ipv6) {
struct netIf ni;
// parse <ip_or_hostname>:<port> string, expect one pair
if (parseStringList(ip_port_pair, &ni, 1) != 1) {
WARN("Net : No valid <IPv4_or_hostname>:<port> pair found");
return ncclInvalidArgument;
}
struct addrinfo hints, *p;
int rv;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
if ( (rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
WARN("Net : error encountered when getting address info : %s", gai_strerror(rv));
return ncclInvalidArgument;
}
// use the first
if (p->ai_family == AF_INET) {
struct sockaddr_in& sin = ua->sin;
memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
sin.sin_family = AF_INET; // IPv4
//inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address
sin.sin_port = htons(ni.port); // port
} else if (p->ai_family == AF_INET6) {
struct sockaddr_in6& sin6 = ua->sin6;
memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
sin6.sin6_family = AF_INET6; // IPv6
sin6.sin6_port = htons(ni.port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = 0; // should be global scope, set to 0
} else {
WARN("Net : unsupported IP family");
freeaddrinfo(p);
return ncclInvalidArgument;
}
freeaddrinfo(p); // all done with this structure
} else {
int i, j = -1, len = strlen(ip_port_pair);
for (i = 1; i < len; i++) {
if (ip_port_pair[i] == '%') j = i;
if (ip_port_pair[i] == ']') break;
}
if (i == len) {
WARN("Net : No valid [IPv6]:port pair found");
return ncclInvalidArgument;
}
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
memset(ip_str, '\0', sizeof(ip_str));
memset(port_str, '\0', sizeof(port_str));
memset(if_name, '\0', sizeof(if_name));
strncpy(ip_str, ip_port_pair+1, global_scope ? i-1 : j-1);
strncpy(port_str, ip_port_pair+i+2, len-i-1);
int port = atoi(port_str);
if (!global_scope) strncpy(if_name, ip_port_pair+j+1, i-j-1); // If not global scope, we need the intf name
struct sockaddr_in6& sin6 = ua->sin6;
sin6.sin6_family = AF_INET6; // IPv6
inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
sin6.sin6_port = htons(port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
}
return ncclSuccess;
}
int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNameMaxSize, int maxIfs) {
static int shownIfName = 0;
int nIfs = 0;
// Allow user to force the INET socket family selection
int sock_family = envSocketFamily();
// User specified interface
const char* env = ncclGetEnv("NCCL_SOCKET_IFNAME");
if (env && strlen(env) > 1) {
INFO(NCCL_ENV, "NCCL_SOCKET_IFNAME set by environment to %s", env);
// Specified by user : find or fail
if (shownIfName++ == 0) INFO(NCCL_NET, "NCCL_SOCKET_IFNAME set to %s", env);
nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
} else {
// Try to automatically pick the right one
// Start with IB
nIfs = findInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// else see if we can get some hint from COMM ID
if (nIfs == 0) {
const char* commId = ncclGetEnv("NCCL_COMM_ID");
if (commId && strlen(commId) > 1) {
INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId);
// Try to find interface that is in the same subnet as the IP in comm id
union ncclSocketAddress idAddr;
ncclSocketGetAddrFromString(&idAddr, commId);
nIfs = ncclFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
}
}
// Then look for anything else (but not docker or lo)
if (nIfs == 0) nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// Finally look for docker, then lo.
if (nIfs == 0) nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
if (nIfs == 0) nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
}
return nIfs;
}
ncclResult_t ncclSocketListen(struct ncclSocket* sock) {
if (sock == NULL) {
WARN("ncclSocketListen: pass NULL socket");
return ncclInvalidArgument;
}
if (sock->fd == -1) {
WARN("ncclSocketListen: file descriptor is -1");
return ncclInvalidArgument;
}
if (socketToPort(&sock->addr)) {
// Port is forced by env. Make sure we get the port.
int opt = 1;
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt");
#if defined(SO_REUSEPORT)
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
#endif
}
// addr port should be 0 (Any port)
SYSCHECK(bind(sock->fd, &sock->addr.sa, sock->salen), "bind");
/* Get the assigned Port */
socklen_t size = sock->salen;
SYSCHECK(getsockname(sock->fd, &sock->addr.sa, &size), "getsockname");
#ifdef ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN+1];
TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", ncclSocketToString(&sock->addr, line));
#endif
/* Put the socket in listen mode
* NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn
*/
SYSCHECK(listen(sock->fd, 16384), "listen");
sock->state = ncclSocketStateReady;
return ncclSuccess;
}
ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr) {
if (sock == NULL) {
WARN("ncclSocketGetAddr: pass NULL socket");
return ncclInvalidArgument;
}
if (sock->state != ncclSocketStateReady) return ncclInternalError;
memcpy(addr, &sock->addr, sizeof(union ncclSocketAddress));
return ncclSuccess;
}
static ncclResult_t socketTryAccept(struct ncclSocket* sock) {
socklen_t socklen = sizeof(union ncclSocketAddress);
sock->fd = accept(sock->acceptFd, (struct sockaddr*)&sock->addr, &socklen);
if (sock->fd != -1) {
sock->state = ncclSocketStateAccepted;
} else if (errno == ENETDOWN || errno == EPROTO || errno == ENOPROTOOPT || errno == EHOSTDOWN ||
errno == ENONET || errno == EHOSTUNREACH || errno == EOPNOTSUPP || errno == ENETUNREACH) {
/* per accept's man page, for linux sockets, the following errors might be already pending errors
* and should be considered as EAGAIN. To avoid infinite loop in case of errors, we use the retry count*/
if (++sock->errorRetries == ncclParamRetryCnt()) {
WARN("socketTryAccept: exceeded error retry count (%d), %s", sock->errorRetries, strerror(errno));
return ncclSystemError;
}
INFO(NCCL_ALL, "Call to accept returned %s, retrying", strerror(errno));
} else if (errno != EAGAIN && errno != EWOULDBLOCK) {
WARN("socketTryAccept: Accept failed: %s", strerror(errno));
return ncclSystemError;
}
return ncclSuccess;
}
static ncclResult_t socketSetFlags(struct ncclSocket* sock) {
const int one = 1;
/* Set socket as non-blocking if async or if we need to be able to abort */
if ((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) {
int flags;
SYSCHECK(flags = fcntl(sock->fd, F_GETFL), "fcntl");
SYSCHECK(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
}
SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
return ncclSuccess;
}
static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) {
uint64_t magic;
enum ncclSocketType type;
int received;
// once accepted, linux sockets do NOT inherit file status flags such as O_NONBLOCK (BSD ones do)
NCCLCHECK(socketSetFlags(sock));
if (sock->asyncFlag == 0 || sock->finalizeCounter < sizeof(magic)) {
if (sock->asyncFlag == 0) {
received = 0;
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
} else {
received = sock->finalizeCounter;
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, sock, sock->finalizeBuffer, sizeof(magic), &received));
sock->finalizeCounter = received;
if (received < sizeof(magic)) return ncclSuccess;
memcpy(&magic, sock->finalizeBuffer, sizeof(magic));
}
if (magic != sock->magic) {
WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic);
close(sock->fd);
sock->fd = -1;
// Ignore spurious connection and accept again
sock->state = ncclSocketStateAccepting;
return ncclSuccess;
}
}
if (sock->asyncFlag == 0) {
received = 0;
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &type, sizeof(type), &received));
} else {
received = sock->finalizeCounter - sizeof(magic);
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, sock, sock->finalizeBuffer, sizeof(type), &received));
sock->finalizeCounter = received + sizeof(magic);
if (received < sizeof(type)) return ncclSuccess;
memcpy(&type, sock->finalizeBuffer, sizeof(type));
}
if (type != sock->type) {
WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type);
sock->state = ncclSocketStateError;
close(sock->fd);
sock->fd = -1;
return ncclInternalError;
} else {
sock->state = ncclSocketStateReady;
}
return ncclSuccess;
}
static ncclResult_t socketResetFd(struct ncclSocket* sock) {
ncclResult_t ret = ncclSuccess;
int fd = -1;
SYSCHECKGOTO(fd = socket(sock->addr.sa.sa_family, SOCK_STREAM, 0), "socket", ret, cleanup);
// if sock->fd is valid, close it and reuse its number
if (sock->fd != -1) {
SYSCHECKGOTO(dup2(fd, sock->fd), "dup2", ret, cleanup);
SYSCHECKGOTO(close(fd), "close", ret, cleanup);
} else {
sock->fd = fd;
}
NCCLCHECKGOTO(socketSetFlags(sock), ret, exit);
exit:
return ret;
cleanup:
// cleanup fd, leave sock->fd untouched
if (fd != -1) {
(void)close(fd);
}
goto exit;
}
static ncclResult_t socketConnectCheck(struct ncclSocket* sock, int errCode, const char funcName[]) {
if (errCode == 0) {
sock->state = ncclSocketStateConnected;
} else if (errCode == EINPROGRESS) {
sock->state = ncclSocketStateConnectPolling;
} else if (errCode == ETIMEDOUT || errCode == EHOSTUNREACH || errCode == ECONNREFUSED) {
if (sock->customRetry == 0) {
if (sock->errorRetries++ == ncclParamRetryCnt()) {
sock->state = ncclSocketStateError;
WARN("%s: connect returned %s, exceeded error retry count (%d)", funcName, strerror(errCode), sock->errorRetries);
return ncclRemoteError;
}
unsigned int sleepTime = sock->errorRetries * ncclParamRetryTimeOut();
INFO(NCCL_ALL, "%s: connect returned %s, retrying (%d/%ld) after sleep for %u msec", funcName, strerror(errCode), sock->errorRetries, ncclParamRetryCnt(), sleepTime);
msleep(sleepTime);
}
NCCLCHECK(socketResetFd(sock)); /* in case of failure in connect, socket state is unspecified */
sock->state = ncclSocketStateConnecting;
} else {
char line[SOCKET_NAME_MAXLEN+1];
sock->state = ncclSocketStateError;
WARN("%s: Connect to %s failed : %s", funcName, ncclSocketToString(&sock->addr, line), strerror(errCode));
return ncclSystemError;
}
return ncclSuccess;
}
static ncclResult_t socketStartConnect(struct ncclSocket* sock) {
/* blocking/non-blocking connect() is determined by asyncFlag. */
int ret = connect(sock->fd, &sock->addr.sa, sock->salen);
return socketConnectCheck(sock, (ret == -1) ? errno : 0, __func__);
}
static ncclResult_t socketPollConnect(struct ncclSocket* sock) {
struct pollfd pfd;
int timeout = 1, ret;
socklen_t rlen = sizeof(int);
memset(&pfd, 0, sizeof(struct pollfd));
pfd.fd = sock->fd;
pfd.events = POLLOUT;
ret = poll(&pfd, 1, timeout);
if (ret == 0 || (ret < 0 && errno == EINTR)) {
return ncclSuccess;
} else if (ret < 0) {
WARN("socketPollConnect poll() failed with error %s", strerror(errno));
return ncclRemoteError;
} else if (ret != 1 || (pfd.revents & POLLOUT) == 0) {
WARN("socketPollConnect poll() returned %d%s", ret, (pfd.revents & POLLOUT) ? "" : ", no POLLOUT events");
return ncclSystemError;
}
/* check socket status */
SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt");
return socketConnectCheck(sock, ret, __func__);
}
ncclResult_t ncclSocketPollConnect(struct ncclSocket* sock) {
if (sock == NULL) {
WARN("ncclSocketPollConnect: pass NULL socket");
return ncclInvalidArgument;
}
NCCLCHECK(socketPollConnect(sock));
return ncclSuccess;
}
static ncclResult_t socketFinalizeConnect(struct ncclSocket* sock) {
int sent;
if (sock->asyncFlag == 0) {
sent = 0;
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
sent = 0;
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent));
} else {
if (sock->finalizeCounter < sizeof(sock->magic)) {
sent = sock->finalizeCounter;
NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
sock->finalizeCounter = sent;
if (sent < sizeof(sock->magic)) return ncclSuccess;
}
sent = sock->finalizeCounter - sizeof(sock->magic);
NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent));
sock->finalizeCounter = sent + sizeof(sock->magic);
if (sent < sizeof(sock->type)) return ncclSuccess;
}
sock->state = ncclSocketStateReady;
return ncclSuccess;
}
static ncclResult_t socketProgressState(struct ncclSocket* sock) {
if (sock->state == ncclSocketStateAccepting) {
NCCLCHECK(socketTryAccept(sock));
}
if (sock->state == ncclSocketStateAccepted) {
NCCLCHECK(socketFinalizeAccept(sock));
}
if (sock->state == ncclSocketStateConnecting) {
NCCLCHECK(socketStartConnect(sock));
}
if (sock->state == ncclSocketStateConnectPolling) {
NCCLCHECK(socketPollConnect(sock));
}
if (sock->state == ncclSocketStateConnected) {
NCCLCHECK(socketFinalizeConnect(sock));
}
return ncclSuccess;
}
ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running) {
if (sock == NULL) {
*running = 0;
return ncclSuccess;
}
if (sock->state == ncclSocketStateError || sock->state == ncclSocketStateClosed) {
WARN("ncclSocketReady: unexpected socket state %d", sock->state);
return ncclRemoteError;
}
*running = (sock->state == ncclSocketStateReady) ? 1 : 0;
if (*running == 0) {
NCCLCHECK(socketProgressState(sock));
*running = (sock->state == ncclSocketStateReady) ? 1 : 0;
}
return ncclSuccess;
}
ncclResult_t ncclSocketConnect(struct ncclSocket* sock) {
#ifdef ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN+1];
#endif
if (sock == NULL) {
WARN("ncclSocketConnect: pass NULL socket");
return ncclInvalidArgument;
}
if (sock->fd == -1) {
WARN("ncclSocketConnect: file descriptor is -1");
return ncclInvalidArgument;
}
if (sock->state != ncclSocketStateInitialized) {
WARN("ncclSocketConnect: wrong socket state %d", sock->state);
if (sock->state == ncclSocketStateError) return ncclRemoteError;
return ncclInternalError;
}
TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", ncclSocketToString(&sock->addr, line));
sock->state = ncclSocketStateConnecting;
sock->finalizeCounter = 0;
do {
NCCLCHECK(socketProgressState(sock));
} while (sock->asyncFlag == 0 &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) &&
(sock->state == ncclSocketStateConnecting ||
sock->state == ncclSocketStateConnectPolling ||
sock->state == ncclSocketStateConnected));
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) return ncclInternalError;
switch (sock->state) {
case ncclSocketStateConnecting:
case ncclSocketStateConnectPolling:
case ncclSocketStateConnected:
case ncclSocketStateReady:
return ncclSuccess;
case ncclSocketStateError:
return ncclSystemError;
default:
WARN("ncclSocketConnect: wrong socket state %d", sock->state);
return ncclInternalError;
}
}
ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSock) {
ncclResult_t ret = ncclSuccess;
if (listenSock == NULL || sock == NULL) {
WARN("ncclSocketAccept: pass NULL socket");
ret = ncclInvalidArgument;
goto exit;
}
if (listenSock->state != ncclSocketStateReady) {
WARN("ncclSocketAccept: wrong socket state %d", listenSock->state);
if (listenSock->state == ncclSocketStateError)
ret = ncclSystemError;
else
ret = ncclInternalError;
goto exit;
}
if (sock->acceptFd == -1) {
memcpy(sock, listenSock, sizeof(struct ncclSocket));
sock->acceptFd = listenSock->fd;
sock->state = ncclSocketStateAccepting;
sock->finalizeCounter = 0;
}
do {
NCCLCHECKGOTO(socketProgressState(sock), ret, exit);
} while (sock->asyncFlag == 0 &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) &&
(sock->state == ncclSocketStateAccepting ||
sock->state == ncclSocketStateAccepted));
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) return ncclInternalError;
switch (sock->state) {
case ncclSocketStateAccepting:
case ncclSocketStateAccepted:
case ncclSocketStateReady:
ret = ncclSuccess;
break;
case ncclSocketStateError:
ret = ncclSystemError;
break;
default:
WARN("ncclSocketAccept: wrong socket state %d", sock->state);
ret = ncclInternalError;
break;
}
exit:
return ret;
}
ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag, int customRetry) {
ncclResult_t ret = ncclSuccess;
if (sock == NULL) goto exit;
sock->errorRetries = 0;
sock->abortFlag = abortFlag;
sock->asyncFlag = asyncFlag;
sock->state = ncclSocketStateInitialized;
sock->magic = magic;
sock->type = type;
sock->fd = -1;
sock->acceptFd = -1;
sock->customRetry = customRetry;
if (addr) {
/* IPv4/IPv6 support */
int family;
memcpy(&sock->addr, addr, sizeof(union ncclSocketAddress));
family = sock->addr.sa.sa_family;
if (family != AF_INET && family != AF_INET6) {
char line[SOCKET_NAME_MAXLEN+1];
WARN("ncclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
ncclSocketToString(&sock->addr, line), family, AF_INET, AF_INET6);
ret = ncclInternalError;
goto exit;
}
sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
// in case of error, we close the fd before returning as it's unclear if the caller has to use ncclSocketClose for cleanup
NCCLCHECKGOTO(socketResetFd(sock), ret, fail);
} else {
memset(&sock->addr, 0, sizeof(union ncclSocketAddress));
}
exit:
return ret;
fail:
if (sock->fd != -1) {
close(sock->fd);
sock->fd = -1;
}
goto exit;
}
ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int* closed) {
if (sock == NULL) {
WARN("ncclSocketProgress: pass NULL socket");
return ncclInvalidArgument;
}
NCCLCHECK(socketProgress(op, sock, ptr, size, offset, closed));
return ncclSuccess;
}
ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) {
if (sock == NULL) {
WARN("ncclSocketWait: pass NULL socket");
return ncclInvalidArgument;
}
NCCLCHECK(socketWait(op, sock, ptr, size, offset));
return ncclSuccess;
}
ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size) {
int offset = 0;
if (sock == NULL) {
WARN("ncclSocketSend: pass NULL socket");
return ncclInvalidArgument;
}
if (sock->state != ncclSocketStateReady) {
WARN("ncclSocketSend: socket state (%d) is not ready", sock->state);
return ncclInternalError;
}
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, ptr, size, &offset));
return ncclSuccess;
}
ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size) {
int offset = 0;
if (sock == NULL) {
WARN("ncclSocketRecv: pass NULL socket");
return ncclInvalidArgument;
}
if (sock->state != ncclSocketStateReady && sock->state != ncclSocketStateTerminating) {
WARN("ncclSocketRecv: socket state (%d) is not ready", sock->state);
return ncclInternalError;
}
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, ptr, size, &offset));
return ncclSuccess;
}
ncclResult_t ncclSocketSendRecv(struct ncclSocket* sendSock, void* sendPtr, int sendSize, struct ncclSocket* recvSock, void* recvPtr, int recvSize) {
int sendOffset = 0, recvOffset = 0;
if (sendSock == NULL || recvSock == NULL) {
WARN("ncclSocketSendRecv: invalid socket %p/%p", sendSock, recvSock);
return ncclInternalError;
}
if (sendSock->state != ncclSocketStateReady ||
(recvSock->state != ncclSocketStateReady && recvSock->state != ncclSocketStateTerminating)) {
WARN("ncclSocketSendRecv: socket state (%d/%d) is not ready", sendSock->state, recvSock->state);
return ncclInternalError;
}
while (sendOffset < sendSize || recvOffset < recvSize) {
if (sendOffset < sendSize) NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sendSock, sendPtr, sendSize, &sendOffset));
if (recvOffset < recvSize) NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, recvSock, recvPtr, recvSize, &recvOffset));
}
return ncclSuccess;
}
// Receive or detect connection closed
ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed, bool blocking) {
int offset = 0;
if (sock == NULL) {
WARN("ncclSocketTryRecv: pass NULL socket");
return ncclInvalidArgument;
}
*closed = 0;
// Block until connection closes or nbytes received
if (blocking) {
while (offset < size) {
NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if (*closed) return ncclSuccess;
}
} else {
NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if (*closed) return ncclSuccess;
// If any bytes were received, block waiting for the rest
if (offset > 0) {
while (offset < size) {
NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if (*closed) return ncclSuccess;
}
// No bytes were received, return ncclInProgress
} else {
return ncclInProgress;
}
}
return ncclSuccess;
}
// Make it possible to close just one part of a socket.
ncclResult_t ncclSocketShutdown(struct ncclSocket* sock, int how) {
if (sock != NULL) {
if (sock->fd >= 0) {
shutdown(sock->fd, how);
}
sock->state = ncclSocketStateTerminating;
}
return ncclSuccess;
}
ncclResult_t ncclSocketClose(struct ncclSocket* sock, bool wait) {
if (sock != NULL) {
if (sock->state > ncclSocketStateNone && sock->state < ncclSocketStateNum && sock->fd >= 0) {
if (wait) {
char data;
int closed = 0;
do {
int offset = 0;
if (ncclSocketProgress(NCCL_SOCKET_RECV, sock, &data, sizeof(char), &offset, &closed) != ncclSuccess) break;
} while (closed == 0);
}
/* shutdown() is needed to send FIN packet to proxy thread; shutdown() is not affected
* by refcount of fd, but close() is. close() won't close a fd and send FIN packet if
* the fd is duplicated (e.g. fork()). So shutdown() guarantees the correct and graceful
* connection close here. */
shutdown(sock->fd, SHUT_RDWR);
close(sock->fd);
}
sock->state = ncclSocketStateClosed;
sock->fd = -1;
}
return ncclSuccess;
}
ncclResult_t ncclSocketGetFd(struct ncclSocket* sock, int* fd) {
if (sock == NULL) {
WARN("ncclSocketGetFd: pass NULL socket");
return ncclInvalidArgument;
}
if (fd) *fd = sock->fd;
return ncclSuccess;
}
ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock) {
if (sock == NULL) {
WARN("ncclSocketGetFd: pass NULL socket");
return ncclInvalidArgument;
}
sock->fd = fd;
return ncclSuccess;
}