Skip to content

Commit

Permalink
Rename compile_pymc to compile
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 10, 2024
1 parent a714b24 commit 6cdfc30
Show file tree
Hide file tree
Showing 21 changed files with 88 additions and 81 deletions.
4 changes: 1 addition & 3 deletions docs/source/api/pytensorf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ PyTensor utils
.. autosummary::
:toctree: generated/

compile_pymc
compile
gradient
hessian
hessian_diag
Expand All @@ -19,6 +19,4 @@ PyTensor utils
CallableTensor
join_nonshared_inputs
make_shared_replacements
generator
convert_generator_data
convert_data
4 changes: 2 additions & 2 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from pymc.backends.report import SamplerReport
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc
from pymc.pytensorf import compile
from pymc.util import get_var_name

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(

if fn is None:
# borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables
fn = compile_pymc(
fn = compile(
inputs=[pytensor.In(v, borrow=True) for v in model.value_vars],
outputs=[pytensor.Out(v, borrow=True) for v in vars],
on_unused_input="ignore",
Expand Down
8 changes: 4 additions & 4 deletions pymc/func_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,18 @@ def find_constrained_prior(
)

target = (pt.exp(logcdf_lower) - mass_below_lower) ** 2
target_fn = pm.pytensorf.compile_pymc([dist_params], target, allow_input_downcast=True)
target_fn = pm.pytensorf.compile([dist_params], target, allow_input_downcast=True)

constraint = pt.exp(logcdf_upper) - pt.exp(logcdf_lower)
constraint_fn = pm.pytensorf.compile_pymc([dist_params], constraint, allow_input_downcast=True)
constraint_fn = pm.pytensorf.compile([dist_params], constraint, allow_input_downcast=True)

jac: str | Callable
constraint_jac: str | Callable
try:
pytensor_jac = pm.gradient(target, [dist_params])
jac = pm.pytensorf.compile_pymc([dist_params], pytensor_jac, allow_input_downcast=True)
jac = pm.pytensorf.compile([dist_params], pytensor_jac, allow_input_downcast=True)
pytensor_constraint_jac = pm.gradient(constraint, [dist_params])
constraint_jac = pm.pytensorf.compile_pymc(
constraint_jac = pm.pytensorf.compile(
[dist_params], pytensor_constraint_jac, allow_input_downcast=True
)
# when PyMC cannot compute the gradient
Expand Down
4 changes: 2 additions & 2 deletions pymc/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from scipy.cluster.vq import kmeans

from pymc.model.core import modelcontext
from pymc.pytensorf import compile_pymc
from pymc.pytensorf import compile

JITTER_DEFAULT = 1e-6

Expand Down Expand Up @@ -55,7 +55,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
if len(inputs) == 0:
return tuple(v.eval() for v in vars_needed)

fn = compile_pymc(
fn = compile(
inputs,
vars_needed,
allow_input_downcast=True,
Expand Down
4 changes: 2 additions & 2 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from pymc.logprob.transforms import Transform
from pymc.pytensorf import (
compile_pymc,
compile,
find_rng_nodes,
replace_rng_nodes,
reseed_rngs,
Expand Down Expand Up @@ -157,7 +157,7 @@ def make_initial_point_fn(
# Replace original rng shared variables so that we don't mess with them
# when calling the final seeded function
initial_values = replace_rng_nodes(initial_values)
func = compile_pymc(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
func = compile(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)

varnames = []
for var in model.free_RVs:
Expand Down
6 changes: 3 additions & 3 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pymc.pytensorf import (
PointFunc,
SeedSequenceSeed,
compile_pymc,
compile,
convert_observed_data,
gradient,
hessian,
Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(
)
inputs = grad_vars

self._pytensor_function = compile_pymc(inputs, outputs, givens=givens, **kwargs)
self._pytensor_function = compile(inputs, outputs, givens=givens, **kwargs)
self._raveled_inputs = ravel_inputs

def set_weights(self, values):
Expand Down Expand Up @@ -1637,7 +1637,7 @@ def compile_fn(
inputs = inputvars(outs)

with self:
fn = compile_pymc(
fn = compile(
inputs,
outs,
allow_input_downcast=True,
Expand Down
13 changes: 11 additions & 2 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

__all__ = [
"CallableTensor",
"compile",
"compile_pymc",
"cont_inputs",
"convert_data",
Expand Down Expand Up @@ -981,7 +982,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
return rng_updates


def compile_pymc(
def compile(
inputs,
outputs,
random_seed: SeedSequenceSeed = None,
Expand All @@ -990,7 +991,7 @@ def compile_pymc(
) -> Function:
"""Use ``pytensor.function`` with specialized pymc rewrites always enabled.
This function also ensures shared RandomState/Generator used by RandomVariables
This function also ensures shared Generator used by RandomVariables
in the graph are updated across calls, to ensure independent draws.
Parameters
Expand Down Expand Up @@ -1061,6 +1062,14 @@ def compile_pymc(
return pytensor_function


def compile_pymc(*args, **kwargs):
warnings.warn(
"compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC",
FutureWarning,
)
return compile(*args, **kwargs)


def constant_fold(
xs: Sequence[TensorVariable], raise_not_constant: bool = True
) -> tuple[np.ndarray | Variable, ...]:
Expand Down
6 changes: 3 additions & 3 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pymc.backends.base import MultiTrace
from pymc.blocking import PointType
from pymc.model import Model, modelcontext
from pymc.pytensorf import compile_pymc
from pymc.pytensorf import compile
from pymc.util import (
CustomProgress,
RandomState,
Expand Down Expand Up @@ -273,7 +273,7 @@ def expand(node):
]

return (
compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
compile(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
)

Expand Down Expand Up @@ -329,7 +329,7 @@ def draw(
if random_seed is not None:
(random_seed,) = _get_seeds_per_chain(random_seed, 1)

draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
draw_fn = compile(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)

if draws == 1:
return draw_fn()
Expand Down
4 changes: 2 additions & 2 deletions pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pymc.initial_point import make_initial_point_expression
from pymc.model import Point, modelcontext
from pymc.pytensorf import (
compile_pymc,
compile,
floatX,
join_nonshared_inputs,
make_shared_replacements,
Expand Down Expand Up @@ -636,6 +636,6 @@ def _logp_forw(point, out_vars, in_vars, shared):
out_list, inarray0 = join_nonshared_inputs(
point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared
)
f = compile_pymc([inarray0], out_list[0])
f = compile([inarray0], out_list[0])
f.trust_input = True
return f
4 changes: 2 additions & 2 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pymc.initial_point import PointType
from pymc.pytensorf import (
CallableTensor,
compile_pymc,
compile,
floatX,
join_nonshared_inputs,
replace_rng_nodes,
Expand Down Expand Up @@ -1241,6 +1241,6 @@ def delta_logp(

if compile_kwargs is None:
compile_kwargs = {}
f = compile_pymc([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
f = compile([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
f.trust_input = True
return f
4 changes: 2 additions & 2 deletions pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pymc.blocking import RaveledVars, StatsType
from pymc.initial_point import PointType
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements
from pymc.pytensorf import compile, join_nonshared_inputs, make_shared_replacements
from pymc.step_methods.arraystep import ArrayStepShared
from pymc.step_methods.compound import Competence, StepMethodState
from pymc.step_methods.state import dataclass_state
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
)
if compile_kwargs is None:
compile_kwargs = {}
self.logp = compile_pymc([raveled_inp], logp, **compile_kwargs)
self.logp = compile([raveled_inp], logp, **compile_kwargs)
self.logp.trust_input = True

super().__init__(vars, shared, blocked=blocked, rng=rng)
Expand Down
8 changes: 4 additions & 4 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
local_check_parameter_to_ninf_switch,
rvs_in_graph,
)
from pymc.pytensorf import compile_pymc, floatX, inputvars
from pymc.pytensorf import compile, floatX, inputvars

# This mode can be used for tests where model compilations takes the bulk of the runtime
# AND where we don't care about posterior numerical or sampling stability (e.g., when
Expand Down Expand Up @@ -645,7 +645,7 @@ def check_selfconsistency_discrete_logcdf(
dist_logp_fn = pytensor.function(list(inputvars(dist_logp)), dist_logp)

dist_logcdf = logcdf(dist, value)
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)
dist_logcdf_fn = compile(list(inputvars(dist_logcdf)), dist_logcdf)

domains = paramdomains.copy()
domains["value"] = domain
Expand Down Expand Up @@ -721,7 +721,7 @@ def continuous_random_tester(

model, param_vars = build_model(dist, valuedomain, paramdomains, extra_args)
model_dist = change_dist_size(model.named_vars["value"], size, expand=True)
pymc_rand = compile_pymc([], model_dist)
pymc_rand = compile([], model_dist)

domains = paramdomains.copy()
for point in product(domains, n_samples=100):
Expand Down Expand Up @@ -760,7 +760,7 @@ def discrete_random_tester(

model, param_vars = build_model(dist, valuedomain, paramdomains)
model_dist = change_dist_size(model.named_vars["value"], size, expand=True)
pymc_rand = compile_pymc([], model_dist)
pymc_rand = compile([], model_dist)

domains = paramdomains.copy()
for point in product(domains, n_samples=100):
Expand Down
10 changes: 5 additions & 5 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from pymc.model import modelcontext
from pymc.pytensorf import (
SeedSequenceSeed,
compile_pymc,
compile,
find_rng_nodes,
identity,
reseed_rngs,
Expand Down Expand Up @@ -388,9 +388,9 @@ def step_function(
)
seed = self.approx.rng.randint(2**30, dtype=np.int64)
if score:
step_fn = compile_pymc([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs)
step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs)
else:
step_fn = compile_pymc([], [], updates=updates, random_seed=seed, **fn_kwargs)
step_fn = compile([], [], updates=updates, random_seed=seed, **fn_kwargs)
return step_fn

@pytensor.config.change_flags(compute_test_value="off")
Expand Down Expand Up @@ -420,7 +420,7 @@ def score_function(
more_replacements = {}
loss = self(sc_n_mc, more_replacements=more_replacements)
seed = self.approx.rng.randint(2**30, dtype=np.int64)
return compile_pymc([], loss, random_seed=seed, **fn_kwargs)
return compile([], loss, random_seed=seed, **fn_kwargs)

@pytensor.config.change_flags(compute_test_value="off")
def __call__(self, nmc, **kwargs):
Expand Down Expand Up @@ -1517,7 +1517,7 @@ def sample_dict_fn(self):
names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs]
sampled = [self.rslice(name) for name in names]
sampled = self.set_size_and_deterministic(sampled, s, 0)
sample_fn = compile_pymc([s], sampled)
sample_fn = compile([s], sampled)
rng_nodes = find_rng_nodes(sampled)

def inner(draws=100, *, random_seed: SeedSequenceSeed = None):
Expand Down
6 changes: 3 additions & 3 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from pymc.distributions.shape_utils import change_dist_size
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import compile_pymc
from pymc.pytensorf import compile
from pymc.testing import (
BaseTestDistributionRandom,
I,
Expand Down Expand Up @@ -169,7 +169,7 @@ def update(self, node):
outputs=[dummy_next_rng, dummy_x],
ndim_supp=0,
)(rng)
fn = compile_pymc(inputs=[], outputs=x, random_seed=431)
fn = compile(inputs=[], outputs=x, random_seed=431)
assert fn() != fn()

# Check that custom updates are respected, by using one that's broken
Expand All @@ -182,7 +182,7 @@ def update(self, node):
ValueError,
match="No update found for at least one RNG used in SymbolicRandomVariable Op SymbolicRVCustomUpdates",
):
compile_pymc(inputs=[], outputs=x, random_seed=431)
compile(inputs=[], outputs=x, random_seed=431)

def test_recreate_with_different_rng_inputs(self):
"""Test that we can recreate a SymbolicRandomVariable with new RNG inputs.
Expand Down
4 changes: 2 additions & 2 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pymc.logprob.basic import logp
from pymc.logprob.utils import ParameterValueError
from pymc.math import kronecker
from pymc.pytensorf import compile_pymc, floatX
from pymc.pytensorf import compile, floatX
from pymc.sampling.forward import draw
from pymc.testing import (
BaseTestDistributionRandom,
Expand Down Expand Up @@ -168,7 +168,7 @@ def stickbreakingweights_logpdf():
_alpha = pt.scalar()
_k = pt.iscalar()
_logp = logp(pm.StickBreakingWeights.dist(_alpha, _k), _value)
core_fn = compile_pymc([_value, _alpha, _k], _logp)
core_fn = compile([_value, _alpha, _k], _logp)

return np.vectorize(core_fn, signature="(n),(),()->()")

Expand Down
4 changes: 2 additions & 2 deletions tests/distributions/test_shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_size_from_dims_rng_update(self):
with pm.Model(coords={"x_dim": range(2)}):
x = pm.Normal("x", dims=("x_dim",))

fn = pm.pytensorf.compile_pymc([], x)
fn = pm.pytensorf.compile([], x)
# Check that both function outputs (rng and draws) come from the same Apply node
assert fn.maker.fgraph.outputs[0].owner is fn.maker.fgraph.outputs[1].owner

Expand All @@ -341,7 +341,7 @@ def test_size_from_observed_rng_update(self):
with pm.Model():
x = pm.Normal("x", observed=[0, 1])

fn = pm.pytensorf.compile_pymc([], x)
fn = pm.pytensorf.compile([], x)
# Check that both function outputs (rng and draws) come from the same Apply node
assert fn.maker.fgraph.outputs[0].owner is fn.maker.fgraph.outputs[1].owner

Expand Down
Loading

0 comments on commit 6cdfc30

Please sign in to comment.