├── .flake8 ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── backends.py ├── benchmarks ├── __init__.py ├── equation_of_state │ ├── README.md │ ├── __init__.py │ ├── eos_aesara.py │ ├── eos_cupy.py │ ├── eos_jax.py │ ├── eos_numba.py │ ├── eos_numpy.py │ ├── eos_pytorch.py │ ├── eos_taichi.py │ └── eos_tensorflow.py ├── isoneutral_mixing │ ├── README.md │ ├── __init__.py │ ├── isoneutral_aesara.py │ ├── isoneutral_cupy.py │ ├── isoneutral_jax.py │ ├── isoneutral_numba.py │ ├── isoneutral_numpy.py │ ├── isoneutral_pytorch.py │ └── isoneutral_taichi.py └── turbulent_kinetic_energy │ ├── README.md │ ├── __init__.py │ ├── tke_jax.py │ ├── tke_numba.py │ ├── tke_numpy.py │ └── tke_pytorch.py ├── environment-cpu.yml ├── environment-gpu.yml ├── plot.py ├── results ├── aws-plots │ ├── bench-equation_of_state-CPU.png │ ├── bench-equation_of_state-GPU.png │ ├── bench-isoneutral_mixing-CPU.png │ ├── bench-isoneutral_mixing-GPU.png │ ├── bench-turbulent_kinetic_energy-CPU.png │ └── bench-turbulent_kinetic_energy-GPU.png ├── aws.md ├── colab.md ├── magni-plots │ ├── bench-equation_of_state-CPU.png │ ├── bench-equation_of_state-GPU.png │ ├── bench-isoneutral_mixing-CPU.png │ ├── bench-isoneutral_mixing-GPU.png │ ├── bench-turbulent_kinetic_energy-CPU.png │ └── bench-turbulent_kinetic_energy-GPU.png ├── magni-run-all.sh ├── magni.md └── pyhpc_benchmarks_colab.ipynb ├── run.py └── utilities.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = __init__.py 3 | ignore = 4 | # Comments 5 | E26, 6 | # Lines starting / ending with binary operators 7 | W503, 8 | W504, 9 | max-line-length = 120 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # OSX temporary fiile 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # VSCode 110 | .vscode/ 111 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Häfner" 5 | given-names: "Dion" 6 | orcid: "https://orcid.org/0000-0002-4465-7317" 7 | title: "pyhpc-benchmarks" 8 | version: 3.0 9 | date-released: 2021-10-28 10 | url: "https://github.com/dionhaefner/pyhpc-benchmarks" 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/212333820.svg)](https://zenodo.org/badge/latestdoi/212333820) 2 | 3 | # HPC benchmarks for Python 4 | 5 | This is a suite of benchmarks to test the *sequential CPU* and GPU performance of various computational backends with Python frontends. 6 | 7 | Specifically, we want to test which high-performance backend is best for *geophysical* (finite-difference based) *simulations*. 8 | 9 | **Contents** 10 | 11 | - [FAQ](#faq) 12 | - [Installation](#environment-setup) 13 | - [Usage](#usage) 14 | - [Example results](#example-results) 15 | - [Conclusion](#conclusion) 16 | - [Contributing](#contributing) 17 | 18 | ## FAQ 19 | 20 | ### Why? 21 | 22 | The scientific Python ecosystem is thriving, but high-performance computing in Python isn't really a thing yet. 23 | We try to change this [with our pure Python ocean simulator Veros](https://github.com/dionhaefner/veros), but which backend should we use for computations? 24 | 25 | Tremendous amounts of time and resources go into the development of Python frontends to high-performance backends, 26 | but those are usually tailored towards deep learning. We wanted to see whether we can profit from those advances, by 27 | (ab-)using these libraries for geophysical modelling. 28 | 29 | ### Why do the benchmarks look so weird? 30 | 31 | These are more or less verbatim copies from [Veros](https://github.com/dionhaefner/veros) (i.e., actual parts of a physical model). 32 | Most earth system and climate model components are based on finite-difference schemes to compute derivatives. This can be represented 33 | in vectorized form by index shifts of arrays (such as `0.5 * (arr[1:] + arr[:-1])`, the first-order derivative of `arr` at every point). The most common index range is `[2:-2]`, which represents the full domain (the two outermost grid cells are overlap / "ghost cells" that allow us to shift the array across the boundary). 34 | 35 | Now, maths is difficult, and numerics are weird. When many different physical quantities (defined on different grids) interact, things 36 | get messy very fast. 37 | 38 | ### Why only test sequential CPU performance? 39 | 40 | Two reasons: 41 | - I was curious to see how good the compilers are without being able to fall back to thread parallelism. 42 | - In many physical models, it is pretty straightforward to parallelize the model "by hand" via MPI. 43 | Therefore, we are not really dependent on good parallel performance out of the box. 44 | 45 | ### Which backends are currently supported? 46 | 47 | - [NumPy](https://numpy.org) (CPU only) 48 | - [Numba](https://numba.pydata.org) (CPU only) 49 | - [Aesara](https://github.com/aesara-devs/aesara) (CPU only) 50 | - [Jax](https://github.com/google/jax) 51 | - [Tensorflow](https://www.tensorflow.org) 52 | - [Pytorch](https://pytorch.org) 53 | - [CuPy](https://cupy.chainer.org/) (GPU only) 54 | - [Taichi](https://www.taichi-lang.org/) 55 | 56 | (not every backend is available for every benchmark) 57 | 58 | ### What is included in the measurements? 59 | 60 | Pure time spent number crunching. Preparing the inputs, copying stuff from and to GPU, compilation time, time it takes to check results etc. are excluded. 61 | This is based on the assumption that these things are only done a few times per simulation (i.e., that their cost is 62 | amortized during long-running simulations). 63 | 64 | ### How does this compare to a low-level implementation? 65 | 66 | As a rule of thumb (from our experience with Veros), the performance of a Fortran implementation is very close to that of the Numba backend, or ~3 times faster than NumPy. 67 | 68 | 69 | ## Environment setup 70 | 71 | For CPU: 72 | 73 | ```bash 74 | $ conda env create -f environment-cpu.yml 75 | $ conda activate pyhpc-bench-cpu 76 | ``` 77 | 78 | GPU: 79 | 80 | ```bash 81 | $ conda env create -f environment-gpu.yml 82 | $ conda activate pyhpc-bench-gpu 83 | ``` 84 | 85 | If you prefer to install things by hand, just have a look at the environment files to see what you need. You don't need to install all backends; if a module is unavailable, it is skipped automatically. 86 | 87 | ## Usage 88 | 89 | Your entrypoint is the script `run.py`: 90 | 91 | ```bash 92 | $ python run.py --help 93 | Usage: run.py [OPTIONS] BENCHMARK 94 | 95 | HPC benchmarks for Python 96 | 97 | Usage: 98 | 99 | $ python run.py benchmarks/ 100 | 101 | Examples: 102 | 103 | $ taskset -c 0 python run.py benchmarks/equation_of_state 104 | 105 | $ python run.py benchmarks/equation_of_state -b numpy -b jax --device 106 | gpu 107 | 108 | More information: 109 | 110 | https://github.com/dionhaefner/pyhpc-benchmarks 111 | 112 | Options: 113 | -s, --size INTEGER Run benchmark for this array size 114 | (repeatable) [default: 4096, 16384, 65536, 115 | 262144, 1048576, 4194304] 116 | -b, --backend [numpy|cupy|jax|aesara|numba|pytorch|taichi|tensorflow] 117 | Run benchmark with this backend (repeatable) 118 | [default: run all backends] 119 | -r, --repetitions INTEGER Fixed number of iterations to run for each 120 | size and backend [default: auto-detect] 121 | --burnin INTEGER Number of initial iterations that are 122 | disregarded for final statistics [default: 123 | 1] 124 | --device [cpu|gpu|tpu] Run benchmarks on given device where 125 | supported by the backend [default: cpu] 126 | --help Show this message and exit. 127 | ``` 128 | 129 | Benchmarks are run for all combinations of the chosen sizes (`-s`) and backends (`-b`), in random order. 130 | 131 | ### CPU 132 | 133 | Some backends refuse to be confined to a single thread, so I recommend you wrap your benchmarks 134 | in `taskset` to set processor affinity to a single core (only works on Linux): 135 | 136 | ```bash 137 | $ conda activate pyhpc-bench-cpu 138 | $ taskset -c 0 python run.py benchmarks/ 139 | ``` 140 | 141 | ### GPU 142 | 143 | Some backends use all available GPUs by default, some don't. If you have multiple GPUs, you can set the 144 | one to be used through `CUDA_VISIBLE_DEVICES`, so keep things fair. 145 | 146 | Some backends are greedy with allocating memory. On GPU, you can only run one backend at a time (add NumPy for reference): 147 | 148 | ```bash 149 | $ conda activate pyhpc-bench-gpu 150 | $ export CUDA_VISIBLE_DEVICES="0" 151 | $ for backend in jax cupy pytorch tensorflow; do 152 | ... python run benchmarks/ --device gpu -b $backend -b numpy -s 10_000_000 153 | ... done 154 | ``` 155 | 156 | ## Example results 157 | 158 | ### Summary 159 | 160 | #### Equation of state 161 | 162 |

163 | 164 | 165 |

166 | 167 | #### Isoneutral mixing 168 | 169 |

170 | 171 | 172 |

173 | 174 | #### Turbulent kinetic energy 175 | 176 |

177 | 178 | 179 |

