diff --git a/gloo/test/pair_test.cc b/gloo/test/pair_test.cc index 05eec4ba2..c8ca3ae46 100644 --- a/gloo/test/pair_test.cc +++ b/gloo/test/pair_test.cc @@ -8,6 +8,8 @@ #include "gloo/test/base_test.h" +#include + #include "gloo/common/logging.h" #if GLOO_HAVE_TRANSPORT_TCP @@ -23,6 +25,19 @@ using Param = Transport; class PairTest : public BaseTest, public ::testing::WithParamInterface {}; +TEST_P(PairTest, ReportsPeerRank) { + const auto transport = GetParam(); + + spawn(transport, 2, [&](std::shared_ptr context) { + const int peer = (context->rank + 1) % 2; + auto& pair = context->getPair(peer); + EXPECT_EQ(peer, pair->getPeerRank()); + EXPECT_NE( + pair->peerDescription().find("rank " + std::to_string(peer)), + std::string::npos); + }); +} + // Regression test for the size_t overflow in tcp::Pair::prepareRead. // // A SEND_BUFFER preamble carries `roffset` and `length` as size_t values read diff --git a/gloo/transport/ibverbs/pair.cc b/gloo/transport/ibverbs/pair.cc index d3a0410bb..3fb98f928 100644 --- a/gloo/transport/ibverbs/pair.cc +++ b/gloo/transport/ibverbs/pair.cc @@ -262,7 +262,7 @@ void Pair::recvMemoryRegion( (std::chrono::steady_clock::now() - start) >= timeout_) { lock.unlock(); signalIoFailure(GLOO_ERROR_MSG( - "Timeout waiting for memory region from ", peer_.str())); + "Timeout waiting for memory region from ", peerDescription())); GLOO_ENFORCE(false, "Unexpected code path"); } it = peerMemoryRegions_.find(slot); @@ -447,7 +447,7 @@ void Pair::put( GLOO_DEBUG( self_.str(), "->", - peer_.str(), + peerDescription(), ": ", "put UnboundBuffer async slot=", wr.wr_id, @@ -509,7 +509,7 @@ void Pair::get( GLOO_DEBUG( self_.str(), "->", - peer_.str(), + peerDescription(), ": ", "get UnboundBuffer async slot=", wr.wr_id, @@ -587,7 +587,7 @@ void Pair::pollCompletions() { GLOO_ERROR( self_.str(), "->", - peer_.str(), + peerDescription(), ": ", "Exception in handleCompletion: ", ex.what()); @@ -606,7 +606,7 @@ void Pair::handleCompletion(struct ibv_wc* wc) { GLOO_DEBUG( self_.str(), "->", - peer_.str(), + peerDescription(), ": handleCompletion id=", wc->wr_id, " opcode=", @@ -659,7 +659,7 @@ void Pair::handleCompletion(struct ibv_wc* wc) { GLOO_DEBUG( self_.str(), "->", - peer_.str(), + peerDescription(), ": handleCompletion id=", wc->wr_id, " opcode=IBV_WC_RECV slot=", @@ -708,7 +708,7 @@ void Pair::handleCompletion(struct ibv_wc* wc) { GLOO_DEBUG( self_.str(), "->", - peer_.str(), + peerDescription(), ": handleCompletion id=", wc->wr_id, " opcode=IBV_WC_SEND"); diff --git a/gloo/transport/ibverbs/pair.h b/gloo/transport/ibverbs/pair.h index 8566fd5f6..772b23259 100644 --- a/gloo/transport/ibverbs/pair.h +++ b/gloo/transport/ibverbs/pair.h @@ -87,6 +87,10 @@ class Pair : public ::gloo::transport::Pair { virtual bool isConnected() override; + int getPeerRank() const override { + return rank_; + } + // Send from the specified buffer to remote side of pair. virtual void send( transport::UnboundBuffer* tbuf, @@ -199,7 +203,7 @@ class Pair : public ::gloo::transport::Pair { return timeout_; } - const Address& peer() const { + const Address& peer() const override { return peer_; } diff --git a/gloo/transport/pair.cc b/gloo/transport/pair.cc index a1af2d1fa..291a05ec0 100644 --- a/gloo/transport/pair.cc +++ b/gloo/transport/pair.cc @@ -8,11 +8,17 @@ #include "gloo/transport/pair.h" +#include "gloo/common/string.h" + namespace gloo { namespace transport { // Have to provide implementation for pure virtual destructor. Pair::~Pair() {} +std::string Pair::peerDescription() const { + return ::gloo::MakeString("rank ", getPeerRank(), " (", peer().str(), ")"); +} + } // namespace transport } // namespace gloo diff --git a/gloo/transport/pair.h b/gloo/transport/pair.h index 42ef7cd69..b5f25c050 100644 --- a/gloo/transport/pair.h +++ b/gloo/transport/pair.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include "gloo/common/logging.h" #include "gloo/transport/address.h" @@ -42,6 +43,11 @@ class Pair { virtual bool isConnected() = 0; + // Returns the global rank and address of the peer process this pair connects + // to + virtual int getPeerRank() const = 0; + std::string peerDescription() const; + // Send from the specified buffer to remote side of pair. virtual void send( UnboundBuffer* buf, @@ -92,6 +98,8 @@ class Pair { } protected: + virtual const Address& peer() const = 0; + // Rank of the process on the local machine // e.g. Suppose we have 2 machines with 8 GPUs per machine. // This means we have a total of 16 processes with diff --git a/gloo/transport/tcp/pair.cc b/gloo/transport/tcp/pair.cc index b8772384b..7f1be0d1e 100644 --- a/gloo/transport/tcp/pair.cc +++ b/gloo/transport/tcp/pair.cc @@ -198,7 +198,7 @@ void Pair::setSync(bool sync, bool busyPoll) { waitUntilConnected(lock, false); if (state_ == CLOSED) { signalAndThrowException( - GLOO_ERROR_MSG("Socket unexpectedly closed ", peer_.str())); + GLOO_ERROR_MSG("Socket unexpectedly closed ", peerDescription())); } if (!sync_) { @@ -311,7 +311,7 @@ bool Pair::write(Op& op) { if (errno == EAGAIN) { if (sync_) { // Sync mode: blocking call returning with EAGAIN indicates timeout. - signalException(GLOO_ERROR_MSG("Write timeout ", peer_.str())); + signalException(GLOO_ERROR_MSG("Write timeout ", peerDescription())); } else { // Async mode: can't write more than this. } @@ -336,7 +336,7 @@ bool Pair::write(Op& op) { // Unexpected error signalException( - GLOO_ERROR_MSG("writev ", peer_.str(), ": ", strerror(errno))); + GLOO_ERROR_MSG("writev ", peerDescription(), ": ", strerror(errno))); return false; } @@ -528,7 +528,8 @@ bool Pair::read() { } else { // Either timeout on poll or blocking call returning with EAGAIN // indicates timeout - signalException(GLOO_ERROR_MSG("Read timeout ", peer_.str())); + signalException( + GLOO_ERROR_MSG("Read timeout ", peerDescription())); } } else { // Async mode: can't read more than this. @@ -544,7 +545,7 @@ bool Pair::read() { // Unexpected error signalException(GLOO_ERROR_MSG( "Read error ", - peer_.str(), + peerDescription(), ": ", strerror(errno), ". ", @@ -560,7 +561,7 @@ bool Pair::read() { if (rv == 0) { signalException(GLOO_ERROR_MSG( "Connection closed by peer ", - peer_.str(), + peerDescription(), ". ", "This is typically caused by a remote worker crashing. ", "Check the logs of the remote worker before reporting an error. ", @@ -812,12 +813,13 @@ void Pair::verifyConnected(std::unique_lock& lock) { "Pair is not connected (", self_.str(), " <--> ", - peer_.str(), + peerDescription(), ")"); // Check if the socket has been closed. We were unable to tell if this was an // error or normal tear down, but now throw since we are trying to do IO. if (state_ == CLOSED) { - signalAndThrowException(GLOO_ERROR_MSG("Socket closed ", peer_.str())); + signalAndThrowException( + GLOO_ERROR_MSG("Socket closed ", peerDescription())); } } diff --git a/gloo/transport/tcp/pair.h b/gloo/transport/tcp/pair.h index fb08adb56..8f247869b 100644 --- a/gloo/transport/tcp/pair.h +++ b/gloo/transport/tcp/pair.h @@ -150,6 +150,10 @@ class Pair : public ::gloo::transport::Pair, public Handler { bool isConnected() override; + int getPeerRank() const override { + return rank_; + } + protected: // Refer to parent context using raw pointer. This could be a // weak_ptr, seeing as the context class is a shared_ptr, but: @@ -206,7 +210,7 @@ class Pair : public ::gloo::transport::Pair, public Handler { void send(Op& op); void recv(); - const Address& peer() const { + const Address& peer() const override { return peer_; } diff --git a/gloo/transport/tcp/tls/pair.cc b/gloo/transport/tcp/tls/pair.cc index 3e0fcfde3..5261e8aaf 100644 --- a/gloo/transport/tcp/tls/pair.cc +++ b/gloo/transport/tcp/tls/pair.cc @@ -109,7 +109,7 @@ bool Pair::write(Op& op) { // Unexpected error signalException(GLOO_ERROR_MSG( "SSL_write ", - peer_.str(), + peerDescription(), " failed: ", "ssl error: ", err, @@ -182,7 +182,7 @@ bool Pair::read() { // Unexpected error signalException(GLOO_ERROR_MSG( "SSL_read ", - peer_.str(), + peerDescription(), " failed: ", "ssl error: ", err, @@ -198,7 +198,7 @@ bool Pair::read() { // Transition to CLOSED on EOF if (rv == 0) { signalException( - GLOO_ERROR_MSG("Connection closed by peer ", peer_.str())); + GLOO_ERROR_MSG("Connection closed by peer ", peerDescription())); return false; } @@ -305,7 +305,7 @@ void Pair::verifyConnected(std::unique_lock& lock) { "Pair is not SSL connected (", self_.str(), " <--> ", - peer_.str(), + peerDescription(), ")"); } diff --git a/gloo/transport/uv/pair.cc b/gloo/transport/uv/pair.cc index 4db8bf32c..4ba8556b4 100644 --- a/gloo/transport/uv/pair.cc +++ b/gloo/transport/uv/pair.cc @@ -62,6 +62,7 @@ void Pair::connect(const std::vector& bytes) { std::unique_lock lock(mutex_); GLOO_ENFORCE_EQ(state_, INITIALIZED); + peer_ = peer; state_ = CONNECTING; // Both processes call the `Pair::connect` function with the address @@ -101,7 +102,7 @@ void Pair::connect(const std::vector& bytes) { if (errno_) { throw ::gloo::IoException(GLOO_ERROR_MSG( "Error connecting to ", - peer.str(), + peerDescription(), ": ", libuv::ErrorEvent(errno_).what())); } diff --git a/gloo/transport/uv/pair.h b/gloo/transport/uv/pair.h index 510c9619b..35c68ade1 100644 --- a/gloo/transport/uv/pair.h +++ b/gloo/transport/uv/pair.h @@ -152,6 +152,15 @@ class Pair : public ::gloo::transport::Pair { bool isConnected() override; + int getPeerRank() const override { + return rank_; + } + + protected: + const Address& peer() const override { + return peer_; + } + private: std::mutex mutex_; std::condition_variable cv_; @@ -201,6 +210,8 @@ class Pair : public ::gloo::transport::Pair { // underlying connection is closed before we destruct. State state_; + Address peer_; + // Error state of the handle, if set. int errno_;