Source code for pybamm.parameters.parameter_store

from __future__ import annotations

import re
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal
from warnings import warn

import pybamm

if TYPE_CHECKING:
    from collections.abc import Mapping


[docs] class ParameterCategory(Enum): """Categories for grouping battery parameters.""" NEGATIVE_ELECTRODE = "negative electrode" POSITIVE_ELECTRODE = "positive electrode" SEPARATOR = "separator" ELECTROLYTE = "electrolyte" THERMAL = "thermal" KINETIC = "kinetic" GEOMETRIC = "geometric" ELECTRICAL = "electrical" OTHER = "other"
[docs] @dataclass class ParameterInfo: """ Metadata about a parameter. Attributes ---------- name : str The parameter name (key). value : Any The parameter value. units : str | None Units parsed from the parameter name (e.g., "K" from "Temperature [K]"). category : str | None Category of the parameter (e.g., "negative electrode", "thermal"). is_function : bool True if the value is callable. is_input : bool True if the value is an InputParameter. """ name: str value: Any units: str | None category: str | None is_function: bool is_input: bool
[docs] @dataclass class ParameterDiff: """ Result of comparing two parameter sets. Attributes ---------- added : dict[str, Any] Parameters in `other` but not in `self`. removed : dict[str, Any] Parameters in `self` but not in `other`. changed : dict[str, tuple[Any, Any]] Parameters with different values: (self_value, other_value). """ added: dict[str, Any] removed: dict[str, Any] changed: dict[str, tuple[Any, Any]]
# Regex to parse units from parameter names like "Temperature [K]" _UNITS_RE = re.compile(r"\[([^\]]+)\]\s*$") # Keywords for category detection _CATEGORY_KEYWORDS: dict[ParameterCategory, list[str]] = { ParameterCategory.NEGATIVE_ELECTRODE: ["negative electrode", "negative particle"], ParameterCategory.POSITIVE_ELECTRODE: ["positive electrode", "positive particle"], ParameterCategory.SEPARATOR: ["separator"], ParameterCategory.ELECTROLYTE: ["electrolyte"], ParameterCategory.THERMAL: [ "thermal", "temperature", "heat", "cooling", "conductivity", ], ParameterCategory.KINETIC: [ "exchange-current", "reaction", "kinetic", "transfer coefficient", ], ParameterCategory.GEOMETRIC: [ "thickness", "length", "width", "height", "radius", "area", "volume", "porosity", ], ParameterCategory.ELECTRICAL: [ "conductivity", "current", "voltage", "capacity", "resistance", ], } def _parse_units(name: str) -> str | None: """Extract units from a parameter name like 'Temperature [K]' -> 'K'.""" match = _UNITS_RE.search(name) return match.group(1) if match else None def _detect_category(name: str) -> str | None: """Detect the category of a parameter based on keywords in its name.""" name_lower = name.lower() for category, keywords in _CATEGORY_KEYWORDS.items(): for keyword in keywords: if keyword in name_lower: return category.value return ParameterCategory.OTHER.value
[docs] class ParameterStore: """ Manages parameter key-value storage with FuzzyDict lookup. This class provides a clean interface for storing and retrieving parameters with explicit control over update behavior. Parameters ---------- initial_data : dict | None Initial parameter data. If None, starts empty. Examples -------- >>> store = ParameterStore({"Temperature [K]": 298.15}) >>> store["Temperature [K]"] 298.15 >>> store.set("New param", 42, allow_new=True) >>> store.get_info("Temperature [K]") ParameterInfo(name='Temperature [K]', value=298.15, units='K', ...) """ def __init__(self, initial_data: dict[str, Any] | None = None) -> None: self._data: pybamm.FuzzyDict = pybamm.FuzzyDict(initial_data or {}) def __getitem__(self, key: str) -> Any: """Get a parameter value by key.""" try: return self._data[key] except KeyError as e: # Re-raise with more context raise KeyError( f"Parameter '{key}' not found. {e.args[0] if e.args else ''}" ) from e def __setitem__(self, key: str, value: Any) -> None: """Set a parameter value (always allows new parameters).""" self._data[key] = value def __delitem__(self, key: str) -> None: """Delete a parameter.""" del self._data[key] def __contains__(self, key: str) -> bool: """Check if a parameter exists.""" return key in self._data def __iter__(self) -> Iterator[str]: """Iterate over parameter keys.""" return iter(self._data) def __len__(self) -> int: """Return the number of parameters.""" return len(self._data)
[docs] def get(self, key: str, default: Any = None) -> Any: """ Get a parameter value, returning default if not found. Example ------- >>> store = ParameterStore({"a": 1}) >>> store.get("a") 1 >>> store.get("missing", default=0) 0 """ try: return self._data[key] except KeyError: return default
[docs] def set(self, key: str, value: Any, *, allow_new: bool = True) -> None: """ Set a parameter value. Parameters ---------- key : str Parameter name. value : Any Parameter value. allow_new : bool If False, raises KeyError when key doesn't exist. If True (default), allows adding new parameters. Raises ------ KeyError If allow_new=False and the key doesn't exist. Example ------- >>> store = ParameterStore({"a": 1}) >>> store.set("a", 10) # Update existing >>> store.set("b", 2) # Add new (allow_new=True by default) """ if not allow_new and key not in self._data: best_matches = self._data.get_best_matches(key) raise KeyError( f"Parameter '{key}' does not exist. " f"Use allow_new=True to add new parameters. " f"Best matches: {best_matches}" ) self._data[key] = value
[docs] def update( self, values: Mapping[str, Any], *, allow_new: bool = True, conflict: Literal["raise", "warn", "ignore"] = "ignore", ) -> None: """ Bulk update parameters. Parameters ---------- values : Mapping[str, Any] Dictionary of parameter values to update. allow_new : bool If False, raises KeyError for unknown parameters. If True (default), allows adding new parameters. conflict : {"raise", "warn", "ignore"} How to handle conflicts when a parameter already exists with a different value: - "raise": Raise ValueError - "warn": Emit a warning and update - "ignore": Silently update (default) Example ------- >>> store = ParameterStore({"a": 1, "b": 2}) >>> store.update({"a": 10, "c": 3}) >>> store["a"], store["c"] (10, 3) """ for key, value in values.items(): # Check if key exists if not allow_new and key not in self._data: best_matches = self._data.get_best_matches(key) raise KeyError( f"Parameter '{key}' does not exist. " f"Use allow_new=True to add new parameters. " f"Best matches: {best_matches}" ) # Check for conflicts if conflict != "ignore" and key in self._data: existing = self._data[key] if existing != value: msg = ( f"Parameter '{key}' already exists with value " f"'{existing}', updating to '{value}'" ) if conflict == "raise": raise ValueError(msg) elif conflict == "warn": warn(msg, stacklevel=2) self._data[key] = value
[docs] def keys(self): """Return parameter keys.""" return self._data.keys()
[docs] def values(self): """Return parameter values.""" return self._data.values()
[docs] def items(self): """Return parameter items.""" return self._data.items()
[docs] def pop(self, key: str, *args) -> Any: """ Remove and return a parameter value. Example ------- >>> store = ParameterStore({"a": 1, "b": 2}) >>> store.pop("a") 1 >>> "a" in store False """ return self._data.pop(key, *args)
[docs] def copy(self) -> ParameterStore: """ Return a shallow copy of the store. Example ------- >>> store = ParameterStore({"a": 1}) >>> store_copy = store.copy() >>> store_copy["a"] = 99 >>> store["a"] # Original unchanged 1 """ return ParameterStore(dict(self._data))
[docs] def search(self, key: str, print_values: bool = True) -> None: """ Search for parameters containing the given key. Example ------- >>> store = ParameterStore({"Temperature [K]": 298.15, "Voltage [V]": 3.7}) >>> store.search("Temperature", print_values=False) Results for 'Temperature': ... """ return self._data.search(key, print_values)
[docs] def get_info(self, key: str) -> ParameterInfo: """ Get metadata about a parameter. Parameters ---------- key : str The parameter name. Returns ------- ParameterInfo Metadata including value, units, category, and type information. Examples -------- >>> store = ParameterStore({"Maximum concentration [mol.m-3]": 51765}) >>> info = store.get_info("Maximum concentration [mol.m-3]") >>> info.units 'mol.m-3' >>> info.is_function False """ value = self[key] return ParameterInfo( name=key, value=value, units=_parse_units(key), category=_detect_category(key), is_function=callable(value), is_input=isinstance(value, pybamm.InputParameter), )
[docs] def list_by_category(self, category: ParameterCategory | str) -> list[str]: """ Return all parameter names in a given category. Parameters ---------- category : ParameterCategory or str The category to filter by. Can be a ParameterCategory enum value or a string like "negative electrode". Returns ------- list[str] List of parameter names in the category. Example ------- >>> store = ParameterStore({"Negative electrode thickness [m]": 1e-4}) >>> store.list_by_category("negative electrode") ['Negative electrode thickness [m]'] """ if isinstance(category, ParameterCategory): category_str = category.value else: category_str = category.lower() return [key for key in self._data if _detect_category(key) == category_str]
[docs] def categories(self) -> dict[str, list[str]]: """ Return all parameters grouped by category. Returns ------- dict[str, list[str]] Dictionary mapping category names to lists of parameter names. Example ------- >>> store = ParameterStore({"Temperature [K]": 298.15}) >>> cats = store.categories() >>> "thermal" in cats True """ result: dict[str, list[str]] = {} for key in self._data: cat = _detect_category(key) or ParameterCategory.OTHER.value if cat not in result: result[cat] = [] result[cat].append(key) return result
[docs] def diff(self, other: ParameterStore, *, rtol: float = 0.0) -> ParameterDiff: """ Compare this store with another and return differences. Parameters ---------- other : ParameterStore The other parameter store to compare against. rtol : float, optional Relative tolerance for numerical comparisons. Differences smaller than ``rtol * max(|a|, |b|)`` are ignored. Default is 0.0 (exact comparison). Set to e.g. 1e-6 to ignore tiny floating-point differences. Returns ------- ParameterDiff Object containing added, removed, and changed parameters. Examples -------- >>> store1 = ParameterStore({"a": 1, "b": 2}) >>> store2 = ParameterStore({"b": 3, "c": 4}) >>> diff = store1.diff(store2) >>> diff.added {'c': 4} >>> diff.removed {'a': 1} >>> diff.changed {'b': (2, 3)} With tolerance to ignore small differences: >>> store1 = ParameterStore({"x": 1.0}) >>> store2 = ParameterStore({"x": 1.0 + 1e-10}) >>> diff = store1.diff(store2, rtol=1e-9) >>> diff.changed # Empty because difference is within tolerance {} """ self_keys = set(self._data.keys()) other_keys = set(other._data.keys()) added = {k: other._data[k] for k in other_keys - self_keys} removed = {k: self._data[k] for k in self_keys - other_keys} changed = {} for key in self_keys & other_keys: self_val = self._data[key] other_val = other._data[key] if not _values_equal(self_val, other_val, rtol=rtol): changed[key] = (self_val, other_val) return ParameterDiff(added=added, removed=removed, changed=changed)
[docs] def to_dict(self) -> dict[str, Any]: """ Return a plain dictionary copy of the parameters. Example ------- >>> store = ParameterStore({"a": 1, "b": 2}) >>> store.to_dict() {'a': 1, 'b': 2} """ return dict(self._data)
def _values_equal(a: Any, b: Any, *, rtol: float = 0.0) -> bool: """ Compare two values for equality, handling special cases. Parameters ---------- a, b : Any Values to compare. rtol : float Relative tolerance for numerical comparisons. """ import numpy as np # Handle numpy arrays if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): try: if rtol > 0: return np.allclose(a, b, rtol=rtol, atol=0) return np.array_equal(a, b) except (TypeError, ValueError): return False # Handle numeric types with tolerance if isinstance(a, (int, float)) and isinstance(b, (int, float)): if rtol > 0: # Check if difference is within relative tolerance max_val = max(abs(a), abs(b)) if max_val == 0: return a == b return abs(a - b) <= rtol * max_val return a == b # Handle callables - compare by identity if callable(a) and callable(b): return a is b # Handle pybamm symbols if isinstance(a, pybamm.Symbol) and isinstance(b, pybamm.Symbol): return a == b # Default comparison try: return a == b except (TypeError, ValueError): return False