Source code for pybamm.solvers.casadi_solver

import warnings

import casadi
import numpy as np
from scipy.interpolate import interp1d

import pybamm

from .lrudict import LRUDict


[docs] class CasadiSolver(pybamm.BaseSolver): """Solve a discretised model, using CasADi. Parameters ---------- mode : str How to solve the model (default is "safe"): - "fast": perform direct integration, without accounting for events. \ Recommended when simulating a drive cycle or other simulation where \ no events should be triggered. - "fast with events": perform direct integration of the whole timespan, \ then go back and check where events were crossed. Experimental only. - "safe": perform step-and-check integration in global steps of size \ dt_max, checking whether events have been triggered. Recommended for \ simulations of a full charge or discharge. - "safe without grid": perform step-and-check integration step-by-step. \ Takes more steps than "safe" mode, but doesn't require creating the grid \ each time, so may be faster. Experimental only. rtol : float, optional The relative tolerance for the solver (default is 1e-6). atol : float, optional The absolute tolerance for the solver (default is 1e-6). root_method : str or pybamm algebraic solver class, optional The method to use to find initial conditions (for DAE solvers). Default is "nonlinear_solver", which uses a custom Newton solver for consistent initial conditions (recommended). If "casadi", the solver uses casadi's Newton rootfinding algorithm as a fallback. Otherwise, the solver uses 'scipy.optimize.root' with method specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for root-finding. Default is 1e-6. max_step_decrease_count : float, optional The maximum number of times step size can be decreased before an error is raised. Default is 5. dt_max : float, optional The maximum global step size (in seconds) used in "safe" mode. If None the default value is 600 seconds. extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not. Default is 0. on_extrapolation : str, optional What to do if the solver is extrapolating. Options are "warn", "error", or "ignore". Default is "error". extra_options_setup : dict, optional Any options to pass to the CasADi integrator when creating the integrator. Please consult `CasADi documentation <https://web.casadi.org/python-api/#integrator>`_ for details. Some useful options: - "max_num_steps": Maximum number of integrator steps - "print_stats": Print out statistics after integration extra_options_call : dict, optional Any options to pass to the CasADi integrator when calling the integrator. Please consult `CasADi documentation <https://web.casadi.org/python-api/#integrator>`_ for details. return_solution_if_failed_early : bool, optional Whether to return a Solution object if the solver fails to reach the end of the simulation, but managed to take some successful steps. Default is False. perturb_algebraic_initial_conditions : bool, optional Whether to perturb algebraic initial conditions to avoid a singularity. This can sometimes slow down the solver, but is kept True as default for "safe" mode as it seems to be more robust (False by default for other modes). integrators_maxcount : int, optional The maximum number of integrators that the solver will retain before ejecting past integrators using an LRU methodology. A value of 0 or None leaves the number of integrators unbound. Default is 100. """ def __init__( self, mode="safe", rtol=1e-6, atol=1e-6, root_method="nonlinear_solver", root_tol=1e-6, max_step_decrease_count=5, dt_max=None, extrap_tol=None, on_extrapolation=None, extra_options_setup=None, extra_options_call=None, return_solution_if_failed_early=False, perturb_algebraic_initial_conditions=None, integrators_maxcount=100, store_first_last=False, ): on_extrapolation = on_extrapolation or "error" super().__init__( method="problem dependent", rtol=rtol, atol=atol, root_method=root_method, root_tol=root_tol, extrap_tol=extrap_tol, on_extrapolation=on_extrapolation, store_first_last=store_first_last, ) if mode in ["safe", "fast", "fast with events", "safe without grid"]: self.mode = mode else: raise ValueError( f"invalid mode '{mode}'. Must be 'safe', for solving with events, " "'fast', for solving quickly without events, or 'safe without grid' or " "'fast with events' (both experimental)" ) self.max_step_decrease_count = max_step_decrease_count self.dt_max = dt_max or 600 self.extra_options_setup = extra_options_setup or {} self.extra_options_call = extra_options_call or {} self.return_solution_if_failed_early = return_solution_if_failed_early # Decide whether to perturb algebraic initial conditions, True by default for # "safe" mode, False by default for other modes if perturb_algebraic_initial_conditions is None: if mode == "safe": self.perturb_algebraic_initial_conditions = True else: self.perturb_algebraic_initial_conditions = False else: self.perturb_algebraic_initial_conditions = ( perturb_algebraic_initial_conditions ) self.name = f"CasADi solver with '{mode}' mode" # Initialize self.integrators_maxcount = integrators_maxcount self.integrators = LRUDict(maxsize=self.integrators_maxcount) self.integrator_specs = LRUDict(maxsize=self.integrators_maxcount) self.y_sols = {} pybamm.citations.register("Andersson2019") def _integrate_single(self, model, t_eval, inputs_dict, y0): """ Solve a single DAE model defined by residuals with initial conditions y0. Parameters ---------- model : :class:`pybamm.BaseModel` The model whose solution to calculate. t_eval : numeric type The times at which to compute the solution inputs_dict : dict, optional Any input parameters to pass to the model when solving y0 : array-like The initial conditions for the model Returns ------- :class:`pybamm.Solution` A Solution object containing the times and values of the solution, as well as various diagnostic messages. """ # casadi solver does not support sensitivity analysis if model.calculate_sensitivities: raise NotImplementedError( "Sensitivity analysis is not implemented for the CasADi solver." ) # Record whether there are any symbolic inputs inputs_dict = inputs_dict or {} # convert inputs to casadi format inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) if self.mode in ["fast", "fast with events"] or not model.events: if not model.events: pybamm.logger.info("No events found, running fast mode") if self.mode == "fast with events": # Create the integrator with an event switch that will set the rhs to # zero when voltage limits are crossed use_event_switch = True else: use_event_switch = False # Create an integrator with the grid (we just need to do this once) self.create_integrator( model, y0, inputs, t_eval, use_event_switch=use_event_switch ) solution = self._run_integrator(model, y0, inputs_dict, inputs, t_eval) # Check if the sign of an event changes, if so find an accurate # termination point and exit solution = self._solve_for_event(solution, y0) solution.check_ys_are_not_too_large() return solution elif self.mode in ["safe", "safe without grid"]: # Step-and-check t = t_eval[0] t_f = t_eval[-1] pybamm.logger.debug(f"Start solving {model.name} with {self.name}") if self.mode == "safe without grid": # in "safe without grid" mode, # create integrator once, without grid, # to avoid having to create several times self.create_integrator(model, y0, inputs) # Initialize solution. Coerce y0 to a CasADi DM so that when # subsequent CasADi integration segments are appended, the # concatenated ``solution.y`` stays a CasADi type. When the # root method is NonlinearSolver, y0 arrives as an ndarray. y0_init = y0 if isinstance(y0, casadi.DM | casadi.MX) else casadi.DM(y0) solution = pybamm.Solution( np.array([t]), y0_init, model, inputs_dict, ) solution.solve_time = 0 solution.integration_time = 0 use_grid = False else: solution = None use_grid = True # Try to integrate in global steps of size dt_max. Note: dt_max must # be at least as big as the the biggest step in t_eval (multiplied # by some tolerance, here 1.01) to avoid an empty integration window below dt_max = self.dt_max dt_eval_max = np.max(np.diff(t_eval)) * 1.01 if dt_max < dt_eval_max: pybamm.logger.debug( "Setting dt_max to be as big as the largest step in " f"t_eval ({dt_eval_max})" ) dt_max = dt_eval_max termination_due_to_small_dt = False first_ts_solved = False while t < t_f: # Step solved = False count = 0 dt = dt_max while not solved: # Get window of time to integrate over (so that we return # all the points in t_eval, not just t and t+dt) t_window = np.concatenate( ([t], t_eval[(t_eval > t) & (t_eval < t + dt)]) ) # Sometimes near events the solver fails between two time # points in t_eval (i.e. no points t < t_i < t+dt for t_i # in t_eval), so we simply integrate from t to t+dt if len(t_window) == 1: t_window = np.array([t, t + dt]) if self.mode == "safe": # update integrator with the grid self.create_integrator(model, y0, inputs, t_window) # Try to solve with the current global step, if it fails then # halve the step size and try again. try: pybamm.logger.debug( "Running integrator for " f"{t_window[0]:.2f} < t < {t_window[-1]:.2f}" ) current_step_sol = self._run_integrator( model, y0, inputs_dict, inputs, t_window, use_grid=use_grid, ) first_ts_solved = True solved = True except pybamm.SolverError as error: pybamm.logger.debug("Failed, halving step size") dt /= 2 count += 1 # also reduce maximum step size for future global steps, # but skip them in the beginning # sometimes, for the first integrator smaller timesteps are # needed, but this won't affect the global timesteps. The # global timestep will only be reduced after the first timestep. if first_ts_solved: dt_max = dt if count > self.max_step_decrease_count: message = ( "Maximum number of decreased steps occurred at " f"t={t} (final SolverError: '{error}'). " "For a full solution try reducing dt_max (currently, " f"dt_max={dt_max}) and/or reducing the size of the " "time steps or period of the experiment." ) if first_ts_solved and self.return_solution_if_failed_early: warnings.warn( message, pybamm.SolverWarning, stacklevel=2 ) termination_due_to_small_dt = True break else: raise pybamm.SolverError( message + " Set `return_solution_if_failed_early=True` to " "return the solution object up to the point where " "failure occured." ) from error if termination_due_to_small_dt: break # Check if the sign of an event changes, if so find an accurate # termination point and exit current_step_sol = self._solve_for_event(current_step_sol, y0) # assign temporary solve time current_step_sol.solve_time = np.nan # append solution from the current step to solution solution = solution + current_step_sol if current_step_sol.termination == "event": break else: # update time as time # from which to start the new casadi integrator t = t_window[-1] # update y0 as initial_values # from which to start the new casadi integrator y0 = solution.all_ys[-1][:, -1] solution.check_ys_are_not_too_large() return solution def _solve_for_event(self, coarse_solution, y0): """ Check if the sign of an event changes, if so find an accurate termination point and exit Locate the event time using a root finding algorithm and event state using interpolation. The solution is then truncated so that only the times up to the event are returned """ pybamm.logger.debug("Solving for events") model = coarse_solution.all_models[-1] inputs_dict = coarse_solution.all_inputs[-1] inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) def find_t_event(sol, typ): # Check most recent y to see if any events have been crossed if model.terminate_events_eval: y_last = sol.all_ys[-1][:, -1] crossed_events = np.sign( np.concatenate( [ event(sol.t[-1], y_last, inputs) for event in model.terminate_events_eval ] ) - 1e-5 ) else: crossed_events = np.sign([]) # Return None if no events have been triggered if (crossed_events == 1).all(): return None, None, None # get the index of the events that have been crossed event_idx = np.where(crossed_events != 1)[0] active_events = [model.terminate_events_eval[i] for i in event_idx] # loop over events to compute the time at which they were triggered t_events = [None] * len(active_events) event_idcs_lower = [None] * len(active_events) for i, event in enumerate(active_events): # Implement our own bisection algorithm for speed # This is used to find the time range in which the event is triggered # Evaluations of the "event" function are (relatively) expensive f_eval = {} def f(idx, f_eval=f_eval, event=event): try: return f_eval[idx] except KeyError: # We take away 1e-5 to deal with the case where the event sits # exactly on zero, as can happen when the event switch is used # (fast with events mode) f_eval[idx] = event(sol.t[idx], sol.y[:, idx], inputs) - 1e-5 return f_eval[idx] def integer_bisect(): a_n = 0 b_n = len(sol.t) - 1 for _ in range(len(sol.t)): if a_n + 1 == b_n: return a_n m_n = (a_n + b_n) // 2 f_m_n = f(m_n) if np.isnan(f_m_n): a_n = a_n b_n = m_n elif f_m_n < 0: a_n = a_n b_n = m_n elif f_m_n > 0: a_n = m_n b_n = b_n event_idx_lower = integer_bisect() if typ == "window": event_idcs_lower[i] = event_idx_lower elif typ == "exact": # Linear interpolation between the two indices to find the root time # We could do cubic interpolation here instead but it would be # slower t_lower = sol.t[event_idx_lower] t_upper = sol.t[event_idx_lower + 1] event_lower = abs(f(event_idx_lower)) event_upper = abs(f(event_idx_lower + 1)) t_events[i] = (event_lower * t_upper + event_upper * t_lower) / ( event_lower + event_upper ) if typ == "window": event_idx_lower = np.nanmin(event_idcs_lower) return event_idx_lower, None, None elif typ == "exact": # t_event is the earliest event triggered t_event = np.nanmin(t_events) # create interpolant to evaluate y in the current integration # window y_sol = interp1d(sol.t, sol.y, kind="linear") y_event = y_sol(t_event) closest_event_idx = event_idx[np.nanargmin(t_events)] return t_event, y_event, closest_event_idx # Find the interval in which the event was triggered event_idx_lower, _, _ = find_t_event(coarse_solution, "window") # Return the existing solution if no events have been triggered if event_idx_lower is None: # Flag "final time" for termination coarse_solution.termination = "final time" return coarse_solution # If events have been triggered, we solve for a dense window in the interval # where the event was triggered, then find the precise location of the event # Solve again with a more dense idx_window, starting from the start of the # window where the event was triggered t_window_event_dense = np.linspace( coarse_solution.t[event_idx_lower], coarse_solution.t[event_idx_lower + 1], 100, ) if self.mode == "safe without grid": use_grid = False else: self.create_integrator(model, y0, inputs, t_window_event_dense) use_grid = True y0 = coarse_solution.y[:, event_idx_lower] dense_step_sol = self._run_integrator( model, y0, inputs_dict, inputs, t_window_event_dense, use_grid=use_grid, ) # Find the exact time at which the event was triggered t_event, y_event, closest_event_idx = find_t_event(dense_step_sol, "exact") # If this returns None, no event was crossed in dense_step_sol. This can happen # if the event crossing was right at the end of the interval in the coarse # solution. In this case, return the t and y from the end of the interval # (i.e. next point in the coarse solution) if y_event is None: # pragma: no cover # This is extremely rare, it's difficult to find a test that triggers this # hence no coverage check t_event = coarse_solution.t[event_idx_lower + 1] y_event = coarse_solution.y[:, event_idx_lower + 1].full().flatten() # Return solution truncated at the first coarse event time # Also assign t_event t_sol = coarse_solution.t[: event_idx_lower + 1] y_sol = coarse_solution.y[:, : event_idx_lower + 1] solution = pybamm.Solution( t_sol, y_sol, model, inputs_dict, np.array([t_event]), y_event[:, np.newaxis], "event", all_t_evals=t_sol, ) solution.integration_time = ( coarse_solution.integration_time + dense_step_sol.integration_time ) solution.closest_event_idx = closest_event_idx return solution
[docs] def create_integrator(self, model, y0, inputs, t_eval=None, use_event_switch=False): """ Method to create a casadi integrator object. If t_eval is provided, the integrator uses t_eval to make the grid. Otherwise, the integrator has grid [0,1]. """ pybamm.logger.debug("Creating CasADi integrator") # Use grid if t_eval is given use_grid = t_eval is not None if use_grid is True: t_eval_shifted = t_eval - t_eval[0] t_eval_shifted_rounded = np.round(t_eval_shifted, decimals=12).tobytes() # Only set up problem once if model in self.integrators: # If we're not using the grid, we don't need to change the integrator if use_grid is False: return self.integrators[model]["no grid"] # Otherwise, create new integrator with an updated grid # We don't need to update the grid if reusing the same t_eval # (up to a shift by a constant) else: if t_eval_shifted_rounded in self.integrators[model]: return self.integrators[model][t_eval_shifted_rounded] else: method, problem, options, time_args = self.integrator_specs[model] time_args = [t_eval_shifted[0], t_eval_shifted[1:]] integrator = casadi.integrator( "F", method, problem, *time_args, options ) self.integrators[model][t_eval_shifted_rounded] = integrator return integrator else: rhs = model.casadi_rhs algebraic = model.casadi_algebraic options = { "show_eval_warnings": False, **self.extra_options_setup, "reltol": self.rtol, "abstol": self.atol, } # set up and solve t = casadi.MX.sym("t") p = casadi.MX.sym("p", inputs.shape[0]) y_diff = casadi.MX.sym("y_diff", rhs(0, y0, p).shape[0]) y_alg = casadi.MX.sym("y_alg", algebraic(0, y0, p).shape[0]) y_full = casadi.vertcat(y_diff, y_alg) if use_grid is False: time_args = [] # rescale time t_min = casadi.MX.sym("t_min") t_max = casadi.MX.sym("t_max") t_max_minus_t_min = t_max - t_min t_scaled = t_min + (t_max - t_min) * t # add time limits as inputs p_with_tlims = casadi.vertcat(p, t_min, t_max) else: time_args = [t_eval_shifted[0], t_eval_shifted[1:]] # rescale time t_min = casadi.MX.sym("t_min") # Set dummy parameters for consistency with rescaled time t_max_minus_t_min = 1 t_scaled = t_min + t p_with_tlims = casadi.vertcat(p, t_min) # define the event switch as the point when an event is crossed # we don't do this for ODE models # see #1082 event_switch = 1 if use_event_switch is True and not algebraic(0, y0, p).is_empty(): for event in model.casadi_switch_events: event_switch *= event(t_scaled, y_full, p) problem = { "t": t, "x": y_diff, # rescale rhs by (t_max - t_min) "ode": (t_max_minus_t_min) * rhs(t_scaled, y_full, p) * event_switch, "p": p_with_tlims, } if algebraic(0, y0, p).is_empty(): method = "cvodes" else: method = "idas" problem.update( { "z": y_alg, "alg": algebraic(t_scaled, y_full, p), } ) integrator = casadi.integrator("F", method, problem, *time_args, options) self.integrator_specs[model] = method, problem, options, time_args if use_grid is False: self.integrators[model] = {"no grid": integrator} else: self.integrators[model] = {t_eval_shifted_rounded: integrator} return integrator
def _run_integrator( self, model, y0, inputs_dict, inputs, t_eval, use_grid=True, ): """ Run the integrator. Parameters ---------- model : :class:`pybamm.BaseModel` The model whose solution to calculate. y0: casadi vector of initial conditions inputs_dict : dict, optional Any input parameters to pass to the model when solving inputs: Casadi vector of inputs t_eval : numeric type The times at which to compute the solution use_grid: bool, optional Determines whether the casadi solver uses a grid or rescales time to (0,1) """ pybamm.logger.debug("Running CasADi integrator") if use_grid is True: pybamm.logger.spam("Calculating t_eval_shifted") t_eval_shifted = t_eval - t_eval[0] t_eval_shifted_rounded = np.round(t_eval_shifted, decimals=12).tobytes() pybamm.logger.spam("Finished calculating t_eval_shifted") integrator = self.integrators[model][t_eval_shifted_rounded] else: integrator = self.integrators[model]["no grid"] len_rhs = model.concatenated_rhs.size len_alg = model.concatenated_algebraic.size y0_diff = y0[:len_rhs] y0_alg_exact = y0[len_rhs:] if self.perturb_algebraic_initial_conditions and len_alg > 0: # Add a tiny perturbation to the algebraic initial conditions # For some reason this helps with convergence # The actual value of the initial conditions for the algebraic variables # doesn't matter y0_alg = y0_alg_exact * (1 + 1e-6 * casadi.DM(np.random.rand(len_alg))) else: y0_alg = y0_alg_exact pybamm.logger.spam("Finished preliminary setup for integrator run") # Solve # Try solving if use_grid is True: t_min = t_eval[0] inputs_with_tmin = casadi.vertcat(inputs, t_min) # Call the integrator once, with the grid timer = pybamm.Timer() pybamm.logger.debug("Calling casadi integrator") try: casadi_sol = integrator( x0=y0_diff, z0=y0_alg, p=inputs_with_tmin, **self.extra_options_call ) except RuntimeError as error: # If it doesn't work raise error pybamm.logger.debug(f"Casadi integrator failed with error {error}") raise pybamm.SolverError(error.args[0]) from error pybamm.logger.debug("Finished casadi integrator") integration_time = timer.time() # Manually add initial conditions and concatenate x_sol = casadi.horzcat(y0_diff, casadi_sol["xf"]) if len_alg > 0: z_sol = casadi.horzcat(y0_alg_exact, casadi_sol["zf"]) y_sol = casadi.vertcat(x_sol, z_sol) else: y_sol = x_sol sol = pybamm.Solution( t_eval, y_sol, model, inputs_dict, check_solution=False, all_t_evals=t_eval, ) sol.integration_time = integration_time return sol else: # Repeated calls to the integrator x = y0_diff z = y0_alg_exact y_diff = x y_alg = z for i in range(len(t_eval) - 1): t_min = t_eval[i] t_max = t_eval[i + 1] inputs_with_tlims = casadi.vertcat(inputs, t_min, t_max) timer = pybamm.Timer() try: casadi_sol = integrator( x0=x, z0=z, p=inputs_with_tlims, **self.extra_options_call ) except RuntimeError as error: # If it doesn't work raise error pybamm.logger.debug(f"Casadi integrator failed with error {error}") raise pybamm.SolverError(error.args[0]) from error integration_time = timer.time() x = casadi_sol["xf"] z = casadi_sol["zf"] y_diff = casadi.horzcat(y_diff, x) if not z.is_empty(): y_alg = casadi.horzcat(y_alg, z) if z.is_empty(): y_sol = y_diff else: y_sol = casadi.vertcat(y_diff, y_alg) sol = pybamm.Solution( t_eval, y_sol, model, inputs_dict, all_t_evals=t_eval, check_solution=False, ) sol.integration_time = integration_time return sol