Source code for pybamm.expression_tree.averages

#
# Classes and methods for averaging
#
from __future__ import annotations

from collections.abc import Callable
from typing import ClassVar

import pybamm


def _is_independent_of(
    symbol: pybamm.Symbol, domain_matches: Callable[[str], bool]
) -> bool:
    """True if no Variable/SpatialVariable leaf has a primary domain that
    ``domain_matches`` matches. Broadcasts from a non-matching domain are
    therefore treated as independent even if the broadcast itself carries a
    matching domain.
    """
    for node in symbol.pre_order():
        if isinstance(node, pybamm.Variable | pybamm.SpatialVariable) and any(
            domain_matches(dom) for dom in node.domain
        ):
            return False
    return True


class _BaseAverage(pybamm.Integral):
    """
    Base class for a symbol representing an average

    Parameters
    -----------
    child : :class:`pybamm.Symbol`
        The child node
    """

    def __init__(
        self,
        child: pybamm.Symbol,
        name: str,
        integration_variable: (
            list[pybamm.IndependentVariable] | pybamm.IndependentVariable
        ),
    ) -> None:
        super().__init__(child, integration_variable)
        self.name = name

    @classmethod
    def domain_matches(cls, d: str) -> bool:
        """Return True if domain string ``d`` is relevant for this average type."""
        raise NotImplementedError  # pragma: no cover

    @classmethod
    def symbol_is_constant(cls, symbol: pybamm.Symbol) -> bool:
        """Return True if ``symbol`` is independent of this average's domain."""
        return _is_independent_of(symbol, cls.domain_matches)

    @classmethod
    def _try_separable(
        cls,
        symbol: pybamm.Symbol,
        average_fn,
    ) -> pybamm.Symbol | None:
        """Rewrite ``avg(symbol)`` using linearity and constant-factor pull-out.

        * ``Addition`` / ``Subtraction`` always split: ``avg(a±b) = avg(a) ± avg(b)``.
        * ``Multiplication`` / ``Division`` split when at least one operand is
          constant under this average.

        Returns ``None`` when no rule applies.
        """
        operator = type(symbol)
        if isinstance(symbol, pybamm.Addition | pybamm.Subtraction):
            left, right = symbol.orphans
            return operator(average_fn(left), average_fn(right))
        if isinstance(symbol, pybamm.Multiplication | pybamm.Division):
            left, right = symbol.orphans
            if cls.symbol_is_constant(left) or cls.symbol_is_constant(right):
                return operator(average_fn(left), average_fn(right))
        return None

    @classmethod
    def from_symbol(cls, symbol: pybamm.Symbol) -> pybamm.Symbol:
        """Create average from symbol with simplifications."""
        raise NotImplementedError  # pragma: no cover


