Source code for pybamm.solvers.idaklu_solver

# mypy: ignore-errors
import logging
import math
import numbers
import warnings
from enum import IntEnum

import casadi
import numpy as np
import pybammsolvers.idaklu as idaklu
from scipy.sparse.linalg import spsolve

import pybamm
from pybamm.codegen.compilation import aot_compile

_UNSET = object()


def _flatten_inputs(inputs_dict):
    """Flatten ``{name: value}`` into a 1-D float array in dict-key order."""
    if not inputs_dict:
        return np.zeros(0)
    return np.concatenate([np.asarray(v).reshape(-1) for v in inputs_dict.values()])


# Mirrors SUNDIALS ``IDA_ROOT_RETURN`` in ``sundials/include/ida/ida.h``.
# Returned by ``IDASolve`` (and surfaced via ``Solution.flag``) when the
# integrator has located one or more root function zeros.
_IDA_ROOT_RETURN = 2

# Function entries staged in ``_setup`` while building the C++ solver
# group; popped after construction (C++ owns them via shared_ptr).
_SETUP_FCN_KEYS = (
    "rhs_algebraic",
    "jac_times_cjmass",
    "jac_rhs_algebraic_action",
    "mass_action",
    "sensfn",
    "rootfn",
    "alg_res",
    "alg_jac",
)
_SETUP_FCN_LIST_KEYS = (
    "var_idaklu_fcns",
    "dvar_dy_idaklu_fcns",
    "dvar_dp_idaklu_fcns",
)

# Attributes holding casadi.Function graphs (or the C++ solver) that are
# rebuilt from the model on the next solve(). Dropped in __getstate__ so a
# pickled solver doesn't carry these heavy, non-portable objects.
_REBUILDABLE_STATE_KEYS = (
    "_setup",
    "_model_set_up",
    "computed_var_fcns",
    "computed_dvar_dy_fcns",
    "computed_dvar_dp_fcns",
    "_time_integral_vars",
)


