Skip to content

Commit

Permalink
Delete response factory when response sender goes out of scope
Browse files Browse the repository at this point in the history
  • Loading branch information
krishung5 committed Dec 14, 2023
1 parent f8b2eb6 commit 489217f
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 100 deletions.
14 changes: 0 additions & 14 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,20 +405,6 @@ InferRequest::InferRequest(
#endif
}

#ifndef TRITON_PB_STUB
TRITONSERVER_Error*
InferRequest::DeleteResponseFactory()
{
TRITONBACKEND_ResponseFactory* response_factory =
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
response_factory_address_);
TRITONSERVER_Error* error =
TRITONBACKEND_ResponseFactoryDelete(response_factory);

return error;
}
#endif

#ifdef TRITON_PB_STUB
bool
InferRequest::IsCancelled()
Expand Down
4 changes: 0 additions & 4 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ class InferRequest {
intptr_t RequestAddress();
~InferRequest() {}

#ifndef TRITON_PB_STUB
TRITONSERVER_Error* DeleteResponseFactory();
#endif

private:
InferRequest(
AllocatedSharedMemory<char>& infer_request_shm,
Expand Down
3 changes: 2 additions & 1 deletion src/ipc_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ typedef enum PYTHONSTUB_commandtype_enum {
PYTHONSTUB_AutoCompleteRequest,
PYTHONSTUB_AutoCompleteResponse,
PYTHONSTUB_LogRequest,
PYTHONSTUB_CleanupRequest,
PYTHONSTUB_BLSDecoupledInferPayloadCleanup,
PYTHONSTUB_BLSDecoupledResponseFactoryCleanup,
PYTHONSTUB_MetricFamilyRequestNew,
PYTHONSTUB_MetricFamilyRequestDelete,
PYTHONSTUB_MetricRequestNew,
Expand Down
2 changes: 1 addition & 1 deletion src/pb_response_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void
ResponseIterator::Clear()
{
std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();
stub->EnqueueCleanupId(id_);
stub->EnqueueCleanupId(id_, PYTHONSTUB_BLSDecoupledInferPayloadCleanup);
{
std::lock_guard<std::mutex> lock{mu_};
response_buffer_.push(DUMMY_MESSAGE);
Expand Down
18 changes: 12 additions & 6 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,12 @@ Stub::ServiceStubToParentRequests()
stub_to_parent_buffer_.pop();
if (utils_msg_payload->command_type == PYTHONSTUB_LogRequest) {
SendLogMessage(utils_msg_payload);
} else if (utils_msg_payload->command_type == PYTHONSTUB_CleanupRequest) {
SendCleanupId(utils_msg_payload);
} else if (
(utils_msg_payload->command_type ==
PYTHONSTUB_BLSDecoupledInferPayloadCleanup) ||
(utils_msg_payload->command_type ==
PYTHONSTUB_BLSDecoupledResponseFactoryCleanup)) {
SendCleanupId(utils_msg_payload, utils_msg_payload->command_type);
} else if (
utils_msg_payload->command_type == PYTHONSTUB_IsRequestCancelled) {
SendIsCancelled(utils_msg_payload);
Expand Down Expand Up @@ -1040,7 +1044,9 @@ Stub::SendLogMessage(std::unique_ptr<UtilsMessagePayload>& utils_msg_payload)
}

void
Stub::SendCleanupId(std::unique_ptr<UtilsMessagePayload>& utils_msg_payload)
Stub::SendCleanupId(
std::unique_ptr<UtilsMessagePayload>& utils_msg_payload,
const PYTHONSTUB_CommandType& command_type)
{
void* id = utils_msg_payload->utils_message_ptr;
{
Expand All @@ -1050,7 +1056,7 @@ Stub::SendCleanupId(std::unique_ptr<UtilsMessagePayload>& utils_msg_payload)

std::unique_ptr<IPCMessage> ipc_message =
IPCMessage::Create(shm_pool_, true /* inline_response */);
ipc_message->Command() = PYTHONSTUB_CleanupRequest;
ipc_message->Command() = command_type;
AllocatedSharedMemory<char> cleanup_request_message =
shm_pool_->Construct<char>(
sizeof(CleanupMessage) +
Expand All @@ -1072,11 +1078,11 @@ Stub::SendCleanupId(std::unique_ptr<UtilsMessagePayload>& utils_msg_payload)
}

void
Stub::EnqueueCleanupId(void* id)
Stub::EnqueueCleanupId(void* id, const PYTHONSTUB_CommandType& command_type)
{
if (id != nullptr) {
std::unique_ptr<UtilsMessagePayload> utils_msg_payload =
std::make_unique<UtilsMessagePayload>(PYTHONSTUB_CleanupRequest, id);
std::make_unique<UtilsMessagePayload>(command_type, id);
EnqueueUtilsMessage(std::move(utils_msg_payload));
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/pb_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,13 @@ class Stub {
std::shared_ptr<InferResponse> infer_response);

/// Send the id to the python backend for object cleanup
void SendCleanupId(std::unique_ptr<UtilsMessagePayload>& utils_msg_payload);
void SendCleanupId(
std::unique_ptr<UtilsMessagePayload>& utils_msg_payload,
const PYTHONSTUB_CommandType& command_type);

/// Add cleanup id to queue
void EnqueueCleanupId(void* id);
/// Add cleanup id to queue. This is used for cleaning up the infer_payload
/// and the response factory for BLS decoupled response.
void EnqueueCleanupId(void* id, const PYTHONSTUB_CommandType& command_type);

/// Add request cancellation query to queue
void EnqueueIsCancelled(PbCancel* pb_cancel);
Expand Down
79 changes: 16 additions & 63 deletions src/python_be.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,21 +379,7 @@ ModelInstanceState::SaveRequestsToSharedMemory(
std::unique_ptr<InferRequest> infer_request;
if (model_state->IsDecoupled()) {
TRITONBACKEND_ResponseFactory* factory_ptr;
// Reuse the response factory if there is already a response factory
// associated with the request
std::lock_guard<std::mutex> guard{response_factory_map_mutex_};
{
if (response_factory_map_.find(reinterpret_cast<intptr_t>(request)) !=
response_factory_map_.end()) {
factory_ptr =
response_factory_map_[reinterpret_cast<intptr_t>(request)];
} else {
RETURN_IF_ERROR(
TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request));
response_factory_map_[reinterpret_cast<intptr_t>(request)] =
factory_ptr;
}
}
RETURN_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request));

infer_request = std::make_unique<InferRequest>(
id, correlation_id, pb_input_tensors, requested_output_names,
Expand Down Expand Up @@ -843,7 +829,8 @@ ModelInstanceState::StubToParentMQMonitor()
ProcessLogRequest(message);
break;
}
case PYTHONSTUB_CleanupRequest: {
case PYTHONSTUB_BLSDecoupledInferPayloadCleanup:
case PYTHONSTUB_BLSDecoupledResponseFactoryCleanup: {
ProcessBLSCleanupRequest(message);
break;
}
Expand Down Expand Up @@ -941,9 +928,17 @@ ModelInstanceState::ProcessBLSCleanupRequest(
Stub()->ShmPool()->Load<char>(message->Args());
CleanupMessage* cleanup_message_ptr =
reinterpret_cast<CleanupMessage*>(cleanup_request_message.data_.get());

void* id = cleanup_message_ptr->id;
infer_payload_.erase(reinterpret_cast<intptr_t>(id));
intptr_t id = reinterpret_cast<intptr_t>(cleanup_message_ptr->id);
if (message->Command() == PYTHONSTUB_BLSDecoupledInferPayloadCleanup) {
// Remove the InferPayload object from the map.
infer_payload_.erase(id);
} else if (
message->Command() == PYTHONSTUB_BLSDecoupledResponseFactoryCleanup) {
// Delete response factory
std::unique_ptr<
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
response_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(id));
}

{
bi::scoped_lock<bi::interprocess_mutex> lock{*(message->ResponseMutex())};
Expand Down Expand Up @@ -1172,12 +1167,6 @@ ModelInstanceState::ResponseSendDecoupled(
std::lock_guard<std::mutex> guard{closed_requests_mutex_};
closed_requests_.push_back(send_message_payload->request_address);
}

// Clean up the response factory map.
{
std::lock_guard<std::mutex> guard{response_factory_map_mutex_};
response_factory_map_.erase(send_message_payload->request_address);
}
}

if (send_message_payload->response != 0) {
Expand All @@ -1195,14 +1184,7 @@ ModelInstanceState::ResponseSendDecoupled(
error_message);

std::vector<std::pair<std::unique_ptr<PbMemory>, void*>> gpu_output_buffers;
std::unique_ptr<
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
response_factory_ptr;
GPUBuffersHelper gpu_buffer_helper;
if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
response_factory_ptr.reset(
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(response_factory));
}

#ifdef TRITON_ENABLE_GPU
for (auto& output_tensor : infer_response->OutputTensors()) {
Expand Down Expand Up @@ -1289,13 +1271,6 @@ ModelInstanceState::ResponseSendDecoupled(
response_factory, send_message_payload->flags);
SetErrorForResponseSendMessage(
send_message_payload, WrapTritonErrorInSharedPtr(error), error_message);

if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
std::unique_ptr<
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
response_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
send_message_payload->response_factory_address));
}
}
}

Expand Down Expand Up @@ -1864,29 +1839,6 @@ ModelInstanceState::ShareCUDAMemoryPool(const int32_t device_id)
#endif // TRITON_ENABLE_GPU
}

void
ModelInstanceState::CleanupDecoupledRequests(
const std::unique_ptr<InferRequest>& infer_request)
{
// Reset the release flags for all the requests.
infer_request->SetReleaseFlags(TRITONSERVER_REQUEST_RELEASE_ALL);

// Clean up the response factory map.
{
std::lock_guard<std::mutex> guard{response_factory_map_mutex_};
response_factory_map_.erase(
reinterpret_cast<intptr_t>(infer_request->RequestAddress()));
}

// We should only delete the response factory for the requests that have
// not been closed.
if (!ExistsInClosedRequests(infer_request->RequestAddress())) {
LOG_IF_ERROR(
infer_request->DeleteResponseFactory(),
"Failed to delete the response factory.");
}
}

ModelInstanceState::~ModelInstanceState()
{
ModelState* model_state = reinterpret_cast<ModelState*>(Model());
Expand Down Expand Up @@ -2518,7 +2470,8 @@ TRITONBACKEND_ModelInstanceExecute(
}

for (auto& infer_request : infer_requests) {
instance_state->CleanupDecoupledRequests(infer_request);
// Reset the release flags for all the requests.
infer_request->SetReleaseFlags(TRITONSERVER_REQUEST_RELEASE_ALL);
}
}
}
Expand Down
10 changes: 2 additions & 8 deletions src/python_be.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,6 @@ class ModelInstanceState : public BackendModelInstance {
std::unique_ptr<boost::asio::thread_pool> thread_pool_;
std::unordered_map<intptr_t, std::shared_ptr<InferPayload>> infer_payload_;
std::unique_ptr<RequestExecutor> request_executor_;
std::mutex response_factory_map_mutex_;
std::unordered_map<intptr_t, TRITONBACKEND_ResponseFactory*>
response_factory_map_;

public:
static TRITONSERVER_Error* Create(
Expand Down Expand Up @@ -403,7 +400,8 @@ class ModelInstanceState : public BackendModelInstance {
std::unique_ptr<InferResponse>* infer_response,
bi::managed_external_buffer::handle_t* response_handle);

// Process the bls decoupled cleanup request
// Process the bls decoupled cleanup request for InferPayload and
// ResponseFactory
void ProcessBLSCleanupRequest(const std::unique_ptr<IPCMessage>& message);

// Process request cancellation query
Expand All @@ -429,9 +427,5 @@ class ModelInstanceState : public BackendModelInstance {

// Attempt to share CUDA memory pool with the stub process
void ShareCUDAMemoryPool(const int32_t device_id);

// Cleanup the decoupled requests when there is an error in the response.
void CleanupDecoupledRequests(
const std::unique_ptr<InferRequest>& infer_request);
};
}}} // namespace triton::backend::python
7 changes: 7 additions & 0 deletions src/response_sender.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ ResponseSender::ResponseSender(
{
}

ResponseSender::~ResponseSender()
{
std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();
stub->EnqueueCleanupId(
reinterpret_cast<void*>(response_factory_address_),
PYTHONSTUB_BLSDecoupledResponseFactoryCleanup);
}

void
ResponseSender::Send(
Expand Down
1 change: 1 addition & 0 deletions src/response_sender.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ResponseSender {
intptr_t request_address, intptr_t response_factory_address,
std::unique_ptr<SharedMemoryManager>& shm_pool,
const std::shared_ptr<PbCancel>& pb_cancel);
~ResponseSender();
void Send(std::shared_ptr<InferResponse> response, const uint32_t flags);
bool IsCancelled();

Expand Down

0 comments on commit 489217f

Please sign in to comment.