Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for request rescheduling (#319) #321

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ any C++ code.
- [Decoupled mode](#decoupled-mode)
- [Use Cases](#use-cases)
- [Known Issues](#known-issues)
- [Request Rescheduling](#request-rescheduling)
- [`finalize`](#finalize)
- [Model Config File](#model-config-file)
- [Inference Request Parameters](#inference-request-parameters)
Expand Down Expand Up @@ -623,6 +624,102 @@ for more details on how to host a decoupled model.

* Currently, decoupled Python models can not make async infer requests.

#### Request Rescheduling

Starting from 23.11, Python backend supports request rescheduling. By calling
the `set_release_flags` function on the request object with the flag
`pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE`, you can reschedule the
request for further execution in a future batch. This feature is useful for
handling generative sequences.

The model config must be configured to enable generative sequence batching in
order to use the request rescheduling API:

```
sequence_batching {
generative_sequence : true
}
```

For non-decoupled models, there can only be one response for each request. Since
the rescheduled request is the same as the original, you must append a `None`
object to the response list for the rescheduled request. For example:

```python
import triton_python_backend_utils as pb_utils

class TritonPythonModel:
...

def execute(self, requests):
responses = []

for request in requests:
# Explicitly reschedule the first request
if self.idx == 0:
request.set_release_flags(
pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
)
responses.append(None)
self.idx += 1
else:
responses.append(inference_response)

return responses
```

For decoupled models, it is required to reschedule a request *before* returning
from the `execute` function.
Below is an example of a decoupled model using request rescheduling. This model
takes 1 input tensor, an INT32 [ 1 ] input named "IN", and produces an output
tensor "OUT" with the same shape as the input tensor. The input value indicates
the total number of responses to be generated and the output value indicates the
number of remaining responses. For example, if the request input has value 2,
the model will:
- Send a response with value 1.
- Release request with RESCHEDULE flag.
- When execute on the same request, send the last response with value 0.
- Release request with ALL flag.

```python
import triton_python_backend_utils as pb_utils

class TritonPythonModel:
...

def execute(self, requests):
responses = []

for request in requests:
in_input = pb_utils.get_input_tensor_by_name(request, "IN").as_numpy()

if self.reset_flag:
self.remaining_response = in_input[0]
self.reset_flag = False

response_sender = request.get_response_sender()

self.remaining_response -= 1

out_output = pb_utils.Tensor(
"OUT", np.array([self.remaining_response], np.int32)
)
response = pb_utils.InferenceResponse(output_tensors=[out_output])

if self.remaining_response <= 0:
response_sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
self.reset_flag = True
else:
request.set_release_flags(
pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
)
response_sender.send(response)

return None
```

### `finalize`

Implementing `finalize` is optional. This function allows you to do any clean
Expand Down
18 changes: 17 additions & 1 deletion src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ InferRequest::InferRequest(
model_version_(model_version), parameters_(parameters), flags_(flags),
timeout_(timeout), response_factory_address_(response_factory_address),
request_address_(request_address), preferred_memory_(preferred_memory),
trace_(trace)
trace_(trace), request_release_flags_(TRITONSERVER_REQUEST_RELEASE_ALL)
{
for (auto& input : inputs) {
if (!input) {
Expand Down Expand Up @@ -175,6 +175,20 @@ InferRequest::Trace()
return trace_;
}

uint32_t
InferRequest::ReleaseFlags()
{
request_release_flags_ = infer_request_shm_ptr_->request_release_flags;
return request_release_flags_;
}

void
InferRequest::SetReleaseFlags(const uint32_t& flags)
{
request_release_flags_ = flags;
infer_request_shm_ptr_->request_release_flags = request_release_flags_;
}

void
InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
{
Expand All @@ -201,6 +215,7 @@ InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
infer_request_shm_ptr_->timeout = timeout_;
infer_request_shm_ptr_->preferred_memory = preferred_memory_;
infer_request_shm_ptr_->trace = trace_;
infer_request_shm_ptr_->request_release_flags = request_release_flags_;

output_names_handle_shm_ptr_ =
reinterpret_cast<bi::managed_external_buffer::handle_t*>(
Expand Down Expand Up @@ -379,6 +394,7 @@ InferRequest::InferRequest(
timeout_ = infer_request_shm_ptr_->timeout;
preferred_memory_ = infer_request_shm_ptr_->preferred_memory;
trace_ = infer_request_shm_ptr_->trace;
request_release_flags_ = infer_request_shm_ptr_->request_release_flags;

#ifdef TRITON_PB_STUB
pb_cancel_ =
Expand Down
4 changes: 4 additions & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct InferRequestShm {
int32_t timeout;
PreferredMemory preferred_memory;
InferenceTrace trace;
uint32_t request_release_flags;
};

class InferRequest {
Expand Down Expand Up @@ -104,6 +105,8 @@ class InferRequest {
void SetIsDecoupled(const bool is_decoupled);
PreferredMemory& GetPreferredMemory();
InferenceTrace& Trace();
uint32_t ReleaseFlags();
void SetReleaseFlags(const uint32_t& flags);

#ifdef TRITON_PB_STUB
std::shared_ptr<InferResponse> Exec(const bool is_decoupled);
Expand Down Expand Up @@ -161,6 +164,7 @@ class InferRequest {
bool is_decoupled_;
PreferredMemory preferred_memory_;
InferenceTrace trace_;
uint32_t request_release_flags_;

// Shared Memory Data Structures
AllocatedSharedMemory<char> infer_request_shm_;
Expand Down
49 changes: 32 additions & 17 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -793,26 +793,39 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
std::to_string(response_size) + "\n";
throw PythonBackendException(err);
}
for (auto& response : responses) {

for (size_t i = 0; i < response_size; i++) {
// Check the return type of execute function.
if (!py::isinstance<InferResponse>(response)) {
std::string str = py::str(response.get_type());
throw PythonBackendException(
std::string("Expected an 'InferenceResponse' object in the execute "
"function return list, found type '") +
str + "'.");
InferRequest* infer_request = py_request_list[i].cast<InferRequest*>();
if (infer_request->ReleaseFlags() ==
TRITONSERVER_REQUEST_RELEASE_RESCHEDULE) {
if (!py::isinstance<py::none>(responses[i])) {
// When the request is rescheduled in non-decoupled model, the
// response must be None.
std::string str = py::str(responses[i].get_type());
throw PythonBackendException(
"Expected a None object in the execute function return list for "
"reschduled request, "
"found type '" +
str + "'.");
}
} else {
if (!py::isinstance<InferResponse>(responses[i])) {
std::string str = py::str(responses[i].get_type());
throw PythonBackendException(
std::string(
"Expected an 'InferenceResponse' object in the execute "
"function return list, found type '") +
str + "'.");
}
InferResponse* infer_response = responses[i].cast<InferResponse*>();
infer_response->PruneOutputTensors(
infer_request->RequestedOutputNames());
ProcessResponse(infer_response);
responses_shm_handle[i] = infer_response->ShmHandle();
}
}
response_batch_shm_ptr->batch_size = response_size;

for (size_t i = 0; i < batch_size; i++) {
InferResponse* infer_response = responses[i].cast<InferResponse*>();
InferRequest* infer_request = py_request_list[i].cast<InferRequest*>();
infer_response->PruneOutputTensors(infer_request->RequestedOutputNames());

ProcessResponse(infer_response);
responses_shm_handle[i] = infer_response->ShmHandle();
}
}
catch (const PythonBackendException& pb_exception) {
has_exception = true;
Expand Down Expand Up @@ -1675,7 +1688,9 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
"requested_output_names", &InferRequest::RequestedOutputNames,
py::return_value_policy::reference_internal)
.def("get_response_sender", &InferRequest::GetResponseSender)
.def("is_cancelled", &InferRequest::IsCancelled);
.def("is_cancelled", &InferRequest::IsCancelled)
.def("set_release_flags", &InferRequest::SetReleaseFlags),
py::arg("flags").none(false);

py::class_<PbTensor, std::shared_ptr<PbTensor>>(module, "Tensor")
.def(py::init(&PbTensor::FromNumpy))
Expand Down
Loading
Loading