Skip to content

Commit

Permalink
Merge pull request #414 from gyorilab/stratify_initial_params
Browse files Browse the repository at this point in the history
Stratify initial expressions and corresponding parameters
  • Loading branch information
bgyori authored Jan 3, 2025
2 parents 63d4308 + 59376f2 commit 2f777c4
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 33 deletions.
112 changes: 94 additions & 18 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,98 @@ def stratify(
all_param_mappings[old_param].add(new_param)
templates.append(stratified_template)

# Handle initial values and expressions depending on different
# criteria
initials = {}
param_value_mappings = {}
for initial_key, initial in template_model.initials.items():
# We need to keep track of whether we stratified any parameters in
# the expression for this initial and if the parameter is being
# replaced by multiple stratified parameters
any_param_stratified = False
param_replacements = defaultdict(set)

for stratum_idx, stratum in enumerate(strata):
# Figure out if the concept for this initial is one that we
# need to stratify or not
if (exclude_concepts and initial.concept.name in exclude_concepts) or \
(concepts_to_preserve and initial.concept.name in concepts_to_preserve):
# Just make a copy of the original initial concept
new_concept = deepcopy(initial.concept)
concept_stratified = False
else:
# We create a new concept for the given stratum
new_concept = initial.concept.with_context(
do_rename=modify_names,
curie_to_name_map=strata_curie_to_name,
**{key: stratum},
)
concept_stratified = True
# Now we may have to rewrite the expression so that we can
# update for stratified parameters so we make a copy and figure
# out what parameters are in the expression
new_expression = deepcopy(initial.expression)
init_expr_params = template_model.get_parameters_from_expression(
new_expression.args[0]
)
template_strata = [stratum if
param_renaming_uses_strata_names else stratum_idx]
for parameter in init_expr_params:
# If a parameter is explicitly listed as one to preserve, then
# don't stratify it
if params_to_preserve is not None and parameter in params_to_preserve:
continue
# If we have an explicit stratification list then if something isn't
# in the list then don't stratify it.
elif params_to_stratify is not None and parameter not in params_to_stratify:
continue
# Otherwise we go ahead with stratification, i.e., in cases
# where nothing was said about parameter stratification or the
# parameter was listed explicitly to be stratified
else:
# We create a new parameter symbol for the given stratum
param_suffix = '_'.join([str(s) for s in template_strata])
new_param = f'{parameter}_{param_suffix}'
any_param_stratified = True
all_param_mappings[parameter].add(new_param)
# We need to update the new, stratified parameter's value
# to be the original parameter's value divided by the number
# of strata
param_value_mappings[new_param] = \
template_model.parameters[parameter].value / len(strata)
# If the concept is not stratified then we have to replace
# the original parameter with the sum of stratified ones
# so we just keep track of that in a set
if not concept_stratified:
param_replacements[parameter].add(new_param)
# Otherwise we have to rewrite the expression to use the
# new parameter as replacement for the original one
else:
new_expression = new_expression.subs(parameter,
sympy.Symbol(new_param))

# If we stratified any parameters in the expression then we have
# to update the initial value expression to reflect that
if any_param_stratified:
if param_replacements:
for orig_param, new_params in param_replacements.items():
new_expression = new_expression.subs(
orig_param,
sympy.Add(*[sympy.Symbol(np) for np in new_params])
)
new_initial = new_expression
# Otherwise we can just use the original expression, except if the
# concept was stratified, then we have to divide the initial
# expression into as many parts as there are strata
else:
if concept_stratified:
new_initial = SympyExprStr(new_expression.args[0] / len(strata))
else:
new_initial = new_expression

initials[new_concept.name] = \
Initial(concept=new_concept, expression=new_initial)

parameters = {}
for parameter_key, parameter in template_model.parameters.items():
if parameter_key not in all_param_mappings:
Expand All @@ -273,26 +365,10 @@ def stratify(
for stratified_param in all_param_mappings[parameter_key]:
d = deepcopy(parameter)
d.name = stratified_param
if stratified_param in param_value_mappings:
d.value = param_value_mappings[stratified_param]
parameters[stratified_param] = d

# Create new initial values for each of the strata
# of the original compartments, copied from the initial
# values of the original compartments
initials = {}
for initial_key, initial in template_model.initials.items():
if initial.concept.name in exclude_concepts:
initials[initial.concept.name] = deepcopy(initial)
continue
for stratum in strata:
new_concept = initial.concept.with_context(
do_rename=modify_names,
curie_to_name_map=strata_curie_to_name,
**{key: stratum},
)
initials[new_concept.name] = Initial(
concept=new_concept, expression=SympyExprStr(initial.expression.args[0] / len(strata))
)

observables = {}
for observable_key, observable in template_model.observables.items():
syms = {s.name for s in observable.expression.args[0].free_symbols}
Expand Down
49 changes: 34 additions & 15 deletions mira/metamodel/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,41 +372,60 @@ class TemplateModel(BaseModel):
"Note that all annotations are optional.",
)

def get_parameters_from_rate_law(self, rate_law) -> Set[str]:
"""Given a rate law, find its elements that are model parameters.
def get_parameters_from_expression(self, expression) -> Set[str]:
"""Given a symbolic expression, find its elements that are model parameters.
Rate laws consist of some combination of participants, rate parameters
and potentially other factors. This function finds those elements of
rate laws that are rate parameters.
Expressions such as rate laws consist of some combination of participants,
rate parameters and potentially other factors. This function finds those
elements of expressions that are rate parameters.
Parameters
----------
rate_law : sympy.Symbol | sympy.Expr
A sympy expression or symbol, whose names are extracted.
expression : sympy.Symbol | sympy.Expr
A sympy expression or symbol, whose parameters are extracted.
Returns
-------
:
A set of parameter names (as strings).
"""
if rate_law is None:
if expression is None:
return set()
params = set()
if isinstance(rate_law, sympy.Symbol):
if rate_law.name in self.parameters:
if isinstance(expression, sympy.Symbol):
if expression.name in self.parameters:
# add the string name to the set
params.add(rate_law.name)
params.add(expression.name)
# There are many sympy classes that have args that can occur here
# so it's better to check for the presence of args
elif not hasattr(rate_law, "args"):
elif not hasattr(expression, "args"):
raise ValueError(
f"Rate law is of invalid type {type(rate_law)}: {rate_law}"
f"Rate law is of invalid type {type(expression)}: {expression}"
)
else:
for arg in rate_law.args:
params |= self.get_parameters_from_rate_law(arg)
for arg in expression.args:
params |= self.get_parameters_from_expression(arg)
return params

def get_parameters_from_rate_law(self, rate_law) -> Set[str]:
"""Given a rate law, find its elements that are model parameters.
Rate laws consist of some combination of participants, rate parameters
and potentially other factors. This function finds those elements of
rate laws that are rate parameters.
Parameters
----------
rate_law : sympy.Symbol | sympy.Expr
A sympy expression or symbol, whose parameters are extracted.
Returns
-------
:
A set of parameter names (as strings).
"""
return self.get_parameters_from_expression(rate_law)

def update_parameters(self, parameter_dict):
"""
Update parameter values.
Expand Down
41 changes: 41 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,44 @@ def test_add_observable_pattern():
assert 'young' in tm.observables
obs = tm.observables['young']
assert obs.expression.args[0] == sympy.Symbol('A_young') + sympy.Symbol('B_young')


def test_stratify_initials_parameters():
s = Concept(name='S')
t = NaturalDegradation(subject=s, rate_law=sympy.Symbol('alpha') *
sympy.Symbol(s.name))
S0 = Initial(concept=s, expression=sympy.Symbol('S0'))
tm = TemplateModel(templates=[t],
parameters={'alpha': Parameter(name='alpha', value=0.1),
'S0': Parameter(name='S0', value=1000)},
initials={'S': S0})
tm1 = stratify(tm, key='age', strata=['young', 'old'], structure=[],
param_renaming_uses_strata_names=True)
assert 'S_young' in tm1.initials
assert tm1.initials['S_young'].expression.args[0] == sympy.Symbol('S0_young')
assert 'S_old' in tm1.initials
assert tm1.initials['S_old'].expression.args[0] == sympy.Symbol('S0_old')
assert 'S0_young' in tm1.parameters
assert tm1.parameters['S0_young'].value == 500
assert 'S0_old' in tm1.parameters
assert tm1.parameters['S0_old'].value == 500

tm2 = stratify(tm, key='age', strata=['young', 'old'], structure=[],
param_renaming_uses_strata_names=True,
params_to_preserve={'S0'})
assert 'S_young' in tm2.initials
assert tm2.initials['S_young'].expression.args[0] == sympy.Symbol('S0') / 2
assert 'S_old' in tm2.initials
assert tm2.initials['S_old'].expression.args[0] == sympy.Symbol('S0') / 2
assert 'S0' in tm2.parameters
assert tm2.parameters['S0'].value == 1000

tm3 = stratify(tm, key='age', strata=['young', 'old'], structure=[],
param_renaming_uses_strata_names=True,
concepts_to_preserve={'S'})
assert set(tm3.initials) == {'S'}
assert tm3.initials['S'].expression.args[0] == \
sympy.Symbol('S0_old') + sympy.Symbol('S0_young')
assert set(tm3.parameters) == {'alpha', 'S0_old', 'S0_young'}
assert tm3.parameters['S0_old'].value == 500
assert tm3.parameters['S0_young'].value == 500

0 comments on commit 2f777c4

Please sign in to comment.