Skip to content

Commit

Permalink
Initial Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
RohitP2005 committed Jan 14, 2025
1 parent 8b33589 commit 5024850
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 31 deletions.
9 changes: 9 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,19 @@ def run_tests(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("setuptools", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
warning_flags = [
"-W",
"error::DeprecationWarning",
"-W",
"error::PendingDeprecationWarning",
"-W",
"error::FutureWarning",
]
session.run(
"python",
"-m",
"pytest",
*warning_flags,
*(session.posargs if session.posargs else ["-m", "unit or integration"]),
)

Expand Down
21 changes: 13 additions & 8 deletions src/pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
from __future__ import annotations
import numbers
import warnings

import numpy as np
import sympy
Expand All @@ -15,6 +14,7 @@
import pybamm
from pybamm.util import import_optional_dependency
from pybamm.expression_tree.printing.print_name import prettify_print_name
from utils import deprecate_function

if TYPE_CHECKING: # pragma: no cover
import casadi
Expand Down Expand Up @@ -356,6 +356,9 @@ def domain(self, domain):
)

@property
@deprecate_function(
msg="symbol.auxiliary_domains has been deprecated, use symbol.domains instead"
)
def auxiliary_domains(self):
"""Returns auxiliary domains."""
raise NotImplementedError(
Expand Down Expand Up @@ -994,18 +997,20 @@ def create_copy(
children = self._children_for_copying(new_children)
return self.__class__(self.name, children, domains=self.domains)

# Assuming `deprecate_function` is imported from deprecation_decorators

@deprecate_function(
version="2.0.0",
msg="The 'new_copy' function for expression tree symbols is deprecated, use 'create_copy' instead.",
)
def new_copy(
self,
new_children: list[Symbol] | None = None,
perform_simplifications: bool = True,
):
""" """
warnings.warn(
"The 'new_copy' function for expression tree symbols is deprecated, use "
"'create_copy' instead.",
DeprecationWarning,
stacklevel=2,
)
"""
This function is deprecated. Use 'create_copy' instead.
"""
return self.create_copy(new_children, perform_simplifications)

@cached_property
Expand Down
13 changes: 10 additions & 3 deletions src/pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pprint import pformat
from warnings import warn
from collections import defaultdict
from utils import deprecate_multiple_renamedParameters


class ParameterValues:
Expand Down Expand Up @@ -417,18 +418,24 @@ def set_initial_ocps(
return parameter_values

@staticmethod
@deprecate_multiple_renamedParameters(
{
"propotional term": "... proportional term [s-1]",
"1 + dlnf/dlnc": "Thermodynamic factor",
"electrode diffusivity": "particle diffusivity",
}
)
def check_parameter_values(values):
for param in list(values.keys()):
if "propotional term" in param:
raise ValueError(
f"The parameter '{param}' has been renamed to "
"'... proportional term [s-1]', and its value should now be divided"
"by 3600 to get the same results as before."
" by 3600 to get the same results as before."
)
# specific check for renamed parameter "1 + dlnf/dlnc"
if "1 + dlnf/dlnc" in param:
raise ValueError(
f"parameter '{param}' has been renamed to 'Thermodynamic factor'"
f"The parameter '{param}' has been renamed to 'Thermodynamic factor'"
)
if "electrode diffusivity" in param:
new_param = param.replace("electrode", "particle")
Expand Down
12 changes: 11 additions & 1 deletion src/pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import timedelta
import pybamm.telemetry
from pybamm.util import import_optional_dependency
from utils import deprecate_function

from pybamm.expression_tree.operations.serialise import Serialise

Expand Down Expand Up @@ -171,7 +172,11 @@ def _set_random_seed(self):
% (2**32)
)

@deprecate_function(
msg="pybamm.simulation.set_up_and_parameterise_experiment is deprecated and not meant to be accessed by users."
)
def set_up_and_parameterise_experiment(self, solve_kwargs=None):
"""Sets up and parameterizes the experiment."""
msg = "pybamm.simulation.set_up_and_parameterise_experiment is deprecated and not meant to be accessed by users."
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self._set_up_and_parameterise_experiment(solve_kwargs=solve_kwargs)
Expand Down Expand Up @@ -254,12 +259,17 @@ def _set_up_and_parameterise_experiment(self, solve_kwargs=None):
parameterised_model
)

@deprecate_function(
version="2.0.0",
msg="pybamm.set_parameters is deprecated and not meant to be accessed by users.",
)
def set_parameters(self):
"""Sets the parameters."""
msg = (
"pybamm.set_parameters is deprecated and not meant to be accessed by users."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self._set_parameters()
self._set_parameters() # Call the internal method

def _set_parameters(self):
"""
Expand Down
30 changes: 11 additions & 19 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pybamm
from pybamm.expression_tree.binary_operators import _Heaviside
from pybamm import ParameterValues
from utils import deprecated_params


class BaseSolver:
Expand Down Expand Up @@ -1162,6 +1163,14 @@ def process_t_interp(self, t_interp):

return t_interp

@deprecated_params(
{
"npts": (
"t_eval",
"The 'npts' parameter is deprecated, use 't_eval' instead.",
)
}
)
def step(
self,
old_solution,
Expand Down Expand Up @@ -1199,12 +1208,6 @@ def step(
Save solution with all previous timesteps. Defaults to True.
calculate_sensitivities : list of str or bool, optional
Whether the solver calculates sensitivities of all input parameters. Defaults to False.
If only a subset of sensitivities are required, can also pass a
list of input parameter names. **Limitations**: sensitivities are not calculated up to numerical tolerances
so are not guarenteed to be within the tolerances set by the solver, please raise an issue if you
require this functionality. Also, when using this feature with `pybamm.Experiment`, the sensitivities
do not take into account the movement of step-transitions wrt input parameters, so do not use this feature
if the timings of your experimental protocol change rapidly with respect to your input parameters.
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).
Expand All @@ -1223,31 +1226,20 @@ def step(
or old_solution.termination == "final time"
or "[experiment]" in old_solution.termination
):
# Return same solution as an event has already been triggered
# With hack to allow stepping past experiment current / voltage cut-off
return old_solution

# Make sure model isn't empty
self._check_empty_model(model)

# Make sure dt is greater than zero
if dt <= 0:
raise pybamm.SolverError("Step time must be >0")

# Raise deprecation warning for npts and convert it to t_eval
if npts is not None:
warnings.warn(
"The 'npts' parameter is deprecated, use 't_eval' instead.",
DeprecationWarning,
stacklevel=2,
)
t_eval = np.linspace(0, dt, npts)
elif t_eval is None:
if t_eval is None:
t_eval = np.array([0, dt])
elif t_eval[0] != 0 or t_eval[-1] != dt:
raise pybamm.SolverError(
"Elements inside array t_eval must lie in the closed interval 0 to dt"
)

else:
pass

Expand Down
1 change: 1 addition & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .decorators import *
102 changes: 102 additions & 0 deletions utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import functools
import warnings
from .exceptions import DeprecatedFunctionWarning


def deprecate_function(func=None, version=None, msg=None):
"""
A decorator to mark a function as deprecated.
Parameters:
- func: The function to decorate. If no function is provided, the decorator will be used as a factory.
- version: The version in which the function was deprecated.
- msg: Custom message to display alongside the deprecation warning.
If no message is provided, a default message is used.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Construct the warning message
message = (
msg
or f"Function '{func.__name__}' is deprecated and will be removed in future versions."
)
if version:
message += f" Deprecated since version {version}."

# Raise the deprecation warning
warnings.warn(message, category=DeprecatedFunctionWarning, stacklevel=2)

return func(*args, **kwargs)

return wrapper

# If no function is passed (when used as a decorator with parameters)
if func is None:
return decorator
else:
return decorator(func)


def deprecate_multiple_renamedParameters(param_dict: dict):
"""
Decorator to handle deprecated parameter names dynamically, issuing warnings when old
parameter names are found and replacing them with the new ones provided in the param_dict.
param_dict is a dictionary mapping old parameter names to new parameter names.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(self, values, *args, **kwargs):
# Iterate over the old-to-new parameter mapping
for old_param, new_param in param_dict.items():
if old_param in values:
# Issue a deprecation warning and update the parameter
warnings.warn(
f"The parameter '{old_param}' has been renamed to '{new_param}'",
DeprecationWarning,
stacklevel=2,
)
values[new_param] = values.pop(old_param)
# Call the original function with the updated values
return func(self, values, *args, **kwargs)

return wrapper

return decorator


def deprecated_params(param_map):
"""
A decorator to handle deprecated parameters in a function.
Parameters
----------
param_map : dict
A dictionary mapping deprecated parameter names to a tuple containing the
new parameter name (or None if it's removed) and a message explaining the deprecation.
Example
-------
@handle_deprecated_params({
'npts': ('t_eval', "The 'npts' parameter is deprecated, use 't_eval' instead.")
})
def step(...):
...
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for deprecated_param, (new_param, message) in param_map.items():
if deprecated_param in kwargs:
warnings.warn(message, DeprecationWarning, stacklevel=2)
if new_param:
kwargs[new_param] = kwargs.pop(deprecated_param)
return func(*args, **kwargs)

return wrapper

return decorator
Loading

0 comments on commit 5024850

Please sign in to comment.