File size: 12,871 Bytes
484b847 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 |
#!/usr/bin/env python3
"""
Compute energy spectra from vorticity field data.
This script loads vorticity trajectory data from a .npy file and computes
the azimuthally averaged energy spectrum E(k). It outputs both the spectrum
data as a .npz file and a visualization plot as a .png file.
To run direct numerical simulations and get fluid fields, please use Jax-CFD: https://github.com/google/jax-cfd
Commit hash we used: 0c17e3855702f884265b97bd6ff0793c34f3155e
Usage:
uv run python fluid_stats.py path/to/vorticity.npy --out_dir results/
"""
import argparse
import logging
import os
from functools import partial
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import jit, vmap
from tqdm import tqdm
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# =============================================================================
# Core computation functions
# =============================================================================
@jit
def vorticity_to_velocity(vorticity):
"""
Convert vorticity to velocity components using the streamfunction.
Solves the Poisson equation in Fourier space: psi_hat = -vorticity_hat / k^2
Then computes velocity from streamfunction: u_x = -d(psi)/dy, u_y = d(psi)/dx
Parameters
----------
vorticity : jnp.ndarray, shape (X, Y)
2D vorticity field on a square grid.
Returns
-------
u_x : jnp.ndarray, shape (X, Y)
x-component of velocity.
u_y : jnp.ndarray, shape (X, Y)
y-component of velocity.
"""
N = vorticity.shape[0]
# Compute streamfunction from vorticity using Poisson equation
# In Fourier space: psi_hat = -vorticity_hat / k^2
vort_hat = jnp.fft.fft2(vorticity)
# Create wavenumber arrays
kx = jnp.fft.fftfreq(N, d=1.0) * 2 * jnp.pi
ky = jnp.fft.fftfreq(N, d=1.0) * 2 * jnp.pi
KX, KY = jnp.meshgrid(kx, ky, indexing="ij")
K2 = KX**2 + KY**2
# Avoid division by zero at k=0
K2 = K2.at[0, 0].set(1.0)
psi_hat = -vort_hat / K2
psi_hat = psi_hat.at[0, 0].set(0.0) # Set mean streamfunction to zero
# Compute velocity components from streamfunction
# u_x = -d(psi)/dy, u_y = d(psi)/dx
u_x_hat = -1j * KY * psi_hat
u_y_hat = 1j * KX * psi_hat
u_x = jnp.real(jnp.fft.ifft2(u_x_hat))
u_y = jnp.real(jnp.fft.ifft2(u_y_hat))
return u_x, u_y
@partial(jit, static_argnames=["k_max"])
def energy_spectrum_single(u_x, u_y, k_max=None):
"""
Compute azimuthally averaged energy spectrum E(k) for a single velocity field.
The energy spectrum is computed by binning the 2D Fourier-transformed
velocity field by wavenumber magnitude |k|.
Parameters
----------
u_x : jnp.ndarray, shape (X, Y)
x-component of velocity.
u_y : jnp.ndarray, shape (X, Y)
y-component of velocity.
k_max : int, optional
Maximum wavenumber to compute. If None, uses N//3 (2/3 dealiasing rule).
Returns
-------
E : jnp.ndarray, shape (k_max+1,)
Energy spectrum E(k) for k = 0, 1, ..., k_max.
"""
N = u_x.shape[0]
# FFT, shifted so k=0 is at centre
Ux = jnp.fft.fftshift(jnp.fft.fft2(u_x))
Ux = Ux / (N**2)
Uy = jnp.fft.fftshift(jnp.fft.fft2(u_y))
Uy = Uy / (N**2)
# Integer wave numbers
kx = jnp.fft.fftshift(jnp.fft.fftfreq(N)) * N
ky = kx
KX, KY = jnp.meshgrid(kx, ky)
K = jnp.hypot(KX, KY).astype(jnp.int32)
if k_max is None: # Nyquist under 2/3 de-alias
k_max = N // 3
# Vectorized computation of energy spectrum
def compute_E_k(k):
mask = K == k
return 0.5 * jnp.sum(jnp.abs(Ux) ** 2 * mask + jnp.abs(Uy) ** 2 * mask)
k_vals = jnp.arange(k_max + 1)
E = vmap(compute_E_k)(k_vals)
return E
@partial(jit, static_argnames=["k_max"])
def energy_spectrum_from_vorticity(vorticity, k_max=None):
"""
Compute energy spectrum from vorticity field using vmap.
Suitable for moderate resolution fields (up to ~1024x1024).
For larger resolutions, use energy_spectrum_from_vorticity_lax_map.
Parameters
----------
vorticity : jnp.ndarray, shape (T, X, Y)
Vorticity field over T time steps on an X x Y grid.
k_max : int, optional
Maximum wavenumber. If None, uses N//3 (2/3 dealiasing rule).
Returns
-------
E : jnp.ndarray, shape (T, k_max+1)
Energy spectrum for each time step.
"""
N = vorticity.shape[1]
if k_max is None:
k_max = N // 3
def process_timestep(vort_t):
u_x, u_y = vorticity_to_velocity(vort_t)
return energy_spectrum_single(u_x, u_y, k_max)
# Vectorize over time dimension
E = vmap(process_timestep)(vorticity)
return E
@partial(jit, static_argnames=["k_max", "batch_size"])
def energy_spectrum_from_vorticity_lax_map(vorticity, k_max=None, batch_size=16):
"""
Compute energy spectrum from vorticity field using jax.lax.map.
Memory-efficient version suitable for high resolution fields (>1024x1024).
Processes timesteps sequentially to reduce memory footprint.
Parameters
----------
vorticity : jnp.ndarray, shape (T, X, Y)
Vorticity field over T time steps on an X x Y grid.
k_max : int, optional
Maximum wavenumber. If None, uses N//3 (2/3 dealiasing rule).
batch_size : int, optional
Batch size for lax.map processing. Default is 16.
Returns
-------
E : jnp.ndarray, shape (T, k_max+1)
Energy spectrum for each time step.
"""
N = vorticity.shape[1]
if k_max is None:
k_max = N // 3
def process_timestep(vort_t):
u_x, u_y = vorticity_to_velocity(vort_t)
return energy_spectrum_single(u_x, u_y, k_max)
# Use lax.map instead of vmap for memory efficiency
E = jax.lax.map(process_timestep, vorticity, batch_size=batch_size)
return E
# =============================================================================
# Main script
# =============================================================================
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description=(
"Compute energy spectra from 2D vorticity trajectory data. "
"Loads vorticity fields from a .npy file, computes the azimuthally "
"averaged energy spectrum E(k), and saves both the spectrum data "
"and a visualization plot."
),
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
uv run python fluid_stats.py simulation.npy
uv run python fluid_stats.py data/vorticity.npy --out_dir results/
Input format:
The input .npy file should contain a 4D array with shape (batch, time, X, Y)
where batch is the number of independent trajectories, time is the number
of snapshots, and X, Y are the spatial grid dimensions.
""",
)
parser.add_argument(
"input_file",
type=str,
help=(
"Path to the input .npy file containing vorticity data. "
"Expected shape: (batch, time, X, Y) where X and Y are the "
"spatial grid dimensions (must be square, i.e., X == Y)."
),
)
parser.add_argument(
"--out_dir",
type=str,
default=".",
help=(
"Directory to save output files. Will be created if it does not "
"exist. Output files are named based on the input filename. "
"Default: current directory."
),
)
return parser.parse_args()
def main():
"""Main entry point for energy spectrum computation."""
args = parse_args()
# Setup
logger.info("JAX devices: %s", jax.devices())
# Validate input file
if not os.path.exists(args.input_file):
logger.error("Input file not found: %s", args.input_file)
raise FileNotFoundError(f"Input file not found: {args.input_file}")
if not args.input_file.endswith(".npy"):
logger.warning(
"Input file does not have .npy extension: %s", args.input_file
)
# Create output directory
os.makedirs(args.out_dir, exist_ok=True)
# Generate output filenames from input filename
input_basename = os.path.splitext(os.path.basename(args.input_file))[0]
data_filename = f"{input_basename}_spectrum_data.npz"
plot_filename = f"{input_basename}_spectrum.png"
data_path = os.path.join(args.out_dir, data_filename)
plot_path = os.path.join(args.out_dir, plot_filename)
# Load data
logger.info("Loading data from: %s", args.input_file)
field = np.load(args.input_file)
logger.info("Loaded field with shape: %s", field.shape)
# Validate shape
if field.ndim != 4:
logger.error(
"Expected 4D array (batch, time, X, Y), got %dD array", field.ndim
)
raise ValueError(
f"Expected 4D array (batch, time, X, Y), got {field.ndim}D array"
)
batch_size, time_steps, height, width = field.shape
if height != width:
logger.error(
"Expected square spatial grid (X == Y), got %d x %d", height, width
)
raise ValueError(
f"Expected square spatial grid (X == Y), got {height} x {width}"
)
resolution = height
k_max = resolution // 3
logger.info(
"Processing %d trajectories with %d timesteps at %dx%d resolution",
batch_size,
time_steps,
resolution,
resolution,
)
logger.info("Maximum wavenumber (k_max): %d", k_max)
# Compute energy spectrum
logger.info("Computing energy spectra...")
spectra_list = []
for i in tqdm(range(batch_size), desc="Computing spectra"):
if resolution > 1024:
# Use memory-efficient lax.map for large resolutions
single_spectrum = energy_spectrum_from_vorticity_lax_map(
field[i], k_max
)
else:
# Use vmap for moderate resolutions
single_spectrum = energy_spectrum_from_vorticity(field[i], k_max)
spectra_list.append(single_spectrum)
# Stack all spectra
all_spectra = jnp.stack(spectra_list)
logger.info("All spectra shape: %s", all_spectra.shape)
# Compute mean spectrum (over batch and time)
mean_spectrum = all_spectra.reshape(-1, all_spectra.shape[-1]).mean(axis=0)
logger.info("Mean spectrum shape: %s", mean_spectrum.shape)
# Save spectrum data
logger.info("Saving spectrum data to: %s", data_path)
np.savez_compressed(
data_path,
mean_spectrum=np.array(mean_spectrum),
all_spectra=np.array(all_spectra),
k_values=np.arange(len(mean_spectrum)),
resolution=resolution,
batch_size=batch_size,
time_steps=time_steps,
)
# Generate plot
logger.info("Generating energy spectrum plot...")
plt.figure(figsize=(10, 6))
# Plot mean spectrum (skip k=0)
offset = 1
spectrum = mean_spectrum[offset:]
k_values = np.arange(offset, len(mean_spectrum))
plt.loglog(k_values, spectrum, "b-", linewidth=2, label="Mean spectrum")
# Add k^{-5/3} reference line (Kolmogorov scaling for 3D turbulence)
# and k^{-3} reference line (enstrophy cascade in 2D turbulence)
k_match = min(10, len(spectrum) // 3)
if k_match > 0:
ref_value = float(spectrum[k_match - 1])
# k^{-3} line (2D enstrophy cascade)
scaling_k3 = ref_value * (k_match**3)
k_theory = np.logspace(0, np.log10(len(mean_spectrum)), 100)
power_law_k3 = scaling_k3 * k_theory ** (-3)
plt.loglog(
k_theory,
power_law_k3,
"k--",
alpha=0.7,
linewidth=1.5,
label=r"$k^{-3}$ (enstrophy cascade)",
)
# k^{-5/3} line (inverse energy cascade)
scaling_k53 = ref_value * (k_match ** (5 / 3))
power_law_k53 = scaling_k53 * k_theory ** (-5 / 3)
plt.loglog(
k_theory,
power_law_k53,
"r--",
alpha=0.7,
linewidth=1.5,
label=r"$k^{-5/3}$ (energy cascade)",
)
plt.xlabel("Wavenumber k", fontsize=12)
plt.ylabel("Energy Spectrum E(k)", fontsize=12)
plt.title(f"Energy Spectrum ({resolution}x{resolution} resolution)", fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
xlim = plt.xlim()
plt.xlim(1, xlim[1])
plt.tight_layout()
# Save plot
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
plt.close()
logger.info("Plot saved to: %s", plot_path)
logger.info("Done!")
if __name__ == "__main__":
main()
|