[docs] class IDAKLUSolver(pybamm.BaseSolver): """ Solve a discretised model, using sundials with the KLU sparse linear solver. Parameters ---------- rtol : float, optional The relative tolerance for the solver (default is 1e-4). atol : float, optional The absolute tolerance for the solver (default is 1e-6). root_method : str or pybamm algebraic solver class, optional The method to use to find initial conditions (for DAE solvers). Default is None, which uses a custom Newton solver for consistent initial conditions (recommended). If "casadi", the solver uses casadi's Newton rootfinding algorithm as a fallback. Otherwise, the solver uses 'scipy.optimize.root' with method specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for the initial-condition solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). on_extrapolation : str, optional What to do if the solver is extrapolating. Options are "warn", "error", or "ignore". Default is "warn". on_failure : str, optional What to do if a solver error flag occurs. Options are "warn", "error", or "ignore". Default is "error". output_variables : list[str], optional List of variables to calculate and return. If none are specified then the complete state vector is returned (can be very large) (default is []) store_first_last : bool, optional If True, only the first and last sample of each integration window are stored (one experiment step in :meth:`pybamm.Simulation.solve`, or the full ``[t_eval[0], t_eval[-1]]`` window in :meth:`solve`). Intended for memory-light long experiments whose post-processing only reads per-step first/last values. Note: with this flag on, IDAKLU's Hermite interpolation is disabled (see :attr:`options["hermite_interpolation"]`) and any query at a non-endpoint time within a step falls back to linear interpolation across the whole step, so this flag is **not** appropriate when post-processing queries an intra-step time. Default is False. options: dict, optional Addititional options to pass to the solver, by default: .. code-block:: python options = { # Print statistics of the solver after every solve "print_stats": False, # If ``True``, ahead-of-time compile each casadi ``Function`` # to a shared library via the system C compiler (see # :mod:`pybamm.codegen.compilation`). If ``False`` (default) # the casadi in-process virtual machine is used. "compile": False, # Number of threads available for OpenMP (must be greater than or equal to `num_solvers`) "num_threads": 1, # Number of solvers to use in parallel (for solving multiple sets of input parameters in parallel) "num_solvers": num_threads, ## Linear solver interface # name of sundials linear solver to use options are: "SUNLinSol_KLU", # "SUNLinSol_Dense", "SUNLinSol_Band", "SUNLinSol_SPBCGS", # "SUNLinSol_SPFGMR", "SUNLinSol_SPGMR", "SUNLinSol_SPTFQMR", "linear_solver": "SUNLinSol_KLU", # Jacobian form, can be "none", "dense", # "banded", "sparse", "matrix-free" "jacobian": "sparse", # Preconditioner for iterative solvers, can be "none", "BBDP" "preconditioner": "BBDP", # For iterative linear solver preconditioner, bandwidth of # approximate jacobian "precon_half_bandwidth": 5, # For iterative linear solver preconditioner, bandwidth of # approximate jacobian that is kept "precon_half_bandwidth_keep": 5, # For iterative linear solvers, max number of iterations "linsol_max_iterations": 5, # Ratio between linear and nonlinear tolerances "epsilon_linear_tolerance": 0.05, # Increment factor used in DQ Jacobian-vector product approximation "increment_factor": 1.0, # Enable or disable linear solution scaling "linear_solution_scaling": True, # Silence Sundials errors during solve "silence_sundials_errors": False, ## Main solver # Maximum order of the linear multistep method "max_order_bdf": 5, # Maximum number of steps to be taken by the solver in its attempt to # reach the next output time. # Note: this value differs from the IDA default of 500 "max_num_steps": 100000, # Initial step size. The solver default is used if this is left at 0.0 "dt_init": 0.0, # Minimum absolute step size. The solver default is used if this is # left at 0.0 "dt_min": 0.0, # Maximum absolute step size. The solver default is used if this is # left at 0.0 "dt_max": 0.0, # Maximum number of error test failures in attempting one step "max_error_test_failures": 10, # Maximum number of nonlinear solver iterations at one step # Note: this value differs from the IDA default of 4 "max_nonlinear_iterations": 40, # Maximum number of nonlinear solver convergence failures at one step # Note: this value differs from the IDA default of 10 "max_convergence_failures": 100, # Safety factor in the nonlinear convergence test "nonlinear_convergence_coefficient": 0.33, # Suppress algebraic variables from error test "suppress_algebraic_error": False, # Store Hermite interpolation data for the solution. # Note: this option is always disabled if output_variables are given # or if t_interp values are specified "hermite_interpolation": True, # Setting hermite_reduction_factor > 1.0 compresses the solution size # by introducing a small amount of error to the Hermite spline # interpolant. A value of `2.0` roughly corresponds to a maximum 2x # increase in error (practically the error is much smaller), while # reducing the number of saved states by around 5-6x. This option is # only active if `hermite_interpolation` is True and sensitivities # are disabled. "hermite_reduction_factor": 1.0, ## Initial conditions calculation # Positive constant in the Newton iteration convergence test within the # initial condition calculation "nonlinear_convergence_coefficient_ic": 0.0033, # Maximum number of steps allowed when `init_all_y_ic = False` # Note: this value differs from the IDA default of 5 "max_num_steps_ic": 50, # Maximum number of the approximate Jacobian or preconditioner evaluations # allowed when the Newton iteration appears to be slowly converging # Note: this value differs from the IDA default of 4 "max_num_jacobians_ic": 40, # Maximum number of Newton iterations allowed in any one attempt to solve # the initial conditions calculation problem # Note: this value differs from the IDA default of 10 "max_num_iterations_ic": 100, # Maximum number of linesearch backtracks allowed in any Newton iteration, # when solving the initial conditions calculation problem "max_linesearch_backtracks_ic": 5, # Turn off linesearch "linesearch_off_ic": False, # How to calculate the initial conditions. # "True": calculate all y0 given ydot0 # "False": calculate y_alg0 and ydot_diff0 given y_diff0 "init_all_y_ic": False, # Internally calculate consistent initial conditions "calc_ic": True, ## Newton IC solver # "auto": use dedicated sub-block solver when possible. # This can result in a potentially smaller system of only the # algebraic variables. Requires a direct linear solver and # a standard form DAE. # "full": always use IDA's full-system linear solve. This uses the full # system of equationd and can handle non-standard form DAEs and # all classes of linear solvers/ "newton_mode": "auto", ## Early termination # Maximum number of consecutive steps allowed without advancing # the solution time by at least `t_no_progress` seconds. # If set to 0, this feature is disabled. "num_steps_no_progress": 0, # Minimum required time advancement (in seconds) after # `num_steps_no_progress` consecutive steps. # If set to 0.0, this feature is disabled. "t_no_progress": 0.0, } Note: These options only have an effect if model.convert_to_format == 'casadi' """
[docs] class StateID(IntEnum): ALGEBRAIC = 0 DIFFERENTIAL = 1
def __init__( self, rtol=1e-4, atol=1e-6, root_method=_UNSET, root_tol=1e-6, extrap_tol=None, on_extrapolation=None, output_variables=None, on_failure=None, options=None, store_first_last=False, ): self.output_variables = [] if output_variables is None else output_variables self._options = self._combine_options(options) # By default, we use an internal nonlinear solver within pybammsolvers # to compute the initial conditions. As a fallback, we use python bindings # for the same solver if root_method is _UNSET: root_method = None if self._internal_initialisation else "nonlinear_solver" super().__init__( method="ida", rtol=rtol, atol=atol, root_method=root_method, root_tol=root_tol, extrap_tol=extrap_tol, output_variables=output_variables, on_extrapolation=on_extrapolation, on_failure=on_failure, store_first_last=store_first_last, ) self.name = "IDA KLU solver" self._supports_interp = True self._supports_t_eval_discontinuities = True pybamm.citations.register("Hindmarsh2000") pybamm.citations.register("Hindmarsh2005") def _combine_options(self, user_options: dict | None) -> dict: user_options = user_options or {} num_solvers = user_options.get("num_threads", 1) default_options = { "print_stats": False, "compile": False, "jacobian": "sparse", "preconditioner": "BBDP", "precon_half_bandwidth": 5, "precon_half_bandwidth_keep": 5, "num_threads": 1, "num_solvers": num_solvers, "linear_solver": "SUNLinSol_KLU", "linsol_max_iterations": 5, "epsilon_linear_tolerance": 0.05, "increment_factor": 1.0, "linear_solution_scaling": True, "silence_sundials_errors": False, "max_order_bdf": 5, "max_num_steps": 100000, "dt_init": 0.0, "dt_min": 0.0, "dt_max": 0.0, "max_error_test_failures": 10, "max_nonlinear_iterations": 40, "max_convergence_failures": 100, "nonlinear_convergence_coefficient": 0.33, "suppress_algebraic_error": False, "hermite_interpolation": True, "hermite_reduction_factor": 1.0, "nonlinear_convergence_coefficient_ic": 0.0033, "max_num_steps_ic": 50, "max_num_jacobians_ic": 40, "max_num_iterations_ic": 100, "max_linesearch_backtracks_ic": 5, "linesearch_off_ic": False, "init_all_y_ic": False, "calc_ic": True, "newton_step_tol": 1e-4, "newton_mode": "auto", "num_steps_no_progress": 0, "t_no_progress": 0.0, } if not user_options: return default_options options = default_options | user_options self._check_options(options) return options def _check_options(self, options: dict): hermite_reduction_factor = options["hermite_reduction_factor"] if hermite_reduction_factor > 1.0: if self.output_variables: raise pybamm.SolverError( "hermite_reduction_factor cannot be used with " "output_variables. Both are memory-saving options " "that are mutually exclusive." ) if not options["hermite_interpolation"]: raise pybamm.SolverError( "hermite_reduction_factor requires " "hermite_interpolation to be enabled." ) else: if hermite_reduction_factor < 1.0: raise pybamm.SolverError("hermite_reduction_factor must be >= 1.0.") if not isinstance(options["compile"], bool): raise pybamm.SolverError("compile must be a bool") def _check_atol_type(self, atol, model): if isinstance(atol, float): return np.full(model.len_rhs_and_alg, atol) elif isinstance(atol, np.ndarray): return atol else: raise pybamm.SolverError( "Absolute tolerances must be a numpy array or float" )
[docs] def set_up(self, model, inputs=None, t_eval=None, ics_only=False): if model.convert_to_format != "casadi": pybamm.logger.warning( f"Converting {model.name} to CasADi for solving with IDAKLUSolver" ) model.convert_to_format = "casadi" base_set_up_return = super().set_up(model, inputs, t_eval, ics_only) if isinstance(inputs, list): # for setup, just use first input dict inputs_dict = inputs[0] else: inputs_dict = inputs or {} # stack inputs if inputs_dict: arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] stacked_inputs = np.vstack(arrays_to_stack) else: stacked_inputs = np.array([[]]) y0 = model.y0_list[0] if isinstance(y0, casadi.DM): y0 = y0.full() y0 = y0.flatten() if ics_only: return base_set_up_return if model.convert_to_format != "casadi": msg = "The python-idaklu solver has been deprecated." warnings.warn(msg, DeprecationWarning, stacklevel=2) raise pybamm.SolverError( f"Unsupported option for convert_to_format={model.convert_to_format}" ) mass_matrix = model.mass_matrix.entries # construct residuals function by binding inputs # TODO: do we need densify here? rhs_algebraic = model.rhs_algebraic_eval if not model.use_jacobian: raise pybamm.SolverError("KLU requires the Jacobian") # need to provide jacobian_rhs_alg - cj * mass_matrix t_casadi = casadi.MX.sym("t") y_casadi = casadi.MX.sym("y", model.len_rhs_and_alg) cj_casadi = casadi.MX.sym("cj") p_casadi = {} for name, value in inputs_dict.items(): if isinstance(value, numbers.Number): p_casadi[name] = casadi.MX.sym(name) else: p_casadi[name] = casadi.MX.sym(name, value.shape[0]) p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()]) jac_times_cjmass_expression = ( model.jac_rhs_algebraic_eval(t_casadi, y_casadi, p_casadi_stacked) - cj_casadi * mass_matrix ) if self._options["jacobian"] == "dense": jac_times_cjmass_expression = casadi.densify(jac_times_cjmass_expression) jac_times_cjmass = casadi.Function( "jac_times_cjmass", [t_casadi, y_casadi, p_casadi_stacked, cj_casadi], [jac_times_cjmass_expression], ) jac_times_cjmass_sparsity = jac_times_cjmass.sparsity_out(0) jac_bw_lower = jac_times_cjmass_sparsity.bw_lower() jac_bw_upper = jac_times_cjmass_sparsity.bw_upper() jac_times_cjmass_nnz = jac_times_cjmass_sparsity.nnz() jac_times_cjmass_colptrs = np.array( jac_times_cjmass_sparsity.colind(), dtype=np.int64 ) jac_times_cjmass_rowvals = np.array( jac_times_cjmass_sparsity.row(), dtype=np.int64 ) v_casadi = casadi.MX.sym("v", model.len_rhs_and_alg) jac_rhs_algebraic_action = model.jac_rhs_algebraic_action_eval if model.is_standard_form_dae: # index v_casadi at the first model.len_rhs entries # concatenate it with zeros at the end mass_action_expression = casadi.vertcat( v_casadi[: model.len_rhs], np.zeros(model.len_alg), ) else: mass_action_expression = casadi.densify(casadi.DM(mass_matrix) @ v_casadi) # also need the action of the mass matrix on a vector mass_action = casadi.Function( "mass_action", [v_casadi], [mass_action_expression] ) num_of_events = len(model.terminate_events_eval) # rootfn needs to return an array of length num_of_events rootfn = casadi.Function( "rootfn", [t_casadi, y_casadi, p_casadi_stacked], [ casadi.vertcat( *[ event(t_casadi, y_casadi, p_casadi_stacked) for event in model.terminate_events_eval ] ) ], ) # get ids of rhs and algebraic variables ids = np.concatenate( ( np.full(model.len_rhs, self.StateID.DIFFERENTIAL, dtype=np.int64), np.full(model.len_alg, self.StateID.ALGEBRAIC, dtype=np.int64), ) ) if model.jacp_rhs_algebraic_eval is not None: sensitivity_names = model.calculate_sensitivities if model.convert_to_format == "casadi": number_of_sensitivity_parameters = model.jacp_rhs_algebraic_eval.n_out() else: number_of_sensitivity_parameters = len(sensitivity_names) else: number_of_sensitivity_parameters = 0 sensitivity_names = [] # for the casadi solver we just give it dFdp_i if model.jacp_rhs_algebraic_eval is None: sensfn = casadi.Function("sensfn", [], []) else: sensfn = model.jacp_rhs_algebraic_eval atol = getattr(model, "atol", self.atol) atol = self._check_atol_type(atol, model) # Build algebraic-only residual and Jacobian for Newton sub-block mode. # When newton_mode="full", skip these so the C++ solver uses the # full-system IDA linear solve (DECOUPLED_FULL or COUPLED_FULL), # which supports any linear solver including iterative ones. if self._options.get("newton_mode", "auto") == "auto": alg_res_fn = model.algebraic_eval jac_alg_fn = model.jac_algebraic_eval else: alg_res_fn = casadi.Function("empty_alg_res", [], []) jac_alg_fn = casadi.Function("empty_alg_jac", [], []) if ( self._options["hermite_reduction_factor"] > 1.0 and number_of_sensitivity_parameters > 0 ): warnings.warn( "Setting hermite_reduction_factor > 1.0 is not currently supported " "with sensitivities. The hermite_reduction_factor option will be " "ignored.", pybamm.SolverWarning, stacklevel=2, ) # Collect all casadi functions for AOT compilation. With compile=True, # all functions are bundled into a single shared library via one gcc # invocation; each serialized Function then points at its entry point. has_sens = (len(stacked_inputs) > 0) and model.calculate_sensitivities fns = dict( zip( _SETUP_FCN_KEYS, ( rhs_algebraic, jac_times_cjmass, jac_rhs_algebraic_action, mass_action, sensfn, rootfn, alg_res_fn, jac_alg_fn, ), strict=True, ) ) for key in self.output_variables: fns[f"var:{key}"] = self.computed_var_fcns[key] if has_sens: fns[f"dvar_dy:{key}"] = self.computed_dvar_dy_fcns[key] fns[f"dvar_dp:{key}"] = self.computed_dvar_dp_fcns[key] if self._options["compile"]: compiled = aot_compile(list(fns.values())) fns = dict(zip(fns.keys(), compiled, strict=True)) def to_idaklu(fn): return idaklu.generate_function(fn.serialize()) self._setup = { "number_of_states": len(y0), "inputs": len(stacked_inputs), "jac_bandwidth_upper": jac_bw_upper, "jac_bandwidth_lower": jac_bw_lower, "jac_times_cjmass_colptrs": jac_times_cjmass_colptrs, "jac_times_cjmass_rowvals": jac_times_cjmass_rowvals, "jac_times_cjmass_nnz": jac_times_cjmass_nnz, "num_of_events": num_of_events, "ids": ids, "atol": atol, "sensitivity_names": sensitivity_names, "number_of_sensitivity_parameters": number_of_sensitivity_parameters, "standard_form_dae": model.is_standard_form_dae, "output_variables": self.output_variables, "var_fcns": self.computed_var_fcns, "var_idaklu_fcns": [], "dvar_dy_idaklu_fcns": [], "dvar_dp_idaklu_fcns": [], } for name in _SETUP_FCN_KEYS: self._setup[name] = to_idaklu(fns[name]) # The idaklu-bound Function isn't callable from Python; keep the # original for the closest_event_idx lookup in _post_process_solution. self._setup["rootfn_casadi"] = fns["rootfn"] for key in self.output_variables: self._setup["var_idaklu_fcns"].append(to_idaklu(fns[f"var:{key}"])) if has_sens: self._setup["dvar_dy_idaklu_fcns"].append( to_idaklu(fns[f"dvar_dy:{key}"]) ) self._setup["dvar_dp_idaklu_fcns"].append( to_idaklu(fns[f"dvar_dp:{key}"]) ) self._setup["solver"] = idaklu.create_casadi_solver_group( number_of_states=self._setup["number_of_states"], number_of_parameters=self._setup["number_of_sensitivity_parameters"], rhs_alg=self._setup["rhs_algebraic"], jac_times_cjmass=self._setup["jac_times_cjmass"], jac_times_cjmass_colptrs=self._setup["jac_times_cjmass_colptrs"], jac_times_cjmass_rowvals=self._setup["jac_times_cjmass_rowvals"], jac_times_cjmass_nnz=self._setup["jac_times_cjmass_nnz"], jac_bandwidth_lower=self._setup["jac_bandwidth_lower"], jac_bandwidth_upper=self._setup["jac_bandwidth_upper"], jac_action=self._setup["jac_rhs_algebraic_action"], mass_action=self._setup["mass_action"], sens=self._setup["sensfn"], events=self._setup["rootfn"], number_of_events=self._setup["num_of_events"], rhs_alg_id=self._setup["ids"], atol=self._setup["atol"], rtol=self.rtol, inputs=self._setup["inputs"], var_fcns=self._setup["var_idaklu_fcns"], dvar_dy_fcns=self._setup["dvar_dy_idaklu_fcns"], dvar_dp_fcns=self._setup["dvar_dp_idaklu_fcns"], options=self._options, alg_res=self._setup["alg_res"], alg_jac=self._setup["alg_jac"], ) # C++ solver group now owns each casadi::Function via shared_ptr; # drop the Python references. rootfn_casadi stays for runtime use. for key in (*_SETUP_FCN_KEYS, *_SETUP_FCN_LIST_KEYS): self._setup.pop(key, None) # Release the public casadi.Function caches now that the C++ group # owns the functions. _setup["var_fcns"] keeps the references that # _post_process_solution still needs; the dvar caches are unused # after setup, so dropping them frees that memory immediately. self.computed_var_fcns = {} self.computed_dvar_dy_fcns = {} self.computed_dvar_dp_fcns = {} return base_set_up_return
def __getstate__(self): # Drop the rebuildable state (C++ solver + casadi.Function graphs) # so the next solve() rebuilds from the model rather than shipping # serialised functions in the pickle. state = self.__dict__.copy() for key in _REBUILDABLE_STATE_KEYS: state.pop(key, None) return state def __setstate__(self, d): self.__dict__.update(d) # Restore the empty defaults BaseSolver.__init__ would set, so the # solver reads as "not yet set up" until the next solve(). self._model_set_up = {} self.computed_var_fcns = {} @property def options(self): return self._options def _integrate( self, model, t_eval, inputs_list: list[dict] | None = None, t_interp=None, nproc=None, ): """ Overloads the _integrate method from BaseSolver to use the IDAKLU solver """ if model.convert_to_format != "casadi": # pragma: no cover raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") inputs_list = inputs_list or [{}] # stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters) if inputs_list and inputs_list[0]: inputs = np.vstack([_flatten_inputs(d) for d in inputs_list]) else: inputs = np.array([[]] * len(inputs_list)) # y0full is now a list with length = number of input sets y0full = np.vstack(model.y0full) ydot0full = np.vstack(model.ydot0full) atol = getattr(model, "atol", self.atol) atol = self._check_atol_type(atol, model) logger = ( pybamm.logger.debug if pybamm.logger.isEnabledFor(logging.DEBUG) else None ) timer = pybamm.Timer() try: solns = self._setup["solver"].solve( t_eval, t_interp, y0full, ydot0full, inputs, logger=logger, ) except ValueError as e: # Return from None to replace the C++ runtime error raise pybamm.SolverError(str(e)) from None integration_time = timer.time() return [ self._post_process_solution( soln, model, integration_time, inputs_dict, t_eval ) for soln, inputs_dict in zip(solns, inputs_list, strict=False) ] @property def _internal_initialisation(self) -> bool: return bool(self._options["calc_ic"]) def _check_event_violation_on_initialisation(self, *args, **kwargs): if self._internal_initialisation: return return self._check_event_violation(*args, **kwargs) def _check_event_violation_post_solve(self, *args, **kwargs): if not self._internal_initialisation: return return self._check_event_violation(*args, **kwargs) def _post_process_solution(self, sol, model, integration_time, inputs_dict, t_eval): number_of_sensitivity_parameters = self._setup[ "number_of_sensitivity_parameters" ] sensitivity_names = self._setup["sensitivity_names"] number_of_timesteps = sol.t.size number_of_states = model.len_rhs_and_alg save_outputs_only = self.output_variables if save_outputs_only: # Substitute empty vectors for state vector 'y' y_out = np.zeros((number_of_timesteps * number_of_states, 0)) y_event = sol.y_term else: y_out = sol.y.reshape((number_of_timesteps, number_of_states)) y_event = y_out[-1] # If there is only one step and an event was found, the event was # triggered at t0 after consistent initialization. y_event is the # post-IC state: sol.y_term (outputs-only) or y_out[-1] (full), # both stored after IC. This check identifies *which* event fired. if number_of_timesteps == 1 and sol.flag == _IDA_ROOT_RETURN: self._check_event_violation_post_solve(t_eval, model, y_event, inputs_dict) # return sensitivity solution, we need to flatten yS to # (#timesteps * #states (where t is changing the quickest),) # to match format used by Solution # note that yS is (n_p, n_t, n_y) if number_of_sensitivity_parameters != 0: yS_out = { name: sol.yS[i].reshape(-1, 1) for i, name in enumerate(sensitivity_names) } # add "all" stacked sensitivities ((#timesteps * #states,#sens_params)) yS_out["all"] = np.hstack([yS_out[name] for name in sensitivity_names]) else: yS_out = {} # IDA_SUCCESS (0) = solved for all t_eval # IDA_ROOT_RETURN (2) = found root(s) # < 0 = solver failure if sol.flag == _IDA_ROOT_RETURN: termination = "event" elif sol.flag >= 0: termination = "final time" elif sol.flag < 0: termination = "failure" msg = idaklu.sundials_error_message(sol.flag) match self._on_failure: case "warn": warnings.warn( msg + ", returning a partial solution.", stacklevel=2, ) case "error": raise pybamm.SolverError(msg) if sol.yp.size > 0: yp = sol.yp.reshape((number_of_timesteps, number_of_states)).T else: yp = None t = sol.t t_eval = np.array(t_eval) if t[-1] != t_eval[-1]: idx_final = np.searchsorted(t_eval, t[-1]) + 1 t_eval = t_eval[:idx_final] # the final index may differ due to an event; # manually set it to the true final time t_eval[-1] = t[-1] # Forward the compile flag so post-solve observation uses the same # backend as the integration. solution_options = {"compile": self._options["compile"]} newsol = pybamm.Solution( t, np.transpose(y_out), model, inputs_dict, np.array([sol.t[-1]]), np.transpose(y_event)[:, np.newaxis], termination, all_sensitivities=yS_out, all_yps=yp, all_t_evals=t_eval, variables_returned=bool(save_outputs_only), options=solution_options, ) # Set closest_event_idx so BaseSolver.get_termination_reason doesn't # re-walk every event's symbolic expression on the Python side. if sol.flag == _IDA_ROOT_RETURN and self._setup["num_of_events"] > 0: event_values = np.asarray( self._setup["rootfn_casadi"]( float(sol.t[-1]), np.asarray(y_event).reshape(-1), _flatten_inputs(inputs_dict), ) ).reshape(-1) newsol.closest_event_idx = int(np.nanargmin(np.abs(event_values))) newsol.integration_time = integration_time if not save_outputs_only: return newsol # Populate variables and sensitivities dictionaries directly number_of_samples = sol.y.shape[0] // number_of_timesteps sol.y = sol.y.reshape((number_of_timesteps, number_of_samples)) sensitivity_params = ( model.calculate_sensitivities if model.calculate_sensitivities else [] ) start_idx = 0 for var in self.output_variables: var_nnz, var_shape, base_variables = self._get_variable_info(model, var) end_idx = start_idx + var_nnz data = sol.y[:, start_idx:end_idx] time_indep = False # handle any time integral variables if var in self._time_integral_vars: # time integral variables should all be 1D tiv = self._time_integral_vars[var] data = tiv.postfix(data.reshape(-1), sol.t, inputs_dict) time_indep = True newsol._variables[var] = pybamm.ProcessedVariableComputed( [model.get_processed_variable_or_event(var)], base_variables, [data], newsol, time_indep=time_indep, ) # Add sensitivities newsol[var]._sensitivities = {} if sensitivity_params: if var_nnz != math.prod(var_shape): raise pybamm.SolverError( f"Sensitivity of sparse variables not supported. {var} is a sparse variable with number of non-zeros {var_nnz} and shape {var_shape}" ) sens_data = sol.yS[:, start_idx:end_idx, :] sens_data = sens_data.reshape( number_of_timesteps * (end_idx - start_idx), number_of_sensitivity_parameters, ) if var in self._time_integral_vars: tiv = self._time_integral_vars[var] sens_data = tiv.postfix_sensitivities( var, data, sol.t, inputs_dict, sens_data ) newsol[var]._sensitivities["all"] = sens_data # Add the individual sensitivity for i, name in enumerate(inputs_dict.keys()): sens = newsol[var]._sensitivities["all"][:, i : i + 1].reshape(-1) newsol[var]._sensitivities[name] = sens start_idx += var_nnz return newsol def _get_variable_info(self, model, var) -> tuple: """Get variable length and base variables based on model format.""" if model.convert_to_format == "casadi": base_var = self._setup["var_fcns"][var] var_eval = base_var(0.0, 0.0, 0.0) var_nnz = var_eval.sparsity().nnz() var_shape = var_eval.shape return var_nnz, var_shape, [base_var] else: # pragma: no cover raise pybamm.SolverError( f"Unsupported evaluation engine for convert_to_format=" f"{model.convert_to_format}" ) def _set_consistent_initialization(self, model, time, inputs_list): """ Initialize y0 and ydot0 for the solver. In addition to calculating y0 from BaseSolver, we also calculate ydot0 for semi-explicit DAEs Parameters ---------- model : :class:`pybamm.BaseModel` The model for which to calculate initial conditions. time : numeric type The time at which to calculate the initial conditions. inputs_list : list of dict Any input parameters to pass to the model when solving. """ # set model.y0_list super()._set_consistent_initialization(model, time, inputs_list) casadi_format = model.convert_to_format == "casadi" def handle_y0(y0): if isinstance(y0, casadi.DM): y0 = y0.full() return y0.flatten() y0_list = [handle_y0(y0) for y0 in model.y0_list] # calculate the time derivatives of the differential equations # for semi-explicit DAEs if model.len_rhs > 0: ydot0_list = [ self._rhs_dot_consistent_initialization(y0, model, time, inputs_dict) for y0, inputs_dict in zip(y0_list, inputs_list, strict=True) ] else: ydot0_list = [np.zeros_like(y0) for y0 in y0_list] sensitivity = model.y0S_list and casadi_format if sensitivity: y0S_list = model.y0S_list y0full = [] ydot0full = [] for y0, ydot0, y0S, inputs_dict in zip( y0_list, ydot0_list, y0S_list, inputs_list, strict=True ): y0f, ydot0f = self._sensitivity_consistent_initialization( y0, ydot0, y0S, time, inputs_dict ) y0full.append(y0f) ydot0full.append(ydot0f) else: y0full = y0_list ydot0full = ydot0_list model.y0full = y0full model.ydot0full = ydot0full def _rhs_dot_consistent_initialization(self, y0, model, time, inputs_dict): """ Compute the consistent initialization of ydot0 for the differential terms for the solver. If we have a semi-explicit DAE, we can explicitly solve for this value using the consistently initialized y0 vector. Parameters ---------- y0 : :class:`numpy.array` The initial values of the state vector. model : :class:`pybamm.BaseModel` The model for which to calculate initial conditions. time : numeric type The time at which to calculate the initial conditions. inputs_dict : dict Any input parameters to pass to the model when solving. """ casadi_format = model.convert_to_format == "casadi" inputs_dict = inputs_dict or {} # stack inputs if inputs_dict: arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] inputs = np.vstack(arrays_to_stack) else: inputs = np.array([[]]) ydot0 = np.zeros_like(y0) # calculate the time derivatives of the differential equations input_eval = inputs if casadi_format else inputs_dict rhs0 = model.rhs_eval(time, y0, input_eval) if isinstance(rhs0, casadi.DM): rhs0 = rhs0.full() rhs0 = rhs0.flatten() # for the differential terms, ydot = M^-1 * (rhs) if model.is_standard_form_dae: # M^-1 is the identity matrix, so we can just use rhs ydot0[: model.len_rhs] = rhs0 else: # M^-1 is not the identity matrix, so we need to use the mass matrix M_ode = model.mass_matrix.entries[: model.len_rhs, : model.len_rhs] ydot0[: model.len_rhs] = spsolve(M_ode, rhs0) return ydot0 def _sensitivity_consistent_initialization(self, y0, ydot0, y0S, time, inputs_dict): """ Extend the consistent initialization to include the sensitivty equations Parameters ---------- y0 : :class:`numpy.array` The initial values of the state vector. ydot0 : :class:`numpy.array` The initial values of the time derivatives of the state vector. y0S : :class:`numpy.array` The initial values of the sensitivity state vectors. time : numeric type The time at which to calculate the initial conditions. inputs_dict : dict Any input parameters to pass to the model when solving. """ if isinstance(y0S, casadi.DM): y0S = (y0S,) if isinstance(y0S[0], casadi.DM): y0S = (x.full() for x in y0S) y0S = [x.flatten() for x in y0S] y0full = np.concatenate([y0, *y0S]) ydot0S = [np.zeros_like(y0S_i) for y0S_i in y0S] ydot0full = np.concatenate([ydot0, *ydot0S]) return y0full, ydot0full
[docs] def jaxify( self, model, t_eval, *, output_variables=None, calculate_sensitivities=True, t_interp=None, ): """JAXify the solver object Creates a JAX expression representing the IDAKLU-wrapped solver object. Parameters ---------- model : :class:`pybamm.BaseModel` The model to be solved t_eval : numeric type, optional The times at which to stop the integration due to a discontinuity in time. output_variables : list of str, optional The variables to be returned. If None, all variables in the model are used. calculate_sensitivities : bool, optional Whether to calculate sensitivities. Default is True. t_interp : None, list or ndarray, optional The times (in seconds) at which to interpolate the solution. Defaults to `None`, which returns the adaptive time-stepping times. """ obj = pybamm.IDAKLUJax( self, # IDAKLU solver instance model, t_eval, output_variables=output_variables, calculate_sensitivities=calculate_sensitivities, t_interp=t_interp, ) return obj
[docs] def reduce_solution( self, solution, hermite_reduction_factor=2.0, ) -> pybamm.Solution: """Reduce knots in a pybamm.Solution using Hermite spline compression. The multiplier M controls the total error budget relative to the solver's own WRMS tolerance. The knot reducer is allowed an additional WRMS error of (M-1), so the total error satisfies ``||e_total||_WRMS <= M``. M = 2 (default) means the reduced solution may have up to 2x the solver's local error. Parameters ---------- solution : :class:`pybamm.Solution` The solution to reduce. Must have hermite interpolation data (all_yps is not None). hermite_reduction_factor : float, optional Total error multiplier (>= 1.0, default 2.0). The knot reducer's WRMS threshold is N * (M-1)^2. Larger values allow more aggressive reduction. Returns ------- :class:`pybamm.Solution` The modified solution. """ if not solution.hermite_interpolation: raise pybamm.SolverError( "reduce_solution requires Hermite interpolation data (all_yps)." ) if self.options["hermite_reduction_factor"] != 1.0: raise pybamm.SolverError( "reduce_solution requires the original solver to have " "`hermite_reduction_factor = 1.0`" ) all_ts = solution.all_ts all_ys = solution.all_ys all_yps = solution.all_yps all_models = solution.all_models n_seg = len(all_ts) rtol = self.rtol atol = self.atol # Build flat time-major arrays for the C++ reducer. # all_ys[i] is (n_states, M) transposed view whose underlying # buffer is already time-major (M * n_states,). .T.ravel() # returns a 1D view of that buffer -- no copy. flat_ys = [all_ys[i].T.ravel() for i in range(n_seg)] flat_yps = [all_yps[i].T.ravel() for i in range(n_seg)] # Build per-segment atol vectors atol_vecs = [None] * n_seg for i, model in enumerate(all_models): atol_vecs[i] = self._check_atol_type(atol, model) ts_vec = idaklu.VectorRealtypeNdArray(all_ts) ys_vec = idaklu.VectorRealtypeNdArray(flat_ys) yps_vec = idaklu.VectorRealtypeNdArray(flat_yps) atols_vec = idaklu.VectorRealtypeNdArray(atol_vecs) t_evals_vec = idaklu.VectorRealtypeNdArray(solution.all_t_evals) red_ts, red_ys, red_yps = idaklu.reduce_knots( ts_vec, ys_vec, yps_vec, atols_vec, t_evals_vec, float(rtol), float(hermite_reduction_factor), ) # Reshape reduced flat arrays back to (n_states, K) convention new_ts = [np.asarray(red_ts[i]) for i in range(n_seg)] new_ys = [] new_yps = [] for i in range(n_seg): K = len(new_ts[i]) N = all_ys[i].shape[0] new_ys.append(np.asarray(red_ys[i]).reshape(K, N).T) new_yps.append(np.asarray(red_yps[i]).reshape(K, N).T) new_sol = pybamm.Solution( all_ts=new_ts, all_ys=new_ys, all_yps=new_yps, all_models=all_models, all_inputs=solution.all_inputs, t_event=solution.t_event, y_event=solution.y_event, termination=solution.termination, all_sensitivities=solution._all_sensitivities, all_t_evals=solution.all_t_evals, variables_returned=solution.variables_returned, options=solution.user_options, ) # Propagate metadata from the original solution new_sol._all_inputs_stacked = solution.all_inputs_stacked new_sol._all_inputs_casadi = solution.all_inputs_casadi new_sol.closest_event_idx = solution.closest_event_idx new_sol.solve_time = solution.solve_time new_sol.integration_time = solution.integration_time new_sol.set_up_time = solution.set_up_time return new_sol