diff --git a/benchmarks/work_precision_sets/time_vs_abstols.py b/benchmarks/work_precision_sets/time_vs_abstols.py index d680766c43..af76493abc 100644 --- a/benchmarks/work_precision_sets/time_vs_abstols.py +++ b/benchmarks/work_precision_sets/time_vs_abstols.py @@ -72,7 +72,7 @@ solver.solve(model, t_eval=t_eval) time = 0 runs = 20 - for k in range(0, runs): + for _ in range(0, runs): solution = solver.solve(model, t_eval=t_eval) time += solution.solve_time.value time = time / runs diff --git a/benchmarks/work_precision_sets/time_vs_dt_max.py b/benchmarks/work_precision_sets/time_vs_dt_max.py index a1f8ca06bc..c9979c4e47 100644 --- a/benchmarks/work_precision_sets/time_vs_dt_max.py +++ b/benchmarks/work_precision_sets/time_vs_dt_max.py @@ -76,7 +76,7 @@ solver.solve(model, t_eval=t_eval) time = 0 runs = 20 - for k in range(0, runs): + for _ in range(0, runs): solution = solver.solve(model, t_eval=t_eval) time += solution.solve_time.value time = time / runs diff --git a/benchmarks/work_precision_sets/time_vs_mesh_size.py b/benchmarks/work_precision_sets/time_vs_mesh_size.py index cbab18d16c..7b8ad525df 100644 --- a/benchmarks/work_precision_sets/time_vs_mesh_size.py +++ b/benchmarks/work_precision_sets/time_vs_mesh_size.py @@ -54,7 +54,7 @@ time = 0 runs = 20 - for k in range(0, runs): + for _ in range(0, runs): solution = sim.solve([0, 3500]) time += solution.solve_time.value time = time / runs diff --git a/benchmarks/work_precision_sets/time_vs_no_of_states.py b/benchmarks/work_precision_sets/time_vs_no_of_states.py index febc69f0a1..fdc039587f 100644 --- a/benchmarks/work_precision_sets/time_vs_no_of_states.py +++ b/benchmarks/work_precision_sets/time_vs_no_of_states.py @@ -54,7 +54,7 @@ time = 0 runs = 20 - for k in range(0, runs): + for _ in range(0, runs): solution = sim.solve([0, 3500]) time += solution.solve_time.value time = time / runs diff --git a/benchmarks/work_precision_sets/time_vs_reltols.py b/benchmarks/work_precision_sets/time_vs_reltols.py index 42e9a1bab1..4afcddf94d 100644 --- a/benchmarks/work_precision_sets/time_vs_reltols.py +++ b/benchmarks/work_precision_sets/time_vs_reltols.py @@ -78,7 +78,7 @@ solver.solve(model, t_eval=t_eval) time = 0 runs = 20 - for k in range(0, runs): + for _ in range(0, runs): solution = solver.solve(model, t_eval=t_eval) time += solution.solve_time.value time = time / runs diff --git a/docs/source/examples/notebooks/batch_study.ipynb b/docs/source/examples/notebooks/batch_study.ipynb index 0c0d216763..63169e6a07 100644 --- a/docs/source/examples/notebooks/batch_study.ipynb +++ b/docs/source/examples/notebooks/batch_study.ipynb @@ -199,7 +199,7 @@ "\n", "# changing the value of \"Current function [A]\" in all the parameter values present in the\n", "# parameter_values dictionary\n", - "for k, v, current_value in zip(\n", + "for _, v, current_value in zip(\n", " parameter_values.keys(), parameter_values.values(), current_values\n", "):\n", " v[\"Current function [A]\"] = current_value\n", @@ -505,7 +505,7 @@ "inner_sei_oc_v_values = [2.0e-4, 2.7e-4, 3.4e-4]\n", "\n", "# updating the value of \"Inner SEI open-circuit potential [V]\" in all the dictionary items\n", - "for k, v, inner_sei_oc_v in zip(\n", + "for _, v, inner_sei_oc_v in zip(\n", " parameter_values.keys(), parameter_values.values(), inner_sei_oc_v_values\n", "):\n", " v.update(\n", diff --git a/pybamm/citations.py b/pybamm/citations.py index c2930e5826..886437e6c2 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -92,7 +92,7 @@ def _add_citation(self, key, entry): # Warn if overwriting a previous citation new_citation = entry.to_string("bibtex") if key in self._all_citations and new_citation != self._all_citations[key]: - warnings.warn(f"Replacing citation for {key}") + warnings.warn(f"Replacing citation for {key}", stacklevel=2) # Add to database self._all_citations[key] = new_citation @@ -165,9 +165,9 @@ def _parse_citation(self, key): # Add to _papers_to_cite set self._papers_to_cite.add(key) return - except PybtexError: + except PybtexError as error: # Unable to parse / unknown key - raise KeyError(f"Not a bibtex citation or known citation: {key}") + raise KeyError(f"Not a bibtex citation or known citation: {key}") from error def _tag_citations(self): """Prints the citation tags for the citations that have been registered @@ -226,6 +226,7 @@ def print(self, filename=None, output_format="text", verbose=False): warnings.warn( message=f'\nCitation with key "{key}" is invalid. Please try again\n', category=UserWarning, + stacklevel=2, ) # delete the invalid citation from the set self._unknown_citations.remove(key) diff --git a/pybamm/discretisations/discretisation.py b/pybamm/discretisations/discretisation.py index 90814ea508..68c2e9f19a 100644 --- a/pybamm/discretisations/discretisation.py +++ b/pybamm/discretisations/discretisation.py @@ -280,7 +280,7 @@ def set_variable_slices(self, variables): sec_points = spatial_method._get_auxiliary_domain_repeats( variable.domains ) - for i in range(sec_points): + for _ in range(sec_points): for child, mesh in meshes.items(): for domain_mesh in mesh: end += domain_mesh.npts_for_broadcast_to_nodes @@ -886,14 +886,14 @@ def _process_symbol(self, symbol): # model.check_well_posedness, but won't be if debug_mode is False try: y_slices = self.y_slices[symbol] - except KeyError: + except KeyError as error: raise pybamm.ModelError( f""" No key set for variable '{symbol.name}'. Make sure it is included in either model.rhs or model.algebraic in an unmodified form (e.g. not Broadcasted) """ - ) + ) from error # Add symbol's reference and multiply by the symbol's scale # so that the state vector is of order 1 return symbol.reference + symbol.scale * pybamm.StateVector( diff --git a/pybamm/experiment/step/_steps_util.py b/pybamm/experiment/step/_steps_util.py index 6bf137ed1b..26f9fe815a 100644 --- a/pybamm/experiment/step/_steps_util.py +++ b/pybamm/experiment/step/_steps_util.py @@ -287,8 +287,8 @@ def _convert_electric(value_string): } try: typ = units_to_type[unit] - except KeyError: + except KeyError as error: raise ValueError( f"units must be 'A', 'V', 'W', 'Ohm', or 'C'. For example: {_examples}" - ) + ) from error return typ, value diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index a476649f05..ce47fc60a8 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -337,7 +337,7 @@ def create_slices(self, node: pybamm.Symbol) -> defaultdict: """Concatenation and children must have the same number of points in secondary dimensions""" ) - for i in range(second_pts): + for _ in range(second_pts): for dom in node.domain: end += self.full_mesh[dom].npts slices[dom].append(slice(start, end)) diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index f161c99b13..b19d1726b5 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -102,8 +102,8 @@ def _base_evaluate( try: input_eval = inputs[self.name] # raise more informative error if can't find name in dict - except KeyError: - raise KeyError(f"Input parameter '{self.name}' not found") + except KeyError as error: + raise KeyError(f"Input parameter '{self.name}' not found") from error if isinstance(input_eval, numbers.Number): input_size = 1 diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index 6348e1fdc0..fd31284703 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -90,10 +90,10 @@ def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): else: try: jac = symbol._jac(variable) - except NotImplementedError: + except NotImplementedError as error: raise NotImplementedError( f"Cannot calculate Jacobian of symbol of type '{type(symbol)}'" - ) + ) from error # Jacobian by default removes the domain(s) if self._clear_domain: diff --git a/pybamm/expression_tree/operations/latexify.py b/pybamm/expression_tree/operations/latexify.py index f1c3734deb..e2817ed2d5 100644 --- a/pybamm/expression_tree/operations/latexify.py +++ b/pybamm/expression_tree/operations/latexify.py @@ -76,7 +76,7 @@ def _get_geometry_displays(self, var): rng_min = get_rng_min_max_name(rng, "min") # Take range maximum from the last domain - for var_name, rng in self.model.default_geometry[var.domain[-1]].items(): + for _, rng in self.model.default_geometry[var.domain[-1]].items(): rng_max = get_rng_min_max_name(rng, "max") geo_latex = f"\quad {rng_min} < {name} < {rng_max}" @@ -303,7 +303,8 @@ def latexify(self, output_variables=None): # When equations are too huge, set output resolution to default except RuntimeError: # pragma: no cover warnings.warn( - "RuntimeError - Setting the output resolution to default" + "RuntimeError - Setting the output resolution to default", + stacklevel=2, ) return sympy.preview( eqn_new_line, diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index cb53e6a787..466e5b7be6 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -7,7 +7,7 @@ import numpy as np import sympy from scipy.sparse import csr_matrix, issparse -from functools import lru_cache, cached_property +from functools import cached_property from typing import TYPE_CHECKING, Sequence, cast import pybamm @@ -876,7 +876,7 @@ def evaluate_ignoring_errors(self, t: float | None = 0): return None raise pybamm.ShapeError( f"Cannot find shape (original error: {error})" - ) # pragma: no cover + ) from error # pragma: no cover return result def evaluates_to_number(self): @@ -895,7 +895,6 @@ def evaluates_to_number(self): def evaluates_to_constant_number(self): return self.evaluates_to_number() and self.is_constant() - @lru_cache def evaluates_on_edges(self, dimension: str) -> bool: """ Returns True if a symbol evaluates on an edge, i.e. symbol contains a gradient @@ -914,9 +913,12 @@ def evaluates_on_edges(self, dimension: str) -> bool: Whether the symbol evaluates on edges (in the finite volume discretisation sense) """ - eval_on_edges = self._evaluates_on_edges(dimension) - self._saved_evaluates_on_edges[dimension] = eval_on_edges - return eval_on_edges + if dimension not in self._saved_evaluates_on_edges: + self._saved_evaluates_on_edges[dimension] = self._evaluates_on_edges( + dimension + ) + + return self._saved_evaluates_on_edges[dimension] def _evaluates_on_edges(self, dimension): # Default behaviour: return False @@ -1039,7 +1041,7 @@ def test_shape(self): try: self.shape_for_testing except ValueError as e: - raise pybamm.ShapeError(f"Cannot find shape (original error: {e})") + raise pybamm.ShapeError(f"Cannot find shape (original error: {e})") from e @property def print_name(self): diff --git a/pybamm/install_odes.py b/pybamm/install_odes.py index 0c93234623..b9f918eb85 100644 --- a/pybamm/install_odes.py +++ b/pybamm/install_odes.py @@ -47,8 +47,8 @@ def install_sundials(download_dir, install_dir): try: subprocess.run(["cmake", "--version"]) - except OSError: - raise RuntimeError("CMake must be installed to build SUNDIALS.") + except OSError as error: + raise RuntimeError("CMake must be installed to build SUNDIALS.") from error url = f"https://github.com/LLNL/sundials/releases/download/v{SUNDIALS_VERSION}/sundials-{SUNDIALS_VERSION}.tar.gz" logger.info("Downloading sundials") diff --git a/pybamm/meshes/meshes.py b/pybamm/meshes/meshes.py index 7fdcd0eede..3ec291b1b8 100644 --- a/pybamm/meshes/meshes.py +++ b/pybamm/meshes/meshes.py @@ -85,8 +85,8 @@ def __init__(self, geometry, submesh_types, var_pts): for spatial_variable, spatial_limits in geometry[domain].items(): # process tab information if using 1 or 2D current collectors if spatial_variable == "tabs": - for tab, position_size in spatial_limits.items(): - for position_size, sym in position_size.items(): + for tab, position_info in spatial_limits.items(): + for position_size, sym in position_info.items(): if isinstance(sym, pybamm.Symbol): sym_eval = sym.evaluate() geometry[domain]["tabs"][tab][position_size] = sym_eval @@ -102,7 +102,7 @@ def __init__(self, geometry, submesh_types, var_pts): "geometry. Make sure that something like " "`param.process_geometry(geometry)` has been " "run." - ) + ) from error else: raise error elif isinstance(sym, numbers.Number): diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index b8575c0f30..d0820334d6 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -590,7 +590,7 @@ def build_coupled_variables(self): f"Missing variable for submodel '{submodel_name}': {key}.\n" + "Check the selected " "submodels provide all of the required variables." - ) + ) from key else: # try setting coupled variables on next loop through pybamm.logger.debug( @@ -677,7 +677,7 @@ def set_initial_conditions_from(self, solution, inplace=True, return_type="model "model.initial_conditions must appear in the solution with " "the same key as the variable name. In the solution provided, " f"'{e.args[0]}' was not found." - ) + ) from e if isinstance(solution, pybamm.Solution): final_state = final_state.data if final_state.ndim == 0: @@ -701,7 +701,7 @@ def set_initial_conditions_from(self, solution, inplace=True, return_type="model "model.initial_conditions must appear in the solution with " "the same key as the variable name. In the solution " f"provided, {e.args[0]}" - ) + ) from e if isinstance(solution, pybamm.Solution): final_state = final_state.data if final_state.ndim == 2: @@ -873,8 +873,8 @@ def check_well_determined(self, post_discretisation): ] ) all_vars_in_eqns.update(vars_in_eqns) - for var, side_eqn in self.boundary_conditions.items(): - for side, (eqn, typ) in side_eqn.items(): + for _, side_eqn in self.boundary_conditions.items(): + for _, (eqn, _) in side_eqn.items(): vars_in_eqns = unpacker.unpack_symbol(eqn) all_vars_in_eqns.update(vars_in_eqns) @@ -1001,7 +1001,7 @@ def check_discretised_or_discretise_inplace_if_0D(self): raise pybamm.DiscretisationError( "Cannot automatically discretise model, model should be " f"discretised before exporting casadi functions ({e})" - ) + ) from e def export_casadi_objects(self, variable_names, input_parameter_order=None): """ @@ -1242,6 +1242,7 @@ def save_model(self, filename=None, mesh=None, variables=None): Plotting may not be available. """, pybamm.ModelWarning, + stacklevel=2, ) Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) diff --git a/pybamm/models/full_battery_models/lithium_ion/electrode_soh_half_cell.py b/pybamm/models/full_battery_models/lithium_ion/electrode_soh_half_cell.py index 4fc8c904b6..92cbc73dcb 100644 --- a/pybamm/models/full_battery_models/lithium_ion/electrode_soh_half_cell.py +++ b/pybamm/models/full_battery_models/lithium_ion/electrode_soh_half_cell.py @@ -129,9 +129,7 @@ def get_initial_stoichiometry_half_cell( return x -def get_min_max_stoichiometries( - parameter_values, options={"working electrode": "positive"} -): +def get_min_max_stoichiometries(parameter_values, options=None): """ Get the minimum and maximum stoichiometries from the parameter values @@ -139,7 +137,13 @@ def get_min_max_stoichiometries( ---------- parameter_values : pybamm.ParameterValues The parameter values to use in the calculation + options : dict, optional + A dictionary of options to be passed to the parameters, see + :class:`pybamm.BatteryModelOptions`. + If None, the default is used: {"working electrode": "positive"} """ + if options is None: + options = {"working electrode": "positive"} esoh_model = pybamm.lithium_ion.ElectrodeSOHHalfCell("ElectrodeSOH") param = pybamm.LithiumIonParameters(options) esoh_sim = pybamm.Simulation(esoh_model, parameter_values=parameter_values) diff --git a/pybamm/models/submodels/thermal/base_thermal.py b/pybamm/models/submodels/thermal/base_thermal.py index 808cdefc67..b530391bde 100644 --- a/pybamm/models/submodels/thermal/base_thermal.py +++ b/pybamm/models/submodels/thermal/base_thermal.py @@ -117,7 +117,7 @@ def _get_standard_coupled_variables(self, variables): # Total Ohmic heating Q_ohm = Q_ohm_s + Q_ohm_e - num_phases = int(getattr(self.options, "positive")["particle phases"]) + num_phases = int(self.options.positive["particle phases"]) phase_names = [""] if num_phases > 1: phase_names = ["primary ", "secondary "] @@ -135,7 +135,7 @@ def _get_standard_coupled_variables(self, variables): dUdT_p = variables[f"Positive electrode {phase}entropic change [V.K-1]"] Q_rev_p += a_j_p * T_p * dUdT_p - num_phases = int(getattr(self.options, "negative")["particle phases"]) + num_phases = int(self.options.negative["particle phases"]) phase_names = [""] if num_phases > 1: phase_names = ["primary", "secondary"] diff --git a/pybamm/parameters/base_parameters.py b/pybamm/parameters/base_parameters.py index a7b319ec81..c686665019 100644 --- a/pybamm/parameters/base_parameters.py +++ b/pybamm/parameters/base_parameters.py @@ -19,7 +19,9 @@ def __getattribute__(self, name): return super().__getattribute__(name) except AttributeError as e: if name == "cap_init": - warnings.warn("Parameter 'cap_init' has been renamed to 'Q_init'") + warnings.warn( + "Parameter 'cap_init' has been renamed to 'Q_init'", stacklevel=2 + ) return self.Q_init for domain in ["n", "s", "p"]: if f"_{domain}_" in name or name.endswith(f"_{domain}"): @@ -32,14 +34,14 @@ def __getattribute__(self, name): raise AttributeError( f"param.{name} does not exist. It has been renamed to " f"param.{domain}.{name_without_domain}" - ) + ) from e elif hasattr(self_domain, "prim") and hasattr( self_domain.prim, name_without_domain ): raise AttributeError( f"param.{name} does not exist. It has been renamed to " f"param.{domain}.prim.{name_without_domain}" - ) + ) from e else: raise e else: diff --git a/pybamm/parameters/bpx.py b/pybamm/parameters/bpx.py index e1c7fed43f..3d95ae4bb4 100644 --- a/pybamm/parameters/bpx.py +++ b/pybamm/parameters/bpx.py @@ -94,7 +94,7 @@ def _get_phase_names(domain): Return a list of the phase names in a given domain """ if isinstance(domain, (ElectrodeBlended, ElectrodeBlendedSPM)): - phases = len(getattr(domain, "particle").keys()) + phases = len(domain.particle.keys()) else: phases = 1 if phases == 1: @@ -468,7 +468,7 @@ def _bpx_to_domain_param_dict(instance: BPX, pybamm_dict: dict, domain: Domain) isinstance(instance, (ElectrodeBlended, ElectrodeBlendedSPM)) and name == "particle" ): - particle_instance = getattr(instance, "particle") + particle_instance = instance.particle # Loop over phases for i, phase_name in enumerate(particle_instance.keys()): phase_instance = particle_instance[phase_name] diff --git a/pybamm/parameters/parameter_sets.py b/pybamm/parameters/parameter_sets.py index e7872f8a28..4b8ccbca51 100644 --- a/pybamm/parameters/parameter_sets.py +++ b/pybamm/parameters/parameter_sets.py @@ -92,7 +92,7 @@ def __getattribute__(self, name): f"Parameter sets should be called directly by their name ({name}), " f"instead of via pybamm.parameter_sets (pybamm.parameter_sets.{name})." ) - warnings.warn(msg, DeprecationWarning) + warnings.warn(msg, DeprecationWarning, stacklevel=2) return name raise error diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index 5a2c61b1d5..f2dd9ec630 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -151,7 +151,7 @@ def __getitem__(self, key): "density for the lithium plating reaction in a porous negative " "electrode. To avoid this error, change your parameter file to use " "the new name." - ) + ) from err else: raise err @@ -251,7 +251,7 @@ def update(self, values, check_conflict=False, check_already_exists=True, path=" + f"have a default value. ({err.args[0]}). If you are " + "sure you want to update this parameter, use " + "param.update({{name: value}}, check_already_exists=False)" - ) + ) from err # if no conflicts, update if isinstance(value, str): if ( @@ -542,7 +542,7 @@ def process_boundary_conditions(self, model): pass # do raise error otherwise (e.g. can't process symbol) else: - raise KeyError(err) + raise err return new_boundary_conditions @@ -568,8 +568,8 @@ def process_and_check(sym): for spatial_variable, spatial_limits in geometry[domain].items(): # process tab information if using 1 or 2D current collectors if spatial_variable == "tabs": - for tab, position_size in spatial_limits.items(): - for position_size, sym in position_size.items(): + for tab, position_info in spatial_limits.items(): + for position_size, sym in position_info.items(): geometry[domain]["tabs"][tab][position_size] = ( process_and_check(sym) ) diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index 7780f092bf..c43498c41b 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -222,10 +222,10 @@ def __init__( except KeyError: # if variable_tuple is not provided, default to "fixed" self.variable_limits[variable_tuple] = "fixed" - except TypeError: + except TypeError as error: raise TypeError( "variable_limits must be 'fixed', 'tight', or a dict" - ) + ) from error self.set_output_variables(output_variable_tuples, solutions) self.reset_axis() diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 493c0069b2..a2b260ab43 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -566,6 +566,7 @@ def solve( to be the points in the data. """, pybamm.SolverWarning, + stacklevel=2, ) dt_data_min = np.min(np.diff(time_data)) dt_eval_max = np.max(np.diff(t_eval)) @@ -580,6 +581,7 @@ def solve( points in the data. """, pybamm.SolverWarning, + stacklevel=2, ) self._solution = solver.solve(self.built_model, t_eval, **kwargs) diff --git a/pybamm/solvers/algebraic_solver.py b/pybamm/solvers/algebraic_solver.py index fed2f4a1c0..bc711ff02a 100644 --- a/pybamm/solvers/algebraic_solver.py +++ b/pybamm/solvers/algebraic_solver.py @@ -101,7 +101,7 @@ def algebraic(t, y): integration_time = 0 for idx, t in enumerate(t_eval): - def root_fun(y_alg): + def root_fun(y_alg, t=t): "Evaluates algebraic using y" y = np.concatenate([y0_diff, y_alg]) out = algebraic(t, y) @@ -114,7 +114,7 @@ def root_fun(y_alg): if jac: if issparse(jac(t_eval[0], y0, inputs)): - def jac_fn(y_alg): + def jac_fn(y_alg, jac=jac): """ Evaluates Jacobian using y0_diff (fixed) and y_alg (varying) """ @@ -123,7 +123,7 @@ def jac_fn(y_alg): else: - def jac_fn(y_alg): + def jac_fn(y_alg, jac=jac): """ Evaluates Jacobian using y0_diff (fixed) and y_alg (varying) """ @@ -168,7 +168,7 @@ def root_norm(y): jac_norm = None else: - def jac_norm(y): + def jac_norm(y, jac_fn=jac_fn): return np.sum(2 * root_fun(y) * jac_fn(y), 0) if self.method == "minimize": diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 5a636b89bb..a2b4c305c2 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -51,7 +51,7 @@ def __init__( root_method=None, root_tol=1e-6, extrap_tol=None, - output_variables=[], + output_variables=None, ): self.method = method self.rtol = rtol @@ -59,7 +59,7 @@ def __init__( self.root_tol = root_tol self.root_method = root_method self.extrap_tol = extrap_tol or -1e-10 - self.output_variables = output_variables + self.output_variables = [] if output_variables is None else output_variables self._model_set_up = {} # Defaults, can be overwritten by specific solver @@ -339,7 +339,7 @@ def _check_and_prepare_model_inplace(self, model, inputs, ics_only): raise pybamm.DiscretisationError( "Cannot automatically discretise model, " f"model should be discretised before solving ({e})" - ) + ) from e if ( isinstance(self, (pybamm.CasadiSolver, pybamm.CasadiAlgebraicSolver)) @@ -684,7 +684,9 @@ def calculate_consistent_state(self, model, time=0, inputs=None): try: root_sol = self.root_method._integrate(model, np.array([time]), inputs) except pybamm.SolverError as e: - raise pybamm.SolverError(f"Could not find consistent states: {e.args[0]}") + raise pybamm.SolverError( + f"Could not find consistent states: {e.args[0]}" + ) from e pybamm.logger.debug("Found consistent states") self.check_extrapolation(root_sol, model.events) @@ -1354,6 +1356,7 @@ def check_extrapolation(self, solution, events): f"While solving {name} extrapolation occurred " f"for {extrap_events}", pybamm.SolverWarning, + stacklevel=2, ) # Add the event dictionaryto the solution object solution.extrap_events = extrap_events diff --git a/pybamm/solvers/casadi_algebraic_solver.py b/pybamm/solvers/casadi_algebraic_solver.py index 40cb2130ca..4c7f1dd290 100644 --- a/pybamm/solvers/casadi_algebraic_solver.py +++ b/pybamm/solvers/casadi_algebraic_solver.py @@ -110,7 +110,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): timer = pybamm.Timer() integration_time = 0 - for idx, t in enumerate(t_eval): + for _, t in enumerate(t_eval): # Solve try: timer.reset() diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index 02ff4a2cd9..daa5233624 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -276,7 +276,9 @@ def _integrate(self, model, t_eval, inputs_dict=None): "time steps or period of the experiment." ) if first_ts_solved and self.return_solution_if_failed_early: - warnings.warn(message, pybamm.SolverWarning) + warnings.warn( + message, pybamm.SolverWarning, stacklevel=2 + ) termination_due_to_small_dt = True break else: @@ -285,7 +287,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): + " Set `return_solution_if_failed_early=True` to " "return the solution object up to the point where " "failure occured." - ) + ) from error if termination_due_to_small_dt: break # Check if the sign of an event changes, if so find an accurate @@ -360,7 +362,7 @@ def find_t_event(sol, typ): # Evaluations of the "event" function are (relatively) expensive f_eval = {} - def f(idx): + def f(idx, f_eval=f_eval, event=event): try: return f_eval[idx] except KeyError: @@ -682,7 +684,7 @@ def _run_integrator( except RuntimeError as error: # If it doesn't work raise error pybamm.logger.debug(f"Casadi integrator failed with error {error}") - raise pybamm.SolverError(error.args[0]) + raise pybamm.SolverError(error.args[0]) from error pybamm.logger.debug("Finished casadi integrator") integration_time = timer.time() # Manually add initial conditions and concatenate @@ -720,7 +722,7 @@ def _run_integrator( except RuntimeError as error: # If it doesn't work raise error pybamm.logger.debug(f"Casadi integrator failed with error {error}") - raise pybamm.SolverError(error.args[0]) + raise pybamm.SolverError(error.args[0]) from error integration_time = timer.time() x = casadi_sol["xf"] z = casadi_sol["zf"] diff --git a/pybamm/solvers/idaklu_jax.py b/pybamm/solvers/idaklu_jax.py index a213eb41fa..9a5315e120 100644 --- a/pybamm/solvers/idaklu_jax.py +++ b/pybamm/solvers/idaklu_jax.py @@ -344,10 +344,10 @@ class _hashabledict(dict): def __hash__(self): return hash(tuple(sorted(self.items()))) - @lru_cache(maxsize=1) + @lru_cache(maxsize=1) # noqa: B019 def _cached_solve(self, model, t_hashable, *args, **kwargs): """Cache the last solve for reuse""" - return self.solver.solve(model, t_hashable, *args, **kwargs) + return self.solve(model, t_hashable, *args, **kwargs) def _jaxify_solve(self, t, invar, *inputs_values): """Solve the model using the IDAKLU solver @@ -370,7 +370,8 @@ def _jaxify_solve(self, t, invar, *inputs_values): logger.debug(f" invar: {invar}") logger.debug(f" inputs: {dict(d)}") logger.debug(f" calculate_sensitivities: {invar is not None}") - sim = self._cached_solve( + sim = IDAKLUJax._cached_solve( + self.solver, self.jax_model, tuple(self.jax_t_eval), inputs=self._hashabledict(d), @@ -572,6 +573,7 @@ def jaxify( "JAX expression has already been created. " "Overwriting with new expression.", UserWarning, + stacklevel=2, ) self.jaxpr = self._jaxify( model, diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index e9976fc28c..fef4cbce3c 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -89,7 +89,7 @@ def __init__( root_method="casadi", root_tol=1e-6, extrap_tol=None, - output_variables=[], + output_variables=None, options=None, ): # set default options, @@ -112,7 +112,7 @@ def __init__( options[key] = value self._options = options - self.output_variables = output_variables + self.output_variables = [] if output_variables is None else output_variables if idaklu_spec is None: # pragma: no cover raise ImportError("KLU is not installed") @@ -649,7 +649,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): number_of_samples = sol.y.shape[0] // number_of_timesteps sol.y = sol.y.reshape((number_of_timesteps, number_of_samples)) startk = 0 - for vark, var in enumerate(self.output_variables): + for _, var in enumerate(self.output_variables): # ExplicitTimeIntegral's are not computed as part of the solver and # do not need to be converted if isinstance( diff --git a/pybamm/solvers/processed_variable_computed.py b/pybamm/solvers/processed_variable_computed.py index fd17dfab7b..a069342254 100644 --- a/pybamm/solvers/processed_variable_computed.py +++ b/pybamm/solvers/processed_variable_computed.py @@ -149,8 +149,10 @@ def unroll_1D(self, realdata=None): .transpose() ) - def unroll_2D(self, realdata=None, n_dim1=None, n_dim2=None, axis_swaps=[]): + def unroll_2D(self, realdata=None, n_dim1=None, n_dim2=None, axis_swaps=None): # initialise settings on first run + if axis_swaps is None: + axis_swaps = [] if not self.unroll_params: self.unroll_params["n_dim1"] = n_dim1 self.unroll_params["n_dim2"] = n_dim2 diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index 735dd8e396..b456bfb2d0 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -297,11 +297,11 @@ def set_y(self): self._y = casadi.horzcat(*self.all_ys) else: self._y = np.hstack(self.all_ys) - except ValueError: + except ValueError as error: raise pybamm.SolverError( "The solution is made up from different models, so `y` cannot be " "computed explicitly." - ) + ) from error def check_ys_are_not_too_large(self): # Only check last one so that it doesn't take too long @@ -978,11 +978,11 @@ def _get_cycle_summary_variables(cycle_solution, esoh_solver): try: esoh_sol = esoh_solver.solve(inputs) - except pybamm.SolverError: # pragma: no cover + except pybamm.SolverError as error: # pragma: no cover raise pybamm.SolverError( "Could not solve for summary variables, run " "`sim.solve(calc_esoh=False)` to skip this step" - ) + ) from error cycle_summary_variables.update(esoh_sol) diff --git a/pybamm/spatial_methods/scikit_finite_element.py b/pybamm/spatial_methods/scikit_finite_element.py index e65e29f7f8..e212ef71f7 100644 --- a/pybamm/spatial_methods/scikit_finite_element.py +++ b/pybamm/spatial_methods/scikit_finite_element.py @@ -275,10 +275,10 @@ def stiffness_form(u, v, w): try: _, neg_bc_type = boundary_conditions[symbol]["negative tab"] _, pos_bc_type = boundary_conditions[symbol]["positive tab"] - except KeyError: + except KeyError as error: raise pybamm.ModelError( f"No boundary conditions provided for symbol `{symbol}``" - ) + ) from error # adjust matrix for Dirichlet boundary conditions if neg_bc_type == "Dirichlet": diff --git a/pybamm/util.py b/pybamm/util.py index b9d0fc5124..1ecd7dbaba 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -58,12 +58,13 @@ def get_best_matches(self, key): def __getitem__(self, key): try: return super().__getitem__(key) - except KeyError: + except KeyError as error: if "particle diffusivity" in key: warn( f"The parameter '{key.replace('particle', 'electrode')}' " f"has been renamed to '{key}'", DeprecationWarning, + stacklevel=2, ) return super().__getitem__(key.replace("particle", "electrode")) if key in ["Negative electrode SOC", "Positive electrode SOC"]: @@ -72,7 +73,7 @@ def __getitem__(self, key): f"Variable '{domain} electrode SOC' has been renamed to " f"'{domain} electrode stoichiometry' to avoid confusion " "with cell SOC" - ) + ) from error if "Measured open circuit voltage" in key: raise KeyError( "The variable that used to be called " @@ -81,26 +82,28 @@ def __getitem__(self, key): "variable called 'Bulk open-circuit voltage [V]' which is the" "open-circuit voltage evaluated at the average particle " "concentrations." - ) + ) from error if "Open-circuit voltage at 0% SOC [V]" in key: raise KeyError( "Parameter 'Open-circuit voltage at 0% SOC [V]' not found." "In most cases this should be set to be equal to " "'Lower voltage cut-off [V]'" - ) + ) from error if "Open-circuit voltage at 100% SOC [V]" in key: raise KeyError( "Parameter 'Open-circuit voltage at 100% SOC [V]' not found." "In most cases this should be set to be equal to " "'Upper voltage cut-off [V]'" - ) + ) from error best_matches = self.get_best_matches(key) for k in best_matches: if key in k and k.endswith("]"): raise KeyError( f"'{key}' not found. Use the dimensional version '{k}' instead." - ) - raise KeyError(f"'{key}' not found. Best matches are {best_matches}") + ) from error + raise KeyError( + f"'{key}' not found. Best matches are {best_matches}" + ) from error def search(self, key, print_values=False): """ @@ -341,7 +344,7 @@ def install_jax(arguments=None): # pragma: no cover "pybamm_install_jax is deprecated," " use 'pip install pybamm[jax]' to install jax & jaxlib" ) - warn(msg, DeprecationWarning) + warn(msg, DeprecationWarning, stacklevel=2) subprocess.check_call( [ sys.executable, @@ -369,5 +372,7 @@ def import_optional_dependency(module_name, attribute=None): else: # Return the entire module if no attribute is specified return module - except ModuleNotFoundError: - raise ModuleNotFoundError(err_msg) + + except ModuleNotFoundError as error: + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(err_msg) from error diff --git a/pyproject.toml b/pyproject.toml index 192d0645c7..b5f875d49b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,7 @@ extend-exclude = ["__init__.py"] [tool.ruff.lint] extend-select = [ - # "B", # flake8-bugbear + "B", # flake8-bugbear # "I", # isort # "ARG", # flake8-unused-arguments # "C4", # flake8-comprehensions diff --git a/scripts/install_KLU_Sundials.py b/scripts/install_KLU_Sundials.py index cd521f431f..9929a250a2 100755 --- a/scripts/install_KLU_Sundials.py +++ b/scripts/install_KLU_Sundials.py @@ -260,12 +260,12 @@ def parallel_download(urls, download_dir): # First check requirements: make and cmake try: subprocess.run(["make", "--version"]) -except OSError: - raise RuntimeError("Make must be installed.") +except OSError as error: + raise RuntimeError("Make must be installed.") from error try: subprocess.run(["cmake", "--version"]) -except OSError: - raise RuntimeError("CMake must be installed.") +except OSError as error: + raise RuntimeError("CMake must be installed.") from error # Build in parallel wherever possible os.environ["CMAKE_BUILD_PARALLEL_LEVEL"] = str(cpu_count()) diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 43eba8894e..ac06f95c9f 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -170,7 +170,7 @@ def test_serialisation(self, solver=None, t_eval=None): new_solution = new_solver.solve(new_model, t_eval) - for x, val in enumerate(self.solution.all_ys): + for x, _ in enumerate(self.solution.all_ys): np.testing.assert_array_almost_equal( new_solution.all_ys[x], self.solution.all_ys[x], decimal=accuracy ) diff --git a/tests/shared.py b/tests/shared.py index 1f0b033582..c863cac175 100644 --- a/tests/shared.py +++ b/tests/shared.py @@ -153,8 +153,10 @@ def get_2p1d_mesh_for_testing( ypts=15, zpts=15, include_particles=True, - cc_submesh=pybamm.MeshGenerator(pybamm.ScikitUniform2DSubMesh), + cc_submesh=None, ): + if cc_submesh is None: + cc_submesh = pybamm.MeshGenerator(pybamm.ScikitUniform2DSubMesh) geometry = pybamm.battery_geometry( include_particles=include_particles, options={"dimensionality": 2} ) diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 0bc9138de7..b180c86990 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -121,7 +121,7 @@ def test_symbol_methods(self): self.assertIsInstance(-a, pybamm.Negate) self.assertIsInstance(abs(a), pybamm.AbsoluteValue) # special cases - self.assertEqual(-(-a), a) + self.assertEqual(-(-a), a) # noqa: B002 self.assertEqual(-(a - b), b - a) self.assertEqual(abs(abs(a)), abs(a)) diff --git a/tests/unit/test_parameters/test_base_parameters.py b/tests/unit/test_parameters/test_base_parameters.py index 5ca21b1c7c..6c87cdcd88 100644 --- a/tests/unit/test_parameters/test_base_parameters.py +++ b/tests/unit/test_parameters/test_base_parameters.py @@ -12,23 +12,23 @@ def test_getattr__(self): param = pybamm.LithiumIonParameters() # ending in _n / _s / _p with self.assertRaisesRegex(AttributeError, "param.n.L"): - getattr(param, "L_n") + param.L_n with self.assertRaisesRegex(AttributeError, "param.s.L"): - getattr(param, "L_s") + param.L_s with self.assertRaisesRegex(AttributeError, "param.p.L"): - getattr(param, "L_p") + param.L_p # _n_ in the name with self.assertRaisesRegex(AttributeError, "param.n.prim.c_max"): - getattr(param, "c_n_max") + param.c_n_max # _n_ or _p_ not in name with self.assertRaisesRegex( AttributeError, "has no attribute 'c_n_not_a_parameter" ): - getattr(param, "c_n_not_a_parameter") + param.c_n_not_a_parameter with self.assertRaisesRegex(AttributeError, "has no attribute 'c_s_test"): - getattr(pybamm.electrical_parameters, "c_s_test") + pybamm.electrical_parameters.c_s_test self.assertEqual(param.n.cap_init, param.n.Q_init) self.assertEqual(param.p.prim.cap_init, param.p.prim.Q_init) diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index c0c1bd22cb..e7dcba6702 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -139,7 +139,7 @@ def test_user_defined_model_recreaction(self): new_solver = pybamm.ScipySolver() new_solution = new_solver.solve(new_model, t) - for x, val in enumerate(solution.all_ys): + for x, _ in enumerate(solution.all_ys): np.testing.assert_array_almost_equal( solution.all_ys[x], new_solution.all_ys[x] ) diff --git a/tests/unit/test_solvers/test_idaklu_jax.py b/tests/unit/test_solvers/test_idaklu_jax.py index c2d29c31ae..b1311001aa 100644 --- a/tests/unit/test_solvers/test_idaklu_jax.py +++ b/tests/unit/test_solvers/test_idaklu_jax.py @@ -862,7 +862,7 @@ def sse(t, inputs): # Check grad against actual sse_grad_actual = {} - for k, v in inputs_pred.items(): + for k, _ in inputs_pred.items(): sse_grad_actual[k] = 2 * np.sum( (pred(t_eval) - data) * pred.sensitivities[k] ) diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index c6b482e3f9..1a84f2bea4 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -125,7 +125,7 @@ def test_solver_sensitivities(self): solve = solver.get_solve(model, t_eval) # create a dummy "model" where we calculate the sum of the time series - def solve_model(rate): + def solve_model(rate, solve=solve): return jax.numpy.sum(solve({"rate": rate})) # check answers with finite difference diff --git a/tests/unit/test_solvers/test_processed_variable.py b/tests/unit/test_solvers/test_processed_variable.py index d8b4ccfd0c..4cf3f9392e 100644 --- a/tests/unit/test_solvers/test_processed_variable.py +++ b/tests/unit/test_solvers/test_processed_variable.py @@ -28,9 +28,11 @@ def to_casadi(var_pybamm, y, inputs=None): def process_and_check_2D_variable( - var, first_spatial_var, second_spatial_var, disc=None, geometry_options={} + var, first_spatial_var, second_spatial_var, disc=None, geometry_options=None ): # first_spatial_var should be on the "smaller" domain, i.e "r" for an "r-x" variable + if geometry_options is None: + geometry_options = {} if disc is None: disc = tests.get_discretisation_for_testing() disc.set_variable_slices([var]) diff --git a/tests/unit/test_solvers/test_processed_variable_computed.py b/tests/unit/test_solvers/test_processed_variable_computed.py index b5f105b34b..7e0616c81b 100644 --- a/tests/unit/test_solvers/test_processed_variable_computed.py +++ b/tests/unit/test_solvers/test_processed_variable_computed.py @@ -32,9 +32,11 @@ def to_casadi(var_pybamm, y, inputs=None): def process_and_check_2D_variable( - var, first_spatial_var, second_spatial_var, disc=None, geometry_options={} + var, first_spatial_var, second_spatial_var, disc=None, geometry_options=None ): # first_spatial_var should be on the "smaller" domain, i.e "r" for an "r-x" variable + if geometry_options is None: + geometry_options = {} if disc is None: disc = tests.get_discretisation_for_testing() disc.set_variable_slices([var])