class XAverage(_BaseAverage):
    DOMAINS: ClassVar[tuple[str, ...]] = (
        "negative electrode",
        "separator",
        "positive electrode",
    )

    def __init__(self, child: pybamm.Symbol) -> None:
        if all(n in child.domain[0] for n in ["negative", "particle"]):
            x = pybamm.standard_spatial_vars.x_n
        elif all(n in child.domain[0] for n in ["positive", "particle"]):
            x = pybamm.standard_spatial_vars.x_p
        else:
            x = pybamm.SpatialVariable("x", domain=child.domain)
        super().__init__(child, "x-average", x)

    def _unary_new_copy(
        self, child: pybamm.Symbol, perform_simplifications: bool = True
    ):
        if perform_simplifications:
            return self.from_symbol(child)
        return XAverage(child)

    @classmethod
    def domain_matches(cls, d: str) -> bool:
        return d in cls.DOMAINS

    @classmethod
    def from_symbol(cls, symbol: pybamm.Symbol) -> pybamm.Symbol:
        """Create x-average with simplifications."""
        # Can't take average if symbol evaluates on edges (unless broadcasted)
        if symbol.evaluates_on_edges("primary") and not isinstance(
            symbol, pybamm.Broadcast
        ):
            raise ValueError(
                "Can't take the x-average of a symbol that evaluates on edges"
            )

        # If symbol doesn't have an electrode domain, return unchanged
        if not any(
            any(dom in cls.DOMAINS for dom in domain)
            for domain in symbol.domains.values()
        ):
            return symbol

        # If symbol is a broadcast, reduce by one dimension
        if isinstance(
            symbol,
            pybamm.PrimaryBroadcast | pybamm.SecondaryBroadcast | pybamm.FullBroadcast,
        ):
            if all(dom in cls.DOMAINS for dom in symbol.broadcast_domain):
                return symbol.reduce_one_dimension()
            elif isinstance(symbol, pybamm.PrimaryBroadcast):
                return pybamm.PrimaryBroadcast(
                    cls.from_symbol(symbol.orphans[0]), symbol.broadcast_domain
                )
            elif isinstance(symbol, pybamm.FullBroadcast) and all(
                dom in cls.DOMAINS for dom in symbol.secondary_domain
            ):
                domains = {
                    "primary": symbol.domains["primary"],
                    "secondary": symbol.domains["tertiary"],
                    "tertiary": symbol.domains["quaternary"],
                }
                return pybamm.FullBroadcast(
                    symbol.orphans[0], broadcast_domains=domains
                )
            elif isinstance(symbol, pybamm.FullBroadcast) and all(
                dom in cls.DOMAINS for dom in symbol.tertiary_domain
            ):
                domains = {
                    "primary": symbol.domains["primary"],
                    "secondary": symbol.domains["secondary"],
                    "tertiary": symbol.domains["quaternary"],
                }
                return pybamm.FullBroadcast(
                    symbol.orphans[0], broadcast_domains=domains
                )
            else:  # pragma: no cover
                raise NotImplementedError

        # Concatenation: thickness-weighted average of children
        if isinstance(symbol, pybamm.Concatenation) and not isinstance(
            symbol, pybamm.ConcatenationVariable
        ):
            geo = pybamm.geometric_parameters
            ls = {
                ("negative electrode",): geo.n.L,
                ("separator",): geo.s.L,
                ("positive electrode",): geo.p.L,
                ("separator", "positive electrode"): geo.s.L + geo.p.L,
            }
            out = sum(
                ls[tuple(orp.domain)] * cls.from_symbol(orp) for orp in symbol.orphans
            ) / sum(ls[tuple(orp.domain)] for orp in symbol.orphans)
            return out

        # Linearity + constant-factor pull-out
        simplified = cls._try_separable(symbol, cls.from_symbol)
        if simplified is not None:
            return simplified

        return cls(symbol)


class ZAverage(_BaseAverage):
    DOMAINS: ClassVar[tuple[str, ...]] = ("current collector",)

    def __init__(self, child: pybamm.Symbol) -> None:
        integration_variable: list[pybamm.IndependentVariable] = [
            pybamm.standard_spatial_vars.z
        ]
        super().__init__(child, "z-average", integration_variable)

    def _unary_new_copy(
        self, child: pybamm.Symbol, perform_simplifications: bool = True
    ):
        if perform_simplifications:
            return self.from_symbol(child)
        return ZAverage(child)

    @classmethod
    def domain_matches(cls, d: str) -> bool:
        return d in cls.DOMAINS

    @classmethod
    def from_symbol(cls, symbol: pybamm.Symbol) -> pybamm.Symbol:
        """Create z-average with simplifications."""
        if symbol.evaluates_on_edges("primary"):
            raise ValueError(
                "Can't take the z-average of a symbol that evaluates on edges"
            )

        if symbol.domain not in [[], ["current collector"]]:
            raise pybamm.DomainError(
                "z-average only implemented in the 'current collector' domain, "
                f"but symbol has domains {symbol.domain}"
            )

        if symbol.domain == []:
            return symbol

        if isinstance(symbol, pybamm.Broadcast):
            return symbol.reduce_one_dimension()

        simplified = cls._try_separable(symbol, cls.from_symbol)
        if simplified is not None:
            return simplified

        return cls(symbol)


