Previous topic

Scipy Solver

Next topic

Scikits.odes Solvers

This Page

JAX Solver

class pybamm.JaxSolver(method='RK45', root_method=None, rtol=1e-06, atol=1e-06, extrap_tol=0, extra_options=None)[source]

Solve a discretised model using a JAX compiled solver.

Note: this solver will not work with models that have

termination events or are not converted to jax format

Raises
Parameters
  • method (str) – ‘RK45’ (default) uses jax.experimental.odeint ‘BDF’ uses custom jax_bdf_integrate (see jax_bdf_integrate.py for details)

  • root_method (str, optional) – Method to use to calculate consistent initial conditions. By default this uses the newton chord method internal to the jax bdf solver, otherwise choose from the set of default options defined in docs for pybamm.BaseSolver

  • rtol (float, optional) – The relative tolerance for the solver (default is 1e-6).

  • atol (float, optional) – The absolute tolerance for the solver (default is 1e-6).

  • extrap_tol (float, optional) – The tolerance to assert whether extrapolation occurs or not (default is 0).

  • extra_options (dict, optional) – Any options to pass to the solver. Please consult JAX documentation for details.

create_solve(model, t_eval)[source]

Return a compiled JAX function that solves an ode model with input arguments.

Parameters
  • model (pybamm.BaseModel) – The model whose solution to calculate.

  • t_eval (numpy.array, size (k,)) – The times at which to compute the solution

Returns

A function with signature f(inputs), where inputs are a dict containing any input parameters to pass to the model when solving

Return type

function

get_solve(model, t_eval)[source]

Return a compiled JAX function that solves an ode model with input arguments.

Parameters
  • model (pybamm.BaseModel) – The model whose solution to calculate.

  • t_eval (numpy.array, size (k,)) – The times at which to compute the solution

Returns

A function with signature f(inputs), where inputs are a dict containing any input parameters to pass to the model when solving

Return type

function

pybamm.jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-06, atol=1e-06, mass=None)[source]

Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm is derived in 2. This particular implementation follows that implemented in the Matlab routine ode15s described in 1 and the SciPy implementation 3, which features the NDF formulas for improved stability, with associated differences in the error constants, and calculates the jacobian at J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the scipy library 3, which also mainly follows 1 but uses the more standard jacobian update.

Parameters
  • func (callable) – function to evaluate the time derivative of the solution y at time t as func(y, t, *args), producing the same shape/structure as y0.

  • y0 (ndarray) – initial state vector

  • t_eval (ndarray) – time points to evaluate the solution, has shape (m,)

  • args ((optional)) – tuple of additional arguments for fun, which must be arrays scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types.

  • rtol ((optional) float) – relative tolerance for the solver

  • atol ((optional) float) – absolute tolerance for the solver

  • mass ((optional) ndarray) – diagonal of the mass matrix with shape (n,)

Returns

y – calculated state vector at each of the m time points

Return type

ndarray with shape (n, m)

References

1(1,2)

L. F. Shampine, M. W. Reichelt, “THE MATLAB ODE SUITE”, SIAM J. SCI. COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997.

2

G. D. Byrne, A. C. Hindmarsh, “A Polyalgorithm for the Numerical Solution of Ordinary Differential Equations”, ACM Transactions on Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975.

3(1,2)

Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, T., Cournapeau, D., … & van der Walt, S. J. (2020). SciPy 1.0: fundamental algorithms for scientific computing in Python. Nature methods, 17(3), 261-272.