Source code for pybamm.expression_tree.operations.evaluate

#
# Write a symbol to python
#
import pybamm

# need numpy imported for code generated in EvaluatorPython
import numpy as np  # noqa: F401
import scipy.sparse  # noqa: F401
from collections import OrderedDict


def id_to_python_variable(symbol_id, constant=False):
    """
    This function defines the format for the python variable names used in find_symbols
    and to_python. Variable names are based on a nodes' id to make them unique
    """

    if constant:
        var_format = "self.const_{:05d}"
    else:
        var_format = "self.var_{:05d}"

    # Need to replace "-" character to make them valid python variable names
    return var_format.format(symbol_id).replace("-", "m")


def find_symbols(symbol, constant_symbols, variable_symbols):
    """
    This function converts an expression tree to a dictionary of node id's and strings
    specifying valid python code to calculate that nodes value, given y and t.

    The function distinguishes between nodes that represent constant nodes in the tree
    (e.g. a pybamm.Matrix), and those that are variable (e.g. subtrees that contain
    pybamm.StateVector). The former are put in `constant_symbols`, the latter in
    `variable_symbols`

    Note that it is important that the arguments `constant_symbols` and
    `variable_symbols` be and *ordered* dict, since the final ordering of the code lines
    are important for the calculations. A dict is specified rather than a list so that
    identical subtrees (which give identical id's) are not recalculated in the code

    Parameters
    ----------
    symbol : :class:`pybamm.Symbol`
        The symbol or expression tree to convert

    constant_symbol: collections.OrderedDict
        The output dictionary of constant symbol ids to lines of code

    variable_symbol: collections.OrderedDict
        The output dictionary of variable (with y or t) symbol ids to lines of code

    """
    if symbol.is_constant():
        constant_symbols[symbol.id] = symbol.evaluate()
        return

    # process children recursively
    for child in symbol.children:
        find_symbols(child, constant_symbols, variable_symbols)

    # calculate the variable names that will hold the result of calculating the
    # children variables
    children_vars = [
        id_to_python_variable(child.id, child.is_constant())
        for child in symbol.children
    ]

    if isinstance(symbol, pybamm.BinaryOperator):
        # Multiplication and Division need special handling for scipy sparse matrices
        # TODO: we can pass through a dummy y and t to get the type and then hardcode
        # the right line, avoiding these checks
        if isinstance(symbol, pybamm.Multiplication):
            symbol_str = (
                "scipy.sparse.csr_matrix({0}.multiply({1})) "
                "if scipy.sparse.issparse({0}) else "
                "scipy.sparse.csr_matrix({1}.multiply({0})) "
                "if scipy.sparse.issparse({1}) else "
                "{0} * {1}".format(children_vars[0], children_vars[1])
            )
        elif isinstance(symbol, pybamm.Division):
            symbol_str = (
                "scipy.sparse.csr_matrix({0}.multiply(1/{1})) "
                "if scipy.sparse.issparse({0}) else "
                "{0} / {1}".format(children_vars[0], children_vars[1])
            )
        elif isinstance(symbol, pybamm.Inner):
            symbol_str = (
                "{0}.multiply({1}) "
                "if scipy.sparse.issparse({0}) else "
                "{1}.multiply({0}) "
                "if scipy.sparse.issparse({1}) else "
                "{0} * {1}".format(children_vars[0], children_vars[1])
            )
        elif isinstance(symbol, pybamm.Minimum):
            symbol_str = "np.minimum({},{})".format(children_vars[0], children_vars[1])
        elif isinstance(symbol, pybamm.Maximum):
            symbol_str = "np.maximum({},{})".format(children_vars[0], children_vars[1])
        else:
            symbol_str = children_vars[0] + " " + symbol.name + " " + children_vars[1]

    elif isinstance(symbol, pybamm.UnaryOperator):
        # Index has a different syntax than other univariate operations
        if isinstance(symbol, pybamm.Index):
            symbol_str = "{}[{}:{}]".format(
                children_vars[0], symbol.slice.start, symbol.slice.stop
            )
        else:
            symbol_str = symbol.name + children_vars[0]

    # For a Function we create two lines of code, one in constant_symbols that
    # contains the function handle, the other in variable_symbols that calls that
    # function on the children variables
    elif isinstance(symbol, pybamm.Function):
        constant_symbols[symbol.id] = symbol.function
        funct_var = id_to_python_variable(symbol.id, True)
        children_str = ""
        for child_var in children_vars:
            if children_str == "":
                children_str = child_var
            else:
                children_str += ", " + child_var
        symbol_str = "{}({})".format(funct_var, children_str)

    elif isinstance(symbol, pybamm.Concatenation):

        # don't bother to concatenate if there is only a single child
        if isinstance(symbol, pybamm.NumpyConcatenation):
            if len(children_vars) > 1:
                symbol_str = "np.concatenate(({}))".format(",".join(children_vars))
            else:
                symbol_str = "{}".format(",".join(children_vars))

        elif isinstance(symbol, pybamm.SparseStack):
            if len(children_vars) > 1:
                symbol_str = "scipy.sparse.vstack(({}))".format(",".join(children_vars))
            else:
                symbol_str = "{}".format(",".join(children_vars))

        # DomainConcatenation specifies a particular ordering for the concatenation,
        # which we must follow
        elif isinstance(symbol, pybamm.DomainConcatenation):
            slice_starts = []
            all_child_vectors = []
            for i in range(symbol.secondary_dimensions_npts):
                child_vectors = []
                for child_var, slices in zip(children_vars, symbol._children_slices):
                    for child_dom, child_slice in slices.items():
                        slice_starts.append(symbol._slices[child_dom][i].start)
                        child_vectors.append(
                            "{}[{}:{}]".format(
                                child_var, child_slice[i].start, child_slice[i].stop
                            )
                        )
                all_child_vectors.extend(
                    [v for _, v in sorted(zip(slice_starts, child_vectors))]
                )
            if len(children_vars) > 1 or symbol.secondary_dimensions_npts > 1:
                symbol_str = "np.concatenate(({}))".format(",".join(all_child_vectors))
            else:
                symbol_str = "{}".format(",".join(children_vars))
        else:
            raise NotImplementedError

    # Note: we assume that y is being passed as a column vector
    elif isinstance(symbol, pybamm.StateVector):
        symbol_str = "y[:{}][{}]".format(
            len(symbol.evaluation_array), symbol.evaluation_array
        )

    elif isinstance(symbol, pybamm.Time):
        symbol_str = "t"

    elif isinstance(symbol, pybamm.InputParameter):
        symbol_str = "inputs['{}']".format(symbol.name)

    else:
        raise NotImplementedError(
            "Not implemented for a symbol of type '{}'".format(type(symbol))
        )

    variable_symbols[symbol.id] = symbol_str


