Skip to content

Commit

Permalink
Merge pull request #1799 from zdata-inc/export_tensorrt_v10
Browse files Browse the repository at this point in the history
  • Loading branch information
mikel-brostrom authored Jan 24, 2025
2 parents 7f5dcea + 8c7aab1 commit 000cb0b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 21 deletions.
57 changes: 39 additions & 18 deletions boxmot/appearance/backends/tensorrt_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class TensorRTBackend(BaseModelBackend):
def __init__(self, weights, device, half):
self.is_trt10 = False
super().__init__(weights, device, half)
self.nhwc = False
self.half = half
Expand Down Expand Up @@ -40,22 +41,37 @@ def load_model(self, w):
self.context = self.model_.create_execution_context()
self.bindings = OrderedDict()

self.is_trt10 = not hasattr(self.model_, "num_bindings")
num = range(self.model_.num_io_tensors) if self.is_trt10 else range(self.model_.num_bindings)

# Parse bindings
for index in range(self.model_.num_bindings):
name = self.model_.get_binding_name(index)
dtype = trt.nptype(self.model_.get_binding_dtype(index))
is_input = self.model_.binding_is_input(index)
for index in num:
if self.is_trt10:
name = self.model_.get_tensor_name(index)
dtype = trt.nptype(self.model_.get_tensor_dtype(name))
is_input = self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT
if is_input and -1 in tuple(self.model_.get_tensor_shape(name)):
self.context.set_input_shape(name, tuple(self.model_.get_tensor_profile_shape(name, 0)[1]))
if is_input and dtype == np.float16:
self.fp16 = True

shape = tuple(self.context.get_tensor_shape(name))

# Handle dynamic shapes
if is_input and -1 in self.model_.get_binding_shape(index):
profile_index = 0
min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index)
self.context.set_binding_shape(index, opt_shape)
else:
name = self.model_.get_binding_name(index)
dtype = trt.nptype(self.model_.get_binding_dtype(index))
is_input = self.model_.binding_is_input(index)

if is_input and dtype == np.float16:
self.fp16 = True
# Handle dynamic shapes
if is_input and -1 in self.model_.get_binding_shape(index):
profile_index = 0
min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index)
self.context.set_binding_shape(index, opt_shape)

shape = tuple(self.context.get_binding_shape(index))
if is_input and dtype == np.float16:
self.fp16 = True

shape = tuple(self.context.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device)
self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))

Expand All @@ -64,12 +80,17 @@ def load_model(self, w):
def forward(self, im_batch):
# Adjust for dynamic shapes
if im_batch.shape != self.bindings["images"].shape:
i_in = self.model_.get_binding_index("images")
i_out = self.model_.get_binding_index("output")
self.context.set_binding_shape(i_in, im_batch.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape)
output_shape = tuple(self.context.get_binding_shape(i_out))
self.bindings["output"].data.resize_(output_shape)
if self.is_trt10:
self.context.set_input_shape("images", im_batch.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape)
self.bindings["output"].data.resize_(tuple(self.context.get_tensor_shape("output")))
else:
i_in = self.model_.get_binding_index("images")
i_out = self.model_.get_binding_index("output")
self.context.set_binding_shape(i_in, im_batch.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape)
output_shape = tuple(self.context.get_binding_shape(i_out))
self.bindings["output"].data.resize_(output_shape)

s = self.bindings["images"].shape
assert im_batch.shape == s, f"Input size {im_batch.shape} does not match model size {s}"
Expand Down
14 changes: 11 additions & 3 deletions boxmot/appearance/exporters/tensorrt_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def export(self):

onnx_file = self.export_onnx()
LOGGER.info(f"\nStarting export with TensorRT {trt.__version__}...")
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
assert onnx_file.exists(), f"Failed to export ONNX file: {onnx_file}"
f = self.file.with_suffix(".engine")
logger = trt.Logger(trt.Logger.INFO)
Expand All @@ -27,7 +28,11 @@ def export(self):

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = self.workspace * 1 << 30
workspace = int(self.workspace * (1 << 30))
if is_trt10:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
else: # TensorRT versions 7, 8
config.max_workspace_size = workspace

flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
Expand Down Expand Up @@ -62,8 +67,11 @@ def export(self):
if builder.platform_has_fast_fp16 and self.half:
config.set_flag(trt.BuilderFlag.FP16)
config.default_device_type = trt.DeviceType.GPU
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
t.write(engine.serialize())

build = builder.build_serialized_network if is_trt10 else builder.build_engine
with build(network, config) as engine, open(f, "wb") as t:
t.write(engine if is_trt10 else engine.serialize())

return f


Expand Down

0 comments on commit 000cb0b

Please sign in to comment.