diff --git a/liesel/model/nodes.py b/liesel/model/nodes.py index bc86e3f..b02f7b6 100644 --- a/liesel/model/nodes.py +++ b/liesel/model/nodes.py @@ -756,8 +756,9 @@ def __init__( ): super().__init__(*inputs, **kwinputs, _name=_name, _needs_seed=_needs_seed) self._function = function + self.update_on_init = update_on_init - if update_on_init: + if self.update_on_init: try: self.update() except Exception as e: @@ -1596,8 +1597,8 @@ def transform(original_var: lsl.Var, bijector: tfb.Bijector): >>> scale.update().log_prob 0.0 """ - if self.weak: - raise RuntimeError(f"{repr(self)} is weak") + # if self.weak: + # raise RuntimeError(f"{repr(self)} is weak") if is_bijector_class(bijector) and not (bijector_args or bijector_kwargs): raise ValueError( @@ -1971,11 +1972,43 @@ def transform_dist(*args, **kwargs): transformed_dist.per_obs = var.dist_node.per_obs - transformed_var = Var( - bijector_inv.forward(var.value), - transformed_dist, - name=f"{var.name}_transformed", - ) + if var.weak: + try: + value_function = var.value_node.function # type: ignore + except AttributeError as e: + raise AttributeError( + "Trying to transform a weak variable without calculator node." + ) from e + + def forward(*args, **kwargs): + return bijector_inv.forward(value_function(*args, **kwargs)) + + value_inputs = var.value_node.inputs + value_kwinputs = var.value_node.kwinputs + value_node_needs_seed = var.value_node.needs_seed + try: + value_node_upadte_on_init = var.value_node.update_on_init # type: ignore + except AttributeError as e: + raise e + + transformed_var = Var( + Calc( + forward, + *value_inputs, + _name="", + _needs_seed=value_node_needs_seed, + update_on_init=value_node_upadte_on_init, + **value_kwinputs, + ), + transformed_dist, + name=f"{var.name}_transformed", + ) + else: + transformed_var = Var( + bijector_inv.forward(var.value), + transformed_dist, + name=f"{var.name}_transformed", + ) var.value_node = Calc(bijector_inst.forward, transformed_var) return transformed_var @@ -2026,11 +2059,43 @@ def transform_dist(dist_args: ArgGroup, bijector_args: ArgGroup): bijector_inv = dist_node_transformed.init_dist().bijector - transformed_var = Var( - bijector_inv.forward(var.value), - dist_node_transformed, - name=f"{var.name}_transformed", - ) + if var.weak: + try: + value_function = var.value_node.function # type: ignore + except AttributeError as e: + raise AttributeError( + "Trying to transform a weak variable without calculator node." + ) from e + + def forward(*args, **kwargs): + return bijector_inv.forward(value_function(*args, **kwargs)) + + value_inputs = var.value_node.inputs + value_kwinputs = var.value_node.kwinputs + value_node_needs_seed = var.value_node.needs_seed + try: + value_node_upadte_on_init = var.value_node.update_on_init # type: ignore + except AttributeError as e: + raise e + + transformed_var = Var( + Calc( + forward, + *value_inputs, + _name="", + _needs_seed=value_node_needs_seed, + update_on_init=value_node_upadte_on_init, + **value_kwinputs, + ), + dist_node_transformed, + name=f"{var.name}_transformed", + ) + else: + transformed_var = Var( + bijector_inv.forward(var.value), + dist_node_transformed, + name=f"{var.name}_transformed", + ) def bijector_fn(value, dist_inputs, bijector_inputs): bijector = transform_dist(dist_inputs, bijector_inputs).bijector @@ -2044,10 +2109,17 @@ def bijector_fn(value, dist_inputs, bijector_inputs): def _transform_var_without_dist_with_bijector_instance( var: Var, bijector_inst: jb.Bijector ) -> Var: - transformed_var = Var( - bijector_inst.inverse(var.value), - name=f"{var.name}_transformed", - ) + if var.strong: + transformed_var = Var( + bijector_inst.inverse(var.value), + name=f"{var.name}_transformed", + ) + else: + transformed_var = Var.new_calc( + bijector_inst.inverse, + var.value_node, + name=f"{var.name}_transformed", + ) var.value_node = Calc(bijector_inst.forward, transformed_var) @@ -2084,10 +2156,19 @@ def bijection_forward(x, *bjargs, **bjkwargs): bijector_inst = bijector_cls(*bjargs, **bjkwargs) return bijector_inst(x) - transformed_var = Var( - bijection_inverse(var.value, *args, **kwargs), - name=f"{var.name}_transformed", - ) + if var.strong: + transformed_var = Var( + bijection_inverse(var.value, *args, **kwargs), + name=f"{var.name}_transformed", + ) + else: + transformed_var = Var.new_calc( + bijection_inverse, + var.value_node, + *args, + **kwargs, + name=f"{var.name}_transformed", + ) var.value_node = Calc(bijection_forward, transformed_var, *args, **kwargs) diff --git a/tests/model/test_var.py b/tests/model/test_var.py index 95e4a56..f519ca1 100644 --- a/tests/model/test_var.py +++ b/tests/model/test_var.py @@ -556,6 +556,86 @@ def test_new_const(self): class TestVarTransform: + def test_transform_weak_var_with_distribtution_inst(self) -> None: + """ + Tests transformation of a weak var with distribution when the bijector is passed + as an instance. + """ + x = lnodes.Var.new_value(jnp.linspace(0.1, 2, 5), name="all_x") + batch_index = lnodes.Var.new_value(1, name="index") + + x_batched = lnodes.Var( + value=lnodes.Calc(lambda i, x: x[i], batch_index, x), + distribution=lnodes.Dist(tfp.distributions.Normal, loc=0.0, scale=1.0), + name="x_batched", + ) + + x_batched_transformed = x_batched.transform(tfp.bijectors.Exp()) + + assert x_batched_transformed.value == pytest.approx(jnp.log(x.value[1])) + + batch_index.value = 2 + x_batched_transformed.update() + x_batched.update() + assert x_batched_transformed.value == pytest.approx(jnp.log(x.value[2])) + + def test_transform_weak_var_with_distribtution_class(self) -> None: + """ + Tests transformation of a weak var with distribution when the bijector is passed + as a class. + """ + x = lnodes.Var.new_value(jnp.linspace(0.1, 2, 5), name="all_x") + batch_index = lnodes.Var.new_value(1, name="index") + + x_batched = lnodes.Var( + value=lnodes.Calc(lambda i, x: x[i], i=batch_index, x=x), + distribution=lnodes.Dist(tfp.distributions.Normal, loc=0.0, scale=1.0), + name="x_batched", + ) + + x_batched_transformed = x_batched.transform(tfp.bijectors.Scale, scale=2.0) + + assert x_batched_transformed.value == pytest.approx(x.value[1] / 2.0) + + batch_index.value = 2 + x_batched_transformed.update() + x_batched.update() + assert x_batched_transformed.value == pytest.approx(x.value[2] / 2.0) + + def test_transform_weak_var_with_bijector_instance(self) -> None: + tau = lnodes.Var.new_param(10.0, name="tau") + tau_sqrt = lnodes.Var.new_calc(jnp.sqrt, tau) + log_tau_sqrt = tau_sqrt.transform(tfp.bijectors.Exp()) + + assert tau.value == pytest.approx(10.0) + assert tau_sqrt.value == pytest.approx(jnp.sqrt(10.0)) + assert log_tau_sqrt.value == pytest.approx(jnp.log(jnp.sqrt(10.0))) + + assert tau.strong + assert tau_sqrt.weak + assert log_tau_sqrt.weak + assert tau.parameter + assert not log_tau_sqrt.parameter + assert not tau_sqrt.parameter + + def test_transform_weak_var_with_bijector_class(self) -> None: + tau = lnodes.Var.new_param(10.0, name="tau") + tau_sqrt = lnodes.Var.new_calc(jnp.sqrt, tau) + + scale = lnodes.Var.new_param(2.0, name="bijector_scale") + scaled_tau_sqrt = tau_sqrt.transform(tfp.bijectors.Scale, scale=scale) + + assert tau.value == pytest.approx(10.0) + assert tau_sqrt.value == pytest.approx(jnp.sqrt(10.0)) + assert scaled_tau_sqrt.value == pytest.approx(jnp.sqrt(10.0) / 2) + + assert tau.strong + assert tau_sqrt.weak + assert scaled_tau_sqrt.weak + assert tau.parameter + assert not scaled_tau_sqrt.parameter + assert not tau_sqrt.parameter + def test_transform_without_dist_with_bijector_instance(self) -> None: tau = lnodes.Var.new_param(10.0, name="tau") log_tau = tau.transform(tfp.bijectors.Exp())