Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a new trait to represent types that can be used as output from a tensor #62

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 0 additions & 148 deletions onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,154 +322,6 @@ impl Into<sys::GraphOptimizationLevel> for GraphOptimizationLevel {
}
}

// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved as-is into tensor.rs. I figured with this + the new trait it was making lib.rs pretty fat, and I think crate::tensor is a reasonable home for these types. WDYT?

// FIXME: Add tests to cover the commented out types
/// Enum mapping ONNX Runtime's supported tensor types
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum TensorElementDataType {
/// 32-bit floating point, equivalent to Rust's `f32`
Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
/// Unsigned 8-bit int, equivalent to Rust's `u8`
Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
/// Signed 8-bit int, equivalent to Rust's `i8`
Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
/// Unsigned 16-bit int, equivalent to Rust's `u16`
Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
/// Signed 16-bit int, equivalent to Rust's `i16`
Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
/// Signed 32-bit int, equivalent to Rust's `i32`
Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
/// Signed 64-bit int, equivalent to Rust's `i64`
Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
/// String, equivalent to Rust's `String`
String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
// /// Boolean, equivalent to Rust's `bool`
// Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
// /// 16-bit floating point, equivalent to Rust's `f16`
// Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
/// 64-bit floating point, equivalent to Rust's `f64`
Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
/// Unsigned 32-bit int, equivalent to Rust's `u32`
Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
/// Unsigned 64-bit int, equivalent to Rust's `u64`
Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
// /// Complex 64-bit floating point, equivalent to Rust's `???`
// Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt,
// /// Complex 128-bit floating point, equivalent to Rust's `???`
// Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt,
// /// Brain 16-bit floating point
// Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt,
}

impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
fn into(self) -> sys::ONNXTensorElementDataType {
use TensorElementDataType::*;
match self {
Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
// Bool => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
// }
// Float16 => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
// }
Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
// Complex64 => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
// }
// Complex128 => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
// }
// Bfloat16 => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
// }
}
}
}

/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
pub trait TypeToTensorElementDataType {
/// Return the ONNX type for a Rust type
fn tensor_element_data_type() -> TensorElementDataType;

/// If the type is `String`, returns `Some` with utf8 contents, else `None`.
fn try_utf8_bytes(&self) -> Option<&[u8]>;
}

macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl TypeToTensorElementDataType for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
// unsafe { std::mem::transmute(TensorElementDataType::$variant) }
TensorElementDataType::$variant
}

fn try_utf8_bytes(&self) -> Option<&[u8]> {
None
}
}
};
}

impl_type_trait!(f32, Float);
impl_type_trait!(u8, Uint8);
impl_type_trait!(i8, Int8);
impl_type_trait!(u16, Uint16);
impl_type_trait!(i16, Int16);
impl_type_trait!(i32, Int32);
impl_type_trait!(i64, Int64);
// impl_type_trait!(bool, Bool);
// impl_type_trait!(f16, Float16);
impl_type_trait!(f64, Double);
impl_type_trait!(u32, Uint32);
impl_type_trait!(u64, Uint64);
// impl_type_trait!(, Complex64);
// impl_type_trait!(, Complex128);
// impl_type_trait!(, Bfloat16);

/// Adapter for common Rust string types to Onnx strings.
///
/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
/// types (which might implement `AsRef<str>` at some point in the future).
pub trait Utf8Data {
/// Returns the utf8 contents.
fn utf8_bytes(&self) -> &[u8];
}

impl Utf8Data for String {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}

impl<'a> Utf8Data for &'a str {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}

impl<T: Utf8Data> TypeToTensorElementDataType for T {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}

fn try_utf8_bytes(&self) -> Option<&[u8]> {
Some(self.utf8_bytes())
}
}

/// Allocator type
#[derive(Debug, Clone)]
#[repr(i32)]
Expand Down
112 changes: 73 additions & 39 deletions onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ use onnxruntime_sys as sys;
use crate::{
char_p_to_string,
environment::Environment,
error::{status_to_result, NonMatchingDimensionsError, OrtError, Result},
error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result},
g_ort,
memory::MemoryInfo,
tensor::{
ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor},
OrtTensor,
ort_owned_tensor::OrtOwnedTensor, OrtTensor, TensorDataToType, TensorElementDataType,
TypeToTensorElementDataType,
},
AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType,
TypeToTensorElementDataType,
AllocatorType, GraphOptimizationLevel, MemType,
};

