Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions gloo/test/pair_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "gloo/test/base_test.h"

#include <string>

#include "gloo/common/logging.h"

#if GLOO_HAVE_TRANSPORT_TCP
Expand All @@ -23,6 +25,19 @@ using Param = Transport;
class PairTest : public BaseTest,
public ::testing::WithParamInterface<Param> {};

TEST_P(PairTest, ReportsPeerRank) {
const auto transport = GetParam();

spawn(transport, 2, [&](std::shared_ptr<Context> context) {
const int peer = (context->rank + 1) % 2;
auto& pair = context->getPair(peer);
EXPECT_EQ(peer, pair->getPeerRank());
EXPECT_NE(
pair->peerDescription().find("rank " + std::to_string(peer)),
std::string::npos);
});
}

// Regression test for the size_t overflow in tcp::Pair::prepareRead.
//
// A SEND_BUFFER preamble carries `roffset` and `length` as size_t values read
Expand Down
14 changes: 7 additions & 7 deletions gloo/transport/ibverbs/pair.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void Pair::recvMemoryRegion(
(std::chrono::steady_clock::now() - start) >= timeout_) {
lock.unlock();
signalIoFailure(GLOO_ERROR_MSG(
"Timeout waiting for memory region from ", peer_.str()));
"Timeout waiting for memory region from ", peerDescription()));
GLOO_ENFORCE(false, "Unexpected code path");
}
it = peerMemoryRegions_.find(slot);
Expand Down Expand Up @@ -447,7 +447,7 @@ void Pair::put(
GLOO_DEBUG(
self_.str(),
"->",
peer_.str(),
peerDescription(),
": ",
"put UnboundBuffer async slot=",
wr.wr_id,
Expand Down Expand Up @@ -509,7 +509,7 @@ void Pair::get(
GLOO_DEBUG(
self_.str(),
"->",
peer_.str(),
peerDescription(),
": ",
"get UnboundBuffer async slot=",
wr.wr_id,
Expand Down Expand Up @@ -587,7 +587,7 @@ void Pair::pollCompletions() {
GLOO_ERROR(
self_.str(),
"->",
peer_.str(),
peerDescription(),
": ",
"Exception in handleCompletion: ",
ex.what());
Expand All @@ -606,7 +606,7 @@ void Pair::handleCompletion(struct ibv_wc* wc) {
GLOO_DEBUG(
self_.str(),
"->",
peer_.str(),
peerDescription(),
": handleCompletion id=",
wc->wr_id,
" opcode=",
Expand Down Expand Up @@ -659,7 +659,7 @@ void Pair::handleCompletion(struct ibv_wc* wc) {
GLOO_DEBUG(
self_.str(),
"->",
peer_.str(),
peerDescription(),
": handleCompletion id=",
wc->wr_id,
" opcode=IBV_WC_RECV slot=",
Expand Down Expand Up @@ -708,7 +708,7 @@ void Pair::handleCompletion(struct ibv_wc* wc) {
GLOO_DEBUG(
self_.str(),
"->",
peer_.str(),
peerDescription(),
": handleCompletion id=",
wc->wr_id,
" opcode=IBV_WC_SEND");
Expand Down
6 changes: 5 additions & 1 deletion gloo/transport/ibverbs/pair.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class Pair : public ::gloo::transport::Pair {

virtual bool isConnected() override;

int getPeerRank() const override {
return rank_;
}

// Send from the specified buffer to remote side of pair.
virtual void send(
transport::UnboundBuffer* tbuf,
Expand Down Expand Up @@ -199,7 +203,7 @@ class Pair : public ::gloo::transport::Pair {
return timeout_;
}

const Address& peer() const {
const Address& peer() const override {
return peer_;
}

Expand Down
6 changes: 6 additions & 0 deletions gloo/transport/pair.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@

#include "gloo/transport/pair.h"

#include "gloo/common/string.h"

namespace gloo {
namespace transport {

// Have to provide implementation for pure virtual destructor.
Pair::~Pair() {}

std::string Pair::peerDescription() const {
return ::gloo::MakeString("rank ", getPeerRank(), " (", peer().str(), ")");
}

} // namespace transport
} // namespace gloo
8 changes: 8 additions & 0 deletions gloo/transport/pair.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <memory>
#include <string>

#include "gloo/common/logging.h"
#include "gloo/transport/address.h"
Expand Down Expand Up @@ -42,6 +43,11 @@ class Pair {

virtual bool isConnected() = 0;

// Returns the global rank and address of the peer process this pair connects
// to
virtual int getPeerRank() const = 0;
std::string peerDescription() const;

// Send from the specified buffer to remote side of pair.
virtual void send(
UnboundBuffer* buf,
Expand Down Expand Up @@ -92,6 +98,8 @@ class Pair {
}

protected:
virtual const Address& peer() const = 0;

// Rank of the process on the local machine
// e.g. Suppose we have 2 machines with 8 GPUs per machine.
// This means we have a total of 16 processes with
Expand Down
18 changes: 10 additions & 8 deletions gloo/transport/tcp/pair.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void Pair::setSync(bool sync, bool busyPoll) {
waitUntilConnected(lock, false);
if (state_ == CLOSED) {
signalAndThrowException(
GLOO_ERROR_MSG("Socket unexpectedly closed ", peer_.str()));
GLOO_ERROR_MSG("Socket unexpectedly closed ", peerDescription()));
}

if (!sync_) {
Expand Down Expand Up @@ -311,7 +311,7 @@ bool Pair::write(Op& op) {
if (errno == EAGAIN) {
if (sync_) {
// Sync mode: blocking call returning with EAGAIN indicates timeout.
signalException(GLOO_ERROR_MSG("Write timeout ", peer_.str()));
signalException(GLOO_ERROR_MSG("Write timeout ", peerDescription()));
} else {
// Async mode: can't write more than this.
}
Expand All @@ -336,7 +336,7 @@ bool Pair::write(Op& op) {

// Unexpected error
signalException(
GLOO_ERROR_MSG("writev ", peer_.str(), ": ", strerror(errno)));
GLOO_ERROR_MSG("writev ", peerDescription(), ": ", strerror(errno)));
return false;
}

Expand Down Expand Up @@ -528,7 +528,8 @@ bool Pair::read() {
} else {
// Either timeout on poll or blocking call returning with EAGAIN
// indicates timeout
signalException(GLOO_ERROR_MSG("Read timeout ", peer_.str()));
signalException(
GLOO_ERROR_MSG("Read timeout ", peerDescription()));
}
} else {
// Async mode: can't read more than this.
Expand All @@ -544,7 +545,7 @@ bool Pair::read() {
// Unexpected error
signalException(GLOO_ERROR_MSG(
"Read error ",
peer_.str(),
peerDescription(),
": ",
strerror(errno),
". ",
Expand All @@ -560,7 +561,7 @@ bool Pair::read() {
if (rv == 0) {
signalException(GLOO_ERROR_MSG(
"Connection closed by peer ",
peer_.str(),
peerDescription(),
". ",
"This is typically caused by a remote worker crashing. ",
"Check the logs of the remote worker before reporting an error. ",
Expand Down Expand Up @@ -812,12 +813,13 @@ void Pair::verifyConnected(std::unique_lock<std::mutex>& lock) {
"Pair is not connected (",
self_.str(),
" <--> ",
peer_.str(),
peerDescription(),
")");
// Check if the socket has been closed. We were unable to tell if this was an
// error or normal tear down, but now throw since we are trying to do IO.
if (state_ == CLOSED) {
signalAndThrowException(GLOO_ERROR_MSG("Socket closed ", peer_.str()));
signalAndThrowException(
GLOO_ERROR_MSG("Socket closed ", peerDescription()));
}
}

Expand Down
6 changes: 5 additions & 1 deletion gloo/transport/tcp/pair.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class Pair : public ::gloo::transport::Pair, public Handler {

bool isConnected() override;

int getPeerRank() const override {
return rank_;
}

protected:
// Refer to parent context using raw pointer. This could be a
// weak_ptr, seeing as the context class is a shared_ptr, but:
Expand Down Expand Up @@ -206,7 +210,7 @@ class Pair : public ::gloo::transport::Pair, public Handler {
void send(Op& op);
void recv();

const Address& peer() const {
const Address& peer() const override {
return peer_;
}

Expand Down
8 changes: 4 additions & 4 deletions gloo/transport/tcp/tls/pair.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ bool Pair::write(Op& op) {
// Unexpected error
signalException(GLOO_ERROR_MSG(
"SSL_write ",
peer_.str(),
peerDescription(),
" failed: ",
"ssl error: ",
err,
Expand Down Expand Up @@ -182,7 +182,7 @@ bool Pair::read() {
// Unexpected error
signalException(GLOO_ERROR_MSG(
"SSL_read ",
peer_.str(),
peerDescription(),
" failed: ",
"ssl error: ",
err,
Expand All @@ -198,7 +198,7 @@ bool Pair::read() {
// Transition to CLOSED on EOF
if (rv == 0) {
signalException(
GLOO_ERROR_MSG("Connection closed by peer ", peer_.str()));
GLOO_ERROR_MSG("Connection closed by peer ", peerDescription()));
return false;
}

Expand Down Expand Up @@ -305,7 +305,7 @@ void Pair::verifyConnected(std::unique_lock<std::mutex>& lock) {
"Pair is not SSL connected (",
self_.str(),
" <--> ",
peer_.str(),
peerDescription(),
")");
}

Expand Down
3 changes: 2 additions & 1 deletion gloo/transport/uv/pair.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void Pair::connect(const std::vector<char>& bytes) {

std::unique_lock<std::mutex> lock(mutex_);
GLOO_ENFORCE_EQ(state_, INITIALIZED);
peer_ = peer;
state_ = CONNECTING;

// Both processes call the `Pair::connect` function with the address
Expand Down Expand Up @@ -101,7 +102,7 @@ void Pair::connect(const std::vector<char>& bytes) {
if (errno_) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Error connecting to ",
peer.str(),
peerDescription(),
": ",
libuv::ErrorEvent(errno_).what()));
}
Expand Down
11 changes: 11 additions & 0 deletions gloo/transport/uv/pair.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ class Pair : public ::gloo::transport::Pair {

bool isConnected() override;

int getPeerRank() const override {
return rank_;
}

protected:
const Address& peer() const override {
return peer_;
}

private:
std::mutex mutex_;
std::condition_variable cv_;
Expand Down Expand Up @@ -201,6 +210,8 @@ class Pair : public ::gloo::transport::Pair {
// underlying connection is closed before we destruct.
State state_;

Address peer_;

// Error state of the handle, if set.
int errno_;

Expand Down
Loading