Source code for pybamm.solvers.scikits_dae_solver

#
# Solver class using Scipy's adaptive time stepper
#
import casadi
import pybamm

import numpy as np
import importlib
import scipy.sparse as sparse

scikits_odes_spec = importlib.util.find_spec("scikits")
if scikits_odes_spec is not None:
    scikits_odes_spec = importlib.util.find_spec("scikits.odes")
    if scikits_odes_spec is not None:
        scikits_odes = importlib.util.module_from_spec(scikits_odes_spec)
        scikits_odes_spec.loader.exec_module(scikits_odes)


[docs]class ScikitsDaeSolver(pybamm.BaseSolver): """Solve a discretised model, using scikits.odes. Parameters ---------- method : str, optional The method to use in solve_ivp (default is "BDF") rtol : float, optional The relative tolerance for the solver (default is 1e-6). atol : float, optional The absolute tolerance for the solver (default is 1e-6). root_method : str or pybamm algebraic solver class, optional The method to use to find initial conditions (for DAE solvers). If a solver class, must be an algebraic solver class. If "casadi", the solver uses casadi's Newton rootfinding algorithm to find initial conditions. Otherwise, the solver uses 'scipy.optimize.root' with method specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for the initial-condition solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). extra_options : dict, optional Any options to pass to the solver. Please consult `scikits.odes documentation <https://bmcage.github.io/odes/dev/index.html>`_ for details. Some common keys: - 'max_steps': maximum (int) number of steps the solver can take """ def __init__( self, method="ida", rtol=1e-6, atol=1e-6, root_method="casadi", root_tol=1e-6, extrap_tol=None, extra_options=None, ): if scikits_odes_spec is None: raise ImportError("scikits.odes is not installed") super().__init__(method, rtol, atol, root_method, root_tol, extrap_tol) self.name = "Scikits DAE solver ({})".format(method) self.extra_options = extra_options or {} pybamm.citations.register("Malengier2018") pybamm.citations.register("Hindmarsh2000") pybamm.citations.register("Hindmarsh2005") def _integrate(self, model, t_eval, inputs_dict=None): """ Solve a model defined by dydt with initial conditions y0. Parameters ---------- model : :class:`pybamm.BaseModel` The model whose solution to calculate. t_eval : numeric type The times at which to compute the solution inputs_dict : dict, optional Any input parameters to pass to the model when solving """ inputs_dict = inputs_dict or {} if model.convert_to_format == "casadi": inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) else: inputs = inputs_dict y0 = model.y0 if isinstance(y0, casadi.DM): y0 = y0.full() y0 = y0.flatten() rhs_algebraic_eval = model.rhs_algebraic_eval events = model.terminate_events_eval jacobian = model.jac_rhs_algebraic_eval if model.convert_to_format == "jax": mass_matrix = model.mass_matrix.entries.toarray() else: mass_matrix = model.mass_matrix.entries if model.convert_to_format == "casadi": def eqsres(t, y, ydot, return_residuals): return_residuals[:] = ( rhs_algebraic_eval(t, y, inputs).full().flatten() - mass_matrix @ ydot ) else: def eqsres(t, y, ydot, return_residuals): return_residuals[:] = ( rhs_algebraic_eval(t, y, inputs).flatten() - mass_matrix @ ydot ) def rootfn(t, y, ydot, return_root): return_root[:] = [float(event(t, y, inputs)) for event in events] extra_options = { **self.extra_options, "old_api": False, "rtol": self.rtol, "atol": self.atol, } if jacobian: jac_y0_t0 = jacobian(t_eval[0], y0, inputs) if sparse.issparse(jac_y0_t0): def jacfn(t, y, ydot, residuals, cj, J): jac_eval = jacobian(t, y, inputs) - cj * mass_matrix J[:][:] = jac_eval.toarray() else: def jacfn(t, y, ydot, residuals, cj, J): jac_eval = jacobian(t, y, inputs) - cj * mass_matrix J[:][:] = jac_eval extra_options.update({"jacfn": jacfn}) if events: extra_options.update({"rootfn": rootfn, "nr_rootfns": len(events)}) # solver works with ydot0 set to zero ydot0 = np.zeros_like(y0) # set up and solve dae_solver = scikits_odes.dae(self.method, eqsres, **extra_options) timer = pybamm.Timer() sol = dae_solver.solve(t_eval, y0, ydot0) integration_time = timer.time() # return solution, we need to tranpose y to match scipy's interface if sol.flag in [0, 2]: # 0 = solved for all t_eval if sol.flag == 0: termination = "final time" # 2 = found root(s) elif sol.flag == 2: termination = "event" if sol.roots.t is None: t_root = None else: t_root = sol.roots.t sol = pybamm.Solution( sol.values.t, np.transpose(sol.values.y), model, inputs_dict, t_root, np.transpose(sol.roots.y), termination, ) sol.integration_time = integration_time return sol else: raise pybamm.SolverError(sol.message)