From 85fb6ef31e5d07c7fa96483b296c64e6b4756442 Mon Sep 17 00:00:00 2001 From: Jack Hogan Date: Thu, 28 Nov 2024 14:56:38 -0500 Subject: [PATCH 1/3] Added sequential macro --- crates/burn-core/src/nn/mod.rs | 2 + crates/burn-core/src/nn/sequential.rs | 89 +++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 crates/burn-core/src/nn/sequential.rs diff --git a/crates/burn-core/src/nn/mod.rs b/crates/burn-core/src/nn/mod.rs index ac428c3063..b3897760ee 100644 --- a/crates/burn-core/src/nn/mod.rs +++ b/crates/burn-core/src/nn/mod.rs @@ -33,6 +33,7 @@ mod prelu; mod relu; mod rnn; mod rope_encoding; +mod sequential; mod sigmoid; mod swiglu; mod tanh; @@ -52,6 +53,7 @@ pub use prelu::*; pub use relu::*; pub use rnn::*; pub use rope_encoding::*; +pub use sequential::*; pub use sigmoid::*; pub use swiglu::*; pub use tanh::*; diff --git a/crates/burn-core/src/nn/sequential.rs b/crates/burn-core/src/nn/sequential.rs new file mode 100644 index 0000000000..54361ff7d8 --- /dev/null +++ b/crates/burn-core/src/nn/sequential.rs @@ -0,0 +1,89 @@ +/// Create a sequential neural network, similar to numpy's nn.Sequential. +/// +/// To use this macro, separate your modules into three categories: +/// - Unit modules: Modules that don't take any parameters (eg. Relu, Sigmoid) +/// - Modules: Modules that take parameters, but don't have a backend parameter (eg. Dropout, LeakyRelu) +/// - Backend modules: Modules that take a backend parameter (eg. Linear) +/// +/// List these classes of modules as comma-separated within classes, then semicolons between, like so: +/// ``` +/// gen_sequential! { +/// // No config +/// Relu, +/// Sigmoid; +/// // Has config +/// DropoutConfig => Dropout, +/// LeakyReluConfig => LeakyRelu; +/// // Requires a backend () +/// LinearConfig => Linear +/// } +/// ``` +/// +/// If there aren't any members of a particular class, the semicolon is still needed +/// ``` +/// gen_sequential! { +/// Relu, +/// Sigmoid; +/// // Nothing with no config +/// ; +/// LinearConfig -> Linear +/// } +/// +/// To use this macro, use the type `SequentialConfig` and `Sequential` in your code. +#[macro_export] +macro_rules! gen_sequential { + ($($unit:tt),*; $($cfg:ty => $module:tt),*; $($bcfg:ty => $bmodule:tt),*) => { + #[derive(Debug, ::burn::config::Config)] + pub enum SequentialLayerConfig { + $($unit,)* + $($module($cfg),)* + $($bmodule($bcfg),)* + } + + #[derive(Debug, ::burn::config::Config)] + pub struct SequentialConfig { + pub layers: Vec + } + + impl SequentialConfig { + pub fn init(&self, device: &B::Device) -> Sequential { + Sequential { + layers: self.layers.iter().map(|l| match l { + $(SequentialLayerConfig::$unit => SequentialLayer::$unit($unit),)* + $(SequentialLayerConfig::$module(c) => SequentialLayer::$module(c.init()),)* + $(SequentialLayerConfig::$bmodule(c) => SequentialLayer::$bmodule(c.init(device)),)* + }).collect() + } + } + } + + #[derive(Debug, ::burn::module::Module)] + pub enum SequentialLayer { + /// In case the expansion doesn't use any backend-based layers. This should never be used. + _PhantomData(::core::marker::PhantomData), + $($unit($unit),)* + $($module($module),)* + $($bmodule($bmodule),)* + } + + #[derive(Debug, ::burn::module::Module)] + pub struct Sequential { + pub layers: Vec> + } + + impl Sequential { + pub fn forward(&self, mut input: ::burn::tensor::Tensor) -> ::burn::tensor::Tensor { + for layer in &self.layers { + input = match layer { + SequentialLayer::_PhantomData(_) => unreachable!("PhantomData should never be instantiated"), + $(SequentialLayer::$unit(u) => u.forward(input),)* + $(SequentialLayer::$module(m) => m.forward(input),)* + $(SequentialLayer::$bmodule(b) => b.forward(input),)* + }; + } + + input + } + } + } +} From 97aec98279b1b02be99446cdfb8cc601c7de463b Mon Sep 17 00:00:00 2001 From: Jack Hogan Date: Thu, 28 Nov 2024 15:21:16 -0500 Subject: [PATCH 2/3] Slight export fixes --- crates/burn-core/src/nn/sequential.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/crates/burn-core/src/nn/sequential.rs b/crates/burn-core/src/nn/sequential.rs index 54361ff7d8..3b8b472c09 100644 --- a/crates/burn-core/src/nn/sequential.rs +++ b/crates/burn-core/src/nn/sequential.rs @@ -19,17 +19,18 @@ /// } /// ``` /// -/// If there aren't any members of a particular class, the semicolon is still needed +/// If there aren't any members of a particular class, the semicolon is still needed: /// ``` /// gen_sequential! { /// Relu, /// Sigmoid; /// // Nothing with no config /// ; -/// LinearConfig -> Linear +/// LinearConfig => Linear /// } +/// ``` /// -/// To use this macro, use the type `SequentialConfig` and `Sequential` in your code. +/// To use this macro, use the types `SequentialConfig` and `Sequential` in your code. #[macro_export] macro_rules! gen_sequential { ($($unit:tt),*; $($cfg:ty => $module:tt),*; $($bcfg:ty => $bmodule:tt),*) => { @@ -87,3 +88,5 @@ macro_rules! gen_sequential { } } } + +pub use gen_sequential; From 0c019b21f4f38db055a299d85f7095a887b342c0 Mon Sep 17 00:00:00 2001 From: Jack Hogan Date: Fri, 29 Nov 2024 12:57:29 -0500 Subject: [PATCH 3/3] Added unit test, fixed doc test --- crates/burn-core/src/nn/sequential.rs | 22 ++++++------- crates/burn-core/tests/test_gen_sequential.rs | 33 +++++++++++++++++++ 2 files changed, 44 insertions(+), 11 deletions(-) create mode 100644 crates/burn-core/tests/test_gen_sequential.rs diff --git a/crates/burn-core/src/nn/sequential.rs b/crates/burn-core/src/nn/sequential.rs index 3b8b472c09..64a3dc68d4 100644 --- a/crates/burn-core/src/nn/sequential.rs +++ b/crates/burn-core/src/nn/sequential.rs @@ -6,7 +6,7 @@ /// - Backend modules: Modules that take a backend parameter (eg. Linear) /// /// List these classes of modules as comma-separated within classes, then semicolons between, like so: -/// ``` +/// ```ignore /// gen_sequential! { /// // No config /// Relu, @@ -20,7 +20,7 @@ /// ``` /// /// If there aren't any members of a particular class, the semicolon is still needed: -/// ``` +/// ```ignore /// gen_sequential! { /// Relu, /// Sigmoid; @@ -34,20 +34,20 @@ #[macro_export] macro_rules! gen_sequential { ($($unit:tt),*; $($cfg:ty => $module:tt),*; $($bcfg:ty => $bmodule:tt),*) => { - #[derive(Debug, ::burn::config::Config)] + #[derive(Debug, burn::config::Config)] pub enum SequentialLayerConfig { $($unit,)* $($module($cfg),)* $($bmodule($bcfg),)* } - #[derive(Debug, ::burn::config::Config)] + #[derive(Debug, burn::config::Config)] pub struct SequentialConfig { pub layers: Vec } impl SequentialConfig { - pub fn init(&self, device: &B::Device) -> Sequential { + pub fn init(&self, device: &B::Device) -> Sequential { Sequential { layers: self.layers.iter().map(|l| match l { $(SequentialLayerConfig::$unit => SequentialLayer::$unit($unit),)* @@ -58,8 +58,8 @@ macro_rules! gen_sequential { } } - #[derive(Debug, ::burn::module::Module)] - pub enum SequentialLayer { + #[derive(Debug, burn::module::Module)] + pub enum SequentialLayer { /// In case the expansion doesn't use any backend-based layers. This should never be used. _PhantomData(::core::marker::PhantomData), $($unit($unit),)* @@ -67,13 +67,13 @@ macro_rules! gen_sequential { $($bmodule($bmodule),)* } - #[derive(Debug, ::burn::module::Module)] - pub struct Sequential { + #[derive(Debug, burn::module::Module)] + pub struct Sequential { pub layers: Vec> } - impl Sequential { - pub fn forward(&self, mut input: ::burn::tensor::Tensor) -> ::burn::tensor::Tensor { + impl Sequential { + pub fn forward(&self, mut input: burn::tensor::Tensor) -> burn::tensor::Tensor { for layer in &self.layers { input = match layer { SequentialLayer::_PhantomData(_) => unreachable!("PhantomData should never be instantiated"), diff --git a/crates/burn-core/tests/test_gen_sequential.rs b/crates/burn-core/tests/test_gen_sequential.rs new file mode 100644 index 0000000000..11b729f725 --- /dev/null +++ b/crates/burn-core/tests/test_gen_sequential.rs @@ -0,0 +1,33 @@ +use burn_core::nn::{ + gen_sequential, Dropout, DropoutConfig, LeakyRelu, LeakyReluConfig, Linear, LinearConfig, Relu, +}; + +use burn_core as burn; + +gen_sequential! { + Relu; + DropoutConfig => Dropout, + LeakyReluConfig => LeakyRelu; + LinearConfig => Linear +} + +type TestBackend = burn_ndarray::NdArray; + +#[test] +fn sequential_should_construct() { + let cfg = SequentialConfig { + layers: vec![ + SequentialLayerConfig::Relu, + SequentialLayerConfig::Dropout(DropoutConfig { prob: 0.3 }), + SequentialLayerConfig::LeakyRelu(LeakyReluConfig { + negative_slope: 0.01, + }), + SequentialLayerConfig::Linear(LinearConfig::new(10, 10)), + ], + }; + + let device = Default::default(); + + let module: Sequential = cfg.init(&device); + assert_eq!(module.layers.len(), 4); +}