├── .gitignore ├── LICENSE.txt ├── README.md ├── ckpt.py ├── demo.gif ├── main.py ├── plot ├── __init__.py ├── animation.py ├── navierstokes.py ├── poisson.py ├── reaction.py ├── reaction.solver.py └── visualize.py ├── run ├── .sh ├── navierstokes.forced.sh ├── navierstokes.sh ├── navierstokes.yaml ├── poisson.dirichlet.sh ├── poisson.periodic.sh ├── reaction.sh └── requirements.txt └── src ├── README ├── __init__.py ├── basis ├── README ├── __init__.py ├── chebyshev.py └── fourier.py ├── dists.py ├── model ├── README ├── __init__.py ├── _base.py ├── fno │ ├── __init__.py │ └── spectral.py └── sno │ └── spectral.py ├── pde ├── README ├── __init__.py ├── _domain.py ├── _params.py ├── mollifier.py ├── navierstokes │ ├── __init__.py │ └── generate.py ├── poisson │ ├── __init__.py │ ├── generate.py │ └── generate.wls └── reaction │ ├── __init__.py │ └── generate.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # PYTHON # 2 | ########## 3 | __pycache__ 4 | 5 | # FILE # 6 | ########## 7 | *.npy 8 | *.jpg 9 | *.gif 10 | 11 | test* 12 | log 13 | 14 | # DEMO # 15 | ########## 16 | 17 | !demo.gif -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ASK-Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

Neural Spectral Methods

