Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the principled initialisation #484

Open
mrTsjolder opened this issue Jan 16, 2024 · 4 comments
Open

Implement the principled initialisation #484

mrTsjolder opened this issue Jan 16, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@mrTsjolder
Copy link

Is your feature request related to a problem? Please describe.
Training input-convex neural networks can be slow.

Describe the solution you'd like
Good initialisations can accelerate learning in a variety of scenarios.
I published a paper at NeurIPS that proposes a principled weight initialisation for input-convex networks:
https://openreview.net/forum?id=pWZ97hUQtQ

Describe alternatives you've considered
@marcocuturi came to visit my poster and mentioned an alternative approach is currently being used for initialisation in OTT.
However, he acknowledged that it might be a good idea to implement my principled initialisation in OTT.

Additional context
I might have some time to implement my proposed initialisation in this framework.
My main question would be whether/where you want to have this initialiser in the framework.

@michalk8 michalk8 added the enhancement New feature or request label Jan 16, 2024
@michalk8
Copy link
Collaborator

Hi @mrTsjolder , and congrats on the acceptance!
Yes, we'd be happy if you'd add this initialization. I think for now, the simplest thing to do would be to add initializers.py in src/ott/neural. There's an ongoing PR (#468) that further restructures the neural module (+ add many new functionalities), so the placement will most likely change later.

For inspiration, here's how the Gaussian initialization from here is implemented.

@marcocuturi
Copy link
Contributor

thanks a lot @mrTsjolder ! yes, we're definitely very interested in any principled ICNN init! you might have noticed our most recent ICNN has a diagonal quadratic + low rank block, i think they're interesting in their own right

@mrTsjolder
Copy link
Author

I am having a bit of a problem with the implementation:
There seems to be no way to couple the initialisation of weights and biases in flax.
Concretely, I would need access to the fan-in when initialising the bias parameters.

I could think of workarounds, but these all seem to require some sort of redesign of the code/API, which is probably not desired.
One idea I had was to use an initialisation builder function that requires the fan-in as input and returns two initialisation functions that can be passed to flax.linen.Module.param.

Does anyone have suggestions for possible workarounds?
Or is it okay for me to mess with the interface? If yes, to what extent?

PS: sorry for the late reply, but something else got in my way...

@michalk8
Copy link
Collaborator

michalk8 commented Apr 3, 2024

Hi @mrTsjolder , sorry for the delayed response. Will re-read your paper and come up with ideas how to best implement it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants