Skip to content

Commit

Permalink
replace
Browse files Browse the repository at this point in the history
  • Loading branch information
yf711 committed Jan 30, 2024
1 parent 2adcb66 commit 57f06cf
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions src/onnxruntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,22 +437,30 @@ ModelState::LoadModel(
OrtTensorRTProviderOptionsV2* trt_options;
THROW_IF_BACKEND_MODEL_ORT_ERROR(
ort_api->CreateTensorRTProviderOptions(&trt_options));
std::unique_ptr<
OrtTensorRTProviderOptionsV2,
decltype(ort_api->ReleaseTensorRTProviderOptions)>
rel_trt_options(
trt_options, ort_api->ReleaseTensorRTProviderOptions);
std::string int8_calibration_table_name;
std::string trt_engine_cache_path;
// Validate and set parameters
triton::common::TritonJson::Value params;
if (ea.Find("parameters", &params)) {
std::vector<std::string> param_keys;
std::vector<const char*> keys, values;
RETURN_IF_ERROR(params.Members(&param_keys));
for (const auto& param_key : param_keys) {
std::string value_string;
std::string value_string, key, value;
if (param_key == "precision_mode") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
if (value_string == "FP16") {
trt_options->trt_fp16_enable = 1;
key = "trt_fp16_enable";
value = "1"
} else if (value_string == "INT8") {
trt_options->trt_int8_enable = 1;
key = "trt_int8_enable";
value = "1"
} else if (value_string != "FP32") {
RETURN_ERROR_IF_FALSE(
false, TRITONSERVER_ERROR_INVALID_ARG,
Expand All @@ -465,33 +473,31 @@ ModelState::LoadModel(
size_t max_workspace_size_bytes;
RETURN_IF_ERROR(ParseUnsignedLongLongValue(
value_string, &max_workspace_size_bytes));
trt_options->trt_max_workspace_size =
max_workspace_size_bytes;
key = "trt_max_workspace_size";
value = value_string;
} else if (param_key == "int8_calibration_table_name") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &int8_calibration_table_name));
trt_options->trt_int8_calibration_table_name =
int8_calibration_table_name.c_str();
param_key.c_str(), &value));
key = "trt_int8_calibration_table_name";
} else if (param_key == "int8_use_native_calibration_table") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int use_native_calibration_table;
RETURN_IF_ERROR(ParseIntValue(
value_string, &use_native_calibration_table));
trt_options->trt_int8_use_native_calibration_table =
use_native_calibration_table;
bool use_native_calibration_table;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &use_native_calibration_table));
key = "trt_int8_use_native_calibration_table";
value = value_string;
} else if (param_key == "trt_engine_cache_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool enable_cache;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &enable_cache));
trt_options->trt_engine_cache_enable = enable_cache;
key = "trt_engine_cache_enable";
value = value_string;
} else if (param_key == "trt_engine_cache_path") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &trt_engine_cache_path));
trt_options->trt_engine_cache_path =
trt_engine_cache_path.c_str();
RETURN_IF_ERROR(params.MemberAsString(param_key.c_str(), &value));
key = "trt_engine_cache_path";
} else {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
Expand All @@ -501,14 +507,14 @@ ModelState::LoadModel(
"Accelerator")
.c_str());
}
keys.push_back(key.c_str());
values.push_back(value.c_str());
}
ort_api->UpdateTensorRTProviderOptions(
rel_trt_options.get(), keys.data(), values.data(),
keys.size())
}

std::unique_ptr<
OrtTensorRTProviderOptionsV2,
decltype(ort_api->ReleaseTensorRTProviderOptions)>
rel_trt_options(
trt_options, ort_api->ReleaseTensorRTProviderOptions);

RETURN_IF_ORT_ERROR(
ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2(
static_cast<OrtSessionOptions*>(soptions),
Expand Down

0 comments on commit 57f06cf

Please sign in to comment.