Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make transform a method #174

Merged
merged 16 commits into from
Sep 25, 2024
Merged

Make transform a method #174

merged 16 commits into from
Sep 25, 2024

Conversation

jobrachem
Copy link
Contributor

@jobrachem jobrachem commented Dec 27, 2023

This PR introduces a Var.transform method. Some notes:

Replacing GraphBuilder.transform

We currently have GraphBuilder.transform. The new method is in fact intended as a replacement for GraphBuilder.transform. Having Var.transform as a method on Var has the advantage that it is easier to find for users. Since the transformation is "doing something to a Var" and it does not actually require any functionality within the GraphBuilder, living as a method on a Var is a natural development for the transform method.

Behavior change

The method behaves similar to GraphBuilder.transform, with a few notable differences:

  1. It will not use the default event space bijector from tensorflow by default. This change is made to encourage users to either select their desired bijector manually, which is often sensible, or to request an automatic bijector manually. In both cases, users are more aware of what they are doing.
  2. The new method accepts bijector instances in addition to bijector classes. In fact, passing an instance is the preferred way of passing a bijector. Passing a bijector class is only supported if you actually defined *bijector_args or *bijector_kwargs to be passed to the bijector. This simplifies the code for the default case. More importantly, this fixes the graph representation after transformation, see below.

Using a bijector instance:

image

Compare to using the default event space bijector. Note that in this case, the same bijector is being used, but there are spurious edges from the nodes "v0" and "v1", the prior parameters, to the original variable "tau".

image

Deprecation and updated documentation

  1. I marked GraphBuilder.transform as deprecated and included directions towards the new method in its documentation.
  2. I updated the usage of the transformation method in the tutorials 01a-transform.md and 04-mcycle.md. As it turned out, usage in these tutorials was outdated anyway.

Notebook for testing

You can play around with the method in this notebook:

050-transform.ipynb.zip

Related issues

@jobrachem jobrachem added enhancement New feature or request comp:model This issue is related to the model module labels Dec 27, 2023
@jobrachem jobrachem self-assigned this Dec 27, 2023
@jobrachem
Copy link
Contributor Author

A log message on the default bijector being used would be nice.

@jobrachem
Copy link
Contributor Author

@jobrachem will transfer the core benefits of the Var.transform method to the GraphBuilder.transform: 1) accepting a bijector instance, 2) logging the default event space bijector

@jobrachem
Copy link
Contributor Author

@jobrachem will continue as follows:

  1. Implement Var.transform
  2. Make API between Var.transform and GraphBuilder.transform consistent
  3. Add a note to the docs of GraphBuilder.transform, referring to Var.transform

@jobrachem jobrachem force-pushed the make-transform-a-method branch from 29f0749 to 75c7383 Compare August 14, 2024 20:17
@jobrachem
Copy link
Contributor Author

I now harmonized the API between Var.transform and GraphBuilder.transform. Namely:

  1. Var.transform now has the same default behavior as GraphBuilder.transform, i.e. it tries to use the default event space bijector from tensorflow probability.
  2. GraphBuilder.transform now also accepts a bijector instance.

There is one small difference in how the methods handle it, if the user passes a bijector class: Var.transform will raise a RuntimeError, if you pass a bijector class, but do not actually use any arguments for the bijector. GraphBuilder.transform will emit a UserWarning instead, in order to avoid breaking existing code.

@jobrachem jobrachem marked this pull request as ready for review August 14, 2024 22:50
@jobrachem jobrachem requested a review from wiep August 28, 2024 14:35
@jobrachem
Copy link
Contributor Author

@wiep gentle reminder :)

Copy link
Contributor

@wiep wiep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work.

I made a few suggestions to make the if statement better comprehensible. I think I would handle the bijector is None case such that I would cast it into a bijector is instance of BijectorType case. Than you need to handle only two cases. See annotations below.

However, both is a bit of a style question. Should you prefer the current version. I'm also ok with that.

Also, please feel free to revert the commit changing doc strings if you do not agree (dfb5217).

liesel/model/model.py Outdated Show resolved Hide resolved
liesel/model/model.py Outdated Show resolved Hide resolved
liesel/model/nodes.py Show resolved Hide resolved
liesel/model/nodes.py Outdated Show resolved Hide resolved
liesel/model/nodes.py Outdated Show resolved Hide resolved
liesel/model/nodes.py Show resolved Hide resolved
@jobrachem
Copy link
Contributor Author

jobrachem commented Sep 25, 2024

Hey @wiep, thanks a lot for your suggestions! I implemented them with only slight deviations. Your change to the docstring is fine with me, too 👍

Your comment about removing the code duplication also alerted me to two issues:

  1. Var.transform lacked a test for whether the default event space bijector exists. I added that test.
  2. There was a subtle error in my implementation of using the default event space bijector. In the case of using the default bijector, we want to account for the possibility that the default bijector can change depending on the distribution's arguments. So the default bijector has to be obtained and initialized anew for every .update() call. This was not the case in my previous implementation. Not it is, so this problem is fixed.

@jobrachem jobrachem merged commit 255d8a7 into main Sep 25, 2024
4 checks passed
@jobrachem jobrachem deleted the make-transform-a-method branch September 25, 2024 13:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:model This issue is related to the model module enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add transformation method to lsl.Var Refactor GraphBuilder.transform() method
2 participants