Skip to content

Commit

Permalink
Flushing Proxy Channels at CPU side upon reaching the Inflight Reques…
Browse files Browse the repository at this point in the history
…t Limit (#415)

Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
caiomcbr and chhwang authored Jan 8, 2025
1 parent 1989d4b commit 80abce5
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 13 deletions.
42 changes: 34 additions & 8 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ class Endpoint {
/// @return The transport used.
Transport transport();

/// Get the maximum write queue size.
///
/// @return The maximum number of write requests that can be queued.
int maxWriteQueueSize();

/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
Expand Down Expand Up @@ -416,6 +421,10 @@ class Endpoint {
/// Represents a connection between two processes.
class Connection {
public:
/// Constructor.
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
Connection(int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize){};

virtual ~Connection() = default;

/// Write data from a source @ref RegisteredMemory to a destination @ref RegisteredMemory.
Expand Down Expand Up @@ -454,10 +463,16 @@ class Connection {
/// @return name of @ref transport() -> @ref remoteTransport()
std::string getTransportName();

/// Get the maximum write queue size
///
/// @return The maximum number of write requests that can be queued.
int getMaxWriteQueueSize();

protected:
// Internal methods for getting implementation pointers.
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
int maxWriteQueueSize;
};

/// Used to configure an endpoint.
Expand All @@ -468,18 +483,29 @@ struct EndpointConfig {
static const int DefaultMaxWrPerSend = 64;

Transport transport;
int ibMaxCqSize = DefaultMaxCqSize;
int ibMaxCqPollNum = DefaultMaxCqPollNum;
int ibMaxSendWr = DefaultMaxSendWr;
int ibMaxWrPerSend = DefaultMaxWrPerSend;

/// Default constructor. Sets transport to Transport::Unknown.
EndpointConfig() : transport(Transport::Unknown) {}
int ibMaxCqSize;
int ibMaxCqPollNum;
int ibMaxSendWr;
int ibMaxWrPerSend;
int maxWriteQueueSize;

/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
EndpointConfig(Transport transport) : transport(transport) {}
/// @param ibMaxCqSize The maximum completion queue size.
/// @param ibMaxCqPollNum The maximum completion queue poll number.
/// @param ibMaxSendWr The maximum send work requests.
/// @param ibMaxWrPerSend The maximum work requests per send.
/// @param maxWriteQueueSize The maximum write queue size.
EndpointConfig(Transport transport = Transport::Unknown, int ibMaxCqSize = DefaultMaxCqSize,
int ibMaxCqPollNum = DefaultMaxCqPollNum, int ibMaxSendWr = DefaultMaxSendWr,
int ibMaxWrPerSend = DefaultMaxWrPerSend, int maxWriteQueueSize = -1)
: transport(transport),
ibMaxCqSize(ibMaxCqSize),
ibMaxCqPollNum(ibMaxCqPollNum),
ibMaxSendWr(ibMaxSendWr),
ibMaxWrPerSend(ibMaxWrPerSend),
maxWriteQueueSize(maxWriteQueueSize) {}
};

/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/proxy_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ProxyService : public BaseProxyService {
std::vector<RegisteredMemory> memories_;
std::shared_ptr<Proxy> proxy_;
int deviceNumaNode;
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests;

void bindThread();

Expand Down
13 changes: 10 additions & 3 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ std::string Connection::getTransportName() {
TransportNames[static_cast<int>(this->remoteTransport())];
}

int Connection::getMaxWriteQueueSize() { return maxWriteQueueSize; }

// CudaIpcConnection

CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream)
: stream_(stream) {
: Connection(localEndpoint.maxWriteQueueSize()), stream_(stream) {
if (localEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
}
Expand Down Expand Up @@ -119,7 +121,9 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
// IBConnection

IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context)
: transport_(localEndpoint.transport()),
: Connection(localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize()
: EndpointConfig::DefaultMaxCqSize),
transport_(localEndpoint.transport()),
remoteTransport_(remoteEndpoint.transport()),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
qp = getImpl(localEndpoint)->ibQp_;
Expand Down Expand Up @@ -231,7 +235,10 @@ void IBConnection::flush(int64_t timeoutUsec) {

EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize,
uint64_t recvBufferSize)
: abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) {
: Connection(localEndpoint.maxWriteQueueSize()),
abortFlag_(0),
sendBufferSize_(sendBufferSize),
recvBufferSize_(recvBufferSize) {
// Validating Transport Protocol
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
Expand Down
4 changes: 3 additions & 1 deletion src/endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace mscclpp {

Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
: transport_(config.transport), hostHash_(getHostHash()) {
: transport_(config.transport), hostHash_(getHostHash()), maxWriteQueueSize_(config.maxWriteQueueSize) {
if (AllIBTransports.has(transport_)) {
ibLocal_ = true;
ibQp_ = contextImpl.getIbContext(transport_)
Expand All @@ -34,6 +34,8 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)

MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; }

MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() { return pimpl_->maxWriteQueueSize_; }

MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
std::vector<char> data;
std::copy_n(reinterpret_cast<char*>(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data));
Expand Down
1 change: 1 addition & 0 deletions src/include/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct Endpoint::Impl {

Transport transport_;
uint64_t hostHash_;
int maxWriteQueueSize_;

// The following are only used for IB and are undefined for other transports.
bool ibLocal_;
Expand Down
7 changes: 6 additions & 1 deletion src/proxy_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,26 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger->fields.chanId];

auto result = ProxyHandlerResult::Continue;
int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize();

if (trigger->fields.type & TriggerData) {
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
trigger->fields.size);
inflightRequests[semaphore->connection()]++;
}

if (trigger->fields.type & TriggerFlag) {
semaphore->signal();
inflightRequests[semaphore->connection()]++;
}

if (trigger->fields.type & TriggerSync) {
if (trigger->fields.type & TriggerSync ||
(maxWriteQueueSize != -1 && inflightRequests[semaphore->connection()] > maxWriteQueueSize)) {
semaphore->connection()->flush();
result = ProxyHandlerResult::FlushFifoTailAndContinue;
inflightRequests[semaphore->connection()] = 0;
}

return result;
Expand Down

0 comments on commit 80abce5

Please sign in to comment.