Source code for pybamm.solvers.idaklu_solver

# mypy: ignore-errors
import logging
import math
import numbers
import warnings

import casadi
import numpy as np
import pybammsolvers.idaklu as idaklu
from scipy.sparse.linalg import spsolve

import pybamm
from pybamm.codegen.compilation import aot_compile


[docs] class IDAKLUSolver(pybamm.BaseSolver): """ Solve a discretised model, using sundials with the KLU sparse linear solver. Parameters ---------- rtol : float, optional The relative tolerance for the solver (default is 1e-4). atol : float, optional The absolute tolerance for the solver (default is 1e-6). root_method : str or pybamm algebraic solver class, optional The method to use to find initial conditions (for DAE solvers). If a solver class, must be an algebraic solver class. If "casadi", the solver uses casadi's Newton rootfinding algorithm to find initial conditions. Otherwise, the solver uses 'scipy.optimize.root' with method specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for the initial-condition solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). on_extrapolation : str, optional What to do if the solver is extrapolating. Options are "warn", "error", or "ignore". Default is "warn". on_failure : str, optional What to do if a solver error flag occurs. Options are "warn", "error", or "ignore". Default is "error". output_variables : list[str], optional List of variables to calculate and return. If none are specified then the complete state vector is returned (can be very large) (default is []) options: dict, optional Addititional options to pass to the solver, by default: .. code-block:: python options = { # Print statistics of the solver after every solve "print_stats": False, # If ``True``, ahead-of-time compile each casadi ``Function`` # to a shared library via the system C compiler (see # :mod:`pybamm.codegen.compilation`). If ``False`` (default) # the casadi in-process virtual machine is used. "compile": False, # Number of threads available for OpenMP (must be greater than or equal to `num_solvers`) "num_threads": 1, # Number of solvers to use in parallel (for solving multiple sets of input parameters in parallel) "num_solvers": num_threads, ## Linear solver interface # name of sundials linear solver to use options are: "SUNLinSol_KLU", # "SUNLinSol_Dense", "SUNLinSol_Band", "SUNLinSol_SPBCGS", # "SUNLinSol_SPFGMR", "SUNLinSol_SPGMR", "SUNLinSol_SPTFQMR", "linear_solver": "SUNLinSol_KLU", # Jacobian form, can be "none", "dense", # "banded", "sparse", "matrix-free" "jacobian": "sparse", # Preconditioner for iterative solvers, can be "none", "BBDP" "preconditioner": "BBDP", # For iterative linear solver preconditioner, bandwidth of # approximate jacobian "precon_half_bandwidth": 5, # For iterative linear solver preconditioner, bandwidth of # approximate jacobian that is kept "precon_half_bandwidth_keep": 5, # For iterative linear solvers, max number of iterations "linsol_max_iterations": 5, # Ratio between linear and nonlinear tolerances "epsilon_linear_tolerance": 0.05, # Increment factor used in DQ Jacobian-vector product approximation "increment_factor": 1.0, # Enable or disable linear solution scaling "linear_solution_scaling": True, # Silence Sundials errors during solve "silence_sundials_errors": False, ## Main solver # Maximum order of the linear multistep method "max_order_bdf": 5, # Maximum number of steps to be taken by the solver in its attempt to # reach the next output time. # Note: this value differs from the IDA default of 500 "max_num_steps": 100000, # Initial step size. The solver default is used if this is left at 0.0 "dt_init": 0.0, # Minimum absolute step size. The solver default is used if this is # left at 0.0 "dt_min": 0.0, # Maximum absolute step size. The solver default is used if this is # left at 0.0 "dt_max": 0.0, # Maximum number of error test failures in attempting one step "max_error_test_failures": 10, # Maximum number of nonlinear solver iterations at one step # Note: this value differs from the IDA default of 4 "max_nonlinear_iterations": 40, # Maximum number of nonlinear solver convergence failures at one step # Note: this value differs from the IDA default of 10 "max_convergence_failures": 100, # Safety factor in the nonlinear convergence test "nonlinear_convergence_coefficient": 0.33, # Suppress algebraic variables from error test "suppress_algebraic_error": False, # Store Hermite interpolation data for the solution. # Note: this option is always disabled if output_variables are given # or if t_interp values are specified "hermite_interpolation": True, # Setting hermite_reduction_factor > 1.0 compresses the solution size # by introducing a small amount of error to the Hermite spline # interpolant. A value of `2.0` roughly corresponds to a maximum 2x # increase in error (practically the error is much smaller), while # reducing the number of saved states by around 5-6x. This option is # only active if `hermite_interpolation` is True and sensitivities # are disabled. "hermite_reduction_factor": 1.0, ## Initial conditions calculation # Positive constant in the Newton iteration convergence test within the # initial condition calculation "nonlinear_convergence_coefficient_ic": 0.0033, # Maximum number of steps allowed when `init_all_y_ic = False` # Note: this value differs from the IDA default of 5 "max_num_steps_ic": 50, # Maximum number of the approximate Jacobian or preconditioner evaluations # allowed when the Newton iteration appears to be slowly converging # Note: this value differs from the IDA default of 4 "max_num_jacobians_ic": 40, # Maximum number of Newton iterations allowed in any one attempt to solve # the initial conditions calculation problem # Note: this value differs from the IDA default of 10 "max_num_iterations_ic": 100, # Maximum number of linesearch backtracks allowed in any Newton iteration, # when solving the initial conditions calculation problem "max_linesearch_backtracks_ic": 100, # Turn off linesearch "linesearch_off_ic": False, # How to calculate the initial conditions. # "True": calculate all y0 given ydot0 # "False": calculate y_alg0 and ydot_diff0 given y_diff0 "init_all_y_ic": False, # Calculate consistent initial conditions "calc_ic": True, ## Early termination # Maximum number of consecutive steps allowed without advancing # the solution time by at least `t_no_progress` seconds. # If set to 0, this feature is disabled. "num_steps_no_progress": 0, # Minimum required time advancement (in seconds) after # `num_steps_no_progress` consecutive steps. # If set to 0.0, this feature is disabled. "t_no_progress": 0.0, } Note: These options only have an effect if model.convert_to_format == 'casadi' """ def __init__( self, rtol=1e-4, atol=1e-6, root_method="casadi", root_tol=1e-6, extrap_tol=None, on_extrapolation=None, output_variables=None, on_failure=None, options=None, ): self.output_variables = [] if output_variables is None else output_variables self._options = self._combine_options(options) super().__init__( method="ida", rtol=rtol, atol=atol, root_method=root_method, root_tol=root_tol, extrap_tol=extrap_tol, output_variables=output_variables, on_extrapolation=on_extrapolation, on_failure=on_failure, ) self.name = "IDA KLU solver" self._supports_interp = True self._supports_t_eval_discontinuities = True pybamm.citations.register("Hindmarsh2000") pybamm.citations.register("Hindmarsh2005") def _combine_options(self, user_options: dict | None) -> dict: num_solvers = user_options.get("num_threads", 1) if user_options else 1 default_options = { "print_stats": False, "compile": False, "jacobian": "sparse", "preconditioner": "BBDP", "precon_half_bandwidth": 5, "precon_half_bandwidth_keep": 5, "num_threads": 1, "num_solvers": num_solvers, "linear_solver": "SUNLinSol_KLU", "linsol_max_iterations": 5, "epsilon_linear_tolerance": 0.05, "increment_factor": 1.0, "linear_solution_scaling": True, "silence_sundials_errors": False, "max_order_bdf": 5, "max_num_steps": 100000, "dt_init": 0.0, "dt_min": 0.0, "dt_max": 0.0, "max_error_test_failures": 10, "max_nonlinear_iterations": 40, "max_convergence_failures": 100, "nonlinear_convergence_coefficient": 0.33, "suppress_algebraic_error": False, "hermite_interpolation": True, "hermite_reduction_factor": 1.0, "nonlinear_convergence_coefficient_ic": 0.0033, "max_num_steps_ic": 50, "max_num_jacobians_ic": 40, "max_num_iterations_ic": 100, "max_linesearch_backtracks_ic": 100, "linesearch_off_ic": False, "init_all_y_ic": False, "calc_ic": True, "num_steps_no_progress": 0, "t_no_progress": 0.0, } if not user_options: return default_options options = default_options | user_options self._check_options(options) return options def _check_options(self, options: dict): hermite_reduction_factor = options["hermite_reduction_factor"] if hermite_reduction_factor > 1.0: if self.output_variables: raise pybamm.SolverError( "hermite_reduction_factor cannot be used with " "output_variables. Both are memory-saving options " "that are mutually exclusive." ) if not options["hermite_interpolation"]: raise pybamm.SolverError( "hermite_reduction_factor requires " "hermite_interpolation to be enabled." ) else: if hermite_reduction_factor < 1.0: raise pybamm.SolverError("hermite_reduction_factor must be >= 1.0.") if not isinstance(options["compile"], bool): raise pybamm.SolverError("compile must be a bool") def _check_atol_type(self, atol, model): if isinstance(atol, float): return np.full(model.len_rhs_and_alg, atol) elif isinstance(atol, np.ndarray): return atol else: raise pybamm.SolverError( "Absolute tolerances must be a numpy array or float" )
[docs] def set_up(self, model, inputs=None, t_eval=None, ics_only=False): base_set_up_return = super().set_up(model, inputs, t_eval, ics_only) if isinstance(inputs, list): # for setup, just use first input dict inputs_dict = inputs[0] else: inputs_dict = inputs or {} # stack inputs if inputs_dict: arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] stacked_inputs = np.vstack(arrays_to_stack) else: stacked_inputs = np.array([[]]) y0 = model.y0_list[0] if isinstance(y0, casadi.DM): y0 = y0.full() y0 = y0.flatten() if ics_only: return base_set_up_return if model.convert_to_format != "casadi": msg = "The python-idaklu solver has been deprecated." warnings.warn(msg, DeprecationWarning, stacklevel=2) raise pybamm.SolverError( f"Unsupported option for convert_to_format={model.convert_to_format}" ) if self._options["jacobian"] == "dense": mass_matrix = casadi.DM(model.mass_matrix.entries.toarray()) else: mass_matrix = casadi.DM(model.mass_matrix.entries) # construct residuals function by binding inputs # TODO: do we need densify here? rhs_algebraic = model.rhs_algebraic_eval if not model.use_jacobian: raise pybamm.SolverError("KLU requires the Jacobian") # need to provide jacobian_rhs_alg - cj * mass_matrix t_casadi = casadi.MX.sym("t") y_casadi = casadi.MX.sym("y", model.len_rhs_and_alg) cj_casadi = casadi.MX.sym("cj") p_casadi = {} for name, value in inputs_dict.items(): if isinstance(value, numbers.Number): p_casadi[name] = casadi.MX.sym(name) else: p_casadi[name] = casadi.MX.sym(name, value.shape[0]) p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()]) jac_times_cjmass = casadi.Function( "jac_times_cjmass", [t_casadi, y_casadi, p_casadi_stacked, cj_casadi], [ model.jac_rhs_algebraic_eval(t_casadi, y_casadi, p_casadi_stacked) - cj_casadi * mass_matrix ], ) jac_times_cjmass_sparsity = jac_times_cjmass.sparsity_out(0) jac_bw_lower = jac_times_cjmass_sparsity.bw_lower() jac_bw_upper = jac_times_cjmass_sparsity.bw_upper() jac_times_cjmass_nnz = jac_times_cjmass_sparsity.nnz() jac_times_cjmass_colptrs = np.array( jac_times_cjmass_sparsity.colind(), dtype=np.int64 ) jac_times_cjmass_rowvals = np.array( jac_times_cjmass_sparsity.row(), dtype=np.int64 ) v_casadi = casadi.MX.sym("v", model.len_rhs_and_alg) jac_rhs_algebraic_action = model.jac_rhs_algebraic_action_eval # also need the action of the mass matrix on a vector mass_action = casadi.Function( "mass_action", [v_casadi], [casadi.densify(mass_matrix @ v_casadi)] ) num_of_events = len(model.terminate_events_eval) # rootfn needs to return an array of length num_of_events rootfn = casadi.Function( "rootfn", [t_casadi, y_casadi, p_casadi_stacked], [ casadi.vertcat( *[ event(t_casadi, y_casadi, p_casadi_stacked) for event in model.terminate_events_eval ] ) ], ) # get ids of rhs and algebraic variables rhs_ids = np.ones(model.rhs_eval(0, y0, stacked_inputs).shape[0]) alg_ids = np.zeros(len(y0) - len(rhs_ids)) ids = np.concatenate((rhs_ids, alg_ids)) if model.jacp_rhs_algebraic_eval is not None: sensitivity_names = model.calculate_sensitivities if model.convert_to_format == "casadi": number_of_sensitivity_parameters = model.jacp_rhs_algebraic_eval.n_out() else: number_of_sensitivity_parameters = len(sensitivity_names) else: number_of_sensitivity_parameters = 0 sensitivity_names = [] # for the casadi solver we just give it dFdp_i if model.jacp_rhs_algebraic_eval is None: sensfn = casadi.Function("sensfn", [], []) else: sensfn = model.jacp_rhs_algebraic_eval atol = getattr(model, "atol", self.atol) atol = self._check_atol_type(atol, model) if ( self._options["hermite_reduction_factor"] > 1.0 and number_of_sensitivity_parameters > 0 ): warnings.warn( "Setting hermite_reduction_factor > 1.0 is not currently supported " "with sensitivities. The hermite_reduction_factor option will be " "ignored.", pybamm.SolverWarning, stacklevel=2, ) # Collect all casadi functions for AOT compilation. With compile=True, # all functions are bundled into a single shared library via one gcc # invocation; each serialized Function then points at its entry point. has_sens = (len(stacked_inputs) > 0) and model.calculate_sensitivities solver_fn_names = [ "rhs_algebraic", "jac_times_cjmass", "jac_rhs_algebraic_action", "mass_action", "sensfn", "rootfn", ] fns = { "rhs_algebraic": rhs_algebraic, "jac_times_cjmass": jac_times_cjmass, "jac_rhs_algebraic_action": jac_rhs_algebraic_action, "mass_action": mass_action, "sensfn": sensfn, "rootfn": rootfn, } for key in self.output_variables: fns[f"var:{key}"] = self.computed_var_fcns[key] if has_sens: fns[f"dvar_dy:{key}"] = self.computed_dvar_dy_fcns[key] fns[f"dvar_dp:{key}"] = self.computed_dvar_dp_fcns[key] if self._options["compile"]: compiled = aot_compile(list(fns.values())) fns = dict(zip(fns.keys(), compiled, strict=True)) # Build _setup dict with serialized functions for idaklu C++ wrapper def to_idaklu(fn): pkl = fn.serialize() return idaklu.generate_function(pkl), pkl self._setup = { "number_of_states": len(y0), "inputs": len(stacked_inputs), "jac_bandwidth_upper": jac_bw_upper, "jac_bandwidth_lower": jac_bw_lower, "jac_times_cjmass_colptrs": jac_times_cjmass_colptrs, "jac_times_cjmass_rowvals": jac_times_cjmass_rowvals, "jac_times_cjmass_nnz": jac_times_cjmass_nnz, "num_of_events": num_of_events, "ids": ids, "atol": atol, "sensitivity_names": sensitivity_names, "number_of_sensitivity_parameters": number_of_sensitivity_parameters, "standard_form_dae": model.is_standard_form_dae, "output_variables": self.output_variables, "var_fcns": self.computed_var_fcns, "var_idaklu_fcns": [], "var_idaklu_fcns_pkl": [], "dvar_dy_idaklu_fcns": [], "dvar_dy_idaklu_fcns_pkl": [], "dvar_dp_idaklu_fcns": [], "dvar_dp_idaklu_fcns_pkl": [], } for name in solver_fn_names: fn, pkl = to_idaklu(fns[name]) self._setup[name] = fn self._setup[f"{name}_pkl"] = pkl for key in self.output_variables: fn, pkl = to_idaklu(fns[f"var:{key}"]) self._setup["var_idaklu_fcns"].append(fn) self._setup["var_idaklu_fcns_pkl"].append(pkl) if has_sens: fn, pkl = to_idaklu(fns[f"dvar_dy:{key}"]) self._setup["dvar_dy_idaklu_fcns"].append(fn) self._setup["dvar_dy_idaklu_fcns_pkl"].append(pkl) fn, pkl = to_idaklu(fns[f"dvar_dp:{key}"]) self._setup["dvar_dp_idaklu_fcns"].append(fn) self._setup["dvar_dp_idaklu_fcns_pkl"].append(pkl) self._create_solver() return base_set_up_return
def _create_solver(self): """Create the idaklu solver group from _setup. Used by set_up and __setstate__.""" self._setup["solver"] = idaklu.create_casadi_solver_group( number_of_states=self._setup["number_of_states"], number_of_parameters=self._setup["number_of_sensitivity_parameters"], rhs_alg=self._setup["rhs_algebraic"], jac_times_cjmass=self._setup["jac_times_cjmass"], jac_times_cjmass_colptrs=self._setup["jac_times_cjmass_colptrs"], jac_times_cjmass_rowvals=self._setup["jac_times_cjmass_rowvals"], jac_times_cjmass_nnz=self._setup["jac_times_cjmass_nnz"], jac_bandwidth_lower=self._setup["jac_bandwidth_lower"], jac_bandwidth_upper=self._setup["jac_bandwidth_upper"], jac_action=self._setup["jac_rhs_algebraic_action"], mass_action=self._setup["mass_action"], sens=self._setup["sensfn"], events=self._setup["rootfn"], number_of_events=self._setup["num_of_events"], rhs_alg_id=self._setup["ids"], atol=self._setup["atol"], rtol=self.rtol, inputs=self._setup["inputs"], var_fcns=self._setup["var_idaklu_fcns"], dvar_dy_fcns=self._setup["dvar_dy_idaklu_fcns"], dvar_dp_fcns=self._setup["dvar_dp_idaklu_fcns"], options=self._options, ) def __getstate__(self): # if _setup is not defined then we haven't called set_up yet if not hasattr(self, "_setup"): return self.__dict__ for key in [ "solver", "rhs_algebraic", "jac_times_cjmass", "jac_rhs_algebraic_action", "mass_action", "sensfn", "rootfn", "var_idaklu_fcns", "dvar_dy_idaklu_fcns", "dvar_dp_idaklu_fcns", ]: self._setup.pop(key, None) return self.__dict__ def __setstate__(self, d): self.__dict__.update(d) # if _setup is not defined then we haven't called set_up yet if not hasattr(self, "_setup"): return for key in [ "rhs_algebraic", "jac_times_cjmass", "jac_rhs_algebraic_action", "mass_action", "sensfn", "rootfn", ]: self._setup[key] = idaklu.generate_function(self._setup[f"{key}_pkl"]) for key in ["var_idaklu_fcns", "dvar_dy_idaklu_fcns", "dvar_dp_idaklu_fcns"]: self._setup[key] = [ idaklu.generate_function(f) for f in self._setup[f"{key}_pkl"] ] self._create_solver() @property def options(self): return self._options def _integrate( self, model, t_eval, inputs_list: list[dict] | None = None, t_interp=None, nproc=None, ): """ Overloads the _integrate method from BaseSolver to use the IDAKLU solver """ if model.convert_to_format != "casadi": # pragma: no cover raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") inputs_list = inputs_list or [{}] # stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters) if inputs_list and inputs_list[0]: inputs = np.vstack( [ np.hstack([np.array(x).reshape(-1) for x in inputs_dict.values()]) for inputs_dict in inputs_list ] ) else: inputs = np.array([[]] * len(inputs_list)) # y0full is now a list with length = number of input sets y0full = np.vstack(model.y0full) ydot0full = np.vstack(model.ydot0full) atol = getattr(model, "atol", self.atol) atol = self._check_atol_type(atol, model) logger = ( pybamm.logger.debug if pybamm.logger.isEnabledFor(logging.DEBUG) else None ) timer = pybamm.Timer() try: solns = self._setup["solver"].solve( t_eval, t_interp, y0full, ydot0full, inputs, logger=logger, ) except ValueError as e: # Return from None to replace the C++ runtime error raise pybamm.SolverError(str(e)) from None integration_time = timer.time() return [ self._post_process_solution( soln, model, integration_time, inputs_dict, t_eval ) for soln, inputs_dict in zip(solns, inputs_list, strict=False) ] def _post_process_solution(self, sol, model, integration_time, inputs_dict, t_eval): number_of_sensitivity_parameters = self._setup[ "number_of_sensitivity_parameters" ] sensitivity_names = self._setup["sensitivity_names"] number_of_timesteps = sol.t.size number_of_states = model.len_rhs_and_alg save_outputs_only = self.output_variables if save_outputs_only: # Substitute empty vectors for state vector 'y' y_out = np.zeros((number_of_timesteps * number_of_states, 0)) y_event = sol.y_term else: y_out = sol.y.reshape((number_of_timesteps, number_of_states)) y_event = y_out[-1] # return sensitivity solution, we need to flatten yS to # (#timesteps * #states (where t is changing the quickest),) # to match format used by Solution # note that yS is (n_p, n_t, n_y) if number_of_sensitivity_parameters != 0: yS_out = { name: sol.yS[i].reshape(-1, 1) for i, name in enumerate(sensitivity_names) } # add "all" stacked sensitivities ((#timesteps * #states,#sens_params)) yS_out["all"] = np.hstack([yS_out[name] for name in sensitivity_names]) else: yS_out = {} # 0 = solved for all t_eval # 2 = found root(s) # < 0 = solver failure if sol.flag == 2: termination = "event" elif sol.flag >= 0: termination = "final time" elif sol.flag < 0: termination = "failure" msg = idaklu.sundials_error_message(sol.flag) match self._on_failure: case "warn": warnings.warn( msg + ", returning a partial solution.", stacklevel=2, ) case "error": raise pybamm.SolverError(msg) if sol.yp.size > 0: yp = sol.yp.reshape((number_of_timesteps, number_of_states)).T else: yp = None t = sol.t t_eval = np.array(t_eval) if t[-1] != t_eval[-1]: idx_final = np.searchsorted(t_eval, t[-1]) + 1 t_eval = t_eval[:idx_final] # the final index may differ due to an event; # manually set it to the true final time t_eval[-1] = t[-1] # Forward the compile flag so post-solve observation uses the same # backend as the integration. solution_options = {"compile": self._options["compile"]} newsol = pybamm.Solution( t, np.transpose(y_out), model, inputs_dict, np.array([sol.t[-1]]), np.transpose(y_event)[:, np.newaxis], termination, all_sensitivities=yS_out, all_yps=yp, all_t_evals=t_eval, variables_returned=bool(save_outputs_only), options=solution_options, ) newsol.integration_time = integration_time if not save_outputs_only: return newsol # Populate variables and sensitivities dictionaries directly number_of_samples = sol.y.shape[0] // number_of_timesteps sol.y = sol.y.reshape((number_of_timesteps, number_of_samples)) sensitivity_params = ( model.calculate_sensitivities if model.calculate_sensitivities else [] ) start_idx = 0 for var in self.output_variables: var_nnz, var_shape, base_variables = self._get_variable_info(model, var) end_idx = start_idx + var_nnz data = sol.y[:, start_idx:end_idx] time_indep = False # handle any time integral variables if var in self._time_integral_vars: # time integral variables should all be 1D tiv = self._time_integral_vars[var] data = tiv.postfix(data.reshape(-1), sol.t, inputs_dict) time_indep = True newsol._variables[var] = pybamm.ProcessedVariableComputed( [model.get_processed_variable_or_event(var)], base_variables, [data], newsol, time_indep=time_indep, ) # Add sensitivities newsol[var]._sensitivities = {} if sensitivity_params: if var_nnz != math.prod(var_shape): raise pybamm.SolverError( f"Sensitivity of sparse variables not supported. {var} is a sparse variable with number of non-zeros {var_nnz} and shape {var_shape}" ) sens_data = sol.yS[:, start_idx:end_idx, :] sens_data = sens_data.reshape( number_of_timesteps * (end_idx - start_idx), number_of_sensitivity_parameters, ) if var in self._time_integral_vars: tiv = self._time_integral_vars[var] sens_data = tiv.postfix_sensitivities( var, data, sol.t, inputs_dict, sens_data ) newsol[var]._sensitivities["all"] = sens_data # Add the individual sensitivity for i, name in enumerate(inputs_dict.keys()): sens = newsol[var]._sensitivities["all"][:, i : i + 1].reshape(-1) newsol[var]._sensitivities[name] = sens start_idx += var_nnz return newsol def _get_variable_info(self, model, var) -> tuple: """Get variable length and base variables based on model format.""" if model.convert_to_format == "casadi": base_var = self._setup["var_fcns"][var] var_eval = base_var(0.0, 0.0, 0.0) var_nnz = var_eval.sparsity().nnz() var_shape = var_eval.shape return var_nnz, var_shape, [base_var] else: # pragma: no cover raise pybamm.SolverError( f"Unsupported evaluation engine for convert_to_format=" f"{model.convert_to_format}" ) def _set_consistent_initialization(self, model, time, inputs_list): """ Initialize y0 and ydot0 for the solver. In addition to calculating y0 from BaseSolver, we also calculate ydot0 for semi-explicit DAEs Parameters ---------- model : :class:`pybamm.BaseModel` The model for which to calculate initial conditions. time : numeric type The time at which to calculate the initial conditions. inputs_list : list of dict Any input parameters to pass to the model when solving. """ # set model.y0_list super()._set_consistent_initialization(model, time, inputs_list) casadi_format = model.convert_to_format == "casadi" def handle_y0(y0): if isinstance(y0, casadi.DM): y0 = y0.full() return y0.flatten() y0_list = [handle_y0(y0) for y0 in model.y0_list] # calculate the time derivatives of the differential equations # for semi-explicit DAEs if model.len_rhs > 0: ydot0_list = [ self._rhs_dot_consistent_initialization(y0, model, time, inputs_dict) for y0, inputs_dict in zip(y0_list, inputs_list, strict=True) ] else: ydot0_list = [np.zeros_like(y0) for y0 in y0_list] sensitivity = model.y0S_list and casadi_format if sensitivity: y0S_list = model.y0S_list y0full = [] ydot0full = [] for y0, ydot0, y0S, inputs_dict in zip( y0_list, ydot0_list, y0S_list, inputs_list, strict=True ): y0f, ydot0f = self._sensitivity_consistent_initialization( y0, ydot0, y0S, time, inputs_dict ) y0full.append(y0f) ydot0full.append(ydot0f) else: y0full = y0_list ydot0full = ydot0_list model.y0full = y0full model.ydot0full = ydot0full def _rhs_dot_consistent_initialization(self, y0, model, time, inputs_dict): """ Compute the consistent initialization of ydot0 for the differential terms for the solver. If we have a semi-explicit DAE, we can explicitly solve for this value using the consistently initialized y0 vector. Parameters ---------- y0 : :class:`numpy.array` The initial values of the state vector. model : :class:`pybamm.BaseModel` The model for which to calculate initial conditions. time : numeric type The time at which to calculate the initial conditions. inputs_dict : dict Any input parameters to pass to the model when solving. """ casadi_format = model.convert_to_format == "casadi" inputs_dict = inputs_dict or {} # stack inputs if inputs_dict: arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] inputs = np.vstack(arrays_to_stack) else: inputs = np.array([[]]) ydot0 = np.zeros_like(y0) # calculate the time derivatives of the differential equations input_eval = inputs if casadi_format else inputs_dict rhs0 = model.rhs_eval(time, y0, input_eval) if isinstance(rhs0, casadi.DM): rhs0 = rhs0.full() rhs0 = rhs0.flatten() # for the differential terms, ydot = M^-1 * (rhs) if model.is_standard_form_dae: # M^-1 is the identity matrix, so we can just use rhs ydot0[: model.len_rhs] = rhs0 else: # M^-1 is not the identity matrix, so we need to use the mass matrix M_ode = model.mass_matrix.entries[: model.len_rhs, : model.len_rhs] ydot0[: model.len_rhs] = spsolve(M_ode, rhs0) return ydot0 def _sensitivity_consistent_initialization(self, y0, ydot0, y0S, time, inputs_dict): """ Extend the consistent initialization to include the sensitivty equations Parameters ---------- y0 : :class:`numpy.array` The initial values of the state vector. ydot0 : :class:`numpy.array` The initial values of the time derivatives of the state vector. y0S : :class:`numpy.array` The initial values of the sensitivity state vectors. time : numeric type The time at which to calculate the initial conditions. inputs_dict : dict Any input parameters to pass to the model when solving. """ if isinstance(y0S, casadi.DM): y0S = (y0S,) if isinstance(y0S[0], casadi.DM): y0S = (x.full() for x in y0S) y0S = [x.flatten() for x in y0S] y0full = np.concatenate([y0, *y0S]) ydot0S = [np.zeros_like(y0S_i) for y0S_i in y0S] ydot0full = np.concatenate([ydot0, *ydot0S]) return y0full, ydot0full
[docs] def jaxify( self, model, t_eval, *, output_variables=None, calculate_sensitivities=True, t_interp=None, ): """JAXify the solver object Creates a JAX expression representing the IDAKLU-wrapped solver object. Parameters ---------- model : :class:`pybamm.BaseModel` The model to be solved t_eval : numeric type, optional The times at which to stop the integration due to a discontinuity in time. output_variables : list of str, optional The variables to be returned. If None, all variables in the model are used. calculate_sensitivities : bool, optional Whether to calculate sensitivities. Default is True. t_interp : None, list or ndarray, optional The times (in seconds) at which to interpolate the solution. Defaults to `None`, which returns the adaptive time-stepping times. """ obj = pybamm.IDAKLUJax( self, # IDAKLU solver instance model, t_eval, output_variables=output_variables, calculate_sensitivities=calculate_sensitivities, t_interp=t_interp, ) return obj
[docs] def reduce_solution( self, solution, hermite_reduction_factor=2.0, ) -> pybamm.Solution: """Reduce knots in a pybamm.Solution using Hermite spline compression. The multiplier M controls the total error budget relative to the solver's own WRMS tolerance. The knot reducer is allowed an additional WRMS error of (M-1), so the total error satisfies ``||e_total||_WRMS <= M``. M = 2 (default) means the reduced solution may have up to 2x the solver's local error. Parameters ---------- solution : :class:`pybamm.Solution` The solution to reduce. Must have hermite interpolation data (all_yps is not None). hermite_reduction_factor : float, optional Total error multiplier (>= 1.0, default 2.0). The knot reducer's WRMS threshold is N * (M-1)^2. Larger values allow more aggressive reduction. Returns ------- :class:`pybamm.Solution` The modified solution. """ if not solution.hermite_interpolation: raise pybamm.SolverError( "reduce_solution requires Hermite interpolation data (all_yps)." ) if self.options["hermite_reduction_factor"] != 1.0: raise pybamm.SolverError( "reduce_solution requires the original solver to have " "`hermite_reduction_factor = 1.0`" ) all_ts = solution.all_ts all_ys = solution.all_ys all_yps = solution.all_yps all_models = solution.all_models n_seg = len(all_ts) rtol = self.rtol atol = self.atol # Build flat time-major arrays for the C++ reducer. # all_ys[i] is (n_states, M) transposed view whose underlying # buffer is already time-major (M * n_states,). .T.ravel() # returns a 1D view of that buffer -- no copy. flat_ys = [all_ys[i].T.ravel() for i in range(n_seg)] flat_yps = [all_yps[i].T.ravel() for i in range(n_seg)] # Build per-segment atol vectors atol_vecs = [None] * n_seg for i, model in enumerate(all_models): atol_vecs[i] = self._check_atol_type(atol, model) ts_vec = idaklu.VectorRealtypeNdArray(all_ts) ys_vec = idaklu.VectorRealtypeNdArray(flat_ys) yps_vec = idaklu.VectorRealtypeNdArray(flat_yps) atols_vec = idaklu.VectorRealtypeNdArray(atol_vecs) t_evals_vec = idaklu.VectorRealtypeNdArray(solution.all_t_evals) red_ts, red_ys, red_yps = idaklu.reduce_knots( ts_vec, ys_vec, yps_vec, atols_vec, t_evals_vec, float(rtol), float(hermite_reduction_factor), ) # Reshape reduced flat arrays back to (n_states, K) convention new_ts = [np.asarray(red_ts[i]) for i in range(n_seg)] new_ys = [] new_yps = [] for i in range(n_seg): K = len(new_ts[i]) N = all_ys[i].shape[0] new_ys.append(np.asarray(red_ys[i]).reshape(K, N).T) new_yps.append(np.asarray(red_yps[i]).reshape(K, N).T) new_sol = pybamm.Solution( all_ts=new_ts, all_ys=new_ys, all_yps=new_yps, all_models=all_models, all_inputs=solution.all_inputs, t_event=solution.t_event, y_event=solution.y_event, termination=solution.termination, all_sensitivities=solution._all_sensitivities, all_t_evals=solution.all_t_evals, variables_returned=solution.variables_returned, options=solution.user_options, ) # Propagate metadata from the original solution new_sol._all_inputs_stacked = solution.all_inputs_stacked new_sol._all_inputs_casadi = solution.all_inputs_casadi new_sol.closest_event_idx = solution.closest_event_idx new_sol.solve_time = solution.solve_time new_sol.integration_time = solution.integration_time new_sol.set_up_time = solution.set_up_time return new_sol