import numpy as np
import pybamm
from pybamm.util import import_optional_dependency
[docs]
def plot_3d_cross_section(
solution: "pybamm.Solution",
variable: str,
t: float | None,
plane: str = "yz",
position: float = 0.5,
n_pts: int = 100,
ax=None,
show_plot: bool = True,
cmap: str = "inferno",
levels: int = 20,
use_offset: bool = False,
show_mesh: bool = False,
mesh_color: str = "white",
mesh_alpha: float = 0.4,
mesh_linewidth: float = 0.7,
mesh_tolerance: float | None = None,
**kwargs,
):
"""
Plots a high-quality 2D cross-section of a 3D variable from a PyBaMM solution,
with mesh overlay support.
Parameters
----------
solution : pybamm.Solution
The solution object containing the 3D variable.
variable : str, optional
The name of the 3D variable to plot (default: "Cell temperature [K]").
t : float, optional
The time at which to plot. If None, the last timestep is used.
plane : str, optional
The plane for the cross-section ('xy', 'yz', 'xz', 'rz').
position : float, optional
The relative position (0 to 1) along the third axis to take the slice.
n_pts : int, optional
The number of points for the plotting grid in each direction.
ax : matplotlib.axes.Axes, optional
The axes on which to draw the plot. If None, a new figure is created.
show_plot : bool, optional
Whether to display the plot.
cmap : str, optional
The colormap for the plot.
levels : int, optional
The number of contour levels for the plot.
use_offset : bool, optional
Parameter to control color bar format to use scientific notation (default: False).
show_mesh : bool, optional
Whether to overlay the calculated FEM mesh slice on the plot.
mesh_color : str, optional
Color of the mesh lines.
mesh_alpha : float, optional
Transparency of the mesh lines.
mesh_linewidth : float, optional
Width of the mesh lines.
**kwargs
Additional keyword arguments passed to matplotlib.contourf.
"""
plt = import_optional_dependency("matplotlib.pyplot")
model = solution.all_models[0]
if model.options.get("dimensionality") != 3:
raise TypeError("This function requires a 3D model solution.")
if t is None:
t = solution.t[-1]
var_obj = solution[variable]
mesh = var_obj.mesh
nodes = mesh.nodes
elements = mesh.elements
x_min, x_max = np.min(nodes[:, 0]), np.max(nodes[:, 0])
y_min, y_max = np.min(nodes[:, 1]), np.max(nodes[:, 1])
z_min, z_max = np.min(nodes[:, 2]), np.max(nodes[:, 2])
geometry = model.options.get("cell geometry", "pouch")
if mesh_tolerance is None:
domain_size = max(x_max - x_min, y_max - y_min, z_max - z_min)
mesh_tolerance = domain_size * 0.01 # 1% of domain size
fig = None
if ax is None:
if geometry == "cylindrical" and plane == "xy":
fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
else:
fig, ax = plt.subplots(figsize=(8, 6))
fig = fig or ax.get_figure()
title_suffix = f"at t={t:.1f}s"
def add_mesh_overlay(
ax, nodes, elements, plane, slice_coord_val, geometry, mesh_tolerance
):
"""Calculates and draws the true intersection of the mesh with the plane."""
if plane == "yz":
slice_axis_idx = 0 # slice on x
plot_axes_indices = [1, 2] # plot y,z
elif plane == "xz":
slice_axis_idx = 1 # slice on y
plot_axes_indices = [0, 2] # plot x,z
elif plane == "xy":
slice_axis_idx = 2 # slice on z
plot_axes_indices = [0, 1] # plot x,y
elif plane == "rz":
slice_axis_idx = 1 # slice on y
slice_coord_val = 0.0
mesh_segments = []
for element in elements:
intersection_points = []
edges = [
(element[0], element[1]),
(element[0], element[2]),
(element[0], element[3]),
(element[1], element[2]),
(element[1], element[3]),
(element[2], element[3]),
]
for p1_idx, p2_idx in edges:
p1, p2 = nodes[p1_idx], nodes[p2_idx]
c1, c2 = p1[slice_axis_idx], p2[slice_axis_idx]
# Check if edge crosses the slice plane
if abs(c1 - slice_coord_val) <= mesh_tolerance:
intersection_points.append(p1)
elif abs(c2 - slice_coord_val) <= mesh_tolerance:
intersection_points.append(p2)
elif (c1 < slice_coord_val < c2) or (c2 < slice_coord_val < c1):
# Edge crosses the plane
ratio = (slice_coord_val - c1) / (c2 - c1)
intersection_point = p1 + ratio * (p2 - p1)
intersection_points.append(intersection_point)
# Remove duplicate points
if len(intersection_points) >= 2:
unique_points = []
for pt in intersection_points:
is_duplicate = False
for existing_pt in unique_points:
if np.linalg.norm(pt - existing_pt) < mesh_tolerance:
is_duplicate = True
break
if not is_duplicate:
unique_points.append(pt)
if len(unique_points) >= 2:
poly_xyz = np.array(unique_points)
if geometry == "cylindrical" and plane == "xy":
x_coords, y_coords = poly_xyz[:, 0], poly_xyz[:, 1]
r_coords = np.sqrt(x_coords**2 + y_coords**2)
theta_coords = np.arctan2(y_coords, x_coords)
theta_coords[theta_coords < 0] += (
2 * np.pi
) # Ensure positive angles
plot_coords = np.column_stack([theta_coords, r_coords])
elif geometry == "cylindrical" and plane == "rz":
mask = poly_xyz[:, 0] >= 0
if np.sum(mask) >= 2:
filtered_points = poly_xyz[mask]
r_coords = np.sqrt(
filtered_points[:, 0] ** 2 + filtered_points[:, 1] ** 2
)
z_coords = filtered_points[:, 2]
plot_coords = np.column_stack([r_coords, z_coords])
else:
continue
else: # Cartesian
plot_coords = poly_xyz[:, plot_axes_indices]
if len(plot_coords) >= 2:
mesh_segments.append(plot_coords)
segments_plotted = 0
for segment in mesh_segments:
if len(segment) >= 2:
if len(segment) > 2:
# Sort points to form a proper polygon
centroid = np.mean(segment, axis=0)
angles = np.arctan2(
segment[:, 1] - centroid[1], segment[:, 0] - centroid[0]
)
sorted_segment = segment[np.argsort(angles)]
# Close the polygon
final_segment = np.vstack([sorted_segment, sorted_segment[0]])
else:
final_segment = segment
ax.plot(
final_segment[:, 0],
final_segment[:, 1],
color=mesh_color,
alpha=mesh_alpha,
linewidth=mesh_linewidth,
)
segments_plotted += 1
print(f"Plotted {segments_plotted} mesh segments")
x_label, y_label = "", ""
slice_coord_val = None
if geometry == "cylindrical":
r_coords = np.sqrt(nodes[:, 0] ** 2 + nodes[:, 1] ** 2)
r_min, r_max = np.min(r_coords), np.max(r_coords)
if plane == "xy":
slice_coord_val = z_min + (z_max - z_min) * position
r_grid = np.linspace(r_min, r_max, n_pts)
theta_grid = np.linspace(0, 2 * np.pi, n_pts)
R_mesh, Theta_mesh = np.meshgrid(r_grid, theta_grid)
X_eval, Y_eval = R_mesh * np.cos(Theta_mesh), R_mesh * np.sin(Theta_mesh)
Z_eval = np.full_like(X_eval, slice_coord_val)
data = var_obj(t=t, x=X_eval, y=Y_eval, z=Z_eval)
pcm = ax.contourf(
Theta_mesh, R_mesh, data, levels=levels, cmap=cmap, **kwargs
)
ax.set_ylim(r_min, r_max)
plot_title = f"T(r,θ) at z={slice_coord_val:.2f}m, {title_suffix}"
elif plane == "rz":
slice_coord_val = 0.0 # Slice at theta=0
r_grid = np.linspace(r_min, r_max, n_pts)
z_grid = np.linspace(z_min, z_max, n_pts)
R_mesh, Z_mesh = np.meshgrid(r_grid, z_grid)
X_eval, Y_eval = R_mesh, np.zeros_like(R_mesh)
Z_eval = Z_mesh
data = var_obj(t=t, x=X_eval, y=Y_eval, z=Z_eval)
pcm = ax.contourf(R_mesh, Z_mesh, data, levels=levels, cmap=cmap, **kwargs)
x_label, y_label = "Radius r [m]", "Height z [m]"
plot_title = f"T(r,z) Cross-Section, {title_suffix}"
else:
raise ValueError(f"Plane '{plane}' invalid for cylindrical geometry.")
else: # Cartesian geometry
if plane == "yz":
slice_coord_val = x_min + (x_max - x_min) * position
grid_1, grid_2 = (
np.linspace(y_min, y_max, n_pts),
np.linspace(z_min, z_max, n_pts),
)
Y_eval, Z_eval = np.meshgrid(grid_1, grid_2)
X_eval = np.full_like(Y_eval, slice_coord_val)
data = var_obj(
t=t, x=X_eval.ravel(), y=Y_eval.ravel(), z=Z_eval.ravel()
).reshape(X_eval.shape)
pcm = ax.contourf(Y_eval, Z_eval, data, levels=levels, cmap=cmap, **kwargs)
x_label, y_label = "y [m]", "z [m]"
plot_title = f"T(y,z) at x={slice_coord_val:.2f}m, {title_suffix}"
elif plane == "xz":
slice_coord_val = y_min + (y_max - y_min) * position
grid_1, grid_2 = (
np.linspace(x_min, x_max, n_pts),
np.linspace(z_min, z_max, n_pts),
)
X_eval, Z_eval = np.meshgrid(grid_1, grid_2)
Y_eval = np.full_like(X_eval, slice_coord_val)
data = var_obj(
t=t, x=X_eval.ravel(), y=Y_eval.ravel(), z=Z_eval.ravel()
).reshape(X_eval.shape)
pcm = ax.contourf(X_eval, Z_eval, data, levels=levels, cmap=cmap, **kwargs)
x_label, y_label = "x [m]", "z [m]"
plot_title = f"T(x,z) at y={slice_coord_val:.2f}m, {title_suffix}"
elif plane == "xy":
slice_coord_val = z_min + (z_max - z_min) * position
grid_1, grid_2 = (
np.linspace(x_min, x_max, n_pts),
np.linspace(y_min, y_max, n_pts),
)
X_eval, Y_eval = np.meshgrid(grid_1, grid_2)
Z_eval = np.full_like(X_eval, slice_coord_val)
data = var_obj(
t=t, x=X_eval.ravel(), y=Y_eval.ravel(), z=Z_eval.ravel()
).reshape(X_eval.shape)
pcm = ax.contourf(X_eval, Y_eval, data, levels=levels, cmap=cmap, **kwargs)
x_label, y_label = "x [m]", "y [m]"
plot_title = f"T(x,y) at z={slice_coord_val:.2f}m, {title_suffix}"
else:
raise ValueError(
f"Plane '{plane}' invalid for Cartesian geometry. Use 'xy', 'yz', or 'xz'."
)
# Add mesh overlay
if show_mesh:
add_mesh_overlay(
ax, nodes, elements, plane, slice_coord_val, geometry, mesh_tolerance
)
cbar = fig.colorbar(pcm, ax=ax, label=f"{variable}")
if not use_offset:
cbar.formatter.set_useOffset(False)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_title(plot_title)
ax.set_aspect("auto", "box")
if show_plot:
plt.tight_layout()
plt.show()
return ax