Source code for pybamm.codegen.compilation

import hashlib
import os
import re
import subprocess  # nosec B404 - compiler validated against _ALLOWED_COMPILERS
import sys
import tempfile
import time
import uuid

import casadi

from pybamm import logger

# Cache of bundle-hash -> list of ``casadi.external`` wrappers, one per
# non-External input to that bundle. A single-Function call is a bundle of
# size one; no separate code path.
_CACHE: dict[str, list[casadi.Function]] = {}

# Only remove build artifacts older than this to avoid racing with another
# process's in-flight compile.
_STALE_TMP_AGE_S = 3600

# Per-attempt temp filenames have the form ``<stem>.<pid>.<32-hex-uuid>.c``
# or ``...<ext>.tmp``.
_PER_ATTEMPT_TOKEN = re.compile(r"\.\d+\.[0-9a-f]{32}(?:\.|$)")

_TMP_FILE_PREFIX = "pybamm_"

# ``int NAME(const casadi_real** arg, ...);`` at the top level of the
# generated C marks an External sub-Function named ``NAME``. Decls for names
# defined in the same TU are fine; decls for anything else mean the caller
# wrapped an inner Function as an External before feeding it to a composite.
_EXTERN_DECL = re.compile(
    r"^\s*int\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(\s*const\s+casadi_real\s*\*\*",
    re.MULTILINE,
)

_swept_dirs: set[str] = set()

_ALLOWED_COMPILERS = frozenset({"gcc", "clang", "cc", "g++", "clang++"})


def _default_cache_dir() -> str:
    d = os.environ.get("PYBAMM_CASADI_AOT_CACHE")
    if d:
        os.makedirs(d, exist_ok=True)
        return d
    d = os.path.join(tempfile.gettempdir(), f"{_TMP_FILE_PREFIX}casadi_aot")
    os.makedirs(d, exist_ok=True)
    return d


def _shared_ext() -> str:
    if sys.platform == "darwin":
        return ".dylib"
    if sys.platform == "win32":
        return ".dll"
    return ".so"