def to_python(symbol, debug=False):
    """
    This function converts an expression tree into a dict of constant input values, and
    valid python code that acts like the tree's :func:`pybamm.Symbol.evaluate` function

    Parameters
    ----------
    symbol : :class:`pybamm.Symbol`
        The symbol to convert to python code

    debug : bool
        If set to True, the function also emits debug code

    Returns
    -------
    collections.OrderedDict:
        dict mapping node id to a constant value. Represents all the constant nodes in
        the expression tree
    str:
        valid python code that will evaluate all the variable nodes in the tree.

    """

    constant_values = OrderedDict()
    variable_symbols = OrderedDict()
    find_symbols(symbol, constant_values, variable_symbols)

    line_format = "{} = {}"

    if debug:
        variable_lines = [
            "print('{}'); ".format(
                line_format.format(id_to_python_variable(symbol_id, False), symbol_line)
            )
            + line_format.format(id_to_python_variable(symbol_id, False), symbol_line)
            + "; print(type({0}),{0}.shape)".format(
                id_to_python_variable(symbol_id, False)
            )
            for symbol_id, symbol_line in variable_symbols.items()
        ]
    else:
        variable_lines = [
            line_format.format(id_to_python_variable(symbol_id, False), symbol_line)
            for symbol_id, symbol_line in variable_symbols.items()
        ]

    return constant_values, "\n".join(variable_lines)


[docs]class EvaluatorPython: """ Converts a pybamm expression tree into pure python code that will calculate the result of calling `evaluate(t, y)` on the given expression tree. Parameters ---------- symbol : :class:`pybamm.Symbol` The symbol to convert to python code """ def __init__(self, symbol): constants, self._variable_function = pybamm.to_python(symbol, debug=False) # store all the constant symbols in the tree as internal variables of this # object for symbol_id, value in constants.items(): setattr( self, id_to_python_variable(symbol_id, True).replace("self.", ""), value ) # calculate the final variable that will output the result of calling `evaluate` # on `symbol` self._result_var = id_to_python_variable(symbol.id, symbol.is_constant()) # compile the generated python code self._variable_compiled = compile( self._variable_function, self._result_var, "exec" ) # compile the line that will return the output of `evaluate` self._return_compiled = compile( self._result_var, "return" + self._result_var, "eval" )
[docs] def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): """ Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate` """ # generated code assumes y is a column vector if y is not None and y.ndim == 1: y = y.reshape(-1, 1) # execute code exec(self._variable_compiled) # don't need known_evals, but need to reproduce Symbol.evaluate signature if known_evals is not None: return eval(self._return_compiled), known_evals else: return eval(self._return_compiled)