Skip to content

Commit

Permalink
Add autocomplete for runtime model config field
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui committed Nov 3, 2023
1 parent 90309ef commit 84a6c6a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ constexpr char kPyTorchBackend[] = "pytorch";
constexpr char kPythonFilename[] = "model.py";
constexpr char kPythonBackend[] = "python";

constexpr char kVLLMBackend[] = "vllm";

#ifdef TRITON_ENABLE_ENSEMBLE
constexpr char kEnsemblePlatform[] = "ensemble";
#endif // TRITON_ENABLE_ENSEMBLE
Expand Down
77 changes: 74 additions & 3 deletions src/model_config_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ AutoCompleteBackendFields(
config->set_name(model_name);
}

// Trying to fill the 'backend', 'default_model_filename' field.
// Trying to fill the 'backend', 'default_model_filename' and 'runtime' field.

// TensorFlow
// For TF backend, the platform is required
Expand Down Expand Up @@ -1016,7 +1016,7 @@ AutoCompleteBackendFields(
}
}

// Fill 'backend' and 'default_model_filename' if missing
// Fill 'backend', 'default_model_filename' and 'runtime' if missing
if ((config->platform() == kTensorFlowSavedModelPlatform) ||
(config->platform() == kTensorFlowGraphDefPlatform)) {
if (config->backend().empty()) {
Expand All @@ -1029,6 +1029,8 @@ AutoCompleteBackendFields(
config->set_default_model_filename(kTensorFlowGraphDefFilename);
}
}
RETURN_IF_ERROR(
AutoCompleteBackendRuntimeField(RuntimeType::RUNTIME_TYPE_CPP, config));
return Status::Success;
}

Expand Down Expand Up @@ -1058,6 +1060,8 @@ AutoCompleteBackendFields(
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kTensorRTPlanFilename);
}
RETURN_IF_ERROR(
AutoCompleteBackendRuntimeField(RuntimeType::RUNTIME_TYPE_CPP, config));
return Status::Success;
}

Expand All @@ -1083,6 +1087,8 @@ AutoCompleteBackendFields(
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kOnnxRuntimeOnnxFilename);
}
RETURN_IF_ERROR(
AutoCompleteBackendRuntimeField(RuntimeType::RUNTIME_TYPE_CPP, config));
return Status::Success;
}

Expand All @@ -1103,10 +1109,12 @@ AutoCompleteBackendFields(
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kOpenVINORuntimeOpenVINOFilename);
}
RETURN_IF_ERROR(
AutoCompleteBackendRuntimeField(RuntimeType::RUNTIME_TYPE_CPP, config));
return Status::Success;
}

// PyTorch (TorchScript, LibTorch)
// PyTorch
if (config->backend().empty()) {
if ((config->platform() == kPyTorchLibTorchPlatform) ||
(config->default_model_filename() == kPyTorchLibTorchFilename)) {
Expand All @@ -1132,6 +1140,8 @@ AutoCompleteBackendFields(
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kPyTorchLibTorchFilename);
}
RETURN_IF_ERROR(AutoCompleteBackendRuntimeField(
RuntimeType::RUNTIME_TYPE_UNKNOWN, config));
return Status::Success;
}

Expand All @@ -1152,6 +1162,18 @@ AutoCompleteBackendFields(
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kPythonFilename);
}
RETURN_IF_ERROR(
AutoCompleteBackendRuntimeField(RuntimeType::RUNTIME_TYPE_CPP, config));
return Status::Success;
}

// vLLM
if (config->backend() == kVLLMBackend) {
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kPythonFilename);
}
RETURN_IF_ERROR(AutoCompleteBackendRuntimeField(
RuntimeType::RUNTIME_TYPE_PYTHON, config));
return Status::Success;
}

Expand Down Expand Up @@ -1180,9 +1202,58 @@ AutoCompleteBackendFields(
config->set_backend(backend_name);
config->set_default_model_filename(
(std::string("model.") + backend_name).c_str());
RETURN_IF_ERROR(AutoCompleteBackendRuntimeField(
RuntimeType::RUNTIME_TYPE_UNKNOWN, config));
return Status::Success;
}

