Source code for pybamm.expression_tree.concatenations

#
# Concatenation classes
#
from __future__ import annotations

import copy
from collections import defaultdict
from collections.abc import Sequence
from typing import Any

import casadi
import numpy as np
import numpy.typing as npt
import sympy
from scipy.sparse import issparse, vstack

import pybamm


[docs] class Concatenation(pybamm.Symbol): """ A node in the expression tree representing a concatenation of symbols. Parameters ---------- children : iterable of :class:`pybamm.Symbol` The symbols to concatenate """ def __init__( self, *children: pybamm.Symbol, name: str | None = None, check_domain=True, concat_fun=None, ): # The second condition checks whether this is the base Concatenation class # or a subclass of Concatenation # (ConcatenationVariable, NumpyConcatenation, ...) if all(isinstance(child, pybamm.Variable) for child in children) and issubclass( Concatenation, type(self) ): raise TypeError( "'ConcatenationVariable' should be used for concatenating 'Variable' " "objects. We recommend using the 'concatenation' function, which will " "automatically choose the best form." ) if name is None: name = "concatenation" if check_domain: domains = self.get_children_domains(children) else: domains = {"primary": []} self.concatenation_function = concat_fun super().__init__(name, children, domains=domains) @classmethod def _from_json(cls, snippet: dict): """Creates a new Concatenation instance from a json object""" instance = cls.__new__(cls) instance.concatenation_function = snippet["concat_fun"] super(Concatenation, instance).__init__( snippet["name"], tuple(snippet["children"]), domains=snippet["domains"] ) return instance def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" out = self.name + "(" for child in self.children: out += f"{child!s}, " out = out[:-2] + ")" return out def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" children_diffs = [child.diff(variable) for child in self.children] if len(children_diffs) == 1: diff = children_diffs[0] else: diff = self.__class__(*children_diffs) return diff
[docs] def get_children_domains(self, children: Sequence[pybamm.Symbol]): # combine domains from children domain: list = [] for child in children: if not isinstance(child, pybamm.Symbol): raise TypeError(f"{child} is not a pybamm symbol") child_domain = child.domain if child_domain == []: raise pybamm.DomainError( f"Cannot concatenate child '{child}' with empty domain" ) if set(domain).isdisjoint(child_domain): domain += child_domain else: raise pybamm.DomainError("domain of children must be disjoint") auxiliary_domains = children[0].domains for level, dom in auxiliary_domains.items(): if level != "primary" and dom != []: for child in children[1:]: if child.domains[level] not in [dom, []]: raise pybamm.DomainError( "children must have same or empty auxiliary domains" ) domains = {**auxiliary_domains, "primary": domain} return domains
def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]): """See :meth:`Concatenation._concatenation_evaluate()`.""" if len(children_eval) == 0: return np.array([]) else: return self.concatenation_function(children_eval)
[docs] def evaluate( self, t: float | None = None, y: npt.NDArray[np.float64] | None = None, y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" children_eval = [child.evaluate(t, y, y_dot, inputs) for child in self.children] return self._concatenation_evaluate(children_eval)
[docs] def create_copy( self, new_children: list[pybamm.Symbol] | None = None, perform_simplifications: bool = True, ): """See :meth:`pybamm.Symbol.new_copy()`.""" children = self._children_for_copying(new_children) return self._concatenation_new_copy(children, perform_simplifications)
def _concatenation_new_copy(self, children, perform_simplifications: bool = True): """ Creates a copy for the current concatenation class using the convenience function :meth:`concatenation` to perform simplifications based on the new children before creating the new copy. """ if perform_simplifications: return concatenation(*children, name=self.name) else: return self.__class__(*children, name=self.name) def _concatenation_jac(self, children_jacs): """Calculate the Jacobian of a concatenation.""" raise NotImplementedError def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape`""" if len(self.children) == 0: return np.array([]) else: # Default: use np.concatenate concatenation_function = self.concatenation_function or np.concatenate return concatenation_function( [child.evaluate_for_shape() for child in self.children] )
[docs] def is_constant(self): """See :meth:`pybamm.Symbol.is_constant()`.""" return all(child.is_constant() for child in self.children)
def _sympy_operator(self, *children): """Apply appropriate SymPy operators.""" self.concat_latex = tuple(map(sympy.latex, children)) if self.print_name is not None: return sympy.Symbol(self.print_name) else: concat_str = r"\\".join(self.concat_latex) concat_sym = sympy.Symbol(r"\begin{cases}" + concat_str + r"\end{cases}") return concat_sym
[docs] def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" eq_list = [] for child in self.children: eq = child.to_equation() eq_list.append(eq) return self._sympy_operator(*eq_list)
[docs] class NumpyConcatenation(Concatenation): """ A node in the expression tree representing a concatenation of equations, when we *don't* care about domains. The class :class:`pybamm.DomainConcatenation`, which *is* careful about domains and uses broadcasting where appropriate, should be used whenever possible instead. Upon evaluation, equations are concatenated using numpy concatenation. Parameters ---------- children : iterable of :class:`pybamm.Symbol` The equations to concatenate """ def __init__(self, *children: pybamm.Symbol): children = list(children) # Turn objects that evaluate to scalars to objects that evaluate to vectors, # so that we can concatenate them for i, child in enumerate(children): if child.evaluates_to_number(): children[i] = child * pybamm.Vector([1]) super().__init__( *children, name="numpy_concatenation", check_domain=False, concat_fun=np.concatenate, ) @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.Concatenation._from_json()`.""" snippet["name"] = "numpy_concatenation" snippet["concat_fun"] = np.concatenate instance = super()._from_json(snippet) return instance def _to_casadi(self, t, y, y_dot, inputs, casadi_symbols): """See :meth:`pybamm.Symbol._to_casadi()`.""" converted_children = self._children_to_casadi( t, y, y_dot, inputs, casadi_symbols ) return casadi.vertcat(*converted_children) def _concatenation_jac(self, children_jacs): """See :meth:`pybamm.Concatenation.concatenation_jac()`.""" children = self.children if len(children) == 0: return pybamm.Scalar(0) else: return SparseStack(*children_jacs) def _concatenation_new_copy( self, children, perform_simplifications: bool = True, ): """See :meth:`pybamm.Concatenation._concatenation_new_copy()`.""" if perform_simplifications: return numpy_concatenation(*children) else: raise NotImplementedError( f"{self.__class__.__name__} should always be copied using " "simplification checks" )
[docs] class DomainConcatenation(Concatenation): """ A node in the expression tree representing a concatenation of symbols, being careful about domains. It is assumed that each child has a domain, and the final concatenated vector will respect the sizes and ordering of domains established in mesh keys Parameters ---------- children : iterable of :class:`pybamm.Symbol` The symbols to concatenate full_mesh : :class:`pybamm.Mesh` The underlying mesh for discretisation, used to obtain the number of mesh points in each domain. copy_this : :class:`pybamm.DomainConcatenation` (optional) if provided, this class is initialised by copying everything except the children from `copy_this`. `mesh` is not used in this case """ def __init__( self, children: Sequence[pybamm.Symbol], full_mesh: pybamm.Mesh, copy_this: pybamm.DomainConcatenation | None = None, ): # Convert any constant symbols in children to a Vector of the right size for # concatenation children = list(children) # Allow the base class to sort the domains into the correct order super().__init__(*children, name="domain_concatenation") if copy_this is None: # store mesh self._full_mesh = full_mesh # create dict of domain => slice of final vector self.secondary_dimensions_npts = self._get_auxiliary_domain_repeats( self.domains ) self._slices = self.create_slices(self) # store size of final vector self._size = self._slices[self.domain[-1]][-1].stop # create disc of domain => slice for each child self._children_slices = [ self.create_slices(child) for child in self.children ] else: self._full_mesh = copy.copy(copy_this._full_mesh) self._slices = copy.copy(copy_this._slices) self._size = copy.copy(copy_this._size) self._children_slices = copy.copy(copy_this._children_slices) self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.Concatenation._from_json()`.""" snippet["name"] = "domain_concatenation" snippet["concat_fun"] = None instance = super()._from_json(snippet) def repack_defaultDict(slices): slices = defaultdict(list, slices) for domain, sls in slices.items(): sls = [slice(s["start"], s["stop"], s["step"]) for s in sls] slices[domain] = sls return slices instance._size = snippet["size"] instance._slices = repack_defaultDict(snippet["slices"]) instance._children_slices = [ repack_defaultDict(s) for s in snippet["children_slices"] ] instance.secondary_dimensions_npts = snippet["secondary_dimensions_npts"] return instance def _get_auxiliary_domain_repeats(self, auxiliary_domains: dict) -> int: """Helper method to read the 'auxiliary_domain' meshes.""" mesh_pts = 1 for level, dom in auxiliary_domains.items(): if level != "primary" and dom != []: mesh_pts *= self.full_mesh[dom].npts return mesh_pts @property def full_mesh(self): return self._full_mesh def create_slices(self, node: pybamm.Symbol) -> defaultdict: slices = defaultdict(list) start = 0 end = 0 second_pts = self._get_auxiliary_domain_repeats(self.domains) if second_pts != self.secondary_dimensions_npts: raise ValueError( """Concatenation and children must have the same number of points in secondary dimensions""" ) for _ in range(second_pts): for dom in node.domain: end += self.full_mesh[dom].npts slices[dom].append(slice(start, end)) start = end return slices def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]): """See :meth:`Concatenation._concatenation_evaluate()`.""" # preallocate vector vector = np.empty((self._size, 1)) # loop through domains of children writing subvectors to final vector for child_vector, slices in zip( children_eval, self._children_slices, strict=True ): for child_dom, child_slice in slices.items(): for i, _slice in enumerate(child_slice): vector[self._slices[child_dom][i]] = child_vector[_slice] return vector def _to_casadi(self, t, y, y_dot, inputs, casadi_symbols): """See :meth:`pybamm.Symbol._to_casadi()`.""" converted_children = self._children_to_casadi( t, y, y_dot, inputs, casadi_symbols ) slice_starts = [] all_child_vectors = [] for i in range(self.secondary_dimensions_npts): child_vectors = [] for child_var, slices in zip( converted_children, self._children_slices, strict=True ): for child_dom, child_slice in slices.items(): slice_starts.append(self._slices[child_dom][i].start) child_vectors.append( child_var[child_slice[i].start : child_slice[i].stop] ) all_child_vectors.extend( [v for _, v in sorted(zip(slice_starts, child_vectors, strict=False))] ) return casadi.vertcat(*all_child_vectors) def _concatenation_jac(self, children_jacs): """See :meth:`pybamm.Concatenation.concatenation_jac()`.""" # note that this assumes that the children are in the right order and only have # one domain each jacs = [] for i in range(self.secondary_dimensions_npts): for child_jac, slices in zip( children_jacs, self._children_slices, strict=True ): if len(slices) > 1: raise NotImplementedError( """jacobian only implemented for when each child has a single domain""" ) child_slice = next(iter(slices.values())) jacs.append(pybamm.Index(child_jac, child_slice[i])) return SparseStack(*jacs) def _concatenation_new_copy( self, children: list[pybamm.Symbol], perform_simplifications: bool = True ): """See :meth:`pybamm.Concatenation._concatenation_new_copy()`.""" if perform_simplifications: return simplified_domain_concatenation( children, self.full_mesh, copy_this=self ) else: return DomainConcatenation(children, self.full_mesh, copy_this=self)
[docs] def to_json(self): """ Method to serialise a DomainConcatenation object into JSON. """ def unpack_defaultDict(slices): slices = dict(slices) for domain, sls in slices.items(): sls = [{"start": s.start, "stop": s.stop, "step": s.step} for s in sls] slices[domain] = sls return slices json_dict = { "name": self.name, "id": self.id, "domains": self.domains, "slices": unpack_defaultDict(self._slices), "size": self._size, "children_slices": [ unpack_defaultDict(child_slice) for child_slice in self._children_slices ], "secondary_dimensions_npts": self.secondary_dimensions_npts, } return json_dict
[docs] class SparseStack(Concatenation): """ A node in the expression tree representing a concatenation of sparse matrices. As with NumpyConcatenation, we *don't* care about domains. The class :class:`pybamm.DomainConcatenation`, which *is* careful about domains and uses broadcasting where appropriate, should be used whenever possible instead. Parameters ---------- children : iterable of :class:`Concatenation` The equations to concatenate """ def __init__(self, *children, name="sparse_stack"): children = list(children) if not any(issparse(child.evaluate_for_shape()) for child in children): concatenation_function = np.vstack else: concatenation_function = vstack super().__init__( *children, name=name, check_domain=False, concat_fun=concatenation_function, ) def _to_casadi(self, t, y, y_dot, inputs, casadi_symbols): """See :meth:`pybamm.Symbol._to_casadi()`.""" converted_children = self._children_to_casadi( t, y, y_dot, inputs, casadi_symbols ) return casadi.vertcat(*converted_children) def _concatenation_new_copy(self, children, perform_simplifications=True): """See :meth:`pybamm.Concatenation._concatenation_new_copy()`.""" return SparseStack(*children)
class ConcatenationVariable(Concatenation): """A Variable representing a concatenation of variables.""" def __init__(self, *children, name: str | None = None): if name is None: # Name is the intersection of the children names (should usually make sense # if the children have been named consistently) name = intersect(children[0].name, children[1].name) for child in children[2:]: name = intersect(name, child.name) if len(name) == 0: name = None # name is unchanged if its length is 1 elif len(name) > 1: name = name[0].capitalize() + name[1:] if len(children) > 0: if all(child.scale == children[0].scale for child in children): self._scale = children[0].scale else: raise ValueError("Cannot concatenate symbols with different scales") if all(child.reference == children[0].reference for child in children): self._reference = children[0].reference else: raise ValueError("Cannot concatenate symbols with different references") if all( child.bounds[0] == children[0].bounds[0] for child in children ) and all(child.bounds[1] == children[0].bounds[1] for child in children): self.bounds = children[0].bounds else: raise ValueError("Cannot concatenate symbols with different bounds") super().__init__(*children, name=name) print_name = intersect(children[0]._raw_print_name, children[1]._raw_print_name) for child in children[2:]: print_name = intersect(print_name, child._raw_print_name) if print_name.endswith("_"): print_name = print_name[:-1] self.print_name = print_name def substrings(s: str): for i in range(len(s)): for j in range(i, len(s)): yield s[i : j + 1] def intersect(s1: str, s2: str): # find all the common strings between two strings all_intersects = set(substrings(s1)) & set(substrings(s2)) # intersect is the longest such intercept if len(all_intersects) == 0: return "" intersect = max(all_intersects, key=len) # remove leading and trailing white space return intersect.lstrip().rstrip() def simplified_concatenation(*children, name: str | None = None): """Perform simplifications on a concatenation.""" # remove children that are None children = list(filter(lambda x: x is not None, children)) # Simplify concatenation of broadcasts all with the same child to a single # broadcast across all domains if len(children) == 0: raise ValueError("Cannot create empty concatenation") elif len(children) == 1: return children[0] elif all(isinstance(child, pybamm.Variable) for child in children): return pybamm.ConcatenationVariable(*children, name=name) else: # Create Concatenation to easily read domains concat = Concatenation(*children, name=name) if all( isinstance(child, pybamm.Broadcast) and child.child == children[0].child for child in children ): unique_child = children[0].orphans[0] if isinstance(children[0], pybamm.PrimaryBroadcast): return pybamm.PrimaryBroadcast(unique_child, concat.domain, name=name) else: return pybamm.FullBroadcast( unique_child, broadcast_domains=concat.domains, name=name ) else: return concat def concatenation(*children, name: str | None = None): """Helper function to create concatenations.""" # TODO: add option to turn off simplifications return simplified_concatenation(*children, name=name) def simplified_numpy_concatenation(*children): """Perform simplifications on a numpy concatenation.""" # Turn a concatenation of concatenations into a single concatenation new_children = [] for child in children: # extract any children from numpy concatenation if isinstance(child, NumpyConcatenation): new_children.extend(child.orphans) else: new_children.append(child) return pybamm.simplify_if_constant(NumpyConcatenation(*new_children))
[docs] def numpy_concatenation(*children): """Helper function to create numpy concatenations.""" # TODO: add option to turn off simplifications return simplified_numpy_concatenation(*children)
def simplified_domain_concatenation( children: list[pybamm.Symbol], mesh: pybamm.Mesh, copy_this: DomainConcatenation | None = None, ): """Perform simplifications on a domain concatenation.""" # Create the DomainConcatenation to read domain and child domain concat = DomainConcatenation(children, mesh, copy_this=copy_this) # Simplify Concatenation of StateVectors to a single StateVector # The sum of the evalation arrays of the StateVectors must be exactly 1 if all(isinstance(child, pybamm.StateVector) for child in children): sv_children: list[pybamm.StateVector] = children # type: ignore[assignment] longest_eval_array = len(sv_children[-1]._evaluation_array) eval_arrays = {} for child in sv_children: eval_arrays[child] = np.concatenate( [ child.evaluation_array, np.zeros(longest_eval_array - len(child.evaluation_array)), ] ) first_start = sv_children[0].y_slices[0].start last_stop = sv_children[-1].y_slices[-1].stop if all( sum(array for array in eval_arrays.values())[first_start:last_stop] == 1 ): return pybamm.StateVector( slice(first_start, last_stop), domains=concat.domains ) return pybamm.simplify_if_constant(concat)
[docs] def domain_concatenation(children: list[pybamm.Symbol], mesh: pybamm.Mesh): """Helper function to create domain concatenations.""" # TODO: add option to turn off simplifications return simplified_domain_concatenation(children, mesh)