# mypy: ignore-errors
import logging
import math
import numbers
import warnings
import casadi
import numpy as np
import pybammsolvers.idaklu as idaklu
from scipy.sparse.linalg import spsolve
import pybamm
from pybamm.codegen.compilation import aot_compile
[docs]
class IDAKLUSolver(pybamm.BaseSolver):
"""
Solve a discretised model, using sundials with the KLU sparse linear solver.
Parameters
----------
rtol : float, optional
The relative tolerance for the solver (default is 1e-4).
atol : float, optional
The absolute tolerance for the solver (default is 1e-6).
root_method : str or pybamm algebraic solver class, optional
The method to use to find initial conditions (for DAE solvers).
If a solver class, must be an algebraic solver class.
If "casadi",
the solver uses casadi's Newton rootfinding algorithm to find initial
conditions. Otherwise, the solver uses 'scipy.optimize.root' with method
specified by 'root_method' (e.g. "lm", "hybr", ...)
root_tol : float, optional
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
on_extrapolation : str, optional
What to do if the solver is extrapolating. Options are "warn", "error", or "ignore".
Default is "warn".
on_failure : str, optional
What to do if a solver error flag occurs. Options are "warn", "error", or "ignore".
Default is "error".
output_variables : list[str], optional
List of variables to calculate and return. If none are specified then
the complete state vector is returned (can be very large) (default is [])
options: dict, optional
Addititional options to pass to the solver, by default:
.. code-block:: python
options = {
# Print statistics of the solver after every solve
"print_stats": False,
# If ``True``, ahead-of-time compile each casadi ``Function``
# to a shared library via the system C compiler (see
# :mod:`pybamm.codegen.compilation`). If ``False`` (default)
# the casadi in-process virtual machine is used.
"compile": False,
# Number of threads available for OpenMP (must be greater than or equal to `num_solvers`)
"num_threads": 1,
# Number of solvers to use in parallel (for solving multiple sets of input parameters in parallel)
"num_solvers": num_threads,
## Linear solver interface
# name of sundials linear solver to use options are: "SUNLinSol_KLU",
# "SUNLinSol_Dense", "SUNLinSol_Band", "SUNLinSol_SPBCGS",
# "SUNLinSol_SPFGMR", "SUNLinSol_SPGMR", "SUNLinSol_SPTFQMR",
"linear_solver": "SUNLinSol_KLU",
# Jacobian form, can be "none", "dense",
# "banded", "sparse", "matrix-free"
"jacobian": "sparse",
# Preconditioner for iterative solvers, can be "none", "BBDP"
"preconditioner": "BBDP",
# For iterative linear solver preconditioner, bandwidth of
# approximate jacobian
"precon_half_bandwidth": 5,
# For iterative linear solver preconditioner, bandwidth of
# approximate jacobian that is kept
"precon_half_bandwidth_keep": 5,
# For iterative linear solvers, max number of iterations
"linsol_max_iterations": 5,
# Ratio between linear and nonlinear tolerances
"epsilon_linear_tolerance": 0.05,
# Increment factor used in DQ Jacobian-vector product approximation
"increment_factor": 1.0,
# Enable or disable linear solution scaling
"linear_solution_scaling": True,
# Silence Sundials errors during solve
"silence_sundials_errors": False,
## Main solver
# Maximum order of the linear multistep method
"max_order_bdf": 5,
# Maximum number of steps to be taken by the solver in its attempt to
# reach the next output time.
# Note: this value differs from the IDA default of 500
"max_num_steps": 100000,
# Initial step size. The solver default is used if this is left at 0.0
"dt_init": 0.0,
# Minimum absolute step size. The solver default is used if this is
# left at 0.0
"dt_min": 0.0,
# Maximum absolute step size. The solver default is used if this is
# left at 0.0
"dt_max": 0.0,
# Maximum number of error test failures in attempting one step
"max_error_test_failures": 10,
# Maximum number of nonlinear solver iterations at one step
# Note: this value differs from the IDA default of 4
"max_nonlinear_iterations": 40,
# Maximum number of nonlinear solver convergence failures at one step
# Note: this value differs from the IDA default of 10
"max_convergence_failures": 100,
# Safety factor in the nonlinear convergence test
"nonlinear_convergence_coefficient": 0.33,
# Suppress algebraic variables from error test
"suppress_algebraic_error": False,
# Store Hermite interpolation data for the solution.
# Note: this option is always disabled if output_variables are given
# or if t_interp values are specified
"hermite_interpolation": True,
# Setting hermite_reduction_factor > 1.0 compresses the solution size
# by introducing a small amount of error to the Hermite spline
# interpolant. A value of `2.0` roughly corresponds to a maximum 2x
# increase in error (practically the error is much smaller), while
# reducing the number of saved states by around 5-6x. This option is
# only active if `hermite_interpolation` is True and sensitivities
# are disabled.
"hermite_reduction_factor": 1.0,
## Initial conditions calculation
# Positive constant in the Newton iteration convergence test within the
# initial condition calculation
"nonlinear_convergence_coefficient_ic": 0.0033,
# Maximum number of steps allowed when `init_all_y_ic = False`
# Note: this value differs from the IDA default of 5
"max_num_steps_ic": 50,
# Maximum number of the approximate Jacobian or preconditioner evaluations
# allowed when the Newton iteration appears to be slowly converging
# Note: this value differs from the IDA default of 4
"max_num_jacobians_ic": 40,
# Maximum number of Newton iterations allowed in any one attempt to solve
# the initial conditions calculation problem
# Note: this value differs from the IDA default of 10
"max_num_iterations_ic": 100,
# Maximum number of linesearch backtracks allowed in any Newton iteration,
# when solving the initial conditions calculation problem
"max_linesearch_backtracks_ic": 100,
# Turn off linesearch
"linesearch_off_ic": False,
# How to calculate the initial conditions.
# "True": calculate all y0 given ydot0
# "False": calculate y_alg0 and ydot_diff0 given y_diff0
"init_all_y_ic": False,
# Calculate consistent initial conditions
"calc_ic": True,
## Early termination
# Maximum number of consecutive steps allowed without advancing
# the solution time by at least `t_no_progress` seconds.
# If set to 0, this feature is disabled.
"num_steps_no_progress": 0,
# Minimum required time advancement (in seconds) after
# `num_steps_no_progress` consecutive steps.
# If set to 0.0, this feature is disabled.
"t_no_progress": 0.0,
}
Note: These options only have an effect if model.convert_to_format == 'casadi'
"""
def __init__(
self,
rtol=1e-4,
atol=1e-6,
root_method="casadi",
root_tol=1e-6,
extrap_tol=None,
on_extrapolation=None,
output_variables=None,
on_failure=None,
options=None,
):
self.output_variables = [] if output_variables is None else output_variables
self._options = self._combine_options(options)
super().__init__(
method="ida",
rtol=rtol,
atol=atol,
root_method=root_method,
root_tol=root_tol,
extrap_tol=extrap_tol,
output_variables=output_variables,
on_extrapolation=on_extrapolation,
on_failure=on_failure,
)
self.name = "IDA KLU solver"
self._supports_interp = True
self._supports_t_eval_discontinuities = True
pybamm.citations.register("Hindmarsh2000")
pybamm.citations.register("Hindmarsh2005")
def _combine_options(self, user_options: dict | None) -> dict:
num_solvers = user_options.get("num_threads", 1) if user_options else 1
default_options = {
"print_stats": False,
"compile": False,
"jacobian": "sparse",
"preconditioner": "BBDP",
"precon_half_bandwidth": 5,
"precon_half_bandwidth_keep": 5,
"num_threads": 1,
"num_solvers": num_solvers,
"linear_solver": "SUNLinSol_KLU",
"linsol_max_iterations": 5,
"epsilon_linear_tolerance": 0.05,
"increment_factor": 1.0,
"linear_solution_scaling": True,
"silence_sundials_errors": False,
"max_order_bdf": 5,
"max_num_steps": 100000,
"dt_init": 0.0,
"dt_min": 0.0,
"dt_max": 0.0,
"max_error_test_failures": 10,
"max_nonlinear_iterations": 40,
"max_convergence_failures": 100,
"nonlinear_convergence_coefficient": 0.33,
"suppress_algebraic_error": False,
"hermite_interpolation": True,
"hermite_reduction_factor": 1.0,
"nonlinear_convergence_coefficient_ic": 0.0033,
"max_num_steps_ic": 50,
"max_num_jacobians_ic": 40,
"max_num_iterations_ic": 100,
"max_linesearch_backtracks_ic": 100,
"linesearch_off_ic": False,
"init_all_y_ic": False,
"calc_ic": True,
"num_steps_no_progress": 0,
"t_no_progress": 0.0,
}
if not user_options:
return default_options
options = default_options | user_options
self._check_options(options)
return options
def _check_options(self, options: dict):
hermite_reduction_factor = options["hermite_reduction_factor"]
if hermite_reduction_factor > 1.0:
if self.output_variables:
raise pybamm.SolverError(
"hermite_reduction_factor cannot be used with "
"output_variables. Both are memory-saving options "
"that are mutually exclusive."
)
if not options["hermite_interpolation"]:
raise pybamm.SolverError(
"hermite_reduction_factor requires "
"hermite_interpolation to be enabled."
)
else:
if hermite_reduction_factor < 1.0:
raise pybamm.SolverError("hermite_reduction_factor must be >= 1.0.")
if not isinstance(options["compile"], bool):
raise pybamm.SolverError("compile must be a bool")
def _check_atol_type(self, atol, model):
if isinstance(atol, float):
return np.full(model.len_rhs_and_alg, atol)
elif isinstance(atol, np.ndarray):
return atol
else:
raise pybamm.SolverError(
"Absolute tolerances must be a numpy array or float"
)
[docs]
def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
base_set_up_return = super().set_up(model, inputs, t_eval, ics_only)
if isinstance(inputs, list):
# for setup, just use first input dict
inputs_dict = inputs[0]
else:
inputs_dict = inputs or {}
# stack inputs
if inputs_dict:
arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()]
stacked_inputs = np.vstack(arrays_to_stack)
else:
stacked_inputs = np.array([[]])
y0 = model.y0_list[0]
if isinstance(y0, casadi.DM):
y0 = y0.full()
y0 = y0.flatten()
if ics_only:
return base_set_up_return
if model.convert_to_format != "casadi":
msg = "The python-idaklu solver has been deprecated."
warnings.warn(msg, DeprecationWarning, stacklevel=2)
raise pybamm.SolverError(
f"Unsupported option for convert_to_format={model.convert_to_format}"
)
if self._options["jacobian"] == "dense":
mass_matrix = casadi.DM(model.mass_matrix.entries.toarray())
else:
mass_matrix = casadi.DM(model.mass_matrix.entries)
# construct residuals function by binding inputs
# TODO: do we need densify here?
rhs_algebraic = model.rhs_algebraic_eval
if not model.use_jacobian:
raise pybamm.SolverError("KLU requires the Jacobian")
# need to provide jacobian_rhs_alg - cj * mass_matrix
t_casadi = casadi.MX.sym("t")
y_casadi = casadi.MX.sym("y", model.len_rhs_and_alg)
cj_casadi = casadi.MX.sym("cj")
p_casadi = {}
for name, value in inputs_dict.items():
if isinstance(value, numbers.Number):
p_casadi[name] = casadi.MX.sym(name)
else:
p_casadi[name] = casadi.MX.sym(name, value.shape[0])
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])
jac_times_cjmass = casadi.Function(
"jac_times_cjmass",
[t_casadi, y_casadi, p_casadi_stacked, cj_casadi],
[
model.jac_rhs_algebraic_eval(t_casadi, y_casadi, p_casadi_stacked)
- cj_casadi * mass_matrix
],
)
jac_times_cjmass_sparsity = jac_times_cjmass.sparsity_out(0)
jac_bw_lower = jac_times_cjmass_sparsity.bw_lower()
jac_bw_upper = jac_times_cjmass_sparsity.bw_upper()
jac_times_cjmass_nnz = jac_times_cjmass_sparsity.nnz()
jac_times_cjmass_colptrs = np.array(
jac_times_cjmass_sparsity.colind(), dtype=np.int64
)
jac_times_cjmass_rowvals = np.array(
jac_times_cjmass_sparsity.row(), dtype=np.int64
)
v_casadi = casadi.MX.sym("v", model.len_rhs_and_alg)
jac_rhs_algebraic_action = model.jac_rhs_algebraic_action_eval
# also need the action of the mass matrix on a vector
mass_action = casadi.Function(
"mass_action", [v_casadi], [casadi.densify(mass_matrix @ v_casadi)]
)
num_of_events = len(model.terminate_events_eval)
# rootfn needs to return an array of length num_of_events
rootfn = casadi.Function(
"rootfn",
[t_casadi, y_casadi, p_casadi_stacked],
[
casadi.vertcat(
*[
event(t_casadi, y_casadi, p_casadi_stacked)
for event in model.terminate_events_eval
]
)
],
)
# get ids of rhs and algebraic variables
rhs_ids = np.ones(model.rhs_eval(0, y0, stacked_inputs).shape[0])
alg_ids = np.zeros(len(y0) - len(rhs_ids))
ids = np.concatenate((rhs_ids, alg_ids))
if model.jacp_rhs_algebraic_eval is not None:
sensitivity_names = model.calculate_sensitivities
if model.convert_to_format == "casadi":
number_of_sensitivity_parameters = model.jacp_rhs_algebraic_eval.n_out()
else:
number_of_sensitivity_parameters = len(sensitivity_names)
else:
number_of_sensitivity_parameters = 0
sensitivity_names = []
# for the casadi solver we just give it dFdp_i
if model.jacp_rhs_algebraic_eval is None:
sensfn = casadi.Function("sensfn", [], [])
else:
sensfn = model.jacp_rhs_algebraic_eval
atol = getattr(model, "atol", self.atol)
atol = self._check_atol_type(atol, model)
if (
self._options["hermite_reduction_factor"] > 1.0
and number_of_sensitivity_parameters > 0
):
warnings.warn(
"Setting hermite_reduction_factor > 1.0 is not currently supported "
"with sensitivities. The hermite_reduction_factor option will be "
"ignored.",
pybamm.SolverWarning,
stacklevel=2,
)
# Collect all casadi functions for AOT compilation. With compile=True,
# all functions are bundled into a single shared library via one gcc
# invocation; each serialized Function then points at its entry point.
has_sens = (len(stacked_inputs) > 0) and model.calculate_sensitivities
solver_fn_names = [
"rhs_algebraic",
"jac_times_cjmass",
"jac_rhs_algebraic_action",
"mass_action",
"sensfn",
"rootfn",
]
fns = {
"rhs_algebraic": rhs_algebraic,
"jac_times_cjmass": jac_times_cjmass,
"jac_rhs_algebraic_action": jac_rhs_algebraic_action,
"mass_action": mass_action,
"sensfn": sensfn,
"rootfn": rootfn,
}
for key in self.output_variables:
fns[f"var:{key}"] = self.computed_var_fcns[key]
if has_sens:
fns[f"dvar_dy:{key}"] = self.computed_dvar_dy_fcns[key]
fns[f"dvar_dp:{key}"] = self.computed_dvar_dp_fcns[key]
if self._options["compile"]:
compiled = aot_compile(list(fns.values()))
fns = dict(zip(fns.keys(), compiled, strict=True))
# Build _setup dict with serialized functions for idaklu C++ wrapper
def to_idaklu(fn):
pkl = fn.serialize()
return idaklu.generate_function(pkl), pkl
self._setup = {
"number_of_states": len(y0),
"inputs": len(stacked_inputs),
"jac_bandwidth_upper": jac_bw_upper,
"jac_bandwidth_lower": jac_bw_lower,
"jac_times_cjmass_colptrs": jac_times_cjmass_colptrs,
"jac_times_cjmass_rowvals": jac_times_cjmass_rowvals,
"jac_times_cjmass_nnz": jac_times_cjmass_nnz,
"num_of_events": num_of_events,
"ids": ids,
"atol": atol,
"sensitivity_names": sensitivity_names,
"number_of_sensitivity_parameters": number_of_sensitivity_parameters,
"standard_form_dae": model.is_standard_form_dae,
"output_variables": self.output_variables,
"var_fcns": self.computed_var_fcns,
"var_idaklu_fcns": [],
"var_idaklu_fcns_pkl": [],
"dvar_dy_idaklu_fcns": [],
"dvar_dy_idaklu_fcns_pkl": [],
"dvar_dp_idaklu_fcns": [],
"dvar_dp_idaklu_fcns_pkl": [],
}
for name in solver_fn_names:
fn, pkl = to_idaklu(fns[name])
self._setup[name] = fn
self._setup[f"{name}_pkl"] = pkl
for key in self.output_variables:
fn, pkl = to_idaklu(fns[f"var:{key}"])
self._setup["var_idaklu_fcns"].append(fn)
self._setup["var_idaklu_fcns_pkl"].append(pkl)
if has_sens:
fn, pkl = to_idaklu(fns[f"dvar_dy:{key}"])
self._setup["dvar_dy_idaklu_fcns"].append(fn)
self._setup["dvar_dy_idaklu_fcns_pkl"].append(pkl)
fn, pkl = to_idaklu(fns[f"dvar_dp:{key}"])
self._setup["dvar_dp_idaklu_fcns"].append(fn)
self._setup["dvar_dp_idaklu_fcns_pkl"].append(pkl)
self._create_solver()
return base_set_up_return
def _create_solver(self):
"""Create the idaklu solver group from _setup. Used by set_up and __setstate__."""
self._setup["solver"] = idaklu.create_casadi_solver_group(
number_of_states=self._setup["number_of_states"],
number_of_parameters=self._setup["number_of_sensitivity_parameters"],
rhs_alg=self._setup["rhs_algebraic"],
jac_times_cjmass=self._setup["jac_times_cjmass"],
jac_times_cjmass_colptrs=self._setup["jac_times_cjmass_colptrs"],
jac_times_cjmass_rowvals=self._setup["jac_times_cjmass_rowvals"],
jac_times_cjmass_nnz=self._setup["jac_times_cjmass_nnz"],
jac_bandwidth_lower=self._setup["jac_bandwidth_lower"],
jac_bandwidth_upper=self._setup["jac_bandwidth_upper"],
jac_action=self._setup["jac_rhs_algebraic_action"],
mass_action=self._setup["mass_action"],
sens=self._setup["sensfn"],
events=self._setup["rootfn"],
number_of_events=self._setup["num_of_events"],
rhs_alg_id=self._setup["ids"],
atol=self._setup["atol"],
rtol=self.rtol,
inputs=self._setup["inputs"],
var_fcns=self._setup["var_idaklu_fcns"],
dvar_dy_fcns=self._setup["dvar_dy_idaklu_fcns"],
dvar_dp_fcns=self._setup["dvar_dp_idaklu_fcns"],
options=self._options,
)
def __getstate__(self):
# if _setup is not defined then we haven't called set_up yet
if not hasattr(self, "_setup"):
return self.__dict__
for key in [
"solver",
"rhs_algebraic",
"jac_times_cjmass",
"jac_rhs_algebraic_action",
"mass_action",
"sensfn",
"rootfn",
"var_idaklu_fcns",
"dvar_dy_idaklu_fcns",
"dvar_dp_idaklu_fcns",
]:
self._setup.pop(key, None)
return self.__dict__
def __setstate__(self, d):
self.__dict__.update(d)
# if _setup is not defined then we haven't called set_up yet
if not hasattr(self, "_setup"):
return
for key in [
"rhs_algebraic",
"jac_times_cjmass",
"jac_rhs_algebraic_action",
"mass_action",
"sensfn",
"rootfn",
]:
self._setup[key] = idaklu.generate_function(self._setup[f"{key}_pkl"])
for key in ["var_idaklu_fcns", "dvar_dy_idaklu_fcns", "dvar_dp_idaklu_fcns"]:
self._setup[key] = [
idaklu.generate_function(f) for f in self._setup[f"{key}_pkl"]
]
self._create_solver()
@property
def options(self):
return self._options
def _integrate(
self,
model,
t_eval,
inputs_list: list[dict] | None = None,
t_interp=None,
nproc=None,
):
"""
Overloads the _integrate method from BaseSolver to use the IDAKLU solver
"""
if model.convert_to_format != "casadi": # pragma: no cover
raise pybamm.SolverError("Unsupported IDAKLU solver configuration.")
inputs_list = inputs_list or [{}]
# stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters)
if inputs_list and inputs_list[0]:
inputs = np.vstack(
[
np.hstack([np.array(x).reshape(-1) for x in inputs_dict.values()])
for inputs_dict in inputs_list
]
)
else:
inputs = np.array([[]] * len(inputs_list))
# y0full is now a list with length = number of input sets
y0full = np.vstack(model.y0full)
ydot0full = np.vstack(model.ydot0full)
atol = getattr(model, "atol", self.atol)
atol = self._check_atol_type(atol, model)
logger = (
pybamm.logger.debug if pybamm.logger.isEnabledFor(logging.DEBUG) else None
)
timer = pybamm.Timer()
try:
solns = self._setup["solver"].solve(
t_eval,
t_interp,
y0full,
ydot0full,
inputs,
logger=logger,
)
except ValueError as e:
# Return from None to replace the C++ runtime error
raise pybamm.SolverError(str(e)) from None
integration_time = timer.time()
return [
self._post_process_solution(
soln, model, integration_time, inputs_dict, t_eval
)
for soln, inputs_dict in zip(solns, inputs_list, strict=False)
]
def _post_process_solution(self, sol, model, integration_time, inputs_dict, t_eval):
number_of_sensitivity_parameters = self._setup[
"number_of_sensitivity_parameters"
]
sensitivity_names = self._setup["sensitivity_names"]
number_of_timesteps = sol.t.size
number_of_states = model.len_rhs_and_alg
save_outputs_only = self.output_variables
if save_outputs_only:
# Substitute empty vectors for state vector 'y'
y_out = np.zeros((number_of_timesteps * number_of_states, 0))
y_event = sol.y_term
else:
y_out = sol.y.reshape((number_of_timesteps, number_of_states))
y_event = y_out[-1]
# return sensitivity solution, we need to flatten yS to
# (#timesteps * #states (where t is changing the quickest),)
# to match format used by Solution
# note that yS is (n_p, n_t, n_y)
if number_of_sensitivity_parameters != 0:
yS_out = {
name: sol.yS[i].reshape(-1, 1)
for i, name in enumerate(sensitivity_names)
}
# add "all" stacked sensitivities ((#timesteps * #states,#sens_params))
yS_out["all"] = np.hstack([yS_out[name] for name in sensitivity_names])
else:
yS_out = {}
# 0 = solved for all t_eval
# 2 = found root(s)
# < 0 = solver failure
if sol.flag == 2:
termination = "event"
elif sol.flag >= 0:
termination = "final time"
elif sol.flag < 0:
termination = "failure"
msg = idaklu.sundials_error_message(sol.flag)
match self._on_failure:
case "warn":
warnings.warn(
msg + ", returning a partial solution.",
stacklevel=2,
)
case "error":
raise pybamm.SolverError(msg)
if sol.yp.size > 0:
yp = sol.yp.reshape((number_of_timesteps, number_of_states)).T
else:
yp = None
t = sol.t
t_eval = np.array(t_eval)
if t[-1] != t_eval[-1]:
idx_final = np.searchsorted(t_eval, t[-1]) + 1
t_eval = t_eval[:idx_final]
# the final index may differ due to an event;
# manually set it to the true final time
t_eval[-1] = t[-1]
# Forward the compile flag so post-solve observation uses the same
# backend as the integration.
solution_options = {"compile": self._options["compile"]}
newsol = pybamm.Solution(
t,
np.transpose(y_out),
model,
inputs_dict,
np.array([sol.t[-1]]),
np.transpose(y_event)[:, np.newaxis],
termination,
all_sensitivities=yS_out,
all_yps=yp,
all_t_evals=t_eval,
variables_returned=bool(save_outputs_only),
options=solution_options,
)
newsol.integration_time = integration_time
if not save_outputs_only:
return newsol
# Populate variables and sensitivities dictionaries directly
number_of_samples = sol.y.shape[0] // number_of_timesteps
sol.y = sol.y.reshape((number_of_timesteps, number_of_samples))
sensitivity_params = (
model.calculate_sensitivities if model.calculate_sensitivities else []
)
start_idx = 0
for var in self.output_variables:
var_nnz, var_shape, base_variables = self._get_variable_info(model, var)
end_idx = start_idx + var_nnz
data = sol.y[:, start_idx:end_idx]
time_indep = False
# handle any time integral variables
if var in self._time_integral_vars:
# time integral variables should all be 1D
tiv = self._time_integral_vars[var]
data = tiv.postfix(data.reshape(-1), sol.t, inputs_dict)
time_indep = True
newsol._variables[var] = pybamm.ProcessedVariableComputed(
[model.get_processed_variable_or_event(var)],
base_variables,
[data],
newsol,
time_indep=time_indep,
)
# Add sensitivities
newsol[var]._sensitivities = {}
if sensitivity_params:
if var_nnz != math.prod(var_shape):
raise pybamm.SolverError(
f"Sensitivity of sparse variables not supported. {var} is a sparse variable with number of non-zeros {var_nnz} and shape {var_shape}"
)
sens_data = sol.yS[:, start_idx:end_idx, :]
sens_data = sens_data.reshape(
number_of_timesteps * (end_idx - start_idx),
number_of_sensitivity_parameters,
)
if var in self._time_integral_vars:
tiv = self._time_integral_vars[var]
sens_data = tiv.postfix_sensitivities(
var, data, sol.t, inputs_dict, sens_data
)
newsol[var]._sensitivities["all"] = sens_data
# Add the individual sensitivity
for i, name in enumerate(inputs_dict.keys()):
sens = newsol[var]._sensitivities["all"][:, i : i + 1].reshape(-1)
newsol[var]._sensitivities[name] = sens
start_idx += var_nnz
return newsol
def _get_variable_info(self, model, var) -> tuple:
"""Get variable length and base variables based on model format."""
if model.convert_to_format == "casadi":
base_var = self._setup["var_fcns"][var]
var_eval = base_var(0.0, 0.0, 0.0)
var_nnz = var_eval.sparsity().nnz()
var_shape = var_eval.shape
return var_nnz, var_shape, [base_var]
else: # pragma: no cover
raise pybamm.SolverError(
f"Unsupported evaluation engine for convert_to_format="
f"{model.convert_to_format}"
)
def _set_consistent_initialization(self, model, time, inputs_list):
"""
Initialize y0 and ydot0 for the solver. In addition to calculating
y0 from BaseSolver, we also calculate ydot0 for semi-explicit DAEs
Parameters
----------
model : :class:`pybamm.BaseModel`
The model for which to calculate initial conditions.
time : numeric type
The time at which to calculate the initial conditions.
inputs_list : list of dict
Any input parameters to pass to the model when solving.
"""
# set model.y0_list
super()._set_consistent_initialization(model, time, inputs_list)
casadi_format = model.convert_to_format == "casadi"
def handle_y0(y0):
if isinstance(y0, casadi.DM):
y0 = y0.full()
return y0.flatten()
y0_list = [handle_y0(y0) for y0 in model.y0_list]
# calculate the time derivatives of the differential equations
# for semi-explicit DAEs
if model.len_rhs > 0:
ydot0_list = [
self._rhs_dot_consistent_initialization(y0, model, time, inputs_dict)
for y0, inputs_dict in zip(y0_list, inputs_list, strict=True)
]
else:
ydot0_list = [np.zeros_like(y0) for y0 in y0_list]
sensitivity = model.y0S_list and casadi_format
if sensitivity:
y0S_list = model.y0S_list
y0full = []
ydot0full = []
for y0, ydot0, y0S, inputs_dict in zip(
y0_list, ydot0_list, y0S_list, inputs_list, strict=True
):
y0f, ydot0f = self._sensitivity_consistent_initialization(
y0, ydot0, y0S, time, inputs_dict
)
y0full.append(y0f)
ydot0full.append(ydot0f)
else:
y0full = y0_list
ydot0full = ydot0_list
model.y0full = y0full
model.ydot0full = ydot0full
def _rhs_dot_consistent_initialization(self, y0, model, time, inputs_dict):
"""
Compute the consistent initialization of ydot0 for the differential terms
for the solver. If we have a semi-explicit DAE, we can explicitly solve
for this value using the consistently initialized y0 vector.
Parameters
----------
y0 : :class:`numpy.array`
The initial values of the state vector.
model : :class:`pybamm.BaseModel`
The model for which to calculate initial conditions.
time : numeric type
The time at which to calculate the initial conditions.
inputs_dict : dict
Any input parameters to pass to the model when solving.
"""
casadi_format = model.convert_to_format == "casadi"
inputs_dict = inputs_dict or {}
# stack inputs
if inputs_dict:
arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()]
inputs = np.vstack(arrays_to_stack)
else:
inputs = np.array([[]])
ydot0 = np.zeros_like(y0)
# calculate the time derivatives of the differential equations
input_eval = inputs if casadi_format else inputs_dict
rhs0 = model.rhs_eval(time, y0, input_eval)
if isinstance(rhs0, casadi.DM):
rhs0 = rhs0.full()
rhs0 = rhs0.flatten()
# for the differential terms, ydot = M^-1 * (rhs)
if model.is_standard_form_dae:
# M^-1 is the identity matrix, so we can just use rhs
ydot0[: model.len_rhs] = rhs0
else:
# M^-1 is not the identity matrix, so we need to use the mass matrix
M_ode = model.mass_matrix.entries[: model.len_rhs, : model.len_rhs]
ydot0[: model.len_rhs] = spsolve(M_ode, rhs0)
return ydot0
def _sensitivity_consistent_initialization(self, y0, ydot0, y0S, time, inputs_dict):
"""
Extend the consistent initialization to include the sensitivty equations
Parameters
----------
y0 : :class:`numpy.array`
The initial values of the state vector.
ydot0 : :class:`numpy.array`
The initial values of the time derivatives of the state vector.
y0S : :class:`numpy.array`
The initial values of the sensitivity state vectors.
time : numeric type
The time at which to calculate the initial conditions.
inputs_dict : dict
Any input parameters to pass to the model when solving.
"""
if isinstance(y0S, casadi.DM):
y0S = (y0S,)
if isinstance(y0S[0], casadi.DM):
y0S = (x.full() for x in y0S)
y0S = [x.flatten() for x in y0S]
y0full = np.concatenate([y0, *y0S])
ydot0S = [np.zeros_like(y0S_i) for y0S_i in y0S]
ydot0full = np.concatenate([ydot0, *ydot0S])
return y0full, ydot0full
[docs]
def jaxify(
self,
model,
t_eval,
*,
output_variables=None,
calculate_sensitivities=True,
t_interp=None,
):
"""JAXify the solver object
Creates a JAX expression representing the IDAKLU-wrapped solver
object.
Parameters
----------
model : :class:`pybamm.BaseModel`
The model to be solved
t_eval : numeric type, optional
The times at which to stop the integration due to a discontinuity in time.
output_variables : list of str, optional
The variables to be returned. If None, all variables in the model are used.
calculate_sensitivities : bool, optional
Whether to calculate sensitivities. Default is True.
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to `None`,
which returns the adaptive time-stepping times.
"""
obj = pybamm.IDAKLUJax(
self, # IDAKLU solver instance
model,
t_eval,
output_variables=output_variables,
calculate_sensitivities=calculate_sensitivities,
t_interp=t_interp,
)
return obj
[docs]
def reduce_solution(
self,
solution,
hermite_reduction_factor=2.0,
) -> pybamm.Solution:
"""Reduce knots in a pybamm.Solution using Hermite spline compression.
The multiplier M controls the total error budget relative to the
solver's own WRMS tolerance. The knot reducer is allowed an
additional WRMS error of (M-1), so the total error satisfies
``||e_total||_WRMS <= M``. M = 2 (default) means the reduced
solution may have up to 2x the solver's local error.
Parameters
----------
solution : :class:`pybamm.Solution`
The solution to reduce. Must have hermite interpolation data
(all_yps is not None).
hermite_reduction_factor : float, optional
Total error multiplier (>= 1.0, default 2.0). The knot
reducer's WRMS threshold is N * (M-1)^2. Larger values
allow more aggressive reduction.
Returns
-------
:class:`pybamm.Solution`
The modified solution.
"""
if not solution.hermite_interpolation:
raise pybamm.SolverError(
"reduce_solution requires Hermite interpolation data (all_yps)."
)
if self.options["hermite_reduction_factor"] != 1.0:
raise pybamm.SolverError(
"reduce_solution requires the original solver to have "
"`hermite_reduction_factor = 1.0`"
)
all_ts = solution.all_ts
all_ys = solution.all_ys
all_yps = solution.all_yps
all_models = solution.all_models
n_seg = len(all_ts)
rtol = self.rtol
atol = self.atol
# Build flat time-major arrays for the C++ reducer.
# all_ys[i] is (n_states, M) transposed view whose underlying
# buffer is already time-major (M * n_states,). .T.ravel()
# returns a 1D view of that buffer -- no copy.
flat_ys = [all_ys[i].T.ravel() for i in range(n_seg)]
flat_yps = [all_yps[i].T.ravel() for i in range(n_seg)]
# Build per-segment atol vectors
atol_vecs = [None] * n_seg
for i, model in enumerate(all_models):
atol_vecs[i] = self._check_atol_type(atol, model)
ts_vec = idaklu.VectorRealtypeNdArray(all_ts)
ys_vec = idaklu.VectorRealtypeNdArray(flat_ys)
yps_vec = idaklu.VectorRealtypeNdArray(flat_yps)
atols_vec = idaklu.VectorRealtypeNdArray(atol_vecs)
t_evals_vec = idaklu.VectorRealtypeNdArray(solution.all_t_evals)
red_ts, red_ys, red_yps = idaklu.reduce_knots(
ts_vec,
ys_vec,
yps_vec,
atols_vec,
t_evals_vec,
float(rtol),
float(hermite_reduction_factor),
)
# Reshape reduced flat arrays back to (n_states, K) convention
new_ts = [np.asarray(red_ts[i]) for i in range(n_seg)]
new_ys = []
new_yps = []
for i in range(n_seg):
K = len(new_ts[i])
N = all_ys[i].shape[0]
new_ys.append(np.asarray(red_ys[i]).reshape(K, N).T)
new_yps.append(np.asarray(red_yps[i]).reshape(K, N).T)
new_sol = pybamm.Solution(
all_ts=new_ts,
all_ys=new_ys,
all_yps=new_yps,
all_models=all_models,
all_inputs=solution.all_inputs,
t_event=solution.t_event,
y_event=solution.y_event,
termination=solution.termination,
all_sensitivities=solution._all_sensitivities,
all_t_evals=solution.all_t_evals,
variables_returned=solution.variables_returned,
options=solution.user_options,
)
# Propagate metadata from the original solution
new_sol._all_inputs_stacked = solution.all_inputs_stacked
new_sol._all_inputs_casadi = solution.all_inputs_casadi
new_sol.closest_event_idx = solution.closest_event_idx
new_sol.solve_time = solution.solve_time
new_sol.integration_time = solution.integration_time
new_sol.set_up_time = solution.set_up_time
return new_sol