RETURN_IF_ERROR(AutoCompleteBackendRuntimeField(
RuntimeType::RUNTIME_TYPE_UNKNOWN, config));
return Status::Success;
}

Status
AutoCompleteBackendRuntimeField(
RuntimeType runtime_type, inference::ModelConfig* config)
{
bool fill_runtime = config->runtime().empty();
#ifdef TRITON_ENABLE_ENSEMBLE
fill_runtime = fill_runtime && config->platform() != kEnsemblePlatform;
#endif // TRITON_ENABLE_ENSEMBLE
if (fill_runtime) {
// auto detect C++ vs Python runtime if unknown
if (runtime_type == RuntimeType::RUNTIME_TYPE_UNKNOWN) {
// default to C++ runtime
runtime_type = RuntimeType::RUNTIME_TYPE_CPP;
// unless the default model filename ends with '.py'
const static std::string py_model_extension = ".py";
const std::string& model_filename = config->default_model_filename();
if (model_filename.length() >= py_model_extension.length()) {
auto start_pos = model_filename.length() - py_model_extension.length();
if (model_filename.substr(start_pos) == py_model_extension) {
runtime_type = RuntimeType::RUNTIME_TYPE_PYTHON;
}
}
}
// set runtime library from runtime type
if (runtime_type == RuntimeType::RUNTIME_TYPE_CPP) {
if (config->backend().empty()) {
return Status(
Status::Code::INTERNAL,
"Model config 'backend' field cannot be empty when auto completing "
"for C++ 'runtime' field.");
}
#ifdef _WIN32
config->set_runtime("triton_" + config->backend() + ".dll");
#else
config->set_runtime("libtriton_" + config->backend() + ".so");
#endif
} else if (runtime_type == RuntimeType::RUNTIME_TYPE_PYTHON) {
config->set_runtime(kPythonFilename);
} else {
return Status(Status::Code::INTERNAL, "Unimplemented runtime type.");
}
}
return Status::Success;
}

Expand Down
21 changes: 18 additions & 3 deletions src/model_config_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ enum BackendType {
BACKEND_TYPE_PYTORCH = 4
};

/// Enumeration for different runtime types.
enum RuntimeType {
RUNTIME_TYPE_UNKNOWN = 0,
RUNTIME_TYPE_CPP = 1,
RUNTIME_TYPE_PYTHON = 2
};

// Get version of a model from the path containing the model
/// definition file.
/// \param path The path to the model definition file.
Expand Down Expand Up @@ -87,9 +94,9 @@ Status GetNormalizedModelConfig(
const std::string& model_name, const std::string& path,
const double min_compute_capability, inference::ModelConfig* config);

/// Auto-complete backend related fields (platform, backend and default model
/// filename) if not set, note that only Triton recognized backends will be
/// checked.
/// Auto-complete backend related fields (platform, backend, default model
/// filename and runtime) if not set, note that only Triton recognized backends
/// will be checked.
/// \param model_name The name of the model.
/// \param model_path The full-path to the directory containing the
/// model configuration.
Expand All @@ -99,6 +106,14 @@ Status AutoCompleteBackendFields(
const std::string& model_name, const std::string& model_path,
inference::ModelConfig* config);

/// Auto-complete backend runtime field if not set.
/// \param runtime_type Specify the runtime type for the model (C++ or Python).
/// If unknown, it will be determined from the default model filename.
/// \param config Returns the auto-completed model configuration.
/// \return The error status.
Status AutoCompleteBackendRuntimeField(
RuntimeType runtime_type, inference::ModelConfig* config);

/// Detects and adds missing fields in the model configuration.
/// \param min_compute_capability The minimum supported CUDA compute
/// capability.
Expand Down

0 comments on commit 84a6c6a

Please sign in to comment.