180 | 181 | ### Full reports 182 | 183 | - [Example results on EC2 with Tesla V100 GPU](/results/aws.md) (more reliable) 184 | - [Example results on Google Colab](/results/colab.md) (easier to reproduce) 185 | - [Example results on bare metal](/results/magni.md) (most reliable, but outdated) 186 | 187 | ## Conclusion 188 | 189 | Lessons I learned by assembling these benchmarks: (your mileage may vary) 190 | 191 | - The performance of JAX is very competitive, both on GPU and CPU. It is consistently among the top implementations on both platforms. 192 | - Pytorch performs very well on GPU for large problems (slightly better than JAX), but its CPU performance is not great for tasks with many slicing operations. 193 | - Numba is a great choice on CPU if you don't mind writing explicit for loops (which can be more readable than a vectorized implementation), being slightly faster than JAX with little effort. 194 | - JAX performance on GPU seems to be quite hardware dependent. JAX performancs significantly better (relatively speaking) on a Tesla P100 than a Tesla K80. 195 | - If you have embarrasingly parallel workloads, speedups of > 1000x are easy to achieve on high-end GPUs. 196 | - TPUs are catching up to GPUs. We can now get similar performance to a high-end GPU on these workloads. 197 | - Tensorflow is not great for applications like ours, since it lacks tools to apply partial updates to tensors (such as `tensor[2:-2] = 0.`). 198 | - If you use Tensorflow on CPU, make sure to use XLA (`experimental_compile`) for tremendous speedups. 199 | - CuPy is nice! Often you don't need to change anything in your NumPy code to have it run on GPU (with decent, but not outstanding performance). 200 | - Reaching Fortran performance on CPU for non-trivial tasks is hard :) 201 | 202 | ## Contributing 203 | 204 | Community contributions are encouraged! Whether you want to donate another benchmark, share your experience, optimize an implementation, or suggest another backend - [feel free to ask](https://github.com/dionhaefner/pyhpc-benchmarks/issues) or [open a PR](https://github.com/dionhaefner/pyhpc-benchmarks/pulls). 205 | 206 | ### Adding a new backend 207 | 208 | Adding a new backend is easy! 209 | 210 | Let's assume that you want to add support for a library called `speedygonzales`. All you need to do is this: 211 | 212 | - Implement a benchmark to use your library, e.g. `benchmarks/equation_of_state/eos_speedygonzales.py`. 213 | - Register the benchmark in the respective `__init__.py` file (`benchmarks/equation_of_state/__init__.py`), by adding `"speedygonzales"` to its `__implementations__` tuple. 214 | - Register the backend, by adding its setup function to the `__backends__` dict in [`backends.py`](https://github.com/dionhaefner/pyhpc-benchmarks/blob/master/backends.py). 215 | 216 | A setup function is what is called before every call to your benchmark, and can be used for custom setup and teardown. In the simplest case, it is just 217 | 218 | ```python 219 | def setup_speedygonzales(device='cpu'): 220 | # code to run before benchmark 221 | yield 222 | # code to run after benchmark 223 | ``` 224 | 225 | Then, you can run the benchmark with your new backend: 226 | 227 | ```bash 228 | $ python run.py benchmarks/equation_of_state -b speedygonzales 229 | ``` 230 | -------------------------------------------------------------------------------- /backends.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy 4 | 5 | 6 | def convert_to_numpy(arr, backend, device="cpu"): 7 | """Converts an array or collection of arrays to np.ndarray""" 8 | if isinstance(arr, (list, tuple)): 9 | return [convert_to_numpy(subarr, backend, device) for subarr in arr] 10 | 11 | if type(arr) is numpy.ndarray: 12 | # this is stricter than isinstance, 13 | # we don't want subclasses to get passed through 14 | return arr 15 | 16 | if backend == "cupy": 17 | return arr.get() 18 | 19 | if backend == "jax": 20 | return numpy.asarray(arr) 21 | 22 | if backend == "pytorch": 23 | if device == "gpu": 24 | return numpy.asarray(arr.cpu()) 25 | else: 26 | return numpy.asarray(arr) 27 | 28 | if backend == "tensorflow": 29 | return numpy.asarray(arr) 30 | 31 | if backend == "aesara": 32 | return numpy.asarray(arr) 33 | 34 | if backend == "taichi": 35 | return arr.to_numpy() 36 | 37 | raise RuntimeError( 38 | f"Got unexpected array / backend combination: {type(arr)} / {backend}" 39 | ) 40 | 41 | 42 | class BackendNotSupported(Exception): 43 | pass 44 | 45 | 46 | class BackendConflict(Exception): 47 | pass 48 | 49 | 50 | def check_backend_conflicts(backends, device): 51 | if device == "gpu": 52 | gpu_backends = set(backends) - {"numba", "numpy", "aesara"} 53 | if len(gpu_backends) > 1: 54 | raise BackendConflict( 55 | f"Can only use one GPU backend at the same time (got: {gpu_backends})" 56 | ) 57 | 58 | 59 | class SetupContext: 60 | def __init__(self, f): 61 | self._f = f 62 | self._f_args = (tuple(), dict()) 63 | 64 | def __call__(self, *args, **kwargs): 65 | self._f_args = (args, kwargs) 66 | return self 67 | 68 | def __enter__(self): 69 | self._env = os.environ.copy() 70 | args, kwargs = self._f_args 71 | self._f_iter = iter(self._f(*args, **kwargs)) 72 | 73 | try: 74 | module = next(self._f_iter) 75 | except Exception as e: 76 | raise BackendNotSupported(str(e)) from None 77 | 78 | return module 79 | 80 | def __exit__(self, *args, **kwargs): 81 | try: 82 | next(self._f_iter) 83 | except StopIteration: 84 | pass 85 | os.environ = self._env 86 | 87 | 88 | setup_function = SetupContext 89 | 90 | 91 | # setup function definitions 92 | 93 | 94 | @setup_function 95 | def setup_numpy(device="cpu"): 96 | import numpy 97 | 98 | os.environ.update( 99 | OMP_NUM_THREADS="1", 100 | ) 101 | yield numpy 102 | 103 | 104 | @setup_function 105 | def setup_aesara(device="cpu"): 106 | os.environ.update( 107 | OMP_NUM_THREADS="1", 108 | ) 109 | if device == "gpu": 110 | raise RuntimeError("aesara uses JAX on GPU") 111 | 112 | import aesara 113 | 114 | # clang needs this, aesara#127 115 | aesara.config.gcc__cxxflags = "-Wno-c++11-narrowing" 116 | yield aesara 117 | 118 | 119 | @setup_function 120 | def setup_numba(device="cpu"): 121 | os.environ.update( 122 | OMP_NUM_THREADS="1", 123 | ) 124 | import numba 125 | 126 | yield numba 127 | 128 | 129 | @setup_function 130 | def setup_cupy(device="cpu"): 131 | if device != "gpu": 132 | raise RuntimeError("cupy requires GPU mode") 133 | import cupy 134 | 135 | yield cupy 136 | 137 | 138 | @setup_function 139 | def setup_jax(device="cpu"): 140 | os.environ.update( 141 | XLA_FLAGS=( 142 | "--xla_cpu_multi_thread_eigen=false " 143 | "intra_op_parallelism_threads=1 " 144 | "inter_op_parallelism_threads=1 " 145 | ), 146 | ) 147 | 148 | if device in ("cpu", "gpu"): 149 | os.environ.update(JAX_PLATFORM_NAME=device) 150 | 151 | import jax 152 | 153 | if device == "tpu": 154 | jax.config.update("jax_xla_backend", "tpu_driver") 155 | jax.config.update("jax_backend_target", os.environ.get("JAX_BACKEND_TARGET")) 156 | 157 | if device != "tpu": 158 | # use 64 bit floats (not supported on TPU) 159 | jax.config.update("jax_enable_x64", True) 160 | 161 | if device == "gpu": 162 | assert len(jax.devices()) > 0 163 | 164 | yield jax 165 | 166 | 167 | @setup_function 168 | def setup_pytorch(device="cpu"): 169 | os.environ.update( 170 | OMP_NUM_THREADS="1", 171 | ) 172 | import torch 173 | 174 | if device == "gpu": 175 | assert torch.cuda.is_available() 176 | assert torch.cuda.device_count() > 0 177 | 178 | yield torch 179 | 180 | 181 | @setup_function 182 | def setup_tensorflow(device="cpu"): 183 | os.environ.update( 184 | OMP_NUM_THREADS="1", 185 | ) 186 | import tensorflow as tf 187 | 188 | tf.config.threading.set_inter_op_parallelism_threads(1) 189 | tf.config.threading.set_intra_op_parallelism_threads(1) 190 | 191 | if device == "gpu": 192 | gpus = tf.config.experimental.list_physical_devices("GPU") 193 | assert gpus 194 | else: 195 | tf.config.experimental.set_visible_devices([], "GPU") 196 | 197 | yield tf 198 | 199 | 200 | TAICHI_SETUP_DONE = False 201 | 202 | @setup_function 203 | def setup_taichi(device="cpu"): 204 | global TAICHI_SETUP_DONE 205 | import taichi 206 | 207 | if not TAICHI_SETUP_DONE: 208 | taichi.init( 209 | arch=taichi.cpu if device == "cpu" else taichi.gpu, 210 | cpu_max_num_threads=1, 211 | default_fp=taichi.f64, 212 | ) 213 | TAICHI_SETUP_DONE = True 214 | 215 | yield taichi 216 | 217 | __backends__ = { 218 | "numpy": setup_numpy, 219 | "cupy": setup_cupy, 220 | "jax": setup_jax, 221 | "aesara": setup_aesara, 222 | "numba": setup_numba, 223 | "pytorch": setup_pytorch, 224 | "tensorflow": setup_tensorflow, 225 | "taichi": setup_taichi, 226 | } 227 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/benchmarks/__init__.py -------------------------------------------------------------------------------- /benchmarks/equation_of_state/README.md: -------------------------------------------------------------------------------- 1 | # Equation of state benchmark 2 | 3 | [Gibbs sea water](http://www.teos-10.org/software.htm) equation of state. 4 | The equation of state describes how certain thermodynamic parameters of sea water behave 5 | depending on its temperature, salinity content, and ambient pressure. 6 | 7 | This routine is embarassingly parallel and pure elementary math (i.e., a GPU's dream). 8 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/__init__.py: -------------------------------------------------------------------------------- 1 | import math 2 | import importlib 3 | import functools 4 | 5 | 6 | def generate_inputs(size): 7 | import numpy as np 8 | 9 | np.random.seed(17) 10 | 11 | shape = ( 12 | math.ceil(2 * size ** (1 / 3)), 13 | math.ceil(2 * size ** (1 / 3)), 14 | math.ceil(0.25 * size ** (1 / 3)), 15 | ) 16 | 17 | s = np.random.uniform(1e-2, 10, size=shape) 18 | t = np.random.uniform(-12, 20, size=shape) 19 | p = np.random.uniform(0, 1000, size=(1, 1, shape[-1])) 20 | return s, t, p 21 | 22 | 23 | def try_import(backend): 24 | try: 25 | return importlib.import_module(f".eos_{backend}", __name__) 26 | except ImportError: 27 | return None 28 | 29 | 30 | def get_callable(backend, size, device="cpu"): 31 | backend_module = try_import(backend) 32 | inputs = generate_inputs(size) 33 | if hasattr(backend_module, "prepare_inputs"): 34 | inputs = backend_module.prepare_inputs(*inputs, device=device) 35 | return functools.partial(backend_module.run, *inputs, device=device) 36 | 37 | 38 | __implementations__ = ( 39 | "aesara", 40 | "cupy", 41 | "jax", 42 | "numba", 43 | "numpy", 44 | "pytorch", 45 | "tensorflow", 46 | "taichi", 47 | ) 48 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/eos_aesara.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================================== 3 | in-situ density, dynamic enthalpy and derivatives 4 | from Absolute Salinity and Conservative 5 | Temperature, using the computationally-efficient 48-term expression for 6 | density in terms of SA, CT and p (IOC et al., 2010). 7 | ========================================================================== 8 | """ 9 | 10 | import aesara 11 | 12 | v01 = 9.998420897506056e2 13 | v02 = 2.839940833161907e0 14 | v03 = -3.147759265588511e-2 15 | v04 = 1.181805545074306e-3 16 | v05 = -6.698001071123802e0 17 | v06 = -2.986498947203215e-2 18 | v07 = 2.327859407479162e-4 19 | v08 = -3.988822378968490e-2 20 | v09 = 5.095422573880500e-4 21 | v10 = -1.426984671633621e-5 22 | v11 = 1.645039373682922e-7 23 | v12 = -2.233269627352527e-2 24 | v13 = -3.436090079851880e-4 25 | v14 = 3.726050720345733e-6 26 | v15 = -1.806789763745328e-4 27 | v16 = 6.876837219536232e-7 28 | v17 = -3.087032500374211e-7 29 | v18 = -1.988366587925593e-8 30 | v19 = -1.061519070296458e-11 31 | v20 = 1.550932729220080e-10 32 | v21 = 1.0e0 33 | v22 = 2.775927747785646e-3 34 | v23 = -2.349607444135925e-5 35 | v24 = 1.119513357486743e-6 36 | v25 = 6.743689325042773e-10 37 | v26 = -7.521448093615448e-3 38 | v27 = -2.764306979894411e-5 39 | v28 = 1.262937315098546e-7 40 | v29 = 9.527875081696435e-10 41 | v30 = -1.811147201949891e-11 42 | v31 = -3.303308871386421e-5 43 | v32 = 3.801564588876298e-7 44 | v33 = -7.672876869259043e-9 45 | v34 = -4.634182341116144e-11 46 | v35 = 2.681097235569143e-12 47 | v36 = 5.419326551148740e-6 48 | v37 = -2.742185394906099e-5 49 | v38 = -3.212746477974189e-7 50 | v39 = 3.191413910561627e-9 51 | v40 = -1.931012931541776e-12 52 | v41 = -1.105097577149576e-7 53 | v42 = 6.211426728363857e-10 54 | v43 = -1.119011592875110e-10 55 | v44 = -1.941660213148725e-11 56 | v45 = -1.864826425365600e-14 57 | v46 = 1.119522344879478e-14 58 | v47 = -1.200507748551599e-15 59 | v48 = 6.057902487546866e-17 60 | rho0 = 1024.0 61 | 62 | 63 | def gsw_dHdT(sa, ct, p): 64 | """ 65 | d/dT of dynamic enthalpy, analytical derivative 66 | 67 | sa : Absolute Salinity [g/kg] 68 | ct : Conservative Temperature [deg C] 69 | p : sea pressure [dbar] 70 | """ 71 | t1 = v45 * ct 72 | t2 = 0.2e1 * t1 73 | t3 = v46 * sa 74 | t4 = 0.5 * v12 75 | t5 = v14 * ct 76 | t7 = ct * (v13 + t5) 77 | t8 = 0.5 * t7 78 | t11 = sa * (v15 + v16 * ct) 79 | t12 = 0.5 * t11 80 | t13 = t4 + t8 + t12 81 | t15 = v19 * ct 82 | t19 = v17 + ct * (v18 + t15) + v20 * sa 83 | t20 = 1.0 / t19 84 | t24 = v47 + v48 * ct 85 | t25 = 0.5 * v13 86 | t26 = 1.0 * t5 87 | t27 = sa * v16 88 | t28 = 0.5 * t27 89 | t29 = t25 + t26 + t28 90 | t33 = t24 * t13 91 | t34 = t19 ** 2 92 | t35 = 1.0 / t34 93 | t37 = v18 + 2.0 * t15 94 | t38 = t35 * t37 95 | t48 = ct * (v44 + t1 + t3) 96 | t57 = v40 * ct 97 | t59 = ct * (v39 + t57) 98 | t64 = t13 ** 2 99 | t68 = t20 * t29 100 | t71 = t24 * t64 101 | t74 = v04 * ct 102 | t76 = ct * (v03 + t74) 103 | t79 = v07 * ct 104 | t82 = aesara.tensor.sqrt(sa) 105 | t83 = v11 * ct 106 | t85 = ct * (v10 + t83) 107 | t92 = ( 108 | v01 109 | + ct * (v02 + t76) 110 | + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85))) 111 | ) 112 | t93 = v48 * t92 113 | t105 = ( 114 | v02 115 | + t76 116 | + ct * (v03 + 2.0 * t74) 117 | + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83))) 118 | ) 119 | t106 = t24 * t105 120 | t107 = v44 + t2 + t3 121 | t110 = v43 + t48 122 | t117 = t24 * t92 123 | t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13 124 | t123 = ( 125 | v38 126 | + t59 127 | + ct * (v39 + 2.0 * t57) 128 | + sa * v42 129 | + ( 130 | 4.0 * v48 * t64 * t20 131 | + 8.0 * t33 * t68 132 | - 4.0 * t71 * t38 133 | - t93 134 | - t106 135 | - 2.0 * t107 * t13 136 | - 2.0 * t110 * t29 137 | ) 138 | * t20 139 | - t120 * t35 * t37 140 | ) 141 | t128 = t19 * p 142 | t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128) 143 | t131 = 1.0 / t92 144 | t133 = 1.0 + t130 * t131 145 | t134 = aesara.tensor.log(t133) 146 | t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20 147 | t152 = t37 * p 148 | t156 = t92 ** 2 149 | t165 = v25 * ct 150 | t167 = ct * (v24 + t165) 151 | t169 = ct * (v23 + t167) 152 | t175 = v30 * ct 153 | t177 = ct * (v29 + t175) 154 | t179 = ct * (v28 + t177) 155 | t185 = v35 * ct 156 | t187 = ct * (v34 + t185) 157 | t189 = ct * (v33 + t187) 158 | t199 = t13 * t20 159 | t217 = 2.0 * t117 * t199 - t110 * t92 160 | t234 = ( 161 | v21 162 | + ct * (v22 + t169) 163 | + sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189))) 164 | + t217 * t20 165 | ) 166 | t241 = t64 - t92 * t19 167 | t242 = aesara.tensor.sqrt(t241) 168 | t243 = 1.0 / t242 169 | t244 = t4 + t8 + t12 - t242 170 | t245 = 1.0 / t244 171 | t247 = t4 + t8 + t12 + t242 + t128 172 | t248 = 1.0 / t247 173 | t249 = t242 * t245 * t248 174 | t252 = 1.0 + 2.0 * t128 * t249 175 | t253 = aesara.tensor.log(t252) 176 | t254 = t243 * t253 177 | t259 = t234 * t19 - t143 * t13 178 | t264 = t259 * t20 179 | t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37 180 | t282 = t128 * t242 181 | t283 = t244 ** 2 182 | t287 = t243 * t272 / 2.0 183 | t292 = t247 ** 2 184 | t305 = ( 185 | 0.1e5 186 | * p 187 | * ( 188 | v44 189 | + t2 190 | + t3 191 | - 2.0 * v48 * t13 * t20 192 | - 2.0 * t24 * t29 * t20 193 | + 2.0 * t33 * t38 194 | + 0.5 * v48 * p 195 | ) 196 | * t20 197 | - 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38 198 | + 0.5e4 * t123 * t20 * t134 199 | - 0.5e4 * t143 * t35 * t134 * t37 200 | + 0.5e4 201 | * t143 202 | * t20 203 | * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) 204 | / t133 205 | + 0.5e4 206 | * ( 207 | ( 208 | v22 209 | + t169 210 | + ct * (v23 + t167 + ct * (v24 + 2.0 * t165)) 211 | + sa 212 | * ( 213 | v27 214 | + t179 215 | + ct * (v28 + t177 + ct * (v29 + 2.0 * t175)) 216 | + t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185))) 217 | ) 218 | + ( 219 | 2.0 * t93 * t199 220 | + 2.0 * t106 * t199 221 | + 2.0 * t117 * t68 222 | - 2.0 * t117 * t13 * t35 * t37 223 | - t107 * t92 224 | - t110 * t105 225 | ) 226 | * t20 227 | - t217 * t35 * t37 228 | ) 229 | * t19 230 | + t234 * t37 231 | - t123 * t13 232 | - t143 * t29 233 | ) 234 | * t20 235 | * t254 236 | - 0.5e4 * t259 * t35 * t254 * t37 237 | - 0.25e4 * t264 / t242 / t241 * t253 * t272 238 | + 0.5e4 239 | * t264 240 | * t243 241 | * ( 242 | 2.0 * t152 * t249 243 | + t128 * t243 * t245 * t248 * t272 244 | - 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287) 245 | - 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152) 246 | ) 247 | / t252 248 | ) 249 | 250 | return t305 251 | 252 | 253 | symbolic_inputs = [ 254 | aesara.tensor.dtensor3("s"), 255 | aesara.tensor.dtensor3("t"), 256 | aesara.tensor.TensorType(dtype="float64", broadcastable=(True, True, False))("p"), 257 | ] 258 | gsw_dHdT_aesara = aesara.function(symbolic_inputs, gsw_dHdT(*symbolic_inputs)) 259 | 260 | 261 | def run(sa, ct, p, device="cpu"): 262 | return gsw_dHdT_aesara(sa, ct, p) 263 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/eos_cupy.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================================== 3 | in-situ density, dynamic enthalpy and derivatives 4 | from Absolute Salinity and Conservative 5 | Temperature, using the computationally-efficient 48-term expression for 6 | density in terms of SA, CT and p (IOC et al., 2010). 7 | ========================================================================== 8 | """ 9 | 10 | import cupy as cp 11 | 12 | v01 = 9.998420897506056e2 13 | v02 = 2.839940833161907e0 14 | v03 = -3.147759265588511e-2 15 | v04 = 1.181805545074306e-3 16 | v05 = -6.698001071123802e0 17 | v06 = -2.986498947203215e-2 18 | v07 = 2.327859407479162e-4 19 | v08 = -3.988822378968490e-2 20 | v09 = 5.095422573880500e-4 21 | v10 = -1.426984671633621e-5 22 | v11 = 1.645039373682922e-7 23 | v12 = -2.233269627352527e-2 24 | v13 = -3.436090079851880e-4 25 | v14 = 3.726050720345733e-6 26 | v15 = -1.806789763745328e-4 27 | v16 = 6.876837219536232e-7 28 | v17 = -3.087032500374211e-7 29 | v18 = -1.988366587925593e-8 30 | v19 = -1.061519070296458e-11 31 | v20 = 1.550932729220080e-10 32 | v21 = 1.0e0 33 | v22 = 2.775927747785646e-3 34 | v23 = -2.349607444135925e-5 35 | v24 = 1.119513357486743e-6 36 | v25 = 6.743689325042773e-10 37 | v26 = -7.521448093615448e-3 38 | v27 = -2.764306979894411e-5 39 | v28 = 1.262937315098546e-7 40 | v29 = 9.527875081696435e-10 41 | v30 = -1.811147201949891e-11 42 | v31 = -3.303308871386421e-5 43 | v32 = 3.801564588876298e-7 44 | v33 = -7.672876869259043e-9 45 | v34 = -4.634182341116144e-11 46 | v35 = 2.681097235569143e-12 47 | v36 = 5.419326551148740e-6 48 | v37 = -2.742185394906099e-5 49 | v38 = -3.212746477974189e-7 50 | v39 = 3.191413910561627e-9 51 | v40 = -1.931012931541776e-12 52 | v41 = -1.105097577149576e-7 53 | v42 = 6.211426728363857e-10 54 | v43 = -1.119011592875110e-10 55 | v44 = -1.941660213148725e-11 56 | v45 = -1.864826425365600e-14 57 | v46 = 1.119522344879478e-14 58 | v47 = -1.200507748551599e-15 59 | v48 = 6.057902487546866e-17 60 | rho0 = 1024.0 61 | 62 | 63 | def gsw_dHdT(sa, ct, p): 64 | """ 65 | d/dT of dynamic enthalpy, analytical derivative 66 | 67 | sa : Absolute Salinity [g/kg] 68 | ct : Conservative Temperature [deg C] 69 | p : sea pressure [dbar] 70 | """ 71 | t1 = v45 * ct 72 | t2 = 0.2e1 * t1 73 | t3 = v46 * sa 74 | t4 = 0.5 * v12 75 | t5 = v14 * ct 76 | t7 = ct * (v13 + t5) 77 | t8 = 0.5 * t7 78 | t11 = sa * (v15 + v16 * ct) 79 | t12 = 0.5 * t11 80 | t13 = t4 + t8 + t12 81 | t15 = v19 * ct 82 | t19 = v17 + ct * (v18 + t15) + v20 * sa 83 | t20 = 1.0 / t19 84 | t24 = v47 + v48 * ct 85 | t25 = 0.5 * v13 86 | t26 = 1.0 * t5 87 | t27 = sa * v16 88 | t28 = 0.5 * t27 89 | t29 = t25 + t26 + t28 90 | t33 = t24 * t13 91 | t34 = t19 ** 2 92 | t35 = 1.0 / t34 93 | t37 = v18 + 2.0 * t15 94 | t38 = t35 * t37 95 | t48 = ct * (v44 + t1 + t3) 96 | t57 = v40 * ct 97 | t59 = ct * (v39 + t57) 98 | t64 = t13 ** 2 99 | t68 = t20 * t29 100 | t71 = t24 * t64 101 | t74 = v04 * ct 102 | t76 = ct * (v03 + t74) 103 | t79 = v07 * ct 104 | t82 = cp.sqrt(sa) 105 | t83 = v11 * ct 106 | t85 = ct * (v10 + t83) 107 | t92 = ( 108 | v01 109 | + ct * (v02 + t76) 110 | + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85))) 111 | ) 112 | t93 = v48 * t92 113 | t105 = ( 114 | v02 115 | + t76 116 | + ct * (v03 + 2.0 * t74) 117 | + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83))) 118 | ) 119 | t106 = t24 * t105 120 | t107 = v44 + t2 + t3 121 | t110 = v43 + t48 122 | t117 = t24 * t92 123 | t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13 124 | t123 = ( 125 | v38 126 | + t59 127 | + ct * (v39 + 2.0 * t57) 128 | + sa * v42 129 | + ( 130 | 4.0 * v48 * t64 * t20 131 | + 8.0 * t33 * t68 132 | - 4.0 * t71 * t38 133 | - t93 134 | - t106 135 | - 2.0 * t107 * t13 136 | - 2.0 * t110 * t29 137 | ) 138 | * t20 139 | - t120 * t35 * t37 140 | ) 141 | t128 = t19 * p 142 | t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128) 143 | t131 = 1.0 / t92 144 | t133 = 1.0 + t130 * t131 145 | t134 = cp.log(t133) 146 | t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20 147 | t152 = t37 * p 148 | t156 = t92 ** 2 149 | t165 = v25 * ct 150 | t167 = ct * (v24 + t165) 151 | t169 = ct * (v23 + t167) 152 | t175 = v30 * ct 153 | t177 = ct * (v29 + t175) 154 | t179 = ct * (v28 + t177) 155 | t185 = v35 * ct 156 | t187 = ct * (v34 + t185) 157 | t189 = ct * (v33 + t187) 158 | t199 = t13 * t20 159 | t217 = 2.0 * t117 * t199 - t110 * t92 160 | t234 = ( 161 | v21 162 | + ct * (v22 + t169) 163 | + sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189))) 164 | + t217 * t20 165 | ) 166 | t241 = t64 - t92 * t19 167 | t242 = cp.sqrt(t241) 168 | t243 = 1.0 / t242 169 | t244 = t4 + t8 + t12 - t242 170 | t245 = 1.0 / t244 171 | t247 = t4 + t8 + t12 + t242 + t128 172 | t248 = 1.0 / t247 173 | t249 = t242 * t245 * t248 174 | t252 = 1.0 + 2.0 * t128 * t249 175 | t253 = cp.log(t252) 176 | t254 = t243 * t253 177 | t259 = t234 * t19 - t143 * t13 178 | t264 = t259 * t20 179 | t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37 180 | t282 = t128 * t242 181 | t283 = t244 ** 2 182 | t287 = t243 * t272 / 2.0 183 | t292 = t247 ** 2 184 | t305 = ( 185 | 0.1e5 186 | * p 187 | * ( 188 | v44 189 | + t2 190 | + t3 191 | - 2.0 * v48 * t13 * t20 192 | - 2.0 * t24 * t29 * t20 193 | + 2.0 * t33 * t38 194 | + 0.5 * v48 * p 195 | ) 196 | * t20 197 | - 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38 198 | + 0.5e4 * t123 * t20 * t134 199 | - 0.5e4 * t143 * t35 * t134 * t37 200 | + 0.5e4 201 | * t143 202 | * t20 203 | * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) 204 | / t133 205 | + 0.5e4 206 | * ( 207 | ( 208 | v22 209 | + t169 210 | + ct * (v23 + t167 + ct * (v24 + 2.0 * t165)) 211 | + sa 212 | * ( 213 | v27 214 | + t179 215 | + ct * (v28 + t177 + ct * (v29 + 2.0 * t175)) 216 | + t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185))) 217 | ) 218 | + ( 219 | 2.0 * t93 * t199 220 | + 2.0 * t106 * t199 221 | + 2.0 * t117 * t68 222 | - 2.0 * t117 * t13 * t35 * t37 223 | - t107 * t92 224 | - t110 * t105 225 | ) 226 | * t20 227 | - t217 * t35 * t37 228 | ) 229 | * t19 230 | + t234 * t37 231 | - t123 * t13 232 | - t143 * t29 233 | ) 234 | * t20 235 | * t254 236 | - 0.5e4 * t259 * t35 * t254 * t37 237 | - 0.25e4 * t264 / t242 / t241 * t253 * t272 238 | + 0.5e4 239 | * t264 240 | * t243 241 | * ( 242 | 2.0 * t152 * t249 243 | + t128 * t243 * t245 * t248 * t272 244 | - 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287) 245 | - 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152) 246 | ) 247 | / t252 248 | ) 249 | 250 | return t305 251 | 252 | 253 | def prepare_inputs(sa, ct, p, device): 254 | out = [cp.asarray(k) for k in (sa, ct, p)] 255 | cp.cuda.stream.get_current_stream().synchronize() 256 | return out 257 | 258 | 259 | def run(sa, ct, p, device="cpu"): 260 | out = gsw_dHdT(sa, ct, p) 261 | cp.cuda.stream.get_current_stream().synchronize() 262 | return out 263 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/eos_jax.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================================== 3 | in-situ density, dynamic enthalpy and derivatives 4 | from Absolute Salinity and Conservative 5 | Temperature, using the computationally-efficient 48-term expression for 6 | density in terms of SA, CT and p (IOC et al., 2010). 7 | ========================================================================== 8 | """ 9 | 10 | import jax 11 | import jax.numpy as np 12 | 13 | v01 = 9.998420897506056e2 14 | v02 = 2.839940833161907e0 15 | v03 = -3.147759265588511e-2 16 | v04 = 1.181805545074306e-3 17 | v05 = -6.698001071123802e0 18 | v06 = -2.986498947203215e-2 19 | v07 = 2.327859407479162e-4 20 | v08 = -3.988822378968490e-2 21 | v09 = 5.095422573880500e-4 22 | v10 = -1.426984671633621e-5 23 | v11 = 1.645039373682922e-7 24 | v12 = -2.233269627352527e-2 25 | v13 = -3.436090079851880e-4 26 | v14 = 3.726050720345733e-6 27 | v15 = -1.806789763745328e-4 28 | v16 = 6.876837219536232e-7 29 | v17 = -3.087032500374211e-7 30 | v18 = -1.988366587925593e-8 31 | v19 = -1.061519070296458e-11 32 | v20 = 1.550932729220080e-10 33 | v21 = 1.0e0 34 | v22 = 2.775927747785646e-3 35 | v23 = -2.349607444135925e-5 36 | v24 = 1.119513357486743e-6 37 | v25 = 6.743689325042773e-10 38 | v26 = -7.521448093615448e-3 39 | v27 = -2.764306979894411e-5 40 | v28 = 1.262937315098546e-7 41 | v29 = 9.527875081696435e-10 42 | v30 = -1.811147201949891e-11 43 | v31 = -3.303308871386421e-5 44 | v32 = 3.801564588876298e-7 45 | v33 = -7.672876869259043e-9 46 | v34 = -4.634182341116144e-11 47 | v35 = 2.681097235569143e-12 48 | v36 = 5.419326551148740e-6 49 | v37 = -2.742185394906099e-5 50 | v38 = -3.212746477974189e-7 51 | v39 = 3.191413910561627e-9 52 | v40 = -1.931012931541776e-12 53 | v41 = -1.105097577149576e-7 54 | v42 = 6.211426728363857e-10 55 | v43 = -1.119011592875110e-10 56 | v44 = -1.941660213148725e-11 57 | v45 = -1.864826425365600e-14 58 | v46 = 1.119522344879478e-14 59 | v47 = -1.200507748551599e-15 60 | v48 = 6.057902487546866e-17 61 | rho0 = 1024.0 62 | 63 | 64 | @jax.jit 65 | def gsw_dHdT(sa, ct, p): 66 | """ 67 | d/dT of dynamic enthalpy, analytical derivative 68 | 69 | sa : Absolute Salinity [g/kg] 70 | ct : Conservative Temperature [deg C] 71 | p : sea pressure [dbar] 72 | """ 73 | t1 = v45 * ct 74 | t2 = 0.2e1 * t1 75 | t3 = v46 * sa 76 | t4 = 0.5 * v12 77 | t5 = v14 * ct 78 | t7 = ct * (v13 + t5) 79 | t8 = 0.5 * t7 80 | t11 = sa * (v15 + v16 * ct) 81 | t12 = 0.5 * t11 82 | t13 = t4 + t8 + t12 83 | t15 = v19 * ct 84 | t19 = v17 + ct * (v18 + t15) + v20 * sa 85 | t20 = 1.0 / t19 86 | t24 = v47 + v48 * ct 87 | t25 = 0.5 * v13 88 | t26 = 1.0 * t5 89 | t27 = sa * v16 90 | t28 = 0.5 * t27 91 | t29 = t25 + t26 + t28 92 | t33 = t24 * t13 93 | t34 = t19 ** 2 94 | t35 = 1.0 / t34 95 | t37 = v18 + 2.0 * t15 96 | t38 = t35 * t37 97 | t48 = ct * (v44 + t1 + t3) 98 | t57 = v40 * ct 99 | t59 = ct * (v39 + t57) 100 | t64 = t13 ** 2 101 | t68 = t20 * t29 102 | t71 = t24 * t64 103 | t74 = v04 * ct 104 | t76 = ct * (v03 + t74) 105 | t79 = v07 * ct 106 | t82 = np.sqrt(sa) 107 | t83 = v11 * ct 108 | t85 = ct * (v10 + t83) 109 | t92 = ( 110 | v01 111 | + ct * (v02 + t76) 112 | + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85))) 113 | ) 114 | t93 = v48 * t92 115 | t105 = ( 116 | v02 117 | + t76 118 | + ct * (v03 + 2.0 * t74) 119 | + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83))) 120 | ) 121 | t106 = t24 * t105 122 | t107 = v44 + t2 + t3 123 | t110 = v43 + t48 124 | t117 = t24 * t92 125 | t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13 126 | t123 = ( 127 | v38 128 | + t59 129 | + ct * (v39 + 2.0 * t57) 130 | + sa * v42 131 | + ( 132 | 4.0 * v48 * t64 * t20 133 | + 8.0 * t33 * t68 134 | - 4.0 * t71 * t38 135 | - t93 136 | - t106 137 | - 2.0 * t107 * t13 138 | - 2.0 * t110 * t29 139 | ) 140 | * t20 141 | - t120 * t35 * t37 142 | ) 143 | t128 = t19 * p 144 | t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128) 145 | t131 = 1.0 / t92 146 | t133 = 1.0 + t130 * t131 147 | t134 = np.log(t133) 148 | t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20 149 | t152 = t37 * p 150 | t156 = t92 ** 2 151 | t165 = v25 * ct 152 | t167 = ct * (v24 + t165) 153 | t169 = ct * (v23 + t167) 154 | t175 = v30 * ct 155 | t177 = ct * (v29 + t175) 156 | t179 = ct * (v28 + t177) 157 | t185 = v35 * ct 158 | t187 = ct * (v34 + t185) 159 | t189 = ct * (v33 + t187) 160 | t199 = t13 * t20 161 | t217 = 2.0 * t117 * t199 - t110 * t92 162 | t234 = ( 163 | v21 164 | + ct * (v22 + t169) 165 | + sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189))) 166 | + t217 * t20 167 | ) 168 | t241 = t64 - t92 * t19 169 | t242 = np.sqrt(t241) 170 | t243 = 1.0 / t242 171 | t244 = t4 + t8 + t12 - t242 172 | t245 = 1.0 / t244 173 | t247 = t4 + t8 + t12 + t242 + t128 174 | t248 = 1.0 / t247 175 | t249 = t242 * t245 * t248 176 | t252 = 1.0 + 2.0 * t128 * t249 177 | t253 = np.log(t252) 178 | t254 = t243 * t253 179 | t259 = t234 * t19 - t143 * t13 180 | t264 = t259 * t20 181 | t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37 182 | t282 = t128 * t242 183 | t283 = t244 ** 2 184 | t287 = t243 * t272 / 2.0 185 | t292 = t247 ** 2 186 | t305 = ( 187 | 0.1e5 188 | * p 189 | * ( 190 | v44 191 | + t2 192 | + t3 193 | - 2.0 * v48 * t13 * t20 194 | - 2.0 * t24 * t29 * t20 195 | + 2.0 * t33 * t38 196 | + 0.5 * v48 * p 197 | ) 198 | * t20 199 | - 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38 200 | + 0.5e4 * t123 * t20 * t134 201 | - 0.5e4 * t143 * t35 * t134 * t37 202 | + 0.5e4 203 | * t143 204 | * t20 205 | * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) 206 | / t133 207 | + 0.5e4 208 | * ( 209 | ( 210 | v22 211 | + t169 212 | + ct * (v23 + t167 + ct * (v24 + 2.0 * t165)) 213 | + sa 214 | * ( 215 | v27 216 | + t179 217 | + ct * (v28 + t177 + ct * (v29 + 2.0 * t175)) 218 | + t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185))) 219 | ) 220 | + ( 221 | 2.0 * t93 * t199 222 | + 2.0 * t106 * t199 223 | + 2.0 * t117 * t68 224 | - 2.0 * t117 * t13 * t35 * t37 225 | - t107 * t92 226 | - t110 * t105 227 | ) 228 | * t20 229 | - t217 * t35 * t37 230 | ) 231 | * t19 232 | + t234 * t37 233 | - t123 * t13 234 | - t143 * t29 235 | ) 236 | * t20 237 | * t254 238 | - 0.5e4 * t259 * t35 * t254 * t37 239 | - 0.25e4 * t264 / t242 / t241 * t253 * t272 240 | + 0.5e4 241 | * t264 242 | * t243 243 | * ( 244 | 2.0 * t152 * t249 245 | + t128 * t243 * t245 * t248 * t272 246 | - 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287) 247 | - 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152) 248 | ) 249 | / t252 250 | ) 251 | 252 | return t305 253 | 254 | 255 | def prepare_inputs(sa, ct, p, device): 256 | out = [np.array(k) for k in (sa, ct, p)] 257 | for o in out: 258 | o.block_until_ready() 259 | return out 260 | 261 | 262 | def run(sa, ct, p, device="cpu"): 263 | out = gsw_dHdT(sa, ct, p) 264 | out.block_until_ready() 265 | return out 266 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/eos_numba.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================================== 3 | in-situ density, dynamic enthalpy and derivatives 4 | from Absolute Salinity and Conservative 5 | Temperature, using the computationally-efficient 48-term expression for 6 | density in terms of SA, CT and p (IOC et al., 2010). 7 | ========================================================================== 8 | """ 9 | 10 | import numpy as np 11 | import numba as nb 12 | 13 | v01 = 9.998420897506056e2 14 | v02 = 2.839940833161907e0 15 | v03 = -3.147759265588511e-2 16 | v04 = 1.181805545074306e-3 17 | v05 = -6.698001071123802e0 18 | v06 = -2.986498947203215e-2 19 | v07 = 2.327859407479162e-4 20 | v08 = -3.988822378968490e-2 21 | v09 = 5.095422573880500e-4 22 | v10 = -1.426984671633621e-5 23 | v11 = 1.645039373682922e-7 24 | v12 = -2.233269627352527e-2 25 | v13 = -3.436090079851880e-4 26 | v14 = 3.726050720345733e-6 27 | v15 = -1.806789763745328e-4 28 | v16 = 6.876837219536232e-7 29 | v17 = -3.087032500374211e-7 30 | v18 = -1.988366587925593e-8 31 | v19 = -1.061519070296458e-11 32 | v20 = 1.550932729220080e-10 33 | v21 = 1.0e0 34 | v22 = 2.775927747785646e-3 35 | v23 = -2.349607444135925e-5 36 | v24 = 1.119513357486743e-6 37 | v25 = 6.743689325042773e-10 38 | v26 = -7.521448093615448e-3 39 | v27 = -2.764306979894411e-5 40 | v28 = 1.262937315098546e-7 41 | v29 = 9.527875081696435e-10 42 | v30 = -1.811147201949891e-11 43 | v31 = -3.303308871386421e-5 44 | v32 = 3.801564588876298e-7 45 | v33 = -7.672876869259043e-9 46 | v34 = -4.634182341116144e-11 47 | v35 = 2.681097235569143e-12 48 | v36 = 5.419326551148740e-6 49 | v37 = -2.742185394906099e-5 50 | v38 = -3.212746477974189e-7 51 | v39 = 3.191413910561627e-9 52 | v40 = -1.931012931541776e-12 53 | v41 = -1.105097577149576e-7 54 | v42 = 6.211426728363857e-10 55 | v43 = -1.119011592875110e-10 56 | v44 = -1.941660213148725e-11 57 | v45 = -1.864826425365600e-14 58 | v46 = 1.119522344879478e-14 59 | v47 = -1.200507748551599e-15 60 | v48 = 6.057902487546866e-17 61 | rho0 = 1024.0 62 | 63 | 64 | @nb.jit(nopython=True, fastmath=True) 65 | def gsw_dHdT(sa, ct, p): 66 | """ 67 | d/dT of dynamic enthalpy, analytical derivative 68 | 69 | sa : Absolute Salinity [g/kg] 70 | ct : Conservative Temperature [deg C] 71 | p : sea pressure [dbar] 72 | """ 73 | t1 = v45 * ct 74 | t2 = 0.2e1 * t1 75 | t3 = v46 * sa 76 | t4 = 0.5 * v12 77 | t5 = v14 * ct 78 | t7 = ct * (v13 + t5) 79 | t8 = 0.5 * t7 80 | t11 = sa * (v15 + v16 * ct) 81 | t12 = 0.5 * t11 82 | t13 = t4 + t8 + t12 83 | t15 = v19 * ct 84 | t19 = v17 + ct * (v18 + t15) + v20 * sa 85 | t20 = 1.0 / t19 86 | t24 = v47 + v48 * ct 87 | t25 = 0.5 * v13 88 | t26 = 1.0 * t5 89 | t27 = sa * v16 90 | t28 = 0.5 * t27 91 | t29 = t25 + t26 + t28 92 | t33 = t24 * t13 93 | t34 = t19 ** 2 94 | t35 = 1.0 / t34 95 | t37 = v18 + 2.0 * t15 96 | t38 = t35 * t37 97 | t48 = ct * (v44 + t1 + t3) 98 | t57 = v40 * ct 99 | t59 = ct * (v39 + t57) 100 | t64 = t13 ** 2 101 | t68 = t20 * t29 102 | t71 = t24 * t64 103 | t74 = v04 * ct 104 | t76 = ct * (v03 + t74) 105 | t79 = v07 * ct 106 | t82 = np.sqrt(sa) 107 | t83 = v11 * ct 108 | t85 = ct * (v10 + t83) 109 | t92 = ( 110 | v01 111 | + ct * (v02 + t76) 112 | + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85))) 113 | ) 114 | t93 = v48 * t92 115 | t105 = ( 116 | v02 117 | + t76 118 | + ct * (v03 + 2.0 * t74) 119 | + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83))) 120 | ) 121 | t106 = t24 * t105 122 | t107 = v44 + t2 + t3 123 | t110 = v43 + t48 124 | t117 = t24 * t92 125 | t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13 126 | t123 = ( 127 | v38 128 | + t59 129 | + ct * (v39 + 2.0 * t57) 130 | + sa * v42 131 | + ( 132 | 4.0 * v48 * t64 * t20 133 | + 8.0 * t33 * t68 134 | - 4.0 * t71 * t38 135 | - t93 136 | - t106 137 | - 2.0 * t107 * t13 138 | - 2.0 * t110 * t29 139 | ) 140 | * t20 141 | - t120 * t35 * t37 142 | ) 143 | t128 = t19 * p 144 | t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128) 145 | t131 = 1.0 / t92 146 | t133 = 1.0 + t130 * t131 147 | t134 = np.log(t133) 148 | t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20 149 | t152 = t37 * p 150 | t156 = t92 ** 2 151 | t165 = v25 * ct 152 | t167 = ct * (v24 + t165) 153 | t169 = ct * (v23 + t167) 154 | t175 = v30 * ct 155 | t177 = ct * (v29 + t175) 156 | t179 = ct * (v28 + t177) 157 | t185 = v35 * ct 158 | t187 = ct * (v34 + t185) 159 | t189 = ct * (v33 + t187) 160 | t199 = t13 * t20 161 | t217 = 2.0 * t117 * t199 - t110 * t92 162 | t234 = ( 163 | v21 164 | + ct * (v22 + t169) 165 | + sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189))) 166 | + t217 * t20 167 | ) 168 | t241 = t64 - t92 * t19 169 | t242 = np.sqrt(t241) 170 | t243 = 1.0 / t242 171 | t244 = t4 + t8 + t12 - t242 172 | t245 = 1.0 / t244 173 | t247 = t4 + t8 + t12 + t242 + t128 174 | t248 = 1.0 / t247 175 | t249 = t242 * t245 * t248 176 | t252 = 1.0 + 2.0 * t128 * t249 177 | t253 = np.log(t252) 178 | t254 = t243 * t253 179 | t259 = t234 * t19 - t143 * t13 180 | t264 = t259 * t20 181 | t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37 182 | t282 = t128 * t242 183 | t283 = t244 ** 2 184 | t287 = t243 * t272 / 2.0 185 | t292 = t247 ** 2 186 | t305 = ( 187 | 0.1e5 188 | * p 189 | * ( 190 | v44 191 | + t2 192 | + t3 193 | - 2.0 * v48 * t13 * t20 194 | - 2.0 * t24 * t29 * t20 195 | + 2.0 * t33 * t38 196 | + 0.5 * v48 * p 197 | ) 198 | * t20 199 | - 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38 200 | + 0.5e4 * t123 * t20 * t134 201 | - 0.5e4 * t143 * t35 * t134 * t37 202 | + 0.5e4 203 | * t143 204 | * t20 205 | * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) 206 | / t133 207 | + 0.5e4 208 | * ( 209 | ( 210 | v22 211 | + t169 212 | + ct * (v23 + t167 + ct * (v24 + 2.0 * t165)) 213 | + sa 214 | * ( 215 | v27 216 | + t179 217 | + ct * (v28 + t177 + ct * (v29 + 2.0 * t175)) 218 | + t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185))) 219 | ) 220 | + ( 221 | 2.0 * t93 * t199 222 | + 2.0 * t106 * t199 223 | + 2.0 * t117 * t68 224 | - 2.0 * t117 * t13 * t35 * t37 225 | - t107 * t92 226 | - t110 * t105 227 | ) 228 | * t20 229 | - t217 * t35 * t37 230 | ) 231 | * t19 232 | + t234 * t37 233 | - t123 * t13 234 | - t143 * t29 235 | ) 236 | * t20 237 | * t254 238 | - 0.5e4 * t259 * t35 * t254 * t37 239 | - 0.25e4 * t264 / t242 / t241 * t253 * t272 240 | + 0.5e4 241 | * t264 242 | * t243 243 | * ( 244 | 2.0 * t152 * t249 245 | + t128 * t243 * t245 * t248 * t272 246 | - 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287) 247 | - 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152) 248 | ) 249 | / t252 250 | ) 251 | 252 | return t305 253 | 254 | 255 | @nb.jit(nopython=True, boundscheck=False, nogil=True, fastmath=True) 256 | def gsw_dHdT_vec(sa, ct, p, out): 257 | ix, iy, iz = sa.shape 258 | for i in range(ix): 259 | for j in range(iy): 260 | for k in range(iz): 261 | out[i, j, k] = gsw_dHdT(sa[i, j, k], ct[i, j, k], p[0, 0, k]) 262 | 263 | 264 | def run(sa, ct, p, device="cpu"): 265 | out = np.empty_like(sa) 266 | gsw_dHdT_vec(sa, ct, p, out) 267 | return out 268 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/eos_numpy.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================================== 3 | in-situ density, dynamic enthalpy and derivatives 4 | from Absolute Salinity and Conservative 5 | Temperature, using the computationally-efficient 48-term expression for 6 | density in terms of SA, CT and p (IOC et al., 2010). 7 | ========================================================================== 8 | """ 9 | 10 | import numpy as np 11 | 12 | v01 = 9.998420897506056e2 13 | v02 = 2.839940833161907e0 14 | v03 = -3.147759265588511e-2 15 | v04 = 1.181805545074306e-3 16 | v05 = -6.698001071123802e0 17 | v06 = -2.986498947203215e-2 18 | v07 = 2.327859407479162e-4 19 | v08 = -3.988822378968490e-2 20 | v09 = 5.095422573880500e-4 21 | v10 = -1.426984671633621e-5 22 | v11 = 1.645039373682922e-7 23 | v12 = -2.233269627352527e-2 24 | v13 = -3.436090079851880e-4 25 | v14 = 3.726050720345733e-6 26 | v15 = -1.806789763745328e-4 27 | v16 = 6.876837219536232e-7 28 | v17 = -3.087032500374211e-7 29 | v18 = -1.988366587925593e-8 30 | v19 = -1.061519070296458e-11 31 | v20 = 1.550932729220080e-10 32 | v21 = 1.0e0 33 | v22 = 2.775927747785646e-3 34 | v23 = -2.349607444135925e-5 35 | v24 = 1.119513357486743e-6 36 | v25 = 6.743689325042773e-10 37 | v26 = -7.521448093615448e-3 38 | v27 = -2.764306979894411e-5 39 | v28 = 1.262937315098546e-7 40 | v29 = 9.527875081696435e-10 41 | v30 = -1.811147201949891e-11 42 | v31 = -3.303308871386421e-5 43 | v32 = 3.801564588876298e-7 44 | v33 = -7.672876869259043e-9 45 | v34 = -4.634182341116144e-11 46 | v35 = 2.681097235569143e-12 47 | v36 = 5.419326551148740e-6 48 | v37 = -2.742185394906099e-5 49 | v38 = -3.212746477974189e-7 50 | v39 = 3.191413910561627e-9 51 | v40 = -1.931012931541776e-12 52 | v41 = -1.105097577149576e-7 53 | v42 = 6.211426728363857e-10 54 | v43 = -1.119011592875110e-10 55 | v44 = -1.941660213148725e-11 56 | v45 = -1.864826425365600e-14 57 | v46 = 1.119522344879478e-14 58 | v47 = -1.200507748551599e-15 59 | v48 = 6.057902487546866e-17 60 | rho0 = 1024.0 61 | 62 | 63 | def gsw_dHdT(sa, ct, p): 64 | """ 65 | d/dT of dynamic enthalpy, analytical derivative 66 | 67 | sa : Absolute Salinity [g/kg] 68 | ct : Conservative Temperature [deg C] 69 | p : sea pressure [dbar] 70 | """ 71 | t1 = v45 * ct 72 | t2 = 0.2e1 * t1 73 | t3 = v46 * sa 74 | t4 = 0.5 * v12 75 | t5 = v14 * ct 76 | t7 = ct * (v13 + t5) 77 | t8 = 0.5 * t7 78 | t11 = sa * (v15 + v16 * ct) 79 | t12 = 0.5 * t11 80 | t13 = t4 + t8 + t12 81 | t15 = v19 * ct 82 | t19 = v17 + ct * (v18 + t15) + v20 * sa 83 | t20 = 1.0 / t19 84 | t24 = v47 + v48 * ct 85 | t25 = 0.5 * v13 86 | t26 = 1.0 * t5 87 | t27 = sa * v16 88 | t28 = 0.5 * t27 89 | t29 = t25 + t26 + t28 90 | t33 = t24 * t13 91 | t34 = t19 ** 2 92 | t35 = 1.0 / t34 93 | t37 = v18 + 2.0 * t15 94 | t38 = t35 * t37 95 | t48 = ct * (v44 + t1 + t3) 96 | t57 = v40 * ct 97 | t59 = ct * (v39 + t57) 98 | t64 = t13 ** 2 99 | t68 = t20 * t29 100 | t71 = t24 * t64 101 | t74 = v04 * ct 102 | t76 = ct * (v03 + t74) 103 | t79 = v07 * ct 104 | t82 = np.sqrt(sa) 105 | t83 = v11 * ct 106 | t85 = ct * (v10 + t83) 107 | t92 = ( 108 | v01 109 | + ct * (v02 + t76) 110 | + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85))) 111 | ) 112 | t93 = v48 * t92 113 | t105 = ( 114 | v02 115 | + t76 116 | + ct * (v03 + 2.0 * t74) 117 | + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83))) 118 | ) 119 | t106 = t24 * t105 120 | t107 = v44 + t2 + t3 121 | t110 = v43 + t48 122 | t117 = t24 * t92 123 | t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13 124 | t123 = ( 125 | v38 126 | + t59 127 | + ct * (v39 + 2.0 * t57) 128 | + sa * v42 129 | + ( 130 | 4.0 * v48 * t64 * t20 131 | + 8.0 * t33 * t68 132 | - 4.0 * t71 * t38 133 | - t93 134 | - t106 135 | - 2.0 * t107 * t13 136 | - 2.0 * t110 * t29 137 | ) 138 | * t20 139 | - t120 * t35 * t37 140 | ) 141 | t128 = t19 * p 142 | t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128) 143 | t131 = 1.0 / t92 144 | t133 = 1.0 + t130 * t131 145 | t134 = np.log(t133) 146 | t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20 147 | t152 = t37 * p 148 | t156 = t92 ** 2 149 | t165 = v25 * ct 150 | t167 = ct * (v24 + t165) 151 | t169 = ct * (v23 + t167) 152 | t175 = v30 * ct 153 | t177 = ct * (v29 + t175) 154 | t179 = ct * (v28 + t177) 155 | t185 = v35 * ct 156 | t187 = ct * (v34 + t185) 157 | t189 = ct * (v33 + t187) 158 | t199 = t13 * t20 159 | t217 = 2.0 * t117 * t199 - t110 * t92 160 | t234 = ( 161 | v21 162 | + ct * (v22 + t169) 163 | + sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189))) 164 | + t217 * t20 165 | ) 166 | t241 = t64 - t92 * t19 167 | t242 = np.sqrt(t241) 168 | t243 = 1.0 / t242 169 | t244 = t4 + t8 + t12 - t242 170 | t245 = 1.0 / t244 171 | t247 = t4 + t8 + t12 + t242 + t128 172 | t248 = 1.0 / t247 173 | t249 = t242 * t245 * t248 174 | t252 = 1.0 + 2.0 * t128 * t249 175 | t253 = np.log(t252) 176 | t254 = t243 * t253 177 | t259 = t234 * t19 - t143 * t13 178 | t264 = t259 * t20 179 | t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37 180 | t282 = t128 * t242 181 | t283 = t244 ** 2 182 | t287 = t243 * t272 / 2.0 183 | t292 = t247 ** 2 184 | t305 = ( 185 | 0.1e5 186 | * p 187 | * ( 188 | v44 189 | + t2 190 | + t3 191 | - 2.0 * v48 * t13 * t20 192 | - 2.0 * t24 * t29 * t20 193 | + 2.0 * t33 * t38 194 | + 0.5 * v48 * p 195 | ) 196 | * t20 197 | - 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38 198 | + 0.5e4 * t123 * t20 * t134 199 | - 0.5e4 * t143 * t35 * t134 * t37 200 | + 0.5e4 201 | * t143 202 | * t20 203 | * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) 204 | / t133 205 | + 0.5e4 206 | * ( 207 | ( 208 | v22 209 | + t169 210 | + ct * (v23 + t167 + ct * (v24 + 2.0 * t165)) 211 | + sa 212 | * ( 213 | v27 214 | + t179 215 | + ct * (v28 + t177 + ct * (v29 + 2.0 * t175)) 216 | + t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185))) 217 | ) 218 | + ( 219 | 2.0 * t93 * t199 220 | + 2.0 * t106 * t199 221 | + 2.0 * t117 * t68 222 | - 2.0 * t117 * t13 * t35 * t37 223 | - t107 * t92 224 | - t110 * t105 225 | ) 226 | * t20 227 | - t217 * t35 * t37 228 | ) 229 | * t19 230 | + t234 * t37 231 | - t123 * t13 232 | - t143 * t29 233 | ) 234 | * t20 235 | * t254 236 | - 0.5e4 * t259 * t35 * t254 * t37 237 | - 0.25e4 * t264 / t242 / t241 * t253 * t272 238 | + 0.5e4 239 | * t264 240 | * t243 241 | * ( 242 | 2.0 * t152 * t249 243 | + t128 * t243 * t245 * t248 * t272 244 | - 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287) 245 | - 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152) 246 | ) 247 | / t252 248 | ) 249 | 250 | return t305 251 | 252 | 253 | def run(sa, ct, p, device="cpu"): 254 | return gsw_dHdT(sa, ct, p) 255 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/eos_pytorch.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================================== 3 | in-situ density, dynamic enthalpy and derivatives 4 | from Absolute Salinity and Conservative 5 | Temperature, using the computationally-efficient 48-term expression for 6 | density in terms of SA, CT and p (IOC et al., 2010). 7 | ========================================================================== 8 | """ 9 | 10 | import torch 11 | 12 | 13 | @torch.jit.script 14 | def gsw_dHdT(sa, ct, p): 15 | """ 16 | d/dT of dynamic enthalpy, analytical derivative 17 | 18 | sa : Absolute Salinity [g/kg] 19 | ct : Conservative Temperature [deg C] 20 | p : sea pressure [dbar] 21 | """ 22 | v01 = 9.998420897506056e2 23 | v02 = 2.839940833161907e0 24 | v03 = -3.147759265588511e-2 25 | v04 = 1.181805545074306e-3 26 | v05 = -6.698001071123802e0 27 | v06 = -2.986498947203215e-2 28 | v07 = 2.327859407479162e-4 29 | v08 = -3.988822378968490e-2 30 | v09 = 5.095422573880500e-4 31 | v10 = -1.426984671633621e-5 32 | v11 = 1.645039373682922e-7 33 | v12 = -2.233269627352527e-2 34 | v13 = -3.436090079851880e-4 35 | v14 = 3.726050720345733e-6 36 | v15 = -1.806789763745328e-4 37 | v16 = 6.876837219536232e-7 38 | v17 = -3.087032500374211e-7 39 | v18 = -1.988366587925593e-8 40 | v19 = -1.061519070296458e-11 41 | v20 = 1.550932729220080e-10 42 | v21 = 1.0e0 43 | v22 = 2.775927747785646e-3 44 | v23 = -2.349607444135925e-5 45 | v24 = 1.119513357486743e-6 46 | v25 = 6.743689325042773e-10 47 | v26 = -7.521448093615448e-3 48 | v27 = -2.764306979894411e-5 49 | v28 = 1.262937315098546e-7 50 | v29 = 9.527875081696435e-10 51 | v30 = -1.811147201949891e-11 52 | v31 = -3.303308871386421e-5 53 | v32 = 3.801564588876298e-7 54 | v33 = -7.672876869259043e-9 55 | v34 = -4.634182341116144e-11 56 | v35 = 2.681097235569143e-12 57 | v36 = 5.419326551148740e-6 58 | v37 = -2.742185394906099e-5 59 | v38 = -3.212746477974189e-7 60 | v39 = 3.191413910561627e-9 61 | v40 = -1.931012931541776e-12 62 | v41 = -1.105097577149576e-7 63 | v42 = 6.211426728363857e-10 64 | v43 = -1.119011592875110e-10 65 | v44 = -1.941660213148725e-11 66 | v45 = -1.864826425365600e-14 67 | v46 = 1.119522344879478e-14 68 | v47 = -1.200507748551599e-15 69 | v48 = 6.057902487546866e-17 70 | 71 | t1 = v45 * ct 72 | t2 = 0.2e1 * t1 73 | t3 = v46 * sa 74 | t4 = 0.5 * v12 75 | t5 = v14 * ct 76 | t7 = ct * (v13 + t5) 77 | t8 = 0.5 * t7 78 | t11 = sa * (v15 + v16 * ct) 79 | t12 = 0.5 * t11 80 | t13 = t4 + t8 + t12 81 | t15 = v19 * ct 82 | t19 = v17 + ct * (v18 + t15) + v20 * sa 83 | t20 = 1.0 / t19 84 | t24 = v47 + v48 * ct 85 | t25 = 0.5 * v13 86 | t26 = 1.0 * t5 87 | t27 = sa * v16 88 | t28 = 0.5 * t27 89 | t29 = t25 + t26 + t28 90 | t33 = t24 * t13 91 | t34 = t19 ** 2 92 | t35 = 1.0 / t34 93 | t37 = v18 + 2.0 * t15 94 | t38 = t35 * t37 95 | t48 = ct * (v44 + t1 + t3) 96 | t57 = v40 * ct 97 | t59 = ct * (v39 + t57) 98 | t64 = t13 ** 2 99 | t68 = t20 * t29 100 | t71 = t24 * t64 101 | t74 = v04 * ct 102 | t76 = ct * (v03 + t74) 103 | t79 = v07 * ct 104 | t82 = torch.sqrt(sa) 105 | t83 = v11 * ct 106 | t85 = ct * (v10 + t83) 107 | t92 = ( 108 | v01 109 | + ct * (v02 + t76) 110 | + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85))) 111 | ) 112 | t93 = v48 * t92 113 | t105 = ( 114 | v02 115 | + t76 116 | + ct * (v03 + 2.0 * t74) 117 | + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83))) 118 | ) 119 | t106 = t24 * t105 120 | t107 = v44 + t2 + t3 121 | t110 = v43 + t48 122 | t117 = t24 * t92 123 | t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13 124 | t123 = ( 125 | v38 126 | + t59 127 | + ct * (v39 + 2.0 * t57) 128 | + sa * v42 129 | + ( 130 | 4.0 * v48 * t64 * t20 131 | + 8.0 * t33 * t68 132 | - 4.0 * t71 * t38 133 | - t93 134 | - t106 135 | - 2.0 * t107 * t13 136 | - 2.0 * t110 * t29 137 | ) 138 | * t20 139 | - t120 * t35 * t37 140 | ) 141 | t128 = t19 * p 142 | t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128) 143 | t131 = 1.0 / t92 144 | t133 = 1.0 + t130 * t131 145 | t134 = torch.log(t133) 146 | t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20 147 | t152 = t37 * p 148 | t156 = t92 ** 2 149 | t165 = v25 * ct 150 | t167 = ct * (v24 + t165) 151 | t169 = ct * (v23 + t167) 152 | t175 = v30 * ct 153 | t177 = ct * (v29 + t175) 154 | t179 = ct * (v28 + t177) 155 | t185 = v35 * ct 156 | t187 = ct * (v34 + t185) 157 | t189 = ct * (v33 + t187) 158 | t199 = t13 * t20 159 | t217 = 2.0 * t117 * t199 - t110 * t92 160 | t234 = ( 161 | v21 162 | + ct * (v22 + t169) 163 | + sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189))) 164 | + t217 * t20 165 | ) 166 | t241 = t64 - t92 * t19 167 | t242 = torch.sqrt(t241) 168 | t243 = 1.0 / t242 169 | t244 = t4 + t8 + t12 - t242 170 | t245 = 1.0 / t244 171 | t247 = t4 + t8 + t12 + t242 + t128 172 | t248 = 1.0 / t247 173 | t249 = t242 * t245 * t248 174 | t252 = 1.0 + 2.0 * t128 * t249 175 | t253 = torch.log(t252) 176 | t254 = t243 * t253 177 | t259 = t234 * t19 - t143 * t13 178 | t264 = t259 * t20 179 | t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37 180 | t282 = t128 * t242 181 | t283 = t244 ** 2 182 | t287 = t243 * t272 / 2.0 183 | t292 = t247 ** 2 184 | t305 = ( 185 | 0.1e5 186 | * p 187 | * ( 188 | v44 189 | + t2 190 | + t3 191 | - 2.0 * v48 * t13 * t20 192 | - 2.0 * t24 * t29 * t20 193 | + 2.0 * t33 * t38 194 | + 0.5 * v48 * p 195 | ) 196 | * t20 197 | - 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38 198 | + 0.5e4 * t123 * t20 * t134 199 | - 0.5e4 * t143 * t35 * t134 * t37 200 | + 0.5e4 201 | * t143 202 | * t20 203 | * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) 204 | / t133 205 | + 0.5e4 206 | * ( 207 | ( 208 | v22 209 | + t169 210 | + ct * (v23 + t167 + ct * (v24 + 2.0 * t165)) 211 | + sa 212 | * ( 213 | v27 214 | + t179 215 | + ct * (v28 + t177 + ct * (v29 + 2.0 * t175)) 216 | + t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185))) 217 | ) 218 | + ( 219 | 2.0 * t93 * t199 220 | + 2.0 * t106 * t199 221 | + 2.0 * t117 * t68 222 | - 2.0 * t117 * t13 * t35 * t37 223 | - t107 * t92 224 | - t110 * t105 225 | ) 226 | * t20 227 | - t217 * t35 * t37 228 | ) 229 | * t19 230 | + t234 * t37 231 | - t123 * t13 232 | - t143 * t29 233 | ) 234 | * t20 235 | * t254 236 | - 0.5e4 * t259 * t35 * t254 * t37 237 | - 0.25e4 * t264 / t242 / t241 * t253 * t272 238 | + 0.5e4 239 | * t264 240 | * t243 241 | * ( 242 | 2.0 * t152 * t249 243 | + t128 * t243 * t245 * t248 * t272 244 | - 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287) 245 | - 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152) 246 | ) 247 | / t252 248 | ) 249 | 250 | return t305 251 | 252 | 253 | def prepare_inputs(sa, ct, p, device): 254 | out = [ 255 | torch.as_tensor(a, device="cuda" if device == "gpu" else "cpu") 256 | for a in (sa, ct, p) 257 | ] 258 | if device == "gpu": 259 | torch.cuda.synchronize() 260 | return out 261 | 262 | 263 | def run(sa, ct, p, device="cpu"): 264 | with torch.no_grad(): 265 | out = gsw_dHdT(sa, ct, p) 266 | if device == "gpu": 267 | torch.cuda.synchronize() 268 | return out 269 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/eos_taichi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================================== 3 | in-situ density, dynamic enthalpy and derivatives 4 | from Absolute Salinity and Conservative 5 | Temperature, using the computationally-efficient 48-term expression for 6 | density in terms of SA, CT and p (IOC et al., 2010). 7 | ========================================================================== 8 | """ 9 | 10 | from functools import lru_cache 11 | import taichi as ti 12 | 13 | v01 = 9.998420897506056e2 14 | v02 = 2.839940833161907e0 15 | v03 = -3.147759265588511e-2 16 | v04 = 1.181805545074306e-3 17 | v05 = -6.698001071123802e0 18 | v06 = -2.986498947203215e-2 19 | v07 = 2.327859407479162e-4 20 | v08 = -3.988822378968490e-2 21 | v09 = 5.095422573880500e-4 22 | v10 = -1.426984671633621e-5 23 | v11 = 1.645039373682922e-7 24 | v12 = -2.233269627352527e-2 25 | v13 = -3.436090079851880e-4 26 | v14 = 3.726050720345733e-6 27 | v15 = -1.806789763745328e-4 28 | v16 = 6.876837219536232e-7 29 | v17 = -3.087032500374211e-7 30 | v18 = -1.988366587925593e-8 31 | v19 = -1.061519070296458e-11 32 | v20 = 1.550932729220080e-10 33 | v21 = 1.0e0 34 | v22 = 2.775927747785646e-3 35 | v23 = -2.349607444135925e-5 36 | v24 = 1.119513357486743e-6 37 | v25 = 6.743689325042773e-10 38 | v26 = -7.521448093615448e-3 39 | v27 = -2.764306979894411e-5 40 | v28 = 1.262937315098546e-7 41 | v29 = 9.527875081696435e-10 42 | v30 = -1.811147201949891e-11 43 | v31 = -3.303308871386421e-5 44 | v32 = 3.801564588876298e-7 45 | v33 = -7.672876869259043e-9 46 | v34 = -4.634182341116144e-11 47 | v35 = 2.681097235569143e-12 48 | v36 = 5.419326551148740e-6 49 | v37 = -2.742185394906099e-5 50 | v38 = -3.212746477974189e-7 51 | v39 = 3.191413910561627e-9 52 | v40 = -1.931012931541776e-12 53 | v41 = -1.105097577149576e-7 54 | v42 = 6.211426728363857e-10 55 | v43 = -1.119011592875110e-10 56 | v44 = -1.941660213148725e-11 57 | v45 = -1.864826425365600e-14 58 | v46 = 1.119522344879478e-14 59 | v47 = -1.200507748551599e-15 60 | v48 = 6.057902487546866e-17 61 | rho0 = 1024.0 62 | 63 | 64 | @ti.kernel 65 | def gsw_dHdT(sa_arr: ti.template(), ct_arr: ti.template(), p_arr: ti.template(), out: ti.template()): 66 | """ 67 | d/dT of dynamic enthalpy, analytical derivative 68 | 69 | sa : Absolute Salinity [g/kg] 70 | ct : Conservative Temperature [deg C] 71 | p : sea pressure [dbar] 72 | """ 73 | for i, j, k in ti.ndrange(*out.shape): 74 | sa = sa_arr[i, j, k] 75 | ct = ct_arr[i, j, k] 76 | p = p_arr[0, 0, k] 77 | 78 | t1 = v45 * ct 79 | t2 = 0.2e1 * t1 80 | t3 = v46 * sa 81 | t4 = 0.5 * v12 82 | t5 = v14 * ct 83 | t7 = ct * (v13 + t5) 84 | t8 = 0.5 * t7 85 | t11 = sa * (v15 + v16 * ct) 86 | t12 = 0.5 * t11 87 | t13 = t4 + t8 + t12 88 | t15 = v19 * ct 89 | t19 = v17 + ct * (v18 + t15) + v20 * sa 90 | t20 = 1.0 / t19 91 | t24 = v47 + v48 * ct 92 | t25 = 0.5 * v13 93 | t26 = 1.0 * t5 94 | t27 = sa * v16 95 | t28 = 0.5 * t27 96 | t29 = t25 + t26 + t28 97 | t33 = t24 * t13 98 | t34 = t19 ** 2 99 | t35 = 1.0 / t34 100 | t37 = v18 + 2.0 * t15 101 | t38 = t35 * t37 102 | t48 = ct * (v44 + t1 + t3) 103 | t57 = v40 * ct 104 | t59 = ct * (v39 + t57) 105 | t64 = t13 ** 2 106 | t68 = t20 * t29 107 | t71 = t24 * t64 108 | t74 = v04 * ct 109 | t76 = ct * (v03 + t74) 110 | t79 = v07 * ct 111 | t82 = ti.sqrt(sa) 112 | t83 = v11 * ct 113 | t85 = ct * (v10 + t83) 114 | t92 = ( 115 | v01 116 | + ct * (v02 + t76) 117 | + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85))) 118 | ) 119 | t93 = v48 * t92 120 | t105 = ( 121 | v02 122 | + t76 123 | + ct * (v03 + 2.0 * t74) 124 | + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83))) 125 | ) 126 | t106 = t24 * t105 127 | t107 = v44 + t2 + t3 128 | t110 = v43 + t48 129 | t117 = t24 * t92 130 | t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13 131 | t123 = ( 132 | v38 133 | + t59 134 | + ct * (v39 + 2.0 * t57) 135 | + sa * v42 136 | + ( 137 | 4.0 * v48 * t64 * t20 138 | + 8.0 * t33 * t68 139 | - 4.0 * t71 * t38 140 | - t93 141 | - t106 142 | - 2.0 * t107 * t13 143 | - 2.0 * t110 * t29 144 | ) 145 | * t20 146 | - t120 * t35 * t37 147 | ) 148 | t128 = t19 * p 149 | t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128) 150 | t131 = 1.0 / t92 151 | t133 = 1.0 + t130 * t131 152 | t134 = ti.log(t133) 153 | t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20 154 | t152 = t37 * p 155 | t156 = t92 ** 2 156 | t165 = v25 * ct 157 | t167 = ct * (v24 + t165) 158 | t169 = ct * (v23 + t167) 159 | t175 = v30 * ct 160 | t177 = ct * (v29 + t175) 161 | t179 = ct * (v28 + t177) 162 | t185 = v35 * ct 163 | t187 = ct * (v34 + t185) 164 | t189 = ct * (v33 + t187) 165 | t199 = t13 * t20 166 | t217 = 2.0 * t117 * t199 - t110 * t92 167 | t234 = ( 168 | v21 169 | + ct * (v22 + t169) 170 | + sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189))) 171 | + t217 * t20 172 | ) 173 | t241 = t64 - t92 * t19 174 | t242 = ti.sqrt(t241) 175 | t243 = 1.0 / t242 176 | t244 = t4 + t8 + t12 - t242 177 | t245 = 1.0 / t244 178 | t247 = t4 + t8 + t12 + t242 + t128 179 | t248 = 1.0 / t247 180 | t249 = t242 * t245 * t248 181 | t252 = 1.0 + 2.0 * t128 * t249 182 | t253 = ti.log(t252) 183 | t254 = t243 * t253 184 | t259 = t234 * t19 - t143 * t13 185 | t264 = t259 * t20 186 | t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37 187 | t282 = t128 * t242 188 | t283 = t244 ** 2 189 | t287 = t243 * t272 / 2.0 190 | t292 = t247 ** 2 191 | t305 = ( 192 | 0.1e5 193 | * p 194 | * ( 195 | v44 196 | + t2 197 | + t3 198 | - 2.0 * v48 * t13 * t20 199 | - 2.0 * t24 * t29 * t20 200 | + 2.0 * t33 * t38 201 | + 0.5 * v48 * p 202 | ) 203 | * t20 204 | - 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38 205 | + 0.5e4 * t123 * t20 * t134 206 | - 0.5e4 * t143 * t35 * t134 * t37 207 | + 0.5e4 208 | * t143 209 | * t20 210 | * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) 211 | / t133 212 | + 0.5e4 213 | * ( 214 | ( 215 | v22 216 | + t169 217 | + ct * (v23 + t167 + ct * (v24 + 2.0 * t165)) 218 | + sa 219 | * ( 220 | v27 221 | + t179 222 | + ct * (v28 + t177 + ct * (v29 + 2.0 * t175)) 223 | + t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185))) 224 | ) 225 | + ( 226 | 2.0 * t93 * t199 227 | + 2.0 * t106 * t199 228 | + 2.0 * t117 * t68 229 | - 2.0 * t117 * t13 * t35 * t37 230 | - t107 * t92 231 | - t110 * t105 232 | ) 233 | * t20 234 | - t217 * t35 * t37 235 | ) 236 | * t19 237 | + t234 * t37 238 | - t123 * t13 239 | - t143 * t29 240 | ) 241 | * t20 242 | * t254 243 | - 0.5e4 * t259 * t35 * t254 * t37 244 | - 0.25e4 * t264 / t242 / t241 * t253 * t272 245 | + 0.5e4 246 | * t264 247 | * t243 248 | * ( 249 | 2.0 * t152 * t249 250 | + t128 * t243 * t245 * t248 * t272 251 | - 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287) 252 | - 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152) 253 | ) 254 | / t252 255 | ) 256 | 257 | out[i, j, k] = t305 258 | 259 | 260 | @lru_cache 261 | def get_fields(sizes, num_scratch=0): 262 | fields = [] 263 | for size in sizes: 264 | fields.append(ti.field(dtype=ti.f64, shape=size)) 265 | 266 | for _ in range(num_scratch): 267 | fields.append(ti.field(dtype=ti.f64, shape=sizes[0])) 268 | 269 | return fields 270 | 271 | 272 | def prepare_inputs(sa, ct, p, device="cpu"): 273 | sa_field, ct_field, p_field, out_field = get_fields((sa.shape, ct.shape, p.shape), num_scratch=1) 274 | sa_field.from_numpy(sa) 275 | ct_field.from_numpy(ct) 276 | p_field.from_numpy(p) 277 | out_field.fill(0) 278 | return sa_field, ct_field, p_field, out_field 279 | 280 | 281 | def run(sa, ct, p, out, device="cpu"): 282 | gsw_dHdT(sa, ct, p, out) 283 | ti.sync() 284 | return out 285 | -------------------------------------------------------------------------------- /benchmarks/equation_of_state/eos_tensorflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================================== 3 | in-situ density, dynamic enthalpy and derivatives 4 | from Absolute Salinity and Conservative 5 | Temperature, using the computationally-efficient 48-term expression for 6 | density in terms of SA, CT and p (IOC et al., 2010). 7 | ========================================================================== 8 | """ 9 | 10 | import tensorflow as tf 11 | 12 | v01 = 9.998420897506056e2 13 | v02 = 2.839940833161907e0 14 | v03 = -3.147759265588511e-2 15 | v04 = 1.181805545074306e-3 16 | v05 = -6.698001071123802e0 17 | v06 = -2.986498947203215e-2 18 | v07 = 2.327859407479162e-4 19 | v08 = -3.988822378968490e-2 20 | v09 = 5.095422573880500e-4 21 | v10 = -1.426984671633621e-5 22 | v11 = 1.645039373682922e-7 23 | v12 = -2.233269627352527e-2 24 | v13 = -3.436090079851880e-4 25 | v14 = 3.726050720345733e-6 26 | v15 = -1.806789763745328e-4 27 | v16 = 6.876837219536232e-7 28 | v17 = -3.087032500374211e-7 29 | v18 = -1.988366587925593e-8 30 | v19 = -1.061519070296458e-11 31 | v20 = 1.550932729220080e-10 32 | v21 = 1.0e0 33 | v22 = 2.775927747785646e-3 34 | v23 = -2.349607444135925e-5 35 | v24 = 1.119513357486743e-6 36 | v25 = 6.743689325042773e-10 37 | v26 = -7.521448093615448e-3 38 | v27 = -2.764306979894411e-5 39 | v28 = 1.262937315098546e-7 40 | v29 = 9.527875081696435e-10 41 | v30 = -1.811147201949891e-11 42 | v31 = -3.303308871386421e-5 43 | v32 = 3.801564588876298e-7 44 | v33 = -7.672876869259043e-9 45 | v34 = -4.634182341116144e-11 46 | v35 = 2.681097235569143e-12 47 | v36 = 5.419326551148740e-6 48 | v37 = -2.742185394906099e-5 49 | v38 = -3.212746477974189e-7 50 | v39 = 3.191413910561627e-9 51 | v40 = -1.931012931541776e-12 52 | v41 = -1.105097577149576e-7 53 | v42 = 6.211426728363857e-10 54 | v43 = -1.119011592875110e-10 55 | v44 = -1.941660213148725e-11 56 | v45 = -1.864826425365600e-14 57 | v46 = 1.119522344879478e-14 58 | v47 = -1.200507748551599e-15 59 | v48 = 6.057902487546866e-17 60 | rho0 = 1024.0 61 | 62 | 63 | def gsw_dHdT(sa, ct, p): 64 | """ 65 | d/dT of dynamic enthalpy, analytical derivative 66 | 67 | sa : Absolute Salinity [g/kg] 68 | ct : Conservative Temperature [deg C] 69 | p : sea pressure [dbar] 70 | """ 71 | t1 = v45 * ct 72 | t2 = 0.2e1 * t1 73 | t3 = v46 * sa 74 | t4 = 0.5 * v12 75 | t5 = v14 * ct 76 | t7 = ct * (v13 + t5) 77 | t8 = 0.5 * t7 78 | t11 = sa * (v15 + v16 * ct) 79 | t12 = 0.5 * t11 80 | t13 = t4 + t8 + t12 81 | t15 = v19 * ct 82 | t19 = v17 + ct * (v18 + t15) + v20 * sa 83 | t20 = 1.0 / t19 84 | t24 = v47 + v48 * ct 85 | t25 = 0.5 * v13 86 | t26 = 1.0 * t5 87 | t27 = sa * v16 88 | t28 = 0.5 * t27 89 | t29 = t25 + t26 + t28 90 | t33 = t24 * t13 91 | t34 = t19 ** 2 92 | t35 = 1.0 / t34 93 | t37 = v18 + 2.0 * t15 94 | t38 = t35 * t37 95 | t48 = ct * (v44 + t1 + t3) 96 | t57 = v40 * ct 97 | t59 = ct * (v39 + t57) 98 | t64 = t13 ** 2 99 | t68 = t20 * t29 100 | t71 = t24 * t64 101 | t74 = v04 * ct 102 | t76 = ct * (v03 + t74) 103 | t79 = v07 * ct 104 | t82 = tf.math.sqrt(sa) 105 | t83 = v11 * ct 106 | t85 = ct * (v10 + t83) 107 | t92 = ( 108 | v01 109 | + ct * (v02 + t76) 110 | + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85))) 111 | ) 112 | t93 = v48 * t92 113 | t105 = ( 114 | v02 115 | + t76 116 | + ct * (v03 + 2.0 * t74) 117 | + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83))) 118 | ) 119 | t106 = t24 * t105 120 | t107 = v44 + t2 + t3 121 | t110 = v43 + t48 122 | t117 = t24 * t92 123 | t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13 124 | t123 = ( 125 | v38 126 | + t59 127 | + ct * (v39 + 2.0 * t57) 128 | + sa * v42 129 | + ( 130 | 4.0 * v48 * t64 * t20 131 | + 8.0 * t33 * t68 132 | - 4.0 * t71 * t38 133 | - t93 134 | - t106 135 | - 2.0 * t107 * t13 136 | - 2.0 * t110 * t29 137 | ) 138 | * t20 139 | - t120 * t35 * t37 140 | ) 141 | t128 = t19 * p 142 | t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128) 143 | t131 = 1.0 / t92 144 | t133 = 1.0 + t130 * t131 145 | t134 = tf.math.log(t133) 146 | t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20 147 | t152 = t37 * p 148 | t156 = t92 ** 2 149 | t165 = v25 * ct 150 | t167 = ct * (v24 + t165) 151 | t169 = ct * (v23 + t167) 152 | t175 = v30 * ct 153 | t177 = ct * (v29 + t175) 154 | t179 = ct * (v28 + t177) 155 | t185 = v35 * ct 156 | t187 = ct * (v34 + t185) 157 | t189 = ct * (v33 + t187) 158 | t199 = t13 * t20 159 | t217 = 2.0 * t117 * t199 - t110 * t92 160 | t234 = ( 161 | v21 162 | + ct * (v22 + t169) 163 | + sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189))) 164 | + t217 * t20 165 | ) 166 | t241 = t64 - t92 * t19 167 | t242 = tf.math.sqrt(t241) 168 | t243 = 1.0 / t242 169 | t244 = t4 + t8 + t12 - t242 170 | t245 = 1.0 / t244 171 | t247 = t4 + t8 + t12 + t242 + t128 172 | t248 = 1.0 / t247 173 | t249 = t242 * t245 * t248 174 | t252 = 1.0 + 2.0 * t128 * t249 175 | t253 = tf.math.log(t252) 176 | t254 = t243 * t253 177 | t259 = t234 * t19 - t143 * t13 178 | t264 = t259 * t20 179 | t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37 180 | t282 = t128 * t242 181 | t283 = t244 ** 2 182 | t287 = t243 * t272 / 2.0 183 | t292 = t247 ** 2 184 | t305 = ( 185 | 0.1e5 186 | * p 187 | * ( 188 | v44 189 | + t2 190 | + t3 191 | - 2.0 * v48 * t13 * t20 192 | - 2.0 * t24 * t29 * t20 193 | + 2.0 * t33 * t38 194 | + 0.5 * v48 * p 195 | ) 196 | * t20 197 | - 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38 198 | + 0.5e4 * t123 * t20 * t134 199 | - 0.5e4 * t143 * t35 * t134 * t37 200 | + 0.5e4 201 | * t143 202 | * t20 203 | * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) 204 | / t133 205 | + 0.5e4 206 | * ( 207 | ( 208 | v22 209 | + t169 210 | + ct * (v23 + t167 + ct * (v24 + 2.0 * t165)) 211 | + sa 212 | * ( 213 | v27 214 | + t179 215 | + ct * (v28 + t177 + ct * (v29 + 2.0 * t175)) 216 | + t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185))) 217 | ) 218 | + ( 219 | 2.0 * t93 * t199 220 | + 2.0 * t106 * t199 221 | + 2.0 * t117 * t68 222 | - 2.0 * t117 * t13 * t35 * t37 223 | - t107 * t92 224 | - t110 * t105 225 | ) 226 | * t20 227 | - t217 * t35 * t37 228 | ) 229 | * t19 230 | + t234 * t37 231 | - t123 * t13 232 | - t143 * t29 233 | ) 234 | * t20 235 | * t254 236 | - 0.5e4 * t259 * t35 * t254 * t37 237 | - 0.25e4 * t264 / t242 / t241 * t253 * t272 238 | + 0.5e4 239 | * t264 240 | * t243 241 | * ( 242 | 2.0 * t152 * t249 243 | + t128 * t243 * t245 * t248 * t272 244 | - 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287) 245 | - 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152) 246 | ) 247 | / t252 248 | ) 249 | 250 | return t305 251 | 252 | 253 | gsw_dHdT_tf = tf.function(gsw_dHdT, experimental_compile=True) 254 | 255 | 256 | def prepare_inputs(sa, ct, p, device): 257 | return [tf.convert_to_tensor(a) for a in (sa, ct, p)] 258 | 259 | 260 | def run(sa, ct, p, device="cpu"): 261 | return gsw_dHdT_tf(sa, ct, p) 262 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/README.md: -------------------------------------------------------------------------------- 1 | # Isoneutral mixing benchmark 2 | 3 | This routine computes the mixing tensors that we need to simulate mixing in the ocean. 4 | 5 | In the ocean, mixing takes place along surfaces of constant neutral density. At every 6 | time step, we need to figure out where these surfaces lie in 3D space. This is usually 7 | the single most costly operation in an ocean model. 8 | 9 | Numerically, this routine consists of many finite difference derivatives and grid 10 | interpolations. There are many horizontal data dependencies and some arrays have 11 | up to 5 dimensions, but still only elementary math, and everything is vectorized. 12 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/__init__.py: -------------------------------------------------------------------------------- 1 | import math 2 | import importlib 3 | import functools 4 | 5 | 6 | def generate_inputs(size): 7 | import numpy as np 8 | 9 | np.random.seed(17) 10 | 11 | shape = ( 12 | math.ceil(2 * size ** (1 / 3)), 13 | math.ceil(2 * size ** (1 / 3)), 14 | math.ceil(0.25 * size ** (1 / 3)), 15 | ) 16 | 17 | # masks 18 | maskT, maskU, maskV, maskW = ( 19 | (np.random.rand(*shape) < 0.8).astype("float64") for _ in range(4) 20 | ) 21 | 22 | # 1d arrays 23 | dxt, dxu = (np.random.randn(shape[0]) for _ in range(2)) 24 | dyt, dyu = (np.random.randn(shape[1]) for _ in range(2)) 25 | dzt, dzw, zt = (np.random.randn(shape[2]) for _ in range(3)) 26 | cost, cosu = (np.random.randn(shape[1]) for _ in range(2)) 27 | 28 | # 3d arrays 29 | K_iso, K_11, K_22, K_33 = (np.random.randn(*shape) for _ in range(4)) 30 | 31 | # 4d arrays 32 | salt, temp = (np.random.randn(*shape, 3) for _ in range(2)) 33 | 34 | # 5d arrays 35 | Ai_ez, Ai_nz, Ai_bx, Ai_by = (np.zeros((*shape, 2, 2)) for _ in range(4)) 36 | 37 | return ( 38 | maskT, 39 | maskU, 40 | maskV, 41 | maskW, 42 | dxt, 43 | dxu, 44 | dyt, 45 | dyu, 46 | dzt, 47 | dzw, 48 | cost, 49 | cosu, 50 | salt, 51 | temp, 52 | zt, 53 | K_iso, 54 | K_11, 55 | K_22, 56 | K_33, 57 | Ai_ez, 58 | Ai_nz, 59 | Ai_bx, 60 | Ai_by, 61 | ) 62 | 63 | 64 | def try_import(backend): 65 | try: 66 | return importlib.import_module(f".isoneutral_{backend}", __name__) 67 | except ImportError: 68 | return None 69 | 70 | 71 | def get_callable(backend, size, device="cpu"): 72 | backend_module = try_import(backend) 73 | inputs = generate_inputs(size) 74 | if hasattr(backend_module, "prepare_inputs"): 75 | inputs = backend_module.prepare_inputs(*inputs, device=device) 76 | return functools.partial(backend_module.run, *inputs, device=device) 77 | 78 | 79 | __implementations__ = ( 80 | "aesara", 81 | "cupy", 82 | "numba", 83 | "numpy", 84 | "jax", 85 | "pytorch", 86 | "taichi", 87 | ) 88 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/isoneutral_aesara.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | import aesara.tensor as aet 3 | 4 | 5 | def get_drhodT(salt, temp, p): 6 | rho0 = 1024.0 7 | z0 = 0.0 8 | theta0 = 283.0 - 273.15 9 | grav = 9.81 10 | betaT = 1.67e-4 11 | betaTs = 1e-5 12 | gammas = 1.1e-8 13 | 14 | zz = -p - z0 15 | thetas = temp - theta0 16 | return -(betaTs * thetas + betaT * (1 - gammas * grav * zz * rho0)) * rho0 17 | 18 | 19 | def get_drhodS(salt, temp, p): 20 | betaS = 0.78e-3 21 | rho0 = 1024.0 22 | return betaS * rho0 * aet.ones_like(temp) 23 | 24 | 25 | def isoneutral_diffusion_pre( 26 | maskT, 27 | maskU, 28 | maskV, 29 | maskW, 30 | dxt, 31 | dxu, 32 | dyt, 33 | dyu, 34 | dzt, 35 | dzw, 36 | cost, 37 | cosu, 38 | salt, 39 | temp, 40 | zt, 41 | K_iso, 42 | K_11, 43 | K_22, 44 | K_33, 45 | Ai_ez, 46 | Ai_nz, 47 | Ai_bx, 48 | Ai_by, 49 | ): 50 | """ 51 | Isopycnal diffusion for tracer 52 | following functional formulation by Griffies et al 53 | Code adopted from MOM2.1 54 | """ 55 | epsln = 1e-20 56 | iso_slopec = 1e-3 57 | iso_dslope = 1e-3 58 | K_iso_steep = 50.0 59 | tau = 0 60 | 61 | dTdx = aet.zeros_like(K_11) 62 | dSdx = aet.zeros_like(K_11) 63 | dTdy = aet.zeros_like(K_11) 64 | dSdy = aet.zeros_like(K_11) 65 | dTdz = aet.zeros_like(K_11) 66 | dSdz = aet.zeros_like(K_11) 67 | 68 | """ 69 | drho_dt and drho_ds at centers of T cells 70 | """ 71 | drdT = maskT * get_drhodT(salt[:, :, :, tau], temp[:, :, :, tau], abs(zt)) 72 | drdS = maskT * get_drhodS(salt[:, :, :, tau], temp[:, :, :, tau], abs(zt)) 73 | 74 | """ 75 | gradients at top face of T cells 76 | """ 77 | dTdz = aet.set_subtensor( 78 | dTdz[:, :, :-1], 79 | maskW[:, :, :-1] 80 | * (temp[:, :, 1:, tau] - temp[:, :, :-1, tau]) 81 | / dzw[:, :, :-1], 82 | ) 83 | dSdz = aet.set_subtensor( 84 | dSdz[:, :, :-1], 85 | maskW[:, :, :-1] 86 | * (salt[:, :, 1:, tau] - salt[:, :, :-1, tau]) 87 | / dzw[:, :, :-1], 88 | ) 89 | 90 | """ 91 | gradients at eastern face of T cells 92 | """ 93 | dTdx = aet.set_subtensor( 94 | dTdx[:-1, :, :], 95 | maskU[:-1, :, :] 96 | * (temp[1:, :, :, tau] - temp[:-1, :, :, tau]) 97 | / (dxu[:-1, :, :] * cost[:, :, :]), 98 | ) 99 | dSdx = aet.set_subtensor( 100 | dSdx[:-1, :, :], 101 | maskU[:-1, :, :] 102 | * (salt[1:, :, :, tau] - salt[:-1, :, :, tau]) 103 | / (dxu[:-1, :, :] * cost[:, :, :]), 104 | ) 105 | 106 | """ 107 | gradients at northern face of T cells 108 | """ 109 | dTdy = aet.set_subtensor( 110 | dTdy[:, :-1, :], 111 | maskV[:, :-1, :] 112 | * (temp[:, 1:, :, tau] - temp[:, :-1, :, tau]) 113 | / dyu[:, :-1, :], 114 | ) 115 | dSdy = aet.set_subtensor( 116 | dSdy[:, :-1, :], 117 | maskV[:, :-1, :] 118 | * (salt[:, 1:, :, tau] - salt[:, :-1, :, tau]) 119 | / dyu[:, :-1, :], 120 | ) 121 | 122 | def dm_taper(sx): 123 | """ 124 | tapering function for isopycnal slopes 125 | """ 126 | return 0.5 * (1.0 + aet.tanh((-abs(sx) + iso_slopec) / iso_dslope)) 127 | 128 | """ 129 | Compute Ai_ez and K11 on center of east face of T cell. 130 | """ 131 | diffloc = aet.zeros_like(K_11) 132 | diffloc = aet.set_subtensor( 133 | diffloc[1:-2, 2:-2, 1:], 134 | 0.25 135 | * ( 136 | K_iso[1:-2, 2:-2, 1:] 137 | + K_iso[1:-2, 2:-2, :-1] 138 | + K_iso[2:-1, 2:-2, 1:] 139 | + K_iso[2:-1, 2:-2, :-1] 140 | ), 141 | ) 142 | diffloc = aet.set_subtensor( 143 | diffloc[1:-2, 2:-2, 0], 0.5 * (K_iso[1:-2, 2:-2, 0] + K_iso[2:-1, 2:-2, 0]) 144 | ) 145 | 146 | sumz = aet.zeros_like(K_11)[1:-2, 2:-2] 147 | for kr in range(2): 148 | ki = 0 if kr == 1 else 1 149 | for ip in range(2): 150 | drodxe = ( 151 | drdT[1 + ip : -2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:] 152 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:] 153 | ) 154 | drodze = ( 155 | drdT[1 + ip : -2 + ip, 2:-2, ki:] 156 | * dTdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None] 157 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] 158 | * dSdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None] 159 | ) 160 | sxe = -drodxe / (aet.minimum(0.0, drodze) - epsln) 161 | taper = dm_taper(sxe) 162 | sumz = aet.inc_subtensor( 163 | sumz[:, :, ki:], 164 | dzw[:, :, : -1 + kr or None] 165 | * maskU[1:-2, 2:-2, ki:] 166 | * aet.maximum(K_iso_steep, diffloc[1:-2, 2:-2, ki:] * taper), 167 | ) 168 | Ai_ez = aet.set_subtensor( 169 | Ai_ez[1:-2, 2:-2, ki:, ip, kr], taper * sxe * maskU[1:-2, 2:-2, ki:] 170 | ) 171 | K_11 = aet.set_subtensor(K_11[1:-2, 2:-2, :], sumz / (4.0 * dzt[:, :, :])) 172 | 173 | """ 174 | Compute Ai_nz and K_22 on center of north face of T cell. 175 | """ 176 | diffloc = aet.set_subtensor(diffloc[...], 0) 177 | diffloc = aet.set_subtensor( 178 | diffloc[2:-2, 1:-2, 1:], 179 | 0.25 180 | * ( 181 | K_iso[2:-2, 1:-2, 1:] 182 | + K_iso[2:-2, 1:-2, :-1] 183 | + K_iso[2:-2, 2:-1, 1:] 184 | + K_iso[2:-2, 2:-1, :-1] 185 | ), 186 | ) 187 | diffloc = aet.set_subtensor( 188 | diffloc[2:-2, 1:-2, 0], 0.5 * (K_iso[2:-2, 1:-2, 0] + K_iso[2:-2, 2:-1, 0]) 189 | ) 190 | 191 | sumz = aet.zeros_like(K_11)[2:-2, 1:-2] 192 | for kr in range(2): 193 | ki = 0 if kr == 1 else 1 194 | for jp in range(2): 195 | drodyn = ( 196 | drdT[2:-2, 1 + jp : -2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:] 197 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:] 198 | ) 199 | drodzn = ( 200 | drdT[2:-2, 1 + jp : -2 + jp, ki:] 201 | * dTdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None] 202 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] 203 | * dSdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None] 204 | ) 205 | syn = -drodyn / (aet.minimum(0.0, drodzn) - epsln) 206 | taper = dm_taper(syn) 207 | sumz = aet.inc_subtensor( 208 | sumz[:, :, ki:], 209 | dzw[:, :, : -1 + kr or None] 210 | * maskV[2:-2, 1:-2, ki:] 211 | * aet.maximum(K_iso_steep, diffloc[2:-2, 1:-2, ki:] * taper), 212 | ) 213 | Ai_nz = aet.set_subtensor( 214 | Ai_nz[2:-2, 1:-2, ki:, jp, kr], taper * syn * maskV[2:-2, 1:-2, ki:] 215 | ) 216 | K_22 = aet.set_subtensor(K_22[2:-2, 1:-2, :], sumz / (4.0 * dzt[:, :, :])) 217 | 218 | """ 219 | compute Ai_bx, Ai_by and K33 on top face of T cell. 220 | """ 221 | sumx = aet.zeros_like(K_11)[2:-2, 2:-2, :-1] 222 | sumy = aet.zeros_like(K_11)[2:-2, 2:-2, :-1] 223 | 224 | for kr in range(2): 225 | drodzb = ( 226 | drdT[2:-2, 2:-2, kr : -1 + kr or None] * dTdz[2:-2, 2:-2, :-1] 227 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] * dSdz[2:-2, 2:-2, :-1] 228 | ) 229 | 230 | # eastward slopes at the top of T cells 231 | for ip in range(2): 232 | drodxb = ( 233 | drdT[2:-2, 2:-2, kr : -1 + kr or None] 234 | * dTdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None] 235 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] 236 | * dSdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None] 237 | ) 238 | sxb = -drodxb / (aet.minimum(0.0, drodzb) - epsln) 239 | taper = dm_taper(sxb) 240 | sumx += ( 241 | dxu[1 + ip : -3 + ip, :, :] 242 | * K_iso[2:-2, 2:-2, :-1] 243 | * taper 244 | * sxb ** 2 245 | * maskW[2:-2, 2:-2, :-1] 246 | ) 247 | Ai_bx = aet.set_subtensor( 248 | Ai_bx[2:-2, 2:-2, :-1, ip, kr], taper * sxb * maskW[2:-2, 2:-2, :-1] 249 | ) 250 | 251 | # northward slopes at the top of T cells 252 | for jp in range(2): 253 | facty = cosu[:, 1 + jp : -3 + jp] * dyu[:, 1 + jp : -3 + jp] 254 | drodyb = ( 255 | drdT[2:-2, 2:-2, kr : -1 + kr or None] 256 | * dTdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None] 257 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] 258 | * dSdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None] 259 | ) 260 | syb = -drodyb / (aet.minimum(0.0, drodzb) - epsln) 261 | taper = dm_taper(syb) 262 | sumy += ( 263 | facty 264 | * K_iso[2:-2, 2:-2, :-1] 265 | * taper 266 | * syb ** 2 267 | * maskW[2:-2, 2:-2, :-1] 268 | ) 269 | Ai_by = aet.set_subtensor( 270 | Ai_by[2:-2, 2:-2, :-1, jp, kr], taper * syb * maskW[2:-2, 2:-2, :-1] 271 | ) 272 | 273 | K_33 = aet.set_subtensor( 274 | K_33[2:-2, 2:-2, :-1], 275 | sumx / (4 * dxt[2:-2, :, :]) + sumy / (4 * dyt[:, 2:-2, :] * cost[:, 2:-2, :]), 276 | ) 277 | K_33 = aet.set_subtensor(K_33[2:-2, 2:-2, -1], 0.0) 278 | 279 | return K_11, K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by 280 | 281 | 282 | t1d_x = aesara.tensor.TensorType(dtype="float64", broadcastable=(False, True, True)) 283 | t1d_y = aesara.tensor.TensorType(dtype="float64", broadcastable=(True, False, True)) 284 | t1d_z = aesara.tensor.TensorType(dtype="float64", broadcastable=(True, True, False)) 285 | 286 | symbolic_inputs = [ 287 | aesara.tensor.dtensor3("maskT"), 288 | aesara.tensor.dtensor3("maskU"), 289 | aesara.tensor.dtensor3("maskV"), 290 | aesara.tensor.dtensor3("maskW"), 291 | t1d_x("dxt"), 292 | t1d_x("dxu"), 293 | t1d_y("dyt"), 294 | t1d_y("dyu"), 295 | t1d_z("dzt"), 296 | t1d_z("dzw"), 297 | t1d_y("cost"), 298 | t1d_y("cosu"), 299 | aesara.tensor.dtensor4("salt"), 300 | aesara.tensor.dtensor4("temp"), 301 | t1d_z("zt"), 302 | aesara.tensor.dtensor3("K_iso"), 303 | aesara.tensor.dtensor3("K_11"), 304 | aesara.tensor.dtensor3("K_22"), 305 | aesara.tensor.dtensor3("K_33"), 306 | aesara.tensor.dtensor5("Ai_ez"), 307 | aesara.tensor.dtensor5("Ai_nz"), 308 | aesara.tensor.dtensor5("Ai_bx"), 309 | aesara.tensor.dtensor5("Ai_by"), 310 | ] 311 | isoneutral_aesara = aesara.function( 312 | symbolic_inputs, isoneutral_diffusion_pre(*symbolic_inputs) 313 | ) 314 | 315 | 316 | def prepare_inputs(*inputs, device): 317 | inputs = list(inputs) 318 | 319 | for i in (4, 5): 320 | inputs[i] = inputs[i].reshape(-1, 1, 1) 321 | 322 | for i in (6, 7, 10, 11): 323 | inputs[i] = inputs[i].reshape(1, -1, 1) 324 | 325 | for i in (8, 9, 14): 326 | inputs[i] = inputs[i].reshape(1, 1, -1) 327 | 328 | return inputs 329 | 330 | 331 | def run(*inputs, device="cpu"): 332 | outputs = isoneutral_aesara(*inputs) 333 | return outputs 334 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/isoneutral_cupy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cupy as cp 3 | 4 | 5 | def get_drhodT(salt, temp, p): 6 | rho0 = 1024.0 7 | z0 = 0.0 8 | theta0 = 283.0 - 273.15 9 | grav = 9.81 10 | betaT = 1.67e-4 11 | betaTs = 1e-5 12 | gammas = 1.1e-8 13 | 14 | zz = -p - z0 15 | thetas = temp - theta0 16 | return -(betaTs * thetas + betaT * (1 - gammas * grav * zz * rho0)) * rho0 17 | 18 | 19 | def get_drhodS(salt, temp, p): 20 | betaS = 0.78e-3 21 | rho0 = 1024.0 22 | return betaS * rho0 * cp.ones_like(temp) 23 | 24 | 25 | def isoneutral_diffusion_pre( 26 | maskT, 27 | maskU, 28 | maskV, 29 | maskW, 30 | dxt, 31 | dxu, 32 | dyt, 33 | dyu, 34 | dzt, 35 | dzw, 36 | cost, 37 | cosu, 38 | salt, 39 | temp, 40 | zt, 41 | K_iso, 42 | K_11, 43 | K_22, 44 | K_33, 45 | Ai_ez, 46 | Ai_nz, 47 | Ai_bx, 48 | Ai_by, 49 | ): 50 | """ 51 | Isopycnal diffusion for tracer 52 | following functional formulation by Griffies et al 53 | Code adopted from MOM2.1 54 | """ 55 | epsln = 1e-20 56 | iso_slopec = 1e-3 57 | iso_dslope = 1e-3 58 | K_iso_steep = 50.0 59 | tau = 0 60 | 61 | dTdx = cp.zeros_like(K_11) 62 | dSdx = cp.zeros_like(K_11) 63 | dTdy = cp.zeros_like(K_11) 64 | dSdy = cp.zeros_like(K_11) 65 | dTdz = cp.zeros_like(K_11) 66 | dSdz = cp.zeros_like(K_11) 67 | 68 | """ 69 | drho_dt and drho_ds at centers of T cells 70 | """ 71 | drdT = maskT * get_drhodT(salt[:, :, :, tau], temp[:, :, :, tau], cp.abs(zt)) 72 | drdS = maskT * get_drhodS(salt[:, :, :, tau], temp[:, :, :, tau], cp.abs(zt)) 73 | 74 | """ 75 | gradients at top face of T cells 76 | """ 77 | dTdz[:, :, :-1] = ( 78 | maskW[:, :, :-1] 79 | * (temp[:, :, 1:, tau] - temp[:, :, :-1, tau]) 80 | / dzw[np.newaxis, np.newaxis, :-1] 81 | ) 82 | dSdz[:, :, :-1] = ( 83 | maskW[:, :, :-1] 84 | * (salt[:, :, 1:, tau] - salt[:, :, :-1, tau]) 85 | / dzw[np.newaxis, np.newaxis, :-1] 86 | ) 87 | 88 | """ 89 | gradients at eastern face of T cells 90 | """ 91 | dTdx[:-1, :, :] = ( 92 | maskU[:-1, :, :] 93 | * (temp[1:, :, :, tau] - temp[:-1, :, :, tau]) 94 | / (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis]) 95 | ) 96 | dSdx[:-1, :, :] = ( 97 | maskU[:-1, :, :] 98 | * (salt[1:, :, :, tau] - salt[:-1, :, :, tau]) 99 | / (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis]) 100 | ) 101 | 102 | """ 103 | gradients at northern face of T cells 104 | """ 105 | dTdy[:, :-1, :] = ( 106 | maskV[:, :-1, :] 107 | * (temp[:, 1:, :, tau] - temp[:, :-1, :, tau]) 108 | / dyu[np.newaxis, :-1, np.newaxis] 109 | ) 110 | dSdy[:, :-1, :] = ( 111 | maskV[:, :-1, :] 112 | * (salt[:, 1:, :, tau] - salt[:, :-1, :, tau]) 113 | / dyu[np.newaxis, :-1, np.newaxis] 114 | ) 115 | 116 | def dm_taper(sx): 117 | """ 118 | tapering function for isopycnal slopes 119 | """ 120 | return 0.5 * (1.0 + cp.tanh((-cp.abs(sx) + iso_slopec) / iso_dslope)) 121 | 122 | """ 123 | Compute Ai_ez and K11 on center of east face of T cell. 124 | """ 125 | diffloc = cp.zeros_like(K_11) 126 | diffloc[1:-2, 2:-2, 1:] = 0.25 * ( 127 | K_iso[1:-2, 2:-2, 1:] 128 | + K_iso[1:-2, 2:-2, :-1] 129 | + K_iso[2:-1, 2:-2, 1:] 130 | + K_iso[2:-1, 2:-2, :-1] 131 | ) 132 | diffloc[1:-2, 2:-2, 0] = 0.5 * (K_iso[1:-2, 2:-2, 0] + K_iso[2:-1, 2:-2, 0]) 133 | 134 | sumz = cp.zeros_like(K_11)[1:-2, 2:-2] 135 | for kr in range(2): 136 | ki = 0 if kr == 1 else 1 137 | for ip in range(2): 138 | drodxe = ( 139 | drdT[1 + ip : -2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:] 140 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:] 141 | ) 142 | drodze = ( 143 | drdT[1 + ip : -2 + ip, 2:-2, ki:] 144 | * dTdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None] 145 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] 146 | * dSdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None] 147 | ) 148 | sxe = -drodxe / (cp.minimum(0.0, drodze) - epsln) 149 | taper = dm_taper(sxe) 150 | sumz[:, :, ki:] += ( 151 | dzw[np.newaxis, np.newaxis, : -1 + kr or None] 152 | * maskU[1:-2, 2:-2, ki:] 153 | * cp.maximum(K_iso_steep, diffloc[1:-2, 2:-2, ki:] * taper) 154 | ) 155 | Ai_ez[1:-2, 2:-2, ki:, ip, kr] = taper * sxe * maskU[1:-2, 2:-2, ki:] 156 | K_11[1:-2, 2:-2, :] = sumz / (4.0 * dzt[np.newaxis, np.newaxis, :]) 157 | 158 | """ 159 | Compute Ai_nz and K_22 on center of north face of T cell. 160 | """ 161 | diffloc[...] = 0 162 | diffloc[2:-2, 1:-2, 1:] = 0.25 * ( 163 | K_iso[2:-2, 1:-2, 1:] 164 | + K_iso[2:-2, 1:-2, :-1] 165 | + K_iso[2:-2, 2:-1, 1:] 166 | + K_iso[2:-2, 2:-1, :-1] 167 | ) 168 | diffloc[2:-2, 1:-2, 0] = 0.5 * (K_iso[2:-2, 1:-2, 0] + K_iso[2:-2, 2:-1, 0]) 169 | 170 | sumz = cp.zeros_like(K_11)[2:-2, 1:-2] 171 | for kr in range(2): 172 | ki = 0 if kr == 1 else 1 173 | for jp in range(2): 174 | drodyn = ( 175 | drdT[2:-2, 1 + jp : -2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:] 176 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:] 177 | ) 178 | drodzn = ( 179 | drdT[2:-2, 1 + jp : -2 + jp, ki:] 180 | * dTdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None] 181 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] 182 | * dSdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None] 183 | ) 184 | syn = -drodyn / (cp.minimum(0.0, drodzn) - epsln) 185 | taper = dm_taper(syn) 186 | sumz[:, :, ki:] += ( 187 | dzw[np.newaxis, np.newaxis, : -1 + kr or None] 188 | * maskV[2:-2, 1:-2, ki:] 189 | * cp.maximum(K_iso_steep, diffloc[2:-2, 1:-2, ki:] * taper) 190 | ) 191 | Ai_nz[2:-2, 1:-2, ki:, jp, kr] = taper * syn * maskV[2:-2, 1:-2, ki:] 192 | K_22[2:-2, 1:-2, :] = sumz / (4.0 * dzt[np.newaxis, np.newaxis, :]) 193 | 194 | """ 195 | compute Ai_bx, Ai_by and K33 on top face of T cell. 196 | """ 197 | sumx = cp.zeros_like(K_11)[2:-2, 2:-2, :-1] 198 | sumy = cp.zeros_like(K_11)[2:-2, 2:-2, :-1] 199 | 200 | for kr in range(2): 201 | drodzb = ( 202 | drdT[2:-2, 2:-2, kr : -1 + kr or None] * dTdz[2:-2, 2:-2, :-1] 203 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] * dSdz[2:-2, 2:-2, :-1] 204 | ) 205 | 206 | # eastward slopes at the top of T cells 207 | for ip in range(2): 208 | drodxb = ( 209 | drdT[2:-2, 2:-2, kr : -1 + kr or None] 210 | * dTdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None] 211 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] 212 | * dSdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None] 213 | ) 214 | sxb = -drodxb / (cp.minimum(0.0, drodzb) - epsln) 215 | taper = dm_taper(sxb) 216 | sumx += ( 217 | dxu[1 + ip : -3 + ip, np.newaxis, np.newaxis] 218 | * K_iso[2:-2, 2:-2, :-1] 219 | * taper 220 | * sxb ** 2 221 | * maskW[2:-2, 2:-2, :-1] 222 | ) 223 | Ai_bx[2:-2, 2:-2, :-1, ip, kr] = taper * sxb * maskW[2:-2, 2:-2, :-1] 224 | 225 | # northward slopes at the top of T cells 226 | for jp in range(2): 227 | facty = cosu[1 + jp : -3 + jp] * dyu[1 + jp : -3 + jp] 228 | drodyb = ( 229 | drdT[2:-2, 2:-2, kr : -1 + kr or None] 230 | * dTdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None] 231 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] 232 | * dSdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None] 233 | ) 234 | syb = -drodyb / (cp.minimum(0.0, drodzb) - epsln) 235 | taper = dm_taper(syb) 236 | sumy += ( 237 | facty[np.newaxis, :, np.newaxis] 238 | * K_iso[2:-2, 2:-2, :-1] 239 | * taper 240 | * syb ** 2 241 | * maskW[2:-2, 2:-2, :-1] 242 | ) 243 | Ai_by[2:-2, 2:-2, :-1, jp, kr] = taper * syb * maskW[2:-2, 2:-2, :-1] 244 | 245 | K_33[2:-2, 2:-2, :-1] = sumx / (4 * dxt[2:-2, np.newaxis, np.newaxis]) + sumy / ( 246 | 4 * dyt[np.newaxis, 2:-2, np.newaxis] * cost[np.newaxis, 2:-2, np.newaxis] 247 | ) 248 | K_33[2:-2, 2:-2, -1] = 0.0 249 | 250 | 251 | def prepare_inputs(*inputs, device): 252 | out = [cp.asarray(k) for k in inputs] 253 | cp.cuda.stream.get_current_stream().synchronize() 254 | return out 255 | 256 | 257 | def run(*inputs, device="cpu"): 258 | isoneutral_diffusion_pre(*inputs) 259 | cp.cuda.stream.get_current_stream().synchronize() 260 | return inputs[-7:] 261 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/isoneutral_jax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | @jax.jit 6 | def get_drhodT(salt, temp, p): 7 | rho0 = 1024.0 8 | z0 = 0.0 9 | theta0 = 283.0 - 273.15 10 | grav = 9.81 11 | betaT = 1.67e-4 12 | betaTs = 1e-5 13 | gammas = 1.1e-8 14 | 15 | zz = -p - z0 16 | thetas = temp - theta0 17 | return -(betaTs * thetas + betaT * (1 - gammas * grav * zz * rho0)) * rho0 18 | 19 | 20 | @jax.jit 21 | def get_drhodS(salt, temp, p): 22 | betaS = 0.78e-3 23 | rho0 = 1024.0 24 | return betaS * rho0 * jnp.ones_like(temp) 25 | 26 | 27 | @jax.jit 28 | def isoneutral_diffusion_pre( 29 | maskT, 30 | maskU, 31 | maskV, 32 | maskW, 33 | dxt, 34 | dxu, 35 | dyt, 36 | dyu, 37 | dzt, 38 | dzw, 39 | cost, 40 | cosu, 41 | salt, 42 | temp, 43 | zt, 44 | K_iso, 45 | K_11, 46 | K_22, 47 | K_33, 48 | Ai_ez, 49 | Ai_nz, 50 | Ai_bx, 51 | Ai_by, 52 | ): 53 | """ 54 | Isopycnal diffusion for tracer 55 | following functional formulation by Griffies et al 56 | Code adopted from MOM2.1 57 | """ 58 | epsln = 1e-20 59 | iso_slopec = 1e-3 60 | iso_dslope = 1e-3 61 | K_iso_steep = 50.0 62 | tau = 0 63 | 64 | dTdx = jnp.zeros_like(K_11) 65 | dSdx = jnp.zeros_like(K_11) 66 | dTdy = jnp.zeros_like(K_11) 67 | dSdy = jnp.zeros_like(K_11) 68 | dTdz = jnp.zeros_like(K_11) 69 | dSdz = jnp.zeros_like(K_11) 70 | 71 | """ 72 | drho_dt and drho_ds at centers of T cells 73 | """ 74 | drdT = maskT * get_drhodT(salt[:, :, :, tau], temp[:, :, :, tau], jnp.abs(zt)) 75 | drdS = maskT * get_drhodS(salt[:, :, :, tau], temp[:, :, :, tau], jnp.abs(zt)) 76 | 77 | """ 78 | gradients at top face of T cells 79 | """ 80 | dTdz = dTdz.at[:, :, :-1].set( 81 | maskW[:, :, :-1] 82 | * (temp[:, :, 1:, tau] - temp[:, :, :-1, tau]) 83 | / dzw[jnp.newaxis, jnp.newaxis, :-1], 84 | ) 85 | dSdz = dSdz.at[:, :, :-1].set( 86 | maskW[:, :, :-1] 87 | * (salt[:, :, 1:, tau] - salt[:, :, :-1, tau]) 88 | / dzw[jnp.newaxis, jnp.newaxis, :-1], 89 | ) 90 | 91 | """ 92 | gradients at eastern face of T cells 93 | """ 94 | dTdx = dTdx.at[:-1, :, :].set( 95 | maskU[:-1, :, :] 96 | * (temp[1:, :, :, tau] - temp[:-1, :, :, tau]) 97 | / (dxu[:-1, jnp.newaxis, jnp.newaxis] * cost[jnp.newaxis, :, jnp.newaxis]), 98 | ) 99 | dSdx = dSdx.at[:-1, :, :].set( 100 | maskU[:-1, :, :] 101 | * (salt[1:, :, :, tau] - salt[:-1, :, :, tau]) 102 | / (dxu[:-1, jnp.newaxis, jnp.newaxis] * cost[jnp.newaxis, :, jnp.newaxis]), 103 | ) 104 | 105 | """ 106 | gradients at northern face of T cells 107 | """ 108 | dTdy = dTdy.at[:, :-1, :].set( 109 | maskV[:, :-1, :] 110 | * (temp[:, 1:, :, tau] - temp[:, :-1, :, tau]) 111 | / dyu[jnp.newaxis, :-1, jnp.newaxis], 112 | ) 113 | dSdy = dSdy.at[:, :-1, :].set( 114 | maskV[:, :-1, :] 115 | * (salt[:, 1:, :, tau] - salt[:, :-1, :, tau]) 116 | / dyu[jnp.newaxis, :-1, jnp.newaxis], 117 | ) 118 | 119 | def dm_taper(sx): 120 | """ 121 | tapering function for isopycnal slopes 122 | """ 123 | return 0.5 * (1.0 + jnp.tanh((-jnp.abs(sx) + iso_slopec) / iso_dslope)) 124 | 125 | """ 126 | Compute Ai_ez and K11 on center of east face of T cell. 127 | """ 128 | diffloc = jnp.zeros_like(K_11) 129 | diffloc = diffloc.at[1:-2, 2:-2, 1:].set( 130 | 0.25 131 | * ( 132 | K_iso[1:-2, 2:-2, 1:] 133 | + K_iso[1:-2, 2:-2, :-1] 134 | + K_iso[2:-1, 2:-2, 1:] 135 | + K_iso[2:-1, 2:-2, :-1] 136 | ), 137 | ) 138 | diffloc = diffloc.at[1:-2, 2:-2, 0].set( 139 | 0.5 * (K_iso[1:-2, 2:-2, 0] + K_iso[2:-1, 2:-2, 0]), 140 | ) 141 | 142 | sumz = jnp.zeros_like(K_11)[1:-2, 2:-2] 143 | for kr in range(2): 144 | ki = 0 if kr == 1 else 1 145 | for ip in range(2): 146 | drodxe = ( 147 | drdT[1 + ip : -2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:] 148 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:] 149 | ) 150 | drodze = ( 151 | drdT[1 + ip : -2 + ip, 2:-2, ki:] 152 | * dTdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None] 153 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] 154 | * dSdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None] 155 | ) 156 | sxe = -drodxe / (jnp.minimum(0.0, drodze) - epsln) 157 | taper = dm_taper(sxe) 158 | sumz = sumz.at[:, :, ki:].set( 159 | sumz[..., ki:] 160 | + dzw[jnp.newaxis, jnp.newaxis, : -1 + kr or None] 161 | * maskU[1:-2, 2:-2, ki:] 162 | * jnp.maximum(K_iso_steep, diffloc[1:-2, 2:-2, ki:] * taper), 163 | ) 164 | Ai_ez = Ai_ez.at[1:-2, 2:-2, ki:, ip, kr].set( 165 | taper * sxe * maskU[1:-2, 2:-2, ki:], 166 | ) 167 | 168 | K_11 = K_11.at[1:-2, 2:-2, :].set( 169 | sumz / (4.0 * dzt[jnp.newaxis, jnp.newaxis, :]), 170 | ) 171 | 172 | """ 173 | Compute Ai_nz and K_22 on center of north face of T cell. 174 | """ 175 | diffloc = diffloc.at[...].set( 0) 176 | diffloc = diffloc.at[2:-2, 1:-2, 1:].set( 177 | 0.25 178 | * ( 179 | K_iso[2:-2, 1:-2, 1:] 180 | + K_iso[2:-2, 1:-2, :-1] 181 | + K_iso[2:-2, 2:-1, 1:] 182 | + K_iso[2:-2, 2:-1, :-1] 183 | ), 184 | ) 185 | diffloc = diffloc.at[2:-2, 1:-2, 0].set( 186 | 0.5 * (K_iso[2:-2, 1:-2, 0] + K_iso[2:-2, 2:-1, 0]), 187 | ) 188 | 189 | sumz = jnp.zeros_like(K_11)[2:-2, 1:-2] 190 | for kr in range(2): 191 | ki = 0 if kr == 1 else 1 192 | for jp in range(2): 193 | drodyn = ( 194 | drdT[2:-2, 1 + jp : -2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:] 195 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:] 196 | ) 197 | drodzn = ( 198 | drdT[2:-2, 1 + jp : -2 + jp, ki:] 199 | * dTdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None] 200 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] 201 | * dSdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None] 202 | ) 203 | syn = -drodyn / (jnp.minimum(0.0, drodzn) - epsln) 204 | taper = dm_taper(syn) 205 | sumz = sumz.at[:, :, ki:].set( 206 | sumz[..., ki:] 207 | + dzw[jnp.newaxis, jnp.newaxis, : -1 + kr or None] 208 | * maskV[2:-2, 1:-2, ki:] 209 | * jnp.maximum(K_iso_steep, diffloc[2:-2, 1:-2, ki:] * taper), 210 | ) 211 | Ai_nz = Ai_nz.at[2:-2, 1:-2, ki:, jp, kr].set( 212 | taper * syn * maskV[2:-2, 1:-2, ki:], 213 | ) 214 | K_22 = K_22.at[2:-2, 1:-2, :].set( 215 | sumz / (4.0 * dzt[jnp.newaxis, jnp.newaxis, :]), 216 | ) 217 | 218 | """ 219 | compute Ai_bx, Ai_by and K33 on top face of T cell. 220 | """ 221 | sumx = jnp.zeros_like(K_11)[2:-2, 2:-2, :-1] 222 | sumy = jnp.zeros_like(K_11)[2:-2, 2:-2, :-1] 223 | 224 | for kr in range(2): 225 | drodzb = ( 226 | drdT[2:-2, 2:-2, kr : -1 + kr or None] * dTdz[2:-2, 2:-2, :-1] 227 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] * dSdz[2:-2, 2:-2, :-1] 228 | ) 229 | 230 | # eastward slopes at the top of T cells 231 | for ip in range(2): 232 | drodxb = ( 233 | drdT[2:-2, 2:-2, kr : -1 + kr or None] 234 | * dTdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None] 235 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] 236 | * dSdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None] 237 | ) 238 | sxb = -drodxb / (jnp.minimum(0.0, drodzb) - epsln) 239 | taper = dm_taper(sxb) 240 | sumx += ( 241 | dxu[1 + ip : -3 + ip, jnp.newaxis, jnp.newaxis] 242 | * K_iso[2:-2, 2:-2, :-1] 243 | * taper 244 | * sxb ** 2 245 | * maskW[2:-2, 2:-2, :-1] 246 | ) 247 | Ai_bx = Ai_bx.at[2:-2, 2:-2, :-1, ip, kr].set( 248 | taper * sxb * maskW[2:-2, 2:-2, :-1], 249 | ) 250 | 251 | # northward slopes at the top of T cells 252 | for jp in range(2): 253 | facty = cosu[1 + jp : -3 + jp] * dyu[1 + jp : -3 + jp] 254 | drodyb = ( 255 | drdT[2:-2, 2:-2, kr : -1 + kr or None] 256 | * dTdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None] 257 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] 258 | * dSdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None] 259 | ) 260 | syb = -drodyb / (jnp.minimum(0.0, drodzb) - epsln) 261 | taper = dm_taper(syb) 262 | sumy += ( 263 | facty[jnp.newaxis, :, jnp.newaxis] 264 | * K_iso[2:-2, 2:-2, :-1] 265 | * taper 266 | * syb ** 2 267 | * maskW[2:-2, 2:-2, :-1] 268 | ) 269 | Ai_by = Ai_by.at[2:-2, 2:-2, :-1, jp, kr].set( 270 | taper * syb * maskW[2:-2, 2:-2, :-1], 271 | ) 272 | 273 | K_33 = K_33.at[2:-2, 2:-2, :-1].set( 274 | sumx / (4 * dxt[2:-2, jnp.newaxis, jnp.newaxis]) 275 | + sumy 276 | / ( 277 | 4 278 | * dyt[jnp.newaxis, 2:-2, jnp.newaxis] 279 | * cost[jnp.newaxis, 2:-2, jnp.newaxis] 280 | ), 281 | ) 282 | K_33 = K_33.at[2:-2, 2:-2, -1].set( 0.0) 283 | 284 | return K_11, K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by 285 | 286 | 287 | def prepare_inputs(*inputs, device): 288 | out = [jnp.array(k) for k in inputs] 289 | for o in out: 290 | o.block_until_ready() 291 | return out 292 | 293 | 294 | def run(*inputs, device="cpu"): 295 | outputs = isoneutral_diffusion_pre(*inputs) 296 | for o in outputs: 297 | o.block_until_ready() 298 | return outputs 299 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/isoneutral_numba.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba as nb 3 | 4 | 5 | @nb.jit(nopython=True, fastmath=True) 6 | def get_drhodT(salt, temp, p): 7 | rho0 = 1024.0 8 | z0 = 0.0 9 | theta0 = 283.0 - 273.15 10 | grav = 9.81 11 | betaT = 1.67e-4 12 | betaTs = 1e-5 13 | gammas = 1.1e-8 14 | 15 | zz = -p - z0 16 | thetas = temp - theta0 17 | return -(betaTs * thetas + betaT * (1 - gammas * grav * zz * rho0)) * rho0 18 | 19 | 20 | @nb.jit(nopython=True, fastmath=True) 21 | def get_drhodS(salt, temp, p): 22 | betaS = 0.78e-3 23 | rho0 = 1024.0 24 | return betaS * rho0 25 | 26 | 27 | @nb.jit(nopython=True, fastmath=True) 28 | def dm_taper(sx): 29 | """ 30 | tapering function for isopycnal slopes 31 | """ 32 | iso_slopec = 1e-3 33 | iso_dslope = 1e-3 34 | return 0.5 * (1.0 + np.tanh((-np.abs(sx) + iso_slopec) / iso_dslope)) 35 | 36 | 37 | @nb.jit(nopython=True, boundscheck=False, nogil=True, fastmath=True, cache=True) 38 | def isoneutral_diffusion_pre( 39 | maskT, 40 | maskU, 41 | maskV, 42 | maskW, 43 | dxt, 44 | dxu, 45 | dyt, 46 | dyu, 47 | dzt, 48 | dzw, 49 | cost, 50 | cosu, 51 | salt, 52 | temp, 53 | zt, 54 | K_iso, 55 | K_11, 56 | K_22, 57 | K_33, 58 | Ai_ez, 59 | Ai_nz, 60 | Ai_bx, 61 | Ai_by, 62 | ): 63 | """ 64 | Isopycnal diffusion for tracer 65 | following functional formulation by Griffies et al 66 | Code adopted from MOM2.1 67 | """ 68 | nx, ny, nz = maskT.shape 69 | 70 | epsln = 1e-20 71 | K_iso_steep = 50.0 72 | tau = 0 73 | 74 | drdT = np.empty_like(K_11) 75 | drdS = np.empty_like(K_11) 76 | dTdx = np.empty_like(K_11) 77 | dSdx = np.empty_like(K_11) 78 | dTdy = np.empty_like(K_11) 79 | dSdy = np.empty_like(K_11) 80 | dTdz = np.empty_like(K_11) 81 | dSdz = np.empty_like(K_11) 82 | 83 | """ 84 | drho_dt and drho_ds at centers of T cells 85 | """ 86 | for i in range(nx): 87 | for j in range(ny): 88 | for k in range(nz): 89 | drdT[i, j, k] = maskT[i, j, k] * get_drhodT( 90 | salt[i, j, k, tau], temp[i, j, k, tau], np.abs(zt[k]) 91 | ) 92 | drdS[i, j, k] = maskT[i, j, k] * get_drhodS( 93 | salt[i, j, k, tau], temp[i, j, k, tau], np.abs(zt[k]) 94 | ) 95 | 96 | """ 97 | gradients at top face of T cells 98 | """ 99 | for i in range(nx): 100 | for j in range(ny): 101 | for k in range(nz - 1): 102 | dTdz[i, j, k] = ( 103 | maskW[i, j, k] 104 | * (temp[i, j, k + 1, tau] - temp[i, j, k, tau]) 105 | / dzw[k] 106 | ) 107 | dSdz[i, j, k] = ( 108 | maskW[i, j, k] 109 | * (salt[i, j, k + 1, tau] - salt[i, j, k, tau]) 110 | / dzw[k] 111 | ) 112 | dTdz[i, j, -1] = 0.0 113 | dSdz[i, j, -1] = 0.0 114 | 115 | """ 116 | gradients at eastern face of T cells 117 | """ 118 | for i in range(nx - 1): 119 | for j in range(ny): 120 | for k in range(nz): 121 | dTdx[i, j, k] = ( 122 | maskU[i, j, k] 123 | * (temp[i + 1, j, k, tau] - temp[i, j, k, tau]) 124 | / (dxu[i] * cost[j]) 125 | ) 126 | dSdx[i, j, k] = ( 127 | maskU[i, j, k] 128 | * (salt[i + 1, j, k, tau] - salt[i, j, k, tau]) 129 | / (dxu[i] * cost[j]) 130 | ) 131 | dTdx[-1, :, :] = 0.0 132 | dSdx[-1, :, :] = 0.0 133 | 134 | """ 135 | gradients at northern face of T cells 136 | """ 137 | for i in range(nx): 138 | for j in range(ny - 1): 139 | for k in range(nz): 140 | dTdy[i, j, k] = ( 141 | maskV[i, j, k] 142 | * (temp[i, j + 1, k, tau] - temp[i, j, k, tau]) 143 | / dyu[j] 144 | ) 145 | dSdy[i, j, k] = ( 146 | maskV[i, j, k] 147 | * (salt[i, j + 1, k, tau] - salt[i, j, k, tau]) 148 | / dyu[j] 149 | ) 150 | dTdy[:, -1, :] = 0.0 151 | dSdy[:, -1, :] = 0.0 152 | 153 | """ 154 | Compute Ai_ez and K11 on center of east face of T cell. 155 | """ 156 | for i in range(1, nx - 2): 157 | for j in range(2, ny - 2): 158 | for k in range(0, nz): 159 | if k == 0: 160 | diffloc = 0.5 * (K_iso[i, j, k] + K_iso[i + 1, j, k]) 161 | else: 162 | diffloc = 0.25 * ( 163 | K_iso[i, j, k] 164 | + K_iso[i, j, k - 1] 165 | + K_iso[i + 1, j, k] 166 | + K_iso[i + 1, j, k - 1] 167 | ) 168 | 169 | sumz = 0.0 170 | 171 | for kr in (0, 1): 172 | if k == 0 and kr == 0: 173 | continue 174 | 175 | for ip in (0, 1): 176 | drodxe = ( 177 | drdT[i + ip, j, k] * dTdx[i, j, k] 178 | + drdS[i + ip, j, k] * dSdx[i, j, k] 179 | ) 180 | drodze = ( 181 | drdT[i + ip, j, k] * dTdz[i + ip, j, k + kr - 1] 182 | + drdS[i + ip, j, k] * dSdz[i + ip, j, k + kr - 1] 183 | ) 184 | sxe = -drodxe / (min(0.0, drodze) - epsln) 185 | taper = dm_taper(sxe) 186 | sumz += ( 187 | dzw[k + kr - 1] 188 | * maskU[i, j, k] 189 | * max(K_iso_steep, diffloc * taper) 190 | ) 191 | Ai_ez[i, j, k, ip, kr] = taper * sxe * maskU[i, j, k] 192 | 193 | K_11[i, j, k] = sumz / (4.0 * dzt[k]) 194 | 195 | """ 196 | Compute Ai_nz and K_22 on center of north face of T cell. 197 | """ 198 | for i in range(2, nx - 2): 199 | for j in range(1, ny - 2): 200 | for k in range(0, nz): 201 | if k == 0: 202 | diffloc = 0.5 * (K_iso[i, j, k] + K_iso[i, j + 1, k]) 203 | else: 204 | diffloc = 0.25 * ( 205 | K_iso[i, j, k] 206 | + K_iso[i, j, k - 1] 207 | + K_iso[i, j + 1, k] 208 | + K_iso[i, j + 1, k - 1] 209 | ) 210 | 211 | sumz = 0.0 212 | 213 | for kr in (0, 1): 214 | if k == 0 and kr == 0: 215 | continue 216 | 217 | for jp in (0, 1): 218 | drodyn = ( 219 | drdT[i, j + jp, k] * dTdy[i, j, k] 220 | + drdS[i, j + jp, k] * dSdy[i, j, k] 221 | ) 222 | drodzn = ( 223 | drdT[i, j + jp, k] * dTdz[i, j + jp, k + kr - 1] 224 | + drdS[i, j + jp, k] * dSdz[i, j + jp, k + kr - 1] 225 | ) 226 | syn = -drodyn / (min(0.0, drodzn) - epsln) 227 | taper = dm_taper(syn) 228 | sumz += ( 229 | dzw[k + kr - 1] 230 | * maskV[i, j, k] 231 | * max(K_iso_steep, diffloc * taper) 232 | ) 233 | Ai_nz[i, j, k, jp, kr] = taper * syn * maskV[i, j, k] 234 | 235 | K_22[i, j, k] = sumz / (4.0 * dzt[k]) 236 | 237 | """ 238 | compute Ai_bx, Ai_by and K33 on top face of T cell. 239 | """ 240 | for i in range(2, nx - 2): 241 | for j in range(2, ny - 2): 242 | for k in range(nz - 1): 243 | sumx = 0.0 244 | sumy = 0.0 245 | 246 | for kr in (0, 1): 247 | drodzb = ( 248 | drdT[i, j, k + kr] * dTdz[i, j, k] 249 | + drdS[i, j, k + kr] * dSdz[i, j, k] 250 | ) 251 | 252 | # eastward slopes at the top of T cells 253 | for ip in (0, 1): 254 | drodxb = ( 255 | drdT[i, j, k + kr] * dTdx[i - 1 + ip, j, k + kr] 256 | + drdS[i, j, k + kr] * dSdx[i - 1 + ip, j, k + kr] 257 | ) 258 | sxb = -drodxb / (min(0.0, drodzb) - epsln) 259 | taper = dm_taper(sxb) 260 | sumx += ( 261 | dxu[i - 1 + ip] 262 | * K_iso[i, j, k] 263 | * taper 264 | * sxb ** 2 265 | * maskW[i, j, k] 266 | ) 267 | Ai_bx[i, j, k, ip, kr] = taper * sxb * maskW[i, j, k] 268 | 269 | # northward slopes at the top of T cells 270 | for jp in (0, 1): 271 | facty = cosu[j - 1 + jp] * dyu[j - 1 + jp] 272 | drodyb = ( 273 | drdT[i, j, k + kr] * dTdy[i, j + jp - 1, k + kr] 274 | + drdS[i, j, k + kr] * dSdy[i, j + jp - 1, k + kr] 275 | ) 276 | syb = -drodyb / (min(0.0, drodzb) - epsln) 277 | taper = dm_taper(syb) 278 | sumy += ( 279 | facty * K_iso[i, j, k] * taper * syb ** 2 * maskW[i, j, k] 280 | ) 281 | Ai_by[i, j, k, jp, kr] = taper * syb * maskW[i, j, k] 282 | 283 | K_33[i, j, k] = sumx / (4 * dxt[i]) + sumy / (4 * dyt[j] * cost[j]) 284 | 285 | K_33[i, j, -1] = 0.0 286 | 287 | 288 | def run(*inputs, device="cpu"): 289 | isoneutral_diffusion_pre(*inputs) 290 | return inputs[-7:] 291 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/isoneutral_numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_drhodT(salt, temp, p): 5 | rho0 = 1024.0 6 | z0 = 0.0 7 | theta0 = 283.0 - 273.15 8 | grav = 9.81 9 | betaT = 1.67e-4 10 | betaTs = 1e-5 11 | gammas = 1.1e-8 12 | 13 | zz = -p - z0 14 | thetas = temp - theta0 15 | return -(betaTs * thetas + betaT * (1 - gammas * grav * zz * rho0)) * rho0 16 | 17 | 18 | def get_drhodS(salt, temp, p): 19 | betaS = 0.78e-3 20 | rho0 = 1024.0 21 | return betaS * rho0 * np.ones_like(temp) 22 | 23 | 24 | def isoneutral_diffusion_pre( 25 | maskT, 26 | maskU, 27 | maskV, 28 | maskW, 29 | dxt, 30 | dxu, 31 | dyt, 32 | dyu, 33 | dzt, 34 | dzw, 35 | cost, 36 | cosu, 37 | salt, 38 | temp, 39 | zt, 40 | K_iso, 41 | K_11, 42 | K_22, 43 | K_33, 44 | Ai_ez, 45 | Ai_nz, 46 | Ai_bx, 47 | Ai_by, 48 | ): 49 | """ 50 | Isopycnal diffusion for tracer 51 | following functional formulation by Griffies et al 52 | Code adopted from MOM2.1 53 | """ 54 | epsln = 1e-20 55 | iso_slopec = 1e-3 56 | iso_dslope = 1e-3 57 | K_iso_steep = 50.0 58 | tau = 0 59 | 60 | dTdx = np.zeros_like(K_11) 61 | dSdx = np.zeros_like(K_11) 62 | dTdy = np.zeros_like(K_11) 63 | dSdy = np.zeros_like(K_11) 64 | dTdz = np.zeros_like(K_11) 65 | dSdz = np.zeros_like(K_11) 66 | 67 | """ 68 | drho_dt and drho_ds at centers of T cells 69 | """ 70 | drdT = maskT * get_drhodT(salt[:, :, :, tau], temp[:, :, :, tau], np.abs(zt)) 71 | drdS = maskT * get_drhodS(salt[:, :, :, tau], temp[:, :, :, tau], np.abs(zt)) 72 | 73 | """ 74 | gradients at top face of T cells 75 | """ 76 | dTdz[:, :, :-1] = ( 77 | maskW[:, :, :-1] 78 | * (temp[:, :, 1:, tau] - temp[:, :, :-1, tau]) 79 | / dzw[np.newaxis, np.newaxis, :-1] 80 | ) 81 | dSdz[:, :, :-1] = ( 82 | maskW[:, :, :-1] 83 | * (salt[:, :, 1:, tau] - salt[:, :, :-1, tau]) 84 | / dzw[np.newaxis, np.newaxis, :-1] 85 | ) 86 | 87 | """ 88 | gradients at eastern face of T cells 89 | """ 90 | dTdx[:-1, :, :] = ( 91 | maskU[:-1, :, :] 92 | * (temp[1:, :, :, tau] - temp[:-1, :, :, tau]) 93 | / (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis]) 94 | ) 95 | dSdx[:-1, :, :] = ( 96 | maskU[:-1, :, :] 97 | * (salt[1:, :, :, tau] - salt[:-1, :, :, tau]) 98 | / (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis]) 99 | ) 100 | 101 | """ 102 | gradients at northern face of T cells 103 | """ 104 | dTdy[:, :-1, :] = ( 105 | maskV[:, :-1, :] 106 | * (temp[:, 1:, :, tau] - temp[:, :-1, :, tau]) 107 | / dyu[np.newaxis, :-1, np.newaxis] 108 | ) 109 | dSdy[:, :-1, :] = ( 110 | maskV[:, :-1, :] 111 | * (salt[:, 1:, :, tau] - salt[:, :-1, :, tau]) 112 | / dyu[np.newaxis, :-1, np.newaxis] 113 | ) 114 | 115 | def dm_taper(sx): 116 | """ 117 | tapering function for isopycnal slopes 118 | """ 119 | return 0.5 * (1.0 + np.tanh((-np.abs(sx) + iso_slopec) / iso_dslope)) 120 | 121 | """ 122 | Compute Ai_ez and K11 on center of east face of T cell. 123 | """ 124 | diffloc = np.zeros_like(K_11) 125 | diffloc[1:-2, 2:-2, 1:] = 0.25 * ( 126 | K_iso[1:-2, 2:-2, 1:] 127 | + K_iso[1:-2, 2:-2, :-1] 128 | + K_iso[2:-1, 2:-2, 1:] 129 | + K_iso[2:-1, 2:-2, :-1] 130 | ) 131 | diffloc[1:-2, 2:-2, 0] = 0.5 * (K_iso[1:-2, 2:-2, 0] + K_iso[2:-1, 2:-2, 0]) 132 | 133 | sumz = np.zeros_like(K_11)[1:-2, 2:-2] 134 | for kr in range(2): 135 | ki = 0 if kr == 1 else 1 136 | for ip in range(2): 137 | drodxe = ( 138 | drdT[1 + ip : -2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:] 139 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:] 140 | ) 141 | drodze = ( 142 | drdT[1 + ip : -2 + ip, 2:-2, ki:] 143 | * dTdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None] 144 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] 145 | * dSdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None] 146 | ) 147 | sxe = -drodxe / (np.minimum(0.0, drodze) - epsln) 148 | taper = dm_taper(sxe) 149 | sumz[:, :, ki:] += ( 150 | dzw[np.newaxis, np.newaxis, : -1 + kr or None] 151 | * maskU[1:-2, 2:-2, ki:] 152 | * np.maximum(K_iso_steep, diffloc[1:-2, 2:-2, ki:] * taper) 153 | ) 154 | Ai_ez[1:-2, 2:-2, ki:, ip, kr] = taper * sxe * maskU[1:-2, 2:-2, ki:] 155 | K_11[1:-2, 2:-2, :] = sumz / (4.0 * dzt[np.newaxis, np.newaxis, :]) 156 | 157 | """ 158 | Compute Ai_nz and K_22 on center of north face of T cell. 159 | """ 160 | diffloc[...] = 0 161 | diffloc[2:-2, 1:-2, 1:] = 0.25 * ( 162 | K_iso[2:-2, 1:-2, 1:] 163 | + K_iso[2:-2, 1:-2, :-1] 164 | + K_iso[2:-2, 2:-1, 1:] 165 | + K_iso[2:-2, 2:-1, :-1] 166 | ) 167 | diffloc[2:-2, 1:-2, 0] = 0.5 * (K_iso[2:-2, 1:-2, 0] + K_iso[2:-2, 2:-1, 0]) 168 | 169 | sumz = np.zeros_like(K_11)[2:-2, 1:-2] 170 | for kr in range(2): 171 | ki = 0 if kr == 1 else 1 172 | for jp in range(2): 173 | drodyn = ( 174 | drdT[2:-2, 1 + jp : -2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:] 175 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:] 176 | ) 177 | drodzn = ( 178 | drdT[2:-2, 1 + jp : -2 + jp, ki:] 179 | * dTdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None] 180 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] 181 | * dSdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None] 182 | ) 183 | syn = -drodyn / (np.minimum(0.0, drodzn) - epsln) 184 | taper = dm_taper(syn) 185 | sumz[:, :, ki:] += ( 186 | dzw[np.newaxis, np.newaxis, : -1 + kr or None] 187 | * maskV[2:-2, 1:-2, ki:] 188 | * np.maximum(K_iso_steep, diffloc[2:-2, 1:-2, ki:] * taper) 189 | ) 190 | Ai_nz[2:-2, 1:-2, ki:, jp, kr] = taper * syn * maskV[2:-2, 1:-2, ki:] 191 | K_22[2:-2, 1:-2, :] = sumz / (4.0 * dzt[np.newaxis, np.newaxis, :]) 192 | 193 | """ 194 | compute Ai_bx, Ai_by and K33 on top face of T cell. 195 | """ 196 | sumx = np.zeros_like(K_11)[2:-2, 2:-2, :-1] 197 | sumy = np.zeros_like(K_11)[2:-2, 2:-2, :-1] 198 | 199 | for kr in range(2): 200 | drodzb = ( 201 | drdT[2:-2, 2:-2, kr : -1 + kr or None] * dTdz[2:-2, 2:-2, :-1] 202 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] * dSdz[2:-2, 2:-2, :-1] 203 | ) 204 | 205 | # eastward slopes at the top of T cells 206 | for ip in range(2): 207 | drodxb = ( 208 | drdT[2:-2, 2:-2, kr : -1 + kr or None] 209 | * dTdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None] 210 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] 211 | * dSdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None] 212 | ) 213 | sxb = -drodxb / (np.minimum(0.0, drodzb) - epsln) 214 | taper = dm_taper(sxb) 215 | sumx += ( 216 | dxu[1 + ip : -3 + ip, np.newaxis, np.newaxis] 217 | * K_iso[2:-2, 2:-2, :-1] 218 | * taper 219 | * sxb ** 2 220 | * maskW[2:-2, 2:-2, :-1] 221 | ) 222 | Ai_bx[2:-2, 2:-2, :-1, ip, kr] = taper * sxb * maskW[2:-2, 2:-2, :-1] 223 | 224 | # northward slopes at the top of T cells 225 | for jp in range(2): 226 | facty = cosu[1 + jp : -3 + jp] * dyu[1 + jp : -3 + jp] 227 | drodyb = ( 228 | drdT[2:-2, 2:-2, kr : -1 + kr or None] 229 | * dTdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None] 230 | + drdS[2:-2, 2:-2, kr : -1 + kr or None] 231 | * dSdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None] 232 | ) 233 | syb = -drodyb / (np.minimum(0.0, drodzb) - epsln) 234 | taper = dm_taper(syb) 235 | sumy += ( 236 | facty[np.newaxis, :, np.newaxis] 237 | * K_iso[2:-2, 2:-2, :-1] 238 | * taper 239 | * syb ** 2 240 | * maskW[2:-2, 2:-2, :-1] 241 | ) 242 | Ai_by[2:-2, 2:-2, :-1, jp, kr] = taper * syb * maskW[2:-2, 2:-2, :-1] 243 | 244 | K_33[2:-2, 2:-2, :-1] = sumx / (4 * dxt[2:-2, np.newaxis, np.newaxis]) + sumy / ( 245 | 4 * dyt[np.newaxis, 2:-2, np.newaxis] * cost[np.newaxis, 2:-2, np.newaxis] 246 | ) 247 | K_33[2:-2, 2:-2, -1] = 0.0 248 | 249 | 250 | def run(*inputs, device="cpu"): 251 | isoneutral_diffusion_pre(*inputs) 252 | return inputs[-7:] 253 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/isoneutral_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.script 5 | def get_drhodT(salt, temp, p): 6 | rho0 = 1024.0 7 | z0 = 0.0 8 | theta0 = 283.0 - 273.15 9 | grav = 9.81 10 | betaT = 1.67e-4 11 | betaTs = 1e-5 12 | gammas = 1.1e-8 13 | 14 | zz = -p - z0 15 | thetas = temp - theta0 16 | return -(betaTs * thetas + betaT * (1 - gammas * grav * zz * rho0)) * rho0 17 | 18 | 19 | @torch.jit.script 20 | def get_drhodS(salt, temp, p): 21 | betaS = 0.78e-3 22 | rho0 = 1024.0 23 | return betaS * rho0 * torch.ones_like(temp) 24 | 25 | 26 | @torch.jit.script 27 | def dm_taper(sx): 28 | """ 29 | tapering function for isopycnal slopes 30 | """ 31 | iso_slopec = 1e-3 32 | iso_dslope = 1e-3 33 | return 0.5 * (1.0 + torch.tanh((-torch.abs(sx) + iso_slopec) / iso_dslope)) 34 | 35 | 36 | @torch.jit.script 37 | def isoneutral_diffusion_pre( 38 | maskT, 39 | maskU, 40 | maskV, 41 | maskW, 42 | dxt, 43 | dxu, 44 | dyt, 45 | dyu, 46 | dzt, 47 | dzw, 48 | cost, 49 | cosu, 50 | salt, 51 | temp, 52 | zt, 53 | K_iso, 54 | K_11, 55 | K_22, 56 | K_33, 57 | Ai_ez, 58 | Ai_nz, 59 | Ai_bx, 60 | Ai_by, 61 | ): 62 | """ 63 | Isopycnal diffusion for tracer 64 | following functional formulation by Griffies et al 65 | Code adopted from MOM2.1 66 | """ 67 | epsln = 1e-20 68 | K_iso_steep = 50.0 69 | tau = 0 70 | 71 | device = K_11.device 72 | 73 | dTdx = torch.zeros_like(K_11) 74 | dSdx = torch.zeros_like(K_11) 75 | dTdy = torch.zeros_like(K_11) 76 | dSdy = torch.zeros_like(K_11) 77 | dTdz = torch.zeros_like(K_11) 78 | dSdz = torch.zeros_like(K_11) 79 | 80 | """ 81 | drho_dt and drho_ds at centers of T cells 82 | """ 83 | drdT = maskT * get_drhodT(salt[:, :, :, tau], temp[:, :, :, tau], torch.abs(zt)) 84 | drdS = maskT * get_drhodS(salt[:, :, :, tau], temp[:, :, :, tau], torch.abs(zt)) 85 | 86 | """ 87 | gradients at top face of T cells 88 | """ 89 | dTdz[:, :, :-1] = ( 90 | maskW[:, :, :-1] 91 | * (temp[:, :, 1:, tau] - temp[:, :, :-1, tau]) 92 | / dzw[None, None, :-1] 93 | ) 94 | dSdz[:, :, :-1] = ( 95 | maskW[:, :, :-1] 96 | * (salt[:, :, 1:, tau] - salt[:, :, :-1, tau]) 97 | / dzw[None, None, :-1] 98 | ) 99 | 100 | """ 101 | gradients at eastern face of T cells 102 | """ 103 | dTdx[:-1, :, :] = ( 104 | maskU[:-1, :, :] 105 | * (temp[1:, :, :, tau] - temp[:-1, :, :, tau]) 106 | / (dxu[:-1, None, None] * cost[None, :, None]) 107 | ) 108 | dSdx[:-1, :, :] = ( 109 | maskU[:-1, :, :] 110 | * (salt[1:, :, :, tau] - salt[:-1, :, :, tau]) 111 | / (dxu[:-1, None, None] * cost[None, :, None]) 112 | ) 113 | 114 | """ 115 | gradients at northern face of T cells 116 | """ 117 | dTdy[:, :-1, :] = ( 118 | maskV[:, :-1, :] 119 | * (temp[:, 1:, :, tau] - temp[:, :-1, :, tau]) 120 | / dyu[None, :-1, None] 121 | ) 122 | dSdy[:, :-1, :] = ( 123 | maskV[:, :-1, :] 124 | * (salt[:, 1:, :, tau] - salt[:, :-1, :, tau]) 125 | / dyu[None, :-1, None] 126 | ) 127 | 128 | """ 129 | Compute Ai_ez and K11 on center of east face of T cell. 130 | """ 131 | diffloc = torch.zeros_like(K_11) 132 | diffloc[1:-2, 2:-2, 1:] = 0.25 * ( 133 | K_iso[1:-2, 2:-2, 1:] 134 | + K_iso[1:-2, 2:-2, :-1] 135 | + K_iso[2:-1, 2:-2, 1:] 136 | + K_iso[2:-1, 2:-2, :-1] 137 | ) 138 | diffloc[1:-2, 2:-2, 0] = 0.5 * (K_iso[1:-2, 2:-2, 0] + K_iso[2:-1, 2:-2, 0]) 139 | 140 | sumz = torch.zeros_like(K_11)[1:-2, 2:-2] 141 | for kr in range(2): 142 | ki = 0 if kr == 1 else 1 143 | if kr == 1: 144 | su = K_11.shape[2] 145 | else: 146 | su = K_11.shape[2] - 1 147 | for ip in range(2): 148 | drodxe = ( 149 | drdT[1 + ip : -2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:] 150 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:] 151 | ) 152 | drodze = ( 153 | drdT[1 + ip : -2 + ip, 2:-2, ki:] * dTdz[1 + ip : -2 + ip, 2:-2, :su] 154 | + drdS[1 + ip : -2 + ip, 2:-2, ki:] * dSdz[1 + ip : -2 + ip, 2:-2, :su] 155 | ) 156 | sxe = -drodxe / ( 157 | torch.min(drodze, torch.tensor([0.0], device=device)) - epsln 158 | ) 159 | taper = dm_taper(sxe) 160 | sumz[:, :, ki:] += ( 161 | dzw[None, None, :su] 162 | * maskU[1:-2, 2:-2, ki:] 163 | * torch.max( 164 | torch.tensor([K_iso_steep], device=device), 165 | diffloc[1:-2, 2:-2, ki:] * taper, 166 | ) 167 | ) 168 | Ai_ez[1:-2, 2:-2, ki:, ip, kr] = taper * sxe * maskU[1:-2, 2:-2, ki:] 169 | K_11[1:-2, 2:-2, :] = sumz / (4.0 * dzt[None, None, :]) 170 | 171 | """ 172 | Compute Ai_nz and K_22 on center of north face of T cell. 173 | """ 174 | diffloc[...] = 0 175 | diffloc[2:-2, 1:-2, 1:] = 0.25 * ( 176 | K_iso[2:-2, 1:-2, 1:] 177 | + K_iso[2:-2, 1:-2, :-1] 178 | + K_iso[2:-2, 2:-1, 1:] 179 | + K_iso[2:-2, 2:-1, :-1] 180 | ) 181 | diffloc[2:-2, 1:-2, 0] = 0.5 * (K_iso[2:-2, 1:-2, 0] + K_iso[2:-2, 2:-1, 0]) 182 | 183 | sumz = torch.zeros_like(K_11)[2:-2, 1:-2] 184 | for kr in range(2): 185 | ki = 0 if kr == 1 else 1 186 | if kr == 1: 187 | su = K_11.shape[2] 188 | else: 189 | su = K_11.shape[2] - 1 190 | for jp in range(2): 191 | drodyn = ( 192 | drdT[2:-2, 1 + jp : -2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:] 193 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:] 194 | ) 195 | drodzn = ( 196 | drdT[2:-2, 1 + jp : -2 + jp, ki:] * dTdz[2:-2, 1 + jp : -2 + jp, :su] 197 | + drdS[2:-2, 1 + jp : -2 + jp, ki:] * dSdz[2:-2, 1 + jp : -2 + jp, :su] 198 | ) 199 | syn = -drodyn / ( 200 | torch.min(torch.tensor([0.0], device=device), drodzn) - epsln 201 | ) 202 | taper = dm_taper(syn) 203 | sumz[:, :, ki:] += ( 204 | dzw[None, None, :su] 205 | * maskV[2:-2, 1:-2, ki:] 206 | * torch.max( 207 | torch.tensor([K_iso_steep], device=device), 208 | diffloc[2:-2, 1:-2, ki:] * taper, 209 | ) 210 | ) 211 | Ai_nz[2:-2, 1:-2, ki:, jp, kr] = taper * syn * maskV[2:-2, 1:-2, ki:] 212 | K_22[2:-2, 1:-2, :] = sumz / (4.0 * dzt[None, None, :]) 213 | 214 | """ 215 | compute Ai_bx, Ai_by and K33 on top face of T cell. 216 | """ 217 | sumx = torch.zeros_like(K_11)[2:-2, 2:-2, :-1] 218 | sumy = torch.zeros_like(K_11)[2:-2, 2:-2, :-1] 219 | 220 | for kr in range(2): 221 | if kr == 1: 222 | sl = 1 223 | su = K_11.shape[2] 224 | else: 225 | sl = 0 226 | su = K_11.shape[2] - 1 227 | 228 | drodzb = ( 229 | drdT[2:-2, 2:-2, sl:su] * dTdz[2:-2, 2:-2, :-1] 230 | + drdS[2:-2, 2:-2, sl:su] * dSdz[2:-2, 2:-2, :-1] 231 | ) 232 | 233 | # eastward slopes at the top of T cells 234 | for ip in range(2): 235 | drodxb = ( 236 | drdT[2:-2, 2:-2, sl:su] * dTdx[1 + ip : -3 + ip, 2:-2, sl:su] 237 | + drdS[2:-2, 2:-2, sl:su] * dSdx[1 + ip : -3 + ip, 2:-2, sl:su] 238 | ) 239 | sxb = -drodxb / ( 240 | torch.min(torch.tensor([0.0], device=device), drodzb) - epsln 241 | ) 242 | taper = dm_taper(sxb) 243 | sumx += ( 244 | dxu[1 + ip : -3 + ip, None, None] 245 | * K_iso[2:-2, 2:-2, :-1] 246 | * taper 247 | * sxb ** 2 248 | * maskW[2:-2, 2:-2, :-1] 249 | ) 250 | Ai_bx[2:-2, 2:-2, :-1, ip, kr] = taper * sxb * maskW[2:-2, 2:-2, :-1] 251 | 252 | # northward slopes at the top of T cells 253 | for jp in range(2): 254 | facty = cosu[1 + jp : -3 + jp] * dyu[1 + jp : -3 + jp] 255 | drodyb = ( 256 | drdT[2:-2, 2:-2, sl:su] * dTdy[2:-2, 1 + jp : -3 + jp, sl:su] 257 | + drdS[2:-2, 2:-2, sl:su] * dSdy[2:-2, 1 + jp : -3 + jp, sl:su] 258 | ) 259 | syb = -drodyb / ( 260 | torch.min(torch.tensor([0.0], device=device), drodzb) - epsln 261 | ) 262 | taper = dm_taper(syb) 263 | sumy += ( 264 | facty[None, :, None] 265 | * K_iso[2:-2, 2:-2, :-1] 266 | * taper 267 | * syb ** 2 268 | * maskW[2:-2, 2:-2, :-1] 269 | ) 270 | Ai_by[2:-2, 2:-2, :-1, jp, kr] = taper * syb * maskW[2:-2, 2:-2, :-1] 271 | 272 | K_33[2:-2, 2:-2, :-1] = sumx / (4 * dxt[2:-2, None, None]) + sumy / ( 273 | 4 * dyt[None, 2:-2, None] * cost[None, 2:-2, None] 274 | ) 275 | K_33[2:-2, 2:-2, -1] = 0.0 276 | 277 | return K_11, K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by 278 | 279 | 280 | def prepare_inputs(*inputs, device): 281 | out = [ 282 | torch.as_tensor(a, device="cuda" if device == "gpu" else "cpu") for a in inputs 283 | ] 284 | if device == "gpu": 285 | torch.cuda.synchronize() 286 | return out 287 | 288 | 289 | def run(*inputs, device="cpu"): 290 | with torch.no_grad(): 291 | outputs = isoneutral_diffusion_pre(*inputs) 292 | if device == "gpu": 293 | torch.cuda.synchronize() 294 | 295 | return outputs 296 | -------------------------------------------------------------------------------- /benchmarks/isoneutral_mixing/isoneutral_taichi.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | import taichi as ti 3 | 4 | 5 | @ti.func 6 | def get_drhodT(salt, temp, p): 7 | rho0 = 1024.0 8 | z0 = 0.0 9 | theta0 = 283.0 - 273.15 10 | grav = 9.81 11 | betaT = 1.67e-4 12 | betaTs = 1e-5 13 | gammas = 1.1e-8 14 | 15 | zz = -p - z0 16 | thetas = temp - theta0 17 | return -(betaTs * thetas + betaT * (1 - gammas * grav * zz * rho0)) * rho0 18 | 19 | 20 | @ti.func 21 | def get_drhodS(salt, temp, p): 22 | betaS = 0.78e-3 23 | rho0 = 1024.0 24 | return betaS * rho0 25 | 26 | 27 | @ti.func 28 | def dm_taper(sx): 29 | """ 30 | tapering function for isopycnal slopes 31 | """ 32 | iso_slopec = 1e-3 33 | iso_dslope = 1e-3 34 | return 0.5 * (1.0 + ti.tanh((-ti.abs(sx) + iso_slopec) / iso_dslope)) 35 | 36 | 37 | @ti.kernel 38 | def isoneutral_diffusion_pre( 39 | maskT: ti.template(), 40 | maskU: ti.template(), 41 | maskV: ti.template(), 42 | maskW: ti.template(), 43 | dxt: ti.template(), 44 | dxu: ti.template(), 45 | dyt: ti.template(), 46 | dyu: ti.template(), 47 | dzt: ti.template(), 48 | dzw: ti.template(), 49 | cost: ti.template(), 50 | cosu: ti.template(), 51 | salt: ti.template(), 52 | temp: ti.template(), 53 | zt: ti.template(), 54 | K_iso: ti.template(), 55 | K_11: ti.template(), 56 | K_22: ti.template(), 57 | K_33: ti.template(), 58 | Ai_ez: ti.template(), 59 | Ai_nz: ti.template(), 60 | Ai_bx: ti.template(), 61 | Ai_by: ti.template(), 62 | # scratch space 63 | drdT: ti.template(), 64 | drdS: ti.template(), 65 | dTdx: ti.template(), 66 | dSdx: ti.template(), 67 | dTdy: ti.template(), 68 | dSdy: ti.template(), 69 | dTdz: ti.template(), 70 | dSdz: ti.template(), 71 | ): 72 | """ 73 | Isopycnal diffusion for tracer 74 | following functional formulation by Griffies et al 75 | Code adopted from MOM2.1 76 | """ 77 | nx, ny, nz = maskT.shape 78 | 79 | epsln = 1e-20 80 | K_iso_steep = 50.0 81 | tau = 0 82 | 83 | """ 84 | drho_dt and drho_ds at centers of T cells 85 | """ 86 | for i, j, k in ti.ndrange(nx, ny, nz): 87 | drdT[i, j, k] = maskT[i, j, k] * get_drhodT( 88 | salt[i, j, k, tau], temp[i, j, k, tau], ti.abs(zt[k]) 89 | ) 90 | drdS[i, j, k] = maskT[i, j, k] * get_drhodS( 91 | salt[i, j, k, tau], temp[i, j, k, tau], ti.abs(zt[k]) 92 | ) 93 | 94 | """ 95 | gradients at top face of T cells 96 | """ 97 | for i in range(nx): 98 | for j in range(ny): 99 | for k in range(nz - 1): 100 | dTdz[i, j, k] = ( 101 | maskW[i, j, k] 102 | * (temp[i, j, k + 1, tau] - temp[i, j, k, tau]) 103 | / dzw[k] 104 | ) 105 | dSdz[i, j, k] = ( 106 | maskW[i, j, k] 107 | * (salt[i, j, k + 1, tau] - salt[i, j, k, tau]) 108 | / dzw[k] 109 | ) 110 | dTdz[i, j, -1] = 0.0 111 | dSdz[i, j, -1] = 0.0 112 | 113 | """ 114 | gradients at eastern face of T cells 115 | """ 116 | for i in range(nx - 1): 117 | for j in range(ny): 118 | for k in range(nz): 119 | dTdx[i, j, k] = ( 120 | maskU[i, j, k] 121 | * (temp[i + 1, j, k, tau] - temp[i, j, k, tau]) 122 | / (dxu[i] * cost[j]) 123 | ) 124 | dSdx[i, j, k] = ( 125 | maskU[i, j, k] 126 | * (salt[i + 1, j, k, tau] - salt[i, j, k, tau]) 127 | / (dxu[i] * cost[j]) 128 | ) 129 | 130 | for j in range(ny): 131 | for k in range(nz): 132 | dTdx[-1, j, k] = 0.0 133 | dSdx[-1, j, k] = 0.0 134 | 135 | """ 136 | gradients at northern face of T cells 137 | """ 138 | for i in range(nx): 139 | for j in range(ny - 1): 140 | for k in range(nz): 141 | dTdy[i, j, k] = ( 142 | maskV[i, j, k] 143 | * (temp[i, j + 1, k, tau] - temp[i, j, k, tau]) 144 | / dyu[j] 145 | ) 146 | dSdy[i, j, k] = ( 147 | maskV[i, j, k] 148 | * (salt[i, j + 1, k, tau] - salt[i, j, k, tau]) 149 | / dyu[j] 150 | ) 151 | 152 | for i in range(nx): 153 | for k in range(nz): 154 | dTdy[i, -1, k] = 0.0 155 | dSdy[i, -1, k] = 0.0 156 | 157 | """ 158 | Compute Ai_ez and K11 on center of east face of T cell. 159 | """ 160 | diffloc = 0.0 161 | 162 | for i in range(1, nx - 2): 163 | for j in range(2, ny - 2): 164 | for k in range(0, nz): 165 | if k == 0: 166 | diffloc = 0.5 * (K_iso[i, j, k] + K_iso[i + 1, j, k]) 167 | else: 168 | diffloc = 0.25 * ( 169 | K_iso[i, j, k] 170 | + K_iso[i, j, k - 1] 171 | + K_iso[i + 1, j, k] 172 | + K_iso[i + 1, j, k - 1] 173 | ) 174 | 175 | sumz = 0.0 176 | 177 | for kr in range(2): 178 | if not (k == 0 and kr == 0): 179 | for ip in range(2): 180 | drodxe = ( 181 | drdT[i + ip, j, k] * dTdx[i, j, k] 182 | + drdS[i + ip, j, k] * dSdx[i, j, k] 183 | ) 184 | drodze = ( 185 | drdT[i + ip, j, k] * dTdz[i + ip, j, k + kr - 1] 186 | + drdS[i + ip, j, k] * dSdz[i + ip, j, k + kr - 1] 187 | ) 188 | sxe = -drodxe / (ti.min(0.0, drodze) - epsln) 189 | taper = dm_taper(sxe) 190 | sumz += ( 191 | dzw[k + kr - 1] 192 | * maskU[i, j, k] 193 | * ti.max(K_iso_steep, diffloc * taper) 194 | ) 195 | Ai_ez[i, j, k, ip, kr] = taper * sxe * maskU[i, j, k] 196 | 197 | K_11[i, j, k] = sumz / (4.0 * dzt[k]) 198 | 199 | """ 200 | Compute Ai_nz and K_22 on center of north face of T cell. 201 | """ 202 | for i in range(2, nx - 2): 203 | for j in range(1, ny - 2): 204 | for k in range(0, nz): 205 | if k == 0: 206 | diffloc = 0.5 * (K_iso[i, j, k] + K_iso[i, j + 1, k]) 207 | else: 208 | diffloc = 0.25 * ( 209 | K_iso[i, j, k] 210 | + K_iso[i, j, k - 1] 211 | + K_iso[i, j + 1, k] 212 | + K_iso[i, j + 1, k - 1] 213 | ) 214 | 215 | sumz = 0.0 216 | 217 | for kr in range(2): 218 | if not (k == 0 and kr == 0): 219 | for jp in range(2): 220 | drodyn = ( 221 | drdT[i, j + jp, k] * dTdy[i, j, k] 222 | + drdS[i, j + jp, k] * dSdy[i, j, k] 223 | ) 224 | drodzn = ( 225 | drdT[i, j + jp, k] * dTdz[i, j + jp, k + kr - 1] 226 | + drdS[i, j + jp, k] * dSdz[i, j + jp, k + kr - 1] 227 | ) 228 | syn = -drodyn / (ti.min(0.0, drodzn) - epsln) 229 | taper = dm_taper(syn) 230 | sumz += ( 231 | dzw[k + kr - 1] 232 | * maskV[i, j, k] 233 | * ti.max(K_iso_steep, diffloc * taper) 234 | ) 235 | Ai_nz[i, j, k, jp, kr] = taper * syn * maskV[i, j, k] 236 | 237 | K_22[i, j, k] = sumz / (4.0 * dzt[k]) 238 | 239 | """ 240 | compute Ai_bx, Ai_by and K33 on top face of T cell. 241 | """ 242 | for i in range(2, nx - 2): 243 | for j in range(2, ny - 2): 244 | for k in range(nz - 1): 245 | sumx = 0.0 246 | sumy = 0.0 247 | 248 | for kr in range(2): 249 | drodzb = ( 250 | drdT[i, j, k + kr] * dTdz[i, j, k] 251 | + drdS[i, j, k + kr] * dSdz[i, j, k] 252 | ) 253 | 254 | # eastward slopes at the top of T cells 255 | for ip in range(2): 256 | drodxb = ( 257 | drdT[i, j, k + kr] * dTdx[i - 1 + ip, j, k + kr] 258 | + drdS[i, j, k + kr] * dSdx[i - 1 + ip, j, k + kr] 259 | ) 260 | sxb = -drodxb / (ti.min(0.0, drodzb) - epsln) 261 | taper = dm_taper(sxb) 262 | sumx += ( 263 | dxu[i - 1 + ip] 264 | * K_iso[i, j, k] 265 | * taper 266 | * sxb ** 2 267 | * maskW[i, j, k] 268 | ) 269 | Ai_bx[i, j, k, ip, kr] = taper * sxb * maskW[i, j, k] 270 | 271 | # northward slopes at the top of T cells 272 | for jp in range(2): 273 | facty = cosu[j - 1 + jp] * dyu[j - 1 + jp] 274 | drodyb = ( 275 | drdT[i, j, k + kr] * dTdy[i, j + jp - 1, k + kr] 276 | + drdS[i, j, k + kr] * dSdy[i, j + jp - 1, k + kr] 277 | ) 278 | syb = -drodyb / (ti.min(0.0, drodzb) - epsln) 279 | taper = dm_taper(syb) 280 | sumy += ( 281 | facty * K_iso[i, j, k] * taper * syb ** 2 * maskW[i, j, k] 282 | ) 283 | Ai_by[i, j, k, jp, kr] = taper * syb * maskW[i, j, k] 284 | 285 | K_33[i, j, k] = sumx / (4 * dxt[i]) + sumy / (4 * dyt[j] * cost[j]) 286 | 287 | K_33[i, j, nz-1] = 0.0 288 | 289 | 290 | @lru_cache 291 | def get_fields(sizes, num_scratch=0): 292 | fields = [] 293 | for size in sizes: 294 | fields.append(ti.field(dtype=ti.f64, shape=size)) 295 | 296 | for _ in range(num_scratch): 297 | fields.append(ti.field(dtype=ti.f64, shape=sizes[0])) 298 | 299 | return fields 300 | 301 | 302 | def prepare_inputs(*inputs, device="cpu"): 303 | field_inputs = get_fields(tuple(inp.shape for inp in inputs), num_scratch=8) 304 | 305 | for inp, inp_orig in zip(field_inputs, inputs): 306 | inp.from_numpy(inp_orig) 307 | 308 | return field_inputs 309 | 310 | 311 | def run(*inputs, device="cpu"): 312 | isoneutral_diffusion_pre(*inputs) 313 | out = inputs[16:23] 314 | ti.sync() 315 | return out 316 | -------------------------------------------------------------------------------- /benchmarks/turbulent_kinetic_energy/README.md: -------------------------------------------------------------------------------- 1 | # Turbulent kinetic energy benchmark 2 | 3 | This is a parameterization for turbulence in large-scale ocean models. 4 | 5 | When we model the whole global ocean, every grid cell is orders of magnitude 6 | too large to resolve small-scale turbulence (even in our most costly simulations). 7 | Therefore, we need a *parameterization* that lets us quantify the effect of turbulence 8 | on the large-scale flow without resolving it explicitly. 9 | 10 | This routine consists of some finite difference derivatives, but also some more challenging 11 | operations like boolean mask operations and a tridiagonal matrix solver 12 | *which cannot be fully vectorized*. 13 | -------------------------------------------------------------------------------- /benchmarks/turbulent_kinetic_energy/__init__.py: -------------------------------------------------------------------------------- 1 | import math 2 | import importlib 3 | import functools 4 | 5 | 6 | def generate_inputs(size): 7 | import numpy as np 8 | 9 | np.random.seed(17) 10 | 11 | shape = ( 12 | math.ceil(2 * size ** (1 / 3)), 13 | math.ceil(2 * size ** (1 / 3)), 14 | math.ceil(0.25 * size ** (1 / 3)), 15 | ) 16 | 17 | # masks 18 | maskU, maskV, maskW = ( 19 | (np.random.rand(*shape) < 0.8).astype("float64") for _ in range(3) 20 | ) 21 | 22 | # 1d arrays 23 | dxt, dxu = (np.random.randn(shape[0]) for _ in range(2)) 24 | dyt, dyu = (np.random.randn(shape[1]) for _ in range(2)) 25 | dzt, dzw = (np.random.randn(shape[2]) for _ in range(2)) 26 | cost, cosu = (np.random.randn(shape[1]) for _ in range(2)) 27 | 28 | # 2d arrays 29 | kbot = np.random.randint(0, shape[2], size=shape[:2]) 30 | forc_tke_surface = np.random.randn(*shape[:2]) 31 | 32 | # 3d arrays 33 | kappaM, mxl, forc = (np.random.randn(*shape) for _ in range(3)) 34 | 35 | # 4d arrays 36 | u, v, w, tke, dtke = (np.random.randn(*shape, 3) for _ in range(5)) 37 | 38 | return ( 39 | u, 40 | v, 41 | w, 42 | maskU, 43 | maskV, 44 | maskW, 45 | dxt, 46 | dxu, 47 | dyt, 48 | dyu, 49 | dzt, 50 | dzw, 51 | cost, 52 | cosu, 53 | kbot, 54 | kappaM, 55 | mxl, 56 | forc, 57 | forc_tke_surface, 58 | tke, 59 | dtke, 60 | ) 61 | 62 | 63 | def try_import(backend): 64 | try: 65 | return importlib.import_module(f".tke_{backend}", __name__) 66 | except ImportError: 67 | return None 68 | 69 | 70 | def get_callable(backend, size, device="cpu"): 71 | backend_module = try_import(backend) 72 | inputs = generate_inputs(size) 73 | if hasattr(backend_module, "prepare_inputs"): 74 | inputs = backend_module.prepare_inputs(*inputs, device=device) 75 | return functools.partial(backend_module.run, *inputs, device=device) 76 | 77 | 78 | __implementations__ = ( 79 | "jax", 80 | "numba", 81 | "numpy", 82 | "pytorch", 83 | ) 84 | -------------------------------------------------------------------------------- /benchmarks/turbulent_kinetic_energy/tke_jax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | 4 | 5 | @jax.jit 6 | def where(mask, a, b): 7 | return np.where(mask, a, b) 8 | 9 | 10 | @jax.jit 11 | def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None): 12 | land_mask = (ks >= 0)[:, :, np.newaxis] 13 | edge_mask = land_mask & ( 14 | np.arange(a.shape[2])[np.newaxis, np.newaxis, :] == ks[:, :, np.newaxis] 15 | ) 16 | water_mask = land_mask & ( 17 | np.arange(a.shape[2])[np.newaxis, np.newaxis, :] >= ks[:, :, np.newaxis] 18 | ) 19 | 20 | a_tri = water_mask * a * np.logical_not(edge_mask) 21 | b_tri = where(water_mask, b, 1.0) 22 | if b_edge is not None: 23 | b_tri = where(edge_mask, b_edge, b_tri) 24 | c_tri = water_mask * c 25 | d_tri = water_mask * d 26 | if d_edge is not None: 27 | d_tri = where(edge_mask, d_edge, d_tri) 28 | 29 | return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask 30 | 31 | 32 | @jax.jit 33 | def solve_tridiag(a, b, c, d): 34 | """ 35 | Solves a tridiagonal matrix system with diagonals a, b, c and RHS vector d. 36 | """ 37 | assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape 38 | 39 | def compute_primes(last_primes, x): 40 | last_cp, last_dp = last_primes 41 | a, b, c, d = x 42 | cp = c / (b - a * last_cp) 43 | dp = (d - a * last_dp) / (b - a * last_cp) 44 | new_primes = np.stack((cp, dp)) 45 | return new_primes, new_primes 46 | 47 | diags_stacked = np.stack([arr.transpose((2, 0, 1)) for arr in (a, b, c, d)], axis=1) 48 | _, primes = jax.lax.scan( 49 | compute_primes, np.zeros((2, *a.shape[:-1])), diags_stacked 50 | ) 51 | 52 | def backsubstitution(last_x, x): 53 | cp, dp = x 54 | new_x = dp - cp * last_x 55 | return new_x, new_x 56 | 57 | _, sol = jax.lax.scan(backsubstitution, np.zeros(a.shape[:-1]), primes[::-1]) 58 | return sol[::-1].transpose((1, 2, 0)) 59 | 60 | 61 | @jax.jit 62 | def _calc_cr(rjp, rj, rjm, vel): 63 | """ 64 | Calculates cr value used in superbee advection scheme 65 | """ 66 | eps = 1e-20 # prevent division by 0 67 | return where(vel > 0.0, rjm, rjp) / where(np.abs(rj) < eps, eps, rj) 68 | 69 | 70 | @jax.jit 71 | def pad_z_edges(arr): 72 | arr_shape = list(arr.shape) 73 | arr_shape[2] += 2 74 | out = np.zeros(arr_shape, arr.dtype) 75 | out = out.at[:, :, 1:-1].set( arr) 76 | return out 77 | 78 | 79 | @jax.jit 80 | def limiter(cr): 81 | return np.maximum(0.0, np.maximum(np.minimum(1.0, 2 * cr), np.minimum(2.0, cr))) 82 | 83 | 84 | def _adv_superbee(vel, var, mask, dx, axis, cost, cosu, dt_tracer): 85 | velfac = 1 86 | if axis == 0: 87 | sm1, s, sp1, sp2 = ( 88 | (slice(1 + n, -2 + n or None), slice(2, -2), slice(None)) 89 | for n in range(-1, 3) 90 | ) 91 | dx = cost[np.newaxis, 2:-2, np.newaxis] * dx[1:-2, np.newaxis, np.newaxis] 92 | elif axis == 1: 93 | sm1, s, sp1, sp2 = ( 94 | (slice(2, -2), slice(1 + n, -2 + n or None), slice(None)) 95 | for n in range(-1, 3) 96 | ) 97 | dx = (cost * dx)[np.newaxis, 1:-2, np.newaxis] 98 | velfac = cosu[np.newaxis, 1:-2, np.newaxis] 99 | elif axis == 2: 100 | vel, var, mask = (pad_z_edges(a) for a in (vel, var, mask)) 101 | sm1, s, sp1, sp2 = ( 102 | (slice(2, -2), slice(2, -2), slice(1 + n, -2 + n or None)) 103 | for n in range(-1, 3) 104 | ) 105 | dx = dx[np.newaxis, np.newaxis, :-1] 106 | else: 107 | raise ValueError("axis must be 0, 1, or 2") 108 | uCFL = np.abs(velfac * vel[s] * dt_tracer / dx) 109 | rjp = (var[sp2] - var[sp1]) * mask[sp1] 110 | rj = (var[sp1] - var[s]) * mask[s] 111 | rjm = (var[s] - var[sm1]) * mask[sm1] 112 | cr = limiter(_calc_cr(rjp, rj, rjm, vel[s])) 113 | return ( 114 | velfac * vel[s] * (var[sp1] + var[s]) * 0.5 115 | - np.abs(velfac * vel[s]) * ((1.0 - cr) + uCFL * cr) * rj * 0.5 116 | ) 117 | 118 | 119 | _adv_superbee = jax.jit(_adv_superbee, static_argnums=(4,)) 120 | 121 | 122 | @jax.jit 123 | def adv_flux_superbee_wgrid( 124 | var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer 125 | ): 126 | """ 127 | Calculates advection of a tracer defined on Wgrid 128 | """ 129 | maskUtr = np.zeros_like(maskW) 130 | maskUtr = maskUtr.at[:-1, :, :].set( maskW[1:, :, :] * maskW[:-1, :, :] 131 | ) 132 | 133 | adv_fe = np.zeros_like(maskW) 134 | adv_fe = adv_fe.at[1:-2, 2:-2, :].set( 135 | _adv_superbee(u_wgrid, var, maskUtr, dxt, 0, cost, cosu, dt_tracer), 136 | ) 137 | 138 | maskVtr = np.zeros_like(maskW) 139 | maskVtr = maskVtr.at[:, :-1, :].set( maskW[:, 1:, :] * maskW[:, :-1, :] 140 | ) 141 | adv_fn = np.zeros_like(maskW) 142 | adv_fn = adv_fn.at[2:-2, 1:-2, :].set( 143 | _adv_superbee(v_wgrid, var, maskVtr, dyt, 1, cost, cosu, dt_tracer), 144 | ) 145 | 146 | maskWtr = np.zeros_like(maskW) 147 | maskWtr = maskWtr.at[:, :, :-1].set( maskW[:, :, 1:] * maskW[:, :, :-1] 148 | ) 149 | adv_ft = np.zeros_like(maskW) 150 | adv_ft = adv_ft.at[2:-2, 2:-2, :-1].set( 151 | _adv_superbee(w_wgrid, var, maskWtr, dzw, 2, cost, cosu, dt_tracer), 152 | ) 153 | 154 | return adv_fe, adv_fn, adv_ft 155 | 156 | 157 | @jax.jit 158 | def integrate_tke( 159 | u, 160 | v, 161 | w, 162 | maskU, 163 | maskV, 164 | maskW, 165 | dxt, 166 | dxu, 167 | dyt, 168 | dyu, 169 | dzt, 170 | dzw, 171 | cost, 172 | cosu, 173 | kbot, 174 | kappaM, 175 | mxl, 176 | forc, 177 | forc_tke_surface, 178 | tke, 179 | dtke, 180 | ): 181 | tau = 0 182 | taup1 = 1 183 | taum1 = 2 184 | 185 | dt_tracer = 1.0 186 | dt_mom = 1.0 187 | AB_eps = 0.1 188 | alpha_tke = 1.0 189 | c_eps = 0.7 190 | K_h_tke = 2000.0 191 | 192 | flux_east = np.zeros_like(maskU) 193 | flux_north = np.zeros_like(maskU) 194 | flux_top = np.zeros_like(maskU) 195 | 196 | sqrttke = np.sqrt(np.maximum(0.0, tke[:, :, :, tau])) 197 | 198 | """ 199 | integrate Tke equation on W grid with surface flux boundary condition 200 | """ 201 | dt_tke = dt_mom # use momentum time step to prevent spurious oscillations 202 | 203 | """ 204 | vertical mixing and dissipation of TKE 205 | """ 206 | ks = kbot[2:-2, 2:-2] - 1 207 | 208 | a_tri = np.zeros_like(maskU[2:-2, 2:-2]) 209 | b_tri = np.zeros_like(maskU[2:-2, 2:-2]) 210 | c_tri = np.zeros_like(maskU[2:-2, 2:-2]) 211 | d_tri = np.zeros_like(maskU[2:-2, 2:-2]) 212 | delta = np.zeros_like(maskU[2:-2, 2:-2]) 213 | 214 | delta = delta.at[:, :, :-1].set( 215 | dt_tke 216 | / dzt[np.newaxis, np.newaxis, 1:] 217 | * alpha_tke 218 | * 0.5 219 | * (kappaM[2:-2, 2:-2, :-1] + kappaM[2:-2, 2:-2, 1:]), 220 | ) 221 | 222 | a_tri = a_tri.at[:, :, 1:-1].set( 223 | -delta[:, :, :-2] / dzw[np.newaxis, np.newaxis, 1:-1], 224 | ) 225 | a_tri = a_tri.at[:, :, -1].set( -delta[:, :, -2] / (0.5 * dzw[-1]) 226 | ) 227 | 228 | b_tri = b_tri.at[:, :, 1:-1].set( 229 | 1 230 | + (delta[:, :, 1:-1] + delta[:, :, :-2]) / dzw[np.newaxis, np.newaxis, 1:-1] 231 | + dt_tke * c_eps * sqrttke[2:-2, 2:-2, 1:-1] / mxl[2:-2, 2:-2, 1:-1], 232 | ) 233 | b_tri = b_tri.at[:, :, -1].set( 234 | 1 235 | + delta[:, :, -2] / (0.5 * dzw[-1]) 236 | + dt_tke * c_eps / mxl[2:-2, 2:-2, -1] * sqrttke[2:-2, 2:-2, -1], 237 | ) 238 | b_tri_edge = ( 239 | 1 240 | + delta / dzw[np.newaxis, np.newaxis, :] 241 | + dt_tke * c_eps / mxl[2:-2, 2:-2, :] * sqrttke[2:-2, 2:-2, :] 242 | ) 243 | 244 | c_tri = c_tri.at[:, :, :-1].set( 245 | -delta[:, :, :-1] / dzw[np.newaxis, np.newaxis, :-1], 246 | ) 247 | 248 | d_tri = tke[2:-2, 2:-2, :, tau] + dt_tke * forc[2:-2, 2:-2, :] 249 | d_tri = d_tri.at[:, :, -1].add( 250 | dt_tke * forc_tke_surface[2:-2, 2:-2] / (0.5 * dzw[-1]), 251 | ) 252 | 253 | sol, water_mask = solve_implicit(ks, a_tri, b_tri, c_tri, d_tri, b_edge=b_tri_edge) 254 | tke = tke.at[2:-2, 2:-2, :, taup1].set( 255 | where(water_mask, sol, tke[2:-2, 2:-2, :, taup1]), 256 | ) 257 | 258 | """ 259 | Add TKE if surface density flux drains TKE in uppermost box 260 | """ 261 | mask = tke[2:-2, 2:-2, -1, taup1] < 0.0 262 | tke_surf_corr = np.zeros_like(maskU[..., -1]) 263 | tke_surf_corr = tke_surf_corr.at[2:-2, 2:-2].set( 264 | where(mask, -tke[2:-2, 2:-2, -1, taup1] * 0.5 * dzw[-1] / dt_tke, 0.0), 265 | ) 266 | 267 | tke = tke.at[2:-2, 2:-2, -1, taup1].set( 268 | np.maximum(0.0, tke[2:-2, 2:-2, -1, taup1]), 269 | ) 270 | 271 | """ 272 | add tendency due to lateral diffusion 273 | """ 274 | flux_east = flux_east.at[:-1, :, :].set( 275 | K_h_tke 276 | * (tke[1:, :, :, tau] - tke[:-1, :, :, tau]) 277 | / (cost[np.newaxis, :, np.newaxis] * dxu[:-1, np.newaxis, np.newaxis]) 278 | * maskU[:-1, :, :], 279 | ) 280 | flux_north = flux_north.at[:, :-1, :].set( 281 | K_h_tke 282 | * (tke[:, 1:, :, tau] - tke[:, :-1, :, tau]) 283 | / dyu[np.newaxis, :-1, np.newaxis] 284 | * maskV[:, :-1, :] 285 | * cosu[np.newaxis, :-1, np.newaxis], 286 | ) 287 | tke = tke.at[2:-2, 2:-2, :, taup1].add( 288 | dt_tke 289 | * maskW[2:-2, 2:-2, :] 290 | * ( 291 | (flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :]) 292 | / (cost[np.newaxis, 2:-2, np.newaxis] * dxt[2:-2, np.newaxis, np.newaxis]) 293 | + (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :]) 294 | / (cost[np.newaxis, 2:-2, np.newaxis] * dyt[np.newaxis, 2:-2, np.newaxis]) 295 | ), 296 | ) 297 | 298 | """ 299 | add tendency due to advection 300 | """ 301 | flux_east, flux_north, flux_top = adv_flux_superbee_wgrid( 302 | tke[:, :, :, tau], 303 | u[..., tau], 304 | v[..., tau], 305 | w[..., tau], 306 | maskW, 307 | dxt, 308 | dyt, 309 | dzw, 310 | cost, 311 | cosu, 312 | dt_tracer, 313 | ) 314 | 315 | dtke = dtke.at[2:-2, 2:-2, :, tau].set( 316 | maskW[2:-2, 2:-2, :] 317 | * ( 318 | -(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :]) 319 | / (cost[np.newaxis, 2:-2, np.newaxis] * dxt[2:-2, np.newaxis, np.newaxis]) 320 | - (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :]) 321 | / (cost[np.newaxis, 2:-2, np.newaxis] * dyt[np.newaxis, 2:-2, np.newaxis]) 322 | ), 323 | ) 324 | dtke = dtke.at[:, :, 0, tau].add( -flux_top[:, :, 0] / dzw[0] 325 | ) 326 | dtke = dtke.at[:, :, 1:-1, tau].add( 327 | -(flux_top[:, :, 1:-1] - flux_top[:, :, :-2]) / dzw[1:-1], 328 | ) 329 | dtke = dtke.at[:, :, -1, tau].add( 330 | -(flux_top[:, :, -1] - flux_top[:, :, -2]) / (0.5 * dzw[-1]), 331 | ) 332 | 333 | """ 334 | Adam Bashforth time stepping 335 | """ 336 | tke = tke.at[:, :, :, taup1].add( 337 | dt_tracer 338 | * ((1.5 + AB_eps) * dtke[:, :, :, tau] - (0.5 + AB_eps) * dtke[:, :, :, taum1]), 339 | ) 340 | 341 | return tke, dtke, tke_surf_corr 342 | 343 | 344 | def prepare_inputs(*inputs, device): 345 | out = [np.array(k) for k in inputs] 346 | for o in out: 347 | o.block_until_ready() 348 | return out 349 | 350 | 351 | def run(*inputs, device="cpu"): 352 | outputs = integrate_tke(*inputs) 353 | for o in outputs: 354 | o.block_until_ready() 355 | return outputs 356 | -------------------------------------------------------------------------------- /benchmarks/turbulent_kinetic_energy/tke_numba.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba as nb 3 | 4 | 5 | @nb.jit(nopython=True, fastmath=True, cache=True) 6 | def solve_tridiag(a, b, c, d): 7 | """ 8 | Solves a tridiagonal matrix system with diagonals a, b, c and RHS vector d. 9 | """ 10 | assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape 11 | 12 | n = a.shape[0] 13 | 14 | for i in range(1, n): 15 | w = a[i] / b[i - 1] 16 | b[i] += -w * c[i - 1] 17 | d[i] += -w * d[i - 1] 18 | 19 | out = np.empty_like(a) 20 | out[-1] = d[-1] / b[-1] 21 | 22 | for i in range(n - 2, -1, -1): 23 | out[i] = (d[i] - c[i] * out[i + 1]) / b[i] 24 | 25 | return out 26 | 27 | 28 | @nb.jit(nopython=True, fastmath=True, cache=True) 29 | def _calc_cr(r_jp, r_j, r_jm, vel): 30 | """ 31 | Calculates cr value used in superbee advection scheme 32 | """ 33 | eps = 1e-20 # prevent division by 0 34 | if abs(r_j) < eps: 35 | fac = eps 36 | else: 37 | fac = r_j 38 | 39 | if vel > 0: 40 | return r_jm / fac 41 | else: 42 | return r_jp / fac 43 | 44 | 45 | @nb.jit(nopython=True, fastmath=True, cache=True) 46 | def limiter(cr): 47 | return max(0.0, max(min(1.0, 2.0 * cr), min(2.0, cr))) 48 | 49 | 50 | @nb.jit(nopython=True, fastmath=True, cache=True) 51 | def adv_flux_superbee_wgrid( 52 | adv_fe, 53 | adv_fn, 54 | adv_ft, 55 | var, 56 | u_wgrid, 57 | v_wgrid, 58 | w_wgrid, 59 | maskW, 60 | dxt, 61 | dyt, 62 | dzw, 63 | cost, 64 | cosu, 65 | dt_tracer, 66 | ): 67 | """ 68 | Calculates advection of a tracer defined on Wgrid 69 | """ 70 | nx, ny, nz = var.shape 71 | 72 | maskUtr = np.zeros_like(maskW) 73 | maskUtr[:-1, :, :] = maskW[1:, :, :] * maskW[:-1, :, :] 74 | 75 | adv_fe[...] = 0.0 76 | for i in range(1, nx - 2): 77 | for j in range(2, ny - 2): 78 | for k in range(nz): 79 | vel = u_wgrid[i, j, k] 80 | u_cfl = abs(vel * dt_tracer / (cost[j] * dxt[i])) 81 | r_jp = (var[i + 2, j, k] - var[i + 1, j, k]) * maskUtr[i + 1, j, k] 82 | r_j = (var[i + 1, j, k] - var[i, j, k]) * maskUtr[i, j, k] 83 | r_jm = (var[i, j, k] - var[i - 1, j, k]) * maskUtr[i - 1, j, k] 84 | cr = limiter(_calc_cr(r_jp, r_j, r_jm, vel)) 85 | adv_fe[i, j, k] = ( 86 | vel * (var[i + 1, j, k] + var[i, j, k]) * 0.5 87 | - abs(vel) * ((1.0 - cr) + u_cfl * cr) * r_j * 0.5 88 | ) 89 | 90 | maskVtr = np.zeros_like(maskW) 91 | maskVtr[:, :-1, :] = maskW[:, 1:, :] * maskW[:, :-1, :] 92 | 93 | adv_fn[...] = 0.0 94 | for i in range(2, nx - 2): 95 | for j in range(1, ny - 2): 96 | for k in range(nz): 97 | vel = cosu[j] * v_wgrid[i, j, k] 98 | u_cfl = abs(vel * dt_tracer / (cost[j] * dyt[j])) 99 | r_jp = (var[i, j + 2, k] - var[i, j + 1, k]) * maskVtr[i, j + 1, k] 100 | r_j = (var[i, j + 1, k] - var[i, j, k]) * maskVtr[i, j, k] 101 | r_jm = (var[i, j, k] - var[i, j - 1, k]) * maskVtr[i, j - 1, k] 102 | cr = limiter(_calc_cr(r_jp, r_j, r_jm, v_wgrid[i, j, k])) 103 | adv_fn[i, j, k] = ( 104 | vel * (var[i, j + 1, k] + var[i, j, k]) * 0.5 105 | - abs(vel) * ((1.0 - cr) + u_cfl * cr) * r_j * 0.5 106 | ) 107 | 108 | maskWtr = np.zeros_like(maskW) 109 | maskWtr[:, :, :-1] = maskW[:, :, 1:] * maskW[:, :, :-1] 110 | 111 | adv_ft[...] = 0.0 112 | for i in range(2, nx - 2): 113 | for j in range(2, ny - 2): 114 | for k in range(nz - 1): 115 | kp1 = min(nz - 2, k + 1) 116 | kp2 = min(nz - 1, k + 2) 117 | km1 = max(0, k - 1) 118 | 119 | vel = w_wgrid[i, j, k] 120 | u_cfl = abs(vel * dt_tracer / dzw[k]) 121 | r_jp = (var[i, j, kp2] - var[i, j, k + 1]) * maskWtr[i, j, kp1] 122 | r_j = (var[i, j, k + 1] - var[i, j, k]) * maskWtr[i, j, k] 123 | r_jm = (var[i, j, k] - var[i, j, km1]) * maskWtr[i, j, km1] 124 | cr = limiter(_calc_cr(r_jp, r_j, r_jm, vel)) 125 | adv_ft[i, j, k] = ( 126 | vel * (var[i, j, k + 1] + var[i, j, k]) * 0.5 127 | - abs(vel) * ((1.0 - cr) + u_cfl * cr) * r_j * 0.5 128 | ) 129 | 130 | 131 | @nb.jit(nopython=True, boundscheck=False, nogil=True, fastmath=True, cache=True) 132 | def integrate_tke( 133 | u, 134 | v, 135 | w, 136 | maskU, 137 | maskV, 138 | maskW, 139 | dxt, 140 | dxu, 141 | dyt, 142 | dyu, 143 | dzt, 144 | dzw, 145 | cost, 146 | cosu, 147 | kbot, 148 | kappaM, 149 | mxl, 150 | forc, 151 | forc_tke_surface, 152 | tke, 153 | dtke, 154 | ): 155 | nx, ny, nz = maskU.shape 156 | 157 | tau = 0 158 | taup1 = 1 159 | taum1 = 2 160 | 161 | dt_tracer = 1 162 | dt_mom = 1 163 | AB_eps = 0.1 164 | alpha_tke = 1.0 165 | c_eps = 0.7 166 | K_h_tke = 2000.0 167 | 168 | flux_east = np.zeros_like(maskU) 169 | flux_north = np.zeros_like(maskU) 170 | flux_top = np.zeros_like(maskU) 171 | 172 | sqrttke = np.sqrt(np.maximum(0.0, tke[:, :, :, tau])) 173 | 174 | """ 175 | integrate Tke equation on W grid with surface flux boundary condition 176 | """ 177 | dt_tke = dt_mom # use momentum time step to prevent spurious oscillations 178 | 179 | """ 180 | vertical mixing and dissipation of TKE 181 | """ 182 | a_tri = np.empty(nz) 183 | b_tri = np.empty(nz) 184 | c_tri = np.empty(nz) 185 | d_tri = np.empty(nz) 186 | delta = np.empty(nz) 187 | 188 | ke = nz - 1 189 | for i in range(2, nx - 2): 190 | for j in range(2, ny - 2): 191 | ks = kbot[i, j] - 1 192 | if ks < 0: 193 | continue 194 | 195 | for k in range(ks, ke): 196 | delta[k] = ( 197 | dt_tke 198 | / dzt[k + 1] 199 | * alpha_tke 200 | * 0.5 201 | * (kappaM[i, j, k] + kappaM[i, j, k + 1]) 202 | ) 203 | delta[ke] = 0.0 204 | 205 | for k in range(ks + 1, ke): 206 | a_tri[k] = -delta[k - 1] / dzw[k] 207 | a_tri[ks] = 0.0 208 | a_tri[ke] = -delta[ke - 1] / (0.5 * dzw[ke]) 209 | 210 | for k in range(ks + 1, ke): 211 | b_tri[k] = ( 212 | 1 213 | + delta[k] / dzw[k] 214 | + delta[k - 1] / dzw[k] 215 | + dt_tke * c_eps * sqrttke[i, j, k] / mxl[i, j, k] 216 | ) 217 | b_tri[ke] = ( 218 | 1 219 | + delta[ke - 1] / (0.5 * dzw[ke]) 220 | + dt_tke * c_eps * sqrttke[i, j, ke] / mxl[i, j, ke] 221 | ) 222 | b_tri[ks] = ( 223 | 1 224 | + delta[ks] / dzw[ks] 225 | + dt_tke * c_eps * sqrttke[i, j, ks] / mxl[i, j, ks] 226 | ) 227 | 228 | for k in range(ks, ke): 229 | c_tri[k] = -delta[k] / dzw[k] 230 | c_tri[ke] = 0.0 231 | 232 | d_tri[ks:] = tke[i, j, ks:, tau] + dt_tke * forc[i, j, ks:] 233 | d_tri[ke] += dt_tke * forc_tke_surface[i, j] / (0.5 * dzw[ke]) 234 | 235 | tke[i, j, ks:, taup1] = solve_tridiag( 236 | a_tri[ks:], b_tri[ks:], c_tri[ks:], d_tri[ks:] 237 | ) 238 | 239 | """ 240 | Add TKE if surface density flux drains TKE in uppermost box 241 | """ 242 | tke_surf_corr = np.zeros((nx, ny)) 243 | for i in range(2, nx - 2): 244 | for j in range(2, ny - 2): 245 | if tke[i, j, -1, taup1] >= 0.0: 246 | continue 247 | tke_surf_corr[i, j] = -tke[i, j, -1, taup1] * (0.5 * dzw[-1]) / dt_tke 248 | tke[i, j, -1, taup1] = 0.0 249 | 250 | """ 251 | add tendency due to lateral diffusion 252 | """ 253 | for i in range(nx - 1): 254 | for j in range(ny): 255 | flux_east[i, j, :] = ( 256 | K_h_tke 257 | * (tke[i + 1, j, :, tau] - tke[i, j, :, tau]) 258 | / (cost[j] * dxu[i]) 259 | * maskU[i, j, :] 260 | ) 261 | flux_east[-1, :, :] = 0.0 262 | 263 | for j in range(ny - 1): 264 | flux_north[:, j, :] = ( 265 | K_h_tke 266 | * (tke[:, j + 1, :, tau] - tke[:, j, :, tau]) 267 | / dyu[j] 268 | * maskV[:, j, :] 269 | * cosu[j] 270 | ) 271 | flux_north[:, -1, :] = 0.0 272 | 273 | for i in range(2, nx - 2): 274 | for j in range(2, ny - 2): 275 | tke[i, j, :, taup1] += ( 276 | dt_tke 277 | * maskW[i, j, :] 278 | * ( 279 | (flux_east[i, j, :] - flux_east[i - 1, j, :]) / (cost[j] * dxt[i]) 280 | + (flux_north[i, j, :] - flux_north[i, j - 1, :]) 281 | / (cost[j] * dyt[j]) 282 | ) 283 | ) 284 | 285 | """ 286 | add tendency due to advection 287 | """ 288 | adv_flux_superbee_wgrid( 289 | flux_east, 290 | flux_north, 291 | flux_top, 292 | tke[:, :, :, tau], 293 | u[..., tau], 294 | v[..., tau], 295 | w[..., tau], 296 | maskW, 297 | dxt, 298 | dyt, 299 | dzw, 300 | cost, 301 | cosu, 302 | dt_tracer, 303 | ) 304 | 305 | for i in range(2, nx - 2): 306 | for j in range(2, ny - 2): 307 | dtke[i, j, :, tau] = maskW[i, j, :] * ( 308 | -(flux_east[i, j, :] - flux_east[i - 1, j, :]) / (cost[j] * dxt[i]) 309 | - (flux_north[i, j, :] - flux_north[i, j - 1, :]) / (cost[j] * dyt[j]) 310 | ) 311 | dtke[:, :, 0, tau] += -flux_top[:, :, 0] / dzw[0] 312 | dtke[:, :, 1:-1, tau] += -(flux_top[:, :, 1:-1] - flux_top[:, :, :-2]) / dzw[1:-1] 313 | dtke[:, :, -1, tau] += -(flux_top[:, :, -1] - flux_top[:, :, -2]) / (0.5 * dzw[-1]) 314 | 315 | """ 316 | Adam Bashforth time stepping 317 | """ 318 | tke[:, :, :, taup1] += dt_tracer * ( 319 | (1.5 + AB_eps) * dtke[:, :, :, tau] - (0.5 + AB_eps) * dtke[:, :, :, taum1] 320 | ) 321 | 322 | return tke, dtke, tke_surf_corr 323 | 324 | 325 | def run(*inputs, device="cpu"): 326 | outputs = integrate_tke(*inputs) 327 | return outputs 328 | -------------------------------------------------------------------------------- /benchmarks/turbulent_kinetic_energy/tke_numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.linalg import lapack 3 | 4 | 5 | def where(mask, a, b): 6 | return np.where(mask, a, b) 7 | 8 | 9 | def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None): 10 | land_mask = (ks >= 0)[:, :, np.newaxis] 11 | edge_mask = land_mask & ( 12 | np.arange(a.shape[2])[np.newaxis, np.newaxis, :] == ks[:, :, np.newaxis] 13 | ) 14 | water_mask = land_mask & ( 15 | np.arange(a.shape[2])[np.newaxis, np.newaxis, :] >= ks[:, :, np.newaxis] 16 | ) 17 | 18 | a_tri = water_mask * a * np.logical_not(edge_mask) 19 | b_tri = where(water_mask, b, 1.0) 20 | if b_edge is not None: 21 | b_tri = where(edge_mask, b_edge, b_tri) 22 | c_tri = water_mask * c 23 | d_tri = water_mask * d 24 | if d_edge is not None: 25 | d_tri = where(edge_mask, d_edge, d_tri) 26 | 27 | return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask 28 | 29 | 30 | def solve_tridiag(a, b, c, d): 31 | """ 32 | Solves a tridiagonal matrix system with diagonals a, b, c and RHS vector d. 33 | """ 34 | assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape 35 | a[..., 0] = c[..., -1] = 0 # remove couplings between slices 36 | return lapack.dgtsv(a.flatten()[1:], b.flatten(), c.flatten()[:-1], d.flatten())[ 37 | 3 38 | ].reshape(a.shape) 39 | 40 | 41 | def _calc_cr(rjp, rj, rjm, vel): 42 | """ 43 | Calculates cr value used in superbee advection scheme 44 | """ 45 | eps = 1e-20 # prevent division by 0 46 | return where(vel > 0.0, rjm, rjp) / where(np.abs(rj) < eps, eps, rj) 47 | 48 | 49 | def pad_z_edges(arr): 50 | arr_shape = list(arr.shape) 51 | arr_shape[2] += 2 52 | out = np.zeros(arr_shape, arr.dtype) 53 | out[:, :, 1:-1] = arr 54 | return out 55 | 56 | 57 | def limiter(cr): 58 | return np.maximum(0.0, np.maximum(np.minimum(1.0, 2 * cr), np.minimum(2.0, cr))) 59 | 60 | 61 | def _adv_superbee(vel, var, mask, dx, axis, cost, cosu, dt_tracer): 62 | velfac = 1 63 | if axis == 0: 64 | sm1, s, sp1, sp2 = ( 65 | (slice(1 + n, -2 + n or None), slice(2, -2), slice(None)) 66 | for n in range(-1, 3) 67 | ) 68 | dx = cost[np.newaxis, 2:-2, np.newaxis] * dx[1:-2, np.newaxis, np.newaxis] 69 | elif axis == 1: 70 | sm1, s, sp1, sp2 = ( 71 | (slice(2, -2), slice(1 + n, -2 + n or None), slice(None)) 72 | for n in range(-1, 3) 73 | ) 74 | dx = (cost * dx)[np.newaxis, 1:-2, np.newaxis] 75 | velfac = cosu[np.newaxis, 1:-2, np.newaxis] 76 | elif axis == 2: 77 | vel, var, mask = (pad_z_edges(a) for a in (vel, var, mask)) 78 | sm1, s, sp1, sp2 = ( 79 | (slice(2, -2), slice(2, -2), slice(1 + n, -2 + n or None)) 80 | for n in range(-1, 3) 81 | ) 82 | dx = dx[np.newaxis, np.newaxis, :-1] 83 | else: 84 | raise ValueError("axis must be 0, 1, or 2") 85 | uCFL = np.abs(velfac * vel[s] * dt_tracer / dx) 86 | rjp = (var[sp2] - var[sp1]) * mask[sp1] 87 | rj = (var[sp1] - var[s]) * mask[s] 88 | rjm = (var[s] - var[sm1]) * mask[sm1] 89 | cr = limiter(_calc_cr(rjp, rj, rjm, vel[s])) 90 | return ( 91 | velfac * vel[s] * (var[sp1] + var[s]) * 0.5 92 | - np.abs(velfac * vel[s]) * ((1.0 - cr) + uCFL * cr) * rj * 0.5 93 | ) 94 | 95 | 96 | def adv_flux_superbee_wgrid( 97 | adv_fe, 98 | adv_fn, 99 | adv_ft, 100 | var, 101 | u_wgrid, 102 | v_wgrid, 103 | w_wgrid, 104 | maskW, 105 | dxt, 106 | dyt, 107 | dzw, 108 | cost, 109 | cosu, 110 | dt_tracer, 111 | ): 112 | """ 113 | Calculates advection of a tracer defined on Wgrid 114 | """ 115 | maskUtr = np.zeros_like(maskW) 116 | maskUtr[:-1, :, :] = maskW[1:, :, :] * maskW[:-1, :, :] 117 | adv_fe[...] = 0.0 118 | adv_fe[1:-2, 2:-2, :] = _adv_superbee( 119 | u_wgrid, var, maskUtr, dxt, 0, cost, cosu, dt_tracer 120 | ) 121 | 122 | maskVtr = np.zeros_like(maskW) 123 | maskVtr[:, :-1, :] = maskW[:, 1:, :] * maskW[:, :-1, :] 124 | adv_fn[...] = 0.0 125 | adv_fn[2:-2, 1:-2, :] = _adv_superbee( 126 | v_wgrid, var, maskVtr, dyt, 1, cost, cosu, dt_tracer 127 | ) 128 | 129 | maskWtr = np.zeros_like(maskW) 130 | maskWtr[:, :, :-1] = maskW[:, :, 1:] * maskW[:, :, :-1] 131 | adv_ft[...] = 0.0 132 | adv_ft[2:-2, 2:-2, :-1] = _adv_superbee( 133 | w_wgrid, var, maskWtr, dzw, 2, cost, cosu, dt_tracer 134 | ) 135 | 136 | 137 | def integrate_tke( 138 | u, 139 | v, 140 | w, 141 | maskU, 142 | maskV, 143 | maskW, 144 | dxt, 145 | dxu, 146 | dyt, 147 | dyu, 148 | dzt, 149 | dzw, 150 | cost, 151 | cosu, 152 | kbot, 153 | kappaM, 154 | mxl, 155 | forc, 156 | forc_tke_surface, 157 | tke, 158 | dtke, 159 | ): 160 | tau = 0 161 | taup1 = 1 162 | taum1 = 2 163 | 164 | dt_tracer = 1 165 | dt_mom = 1 166 | AB_eps = 0.1 167 | alpha_tke = 1.0 168 | c_eps = 0.7 169 | K_h_tke = 2000.0 170 | 171 | flux_east = np.zeros_like(maskU) 172 | flux_north = np.zeros_like(maskU) 173 | flux_top = np.zeros_like(maskU) 174 | 175 | sqrttke = np.sqrt(np.maximum(0.0, tke[:, :, :, tau])) 176 | 177 | """ 178 | integrate Tke equation on W grid with surface flux boundary condition 179 | """ 180 | dt_tke = dt_mom # use momentum time step to prevent spurious oscillations 181 | 182 | """ 183 | vertical mixing and dissipation of TKE 184 | """ 185 | ks = kbot[2:-2, 2:-2] - 1 186 | 187 | a_tri = np.zeros_like(maskU[2:-2, 2:-2]) 188 | b_tri = np.zeros_like(maskU[2:-2, 2:-2]) 189 | c_tri = np.zeros_like(maskU[2:-2, 2:-2]) 190 | d_tri = np.zeros_like(maskU[2:-2, 2:-2]) 191 | delta = np.zeros_like(maskU[2:-2, 2:-2]) 192 | 193 | delta[:, :, :-1] = ( 194 | dt_tke 195 | / dzt[np.newaxis, np.newaxis, 1:] 196 | * alpha_tke 197 | * 0.5 198 | * (kappaM[2:-2, 2:-2, :-1] + kappaM[2:-2, 2:-2, 1:]) 199 | ) 200 | 201 | a_tri[:, :, 1:-1] = -delta[:, :, :-2] / dzw[np.newaxis, np.newaxis, 1:-1] 202 | a_tri[:, :, -1] = -delta[:, :, -2] / (0.5 * dzw[-1]) 203 | 204 | b_tri[:, :, 1:-1] = ( 205 | 1 206 | + (delta[:, :, 1:-1] + delta[:, :, :-2]) / dzw[np.newaxis, np.newaxis, 1:-1] 207 | + dt_tke * c_eps * sqrttke[2:-2, 2:-2, 1:-1] / mxl[2:-2, 2:-2, 1:-1] 208 | ) 209 | b_tri[:, :, -1] = ( 210 | 1 211 | + delta[:, :, -2] / (0.5 * dzw[-1]) 212 | + dt_tke * c_eps / mxl[2:-2, 2:-2, -1] * sqrttke[2:-2, 2:-2, -1] 213 | ) 214 | b_tri_edge = ( 215 | 1 216 | + delta / dzw[np.newaxis, np.newaxis, :] 217 | + dt_tke * c_eps / mxl[2:-2, 2:-2, :] * sqrttke[2:-2, 2:-2, :] 218 | ) 219 | 220 | c_tri[:, :, :-1] = -delta[:, :, :-1] / dzw[np.newaxis, np.newaxis, :-1] 221 | 222 | d_tri[...] = tke[2:-2, 2:-2, :, tau] + dt_tke * forc[2:-2, 2:-2, :] 223 | d_tri[:, :, -1] += dt_tke * forc_tke_surface[2:-2, 2:-2] / (0.5 * dzw[-1]) 224 | 225 | sol, water_mask = solve_implicit(ks, a_tri, b_tri, c_tri, d_tri, b_edge=b_tri_edge) 226 | tke[2:-2, 2:-2, :, taup1] = where(water_mask, sol, tke[2:-2, 2:-2, :, taup1]) 227 | 228 | """ 229 | Add TKE if surface density flux drains TKE in uppermost box 230 | """ 231 | tke_surf_corr = np.zeros(maskU.shape[:2]) 232 | mask = tke[2:-2, 2:-2, -1, taup1] < 0.0 233 | tke_surf_corr[2:-2, 2:-2] = where( 234 | mask, -tke[2:-2, 2:-2, -1, taup1] * 0.5 * dzw[-1] / dt_tke, 0.0 235 | ) 236 | tke[2:-2, 2:-2, -1, taup1] = np.maximum(0.0, tke[2:-2, 2:-2, -1, taup1]) 237 | 238 | """ 239 | add tendency due to lateral diffusion 240 | """ 241 | flux_east[:-1, :, :] = ( 242 | K_h_tke 243 | * (tke[1:, :, :, tau] - tke[:-1, :, :, tau]) 244 | / (cost[np.newaxis, :, np.newaxis] * dxu[:-1, np.newaxis, np.newaxis]) 245 | * maskU[:-1, :, :] 246 | ) 247 | flux_east[-1, :, :] = 0.0 248 | flux_north[:, :-1, :] = ( 249 | K_h_tke 250 | * (tke[:, 1:, :, tau] - tke[:, :-1, :, tau]) 251 | / dyu[np.newaxis, :-1, np.newaxis] 252 | * maskV[:, :-1, :] 253 | * cosu[np.newaxis, :-1, np.newaxis] 254 | ) 255 | flux_north[:, -1, :] = 0.0 256 | tke[2:-2, 2:-2, :, taup1] += ( 257 | dt_tke 258 | * maskW[2:-2, 2:-2, :] 259 | * ( 260 | (flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :]) 261 | / (cost[np.newaxis, 2:-2, np.newaxis] * dxt[2:-2, np.newaxis, np.newaxis]) 262 | + (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :]) 263 | / (cost[np.newaxis, 2:-2, np.newaxis] * dyt[np.newaxis, 2:-2, np.newaxis]) 264 | ) 265 | ) 266 | 267 | """ 268 | add tendency due to advection 269 | """ 270 | adv_flux_superbee_wgrid( 271 | flux_east, 272 | flux_north, 273 | flux_top, 274 | tke[:, :, :, tau], 275 | u[..., tau], 276 | v[..., tau], 277 | w[..., tau], 278 | maskW, 279 | dxt, 280 | dyt, 281 | dzw, 282 | cost, 283 | cosu, 284 | dt_tracer, 285 | ) 286 | 287 | dtke[2:-2, 2:-2, :, tau] = maskW[2:-2, 2:-2, :] * ( 288 | -(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :]) 289 | / (cost[np.newaxis, 2:-2, np.newaxis] * dxt[2:-2, np.newaxis, np.newaxis]) 290 | - (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :]) 291 | / (cost[np.newaxis, 2:-2, np.newaxis] * dyt[np.newaxis, 2:-2, np.newaxis]) 292 | ) 293 | dtke[:, :, 0, tau] += -flux_top[:, :, 0] / dzw[0] 294 | dtke[:, :, 1:-1, tau] += -(flux_top[:, :, 1:-1] - flux_top[:, :, :-2]) / dzw[1:-1] 295 | dtke[:, :, -1, tau] += -(flux_top[:, :, -1] - flux_top[:, :, -2]) / (0.5 * dzw[-1]) 296 | 297 | """ 298 | Adam Bashforth time stepping 299 | """ 300 | tke[:, :, :, taup1] += dt_tracer * ( 301 | (1.5 + AB_eps) * dtke[:, :, :, tau] - (0.5 + AB_eps) * dtke[:, :, :, taum1] 302 | ) 303 | 304 | return tke, dtke, tke_surf_corr 305 | 306 | 307 | def run(*inputs, device="cpu"): 308 | outputs = integrate_tke(*inputs) 309 | return outputs 310 | -------------------------------------------------------------------------------- /environment-cpu.yml: -------------------------------------------------------------------------------- 1 | name: pyhpc-bench-cpu 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - python>3.6,<3.9 7 | - pip 8 | - pip: 9 | - click 10 | - aesara 11 | - numba 12 | - taichi 13 | - torch>=1.10 14 | - tensorflow>=2.0 15 | - jax[cpu] 16 | -------------------------------------------------------------------------------- /environment-gpu.yml: -------------------------------------------------------------------------------- 1 | name: pyhpc-bench-gpu 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - python>3.6,<3.9 7 | - pip 8 | - cudnn>=8.2 9 | - cudatoolkit>=11.1 10 | - pip: 11 | - click 12 | - aesara 13 | - cupy 14 | - taichi 15 | - numba 16 | - torch>=1.10 17 | - tensorflow>=2.0 18 | - -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 19 | - jax[cuda] 20 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import re 5 | from collections import defaultdict 6 | 7 | import click 8 | import matplotlib 9 | 10 | matplotlib.use("Agg") 11 | 12 | # stupid regex matching ahead 13 | RE_RESULT = re.compile( 14 | r"".join( 15 | [ 16 | r"\s*", 17 | r"(?P(?:\d|,)+)\s*", 18 | r"(?P\w+)\s*", 19 | *( 20 | rf"(?P<{name}>(?:\d|\.|,)+)\s*" 21 | for name in ( 22 | "calls", 23 | "mean", 24 | "stdev", 25 | "min", 26 | "p25", 27 | "median", 28 | "p75", 29 | "max", 30 | "delta", 31 | ) 32 | ), 33 | ] 34 | ) 35 | ) 36 | RE_BENCHMARK = re.compile(r"benchmarks\.(?P\w+)") 37 | RE_PLATFORM = re.compile(r"Running on (?P\w+)") 38 | 39 | BACKEND_COLORS = { 40 | "numpy": "C0", 41 | "aesara": "C1", 42 | "cupy": "C2", 43 | "jax": "C3", 44 | "numba": "C4", 45 | "pytorch": "C5", 46 | "tensorflow": "C6", 47 | } 48 | 49 | 50 | def plot_results(records, benchmark, platform, outfile, plot_delta=False): 51 | import matplotlib.pyplot as plt 52 | 53 | fig, ax = plt.subplots(1, 1, figsize=(5.5, 4), dpi=75) 54 | 55 | this_record = records[(benchmark, platform)] 56 | last_coords = {} 57 | 58 | for backend, backend_values in this_record.items(): 59 | x = backend_values["size"] 60 | if plot_delta: 61 | y = backend_values["delta"] 62 | ylabel = "Relative speedup" 63 | else: 64 | y = backend_values["mean"] 65 | ylabel = "Mean runtime (s)" 66 | 67 | x, y = zip(*sorted(zip(x, y), key=lambda ix: ix[0])) 68 | 69 | plt.plot(x, y, "o--", label=backend, color=BACKEND_COLORS[backend]) 70 | last_coords[backend] = (x[-1], y[-1]) 71 | 72 | ax.spines["right"].set_visible(False) 73 | ax.spines["top"].set_visible(False) 74 | 75 | plt.xlabel("Problem size (# elements)") 76 | plt.ylabel(ylabel) 77 | 78 | plt.xscale("log") 79 | plt.yscale("log") 80 | 81 | plt.title(f'Benchmark "{benchmark}" on {platform.upper()}') 82 | fig.canvas.draw() 83 | 84 | # add annotations, make sure they don't overlap 85 | last_text_pos = 0 86 | for backend, (x, y) in sorted(last_coords.items(), key=lambda k: k[1][1]): 87 | trans = ax.transData 88 | _, tp = trans.transform((0, y)) 89 | tp = max(tp, last_text_pos + 15) 90 | _, y = trans.inverted().transform((0, tp)) 91 | 92 | plt.annotate( 93 | backend, 94 | (x, y), 95 | xytext=(10, 0), 96 | textcoords="offset points", 97 | annotation_clip=False, 98 | color=BACKEND_COLORS[backend], 99 | va="center", 100 | ) 101 | 102 | last_text_pos = tp 103 | 104 | fig.tight_layout() 105 | fig.savefig(outfile) 106 | plt.close(fig) 107 | 108 | 109 | def _parse_int(string): 110 | return int(string.replace(",", "_")) 111 | 112 | 113 | @click.command("plot") 114 | @click.argument("INFILE", type=click.File("r")) 115 | @click.option( 116 | "-o", 117 | "--outdir", 118 | required=True, 119 | type=click.Path(file_okay=False, writable=True), 120 | help="Output directory for plots", 121 | ) 122 | @click.option( 123 | "--plot-delta", 124 | is_flag=True, 125 | help="Plot relative speedup instead of absolute runtime", 126 | ) 127 | def main(infile, outdir, plot_delta): 128 | """Read a benchmark report from file or stdin and plot the results 129 | 130 | Example: 131 | 132 | $ python run.py benchmarks/equation_of_state > bench.txt 133 | 134 | $ python plot.py bench.txt -o plots 135 | 136 | """ 137 | records = {} 138 | 139 | for line in infile: 140 | bench_match = RE_BENCHMARK.match(line) 141 | if bench_match: 142 | current_benchmark = bench_match.group("name") 143 | continue 144 | 145 | platform_match = RE_PLATFORM.match(line) 146 | if platform_match: 147 | current_platform = platform_match.group("platform") 148 | continue 149 | 150 | result_match = RE_RESULT.match(line) 151 | if not result_match: 152 | continue 153 | 154 | result_line = result_match.groupdict() 155 | backend = result_line["backend"] 156 | 157 | key = (current_benchmark, current_platform) 158 | if key not in records: 159 | records[key] = {} 160 | 161 | if backend not in records[key]: 162 | records[key][backend] = defaultdict(list) 163 | 164 | record = records[key][backend] 165 | 166 | if _parse_int(result_line["size"]) in record["size"]: 167 | click.echo( 168 | f"Warning: duplicate entry for benchmark {current_benchmark} " 169 | f'on {current_platform}, backend {backend}, size {result_line["size"]} ' 170 | "- skipping" 171 | ) 172 | continue 173 | 174 | for rkey, rval in result_line.items(): 175 | if rkey in ("calls", "size"): 176 | rval = _parse_int(rval) 177 | elif rkey in ( 178 | "mean", 179 | "stdev", 180 | "min", 181 | "p25", 182 | "median", 183 | "p75", 184 | "max", 185 | "delta", 186 | ): 187 | rval = float(rval) 188 | 189 | record[rkey].append(rval) 190 | 191 | os.makedirs(outdir, exist_ok=True) 192 | 193 | for benchmark, platform in records.keys(): 194 | outfile = os.path.join(outdir, f"bench-{benchmark}-{platform}.png") 195 | plot_results(records, benchmark, platform, outfile, plot_delta) 196 | click.echo(f"Wrote {outfile}") 197 | 198 | 199 | if __name__ == "__main__": 200 | main() 201 | -------------------------------------------------------------------------------- /results/aws-plots/bench-equation_of_state-CPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/aws-plots/bench-equation_of_state-CPU.png -------------------------------------------------------------------------------- /results/aws-plots/bench-equation_of_state-GPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/aws-plots/bench-equation_of_state-GPU.png -------------------------------------------------------------------------------- /results/aws-plots/bench-isoneutral_mixing-CPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/aws-plots/bench-isoneutral_mixing-CPU.png -------------------------------------------------------------------------------- /results/aws-plots/bench-isoneutral_mixing-GPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/aws-plots/bench-isoneutral_mixing-GPU.png -------------------------------------------------------------------------------- /results/aws-plots/bench-turbulent_kinetic_energy-CPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/aws-plots/bench-turbulent_kinetic_energy-CPU.png -------------------------------------------------------------------------------- /results/aws-plots/bench-turbulent_kinetic_energy-GPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/aws-plots/bench-turbulent_kinetic_energy-GPU.png -------------------------------------------------------------------------------- /results/magni-plots/bench-equation_of_state-CPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/magni-plots/bench-equation_of_state-CPU.png -------------------------------------------------------------------------------- /results/magni-plots/bench-equation_of_state-GPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/magni-plots/bench-equation_of_state-GPU.png -------------------------------------------------------------------------------- /results/magni-plots/bench-isoneutral_mixing-CPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/magni-plots/bench-isoneutral_mixing-CPU.png -------------------------------------------------------------------------------- /results/magni-plots/bench-isoneutral_mixing-GPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/magni-plots/bench-isoneutral_mixing-GPU.png -------------------------------------------------------------------------------- /results/magni-plots/bench-turbulent_kinetic_energy-CPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/magni-plots/bench-turbulent_kinetic_energy-CPU.png -------------------------------------------------------------------------------- /results/magni-plots/bench-turbulent_kinetic_energy-GPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/pyhpc-benchmarks/d438ef8b76ecb51a019b105a6e06150e5a35c177/results/magni-plots/bench-turbulent_kinetic_energy-GPU.png -------------------------------------------------------------------------------- /results/magni-run-all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | set -e 4 | 5 | ml load nvtoolkit 6 | conda activate pyhpc-bench-gpu 7 | conda list 8 | 9 | export XLA_FLAGS="--xla_gpu_cuda_data_dir=/groups/ocean/software/software_gcc2020/nvtoolkit/11.2.2" 10 | 11 | cd `git rev-parse --show-toplevel` 12 | 13 | CUDA_VISIBLE_DEVICES="" taskset -c 23 python run.py benchmarks/equation_of_state/ --device cpu 14 | CUDA_VISIBLE_DEVICES="" taskset -c 23 python run.py benchmarks/equation_of_state/ --device cpu -s 16777216 15 | for backend in cupy jax pytorch tensorflow; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/equation_of_state/ --device gpu -b $backend -b numpy; done 16 | 17 | CUDA_VISIBLE_DEVICES="" taskset -c 23 python run.py benchmarks/isoneutral_mixing/ --device cpu 18 | CUDA_VISIBLE_DEVICES="" taskset -c 23 python run.py benchmarks/isoneutral_mixing/ --device cpu -s 16777216 19 | for backend in cupy jax pytorch; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/isoneutral_mixing/ --device gpu -b $backend -b numpy; done 20 | 21 | CUDA_VISIBLE_DEVICES="" taskset -c 23 python run.py benchmarks/turbulent_kinetic_energy/ --device cpu 22 | CUDA_VISIBLE_DEVICES="" taskset -c 23 python run.py benchmarks/turbulent_kinetic_energy/ --device cpu -s 16777216 23 | for backend in jax pytorch; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/turbulent_kinetic_energy/ --device gpu -b $backend -b numpy; done 24 | -------------------------------------------------------------------------------- /results/pyhpc_benchmarks_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pyhpc-benchmarks-colab.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "TPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "A_5C1IvdfxkF" 21 | }, 22 | "source": [ 23 | "#pyhpc-benchmarks @ Google Colab\n", 24 | "\n", 25 | "To run all benchmarks, you need to switch the runtime type to match the corresponding section (CPU, TPU, GPU)." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "TTViNK-9OfRJ" 32 | }, 33 | "source": [ 34 | "!rm -rf pyhpc-benchmarks; git clone https://github.com/dionhaefner/pyhpc-benchmarks.git" 35 | ], 36 | "execution_count": null, 37 | "outputs": [] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "metadata": { 42 | "id": "Eyc45XkjQB1X" 43 | }, 44 | "source": [ 45 | "%cd pyhpc-benchmarks" 46 | ], 47 | "execution_count": null, 48 | "outputs": [] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "metadata": { 53 | "id": "RbM7XH04MwFA" 54 | }, 55 | "source": [ 56 | "# check CPU model\n", 57 | "!lscpu |grep 'Model name'" 58 | ], 59 | "execution_count": null, 60 | "outputs": [] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": { 65 | "id": "cK3jm6V_P4pB" 66 | }, 67 | "source": [ 68 | "## CPU" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "metadata": { 74 | "id": "exG5HvsIQtyE" 75 | }, 76 | "source": [ 77 | "!pip install -U -q numba aesara" 78 | ], 79 | "execution_count": null, 80 | "outputs": [] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "id": "tD19gJ_-QAiZ" 86 | }, 87 | "source": [ 88 | "!taskset -c 0 python run.py benchmarks/equation_of_state/" 89 | ], 90 | "execution_count": null, 91 | "outputs": [] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "metadata": { 96 | "id": "NYykl19BWfQI" 97 | }, 98 | "source": [ 99 | "!taskset -c 0 python run.py benchmarks/isoneutral_mixing/" 100 | ], 101 | "execution_count": null, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "id": "zf2RaRlPXpM6" 108 | }, 109 | "source": [ 110 | "!taskset -c 0 python run.py benchmarks/turbulent_kinetic_energy/" 111 | ], 112 | "execution_count": null, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": { 118 | "id": "oOIzGKsPP0ui" 119 | }, 120 | "source": [ 121 | "## TPU\n", 122 | "\n", 123 | "Make sure to set accelerator to \"TPU\" before executing this." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "metadata": { 129 | "id": "JHOxWXecn3kx" 130 | }, 131 | "source": [ 132 | "import jax.tools.colab_tpu\n", 133 | "jax.tools.colab_tpu.setup_tpu()" 134 | ], 135 | "execution_count": null, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": { 141 | "id": "9-tQlcfZOzm0" 142 | }, 143 | "source": [ 144 | "!python run.py benchmarks/equation_of_state -b jax -b numpy --device tpu" 145 | ], 146 | "execution_count": null, 147 | "outputs": [] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "metadata": { 152 | "id": "rKcTVbiVPXFu" 153 | }, 154 | "source": [ 155 | "!python run.py benchmarks/isoneutral_mixing -b jax -b numpy --device tpu" 156 | ], 157 | "execution_count": null, 158 | "outputs": [] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "metadata": { 163 | "id": "gfIlTgZol9OA" 164 | }, 165 | "source": [ 166 | "!python run.py benchmarks/turbulent_kinetic_energy -b jax -b numpy --device tpu" 167 | ], 168 | "execution_count": null, 169 | "outputs": [] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": { 174 | "id": "RDoapE1YPrpN" 175 | }, 176 | "source": [ 177 | "## GPU\n", 178 | "\n", 179 | "Make sure to set accelerator to \"GPU\" before executing this." 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "metadata": { 185 | "id": "b4CQKseuMnzE" 186 | }, 187 | "source": [ 188 | "# get GPU model\n", 189 | "!nvidia-smi -L" 190 | ], 191 | "execution_count": null, 192 | "outputs": [] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "metadata": { 197 | "id": "Azo78zrdo88Y" 198 | }, 199 | "source": [ 200 | "!for backend in jax tensorflow pytorch cupy; do python run.py benchmarks/equation_of_state/ --device gpu -b $backend -b numpy; done" 201 | ], 202 | "execution_count": null, 203 | "outputs": [] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "metadata": { 208 | "id": "Ps8zEacsPWQW" 209 | }, 210 | "source": [ 211 | "!for backend in jax pytorch cupy; do python run.py benchmarks/isoneutral_mixing/ --device gpu -b $backend -b numpy; done" 212 | ], 213 | "execution_count": null, 214 | "outputs": [] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "metadata": { 219 | "id": "ogXoFxFAd0KI" 220 | }, 221 | "source": [ 222 | "!for backend in jax pytorch; do python run.py benchmarks/turbulent_kinetic_energy/ --device gpu -b $backend -b numpy; done" 223 | ], 224 | "execution_count": null, 225 | "outputs": [] 226 | } 227 | ] 228 | } -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import random 4 | import itertools 5 | 6 | import click 7 | 8 | from backends import ( 9 | __backends__ as setup_functions, 10 | check_backend_conflicts, 11 | convert_to_numpy, 12 | BackendNotSupported, 13 | BackendConflict, 14 | ) 15 | from utilities import ( 16 | Timer, 17 | estimate_repetitions, 18 | format_output, 19 | compute_statistics, 20 | get_benchmark_module, 21 | check_consistency, 22 | ) 23 | 24 | 25 | DEFAULT_SIZE = tuple(2 ** i for i in range(12, 23, 2)) 26 | 27 | 28 | @click.command("run") 29 | @click.argument( 30 | "BENCHMARK", 31 | required=True, 32 | type=click.Path(exists=True, file_okay=False, readable=True), 33 | ) 34 | @click.option( 35 | "-s", 36 | "--size", 37 | required=False, 38 | multiple=True, 39 | default=DEFAULT_SIZE, 40 | show_default=True, 41 | type=click.INT, 42 | help="Run benchmark for this array size (repeatable)", 43 | ) 44 | @click.option( 45 | "-b", 46 | "--backend", 47 | required=False, 48 | multiple=True, 49 | default=None, 50 | type=click.Choice(setup_functions.keys()), 51 | help="Run benchmark with this backend (repeatable) [default: run all backends]", 52 | ) 53 | @click.option( 54 | "-r", 55 | "--repetitions", 56 | required=False, 57 | default=None, 58 | type=click.INT, 59 | help="Fixed number of iterations to run for each size and backend [default: auto-detect]", 60 | ) 61 | @click.option( 62 | "--burnin", 63 | required=False, 64 | default=1, 65 | type=click.INT, 66 | show_default=True, 67 | help="Number of initial iterations that are disregarded for final statistics", 68 | ) 69 | @click.option( 70 | "--device", 71 | required=False, 72 | default="cpu", 73 | type=click.Choice(["cpu", "gpu", "tpu"]), 74 | show_default=True, 75 | help="Run benchmarks on given device where supported by the backend", 76 | ) 77 | def main(benchmark, size=None, backend=None, repetitions=None, burnin=1, device="cpu"): 78 | """HPC benchmarks for Python 79 | 80 | Usage: 81 | 82 | $ python run.py benchmarks/ 83 | 84 | Examples: 85 | 86 | $ taskset -c 0 python run.py benchmarks/equation_of_state 87 | 88 | $ python run.py benchmarks/equation_of_state -b numpy -b jax --device gpu 89 | 90 | More information: 91 | 92 | https://github.com/dionhaefner/pyhpc-benchmarks 93 | 94 | """ 95 | try: 96 | bm_module, bm_identifier = get_benchmark_module(benchmark) 97 | except ImportError as e: 98 | click.echo(f"Error while loading benchmark {benchmark}: {e!s}", err=True) 99 | raise click.Abort() 100 | 101 | available_backends = set(bm_module.__implementations__) 102 | 103 | if len(backend) == 0: 104 | backend = available_backends.copy() 105 | else: 106 | backend = set(backend) 107 | 108 | unsupported_backends = [b for b in backend if b not in available_backends] 109 | 110 | for b in unsupported_backends: 111 | click.echo( 112 | f'Backend "{b}" is not supported by chosen benchmark (skipping)', err=True 113 | ) 114 | backend.remove(b) 115 | 116 | for b in backend.copy(): 117 | try: 118 | with setup_functions[b](device=device) as bmod: 119 | click.echo(f"Using {b} version {bmod.__version__}") 120 | except BackendNotSupported as e: 121 | click.echo( 122 | f'Setup for backend "{b}" failed (skipping), reason: {e!s}', err=True 123 | ) 124 | backend.remove(b) 125 | 126 | try: 127 | check_backend_conflicts(backend, device) 128 | except BackendConflict as exc: 129 | click.echo(f"Backend conflict: {exc!s}", err=True) 130 | raise click.Abort() 131 | 132 | runs = sorted(itertools.product(backend, size)) 133 | 134 | if len(runs) == 0: 135 | click.echo("Nothing to do") 136 | return 137 | 138 | timings = {run: [] for run in runs} 139 | 140 | if repetitions is None: 141 | click.echo("Estimating repetitions...") 142 | repetitions = {} 143 | 144 | for b, s in runs: 145 | # use end-to-end runtime for repetition estimation 146 | with setup_functions[b](device=device): 147 | run = bm_module.get_callable(b, s, device=device) 148 | repetitions[(b, s)] = estimate_repetitions(run) 149 | else: 150 | repetitions = {(b, s): repetitions for b, s in runs} 151 | 152 | all_runs = list( 153 | itertools.chain.from_iterable( 154 | [run] * (repetitions[run] + burnin) for run in runs 155 | ) 156 | ) 157 | random.shuffle(all_runs) 158 | 159 | results = {} 160 | checked = {r: False for r in runs} 161 | 162 | pbar = click.progressbar( 163 | label=f"Running {len(all_runs)} benchmarks...", length=len(runs) 164 | ) 165 | 166 | try: 167 | with pbar: 168 | for (b, size) in all_runs: 169 | with setup_functions[b](device=device): 170 | run = bm_module.get_callable(b, size, device=device) 171 | with Timer() as t: 172 | res = run() 173 | 174 | # YOWO (you only warn once) 175 | if not checked[(b, size)]: 176 | if size in results: 177 | is_consistent = check_consistency( 178 | results[size], convert_to_numpy(res, b, device) 179 | ) 180 | if not is_consistent: 181 | click.echo( 182 | f"\nWarning: inconsistent results for size {size}", 183 | err=True, 184 | ) 185 | else: 186 | results[size] = convert_to_numpy(res, b, device) 187 | checked[(b, size)] = True 188 | 189 | timings[(b, size)].append(t.elapsed) 190 | pbar.update(1.0 / (repetitions[(b, size)] + burnin)) 191 | 192 | # push pbar to 100% 193 | pbar.update(1.0) 194 | 195 | for run in runs: 196 | assert len(timings[run]) == repetitions[run] + burnin 197 | 198 | finally: 199 | stats = compute_statistics(timings) 200 | click.echo(format_output(stats, bm_identifier, device=device)) 201 | 202 | 203 | if __name__ == "__main__": 204 | main() 205 | -------------------------------------------------------------------------------- /utilities.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import collections 4 | import importlib 5 | import os 6 | 7 | import numpy as np 8 | 9 | 10 | class Timer: 11 | def __init__(self): 12 | self.elapsed = float("nan") 13 | 14 | def __enter__(self): 15 | self._start = time.perf_counter() 16 | return self 17 | 18 | def __exit__(self, type, value, traceback): 19 | if value is None: 20 | self.elapsed = time.perf_counter() - self._start 21 | 22 | 23 | def estimate_repetitions(func, args=(), target_time=10, powers_of=10): 24 | # call function once for warm-up 25 | func(*args) 26 | 27 | # some backends need an extra nudge (looking at you, PyTorch) 28 | func(*args) 29 | 30 | # call again and measure time 31 | with Timer() as t: 32 | func(*args) 33 | 34 | time_per_rep = t.elapsed 35 | exponent = math.log(target_time / time_per_rep, powers_of) 36 | num_reps = int(powers_of ** round(exponent)) 37 | return max(powers_of, num_reps) 38 | 39 | 40 | def compute_statistics(timings, burnin=1): 41 | stats = [] 42 | 43 | for (backend, size), t in timings.items(): 44 | t = t[burnin:] 45 | repetitions = len(t) 46 | 47 | if repetitions: 48 | mean = np.mean(t) 49 | stdev = np.std(t) 50 | percentiles = np.percentile(t, [0, 25, 50, 75, 100]) 51 | else: 52 | mean = stdev = float("nan") 53 | percentiles = [float("nan")] * 5 54 | 55 | stats.append( 56 | (size, backend, repetitions, mean, stdev, *percentiles, float("nan")) 57 | ) 58 | 59 | stats = np.array( 60 | stats, 61 | dtype=[ 62 | ("size", "i8"), 63 | ("backend", object), 64 | ("calls", "i8"), 65 | ("mean", "f4"), 66 | ("stdev", "f4"), 67 | ("min", "f4"), 68 | ("25%", "f4"), 69 | ("median", "f4"), 70 | ("75%", "f4"), 71 | ("max", "f4"), 72 | ("Δ", "f4"), 73 | ], 74 | ) 75 | 76 | # add deltas 77 | sizes = np.unique(stats["size"]) 78 | for s in sizes: 79 | mask = stats["size"] == s 80 | 81 | # measure relative to NumPy if present, otherwise worst backend 82 | if "numpy" in stats["backend"][mask]: 83 | reference_time = stats["mean"][mask & (stats["backend"] == "numpy")] 84 | else: 85 | reference_time = np.nanmax(stats["mean"][mask]) 86 | 87 | stats["Δ"][mask] = reference_time / stats["mean"][mask] 88 | 89 | return stats 90 | 91 | 92 | def format_output(stats, benchmark_title, device="cpu"): 93 | stats = np.sort(stats, axis=0, order=["size", "mean", "max", "median"]) 94 | 95 | header = stats.dtype.names 96 | col_widths = collections.defaultdict(lambda: 8) 97 | col_widths.update(size=12, backend=10) 98 | 99 | def format_col(col_name, value, is_time=False): 100 | col_width = col_widths[col_name] 101 | 102 | if np.issubdtype(type(value), np.integer): 103 | typecode = "," 104 | else: 105 | typecode = ".3f" 106 | 107 | if is_time: 108 | format_string = f"{{value:>{col_width}{typecode}}}" 109 | else: 110 | format_string = f"{{value:<{col_width}}}" 111 | 112 | return format_string.format(value=value) 113 | 114 | out = [ 115 | "", 116 | benchmark_title, 117 | "=" * len(benchmark_title), 118 | f"Running on {device.upper()}", 119 | "", 120 | " ".join(format_col(s, s) for s in header), 121 | ] 122 | 123 | out.append("-" * len(out[-1])) 124 | 125 | current_size = None 126 | for row in stats: 127 | # print empty line on size change 128 | size = row[0] 129 | if current_size is not None and size != current_size: 130 | out.append("") 131 | current_size = size 132 | 133 | out.append( 134 | " ".join( 135 | format_col(n, s, not isinstance(s, str)) for n, s in zip(header, row) 136 | ) 137 | ) 138 | 139 | out.extend( 140 | [ 141 | "", 142 | "(time in wall seconds, less is better)", 143 | ] 144 | ) 145 | 146 | return "\n".join(out) 147 | 148 | 149 | def get_benchmark_module(file_path): 150 | base_path = os.path.dirname(os.path.abspath(__file__)) 151 | module_path = os.path.relpath(file_path, base_path) 152 | import_path = ".".join(os.path.split(module_path)) 153 | bm_module = importlib.import_module(import_path) 154 | return bm_module, import_path 155 | 156 | 157 | def check_consistency(res1, res2): 158 | if isinstance(res1, (tuple, list)): 159 | if not len(res1) == len(res2): 160 | return False 161 | 162 | return all(check_consistency(r1, r2) for r1, r2 in zip(res1, res2)) 163 | 164 | assert isinstance(res1, np.ndarray) 165 | assert isinstance(res2, np.ndarray) 166 | return np.allclose(res1, res2) 167 | --------------------------------------------------------------------------------