IDAKLU-JAX Interface#
- class pybamm.IDAKLUJax(solver, model, t_eval, output_variables=None, calculate_sensitivities=True, t_interp=None)[source]#
JAX wrapper for IDAKLU solver
Objects of this class should be created via an IDAKLUSolver object.
Log information is available for this module via the named ‘pybamm.solvers.idaklu_jax’ logger.
- Parameters:
solver (
pybamm.IDAKLUSolver
) – The IDAKLU solver object to be wrapped
- get_jaxpr()[source]#
Returns a JAX expression representing the IDAKLU-wrapped solver object
- Returns:
- A JAX expression with the following call signature:
f(t, inputs=None)
- where:
- tfloat | np.ndarray
Time sample or vector of time samples
- inputsdict, optional
dictionary of input values, e.g. {‘Current function [A]’: 0.222, ‘Separator porosity’: 0.3}
- Return type:
Callable
- get_var(*args)[source]#
Helper function to extract a single variable
Isolates a single variable from the model output. Can be called on a JAX expression (which returns a JAX expression), or on a numeric (np.ndarray) object (which returns a slice of the output).
Example call using default JAX expression, returns a JAX expression:
f = idaklu_jax.get_var("Voltage [V]") data = f(t, inputs=None)
Example call using a custom function, returns a JAX expression:
f = idaklu_jax.get_var(jax.jit(f), "Voltage [V]") data = f(t, inputs=None)
Example call to slice a matrix, returns an np.array:
data = idaklu_jax.get_var( jax.fwd(f, argnums=1)(t_eval, inputs)['Current function [A]'], 'Voltage [V]' )
- Parameters:
f (Callable | np.ndarray, optional) – Expression or array from which to extract the target variable
varname (str) – The name of the variable to extract
- Returns:
Callable – If called with a JAX expression, returns a JAX expression with the following call signature:
f(t, inputs=None)
- where:
- tfloat | np.ndarray
Time sample or vector of time samples
- inputsdict, optional
dictionary of input values, e.g. {‘Current function [A]’: 0.222, ‘Separator porosity’: 0.3}
np.ndarray – If called with a numeric (np.ndarray) object, returns a slice of the output corresponding to the target variable.
- get_vars(*args)[source]#
Helper function to extract a list of variables
Isolates a list of variables from the model output. Can be called on a JAX expression (which returns a JAX expression), or on a numeric (np.ndarray) object (which returns a slice of the output).
Example call using default JAX expression, returns a JAX expression:
f = idaklu_jax.get_vars(["Voltage [V]", "Current [A]"]) data = f(t, inputs=None)
Example call using a custom function, returns a JAX expression:
f = idaklu_jax.get_vars(jax.jit(f), ["Voltage [V]", "Current [A]"]) data = f(t, inputs=None)
Example call to slice a matrix, returns an np.array:
data = idaklu_jax.get_vars( jax.fwd(f, argnums=1)(t_eval, inputs)['Current function [A]'], ["Voltage [V]", "Current [A]"] )
- Parameters:
- Returns:
Callable – If called with a JAX expression, returns a JAX expression with the following call signature:
f(t, inputs=None)
- where:
- tfloat | np.ndarray
Time sample or vector of time samples
- inputsdict, optional
dictionary of input values, e.g. {‘Current function [A]’: 0.222, ‘Separator porosity’: 0.3}
np.ndarray – If called with a numeric (np.ndarray) object, returns a slice of the output corresponding to the target variables.
- jax_grad(t: ndarray = None, inputs: dict | None = None, output_variables: list[str] | None = None)[source]#
Helper function to compute the gradient of a jaxified expression
Returns a numeric (np.ndarray) object (not a JAX expression). Parameters are inferred from the base object, but can be overridden.
- jax_value(t: ndarray = None, inputs: dict | None = None, output_variables: list[str] | None = None)[source]#
Helper function to compute the gradient of a jaxified expression
Returns a numeric (np.ndarray) object (not a JAX expression). Parameters are inferred from the base object, but can be overridden.
- jaxify(model, t_eval, *, output_variables=None, calculate_sensitivities=True, t_interp=None)[source]#
JAXify the model and solver
Creates a JAX expression representing the IDAKLU-wrapped solver object.
- Parameters:
model (
pybamm.BaseModel
) – The model to be solvedt_eval (numeric type, optional) – The times at which to stop the integration due to a discontinuity in time.
output_variables (list of str, optional) – The variables to be returned. If None, the variables in the model are used.
calculate_sensitivities (bool, optional) – Whether to calculate sensitivities. Default is True.
t_interp (None, list or ndarray, optional) – The times (in seconds) at which to interpolate the solution. Defaults to None. Only valid for solvers that support intra-solve interpolation (IDAKLUSolver).