JAX Solver#
- class pybamm.JaxSolver(method='RK45', root_method=None, rtol=1e-06, atol=1e-06, extrap_tol=None, extra_options=None)[source]#
Solve a discretised model using a JAX compiled solver.
- Note: this solver will not work with models that have
termination events or are not converted to jax format
- Raises:
RuntimeError – if model has any termination events
RuntimeError – if model.convert_to_format != ‘jax’
- Parameters:
method (str) – ‘RK45’ (default) uses jax.experimental.odeint ‘BDF’ uses custom jax_bdf_integrate (see jax_bdf_integrate.py for details)
root_method (str, optional) – Method to use to calculate consistent initial conditions. By default this uses the newton chord method internal to the jax bdf solver, otherwise choose from the set of default options defined in docs for pybamm.BaseSolver
rtol (float, optional) – The relative tolerance for the solver (default is 1e-6).
atol (float, optional) – The absolute tolerance for the solver (default is 1e-6).
extrap_tol (float, optional) – The tolerance to assert whether extrapolation occurs or not (default is 0).
extra_options (dict, optional) – Any options to pass to the solver. Please consult JAX documentation for details.
Extends:
pybamm.solvers.base_solver.BaseSolver
- create_solve(model, t_eval)[source]#
Return a compiled JAX function that solves an ode model with input arguments.
- Parameters:
model (
pybamm.BaseModel
) – The model whose solution to calculate.t_eval (
numpy.array
, size (k,)) – The times at which to compute the solution
- Returns:
A function with signature f(inputs), where inputs are a dict containing any input parameters to pass to the model when solving
- Return type:
function
- get_solve(model, t_eval)[source]#
Return a compiled JAX function that solves an ode model with input arguments.
- Parameters:
model (
pybamm.BaseModel
) – The model whose solution to calculate.t_eval (
numpy.array
, size (k,)) – The times at which to compute the solution
- Returns:
A function with signature f(inputs), where inputs are a dict containing any input parameters to pass to the model when solving
- Return type:
function
- pybamm.jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-06, atol=1e-06, mass=None)[source]#
Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm is derived in Byrne and Hindmarsh[1]. This particular implementation follows that implemented in the Matlab routine ode15s described in Shampine and Reichelt[2] and the SciPy implementation Virtanen et al.[3] which features the NDF formulas for improved stability, with associated differences in the error constants, and calculates the jacobian at J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the SciPy library Virtanen et al.[3], which also mainly follows Shampine and Reichelt[2] but uses the more standard jacobian update.
- Parameters:
func (callable) – function to evaluate the time derivative of the solution y at time t as func(y, t, *args), producing the same shape/structure as y0.
y0 (ndarray) – initial state vector
t_eval (ndarray) – time points to evaluate the solution, has shape (m,)
args ((optional)) – tuple of additional arguments for fun, which must be arrays scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types.
rtol ((optional) float) – relative tolerance for the solver
atol ((optional) float) – absolute tolerance for the solver
mass ((optional) ndarray) – diagonal of the mass matrix with shape (n,)
- Returns:
y – calculated state vector at each of the m time points
- Return type:
ndarray with shape (n, m)
References