#
# CasADi Solver class
#
import casadi
import pybamm
import numpy as np
from scipy.interpolate import interp1d
from scipy.optimize import brentq
[docs]class CasadiSolver(pybamm.BaseSolver):
"""Solve a discretised model, using CasADi.
**Extends**: :class:`pybamm.BaseSolver`
Parameters
----------
mode : str
How to solve the model (default is "safe"):
- "fast": perform direct integration, without accounting for events. \
Recommended when simulating a drive cycle or other simulation where \
no events should be triggered.
- "safe": perform step-and-check integration in global steps of size \
dt_max, checking whether events have been triggered. Recommended for \
simulations of a full charge or discharge.
- "old safe": perform step-and-check integration in steps of size dt \
for each dt in t_eval, checking whether events have been triggered.
rtol : float, optional
The relative tolerance for the solver (default is 1e-6).
atol : float, optional
The absolute tolerance for the solver (default is 1e-6).
root_method : str or pybamm algebraic solver class, optional
The method to use to find initial conditions (for DAE solvers).
If a solver class, must be an algebraic solver class.
If "casadi",
the solver uses casadi's Newton rootfinding algorithm to find initial
conditions. Otherwise, the solver uses 'scipy.optimize.root' with method
specified by 'root_method' (e.g. "lm", "hybr", ...)
root_tol : float, optional
The tolerance for root-finding. Default is 1e-6.
max_step_decrease_counts : float, optional
The maximum number of times step size can be decreased before an error is
raised. Default is 5.
dt_max : float, optional
The maximum global step size (in seconds) used in "safe" mode. If None
the default value corresponds to a non-dimensional time of 0.01
(i.e. ``0.01 * model.timescale_eval``).
extra_options_setup : dict, optional
Any options to pass to the CasADi integrator when creating the integrator.
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
details. Some typical options:
- "max_num_steps": Maximum number of integrator steps
extra_options_call : dict, optional
Any options to pass to the CasADi integrator when calling the integrator.
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
details.
"""
def __init__(
self,
mode="safe",
rtol=1e-6,
atol=1e-6,
root_method="casadi",
root_tol=1e-6,
max_step_decrease_count=5,
dt_max=None,
extra_options_setup=None,
extra_options_call=None,
):
super().__init__("problem dependent", rtol, atol, root_method, root_tol)
if mode in ["safe", "fast", "old safe"]:
self.mode = mode
else:
raise ValueError(
"""
invalid mode '{}'. Must be either 'safe' or 'old safe', for solving
with events, or 'fast', for solving quickly without events""".format(
mode
)
)
self.max_step_decrease_count = max_step_decrease_count
self.dt_max = dt_max
self.extra_options_setup = extra_options_setup or {}
self.extra_options_call = extra_options_call or {}
self.name = "CasADi solver with '{}' mode".format(mode)
# Initialize
self.problems = {}
self.options = {}
self.methods = {}
pybamm.citations.register("Andersson2019")
def _integrate(self, model, t_eval, inputs=None):
"""
Solve a DAE model defined by residuals with initial conditions y0.
Parameters
----------
model : :class:`pybamm.BaseModel`
The model whose solution to calculate.
t_eval : numeric type
The times at which to compute the solution
inputs : dict, optional
Any external variables or input parameters to pass to the model when solving
"""
inputs = inputs or {}
# convert inputs to casadi format
inputs = casadi.vertcat(*[x for x in inputs.values()])
if self.mode == "fast":
integrator = self.get_integrator(model, t_eval, inputs)
solution = self._run_integrator(integrator, model, model.y0, inputs, t_eval)
solution.termination = "final time"
return solution
elif not model.events:
pybamm.logger.info("No events found, running fast mode")
integrator = self.get_integrator(model, t_eval, inputs)
solution = self._run_integrator(integrator, model, model.y0, inputs, t_eval)
solution.termination = "final time"
return solution
elif self.mode == "safe":
y0 = model.y0
if isinstance(y0, casadi.DM):
y0 = y0.full().flatten()
# Step-and-check
t = t_eval[0]
t_f = t_eval[-1]
init_event_signs = np.sign(
np.concatenate(
[event(t, y0, inputs) for event in model.terminate_events_eval]
)
)
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))
# Initialize solution
solution = pybamm.Solution(np.array([t]), y0[:, np.newaxis])
solution.solve_time = 0
# Try to integrate in global steps of size dt_max. Note: dt_max must
# be at least as big as the the biggest step in t_eval (multiplied
# by some tolerance, here 1.01) to avoid an empty integration window below
if self.dt_max:
# Non-dimensionalise provided dt_max
dt_max = self.dt_max / model.timescale_eval
else:
dt_max = 0.01
dt_eval_max = np.max(np.diff(t_eval)) * 1.01
dt_max = np.max([dt_max, dt_eval_max])
while t < t_f:
# Step
solved = False
count = 0
dt = dt_max
while not solved:
# Get window of time to integrate over (so that we return
# all the points in t_eval, not just t and t+dt)
t_window = np.concatenate(
([t], t_eval[(t_eval > t) & (t_eval < t + dt)])
)
# Sometimes near events the solver fails between two time
# points in t_eval (i.e. no points t < t_i < t+dt for t_i
# in t_eval), so we simply integrate from t to t+dt
if len(t_window) == 1:
t_window = np.array([t, t + dt])
integrator = self.get_integrator(model, t_window, inputs)
# Try to solve with the current global step, if it fails then
# halve the step size and try again.
try:
current_step_sol = self._run_integrator(
integrator, model, y0, inputs, t_window
)
solved = True
except pybamm.SolverError:
dt /= 2
# also reduce maximum step size for future global steps
dt_max = dt
count += 1
if count >= self.max_step_decrease_count:
raise pybamm.SolverError(
"""
Maximum number of decreased steps occurred at t={}. Try
solving the model up to this time only or reducing dt_max.
""".format(
t
)
)
# Check most recent y to see if any events have been crossed
new_event_signs = np.sign(
np.concatenate(
[
event(t, current_step_sol.y[:, -1], inputs)
for event in model.terminate_events_eval
]
)
)
# Exit loop if the sign of an event changes
# Locate the event time using a root finding algorithm and
# event state using interpolation. The solution is then truncated
# so that only the times up to the event are returned
if (new_event_signs != init_event_signs).any():
# get the index of the events that have been crossed
event_ind = np.where(new_event_signs != init_event_signs)[0]
active_events = [model.terminate_events_eval[i] for i in event_ind]
# create interpolant to evaluate y in the current integration
# window
y_sol = interp1d(current_step_sol.t, current_step_sol.y)
# loop over events to compute the time at which they were triggered
t_events = [None] * len(active_events)
for i, event in enumerate(active_events):
def event_fun(t):
return event(t, y_sol(t), inputs)
if np.isnan(event_fun(current_step_sol.t[-1])[0]):
# bracketed search fails if f(a) or f(b) is NaN, so we
# need to find the times for which we can evaluate the event
times = [
t
for t in current_step_sol.t
if event_fun(t)[0] == event_fun(t)[0]
]
else:
times = current_step_sol.t
# skip if sign hasn't changed
if np.sign(event_fun(times[0])) != np.sign(
event_fun(times[-1])
):
t_events[i] = brentq(
lambda t: event_fun(t), times[0], times[-1]
)
else:
t_events[i] = np.nan
# t_event is the earliest event triggered
t_event = np.nanmin(t_events)
y_event = y_sol(t_event)
# return truncated solution
t_truncated = current_step_sol.t[current_step_sol.t < t_event]
y_trunctaed = current_step_sol.y[:, 0 : len(t_truncated)]
truncated_step_sol = pybamm.Solution(t_truncated, y_trunctaed)
# assign temporary solve time
truncated_step_sol.solve_time = np.nan
# append solution from the current step to solution
solution.append(truncated_step_sol)
solution.termination = "event"
solution.t_event = t_event
solution.y_event = y_event
break
else:
# assign temporary solve time
current_step_sol.solve_time = np.nan
# append solution from the current step to solution
solution.append(current_step_sol)
# update time
t = t_window[-1]
# update y0
y0 = solution.y[:, -1]
return solution
elif self.mode == "old safe":
y0 = model.y0
if isinstance(y0, casadi.DM):
y0 = y0.full().flatten()
# Step-and-check
t = t_eval[0]
init_event_signs = np.sign(
np.concatenate(
[event(t, y0, inputs) for event in model.terminate_events_eval]
)
)
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))
# Initialize solution
solution = pybamm.Solution(np.array([t]), y0[:, np.newaxis])
solution.solve_time = 0
for dt in np.diff(t_eval):
# Step
solved = False
count = 0
while not solved:
integrator = self.get_integrator(
model, np.array([t, t + dt]), inputs
)
# Try to solve with the current step, if it fails then halve the
# step size and try again. This will make solution.t slightly
# different to t_eval, but shouldn't matter too much as it should
# only happen near events.
try:
current_step_sol = self._run_integrator(
integrator, model, y0, inputs, np.array([t, t + dt])
)
solved = True
except pybamm.SolverError:
dt /= 2
count += 1
if count >= self.max_step_decrease_count:
raise pybamm.SolverError(
"""
Maximum number of decreased steps occurred at t={}. Try
solving the model up to this time only.
""".format(
t
)
)
# Check most recent y
new_event_signs = np.sign(
np.concatenate(
[
event(t, current_step_sol.y[:, -1], inputs)
for event in model.terminate_events_eval
]
)
)
# Exit loop if the sign of an event changes
if (new_event_signs != init_event_signs).any():
solution.termination = "event"
solution.t_event = solution.t[-1]
solution.y_event = solution.y[:, -1]
break
else:
# assign temporary solve time
current_step_sol.solve_time = np.nan
# append solution from the current step to solution
solution.append(current_step_sol)
# update time
t += dt
# update y0
y0 = solution.y[:, -1]
return solution
def get_integrator(self, model, t_eval, inputs):
# Only set up problem once
if model not in self.problems:
y0 = model.y0
rhs = model.casadi_rhs
algebraic = model.casadi_algebraic
# When not in DEBUG mode (level=10), suppress warnings from CasADi
if (
pybamm.logger.getEffectiveLevel() == 10
or pybamm.settings.debug_mode is True
):
show_eval_warnings = True
else:
show_eval_warnings = False
options = {
**self.extra_options_setup,
"grid": t_eval,
"reltol": self.rtol,
"abstol": self.atol,
"output_t0": True,
"show_eval_warnings": show_eval_warnings,
}
# set up and solve
t = casadi.MX.sym("t")
p = casadi.MX.sym("p", inputs.shape[0])
y_diff = casadi.MX.sym("y_diff", rhs(t_eval[0], y0, p).shape[0])
problem = {"t": t, "x": y_diff, "p": p}
if algebraic(t_eval[0], y0, p).is_empty():
method = "cvodes"
problem.update({"ode": rhs(t, y_diff, p)})
else:
options["calc_ic"] = True
method = "idas"
y_alg = casadi.MX.sym("y_alg", algebraic(t_eval[0], y0, p).shape[0])
y_full = casadi.vertcat(y_diff, y_alg)
problem.update(
{
"z": y_alg,
"ode": rhs(t, y_full, p),
"alg": algebraic(t, y_full, p),
}
)
self.problems[model] = problem
self.options[model] = options
self.methods[model] = method
else:
# problem stays the same
# just update options
self.options[model]["grid"] = t_eval
return casadi.integrator(
"F", self.methods[model], self.problems[model], self.options[model]
)
def _run_integrator(self, integrator, model, y0, inputs, t_eval):
rhs_size = model.concatenated_rhs.size
y0_diff, y0_alg = np.split(y0, [rhs_size])
try:
# Try solving
sol = integrator(x0=y0_diff, z0=y0_alg, p=inputs, **self.extra_options_call)
y_values = np.concatenate([sol["xf"].full(), sol["zf"].full()])
return pybamm.Solution(t_eval, y_values)
except RuntimeError as e:
# If it doesn't work raise error
raise pybamm.SolverError(e.args[0])