-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce a new trait to represent types that can be used as output from a tensor #62
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,15 +18,14 @@ use onnxruntime_sys as sys; | |
use crate::{ | ||
char_p_to_string, | ||
environment::Environment, | ||
error::{status_to_result, NonMatchingDimensionsError, OrtError, Result}, | ||
error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result}, | ||
g_ort, | ||
memory::MemoryInfo, | ||
tensor::{ | ||
ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor}, | ||
OrtTensor, | ||
ort_owned_tensor::OrtOwnedTensor, OrtTensor, TensorDataToType, TensorElementDataType, | ||
TypeToTensorElementDataType, | ||
}, | ||
AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType, | ||
TypeToTensorElementDataType, | ||
AllocatorType, GraphOptimizationLevel, MemType, | ||
}; | ||
|
||
#[cfg(feature = "model-fetching")] | ||
|
@@ -371,7 +370,7 @@ impl<'a> Session<'a> { | |
) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>> | ||
where | ||
TIn: TypeToTensorElementDataType + Debug + Clone, | ||
TOut: TypeToTensorElementDataType + Debug + Clone, | ||
TOut: TensorDataToType, | ||
D: ndarray::Dimension, | ||
'm: 't, // 'm outlives 't (memory info outlives tensor) | ||
's: 'm, // 's outlives 'm (session outlives memory info) | ||
|
@@ -440,21 +439,30 @@ impl<'a> Session<'a> { | |
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> = | ||
output_tensor_extractors_ptrs | ||
.into_iter() | ||
.map(|ptr| { | ||
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = | ||
std::ptr::null_mut(); | ||
let status = unsafe { | ||
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _) | ||
}; | ||
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; | ||
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) }; | ||
unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) }; | ||
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); | ||
|
||
let mut output_tensor_extractor = | ||
OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims)); | ||
output_tensor_extractor.tensor_ptr = ptr; | ||
output_tensor_extractor.extract::<TOut>() | ||
.map(|tensor_ptr| { | ||
let dims = unsafe { | ||
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little pre-cleanup -- in my WIP branch to extract string output, I needed to get multiple things out of the tensor info, so I made this helper to make it hard to forget to clean up |
||
get_tensor_dimensions(tensor_info_ptr) | ||
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>()) | ||
}) | ||
}?; | ||
|
||
// Note: Both tensor and array will point to the same data, nothing is copied. | ||
// As such, there is no need to free the pointer used to create the ArrayView. | ||
assert_ne!(tensor_ptr, std::ptr::null_mut()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. taken from what was |
||
|
||
let mut is_tensor = 0; | ||
unsafe { call_ort(|ort| ort.IsTensor.unwrap()(tensor_ptr, &mut is_tensor)) } | ||
.map_err(OrtError::IsTensor)?; | ||
assert_eq!(is_tensor, 1); | ||
|
||
let array_view = TOut::extract_array(ndarray::IxDyn(&dims), tensor_ptr)?; | ||
|
||
Ok(OrtOwnedTensor::new( | ||
tensor_ptr, | ||
array_view, | ||
&memory_info_ref, | ||
)) | ||
}) | ||
.collect(); | ||
|
||
|
@@ -554,25 +562,60 @@ unsafe fn get_tensor_dimensions( | |
tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, | ||
) -> Result<Vec<i64>> { | ||
let mut num_dims = 0; | ||
let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims); | ||
status_to_result(status).map_err(OrtError::GetDimensionsCount)?; | ||
call_ort(|ort| ort.GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims)) | ||
.map_err(OrtError::GetDimensionsCount)?; | ||
assert_ne!(num_dims, 0); | ||
|
||
let mut node_dims: Vec<i64> = vec![0; num_dims as usize]; | ||
let status = g_ort().GetDimensions.unwrap()( | ||
tensor_info_ptr, | ||
node_dims.as_mut_ptr(), // FIXME: UB? | ||
num_dims, | ||
); | ||
status_to_result(status).map_err(OrtError::GetDimensions)?; | ||
call_ort(|ort| { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just converting a few g_ort()'s to call_ort |
||
ort.GetDimensions.unwrap()( | ||
tensor_info_ptr, | ||
node_dims.as_mut_ptr(), // FIXME: UB? | ||
num_dims, | ||
) | ||
}) | ||
.map_err(OrtError::GetDimensions)?; | ||
Ok(node_dims) | ||
} | ||
|
||
unsafe fn extract_data_type( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extracted a function for this logic as-is as I had need of it in one additional callsite for strings |
||
tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, | ||
) -> Result<TensorElementDataType> { | ||
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; | ||
call_ort(|ort| ort.GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys)) | ||
.map_err(OrtError::TensorElementType)?; | ||
assert_ne!( | ||
type_sys, | ||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED | ||
); | ||
// This transmute should be safe since its value is read from GetTensorElementType which we must trust. | ||
Ok(std::mem::transmute(type_sys)) | ||
} | ||
|
||
/// Calls the provided closure with the result of `GetTensorTypeAndShape`, deallocating the | ||
/// resulting `*OrtTensorTypeAndShapeInfo` before returning. | ||
unsafe fn call_with_tensor_info<F, T>(tensor_ptr: *const sys::OrtValue, mut f: F) -> Result<T> | ||
where | ||
F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> Result<T>, | ||
{ | ||
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); | ||
call_ort(|ort| ort.GetTensorTypeAndShape.unwrap()(tensor_ptr, &mut tensor_info_ptr as _)) | ||
.map_err(OrtError::GetTensorTypeAndShape)?; | ||
|
||
let res = f(tensor_info_ptr); | ||
|
||
// no return code, so no errors to check for | ||
g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr); | ||
|
||
res | ||
} | ||
|
||
/// This module contains dangerous functions working on raw pointers. | ||
/// Those functions are only to be used from inside the | ||
/// `SessionBuilder::with_model_from_file()` method. | ||
mod dangerous { | ||
use super::*; | ||
use crate::tensor::TensorElementDataType; | ||
|
||
pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result<u64> { | ||
let f = g_ort().SessionGetInputCount.unwrap(); | ||
|
@@ -689,16 +732,7 @@ mod dangerous { | |
status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?; | ||
assert_ne!(tensor_info_ptr, std::ptr::null_mut()); | ||
|
||
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; | ||
let status = | ||
unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) }; | ||
status_to_result(status).map_err(OrtError::TensorElementType)?; | ||
assert_ne!( | ||
type_sys, | ||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED | ||
); | ||
// This transmute should be safe since its value is read from GetTensorElementType which we must trust. | ||
let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) }; | ||
let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? }; | ||
|
||
// info!("{} : type={}", i, type_); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved as-is into
tensor.rs
. I figured with this + the new trait it was makinglib.rs
pretty fat, and I thinkcrate::tensor
is a reasonable home for these types. WDYT?