├── .gitignore ├── LICENSE ├── README.md ├── euler.py ├── euler_distributed.py ├── euler_numpy.py ├── make_gif.sh ├── requirements.txt ├── results ├── plot_results.py ├── result_16384_single.png ├── runs.json ├── scaling_strong.png └── scaling_weak.png ├── sbatch_profile.sh └── sbatch_rusty_gpu.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | output* 3 | results/output* 4 | timing_data/ 5 | jax-trace/ 6 | slurm* 7 | report* 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Philip Mocz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple Euler Equation JAX benchmarking 2 | 3 | Philip Mocz (2024) 4 | 5 | Flatiron Institute 6 | 7 | Benchmarking on `macbook` (Apple M3 Max) and `rusty` (Nvidia A100) 8 | 9 | 10 | ## Files 11 | 12 | * `euler.py` simple JAX version on single node 13 | * `euler_distributed.py` JAX version for distributed systems 14 | * `euler_numpy.py` simple numpy version (based on my [blog tutorial](https://levelup.gitconnected.com/create-your-own-finite-volume-fluid-simulation-with-python-8f9eab0b8305?sk=584a56a12a551ca1b74ba19b2a9dffbb)) 15 | 16 | 17 | ## Setup 18 | 19 | * Create a python virtual environment and install required modules: 20 | 21 | ```console 22 | python -m venv --system-site-packages $VENVDIR/my-jax-venv 23 | source $VENVDIR/my-jax-venv/bin/activate 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | 28 | ## Strong Scaling on `macbook`: 29 | 30 | ![strong scaling](results/scaling_strong.png) 31 | 32 | 33 | ## Weak Scaling on `rusty`: 34 | 35 | 36 | ![weak scaling](results/scaling_weak.png) 37 | 38 | 39 | ## Final Simulation Result 40 | 41 | 16384^2 resolution JAX (single-precision) simulation after 277300 iterations on 16 GPUs in 64.1 minutes 42 | 43 | (for reference, my macbook run (single-precision) at 1024^2 resolution after 15426 iterations took 4.6 minutes) 44 | 45 | The GPU calculations had a throughput (mcups) 335x more! 46 | 47 | ![final snapshot](results/result_16384_single.png) 48 | -------------------------------------------------------------------------------- /euler.py: -------------------------------------------------------------------------------- 1 | # A simple example of solving the Euler equations with JAX 2 | # Philip Mocz (2024) 3 | 4 | import os 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | import matplotlib.pyplot as plt 9 | import time 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--resolution", type=int, default=1024) # 1024 512 # 256 # 128 # 64 14 | parser.add_argument("--double", action="store_true") 15 | args = parser.parse_args() 16 | 17 | if args.double: 18 | print("Using double precision") 19 | jax.config.update("jax_enable_x64", True) 20 | else: 21 | print("Using single precision") 22 | 23 | 24 | @jax.jit 25 | def get_conserved(rho, vx, vy, P, gamma, vol): 26 | """Calculate the conserved variables from the primitive variables""" 27 | 28 | Mass = rho * vol 29 | Momx = rho * vx * vol 30 | Momy = rho * vy * vol 31 | Energy = (P / (gamma - 1) + 0.5 * rho * (vx**2 + vy**2)) * vol 32 | 33 | return Mass, Momx, Momy, Energy 34 | 35 | 36 | @jax.jit 37 | def get_primitive(Mass, Momx, Momy, Energy, gamma, vol): 38 | """Calculate the primitive variable from the conserved variables""" 39 | 40 | rho = Mass / vol 41 | vx = Momx / rho / vol 42 | vy = Momy / rho / vol 43 | P = (Energy / vol - 0.5 * rho * (vx**2 + vy**2)) * (gamma - 1) 44 | 45 | return rho, vx, vy, P 46 | 47 | 48 | @jax.jit 49 | def get_gradient(f, dx): 50 | """Calculate the gradients of a field""" 51 | 52 | # (right - left) / 2dx 53 | f_dx = (jnp.roll(f, -1, axis=0) - jnp.roll(f, 1, axis=0)) / (2 * dx) 54 | f_dy = (jnp.roll(f, -1, axis=1) - jnp.roll(f, 1, axis=1)) / (2 * dx) 55 | 56 | return f_dx, f_dy 57 | 58 | 59 | @jax.jit 60 | def extrapolate_to_face(f, f_dx, f_dy, dx): 61 | """Extrapolate the field from face centers to faces using gradients""" 62 | 63 | f_XL = f - f_dx * dx / 2 64 | f_XL = jnp.roll(f_XL, -1, axis=0) # right/up roll 65 | f_XR = f + f_dx * dx / 2 66 | 67 | f_YL = f - f_dy * dx / 2 68 | f_YL = jnp.roll(f_YL, -1, axis=1) 69 | f_YR = f + f_dy * dx / 2 70 | 71 | return f_XL, f_XR, f_YL, f_YR 72 | 73 | 74 | @jax.jit 75 | def apply_fluxes(F, flux_F_X, flux_F_Y, dx, dt): 76 | """Apply fluxes to conserved variables to update solution state""" 77 | 78 | F += -dt * dx * flux_F_X 79 | F += dt * dx * jnp.roll(flux_F_X, 1, axis=0) # left/down roll 80 | F += -dt * dx * flux_F_Y 81 | F += dt * dx * jnp.roll(flux_F_Y, 1, axis=1) 82 | 83 | return F 84 | 85 | 86 | @jax.jit 87 | def get_flux(rho_L, rho_R, vx_L, vx_R, vy_L, vy_R, P_L, P_R, gamma): 88 | """Calculate fluxes between 2 states with local Lax-Friedrichs/Rusanov rule""" 89 | 90 | # left and right energies 91 | en_L = P_L / (gamma - 1) + 0.5 * rho_L * (vx_L**2 + vy_L**2) 92 | en_R = P_R / (gamma - 1) + 0.5 * rho_R * (vx_R**2 + vy_R**2) 93 | 94 | # compute star (averaged) states 95 | rho_star = 0.5 * (rho_L + rho_R) 96 | momx_star = 0.5 * (rho_L * vx_L + rho_R * vx_R) 97 | momy_star = 0.5 * (rho_L * vy_L + rho_R * vy_R) 98 | en_star = 0.5 * (en_L + en_R) 99 | 100 | P_star = (gamma - 1) * (en_star - 0.5 * (momx_star**2 + momy_star**2) / rho_star) 101 | 102 | # compute fluxes (local Lax-Friedrichs/Rusanov) 103 | flux_Mass = momx_star 104 | flux_Momx = momx_star**2 / rho_star + P_star 105 | flux_Momy = momx_star * momy_star / rho_star 106 | flux_Energy = (en_star + P_star) * momx_star / rho_star 107 | 108 | # find wavespeeds 109 | C_L = jnp.sqrt(gamma * P_L / rho_L) + jnp.abs(vx_L) 110 | C_R = jnp.sqrt(gamma * P_R / rho_R) + jnp.abs(vx_R) 111 | C = jnp.maximum(C_L, C_R) 112 | 113 | # add stabilizing diffusive term 114 | flux_Mass -= C * 0.5 * (rho_L - rho_R) 115 | flux_Momx -= C * 0.5 * (rho_L * vx_L - rho_R * vx_R) 116 | flux_Momy -= C * 0.5 * (rho_L * vy_L - rho_R * vy_R) 117 | flux_Energy -= C * 0.5 * (en_L - en_R) 118 | 119 | return flux_Mass, flux_Momx, flux_Momy, flux_Energy 120 | 121 | 122 | def update(Mass, Momx, Momy, Energy, vol, dx, gamma, courant_fac): 123 | """Take a simulation timestep""" 124 | 125 | # get Primitive variables 126 | rho, vx, vy, P = get_primitive(Mass, Momx, Momy, Energy, gamma, vol) 127 | 128 | # get time step (CFL) = dx / max signal speed 129 | dt = courant_fac * jnp.min( 130 | dx / (jnp.sqrt(gamma * P / rho) + jnp.sqrt(vx**2 + vy**2)) 131 | ) 132 | 133 | # calculate gradients 134 | rho_dx, rho_dy = get_gradient(rho, dx) 135 | vx_dx, vx_dy = get_gradient(vx, dx) 136 | vy_dx, vy_dy = get_gradient(vy, dx) 137 | P_dx, P_dy = get_gradient(P, dx) 138 | 139 | # extrapolate half-step in time 140 | rho_prime = rho - 0.5 * dt * (vx * rho_dx + rho * vx_dx + vy * rho_dy + rho * vy_dy) 141 | vx_prime = vx - 0.5 * dt * (vx * vx_dx + vy * vx_dy + (1 / rho) * P_dx) 142 | vy_prime = vy - 0.5 * dt * (vx * vy_dx + vy * vy_dy + (1 / rho) * P_dy) 143 | P_prime = P - 0.5 * dt * (gamma * P * (vx_dx + vy_dy) + vx * P_dx + vy * P_dy) 144 | 145 | # extrapolate in space to face centers 146 | rho_XL, rho_XR, rho_YL, rho_YR = extrapolate_to_face(rho_prime, rho_dx, rho_dy, dx) 147 | vx_XL, vx_XR, vx_YL, vx_YR = extrapolate_to_face(vx_prime, vx_dx, vx_dy, dx) 148 | vy_XL, vy_XR, vy_YL, vy_YR = extrapolate_to_face(vy_prime, vy_dx, vy_dy, dx) 149 | P_XL, P_XR, P_YL, P_YR = extrapolate_to_face(P_prime, P_dx, P_dy, dx) 150 | 151 | # compute fluxes (local Lax-Friedrichs/Rusanov) 152 | flux_Mass_X, flux_Momx_X, flux_Momy_X, flux_Energy_X = get_flux( 153 | rho_XL, rho_XR, vx_XL, vx_XR, vy_XL, vy_XR, P_XL, P_XR, gamma 154 | ) 155 | flux_Mass_Y, flux_Momy_Y, flux_Momx_Y, flux_Energy_Y = get_flux( 156 | rho_YL, rho_YR, vy_YL, vy_YR, vx_YL, vx_YR, P_YL, P_YR, gamma 157 | ) 158 | 159 | # update solution 160 | Mass = apply_fluxes(Mass, flux_Mass_X, flux_Mass_Y, dx, dt) 161 | Momx = apply_fluxes(Momx, flux_Momx_X, flux_Momx_Y, dx, dt) 162 | Momy = apply_fluxes(Momy, flux_Momy_X, flux_Momy_Y, dx, dt) 163 | Energy = apply_fluxes(Energy, flux_Energy_X, flux_Energy_Y, dx, dt) 164 | 165 | return Mass, Momx, Momy, Energy, dt, rho 166 | 167 | 168 | def main(): 169 | """Finite Volume simulation""" 170 | 171 | # Simulation parameters 172 | N = args.resolution 173 | boxsize = 1.0 174 | gamma = 5.0 / 3.0 # ideal gas gamma 175 | courant_fac = 0.4 176 | t_stop = 2.0 177 | save_freq = 0.1 178 | save_animation_path = ( 179 | "output_euler_" + str(N) + ("double" if args.double else "single") 180 | ) 181 | 182 | # Mesh 183 | dx = boxsize / N 184 | vol = dx**2 185 | xlin = jnp.linspace(0.5 * dx, boxsize - 0.5 * dx, N) 186 | X, Y = jnp.meshgrid(xlin, xlin, indexing="ij") 187 | 188 | # Generate Initial Conditions - opposite moving streams with perturbation 189 | w0 = 0.1 190 | sigma = 0.05 / jnp.sqrt(2.0) 191 | rho = 1.0 + (jnp.abs(Y - 0.5) < 0.25) 192 | vx = -0.5 + (jnp.abs(Y - 0.5) < 0.25) 193 | vy = ( 194 | w0 195 | * jnp.sin(4 * jnp.pi * X) 196 | * ( 197 | jnp.exp(-((Y - 0.25) ** 2) / (2 * sigma**2)) 198 | + jnp.exp(-((Y - 0.75) ** 2) / (2 * sigma**2)) 199 | ) 200 | ) 201 | P = 2.5 * jnp.ones(X.shape) 202 | 203 | # Get conserved variables 204 | Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P, gamma, vol) 205 | 206 | # Make animation directory if it doesn't exist 207 | if not os.path.exists(save_animation_path): 208 | os.makedirs(save_animation_path, exist_ok=True) 209 | 210 | # Simulation Main Loop 211 | tic = time.time() 212 | t = 0 213 | output_counter = 0 214 | n_iter = 0 215 | save_freq = 0.05 216 | while t < t_stop: 217 | 218 | # Time step 219 | Mass, Momx, Momy, Energy, dt, rho = update( 220 | Mass, Momx, Momy, Energy, vol, dx, gamma, courant_fac 221 | ) 222 | 223 | # determine if we should save the plot 224 | save_plot = False 225 | if t + dt > output_counter * save_freq: 226 | save_plot = True 227 | output_counter += 1 228 | 229 | # update time 230 | t += dt 231 | 232 | # update iteration counter 233 | n_iter += 1 234 | 235 | # save plot 236 | if save_plot: 237 | plt.imsave( 238 | save_animation_path + "/rho" + str(output_counter).zfill(6) + ".png", 239 | jnp.rot90(rho), 240 | cmap="jet", 241 | vmin=0.8, 242 | vmax=2.2, 243 | ) 244 | 245 | # Print progress 246 | print("[it=" + str(n_iter) + " t=" + "{:.6f}".format(t) + "]") 247 | print( 248 | " saved state " 249 | + str(output_counter).zfill(6) 250 | + " of " 251 | + str(int(jnp.ceil(t_stop / save_freq))) 252 | ) 253 | 254 | # Print million updates per second 255 | cell_updates = X.shape[0] * X.shape[1] * n_iter 256 | total_time = time.time() - tic 257 | mcups = cell_updates / (1e6 * total_time) 258 | print(" million cell updates / second: ", mcups) 259 | 260 | print("Total time: ", total_time) 261 | 262 | 263 | if __name__ == "__main__": 264 | main() 265 | -------------------------------------------------------------------------------- /euler_distributed.py: -------------------------------------------------------------------------------- 1 | # A simple example of solving the Euler equations with JAX on distributed systems 2 | # Philip Mocz (2024) 3 | 4 | USE_CPU_ONLY = False # True # False 5 | 6 | import os 7 | 8 | if USE_CPU_ONLY: 9 | flags = os.environ.get("XLA_FLAGS", "") 10 | flags = os.environ.get("XLA_FLAGS", "") 11 | flags += " --xla_force_host_platform_device_count=8" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 13 | os.environ["XLA_FLAGS"] = flags 14 | 15 | # del os.environ["QUADD_INJECTION_PROXY"] 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | from jax.experimental import mesh_utils 20 | from jax.sharding import Mesh, PartitionSpec, NamedSharding 21 | 22 | import matplotlib.pyplot as plt 23 | import time 24 | import argparse 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--resolution", type=int, default=8192) # 1024 512 # 256 # 128 # 64 28 | parser.add_argument("--double", action="store_true") 29 | args = parser.parse_args() 30 | 31 | if args.double: 32 | print("Using double precision") 33 | jax.config.update("jax_enable_x64", True) 34 | else: 35 | print("Using single precision") 36 | 37 | 38 | @jax.jit 39 | def get_conserved(rho, vx, vy, P, gamma, vol): 40 | """Calculate the conserved variables from the primitive variables""" 41 | 42 | Mass = rho * vol 43 | Momx = rho * vx * vol 44 | Momy = rho * vy * vol 45 | Energy = (P / (gamma - 1) + 0.5 * rho * (vx**2 + vy**2)) * vol 46 | 47 | return Mass, Momx, Momy, Energy 48 | 49 | 50 | @jax.jit 51 | def get_primitive(Mass, Momx, Momy, Energy, gamma, vol): 52 | """Calculate the primitive variable from the conserved variables""" 53 | 54 | rho = Mass / vol 55 | vx = Momx / rho / vol 56 | vy = Momy / rho / vol 57 | P = (Energy / vol - 0.5 * rho * (vx**2 + vy**2)) * (gamma - 1) 58 | 59 | return rho, vx, vy, P 60 | 61 | 62 | @jax.jit 63 | def get_gradient(f, dx): 64 | """Calculate the gradients of a field""" 65 | 66 | # (right - left) / 2dx 67 | f_dx = (jnp.roll(f, -1, axis=0) - jnp.roll(f, 1, axis=0)) / (2 * dx) 68 | f_dy = (jnp.roll(f, -1, axis=1) - jnp.roll(f, 1, axis=1)) / (2 * dx) 69 | 70 | return f_dx, f_dy 71 | 72 | 73 | @jax.jit 74 | def extrapolate_to_face(f, f_dx, f_dy, dx): 75 | """Extrapolate the field from face centers to faces using gradients""" 76 | 77 | f_XL = f - f_dx * dx / 2 78 | f_XL = jnp.roll(f_XL, -1, axis=0) # right/up roll 79 | f_XR = f + f_dx * dx / 2 80 | 81 | f_YL = f - f_dy * dx / 2 82 | f_YL = jnp.roll(f_YL, -1, axis=1) 83 | f_YR = f + f_dy * dx / 2 84 | 85 | return f_XL, f_XR, f_YL, f_YR 86 | 87 | 88 | @jax.jit 89 | def apply_fluxes(F, flux_F_X, flux_F_Y, dx, dt): 90 | """Apply fluxes to conserved variables to update solution state""" 91 | 92 | F += -dt * dx * flux_F_X 93 | F += dt * dx * jnp.roll(flux_F_X, 1, axis=0) # left/down roll 94 | F += -dt * dx * flux_F_Y 95 | F += dt * dx * jnp.roll(flux_F_Y, 1, axis=1) 96 | 97 | return F 98 | 99 | 100 | @jax.jit 101 | def get_flux(rho_L, rho_R, vx_L, vx_R, vy_L, vy_R, P_L, P_R, gamma): 102 | """Calculate fluxes between 2 states with local Lax-Friedrichs/Rusanov rule""" 103 | 104 | # left and right energies 105 | en_L = P_L / (gamma - 1) + 0.5 * rho_L * (vx_L**2 + vy_L**2) 106 | en_R = P_R / (gamma - 1) + 0.5 * rho_R * (vx_R**2 + vy_R**2) 107 | 108 | # compute star (averaged) states 109 | rho_star = 0.5 * (rho_L + rho_R) 110 | momx_star = 0.5 * (rho_L * vx_L + rho_R * vx_R) 111 | momy_star = 0.5 * (rho_L * vy_L + rho_R * vy_R) 112 | en_star = 0.5 * (en_L + en_R) 113 | 114 | P_star = (gamma - 1) * (en_star - 0.5 * (momx_star**2 + momy_star**2) / rho_star) 115 | 116 | # compute fluxes (local Lax-Friedrichs/Rusanov) 117 | flux_Mass = momx_star 118 | flux_Momx = momx_star**2 / rho_star + P_star 119 | flux_Momy = momx_star * momy_star / rho_star 120 | flux_Energy = (en_star + P_star) * momx_star / rho_star 121 | 122 | # find wavespeeds 123 | C_L = jnp.sqrt(gamma * P_L / rho_L) + jnp.abs(vx_L) 124 | C_R = jnp.sqrt(gamma * P_R / rho_R) + jnp.abs(vx_R) 125 | C = jnp.maximum(C_L, C_R) 126 | 127 | # add stabilizing diffusive term 128 | flux_Mass -= C * 0.5 * (rho_L - rho_R) 129 | flux_Momx -= C * 0.5 * (rho_L * vx_L - rho_R * vx_R) 130 | flux_Momy -= C * 0.5 * (rho_L * vy_L - rho_R * vy_R) 131 | flux_Energy -= C * 0.5 * (en_L - en_R) 132 | 133 | return flux_Mass, flux_Momx, flux_Momy, flux_Energy 134 | 135 | 136 | def update(Mass, Momx, Momy, Energy, vol, dx, gamma, courant_fac): 137 | """Take a simulation timestep""" 138 | 139 | # get Primitive variables 140 | rho, vx, vy, P = get_primitive(Mass, Momx, Momy, Energy, gamma, vol) 141 | 142 | # get time step (CFL) = dx / max signal speed 143 | dt = courant_fac * jnp.min( 144 | dx / (jnp.sqrt(gamma * P / rho) + jnp.sqrt(vx**2 + vy**2)) 145 | ) 146 | 147 | # calculate gradients 148 | rho_dx, rho_dy = get_gradient(rho, dx) 149 | vx_dx, vx_dy = get_gradient(vx, dx) 150 | vy_dx, vy_dy = get_gradient(vy, dx) 151 | P_dx, P_dy = get_gradient(P, dx) 152 | 153 | # extrapolate half-step in time 154 | rho_prime = rho - 0.5 * dt * (vx * rho_dx + rho * vx_dx + vy * rho_dy + rho * vy_dy) 155 | vx_prime = vx - 0.5 * dt * (vx * vx_dx + vy * vx_dy + (1 / rho) * P_dx) 156 | vy_prime = vy - 0.5 * dt * (vx * vy_dx + vy * vy_dy + (1 / rho) * P_dy) 157 | P_prime = P - 0.5 * dt * (gamma * P * (vx_dx + vy_dy) + vx * P_dx + vy * P_dy) 158 | 159 | # extrapolate in space to face centers 160 | rho_XL, rho_XR, rho_YL, rho_YR = extrapolate_to_face(rho_prime, rho_dx, rho_dy, dx) 161 | vx_XL, vx_XR, vx_YL, vx_YR = extrapolate_to_face(vx_prime, vx_dx, vx_dy, dx) 162 | vy_XL, vy_XR, vy_YL, vy_YR = extrapolate_to_face(vy_prime, vy_dx, vy_dy, dx) 163 | P_XL, P_XR, P_YL, P_YR = extrapolate_to_face(P_prime, P_dx, P_dy, dx) 164 | 165 | # compute fluxes (local Lax-Friedrichs/Rusanov) 166 | flux_Mass_X, flux_Momx_X, flux_Momy_X, flux_Energy_X = get_flux( 167 | rho_XL, rho_XR, vx_XL, vx_XR, vy_XL, vy_XR, P_XL, P_XR, gamma 168 | ) 169 | flux_Mass_Y, flux_Momy_Y, flux_Momx_Y, flux_Energy_Y = get_flux( 170 | rho_YL, rho_YR, vy_YL, vy_YR, vx_YL, vx_YR, P_YL, P_YR, gamma 171 | ) 172 | 173 | # update solution 174 | Mass = apply_fluxes(Mass, flux_Mass_X, flux_Mass_Y, dx, dt) 175 | Momx = apply_fluxes(Momx, flux_Momx_X, flux_Momx_Y, dx, dt) 176 | Momy = apply_fluxes(Momy, flux_Momy_X, flux_Momy_Y, dx, dt) 177 | Energy = apply_fluxes(Energy, flux_Energy_X, flux_Energy_Y, dx, dt) 178 | 179 | return Mass, Momx, Momy, Energy, dt, rho 180 | 181 | 182 | def main(): 183 | """Finite Volume simulation""" 184 | 185 | if not USE_CPU_ONLY: 186 | jax.distributed.initialize() 187 | n_devices = jax.device_count() 188 | mesh = Mesh(mesh_utils.create_device_mesh((n_devices, 1)), ("x", "y")) 189 | sharding = NamedSharding(mesh, PartitionSpec("x", "y")) 190 | 191 | if jax.process_index() == 0: 192 | for env_var in [ 193 | "SLURM_JOB_ID", 194 | "SLURM_NTASKS", 195 | "SLURM_NODELIST", 196 | "SLURM_STEP_NODELIST", 197 | "SLURM_STEP_GPUS", 198 | "SLURM_GPUS", 199 | ]: 200 | print(f'{env_var}: {os.getenv(env_var,"")}') 201 | print("Total number of processes: ", jax.process_count()) 202 | print("Total number of devices: ", jax.device_count()) 203 | print("List of devices: ", jax.devices()) 204 | print("Number of devices on this process: ", jax.local_device_count()) 205 | 206 | # Simulation parameters 207 | N = args.resolution 208 | boxsize = 1.0 209 | gamma = 5.0 / 3.0 # ideal gas gamma 210 | courant_fac = 0.4 211 | t_stop = 2.0 212 | save_freq = 0.1 213 | if USE_CPU_ONLY: 214 | save_animation_path = ( 215 | "output_euler_distributed_" 216 | + str(N) 217 | + ("double" if args.double else "single") 218 | ) 219 | else: 220 | save_animation_path = ( 221 | "/mnt/home/pmocz/ceph/jax-euler-benchmarks/output_euler_distributed_" 222 | + str(N) 223 | + ("double" if args.double else "single") 224 | ) 225 | 226 | # Mesh 227 | dx = boxsize / N 228 | vol = dx**2 229 | xlin = jnp.linspace(0.5 * dx, boxsize - 0.5 * dx, N) 230 | X, Y = jnp.meshgrid(xlin, xlin, indexing="ij") 231 | 232 | X = jax.lax.with_sharding_constraint(X, sharding) 233 | Y = jax.lax.with_sharding_constraint(Y, sharding) 234 | 235 | if jax.process_index() == 0: 236 | print("X:") 237 | jax.debug.visualize_array_sharding(X) 238 | 239 | # Generate Initial Conditions - opposite moving streams with perturbation 240 | w0 = 0.1 241 | sigma = 0.05 / jnp.sqrt(2.0) 242 | rho = 1.0 + (jnp.abs(Y - 0.5) < 0.25) 243 | vx = -0.5 + (jnp.abs(Y - 0.5) < 0.25) 244 | vy = ( 245 | w0 246 | * jnp.sin(4 * jnp.pi * X) 247 | * ( 248 | jnp.exp(-((Y - 0.25) ** 2) / (2 * sigma**2)) 249 | + jnp.exp(-((Y - 0.75) ** 2) / (2 * sigma**2)) 250 | ) 251 | ) 252 | P = 2.5 * jnp.ones(X.shape) 253 | P = jax.lax.with_sharding_constraint(P, sharding) 254 | 255 | if jax.process_index() == 0: 256 | print("rho:") 257 | jax.debug.visualize_array_sharding(rho) 258 | print("P:") 259 | jax.debug.visualize_array_sharding(P) 260 | 261 | # Get conserved variables 262 | Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P, gamma, vol) 263 | 264 | if jax.process_index() == 0: 265 | print("Mass:") 266 | jax.debug.visualize_array_sharding(Mass) 267 | 268 | # Make animation directory if it doesn't exist 269 | if not os.path.exists(save_animation_path): 270 | os.makedirs(save_animation_path, exist_ok=True) 271 | 272 | # Simulation Main Loop 273 | tic = time.time() 274 | t = 0 275 | output_counter = 0 276 | n_iter = 0 277 | save_freq = 0.05 278 | while t < t_stop: 279 | 280 | # Time step 281 | Mass, Momx, Momy, Energy, dt, rho = update( 282 | Mass, Momx, Momy, Energy, vol, dx, gamma, courant_fac 283 | ) 284 | 285 | # determine if we should save the plot 286 | save_plot = False 287 | if t + dt > output_counter * save_freq: 288 | save_plot = True 289 | output_counter += 1 290 | 291 | # update time 292 | t += dt 293 | 294 | # update iteration counter 295 | n_iter += 1 296 | 297 | # save plot 298 | if save_plot: 299 | for d in range(jax.local_device_count()): 300 | plt.imsave( 301 | save_animation_path 302 | + "/dump_rho" 303 | + str(output_counter).zfill(6) 304 | + "_" 305 | + str(jax.process_index()).zfill(2) 306 | + "_" 307 | + str(d).zfill(2) 308 | + ".png", 309 | jnp.rot90(rho.addressable_data(d)), 310 | cmap="jet", 311 | vmin=0.8, 312 | vmax=2.2, 313 | ) 314 | 315 | # Print progress 316 | if jax.process_index() == 0: 317 | print("[it=" + str(n_iter) + " t=" + "{:.6f}".format(t) + "]") 318 | print( 319 | " saved state " 320 | + str(output_counter).zfill(6) 321 | + " of " 322 | + str(int(jnp.ceil(t_stop / save_freq))) 323 | ) 324 | 325 | # Print million updates per second 326 | cell_updates = X.shape[0] * X.shape[1] * n_iter 327 | total_time = time.time() - tic 328 | mcups = cell_updates / (1e6 * total_time) 329 | print(" million cell updates / second: ", mcups) 330 | 331 | if jax.process_index() == 0: 332 | print("Total time: ", total_time) 333 | print("rho:") 334 | jax.debug.visualize_array_sharding(rho) 335 | 336 | import numpy as np 337 | 338 | # for each output time, stich together images from each of the processes 339 | for snap in range(1, output_counter + 1): 340 | images = [] 341 | for p in range(jax.process_count()): 342 | for d in range(jax.local_device_count()): 343 | images.append( 344 | plt.imread( 345 | save_animation_path 346 | + "/dump_rho" 347 | + str(snap).zfill(6) 348 | + "_" 349 | + str(p).zfill(2) 350 | + "_" 351 | + str(d).zfill(2) 352 | + ".png" 353 | ) 354 | ) 355 | 356 | plt.imsave( 357 | save_animation_path + "/rho" + str(snap).zfill(6) + ".png", 358 | np.concatenate(images, axis=1), 359 | cmap="jet", 360 | vmin=0.8, 361 | vmax=2.2, 362 | ) 363 | 364 | 365 | if __name__ == "__main__": 366 | main() 367 | -------------------------------------------------------------------------------- /euler_numpy.py: -------------------------------------------------------------------------------- 1 | # A simple example of solving the Euler equations with Numpy 2 | # Philip Mocz (2024) 3 | 4 | import os 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import time 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--resolution", type=int, default=1024) # 1024 512 # 256 # 128 # 64 12 | args = parser.parse_args() 13 | 14 | 15 | def get_conserved(rho, vx, vy, P, gamma, vol): 16 | """Calculate the conserved variables from the primitive variables""" 17 | 18 | Mass = rho * vol 19 | Momx = rho * vx * vol 20 | Momy = rho * vy * vol 21 | Energy = (P / (gamma - 1) + 0.5 * rho * (vx**2 + vy**2)) * vol 22 | 23 | return Mass, Momx, Momy, Energy 24 | 25 | 26 | def get_primitive(Mass, Momx, Momy, Energy, gamma, vol): 27 | """Calculate the primitive variable from the conserved variables""" 28 | 29 | rho = Mass / vol 30 | vx = Momx / rho / vol 31 | vy = Momy / rho / vol 32 | P = (Energy / vol - 0.5 * rho * (vx**2 + vy**2)) * (gamma - 1) 33 | 34 | return rho, vx, vy, P 35 | 36 | 37 | def get_gradient(f, dx): 38 | """Calculate the gradients of a field""" 39 | 40 | # (right - left) / 2dx 41 | f_dx = (np.roll(f, -1, axis=0) - np.roll(f, 1, axis=0)) / (2 * dx) 42 | f_dy = (np.roll(f, -1, axis=1) - np.roll(f, 1, axis=1)) / (2 * dx) 43 | 44 | return f_dx, f_dy 45 | 46 | 47 | def extrapolate_to_face(f, f_dx, f_dy, dx): 48 | """Extrapolate the field from face centers to faces using gradients""" 49 | 50 | f_XL = f - f_dx * dx / 2 51 | f_XL = np.roll(f_XL, -1, axis=0) # right/up roll 52 | f_XR = f + f_dx * dx / 2 53 | 54 | f_YL = f - f_dy * dx / 2 55 | f_YL = np.roll(f_YL, -1, axis=1) 56 | f_YR = f + f_dy * dx / 2 57 | 58 | return f_XL, f_XR, f_YL, f_YR 59 | 60 | 61 | def apply_fluxes(F, flux_F_X, flux_F_Y, dx, dt): 62 | """Apply fluxes to conserved variables to update solution state""" 63 | 64 | F += -dt * dx * flux_F_X 65 | F += dt * dx * np.roll(flux_F_X, 1, axis=0) # left/down roll 66 | F += -dt * dx * flux_F_Y 67 | F += dt * dx * np.roll(flux_F_Y, 1, axis=1) 68 | 69 | return F 70 | 71 | 72 | def get_flux(rho_L, rho_R, vx_L, vx_R, vy_L, vy_R, P_L, P_R, gamma): 73 | """Calculate fluxes between 2 states with local Lax-Friedrichs/Rusanov rule""" 74 | 75 | # left and right energies 76 | en_L = P_L / (gamma - 1) + 0.5 * rho_L * (vx_L**2 + vy_L**2) 77 | en_R = P_R / (gamma - 1) + 0.5 * rho_R * (vx_R**2 + vy_R**2) 78 | 79 | # compute star (averaged) states 80 | rho_star = 0.5 * (rho_L + rho_R) 81 | momx_star = 0.5 * (rho_L * vx_L + rho_R * vx_R) 82 | momy_star = 0.5 * (rho_L * vy_L + rho_R * vy_R) 83 | en_star = 0.5 * (en_L + en_R) 84 | 85 | P_star = (gamma - 1) * (en_star - 0.5 * (momx_star**2 + momy_star**2) / rho_star) 86 | 87 | # compute fluxes (local Lax-Friedrichs/Rusanov) 88 | flux_Mass = momx_star 89 | flux_Momx = momx_star**2 / rho_star + P_star 90 | flux_Momy = momx_star * momy_star / rho_star 91 | flux_Energy = (en_star + P_star) * momx_star / rho_star 92 | 93 | # find wavespeeds 94 | C_L = np.sqrt(gamma * P_L / rho_L) + np.abs(vx_L) 95 | C_R = np.sqrt(gamma * P_R / rho_R) + np.abs(vx_R) 96 | C = np.maximum(C_L, C_R) 97 | 98 | # add stabilizing diffusive term 99 | flux_Mass -= C * 0.5 * (rho_L - rho_R) 100 | flux_Momx -= C * 0.5 * (rho_L * vx_L - rho_R * vx_R) 101 | flux_Momy -= C * 0.5 * (rho_L * vy_L - rho_R * vy_R) 102 | flux_Energy -= C * 0.5 * (en_L - en_R) 103 | 104 | return flux_Mass, flux_Momx, flux_Momy, flux_Energy 105 | 106 | 107 | def update(Mass, Momx, Momy, Energy, vol, dx, gamma, courant_fac): 108 | """Take a simulation timestep""" 109 | 110 | # get Primitive variables 111 | rho, vx, vy, P = get_primitive(Mass, Momx, Momy, Energy, gamma, vol) 112 | 113 | # get time step (CFL) = dx / max signal speed 114 | dt = courant_fac * np.min(dx / (np.sqrt(gamma * P / rho) + np.sqrt(vx**2 + vy**2))) 115 | 116 | # calculate gradients 117 | rho_dx, rho_dy = get_gradient(rho, dx) 118 | vx_dx, vx_dy = get_gradient(vx, dx) 119 | vy_dx, vy_dy = get_gradient(vy, dx) 120 | P_dx, P_dy = get_gradient(P, dx) 121 | 122 | # extrapolate half-step in time 123 | rho_prime = rho - 0.5 * dt * (vx * rho_dx + rho * vx_dx + vy * rho_dy + rho * vy_dy) 124 | vx_prime = vx - 0.5 * dt * (vx * vx_dx + vy * vx_dy + (1 / rho) * P_dx) 125 | vy_prime = vy - 0.5 * dt * (vx * vy_dx + vy * vy_dy + (1 / rho) * P_dy) 126 | P_prime = P - 0.5 * dt * (gamma * P * (vx_dx + vy_dy) + vx * P_dx + vy * P_dy) 127 | 128 | # extrapolate in space to face centers 129 | rho_XL, rho_XR, rho_YL, rho_YR = extrapolate_to_face(rho_prime, rho_dx, rho_dy, dx) 130 | vx_XL, vx_XR, vx_YL, vx_YR = extrapolate_to_face(vx_prime, vx_dx, vx_dy, dx) 131 | vy_XL, vy_XR, vy_YL, vy_YR = extrapolate_to_face(vy_prime, vy_dx, vy_dy, dx) 132 | P_XL, P_XR, P_YL, P_YR = extrapolate_to_face(P_prime, P_dx, P_dy, dx) 133 | 134 | # compute fluxes (local Lax-Friedrichs/Rusanov) 135 | flux_Mass_X, flux_Momx_X, flux_Momy_X, flux_Energy_X = get_flux( 136 | rho_XL, rho_XR, vx_XL, vx_XR, vy_XL, vy_XR, P_XL, P_XR, gamma 137 | ) 138 | flux_Mass_Y, flux_Momy_Y, flux_Momx_Y, flux_Energy_Y = get_flux( 139 | rho_YL, rho_YR, vy_YL, vy_YR, vx_YL, vx_YR, P_YL, P_YR, gamma 140 | ) 141 | 142 | # update solution 143 | Mass = apply_fluxes(Mass, flux_Mass_X, flux_Mass_Y, dx, dt) 144 | Momx = apply_fluxes(Momx, flux_Momx_X, flux_Momx_Y, dx, dt) 145 | Momy = apply_fluxes(Momy, flux_Momy_X, flux_Momy_Y, dx, dt) 146 | Energy = apply_fluxes(Energy, flux_Energy_X, flux_Energy_Y, dx, dt) 147 | 148 | return Mass, Momx, Momy, Energy, dt, rho 149 | 150 | 151 | def main(): 152 | """Finite Volume simulation""" 153 | 154 | # Simulation parameters 155 | N = args.resolution 156 | boxsize = 1.0 157 | gamma = 5.0 / 3.0 # ideal gas gamma 158 | courant_fac = 0.4 159 | t_stop = 2.0 160 | save_freq = 0.1 161 | save_animation_path = "output_euler_numpy_" + str(N) + "double" 162 | 163 | # Mesh 164 | dx = boxsize / N 165 | vol = dx**2 166 | xlin = np.linspace(0.5 * dx, boxsize - 0.5 * dx, N) 167 | X, Y = np.meshgrid(xlin, xlin, indexing="ij") 168 | 169 | # Generate Initial Conditions - opposite moving streams with perturbation 170 | w0 = 0.1 171 | sigma = 0.05 / np.sqrt(2.0) 172 | rho = 1.0 + (np.abs(Y - 0.5) < 0.25) 173 | vx = -0.5 + (np.abs(Y - 0.5) < 0.25) 174 | vy = ( 175 | w0 176 | * np.sin(4 * np.pi * X) 177 | * ( 178 | np.exp(-((Y - 0.25) ** 2) / (2 * sigma**2)) 179 | + np.exp(-((Y - 0.75) ** 2) / (2 * sigma**2)) 180 | ) 181 | ) 182 | P = 2.5 * np.ones(X.shape) 183 | 184 | # Get conserved variables 185 | Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P, gamma, vol) 186 | 187 | # Make animation directory if it doesn't exist 188 | if not os.path.exists(save_animation_path): 189 | os.makedirs(save_animation_path, exist_ok=True) 190 | 191 | # Simulation Main Loop 192 | tic = time.time() 193 | t = 0 194 | output_counter = 0 195 | n_iter = 0 196 | save_freq = 0.05 197 | while t < t_stop: 198 | 199 | # Time step 200 | Mass, Momx, Momy, Energy, dt, rho = update( 201 | Mass, Momx, Momy, Energy, vol, dx, gamma, courant_fac 202 | ) 203 | 204 | # determine if we should save the plot 205 | save_plot = False 206 | if t + dt > output_counter * save_freq: 207 | save_plot = True 208 | output_counter += 1 209 | 210 | # update time 211 | t += dt 212 | 213 | # update iteration counter 214 | n_iter += 1 215 | 216 | # save plot 217 | if save_plot: 218 | plt.imsave( 219 | save_animation_path + "/rho" + str(output_counter).zfill(6) + ".png", 220 | np.rot90(rho), 221 | cmap="jet", 222 | vmin=0.8, 223 | vmax=2.2, 224 | ) 225 | 226 | # Print progress 227 | print("[it=" + str(n_iter) + " t=" + "{:.6f}".format(t) + "]") 228 | print( 229 | " saved state " 230 | + str(output_counter).zfill(6) 231 | + " of " 232 | + str(int(np.ceil(t_stop / save_freq))) 233 | ) 234 | 235 | # Print million updates per second 236 | cell_updates = X.shape[0] * X.shape[1] * n_iter 237 | total_time = time.time() - tic 238 | mcups = cell_updates / (1e6 * total_time) 239 | print(" million cell updates / second: ", mcups) 240 | 241 | print("Total time: ", total_time) 242 | 243 | 244 | if __name__ == "__main__": 245 | main() 246 | -------------------------------------------------------------------------------- /make_gif.sh: -------------------------------------------------------------------------------- 1 | 2 | ffmpeg -i output_euler_distributed_4096single/rho%06d.png output_euler_distributed_4096single.mp4 3 | 4 | ffmpeg -y -i output_euler_distributed_4096single.mp4 -vf fps=20,scale=300:-1:flags=lanczos,palettegen palette.png 5 | 6 | ffmpeg -i output_euler_distributed_4096single.mp4 -i palette.png -filter_complex "fps=10,scale=300:-1:flags=lanczos[x];[x][1:v]paletteuse" output_euler_distributed_4096single.gif 7 | 8 | rm palette.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax[cuda12] 2 | numpy 3 | -------------------------------------------------------------------------------- /results/plot_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import json 4 | 5 | 6 | # Plot scaling results of1024euler_distributed.py 7 | # mcups = million cell updates per second 8 | # time in seconds 9 | 10 | with open("runs.json", "r") as file: 11 | runs = json.load(file) 12 | 13 | 14 | def main(): 15 | 16 | # Plot scaling results 17 | 18 | # plot macbook mcups vs resolution for jax and numpy 19 | fig, ax = plt.subplots() 20 | libraries = ["jax", "jax", "numpy"] 21 | precisions = ["single", "double", "double"] 22 | styles = ["gs-", "bs-", "ro-"] 23 | for lib, prec, style in zip(libraries, precisions, styles): 24 | n_cells = [ 25 | r["resolution"] ** 2 26 | for r in runs 27 | if r["computer"] == "macbook" 28 | and r["library"] == lib 29 | and r["precision"] == prec 30 | ] 31 | mcups = [ 32 | r["mcups"] 33 | for r in runs 34 | if r["computer"] == "macbook" 35 | and r["library"] == lib 36 | and r["precision"] == prec 37 | ] 38 | ax.plot(n_cells, mcups, style, label=lib + " (" + prec + "-precision)") 39 | ax.set_xlabel("problem size: # cells") 40 | ax.set_ylabel("million cell updates per second") 41 | ax.set_xscale("log") 42 | ax.set_yscale("log") 43 | ax.set_xlim([2e3, 2e6]) 44 | ax.set_ylim([3e0, 1e2]) 45 | ax.set_title("Macbook M3 Max - strong scaling") 46 | ax.legend() 47 | plt.show() 48 | fig.savefig("scaling_strong.png") 49 | 50 | # plot rusty mcups vs number of GPUs(/resolution) for jax 51 | fig, ax = plt.subplots() 52 | libraries = ["jax", "jax"] 53 | precisions = ["single", "double"] 54 | styles = ["gs-", "bs-"] 55 | for lib, prec, style in zip(libraries, precisions, styles): 56 | n_devices = [ 57 | r["n_devices"] 58 | for r in runs 59 | if r["computer"] == "rusty" 60 | and r["library"] == lib 61 | and r["precision"] == prec 62 | ] 63 | mcups = [ 64 | r["mcups"] 65 | for r in runs 66 | if r["computer"] == "rusty" 67 | and r["library"] == lib 68 | and r["precision"] == prec 69 | ] 70 | ax.plot(n_devices, mcups, style, label=lib + " (" + prec + "-precision)") 71 | ax.set_xlabel("# gpus (# cells)") 72 | ax.set_ylabel("million cell updates per second") 73 | ax.set_xscale("log") 74 | ax.set_yscale("log") 75 | ax.set_xlim([0.5, 32]) 76 | ax.set_xticks([1, 4, 16]) 77 | ax.set_xticklabels( 78 | [ 79 | "1 (4096^2)", 80 | "4 (8192^2)", 81 | "16 (16384^2)", 82 | ] 83 | ) 84 | ax.set_ylim([5e2, 3e4]) 85 | ax.set_title("Rusty - weak scaling") 86 | ax.legend() 87 | plt.show() 88 | fig.savefig("scaling_weak.png") 89 | 90 | return 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /results/result_16384_single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pmocz/jax-euler-benchmarks/f47dedf0ae61e9bbb0591159f9edfe5271c3d393/results/result_16384_single.png -------------------------------------------------------------------------------- /results/runs.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "computer": "macbook", 4 | "chip": "M3 Max", 5 | "library": "jax", 6 | "precision": "double", 7 | "resolution": 64, 8 | "n_devices": 1, 9 | "n_nodes": 1, 10 | "mcups": 4.6, 11 | "iterations": 874, 12 | "total_time": 0.7 13 | }, 14 | { 15 | "computer": "macbook", 16 | "chip": "M3 Max", 17 | "library": "jax", 18 | "precision": "double", 19 | "resolution": 128, 20 | "n_devices": 1, 21 | "n_nodes": 1, 22 | "mcups": 15.0, 23 | "iterations": 1813, 24 | "total_time": 1.9 25 | }, 26 | { 27 | "computer": "macbook", 28 | "chip": "M3 Max", 29 | "library": "jax", 30 | "precision": "double", 31 | "resolution": 256, 32 | "n_devices": 1, 33 | "n_nodes": 1, 34 | "mcups": 22.2, 35 | "iterations": 3694, 36 | "total_time": 10.9 37 | }, 38 | { 39 | "computer": "macbook", 40 | "chip": "M3 Max", 41 | "library": "jax", 42 | "precision": "double", 43 | "resolution": 512, 44 | "n_devices": 1, 45 | "n_nodes": 1, 46 | "mcups": 27.0, 47 | "iterations": 7627, 48 | "total_time": 74.0 49 | }, 50 | { 51 | "computer": "macbook", 52 | "chip": "M3 Max", 53 | "library": "jax", 54 | "precision": "double", 55 | "resolution": 1024, 56 | "n_devices": 1, 57 | "n_nodes": 1, 58 | "mcups": 36.2, 59 | "iterations": 15424, 60 | "total_time": 446.2 61 | }, 62 | { 63 | "computer": "macbook", 64 | "chip": "M3 Max", 65 | "library": "numpy", 66 | "precision": "double", 67 | "resolution": 64, 68 | "n_devices": 1, 69 | "n_nodes": 1, 70 | "mcups": 7.4, 71 | "iterations": 874, 72 | "total_time": 0.4 73 | }, 74 | { 75 | "computer": "macbook", 76 | "chip": "M3 Max", 77 | "library": "jax", 78 | "precision": "single", 79 | "resolution": 64, 80 | "n_devices": 1, 81 | "n_nodes": 1, 82 | "mcups": 4.1, 83 | "iterations": 874, 84 | "total_time": 0.8 85 | }, 86 | { 87 | "computer": "macbook", 88 | "chip": "M3 Max", 89 | "library": "jax", 90 | "precision": "single", 91 | "resolution": 128, 92 | "n_devices": 1, 93 | "n_nodes": 1, 94 | "mcups": 17.9, 95 | "iterations": 1813, 96 | "total_time": 1.6 97 | }, 98 | { 99 | "computer": "macbook", 100 | "chip": "M3 Max", 101 | "library": "jax", 102 | "precision": "single", 103 | "resolution": 256, 104 | "n_devices": 1, 105 | "n_nodes": 1, 106 | "mcups": 31.8, 107 | "iterations": 3694, 108 | "total_time": 7.5 109 | }, 110 | { 111 | "computer": "macbook", 112 | "chip": "M3 Max", 113 | "library": "jax", 114 | "precision": "single", 115 | "resolution": 512, 116 | "n_devices": 1, 117 | "n_nodes": 1, 118 | "mcups": 53.6, 119 | "iterations": 7627, 120 | "total_time": 37.2 121 | }, 122 | { 123 | "computer": "macbook", 124 | "chip": "M3 Max", 125 | "library": "jax", 126 | "precision": "single", 127 | "resolution": 1024, 128 | "n_devices": 1, 129 | "n_nodes": 1, 130 | "mcups": 57.8, 131 | "iterations": 15426, 132 | "total_time": 279.4 133 | }, 134 | { 135 | "computer": "macbook", 136 | "chip": "M3 Max", 137 | "library": "numpy", 138 | "precision": "double", 139 | "resolution": 128, 140 | "n_devices": 1, 141 | "n_nodes": 1, 142 | "mcups": 10.8, 143 | "iterations": 1813, 144 | "total_time": 2.7 145 | }, 146 | { 147 | "computer": "macbook", 148 | "chip": "M3 Max", 149 | "library": "numpy", 150 | "precision": "double", 151 | "resolution": 256, 152 | "n_devices": 1, 153 | "n_nodes": 1, 154 | "mcups": 12.0, 155 | "iterations": 3694, 156 | "total_time": 20.1 157 | }, 158 | { 159 | "computer": "macbook", 160 | "chip": "M3 Max", 161 | "library": "numpy", 162 | "precision": "double", 163 | "resolution": 512, 164 | "n_devices": 1, 165 | "n_nodes": 1, 166 | "mcups": 10.2, 167 | "iterations": 7627, 168 | "total_time": 194.4 169 | }, 170 | { 171 | "computer": "macbook", 172 | "chip": "M3 Max", 173 | "library": "numpy", 174 | "precision": "double", 175 | "resolution": 1024, 176 | "n_devices": 1, 177 | "n_nodes": 1, 178 | "mcups": 9.9, 179 | "iterations": 15424, 180 | "total_time": 1632.3 181 | }, 182 | { 183 | "computer": "rusty", 184 | "chip": "A100", 185 | "library": "jax", 186 | "precision": "single", 187 | "resolution": 4096, 188 | "n_devices": 1, 189 | "n_nodes": 1, 190 | "mcups": 1298.9, 191 | "iterations": 66515, 192 | "total_time": 859.0 193 | }, 194 | { 195 | "computer": "rusty", 196 | "chip": "A100", 197 | "library": "jax", 198 | "precision": "single", 199 | "resolution": 8192, 200 | "n_devices": 4, 201 | "n_nodes": 1, 202 | "mcups": 6070.0, 203 | "iterations": 137145, 204 | "total_time": 1516.2 205 | }, 206 | { 207 | "computer": "rusty", 208 | "chip": "A100", 209 | "library": "jax", 210 | "precision": "single", 211 | "resolution": 16384, 212 | "n_devices": 16, 213 | "n_nodes": 4, 214 | "mcups": 19334.5, 215 | "iterations": 277300, 216 | "total_time": 3849.9 217 | }, 218 | { 219 | "computer": "rusty", 220 | "chip": "A100", 221 | "library": "jax", 222 | "precision": "double", 223 | "resolution": 4096, 224 | "n_devices": 1, 225 | "n_nodes": 1, 226 | "mcups": 652.2, 227 | "iterations": 65277, 228 | "total_time": 1679.0 229 | }, 230 | { 231 | "computer": "rusty", 232 | "chip": "A100", 233 | "library": "jax", 234 | "precision": "double", 235 | "resolution": 8192, 236 | "n_devices": 4, 237 | "n_nodes": 1, 238 | "mcups": 2592.2, 239 | "iterations": 134313, 240 | "total_time": 3477.1 241 | }, 242 | { 243 | "computer": "rusty", 244 | "chip": "A100", 245 | "library": "jax", 246 | "precision": "double", 247 | "resolution": 16384, 248 | "n_devices": 16, 249 | "n_nodes": 4, 250 | "mcups": 10141.3, 251 | "iterations": 276335, 252 | "total_time": 7314.4 253 | } 254 | ] -------------------------------------------------------------------------------- /results/scaling_strong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pmocz/jax-euler-benchmarks/f47dedf0ae61e9bbb0591159f9edfe5271c3d393/results/scaling_strong.png -------------------------------------------------------------------------------- /results/scaling_weak.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pmocz/jax-euler-benchmarks/f47dedf0ae61e9bbb0591159f9edfe5271c3d393/results/scaling_weak.png -------------------------------------------------------------------------------- /sbatch_profile.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | #SBATCH --job-name=euler_parallel 3 | #SBATCH --output=slurm-%j.out 4 | #SBATCH --error=slurm-%j.err 5 | #SBATCH --partition gpu 6 | #SBATCH --constraint=a100 7 | #SBATCH --nodes=1 8 | #SBATCH --ntasks-per-node=1 9 | #SBATCH --gpus-per-node=1 10 | #SBATCH --cpus-per-task=2 11 | #SBATCH --mem=16G 12 | #SBATCH --time=00-00:10 13 | 14 | module purge 15 | module load cuda 16 | module load python 17 | 18 | export PYTHONUNBUFFERED=TRUE 19 | 20 | source $VENVDIR/my-jax-venv/bin/activate 21 | 22 | srun nsys profile --cuda-graph-trace=node --stats=true python euler_distributed.py --resolution=512 23 | -------------------------------------------------------------------------------- /sbatch_rusty_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | #SBATCH --job-name=euler_parallel 3 | #SBATCH --output=slurm-%j.out 4 | #SBATCH --error=slurm-%j.err 5 | #SBATCH --partition gpu 6 | #SBATCH --constraint=a100 7 | #SBATCH --nodes=2 8 | #SBATCH --ntasks-per-node=4 9 | #SBATCH --gpus-per-node=4 10 | #SBATCH --cpus-per-task=2 11 | #SBATCH --mem=16G 12 | #SBATCH --time=00-01:00 13 | 14 | module purge 15 | module load cuda 16 | module load python 17 | 18 | export PYTHONUNBUFFERED=TRUE 19 | 20 | source $VENVDIR/my-jax-venv/bin/activate 21 | 22 | srun python euler_distributed.py --resolution=8192 --double 23 | --------------------------------------------------------------------------------