Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/Split ONNX Import #2568

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ represent the corresponding Burn Op.
| [Softplus][170] | ❌ | ❌ |
| [Softsign][171] | ❌ | ❌ |
| [SpaceToDepth][172] | ❌ | ❌ |
| [Split][173] | | |
| [Split][173] | | |
| [SplitToSequence][174] | ❌ | ❌ |
| [Sqrt][175] | ✅ | ✅ |
| [Squeeze][176] | ✅ | ✅ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ fn main() {
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.input("tests/split/split.onnx")
.out_dir("model/")
.run_from_script();

Expand Down
Binary file added crates/burn-import/onnx-tests/tests/split/split.onnx
Binary file not shown.
40 changes: 40 additions & 0 deletions crates/burn-import/onnx-tests/tests/split/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3

# used to generate model: split.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
x = torch.split(x, 2)
return x


def main():
# Set seed for reproducability
torch.manual_seed(42)

torch.set_printoptions(precision=8)

model = Model()
model.eval()
device = torch.device("cpu")

file_name = "split.onnx"
test_input = torch.arange(10, device=device).reshape(5, 2)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(file_name))

print("Test input data shape: {}".format(test_input.shape))
print("Splitting input tensor into chunks of size 2")
output = model.forward(test_input)
print("Test output data length: {}".format(len(output)))

if __name__ == '__main__':
main()
6 changes: 3 additions & 3 deletions crates/burn-import/onnx-tests/tests/squeeze/squeeze.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# used to generate models: squeeze_opset13.onnx,
# used to generate models: squeeze_opset13.onnx,
# squeeze_opset16.onnx, and squeeze_multiple.onnx

import torch
Expand Down Expand Up @@ -35,7 +35,7 @@ def main():
torch.onnx.export(model, test_input, "squeeze_opset16.onnx", verbose=False, opset_version=16)
torch.onnx.export(model, test_input, "squeeze_opset13.onnx", verbose=False, opset_version=13)

print(f"Finished exporting model to 16 and 13")
print("Finished exporting model to 16 and 13")

# Output some test data for use in the test
output = model(test_input)
Expand All @@ -56,7 +56,7 @@ def main():
onnx.checker.check_model(m, full_check=True)
onnx.save(m, "squeeze_multiple.onnx")

print(f"Finished exporting model with multiple squeeze axes specified to 13")
print("Finished exporting model with multiple squeeze axes specified to 13")

if __name__ == "__main__":
main()
14 changes: 13 additions & 1 deletion crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ include_models!(
transpose,
unsqueeze,
unsqueeze_opset11,
unsqueeze_opset16
unsqueeze_opset16,
split
);

#[cfg(test)]
Expand Down Expand Up @@ -2214,4 +2215,15 @@ mod tests {
assert!(i_output.equal(i_expected).all().into_scalar());
assert!(b_output.equal(b_expected).all().into_scalar());
}

#[test]
fn split() {
let device = Default::default();
let model = split::Model::<Backend>::new(&device);
let shape = [5, 2];
let input = Tensor::ones(shape, &device);

let split_tensors = model.forward(input);
assert_eq!(split_tensors.len(), 3);
}
}
5 changes: 4 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::{
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode,
random_uniform::RandomUniformNode, random_uniform_like::RandomUniformLikeNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, split::SplitNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
Expand Down Expand Up @@ -116,6 +116,7 @@ pub enum Node<PS: PrecisionSettings> {
Resize(ResizeNode),
Slice(SliceNode),
Squeeze(SqueezeNode),
Split(SplitNode),
Sum(SumNode),
Tile(TileNode),
Trilu(TriluNode),
Expand Down Expand Up @@ -179,6 +180,7 @@ macro_rules! match_all {
Node::RandomUniform(node) => $func(node),
Node::RandomUniformLike(node) => $func(node),
Node::ConstantOfShape(node) => $func(node),
Node::Split(node) => $func(node),
_ => unimplemented!(),
}
}};
Expand Down Expand Up @@ -239,6 +241,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::RandomUniform(_) => "random_uniform",
Node::RandomUniformLike(_) => "random_uniform_like",
Node::ConstantOfShape(_) => "constant_of_shape",
Node::Split(_) => "split",
_ => unimplemented!(),
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(crate) mod range;
pub(crate) mod reshape;
pub(crate) mod resize;
pub(crate) mod slice;
pub(crate) mod split;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod tile;
Expand Down
137 changes: 137 additions & 0 deletions crates/burn-import/src/burn/node/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use super::{Node, NodeCodegen};
use crate::burn::{OtherType, Scope, TensorType, ToTokens, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Config, Debug)]
pub struct SplitConfig {
pub axis: usize,
pub num_outputs: Option<usize>,
pub split_sizes: Option<Vec<usize>>,
}