3 | 4 | This repo contains the JAX implementation of our ICLR 2024 paper, 5 | 6 | *Neural Spectral Methods: Self-supervised learning in the spectral domain*. 7 |
8 | Yiheng Du, Nithin Chalapathi, Aditi Krishnapriyan. 9 |
10 | https://arxiv.org/abs/2312.05225. 11 | 12 | Neural Spectral Methods (NSM) is a class of machine learning method designed for solving parametric PDEs within the spectral domain. Above is a demonstration of NSM's prediction of the 2D Navier-Stokes equation. 13 | 14 | > We present Neural Spectral Methods, a technique to solve parametric Partial Differential Equations (PDEs), grounded in classical spectral methods. Our method uses orthogonal bases to learn PDE solutions as mappings between spectral coefficients. In contrast to current machine learning approaches which enforce PDE constraints by minimizing the numerical quadrature of the residuals in the spatiotemporal domain, we leverage Parseval’s identity and introduce a new training strategy through a spectral loss. Our spectral loss enables more efficient differentiation through the neural network, and substantially reduces training complexity. At inference time, the computational cost of our method remains constant, regardless of the spatiotemporal resolution of the domain. Our experimental results demonstrate that our method significantly outperforms previous machine learning approaches in terms of speed and accuracy by one to two orders of magnitude on multiple different problems. When compared to numerical solvers of the same accuracy, our method demonstrates a 10× increase in performance speed. 15 | 16 | ## Code structure 17 | 18 | The structure of this codebase is outlined in the [`src`](src) directory. This includes definitions for [PDE systems](src/pde), utilities for [orthogonal basis](src/basis), and general implementations of [baseline models](src/model). This codebase is self-contained and can serve as a standalong module for other purposes. To experiment with new problems, follow the templates for instantiating the abstract PDE class in the [`src/pde`](src/pde) directory. 19 | 20 | ### Arguments 21 | 22 | `main.py` accepts command line arguments as specified below. Refer to each bash script in the `run` directory for configurations of each experiment. 23 | 24 | ``` 25 | usage: main.py [-h] [--seed SEED] [--f64] [--smi] [--pde PDE] [--model {fno,sno}] 26 | [--spectral] [--hdim HDIM] [--depth DEPTH] [--activate ACTIVATE] 27 | [--mode MODE [MODE ...]] [--grid GRID] [--fourier] [--cheb] 28 | {train,test} ... 29 | 30 | positional arguments: 31 | {train,test} 32 | train train model from scratch 33 | test enter REPL after loading 34 | 35 | optional arguments: 36 | -h, --help show this help message and exit 37 | --seed SEED random seed 38 | --f64 use double precision 39 | --smi profile memory usage 40 | --pde PDE PDE name 41 | --model {fno,sno} model name 42 | --spectral spectral training 43 | --hdim HDIM hidden dimension 44 | --depth DEPTH number of layers 45 | --activate ACTIVATE activation name 46 | --mode MODE [MODE ...] 47 | number of modes per dim 48 | --grid GRID training grid size 49 | --fourier fourier basis only 50 | --cheb using chebyshev 51 | ``` 52 | 53 | ``` 54 | usage: main.py train [-h] --bs BS --lr LR [--clip CLIP] --schd SCHD --iter ITER 55 | --ckpt CKPT --note NOTE [--vmap VMAP] [--save] 56 | 57 | optional arguments: 58 | -h, --help show this help message and exit 59 | --bs BS batch size 60 | --lr LR learning rate 61 | --clip CLIP gradient clipping 62 | --schd SCHD scheduler name 63 | --iter ITER total iterations 64 | --ckpt CKPT checkpoint every n iters 65 | --note NOTE leave a note here 66 | --vmap VMAP vectorization size 67 | --save save model checkpoints 68 | ``` 69 | 70 | ``` 71 | usage: main.py test [-h] [--load LOAD] 72 | 73 | optional arguments: 74 | -h, --help show this help message and exit 75 | --load LOAD saved model path 76 | ``` 77 | 78 | ## Run 79 | 80 | ### Environment 81 | 82 | ```bash 83 | ## We are using jax version 0.4.7. 84 | 85 | pip install -r run/requirements.txt 86 | 87 | ## Please install jaxlib based on your own machine configuration. 88 | 89 | pip install https://storage.googleapis.com/jax-releases/ 90 | ``` 91 | 92 | ### Quick test 93 | 94 | ```bash 95 | python main.py test 96 | ``` 97 | 98 | ### Train models 99 | 100 | - Local machine 101 | 102 | ```bash 103 | ## Generate data .. 104 | 105 | python -m src.pde..generate 106 | 107 | ## .. and launch. 108 | 109 | bash run/.sh 110 | ``` 111 | 112 | - Cloud compute 113 | 114 | ```bash 115 | ## Using [SkyPilot](https://skypilot.readthedocs.io). 116 | ## Launch cloud jobs based on .yaml configurations .. 117 | 118 | sky launch -c ns run/navierstokes.yaml 119 | sky exec ns bash run/navierstokes.sh --env seed= 120 | 121 | ## .. and collect data. 122 | 123 | rsync -Pavzr ns:~/sky_workdir/log log/gcp 124 | ``` 125 | 126 | ### Plot results 127 | 128 | ```bash 129 | python -m plot. 130 | ``` 131 | 132 | ### Trouble shooting 133 | 134 | 1. The `XLA_PYTHON_CLIENT_MEM_FRACTION` environment variable in the script maximizes memory utilization. This is beneficial if your machine has limited memory (e.g. less than 32GB of GPU memory), but it can lead to initialization issues in certain edge cases. If such issues arise, simply remove the line. 135 | 136 | 2. For machines with very limited memory (e.g. 16GB of GPU memory), consider setting the `vmap` environment variable to a small integer. This allows loop-based mapping across the batch dimension and uses vectorization only for the number of elements specified by the `vmap` argument. While this approach saves GPU memory, it makes the FLOP measurements for each model inaccurate. Keep this in mind when interpreting the cost estimation message in each run. 137 | 138 | ## Timeline 139 | 140 | - 2023.12.9: initial commit 141 | - 2023.12.11: add arXiv link 142 | - 2023.12.22: release code 143 | - 2024.01.19: update citation 144 | - 2024.01.31: update ICLR citation 145 | 146 | ## Citation 147 | 148 | If you find this repository useful, please cite our work: 149 | 150 | ``` 151 | @article{du2024neural, 152 | title={Neural Spectral Methods: Self-supervised learning in the spectral domain}, 153 | journal={The Twelfth International Conference on Learning Representations}, 154 | author={Du, Yiheng and Chalapathi, Nithin and Krishnapriyan, Aditi}, 155 | year={2024} 156 | } 157 | ``` 158 | 159 | -------------------------------------------------------------------------------- /ckpt.py: -------------------------------------------------------------------------------- 1 | from src import * 2 | from threading import * 3 | 4 | class Checkpoint(Thread): 5 | 6 | def __init__(self, **kwargs): 7 | super().__init__(daemon=True) 8 | 9 | from queue import Queue 10 | self.metric = Queue() 11 | self.predict = Lock() 12 | 13 | # ---------------------------------------------------------------------------- # 14 | # INITIALIZE # 15 | # ---------------------------------------------------------------------------- # 16 | 17 | import datetime, os 18 | time = datetime.datetime.now().strftime("%c") 19 | 20 | self.path = "log/" + (kwargs["note"] or time) 21 | self.title = kwargs["pde"] + ":" + kwargs["model"] 22 | if kwargs["spectral"]: self.title += "+spectral" 23 | 24 | os.makedirs(self.path, exist_ok=False) 25 | print(kwargs, file=open(f"{self.path}/cfg", "w")) 26 | 27 | @property 28 | def prediction(self): 29 | 30 | with self.predict: 31 | return self._predict 32 | 33 | @prediction.setter 34 | def prediction(self, answer): 35 | 36 | with self.predict: 37 | self._predict = answer 38 | 39 | def run(self): 40 | 41 | from collections import defaultdict 42 | log = defaultdict(lambda: list()) 43 | 44 | while True: 45 | 46 | # ---------------------------------------------------------------------------- # 47 | # CHECKPOINT # 48 | # ---------------------------------------------------------------------------- # 49 | 50 | try: 51 | for _ in range(max(1, self.metric.qsize())): 52 | metric, it = self.metric.get() 53 | 54 | for key, value in metric.items(): 55 | log[key].append(value) 56 | 57 | except: pass 58 | 59 | import matplotlib.pyplot as plt 60 | import matplotlib.colors as clr 61 | 62 | import scienceplots; plt.style.use(["science", "no-latex"]) 63 | 64 | # ---------------------------------------------------------------------------- # 65 | # METRIC # 66 | # ---------------------------------------------------------------------------- # 67 | 68 | for key, values in log.items(): 69 | 70 | fig, ax = plt.subplots() 71 | xs = np.arange(len(values))+1 72 | 73 | if all(isinstance(value, Array) for value in values): 74 | 75 | ax.plot(xs, ys:=np.array(values), label=key) 76 | np.save(f"{self.path}/metric.{key}.npy", ys) 77 | 78 | if all(isinstance(value, Dict) for value in values): 79 | 80 | for subkey in set.union(*map(set, map(dict.keys, values))): 81 | ys = np.array(list(map(O.itemgetter(subkey), values))) 82 | 83 | ax.plot(xs, ys, label=f"{key}:{subkey}") 84 | np.save(f"{self.path}/metric.{key}:{subkey}.npy", ys) 85 | 86 | ax.set_xscale("log") 87 | ax.set_yscale("log") 88 | ax.set_title(self.title) 89 | ax.legend() 90 | 91 | fig.savefig(f"{self.path}/metric.{key}.jpg") 92 | plt.close(fig) 93 | 94 | # ---------------------------------------------------------------------------- # 95 | # IMAGE # 96 | # ---------------------------------------------------------------------------- # 97 | 98 | u, uhat = self.prediction 99 | np.save(f"{self.path}/uhat.npy", uhat) 100 | 101 | N = max(min(16, len(u)), 2) 102 | u, uhat = u[:N], uhat[:N] 103 | 104 | fig = plt.figure(figsize=(5 * N, 5 * 3)) 105 | subfig = fig.subfigures(ncols=N) 106 | 107 | def create(subfig, u: X, uhat: X, i: int = None): 108 | axes = subfig.subplots(nrows=3) 109 | 110 | vmin, vmax = u.min().item(), u.max().item() 111 | if i is not None: u = u[i]; uhat = uhat[i] 112 | 113 | true = axes[0].imshow(u, vmin=vmin, vmax=vmax) 114 | pred = axes[1].imshow(uhat, vmin=vmin, vmax=vmax) 115 | 116 | vlim = max(abs(vmin), abs(vmax)) 117 | diff = axes[2].imshow(uhat - u, 118 | cmap="Spectral", 119 | norm=clr.SymLogNorm( 120 | linthresh=vlim / 100, 121 | vmin=-vlim / 10, 122 | vmax=+vlim / 10, 123 | ) 124 | ) 125 | 126 | axes[0].axis("off") 127 | axes[1].axis("off") 128 | axes[2].axis("off") 129 | 130 | subfig.colorbar(true, ax=axes[:2]) 131 | subfig.colorbar(diff, ax=axes[2:]) 132 | 133 | return true, pred, diff 134 | 135 | if u.ndim-2 == 2: 136 | 137 | _ = list(map(create, subfig, u, uhat)) 138 | fig.savefig(f"{self.path}/image.jpg") 139 | 140 | plt.close(fig) 141 | 142 | # ---------------------------------------------------------------------------- # 143 | # EXIT # 144 | # ---------------------------------------------------------------------------- # 145 | 146 | if it is None: break 147 | -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASK-Berkeley/Neural-Spectral-Methods/090e7a173f27734dbae5ec479a410ca1748981af/demo.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from src import * 2 | from src.pde import * 3 | from src.model import * 4 | from src.basis import * 5 | 6 | def main(cfg: Dict[str, Any]): 7 | 8 | # configure precision at the very beginning 9 | jax.config.update("jax_enable_x64", cfg["f64"]) 10 | 11 | if cfg["smi"]: 12 | 13 | import jax_smi as smi 14 | smi.initialise_tracking() 15 | 16 | prng = random.PRNGKey(cfg["seed"]) 17 | rngs = RNGS(prng, ["params", "sample"]) 18 | 19 | from importlib import import_module 20 | 21 | if cfg["pde"] is not None: 22 | 23 | col, name = cfg["pde"].split(".", 2) 24 | mod = import_module(f"src.pde.{col}") 25 | 26 | pde: PDE = getattr(mod, name) 27 | pde.solution # load solution 28 | 29 | if cfg["model"] is not None: 30 | 31 | col, name = cfg["model"], cfg["model"].upper() 32 | 33 | if cfg["spectral"]: col += ".spectral" 34 | mod = import_module(f"src.model.{col}") 35 | 36 | Model: Solver = getattr(mod, name) 37 | model = Model(pde, cfg) 38 | 39 | # ---------------------------------------------------------------------------- # 40 | # TRAIN # 41 | # ---------------------------------------------------------------------------- # 42 | 43 | if cfg["action"] == "train": 44 | 45 | from src.train import step, eval 46 | train = Trainer(model, pde, cfg) 47 | variable, state = train.init_with_output(next(rngs), method="init") 48 | 49 | step = utils.jit(F.partial(train.apply, method=step, mutable=True)) 50 | step(state, variable, rngs=next(rngs)) # compile train iteration 51 | 52 | def evaluate(): 53 | 54 | global metric, predictions 55 | metric, predictions = train.apply(state, variable, 56 | method=eval, rngs=next(rngs)) 57 | 58 | if cfg["save"]: 59 | 60 | from flax.training.checkpoints import save_checkpoint 61 | save_checkpoint(f"{ckpt.path}/variable", variable, it) 62 | 63 | from ckpt import Checkpoint 64 | ckpt = Checkpoint(**cfg) 65 | ckpt.start() 66 | 67 | from tqdm import trange 68 | for it in (pbar:=trange(cfg["iter"])): 69 | 70 | if not it % cfg["ckpt"]: evaluate() 71 | 72 | # ----------------------------------- STEP ----------------------------------- # 73 | 74 | import time 75 | rate = time.time() 76 | 77 | (variable, loss), state = jax.tree_map(jax.block_until_ready, 78 | step(state, variable, rngs=next(rngs))) 79 | 80 | rate = np.array(time.time() - rate) 81 | 82 | # ----------------------------------- CKPT ----------------------------------- # 83 | 84 | metric.update(loss=loss, rate=rate) 85 | 86 | ckpt.metric.put((metric.copy(), it)) 87 | ckpt.prediction = predictions 88 | 89 | pbar.set_postfix(jax.tree_map(lambda x: f"{x:.2e}", metric)) 90 | 91 | evaluate() 92 | 93 | ckpt.metric.put((metric, None)) 94 | ckpt.prediction = predictions 95 | ckpt.join() 96 | 97 | return pde, model.bind(variable, rngs=next(rngs)) 98 | 99 | # ---------------------------------------------------------------------------- # 100 | # LOAD # 101 | # ---------------------------------------------------------------------------- # 102 | 103 | else: 104 | 105 | if cfg["load"]: 106 | 107 | from flax.training.checkpoints import restore_checkpoint 108 | variable = restore_checkpoint(cfg["load"], target=None) 109 | model = model.bind(variable, rngs=next(rngs)) 110 | 111 | exit(utils.repl(locals())) 112 | 113 | # ---------------------------------------------------------------------------- # 114 | # ARGPARSE # 115 | # ---------------------------------------------------------------------------- # 116 | 117 | if __name__ == "__main__": 118 | 119 | import argparse 120 | args = argparse.ArgumentParser() 121 | action = args.add_subparsers(dest="action") 122 | 123 | args.add_argument("--seed", type=int, default=19260817, help="random seed") 124 | args.add_argument("--f64", dest="f64", action="store_true", help="use double precision") 125 | args.add_argument("--smi", dest="smi", action="store_true", help="profile memory usage") 126 | 127 | args.add_argument("--pde", type=str, help="PDE name") 128 | args.add_argument("--model", type=str, help="model name", choices=["fno", "sno"]) # --cheb=cno 129 | args.add_argument("--spectral", dest="spectral", action="store_true", help="spectral training") 130 | 131 | # ----------------------------------- MODEL ---------------------------------- # 132 | 133 | args.add_argument("--hdim", type=int, help="hidden dimension") 134 | args.add_argument("--depth", type=int, help="number of layers") 135 | args.add_argument("--activate", type=str, help="activation name") 136 | 137 | args.add_argument("--mode", type=int, nargs="+", help="number of modes per dim") 138 | args.add_argument("--grid", type=int, default=256, help="training grid size") 139 | 140 | ## ablation study 141 | 142 | args.add_argument("--fourier", dest="fourier", action="store_true", help="fourier basis only") 143 | args.add_argument("--cheb", dest="cheb", action="store_true", help="using chebyshev") 144 | 145 | # ----------------------------------- TRAIN ---------------------------------- # 146 | 147 | args_train = action.add_parser("train", help="train model from scratch") 148 | 149 | args_train.add_argument("--bs", type=int, required=True, help="batch size") 150 | args_train.add_argument("--lr", type=float, required=True, help="learning rate") 151 | args_train.add_argument("--clip", type=float, required=False, help="gradient clipping") 152 | args_train.add_argument("--schd", type=str, required=True, help="scheduler name") 153 | args_train.add_argument("--iter", type=int, required=True, help="total iterations") 154 | args_train.add_argument("--ckpt", type=int, required=True, help="checkpoint every n iters") 155 | args_train.add_argument("--note", type=str, required=True, help="leave a note here") 156 | 157 | args_train.add_argument("--vmap", type=lambda x: int(x) if x else None, help="vectorization size") 158 | args_train.add_argument("--save", dest="save", action="store_true", help="save model checkpoints") 159 | 160 | # ----------------------------------- TEST ----------------------------------- # 161 | 162 | args_test = action.add_parser("test", help="enter REPL after loading") 163 | 164 | args_test.add_argument("--load", type=str, help="saved model path") 165 | 166 | # ---------------------------------------------------------------------------- # 167 | # MAIN # 168 | # ---------------------------------------------------------------------------- # 169 | 170 | args = args.parse_args() 171 | cfg = vars(args); print(f"{cfg=}") 172 | 173 | pde, model = main(cfg) 174 | # utils.repl(locals()) 175 | -------------------------------------------------------------------------------- /plot/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import scienceplots 4 | import functools as F 5 | 6 | import matplotlib.pyplot as plt 7 | import matplotlib.colors as clr 8 | import matplotlib.ticker as tkr 9 | 10 | plt.style.use(style=["science", "nature", "std-colors", "grid"]) 11 | CLR = plt.rcParams["axes.prop_cycle"].by_key()["color"] 12 | CLR = [CLR[0], CLR[2], CLR[3], CLR[1]] 13 | 14 | TIMES = "\\!\\times\\!" 15 | PLUS = "\\!+\\!" 16 | 17 | def load(pde: str, metric: str, dir: str, METHOD): 18 | 19 | def random(method: str): 20 | def call(seed: int): 21 | try: return np.load(f"{dir}/{pde}{method}.{seed}/{metric}.npy") 22 | except: return np.array([np.nan]) 23 | return call 24 | 25 | return { method: list(map(random(method), range(4))) for method in METHOD } 26 | 27 | def color(method: str) -> str: 28 | 29 | if method[0] == ":": 30 | 31 | if method[-1] == "M": 32 | 33 | return CLR[0] 34 | 35 | return CLR[-1] 36 | 37 | if method[0] == "x": 38 | 39 | if method[-1] == "C": 40 | 41 | return CLR[2] 42 | 43 | return CLR[1] 44 | 45 | def lines(method: str) -> str: 46 | 47 | if method[0] == ":": 48 | 49 | return "-" 50 | 51 | if method[0] == "x": 52 | 53 | if method[-1] == "C": 54 | 55 | return "-" 56 | 57 | return { 58 | "64": ":", 59 | "128": "--", 60 | "256": "-", 61 | 62 | "96": "-.", 63 | }[method[1:]] 64 | 65 | def reorder(ax, order): 66 | 67 | handles, labels = ax.get_legend_handles_labels() 68 | return [handles[i] for i in order], [labels[i] for i in order] 69 | -------------------------------------------------------------------------------- /plot/animation.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from sys import argv 3 | 4 | N = list(map(int, argv[1:])) 5 | 6 | u = np.load("log/aws/u.npy")[np.r_[N]] 7 | unsm = np.load(f"log/aws/unsm.npy")[np.r_[N]] 8 | ufno = np.load(f"log/aws/ux96.npy")[np.r_[N]] 9 | 10 | fig, ax = plt.subplots(ncols=len(N), nrows=3, figsize=(2.1*len(N), 4), 11 | constrained_layout=True, squeeze=False, dpi=72) 12 | 13 | fig.set_constrained_layout_pads(hspace=0.1, wspace=0.1) 14 | 15 | def create(sol, nsm, fno, i: int): 16 | 17 | N, K = 15, 4 # remove dark colors 18 | vmin = u[i].min() * (N + 2*K) / N 19 | vmax = u[i].max() * (N + 2*K) / N 20 | 21 | cmap = plt.get_cmap("twilight_shifted") 22 | norm = clr.BoundaryNorm(np.linspace(vmin, vmax, N+2*K), cmap.N) 23 | 24 | im_sol = sol.imshow(u[i, 0], cmap=cmap, norm=norm) 25 | im_nsm = nsm.imshow(unsm[i, 0], cmap=cmap, norm=norm) 26 | im_fno = fno.imshow(ufno[i, 0], cmap=cmap, norm=norm) 27 | 28 | sol.set_xticks([]); sol.set_yticks([]) 29 | nsm.set_xticks([]); nsm.set_yticks([]) 30 | fno.set_xticks([]); fno.set_yticks([]) 31 | 32 | return im_sol, im_nsm, im_fno 33 | 34 | im = [] 35 | for i in range(len(N)): 36 | im.append(create(*ax[:, i], i)) 37 | 38 | kw = dict(rotation=90, labelpad=10, fontsize=10) 39 | ax[0, 0].set_ylabel("Numerical solver", **kw) 40 | ax[1, 0].set_ylabel("NSM (ours) Prediction", **kw) 41 | ax[2, 0].set_ylabel("FNO+PINN Prediction", **kw) 42 | 43 | plt.colorbar(im[-1][0],fraction=0.046, pad=0.1, ticks=[-2, -1, 0, 1]) 44 | plt.colorbar(im[-1][1],fraction=0.046, pad=0.1, ticks=[-2, -1, 0, 1]) 45 | plt.colorbar(im[-1][2],fraction=0.046, pad=0.1, ticks=[-2, -1, 0, 1]) 46 | 47 | def frame(index, total: int = 64): 48 | i = int(total * (index/T/fps)) 49 | 50 | for n, [sol, nsm, fno] in enumerate(im): 51 | 52 | sol.set_array(u[n, i]) 53 | nsm.set_array(unsm[n, i]) 54 | fno.set_array(ufno[n, i]) 55 | 56 | return sum(im, ()) 57 | 58 | from matplotlib import animation 59 | ani = animation.FuncAnimation(fig, frame, (T:=8)*(fps:=24), interval=T/fps * 1000, blit=True) 60 | ani.save("plot/animation.gif", writer="pillow", fps=fps) 61 | -------------------------------------------------------------------------------- /plot/navierstokes.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | METHOD = ["x64", "x96", ":NSM"] 4 | LABELS = [f"FNO${TIMES}64^3$+PINN", f"FNO${TIMES}96^3$+PINN", "NSM (ours)"] 5 | 6 | res = ["re2", "re3", "re4"] 7 | 8 | def load(pde: str, metric: str, dir: str = "log/gcp/ns.T3"): 9 | 10 | def random(method: str): 11 | def call(seed: int): 12 | try: 13 | if method == "x64": return np.load(f"{dir}/{pde}{method}.{seed}/{metric}.npy")[:int(3600*24 // 0.69)] 14 | if method == "x96": return np.load(f"{dir}/{pde}{method}.{seed}/{metric}.npy")[:int(3600*72 // 2.33)] 15 | return np.load(f"{dir}/{pde}{method}.{seed}/{metric}.npy") 16 | except: print("oh no not finished!"); return np.array([np.inf]) 17 | return call 18 | 19 | return { method: list(map(random(method), range(4))) for method in METHOD } 20 | 21 | # ---------------------------------------------------------------------------- # 22 | # TABLE # 23 | # ---------------------------------------------------------------------------- # 24 | 25 | for re in res: 26 | 27 | print(f"{re = }", end=":\t") 28 | 29 | for method, errs in load(re, "metric.errr").items(): 30 | 31 | errs = np.array([err[-1] for err in errs]) * 100 32 | print(f"{np.mean(errs):.2f}±{np.std(errs):.2f}", end="\t") 33 | 34 | print("") 35 | 36 | # ---------------------------------------------------------------------------- # 37 | # VISUALIZE # 38 | # ---------------------------------------------------------------------------- # 39 | 40 | import sys 41 | if len(sys.argv) > 1: 42 | 43 | N = int(sys.argv[1]) 44 | 45 | fig, [ok, nsm, fno] = plt.subplots(ncols=(T:=3)+1, nrows=3, figsize=(T*2, 5)) 46 | 47 | u = np.load("log/gcp/ns.T3/u.npy")[N] 48 | unsm = np.load(f"log/{sys.argv[2]}/uhat.npy")[N] 49 | ufno = np.load(f"log/{sys.argv[3]}/uhat.npy")[N] 50 | 51 | # remove dark colors 52 | 53 | L, K = 15, 4 54 | vmin = u.min() * (L+2*K)/L 55 | vmax = u.max() * (L+2*K)/L 56 | 57 | def show(ax, img): 58 | 59 | cmap = plt.get_cmap("twilight_shifted") 60 | norm = clr.BoundaryNorm(np.linspace(vmin, vmax, L+2*K+1), cmap.N) 61 | ax.imshow(img, cmap=cmap, norm=norm) 62 | ax.set_xticks([]) 63 | ax.set_yticks([]) 64 | 65 | for i in range(T+1): 66 | 67 | t = int(63 * i/T) 68 | 69 | show(ok[i], u[t]) 70 | show(nsm[i], unsm[t]) 71 | show(fno[i], ufno[t]) 72 | 73 | ok[0].set_title("Initial vorticity", fontsize=10) 74 | for t in range(1, T + 1): 75 | ok[t].set_title(f"$T={t}$", fontsize=10) 76 | 77 | ok[0].set_ylabel("Numerical solver", rotation=90, labelpad=10, fontsize=10) 78 | nsm[0].set_ylabel("NSM (ours) prediction", rotation=90, labelpad=10, fontsize=10) 79 | fno[0].set_ylabel("FNO+PINN prediction", rotation=90, labelpad=10, fontsize=10) 80 | 81 | ok[0].text(-0.21, 1.03, "(c)", transform=ok[0].transAxes, fontsize=10) 82 | 83 | fig.tight_layout() 84 | fig.savefig(f"plot/navierstokes.{N}.jpg", dpi=300) 85 | 86 | # ---------------------------------------------------------------------------- # 87 | # CURVE # 88 | # ---------------------------------------------------------------------------- # 89 | 90 | for re in ["re4"]: 91 | 92 | figure = plt.figure(figsize=(3, 4.5)) 93 | res, err = figure.subplots(nrows=2, sharex=True) 94 | 95 | pct = tkr.FuncFormatter(lambda y, _: f"{y:.1%}"[:3]) 96 | 97 | rate = load(re, "metric.rate") 98 | 99 | # --------------------------------- RESIDUAL --------------------------------- # 100 | 101 | residual = load(re, "metric.residual") 102 | 103 | for method, label in zip(METHOD, LABELS): 104 | 105 | xs = np.linspace(1, max(map(np.sum, rate[method])), N:=1000) 106 | 107 | ys = [] 108 | for rat, rrr in zip(rate[method], residual[method]): 109 | 110 | a = np.concatenate([np.zeros(1), np.cumsum(rat)])[:-1] 111 | 112 | ys.append(np.interp(xs, a, rrr[:len(a)])) 113 | 114 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0) 115 | ys_std = np.std(np.stack(ys, axis=0), axis=0) 116 | 117 | res.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method)) 118 | res.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None) 119 | 120 | res.set_xscale("log") 121 | res.set_yscale("log") 122 | 123 | res.xaxis.set_label_position("top") 124 | res.set_xlabel("Training time (seconds)", fontsize=10) 125 | res.set_ylabel(f"PDE residual on test set", fontsize=10) 126 | 127 | res.yaxis.grid(False, which='minor') 128 | res.set_yticks([0.1, 0.08, 0.06, 0.04, 0.02]) 129 | res.yaxis.set_major_formatter(lambda x, _: str(x)) 130 | res.yaxis.set_minor_formatter(lambda *_: "") 131 | 132 | err.set_xlim(10, 3e5) 133 | err.xaxis.tick_top() 134 | for tick in err.xaxis.get_majorticklabels(): 135 | tick.set_horizontalalignment("left") 136 | 137 | # ----------------------------------- ERROR ---------------------------------- # 138 | 139 | # TRAIN 140 | 141 | errr = load(re, "metric.errr") 142 | 143 | for method, label in zip(METHOD, LABELS): 144 | 145 | xs = np.linspace(1, max(map(np.sum, rate[method])), N:=1000) 146 | 147 | ys = [] 148 | for rat, rrr in zip(rate[method], errr[method]): 149 | 150 | a = np.concatenate([np.zeros(1), np.cumsum(rat)])[:-1] 151 | 152 | ys.append(np.interp(xs, a, rrr[:len(a)])) 153 | 154 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0) 155 | ys_std = np.std(np.stack(ys, axis=0), axis=0) 156 | 157 | err.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method)) 158 | err.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None) 159 | 160 | err.set_xscale("log") 161 | err.set_yscale("log") 162 | 163 | err.set_ylabel(f"$L_2$ rel. error (\\%) on test set", fontsize=10) 164 | 165 | err.yaxis.set_major_formatter(pct) 166 | err.set_yticks([0.05, 0.1, 0.2, 0.3]) 167 | 168 | res.legend(fontsize=7, loc="lower left", handlelength=1.7) 169 | err.legend(fontsize=7, loc="lower left", handlelength=1.7) 170 | 171 | res.xaxis.set_tick_params(labelsize=8) 172 | res.yaxis.set_tick_params(labelsize=8) 173 | err.yaxis.set_tick_params(labelsize=8) 174 | err.xaxis.set_tick_params(labelsize=8) 175 | 176 | res.yaxis.set_label_coords(-0.15, 0.5) 177 | err.yaxis.set_label_coords(-0.15, 0.5) 178 | 179 | res.text(-0.21, 1, "(a)", transform=res.transAxes, fontsize=10) 180 | err.text(-0.21, 1, "(b)", transform=err.transAxes, fontsize=10) 181 | 182 | figure.tight_layout(h_pad=0.1) 183 | figure.savefig(f"plot/ns.curve.{re}.jpg", dpi=300) 184 | -------------------------------------------------------------------------------- /plot/poisson.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | METHOD = [":SNO", "x64", "x128", "x256", ":NSM"] 4 | 5 | load = F.partial(load, dir="log/ps.dirichlet", METHOD=METHOD) 6 | bcs = ["periodic", "dirichlet"] 7 | 8 | # ---------------------------------------------------------------------------- # 9 | # TABLE # 10 | # ---------------------------------------------------------------------------- # 11 | 12 | print("="*116) 13 | print("*relu*\t\t\t", "\t\t".join(METHOD)) 14 | print("-"*116) 15 | 16 | for bc in bcs: 17 | 18 | print(f"{bc = }", end=":\t") 19 | 20 | for method, errs in load("relu/"+bc, "metric.errr").items(): 21 | 22 | errs = np.array([err[-1] for err in errs]) * 100 23 | print(f"{np.mean(errs):.3f}±{np.std(errs):.3f}", end=" \t") 24 | 25 | print("") 26 | 27 | print("="*116) 28 | print("*tanh*\t\t\t", "\t\t".join(METHOD)) 29 | print("-"*116) 30 | 31 | for bc in bcs: 32 | 33 | print(f"{bc = }", end=":\t") 34 | 35 | for method, errs in load("tanh/"+bc, "metric.errr").items(): 36 | 37 | errs = np.array([err[-1] for err in errs]) * 100 38 | print(f"{np.mean(errs):.3f}±{np.std(errs):.3f}", end=" \t") 39 | 40 | print("") 41 | 42 | print("="*116) 43 | print("*long*\t\t\t", "\t\t".join(METHOD)) 44 | print("-"*116) 45 | 46 | for bc in bcs: 47 | 48 | print(f"{bc = }", end=":\t") 49 | 50 | for method, errs in load("long/"+bc, "metric.errr").items(): 51 | 52 | errs = np.array([err[-1] for err in errs]) * 100 53 | print(f"{np.mean(errs):.3f}±{np.std(errs):.3f}", end=" \t") 54 | 55 | print("") 56 | 57 | print("="*116) 58 | -------------------------------------------------------------------------------- /plot/reaction.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | METHOD = [":SNO", "x64", "x128", "x256", "x256C", ":NSM"] 4 | 5 | # FNOPINN = lambda n, m="": "FNO{\\scriptsize$\\!\\times\\!"+str(n)+"^2$"+m+"}+PINN" 6 | FNOPINN = lambda n, m="": "FNO$\\!\\times\\!" + str(n) + "^2$" + m + "+PINN" 7 | LABELS = [None, FNOPINN(64, "\\ \\ "), FNOPINN(128), FNOPINN(256), f"CNO+PINN(ours)", f"NSM (ours)"] 8 | 9 | load = F.partial(load, dir="log/rd", METHOD=METHOD) 10 | 11 | # ---------------------------------------------------------------------------- # 12 | # TABLE # 13 | # ---------------------------------------------------------------------------- # 14 | 15 | for nu in ["005", "01", "05", "1"]: 16 | 17 | print(f"{nu = }", end=":\t") 18 | 19 | for method, errs in load(f"nu{nu}", "metric.errr").items(): 20 | 21 | errs = np.array([err[-1] for err in errs]) * 100 22 | print(f"{np.mean(errs):.3f}±{np.std(errs):.3f}", end="\t") 23 | 24 | print("") 25 | 26 | 27 | for nu in ["01"]: 28 | 29 | figure = plt.figure(figsize=(8, 4), constrained_layout=False) 30 | train, test = figure.subfigures(ncols=2, width_ratios=[1, 1.2]) 31 | 32 | train.subplots_adjust(hspace=0.31) 33 | test.subplots_adjust(left=0.08) 34 | 35 | pct = tkr.FuncFormatter(lambda y, _: f"{y:.1%}"[:3]) 36 | 37 | # ---------------------------------------------------------------------------- # 38 | # TRAIN # 39 | # ---------------------------------------------------------------------------- # 40 | 41 | rate = load(f"nu{nu}", "metric.rate") 42 | res, err = train.subplots(nrows=2, sharex=True) 43 | 44 | # --------------------------------- RESIDUAL --------------------------------- # 45 | 46 | residual = load(f"nu{nu}", "metric.residual") 47 | 48 | for method, label in zip(METHOD, LABELS): 49 | if label is None: continue 50 | 51 | xs = np.linspace(1, max(map(np.sum, rate[method])), N:=1000) 52 | ys = [np.interp(xs, np.concatenate([np.zeros(1), np.cumsum(rate)]), residual) 53 | for rate, residual in zip(rate[method], residual[method])] 54 | 55 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0) 56 | ys_std = np.std(np.stack(ys, axis=0), axis=0) 57 | 58 | res.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method)) 59 | res.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None) 60 | 61 | res.set_xscale("log") 62 | res.set_yscale("log") 63 | 64 | res.set_xlim(1, 1e4) 65 | res.set_xticks([1, 1e1, 1e2, 1e3, 1e4]) 66 | 67 | # res.yaxis.set_major_formatter(tkr.ScalarFormatter()) 68 | res.yaxis.set_label_coords(-0.13, None) 69 | 70 | # ---------------------------------- ERROR % --------------------------------- # 71 | 72 | errr = load(f"nu{nu}", "metric.errr") 73 | 74 | for method, label in zip(METHOD, LABELS): 75 | if label is None: continue 76 | 77 | xs = np.linspace(1, max(map(np.sum, rate[method])), N:=1000) 78 | ys = [np.interp(xs, np.concatenate([np.zeros(1), np.cumsum(rate)]), errr) 79 | for rate, errr in zip(rate[method], errr[method])] 80 | 81 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0) 82 | ys_std = np.std(np.stack(ys, axis=0), axis=0) 83 | 84 | err.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method)) 85 | err.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None) 86 | 87 | err.set_xscale("log") 88 | err.set_yscale("log") 89 | 90 | err.set_xlim(1, 1e4) 91 | err.set_xticks([1, 1e1, 1e2, 1e3, 1e4]) 92 | 93 | err.set_ylim(6e-4) 94 | 95 | err.yaxis.set_major_formatter(pct) 96 | err.yaxis.set_label_coords(-0.13, None) 97 | 98 | # ---------------------------------------------------------------------------- # 99 | # TEST # 100 | # ---------------------------------------------------------------------------- # 101 | 102 | ax = test.subplots() 103 | 104 | mkr = dict(marker=".", markersize=7) 105 | 106 | # ------------------------------------ FNO ----------------------------------- # 107 | 108 | time = np.load("log/rd/solver/time.fno.npy") 109 | 110 | for method, label in zip(METHOD, LABELS): 111 | 112 | if method[0] == "x" and method[-1] != "C": 113 | 114 | ax.plot(X:=time.mean(-1), Y:=np.load(f"log/rd/solver/errr.fnox{method[1:]}.npy"), **mkr, label=label+" (training)", c=color(method), ls=lines(method)) 115 | 116 | def getl(n): 117 | 118 | style = dict(text=f"${2**n*32}{TIMES}{2**n*32}$", xy=(x:=X[n], y:=Y[n]), textcoords="offset pixels", fontsize=9) 119 | 120 | if n < 2: style.update(xytext=(x-20, y+40)) 121 | if n > 1: style.update(xytext=(x-78, y+40)) 122 | 123 | return style 124 | 125 | if method == "x256": 126 | for n in range(2): ax.annotate(**getl(n)) 127 | if method == "x64": 128 | for n in range(2, 5): ax.annotate(**getl(n)) 129 | 130 | # ---------------------------------- SOLVER ---------------------------------- # 131 | 132 | time = np.load("log/rd/solver/time.solver.npy") 133 | errr = np.load(f"log/rd/solver/errr.solver.npy") 134 | 135 | ax.plot(X:=time.mean(-1), Y:=errr.mean(-1), c="black", **mkr, label="Numerical solver") 136 | for n in range(1, len(errr)+1): ax.annotate(f"${2**n}{TIMES}{2**n}$", (x:=X[n-1], y:=Y[n-1]), (x+16, y), textcoords="offset pixels", fontsize=9) 137 | 138 | # ------------------------------------ NSM ----------------------------------- # 139 | 140 | time = np.load("log/rd/solver/time.nsm.npy") 141 | 142 | ax.plot(x:=time.mean(), y:=np.load("log/rd/solver/errr.nsm.npy").mean(), **mkr, c=CLR[0], label="NSM (ours)") 143 | ax.scatter(x, y, c=CLR[0], s=20) 144 | ax.annotate(f"$32{TIMES}32$", (x, y), (x-9, y+152), textcoords="offset pixels", fontsize=9) 145 | ax.annotate(f"$64{TIMES}64$", (x, y), (x-9, y+120), textcoords="offset pixels", fontsize=9) 146 | ax.annotate(f"$128{TIMES}128$", (x, y), (x-9, y+88), textcoords="offset pixels", fontsize=9) 147 | ax.annotate(f"$256{TIMES}256$", (x, y), (x-9, y+56), textcoords="offset pixels", fontsize=9) 148 | ax.annotate(f"$512{TIMES}512$", (x, y), (x-9, y+24), textcoords="offset pixels", fontsize=9) 149 | 150 | ax.set_xscale("log") 151 | ax.set_yscale("log") 152 | 153 | ax.set_xlim(6e-4, 1e-1) 154 | ax.set_ylim(6e-4, 1e-1) 155 | 156 | ax.yaxis.set_major_formatter(pct) 157 | ax.yaxis.set_label_coords(-0.09, None) 158 | 159 | # ---------------------------------------------------------------------------- # 160 | # LAYOUT # 161 | # ---------------------------------------------------------------------------- # 162 | 163 | res.legend(*reorder(res, [0, 1, 2, 3, 4]), loc="lower center", handlelength=1.5, bbox_to_anchor=(0.5, -0.315), ncol=3, fontsize=7, labelspacing=0.5, columnspacing=1) 164 | ax.legend(loc="upper right", handlelength=1.8, ncols=2, fontsize=8, labelspacing=0.3, columnspacing=0.8) 165 | 166 | res.set_ylabel(f"PDE residual on\n $512{TIMES}512$ test res.", fontsize=10) 167 | err.set_xlabel(f"Training time (seconds)", fontsize=10) 168 | err.set_ylabel(f"$L_2$ rel. error (\\%) on\n $512{TIMES}512$ test res.", fontsize=10) 169 | 170 | ax.set_xlabel("Inference time (seconds)", fontsize=10) 171 | ax.set_ylabel("$L_2$ rel. error (\\%) on different test resolutions", fontsize=10) 172 | 173 | res.xaxis.set_tick_params(labelsize=8) 174 | res.yaxis.set_tick_params(labelsize=8) 175 | err.xaxis.set_tick_params(labelsize=8) 176 | err.yaxis.set_tick_params(labelsize=8) 177 | ax.xaxis.set_tick_params(labelsize=8) 178 | ax.yaxis.set_tick_params(labelsize=8) 179 | 180 | res.text(-0.23, 1, "(a)", transform=res.transAxes, fontsize=10) 181 | err.text(-0.23, 1, "(b)", transform=err.transAxes, fontsize=10) 182 | ax.text(-0.13, 1, "(c)", transform=ax.transAxes, fontsize=10) 183 | 184 | figure.savefig(f"plot/re.curve.jpg", dpi=300) 185 | 186 | # ---------------------------------------------------------------------------- # 187 | # BOX # 188 | # ---------------------------------------------------------------------------- # 189 | 190 | FNOPINN = lambda n, m="": "FNO$\\!\\times\\!" + str(n) + "^2$" + m 191 | LABELS = [None, FNOPINN(64, "\\ \\ "), FNOPINN(128), FNOPINN(256), f"CNO (ours)", f"NSM (ours)"] 192 | 193 | fig, axes = plt.subplots(figsize=(16, 3.4), ncols=4) 194 | for ax, nu in zip(axes, ["005", "01", "05", "1"]): 195 | 196 | u = np.load(f"src/pde/reaction/u.rho=5:nu=0.{nu.ljust(3, '0')}.npy") 197 | uhat = load(f"nu{nu}", "uhat") 198 | 199 | ax.boxplot([np.concatenate([np.mean(np.abs(uhat-u), axis=(1, 2, 3)) 200 | for uhat in uhat[method][:1]]) for method in METHOD[1:]], labels=LABELS[1:]) 201 | 202 | ax.set_title(f"Absolute error distribution for $\\nu=0.{nu}$", fontsize=10) 203 | 204 | ax.set_ylim(0, 0.06) 205 | 206 | axes[0].text(-0.08, 1.04, "(a)", transform=axes[0].transAxes, fontsize=10) 207 | axes[1].text(-0.08, 1.04, "(b)", transform=axes[1].transAxes, fontsize=10) 208 | axes[2].text(-0.08, 1.04, "(c)", transform=axes[2].transAxes, fontsize=10) 209 | axes[3].text(-0.08, 1.04, "(d)", transform=axes[3].transAxes, fontsize=10) 210 | 211 | fig.savefig(f"plot/re.box.jpg", dpi=300) 212 | -------------------------------------------------------------------------------- /plot/reaction.solver.py: -------------------------------------------------------------------------------- 1 | from src.utils import * 2 | from src.pde.reaction import * 3 | from src.pde.reaction.generate import * 4 | 5 | from tqdm import tqdm 6 | 7 | def time(N: int = 12) -> X: 8 | 9 | h, s, u = nu005.solution 10 | h = h.map(lambda x: x[0]) 11 | 12 | def check(n: int) -> float: 13 | 14 | return timeit(lambda: 15 | solution(lambda x: h(np.array([0, x]))[0], 16 | 1.0, 1.0, nx=2**n, nt=2**n+1))() 17 | 18 | return list(map(check, tqdm(range(1, N)))) 19 | 20 | def error(rd: ReactionDiffusion, N: int = 12) -> X: 21 | 22 | h, s, u = rd.solution 23 | 24 | def solve(f: Basis, n: int) -> X: 25 | 26 | h = lambda x: f(np.array([0, x])).squeeze() 27 | return solution(h, rd.nu, rd.rho, n, n+1)[1] 28 | 29 | U = jax.lax.map(F.partial(solve, n=(K:=2**N)), h) 30 | 31 | def call(h: Fx, u: X, n: int) -> X: 32 | 33 | un = solve(h, k:=2**n) 34 | uN = u[::K//k, ::K//k] 35 | 36 | return np.linalg.norm(np.ravel(uN - un)) \ 37 | / np.linalg.norm(np.ravel(uN)) 38 | 39 | return [jax.vmap(F.partial(call, n=n))(h, U) 40 | for n in tqdm(range(1, N))] 41 | 42 | # ---------------------------------------------------------------------------- # 43 | # SOLVER # 44 | # ---------------------------------------------------------------------------- # 45 | 46 | np.save("log/rd/solver/errr.solver.npy", np.array(error(nu01))) 47 | np.save("log/rd/solver/time.solver.npy", np.array([time() for _ in tqdm(range(16))])) 48 | 49 | # ---------------------------------------------------------------------------- # 50 | # MODEL # 51 | # ---------------------------------------------------------------------------- # 52 | 53 | pde, model = "don't run me this way" 54 | prng = "run main with `--test` flag" 55 | 56 | def solve(f: Basis, n: int) -> X: 57 | h = lambda x: f(np.array([0, x])).squeeze() 58 | return solution(h, pde.nu, pde.rho, n, n+1)[1] 59 | 60 | U = jax.lax.map(F.partial(solve, n=4096), pde.solution[0]) 61 | # np.save("test.U.npy", U);;;;;; U=np.load("test.U.npy") 62 | 63 | def acc(n: int, K=4096): 64 | uhat = jax.lax.map(F.partial(model.apply, model.variables, x=((k:=2**n)+1, k+1), method="u"), pde.solution[0]) 65 | return np.linalg.norm((u:=U[:, ::K//k, ::K//k]) - uhat[:, :, :-1, 0]) / np.linalg.norm(u) 66 | 67 | grid = model.cfg["grid"] 68 | accuracy = np.stack([acc(n) for n in range(5, 10)]) 69 | np.save(f"log/rd/solver/errr.fnox{grid}.npy", accuracy) 70 | 71 | def tim(n: int): 72 | v = model.init(prng, p:=pde.params.sample(prng), (s:=2**n+1, s), method="forward") 73 | timer = timeit(lambda: model.apply(v, p, (s:=2**n+1, s), method="forward")) 74 | return np.array([timer() for _ in range(16)]) 75 | 76 | time = np.stack([tim(n) for n in range(5, 10)]) 77 | np.save(f"log/rd/solver/time.fno.npy", time) 78 | 79 | # ------------------------------------ NSM ----------------------------------- # 80 | 81 | accuracy = np.stack([acc(n) for n in range(5, 10)]) 82 | np.save(f"log/rd/solver/errr.nsm.npy", accuracy) 83 | 84 | v = model.init(prng, p:=pde.params.sample(prng), method="forward") 85 | time = timeit(lambda: model.apply(v, p, method="forward")) 86 | 87 | np.save("log/rd/solver/time.nsm.npy", np.array([time() for _ in range(16)])) 88 | -------------------------------------------------------------------------------- /plot/visualize.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | from sys import argv; N = int(argv[1]) 4 | 5 | # ---------------------------------------------------------------------------- # 6 | # POISSON # 7 | # ---------------------------------------------------------------------------- # 8 | 9 | u = np.load(f"src/pde/poisson/u.periodic.npy")[N] 10 | s = np.load("test.s.npy") 11 | 12 | fig, [ax_s, ax_u, err] = plt.subplots(ncols=3, figsize=(8.5, 2.2), 13 | width_ratios=[1, 1, 1.3], 14 | constrained_layout=True) 15 | 16 | plt.colorbar(ax_s.imshow(s, origin='lower', extent=[0, 255/256, 0, 255/256]), fraction=0.046, pad=0.04 ); ax_s.grid(False); ax_s.set_xticks([0, 0.2, 0.4, 0.6, 0.8]); ax_s.set_yticks([0, 0.2, 0.4, 0.6, 0.8]); ax_s.set_title("Source term $s$", fontsize=10) 17 | plt.colorbar(ax_u.imshow(u, origin='lower', extent=[0, 255/256, 0, 255/256]), fraction=0.046, pad=0.04, ticks=[0, 0.01, 0.02]); ax_u.grid(False); ax_u.set_xticks([0, 0.2, 0.4, 0.6, 0.8]); ax_u.set_yticks([0, 0.2, 0.4, 0.6, 0.8]); ax_u.set_title("Solution $u$", fontsize=10) 18 | 19 | ax_s.xaxis.set_tick_params(labelsize=8) 20 | ax_s.yaxis.set_tick_params(labelsize=8) 21 | ax_u.xaxis.set_tick_params(labelsize=8) 22 | ax_u.yaxis.set_tick_params(labelsize=8) 23 | 24 | def work(METHOD, LABELS, run: str): 25 | 26 | load_now = F.partial(load, dir=f"log/ps/{run}") 27 | rate = load_now("periodic", "metric.rate", METHOD=METHOD) 28 | errr = load_now("periodic", "metric.errr", METHOD=METHOD) 29 | 30 | for method, label in zip(METHOD, LABELS): 31 | 32 | xs = np.linspace(1, max(map(np.sum, rate[method])), 1000) 33 | ys = [np.interp(xs, np.concatenate([np.zeros(1), np.cumsum(rate)]), errr) 34 | for rate, errr in zip(rate[method], errr[method])] 35 | 36 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0) 37 | ys_std = np.std(np.stack(ys, axis=0), axis=0); ys_std = np.minimum(ys_std, ys_mean/2) 38 | 39 | err.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method)) 40 | err.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None) 41 | 42 | err.set_xscale("log") 43 | err.set_yscale("log") 44 | 45 | err.set_xlim(1, 1e5) 46 | err.set_xticks([1, 1e1, 1e2, 1e3, 1e4, 1e5]) 47 | 48 | err.yaxis.set_major_formatter(tkr.FuncFormatter(lambda y, _: f"{y:.1%}"[:3])) 49 | 50 | work([":NSM", ":SNO"], ["NSM", "SNO"], "relu") 51 | work(["x64", "x128"], [f"FNO${TIMES}64^2\\ \\ $+PINN", f"FNO${TIMES}128^2$+PINN"], "tanh") 52 | work(["x256"], [f"FNO${TIMES}256^2$+PINN"], "long") 53 | 54 | err.set_title(f"$L_2$ rel. error (\\%) on $256{TIMES}256$ test res.", fontsize=10) 55 | err.legend(loc="lower right")#, handlelength=1.3, bbox_to_anchor=(1.5, 0.5), fontsize=8, labelspacing=0.3) 56 | err.xaxis.set_tick_params(labelsize=8) 57 | err.yaxis.set_tick_params(labelsize=8) 58 | 59 | ax_s.text(-0.11, 1.05, "(a)", transform=ax_s.transAxes, fontsize=10) 60 | ax_u.text(-0.11, 1.05, "(b)", transform=ax_u.transAxes, fontsize=10) 61 | err.text(-0.08, 1.05, "(c)", transform=err.transAxes, fontsize=10) 62 | 63 | fig.savefig(f"plot/poisson.{N}.jpg", dpi=300) 64 | 65 | # ---------------------------------------------------------------------------- # 66 | # REACTION # 67 | # ---------------------------------------------------------------------------- # 68 | 69 | nus = [0.005, 0.01, 0.05, 0.1] 70 | us = [np.load(f"src/pde/reaction/u.rho=5:nu={nu:.3f}.npy")[N] for nu in nus] 71 | 72 | fig, axes = plt.subplots(ncols=4, figsize=(7, 2), constrained_layout=True) 73 | for u, ax, nu in zip(us:=np.stack(us), axes, nus): 74 | 75 | ax.grid(False) 76 | im = ax.imshow(u, cmap="Spectral", origin="lower", 77 | vmin=us.min(), vmax=us.max(), 78 | extent=[0, 511/512, 0, 1]) 79 | 80 | ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8]) 81 | ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1]) 82 | 83 | ax.set_title(f"$\\nu={nu}$") 84 | 85 | plt.colorbar(im,fraction=0.046, pad=0.1) 86 | 87 | fig.savefig(f"plot/reaction.{N}.jpg", dpi=300) 88 | -------------------------------------------------------------------------------- /run/.sh: -------------------------------------------------------------------------------- 1 | XLA_PYTHON_CLIENT_PREALLOCATE=True \ 2 | XLA_PYTHON_CLIENT_MEM_FRACTION=.95 \ 3 | python main.py \ 4 | --seed ${seed:-$RANDOM} \ 5 | --hdim 64 \ 6 | --depth 4 \ 7 | --activate relu \ 8 | $@ train \ 9 | --bs ${bs:-16} \ 10 | --lr ${lr:-1e-3} \ 11 | --schd ${schd:-exp} \ 12 | --iter ${iter:-30000} \ 13 | --vmap ${vmap:-""} \ 14 | --ckpt ${ckpt:-100} \ 15 | --note ${note:-"$(date)"} -------------------------------------------------------------------------------- /run/navierstokes.forced.sh: -------------------------------------------------------------------------------- 1 | note=ns.T50/NSM."$seed" iter=200000 ckpt=1000 bash run/.sh --pde navierstokes.tf --model fno --hdim 32 --depth 5 --mode 18 31 31 --spectral 2 | note=ns.T50/x96."$seed" iter=200000 ckpt=1000 vmap=1 bash run/.sh --pde navierstokes.tf --model fno --hdim 32 --depth 5 --mode 18 31 31 --grid 96 -------------------------------------------------------------------------------- /run/navierstokes.sh: -------------------------------------------------------------------------------- 1 | for re in 4 3 2; do 2 | 3 | note=ns.T3/re"$re":NSM."$seed" iter=200000 ckpt=1000 bash run/.sh --pde navierstokes.re"$re" --model fno --hdim 32 --depth 10 --mode 12 31 31 --spectral 4 | note=ns.T3/re"$re"x64."$seed" iter=200000 ckpt=1000 vmap=1 bash run/.sh --pde navierstokes.re"$re" --model fno --hdim 32 --depth 10 --mode 12 31 31 --grid 64 5 | note=ns.T3/re"$re"x96."$seed" iter=200000 ckpt=1000 vmap=1 bash run/.sh --pde navierstokes.re"$re" --model fno --hdim 32 --depth 10 --mode 12 31 31 --grid 96 6 | 7 | done -------------------------------------------------------------------------------- /run/navierstokes.yaml: -------------------------------------------------------------------------------- 1 | name: ns 2 | workdir: . 3 | resources: 4 | 5 | cloud: gcp 6 | disk_size: 256 7 | accelerators: A100:1 8 | 9 | image_id: "projects/deeplearning-platform-release/global\ 10 | /images/common-cu113-v20230615-ubuntu-2004-py310" 11 | 12 | setup: | 13 | 14 | echo ==================== 15 | echo executing init...... 16 | echo ==================== 17 | 18 | JAXLIB=jaxlib-0.4.7+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64 19 | pip install https://storage.googleapis.com/jax-releases/cuda11/$JAXLIB.whl 20 | 21 | pip install -r run/requirements.txt 22 | 23 | echo ===================== 24 | echo generating data...... 25 | echo ===================== 26 | 27 | python -m src.pde.navierstokes.generate ns 28 | python -m src.pde.navierstokes.generate tf 29 | -------------------------------------------------------------------------------- /run/poisson.dirichlet.sh: -------------------------------------------------------------------------------- 1 | for seed in 0 1 2 3; do 2 | 3 | seed=$seed iter=100000 note=ps/relu/dirichlet:SNO."$seed" bash run/.sh --pde poisson.dirichlet --model sno --mode 31 31 --spectral 4 | seed=$seed iter=100000 note=ps/relu/dirichlet:NSM."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --spectral 5 | seed=$seed iter=100000 note=ps/relu/dirichletx64."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 64 6 | seed=$seed iter=100000 note=ps/relu/dirichletx128."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 128 7 | seed=$seed iter=100000 note=ps/relu/dirichletx256."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 256 8 | seed=$seed iter=100000 note=ps/relu/dirichletx256C."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 256 --cheb --vmap 16 9 | 10 | seed=$seed iter=100000 note=ps/tanh/dirichlet:NSM."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --spectral --activate tanh 11 | seed=$seed iter=100000 note=ps/tanh/dirichletx64."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 64 --activate tanh 12 | seed=$seed iter=100000 note=ps/tanh/dirichletx128."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 128 --activate tanh 13 | seed=$seed iter=100000 note=ps/tanh/dirichletx256."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 256 --activate tanh 14 | seed=$seed iter=100000 note=ps/tanh/dirichletx256C."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 256 --activate tanh --cheb --vmap 16 15 | 16 | done -------------------------------------------------------------------------------- /run/poisson.periodic.sh: -------------------------------------------------------------------------------- 1 | for seed in 0 1 2 3; do 2 | 3 | seed=$seed note=ps/relu/periodic:SNO."$seed" bash run/.sh --pde poisson.periodic --model sno --mode 31 31 --spectral 4 | seed=$seed note=ps/relu/periodic:NSM."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --spectral 5 | seed=$seed note=ps/relu/periodicx64."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 64 6 | seed=$seed note=ps/relu/periodicx128."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 128 7 | seed=$seed note=ps/relu/periodicx256."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 256 8 | 9 | seed=$seed note=ps/tanh/periodicx64."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 64 --activate tanh 10 | seed=$seed note=ps/tanh/periodicx128."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 128 --activate tanh 11 | seed=$seed note=ps/tanh/periodicx256."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 256 --activate tanh 12 | 13 | iter=100000 \ 14 | seed=$seed note=ps/long/periodicx256."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 256 --activate tanh 15 | 16 | done -------------------------------------------------------------------------------- /run/reaction.sh: -------------------------------------------------------------------------------- 1 | function run() { 2 | note=rd/"$name":SNO."$seed" bash run/.sh --pde reaction."$name" --model sno --mode 32 64 --spectral 3 | note=rd/"$name":NSM."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --spectral 4 | note=rd/"$name"x64."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --grid 64 5 | note=rd/"$name"x128."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --grid 128 6 | note=rd/"$name"x256."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --grid 256 7 | note=rd/"$name"x256C."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --grid 256 --cheb 8 | note=rd/"$name"x256F."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --spectral --fourier 9 | } 10 | 11 | for seed in 0 1 2 3; do 12 | 13 | seed=$seed name=nu005 run 14 | seed=$seed name=nu01 run 15 | seed=$seed name=nu05 iter=100000 run 16 | seed=$seed name=nu1 iter=200000 run 17 | 18 | done -------------------------------------------------------------------------------- /run/requirements.txt: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | # PLEASE INSTALL JAXLIB BASE ON YOUR OWN MACHINE CONFIGURATION # 3 | ################################################################ 4 | 5 | jax==0.4.7 6 | flax==0.7.0 7 | optax==0.1.5 8 | 9 | tqdm==4.65.0 10 | matplotlib==3.7.0 11 | scienceplots==2.1.0 12 | -------------------------------------------------------------------------------- /src/README: -------------------------------------------------------------------------------- 1 | Please refer to the READMEs in each directories for details. 2 | 3 | FILE STRUCTURE 4 | 5 | |- main.py : entry point. 6 | |- ckpt.py : checkpoint subproc. 7 | |- plot/ : matplotlib utilities. 8 | |- run/ : command line scripts. 9 | |- src/ 10 | |- dists.py : common distributions. 11 | |- train.py : training routines: step & eval. 12 | |- utils.py : grid generation, differentiate schemes, etc. 13 | |- basis/ 14 | | |- fourier.py : 1-d Trigonometric series. 15 | | |- chebyshev.py : 1-d T type Chebyshev polynomials. 16 | |- pde/ 17 | | |- _domain.py : physical domain definition. N-d unit rect. 18 | | |- _params.py : parameterize PDE with interpolating series. 19 | | |- mollifier.py : how boundary condition applies to solution. 20 | |- model/ 21 | |- _base.py : shared modules. 22 | |- sno/spectral.py: SNO. 23 | |- fno/ : FNO. 24 | |- __init__.py : PINN. 25 | |- spectral.py : NSM (ours). 26 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | 5 | import operator as O 6 | import itertools as I 7 | import functools as F 8 | 9 | # ---------------------------------------------------------------------------- # 10 | # JAX # 11 | # ---------------------------------------------------------------------------- # 12 | 13 | import jax 14 | 15 | import flax 16 | import optax 17 | 18 | import jax.numpy as np 19 | import flax.linen as nn 20 | 21 | # ---------------------------------------------------------------------------- # 22 | # TYPE # 23 | # ---------------------------------------------------------------------------- # 24 | 25 | from abc import * 26 | from typing import * 27 | 28 | from jax import Array 29 | from flax import struct 30 | 31 | X = Union[Tuple["X", ...], List["X"], Array] 32 | ϴ = Union[struct.PyTreeNode, "X", None] 33 | 34 | Fx = Callable[..., "X"] # real-valued function 35 | Fϴ = Callable[..., "Fx"] # parametrized function 36 | 37 | # ---------------------------------------------------------------------------- # 38 | # CONST # 39 | # ---------------------------------------------------------------------------- # 40 | 41 | e = np.e 42 | π = np.pi 43 | 44 | Δ = F.partial(np.einsum, "...ii -> ...") 45 | 46 | # ---------------------------------------------------------------------------- # 47 | # RANDOM # 48 | # ---------------------------------------------------------------------------- # 49 | 50 | from jax import random 51 | RNG = Dict[str, random.KeyArray] 52 | 53 | class RNGS(RNG): 54 | 55 | def __init__(self, prng: random.KeyArray, name: List[str]): 56 | keys = random.split(prng, len(name)) 57 | super().__init__(zip(name, keys)) 58 | 59 | def __next__(self) -> RNG: 60 | self.it = getattr(self, "it", 0) + 1 61 | return self.fold_in(self.it) 62 | 63 | def fold_in(self, key: Any) -> RNG: 64 | return { name: random.fold_in(data, hash(key)) 65 | for name, data in self.items() } 66 | -------------------------------------------------------------------------------- /src/basis/README: -------------------------------------------------------------------------------- 1 | This is a self-contained module of different orthogonal basis functions. `Basis` 2 | is the abstract class defined for a general N-d basis, in terms of its spectral 3 | coefficients. Each class is associated with a static `ndim`, the dimensionality 4 | of the basis. The coefficient data, `coef`, is at least `ndim` in size: 5 | 6 | - The first `ndim` dimensions corresponds to the basis dimension; 7 | 8 | - The rest of the them are interpreted as arbitrary channels, i.e. they are 9 | broadcasted by each operation, and have no special meaning to the basis. 10 | 11 | The `Basis` class requires the following implementations: 12 | 13 | - `ix`: defines ordered indices of `coef`s w.r.t. given number of modes. It 14 | is used to truncate or extend instances of basis to other modes. 15 | 16 | - `fn`: defines array of basis functions w.r.t. given number of modes; 17 | 18 | - `grid`: defines the collocation points w.r.t. given number of modes; 19 | 20 | - `transform`/`inv`: transforms between coefficients and function values on 21 | the collocation points defined by the `grid` function; 22 | 23 | The `Basis` class provides the following functionalities: 24 | 25 | - `__call__`: evaluates function values at any position. `grid`point values 26 | are identical to the results of `inv`erse transformation; but 27 | calling `inv` is usually faster due to the use of FFTs. 28 | 29 | - `to`: aligns the instance of basis function to another mode; larger modes 30 | are truncated, and smaller modes are padded with zeros. 31 | 32 | - `add`/`mul`: sums / multiplies basis functions from the same class. Given 33 | operands are aligned to the same number of modes first. 34 | 35 | - `grad`/`int`: obtains the derivatives and indefinite integrals along each 36 | dim. The resulting function has an extra trailing dimension 37 | to it, representing the operation on each of the `ndim`s. 38 | 39 | IMPLEMENTATIONS 40 | 41 | Inherited from `Basis`, the `Series` specializes to 1-d basis, which is further 42 | realized by `fourier` and `chebyshev` basis. Built on any sequence of `Series`, 43 | the class factory function `series` recursively instantiates a `Basis` class on 44 | top of the given classes of 1-d series. 45 | 46 | **SHARP BITS** 47 | 48 | 1. It might not be lossless to take derivatives on Fourier basis, if the number 49 | of modes is even. I'm using a compact way to storing Fourier coefficients by 50 | squashing the Hermitian spectrum into reals of the same shape. Therefore the 51 | last coefficient is on its own, which will be dropped by taking gradient. 52 | 53 | 2. The (inverse) transform of Chebyshev basis does not work with singltons. I'm 54 | using Hermitian FFT, which returns zero-sized array for one-sized input. Try 55 | broadcasting the input to at least two-sized before transforming. 56 | -------------------------------------------------------------------------------- /src/basis/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | from .. import utils 3 | 4 | @struct.dataclass 5 | class Basis(ABC): 6 | 7 | coef: Array 8 | 9 | """ 10 | Basis function on [0, 1]^ndim. The leading `ndim` axes are for: 11 | 12 | - (spectral) coefficients of basis functions 13 | - (physical) values evaluated on collocation points 14 | """ 15 | 16 | @staticmethod 17 | 18 | @abstractmethod 19 | def repr(self) -> str: pass 20 | 21 | @staticmethod 22 | 23 | @abstractmethod 24 | def ndim() -> int: pass 25 | 26 | @property 27 | def mode(self): return self.coef.shape[:self.ndim()] 28 | 29 | def map(self, f: Fx): return self.__class__(f(self.coef)) 30 | 31 | @staticmethod 32 | 33 | @abstractmethod 34 | def grid(*mode: int) -> X: pass 35 | 36 | @staticmethod 37 | 38 | @abstractmethod 39 | def ix(*mode: int) -> X: pass 40 | 41 | @staticmethod 42 | 43 | @abstractmethod 44 | def fn(*mode: int, x: X) -> X: pass 45 | def __call__(self, x: X) -> X: 46 | 47 | assert x.shape[-1] == self.ndim(), f"{x.shape[-1]=} =/= {self.ndim()=}" 48 | return np.tensordot(self.fn(*self.mode, x=x), self.coef, self.ndim()) 49 | 50 | def to(self, *mode: int): 51 | 52 | if self.mode == mode: return self 53 | ax = self.ix(*map(min, mode, self.mode)) 54 | 55 | zero = np.zeros(mode + self.coef.shape[self.ndim():]) 56 | return self.__class__(zero.at[ax].set(self.coef[ax])) 57 | 58 | @classmethod 59 | def add(cls, *terms): return cls(sum(map(O.attrgetter("coef"), align(*terms, scheme=max)))) 60 | 61 | @classmethod 62 | def mul(cls, *terms): return cls.transform(math.prod(map(cls.inv, align(*terms, scheme=sum)))) 63 | 64 | # --------------------------------- TRANSFORM -------------------------------- # 65 | 66 | @staticmethod 67 | 68 | @abstractmethod 69 | def transform(x: X): pass 70 | 71 | @abstractmethod 72 | def inv(self) -> X: pass 73 | 74 | # --------------------------------- OPERATOR --------------------------------- # 75 | 76 | @abstractmethod 77 | def grad(self, k: int = 1): pass 78 | 79 | @abstractmethod 80 | def int(self, k: int = 1): pass 81 | 82 | def align(*basis: Basis, scheme: Fx = max) -> Tuple[Basis]: 83 | 84 | # asserting uniform properties: 85 | 86 | _ = set(map(lambda cls: cls.repr(), basis)) 87 | _ = set(map(lambda cls: cls.ndim(), basis)) 88 | _ = set(map(lambda self: self.coef.ndim, basis)) 89 | 90 | mode = tuple(map(scheme, zip(*map(O.attrgetter("mode"), basis)))) 91 | return tuple(map(lambda self: self.to(*mode), basis)) 92 | 93 | # ---------------------------------------------------------------------------- # 94 | # SERIES # 95 | # ---------------------------------------------------------------------------- # 96 | 97 | class SeriesMeta(ABCMeta, type): 98 | 99 | def __getitem__(cls, n: int): 100 | return series(*(cls,)*n) 101 | 102 | @struct.dataclass 103 | class Series(Basis, metaclass=SeriesMeta): 104 | 105 | """1-dimensional series on interval""" 106 | 107 | @staticmethod 108 | def ndim() -> int: return 1 # on [0, 1] 109 | def __len__(self): return len(self.coef) 110 | 111 | @abstractmethod 112 | def __getitem__(self, s: int) -> X: pass 113 | 114 | def series(*types: Type[Series]) -> Type[Basis]: 115 | 116 | """ 117 | Generate new basis using finite product of given series. Each argument 118 | type corresponds to certain kind of series used for each dimension. 119 | """ 120 | 121 | @struct.dataclass 122 | class Class(Basis): 123 | 124 | @staticmethod 125 | def repr() -> str: return "".join(map(O.methodcaller("repr"), types)) 126 | 127 | @staticmethod 128 | def ndim() -> int: return len(types) 129 | 130 | @staticmethod 131 | def grid(*mode: int) -> X: 132 | 133 | assert len(mode) == len(types) 134 | 135 | axes = mesh(lambda i, cls: cls.grid(mode[i]).squeeze(1)) 136 | return np.stack(axes, axis=-1) 137 | 138 | def ix(self, *mode: int) -> X: 139 | 140 | return np.ix_(*map(lambda self, n: self.ix(n), types, mode)) 141 | 142 | def fn(self, *mode: int, x: X) -> X: 143 | 144 | axes = mesh(lambda i, self: self.fn(mode[i], x=x[..., [i]])) 145 | return np.product(np.stack(axes, axis=-1), axis=-1) 146 | 147 | def __getitem__(self, s: Tuple[int]) -> X: 148 | 149 | return jax.vmap(F.partial(Super.__getitem__, s=s[1:]))(Super(Self(self.coef)[s[0]])) 150 | 151 | # --------------------------------- TRANSFORM -------------------------------- # 152 | 153 | @staticmethod 154 | def transform(x: X): 155 | 156 | return Class(Self.transform(jax.vmap(Super.transform)(x).coef).coef) 157 | 158 | def inv(self) -> X: 159 | 160 | return jax.vmap(Super.inv)(Super(Self(self.coef).inv())) 161 | 162 | # --------------------------------- OPERATOR --------------------------------- # 163 | 164 | def grad(self, k: int = 1): 165 | 166 | coef = jax.vmap(F.partial(Super.grad, k=k))(Super(self.coef)).coef 167 | return Class(np.concatenate([Self(self.coef).grad(k).coef, coef], axis=-1)) 168 | 169 | def int(self, k: int = 1): 170 | 171 | coef = jax.vmap(F.partial(Super.int, k=k))(Super(self.coef)).coef 172 | return Class(np.concatenate([Self(self.coef).int(k).coef, coef], axis=-1)) 173 | 174 | def mesh(call: Fx) -> Tuple[X]: 175 | def cat(*x: X) -> Tuple[X]: 176 | 177 | n, = set(map(np.ndim, x)) 178 | 179 | if n != 1: return jax.vmap(cat)(*x) 180 | return np.meshgrid(*x, indexing="ij") 181 | 182 | args = zip(*enumerate(types)) 183 | return cat(*map(call, *args)) 184 | 185 | try: cls, = types; return cls 186 | except: 187 | 188 | Self, *other = types 189 | Super = series(*other) 190 | 191 | return Class 192 | -------------------------------------------------------------------------------- /src/basis/chebyshev.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | @struct.dataclass 4 | class Chebyshev(Series): 5 | 6 | """ 7 | Chebyshev polynomial of T kind 8 | - Tn(x) = cos(n cos^-1(x)) 9 | - Tn^*(x) = Tn(2 x - 1) 10 | """ 11 | 12 | @staticmethod 13 | def repr() -> str: return "C" 14 | 15 | @staticmethod 16 | def grid(n: int) -> X: return np.cos(π * utils.grid(n))/2+0.5 17 | 18 | @staticmethod 19 | def ix(n: int) -> X: return np.arange(n) 20 | 21 | @staticmethod 22 | def fn(n: int, x: X) -> X: return np.cos(np.arange(n) * np.arccos(x*2-1)) 23 | 24 | def __getitem__(self, s: int) -> X: 25 | if isinstance(s, Tuple): s, = s 26 | return self(utils.grid(s)) 27 | 28 | # --------------------------------- TRANSFORM -------------------------------- # 29 | 30 | @staticmethod 31 | def transform(x: X): 32 | 33 | coef = np.fft.hfft(x, axis=0, norm="forward")[:len(x)] 34 | coef = coef.at[1:-1].multiply(2) 35 | 36 | assert len(x) > 1, "sharp bits!" 37 | return Chebyshev(coef) 38 | 39 | def inv(self) -> X: 40 | 41 | coef = self.coef.at[1:-1].divide(2) 42 | coef = np.concatenate([coef, coef[::-1][1:-1]]) 43 | 44 | return np.fft.ihfft(coef, axis=0, norm="forward").real 45 | 46 | # --------------------------------- OPERATOR --------------------------------- # 47 | 48 | def grad(self, k: int = 1): 49 | 50 | coef = np.linalg.matrix_power(np.pad(gradient(len(self)), [(0, 1), (0, 0)]), k) 51 | return Chebyshev(np.tensordot(coef, self.coef, (1, 0))[..., np.newaxis]) 52 | 53 | def int(self, k: int = 1): 54 | 55 | coef = np.linalg.matrix_power(integrate(len(self))[:-1], k) 56 | return Chebyshev(np.tensordot(coef, self.coef, (1, 0))[..., np.newaxis]) 57 | 58 | # ---------------------------------------------------------------------------- # 59 | # MATRIX # 60 | # ---------------------------------------------------------------------------- # 61 | 62 | """ 63 | Chebyshev gradient and integrate matrix 64 | 65 | - gradient ∈ R ^ n-1⨉n; integrate ∈ R ^ n+1⨉n 66 | - When aligned, they are pseudo-inverse of each other: 67 | `gradient(n+1) @ integrate(n) == identity(n)` 68 | """ 69 | 70 | def gradient(n: int) -> X: 71 | 72 | alternate = np.pad(np.eye(2), [(0, n-3), (0, n-3)], mode="reflect").at[0].divide(2) 73 | coef = np.concatenate([np.zeros(n - 1)[:, np.newaxis], np.triu(alternate)], axis=1) 74 | 75 | return coef * 4 * np.arange(n) 76 | 77 | def integrate(n: int) -> X: 78 | 79 | shift = np.identity(n).at[0, 0].set(2) - np.eye(n, k=2) 80 | coef = np.concatenate([np.zeros(n)[np.newaxis], shift]) 81 | 82 | return coef.at[1:].divide(4 * np.arange(1, n+1)[:, np.newaxis]) 83 | -------------------------------------------------------------------------------- /src/basis/fourier.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | @struct.dataclass 4 | class Fourier(Series): 5 | 6 | """ 7 | Trigonometric series 8 | 9 | fk(x) = e^{ik·x} 10 | """ 11 | 12 | @staticmethod 13 | def repr() -> str: return "F" 14 | 15 | @staticmethod 16 | def grid(n: int) -> X: return utils.grid(n, mode="left") 17 | 18 | @staticmethod 19 | def ix(n: int) -> X: return np.r_[-n//2+1:n//2+1] 20 | 21 | @staticmethod 22 | def fn(n: int, x: X) -> X: 23 | 24 | return np.moveaxis(real(np.moveaxis(np.exp(x * -freq(n)), -1, 0), n), 0, -1) 25 | 26 | def __getitem__(self, s: int) -> X: 27 | if isinstance(s, Tuple): s, = s 28 | 29 | return np.concatenate([x:=self.to(s - 1).inv(), x[np.r_[0]]]) 30 | 31 | # --------------------------------- TRANSFORM -------------------------------- # 32 | 33 | @staticmethod 34 | def transform(x: X): 35 | 36 | coef = np.fft.rfft(x, axis=0, norm="forward") 37 | coef = coef.at[1:-(len(x)//-2)].multiply(2) 38 | 39 | return Fourier(real(coef, len(x))) 40 | 41 | def inv(self) -> X: 42 | 43 | coef = comp(self.coef, n:=len(self)) 44 | coef = coef.at[1:-(n//-2)].divide(2) 45 | 46 | return np.fft.irfft(coef, len(self), axis=0, norm="forward") 47 | 48 | # --------------------------------- OPERATOR --------------------------------- # 49 | 50 | def grad(self, k: int = 1): 51 | 52 | coef = np.expand_dims(freq(len(self))**k, range(1, self.coef.ndim)) 53 | return Fourier(real(comp(self.coef, n:=len(self)) * coef, n)[..., None]) 54 | 55 | def int(self, k: int = 1): 56 | 57 | coef = np.expand_dims(self.freq(len(self)), range(1, self.coef.ndim)) 58 | return Fourier((self.coef / coef ** k)[..., np.newaxis].at[0].set(0)) 59 | 60 | # ---------------------------------------------------------------------------- # 61 | # HELPER # 62 | # ---------------------------------------------------------------------------- # 63 | 64 | def freq(n: int) -> X: return np.arange(n//2+1) * 2j * π 65 | 66 | def real(coef: X, n: int) -> X: 67 | 68 | """Complex coef -> Real coef""" 69 | 70 | cos, sin = coef.real, coef.imag[1:-(n//-2)] 71 | return np.concatenate((cos, sin[::-1]), 0) 72 | 73 | def comp(coef: X, n: int) -> X: 74 | 75 | """Real coef -> Complex coef""" 76 | 77 | cos, sin = np.split(coef, (m:=n//2+1, )) 78 | return (cos+0j).at[n-m:0:-1].add(sin*1j) 79 | -------------------------------------------------------------------------------- /src/dists.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | class Ω(ABC): 4 | 5 | """ 6 | Distribution 7 | """ 8 | 9 | @abstractmethod 10 | def sample(self, prng) -> X: pass 11 | 12 | # ---------------------------------------------------------------------------- # 13 | # UNIFORM # 14 | # ---------------------------------------------------------------------------- # 15 | 16 | class Uniform: 17 | 18 | """ 19 | Uniform distribution 20 | """ 21 | 22 | min: X 23 | max: X 24 | 25 | def __init__(self, min: X, max: X): 26 | 27 | self.min = np.array(min) 28 | self.max = np.array(max) 29 | 30 | def sample(self, prng, shape=()) -> X: 31 | 32 | scale = self.max - self.min 33 | x = random.uniform(prng, shape + scale.shape) 34 | 35 | return x * scale + self.min 36 | 37 | # ---------------------------------------------------------------------------- # 38 | # NORMAL # 39 | # ---------------------------------------------------------------------------- # 40 | 41 | class Normal: 42 | 43 | """ 44 | Normal distribution 45 | """ 46 | 47 | μ: X 48 | λ: X 49 | 50 | def __init__(self, μ: X, Σ: X): 51 | 52 | self.μ = μ 53 | 54 | U, Λ, _ = np.linalg.svd(Σ) 55 | self.λ = U * np.sqrt(Λ) 56 | 57 | def sample(self, prng, shape=()) -> X: 58 | 59 | var = random.normal(prng, shape + self.μ.shape) 60 | ε = np.einsum("...ij,...j->...i", self.λ, var) 61 | 62 | return self.μ + ε 63 | 64 | # ---------------------------------------------------------------------------- # 65 | # GAUSSIAN # 66 | # ---------------------------------------------------------------------------- # 67 | 68 | class Gaussian(Normal): 69 | 70 | """ 71 | Gaussian Process 72 | """ 73 | 74 | dim: Tuple[int] 75 | 76 | def __init__(self, grid: X, kernel: Fx): 77 | 78 | *dim, ndim = grid.shape 79 | assert len(dim) == ndim 80 | 81 | X = grid.reshape(-1, ndim) 82 | K = jax.vmap(kernel, (0, None)) 83 | Σ = jax.vmap(lambda y: K(X, y))(X) 84 | 85 | super().__init__(np.zeros(len(Σ)), Σ) 86 | self.dim = tuple(dim) 87 | 88 | def sample(self, prng, shape=()) -> X: 89 | 90 | x = super().sample(prng, shape) 91 | return x.reshape(shape + self.dim) 92 | 93 | # ---------------------------------- KERNEL ---------------------------------- # 94 | 95 | RBF = lambda ƛ: lambda x, y: np.exp(-np.sum((x-y)**2) / ƛ**2/2) 96 | Per = lambda ƛ: lambda x, y: np.exp(-np.sum((np.sin(π*(x-y))/2)**2) / ƛ**2*2) 97 | -------------------------------------------------------------------------------- /src/model/README: -------------------------------------------------------------------------------- 1 | This module implements both SNO, FNO and NSM. The `_base` module defines shared 2 | classes used by neural operators, namely the spectral convolution layers. It is 3 | a general implementation for arbitrary input dimensions, defined by the `Basis` 4 | class. Similarily, the SNO and FNO models are also compatible to any dimension. 5 | 6 | Each model is associated with a `loss` function, which defines how the loss is 7 | obtained for a given sample of the parameter function. For an instantiation of 8 | the PDE class, this function returns a dictionary of the loss terms separately. 9 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | from ..pde import * 3 | from ..basis import * 4 | 5 | # ---------------------------------------------------------------------------- # 6 | # SOLVER # 7 | # ---------------------------------------------------------------------------- # 8 | 9 | class Solver(ABC, nn.Module): 10 | 11 | pde: PDE 12 | cfg: Dict 13 | 14 | @F.cached_property 15 | def activate(self) -> Fx: return \ 16 | getattr(nn, self.cfg["activate"]) 17 | 18 | @abstractmethod 19 | def u(self, ϕ: Fx, x: X) -> X: pass 20 | 21 | @abstractmethod 22 | def loss(self, ϕ: Fx) -> Dict[str, X]: pass 23 | 24 | # ----------------------------------- PINN ----------------------------------- # 25 | 26 | class PINN(Solver, ABC): 27 | 28 | @abstractmethod 29 | def forward(self, ϕ: Basis, s: Tuple[int]) -> Tuple[X, X]: pass 30 | 31 | def u(self, ϕ: Basis, x: Tuple[int]) -> X: 32 | 33 | assert isinstance(x, Tuple), "uniform" 34 | assert len(x) == self.pde.domain.ndim 35 | 36 | return self.forward(ϕ, x)[-1] 37 | 38 | def loss(self, ϕ: Basis) -> Dict[str, X]: 39 | 40 | x, y, u = self.forward(ϕ, (self.cfg["grid"], ) * (d:=self.pde.domain.ndim)) 41 | edges = [(np.take(u, 0, axis=n), np.take(u, -1, axis=n)) for n in range(d)] 42 | 43 | R = self.pde.equation(x, y, u) 44 | B = self.pde.boundary(edges) 45 | 46 | return dict( 47 | residual=np.mean(R**2), **{ 48 | f"boundary{n}": np.mean(Bn**2) 49 | for n, Bn in enumerate(B) 50 | }) 51 | 52 | # --------------------------------- SPECTRAL --------------------------------- # 53 | 54 | class Spectral(Solver, ABC): 55 | 56 | @abstractmethod 57 | def forward(self, ϕ: Basis) -> Basis: pass 58 | 59 | def u(self, ϕ: Basis, x: X) -> X: 60 | 61 | u = self.forward(ϕ) 62 | 63 | if isinstance(x, Tuple): return u[x] 64 | if isinstance(x, Array): return u(x) 65 | 66 | def loss(self, ϕ: Basis) -> Dict[str, X]: 67 | 68 | R = self.pde.spectral(ϕ, self.forward(ϕ)) 69 | return dict(residual=np.sum(np.square(R.coef))) 70 | 71 | # ---------------------------------------------------------------------------- # 72 | # TRAINER # 73 | # ---------------------------------------------------------------------------- # 74 | 75 | class Trainer(ABC, nn.Module): 76 | 77 | mod: Solver 78 | pde: PDE 79 | cfg: Dict 80 | 81 | def setup(self): 82 | 83 | # --------------------------------- SCHEDULER -------------------------------- # 84 | 85 | if self.cfg["schd"] is None: 86 | 87 | scheduler = self.cfg["lr"] 88 | 89 | if self.cfg["schd"] == "cos": 90 | 91 | scheduler = optax.cosine_decay_schedule(self.cfg["lr"], self.cfg["iter"]) 92 | 93 | if self.cfg["schd"] == "exp": 94 | 95 | decay_rate = 1e-3 ** (1.0 / self.cfg["iter"]) 96 | scheduler = optax.exponential_decay(self.cfg["lr"], 1, decay_rate) 97 | 98 | # --------------------------------- OPTIMIZER -------------------------------- # 99 | 100 | self.optimizer = optax.adam(scheduler) 101 | 102 | @nn.compact 103 | def init(self): 104 | 105 | ϕ = self.pde.params.sample(prng:=self.make_rng("sample")) 106 | s = tuple([self.cfg["grid"]] * self.pde.domain.ndim) 107 | 108 | variable = self.mod.init(prng, ϕ, s, method="u") 109 | print(self.mod.tabulate(prng, ϕ, s, method="u")) 110 | 111 | self.variable("optim", "state", self.optimizer.init, variable["params"]) 112 | 113 | return variable 114 | -------------------------------------------------------------------------------- /src/model/_base.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from ..basis import * 3 | 4 | class SpectralConv(nn.Module): 5 | 6 | odim: int 7 | 8 | # do truncation or not 9 | mode: Tuple[int] = None 10 | 11 | # initialization modes 12 | init: Tuple[int] = None 13 | 14 | @nn.compact 15 | def __call__(self, u: Basis) -> Basis: 16 | 17 | def W(a: X) -> X: 18 | 19 | def init(prng, *shape: Tuple[int]) -> X: 20 | x = random.normal(prng, shape) 21 | 22 | *mode, idim, odim = shape 23 | 24 | if self.init is None: 25 | 26 | scale = 1 / idim / odim 27 | 28 | else: 29 | 30 | from math import prod 31 | rate = prod(self.init) / prod(mode) 32 | 33 | scale = np.sqrt(rate / idim) 34 | 35 | return x * scale 36 | 37 | W = self.param("W", init, *a.shape, self.odim) 38 | 39 | dims = (N:=u.ndim(), N), (B:=range(N), B) 40 | return jax.lax.dot_general(a, W, dims) 41 | 42 | mode = self.mode or u.mode 43 | return u.to(*mode).map(W).to(*u.mode) 44 | 45 | class SpatialMixing(nn.Module): 46 | 47 | @nn.compact 48 | def __call__(self, u: Basis) -> Basis: 49 | 50 | def M(a: X, i: int) -> X: 51 | 52 | def init(prng, *shape: Tuple[int]) -> X: 53 | x = random.normal(prng, shape) 54 | 55 | return x / np.sqrt(shape[-1]) 56 | 57 | M = self.param(f"M{i}", init, *a.shape, a.shape[i]) 58 | 59 | batch = [*range(i), *range(i + 1, u.ndim()), u.ndim()] 60 | a = jax.lax.dot_general(a, M, ((i, i), (batch, batch))) 61 | 62 | return np.moveaxis(a, -1, i) 63 | 64 | return u.map(F.partial(F.reduce, M, range(u.ndim()))) 65 | -------------------------------------------------------------------------------- /src/model/fno/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | from .._base import * 3 | from ...basis.fourier import * 4 | 5 | # ---------------------------------------------------------------------------- # 6 | # SOLVER # 7 | # ---------------------------------------------------------------------------- # 8 | 9 | class FNO(PINN): 10 | 11 | def __repr__(self): return f"FNOx{self.cfg['grid']}+PINN" 12 | 13 | @nn.compact 14 | def forward(self, ϕ: Basis, s: Tuple[int]) -> Tuple[X, X]: 15 | 16 | if self.cfg["cheb"]: T = self.pde.basis 17 | else: T = Fourier[self.pde.domain.ndim] 18 | 19 | from .. import utils 20 | x = utils.grid(*s) 21 | 22 | z = np.concatenate([x, y:=ϕ[s]], -1) 23 | 24 | z = nn.Dense(self.cfg["hdim"] * 4)(z) 25 | z = self.activate(z) 26 | 27 | z = nn.Dense(self.cfg["hdim"])(z) 28 | z = self.activate(z) 29 | 30 | for _ in range(self.cfg["depth"]): 31 | 32 | conv = SpectralConv(self.cfg["hdim"], self.cfg["mode"])(T.transform(z)).inv() 33 | fc = nn.Dense(self.cfg["hdim"])(z) 34 | 35 | z = conv + fc 36 | z = self.activate(z) 37 | 38 | z = nn.Dense(self.cfg["hdim"])(z) 39 | z = self.activate(z) 40 | 41 | z = nn.Dense(self.cfg["hdim"] * 4)(z) 42 | z = self.activate(z) 43 | 44 | z = nn.Dense(self.pde.odim)(z) 45 | return x, y, self.pde.mollifier(y, (x, z)) 46 | -------------------------------------------------------------------------------- /src/model/fno/spectral.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | from .._base import * 3 | from ...basis.fourier import * 4 | 5 | # ---------------------------------------------------------------------------- # 6 | # SOLVER # 7 | # ---------------------------------------------------------------------------- # 8 | 9 | class FNO(Spectral): 10 | 11 | def __repr__(self): return "NSM" 12 | 13 | @nn.compact 14 | def forward(self, ϕ: Basis) -> Basis: 15 | 16 | if not self.cfg["fourier"]: T = self.pde.basis 17 | else: T = Fourier[self.pde.domain.ndim] 18 | 19 | u = ϕ.to(*self.cfg["mode"]) 20 | 21 | bias = T.transform(u.grid(*u.mode)).coef 22 | u = u.map(lambda coef: np.concatenate([coef, bias], axis=-1)) 23 | 24 | u = u.map(nn.Dense(self.cfg["hdim"] * 4)) 25 | u = T.transform(self.activate(u.inv())) 26 | 27 | u = u.map(nn.Dense(self.cfg["hdim"])) 28 | u = T.transform(self.activate(u.inv())) 29 | 30 | for _ in range(self.cfg["depth"]): 31 | 32 | conv = SpectralConv(self.cfg["hdim"])(u) 33 | fc = u.map(nn.Dense(self.cfg["hdim"])) 34 | 35 | u = T.add(conv, fc) 36 | u = T.transform(self.activate(u.inv())) 37 | 38 | u = u.map(nn.Dense(self.cfg["hdim"])) 39 | u = T.transform(self.activate(u.inv())) 40 | 41 | u = u.map(nn.Dense(self.cfg["hdim"] * 4)) 42 | u = T.transform(self.activate(u.inv())) 43 | 44 | u = u.map(nn.Dense(self.pde.odim)) 45 | return self.pde.mollifier(ϕ, u) 46 | -------------------------------------------------------------------------------- /src/model/sno/spectral.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | from .._base import * 3 | 4 | # ---------------------------------------------------------------------------- # 5 | # SOLVER # 6 | # ---------------------------------------------------------------------------- # 7 | 8 | class SNO(Spectral): 9 | 10 | def __repr__(self): return "SNO" 11 | 12 | @nn.compact 13 | def forward(self, ϕ: Basis) -> Basis: 14 | 15 | u = ϕ.to(*self.cfg["mode"]) 16 | 17 | bias = u.transform(u.grid(*u.mode)).coef 18 | u = u.map(lambda coef: np.concatenate([coef, bias], axis=-1)) 19 | 20 | u = u.map(nn.Dense(self.cfg["hdim"] * 4)) 21 | u = u.map(self.activate) 22 | 23 | u = u.map(nn.Dense(self.cfg["hdim"])) 24 | u = u.map(self.activate) 25 | 26 | for _ in range(self.cfg["depth"]): 27 | 28 | def Integral(coef: X) -> X: 29 | 30 | K = nn.DenseGeneral(u.mode, axis=range(-u.ndim(), 0)) 31 | return np.moveaxis(K(np.moveaxis(coef, -1, 0)), 0, -1) 32 | 33 | u = u.map(Integral) 34 | 35 | u = u.map(nn.Dense(self.cfg["hdim"])) 36 | u = u.map(self.activate) 37 | 38 | u = u.map(nn.Dense(self.cfg["hdim"])) 39 | u = u.map(self.activate) 40 | 41 | u = u.map(nn.Dense(self.cfg["hdim"] * 4)) 42 | u = u.map(self.activate) 43 | 44 | u = u.map(nn.Dense(self.pde.odim)) 45 | return self.pde.mollifier(ϕ, u) 46 | -------------------------------------------------------------------------------- /src/pde/README: -------------------------------------------------------------------------------- 1 | This module defines the interface of PDE classess. Abstract classes in `_domain` 2 | and `_params` modules defines domain and PDE parameter. Functions in `mollifier` 3 | define how each type of boundary condition can be satisfied by transforming the 4 | predicted solution. Built on these classes, the general `PDE` class defines how 5 | the residual terms are computed in both the physical and spectral domain, which 6 | will be used by the models for loss calculation. 7 | 8 | Each subdirectory contains a instantiation of these interface, together with a 9 | `generate` module for numerical solution generation. Follow these templates if 10 | you want to try a new problem. 11 | -------------------------------------------------------------------------------- /src/pde/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | 3 | class PDE(ABC): 4 | 5 | from ._domain import R 6 | from ._params import G 7 | 8 | odim: int # output dimension 9 | 10 | domain: R # interior domain 11 | params: G # parameter function 12 | 13 | mollifier: Fx # transformation 14 | 15 | equation: Fx # PDE (equation) 16 | boundary: Fx # PDE (boundary) 17 | 18 | from ..basis import Basis 19 | basis: Basis # basis function 20 | spectral: Fx # PDE (spectral) 21 | 22 | solution: Any 23 | -------------------------------------------------------------------------------- /src/pde/_domain.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from ..dists import * 3 | 4 | class R(Ω): 5 | 6 | """ 7 | Euclidean space 8 | """ 9 | 10 | ndim: int 11 | boundary: List[Ω] 12 | 13 | # ---------------------------------------------------------------------------- # 14 | # RECT # 15 | # ---------------------------------------------------------------------------- # 16 | 17 | class Rect(Uniform, R): 18 | 19 | """ 20 | N-d unit rectangle 21 | """ 22 | 23 | def __init__(self, ndim: int): 24 | super().__init__(np.zeros(ndim), 25 | np.ones(ndim)) 26 | 27 | class Boundary(Uniform): 28 | 29 | def __init__(self, dim: int): 30 | super().__init__(np.zeros(ndim-1), 31 | np.ones(ndim-1)) 32 | 33 | self.dim = dim 34 | 35 | def sample(self, prng, shape=()) -> X: 36 | 37 | x = super().sample(prng, shape) 38 | return np.insert(x, self.dim, np.zeros(shape), axis=-1), \ 39 | np.insert(x, self.dim, np.ones(shape), axis=-1) 40 | 41 | self.ndim = ndim 42 | self.boundary = list(map(Boundary, range(ndim))) 43 | -------------------------------------------------------------------------------- /src/pde/_params.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from ..dists import * 3 | 4 | class G(ABC): 5 | 6 | """ 7 | Function space 8 | """ 9 | 10 | idim: int 11 | odim: int 12 | 13 | @abstractmethod 14 | def sample(self, prng) -> Fx: pass 15 | 16 | # ---------------------------------------------------------------------------- # 17 | # INTERPOLATE # 18 | # ---------------------------------------------------------------------------- # 19 | 20 | from ..basis import * 21 | class Interpolate(G): 22 | 23 | """ 24 | Interpolated function 25 | """ 26 | 27 | def __init__(self, dist: Ω, basis: Type[Basis]): 28 | 29 | self.dist = dist 30 | self.basis = basis 31 | 32 | self.idim = len(dist.dim) 33 | self.odim = 1 34 | 35 | def sample(self, prng, shape=()) -> Basis: 36 | 37 | x = self.dist.sample(prng, shape)[..., None] 38 | return utils.nmap(self.basis.transform, len(shape))(x) 39 | -------------------------------------------------------------------------------- /src/pde/mollifier.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from ..basis import * 3 | 4 | SCALE = 1e-3 5 | 6 | # ---------------------------------------------------------------------------- # 7 | # PERIODIC # 8 | # ---------------------------------------------------------------------------- # 9 | 10 | def periodic(ϕ: X, u: X) -> X: 11 | 12 | if isinstance(ϕ, Basis): 13 | base = (0, ) * u.ndim() 14 | origin = np.array(base) 15 | 16 | return u.map(lambda coef: coef.at[base].add(-u(origin)) * SCALE) 17 | 18 | if isinstance(ϕ, Array): 19 | x, uofx = u 20 | 21 | return (uofx - uofx[(0,)*(uofx.ndim-1)]) * SCALE 22 | 23 | # ---------------------------------------------------------------------------- # 24 | # DIRICHLET # 25 | # ---------------------------------------------------------------------------- # 26 | 27 | def dirichlet(ϕ: X, u: X) -> X: 28 | 29 | if isinstance(ϕ, Basis): 30 | x = u.grid(*u.mode) 31 | 32 | mol = np.prod(np.sin(π*x), axis=-1, keepdims=True) 33 | return u.transform(u.inv() * mol * SCALE) 34 | 35 | if isinstance(ϕ, Array): 36 | x, uofx = u 37 | 38 | mol = np.prod(np.sin(π*x), axis=-1, keepdims=True) 39 | return uofx * mol * SCALE 40 | 41 | # ---------------------------------------------------------------------------- # 42 | # INITIAL-CONDITION # 43 | # ---------------------------------------------------------------------------- # 44 | 45 | def initial_condition(ϕ: X, u: X) -> X: 46 | 47 | """ 48 | Initial condition problem. The first dimension is temporal and the rest 49 | of them have periodic boundaries. 50 | """ 51 | 52 | if isinstance(ϕ, Basis): 53 | 54 | mol = u.grid(*u.mode)[..., [0]] * SCALE 55 | return u.__class__.add(u.transform(u.inv() * mol), ϕ) 56 | 57 | if isinstance(ϕ, Array): 58 | x, uofx = u 59 | 60 | mol = x[..., [0]] * SCALE 61 | return uofx * mol + ϕ 62 | -------------------------------------------------------------------------------- /src/pde/navierstokes/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | from ...dists import * 3 | from ...basis import * 4 | 5 | from .._domain import * 6 | from .._params import * 7 | 8 | from ...basis.fourier import * 9 | from ...basis.chebyshev import * 10 | 11 | def k(nx: int, ny: int) -> Tuple[X]: 12 | 13 | return np.tile(np.fft.fftfreq(nx)[:, None] * nx * 2*π, (1, ny)), \ 14 | np.tile(np.fft.fftfreq(ny)[None, :] * ny * 2*π, (nx, 1)) 15 | 16 | def velocity(what: X = None, *, w: X = None) -> X: 17 | 18 | if what is None: 19 | what = np.fft.fft2(w) 20 | kx, ky = k(*what.shape) 21 | 22 | Δ = kx ** 2 + ky ** 2 23 | Δ = Δ.at[0, 0].set(1) 24 | 25 | vx = np.fft.irfft2(what * 1j*ky / Δ, what.shape) 26 | vy = np.fft.irfft2(what *-1j*kx / Δ, what.shape) 27 | 28 | return vx, vy 29 | 30 | class Initial(Gaussian): 31 | 32 | grid = Fourier[2].grid(64, 64) 33 | 34 | def __str__(self): return f"{self.length}x{self.scaling}" 35 | 36 | def __init__(self, length: float, scaling: float = 1.0): 37 | super().__init__(Initial.grid, Gaussian.Per(length)) 38 | 39 | self.length = length 40 | self.scaling = scaling 41 | 42 | def sample(self, prng, shape=()) -> X: 43 | 44 | x = super().sample(prng, shape) 45 | x-= np.mean(x, (-2, -1), keepdims=True) 46 | 47 | x = self.scaling * x[..., np.newaxis, :, :] 48 | return np.broadcast_to(x, shape + (2, *self.dim)) 49 | 50 | class NavierStokes(PDE): 51 | 52 | """ 53 | wt + v ∇w = nu ∆w 54 | where 55 | ∇⨉v = w 56 | ∇·v = 0 57 | """ 58 | 59 | T: int # end time 60 | nu: float # viscosity 61 | l: float # length scale 62 | 63 | # forcing term? 64 | fn: Optional[Fx] 65 | 66 | def __str__(self): return f"Re={int(self.Re)}:T={self.T}:{self.F}" 67 | def __init__(self, ic: Initial, T: float, nu: float, fn: Fx = None): 68 | 69 | self.odim = 1 70 | self.ic = ic 71 | 72 | self.T = T 73 | self.nu = nu 74 | self.fn = fn 75 | 76 | self.l = (l:=ic.length) 77 | self.Re = l / nu * ic.scaling 78 | 79 | if fn is None: self.F = None 80 | else: self.F = fn.__name__ 81 | 82 | self.domain = Rect(3) 83 | 84 | self.basis = series(Chebyshev, Fourier, Fourier) 85 | self.params = Interpolate(ic, self.basis) 86 | 87 | from ..mollifier import initial_condition 88 | self.mollifier = initial_condition 89 | 90 | @F.cached_property 91 | def solution(self): 92 | 93 | dir = os.path.dirname(__file__) 94 | 95 | with jax.default_device(jax.devices("cpu")[0]): 96 | 97 | # solution data are typically large 98 | # transfer to RAM in the first place 99 | 100 | w = np.load(f"{dir}/w.{self.ic}.npy") 101 | u = np.load(f"{dir}/u.{self}.npy") 102 | 103 | return jax.vmap(self.basis)(w), u.shape[1:-1], u 104 | 105 | def equation(self, x: X, w0: X, w: X) -> X: 106 | w, w1, w2 = utils.fdm(w, n=2) 107 | 108 | wt = w1[..., 0, 0] 109 | wx = w1[..., 0, 1] 110 | wy = w1[..., 0, 2] 111 | Δw = Δ(w2[..., 0, 1:, 1:]) 112 | 113 | vx, vy = jax.vmap(velocity)(w=w.squeeze(-1)) 114 | Dwdt = wt / self.T + (vx * wx + vy * wy) 115 | 116 | if self.fn is None: f = np.zeros_like(Dwdt) 117 | else: f = self.fn(*w[0].squeeze(-1).shape) 118 | 119 | return Dwdt - self.nu * Δw - f 120 | 121 | def boundary(self, w: List[Tuple[X]]) -> List[X]: 122 | 123 | _, (wt, wb), (wl, wr) = w 124 | return [wt - wb, wl - wr] 125 | 126 | def spectral(self, w0: Basis, w: Basis) -> Basis: 127 | w1 = w.grad(); w2 = w1.grad() 128 | 129 | wt = self.basis(w1.coef[..., 0, 0]) 130 | wx = self.basis(w1.coef[..., 0, 1]) 131 | wy = self.basis(w1.coef[..., 0, 2]) 132 | Δw = self.basis(Δ(w2.coef[..., 0, 1:, 1:])) 133 | 134 | vx, vy = jax.vmap(velocity)(w=w.inv().squeeze(-1)) 135 | Dwdt = self.basis.add(wt.map(lambda coef: coef / self.T), 136 | self.basis.transform(vx * wx.inv() + vy * wy.inv())) 137 | 138 | if self.fn is None: f = self.basis(np.zeros_like(Dwdt.coef)) 139 | else: f = self.basis.transform(np.broadcast_to(self.fn(*w.mode[1:]), w.mode)) 140 | 141 | return self.basis.add(Dwdt, self.basis(-self.nu * Δw.coef), f.map(np.negative)) 142 | 143 | ic = Initial(0.8) 144 | 145 | # ------------------------------- UNFORCED FLOW ------------------------------ # 146 | 147 | re2 = NavierStokes(ic, T=3, nu=1e-2) 148 | re3 = NavierStokes(ic, T=3, nu=1e-3) 149 | re4 = NavierStokes(ic, T=3, nu=1e-4) 150 | 151 | # ------------------------------ TRANSIENT FLOW ------------------------------ # 152 | 153 | def transient(nx: int, ny: int) -> X: 154 | 155 | xy = utils.grid(nx, ny, mode="left").sum(-1) 156 | return 0.1*(np.sin(2*π*xy) + np.cos(2*π*xy)) 157 | 158 | tf = NavierStokes(ic, T=50, nu=2e-3, fn=transient) 159 | -------------------------------------------------------------------------------- /src/pde/navierstokes/generate.py: -------------------------------------------------------------------------------- 1 | # Modified from neuraloperator. Commit ef3de3bb1140175c69a9fe3a8b45afd1335077d9 2 | # https://github.com/neuraloperator/neuraloperator/blob/master/data_generation/navier_stokes/ns_2d.py 3 | 4 | from . import * 5 | 6 | def simulate(w0: X, nu: float, f: X) -> Fx: 7 | 8 | """ 9 | Returns: 10 | u: callable what -> what' for next step 11 | advance vorticity in spectral domain 12 | """ 13 | 14 | s, s = w0.shape 15 | kx, ky = k(s, s) 16 | 17 | diffuse = (Δ := kx ** 2 + ky ** 2) * nu 18 | 19 | dealias = (Δ < (2/3 * π * s) ** 2).astype(float) 20 | dealias = dealias.at[0, 0].set(0) # zero-mean 21 | 22 | def Δhat(what: X) -> X: 23 | 24 | vx, vy = velocity(what) 25 | 26 | wx = np.fft.irfft2(what * 1j * kx, (s, s)) 27 | wy = np.fft.irfft2(what * 1j * ky, (s, s)) 28 | 29 | vxwx = np.fft.fft2(vx * wx, (s, s)) 30 | vywy = np.fft.fft2(vy * wy, (s, s)) 31 | 32 | return np.fft.fft2(f) - vxwx - vywy 33 | 34 | def call(what: X, dt: float) -> X: 35 | 36 | Δhat1 = Δhat(what) # Heun's method 37 | 38 | what_tilde = what + dt * (Δhat1 - diffuse * what / 2) 39 | what_tilde/= 1 + dt * diffuse / 2 40 | 41 | Δhat2 = Δhat(what_tilde) # Cranck-Nicholson + Heun 42 | 43 | what = what + dt * ((Δhat1 + Δhat2) - diffuse * what) / 2 44 | what/= 1 + dt * diffuse / 2 45 | 46 | return what * dealias 47 | 48 | return call 49 | 50 | def solution(w0: X, T: float, nu: float, force: Fx, 51 | dt: float, nt: int) -> X: 52 | 53 | """ 54 | Args: 55 | w0: initial condition 56 | 57 | T: total time 58 | nu: viscosity 59 | force: -ing term 60 | 61 | dt: advance step 62 | nt: record step 63 | 64 | Returns: 65 | u: solution recorded at each timestep 66 | inclusive of the end time 67 | shape = (nt, *w0.shape) 68 | """ 69 | 70 | if not force: f = np.zeros_like(w0) 71 | else: f = force(*w0.shape) 72 | 73 | step = simulate(w0, nu, f) 74 | Δt = T / (N := nt - 1) 75 | 76 | def record(what: X, _) -> Tuple[X, X]: 77 | call = lambda _, what: step(what, dt) 78 | 79 | what = step(jax.lax.fori_loop(0., Δt // dt, call, what), Δt % dt) 80 | return what, np.fft.irfft2(what, s=w0.shape) 81 | 82 | _, w = jax.lax.scan(record, np.fft.fft2(w0), None, N) 83 | return np.concatenate([w0[np.newaxis, :], w], axis=0) 84 | 85 | # ---------------------------------------------------------------------------- # 86 | # GENERATE # 87 | # ---------------------------------------------------------------------------- # 88 | 89 | def generate(pde: NavierStokes, dt: float = 1e-3, T: int = 64, X: int = 256): 90 | 91 | params = pde.params.sample(random.PRNGKey(0), (128, )) 92 | solve = F.partial(solution, T=pde.T, nu=pde.nu, force=pde.fn, dt=dt, nt=T) 93 | 94 | w = jax.vmap(solve)(jax.vmap(lambda w: w.to(1, X, X).inv().squeeze())(params)) 95 | w = np.pad(w, [(0, 0), (0, 0), (0, 1), (0, 1)], mode="wrap")[..., np.newaxis] 96 | 97 | dir = os.path.dirname(__file__) 98 | 99 | np.save(f"{dir}/w.{pde.ic}.npy", params.coef) 100 | np.save(f"{dir}/u.{pde}.npy", w) 101 | 102 | return w 103 | 104 | if __name__ == "__main__": 105 | 106 | from sys import argv 107 | 108 | if argv[1] == "ns": 109 | 110 | generate(re2) 111 | generate(re3) 112 | generate(re4) 113 | 114 | if argv[1] == "tf": 115 | 116 | generate(tf, dt=5e-3) 117 | -------------------------------------------------------------------------------- /src/pde/poisson/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | from ...dists import * 3 | from ...basis import * 4 | 5 | from .._domain import * 6 | from .._params import * 7 | 8 | class Poisson(PDE): 9 | 10 | """ 11 | -Δ u(x) = s(x) 12 | """ 13 | 14 | def __init__(self): 15 | 16 | self.odim = 1 17 | 18 | self.domain = Rect(2) 19 | 20 | def spectral(self, s: Basis, u: Basis) -> Basis: 21 | 22 | return self.basis.add(s, u.grad().grad().map(Δ)) 23 | 24 | # ---------------------------------------------------------------------------- # 25 | # PERIODIC # 26 | # ---------------------------------------------------------------------------- # 27 | 28 | from ...basis.fourier import * 29 | 30 | class Periodic(Poisson): 31 | 32 | def __init__(self, res: int): 33 | super().__init__() 34 | 35 | class Source(Gaussian): 36 | 37 | def sample(self, prng, shape=()) -> X: 38 | 39 | x = super().sample(prng, shape) 40 | μ = np.mean(x, (-2, -1), keepdims=True) 41 | 42 | return x - μ 43 | 44 | source = Source(Fourier[2].grid(res, res), Gaussian.Per(0.2)) 45 | self.params = Interpolate(source, Fourier[2]) 46 | self.basis = Fourier[2] 47 | 48 | from ..mollifier import periodic 49 | self.mollifier = periodic 50 | 51 | def equation(self, x: X, s: X, u: X) -> X: 52 | 53 | s = s[:-1, :-1] # periodic 54 | u = u[:-1, :-1] # boundary 55 | 56 | # 5-point stencil for discrete laplacian 57 | 58 | Δ = np.roll(u, 1, 0) + np.roll(u, -1, 0) \ 59 | + np.roll(u, 1, 1) + np.roll(u, -1, 1) - 4 * u 60 | 61 | return s + Δ * len(u) ** 2 62 | 63 | def boundary(self, u: List[Tuple[X]]) -> List[X]: 64 | 65 | return [ul - ur for ul, ur in u] 66 | 67 | @F.cached_property 68 | def solution(self): 69 | 70 | dir = os.path.dirname(__file__) 71 | 72 | s = np.load(f"{dir}/s.periodic.npy") 73 | u = np.load(f"{dir}/u.periodic.npy") 74 | 75 | return jax.vmap(self.basis)(s), u.shape[1:-1], u 76 | 77 | # --------------------------------- INSTANCE --------------------------------- # 78 | 79 | periodic = Periodic(16) 80 | 81 | # ---------------------------------------------------------------------------- # 82 | # DIRICHLET # 83 | # ---------------------------------------------------------------------------- # 84 | 85 | from ...basis.chebyshev import * 86 | 87 | class Dirichlet(Poisson): 88 | 89 | def __init__(self, res: int): 90 | super().__init__() 91 | 92 | source = Gaussian(Chebyshev[2].grid(res, res), Gaussian.RBF(0.2)) 93 | self.params = Interpolate(source, Chebyshev[2]) 94 | self.basis = Chebyshev[2] 95 | 96 | from ..mollifier import dirichlet 97 | self.mollifier = dirichlet 98 | 99 | def equation(self, x: X, s: X, u: X) -> X: 100 | 101 | return s + Δ(utils.fdm(u, 2)[2]) 102 | 103 | def boundary(self, u: List[Tuple[X]]) -> List[X]: 104 | 105 | return [] 106 | 107 | @F.cached_property 108 | def solution(self): 109 | 110 | dir = os.path.dirname(__file__) 111 | 112 | s = np.load(f"{dir}/s.dirichlet.npy") 113 | u = np.load(f"{dir}/u.dirichlet.npy") 114 | 115 | return jax.vmap(self.basis)(s), u.shape[1:-1], u 116 | 117 | # --------------------------------- INSTANCE --------------------------------- # 118 | 119 | dirichlet = Dirichlet(16) 120 | -------------------------------------------------------------------------------- /src/pde/poisson/generate.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | def solution(s: Basis, res: int = 256) -> X: 4 | 5 | freq = map(lambda n: np.square(np.fft.fftfreq(n) * 2 * π * n), s.mode) 6 | u = s.map(lambda coef: coef / sum(np.meshgrid(*freq, indexing="ij"))[..., np.newaxis]) 7 | 8 | u = u.map(lambda coef: coef.at[(0, ) * u.ndim()].set(0)) 9 | u = u.map(lambda coef: coef.at[(0, ) * u.ndim()].add(-u(np.zeros(u.ndim())))) 10 | 11 | return u[res, res] 12 | 13 | def generate(pde: Periodic, N: int = 128, X: int = 256): 14 | 15 | params = pde.params.sample(random.PRNGKey(0), (N, )) 16 | u = jax.vmap(F.partial(solution, res=X))(params) 17 | 18 | dir = os.path.dirname(__file__) 19 | 20 | np.save(f"{dir}/s.periodic.npy", params.coef) 21 | np.save(f"{dir}/u.periodic.npy", u) 22 | 23 | return u 24 | 25 | if __name__ == "__main__": 26 | 27 | generate(periodic) 28 | -------------------------------------------------------------------------------- /src/pde/poisson/generate.wls: -------------------------------------------------------------------------------- 1 | (* Load parameters from numpy file *) 2 | Source = Last[ExternalEvaluate["Python", { 3 | "import numpy as np", 4 | "np.load(\"s.dirichlet.npy\")" 5 | }]]; 6 | 7 | (* Qurey grid points: uniformly spaced *) 8 | X = CoordinateBoundsArray[{{0, 255}, {0, 255}}] / 255; 9 | 10 | PDE = -Div[Grad[u[x, y], {x, y}], {x, y}] == s; 11 | BCs = u[0, y] == u[1, y] == u[x, 0] == u[x, 1] == 0; 12 | 13 | U = Table[ 14 | 15 | {nx, ny, one} = Dimensions[S]; 16 | 17 | cx = ChebyshevT[#, x] & /@ Range[0, nx-1] /. x -> 2x - 1; 18 | cy = ChebyshevT[#, y] & /@ Range[0, ny-1] /. y -> 2y - 1; 19 | s = Sum[S[[i, j, 1]] cx[[i]] cy[[j]], {i, nx}, {j, ny}]; 20 | 21 | (* Solve PDE & measure cost *) 22 | { Cost, Evaluation } = Timing[ 23 | Solution = NDSolveValue[{PDE, BCs}, u, {x, 0, 1}, {y, 0, 1}]; 24 | ArrayReshape[Solution @@@ Flatten[X, 1], {255, 255, 1}] 25 | ]; 26 | 27 | Print[Cost]; 28 | Evaluation 29 | 30 | , { S, Normal[Source] }]; 31 | 32 | (* Store answer to `.npy` *) 33 | ExternalEvaluate["Python", { 34 | "import numpy as np", 35 | "np.save(\"u.dirichlet.npy\", <*NumericArray[U, \"Real32\"]*>)", 36 | "print('DONE')" 37 | }]; 38 | -------------------------------------------------------------------------------- /src/pde/reaction/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import * 2 | from ...dists import * 3 | from ...basis import * 4 | 5 | from .._domain import * 6 | from .._params import * 7 | 8 | from ...basis.fourier import * 9 | from ...basis.chebyshev import * 10 | 11 | class ReactionDiffusion(PDE): 12 | 13 | """ 14 | ut - \nu uxx = \rho u (1 - u) 15 | """ 16 | 17 | rho: float # reaction coefficient 18 | nu: float # diffusion coefficient 19 | 20 | def __str__(self): return f"rho={self.rho}:nu={self.nu:.3f}" 21 | def __init__(self, res: int, rho: float, nu: float): 22 | 23 | self.odim = 1 24 | 25 | self.rho = rho 26 | self.nu = nu 27 | 28 | self.domain = Rect(2) 29 | self.basis = series(Chebyshev, Fourier) 30 | 31 | class Initial(Gaussian): 32 | 33 | def sample(self, prng, shape=()) -> X: 34 | x = super().sample(prng, shape) 35 | 36 | inf = np.min(x, axis=-1, keepdims=True) 37 | sup = np.max(x, axis=-1, keepdims=True) 38 | 39 | x, shape = (x - inf) / (sup - inf), shape + (2, res) 40 | return np.broadcast_to(x[..., np.newaxis, :], shape) 41 | 42 | initial = Initial(Fourier.grid(res), Gaussian.Per(0.2)) 43 | self.params = Interpolate(initial, self.basis) 44 | 45 | from ..mollifier import initial_condition 46 | self.mollifier = initial_condition 47 | 48 | @F.cached_property 49 | def solution(self): 50 | 51 | dir = os.path.dirname(__file__) 52 | 53 | h = np.load(f"{dir}/h.npy") 54 | u = np.load(f"{dir}/u.{self}.npy") 55 | 56 | return jax.vmap(self.basis)(h), u.shape[1:-1], u 57 | 58 | def equation(self, x: X, h: X, u: X) -> X: 59 | u, u1, u2 = utils.fdm(u, n=2) 60 | 61 | ut = u1[..., 0] 62 | uxx = u2[..., 1, 1] 63 | 64 | reaction = -self.rho * u * (1 - u) 65 | diffusion = -self.nu * uxx 66 | 67 | return ut + reaction + diffusion 68 | 69 | def boundary(self, u: List[Tuple[X]]) -> List[X]: 70 | 71 | _, (ul, ur) = u 72 | return [ul - ur] 73 | 74 | def spectral(self, h: Basis, u: Basis) -> Basis: 75 | 76 | u1 = u.map(lambda coef: coef.at[(0, 0)].add(-1)) 77 | 78 | reaction = u.__class__.mul(u, u1).map(lambda uu1: uu1 * self.rho) 79 | diffusion = u.grad(2).map(lambda u2: -u2[..., 1] * self.nu) 80 | 81 | ut = u.grad().map(lambda u1: u1[..., 0]) 82 | return u.__class__.add(ut, reaction, diffusion) 83 | 84 | # --------------------------------- INSTANCE --------------------------------- # 85 | 86 | nu005 = ReactionDiffusion(64, rho=5, nu=0.005) 87 | nu01 = ReactionDiffusion(64, rho=5, nu=0.01) 88 | nu05 = ReactionDiffusion(64, rho=5, nu=0.05) 89 | nu1 = ReactionDiffusion(64, rho=5, nu=0.1) 90 | -------------------------------------------------------------------------------- /src/pde/reaction/generate.py: -------------------------------------------------------------------------------- 1 | # Modified from characterizing-pinns-failure-modes. Commit 4390d09c507c117a37e621ab1b785a43f0c32f57 2 | # https://github.com/a1k12/characterizing-pinns-failure-modes/blob/main/pbc_examples/systems_pbc.py 3 | 4 | from . import * 5 | 6 | def reaction(u: X, rho: float, dt: float) -> X: 7 | 8 | """ 9 | du/dt = rho*u*(1-u) 10 | """ 11 | 12 | factor_1 = u * np.exp(rho * dt) 13 | factor_2 = 1 - u 14 | 15 | return factor_1 \ 16 | / (factor_1 + factor_2) 17 | 18 | def diffusion(u: X, nu: float, dt: float, IKX2: X) -> X: 19 | 20 | """ 21 | du/dt = nu*d2u/dx2 22 | """ 23 | 24 | factor = np.exp(nu * IKX2 * dt) 25 | u_hat = np.fft.fft(u) * factor 26 | 27 | return np.fft.ifft(u_hat).real 28 | 29 | def solution(h: Fx, nu: float, rho: float, nx=4096, nt=4097) -> X: 30 | 31 | """ 32 | Computes the discrete solution of the reaction-diffusion PDE using pseudo 33 | spectral operator splitting. 34 | 35 | Args: 36 | h: initial condition 37 | nu: diffusion coefficient 38 | rho: reaction coefficient 39 | nx: number of points in the x grid 40 | nt: number of points in the t grid 41 | 42 | Returns: 43 | x: grids 44 | u: solution 45 | """ 46 | 47 | L = 1 48 | T = 1 49 | dx = L/nx 50 | dt = T/(nt-1) 51 | x = np.arange(0, L, dx) # not inclusive of the last point 52 | t = np.linspace(0, T, nt) # inclusive of the end time 1 53 | u = np.zeros((nx, nt)) 54 | 55 | IKX_pos = 2j * π * np.arange(0, nx/2+1, 1) 56 | IKX_neg = 2j * π * np.arange(-nx/2+1, 0, 1) 57 | IKX = np.concatenate((IKX_pos, IKX_neg)) 58 | IKX2 = IKX * IKX 59 | 60 | u = [_u:=jax.vmap(h)(x)] 61 | 62 | for _ in range(nt-1): 63 | 64 | _u = reaction(_u, rho, dt) 65 | _u = diffusion(_u, nu, dt, IKX2) 66 | 67 | u.append(_u) 68 | 69 | return np.dstack(np.meshgrid(t, x, indexing="ij")), np.stack(u) 70 | 71 | # ---------------------------------------------------------------------------- # 72 | # GENERATE # 73 | # ---------------------------------------------------------------------------- # 74 | 75 | def generate(pde: ReactionDiffusion, N: int = 128): 76 | 77 | params = pde.params.sample(random.PRNGKey(0), (N, )) 78 | solve = F.partial(solution, nu=pde.nu, rho=pde.rho) 79 | 80 | u = jax.lax.map(lambda h: solve(lambda x: h(np.array([0, x])).squeeze()), params)[1] 81 | u = np.pad(u[:, ::8, ::8], [(0, 0), (0, 0), (0, 1)], mode="wrap")[..., np.newaxis] 82 | 83 | dir = os.path.dirname(__file__) 84 | 85 | np.save(f"{dir}/h.npy", params.coef) 86 | np.save(f"{dir}/u.{pde}.npy", u) 87 | 88 | return u 89 | 90 | if __name__ == "__main__": 91 | 92 | for nu in [0.005, 0.01, 0.05, 0.1]: 93 | 94 | generate(ReactionDiffusion(64, 5, nu)) 95 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from .pde import * 3 | from .model import * 4 | from .basis import * 5 | 6 | # ---------------------------------------------------------------------------- # 7 | # STEP # 8 | # ---------------------------------------------------------------------------- # 9 | 10 | def step(self: Trainer, variable: ϴ) -> Tuple[ϴ, Dict[str, X]]: 11 | 12 | params, prng = variable.get("params"), self.make_rng("sample") 13 | ϕ = self.pde.params.sample(prng, (self.cfg["bs"], )) 14 | 15 | @F.partial(jax.grad, has_aux=True) 16 | def loss(params: ϴ, ϕ: X) -> Tuple[X, Dict[str, X]]: 17 | 18 | loss = self.mod.apply(variable.copy(dict(params=params)), ϕ, method="loss") 19 | return jax.tree_util.tree_reduce(O.add, loss), loss 20 | 21 | grads, loss = jax.tree_map(F.partial(np.mean, axis=0), 22 | utils.cmap(F.partial(loss, params), self.cfg["vmap"])(ϕ)) 23 | 24 | if self.cfg["clip"] is not None: 25 | 26 | loss["norm"] = np.sqrt(sum(np.sum(np.square(grad)) 27 | for grad in jax.tree_util.tree_leaves(grads))) 28 | 29 | grads = jax.tree_map(F.partial(jax.lax.cond, loss["norm"] < self.cfg["clip"], 30 | lambda grad: grad, lambda grad: grad / loss["norm"] * self.cfg["clip"]), grads) 31 | 32 | updates, state = self.optimizer.update(grads, self.get_variable("optim", "state"), params) 33 | 34 | self.put_variable("optim", "state", state) 35 | return variable.copy(dict(params=optax.apply_updates(params, updates))), loss 36 | 37 | # ---------------------------------------------------------------------------- # 38 | # EVAL # 39 | # ---------------------------------------------------------------------------- # 40 | 41 | def eval(self: Trainer, variable: ϴ) -> Tuple[Dict, X]: 42 | 43 | if isinstance(self.pde.solution, Tuple): 44 | ϕ, s, u = self.pde.solution 45 | 46 | v = utils.cmap(F.partial(self.mod.apply, variable, x=s, method="u"), self.cfg["vmap"])(ϕ) 47 | with jax.default_device(cpu:=jax.devices("cpu")[0]): 48 | 49 | u, v = jax.device_put((u, v), device=cpu) 50 | return jax.tree_map(np.mean, jax.vmap(metric(self.pde, s))(ϕ, u, v)), (u, v) 51 | 52 | def metric(pde: PDE, s: Tuple[int]) -> Fx: 53 | 54 | def call(ϕ: Basis, u: X, v: X) -> Dict[str, X]: 55 | 56 | return dict( 57 | erra=np.mean(np.abs(np.ravel(u - v))), 58 | errr=np.linalg.norm(np.ravel(u - v)) / np.linalg.norm(np.ravel(u)), 59 | residual=np.mean(np.abs(pde.equation(utils.grid(*s), ϕ[s], v))), 60 | ) 61 | 62 | return call 63 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | def grid(*s: int, mode: str = None, flatten: bool = False) -> X: 4 | 5 | """ 6 | Return grid on [0, 1)^n. If not flatten, shape=(*s, len(s)); 7 | else shape=(∏s, len(s)). 8 | 9 | Mode: 10 | - `None`: uniformly spaced 11 | - "left": exclude endpoint 12 | - "cell": centers of rects 13 | """ 14 | 15 | axes = F.partial(np.linspace, 0, 1, endpoint=mode is None) 16 | grid = np.stack(np.meshgrid(*map(axes, s), indexing="ij"), -1) 17 | 18 | if mode == "cell": grid += .5 / np.array(s) 19 | if flatten: return grid.reshape(-1, len(s)) 20 | 21 | return grid 22 | 23 | def nmap(f: Fx, n: int = 1, **kwargs) -> Fx: 24 | 25 | """ 26 | Nested vmap. Keeps the same semantics as `jax.vmap` except that arbitrary 27 | `n` leading dimensions are vectorized. Returns the vmapped function. 28 | """ 29 | 30 | if not n: return f 31 | 32 | if n > 1: f = nmap(f, n - 1) 33 | return jax.vmap(f, **kwargs) 34 | 35 | def cmap(f: Fx, n: int = None, **kwargs) -> Fx: 36 | 37 | """ 38 | Chunck vmap. Keeps the same semantics as `jax.vmap` but only vectorizing 39 | over `n` items, and uses loop-based map over the chunks of that size. 40 | """ 41 | 42 | f = jax.vmap(f, **kwargs) 43 | def call(*args, **kwargs): 44 | 45 | return jax.tree_map(np.concatenate, jax.lax.map(f, *jax.tree_map(into:=lambda x: 46 | x.reshape(-1, n, *x.shape[1:]), args), **jax.tree_map(into, kwargs))) 47 | 48 | if n is None: return f 49 | 50 | return call 51 | 52 | def jit(f: Fx, **options) -> Fx: 53 | 54 | """ 55 | JIT function with cost analysis on the first run. Keep in mind that loops 56 | are not taken in to account correctly (which means with cfg.vmap set, the 57 | results are not reliable). 58 | """ 59 | 60 | f = jax.jit(f, **options) 61 | def call(*args, **kwargs) -> Any: 62 | 63 | nonlocal f 64 | 65 | if not isinstance(f, jax.stages.Compiled): 66 | 67 | print("=" * 116) 68 | 69 | print(f"compling function {f} ......") 70 | f = f.lower(*args, **kwargs).compile() 71 | 72 | cost, = f.cost_analysis() 73 | print("flop:", cost["flops"]) 74 | print("memory:", cost["bytes accessed"]) 75 | 76 | print("=" * 116) 77 | 78 | return f(*args, **kwargs) 79 | 80 | return call 81 | 82 | def timeit(f: Fx, **options) -> Fx: 83 | fjit = jit(f, **options) 84 | 85 | def call(*args, **kwargs) -> float: 86 | fjit(*args, **kwargs) # compile 87 | 88 | import time 89 | iter = time.time() 90 | 91 | jax.tree_map(jax.block_until_ready, 92 | fjit(*args, **kwargs)) 93 | 94 | return time.time() - iter 95 | 96 | return call 97 | 98 | # ---------------------------------------------------------------------------- # 99 | # DERIVATIVE # 100 | # ---------------------------------------------------------------------------- # 101 | 102 | def grad(f: Fx, n: int, D: Fx = jax.jacfwd) -> Fx: 103 | 104 | """ 105 | Calculate up to `n`th order derivatives. Return List[Array] of length `n` 106 | where `k`th element is the `k`th derivative of shape `(..., d^k)`. 107 | 108 | Differential scheme `D :: (X -> X) -> (X -> X)` determines how to obtain 109 | the Jacobian. Default to JAX's forward mode autograd. 110 | """ 111 | 112 | u = [f] 113 | 114 | for _ in range(n): u.append(f:=D(f)) 115 | return lambda x: [fk(x) for fk in u] 116 | 117 | def fdm(x: X, n: int) -> List[X]: 118 | 119 | """ 120 | Approximate the above derivative using finite difference. `x` is assumed 121 | to be evaluated on uniform grids on [0, 1]^d (include end-points), where 122 | dimension is taken as `x.ndim - 1`, i.e. `x` has a trailing channel dim. 123 | """ 124 | 125 | u = [x] 126 | d = x.ndim - 1 127 | s = x.shape[:-1] 128 | 129 | for _ in range(n): 130 | 131 | grad = map(lambda i: np.gradient(x, axis=i), range(d)) 132 | u.append(x:=np.stack(tuple(grad), axis=-1) * (np.array(s)-1)) 133 | 134 | return u 135 | 136 | # ---------------------------------------------------------------------------- # 137 | # DEBUG # 138 | # ---------------------------------------------------------------------------- # 139 | 140 | def repl(local): 141 | 142 | # ---------------------------------- IMPORT ---------------------------------- # 143 | 144 | import matplotlib.pyplot as plt 145 | import matplotlib.colors as clr 146 | 147 | import scienceplots 148 | plt.style.use(["science", 149 | "no-latex"]) 150 | 151 | from src.basis import Basis, series 152 | from src.basis.fourier import Fourier 153 | from src.basis.chebyshev import Chebyshev 154 | 155 | # ---------------------------------- HELPER ---------------------------------- # 156 | 157 | def save(fig=plt): fig.savefig("test.jpg", dpi=300); fig.clf() 158 | def show(img, **kw): plt.colorbar(plt.imshow(img, **kw)) 159 | def gif(*imgs, fps: int = 50, **kw): 160 | fig, ax = plt.subplots() 161 | 162 | vmin = kw.pop("vmin", min(map(np.min, imgs))) 163 | vmax = kw.pop("vmax", max(map(np.max, imgs))) 164 | im = ax.imshow(imgs[0], vmin=vmin, vmax=vmax, **kw) 165 | id = ax.text(0.98, 0.02, "", transform=ax.transAxes, ha="right", va="bottom") 166 | 167 | plt.colorbar(im) 168 | def frame(index): 169 | 170 | i = len(imgs) * index//fps 171 | id.set_text(f"#{i:03}") 172 | im.set_array(imgs[i]) 173 | return im, id 174 | 175 | from matplotlib import animation 176 | ani = animation.FuncAnimation(fig, frame, fps, blit=True) 177 | ani.save("test.gif", writer="pillow", fps=fps, dpi=300) 178 | 179 | import code; code.interact(local=dict(globals(), **dict(locals(), **local))) 180 | --------------------------------------------------------------------------------