-
Notifications
You must be signed in to change notification settings - Fork 478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sequential Macro #2565
Open
ImTheSquid
wants to merge
3
commits into
tracel-ai:main
Choose a base branch
from
ImTheSquid:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Sequential Macro #2565
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/// Create a sequential neural network, similar to numpy's nn.Sequential. | ||
/// | ||
/// To use this macro, separate your modules into three categories: | ||
/// - Unit modules: Modules that don't take any parameters (eg. Relu, Sigmoid) | ||
/// - Modules: Modules that take parameters, but don't have a backend parameter (eg. Dropout, LeakyRelu) | ||
/// - Backend modules: Modules that take a backend parameter (eg. Linear) | ||
/// | ||
/// List these classes of modules as comma-separated within classes, then semicolons between, like so: | ||
/// ```ignore | ||
/// gen_sequential! { | ||
/// // No config | ||
/// Relu, | ||
/// Sigmoid; | ||
/// // Has config | ||
/// DropoutConfig => Dropout, | ||
/// LeakyReluConfig => LeakyRelu; | ||
/// // Requires a backend (<B>) | ||
/// LinearConfig => Linear | ||
/// } | ||
/// ``` | ||
/// | ||
/// If there aren't any members of a particular class, the semicolon is still needed: | ||
/// ```ignore | ||
/// gen_sequential! { | ||
/// Relu, | ||
/// Sigmoid; | ||
/// // Nothing with no config | ||
/// ; | ||
/// LinearConfig => Linear | ||
/// } | ||
/// ``` | ||
/// | ||
/// To use this macro, use the types `SequentialConfig` and `Sequential<B>` in your code. | ||
#[macro_export] | ||
macro_rules! gen_sequential { | ||
($($unit:tt),*; $($cfg:ty => $module:tt),*; $($bcfg:ty => $bmodule:tt),*) => { | ||
#[derive(Debug, burn::config::Config)] | ||
pub enum SequentialLayerConfig { | ||
$($unit,)* | ||
$($module($cfg),)* | ||
$($bmodule($bcfg),)* | ||
} | ||
|
||
#[derive(Debug, burn::config::Config)] | ||
pub struct SequentialConfig { | ||
pub layers: Vec<SequentialLayerConfig> | ||
} | ||
|
||
impl SequentialConfig { | ||
pub fn init<B: burn::prelude::Backend>(&self, device: &B::Device) -> Sequential<B> { | ||
Sequential { | ||
layers: self.layers.iter().map(|l| match l { | ||
$(SequentialLayerConfig::$unit => SequentialLayer::$unit($unit),)* | ||
$(SequentialLayerConfig::$module(c) => SequentialLayer::$module(c.init()),)* | ||
$(SequentialLayerConfig::$bmodule(c) => SequentialLayer::$bmodule(c.init(device)),)* | ||
}).collect() | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, burn::module::Module)] | ||
pub enum SequentialLayer<B: burn::prelude::Backend> { | ||
/// In case the expansion doesn't use any backend-based layers. This should never be used. | ||
_PhantomData(::core::marker::PhantomData<B>), | ||
$($unit($unit),)* | ||
$($module($module),)* | ||
$($bmodule($bmodule<B>),)* | ||
} | ||
|
||
#[derive(Debug, burn::module::Module)] | ||
pub struct Sequential<B: burn::prelude::Backend> { | ||
pub layers: Vec<SequentialLayer<B>> | ||
} | ||
|
||
impl<B: burn::prelude::Backend> Sequential<B> { | ||
pub fn forward<const D: usize>(&self, mut input: burn::tensor::Tensor<B, D>) -> burn::tensor::Tensor<B, D> { | ||
for layer in &self.layers { | ||
input = match layer { | ||
SequentialLayer::_PhantomData(_) => unreachable!("PhantomData should never be instantiated"), | ||
$(SequentialLayer::$unit(u) => u.forward(input),)* | ||
$(SequentialLayer::$module(m) => m.forward(input),)* | ||
$(SequentialLayer::$bmodule(b) => b.forward(input),)* | ||
}; | ||
} | ||
|
||
input | ||
} | ||
} | ||
} | ||
} | ||
|
||
pub use gen_sequential; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
use burn_core::nn::{ | ||
gen_sequential, Dropout, DropoutConfig, LeakyRelu, LeakyReluConfig, Linear, LinearConfig, Relu, | ||
}; | ||
|
||
use burn_core as burn; | ||
|
||
gen_sequential! { | ||
Relu; | ||
DropoutConfig => Dropout, | ||
LeakyReluConfig => LeakyRelu; | ||
LinearConfig => Linear | ||
} | ||
|
||
type TestBackend = burn_ndarray::NdArray; | ||
|
||
#[test] | ||
fn sequential_should_construct() { | ||
let cfg = SequentialConfig { | ||
layers: vec![ | ||
SequentialLayerConfig::Relu, | ||
SequentialLayerConfig::Dropout(DropoutConfig { prob: 0.3 }), | ||
SequentialLayerConfig::LeakyRelu(LeakyReluConfig { | ||
negative_slope: 0.01, | ||
}), | ||
SequentialLayerConfig::Linear(LinearConfig::new(10, 10)), | ||
], | ||
}; | ||
|
||
let device = Default::default(); | ||
|
||
let module: Sequential<TestBackend> = cfg.init(&device); | ||
assert_eq!(module.layers.len(), 4); | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use pattern matching to differentiate modules that require a backend and config:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this would be my ideal method for doing this, it would require me to rewrite it as a proc macro since there isn't a great way to differentiate depending on whether a value is present or not (specifically the whole
($cfg$(; $($generic),+))?
) in declarative macros. Separating into multiple blocks allows me to know that the structs in that block need to have.init()
called if they have a config or.init(device)
called if they have a backend-dependent config without needing actual Rust code to differentiate them. A proc macro would fix this but would also require a new crate just for the macro and some more advanced parsing techniques.Additionally, how would multiple generics work? Do all generics need to be unique? If I define
A
as generic acrossCustom2
andCustom3
, is it the same generic? I assume it would be. We would probably designateB
as reserved and meaning "needs a device passed to it on initialization".Should I look into rewriting this as a more complicated but more flexible proc macro or should I keep the simpler but slightly more rigid declarative macro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I actually don't think the macro helps much, and I don't think we should implement a proc macro. Maybe the real solution would be to implement a trait
Forward
instead of simply having a method. We could then support tuples as sequential layers. TheForward
trait would be totaly decoupled from theModule
trait and only used to simplify composing multipleforward
methods.