From a31e5a24e8152166c5a72de24a9e1c79d57d7e7d Mon Sep 17 00:00:00 2001 From: Chris Granger Date: Tue, 30 Apr 2024 15:00:56 -0600 Subject: [PATCH] Added reaper thread for model instance in python_backend --- src/ipc_message.h | 3 ++- src/pb_stub.cc | 10 ++++++++++ src/pb_stub.h | 2 ++ src/python_be.cc | 33 +++++++++++++++++++++++++++++++++ src/python_be.h | 5 +++++ 5 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/ipc_message.h b/src/ipc_message.h index ac28238c..96390ba1 100644 --- a/src/ipc_message.h +++ b/src/ipc_message.h @@ -66,7 +66,8 @@ typedef enum PYTHONSTUB_commandtype_enum { PYTHONSTUB_LoadModelRequest, PYTHONSTUB_UnloadModelRequest, PYTHONSTUB_ModelReadinessRequest, - PYTHONSTUB_IsRequestCancelled + PYTHONSTUB_IsRequestCancelled, + PYTHONSTUB_CheckCorrid } PYTHONSTUB_CommandType; /// diff --git a/src/pb_stub.cc b/src/pb_stub.cc index a9a910a1..363be8b7 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -246,6 +246,9 @@ Stub::RunCommand() ipc_message = this->PopMessage(); } switch (ipc_message->Command()) { + case PYTHONSTUB_CommandType::PYTHONSTUB_CheckCorrid: { + CheckCorrid(); + } break; case PYTHONSTUB_CommandType::PYTHONSTUB_AutoCompleteRequest: { // Only run this case when auto complete was requested by // Triton core. @@ -721,6 +724,13 @@ Stub::ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr) } } +void Stub::CheckCorrid(){ + if(py::hasattr(model_instance_, "check_corrid")){ + model_instance_.attr("check_corrid")(); + } + return; +} + void Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) { diff --git a/src/pb_stub.h b/src/pb_stub.h index a51f25f5..110f04bf 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -253,6 +253,8 @@ class Stub { /// Execute a batch of requests. void ProcessRequests(RequestBatch* request_batch_shm_ptr); + void CheckCorrid(); + void ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr); /// Get the memory manager message queue diff --git a/src/python_be.cc b/src/python_be.cc index b688fdfd..a74f2e93 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -37,6 +37,24 @@ namespace triton { namespace backend { namespace python { namespace bi = boost::interprocess; +void ModelInstanceState::ReaperThread(){ + // wait for stub to fully load + while (!IsStubProcessAlive()){} + + while (true){ + std::this_thread::sleep_for(std::chrono::milliseconds(30000)); + + // Check that we still have a valid stub + if (!reaper_thread_exit){ + SendMessageToStub(ipc_message->ShmHandle()); + } + // if stub is invalid just return + else{ + break; + } + } +} + ModelInstanceState::ModelInstanceState( ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) : BackendModelInstance(model_state, triton_model_instance), @@ -456,6 +474,13 @@ ModelInstanceState::LaunchStubProcess() } request_executor_ = std::make_unique( Stub()->ShmPool(), model_state->TritonServer()); + + // Save this object so it doesnt get deleted before the stub MQ can access it + ipc_message = IPCMessage::Create(Stub()->ShmPool(), false); + ipc_message->Command() = PYTHONSTUB_CommandType::PYTHONSTUB_CheckCorrid; + + // Launch reaper thread after stub creation + reaper_thread_ = new std::thread([this]() { this->ReaperThread(); }); return nullptr; } @@ -1871,6 +1896,10 @@ ModelInstanceState::ShareCUDAMemoryPool(const int32_t device_id) ModelInstanceState::~ModelInstanceState() { + // Allow reaper thread to exit first so we don't run into issues on close + reaper_thread_exit = true; + ipc_message.reset(); + ModelState* model_state = reinterpret_cast(Model()); Stub()->UpdateHealth(); if (Stub()->IsHealthy()) { @@ -1891,6 +1920,10 @@ ModelInstanceState::~ModelInstanceState() Stub()->ClearQueues(); received_message_.reset(); Stub().reset(); + + // join last so we don't block closing + reaper_thread_->join(); + delete reaper_thread_; } TRITONSERVER_Error* diff --git a/src/python_be.h b/src/python_be.h index 4430767c..be37d3d5 100644 --- a/src/python_be.h +++ b/src/python_be.h @@ -297,6 +297,9 @@ class ModelInstanceState : public BackendModelInstance { std::unique_ptr thread_pool_; std::unordered_map> infer_payload_; std::unique_ptr request_executor_; + std::thread* reaper_thread_; + std::shared_ptr ipc_message; + bool reaper_thread_exit = false; public: static TRITONSERVER_Error* Create( @@ -322,6 +325,8 @@ class ModelInstanceState : public BackendModelInstance { std::shared_ptr>& responses, TRITONBACKEND_Request** requests, const uint32_t request_count); + void ReaperThread(); + // Responds to all the requests with an error message. void RespondErrorToAllRequests( const char* message,