from __future__ import annotations
#
# Concatenation classes
#
from __future__ import annotations
import copy
from collections import defaultdict
import numpy as np
import numpy.typing as npt
import sympy
from scipy.sparse import issparse, vstack
from collections.abc import Sequence
import pybamm
[docs]
class Concatenation(pybamm.Symbol):
"""
A node in the expression tree representing a concatenation of symbols.
Parameters
----------
children : iterable of :class:`pybamm.Symbol`
The symbols to concatenate
"""
def __init__(
self,
*children: pybamm.Symbol,
name: str | None = None,
check_domain=True,
concat_fun=None,
):
# The second condition checks whether this is the base Concatenation class
# or a subclass of Concatenation
# (ConcatenationVariable, NumpyConcatenation, ...)
if all(isinstance(child, pybamm.Variable) for child in children) and issubclass(
Concatenation, type(self)
):
raise TypeError(
"'ConcatenationVariable' should be used for concatenating 'Variable' "
"objects. We recommend using the 'concatenation' function, which will "
"automatically choose the best form."
)
if name is None:
name = "concatenation"
if check_domain:
domains = self.get_children_domains(children)
else:
domains = {"primary": []}
self.concatenation_function = concat_fun
super().__init__(name, children, domains=domains)
@classmethod
def _from_json(cls, snippet: dict):
"""Creates a new Concatenation instance from a json object"""
instance = cls.__new__(cls)
instance.concatenation_function = snippet["concat_fun"]
super(Concatenation, instance).__init__(
snippet["name"], tuple(snippet["children"]), domains=snippet["domains"]
)
return instance
def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
out = self.name + "("
for child in self.children:
out += f"{child!s}, "
out = out[:-2] + ")"
return out
def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
children_diffs = [child.diff(variable) for child in self.children]
if len(children_diffs) == 1:
diff = children_diffs[0]
else:
diff = self.__class__(*children_diffs)
return diff
[docs]
def get_children_domains(self, children: Sequence[pybamm.Symbol]):
# combine domains from children
domain: list = []
for child in children:
if not isinstance(child, pybamm.Symbol):
raise TypeError(f"{child} is not a pybamm symbol")
child_domain = child.domain
if child_domain == []:
raise pybamm.DomainError(
f"Cannot concatenate child '{child}' with empty domain"
)
if set(domain).isdisjoint(child_domain):
domain += child_domain
else:
raise pybamm.DomainError("domain of children must be disjoint")
auxiliary_domains = children[0].domains
for level, dom in auxiliary_domains.items():
if level != "primary" and dom != []:
for child in children[1:]:
if child.domains[level] not in [dom, []]:
raise pybamm.DomainError(
"children must have same or empty auxiliary domains"
)
domains = {**auxiliary_domains, "primary": domain}
return domains
def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
"""See :meth:`Concatenation._concatenation_evaluate()`."""
if len(children_eval) == 0:
return np.array([])
else:
return self.concatenation_function(children_eval)
[docs]
def evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
children_eval = [child.evaluate(t, y, y_dot, inputs) for child in self.children]
return self._concatenation_evaluate(children_eval)
[docs]
def create_copy(
self,
new_children: list[pybamm.Symbol] | None = None,
perform_simplifications: bool = True,
):
"""See :meth:`pybamm.Symbol.new_copy()`."""
children = self._children_for_copying(new_children)
return self._concatenation_new_copy(children, perform_simplifications)
def _concatenation_new_copy(self, children, perform_simplifications: bool = True):
"""
Creates a copy for the current concatenation class using the convenience
function :meth:`concatenation` to perform simplifications based on the new
children before creating the new copy.
"""
if perform_simplifications:
return concatenation(*children, name=self.name)
else:
return self.__class__(*children, name=self.name)
def _concatenation_jac(self, children_jacs):
"""Calculate the Jacobian of a concatenation."""
raise NotImplementedError
def _evaluate_for_shape(self):
"""See :meth:`pybamm.Symbol.evaluate_for_shape`"""
if len(self.children) == 0:
return np.array([])
else:
# Default: use np.concatenate
concatenation_function = self.concatenation_function or np.concatenate
return concatenation_function(
[child.evaluate_for_shape() for child in self.children]
)
[docs]
def is_constant(self):
"""See :meth:`pybamm.Symbol.is_constant()`."""
return all(child.is_constant() for child in self.children)
def _sympy_operator(self, *children):
"""Apply appropriate SymPy operators."""
self.concat_latex = tuple(map(sympy.latex, children))
if self.print_name is not None:
return sympy.Symbol(self.print_name)
else:
concat_str = r"\\".join(self.concat_latex)
concat_sym = sympy.Symbol(r"\begin{cases}" + concat_str + r"\end{cases}")
return concat_sym
[docs]
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
eq_list = []
for child in self.children:
eq = child.to_equation()
eq_list.append(eq)
return self._sympy_operator(*eq_list)
[docs]
class NumpyConcatenation(Concatenation):
"""
A node in the expression tree representing a concatenation of equations, when we
*don't* care about domains. The class :class:`pybamm.DomainConcatenation`, which
*is* careful about domains and uses broadcasting where appropriate, should be used
whenever possible instead.
Upon evaluation, equations are concatenated using numpy concatenation.
Parameters
----------
children : iterable of :class:`pybamm.Symbol`
The equations to concatenate
"""
def __init__(self, *children: pybamm.Symbol):
children = list(children)
# Turn objects that evaluate to scalars to objects that evaluate to vectors,
# so that we can concatenate them
for i, child in enumerate(children):
if child.evaluates_to_number():
children[i] = child * pybamm.Vector([1])
super().__init__(
*children,
name="numpy_concatenation",
check_domain=False,
concat_fun=np.concatenate,
)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.Concatenation._from_json()`."""
snippet["name"] = "numpy_concatenation"
snippet["concat_fun"] = np.concatenate
instance = super()._from_json(snippet)
return instance
def _concatenation_jac(self, children_jacs):
"""See :meth:`pybamm.Concatenation.concatenation_jac()`."""
children = self.children
if len(children) == 0:
return pybamm.Scalar(0)
else:
return SparseStack(*children_jacs)
def _concatenation_new_copy(
self,
children,
perform_simplifications: bool = True,
):
"""See :meth:`pybamm.Concatenation._concatenation_new_copy()`."""
if perform_simplifications:
return numpy_concatenation(*children)
else:
raise NotImplementedError(
f"{self.__class__.__name__} should always be copied using "
"simplification checks"
)
[docs]
class DomainConcatenation(Concatenation):
"""
A node in the expression tree representing a concatenation of symbols, being
careful about domains.
It is assumed that each child has a domain, and the final concatenated vector will
respect the sizes and ordering of domains established in mesh keys
Parameters
----------
children : iterable of :class:`pybamm.Symbol`
The symbols to concatenate
full_mesh : :class:`pybamm.Mesh`
The underlying mesh for discretisation, used to obtain the number of mesh points
in each domain.
copy_this : :class:`pybamm.DomainConcatenation` (optional)
if provided, this class is initialised by copying everything except the children
from `copy_this`. `mesh` is not used in this case
"""
def __init__(
self,
children: Sequence[pybamm.Symbol],
full_mesh: pybamm.Mesh,
copy_this: pybamm.DomainConcatenation | None = None,
):
# Convert any constant symbols in children to a Vector of the right size for
# concatenation
children = list(children)
# Allow the base class to sort the domains into the correct order
super().__init__(*children, name="domain_concatenation")
if copy_this is None:
# store mesh
self._full_mesh = full_mesh
# create dict of domain => slice of final vector
self.secondary_dimensions_npts = self._get_auxiliary_domain_repeats(
self.domains
)
self._slices = self.create_slices(self)
# store size of final vector
self._size = self._slices[self.domain[-1]][-1].stop
# create disc of domain => slice for each child
self._children_slices = [
self.create_slices(child) for child in self.children
]
else:
self._full_mesh = copy.copy(copy_this._full_mesh)
self._slices = copy.copy(copy_this._slices)
self._size = copy.copy(copy_this._size)
self._children_slices = copy.copy(copy_this._children_slices)
self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.Concatenation._from_json()`."""
snippet["name"] = "domain_concatenation"
snippet["concat_fun"] = None
instance = super()._from_json(snippet)
def repack_defaultDict(slices):
slices = defaultdict(list, slices)
for domain, sls in slices.items():
sls = [slice(s["start"], s["stop"], s["step"]) for s in sls]
slices[domain] = sls
return slices
instance._size = snippet["size"]
instance._slices = repack_defaultDict(snippet["slices"])
instance._children_slices = [
repack_defaultDict(s) for s in snippet["children_slices"]
]
instance.secondary_dimensions_npts = snippet["secondary_dimensions_npts"]
return instance
def _get_auxiliary_domain_repeats(self, auxiliary_domains: dict) -> int:
"""Helper method to read the 'auxiliary_domain' meshes."""
mesh_pts = 1
for level, dom in auxiliary_domains.items():
if level != "primary" and dom != []:
mesh_pts *= self.full_mesh[dom].npts
return mesh_pts
@property
def full_mesh(self):
return self._full_mesh
def create_slices(self, node: pybamm.Symbol) -> defaultdict:
slices = defaultdict(list)
start = 0
end = 0
second_pts = self._get_auxiliary_domain_repeats(self.domains)
if second_pts != self.secondary_dimensions_npts:
raise ValueError(
"""Concatenation and children must have the same number of
points in secondary dimensions"""
)
for _ in range(second_pts):
for dom in node.domain:
end += self.full_mesh[dom].npts
slices[dom].append(slice(start, end))
start = end
return slices
def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
"""See :meth:`Concatenation._concatenation_evaluate()`."""
# preallocate vector
vector = np.empty((self._size, 1))
# loop through domains of children writing subvectors to final vector
for child_vector, slices in zip(children_eval, self._children_slices):
for child_dom, child_slice in slices.items():
for i, _slice in enumerate(child_slice):
vector[self._slices[child_dom][i]] = child_vector[_slice]
return vector
def _concatenation_jac(self, children_jacs):
"""See :meth:`pybamm.Concatenation.concatenation_jac()`."""
# note that this assumes that the children are in the right order and only have
# one domain each
jacs = []
for i in range(self.secondary_dimensions_npts):
for child_jac, slices in zip(children_jacs, self._children_slices):
if len(slices) > 1:
raise NotImplementedError(
"""jacobian only implemented for when each child has
a single domain"""
)
child_slice = next(iter(slices.values()))
jacs.append(pybamm.Index(child_jac, child_slice[i]))
return SparseStack(*jacs)
def _concatenation_new_copy(
self, children: list[pybamm.Symbol], perform_simplifications: bool = True
):
"""See :meth:`pybamm.Concatenation._concatenation_new_copy()`."""
if perform_simplifications:
return simplified_domain_concatenation(
children, self.full_mesh, copy_this=self
)
else:
return DomainConcatenation(children, self.full_mesh, copy_this=self)
[docs]
def to_json(self):
"""
Method to serialise a DomainConcatenation object into JSON.
"""
def unpack_defaultDict(slices):
slices = dict(slices)
for domain, sls in slices.items():
sls = [{"start": s.start, "stop": s.stop, "step": s.step} for s in sls]
slices[domain] = sls
return slices
json_dict = {
"name": self.name,
"id": self.id,
"domains": self.domains,
"slices": unpack_defaultDict(self._slices),
"size": self._size,
"children_slices": [
unpack_defaultDict(child_slice) for child_slice in self._children_slices
],
"secondary_dimensions_npts": self.secondary_dimensions_npts,
}
return json_dict
[docs]
class SparseStack(Concatenation):
"""
A node in the expression tree representing a concatenation of sparse
matrices. As with NumpyConcatenation, we *don't* care about domains.
The class :class:`pybamm.DomainConcatenation`, which *is* careful about
domains and uses broadcasting where appropriate, should be used whenever
possible instead.
Parameters
----------
children : iterable of :class:`Concatenation`
The equations to concatenate
"""
def __init__(self, *children):
children = list(children)
if not any(issparse(child.evaluate_for_shape()) for child in children):
concatenation_function = np.vstack
else:
concatenation_function = vstack
super().__init__(
*children,
name="sparse_stack",
check_domain=False,
concat_fun=concatenation_function,
)
def _concatenation_new_copy(self, children, perform_simplifications=True):
"""See :meth:`pybamm.Concatenation._concatenation_new_copy()`."""
return SparseStack(*children)
class ConcatenationVariable(Concatenation):
"""A Variable representing a concatenation of variables."""
def __init__(self, *children, name: str | None = None):
if name is None:
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
name = None
# name is unchanged if its length is 1
elif len(name) > 1:
name = name[0].capitalize() + name[1:]
if len(children) > 0:
if all(child.scale == children[0].scale for child in children):
self._scale = children[0].scale
else:
raise ValueError("Cannot concatenate symbols with different scales")
if all(child.reference == children[0].reference for child in children):
self._reference = children[0].reference
else:
raise ValueError("Cannot concatenate symbols with different references")
if all(
child.bounds[0] == children[0].bounds[0] for child in children
) and all(child.bounds[1] == children[0].bounds[1] for child in children):
self.bounds = children[0].bounds
else:
raise ValueError("Cannot concatenate symbols with different bounds")
super().__init__(*children, name=name)
print_name = intersect(children[0]._raw_print_name, children[1]._raw_print_name)
for child in children[2:]:
print_name = intersect(print_name, child._raw_print_name)
if print_name.endswith("_"):
print_name = print_name[:-1]
self.print_name = print_name
def substrings(s: str):
for i in range(len(s)):
for j in range(i, len(s)):
yield s[i : j + 1]
def intersect(s1: str, s2: str):
# find all the common strings between two strings
all_intersects = set(substrings(s1)) & set(substrings(s2))
# intersect is the longest such intercept
if len(all_intersects) == 0:
return ""
intersect = max(all_intersects, key=len)
# remove leading and trailing white space
return intersect.lstrip().rstrip()
def simplified_concatenation(*children, name: str | None = None):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
# Simplify concatenation of broadcasts all with the same child to a single
# broadcast across all domains
if len(children) == 0:
raise ValueError("Cannot create empty concatenation")
elif len(children) == 1:
return children[0]
elif all(isinstance(child, pybamm.Variable) for child in children):
return pybamm.ConcatenationVariable(*children, name=name)
else:
# Create Concatenation to easily read domains
concat = Concatenation(*children, name=name)
if all(
isinstance(child, pybamm.Broadcast) and child.child == children[0].child
for child in children
):
unique_child = children[0].orphans[0]
if isinstance(children[0], pybamm.PrimaryBroadcast):
return pybamm.PrimaryBroadcast(unique_child, concat.domain, name=name)
else:
return pybamm.FullBroadcast(
unique_child, broadcast_domains=concat.domains, name=name
)
else:
return concat
def concatenation(*children, name: str | None = None):
"""Helper function to create concatenations."""
# TODO: add option to turn off simplifications
return simplified_concatenation(*children, name=name)
def simplified_numpy_concatenation(*children):
"""Perform simplifications on a numpy concatenation."""
# Turn a concatenation of concatenations into a single concatenation
new_children = []
for child in children:
# extract any children from numpy concatenation
if isinstance(child, NumpyConcatenation):
new_children.extend(child.orphans)
else:
new_children.append(child)
return pybamm.simplify_if_constant(NumpyConcatenation(*new_children))
[docs]
def numpy_concatenation(*children):
"""Helper function to create numpy concatenations."""
# TODO: add option to turn off simplifications
return simplified_numpy_concatenation(*children)
def simplified_domain_concatenation(
children: list[pybamm.Symbol],
mesh: pybamm.Mesh,
copy_this: DomainConcatenation | None = None,
):
"""Perform simplifications on a domain concatenation."""
# Create the DomainConcatenation to read domain and child domain
concat = DomainConcatenation(children, mesh, copy_this=copy_this)
# Simplify Concatenation of StateVectors to a single StateVector
# The sum of the evalation arrays of the StateVectors must be exactly 1
if all(isinstance(child, pybamm.StateVector) for child in children):
sv_children: list[pybamm.StateVector] = children # type: ignore[assignment]
longest_eval_array = len(sv_children[-1]._evaluation_array)
eval_arrays = {}
for child in sv_children:
eval_arrays[child] = np.concatenate(
[
child.evaluation_array,
np.zeros(longest_eval_array - len(child.evaluation_array)),
]
)
first_start = sv_children[0].y_slices[0].start
last_stop = sv_children[-1].y_slices[-1].stop
if all(
sum(array for array in eval_arrays.values())[first_start:last_stop] == 1
):
return pybamm.StateVector(
slice(first_start, last_stop), domains=concat.domains
)
return pybamm.simplify_if_constant(concat)
[docs]
def domain_concatenation(children: list[pybamm.Symbol], mesh: pybamm.Mesh):
"""Helper function to create domain concatenations."""
# TODO: add option to turn off simplifications
return simplified_domain_concatenation(children, mesh)