from __future__ import annotations
import base64
import importlib
import inspect
import json
import numbers
import re
import warnings
import zlib
from datetime import datetime
from enum import Enum
from pathlib import Path
import black
import numpy as np
import pybamm
from pybamm.expression_tree.tracing import tracing
SUPPORTED_SCHEMA_VERSION = "1.1"
def _experiment_step_factories() -> dict:
"""Canonical map of serialised step ``"type"`` string to the
:mod:`pybamm.step` factory that reconstructs it.
Single source of truth shared between :meth:`Serialise.deserialise_experiment`
and the Hypothesis step strategy in ``tests/strategies/serialise_values.py``,
so the fuzzed step types cannot drift from the set the serialiser supports.
Defined as a function (not a module-level constant) because the
``pybamm.step`` factories are not yet bound when this module is first
imported during ``pybamm`` initialisation.
"""
return {
"current": pybamm.step.current,
"voltage": pybamm.step.voltage,
"power": pybamm.step.power,
"c-rate": pybamm.step.c_rate,
"rest": pybamm.step.rest,
"resistance": pybamm.step.resistance,
}
class ExpressionFunctionParameter(pybamm.UnaryOperator):
def __init__(self, name, child, func_name, func_args):
super().__init__(name, child)
self.func_name = func_name
self.func_args = func_args
def _unary_evaluate(self, child):
"""Evaluate the symbolic expression (the child)"""
return child
def to_json(self):
return {
"name": self.name,
"domains": self.domains,
"func_name": self.func_name,
"func_args": self.func_args,
}
@classmethod
def _from_json(cls, snippet):
return cls(
snippet["name"],
snippet["children"][0],
snippet["func_name"],
snippet["func_args"],
)
def to_source(self):
"""
Creates python source code for the function.
"""
src = f"def {self.func_name}({', '.join(self.func_args)}):\n"
# Fix printing of parameters so they print as Parameter('name'). Do this on a
# copy to avoid modifying the original expression.
expression = self.child.create_copy()
for child in expression.pre_order():
if isinstance(child, pybamm.Interpolant):
# Replace Interpolant with a constructor call that preserves all data
# This works for 1D, 2D, and 3D interpolants
# Format x arrays (list of arrays, one per dimension)
x_arrays = ", ".join(f"np.array({x.tolist()})" for x in child.x)
# Format y array (can be 1D, 2D, or 3D)
y_array = f"np.array({child.y.tolist()})"
# Get the input variable names from children (one per dimension)
if len(child.children) == 1:
# Single child - pass directly without brackets
input_vars = child.children[0].name
else:
# Multiple children - pass as list
input_vars = "[" + ", ".join(c.name for c in child.children) + "]"
# Build the full Interpolant constructor call
# Set _print_name directly to bypass prettify_print_name which
# mangles the output for LaTeX display
child._print_name = (
f"pybamm.Interpolant([{x_arrays}], {y_array}, {input_vars}, "
f'name="{child.name}", interpolator="{child.interpolator}", '
f"extrapolate={child.extrapolate})"
)
elif isinstance(child, pybamm.FunctionParameter):
# Replace FunctionParameter with a constructor call
# Build the inputs dict string mapping input names to actual parameter
# names
inputs_str = ", ".join(
f'"{input_name}": {child.children[i].name}'
for i, input_name in enumerate(child.input_names)
)
child.print_name = (
f'FunctionParameter("{child.name}", {{{inputs_str}}})'
)
elif (
isinstance(child, pybamm.Parameter) and child.name not in self.func_args
):
child.name = f'Parameter("{child.name}")'
src += f" return {expression.to_equation()}"
formatted_src = black.format_str(src, mode=black.FileMode())
return formatted_src
[docs]
class Serialise:
"""
Converts a discretised model to and from a JSON file.
"""
def __init__(self):
pass
[docs]
def serialise_model(
self,
model: pybamm.BaseModel,
mesh: pybamm.Mesh | None = None,
variables: None = None,
) -> dict:
"""Converts a discretised model to a JSON-serialisable dictionary.
As the model is discretised and ready to solve, only the right hand side,
algebraic and initial condition variables are serialised.
Parameters
----------
model : :class:`pybamm.BaseModel`
The discretised model to be serialised
mesh : :class:`pybamm.Mesh` (optional)
The mesh the model has been discretised over. Not necessary to solve
the model when read in, but required to use pybamm's plotting tools.
variables: None (optional)
This parameter is deprecated and enabled by default.
Returns
-------
dict
A JSON-serialisable dictionary representation of the model
"""
if model.is_discretised is False:
raise NotImplementedError(
"PyBaMM can only serialise a discretised, ready-to-solve model."
)
if variables is not None:
warnings.warn(
"The `variables` parameter is deprecated and will be removed in a future version. "
"Use `model._variables_processed` instead.",
DeprecationWarning,
stacklevel=2,
)
for k in model.variables.keys():
model.get_processed_variable(k)
variables_processed = model.get_processed_variables_dict()
from pybamm.expression_tree.operations.serialise_kernel import (
TAG,
_class_path,
encode,
)
model_json = {
TAG: _class_path(type(model)),
"pybamm_version": pybamm.__version__,
"name": model.name,
"options": model.options,
"bounds": [bound.tolist() for bound in model.bounds], # type: ignore[attr-defined]
"concatenated_rhs": encode(model._concatenated_rhs),
"concatenated_algebraic": encode(model._concatenated_algebraic),
"concatenated_initial_conditions": encode(
model._concatenated_initial_conditions
),
"events": [encode(event) for event in model.events],
"mass_matrix": encode(model.mass_matrix),
"_solution_observable": model._solution_observable.name,
}
if mesh:
model_json["mesh"] = encode(mesh)
if variables_processed:
variables_processed = dict(variables_processed)
if model._geometry:
model_json["geometry"] = self._deconstruct_pybamm_dicts(model._geometry)
model_json["_variables_processed"] = {
k: encode(v) for k, v in variables_processed.items()
}
return model_json
[docs]
def save_model(
self,
model: pybamm.BaseModel,
mesh: pybamm.Mesh | None = None,
variables: None = None,
filename: str | None = None,
):
"""Saves a discretised model to a JSON file.
As the model is discretised and ready to solve, only the right hand side,
algebraic and initial condition variables are saved.
Parameters
----------
model : :class:`pybamm.BaseModel`
The discretised model to be saved
mesh : :class:`pybamm.Mesh` (optional)
The mesh the model has been discretised over. Not neccesary to solve
the model when read in, but required to use pybamm's plotting tools.
variables: None (optional)
This parameter is deprecated and enabled by default.
filename: str (optional)
The desired name of the JSON file. If no name is provided, one will be
created based on the model name, and the current datetime.
"""
model_json = self.serialise_model(model, mesh, variables)
if filename is None:
filename = model.name + "_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M")
with open(filename + ".json", "w") as f:
json.dump(model_json, f)
@staticmethod
def _is_legacy_node(node) -> bool:
"""True if any node in the tree uses the legacy py/object tag."""
if isinstance(node, dict):
if "py/object" in node:
return True
return any(Serialise._is_legacy_node(v) for v in node.values())
if isinstance(node, list):
return any(Serialise._is_legacy_node(v) for v in node)
return False
@staticmethod
def _decode_model_node(node):
"""Decode one serialised model field through the kernel, relocating the
legacy py/object nested shapes first so old files load through the same
single decode path."""
from pybamm.expression_tree.operations.serialise_kernel import decode
if Serialise._is_legacy_node(node):
node = _relocate_legacy_model_tree(node)
return decode(node)
[docs]
def load_model(
self, filename: str | dict, battery_model: pybamm.BaseModel | None = None
) -> pybamm.BaseModel:
"""
Loads a discretised, ready to solve model into PyBaMM.
A new pybamm battery model instance will be created, which can be solved
and the results plotted as usual.
Currently only available for pybamm models which have previously been written
out using the `save_model()` option.
Warning: This only loads in discretised models. If you wish to make edits to the
model or initial conditions, a new model will need to be constructed seperately.
Parameters
----------
filename: str or dict
Path to the JSON file containing the serialised model file, or a dictionary
containing the serialised model data
battery_model: :class:`pybamm.BaseModel` (optional)
PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will
override any model names within the file. If None, the function will look
for the saved object path, present if the original model came from PyBaMM.
Returns
-------
:class:`pybamm.BaseModel`
A PyBaMM model object, of type specified either in the JSON or in
`battery_model`.
"""
if isinstance(filename, dict):
model_data = filename
else:
with open(filename) as f:
model_data = json.load(f)
recon_model_dict = {
"name": model_data["name"],
"options": self._convert_options(model_data["options"]),
"bounds": tuple(np.array(bound) for bound in model_data["bounds"]),
"concatenated_rhs": self._decode_model_node(model_data["concatenated_rhs"]),
"concatenated_algebraic": self._decode_model_node(
model_data["concatenated_algebraic"]
),
"concatenated_initial_conditions": self._decode_model_node(
model_data["concatenated_initial_conditions"]
),
"events": [self._decode_model_node(e) for e in model_data["events"]],
"mass_matrix": self._decode_model_node(model_data["mass_matrix"]),
}
recon_model_dict["geometry"] = (
self._reconstruct_pybamm_dict(model_data["geometry"])
if "geometry" in model_data
else None
)
recon_model_dict["mesh"] = (
self._decode_model_node(model_data["mesh"])
if "mesh" in model_data
else None
)
vars_processed_data = model_data.get("_variables_processed") or {}
recon_model_dict["_variables_processed"] = (
{k: self._decode_model_node(v) for k, v in vars_processed_data.items()}
if vars_processed_data
else {}
)
recon_model_dict["_solution_observable"] = model_data.get(
"_solution_observable", False
)
if battery_model:
return battery_model.deserialise(recon_model_dict)
tag = model_data.get("$type") or model_data.get("py/object")
if tag:
from pybamm.expression_tree.operations.serialise_kernel import (
_resolve_class,
)
model_framework = (
_resolve_class(tag) if "." in tag else _resolve_class(f"pybamm.{tag}")
)
# deserialise is a BaseModel classmethod inherited by every model class.
return model_framework.deserialise(recon_model_dict)
raise TypeError("The PyBaMM battery model to use has not been provided.")
@staticmethod
def _json_encoder(obj):
if isinstance(obj, Enum):
return obj.name
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.integer):
return int(obj)
else:
raise TypeError(f"Object of type {type(obj)} is not JSON serializable.")
@staticmethod
def _import_dotted_class(dotted_path: str):
"""Import ``module.ClassName`` and return the class."""
if "." not in dotted_path:
raise ValueError(f"Expected 'module.ClassName' but got {dotted_path!r}")
module_name, class_name = dotted_path.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
@staticmethod
def _resolve_base_class(base_class: str, base_class_mro: list[str]):
"""Resolve a serialised ``base_class`` to a Python class.
Tries ``base_class`` first, then each entry in ``base_class_mro`` (the
ancestor chain recorded at serialise time). Falls back to
``pybamm.BaseModel`` only if nothing in the MRO is importable, so a
custom subclass loaded in an environment without its defining package
is still reconstructed against the closest pybamm-provided ancestor.
"""
try:
return Serialise._import_dotted_class(base_class)
except (ModuleNotFoundError, AttributeError, ValueError) as primary_err:
for ancestor in base_class_mro:
if ancestor == base_class:
continue
try:
resolved = Serialise._import_dotted_class(ancestor)
except (ModuleNotFoundError, AttributeError):
# ValueError is left to propagate: MRO entries are produced
# by ``serialise_custom_model`` and are always well-formed.
continue
warnings.warn(
f"Could not import base class '{base_class}': "
f"{primary_err}. Falling back to ancestor "
f"'{ancestor}' from the recorded MRO.",
stacklevel=3,
)
return resolved
warnings.warn(
f"Could not import base class '{base_class}': "
f"{primary_err}. Falling back to pybamm.BaseModel; the loaded "
"model will contain the symbolic equations but not any "
"subclass-specific Python behaviour.",
stacklevel=3,
)
return pybamm.BaseModel
[docs]
@staticmethod
def serialise_custom_model(model: pybamm.BaseModel, compress: bool = False) -> dict:
"""
Converts a custom (non-discretised) PyBaMM model to a JSON-serialisable dictionary.
This includes symbolic expressions for rhs, algebraic, initial and boundary
conditions, events, and variables. Works for user defined models that are
subclasses of BaseModel.
Parameters
----------
model : :class:`pybamm.BaseModel`
The custom symbolic model to be serialised.
compress : bool, optional
If True, the resulting dictionary will be compressed using zlib and
encoded as base64. The output will contain a "compressed" flag set to
True and a "data" field with the compressed payload. Default is False.
Returns
-------
dict
A JSON-serialisable dictionary representation of the model. If compress
is True, returns {"compressed": True, "data": <base64-encoded-zlib-data>}.
Raises
------
AttributeError
If the model is missing required sections
"""
if getattr(model, "is_processed", True):
raise ValueError("Cannot serialise a built model.")
required_attrs = [
"rhs",
"algebraic",
"initial_conditions",
"boundary_conditions",
"events",
"variables",
]
missing = [attr for attr in required_attrs if not hasattr(model, attr)]
if missing:
raise AttributeError(f"Model is missing required sections: {missing}")
base_cls = model.__class__
# If the class is object or builtins.object, use pybamm.BaseModel instead
if base_cls is object or (
base_cls.__module__ == "builtins" and base_cls.__name__ == "object"
):
base_cls_str = "pybamm.BaseModel"
else:
base_cls_str = f"{base_cls.__module__}.{base_cls.__name__}"
base_class_mro = [
f"{ancestor.__module__}.{ancestor.__name__}"
for ancestor in base_cls.__mro__
if ancestor is not object and ancestor.__module__ != "builtins"
]
model_content = {
"name": getattr(model, "name", "unnamed_model"),
"base_class": base_cls_str,
"base_class_mro": base_class_mro,
"options": getattr(model, "options", {}),
"rhs": [
(
convert_symbol_to_json(variable),
convert_symbol_to_json(rhs_expression),
)
for variable, rhs_expression in getattr(model, "rhs", {}).items()
],
"algebraic": [
(
convert_symbol_to_json(variable),
convert_symbol_to_json(algebraic_expression),
)
for variable, algebraic_expression in getattr(
model, "algebraic", {}
).items()
],
"initial_conditions": [
(
convert_symbol_to_json(variable),
convert_symbol_to_json(initial_value),
)
for variable, initial_value in getattr(
model, "initial_conditions", {}
).items()
],
"boundary_conditions": [
(
convert_symbol_to_json(variable),
{
side: [
convert_symbol_to_json(expression),
boundary_type,
]
for side, (expression, boundary_type) in conditions.items()
},
)
for variable, conditions in getattr(
model, "boundary_conditions", {}
).items()
],
"events": [
{
"name": event.name,
"expression": convert_symbol_to_json(event.expression),
"event_type": event.event_type,
}
for event in getattr(model, "events", [])
],
"variables": {
str(variable_name): convert_symbol_to_json(expression)
for variable_name, expression in getattr(model, "variables", {}).items()
},
}
SCHEMA_VERSION = "1.1"
model_json = {
"schema_version": SCHEMA_VERSION,
"pybamm_version": pybamm.__version__,
"model": model_content,
}
if compress:
# Serialize to JSON string, compress with zlib, and encode as base64
json_str = json.dumps(model_json, default=Serialise._json_encoder)
compressed_bytes = zlib.compress(json_str.encode("utf-8"))
compressed_b64 = base64.b64encode(compressed_bytes).decode("ascii")
return {
"compressed": True,
"data": compressed_b64,
}
return model_json
[docs]
@staticmethod
def save_custom_model(
model: pybamm.BaseModel,
filename: str | Path | None = None,
compress: bool = False,
) -> None:
"""
Saves a custom (non-discretised) PyBaMM model to a JSON file. Works for user defined models that are subclasses of BaseModel.
This includes symbolic expressions for rhs, algebraic, initial and boundary
conditions, events, and variables. Useful for storing or sharing models
before discretisation.
Parameters
----------
model : :class:`pybamm.BaseModel`
The custom symbolic model to be saved.
filename : str, optional
The desired name of the JSON file. If not provided, a name will be
generated from the model name and current datetime.
compress : bool, optional
If True, the model data will be compressed using zlib before saving.
This can significantly reduce file size. Default is False.
Example
-------
>>> import pybamm
>>> model = pybamm.lithium_ion.BasicDFN()
>>> from pybamm.expression_tree.operations.serialise import Serialise
>>> Serialise.save_custom_model(model, "basicdfn_model.json")
>>> # Or with compression:
>>> Serialise.save_custom_model(model, "basicdfn_model.json", compress=True)
"""
try:
model_json = Serialise.serialise_custom_model(model, compress=compress)
# Extract model name for filename generation
# When compressed, use the model's name attribute directly
if compress:
model_name = getattr(model, "name", "unnamed_model")
else:
model_name = model_json["model"]["name"]
if filename is None:
safe_name = re.sub(r"[^\w\-_.]", "_", model_name or "unnamed_model")
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
filename = f"{safe_name}_{timestamp}.json"
filename = Path(filename)
else:
filename = Path(filename)
if not filename.name.endswith(".json"):
raise ValueError(
f"Filename '{filename}' must end with '.json' extension."
)
# Sanitize only the filename, not the directory path
safe_stem = re.sub(r"[^\w\-_.]", "_", filename.stem)
filename = filename.with_name(f"{safe_stem}.json")
try:
with open(filename, "w") as f:
json.dump(model_json, f, indent=2, default=Serialise._json_encoder)
except OSError as file_err:
raise OSError(
f"Failed to write model JSON to file '{filename}': {file_err}"
) from file_err
except AttributeError:
# Let AttributeError propagate directly
raise
except Exception as e:
raise ValueError(f"Failed to save custom model: {e}") from e
[docs]
@staticmethod
def serialise_custom_geometry(geometry: pybamm.Geometry) -> dict:
"""
Converts a custom PyBaMM geometry to a JSON-serialisable dictionary.
Parameters
----------
geometry : :class:`pybamm.Geometry`
The geometry object to be serialised.
Returns
-------
dict
A JSON-serialisable dictionary representation of the geometry
"""
# Serialize the geometry dict using convert_symbol_to_json for nested symbols
geometry_dict_serialized: dict = {}
for domain, domain_geom in geometry.items():
geometry_dict_serialized[domain] = {}
for key, value in domain_geom.items():
# Convert SpatialVariable keys to strings and serialize the key itself
if isinstance(key, pybamm.Symbol):
key_str = key.name if hasattr(key, "name") else str(key)
geometry_dict_serialized[domain]["symbol_" + key_str] = (
convert_symbol_to_json(key)
)
# Serialize the value dict
serialized_value = {}
for k, v in value.items():
if isinstance(v, pybamm.Symbol):
serialized_value[k] = convert_symbol_to_json(v)
else:
serialized_value[k] = v
geometry_dict_serialized[domain][key_str] = serialized_value
elif isinstance(key, str):
# String keys (like 'tabs') - keep as is
if isinstance(value, dict):
serialized_value = {}
for k, v in value.items():
if isinstance(v, pybamm.Symbol):
serialized_value[k] = convert_symbol_to_json(v)
else:
serialized_value[k] = v
geometry_dict_serialized[domain][key] = serialized_value
else:
geometry_dict_serialized[domain][key] = value
SCHEMA_VERSION = "1.1"
geometry_json = {
"schema_version": SCHEMA_VERSION,
"pybamm_version": pybamm.__version__,
"geometry": geometry_dict_serialized,
}
return geometry_json
[docs]
@staticmethod
def save_custom_geometry(
geometry: pybamm.Geometry, filename: str | Path | None = None
) -> None:
"""
Saves a custom PyBaMM geometry to a JSON file.
Parameters
----------
geometry : :class:`pybamm.Geometry`
The geometry object to be saved.
filename : str or Path, optional
The desired name of the JSON file. If not provided, a name will be
generated using current datetime.
"""
try:
geometry_json = Serialise.serialise_custom_geometry(geometry)
if filename is None:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
filename = f"geometry_{timestamp}.json"
filename = Path(filename)
else:
filename = Path(filename)
if not filename.name.endswith(".json"):
raise ValueError(
f"Filename '{filename}' must end with '.json' extension."
)
# Sanitize only the filename, not the directory path
safe_stem = re.sub(r"[^\w\-_.]", "_", filename.stem)
filename = filename.with_name(f"{safe_stem}.json")
try:
with open(filename, "w") as f:
json.dump(
geometry_json, f, indent=2, default=Serialise._json_encoder
)
except OSError as file_err:
raise OSError(
f"Failed to write geometry JSON to file '{filename}': {file_err}"
) from file_err
except Exception as e:
raise ValueError(f"Failed to save custom geometry: {e}") from e
[docs]
@staticmethod
def load_custom_geometry(filename: str | dict) -> pybamm.Geometry:
"""
Loads a custom PyBaMM geometry from a JSON file or dictionary.
Parameters
----------
filename : str or dict
Path to the JSON file containing the saved geometry, or a dictionary
containing the serialised geometry data.
Returns
-------
:class:`pybamm.Geometry`
The reconstructed geometry object.
"""
if isinstance(filename, dict):
data = filename
else:
try:
with open(filename) as file:
data = json.load(file)
except FileNotFoundError as err:
raise FileNotFoundError(f"Could not find file: {filename}") from err
except json.JSONDecodeError as e:
raise ValueError(
f"The file '{filename}' contains invalid JSON: {e!s}"
) from e
# Validate schema version
schema_version = data.get("schema_version", SUPPORTED_SCHEMA_VERSION)
if schema_version != SUPPORTED_SCHEMA_VERSION:
raise ValueError(
f"Unsupported schema version: {schema_version}. "
f"Expected: {SUPPORTED_SCHEMA_VERSION}"
)
# Extract geometry data
geometry_data = data.get("geometry")
if geometry_data is None:
raise KeyError("Missing 'geometry' section in JSON data.")
# Reconstruct geometry
reconstructed_geometry: dict = {}
for domain, domain_geom in geometry_data.items():
reconstructed_geometry[domain] = {}
# Find symbol keys and reconstruct SpatialVariables
symbol_keys = {}
for key in domain_geom.keys():
if key.startswith("symbol_"):
var_name = key[7:] # Remove "symbol_" prefix
symbol_keys[var_name] = convert_symbol_from_json(domain_geom[key])
# Now reconstruct the domain geometry with proper keys
for key, value in domain_geom.items():
if key.startswith("symbol_"):
continue # Skip symbol definitions
if key in symbol_keys:
# Use the reconstructed SpatialVariable as key
spatial_var = symbol_keys[key]
reconstructed_value = {}
for k, v in value.items():
if isinstance(v, dict) and ("$type" in v or "type" in v):
# Reconstruct PyBaMM Symbol using convert_symbol_from_json
reconstructed_value[k] = convert_symbol_from_json(v)
else:
reconstructed_value[k] = v
reconstructed_geometry[domain][spatial_var] = reconstructed_value
else:
# String key (like 'tabs')
if isinstance(value, dict):
reconstructed_value = {}
for k, v in value.items():
if isinstance(v, dict) and ("$type" in v or "type" in v):
reconstructed_value[k] = convert_symbol_from_json(v)
else:
reconstructed_value[k] = v
reconstructed_geometry[domain][key] = reconstructed_value
else:
reconstructed_geometry[domain][key] = value
return pybamm.Geometry(reconstructed_geometry)
[docs]
@staticmethod
def serialise_spatial_method_item(method) -> dict:
"""Serialise a spatial method. The class is encoded via the kernel's
class_reference codec; its options ride alongside under "options"."""
from pybamm.expression_tree.operations.serialise_kernel import encode
result = encode(type(method))
result["options"] = method.options if hasattr(method, "options") else {}
return result
[docs]
@staticmethod
def deserialise_spatial_method_item(method_info: dict):
"""Deserialise a spatial method from either the kernel class_reference shape
or the legacy {class, module} shape."""
from pybamm.expression_tree.operations.serialise_kernel import (
_resolve_class,
normalise_legacy,
)
node = normalise_legacy(dict(method_info))
method_class = _resolve_class(node["class"])
options = method_info.get("options") or {}
return method_class(options=options)
[docs]
@staticmethod
def serialise_spatial_methods(spatial_methods: dict) -> dict:
"""
Converts a dictionary of spatial methods to a JSON-serialisable dictionary.
Parameters
----------
spatial_methods : dict
Dictionary mapping domain names to spatial method instances.
Returns
-------
dict
A JSON-serialisable dictionary representation of the spatial methods
"""
spatial_methods_dict = {}
for domain, method in spatial_methods.items():
spatial_methods_dict[domain] = Serialise.serialise_spatial_method_item(
method
)
SCHEMA_VERSION = "1.1"
spatial_methods_json = {
"schema_version": SCHEMA_VERSION,
"pybamm_version": pybamm.__version__,
"spatial_methods": spatial_methods_dict,
}
return spatial_methods_json
[docs]
@staticmethod
def save_spatial_methods(
spatial_methods: dict, filename: str | Path | None = None
) -> None:
"""
Saves spatial methods to a JSON file.
Parameters
----------
spatial_methods : dict
Dictionary mapping domain names to spatial method instances.
filename : str or Path, optional
The desired name of the JSON file. If not provided, a name will be
generated using current datetime.
"""
try:
spatial_methods_json = Serialise.serialise_spatial_methods(spatial_methods)
if filename is None:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
filename = f"spatial_methods_{timestamp}.json"
filename = Path(filename)
else:
filename = Path(filename)
if not filename.name.endswith(".json"):
raise ValueError(
f"Filename '{filename}' must end with '.json' extension."
)
# Sanitize only the filename, not the directory path
safe_stem = re.sub(r"[^\w\-_.]", "_", filename.stem)
filename = filename.with_name(f"{safe_stem}.json")
try:
with open(filename, "w") as f:
json.dump(
spatial_methods_json,
f,
indent=2,
default=Serialise._json_encoder,
)
except OSError as file_err:
raise OSError(
f"Failed to write spatial methods JSON to file '{filename}': {file_err}"
) from file_err
except Exception as e:
raise ValueError(f"Failed to save spatial methods: {e}") from e
[docs]
@staticmethod
def load_spatial_methods(filename: str | dict) -> dict:
"""
Loads spatial methods from a JSON file or dictionary.
Parameters
----------
filename : str or dict
Path to the JSON file containing the saved spatial methods, or a dictionary
containing the serialised spatial methods data.
Returns
-------
dict
Dictionary mapping domain names to spatial method instances.
"""
if isinstance(filename, dict):
data = filename
else:
try:
with open(filename) as file:
data = json.load(file)
except FileNotFoundError as err:
raise FileNotFoundError(f"Could not find file: {filename}") from err
except json.JSONDecodeError as e:
raise ValueError(
f"The file '{filename}' contains invalid JSON: {e!s}"
) from e
# Validate schema version
schema_version = data.get("schema_version", SUPPORTED_SCHEMA_VERSION)
if schema_version != SUPPORTED_SCHEMA_VERSION:
raise ValueError(
f"Unsupported schema version: {schema_version}. "
f"Expected: {SUPPORTED_SCHEMA_VERSION}"
)
# Extract spatial methods data
spatial_methods_data = data.get("spatial_methods")
if spatial_methods_data is None:
raise KeyError("Missing 'spatial_methods' section in JSON data.")
# Reconstruct spatial methods
from pybamm.expression_tree.operations.serialise_kernel import (
SerialisationError,
)
reconstructed_methods = {}
for domain, method_info in spatial_methods_data.items():
try:
reconstructed_methods[domain] = (
Serialise.deserialise_spatial_method_item(method_info)
)
except (ModuleNotFoundError, AttributeError, SerialisationError) as e:
class_name = method_info.get("class", "?")
raise ImportError(
f"Could not import spatial method '{class_name}': {e}"
) from e
except Exception as e:
raise ValueError(
f"Failed to reconstruct spatial method for domain '{domain}': {e}"
) from e
return reconstructed_methods
[docs]
@staticmethod
def serialise_var_pts(var_pts: dict) -> dict:
"""
Converts a var_pts dictionary to a JSON-serialisable dictionary.
Parameters
----------
var_pts : dict
Dictionary mapping spatial variable names (str or SpatialVariable) to
number of points (int).
Returns
-------
dict
A JSON-serialisable dictionary representation of var_pts
"""
# Convert all keys to strings
var_pts_dict = {}
for key, value in var_pts.items():
if isinstance(key, str):
var_pts_dict[key] = value
elif hasattr(key, "name"):
# SpatialVariable or similar object with name attribute
var_pts_dict[key.name] = value
else:
raise ValueError(f"Unexpected key type in var_pts: {type(key)}")
SCHEMA_VERSION = "1.1"
var_pts_json = {
"schema_version": SCHEMA_VERSION,
"pybamm_version": pybamm.__version__,
"var_pts": var_pts_dict,
}
return var_pts_json
[docs]
@staticmethod
def save_var_pts(var_pts: dict, filename: str | Path | None = None) -> None:
"""
Saves var_pts to a JSON file.
Parameters
----------
var_pts : dict
Dictionary mapping spatial variable names to number of points.
filename : str or Path, optional
The desired name of the JSON file. If not provided, a name will be
generated using current datetime.
"""
try:
var_pts_json = Serialise.serialise_var_pts(var_pts)
if filename is None:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
filename = f"var_pts_{timestamp}.json"
filename = Path(filename)
else:
filename = Path(filename)
if not filename.name.endswith(".json"):
raise ValueError(
f"Filename '{filename}' must end with '.json' extension."
)
# Sanitize only the filename, not the directory path
safe_stem = re.sub(r"[^\w\-_.]", "_", filename.stem)
filename = filename.with_name(f"{safe_stem}.json")
try:
with open(filename, "w") as f:
json.dump(
var_pts_json, f, indent=2, default=Serialise._json_encoder
)
except OSError as file_err:
raise OSError(
f"Failed to write var_pts JSON to file '{filename}': {file_err}"
) from file_err
except Exception as e:
raise ValueError(f"Failed to save var_pts: {e}") from e
[docs]
@staticmethod
def load_var_pts(filename: str | dict) -> dict:
"""
Loads var_pts from a JSON file or dictionary.
Parameters
----------
filename : str or dict
Path to the JSON file containing the saved var_pts, or a dictionary
containing the serialised var_pts data.
Returns
-------
dict
Dictionary mapping spatial variable names (strings) to number of points.
"""
if isinstance(filename, dict):
data = filename
else:
try:
with open(filename) as file:
data = json.load(file)
except FileNotFoundError as err:
raise FileNotFoundError(f"Could not find file: {filename}") from err
except json.JSONDecodeError as e:
raise ValueError(
f"The file '{filename}' contains invalid JSON: {e!s}"
) from e
# Validate schema version
schema_version = data.get("schema_version", SUPPORTED_SCHEMA_VERSION)
if schema_version != SUPPORTED_SCHEMA_VERSION:
raise ValueError(
f"Unsupported schema version: {schema_version}. "
f"Expected: {SUPPORTED_SCHEMA_VERSION}"
)
# Extract var_pts data
var_pts_data = data.get("var_pts")
if var_pts_data is None:
raise KeyError("Missing 'var_pts' section in JSON data.")
return var_pts_data
[docs]
@staticmethod
def serialise_submesh_item(submesh_item) -> dict:
"""Serialise a SubMesh class or MeshGenerator. The class is encoded via the
kernel's class_reference codec ({"$type": "type", "class": dotted-path}); a
MeshGenerator's params ride alongside under "submesh_params"."""
from pybamm.expression_tree.operations.serialise_kernel import encode
if hasattr(submesh_item, "submesh_type"): # MeshGenerator instance
result = encode(submesh_item.submesh_type)
if getattr(submesh_item, "submesh_params", None):
result["submesh_params"] = dict(submesh_item.submesh_params)
return result
return encode(submesh_item) # a SubMesh class
[docs]
@staticmethod
def deserialise_submesh_item(submesh_info: dict, return_class_only: bool = False):
"""Deserialise a SubMesh class / MeshGenerator from either the kernel
class_reference shape or the legacy {class, module} shape."""
from pybamm.expression_tree.operations.serialise_kernel import (
_resolve_class,
normalise_legacy,
)
node = normalise_legacy(
dict(submesh_info)
) # legacy {class,module} -> $type form
submesh_class = _resolve_class(node["class"])
if return_class_only:
return submesh_class
params = submesh_info.get("submesh_params") or {}
return pybamm.MeshGenerator(submesh_class, params)
[docs]
@staticmethod
def serialise_submesh_types(submesh_types: dict) -> dict:
"""
Converts a dictionary of submesh types to a JSON-serialisable dictionary.
Parameters
----------
submesh_types : dict
Dictionary mapping domain names to submesh classes or MeshGenerator objects.
Returns
-------
dict
A JSON-serialisable dictionary representation of the submesh types
"""
submesh_types_dict = {}
for domain, submesh_item in submesh_types.items():
submesh_types_dict[domain] = Serialise.serialise_submesh_item(submesh_item)
SCHEMA_VERSION = "1.1"
submesh_types_json = {
"schema_version": SCHEMA_VERSION,
"pybamm_version": pybamm.__version__,
"submesh_types": submesh_types_dict,
}
return submesh_types_json
[docs]
@staticmethod
def save_submesh_types(
submesh_types: dict, filename: str | Path | None = None
) -> None:
"""
Saves submesh types to a JSON file.
Parameters
----------
submesh_types : dict
Dictionary mapping domain names to submesh classes.
filename : str or Path, optional
The desired name of the JSON file. If not provided, a name will be
generated using current datetime.
"""
try:
submesh_types_json = Serialise.serialise_submesh_types(submesh_types)
if filename is None:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
filename = f"submesh_types_{timestamp}.json"
filename = Path(filename)
else:
filename = Path(filename)
if not filename.name.endswith(".json"):
raise ValueError(
f"Filename '{filename}' must end with '.json' extension."
)
# Sanitize only the filename, not the directory path
safe_stem = re.sub(r"[^\w\-_.]", "_", filename.stem)
filename = filename.with_name(f"{safe_stem}.json")
try:
with open(filename, "w") as f:
json.dump(
submesh_types_json, f, indent=2, default=Serialise._json_encoder
)
except OSError as file_err:
raise OSError(
f"Failed to write submesh types JSON to file '{filename}': {file_err}"
) from file_err
except Exception as e:
raise ValueError(f"Failed to save submesh types: {e}") from e
[docs]
@staticmethod
def load_submesh_types(filename: str | dict) -> dict:
"""
Loads submesh types from a JSON file or dictionary.
Parameters
----------
filename : str or dict
Path to the JSON file containing the saved submesh types, or a dictionary
containing the serialised submesh types data.
Returns
-------
dict
Dictionary mapping domain names to MeshGenerator objects.
"""
if isinstance(filename, dict):
data = filename
else:
try:
with open(filename) as file:
data = json.load(file)
except FileNotFoundError as err:
raise FileNotFoundError(f"Could not find file: {filename}") from err
except json.JSONDecodeError as e:
raise ValueError(
f"The file '{filename}' contains invalid JSON: {e!s}"
) from e
# Validate schema version
schema_version = data.get("schema_version", SUPPORTED_SCHEMA_VERSION)
if schema_version != SUPPORTED_SCHEMA_VERSION:
raise ValueError(
f"Unsupported schema version: {schema_version}. "
f"Expected: {SUPPORTED_SCHEMA_VERSION}"
)
# Extract submesh types data
submesh_types_data = data.get("submesh_types")
if submesh_types_data is None:
raise KeyError("Missing 'submesh_types' section in JSON data.")
# Reconstruct submesh types
from pybamm.expression_tree.operations.serialise_kernel import (
SerialisationError,
)
reconstructed_submesh_types = {}
for domain, submesh_info in submesh_types_data.items():
try:
reconstructed_submesh_types[domain] = (
Serialise.deserialise_submesh_item(
submesh_info, return_class_only=False
)
)
except (ModuleNotFoundError, AttributeError, SerialisationError) as e:
class_name = submesh_info.get("class", "?")
raise ImportError(
f"Could not import submesh type '{class_name}': {e}"
) from e
except Exception as e:
raise ValueError(
f"Failed to reconstruct submesh type for domain '{domain}': {e}"
) from e
return reconstructed_submesh_types
@staticmethod
def _create_symbol_key(symbol_json: dict) -> str:
"""
Given the JSON‐dict for a symbol, return a unique, hashable key.
We just sort the dict keys and dump to a string.
"""
return json.dumps(symbol_json, sort_keys=True)
[docs]
@staticmethod
def load_custom_model(filename: str | dict) -> pybamm.BaseModel:
"""
Loads a custom (symbolic) PyBaMM model from a JSON file or dictionary.
Reconstructs a model saved using `save_custom_model`, including its rhs,
algebraic equations, initial and boundary conditions, events, and variables.
Returns a fully symbolic model ready for further processing or discretisation.
Automatically detects and decompresses data that was serialised with
compression enabled (compress=True in serialise_custom_model).
Parameters
----------
filename : str or dict
Path to the JSON file containing the saved model, or a dictionary
containing the serialised model data (optionally compressed).
Returns
-------
:class:`pybamm.BaseModel` or subclass
The reconstructed symbolic PyBaMM model.
Example
-------
>>> import pybamm
>>> model = pybamm.lithium_ion.BasicDFN()
>>> from pybamm.expression_tree.operations.serialise import Serialise
>>> Serialise.save_custom_model(model, "basicdfn_model.json")
>>> loaded_model = Serialise.load_custom_model("basicdfn_model.json")
"""
if isinstance(filename, dict):
data = filename
else:
try:
with open(filename) as file:
data = json.load(file)
except FileNotFoundError as err:
raise FileNotFoundError(f"Could not find file: {filename}") from err
except json.JSONDecodeError as e:
raise pybamm.InvalidModelJSONError(
f"The model defined in the file '{filename}' contains invalid JSON: {e!s}"
) from e
# Check if the data is compressed and decompress if needed
if data.get("compressed", False):
try:
compressed_b64 = data["data"]
compressed_bytes = base64.b64decode(compressed_b64)
json_str = zlib.decompress(compressed_bytes).decode("utf-8")
data = json.loads(json_str)
except (KeyError, zlib.error, base64.binascii.Error) as e:
raise ValueError(f"Failed to decompress model data: {e}") from e
# Validate outer structure
schema_version = data.get("schema_version", SUPPORTED_SCHEMA_VERSION)
if schema_version != SUPPORTED_SCHEMA_VERSION:
raise ValueError(
f"Unsupported schema version: {schema_version}. "
f"Expected: {SUPPORTED_SCHEMA_VERSION}"
)
model_data = data.get("model")
if model_data is None:
raise KeyError("Missing 'model' section in JSON file.")
required = [
"name",
"rhs",
"initial_conditions",
"base_class",
"algebraic",
"boundary_conditions",
"events",
"variables",
]
missing = [k for k in required if k not in model_data]
if missing:
raise KeyError(f"Missing required model sections: {missing}")
battery_model = (model_data.get("base_class") or "").strip()
if battery_model in ("", "pybamm.BaseModel", "builtins.object"):
base_cls = pybamm.BaseModel
else:
base_cls = Serialise._resolve_base_class(
battery_model,
model_data.get("base_class_mro") or [],
)
model = base_cls()
model.name = model_data["name"]
model.schema_version = schema_version
# Restore options so round-trip serialisation produces an equivalent model
opts = model_data.get("options", {})
if opts is not None:
model.options = dict(opts)
all_variable_keys = (
[lhs_json for lhs_json, _ in model_data["rhs"]]
+ [lhs_json for lhs_json, _ in model_data["initial_conditions"]]
+ [lhs_json for lhs_json, _ in model_data["algebraic"]]
+ [variable_json for variable_json, _ in model_data["boundary_conditions"]]
)
symbol_map = {}
for variable_json in all_variable_keys:
try:
symbol = _require_symbol(variable_json)
key = Serialise._create_symbol_key(variable_json)
symbol_map[key] = symbol
except Exception as e:
raise ValueError(
f"Failed to process symbol key for variable {variable_json}: {e!s}"
) from e
model.rhs = {}
for lhs_json, rhs_expr_json in model_data["rhs"]:
try:
lhs = symbol_map[Serialise._create_symbol_key(lhs_json)]
rhs = convert_symbol_from_json(rhs_expr_json)
model.rhs[lhs] = rhs
except Exception as e:
raise ValueError(
f"Failed to convert rhs entry for {lhs_json}: {e!s}"
) from e
model.algebraic = {}
for lhs_json, algebraic_expr_json in model_data["algebraic"]:
try:
lhs = symbol_map[Serialise._create_symbol_key(lhs_json)]
rhs = convert_symbol_from_json(algebraic_expr_json)
model.algebraic[lhs] = rhs
except Exception as e:
raise ValueError(
f"Failed to convert algebraic entry for {lhs_json}: {e!s}"
) from e
model.initial_conditions = {}
for lhs_json, initial_value_json in model_data["initial_conditions"]:
try:
lhs = symbol_map[Serialise._create_symbol_key(lhs_json)]
rhs = convert_symbol_from_json(initial_value_json)
model.initial_conditions[lhs] = rhs
except Exception as e:
raise ValueError(
f"Failed to convert initial condition entry for {lhs_json}: {e!s}"
) from e
model.boundary_conditions = {}
for variable_json, condition_dict in model_data["boundary_conditions"]:
try:
variable = symbol_map[Serialise._create_symbol_key(variable_json)]
sides = {}
for side, (expression_json, boundary_type) in condition_dict.items():
try:
expr = _require_symbol(expression_json)
sides[side] = (expr, boundary_type)
except Exception as e:
raise ValueError(
f"Failed to convert boundary expression for variable {variable_json} on side '{side}': {e!s}"
) from e
model.boundary_conditions[variable] = sides
except Exception as e:
raise ValueError(
f"Failed to convert boundary condition entry for variable {variable_json}: {e!s}"
) from e
model.events = []
for event_data in model_data["events"]:
try:
name = event_data["name"]
expr = convert_symbol_from_json(event_data["expression"])
event_type = event_data["event_type"]
# ``_json_encoder`` stores Enums as their ``.name``.
if isinstance(event_type, str):
event_type = pybamm.EventType[event_type]
model.events.append(pybamm.Event(name, expr, event_type))
except Exception as e:
raise ValueError(
f"Failed to convert event '{event_data.get('name', 'UNKNOWN')}': {e!s}"
) from e
model.variables = {}
for variable_name, expression_json in model_data["variables"].items():
try:
key = Serialise._create_symbol_key(expression_json)
symbol = symbol_map.get(key)
if symbol is None:
symbol = _require_symbol(expression_json)
model.variables[variable_name] = symbol
except Exception as e:
raise ValueError(
f"Failed to convert variable '{variable_name}': {e!s}"
) from e
# Restore observable state
model._solution_observable = False
return model
[docs]
@staticmethod
def save_parameters(parameters: dict, filename=None):
"""
Serializes a dictionary of parameters to a JSON file.
The values can be numbers, PyBaMM symbols, or callables.
Parameters
----------
parameters : dict
A dictionary of parameter names and values.
Values can be numeric, PyBaMM symbols, or callables.
filename : str, optional
If given, saves the serialized parameters to this file.
"""
parameter_values_dict = {}
for k, v in parameters.items():
if callable(v):
parameter_values_dict[k] = convert_symbol_to_json(
convert_function_to_symbolic_expression(v, k)
)
else:
parameter_values_dict[k] = convert_symbol_to_json(v)
if filename is not None:
with open(filename, "w") as f:
json.dump(parameter_values_dict, f, indent=4)
[docs]
@staticmethod
def load_parameters(filename):
"""
Load a JSON file of parameters (either from Serialise.save_parameters
or from a standard pybamm.ParameterValues.save), and return a
pybamm.ParameterValues object.
- If a value is a dict with a "$type" (kernel) or legacy "type" key,
deserialize it as a PyBaMM symbol.
- Otherwise (float, int, bool, str, list, dict-without-type), leave it as-is.
"""
with open(filename) as f:
raw_dict = json.load(f)
deserialized = {}
for key, val in raw_dict.items():
if isinstance(val, dict) and ("$type" in val or "type" in val):
deserialized[key] = convert_symbol_from_json(val)
elif isinstance(val, list):
deserialized[key] = val
elif isinstance(val, (numbers.Number | bool)):
deserialized[key] = val
elif isinstance(val, str):
deserialized[key] = val
elif isinstance(val, dict):
deserialized[key] = val
else:
raise ValueError(
f"Unsupported parameter format for key '{key}': {val!r}"
)
return pybamm.ParameterValues(deserialized)
# Helper functions
def _deconstruct_pybamm_dicts(self, dct: dict):
"""
Converts dictionaries which contain pybamm classes as keys
into a json serialisable format.
Dictionary keys present as pybamm objects are given a seperate key
as "symbol_<symbol name>" to store the dictionary required to reconstruct
a symbol, and their seperate key is used in the original dictionary. E.G:
{'rod':
{SpatialVariable(name='spat_var'): {"min":0.0, "max":2.0} }
}
converts to
{'rod':
{'symbol_spat_var': {"min":0.0, "max":2.0} },
'spat_var':
{"py/object":pybamm....}
}
Dictionaries which don't contain pybamm symbols are returned unchanged.
"""
from pybamm.expression_tree.operations.serialise_kernel import encode
def nested_convert(obj):
if isinstance(obj, dict):
new_dict = {}
for k, v in obj.items():
if isinstance(k, pybamm.Symbol):
new_k = encode(k)
new_dict["symbol_" + new_k["name"]] = new_k
k = new_k["name"]
new_dict[k] = nested_convert(v)
return new_dict
return obj
try:
_ = json.dumps(dct)
return dict(dct)
except TypeError: # dct must contain pybamm objects
return nested_convert(dct)
def _reconstruct_pybamm_dict(self, obj: dict):
"""
pybamm.Geometry can contain PyBaMM symbols as dictionary keys.
Converts
{"rod":
{"symbol_spat_var":
{"min":0.0, "max":2.0} },
"spat_var":
{"py/object":"pybamm...."}
}
from an exported JSON file to
{"rod":
{SpatialVariable(name="spat_var"): {"min":0.0, "max":2.0} }
}
"""
def recurse(obj):
if isinstance(obj, dict):
new_dict = {}
for k, v in obj.items():
if "symbol_" in k:
new_dict[k] = self._decode_model_node(v)
elif isinstance(v, dict):
new_dict[k] = recurse(v)
else:
new_dict[k] = v
pattern = re.compile("symbol_")
symbol_keys = {k: v for k, v in new_dict.items() if pattern.match(k)}
# rearrange the dictionary to make pybamm objects the dictionary keys
if symbol_keys:
for k, v in symbol_keys.items():
new_dict[v] = new_dict[k.removeprefix("symbol_")]
del new_dict[k]
del new_dict[k.removeprefix("symbol_")]
return new_dict
return obj
return recurse(obj)
def _convert_options(self, d):
"""
Converts a dictionary with nested lists to nested tuples,
used to convert model options back into correct format
"""
if isinstance(d, dict):
return {k: self._convert_options(v) for k, v in d.items()}
elif isinstance(d, list):
return tuple(self._convert_options(item) for item in d)
else:
return d
@staticmethod
def _to_json_safe(value):
"""Convert a value to a JSON-serializable form (native Python types).
Handles numpy scalars, arrays, booleans, and nested dicts/lists.
"""
if isinstance(value, (np.bool_, bool)):
return bool(value)
if isinstance(value, (np.floating, float)):
return float(value)
if isinstance(value, (np.integer, int)):
return int(value)
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, dict):
return {k: Serialise._to_json_safe(v) for k, v in value.items()}
if isinstance(value, list):
return [Serialise._to_json_safe(v) for v in value]
return value
[docs]
@staticmethod
def serialise_experiment(experiment) -> dict:
"""Convert a :class:`pybamm.Experiment` to a JSON-serialisable dict.
Returns ``{"cycles": [[step_config, ...], ...]}``, grouping steps
into cycles according to ``experiment.cycle_lengths``.
Parameters
----------
experiment : :class:`pybamm.Experiment`
The experiment to serialise.
Returns
-------
dict
Config dict with key ``"cycles"``.
Raises
------
NotImplementedError
If a step uses a custom callable (``CustomTermination``,
``CustomStepExplicit``, ``CustomStepImplicit``); these have no
JSON representation.
"""
step_type_map = {
"Current": "current",
"Rest": "rest",
"Voltage": "voltage",
"Power": "power",
"CRate": "c-rate",
"Resistance": "resistance",
}
termination_type_map = {
"VoltageTermination": "voltage",
"CurrentTermination": "current",
"CrateTermination": "c-rate",
"CRateTermination": "c-rate",
}
# Top-level defaults; per-step values are emitted only when they differ.
experiment_period = getattr(experiment, "period", None)
experiment_temperature = getattr(experiment, "temperature", None)
def _serialise_step(step):
step_class_name = step.__class__.__name__
if step_class_name in ("CustomStepExplicit", "CustomStepImplicit"):
raise NotImplementedError(
f"{step_class_name} cannot be serialised: it carries a "
"user-supplied Python callable that has no JSON "
"representation. Serialisation of custom steps is not "
"supported."
)
if step_class_name not in step_type_map:
raise NotImplementedError(
f"Cannot serialise step of type {step_class_name!r}: only "
f"the built-in step classes "
f"({sorted(step_type_map)!r}) are supported."
)
step_type = step_type_map[step_class_name]
# Current with value 0 is a rest step
if step_class_name == "Current" and step.value == 0:
step_type = "rest"
step_config: dict = {"type": step_type}
# Use ``input_duration`` so ``uses_default_duration`` round-trips.
if step.input_duration is not None:
step_config["duration"] = step.input_duration
if step_type != "rest":
if step.is_drive_cycle:
step_config["value"] = step.input_value.tolist()
else:
value = step.value
if isinstance(value, pybamm.InputParameter):
param_name = value.name
step_config["value"] = (
param_name if isinstance(param_name, str) else str(value)
)
elif isinstance(value, (int, float, str)):
step_config["value"] = value
else:
step_config["value"] = str(value)
if step.termination:
terminations = []
for term in step.termination:
term_class_name = term.__class__.__name__
if term_class_name == "CustomTermination":
raise NotImplementedError(
"CustomTermination cannot be serialised: it "
"carries a user-supplied Python callable "
"(``event_function``) that has no JSON "
"representation."
)
if term_class_name not in termination_type_map:
raise NotImplementedError(
f"Cannot serialise termination of type "
f"{term_class_name!r}: only the built-in "
f"termination classes "
f"({sorted(termination_type_map)!r}) are "
f"supported."
)
term_type = termination_type_map[term_class_name]
term_config = {"type": term_type, "value": term.value}
if hasattr(term, "operator") and term.operator:
term_config["operator"] = term.operator
terminations.append(term_config)
step_config["terminations"] = terminations
field_defaults = {
"temperature": experiment_temperature,
"period": experiment_period,
}
for field, default in field_defaults.items():
value = getattr(step, field, None)
if value is not None and value != default:
step_config[field] = value
tags = getattr(step, "tags", None)
if tags:
step_config["tags"] = tags
description = getattr(step, "description", None)
if description is not None:
step_config["description"] = description
# Skip ``direction`` when the loader will recompute it from the value sign.
direction = getattr(step, "direction", None)
if direction is not None and not getattr(
step, "calculate_charge_or_discharge", False
):
step_config["direction"] = direction
start_time = getattr(step, "start_time", None)
if isinstance(start_time, datetime):
step_config["start_time"] = start_time.isoformat()
if getattr(step, "skip_ok", True) is False:
step_config["skip_ok"] = False
return step_config
steps_config = [_serialise_step(step) for step in experiment.steps]
cycles_config = []
step_idx = 0
for cycle_length in experiment.cycle_lengths:
cycles_config.append(steps_config[step_idx : step_idx + cycle_length])
step_idx += cycle_length
config: dict = {"cycles": cycles_config}
for field in ("period", "temperature"):
value = getattr(experiment, field, None)
if value is not None:
config[field] = value
# Store the raw input so a single string isn't split into chars on reload.
termination = getattr(experiment, "termination_string", None)
if termination is not None:
config["termination"] = termination
return config
[docs]
@staticmethod
def deserialise_experiment(data: dict):
"""Convert a config dict to a :class:`pybamm.Experiment`.
Accepts ``{"cycles": [[step_config, ...], ...]}`` (new format) or
``{"steps": [step_config, ...]}`` (legacy flat format).
Parameters
----------
data : dict
Config dict as produced by :meth:`serialise_experiment`.
Returns
-------
:class:`pybamm.Experiment`
"""
step_func_map = _experiment_step_factories()
term_class_map = {
"voltage": pybamm.step.VoltageTermination,
"current": pybamm.step.CurrentTermination,
"c-rate": pybamm.step.CrateTermination,
}
def _parse_termination(term_dict):
term_type = term_dict.get("type")
if term_type not in term_class_map:
raise ValueError(
f"Unknown termination type: {term_type!r}. "
f"Expected one of {list(term_class_map)!r}."
)
value = float(term_dict["value"])
operator = term_dict.get("operator")
return term_class_map[term_type](value, operator=operator)
def _parse_step(step_dict):
step_type = step_dict.get("type")
if step_type not in step_func_map:
raise ValueError(
f"Unknown step type: {step_type!r}. "
f"Expected one of {list(step_func_map)!r}."
)
step_func = step_func_map[step_type]
if step_type == "rest":
value = 0.0
elif "value" in step_dict and step_dict["value"] is not None:
raw = step_dict["value"]
if isinstance(raw, list):
value = np.array(raw)
else:
try:
value = float(raw)
except (ValueError, TypeError):
if isinstance(raw, str):
value = pybamm.InputParameter(raw)
else:
raise
else:
raise ValueError(f"Value is required for {step_type!r} steps.")
# Missing ``duration`` round-trips as ``None`` to preserve
# ``uses_default_duration`` (used by infeasibility handling).
duration_kwargs = {}
if "duration" in step_dict and step_dict["duration"] is not None:
duration_kwargs["duration"] = step_dict["duration"]
terminations = None
if step_dict.get("terminations"):
terminations = [
_parse_termination(t) for t in step_dict["terminations"]
]
extra_kwargs = {}
for field in (
"temperature",
"period",
"tags",
"description",
"direction",
):
if step_dict.get(field) is not None:
extra_kwargs[field] = step_dict[field]
if step_dict.get("start_time") is not None:
extra_kwargs["start_time"] = datetime.fromisoformat(
step_dict["start_time"]
)
if "skip_ok" in step_dict:
# ``bool("False")`` is ``True``; reject non-bool input.
skip_ok_value = step_dict["skip_ok"]
if not isinstance(skip_ok_value, bool):
raise TypeError(
f"skip_ok must be a bool, got "
f"{type(skip_ok_value).__name__}: {skip_ok_value!r}."
)
extra_kwargs["skip_ok"] = skip_ok_value
if step_type == "rest":
return step_func(
termination=terminations, **duration_kwargs, **extra_kwargs
)
return step_func(
value,
termination=terminations,
**duration_kwargs,
**extra_kwargs,
)
experiment_kwargs = {}
for field in ("period", "temperature", "termination"):
if data.get(field) is not None:
experiment_kwargs[field] = data[field]
if "cycles" in data and data["cycles"] is not None:
processed_cycles = []
for cycle_steps in data["cycles"]:
processed_cycle = tuple(_parse_step(s) for s in cycle_steps)
processed_cycles.append(processed_cycle)
return pybamm.Experiment(processed_cycles, **experiment_kwargs)
elif "steps" in data and data["steps"] is not None:
processed_steps = [_parse_step(s) for s in data["steps"]]
return pybamm.Experiment(processed_steps, **experiment_kwargs)
else:
raise ValueError("Experiment config must have 'steps' or 'cycles'.")
# Solver __init__ params that are genuinely re-derived on construction or are
# non-serialisable transients. Empty today (no shipped solver omits anything);
# anything not listed that fails to serialise raises (safe-or-loud).
_SOLVER_DERIVED_PARAMS: frozenset = frozenset()
[docs]
@staticmethod
def serialise_solver(solver) -> dict:
"""Convert a :class:`pybamm.BaseSolver` to a JSON-serialisable config dict.
Uses ``inspect.signature`` to discover ``__init__`` parameters, reads
the corresponding attribute values from the instance (trying both
``solver.<name>`` and ``solver._<name>``), and filters out values that
are not JSON-serialisable. Handles ``CompositeSolver`` recursively.
Parameters
----------
solver : :class:`pybamm.BaseSolver`
The solver to serialise.
Returns
-------
dict
Config dict with a ``"type"`` key and one key per serialisable
init parameter.
"""
from pybamm.expression_tree.operations.serialise_kernel import (
SerialisationError,
)
if solver.__class__.__name__ == "CompositeSolver":
return {
"type": "CompositeSolver",
"sub_solvers": [
Serialise.serialise_solver(sub) for sub in solver.sub_solvers
],
}
config = {"type": solver.__class__.__name__}
sig = inspect.signature(solver.__class__.__init__)
for param_name in sig.parameters:
if param_name == "self":
continue
value = None
found = False
for attr_name in (param_name, f"_{param_name}"):
if hasattr(solver, attr_name):
value = getattr(solver, attr_name)
found = True
break
if not found:
continue
# ``root_method`` strings are resolved into ``BaseSolver`` instances
# by ``BaseSolver.__init__``; recurse so they survive JSON.
if isinstance(value, pybamm.BaseSolver):
config[param_name] = Serialise.serialise_solver(value)
continue
value = Serialise._to_json_safe(value)
try:
json.dumps(value)
except (TypeError, ValueError) as err:
if param_name in Serialise._SOLVER_DERIVED_PARAMS:
continue
raise SerialisationError(
f"Solver parameter '{param_name}' on "
f"{solver.__class__.__name__} is not JSON-serialisable. Make it "
f"serialisable (extend _to_json_safe), or -- if it is re-derived on "
f"construction or a non-serialisable transient -- add it to "
f"_SOLVER_DERIVED_PARAMS with a justification."
) from err
config[param_name] = value
return config
[docs]
@staticmethod
def deserialise_solver(data: dict):
"""Convert a config dict to a :class:`pybamm.BaseSolver` instance.
Handles ``CompositeSolver`` by recursively deserialising ``sub_solvers``.
Parameters
----------
data : dict
Config dict as produced by :meth:`serialise_solver`.
Returns
-------
:class:`pybamm.BaseSolver`
"""
data = dict(data)
solver_type = data.pop("type", None)
if solver_type is None:
raise ValueError("Solver config must include a 'type' key.")
solver_class = getattr(pybamm, solver_type, None)
if solver_class is None:
raise ValueError(
f"Unknown solver type '{solver_type}'. "
"Must be a class available on the pybamm module."
)
if solver_type == "CompositeSolver":
sub_solvers_config = data.pop("sub_solvers", None)
if sub_solvers_config is None:
raise ValueError(
"CompositeSolver config must include a 'sub_solvers' list."
)
sub_solvers = [Serialise.deserialise_solver(c) for c in sub_solvers_config]
return solver_class(sub_solvers)
# Rebuild any nested solver dicts (mirror of the recursion in
# ``serialise_solver``) before passing them to the constructor.
for k, v in list(data.items()):
if isinstance(v, dict) and isinstance(v.get("type"), str):
nested_cls = getattr(pybamm, v["type"], None)
if isinstance(nested_cls, type) and issubclass(
nested_cls, pybamm.BaseSolver
):
data[k] = Serialise.deserialise_solver(v)
return solver_class(**data)
def convert_function_to_symbolic_expression(func, name=None):
"""
Converts a Python function to a PyBaMM symbolic expression
Parameters
----------
func : callable
The Python function to convert
name : str, optional
The name of the function to use in the symbolic expression. If not provided,
the name of the function is used.
Returns
-------
pybamm.Symbol
The PyBaMM symbolic expression
"""
# Create symbolic parameters for each input argument
try:
func_name = func.get_name()
func_args = func.get_args()
# Use the underlying function for evaluation
func_to_eval = func.func
except AttributeError:
try:
func_name = func.__name__
func_args = list(inspect.signature(func).parameters)
func_to_eval = func
except AttributeError:
# One more fallback, in case it's a partial
func_name = func.func.__name__
func_args = list(inspect.signature(func).parameters)
func_to_eval = func
sym_inputs = [pybamm.Parameter(arg) for arg in func_args]
with tracing():
sym_output = func_to_eval(*sym_inputs)
# Wrap the symbolic expression in an ExpressionFunctionParameter to allow access
# to the function name and arguments
name = name or func_name
return ExpressionFunctionParameter(name, sym_output, func_name, func_args)
def _relocate_legacy_model_tree(node):
"""Rewrite the three legacy ``py/object`` nested shapes into the canonical
kernel ``children`` shape, in place of the per-class divergences the tag-only
``normalise_legacy`` shim does not touch. Read-only; recursive; applied only to
legacy discretised files before ``decode``.
- ``Event``: ``expression`` sibling -> ``children[0]``.
- ``ExplicitTimeIntegral``: ``initial_condition`` sibling -> appended to ``children``.
- ``Mesh``: ``sub_meshes`` {domain: node} -> ``children`` + ``sub_mesh_domains``.
Canonical (``$type``) nodes and non-dicts pass through (their children still
recurse). Numpy/leaf nodes have no nested model fields, so they are unaffected.
"""
if isinstance(node, list):
return [_relocate_legacy_model_tree(n) for n in node]
if not isinstance(node, dict):
return node
out = {k: v for k, v in node.items()}
children = [_relocate_legacy_model_tree(c) for c in out.get("children", [])]
if "sub_meshes" in out:
sub = out.pop("sub_meshes")
out["sub_mesh_domains"] = list(sub.keys())
children = [_relocate_legacy_model_tree(v) for v in sub.values()]
if "expression" in out:
children = [_relocate_legacy_model_tree(out.pop("expression")), *children]
if "initial_condition" in out:
children = [
*children,
_relocate_legacy_model_tree(out.pop("initial_condition")),
]
if children or "children" in out:
out["children"] = children
return out
def convert_symbol_from_json(json_data):
"""Reconstruct a pybamm.Symbol (or decoded leaf value) from kernel/legacy JSON.
Strict: raw strings, lists, and dicts without a recognised type tag cannot
be symbol nodes and raise :class:`SerialisationError` (as pre-kernel
versions did) instead of passing through silently. Numeric scalars and
decoded leaf values (tuples, ndarrays) are returned as-is; use the kernel
``decode`` directly for generic, non-symbol JSON.
"""
from pybamm.expression_tree.operations.serialise_kernel import (
SerialisationError,
decode,
)
decoded = decode(json_data)
# decode returns a dict/list/str only when the input was one of those and
# carried no type tag, i.e. a silent pass-through rather than a decode.
if isinstance(decoded, (dict, list, str)):
raise SerialisationError(
f"Cannot reconstruct a symbol from {json_data!r}: expected a JSON "
f"node with a '$type'/'type' tag or a numeric scalar."
)
return decoded
def _require_symbol(raw):
"""Decode *raw* and reject anything that is not a pybamm.Symbol.
convert_symbol_from_json also admits numeric scalars and decoded leaf
values; where the caller specifically needs a Symbol, reject anything else
here with a clear error instead of a confusing downstream failure.
"""
from pybamm.expression_tree.operations.serialise_kernel import (
SerialisationError,
)
symbol = convert_symbol_from_json(raw)
if not isinstance(symbol, pybamm.Symbol):
raise SerialisationError(f"expected a pybamm.Symbol, got {raw!r}")
return symbol
def convert_symbol_to_json(symbol):
"""Serialise a pybamm.Symbol to a JSON-compatible dict via the kernel."""
from pybamm.expression_tree.operations.serialise_kernel import encode
return encode(symbol)