Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequential Macro #2565

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/burn-core/src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod prelu;
mod relu;
mod rnn;
mod rope_encoding;
mod sequential;
mod sigmoid;
mod swiglu;
mod tanh;
Expand All @@ -52,6 +53,7 @@ pub use prelu::*;
pub use relu::*;
pub use rnn::*;
pub use rope_encoding::*;
pub use sequential::*;
pub use sigmoid::*;
pub use swiglu::*;
pub use tanh::*;
Expand Down
92 changes: 92 additions & 0 deletions crates/burn-core/src/nn/sequential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/// Create a sequential neural network, similar to numpy's nn.Sequential.
///
/// To use this macro, separate your modules into three categories:
/// - Unit modules: Modules that don't take any parameters (eg. Relu, Sigmoid)
/// - Modules: Modules that take parameters, but don't have a backend parameter (eg. Dropout, LeakyRelu)
/// - Backend modules: Modules that take a backend parameter (eg. Linear)
///
/// List these classes of modules as comma-separated within classes, then semicolons between, like so:
/// ```ignore
/// gen_sequential! {
/// // No config
/// Relu,
/// Sigmoid;
/// // Has config
/// DropoutConfig => Dropout,
/// LeakyReluConfig => LeakyRelu;
/// // Requires a backend (<B>)
/// LinearConfig => Linear
/// }
/// ```
Comment on lines +10 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use pattern matching to differentiate modules that require a backend and config:

Sequential!(
    Relu, // Without config
    Dropout(DropoutConfig), // With config
    Linear(LinearConfig; B), // With config + Generics
    Custom(; B), // No config + Generics
    Custom2(config; B, A, C), // With config + many generics
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this would be my ideal method for doing this, it would require me to rewrite it as a proc macro since there isn't a great way to differentiate depending on whether a value is present or not (specifically the whole ($cfg$(; $($generic),+))?) in declarative macros. Separating into multiple blocks allows me to know that the structs in that block need to have .init() called if they have a config or .init(device) called if they have a backend-dependent config without needing actual Rust code to differentiate them. A proc macro would fix this but would also require a new crate just for the macro and some more advanced parsing techniques.

Additionally, how would multiple generics work? Do all generics need to be unique? If I define A as generic across Custom2 and Custom3, is it the same generic? I assume it would be. We would probably designate B as reserved and meaning "needs a device passed to it on initialization".

Should I look into rewriting this as a more complicated but more flexible proc macro or should I keep the simpler but slightly more rigid declarative macro?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I actually don't think the macro helps much, and I don't think we should implement a proc macro. Maybe the real solution would be to implement a trait Forward instead of simply having a method. We could then support tuples as sequential layers. The Forward trait would be totaly decoupled from the Module trait and only used to simplify composing multiple forward methods.

///
/// If there aren't any members of a particular class, the semicolon is still needed:
/// ```ignore
/// gen_sequential! {
/// Relu,
/// Sigmoid;
/// // Nothing with no config
/// ;
/// LinearConfig => Linear
/// }
/// ```
///
/// To use this macro, use the types `SequentialConfig` and `Sequential<B>` in your code.
#[macro_export]
macro_rules! gen_sequential {
($($unit:tt),*; $($cfg:ty => $module:tt),*; $($bcfg:ty => $bmodule:tt),*) => {
#[derive(Debug, burn::config::Config)]
pub enum SequentialLayerConfig {
$($unit,)*
$($module($cfg),)*
$($bmodule($bcfg),)*
}

#[derive(Debug, burn::config::Config)]
pub struct SequentialConfig {
pub layers: Vec<SequentialLayerConfig>
}

impl SequentialConfig {
pub fn init<B: burn::prelude::Backend>(&self, device: &B::Device) -> Sequential<B> {
Sequential {
layers: self.layers.iter().map(|l| match l {
$(SequentialLayerConfig::$unit => SequentialLayer::$unit($unit),)*
$(SequentialLayerConfig::$module(c) => SequentialLayer::$module(c.init()),)*
$(SequentialLayerConfig::$bmodule(c) => SequentialLayer::$bmodule(c.init(device)),)*
}).collect()
}
}
}

#[derive(Debug, burn::module::Module)]
pub enum SequentialLayer<B: burn::prelude::Backend> {
/// In case the expansion doesn't use any backend-based layers. This should never be used.
_PhantomData(::core::marker::PhantomData<B>),
$($unit($unit),)*
$($module($module),)*
$($bmodule($bmodule<B>),)*
}

#[derive(Debug, burn::module::Module)]
pub struct Sequential<B: burn::prelude::Backend> {
pub layers: Vec<SequentialLayer<B>>
}

impl<B: burn::prelude::Backend> Sequential<B> {
pub fn forward<const D: usize>(&self, mut input: burn::tensor::Tensor<B, D>) -> burn::tensor::Tensor<B, D> {
for layer in &self.layers {
input = match layer {
SequentialLayer::_PhantomData(_) => unreachable!("PhantomData should never be instantiated"),
$(SequentialLayer::$unit(u) => u.forward(input),)*
$(SequentialLayer::$module(m) => m.forward(input),)*
$(SequentialLayer::$bmodule(b) => b.forward(input),)*
};
}

input
}
}
}
}

pub use gen_sequential;
33 changes: 33 additions & 0 deletions crates/burn-core/tests/test_gen_sequential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use burn_core::nn::{
gen_sequential, Dropout, DropoutConfig, LeakyRelu, LeakyReluConfig, Linear, LinearConfig, Relu,
};

use burn_core as burn;

gen_sequential! {
Relu;
DropoutConfig => Dropout,
LeakyReluConfig => LeakyRelu;
LinearConfig => Linear
}

type TestBackend = burn_ndarray::NdArray;

#[test]
fn sequential_should_construct() {
let cfg = SequentialConfig {
layers: vec![
SequentialLayerConfig::Relu,
SequentialLayerConfig::Dropout(DropoutConfig { prob: 0.3 }),
SequentialLayerConfig::LeakyRelu(LeakyReluConfig {
negative_slope: 0.01,
}),
SequentialLayerConfig::Linear(LinearConfig::new(10, 10)),
],
};

let device = Default::default();

let module: Sequential<TestBackend> = cfg.init(&device);
assert_eq!(module.layers.len(), 4);
}