From 3f1db21ae440682cb18e00dcaf0492ae1dfbd7d3 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Wed, 1 Jan 2025 22:58:16 -0500 Subject: [PATCH 1/4] Implement parameter stratification --- mira/metamodel/ops.py | 65 +++++++++++++++++++++++--------- mira/metamodel/template_model.py | 37 +++++++++--------- tests/test_ops.py | 17 +++++++++ 3 files changed, 84 insertions(+), 35 deletions(-) diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index 3a8f2051d..27d9532c4 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -258,6 +258,53 @@ def stratify( all_param_mappings[old_param].add(new_param) templates.append(stratified_template) + # Create new initial values for each of the strata + # of the original compartments, copied from the initial + # values of the original compartments + initials = {} + for initial_key, initial in template_model.initials.items(): + if initial.concept.name in exclude_concepts: + initials[initial.concept.name] = deepcopy(initial) + continue + any_param_stratified = False + for stratum_idx, stratum in enumerate(strata): + new_concept = initial.concept.with_context( + do_rename=modify_names, + curie_to_name_map=strata_curie_to_name, + **{key: stratum}, + ) + new_expression = deepcopy(initial.expression) + init_expr_params = template_model.get_parameters_from_expression( + new_expression.args[0] + ) + template_strata = [stratum if + param_renaming_uses_strata_names else stratum_idx] + for parameter in init_expr_params: + # If a parameter is explicitly listed as one to preserve, then + # don't stratify it + if params_to_preserve is not None and parameter in params_to_preserve: + continue + # If we have an explicit stratification list then if something isn't + # in the list then don't stratify it. + elif params_to_stratify is not None and parameter not in params_to_stratify: + continue + # Otherwise we go ahead with stratification, i.e., in cases + # where nothing was said about parameter stratification or the + # parameter was listed explicitly to be stratified + else: + param_suffix = '_'.join([str(s) for s in template_strata]) + new_param = f'{parameter}_{param_suffix}' + all_param_mappings[parameter].add(new_param) + new_expression = new_expression.subs(parameter, + sympy.Symbol(new_param)) + any_param_stratified = True + if not any_param_stratified: + new_initial = SympyExprStr(new_expression.args[0] / len(strata)) + else: + new_initial = new_expression + initials[new_concept.name] = \ + Initial(concept=new_concept, expression=new_initial) + parameters = {} for parameter_key, parameter in template_model.parameters.items(): if parameter_key not in all_param_mappings: @@ -275,24 +322,6 @@ def stratify( d.name = stratified_param parameters[stratified_param] = d - # Create new initial values for each of the strata - # of the original compartments, copied from the initial - # values of the original compartments - initials = {} - for initial_key, initial in template_model.initials.items(): - if initial.concept.name in exclude_concepts: - initials[initial.concept.name] = deepcopy(initial) - continue - for stratum in strata: - new_concept = initial.concept.with_context( - do_rename=modify_names, - curie_to_name_map=strata_curie_to_name, - **{key: stratum}, - ) - initials[new_concept.name] = Initial( - concept=new_concept, expression=SympyExprStr(initial.expression.args[0] / len(strata)) - ) - observables = {} for observable_key, observable in template_model.observables.items(): syms = {s.name for s in observable.expression.args[0].free_symbols} diff --git a/mira/metamodel/template_model.py b/mira/metamodel/template_model.py index 2c4fade32..9cac17e2f 100644 --- a/mira/metamodel/template_model.py +++ b/mira/metamodel/template_model.py @@ -372,6 +372,25 @@ class TemplateModel(BaseModel): "Note that all annotations are optional.", ) + def get_parameters_from_expression(self, expression) -> Set[str]: + if expression is None: + return set() + params = set() + if isinstance(expression, sympy.Symbol): + if expression.name in self.parameters: + # add the string name to the set + params.add(expression.name) + # There are many sympy classes that have args that can occur here + # so it's better to check for the presence of args + elif not hasattr(expression, "args"): + raise ValueError( + f"Rate law is of invalid type {type(expression)}: {expression}" + ) + else: + for arg in expression.args: + params |= self.get_parameters_from_expression(arg) + return params + def get_parameters_from_rate_law(self, rate_law) -> Set[str]: """Given a rate law, find its elements that are model parameters. @@ -389,23 +408,7 @@ def get_parameters_from_rate_law(self, rate_law) -> Set[str]: : A set of parameter names (as strings). """ - if rate_law is None: - return set() - params = set() - if isinstance(rate_law, sympy.Symbol): - if rate_law.name in self.parameters: - # add the string name to the set - params.add(rate_law.name) - # There are many sympy classes that have args that can occur here - # so it's better to check for the presence of args - elif not hasattr(rate_law, "args"): - raise ValueError( - f"Rate law is of invalid type {type(rate_law)}: {rate_law}" - ) - else: - for arg in rate_law.args: - params |= self.get_parameters_from_rate_law(arg) - return params + return self.get_parameters_from_expression(rate_law) def update_parameters(self, parameter_dict): """ diff --git a/tests/test_ops.py b/tests/test_ops.py index 585daab74..1f219f682 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -636,3 +636,20 @@ def test_add_observable_pattern(): assert 'young' in tm.observables obs = tm.observables['young'] assert obs.expression.args[0] == sympy.Symbol('A_young') + sympy.Symbol('B_young') + + +def test_stratify_initials_parameters(): + s = Concept(name='S') + t = NaturalDegradation(subject=s, rate_law=sympy.Symbol('alpha') * + sympy.Symbol(s.name)) + S0 = Initial(concept=s, expression=sympy.Symbol('S0')) + tm = TemplateModel(templates=[t], + parameters={'alpha': Parameter(name='alpha', value=0.1), + 'S0': Parameter(name='S0', value=1000)}, + initials={'S': S0}) + tm = stratify(tm, key='age', strata=['young', 'old'], structure=[], + param_renaming_uses_strata_names=True) + assert 'S_young' in tm.initials + assert tm.initials['S_young'].expression.args[0] == sympy.Symbol('S0_young') + assert 'S_old' in tm.initials + assert tm.initials['S_old'].expression.args[0] == sympy.Symbol('S0_old') From 30021971a8085127d811790593dd81b54ef46e14 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Thu, 2 Jan 2025 17:58:05 -0500 Subject: [PATCH 2/4] Handle more complex stratification of initial expressions --- mira/metamodel/ops.py | 69 +++++++++++++++++++++++++++++++++---------- tests/test_ops.py | 27 +++++++++++++---- 2 files changed, 75 insertions(+), 21 deletions(-) diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index 27d9532c4..649b80f18 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -258,21 +258,35 @@ def stratify( all_param_mappings[old_param].add(new_param) templates.append(stratified_template) - # Create new initial values for each of the strata - # of the original compartments, copied from the initial - # values of the original compartments + # Handle initial values and expressions depending on different + # criteria initials = {} for initial_key, initial in template_model.initials.items(): - if initial.concept.name in exclude_concepts: - initials[initial.concept.name] = deepcopy(initial) - continue + # We need to keep track of whether we stratified any parameters in + # the expression for this initial and if the parameter is being + # replaced by multiple stratified parameters any_param_stratified = False + param_replacements = defaultdict(set) + for stratum_idx, stratum in enumerate(strata): - new_concept = initial.concept.with_context( - do_rename=modify_names, - curie_to_name_map=strata_curie_to_name, - **{key: stratum}, - ) + # Figure out if the concept for this initial is one that we + # need to stratify or not + if (exclude_concepts and initial.concept.name in exclude_concepts) or \ + (concepts_to_preserve and initial.concept.name in concepts_to_preserve): + # Just make a copy of the original initial concept + new_concept = deepcopy(initial.concept) + concept_stratified = False + else: + # We create a new concept for the given stratum + new_concept = initial.concept.with_context( + do_rename=modify_names, + curie_to_name_map=strata_curie_to_name, + **{key: stratum}, + ) + concept_stratified = True + # Now we may have to rewrite the expression so that we can + # update for stratified parameters so we make a copy and figure + # out what parameters are in the expression new_expression = deepcopy(initial.expression) init_expr_params = template_model.get_parameters_from_expression( new_expression.args[0] @@ -292,16 +306,41 @@ def stratify( # where nothing was said about parameter stratification or the # parameter was listed explicitly to be stratified else: + # We create a new parameter symbol for the given stratum param_suffix = '_'.join([str(s) for s in template_strata]) new_param = f'{parameter}_{param_suffix}' + # If the concept is not stratified then we have to replace + # the original parameter with the sum of stratified ones + # so we just keep track of that in a set + any_param_stratified = True + if not concept_stratified: + param_replacements[parameter].add(new_param) + continue + # Otherwise we have to rewrite the expression to use the + # new parameter as replacement for the original one all_param_mappings[parameter].add(new_param) new_expression = new_expression.subs(parameter, sympy.Symbol(new_param)) - any_param_stratified = True - if not any_param_stratified: - new_initial = SympyExprStr(new_expression.args[0] / len(strata)) - else: + + # If we stratified any parameters in the expression then we have + # to update the initial value expression to reflect that + if any_param_stratified: + if param_replacements: + for orig_param, new_params in param_replacements.items(): + new_expression = new_expression.subs( + orig_param, + sympy.Add(*[sympy.Symbol(np) for np in new_params]) + ) new_initial = new_expression + # Otherwise we can just use the original expression, except if the + # concept was stratified, then we have to divide the initial + # expression into as many parts as there are strata + else: + if concept_stratified: + new_initial = SympyExprStr(new_expression.args[0] / len(strata)) + else: + new_initial = new_expression + initials[new_concept.name] = \ Initial(concept=new_concept, expression=new_initial) diff --git a/tests/test_ops.py b/tests/test_ops.py index 1f219f682..41479d0ab 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -647,9 +647,24 @@ def test_stratify_initials_parameters(): parameters={'alpha': Parameter(name='alpha', value=0.1), 'S0': Parameter(name='S0', value=1000)}, initials={'S': S0}) - tm = stratify(tm, key='age', strata=['young', 'old'], structure=[], - param_renaming_uses_strata_names=True) - assert 'S_young' in tm.initials - assert tm.initials['S_young'].expression.args[0] == sympy.Symbol('S0_young') - assert 'S_old' in tm.initials - assert tm.initials['S_old'].expression.args[0] == sympy.Symbol('S0_old') + tm1 = stratify(tm, key='age', strata=['young', 'old'], structure=[], + param_renaming_uses_strata_names=True) + assert 'S_young' in tm1.initials + assert tm1.initials['S_young'].expression.args[0] == sympy.Symbol('S0_young') + assert 'S_old' in tm1.initials + assert tm1.initials['S_old'].expression.args[0] == sympy.Symbol('S0_old') + + tm2 = stratify(tm, key='age', strata=['young', 'old'], structure=[], + param_renaming_uses_strata_names=True, + params_to_preserve={'S0'}) + assert 'S_young' in tm2.initials + assert tm2.initials['S_young'].expression.args[0] == sympy.Symbol('S0') / 2 + assert 'S_old' in tm2.initials + assert tm2.initials['S_old'].expression.args[0] == sympy.Symbol('S0') / 2 + + tm3 = stratify(tm, key='age', strata=['young', 'old'], structure=[], + param_renaming_uses_strata_names=True, + concepts_to_preserve={'S'}) + assert set(tm3.initials) == {'S'} + assert tm3.initials['S'].expression.args[0] == \ + sympy.Symbol('S0_old') + sympy.Symbol('S0_young') From a2ba072aea709aae5b3ee0a8e0a5a398cb0120b4 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Fri, 3 Jan 2025 10:15:55 -0500 Subject: [PATCH 3/4] Implement update of parameter values --- mira/metamodel/ops.py | 18 +++++++++++++----- tests/test_ops.py | 9 +++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index 649b80f18..609c4f533 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -261,6 +261,7 @@ def stratify( # Handle initial values and expressions depending on different # criteria initials = {} + param_value_mappings = {} for initial_key, initial in template_model.initials.items(): # We need to keep track of whether we stratified any parameters in # the expression for this initial and if the parameter is being @@ -309,18 +310,23 @@ def stratify( # We create a new parameter symbol for the given stratum param_suffix = '_'.join([str(s) for s in template_strata]) new_param = f'{parameter}_{param_suffix}' + any_param_stratified = True + all_param_mappings[parameter].add(new_param) + # We need to update the new, stratified parameter's value + # to be the original parameter's value divided by the number + # of strata + param_value_mappings[new_param] = \ + template_model.parameters[parameter].value / len(strata) # If the concept is not stratified then we have to replace # the original parameter with the sum of stratified ones # so we just keep track of that in a set - any_param_stratified = True if not concept_stratified: param_replacements[parameter].add(new_param) - continue # Otherwise we have to rewrite the expression to use the # new parameter as replacement for the original one - all_param_mappings[parameter].add(new_param) - new_expression = new_expression.subs(parameter, - sympy.Symbol(new_param)) + else: + new_expression = new_expression.subs(parameter, + sympy.Symbol(new_param)) # If we stratified any parameters in the expression then we have # to update the initial value expression to reflect that @@ -359,6 +365,8 @@ def stratify( for stratified_param in all_param_mappings[parameter_key]: d = deepcopy(parameter) d.name = stratified_param + if stratified_param in param_value_mappings: + d.value = param_value_mappings[stratified_param] parameters[stratified_param] = d observables = {} diff --git a/tests/test_ops.py b/tests/test_ops.py index 41479d0ab..f40f199b8 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -653,6 +653,10 @@ def test_stratify_initials_parameters(): assert tm1.initials['S_young'].expression.args[0] == sympy.Symbol('S0_young') assert 'S_old' in tm1.initials assert tm1.initials['S_old'].expression.args[0] == sympy.Symbol('S0_old') + assert 'S0_young' in tm1.parameters + assert tm1.parameters['S0_young'].value == 500 + assert 'S0_old' in tm1.parameters + assert tm1.parameters['S0_old'].value == 500 tm2 = stratify(tm, key='age', strata=['young', 'old'], structure=[], param_renaming_uses_strata_names=True, @@ -661,6 +665,8 @@ def test_stratify_initials_parameters(): assert tm2.initials['S_young'].expression.args[0] == sympy.Symbol('S0') / 2 assert 'S_old' in tm2.initials assert tm2.initials['S_old'].expression.args[0] == sympy.Symbol('S0') / 2 + assert 'S0' in tm2.parameters + assert tm2.parameters['S0'].value == 1000 tm3 = stratify(tm, key='age', strata=['young', 'old'], structure=[], param_renaming_uses_strata_names=True, @@ -668,3 +674,6 @@ def test_stratify_initials_parameters(): assert set(tm3.initials) == {'S'} assert tm3.initials['S'].expression.args[0] == \ sympy.Symbol('S0_old') + sympy.Symbol('S0_young') + assert set(tm3.parameters) == {'alpha', 'S0_old', 'S0_young'} + assert tm3.parameters['S0_old'].value == 500 + assert tm3.parameters['S0_young'].value == 500 From 59376f280ad2d9a5d45849769e3add6d37d6ef94 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Fri, 3 Jan 2025 10:18:50 -0500 Subject: [PATCH 4/4] Add docstring for new method --- mira/metamodel/template_model.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mira/metamodel/template_model.py b/mira/metamodel/template_model.py index 9cac17e2f..fb4874d2d 100644 --- a/mira/metamodel/template_model.py +++ b/mira/metamodel/template_model.py @@ -373,6 +373,22 @@ class TemplateModel(BaseModel): ) def get_parameters_from_expression(self, expression) -> Set[str]: + """Given a symbolic expression, find its elements that are model parameters. + + Expressions such as rate laws consist of some combination of participants, + rate parameters and potentially other factors. This function finds those + elements of expressions that are rate parameters. + + Parameters + ---------- + expression : sympy.Symbol | sympy.Expr + A sympy expression or symbol, whose parameters are extracted. + + Returns + ------- + : + A set of parameter names (as strings). + """ if expression is None: return set() params = set() @@ -401,7 +417,7 @@ def get_parameters_from_rate_law(self, rate_law) -> Set[str]: Parameters ---------- rate_law : sympy.Symbol | sympy.Expr - A sympy expression or symbol, whose names are extracted. + A sympy expression or symbol, whose parameters are extracted. Returns -------