class YZAverage(_BaseAverage):
    DOMAINS: ClassVar[tuple[str, ...]] = ("current collector",)

    def __init__(self, child: pybamm.Symbol) -> None:
        y = pybamm.standard_spatial_vars.y
        z = pybamm.standard_spatial_vars.z
        integration_variable: list[pybamm.IndependentVariable] = [y, z]
        super().__init__(child, "yz-average", integration_variable)

    def _unary_new_copy(
        self, child: pybamm.Symbol, perform_simplifications: bool = True
    ):
        if perform_simplifications:
            return self.from_symbol(child)
        return YZAverage(child)

    @classmethod
    def domain_matches(cls, d: str) -> bool:
        return d in cls.DOMAINS

    @classmethod
    def from_symbol(cls, symbol: pybamm.Symbol) -> pybamm.Symbol:
        """Create yz-average with simplifications."""
        if symbol.domain not in [[], ["current collector"]]:
            raise pybamm.DomainError(
                "y-z-average only implemented in the 'current collector' domain, "
                f"but symbol has domains {symbol.domain}"
            )

        if symbol.domain == []:
            return symbol

        if isinstance(symbol, pybamm.Broadcast):
            return symbol.reduce_one_dimension()

        simplified = cls._try_separable(symbol, cls.from_symbol)
        if simplified is not None:
            return simplified

        return cls(symbol)


class RAverage(_BaseAverage):
    def __init__(self, child: pybamm.Symbol) -> None:
        integration_variable: list[pybamm.IndependentVariable] = [
            pybamm.SpatialVariable("r", child.domain)
        ]
        super().__init__(child, "r-average", integration_variable)

    def _unary_new_copy(
        self, child: pybamm.Symbol, perform_simplifications: bool = True
    ):
        if perform_simplifications:
            return self.from_symbol(child)
        return RAverage(child)

    @classmethod
    def domain_matches(cls, d: str) -> bool:
        return d.endswith("particle") and not d.endswith("particle size")

    @classmethod
    def from_symbol(cls, symbol: pybamm.Symbol) -> pybamm.Symbol:
        """Create r-average with simplifications."""
        has_particle_domain = symbol.domain != [] and symbol.domain[0].endswith(
            "particle"
        )

        if symbol.evaluates_on_edges("primary"):
            raise ValueError(
                "Can't take the r-average of a symbol that evaluates on edges"
            )

        if not has_particle_domain:
            return symbol

        # SecondaryBroadcast onto electrode: r-average child then broadcast back
        if isinstance(symbol, pybamm.SecondaryBroadcast) and symbol.domains[
            "secondary"
        ] in [["positive electrode"], ["negative electrode"]]:
            child = symbol.orphans[0]
            child_av = cls.from_symbol(child)
            return pybamm.PrimaryBroadcast(child_av, symbol.domains["secondary"])

        # PrimaryBroadcast/FullBroadcast onto particle domain: reduce
        if (
            isinstance(symbol, pybamm.PrimaryBroadcast | pybamm.FullBroadcast)
            and has_particle_domain
        ):
            return symbol.reduce_one_dimension()

        simplified = cls._try_separable(symbol, cls.from_symbol)
        if simplified is not None:
            return simplified

        return cls(symbol)


