Source code for pybamm.util

from __future__ import annotations

import difflib
import importlib.metadata
import importlib.util
import numbers
import os
import pathlib
import pickle
import timeit
from warnings import warn

import pybamm


[docs] def root_dir(): """return the root directory of the PyBaMM install directory""" return str(pathlib.Path(pybamm.__path__[0]).parent.parent)
[docs] class FuzzyDict(dict):
[docs] def get_best_matches(self, key): """Get best matches from keys""" return difflib.get_close_matches(key, list(self.keys()), n=3, cutoff=0.5)
def __getitem__(self, key): try: return super().__getitem__(key) except KeyError as error: if "electrode diffusivity" in key or "particle diffusivity" in key: old_term, new_term = ( ("electrode", "particle") if "electrode diffusivity" in key else ("particle", "electrode") ) alternative_key = key.replace(old_term, new_term) if old_term == "electrode": warn( f"The parameter '{alternative_key}' has been renamed to '{key}' and will be removed in a future release. Using '{key}'", DeprecationWarning, stacklevel=2, ) return super().__getitem__(alternative_key) if key in ["Negative electrode SOC", "Positive electrode SOC"]: domain = key.split(" ")[0] raise KeyError( f"Variable '{domain} electrode SOC' has been renamed to " f"'{domain} electrode stoichiometry' to avoid confusion " "with cell SOC" ) from error if "Measured open circuit voltage" in key: raise KeyError( "The variable that used to be called " "'Measured open circuit voltage [V]' is now called " "'Surface open-circuit voltage [V]'. There is also another " "variable called 'Bulk open-circuit voltage [V]' which is the" "open-circuit voltage evaluated at the average particle " "concentrations." ) from error if "Open-circuit voltage at 0% SOC [V]" in key: raise KeyError( "Parameter 'Open-circuit voltage at 0% SOC [V]' not found." "In most cases this should be set to be equal to " "'Lower voltage cut-off [V]'" ) from error if "Open-circuit voltage at 100% SOC [V]" in key: raise KeyError( "Parameter 'Open-circuit voltage at 100% SOC [V]' not found." "In most cases this should be set to be equal to " "'Upper voltage cut-off [V]'" ) from error best_matches = self.get_best_matches(key) for k in best_matches: if key in k and k.endswith("]") and not key.endswith("]"): raise KeyError( f"'{key}' not found. Use the dimensional version '{k}' instead." ) from error elif key in k and ( k.startswith("Primary") or k.startswith("Secondary") ): raise KeyError( f"'{key}' not found. If you are using a composite model, you may need to use {k} instead. Otherwise, best matches are {best_matches}" ) from error raise KeyError( f"'{key}' not found. Best matches are {best_matches}" ) from error def _find_matches( self, search_key: str, known_keys: list[str], min_similarity: float = 0.4 ): """ Helper method to find exact and partial matches for a given search key. Parameters ---------- search_key : str The term to search for in the keys. known_keys : list of str The list of known dictionary keys to search within. min_similarity : float, optional The minimum similarity threshold for a match. Default is 0.4 """ search_key = search_key.lower() exact_matches = [] partial_matches = [] for key in known_keys: key_lower = key.lower() if search_key in key_lower: key_words = key_lower.split() for word in key_words: similarity = difflib.SequenceMatcher(None, search_key, word).ratio() if similarity >= min_similarity: exact_matches.append(key) else: partial_matches = difflib.get_close_matches( search_key, known_keys, n=5, cutoff=0.5 ) return exact_matches, partial_matches
[docs] def search( self, keys: str | list[str], print_values: bool = False, min_similarity: float = 0.4, ): """ Search dictionary for keys containing all terms in 'keys'. If print_values is True, both the keys and values will be printed. Otherwise, just the keys will be printed. If no results are found, the best matches are printed. Parameters ---------- keys : str or list of str Search term(s) print_values : bool, optional If True, print both keys and values. Otherwise, print only keys. Default is False. min_similarity : float, optional The minimum similarity threshold for a match. Default is 0.4 """ if not isinstance(keys, str | list) or not all( isinstance(k, str) for k in keys ): msg = f"'keys' must be a string or a list of strings, got {type(keys)}" raise TypeError(msg) if isinstance(keys, str): if not keys.strip(): msg = "The search term cannot be an empty or whitespace-only string" raise ValueError(msg) original_keys = [keys] search_keys = [keys.strip().lower()] elif isinstance(keys, list): if all(not str(k).strip() for k in keys): msg = "The 'keys' list cannot contain only empty or whitespace strings" raise ValueError(msg) original_keys = keys search_keys = [k.strip().lower() for k in keys if k.strip()] known_keys = list(self.keys()) # Check for exact matches where all search keys appear together in a key exact_matches = [] for key in known_keys: key_lower = key.lower() if all(term in key_lower for term in search_keys): key_words = key_lower.split() # Ensure all search terms match at least one word in the key if all( any( difflib.SequenceMatcher(None, term, word).ratio() >= min_similarity for word in key_words ) for term in search_keys ): exact_matches.append(key) if exact_matches: print( f"Results for '{' '.join(k for k in original_keys if k.strip())}': {exact_matches}" ) if print_values: for match in exact_matches: print(f"{match} -> {self[match]}") return # If no exact matches, iterate over search keys individually for original_key, search_key in zip(original_keys, search_keys, strict=False): exact_key_matches, partial_matches = self._find_matches( search_key, known_keys, min_similarity ) if exact_key_matches: print(f"Exact matches for '{original_key}': {exact_key_matches}") if print_values: for match in exact_key_matches: print(f"{match} -> {self[match]}") else: if partial_matches: print( f"No exact matches found for '{original_key}'. Best matches are: {partial_matches}" ) else: print(f"No matches found for '{original_key}'")
[docs] def copy(self): return FuzzyDict(super().copy())
[docs] class Timer: """ Provides accurate timing. Example ------- timer = pybamm.Timer() print(timer.time()) """ def __init__(self): self._start = timeit.default_timer()
[docs] def reset(self): """ Resets this timer's start time. """ self._start = timeit.default_timer()
[docs] def time(self): """ Returns the time (float, in seconds) since this timer was created, or since meth:`reset()` was last called. """ return TimerTime(timeit.default_timer() - self._start)
[docs] class TimerTime: def __init__(self, value): """A string whose value prints in human-readable form""" self.value = value def __str__(self): """ Formats a (non-integer) number of seconds, returns a string like "5 weeks, 3 days, 1 hour, 4 minutes, 9 seconds", or "0.0019 seconds". """ time = self.value if time < 1e-6: return f"{time * 1e9:.3f} ns" if time < 1e-3: return f"{time * 1e6:.3f} us" if time < 1: return f"{time * 1e3:.3f} ms" elif time < 60: return f"{time:.3f} s" output = [] time = round(time) units = [(604800, "week"), (86400, "day"), (3600, "hour"), (60, "minute")] for k, name in units: f = time // k if f > 0 or output: output.append(str(f) + " " + (name if f == 1 else name + "s")) time -= f * k output.append("1 second" if time == 1 else str(time) + " seconds") return ", ".join(output) def __repr__(self): return f"pybamm.TimerTime({self.value})" def __add__(self, other): if isinstance(other, numbers.Number): return TimerTime(self.value + other) else: return TimerTime(self.value + other.value) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): if isinstance(other, numbers.Number): return TimerTime(self.value - other) else: return TimerTime(self.value - other.value) def __rsub__(self, other): if isinstance(other, numbers.Number): return TimerTime(other - self.value) def __mul__(self, other): if isinstance(other, numbers.Number): return TimerTime(self.value * other) else: return TimerTime(self.value * other.value) def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): if isinstance(other, numbers.Number): return TimerTime(self.value / other) else: return TimerTime(self.value / other.value) def __rtruediv__(self, other): if isinstance(other, numbers.Number): return TimerTime(other / self.value) def __eq__(self, other): return self.value == other.value
[docs] def load(filename): """Load a saved object""" with open(filename, "rb") as f: obj = pickle.load(f) return obj
def get_parameters_filepath(path): """Returns path if it exists in current working dir, otherwise get it from package dir""" if os.path.exists(path): return path else: return os.path.join(pybamm.__path__[0], path)
[docs] def has_jax(): """ Check if jax and jaxlib are installed with the correct versions Returns ------- bool True if jax and jaxlib are installed with the correct versions, False if otherwise """ return (importlib.util.find_spec("jax") is not None) and ( importlib.util.find_spec("jaxlib") is not None )
def is_constant_and_can_evaluate(symbol): """ Returns True if symbol is constant and evaluation does not raise any errors. Returns False otherwise. An example of a constant symbol that cannot be "evaluated" is PrimaryBroadcast(0). """ if symbol.is_constant(): try: symbol.evaluate() return True except NotImplementedError: return False else: return False def import_optional_dependency(module_name, attribute=None): err_msg = f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details." try: module = importlib.import_module(module_name) if attribute: if hasattr(module, attribute): imported_attribute = getattr(module, attribute) # Return the imported attribute return imported_attribute else: raise ModuleNotFoundError(err_msg) # pragma: no cover else: # Return the entire module if no attribute is specified return module except ModuleNotFoundError as error: # Raise an ModuleNotFoundError if the module or attribute is not available raise ModuleNotFoundError(err_msg) from error