#[derive(Debug, Clone, new)]
pub struct SplitNode {
pub input: TensorType,
pub outputs: Vec<TensorType>,
pub config: SplitConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for SplitNode {
fn output_types(&self) -> Vec<Type> {
let tensor = &self.outputs[0];
let dims = tensor.dim;

let dims_literal = proc_macro2::Literal::usize_unsuffixed(dims);

let vec_tensor_type = quote! {Vec<Tensor<B, #dims_literal>>};

let other_type = OtherType {
name: syn::Ident::new("split_tensors", proc_macro2::Span::call_site()),
ty: vec_tensor_type,
};

vec![Type::Other(other_type)]
}

fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let axis = self.config.axis.to_tokens();
let split_tensors = syn::Ident::new("split_tensors", proc_macro2::Span::call_site());

let split_code = if let Some(split_sizes) = &self.config.split_sizes {
let split_sizes_tokens = split_sizes.to_tokens();
quote! {
let #split_tensors = #input.split_with_sizes(#split_sizes_tokens, #axis);
}
} else {
let num_outputs = self.config.num_outputs.unwrap();
let num_outputs_tokens = num_outputs.to_tokens();
quote! {
let #split_tensors = #input.split(#num_outputs_tokens, #axis);
}
};

split_code
}

fn into_node(self) -> Node<PS> {
Node::Split(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{split::SplitNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_split() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(SplitNode::new(
TensorType::new_float("tensor1", 2),
vec![
TensorType::new_float("tensor2", 2),
TensorType::new_float("tensor3", 2),
],
SplitConfig {
axis: 0,
num_outputs: Some(2),
split_sizes: None,
},
));

graph.register_input_output(
vec!["tensor1".to_string()],
vec!["split_tensors".to_string()],
);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
) -> Vec<Tensor<B, 2>> {
let split_tensors = tensor1.split(2, 0);

split_tensors
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
55 changes: 54 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use burn::nn::{
};

use crate::burn::node::{
expand::ExpandShape, pad::PadConfig, tile::TileConfig, trilu::TriluConfig,
expand::ExpandShape, pad::PadConfig, split::SplitConfig, tile::TileConfig, trilu::TriluConfig,
};
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node};

Expand Down Expand Up @@ -1805,3 +1805,56 @@ pub fn squeeze_config(curr: &Node) -> Vec<i64> {

axes
}

pub fn split_config(node: &Node) -> SplitConfig {
// Axis to split along (default is 0 per ONNX spec)
let mut axis: i64 = 0;
let mut num_outputs: Option<usize> = None;
let mut split_sizes: Option<Vec<usize>> = None;

let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"num_outputs" => num_outputs = Some(value.clone().into_i64() as usize),
_ => {}
}
}

if axis < 0 {
axis += tensor.dim as i64;
}

if node.inputs.len() > 1 {
let split_input_arg = &node.inputs[1];
if let Some(Data::Int64s(sizes)) = &split_input_arg.value {
let sizes: Vec<usize> = sizes.iter().map(|&x| x as usize).collect();
split_sizes = Some(sizes);
}
}

// Only one of 'split_sizes' or 'num_outputs' is provided
if split_sizes.is_some() && num_outputs.is_some() {
panic!("Split: Either 'split' input or 'num_outputs' attribute should be specified, but not both.");
}

// If neither 'split_sizes' nor 'num_outputs' is provided, infer 'num_outputs' from the number of outputs
if split_sizes.is_none() && num_outputs.is_none() {
num_outputs = Some(node.outputs.len());
}

// Final validation to ensure one of 'split_sizes' or 'num_outputs' is specified
if split_sizes.is_none() && num_outputs.is_none() {
panic!("Split: Either 'split' input or 'num_outputs' attribute must be specified.");
}

SplitConfig {
axis: axis as usize,
num_outputs,
split_sizes,
}
}
12 changes: 11 additions & 1 deletion crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use crate::{
reshape::ReshapeNode,
resize::ResizeNode,
slice::SliceNode,
split::SplitNode,
squeeze::SqueezeNode,
sum::SumNode,
tile::TileNode,
Expand All @@ -72,7 +73,7 @@ use super::op_configuration::{
linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, pad_config,
reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config,
reduce_sum_config, reshape_config, resize_config, shape_config, slice_config, softmax_config,
squeeze_config, tile_config, transpose_config, trilu_config, unsqueeze_config,
split_config, squeeze_config, tile_config, transpose_config, trilu_config, unsqueeze_config,
};
use onnx_ir::{
convert_constant_value,
Expand Down Expand Up @@ -356,6 +357,7 @@ impl ParsedOnnxGraph {
NodeType::ConstantOfShape => {
graph.register(Self::constant_of_shape_conversion(node))
}
NodeType::Split => graph.register(Self::split_conversion(node)),
node_type => unsupported_ops.push(node_type),
}
}
Expand Down Expand Up @@ -1264,6 +1266,14 @@ impl ParsedOnnxGraph {
let config = trilu_config(&node);
TriluNode::new(input, output, config)
}

fn split_conversion(node: Node) -> SplitNode {
let input = TensorType::from(node.inputs.first().unwrap());
let outputs = node.outputs.iter().map(TensorType::from).collect();
let config = split_config(&node);

SplitNode::new(input, outputs, config)
}
}

/// Extract data from node states and convert it to `TensorData`.
Expand Down
Loading
Loading