class SizeAverage(_BaseAverage):
    """Size average uses weighted distribution. Does NOT support separable rewrite
    because the weight (f_a_dist) depends on the symbol's domain and cannot be
    meaningfully reassigned to sub-expressions.
    """

    DOMAINS: ClassVar[tuple[list[str], ...]] = (
        ["negative particle size"],
        ["positive particle size"],
        ["negative primary particle size"],
        ["positive primary particle size"],
        ["negative secondary particle size"],
        ["positive secondary particle size"],
    )

    def __init__(self, child: pybamm.Symbol, f_a_dist) -> None:
        R = pybamm.SpatialVariable("R", domains=child.domains, coord_sys="cartesian")
        integration_variable: list[pybamm.IndependentVariable] = [R]
        super().__init__(child, "size-average", integration_variable)
        self.f_a_dist = f_a_dist

    def _unary_new_copy(
        self, child: pybamm.Symbol, perform_simplifications: bool = True
    ):
        if perform_simplifications:
            return self.from_symbol(child, f_a_dist=self.f_a_dist)
        return SizeAverage(child, f_a_dist=self.f_a_dist)

    @classmethod
    def domain_matches(cls, d: str) -> bool:
        return d.endswith("particle size")

    @classmethod
    def _has_size_domain(cls, symbol: pybamm.Symbol) -> bool:
        """Check if symbol has any particle size domain."""
        return any(
            list(domain) in list(cls.DOMAINS) for domain in symbol.domains.values()
        )

    @classmethod
    def _get_f_a_dist(cls, symbol: pybamm.Symbol) -> pybamm.Symbol | None:
        """Compute area-weighted distribution for the symbol's domain."""
        geo = pybamm.geometric_parameters
        name = "R"
        if "negative" in symbol.domain[0]:
            name += "_n"
        elif "positive" in symbol.domain[0]:
            name += "_p"
        if "primary" in symbol.domain[0]:
            name += "_prim"
        elif "secondary" in symbol.domain[0]:
            name += "_sec"

        R = pybamm.SpatialVariable(name, domains=symbol.domains, coord_sys="cartesian")

        domains = symbol.domains
        if ["negative particle size"] in domains.values() or [
            "negative primary particle size"
        ] in domains.values():
            return geo.n.prim.f_a_dist(R)
        if ["negative secondary particle size"] in domains.values():
            return geo.n.sec.f_a_dist(R)
        if ["positive particle size"] in domains.values() or [
            "positive primary particle size"
        ] in domains.values():
            return geo.p.prim.f_a_dist(R)
        if ["positive secondary particle size"] in domains.values():
            return geo.p.sec.f_a_dist(R)
        return None  # pragma: no cover

    @classmethod
    def from_symbol(
        cls, symbol: pybamm.Symbol, f_a_dist: pybamm.Symbol | None = None
    ) -> pybamm.Symbol:
        """Create size-average with simplifications.

        Note: Does NOT use separable rewrite because the weighted average
        with distribution-dependent weight cannot be naively split without
        breaking conservation.
        """
        if symbol.evaluates_on_edges("primary"):
            raise ValueError(
                "Can't take the size-average of a symbol that evaluates on edges"
            )

        # If no size domain, return unchanged
        if symbol.domain == [] or not cls._has_size_domain(symbol):
            return symbol

        # PrimaryBroadcast to particle size: return orphan
        if isinstance(symbol, pybamm.PrimaryBroadcast) and symbol.domain in [
            ["negative particle size"],
            ["positive particle size"],
        ]:
            return symbol.orphans[0]

        # SecondaryBroadcast to particle size: return orphan
        if isinstance(symbol, pybamm.SecondaryBroadcast) and symbol.domains[
            "secondary"
        ] in [["negative particle size"], ["positive particle size"]]:
            return symbol.orphans[0]

        # Compute f_a_dist if not provided
        if f_a_dist is None:
            f_a_dist = cls._get_f_a_dist(symbol)

        return cls(symbol, f_a_dist)


# Convenience functions (thin wrappers)
[docs] def x_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the x-direction. Parameters ---------- symbol : :class:`pybamm.Symbol` The function to be averaged Returns ------- :class:`Symbol` the new averaged symbol """ return XAverage.from_symbol(symbol)
[docs] def z_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the z-direction. Parameters ---------- symbol : :class:`pybamm.Symbol` The function to be averaged Returns ------- :class:`Symbol` the new averaged symbol """ return ZAverage.from_symbol(symbol)
[docs] def yz_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the y-z-direction. Parameters ---------- symbol : :class:`pybamm.Symbol` The function to be averaged Returns ------- :class:`Symbol` the new averaged symbol """ return YZAverage.from_symbol(symbol)
def xyz_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return YZAverage.from_symbol(XAverage.from_symbol(symbol))
[docs] def r_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the r-direction. Parameters ---------- symbol : :class:`pybamm.Symbol` The function to be averaged Returns ------- :class:`Symbol` the new averaged symbol """ return RAverage.from_symbol(symbol)
[docs] def size_average( symbol: pybamm.Symbol, f_a_dist: pybamm.Symbol | None = None ) -> pybamm.Symbol: """Convenience function for averaging over particle size R using the area-weighted particle-size distribution. Parameters ---------- symbol : :class:`pybamm.Symbol` The function to be averaged Returns ------- :class:`Symbol` the new averaged symbol """ return SizeAverage.from_symbol(symbol, f_a_dist)
def xyzs_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return xyz_average(size_average(symbol))