#[cfg(feature = "model-fetching")]
Expand Down Expand Up @@ -371,7 +370,7 @@ impl<'a> Session<'a> {
) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
TOut: TypeToTensorElementDataType + Debug + Clone,
TOut: TensorDataToType,
D: ndarray::Dimension,
'm: 't, // 'm outlives 't (memory info outlives tensor)
's: 'm, // 's outlives 'm (session outlives memory info)
Expand Down Expand Up @@ -440,21 +439,30 @@ impl<'a> Session<'a> {
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
output_tensor_extractors_ptrs
.into_iter()
.map(|ptr| {
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo =
std::ptr::null_mut();
let status = unsafe {
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
};
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };
unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();

let mut output_tensor_extractor =
OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims));
output_tensor_extractor.tensor_ptr = ptr;
output_tensor_extractor.extract::<TOut>()
.map(|tensor_ptr| {
let dims = unsafe {
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little pre-cleanup -- in my WIP branch to extract string output, I needed to get multiple things out of the tensor info, so I made this helper to make it hard to forget to clean up tensor_info_ptr in all error handling cases. Soon this will be getting more than just dims out of the info ptr.

get_tensor_dimensions(tensor_info_ptr)
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>())
})
}?;

// Note: Both tensor and array will point to the same data, nothing is copied.
// As such, there is no need to free the pointer used to create the ArrayView.
assert_ne!(tensor_ptr, std::ptr::null_mut());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taken from what was OrtOwnedTensorExtractor, the remaining contents of which now lives mostly in the numeric type impls of TensorDataToType::extract_array


let mut is_tensor = 0;
unsafe { call_ort(|ort| ort.IsTensor.unwrap()(tensor_ptr, &mut is_tensor)) }
.map_err(OrtError::IsTensor)?;
assert_eq!(is_tensor, 1);

let array_view = TOut::extract_array(ndarray::IxDyn(&dims), tensor_ptr)?;

Ok(OrtOwnedTensor::new(
tensor_ptr,
array_view,
&memory_info_ref,
))
})
.collect();

Expand Down Expand Up @@ -554,25 +562,60 @@ unsafe fn get_tensor_dimensions(
tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo,
) -> Result<Vec<i64>> {
let mut num_dims = 0;
let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims);
status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
call_ort(|ort| ort.GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims))
.map_err(OrtError::GetDimensionsCount)?;
assert_ne!(num_dims, 0);

let mut node_dims: Vec<i64> = vec![0; num_dims as usize];
let status = g_ort().GetDimensions.unwrap()(
tensor_info_ptr,
node_dims.as_mut_ptr(), // FIXME: UB?
num_dims,
);
status_to_result(status).map_err(OrtError::GetDimensions)?;
call_ort(|ort| {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just converting a few g_ort()'s to call_ort

ort.GetDimensions.unwrap()(
tensor_info_ptr,
node_dims.as_mut_ptr(), // FIXME: UB?
num_dims,
)
})
.map_err(OrtError::GetDimensions)?;
Ok(node_dims)
}

unsafe fn extract_data_type(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extracted a function for this logic as-is as I had need of it in one additional callsite for strings

tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo,
) -> Result<TensorElementDataType> {
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
call_ort(|ort| ort.GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys))
.map_err(OrtError::TensorElementType)?;
assert_ne!(
type_sys,
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
);
// This transmute should be safe since its value is read from GetTensorElementType which we must trust.
Ok(std::mem::transmute(type_sys))
}

/// Calls the provided closure with the result of `GetTensorTypeAndShape`, deallocating the
/// resulting `*OrtTensorTypeAndShapeInfo` before returning.
unsafe fn call_with_tensor_info<F, T>(tensor_ptr: *const sys::OrtValue, mut f: F) -> Result<T>
where
F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> Result<T>,
{
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
call_ort(|ort| ort.GetTensorTypeAndShape.unwrap()(tensor_ptr, &mut tensor_info_ptr as _))
.map_err(OrtError::GetTensorTypeAndShape)?;

let res = f(tensor_info_ptr);

// no return code, so no errors to check for
g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr);

res
}

/// This module contains dangerous functions working on raw pointers.
/// Those functions are only to be used from inside the
/// `SessionBuilder::with_model_from_file()` method.
mod dangerous {
use super::*;
use crate::tensor::TensorElementDataType;

pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result<u64> {
let f = g_ort().SessionGetInputCount.unwrap();
Expand Down Expand Up @@ -689,16 +732,7 @@ mod dangerous {
status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?;
assert_ne!(tensor_info_ptr, std::ptr::null_mut());

let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
let status =
unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) };
status_to_result(status).map_err(OrtError::TensorElementType)?;
assert_ne!(
type_sys,
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
);
// This transmute should be safe since its value is read from GetTensorElementType which we must trust.
let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) };
let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? };

// info!("{} : type={}", i, type_);

Expand Down
Loading