Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add BF16 tensor support via dlpack #371

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,10 @@ input0 = pb_utils.Tensor.from_dlpack("INPUT0", pytorch_tensor)
This method only supports contiguous Tensors that are in C-order. If the tensor
is not C-order contiguous an exception will be raised.

For python models with input or output tensors of type BFloat16 (BF16), the
`as_numpy()` method is not supported, and the `from_dlpack` and `to_dlpack`
methods must be used instead.

## `pb_utils.Tensor.is_cpu() -> bool`

This function can be used to check whether a tensor is placed in CPU or not.
Expand Down
17 changes: 16 additions & 1 deletion src/pb_stub_utils.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -168,6 +168,8 @@ triton_to_pybind_dtype(TRITONSERVER_DataType data_type)
dtype_numpy = py::dtype(py::format_descriptor<uint8_t>::format());
break;
case TRITONSERVER_TYPE_BF16:
// NOTE: Currently skipping this call via `if (BF16)` check, but may
// want to better handle this or set some default/invalid dtype.
throw PythonBackendException("TYPE_BF16 not currently supported.");
case TRITONSERVER_TYPE_INVALID:
throw PythonBackendException("Dtype is invalid.");
Expand Down Expand Up @@ -240,6 +242,10 @@ triton_to_dlpack_type(TRITONSERVER_DataType triton_dtype)
case TRITONSERVER_TYPE_BYTES:
throw PythonBackendException(
"TYPE_BYTES tensors cannot be converted to DLPack.");
case TRITONSERVER_TYPE_BF16:
dl_code = DLDataTypeCode::kDLBfloat;
dt_size = 16;
break;

default:
throw PythonBackendException(
Expand Down Expand Up @@ -301,6 +307,15 @@ dlpack_to_triton_type(const DLDataType& data_type)
}
}

if (data_type.code == DLDataTypeCode::kDLBfloat) {
if (data_type.bits != 16) {
throw PythonBackendException(
"Expected BF16 tensor to have 16 bits, but had: " +
std::to_string(data_type.bits));
}
return TRITONSERVER_TYPE_BF16;
}

return TRITONSERVER_TYPE_INVALID;
}
}}} // namespace triton::backend::python
24 changes: 18 additions & 6 deletions src/pb_tensor.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -152,7 +152,10 @@ PbTensor::PbTensor(
#ifdef TRITON_PB_STUB
if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
if (dtype != TRITONSERVER_TYPE_BYTES) {
if (dtype == TRITONSERVER_TYPE_BF16) {
// No native numpy representation for BF16. DLPack should be used instead.
numpy_array_ = py::none();
} else if (dtype != TRITONSERVER_TYPE_BYTES) {
py::object numpy_array =
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
Expand Down Expand Up @@ -512,12 +515,18 @@ PbTensor::Name() const
const py::array*
PbTensor::AsNumpy() const
{
if (IsCPU()) {
return &numpy_array_;
} else {
if (!IsCPU()) {
throw PythonBackendException(
"Tensor is stored in GPU and cannot be converted to NumPy.");
}

if (dtype_ == TRITONSERVER_TYPE_BF16) {
throw PythonBackendException(
"Tensor dtype is BF16 and cannot be converted to NumPy. Use "
"to_dlpack() and from_dlpack() instead.");
}

return &numpy_array_;
}
#endif // TRITON_PB_STUB

Expand Down Expand Up @@ -643,7 +652,10 @@ PbTensor::PbTensor(
#ifdef TRITON_PB_STUB
if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
if (dtype_ != TRITONSERVER_TYPE_BYTES) {
if (dtype_ == TRITONSERVER_TYPE_BF16) {
// No native numpy representation for BF16. DLPack should be used instead.
numpy_array_ = py::none();
} else if (dtype_ != TRITONSERVER_TYPE_BYTES) {
py::object numpy_array =
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
Expand Down
Loading