Skip to content

Commit

Permalink
initial support for burn tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
sunny-g committed Mar 11, 2023
1 parent a80751a commit f348fc3
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ image = {version = "^0.23", optional = true}
nalgebra = {version = "^0.26", optional = true}
num-traits = {version = "^0.2", optional = true}
ndarray = {version = "^0.15", optional = true}
# burn-tensor = {version = "0.5", optional = true}
burn-tensor = {git = "https://github.com/burn-rs/burn", optional = true}

[dev-dependencies]
mime = "0.3.14"
Expand All @@ -31,3 +33,4 @@ all = ["show_nalgebra", "show_ndarray", "show_image"]
show_nalgebra = ["nalgebra", "num-traits"]
show_ndarray = ["ndarray"]
show_image = ["image"]
show_burn = ["burn-tensor"]
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ mod show_ndarray;
#[cfg(feature = "show_image")]
mod show_image;

#[cfg(feature = "show_burn")]
mod show_burn;

use anyhow::Error;
use std::path::Path;
use std::path::PathBuf;
Expand Down
93 changes: 93 additions & 0 deletions src/show_burn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::ContentInfo;
use crate::Showable;
use anyhow::Error;
use burn_tensor;
use core::fmt::{Debug, Write};

impl<B, const D: usize, K> Showable for burn_tensor::Tensor<B, D, K>
where
B: burn_tensor::backend::Backend,
K: burn_tensor::TensorKind<B> + burn_tensor::BasicOps<B>,
K::Elem: Debug,
{
fn to_content_info(&self) -> Result<ContentInfo, Error> {
self.to_data().to_content_info()
}
}

const START: &str = "tensor([";
const END: &str = "])";

impl<P, const D: usize> Showable for burn_tensor::Data<P, D>
where
P: Debug,
{
fn to_content_info(&self) -> Result<ContentInfo, Error> {
let dims = self.shape.dims;
let data = &self.value;

// if single scalar, just print
if dims.len() == 1 && dims[0] == 1 {
return Ok(ContentInfo {
mime_type: "text/plain".into(),
content: format!("tensor({:?})", data[0]),
});
}

let col_size = match dims.len() {
1 => dims[0],
2 => dims[1],
_ => unimplemented!("only supports 1D/2D tensors"),
};
let row_size = data.len() / col_size;

let mut out = String::new();
let padding = " ".repeat(8);

write!(out, "{}", START)?;
for (r, row) in data.as_slice().chunks(col_size).enumerate() {
let (pre, post) = match r {
0 => ("", ",\n"),
_ if r < row_size - 1 => (padding.as_str(), ",\n"),
_ => (padding.as_str(), END),
};

write!(out, "{}{:?}{}", pre, row, post)?;
}

Ok(ContentInfo {
mime_type: "text/plain".into(),
content: out,
})
}
}

#[cfg(test)]
mod tests {
use crate::Showable;

#[test]
fn test_no_crash_on_3x4() {
use burn_tensor::Tensor;
let m = Tensor::<_, 2>::zeros([3, 4]);
m.show().unwrap();
}

#[test]
fn test_no_crash_on_3x1() {
use burn_tensor::Tensor;
let m = Tensor::<_, 1>::zeros([3]);
m.show().unwrap();
}
}

0 comments on commit f348fc3

Please sign in to comment.