#
# Solver class using Scipy's adaptive time stepper
#
import numpy as onp
import asyncio
import pybamm
if pybamm.has_jax():
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
[docs]
class JaxSolver(pybamm.BaseSolver):
"""
Solve a discretised model using a JAX compiled solver.
**Note**: this solver will not work with models that have
termination events or are not converted to jax format
Raises
------
RuntimeError
if model has any termination events
RuntimeError
if `model.convert_to_format != 'jax'`
Parameters
----------
method: str, optional (see `jax.experimental.ode.odeint` for details)
* 'BDF' (default) uses custom jax_bdf_integrate (see `jax_bdf_integrate.py` for details)
* 'RK45' uses jax.experimental.ode.odeint
root_method: str, optional
Method to use to calculate consistent initial conditions. By default, this uses
the newton chord method internal to the jax bdf solver, otherwise choose from
the set of default options defined in docs for pybamm.BaseSolver
rtol : float, optional
The relative tolerance for the solver (default is 1e-6).
atol : float, optional
The absolute tolerance for the solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
extra_options : dict, optional
Any options to pass to the solver.
Please consult `JAX documentation
<https://github.com/jax-ml/jax/blob/master/jax/experimental/ode.py>`_
for details.
"""
def __init__(
self,
method="BDF",
root_method=None,
rtol=1e-6,
atol=1e-6,
extrap_tol=None,
extra_options=None,
):
if not pybamm.has_jax():
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
# note: bdf solver itself calculates consistent initial conditions so can set
# root_method to none, allow user to override this behavior
super().__init__(
method, rtol, atol, root_method=root_method, extrap_tol=extrap_tol
)
method_options = ["RK45", "BDF"]
if method not in method_options:
raise ValueError(f"method must be one of {method_options}")
self._ode_solver = method == "RK45"
self.extra_options = extra_options or {}
self.name = f"JAX solver ({method})"
self._cached_solves = dict()
pybamm.citations.register("jax2018")
[docs]
def get_solve(self, model, t_eval):
"""
Return a compiled JAX function that solves an ode model with input arguments.
Parameters
----------
model : :class:`pybamm.BaseModel`
The model whose solution to calculate.
t_eval : :class:`numpy.array`, size (k,)
The times at which to compute the solution
Returns
-------
function
A function with signature `f(inputs)`, where inputs are a dict containing
any input parameters to pass to the model when solving
"""
if model not in self._cached_solves:
if model not in self._model_set_up:
raise RuntimeError(
"Model is not set up for solving, run" "`solver.solve(model)` first"
)
self._cached_solves[model] = self.create_solve(model, t_eval)
return self._cached_solves[model]
[docs]
def create_solve(self, model, t_eval):
"""
Return a compiled JAX function that solves an ode model with input arguments.
Parameters
----------
model : :class:`pybamm.BaseModel`
The model whose solution to calculate.
t_eval : :class:`numpy.array`, size (k,)
The times at which to compute the solution
Returns
-------
function
A function with signature `f(inputs)`, where inputs are a dict containing
any input parameters to pass to the model when solving
"""
if model.convert_to_format != "jax":
raise RuntimeError(
"Model must be converted to JAX to use this solver"
" (i.e. `model.convert_to_format = 'jax')"
)
if model.terminate_events_eval:
raise RuntimeError(
"Terminate events not supported for this solver."
" Model has the following events:"
f" {model.events}.\nYou can remove events using `model.events = []`."
" It might be useful to first solve the model using a"
" different solver to obtain the time of the event, then"
" re-solve using no events and a fixed"
" end-time"
)
# Initial conditions, make sure they are an 0D array
y0 = jnp.array(model.y0).reshape(-1)
mass = None
if self.method == "BDF":
mass = model.mass_matrix.entries.toarray()
def rhs_ode(y, t, inputs):
return (model.rhs_eval(t, y, inputs),)
def rhs_dae(y, t, inputs):
return jnp.concatenate(
[model.rhs_eval(t, y, inputs), model.algebraic_eval(t, y, inputs)]
)
def solve_model_rk45(inputs):
y = odeint(
rhs_ode,
y0,
t_eval,
inputs,
rtol=self.rtol,
atol=self.atol,
**self.extra_options,
)
return jnp.transpose(y)
def solve_model_bdf(inputs):
y = pybamm.jax_bdf_integrate(
rhs_dae,
y0,
t_eval,
inputs,
rtol=self.rtol,
atol=self.atol,
mass=mass,
**self.extra_options,
)
return jnp.transpose(y)
if self.method == "RK45":
return jax.jit(solve_model_rk45)
else:
return jax.jit(solve_model_bdf)
@property
def supports_parallel_solve(self):
return True
@property
def requires_explicit_sensitivities(self):
return False
def _integrate(self, model, t_eval, inputs=None, t_interp=None):
"""
Solve a model defined by dydt with initial conditions y0.
Parameters
----------
model : :class:`pybamm.BaseModel`
The model whose solution to calculate.
t_eval : :class:`numpy.array`, size (k,)
The times at which to compute the solution
inputs : dict, list[dict], optional
Any input parameters to pass to the model when solving
Returns
-------
list of `pybamm.Solution`
An object containing the times and values of the solution, as well as
various diagnostic messages.
"""
if isinstance(inputs, dict):
inputs = [inputs]
timer = pybamm.Timer()
if model not in self._cached_solves:
self._cached_solves[model] = self.create_solve(model, t_eval)
y = []
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
if len(inputs) <= 1 or platform.startswith("cpu"):
# cpu execution runs faster when multithreaded
async def solve_model_for_inputs():
async def solve_model_async(inputs_v):
return self._cached_solves[model](inputs_v)
coro = []
for inputs_v in inputs:
coro.append(asyncio.create_task(solve_model_async(inputs_v)))
return await asyncio.gather(*coro)
y = asyncio.run(solve_model_for_inputs())
elif (
platform.startswith("gpu")
or platform.startswith("tpu")
or platform.startswith("metal")
):
# gpu execution runs faster when parallelised with vmap
# (see also comment below regarding single-program multiple-data
# execution (SPMD) using pmap on multiple XLAs)
# convert inputs (array of dict) to a dict of arrays for vmap
inputs_v = {
key: jnp.array([dic[key] for dic in inputs]) for key in inputs[0]
}
y.extend(jax.vmap(self._cached_solves[model])(inputs_v))
else:
# Unknown platform, use serial execution as fallback
print(
f'Unknown platform requested: "{platform}", '
"falling back to serial execution"
)
for inputs_v in inputs:
y.append(self._cached_solves[model](inputs_v))
# This code block implements single-program multiple-data execution
# using pmap across multiple XLAs. It is currently commented out
# because it produces bus errors for even moderate-sized models.
# It is suspected that this is due to either a bug in JAX, insufficient
# sparse matrix support in JAX resulting in high memory usage, or a bug
# in the BDF solver.
#
# This issue on GitHub appears related:
# https://github.com/jax-ml/jax/discussions/13930
#
# # Split input list based on the number of available xla devices
# device_count = jax.local_device_count()
# inputs_listoflists = [inputs[x:x + device_count]
# for x in range(0, len(inputs), device_count)]
# if len(inputs_listoflists) > 1:
# print(f"{len(inputs)} parameter sets were provided, "
# f"but only {device_count} XLA devices are available")
# print(f"Parameter sets split into {len(inputs_listoflists)} "
# "lists for parallel processing")
# y = []
# for k, inputs_list in enumerate(inputs_listoflists):
# if len(inputs_listoflists) > 1:
# print(f" Solving list {k+1} of {len(inputs_listoflists)} "
# f"({len(inputs_list)} parameter sets)")
# # convert inputs to a dict of arrays for pmap
# inputs_v = {key: jnp.array([dic[key] for dic in inputs_list])
# for key in inputs_list[0]}
# y.extend(jax.pmap(self._cached_solves[model])(inputs_v))
integration_time = timer.time()
# convert to a normal numpy array
y = onp.array(y)
termination = "final time"
t_event = None
y_event = onp.array(None)
# Extract solutions from y with their associated input dicts
solutions = []
for k, inputs_dict in enumerate(inputs):
sol = pybamm.Solution(
t_eval,
jnp.reshape(y[k,], y.shape[1:]),
model,
inputs_dict,
t_event,
y_event,
termination,
)
sol.integration_time = integration_time
solutions.append(sol)
return solutions