Fix share memory collision in multi-communicator case.

Current SHM object name would only use pidHash and ranks as
identification, which would collide each other when program runs with
multiple communicators. Here we added commId info into pidHash, it makes
'pidHash'es of different communicators keeping in same process will be
distincted with each other.
This commit is contained in:
Cao Zongyan 2019-03-13 17:13:39 +08:00
parent 14e0cf644b
commit 161763aab2
3 changed files with 17 additions and 6 deletions

View File

@ -11,6 +11,7 @@
#include <stdint.h> #include <stdint.h>
ncclResult_t getHostName(char* hostname, int maxlen); ncclResult_t getHostName(char* hostname, int maxlen);
uint64_t getnHash(const char* string, int n);
uint64_t getHostHash(); uint64_t getHostHash();
uint64_t getPidHash(); uint64_t getPidHash();

View File

@ -302,12 +302,12 @@ static void showVersion() {
} }
} }
static ncclResult_t fillInfo(struct ncclPeerInfo* info, int rank) { static ncclResult_t fillInfo(struct ncclPeerInfo* info, int rank, uint64_t commHash) {
info->rank = rank; info->rank = rank;
CUDACHECK(cudaGetDevice(&info->cudaDev)); CUDACHECK(cudaGetDevice(&info->cudaDev));
NCCLCHECK(getNvmlDevice(info->cudaDev, &info->nvmlDev)) NCCLCHECK(getNvmlDevice(info->cudaDev, &info->nvmlDev))
info->hostHash=getHostHash(); info->hostHash=getHostHash()+commHash;
info->pidHash=getPidHash(); info->pidHash=getPidHash()+commHash;
// Get PCI Bus Id. We need to get the bus ID through CUDA first, since the // Get PCI Bus Id. We need to get the bus ID through CUDA first, since the
// cudaDev is a CUDA runtime dev number which could be different from the // cudaDev is a CUDA runtime dev number which could be different from the
@ -679,7 +679,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
int rank = comm->rank; int rank = comm->rank;
int nranks = comm->nRanks; int nranks = comm->nRanks;
TRACE(NCCL_INIT, "rank %d nranks %d - BEGIN", rank, nranks); uint64_t commHash = getnHash(commId->internal, NCCL_UNIQUE_ID_BYTES);
TRACE(NCCL_INIT, "comm %p, commHash %lu, rank %d nranks %d - BEGIN", comm, commHash, rank, nranks);
NCCLCHECK(bootstrapInit(commId, rank, nranks, &comm->bootstrap)); NCCLCHECK(bootstrapInit(commId, rank, nranks, &comm->bootstrap));
// AllGather1 - begin // AllGather1 - begin
@ -690,7 +691,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
NCCLCHECK(ncclCalloc(&allGather1Data, nranks)); NCCLCHECK(ncclCalloc(&allGather1Data, nranks));
allGather1Data[rank].comm = comm; allGather1Data[rank].comm = comm;
NCCLCHECK(fillInfo(&allGather1Data[rank].peerInfo, rank)); NCCLCHECK(fillInfo(&allGather1Data[rank].peerInfo, rank, commHash));
NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather1Data, sizeof(*allGather1Data))); NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather1Data, sizeof(*allGather1Data)));
NCCLCHECK(ncclCalloc(&comm->peerInfo, nranks)); NCCLCHECK(ncclCalloc(&comm->peerInfo, nranks));
@ -960,7 +961,7 @@ static ncclResult_t initTransportsAll(struct ncclComm** comms, const int* devs,
NCCLCHECK(ncclCalloc(&allInfo, nranks)); NCCLCHECK(ncclCalloc(&allInfo, nranks));
for (int rank=0; rank<nranks; rank++) { for (int rank=0; rank<nranks; rank++) {
CUDACHECK(cudaSetDevice(devs[rank])); CUDACHECK(cudaSetDevice(devs[rank]));
NCCLCHECK(fillInfo(allInfo+rank, rank)); NCCLCHECK(fillInfo(allInfo+rank, rank, 0));
} }
int* connectTransport; int* connectTransport;

View File

@ -96,6 +96,15 @@ uint64_t getHash(const char* string) {
return result; return result;
} }
uint64_t getnHash(const char* string, int n) {
// Based on DJB2, result = result * 33 + char
uint64_t result = 9527;
for (int c = 0; c < n; c++) {
result = ((result << 5) + result) + string[c];
}
return result;
}
/* Generate a hash of the unique identifying string for this host /* Generate a hash of the unique identifying string for this host
* that will be unique for both bare-metal and container instances * that will be unique for both bare-metal and container instances
* Equivalent of a hash of; * Equivalent of a hash of;