[docs] def aot_compile(fn_or_fns, **kwargs): """Ahead-of-time compile one or more casadi ``Function`` objects to a single shared library and return ``casadi.external`` wrappers. Accepts either a single ``casadi.Function`` (returns a Function) or a list/tuple of Functions (returns a list, one per input, in order). In either case everything is lowered in one ``CodeGenerator`` / ``gcc`` invocation -- a single fn is a bundle of size one. Intended for the *outermost* Functions a solver hands off (e.g. ``rhs_algebraic``, ``jac_times_cjmass``, ``rootfn``, output-variable evaluators). Intermediate Functions should stay as MX/SX so ``casadi.CodeGenerator`` can inline them into one translation unit. Wrapping inner Functions as Externals forces cross-dylib dispatch and produces unresolvable ``extern`` declarations. Results are cached in-process (by a hash of the serialised forms) and on disk under ``$PYBAMM_CASADI_AOT_CACHE`` (default ``$TMPDIR/pybamm_casadi_aot``). Inputs already of class ``External`` are returned unchanged. On any failure, the original Function(s) are returned and a warning is logged. Parameters ---------- fn_or_fns : casadi.Function or list of casadi.Function **kwargs ``cache_dir``, ``compiler`` and ``flags`` overrides. """ is_single = isinstance(fn_or_fns, casadi.Function) fns = [fn_or_fns] if is_single else list(fn_or_fns) try: out = _aot_compile(fns, **kwargs) except Exception as e: names = ", ".join(fn.name() for fn in fns) logger.warning(f"Failed to compile [{names}] with error: {e}") out = list(fns) return out[0] if is_single else out
def _aot_compile( fns: list[casadi.Function], *, cache_dir: str | None = None, compiler: str | None = None, flags: tuple[str, ...] | None = None, ) -> list[casadi.Function]: # Pass-through Externals; compile the rest together in one TU. result: list[casadi.Function] = list(fns) indices_to_compile = [ i for i, fn in enumerate(fns) if fn.class_name() != "External" ] if not indices_to_compile: return result # Cache key: ordered hash of each fn's name + serialized form. hasher = hashlib.sha1(usedforsecurity=False) for idx in indices_to_compile: fn = fns[idx] hasher.update(fn.name().encode()) hasher.update(b"\0") hasher.update(fn.serialize().encode()) hasher.update(b"\0") key = hasher.hexdigest()[:16] cached = _CACHE.get(key) if cached is not None: for idx, ext_fn in zip(indices_to_compile, cached, strict=True): result[idx] = ext_fn return result if compiler is None: compiler = "gcc" if os.path.basename(compiler) not in _ALLOWED_COMPILERS: raise ValueError( f"Compiler '{compiler}' not in allowed list: {sorted(_ALLOWED_COMPILERS)}" ) if flags is None: flags = ("-O3", "-march=native", "-fPIC") cdir = cache_dir or _default_cache_dir() _maybe_sweep_stale(cdir) # Single-fn bundles get named after the fn for readability; multi-fn # bundles are hash-only since the member list isn't knowable from the # filename anyway. fns_to_compile = [fns[idx] for idx in indices_to_compile] label = fns_to_compile[0].name() if len(fns_to_compile) == 1 else "bundle" stem = f"{_TMP_FILE_PREFIX}{label}_{key}" ext = _shared_ext() sofile = os.path.join(cdir, stem + ext) if not os.path.exists(sofile): gen = casadi.CodeGenerator(stem, {"with_header": False}) for fn in fns_to_compile: gen.add(fn) c_source = gen.dump() bundled = {fn.name() for fn in fns_to_compile} externs = set(_EXTERN_DECL.findall(c_source)) - bundled if externs: raise RuntimeError( f"References to External sub-Function(s) {sorted(externs)} " "cannot be linked. aot_compile should only be called on " "top-level Functions; keep intermediate Functions as MX/SX." ) # Per-attempt temp paths so concurrent compiles of the same bundle # can't clobber each other, and so an interrupted build can be # detected and cleaned up later. suffix = f".{os.getpid()}.{uuid.uuid4().hex}" tmp_cfile = os.path.join(cdir, stem + suffix + ".c") tmp_sofile = os.path.join(cdir, stem + suffix + ext + ".tmp") try: with open(tmp_cfile, "w") as f: f.write(c_source) subprocess.run( # nosec B603 B607 - compiler validated against allowlist [compiler, *flags, "-shared", tmp_cfile, "-o", tmp_sofile], check=True, ) os.replace(tmp_sofile, sofile) if os.environ.get("PYBAMM_CASADI_AOT_KEEP_C"): os.replace(tmp_cfile, os.path.join(cdir, stem + ".c")) finally: for p in (tmp_cfile, tmp_sofile): try: os.remove(p) except OSError: pass ext_fns: list[casadi.Function] = [] for idx, fn in zip(indices_to_compile, fns_to_compile, strict=True): ext_fn = casadi.external(fn.name(), sofile) result[idx] = ext_fn ext_fns.append(ext_fn) _CACHE[key] = ext_fns return result def _maybe_sweep_stale(cdir: str) -> None: # Remove leaked per-attempt artifacts and orphan .c files once per # process. Only touches files matching our naming, and only if older # than ``_STALE_TMP_AGE_S``. if cdir in _swept_dirs: return _swept_dirs.add(cdir) try: entries = os.listdir(cdir) except OSError: return cutoff = time.time() - _STALE_TMP_AGE_S ext = _shared_ext() have_so = {n for n in entries if n.endswith(ext) and n.startswith(_TMP_FILE_PREFIX)} for name in entries: if not name.startswith(_TMP_FILE_PREFIX): continue path = os.path.join(cdir, name) try: if os.path.getmtime(path) > cutoff: continue is_per_attempt = bool(_PER_ATTEMPT_TOKEN.search(name)) if is_per_attempt and (name.endswith(".tmp") or name.endswith(".c")): os.remove(path) continue if name.endswith(".c") and not is_per_attempt: stem = name[: -len(".c")] if (stem + ext) not in have_so: os.remove(path) except OSError: pass