#
# Function classes and methods
#
import numbers
import autograd
import numpy as np
import sympy
from scipy import special
import pybamm
[docs]class Function(pybamm.Symbol):
"""
A node in the expression tree representing an arbitrary function.
Parameters
----------
function : method
A function can have 0 or many inputs. If no inputs are given, self.evaluate()
simply returns func(). Otherwise, self.evaluate(t, y, u) returns
func(child0.evaluate(t, y, u), child1.evaluate(t, y, u), etc).
children : :class:`pybamm.Symbol`
The children nodes to apply the function to
derivative : str, optional
Which derivative to use when differentiating ("autograd" or "derivative").
Default is "autograd".
differentiated_function : method, optional
The function which was differentiated to obtain this one. Default is None.
"""
def __init__(
self,
function,
*children,
name=None,
derivative="autograd",
differentiated_function=None,
):
# Turn numbers into scalars
children = list(children)
for idx, child in enumerate(children):
if isinstance(child, numbers.Number):
children[idx] = pybamm.Scalar(child)
if name is not None:
self.name = name
else:
try:
name = "function ({})".format(function.__name__)
except AttributeError:
name = "function ({})".format(function.__class__)
domains = self.get_children_domains(children)
self.function = function
self.derivative = derivative
self.differentiated_function = differentiated_function
super().__init__(name, children=children, domains=domains)
def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
out = "{}(".format(self.name[10:-1])
for child in self.children:
out += "{!s}, ".format(child)
out = out[:-2] + ")"
return out
[docs] def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
if variable == self:
return pybamm.Scalar(1)
else:
children = self.orphans
partial_derivatives = [None] * len(children)
for i, child in enumerate(self.children):
# if variable appears in the function, differentiate
# function, and apply chain rule
if variable in child.pre_order():
partial_derivatives[i] = self._function_diff(
children, i
) * child.diff(variable)
# remove None entries
partial_derivatives = [x for x in partial_derivatives if x is not None]
derivative = sum(partial_derivatives)
if derivative == 0:
derivative = pybamm.Scalar(0)
return derivative
def _function_diff(self, children, idx):
"""
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
# Store differentiated function, needed in case we want to convert to CasADi
if self.derivative == "autograd":
return Function(
autograd.elementwise_grad(self.function, idx),
*children,
differentiated_function=self.function,
)
elif self.derivative == "derivative":
if len(children) > 1:
raise ValueError(
"""
differentiation using '.derivative()' not implemented for functions
with more than one child
"""
)
else:
# keep using "derivative" as derivative
return pybamm.Function(
self.function.derivative(),
*children,
derivative="derivative",
differentiated_function=self.function,
)
def _function_jac(self, children_jacs):
"""Calculate the Jacobian of a function."""
if all(child.evaluates_to_constant_number() for child in self.children):
jacobian = pybamm.Scalar(0)
else:
# if at least one child contains variable dependence, then
# calculate the required partial Jacobians and add them
jacobian = None
children = self.orphans
for i, child in enumerate(children):
if not child.evaluates_to_constant_number():
jac_fun = self._function_diff(children, i) * children_jacs[i]
jac_fun.clear_domains()
if jacobian is None:
jacobian = jac_fun
else:
jacobian += jac_fun
return jacobian
[docs] def evaluate(self, t=None, y=None, y_dot=None, inputs=None):
"""See :meth:`pybamm.Symbol.evaluate()`."""
evaluated_children = [
child.evaluate(t, y, y_dot, inputs) for child in self.children
]
return self._function_evaluate(evaluated_children)
def _evaluates_on_edges(self, dimension):
"""See :meth:`pybamm.Symbol._evaluates_on_edges()`."""
return any(child.evaluates_on_edges(dimension) for child in self.children)
[docs] def is_constant(self):
"""See :meth:`pybamm.Symbol.is_constant()`."""
return all(child.is_constant() for child in self.children)
def _evaluate_for_shape(self):
"""
Default behaviour: has same shape as all child
See :meth:`pybamm.Symbol.evaluate_for_shape()`
"""
evaluated_children = [child.evaluate_for_shape() for child in self.children]
return self._function_evaluate(evaluated_children)
def _function_evaluate(self, evaluated_children):
return self.function(*evaluated_children)
[docs] def create_copy(self):
"""See :meth:`pybamm.Symbol.new_copy()`."""
children_copy = [child.new_copy() for child in self.children]
return self._function_new_copy(children_copy)
def _function_new_copy(self, children):
"""
Returns a new copy of the function.
Inputs
------
children : : list
A list of the children of the function
Returns
-------
: :pybamm.Function
A new copy of the function
"""
return pybamm.simplify_if_constant(
pybamm.Function(
self.function,
*children,
name=self.name,
derivative=self.derivative,
differentiated_function=self.differentiated_function,
)
)
def _sympy_operator(self, child):
"""Apply appropriate SymPy operators."""
return child
[docs] def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.Symbol(self.print_name)
else:
eq_list = []
for child in self.children:
eq = child.to_equation()
eq_list.append(eq)
return self._sympy_operator(*eq_list)
def simplified_function(func_class, child):
"""
Simplifications implemented before applying the function.
Currently only implemented for one-child functions.
"""
if isinstance(child, pybamm.Broadcast):
# Move the function inside the broadcast
# Apply recursively
func_child_not_broad = pybamm.simplify_if_constant(
simplified_function(func_class, child.orphans[0])
)
return child._unary_new_copy(func_child_not_broad)
else:
return pybamm.simplify_if_constant(func_class(child))
[docs]class SpecificFunction(Function):
"""
Parent class for the specific functions, which implement their own `diff`
operators directly.
Parameters
----------
function : method
Function to be applied to child
child : :class:`pybamm.Symbol`
The child to apply the function to
"""
def __init__(self, function, child):
super().__init__(function, child)
def _function_new_copy(self, children):
"""See :meth:`pybamm.Function._function_new_copy()`"""
return pybamm.simplify_if_constant(self.__class__(*children))
def _sympy_operator(self, child):
"""Apply appropriate SymPy operators."""
class_name = self.__class__.__name__.lower()
sympy_function = getattr(sympy, class_name)
return sympy_function(child)
[docs]class Arcsinh(SpecificFunction):
"""Arcsinh function."""
def __init__(self, child):
super().__init__(np.arcsinh, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Symbol._function_diff()`."""
return 1 / sqrt(children[0] ** 2 + 1)
def _sympy_operator(self, child):
"""Override :meth:`pybamm.Function._sympy_operator`"""
return sympy.asinh(child)
[docs]def arcsinh(child):
"""Returns arcsinh function of child."""
return simplified_function(Arcsinh, child)
[docs]class Arctan(SpecificFunction):
"""Arctan function."""
def __init__(self, child):
super().__init__(np.arctan, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return 1 / (children[0] ** 2 + 1)
def _sympy_operator(self, child):
"""Override :meth:`pybamm.Function._sympy_operator`"""
return sympy.atan(child)
[docs]def arctan(child):
"""Returns hyperbolic tan function of child."""
return simplified_function(Arctan, child)
[docs]class Cos(SpecificFunction):
"""Cosine function."""
def __init__(self, child):
super().__init__(np.cos, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Symbol._function_diff()`."""
return -sin(children[0])
[docs]def cos(child):
"""Returns cosine function of child."""
return simplified_function(Cos, child)
[docs]class Cosh(SpecificFunction):
"""Hyberbolic cosine function."""
def __init__(self, child):
super().__init__(np.cosh, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return sinh(children[0])
[docs]def cosh(child):
"""Returns hyperbolic cosine function of child."""
return simplified_function(Cosh, child)
[docs]class Erf(SpecificFunction):
"""Error function."""
def __init__(self, child):
super().__init__(special.erf, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return 2 / np.sqrt(np.pi) * exp(-children[0] ** 2)
[docs]def erf(child):
"""Returns error function of child."""
return simplified_function(Erf, child)
[docs]def erfc(child):
"""Returns complementary error function of child."""
return 1 - simplified_function(Erf, child)
[docs]class Exp(SpecificFunction):
"""Exponential function."""
def __init__(self, child):
super().__init__(np.exp, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return exp(children[0])
[docs]def exp(child):
"""Returns exponential function of child."""
return simplified_function(Exp, child)
[docs]class Log(SpecificFunction):
"""Logarithmic function."""
def __init__(self, child):
super().__init__(np.log, child)
def _function_evaluate(self, evaluated_children):
# don't raise RuntimeWarning for NaNs
with np.errstate(invalid="ignore"):
return np.log(*evaluated_children)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return 1 / children[0]
[docs]def log(child, base="e"):
"""Returns logarithmic function of child (any base, default 'e')."""
log_child = simplified_function(Log, child)
if base == "e":
return log_child
else:
return log_child / np.log(base)
[docs]def log10(child):
"""Returns logarithmic function of child, with base 10."""
return log(child, base=10)
[docs]class Max(SpecificFunction):
"""Max function."""
def __init__(self, child):
super().__init__(np.max, child)
def _evaluate_for_shape(self):
"""See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`"""
# Max will always return a scalar
return np.nan * np.ones((1, 1))
[docs]def max(child):
"""
Returns max function of child. Not to be confused with :meth:`pybamm.maximum`, which
returns the larger of two objects.
"""
return pybamm.simplify_if_constant(Max(child))
[docs]class Min(SpecificFunction):
"""Min function."""
def __init__(self, child):
super().__init__(np.min, child)
def _evaluate_for_shape(self):
"""See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`"""
# Min will always return a scalar
return np.nan * np.ones((1, 1))
[docs]def min(child):
"""
Returns min function of child. Not to be confused with :meth:`pybamm.minimum`, which
returns the smaller of two objects.
"""
return pybamm.simplify_if_constant(Min(child))
[docs]def sech(child):
"""Returns hyperbolic sec function of child."""
return 1 / simplified_function(Cosh, child)
[docs]class Sin(SpecificFunction):
"""Sine function."""
def __init__(self, child):
super().__init__(np.sin, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return cos(children[0])
[docs]def sin(child):
"""Returns sine function of child."""
return simplified_function(Sin, child)
[docs]class Sinh(SpecificFunction):
"""Hyperbolic sine function."""
def __init__(self, child):
super().__init__(np.sinh, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return cosh(children[0])
[docs]def sinh(child):
"""Returns hyperbolic sine function of child."""
return simplified_function(Sinh, child)
[docs]class Sqrt(SpecificFunction):
"""Square root function."""
def __init__(self, child):
super().__init__(np.sqrt, child)
def _function_evaluate(self, evaluated_children):
# don't raise RuntimeWarning for NaNs
with np.errstate(invalid="ignore"):
return np.sqrt(*evaluated_children)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return 1 / (2 * sqrt(children[0]))
[docs]def sqrt(child):
"""Returns square root function of child."""
return simplified_function(Sqrt, child)
[docs]class Tanh(SpecificFunction):
"""Hyperbolic tan function."""
def __init__(self, child):
super().__init__(np.tanh, child)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return sech(children[0]) ** 2
[docs]def tanh(child):
"""Returns hyperbolic tan function of child."""
return simplified_function(Tanh, child)