Skip to content

Commit

Permalink
Fix latex repr of symbolic distributions (#6231)
Browse files Browse the repository at this point in the history
* Fix latex repr of nested variables
* Do not put \operatorname inside \text
* Remove inner $ symbols from latex representations
* remove them from the string borders since they always represent the math environment
* fix tests to check the correct behavior
* Add test for a full model that used to be rendered wrong
* Use "\\operatorname{" instead of "operatorname" to determine if the latex command is used
* Update tests to new distributions' notation
  • Loading branch information
mattiadg authored Oct 30, 2022
1 parent 9105d74 commit ea721e4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
27 changes: 20 additions & 7 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def str_for_dist(

if "latex" in formatting:
if print_name is not None:
print_name = r"\text{" + _latex_escape(dist.name) + "}"
print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}"

op_name = (
dist.owner.op._print_name[1]
Expand All @@ -67,9 +67,11 @@ def str_for_dist(
)
if include_params:
if print_name:
return r"${} \sim {}({})$".format(print_name, op_name, ",~".join(dist_args))
return r"${} \sim {}({})$".format(
print_name, op_name, ",~".join([d.strip("$") for d in dist_args])
)
else:
return r"${}({})$".format(op_name, ",~".join(dist_args))
return r"${}({})$".format(op_name, ",~".join([d.strip("$") for d in dist_args]))

else:
if print_name:
Expand Down Expand Up @@ -138,7 +140,7 @@ def str_for_potential_or_deterministic(
LaTeX or plain, optionally with distribution parameter values included."""
print_name = var.name if var.name is not None else "<unnamed>"
if "latex" in formatting:
print_name = r"\text{" + _latex_escape(print_name) + "}"
print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}"
if include_params:
return rf"${print_name} \sim \operatorname{{{dist_name}}}({_str_for_expression(var, formatting=formatting)})$"
else:
Expand Down Expand Up @@ -182,7 +184,7 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
else str_for_dist(var, formatting=formatting, include_params=True)
)
if "latex" in formatting:
return r"\text{" + _latex_escape(_str) + "}"
return _latex_text_format(_latex_escape(_str.strip("$")))
else:
return _str

Expand Down Expand Up @@ -215,9 +217,20 @@ def _expand(x):
names = [x.name for x in parents]

if "latex" in formatting:
return r"f(" + ",~".join([r"\text{" + _latex_escape(n) + "}" for n in names]) + ")"
return (
r"f("
+ ",~".join([_latex_text_format(_latex_escape(n.strip("$"))) for n in names])
+ ")"
)
else:
return r"f(" + ", ".join([n.strip("$") for n in names]) + ")"


def _latex_text_format(text: str) -> str:
if r"\operatorname{" in text:
return text
else:
return r"f(" + ", ".join(names) + ")"
return r"\text{" + text + "}"


def _latex_escape(text: str) -> str:
Expand Down
48 changes: 44 additions & 4 deletions pymc/tests/test_printing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from pymc import Bernoulli, Censored, Mixture
from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT
from pymc.aesaraf import floatX
from pymc.distributions import (
Dirichlet,
Expand Down Expand Up @@ -130,12 +130,12 @@ def setup_class(self):
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\text{\$\operatorname{DiracDelta}(0)\$},~\text{\$\operatorname{Pois}(5)\$})$",
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))$",
r"$\text{w} \sim \operatorname{Dir}(\text{<constant>})$",
(
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w},"
r"~\text{\$\operatorname{MarginalMixture}(f(),~\text{\\$\operatorname{DiracDelta}(0)\\$},~\text{\\$\operatorname{Pois}(5)\\$})\$},"
r"~\text{\$\operatorname{Censored}(\text{\\$\operatorname{Bern}(0.5)\\$},~-1,~1)\$})$"
r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5)),"
r"~\operatorname{Censored}(\operatorname{Bern}(0.5),~-1,~1))$"
),
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
Expand Down Expand Up @@ -178,3 +178,43 @@ def test_str_repr(self):
assert segment in model_text
else:
assert text in model_text


def test_model_latex_repr_three_levels_model():
with Model() as censored_model:
mu = Normal("mu", 0.0, 5.0)
sigma = HalfCauchy("sigma", 2.5)
normal_dist = Normal.dist(mu=mu, sigma=sigma)
censored_normal = Censored(
"censored_normal", normal_dist, lower=-2.0, upper=2.0, observed=[1, 0, 0.5]
)

latex_repr = censored_model.str_repr(formatting="latex")
expected = [
"$$",
"\\begin{array}{rcl}",
"\\text{mu} &\\sim & \\operatorname{N}(0,~5)\\\\\\text{sigma} &\\sim & "
"\\operatorname{C^{+}}(0,~2.5)\\\\\\text{censored_normal} &\\sim & "
"\\operatorname{Censored}(\\operatorname{N}(\\text{mu},~\\text{sigma}),~-2,~2)",
"\\end{array}",
"$$",
]
assert [line.strip() for line in latex_repr.split("\n")] == expected


def test_model_latex_repr_mixture_model():
with Model() as mix_model:
w = Dirichlet("w", [1, 1])
mix = Mixture("mix", w=w, comp_dists=[Normal.dist(0.0, 5.0), StudentT.dist(7.0)])

latex_repr = mix_model.str_repr(formatting="latex")
expected = [
"$$",
"\\begin{array}{rcl}",
"\\text{w} &\\sim & "
"\\operatorname{Dir}(\\text{<constant>})\\\\\\text{mix} &\\sim & "
"\\operatorname{MarginalMixture}(\\text{w},~\\operatorname{N}(0,~5),~\\operatorname{StudentT}(7,~0,~1))",
"\\end{array}",
"$$",
]
assert [line.strip() for line in latex_repr.split("\n")] == expected

0 comments on commit ea721e4

Please sign in to comment.