Skip to content

Commit

Permalink
Added reaper thread for model instance in python_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
CGranger-sorenson committed Apr 30, 2024
1 parent 4d42111 commit a31e5a2
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/ipc_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ typedef enum PYTHONSTUB_commandtype_enum {
PYTHONSTUB_LoadModelRequest,
PYTHONSTUB_UnloadModelRequest,
PYTHONSTUB_ModelReadinessRequest,
PYTHONSTUB_IsRequestCancelled
PYTHONSTUB_IsRequestCancelled,
PYTHONSTUB_CheckCorrid
} PYTHONSTUB_CommandType;

///
Expand Down
10 changes: 10 additions & 0 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ Stub::RunCommand()
ipc_message = this->PopMessage();
}
switch (ipc_message->Command()) {
case PYTHONSTUB_CommandType::PYTHONSTUB_CheckCorrid: {
CheckCorrid();
} break;
case PYTHONSTUB_CommandType::PYTHONSTUB_AutoCompleteRequest: {
// Only run this case when auto complete was requested by
// Triton core.
Expand Down Expand Up @@ -721,6 +724,13 @@ Stub::ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr)
}
}

void Stub::CheckCorrid(){
if(py::hasattr(model_instance_, "check_corrid")){
model_instance_.attr("check_corrid")();
}
return;
}

void
Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
{
Expand Down
2 changes: 2 additions & 0 deletions src/pb_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ class Stub {
/// Execute a batch of requests.
void ProcessRequests(RequestBatch* request_batch_shm_ptr);

void CheckCorrid();

void ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr);

/// Get the memory manager message queue
Expand Down
33 changes: 33 additions & 0 deletions src/python_be.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ namespace triton { namespace backend { namespace python {

namespace bi = boost::interprocess;

void ModelInstanceState::ReaperThread(){
// wait for stub to fully load
while (!IsStubProcessAlive()){}

while (true){
std::this_thread::sleep_for(std::chrono::milliseconds(30000));

// Check that we still have a valid stub
if (!reaper_thread_exit){
SendMessageToStub(ipc_message->ShmHandle());
}
// if stub is invalid just return
else{
break;
}
}
}

ModelInstanceState::ModelInstanceState(
ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance)
: BackendModelInstance(model_state, triton_model_instance),
Expand Down Expand Up @@ -456,6 +474,13 @@ ModelInstanceState::LaunchStubProcess()
}
request_executor_ = std::make_unique<RequestExecutor>(
Stub()->ShmPool(), model_state->TritonServer());

// Save this object so it doesnt get deleted before the stub MQ can access it
ipc_message = IPCMessage::Create(Stub()->ShmPool(), false);
ipc_message->Command() = PYTHONSTUB_CommandType::PYTHONSTUB_CheckCorrid;

// Launch reaper thread after stub creation
reaper_thread_ = new std::thread([this]() { this->ReaperThread(); });

return nullptr;
}
Expand Down Expand Up @@ -1871,6 +1896,10 @@ ModelInstanceState::ShareCUDAMemoryPool(const int32_t device_id)

ModelInstanceState::~ModelInstanceState()
{
// Allow reaper thread to exit first so we don't run into issues on close
reaper_thread_exit = true;
ipc_message.reset();

ModelState* model_state = reinterpret_cast<ModelState*>(Model());
Stub()->UpdateHealth();
if (Stub()->IsHealthy()) {
Expand All @@ -1891,6 +1920,10 @@ ModelInstanceState::~ModelInstanceState()
Stub()->ClearQueues();
received_message_.reset();
Stub().reset();

// join last so we don't block closing
reaper_thread_->join();
delete reaper_thread_;
}

TRITONSERVER_Error*
Expand Down
5 changes: 5 additions & 0 deletions src/python_be.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ class ModelInstanceState : public BackendModelInstance {
std::unique_ptr<boost::asio::thread_pool> thread_pool_;
std::unordered_map<intptr_t, std::shared_ptr<InferPayload>> infer_payload_;
std::unique_ptr<RequestExecutor> request_executor_;
std::thread* reaper_thread_;
std::shared_ptr<IPCMessage> ipc_message;
bool reaper_thread_exit = false;

public:
static TRITONSERVER_Error* Create(
Expand All @@ -322,6 +325,8 @@ class ModelInstanceState : public BackendModelInstance {
std::shared_ptr<std::vector<TRITONBACKEND_Response*>>& responses,
TRITONBACKEND_Request** requests, const uint32_t request_count);

void ReaperThread();

// Responds to all the requests with an error message.
void RespondErrorToAllRequests(
const char* message,
Expand Down

0 comments on commit a31e5a2

Please sign in to comment.