From 6bd5a5a61e47ddb0863819ecf040f0b101887c58 Mon Sep 17 00:00:00 2001 From: Johannes Brachem <37882800+jobrachem@users.noreply.github.com> Date: Mon, 16 Dec 2024 00:01:41 +0100 Subject: [PATCH] allow transformation of weak vars --- liesel/model/nodes.py | 36 ++++++++++++++++++++++++++---------- tests/model/test_var.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 10 deletions(-) diff --git a/liesel/model/nodes.py b/liesel/model/nodes.py index bc86e3f..e856b7b 100644 --- a/liesel/model/nodes.py +++ b/liesel/model/nodes.py @@ -1596,8 +1596,8 @@ def transform(original_var: lsl.Var, bijector: tfb.Bijector): >>> scale.update().log_prob 0.0 """ - if self.weak: - raise RuntimeError(f"{repr(self)} is weak") + # if self.weak: + # raise RuntimeError(f"{repr(self)} is weak") if is_bijector_class(bijector) and not (bijector_args or bijector_kwargs): raise ValueError( @@ -2044,10 +2044,17 @@ def bijector_fn(value, dist_inputs, bijector_inputs): def _transform_var_without_dist_with_bijector_instance( var: Var, bijector_inst: jb.Bijector ) -> Var: - transformed_var = Var( - bijector_inst.inverse(var.value), - name=f"{var.name}_transformed", - ) + if var.strong: + transformed_var = Var( + bijector_inst.inverse(var.value), + name=f"{var.name}_transformed", + ) + else: + transformed_var = Var.new_calc( + bijector_inst.inverse, + var.value_node, + name=f"{var.name}_transformed", + ) var.value_node = Calc(bijector_inst.forward, transformed_var) @@ -2084,10 +2091,19 @@ def bijection_forward(x, *bjargs, **bjkwargs): bijector_inst = bijector_cls(*bjargs, **bjkwargs) return bijector_inst(x) - transformed_var = Var( - bijection_inverse(var.value, *args, **kwargs), - name=f"{var.name}_transformed", - ) + if var.strong: + transformed_var = Var( + bijection_inverse(var.value, *args, **kwargs), + name=f"{var.name}_transformed", + ) + else: + transformed_var = Var.new_calc( + bijection_inverse, + var.value_node, + *args, + **kwargs, + name=f"{var.name}_transformed", + ) var.value_node = Calc(bijection_forward, transformed_var, *args, **kwargs) diff --git a/tests/model/test_var.py b/tests/model/test_var.py index 95e4a56..631dae5 100644 --- a/tests/model/test_var.py +++ b/tests/model/test_var.py @@ -556,6 +556,40 @@ def test_new_const(self): class TestVarTransform: + def test_transform_weak_var_with_bijector_instance(self) -> None: + tau = lnodes.Var.new_param(10.0, name="tau") + tau_sqrt = lnodes.Var.new_calc(jnp.sqrt, tau) + log_tau_sqrt = tau_sqrt.transform(tfp.bijectors.Exp()) + + assert tau.value == pytest.approx(10.0) + assert tau_sqrt.value == pytest.approx(jnp.sqrt(10.0)) + assert log_tau_sqrt.value == pytest.approx(jnp.log(jnp.sqrt(10.0))) + + assert tau.strong + assert tau_sqrt.weak + assert log_tau_sqrt.weak + assert tau.parameter + assert not log_tau_sqrt.parameter + assert not tau_sqrt.parameter + + def test_transform_weak_var_with_bijector_class(self) -> None: + tau = lnodes.Var.new_param(10.0, name="tau") + tau_sqrt = lnodes.Var.new_calc(jnp.sqrt, tau) + + scale = lnodes.Var.new_param(2.0, name="bijector_scale") + scaled_tau_sqrt = tau_sqrt.transform(tfp.bijectors.Scale, scale=scale) + + assert tau.value == pytest.approx(10.0) + assert tau_sqrt.value == pytest.approx(jnp.sqrt(10.0)) + assert scaled_tau_sqrt.value == pytest.approx(jnp.sqrt(10.0) / 2) + + assert tau.strong + assert tau_sqrt.weak + assert scaled_tau_sqrt.weak + assert tau.parameter + assert not scaled_tau_sqrt.parameter + assert not tau_sqrt.parameter + def test_transform_without_dist_with_bijector_instance(self) -> None: tau = lnodes.Var.new_param(10.0, name="tau") log_tau = tau.transform(tfp.bijectors.Exp())