├── .gitignore
├── LICENSE.txt
├── README.md
├── ckpt.py
├── demo.gif
├── main.py
├── plot
├── __init__.py
├── animation.py
├── navierstokes.py
├── poisson.py
├── reaction.py
├── reaction.solver.py
└── visualize.py
├── run
├── .sh
├── navierstokes.forced.sh
├── navierstokes.sh
├── navierstokes.yaml
├── poisson.dirichlet.sh
├── poisson.periodic.sh
├── reaction.sh
└── requirements.txt
└── src
├── README
├── __init__.py
├── basis
├── README
├── __init__.py
├── chebyshev.py
└── fourier.py
├── dists.py
├── model
├── README
├── __init__.py
├── _base.py
├── fno
│ ├── __init__.py
│ └── spectral.py
└── sno
│ └── spectral.py
├── pde
├── README
├── __init__.py
├── _domain.py
├── _params.py
├── mollifier.py
├── navierstokes
│ ├── __init__.py
│ └── generate.py
├── poisson
│ ├── __init__.py
│ ├── generate.py
│ └── generate.wls
└── reaction
│ ├── __init__.py
│ └── generate.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # PYTHON #
2 | ##########
3 | __pycache__
4 |
5 | # FILE #
6 | ##########
7 | *.npy
8 | *.jpg
9 | *.gif
10 |
11 | test*
12 | log
13 |
14 | # DEMO #
15 | ##########
16 |
17 | !demo.gif
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 ASK-Berkeley
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Neural Spectral Methods
3 |
4 | This repo contains the JAX implementation of our ICLR 2024 paper,
5 |
6 | *Neural Spectral Methods: Self-supervised learning in the spectral domain*.
7 |
8 | Yiheng Du, Nithin Chalapathi, Aditi Krishnapriyan.
9 |
10 | https://arxiv.org/abs/2312.05225.
11 |
12 | Neural Spectral Methods (NSM) is a class of machine learning method designed for solving parametric PDEs within the spectral domain. Above is a demonstration of NSM's prediction of the 2D Navier-Stokes equation.
13 |
14 | > We present Neural Spectral Methods, a technique to solve parametric Partial Differential Equations (PDEs), grounded in classical spectral methods. Our method uses orthogonal bases to learn PDE solutions as mappings between spectral coefficients. In contrast to current machine learning approaches which enforce PDE constraints by minimizing the numerical quadrature of the residuals in the spatiotemporal domain, we leverage Parseval’s identity and introduce a new training strategy through a spectral loss. Our spectral loss enables more efficient differentiation through the neural network, and substantially reduces training complexity. At inference time, the computational cost of our method remains constant, regardless of the spatiotemporal resolution of the domain. Our experimental results demonstrate that our method significantly outperforms previous machine learning approaches in terms of speed and accuracy by one to two orders of magnitude on multiple different problems. When compared to numerical solvers of the same accuracy, our method demonstrates a 10× increase in performance speed.
15 |
16 | ## Code structure
17 |
18 | The structure of this codebase is outlined in the [`src`](src) directory. This includes definitions for [PDE systems](src/pde), utilities for [orthogonal basis](src/basis), and general implementations of [baseline models](src/model). This codebase is self-contained and can serve as a standalong module for other purposes. To experiment with new problems, follow the templates for instantiating the abstract PDE class in the [`src/pde`](src/pde) directory.
19 |
20 | ### Arguments
21 |
22 | `main.py` accepts command line arguments as specified below. Refer to each bash script in the `run` directory for configurations of each experiment.
23 |
24 | ```
25 | usage: main.py [-h] [--seed SEED] [--f64] [--smi] [--pde PDE] [--model {fno,sno}]
26 | [--spectral] [--hdim HDIM] [--depth DEPTH] [--activate ACTIVATE]
27 | [--mode MODE [MODE ...]] [--grid GRID] [--fourier] [--cheb]
28 | {train,test} ...
29 |
30 | positional arguments:
31 | {train,test}
32 | train train model from scratch
33 | test enter REPL after loading
34 |
35 | optional arguments:
36 | -h, --help show this help message and exit
37 | --seed SEED random seed
38 | --f64 use double precision
39 | --smi profile memory usage
40 | --pde PDE PDE name
41 | --model {fno,sno} model name
42 | --spectral spectral training
43 | --hdim HDIM hidden dimension
44 | --depth DEPTH number of layers
45 | --activate ACTIVATE activation name
46 | --mode MODE [MODE ...]
47 | number of modes per dim
48 | --grid GRID training grid size
49 | --fourier fourier basis only
50 | --cheb using chebyshev
51 | ```
52 |
53 | ```
54 | usage: main.py train [-h] --bs BS --lr LR [--clip CLIP] --schd SCHD --iter ITER
55 | --ckpt CKPT --note NOTE [--vmap VMAP] [--save]
56 |
57 | optional arguments:
58 | -h, --help show this help message and exit
59 | --bs BS batch size
60 | --lr LR learning rate
61 | --clip CLIP gradient clipping
62 | --schd SCHD scheduler name
63 | --iter ITER total iterations
64 | --ckpt CKPT checkpoint every n iters
65 | --note NOTE leave a note here
66 | --vmap VMAP vectorization size
67 | --save save model checkpoints
68 | ```
69 |
70 | ```
71 | usage: main.py test [-h] [--load LOAD]
72 |
73 | optional arguments:
74 | -h, --help show this help message and exit
75 | --load LOAD saved model path
76 | ```
77 |
78 | ## Run
79 |
80 | ### Environment
81 |
82 | ```bash
83 | ## We are using jax version 0.4.7.
84 |
85 | pip install -r run/requirements.txt
86 |
87 | ## Please install jaxlib based on your own machine configuration.
88 |
89 | pip install https://storage.googleapis.com/jax-releases/
90 | ```
91 |
92 | ### Quick test
93 |
94 | ```bash
95 | python main.py test
96 | ```
97 |
98 | ### Train models
99 |
100 | - Local machine
101 |
102 | ```bash
103 | ## Generate data ..
104 |
105 | python -m src.pde..generate
106 |
107 | ## .. and launch.
108 |
109 | bash run/.sh
110 | ```
111 |
112 | - Cloud compute
113 |
114 | ```bash
115 | ## Using [SkyPilot](https://skypilot.readthedocs.io).
116 | ## Launch cloud jobs based on .yaml configurations ..
117 |
118 | sky launch -c ns run/navierstokes.yaml
119 | sky exec ns bash run/navierstokes.sh --env seed=
120 |
121 | ## .. and collect data.
122 |
123 | rsync -Pavzr ns:~/sky_workdir/log log/gcp
124 | ```
125 |
126 | ### Plot results
127 |
128 | ```bash
129 | python -m plot.
130 | ```
131 |
132 | ### Trouble shooting
133 |
134 | 1. The `XLA_PYTHON_CLIENT_MEM_FRACTION` environment variable in the script maximizes memory utilization. This is beneficial if your machine has limited memory (e.g. less than 32GB of GPU memory), but it can lead to initialization issues in certain edge cases. If such issues arise, simply remove the line.
135 |
136 | 2. For machines with very limited memory (e.g. 16GB of GPU memory), consider setting the `vmap` environment variable to a small integer. This allows loop-based mapping across the batch dimension and uses vectorization only for the number of elements specified by the `vmap` argument. While this approach saves GPU memory, it makes the FLOP measurements for each model inaccurate. Keep this in mind when interpreting the cost estimation message in each run.
137 |
138 | ## Timeline
139 |
140 | - 2023.12.9: initial commit
141 | - 2023.12.11: add arXiv link
142 | - 2023.12.22: release code
143 | - 2024.01.19: update citation
144 | - 2024.01.31: update ICLR citation
145 |
146 | ## Citation
147 |
148 | If you find this repository useful, please cite our work:
149 |
150 | ```
151 | @article{du2024neural,
152 | title={Neural Spectral Methods: Self-supervised learning in the spectral domain},
153 | journal={The Twelfth International Conference on Learning Representations},
154 | author={Du, Yiheng and Chalapathi, Nithin and Krishnapriyan, Aditi},
155 | year={2024}
156 | }
157 | ```
158 |
159 |
--------------------------------------------------------------------------------
/ckpt.py:
--------------------------------------------------------------------------------
1 | from src import *
2 | from threading import *
3 |
4 | class Checkpoint(Thread):
5 |
6 | def __init__(self, **kwargs):
7 | super().__init__(daemon=True)
8 |
9 | from queue import Queue
10 | self.metric = Queue()
11 | self.predict = Lock()
12 |
13 | # ---------------------------------------------------------------------------- #
14 | # INITIALIZE #
15 | # ---------------------------------------------------------------------------- #
16 |
17 | import datetime, os
18 | time = datetime.datetime.now().strftime("%c")
19 |
20 | self.path = "log/" + (kwargs["note"] or time)
21 | self.title = kwargs["pde"] + ":" + kwargs["model"]
22 | if kwargs["spectral"]: self.title += "+spectral"
23 |
24 | os.makedirs(self.path, exist_ok=False)
25 | print(kwargs, file=open(f"{self.path}/cfg", "w"))
26 |
27 | @property
28 | def prediction(self):
29 |
30 | with self.predict:
31 | return self._predict
32 |
33 | @prediction.setter
34 | def prediction(self, answer):
35 |
36 | with self.predict:
37 | self._predict = answer
38 |
39 | def run(self):
40 |
41 | from collections import defaultdict
42 | log = defaultdict(lambda: list())
43 |
44 | while True:
45 |
46 | # ---------------------------------------------------------------------------- #
47 | # CHECKPOINT #
48 | # ---------------------------------------------------------------------------- #
49 |
50 | try:
51 | for _ in range(max(1, self.metric.qsize())):
52 | metric, it = self.metric.get()
53 |
54 | for key, value in metric.items():
55 | log[key].append(value)
56 |
57 | except: pass
58 |
59 | import matplotlib.pyplot as plt
60 | import matplotlib.colors as clr
61 |
62 | import scienceplots; plt.style.use(["science", "no-latex"])
63 |
64 | # ---------------------------------------------------------------------------- #
65 | # METRIC #
66 | # ---------------------------------------------------------------------------- #
67 |
68 | for key, values in log.items():
69 |
70 | fig, ax = plt.subplots()
71 | xs = np.arange(len(values))+1
72 |
73 | if all(isinstance(value, Array) for value in values):
74 |
75 | ax.plot(xs, ys:=np.array(values), label=key)
76 | np.save(f"{self.path}/metric.{key}.npy", ys)
77 |
78 | if all(isinstance(value, Dict) for value in values):
79 |
80 | for subkey in set.union(*map(set, map(dict.keys, values))):
81 | ys = np.array(list(map(O.itemgetter(subkey), values)))
82 |
83 | ax.plot(xs, ys, label=f"{key}:{subkey}")
84 | np.save(f"{self.path}/metric.{key}:{subkey}.npy", ys)
85 |
86 | ax.set_xscale("log")
87 | ax.set_yscale("log")
88 | ax.set_title(self.title)
89 | ax.legend()
90 |
91 | fig.savefig(f"{self.path}/metric.{key}.jpg")
92 | plt.close(fig)
93 |
94 | # ---------------------------------------------------------------------------- #
95 | # IMAGE #
96 | # ---------------------------------------------------------------------------- #
97 |
98 | u, uhat = self.prediction
99 | np.save(f"{self.path}/uhat.npy", uhat)
100 |
101 | N = max(min(16, len(u)), 2)
102 | u, uhat = u[:N], uhat[:N]
103 |
104 | fig = plt.figure(figsize=(5 * N, 5 * 3))
105 | subfig = fig.subfigures(ncols=N)
106 |
107 | def create(subfig, u: X, uhat: X, i: int = None):
108 | axes = subfig.subplots(nrows=3)
109 |
110 | vmin, vmax = u.min().item(), u.max().item()
111 | if i is not None: u = u[i]; uhat = uhat[i]
112 |
113 | true = axes[0].imshow(u, vmin=vmin, vmax=vmax)
114 | pred = axes[1].imshow(uhat, vmin=vmin, vmax=vmax)
115 |
116 | vlim = max(abs(vmin), abs(vmax))
117 | diff = axes[2].imshow(uhat - u,
118 | cmap="Spectral",
119 | norm=clr.SymLogNorm(
120 | linthresh=vlim / 100,
121 | vmin=-vlim / 10,
122 | vmax=+vlim / 10,
123 | )
124 | )
125 |
126 | axes[0].axis("off")
127 | axes[1].axis("off")
128 | axes[2].axis("off")
129 |
130 | subfig.colorbar(true, ax=axes[:2])
131 | subfig.colorbar(diff, ax=axes[2:])
132 |
133 | return true, pred, diff
134 |
135 | if u.ndim-2 == 2:
136 |
137 | _ = list(map(create, subfig, u, uhat))
138 | fig.savefig(f"{self.path}/image.jpg")
139 |
140 | plt.close(fig)
141 |
142 | # ---------------------------------------------------------------------------- #
143 | # EXIT #
144 | # ---------------------------------------------------------------------------- #
145 |
146 | if it is None: break
147 |
--------------------------------------------------------------------------------
/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASK-Berkeley/Neural-Spectral-Methods/090e7a173f27734dbae5ec479a410ca1748981af/demo.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from src import *
2 | from src.pde import *
3 | from src.model import *
4 | from src.basis import *
5 |
6 | def main(cfg: Dict[str, Any]):
7 |
8 | # configure precision at the very beginning
9 | jax.config.update("jax_enable_x64", cfg["f64"])
10 |
11 | if cfg["smi"]:
12 |
13 | import jax_smi as smi
14 | smi.initialise_tracking()
15 |
16 | prng = random.PRNGKey(cfg["seed"])
17 | rngs = RNGS(prng, ["params", "sample"])
18 |
19 | from importlib import import_module
20 |
21 | if cfg["pde"] is not None:
22 |
23 | col, name = cfg["pde"].split(".", 2)
24 | mod = import_module(f"src.pde.{col}")
25 |
26 | pde: PDE = getattr(mod, name)
27 | pde.solution # load solution
28 |
29 | if cfg["model"] is not None:
30 |
31 | col, name = cfg["model"], cfg["model"].upper()
32 |
33 | if cfg["spectral"]: col += ".spectral"
34 | mod = import_module(f"src.model.{col}")
35 |
36 | Model: Solver = getattr(mod, name)
37 | model = Model(pde, cfg)
38 |
39 | # ---------------------------------------------------------------------------- #
40 | # TRAIN #
41 | # ---------------------------------------------------------------------------- #
42 |
43 | if cfg["action"] == "train":
44 |
45 | from src.train import step, eval
46 | train = Trainer(model, pde, cfg)
47 | variable, state = train.init_with_output(next(rngs), method="init")
48 |
49 | step = utils.jit(F.partial(train.apply, method=step, mutable=True))
50 | step(state, variable, rngs=next(rngs)) # compile train iteration
51 |
52 | def evaluate():
53 |
54 | global metric, predictions
55 | metric, predictions = train.apply(state, variable,
56 | method=eval, rngs=next(rngs))
57 |
58 | if cfg["save"]:
59 |
60 | from flax.training.checkpoints import save_checkpoint
61 | save_checkpoint(f"{ckpt.path}/variable", variable, it)
62 |
63 | from ckpt import Checkpoint
64 | ckpt = Checkpoint(**cfg)
65 | ckpt.start()
66 |
67 | from tqdm import trange
68 | for it in (pbar:=trange(cfg["iter"])):
69 |
70 | if not it % cfg["ckpt"]: evaluate()
71 |
72 | # ----------------------------------- STEP ----------------------------------- #
73 |
74 | import time
75 | rate = time.time()
76 |
77 | (variable, loss), state = jax.tree_map(jax.block_until_ready,
78 | step(state, variable, rngs=next(rngs)))
79 |
80 | rate = np.array(time.time() - rate)
81 |
82 | # ----------------------------------- CKPT ----------------------------------- #
83 |
84 | metric.update(loss=loss, rate=rate)
85 |
86 | ckpt.metric.put((metric.copy(), it))
87 | ckpt.prediction = predictions
88 |
89 | pbar.set_postfix(jax.tree_map(lambda x: f"{x:.2e}", metric))
90 |
91 | evaluate()
92 |
93 | ckpt.metric.put((metric, None))
94 | ckpt.prediction = predictions
95 | ckpt.join()
96 |
97 | return pde, model.bind(variable, rngs=next(rngs))
98 |
99 | # ---------------------------------------------------------------------------- #
100 | # LOAD #
101 | # ---------------------------------------------------------------------------- #
102 |
103 | else:
104 |
105 | if cfg["load"]:
106 |
107 | from flax.training.checkpoints import restore_checkpoint
108 | variable = restore_checkpoint(cfg["load"], target=None)
109 | model = model.bind(variable, rngs=next(rngs))
110 |
111 | exit(utils.repl(locals()))
112 |
113 | # ---------------------------------------------------------------------------- #
114 | # ARGPARSE #
115 | # ---------------------------------------------------------------------------- #
116 |
117 | if __name__ == "__main__":
118 |
119 | import argparse
120 | args = argparse.ArgumentParser()
121 | action = args.add_subparsers(dest="action")
122 |
123 | args.add_argument("--seed", type=int, default=19260817, help="random seed")
124 | args.add_argument("--f64", dest="f64", action="store_true", help="use double precision")
125 | args.add_argument("--smi", dest="smi", action="store_true", help="profile memory usage")
126 |
127 | args.add_argument("--pde", type=str, help="PDE name")
128 | args.add_argument("--model", type=str, help="model name", choices=["fno", "sno"]) # --cheb=cno
129 | args.add_argument("--spectral", dest="spectral", action="store_true", help="spectral training")
130 |
131 | # ----------------------------------- MODEL ---------------------------------- #
132 |
133 | args.add_argument("--hdim", type=int, help="hidden dimension")
134 | args.add_argument("--depth", type=int, help="number of layers")
135 | args.add_argument("--activate", type=str, help="activation name")
136 |
137 | args.add_argument("--mode", type=int, nargs="+", help="number of modes per dim")
138 | args.add_argument("--grid", type=int, default=256, help="training grid size")
139 |
140 | ## ablation study
141 |
142 | args.add_argument("--fourier", dest="fourier", action="store_true", help="fourier basis only")
143 | args.add_argument("--cheb", dest="cheb", action="store_true", help="using chebyshev")
144 |
145 | # ----------------------------------- TRAIN ---------------------------------- #
146 |
147 | args_train = action.add_parser("train", help="train model from scratch")
148 |
149 | args_train.add_argument("--bs", type=int, required=True, help="batch size")
150 | args_train.add_argument("--lr", type=float, required=True, help="learning rate")
151 | args_train.add_argument("--clip", type=float, required=False, help="gradient clipping")
152 | args_train.add_argument("--schd", type=str, required=True, help="scheduler name")
153 | args_train.add_argument("--iter", type=int, required=True, help="total iterations")
154 | args_train.add_argument("--ckpt", type=int, required=True, help="checkpoint every n iters")
155 | args_train.add_argument("--note", type=str, required=True, help="leave a note here")
156 |
157 | args_train.add_argument("--vmap", type=lambda x: int(x) if x else None, help="vectorization size")
158 | args_train.add_argument("--save", dest="save", action="store_true", help="save model checkpoints")
159 |
160 | # ----------------------------------- TEST ----------------------------------- #
161 |
162 | args_test = action.add_parser("test", help="enter REPL after loading")
163 |
164 | args_test.add_argument("--load", type=str, help="saved model path")
165 |
166 | # ---------------------------------------------------------------------------- #
167 | # MAIN #
168 | # ---------------------------------------------------------------------------- #
169 |
170 | args = args.parse_args()
171 | cfg = vars(args); print(f"{cfg=}")
172 |
173 | pde, model = main(cfg)
174 | # utils.repl(locals())
175 |
--------------------------------------------------------------------------------
/plot/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import scienceplots
4 | import functools as F
5 |
6 | import matplotlib.pyplot as plt
7 | import matplotlib.colors as clr
8 | import matplotlib.ticker as tkr
9 |
10 | plt.style.use(style=["science", "nature", "std-colors", "grid"])
11 | CLR = plt.rcParams["axes.prop_cycle"].by_key()["color"]
12 | CLR = [CLR[0], CLR[2], CLR[3], CLR[1]]
13 |
14 | TIMES = "\\!\\times\\!"
15 | PLUS = "\\!+\\!"
16 |
17 | def load(pde: str, metric: str, dir: str, METHOD):
18 |
19 | def random(method: str):
20 | def call(seed: int):
21 | try: return np.load(f"{dir}/{pde}{method}.{seed}/{metric}.npy")
22 | except: return np.array([np.nan])
23 | return call
24 |
25 | return { method: list(map(random(method), range(4))) for method in METHOD }
26 |
27 | def color(method: str) -> str:
28 |
29 | if method[0] == ":":
30 |
31 | if method[-1] == "M":
32 |
33 | return CLR[0]
34 |
35 | return CLR[-1]
36 |
37 | if method[0] == "x":
38 |
39 | if method[-1] == "C":
40 |
41 | return CLR[2]
42 |
43 | return CLR[1]
44 |
45 | def lines(method: str) -> str:
46 |
47 | if method[0] == ":":
48 |
49 | return "-"
50 |
51 | if method[0] == "x":
52 |
53 | if method[-1] == "C":
54 |
55 | return "-"
56 |
57 | return {
58 | "64": ":",
59 | "128": "--",
60 | "256": "-",
61 |
62 | "96": "-.",
63 | }[method[1:]]
64 |
65 | def reorder(ax, order):
66 |
67 | handles, labels = ax.get_legend_handles_labels()
68 | return [handles[i] for i in order], [labels[i] for i in order]
69 |
--------------------------------------------------------------------------------
/plot/animation.py:
--------------------------------------------------------------------------------
1 | from . import *
2 | from sys import argv
3 |
4 | N = list(map(int, argv[1:]))
5 |
6 | u = np.load("log/aws/u.npy")[np.r_[N]]
7 | unsm = np.load(f"log/aws/unsm.npy")[np.r_[N]]
8 | ufno = np.load(f"log/aws/ux96.npy")[np.r_[N]]
9 |
10 | fig, ax = plt.subplots(ncols=len(N), nrows=3, figsize=(2.1*len(N), 4),
11 | constrained_layout=True, squeeze=False, dpi=72)
12 |
13 | fig.set_constrained_layout_pads(hspace=0.1, wspace=0.1)
14 |
15 | def create(sol, nsm, fno, i: int):
16 |
17 | N, K = 15, 4 # remove dark colors
18 | vmin = u[i].min() * (N + 2*K) / N
19 | vmax = u[i].max() * (N + 2*K) / N
20 |
21 | cmap = plt.get_cmap("twilight_shifted")
22 | norm = clr.BoundaryNorm(np.linspace(vmin, vmax, N+2*K), cmap.N)
23 |
24 | im_sol = sol.imshow(u[i, 0], cmap=cmap, norm=norm)
25 | im_nsm = nsm.imshow(unsm[i, 0], cmap=cmap, norm=norm)
26 | im_fno = fno.imshow(ufno[i, 0], cmap=cmap, norm=norm)
27 |
28 | sol.set_xticks([]); sol.set_yticks([])
29 | nsm.set_xticks([]); nsm.set_yticks([])
30 | fno.set_xticks([]); fno.set_yticks([])
31 |
32 | return im_sol, im_nsm, im_fno
33 |
34 | im = []
35 | for i in range(len(N)):
36 | im.append(create(*ax[:, i], i))
37 |
38 | kw = dict(rotation=90, labelpad=10, fontsize=10)
39 | ax[0, 0].set_ylabel("Numerical solver", **kw)
40 | ax[1, 0].set_ylabel("NSM (ours) Prediction", **kw)
41 | ax[2, 0].set_ylabel("FNO+PINN Prediction", **kw)
42 |
43 | plt.colorbar(im[-1][0],fraction=0.046, pad=0.1, ticks=[-2, -1, 0, 1])
44 | plt.colorbar(im[-1][1],fraction=0.046, pad=0.1, ticks=[-2, -1, 0, 1])
45 | plt.colorbar(im[-1][2],fraction=0.046, pad=0.1, ticks=[-2, -1, 0, 1])
46 |
47 | def frame(index, total: int = 64):
48 | i = int(total * (index/T/fps))
49 |
50 | for n, [sol, nsm, fno] in enumerate(im):
51 |
52 | sol.set_array(u[n, i])
53 | nsm.set_array(unsm[n, i])
54 | fno.set_array(ufno[n, i])
55 |
56 | return sum(im, ())
57 |
58 | from matplotlib import animation
59 | ani = animation.FuncAnimation(fig, frame, (T:=8)*(fps:=24), interval=T/fps * 1000, blit=True)
60 | ani.save("plot/animation.gif", writer="pillow", fps=fps)
61 |
--------------------------------------------------------------------------------
/plot/navierstokes.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | METHOD = ["x64", "x96", ":NSM"]
4 | LABELS = [f"FNO${TIMES}64^3$+PINN", f"FNO${TIMES}96^3$+PINN", "NSM (ours)"]
5 |
6 | res = ["re2", "re3", "re4"]
7 |
8 | def load(pde: str, metric: str, dir: str = "log/gcp/ns.T3"):
9 |
10 | def random(method: str):
11 | def call(seed: int):
12 | try:
13 | if method == "x64": return np.load(f"{dir}/{pde}{method}.{seed}/{metric}.npy")[:int(3600*24 // 0.69)]
14 | if method == "x96": return np.load(f"{dir}/{pde}{method}.{seed}/{metric}.npy")[:int(3600*72 // 2.33)]
15 | return np.load(f"{dir}/{pde}{method}.{seed}/{metric}.npy")
16 | except: print("oh no not finished!"); return np.array([np.inf])
17 | return call
18 |
19 | return { method: list(map(random(method), range(4))) for method in METHOD }
20 |
21 | # ---------------------------------------------------------------------------- #
22 | # TABLE #
23 | # ---------------------------------------------------------------------------- #
24 |
25 | for re in res:
26 |
27 | print(f"{re = }", end=":\t")
28 |
29 | for method, errs in load(re, "metric.errr").items():
30 |
31 | errs = np.array([err[-1] for err in errs]) * 100
32 | print(f"{np.mean(errs):.2f}±{np.std(errs):.2f}", end="\t")
33 |
34 | print("")
35 |
36 | # ---------------------------------------------------------------------------- #
37 | # VISUALIZE #
38 | # ---------------------------------------------------------------------------- #
39 |
40 | import sys
41 | if len(sys.argv) > 1:
42 |
43 | N = int(sys.argv[1])
44 |
45 | fig, [ok, nsm, fno] = plt.subplots(ncols=(T:=3)+1, nrows=3, figsize=(T*2, 5))
46 |
47 | u = np.load("log/gcp/ns.T3/u.npy")[N]
48 | unsm = np.load(f"log/{sys.argv[2]}/uhat.npy")[N]
49 | ufno = np.load(f"log/{sys.argv[3]}/uhat.npy")[N]
50 |
51 | # remove dark colors
52 |
53 | L, K = 15, 4
54 | vmin = u.min() * (L+2*K)/L
55 | vmax = u.max() * (L+2*K)/L
56 |
57 | def show(ax, img):
58 |
59 | cmap = plt.get_cmap("twilight_shifted")
60 | norm = clr.BoundaryNorm(np.linspace(vmin, vmax, L+2*K+1), cmap.N)
61 | ax.imshow(img, cmap=cmap, norm=norm)
62 | ax.set_xticks([])
63 | ax.set_yticks([])
64 |
65 | for i in range(T+1):
66 |
67 | t = int(63 * i/T)
68 |
69 | show(ok[i], u[t])
70 | show(nsm[i], unsm[t])
71 | show(fno[i], ufno[t])
72 |
73 | ok[0].set_title("Initial vorticity", fontsize=10)
74 | for t in range(1, T + 1):
75 | ok[t].set_title(f"$T={t}$", fontsize=10)
76 |
77 | ok[0].set_ylabel("Numerical solver", rotation=90, labelpad=10, fontsize=10)
78 | nsm[0].set_ylabel("NSM (ours) prediction", rotation=90, labelpad=10, fontsize=10)
79 | fno[0].set_ylabel("FNO+PINN prediction", rotation=90, labelpad=10, fontsize=10)
80 |
81 | ok[0].text(-0.21, 1.03, "(c)", transform=ok[0].transAxes, fontsize=10)
82 |
83 | fig.tight_layout()
84 | fig.savefig(f"plot/navierstokes.{N}.jpg", dpi=300)
85 |
86 | # ---------------------------------------------------------------------------- #
87 | # CURVE #
88 | # ---------------------------------------------------------------------------- #
89 |
90 | for re in ["re4"]:
91 |
92 | figure = plt.figure(figsize=(3, 4.5))
93 | res, err = figure.subplots(nrows=2, sharex=True)
94 |
95 | pct = tkr.FuncFormatter(lambda y, _: f"{y:.1%}"[:3])
96 |
97 | rate = load(re, "metric.rate")
98 |
99 | # --------------------------------- RESIDUAL --------------------------------- #
100 |
101 | residual = load(re, "metric.residual")
102 |
103 | for method, label in zip(METHOD, LABELS):
104 |
105 | xs = np.linspace(1, max(map(np.sum, rate[method])), N:=1000)
106 |
107 | ys = []
108 | for rat, rrr in zip(rate[method], residual[method]):
109 |
110 | a = np.concatenate([np.zeros(1), np.cumsum(rat)])[:-1]
111 |
112 | ys.append(np.interp(xs, a, rrr[:len(a)]))
113 |
114 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0)
115 | ys_std = np.std(np.stack(ys, axis=0), axis=0)
116 |
117 | res.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method))
118 | res.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None)
119 |
120 | res.set_xscale("log")
121 | res.set_yscale("log")
122 |
123 | res.xaxis.set_label_position("top")
124 | res.set_xlabel("Training time (seconds)", fontsize=10)
125 | res.set_ylabel(f"PDE residual on test set", fontsize=10)
126 |
127 | res.yaxis.grid(False, which='minor')
128 | res.set_yticks([0.1, 0.08, 0.06, 0.04, 0.02])
129 | res.yaxis.set_major_formatter(lambda x, _: str(x))
130 | res.yaxis.set_minor_formatter(lambda *_: "")
131 |
132 | err.set_xlim(10, 3e5)
133 | err.xaxis.tick_top()
134 | for tick in err.xaxis.get_majorticklabels():
135 | tick.set_horizontalalignment("left")
136 |
137 | # ----------------------------------- ERROR ---------------------------------- #
138 |
139 | # TRAIN
140 |
141 | errr = load(re, "metric.errr")
142 |
143 | for method, label in zip(METHOD, LABELS):
144 |
145 | xs = np.linspace(1, max(map(np.sum, rate[method])), N:=1000)
146 |
147 | ys = []
148 | for rat, rrr in zip(rate[method], errr[method]):
149 |
150 | a = np.concatenate([np.zeros(1), np.cumsum(rat)])[:-1]
151 |
152 | ys.append(np.interp(xs, a, rrr[:len(a)]))
153 |
154 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0)
155 | ys_std = np.std(np.stack(ys, axis=0), axis=0)
156 |
157 | err.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method))
158 | err.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None)
159 |
160 | err.set_xscale("log")
161 | err.set_yscale("log")
162 |
163 | err.set_ylabel(f"$L_2$ rel. error (\\%) on test set", fontsize=10)
164 |
165 | err.yaxis.set_major_formatter(pct)
166 | err.set_yticks([0.05, 0.1, 0.2, 0.3])
167 |
168 | res.legend(fontsize=7, loc="lower left", handlelength=1.7)
169 | err.legend(fontsize=7, loc="lower left", handlelength=1.7)
170 |
171 | res.xaxis.set_tick_params(labelsize=8)
172 | res.yaxis.set_tick_params(labelsize=8)
173 | err.yaxis.set_tick_params(labelsize=8)
174 | err.xaxis.set_tick_params(labelsize=8)
175 |
176 | res.yaxis.set_label_coords(-0.15, 0.5)
177 | err.yaxis.set_label_coords(-0.15, 0.5)
178 |
179 | res.text(-0.21, 1, "(a)", transform=res.transAxes, fontsize=10)
180 | err.text(-0.21, 1, "(b)", transform=err.transAxes, fontsize=10)
181 |
182 | figure.tight_layout(h_pad=0.1)
183 | figure.savefig(f"plot/ns.curve.{re}.jpg", dpi=300)
184 |
--------------------------------------------------------------------------------
/plot/poisson.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | METHOD = [":SNO", "x64", "x128", "x256", ":NSM"]
4 |
5 | load = F.partial(load, dir="log/ps.dirichlet", METHOD=METHOD)
6 | bcs = ["periodic", "dirichlet"]
7 |
8 | # ---------------------------------------------------------------------------- #
9 | # TABLE #
10 | # ---------------------------------------------------------------------------- #
11 |
12 | print("="*116)
13 | print("*relu*\t\t\t", "\t\t".join(METHOD))
14 | print("-"*116)
15 |
16 | for bc in bcs:
17 |
18 | print(f"{bc = }", end=":\t")
19 |
20 | for method, errs in load("relu/"+bc, "metric.errr").items():
21 |
22 | errs = np.array([err[-1] for err in errs]) * 100
23 | print(f"{np.mean(errs):.3f}±{np.std(errs):.3f}", end=" \t")
24 |
25 | print("")
26 |
27 | print("="*116)
28 | print("*tanh*\t\t\t", "\t\t".join(METHOD))
29 | print("-"*116)
30 |
31 | for bc in bcs:
32 |
33 | print(f"{bc = }", end=":\t")
34 |
35 | for method, errs in load("tanh/"+bc, "metric.errr").items():
36 |
37 | errs = np.array([err[-1] for err in errs]) * 100
38 | print(f"{np.mean(errs):.3f}±{np.std(errs):.3f}", end=" \t")
39 |
40 | print("")
41 |
42 | print("="*116)
43 | print("*long*\t\t\t", "\t\t".join(METHOD))
44 | print("-"*116)
45 |
46 | for bc in bcs:
47 |
48 | print(f"{bc = }", end=":\t")
49 |
50 | for method, errs in load("long/"+bc, "metric.errr").items():
51 |
52 | errs = np.array([err[-1] for err in errs]) * 100
53 | print(f"{np.mean(errs):.3f}±{np.std(errs):.3f}", end=" \t")
54 |
55 | print("")
56 |
57 | print("="*116)
58 |
--------------------------------------------------------------------------------
/plot/reaction.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | METHOD = [":SNO", "x64", "x128", "x256", "x256C", ":NSM"]
4 |
5 | # FNOPINN = lambda n, m="": "FNO{\\scriptsize$\\!\\times\\!"+str(n)+"^2$"+m+"}+PINN"
6 | FNOPINN = lambda n, m="": "FNO$\\!\\times\\!" + str(n) + "^2$" + m + "+PINN"
7 | LABELS = [None, FNOPINN(64, "\\ \\ "), FNOPINN(128), FNOPINN(256), f"CNO+PINN(ours)", f"NSM (ours)"]
8 |
9 | load = F.partial(load, dir="log/rd", METHOD=METHOD)
10 |
11 | # ---------------------------------------------------------------------------- #
12 | # TABLE #
13 | # ---------------------------------------------------------------------------- #
14 |
15 | for nu in ["005", "01", "05", "1"]:
16 |
17 | print(f"{nu = }", end=":\t")
18 |
19 | for method, errs in load(f"nu{nu}", "metric.errr").items():
20 |
21 | errs = np.array([err[-1] for err in errs]) * 100
22 | print(f"{np.mean(errs):.3f}±{np.std(errs):.3f}", end="\t")
23 |
24 | print("")
25 |
26 |
27 | for nu in ["01"]:
28 |
29 | figure = plt.figure(figsize=(8, 4), constrained_layout=False)
30 | train, test = figure.subfigures(ncols=2, width_ratios=[1, 1.2])
31 |
32 | train.subplots_adjust(hspace=0.31)
33 | test.subplots_adjust(left=0.08)
34 |
35 | pct = tkr.FuncFormatter(lambda y, _: f"{y:.1%}"[:3])
36 |
37 | # ---------------------------------------------------------------------------- #
38 | # TRAIN #
39 | # ---------------------------------------------------------------------------- #
40 |
41 | rate = load(f"nu{nu}", "metric.rate")
42 | res, err = train.subplots(nrows=2, sharex=True)
43 |
44 | # --------------------------------- RESIDUAL --------------------------------- #
45 |
46 | residual = load(f"nu{nu}", "metric.residual")
47 |
48 | for method, label in zip(METHOD, LABELS):
49 | if label is None: continue
50 |
51 | xs = np.linspace(1, max(map(np.sum, rate[method])), N:=1000)
52 | ys = [np.interp(xs, np.concatenate([np.zeros(1), np.cumsum(rate)]), residual)
53 | for rate, residual in zip(rate[method], residual[method])]
54 |
55 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0)
56 | ys_std = np.std(np.stack(ys, axis=0), axis=0)
57 |
58 | res.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method))
59 | res.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None)
60 |
61 | res.set_xscale("log")
62 | res.set_yscale("log")
63 |
64 | res.set_xlim(1, 1e4)
65 | res.set_xticks([1, 1e1, 1e2, 1e3, 1e4])
66 |
67 | # res.yaxis.set_major_formatter(tkr.ScalarFormatter())
68 | res.yaxis.set_label_coords(-0.13, None)
69 |
70 | # ---------------------------------- ERROR % --------------------------------- #
71 |
72 | errr = load(f"nu{nu}", "metric.errr")
73 |
74 | for method, label in zip(METHOD, LABELS):
75 | if label is None: continue
76 |
77 | xs = np.linspace(1, max(map(np.sum, rate[method])), N:=1000)
78 | ys = [np.interp(xs, np.concatenate([np.zeros(1), np.cumsum(rate)]), errr)
79 | for rate, errr in zip(rate[method], errr[method])]
80 |
81 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0)
82 | ys_std = np.std(np.stack(ys, axis=0), axis=0)
83 |
84 | err.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method))
85 | err.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None)
86 |
87 | err.set_xscale("log")
88 | err.set_yscale("log")
89 |
90 | err.set_xlim(1, 1e4)
91 | err.set_xticks([1, 1e1, 1e2, 1e3, 1e4])
92 |
93 | err.set_ylim(6e-4)
94 |
95 | err.yaxis.set_major_formatter(pct)
96 | err.yaxis.set_label_coords(-0.13, None)
97 |
98 | # ---------------------------------------------------------------------------- #
99 | # TEST #
100 | # ---------------------------------------------------------------------------- #
101 |
102 | ax = test.subplots()
103 |
104 | mkr = dict(marker=".", markersize=7)
105 |
106 | # ------------------------------------ FNO ----------------------------------- #
107 |
108 | time = np.load("log/rd/solver/time.fno.npy")
109 |
110 | for method, label in zip(METHOD, LABELS):
111 |
112 | if method[0] == "x" and method[-1] != "C":
113 |
114 | ax.plot(X:=time.mean(-1), Y:=np.load(f"log/rd/solver/errr.fnox{method[1:]}.npy"), **mkr, label=label+" (training)", c=color(method), ls=lines(method))
115 |
116 | def getl(n):
117 |
118 | style = dict(text=f"${2**n*32}{TIMES}{2**n*32}$", xy=(x:=X[n], y:=Y[n]), textcoords="offset pixels", fontsize=9)
119 |
120 | if n < 2: style.update(xytext=(x-20, y+40))
121 | if n > 1: style.update(xytext=(x-78, y+40))
122 |
123 | return style
124 |
125 | if method == "x256":
126 | for n in range(2): ax.annotate(**getl(n))
127 | if method == "x64":
128 | for n in range(2, 5): ax.annotate(**getl(n))
129 |
130 | # ---------------------------------- SOLVER ---------------------------------- #
131 |
132 | time = np.load("log/rd/solver/time.solver.npy")
133 | errr = np.load(f"log/rd/solver/errr.solver.npy")
134 |
135 | ax.plot(X:=time.mean(-1), Y:=errr.mean(-1), c="black", **mkr, label="Numerical solver")
136 | for n in range(1, len(errr)+1): ax.annotate(f"${2**n}{TIMES}{2**n}$", (x:=X[n-1], y:=Y[n-1]), (x+16, y), textcoords="offset pixels", fontsize=9)
137 |
138 | # ------------------------------------ NSM ----------------------------------- #
139 |
140 | time = np.load("log/rd/solver/time.nsm.npy")
141 |
142 | ax.plot(x:=time.mean(), y:=np.load("log/rd/solver/errr.nsm.npy").mean(), **mkr, c=CLR[0], label="NSM (ours)")
143 | ax.scatter(x, y, c=CLR[0], s=20)
144 | ax.annotate(f"$32{TIMES}32$", (x, y), (x-9, y+152), textcoords="offset pixels", fontsize=9)
145 | ax.annotate(f"$64{TIMES}64$", (x, y), (x-9, y+120), textcoords="offset pixels", fontsize=9)
146 | ax.annotate(f"$128{TIMES}128$", (x, y), (x-9, y+88), textcoords="offset pixels", fontsize=9)
147 | ax.annotate(f"$256{TIMES}256$", (x, y), (x-9, y+56), textcoords="offset pixels", fontsize=9)
148 | ax.annotate(f"$512{TIMES}512$", (x, y), (x-9, y+24), textcoords="offset pixels", fontsize=9)
149 |
150 | ax.set_xscale("log")
151 | ax.set_yscale("log")
152 |
153 | ax.set_xlim(6e-4, 1e-1)
154 | ax.set_ylim(6e-4, 1e-1)
155 |
156 | ax.yaxis.set_major_formatter(pct)
157 | ax.yaxis.set_label_coords(-0.09, None)
158 |
159 | # ---------------------------------------------------------------------------- #
160 | # LAYOUT #
161 | # ---------------------------------------------------------------------------- #
162 |
163 | res.legend(*reorder(res, [0, 1, 2, 3, 4]), loc="lower center", handlelength=1.5, bbox_to_anchor=(0.5, -0.315), ncol=3, fontsize=7, labelspacing=0.5, columnspacing=1)
164 | ax.legend(loc="upper right", handlelength=1.8, ncols=2, fontsize=8, labelspacing=0.3, columnspacing=0.8)
165 |
166 | res.set_ylabel(f"PDE residual on\n $512{TIMES}512$ test res.", fontsize=10)
167 | err.set_xlabel(f"Training time (seconds)", fontsize=10)
168 | err.set_ylabel(f"$L_2$ rel. error (\\%) on\n $512{TIMES}512$ test res.", fontsize=10)
169 |
170 | ax.set_xlabel("Inference time (seconds)", fontsize=10)
171 | ax.set_ylabel("$L_2$ rel. error (\\%) on different test resolutions", fontsize=10)
172 |
173 | res.xaxis.set_tick_params(labelsize=8)
174 | res.yaxis.set_tick_params(labelsize=8)
175 | err.xaxis.set_tick_params(labelsize=8)
176 | err.yaxis.set_tick_params(labelsize=8)
177 | ax.xaxis.set_tick_params(labelsize=8)
178 | ax.yaxis.set_tick_params(labelsize=8)
179 |
180 | res.text(-0.23, 1, "(a)", transform=res.transAxes, fontsize=10)
181 | err.text(-0.23, 1, "(b)", transform=err.transAxes, fontsize=10)
182 | ax.text(-0.13, 1, "(c)", transform=ax.transAxes, fontsize=10)
183 |
184 | figure.savefig(f"plot/re.curve.jpg", dpi=300)
185 |
186 | # ---------------------------------------------------------------------------- #
187 | # BOX #
188 | # ---------------------------------------------------------------------------- #
189 |
190 | FNOPINN = lambda n, m="": "FNO$\\!\\times\\!" + str(n) + "^2$" + m
191 | LABELS = [None, FNOPINN(64, "\\ \\ "), FNOPINN(128), FNOPINN(256), f"CNO (ours)", f"NSM (ours)"]
192 |
193 | fig, axes = plt.subplots(figsize=(16, 3.4), ncols=4)
194 | for ax, nu in zip(axes, ["005", "01", "05", "1"]):
195 |
196 | u = np.load(f"src/pde/reaction/u.rho=5:nu=0.{nu.ljust(3, '0')}.npy")
197 | uhat = load(f"nu{nu}", "uhat")
198 |
199 | ax.boxplot([np.concatenate([np.mean(np.abs(uhat-u), axis=(1, 2, 3))
200 | for uhat in uhat[method][:1]]) for method in METHOD[1:]], labels=LABELS[1:])
201 |
202 | ax.set_title(f"Absolute error distribution for $\\nu=0.{nu}$", fontsize=10)
203 |
204 | ax.set_ylim(0, 0.06)
205 |
206 | axes[0].text(-0.08, 1.04, "(a)", transform=axes[0].transAxes, fontsize=10)
207 | axes[1].text(-0.08, 1.04, "(b)", transform=axes[1].transAxes, fontsize=10)
208 | axes[2].text(-0.08, 1.04, "(c)", transform=axes[2].transAxes, fontsize=10)
209 | axes[3].text(-0.08, 1.04, "(d)", transform=axes[3].transAxes, fontsize=10)
210 |
211 | fig.savefig(f"plot/re.box.jpg", dpi=300)
212 |
--------------------------------------------------------------------------------
/plot/reaction.solver.py:
--------------------------------------------------------------------------------
1 | from src.utils import *
2 | from src.pde.reaction import *
3 | from src.pde.reaction.generate import *
4 |
5 | from tqdm import tqdm
6 |
7 | def time(N: int = 12) -> X:
8 |
9 | h, s, u = nu005.solution
10 | h = h.map(lambda x: x[0])
11 |
12 | def check(n: int) -> float:
13 |
14 | return timeit(lambda:
15 | solution(lambda x: h(np.array([0, x]))[0],
16 | 1.0, 1.0, nx=2**n, nt=2**n+1))()
17 |
18 | return list(map(check, tqdm(range(1, N))))
19 |
20 | def error(rd: ReactionDiffusion, N: int = 12) -> X:
21 |
22 | h, s, u = rd.solution
23 |
24 | def solve(f: Basis, n: int) -> X:
25 |
26 | h = lambda x: f(np.array([0, x])).squeeze()
27 | return solution(h, rd.nu, rd.rho, n, n+1)[1]
28 |
29 | U = jax.lax.map(F.partial(solve, n=(K:=2**N)), h)
30 |
31 | def call(h: Fx, u: X, n: int) -> X:
32 |
33 | un = solve(h, k:=2**n)
34 | uN = u[::K//k, ::K//k]
35 |
36 | return np.linalg.norm(np.ravel(uN - un)) \
37 | / np.linalg.norm(np.ravel(uN))
38 |
39 | return [jax.vmap(F.partial(call, n=n))(h, U)
40 | for n in tqdm(range(1, N))]
41 |
42 | # ---------------------------------------------------------------------------- #
43 | # SOLVER #
44 | # ---------------------------------------------------------------------------- #
45 |
46 | np.save("log/rd/solver/errr.solver.npy", np.array(error(nu01)))
47 | np.save("log/rd/solver/time.solver.npy", np.array([time() for _ in tqdm(range(16))]))
48 |
49 | # ---------------------------------------------------------------------------- #
50 | # MODEL #
51 | # ---------------------------------------------------------------------------- #
52 |
53 | pde, model = "don't run me this way"
54 | prng = "run main with `--test` flag"
55 |
56 | def solve(f: Basis, n: int) -> X:
57 | h = lambda x: f(np.array([0, x])).squeeze()
58 | return solution(h, pde.nu, pde.rho, n, n+1)[1]
59 |
60 | U = jax.lax.map(F.partial(solve, n=4096), pde.solution[0])
61 | # np.save("test.U.npy", U);;;;;; U=np.load("test.U.npy")
62 |
63 | def acc(n: int, K=4096):
64 | uhat = jax.lax.map(F.partial(model.apply, model.variables, x=((k:=2**n)+1, k+1), method="u"), pde.solution[0])
65 | return np.linalg.norm((u:=U[:, ::K//k, ::K//k]) - uhat[:, :, :-1, 0]) / np.linalg.norm(u)
66 |
67 | grid = model.cfg["grid"]
68 | accuracy = np.stack([acc(n) for n in range(5, 10)])
69 | np.save(f"log/rd/solver/errr.fnox{grid}.npy", accuracy)
70 |
71 | def tim(n: int):
72 | v = model.init(prng, p:=pde.params.sample(prng), (s:=2**n+1, s), method="forward")
73 | timer = timeit(lambda: model.apply(v, p, (s:=2**n+1, s), method="forward"))
74 | return np.array([timer() for _ in range(16)])
75 |
76 | time = np.stack([tim(n) for n in range(5, 10)])
77 | np.save(f"log/rd/solver/time.fno.npy", time)
78 |
79 | # ------------------------------------ NSM ----------------------------------- #
80 |
81 | accuracy = np.stack([acc(n) for n in range(5, 10)])
82 | np.save(f"log/rd/solver/errr.nsm.npy", accuracy)
83 |
84 | v = model.init(prng, p:=pde.params.sample(prng), method="forward")
85 | time = timeit(lambda: model.apply(v, p, method="forward"))
86 |
87 | np.save("log/rd/solver/time.nsm.npy", np.array([time() for _ in range(16)]))
88 |
--------------------------------------------------------------------------------
/plot/visualize.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | from sys import argv; N = int(argv[1])
4 |
5 | # ---------------------------------------------------------------------------- #
6 | # POISSON #
7 | # ---------------------------------------------------------------------------- #
8 |
9 | u = np.load(f"src/pde/poisson/u.periodic.npy")[N]
10 | s = np.load("test.s.npy")
11 |
12 | fig, [ax_s, ax_u, err] = plt.subplots(ncols=3, figsize=(8.5, 2.2),
13 | width_ratios=[1, 1, 1.3],
14 | constrained_layout=True)
15 |
16 | plt.colorbar(ax_s.imshow(s, origin='lower', extent=[0, 255/256, 0, 255/256]), fraction=0.046, pad=0.04 ); ax_s.grid(False); ax_s.set_xticks([0, 0.2, 0.4, 0.6, 0.8]); ax_s.set_yticks([0, 0.2, 0.4, 0.6, 0.8]); ax_s.set_title("Source term $s$", fontsize=10)
17 | plt.colorbar(ax_u.imshow(u, origin='lower', extent=[0, 255/256, 0, 255/256]), fraction=0.046, pad=0.04, ticks=[0, 0.01, 0.02]); ax_u.grid(False); ax_u.set_xticks([0, 0.2, 0.4, 0.6, 0.8]); ax_u.set_yticks([0, 0.2, 0.4, 0.6, 0.8]); ax_u.set_title("Solution $u$", fontsize=10)
18 |
19 | ax_s.xaxis.set_tick_params(labelsize=8)
20 | ax_s.yaxis.set_tick_params(labelsize=8)
21 | ax_u.xaxis.set_tick_params(labelsize=8)
22 | ax_u.yaxis.set_tick_params(labelsize=8)
23 |
24 | def work(METHOD, LABELS, run: str):
25 |
26 | load_now = F.partial(load, dir=f"log/ps/{run}")
27 | rate = load_now("periodic", "metric.rate", METHOD=METHOD)
28 | errr = load_now("periodic", "metric.errr", METHOD=METHOD)
29 |
30 | for method, label in zip(METHOD, LABELS):
31 |
32 | xs = np.linspace(1, max(map(np.sum, rate[method])), 1000)
33 | ys = [np.interp(xs, np.concatenate([np.zeros(1), np.cumsum(rate)]), errr)
34 | for rate, errr in zip(rate[method], errr[method])]
35 |
36 | ys_mean = np.mean(np.stack(ys, axis=0), axis=0)
37 | ys_std = np.std(np.stack(ys, axis=0), axis=0); ys_std = np.minimum(ys_std, ys_mean/2)
38 |
39 | err.plot(xs, ys_mean, label=label, c=(c:=color(method)), ls=lines(method))
40 | err.fill_between(xs, ys_mean-ys_std, ys_mean+ys_std, alpha=0.2, color=c, edgecolor=None)
41 |
42 | err.set_xscale("log")
43 | err.set_yscale("log")
44 |
45 | err.set_xlim(1, 1e5)
46 | err.set_xticks([1, 1e1, 1e2, 1e3, 1e4, 1e5])
47 |
48 | err.yaxis.set_major_formatter(tkr.FuncFormatter(lambda y, _: f"{y:.1%}"[:3]))
49 |
50 | work([":NSM", ":SNO"], ["NSM", "SNO"], "relu")
51 | work(["x64", "x128"], [f"FNO${TIMES}64^2\\ \\ $+PINN", f"FNO${TIMES}128^2$+PINN"], "tanh")
52 | work(["x256"], [f"FNO${TIMES}256^2$+PINN"], "long")
53 |
54 | err.set_title(f"$L_2$ rel. error (\\%) on $256{TIMES}256$ test res.", fontsize=10)
55 | err.legend(loc="lower right")#, handlelength=1.3, bbox_to_anchor=(1.5, 0.5), fontsize=8, labelspacing=0.3)
56 | err.xaxis.set_tick_params(labelsize=8)
57 | err.yaxis.set_tick_params(labelsize=8)
58 |
59 | ax_s.text(-0.11, 1.05, "(a)", transform=ax_s.transAxes, fontsize=10)
60 | ax_u.text(-0.11, 1.05, "(b)", transform=ax_u.transAxes, fontsize=10)
61 | err.text(-0.08, 1.05, "(c)", transform=err.transAxes, fontsize=10)
62 |
63 | fig.savefig(f"plot/poisson.{N}.jpg", dpi=300)
64 |
65 | # ---------------------------------------------------------------------------- #
66 | # REACTION #
67 | # ---------------------------------------------------------------------------- #
68 |
69 | nus = [0.005, 0.01, 0.05, 0.1]
70 | us = [np.load(f"src/pde/reaction/u.rho=5:nu={nu:.3f}.npy")[N] for nu in nus]
71 |
72 | fig, axes = plt.subplots(ncols=4, figsize=(7, 2), constrained_layout=True)
73 | for u, ax, nu in zip(us:=np.stack(us), axes, nus):
74 |
75 | ax.grid(False)
76 | im = ax.imshow(u, cmap="Spectral", origin="lower",
77 | vmin=us.min(), vmax=us.max(),
78 | extent=[0, 511/512, 0, 1])
79 |
80 | ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8])
81 | ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
82 |
83 | ax.set_title(f"$\\nu={nu}$")
84 |
85 | plt.colorbar(im,fraction=0.046, pad=0.1)
86 |
87 | fig.savefig(f"plot/reaction.{N}.jpg", dpi=300)
88 |
--------------------------------------------------------------------------------
/run/.sh:
--------------------------------------------------------------------------------
1 | XLA_PYTHON_CLIENT_PREALLOCATE=True \
2 | XLA_PYTHON_CLIENT_MEM_FRACTION=.95 \
3 | python main.py \
4 | --seed ${seed:-$RANDOM} \
5 | --hdim 64 \
6 | --depth 4 \
7 | --activate relu \
8 | $@ train \
9 | --bs ${bs:-16} \
10 | --lr ${lr:-1e-3} \
11 | --schd ${schd:-exp} \
12 | --iter ${iter:-30000} \
13 | --vmap ${vmap:-""} \
14 | --ckpt ${ckpt:-100} \
15 | --note ${note:-"$(date)"}
--------------------------------------------------------------------------------
/run/navierstokes.forced.sh:
--------------------------------------------------------------------------------
1 | note=ns.T50/NSM."$seed" iter=200000 ckpt=1000 bash run/.sh --pde navierstokes.tf --model fno --hdim 32 --depth 5 --mode 18 31 31 --spectral
2 | note=ns.T50/x96."$seed" iter=200000 ckpt=1000 vmap=1 bash run/.sh --pde navierstokes.tf --model fno --hdim 32 --depth 5 --mode 18 31 31 --grid 96
--------------------------------------------------------------------------------
/run/navierstokes.sh:
--------------------------------------------------------------------------------
1 | for re in 4 3 2; do
2 |
3 | note=ns.T3/re"$re":NSM."$seed" iter=200000 ckpt=1000 bash run/.sh --pde navierstokes.re"$re" --model fno --hdim 32 --depth 10 --mode 12 31 31 --spectral
4 | note=ns.T3/re"$re"x64."$seed" iter=200000 ckpt=1000 vmap=1 bash run/.sh --pde navierstokes.re"$re" --model fno --hdim 32 --depth 10 --mode 12 31 31 --grid 64
5 | note=ns.T3/re"$re"x96."$seed" iter=200000 ckpt=1000 vmap=1 bash run/.sh --pde navierstokes.re"$re" --model fno --hdim 32 --depth 10 --mode 12 31 31 --grid 96
6 |
7 | done
--------------------------------------------------------------------------------
/run/navierstokes.yaml:
--------------------------------------------------------------------------------
1 | name: ns
2 | workdir: .
3 | resources:
4 |
5 | cloud: gcp
6 | disk_size: 256
7 | accelerators: A100:1
8 |
9 | image_id: "projects/deeplearning-platform-release/global\
10 | /images/common-cu113-v20230615-ubuntu-2004-py310"
11 |
12 | setup: |
13 |
14 | echo ====================
15 | echo executing init......
16 | echo ====================
17 |
18 | JAXLIB=jaxlib-0.4.7+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64
19 | pip install https://storage.googleapis.com/jax-releases/cuda11/$JAXLIB.whl
20 |
21 | pip install -r run/requirements.txt
22 |
23 | echo =====================
24 | echo generating data......
25 | echo =====================
26 |
27 | python -m src.pde.navierstokes.generate ns
28 | python -m src.pde.navierstokes.generate tf
29 |
--------------------------------------------------------------------------------
/run/poisson.dirichlet.sh:
--------------------------------------------------------------------------------
1 | for seed in 0 1 2 3; do
2 |
3 | seed=$seed iter=100000 note=ps/relu/dirichlet:SNO."$seed" bash run/.sh --pde poisson.dirichlet --model sno --mode 31 31 --spectral
4 | seed=$seed iter=100000 note=ps/relu/dirichlet:NSM."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --spectral
5 | seed=$seed iter=100000 note=ps/relu/dirichletx64."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 64
6 | seed=$seed iter=100000 note=ps/relu/dirichletx128."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 128
7 | seed=$seed iter=100000 note=ps/relu/dirichletx256."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 256
8 | seed=$seed iter=100000 note=ps/relu/dirichletx256C."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 256 --cheb --vmap 16
9 |
10 | seed=$seed iter=100000 note=ps/tanh/dirichlet:NSM."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --spectral --activate tanh
11 | seed=$seed iter=100000 note=ps/tanh/dirichletx64."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 64 --activate tanh
12 | seed=$seed iter=100000 note=ps/tanh/dirichletx128."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 128 --activate tanh
13 | seed=$seed iter=100000 note=ps/tanh/dirichletx256."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 256 --activate tanh
14 | seed=$seed iter=100000 note=ps/tanh/dirichletx256C."$seed" bash run/.sh --pde poisson.dirichlet --model fno --mode 31 31 --grid 256 --activate tanh --cheb --vmap 16
15 |
16 | done
--------------------------------------------------------------------------------
/run/poisson.periodic.sh:
--------------------------------------------------------------------------------
1 | for seed in 0 1 2 3; do
2 |
3 | seed=$seed note=ps/relu/periodic:SNO."$seed" bash run/.sh --pde poisson.periodic --model sno --mode 31 31 --spectral
4 | seed=$seed note=ps/relu/periodic:NSM."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --spectral
5 | seed=$seed note=ps/relu/periodicx64."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 64
6 | seed=$seed note=ps/relu/periodicx128."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 128
7 | seed=$seed note=ps/relu/periodicx256."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 256
8 |
9 | seed=$seed note=ps/tanh/periodicx64."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 64 --activate tanh
10 | seed=$seed note=ps/tanh/periodicx128."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 128 --activate tanh
11 | seed=$seed note=ps/tanh/periodicx256."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 256 --activate tanh
12 |
13 | iter=100000 \
14 | seed=$seed note=ps/long/periodicx256."$seed" bash run/.sh --pde poisson.periodic --model fno --mode 31 31 --grid 256 --activate tanh
15 |
16 | done
--------------------------------------------------------------------------------
/run/reaction.sh:
--------------------------------------------------------------------------------
1 | function run() {
2 | note=rd/"$name":SNO."$seed" bash run/.sh --pde reaction."$name" --model sno --mode 32 64 --spectral
3 | note=rd/"$name":NSM."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --spectral
4 | note=rd/"$name"x64."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --grid 64
5 | note=rd/"$name"x128."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --grid 128
6 | note=rd/"$name"x256."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --grid 256
7 | note=rd/"$name"x256C."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --grid 256 --cheb
8 | note=rd/"$name"x256F."$seed" bash run/.sh --pde reaction."$name" --model fno --mode 32 64 --spectral --fourier
9 | }
10 |
11 | for seed in 0 1 2 3; do
12 |
13 | seed=$seed name=nu005 run
14 | seed=$seed name=nu01 run
15 | seed=$seed name=nu05 iter=100000 run
16 | seed=$seed name=nu1 iter=200000 run
17 |
18 | done
--------------------------------------------------------------------------------
/run/requirements.txt:
--------------------------------------------------------------------------------
1 | ################################################################
2 | # PLEASE INSTALL JAXLIB BASE ON YOUR OWN MACHINE CONFIGURATION #
3 | ################################################################
4 |
5 | jax==0.4.7
6 | flax==0.7.0
7 | optax==0.1.5
8 |
9 | tqdm==4.65.0
10 | matplotlib==3.7.0
11 | scienceplots==2.1.0
12 |
--------------------------------------------------------------------------------
/src/README:
--------------------------------------------------------------------------------
1 | Please refer to the READMEs in each directories for details.
2 |
3 | FILE STRUCTURE
4 |
5 | |- main.py : entry point.
6 | |- ckpt.py : checkpoint subproc.
7 | |- plot/ : matplotlib utilities.
8 | |- run/ : command line scripts.
9 | |- src/
10 | |- dists.py : common distributions.
11 | |- train.py : training routines: step & eval.
12 | |- utils.py : grid generation, differentiate schemes, etc.
13 | |- basis/
14 | | |- fourier.py : 1-d Trigonometric series.
15 | | |- chebyshev.py : 1-d T type Chebyshev polynomials.
16 | |- pde/
17 | | |- _domain.py : physical domain definition. N-d unit rect.
18 | | |- _params.py : parameterize PDE with interpolating series.
19 | | |- mollifier.py : how boundary condition applies to solution.
20 | |- model/
21 | |- _base.py : shared modules.
22 | |- sno/spectral.py: SNO.
23 | |- fno/ : FNO.
24 | |- __init__.py : PINN.
25 | |- spectral.py : NSM (ours).
26 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 |
5 | import operator as O
6 | import itertools as I
7 | import functools as F
8 |
9 | # ---------------------------------------------------------------------------- #
10 | # JAX #
11 | # ---------------------------------------------------------------------------- #
12 |
13 | import jax
14 |
15 | import flax
16 | import optax
17 |
18 | import jax.numpy as np
19 | import flax.linen as nn
20 |
21 | # ---------------------------------------------------------------------------- #
22 | # TYPE #
23 | # ---------------------------------------------------------------------------- #
24 |
25 | from abc import *
26 | from typing import *
27 |
28 | from jax import Array
29 | from flax import struct
30 |
31 | X = Union[Tuple["X", ...], List["X"], Array]
32 | ϴ = Union[struct.PyTreeNode, "X", None]
33 |
34 | Fx = Callable[..., "X"] # real-valued function
35 | Fϴ = Callable[..., "Fx"] # parametrized function
36 |
37 | # ---------------------------------------------------------------------------- #
38 | # CONST #
39 | # ---------------------------------------------------------------------------- #
40 |
41 | e = np.e
42 | π = np.pi
43 |
44 | Δ = F.partial(np.einsum, "...ii -> ...")
45 |
46 | # ---------------------------------------------------------------------------- #
47 | # RANDOM #
48 | # ---------------------------------------------------------------------------- #
49 |
50 | from jax import random
51 | RNG = Dict[str, random.KeyArray]
52 |
53 | class RNGS(RNG):
54 |
55 | def __init__(self, prng: random.KeyArray, name: List[str]):
56 | keys = random.split(prng, len(name))
57 | super().__init__(zip(name, keys))
58 |
59 | def __next__(self) -> RNG:
60 | self.it = getattr(self, "it", 0) + 1
61 | return self.fold_in(self.it)
62 |
63 | def fold_in(self, key: Any) -> RNG:
64 | return { name: random.fold_in(data, hash(key))
65 | for name, data in self.items() }
66 |
--------------------------------------------------------------------------------
/src/basis/README:
--------------------------------------------------------------------------------
1 | This is a self-contained module of different orthogonal basis functions. `Basis`
2 | is the abstract class defined for a general N-d basis, in terms of its spectral
3 | coefficients. Each class is associated with a static `ndim`, the dimensionality
4 | of the basis. The coefficient data, `coef`, is at least `ndim` in size:
5 |
6 | - The first `ndim` dimensions corresponds to the basis dimension;
7 |
8 | - The rest of the them are interpreted as arbitrary channels, i.e. they are
9 | broadcasted by each operation, and have no special meaning to the basis.
10 |
11 | The `Basis` class requires the following implementations:
12 |
13 | - `ix`: defines ordered indices of `coef`s w.r.t. given number of modes. It
14 | is used to truncate or extend instances of basis to other modes.
15 |
16 | - `fn`: defines array of basis functions w.r.t. given number of modes;
17 |
18 | - `grid`: defines the collocation points w.r.t. given number of modes;
19 |
20 | - `transform`/`inv`: transforms between coefficients and function values on
21 | the collocation points defined by the `grid` function;
22 |
23 | The `Basis` class provides the following functionalities:
24 |
25 | - `__call__`: evaluates function values at any position. `grid`point values
26 | are identical to the results of `inv`erse transformation; but
27 | calling `inv` is usually faster due to the use of FFTs.
28 |
29 | - `to`: aligns the instance of basis function to another mode; larger modes
30 | are truncated, and smaller modes are padded with zeros.
31 |
32 | - `add`/`mul`: sums / multiplies basis functions from the same class. Given
33 | operands are aligned to the same number of modes first.
34 |
35 | - `grad`/`int`: obtains the derivatives and indefinite integrals along each
36 | dim. The resulting function has an extra trailing dimension
37 | to it, representing the operation on each of the `ndim`s.
38 |
39 | IMPLEMENTATIONS
40 |
41 | Inherited from `Basis`, the `Series` specializes to 1-d basis, which is further
42 | realized by `fourier` and `chebyshev` basis. Built on any sequence of `Series`,
43 | the class factory function `series` recursively instantiates a `Basis` class on
44 | top of the given classes of 1-d series.
45 |
46 | **SHARP BITS**
47 |
48 | 1. It might not be lossless to take derivatives on Fourier basis, if the number
49 | of modes is even. I'm using a compact way to storing Fourier coefficients by
50 | squashing the Hermitian spectrum into reals of the same shape. Therefore the
51 | last coefficient is on its own, which will be dropped by taking gradient.
52 |
53 | 2. The (inverse) transform of Chebyshev basis does not work with singltons. I'm
54 | using Hermitian FFT, which returns zero-sized array for one-sized input. Try
55 | broadcasting the input to at least two-sized before transforming.
56 |
--------------------------------------------------------------------------------
/src/basis/__init__.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 | from .. import utils
3 |
4 | @struct.dataclass
5 | class Basis(ABC):
6 |
7 | coef: Array
8 |
9 | """
10 | Basis function on [0, 1]^ndim. The leading `ndim` axes are for:
11 |
12 | - (spectral) coefficients of basis functions
13 | - (physical) values evaluated on collocation points
14 | """
15 |
16 | @staticmethod
17 |
18 | @abstractmethod
19 | def repr(self) -> str: pass
20 |
21 | @staticmethod
22 |
23 | @abstractmethod
24 | def ndim() -> int: pass
25 |
26 | @property
27 | def mode(self): return self.coef.shape[:self.ndim()]
28 |
29 | def map(self, f: Fx): return self.__class__(f(self.coef))
30 |
31 | @staticmethod
32 |
33 | @abstractmethod
34 | def grid(*mode: int) -> X: pass
35 |
36 | @staticmethod
37 |
38 | @abstractmethod
39 | def ix(*mode: int) -> X: pass
40 |
41 | @staticmethod
42 |
43 | @abstractmethod
44 | def fn(*mode: int, x: X) -> X: pass
45 | def __call__(self, x: X) -> X:
46 |
47 | assert x.shape[-1] == self.ndim(), f"{x.shape[-1]=} =/= {self.ndim()=}"
48 | return np.tensordot(self.fn(*self.mode, x=x), self.coef, self.ndim())
49 |
50 | def to(self, *mode: int):
51 |
52 | if self.mode == mode: return self
53 | ax = self.ix(*map(min, mode, self.mode))
54 |
55 | zero = np.zeros(mode + self.coef.shape[self.ndim():])
56 | return self.__class__(zero.at[ax].set(self.coef[ax]))
57 |
58 | @classmethod
59 | def add(cls, *terms): return cls(sum(map(O.attrgetter("coef"), align(*terms, scheme=max))))
60 |
61 | @classmethod
62 | def mul(cls, *terms): return cls.transform(math.prod(map(cls.inv, align(*terms, scheme=sum))))
63 |
64 | # --------------------------------- TRANSFORM -------------------------------- #
65 |
66 | @staticmethod
67 |
68 | @abstractmethod
69 | def transform(x: X): pass
70 |
71 | @abstractmethod
72 | def inv(self) -> X: pass
73 |
74 | # --------------------------------- OPERATOR --------------------------------- #
75 |
76 | @abstractmethod
77 | def grad(self, k: int = 1): pass
78 |
79 | @abstractmethod
80 | def int(self, k: int = 1): pass
81 |
82 | def align(*basis: Basis, scheme: Fx = max) -> Tuple[Basis]:
83 |
84 | # asserting uniform properties:
85 |
86 | _ = set(map(lambda cls: cls.repr(), basis))
87 | _ = set(map(lambda cls: cls.ndim(), basis))
88 | _ = set(map(lambda self: self.coef.ndim, basis))
89 |
90 | mode = tuple(map(scheme, zip(*map(O.attrgetter("mode"), basis))))
91 | return tuple(map(lambda self: self.to(*mode), basis))
92 |
93 | # ---------------------------------------------------------------------------- #
94 | # SERIES #
95 | # ---------------------------------------------------------------------------- #
96 |
97 | class SeriesMeta(ABCMeta, type):
98 |
99 | def __getitem__(cls, n: int):
100 | return series(*(cls,)*n)
101 |
102 | @struct.dataclass
103 | class Series(Basis, metaclass=SeriesMeta):
104 |
105 | """1-dimensional series on interval"""
106 |
107 | @staticmethod
108 | def ndim() -> int: return 1 # on [0, 1]
109 | def __len__(self): return len(self.coef)
110 |
111 | @abstractmethod
112 | def __getitem__(self, s: int) -> X: pass
113 |
114 | def series(*types: Type[Series]) -> Type[Basis]:
115 |
116 | """
117 | Generate new basis using finite product of given series. Each argument
118 | type corresponds to certain kind of series used for each dimension.
119 | """
120 |
121 | @struct.dataclass
122 | class Class(Basis):
123 |
124 | @staticmethod
125 | def repr() -> str: return "".join(map(O.methodcaller("repr"), types))
126 |
127 | @staticmethod
128 | def ndim() -> int: return len(types)
129 |
130 | @staticmethod
131 | def grid(*mode: int) -> X:
132 |
133 | assert len(mode) == len(types)
134 |
135 | axes = mesh(lambda i, cls: cls.grid(mode[i]).squeeze(1))
136 | return np.stack(axes, axis=-1)
137 |
138 | def ix(self, *mode: int) -> X:
139 |
140 | return np.ix_(*map(lambda self, n: self.ix(n), types, mode))
141 |
142 | def fn(self, *mode: int, x: X) -> X:
143 |
144 | axes = mesh(lambda i, self: self.fn(mode[i], x=x[..., [i]]))
145 | return np.product(np.stack(axes, axis=-1), axis=-1)
146 |
147 | def __getitem__(self, s: Tuple[int]) -> X:
148 |
149 | return jax.vmap(F.partial(Super.__getitem__, s=s[1:]))(Super(Self(self.coef)[s[0]]))
150 |
151 | # --------------------------------- TRANSFORM -------------------------------- #
152 |
153 | @staticmethod
154 | def transform(x: X):
155 |
156 | return Class(Self.transform(jax.vmap(Super.transform)(x).coef).coef)
157 |
158 | def inv(self) -> X:
159 |
160 | return jax.vmap(Super.inv)(Super(Self(self.coef).inv()))
161 |
162 | # --------------------------------- OPERATOR --------------------------------- #
163 |
164 | def grad(self, k: int = 1):
165 |
166 | coef = jax.vmap(F.partial(Super.grad, k=k))(Super(self.coef)).coef
167 | return Class(np.concatenate([Self(self.coef).grad(k).coef, coef], axis=-1))
168 |
169 | def int(self, k: int = 1):
170 |
171 | coef = jax.vmap(F.partial(Super.int, k=k))(Super(self.coef)).coef
172 | return Class(np.concatenate([Self(self.coef).int(k).coef, coef], axis=-1))
173 |
174 | def mesh(call: Fx) -> Tuple[X]:
175 | def cat(*x: X) -> Tuple[X]:
176 |
177 | n, = set(map(np.ndim, x))
178 |
179 | if n != 1: return jax.vmap(cat)(*x)
180 | return np.meshgrid(*x, indexing="ij")
181 |
182 | args = zip(*enumerate(types))
183 | return cat(*map(call, *args))
184 |
185 | try: cls, = types; return cls
186 | except:
187 |
188 | Self, *other = types
189 | Super = series(*other)
190 |
191 | return Class
192 |
--------------------------------------------------------------------------------
/src/basis/chebyshev.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | @struct.dataclass
4 | class Chebyshev(Series):
5 |
6 | """
7 | Chebyshev polynomial of T kind
8 | - Tn(x) = cos(n cos^-1(x))
9 | - Tn^*(x) = Tn(2 x - 1)
10 | """
11 |
12 | @staticmethod
13 | def repr() -> str: return "C"
14 |
15 | @staticmethod
16 | def grid(n: int) -> X: return np.cos(π * utils.grid(n))/2+0.5
17 |
18 | @staticmethod
19 | def ix(n: int) -> X: return np.arange(n)
20 |
21 | @staticmethod
22 | def fn(n: int, x: X) -> X: return np.cos(np.arange(n) * np.arccos(x*2-1))
23 |
24 | def __getitem__(self, s: int) -> X:
25 | if isinstance(s, Tuple): s, = s
26 | return self(utils.grid(s))
27 |
28 | # --------------------------------- TRANSFORM -------------------------------- #
29 |
30 | @staticmethod
31 | def transform(x: X):
32 |
33 | coef = np.fft.hfft(x, axis=0, norm="forward")[:len(x)]
34 | coef = coef.at[1:-1].multiply(2)
35 |
36 | assert len(x) > 1, "sharp bits!"
37 | return Chebyshev(coef)
38 |
39 | def inv(self) -> X:
40 |
41 | coef = self.coef.at[1:-1].divide(2)
42 | coef = np.concatenate([coef, coef[::-1][1:-1]])
43 |
44 | return np.fft.ihfft(coef, axis=0, norm="forward").real
45 |
46 | # --------------------------------- OPERATOR --------------------------------- #
47 |
48 | def grad(self, k: int = 1):
49 |
50 | coef = np.linalg.matrix_power(np.pad(gradient(len(self)), [(0, 1), (0, 0)]), k)
51 | return Chebyshev(np.tensordot(coef, self.coef, (1, 0))[..., np.newaxis])
52 |
53 | def int(self, k: int = 1):
54 |
55 | coef = np.linalg.matrix_power(integrate(len(self))[:-1], k)
56 | return Chebyshev(np.tensordot(coef, self.coef, (1, 0))[..., np.newaxis])
57 |
58 | # ---------------------------------------------------------------------------- #
59 | # MATRIX #
60 | # ---------------------------------------------------------------------------- #
61 |
62 | """
63 | Chebyshev gradient and integrate matrix
64 |
65 | - gradient ∈ R ^ n-1⨉n; integrate ∈ R ^ n+1⨉n
66 | - When aligned, they are pseudo-inverse of each other:
67 | `gradient(n+1) @ integrate(n) == identity(n)`
68 | """
69 |
70 | def gradient(n: int) -> X:
71 |
72 | alternate = np.pad(np.eye(2), [(0, n-3), (0, n-3)], mode="reflect").at[0].divide(2)
73 | coef = np.concatenate([np.zeros(n - 1)[:, np.newaxis], np.triu(alternate)], axis=1)
74 |
75 | return coef * 4 * np.arange(n)
76 |
77 | def integrate(n: int) -> X:
78 |
79 | shift = np.identity(n).at[0, 0].set(2) - np.eye(n, k=2)
80 | coef = np.concatenate([np.zeros(n)[np.newaxis], shift])
81 |
82 | return coef.at[1:].divide(4 * np.arange(1, n+1)[:, np.newaxis])
83 |
--------------------------------------------------------------------------------
/src/basis/fourier.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | @struct.dataclass
4 | class Fourier(Series):
5 |
6 | """
7 | Trigonometric series
8 |
9 | fk(x) = e^{ik·x}
10 | """
11 |
12 | @staticmethod
13 | def repr() -> str: return "F"
14 |
15 | @staticmethod
16 | def grid(n: int) -> X: return utils.grid(n, mode="left")
17 |
18 | @staticmethod
19 | def ix(n: int) -> X: return np.r_[-n//2+1:n//2+1]
20 |
21 | @staticmethod
22 | def fn(n: int, x: X) -> X:
23 |
24 | return np.moveaxis(real(np.moveaxis(np.exp(x * -freq(n)), -1, 0), n), 0, -1)
25 |
26 | def __getitem__(self, s: int) -> X:
27 | if isinstance(s, Tuple): s, = s
28 |
29 | return np.concatenate([x:=self.to(s - 1).inv(), x[np.r_[0]]])
30 |
31 | # --------------------------------- TRANSFORM -------------------------------- #
32 |
33 | @staticmethod
34 | def transform(x: X):
35 |
36 | coef = np.fft.rfft(x, axis=0, norm="forward")
37 | coef = coef.at[1:-(len(x)//-2)].multiply(2)
38 |
39 | return Fourier(real(coef, len(x)))
40 |
41 | def inv(self) -> X:
42 |
43 | coef = comp(self.coef, n:=len(self))
44 | coef = coef.at[1:-(n//-2)].divide(2)
45 |
46 | return np.fft.irfft(coef, len(self), axis=0, norm="forward")
47 |
48 | # --------------------------------- OPERATOR --------------------------------- #
49 |
50 | def grad(self, k: int = 1):
51 |
52 | coef = np.expand_dims(freq(len(self))**k, range(1, self.coef.ndim))
53 | return Fourier(real(comp(self.coef, n:=len(self)) * coef, n)[..., None])
54 |
55 | def int(self, k: int = 1):
56 |
57 | coef = np.expand_dims(self.freq(len(self)), range(1, self.coef.ndim))
58 | return Fourier((self.coef / coef ** k)[..., np.newaxis].at[0].set(0))
59 |
60 | # ---------------------------------------------------------------------------- #
61 | # HELPER #
62 | # ---------------------------------------------------------------------------- #
63 |
64 | def freq(n: int) -> X: return np.arange(n//2+1) * 2j * π
65 |
66 | def real(coef: X, n: int) -> X:
67 |
68 | """Complex coef -> Real coef"""
69 |
70 | cos, sin = coef.real, coef.imag[1:-(n//-2)]
71 | return np.concatenate((cos, sin[::-1]), 0)
72 |
73 | def comp(coef: X, n: int) -> X:
74 |
75 | """Real coef -> Complex coef"""
76 |
77 | cos, sin = np.split(coef, (m:=n//2+1, ))
78 | return (cos+0j).at[n-m:0:-1].add(sin*1j)
79 |
--------------------------------------------------------------------------------
/src/dists.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | class Ω(ABC):
4 |
5 | """
6 | Distribution
7 | """
8 |
9 | @abstractmethod
10 | def sample(self, prng) -> X: pass
11 |
12 | # ---------------------------------------------------------------------------- #
13 | # UNIFORM #
14 | # ---------------------------------------------------------------------------- #
15 |
16 | class Uniform:
17 |
18 | """
19 | Uniform distribution
20 | """
21 |
22 | min: X
23 | max: X
24 |
25 | def __init__(self, min: X, max: X):
26 |
27 | self.min = np.array(min)
28 | self.max = np.array(max)
29 |
30 | def sample(self, prng, shape=()) -> X:
31 |
32 | scale = self.max - self.min
33 | x = random.uniform(prng, shape + scale.shape)
34 |
35 | return x * scale + self.min
36 |
37 | # ---------------------------------------------------------------------------- #
38 | # NORMAL #
39 | # ---------------------------------------------------------------------------- #
40 |
41 | class Normal:
42 |
43 | """
44 | Normal distribution
45 | """
46 |
47 | μ: X
48 | λ: X
49 |
50 | def __init__(self, μ: X, Σ: X):
51 |
52 | self.μ = μ
53 |
54 | U, Λ, _ = np.linalg.svd(Σ)
55 | self.λ = U * np.sqrt(Λ)
56 |
57 | def sample(self, prng, shape=()) -> X:
58 |
59 | var = random.normal(prng, shape + self.μ.shape)
60 | ε = np.einsum("...ij,...j->...i", self.λ, var)
61 |
62 | return self.μ + ε
63 |
64 | # ---------------------------------------------------------------------------- #
65 | # GAUSSIAN #
66 | # ---------------------------------------------------------------------------- #
67 |
68 | class Gaussian(Normal):
69 |
70 | """
71 | Gaussian Process
72 | """
73 |
74 | dim: Tuple[int]
75 |
76 | def __init__(self, grid: X, kernel: Fx):
77 |
78 | *dim, ndim = grid.shape
79 | assert len(dim) == ndim
80 |
81 | X = grid.reshape(-1, ndim)
82 | K = jax.vmap(kernel, (0, None))
83 | Σ = jax.vmap(lambda y: K(X, y))(X)
84 |
85 | super().__init__(np.zeros(len(Σ)), Σ)
86 | self.dim = tuple(dim)
87 |
88 | def sample(self, prng, shape=()) -> X:
89 |
90 | x = super().sample(prng, shape)
91 | return x.reshape(shape + self.dim)
92 |
93 | # ---------------------------------- KERNEL ---------------------------------- #
94 |
95 | RBF = lambda ƛ: lambda x, y: np.exp(-np.sum((x-y)**2) / ƛ**2/2)
96 | Per = lambda ƛ: lambda x, y: np.exp(-np.sum((np.sin(π*(x-y))/2)**2) / ƛ**2*2)
97 |
--------------------------------------------------------------------------------
/src/model/README:
--------------------------------------------------------------------------------
1 | This module implements both SNO, FNO and NSM. The `_base` module defines shared
2 | classes used by neural operators, namely the spectral convolution layers. It is
3 | a general implementation for arbitrary input dimensions, defined by the `Basis`
4 | class. Similarily, the SNO and FNO models are also compatible to any dimension.
5 |
6 | Each model is associated with a `loss` function, which defines how the loss is
7 | obtained for a given sample of the parameter function. For an instantiation of
8 | the PDE class, this function returns a dictionary of the loss terms separately.
9 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 | from ..pde import *
3 | from ..basis import *
4 |
5 | # ---------------------------------------------------------------------------- #
6 | # SOLVER #
7 | # ---------------------------------------------------------------------------- #
8 |
9 | class Solver(ABC, nn.Module):
10 |
11 | pde: PDE
12 | cfg: Dict
13 |
14 | @F.cached_property
15 | def activate(self) -> Fx: return \
16 | getattr(nn, self.cfg["activate"])
17 |
18 | @abstractmethod
19 | def u(self, ϕ: Fx, x: X) -> X: pass
20 |
21 | @abstractmethod
22 | def loss(self, ϕ: Fx) -> Dict[str, X]: pass
23 |
24 | # ----------------------------------- PINN ----------------------------------- #
25 |
26 | class PINN(Solver, ABC):
27 |
28 | @abstractmethod
29 | def forward(self, ϕ: Basis, s: Tuple[int]) -> Tuple[X, X]: pass
30 |
31 | def u(self, ϕ: Basis, x: Tuple[int]) -> X:
32 |
33 | assert isinstance(x, Tuple), "uniform"
34 | assert len(x) == self.pde.domain.ndim
35 |
36 | return self.forward(ϕ, x)[-1]
37 |
38 | def loss(self, ϕ: Basis) -> Dict[str, X]:
39 |
40 | x, y, u = self.forward(ϕ, (self.cfg["grid"], ) * (d:=self.pde.domain.ndim))
41 | edges = [(np.take(u, 0, axis=n), np.take(u, -1, axis=n)) for n in range(d)]
42 |
43 | R = self.pde.equation(x, y, u)
44 | B = self.pde.boundary(edges)
45 |
46 | return dict(
47 | residual=np.mean(R**2), **{
48 | f"boundary{n}": np.mean(Bn**2)
49 | for n, Bn in enumerate(B)
50 | })
51 |
52 | # --------------------------------- SPECTRAL --------------------------------- #
53 |
54 | class Spectral(Solver, ABC):
55 |
56 | @abstractmethod
57 | def forward(self, ϕ: Basis) -> Basis: pass
58 |
59 | def u(self, ϕ: Basis, x: X) -> X:
60 |
61 | u = self.forward(ϕ)
62 |
63 | if isinstance(x, Tuple): return u[x]
64 | if isinstance(x, Array): return u(x)
65 |
66 | def loss(self, ϕ: Basis) -> Dict[str, X]:
67 |
68 | R = self.pde.spectral(ϕ, self.forward(ϕ))
69 | return dict(residual=np.sum(np.square(R.coef)))
70 |
71 | # ---------------------------------------------------------------------------- #
72 | # TRAINER #
73 | # ---------------------------------------------------------------------------- #
74 |
75 | class Trainer(ABC, nn.Module):
76 |
77 | mod: Solver
78 | pde: PDE
79 | cfg: Dict
80 |
81 | def setup(self):
82 |
83 | # --------------------------------- SCHEDULER -------------------------------- #
84 |
85 | if self.cfg["schd"] is None:
86 |
87 | scheduler = self.cfg["lr"]
88 |
89 | if self.cfg["schd"] == "cos":
90 |
91 | scheduler = optax.cosine_decay_schedule(self.cfg["lr"], self.cfg["iter"])
92 |
93 | if self.cfg["schd"] == "exp":
94 |
95 | decay_rate = 1e-3 ** (1.0 / self.cfg["iter"])
96 | scheduler = optax.exponential_decay(self.cfg["lr"], 1, decay_rate)
97 |
98 | # --------------------------------- OPTIMIZER -------------------------------- #
99 |
100 | self.optimizer = optax.adam(scheduler)
101 |
102 | @nn.compact
103 | def init(self):
104 |
105 | ϕ = self.pde.params.sample(prng:=self.make_rng("sample"))
106 | s = tuple([self.cfg["grid"]] * self.pde.domain.ndim)
107 |
108 | variable = self.mod.init(prng, ϕ, s, method="u")
109 | print(self.mod.tabulate(prng, ϕ, s, method="u"))
110 |
111 | self.variable("optim", "state", self.optimizer.init, variable["params"])
112 |
113 | return variable
114 |
--------------------------------------------------------------------------------
/src/model/_base.py:
--------------------------------------------------------------------------------
1 | from . import *
2 | from ..basis import *
3 |
4 | class SpectralConv(nn.Module):
5 |
6 | odim: int
7 |
8 | # do truncation or not
9 | mode: Tuple[int] = None
10 |
11 | # initialization modes
12 | init: Tuple[int] = None
13 |
14 | @nn.compact
15 | def __call__(self, u: Basis) -> Basis:
16 |
17 | def W(a: X) -> X:
18 |
19 | def init(prng, *shape: Tuple[int]) -> X:
20 | x = random.normal(prng, shape)
21 |
22 | *mode, idim, odim = shape
23 |
24 | if self.init is None:
25 |
26 | scale = 1 / idim / odim
27 |
28 | else:
29 |
30 | from math import prod
31 | rate = prod(self.init) / prod(mode)
32 |
33 | scale = np.sqrt(rate / idim)
34 |
35 | return x * scale
36 |
37 | W = self.param("W", init, *a.shape, self.odim)
38 |
39 | dims = (N:=u.ndim(), N), (B:=range(N), B)
40 | return jax.lax.dot_general(a, W, dims)
41 |
42 | mode = self.mode or u.mode
43 | return u.to(*mode).map(W).to(*u.mode)
44 |
45 | class SpatialMixing(nn.Module):
46 |
47 | @nn.compact
48 | def __call__(self, u: Basis) -> Basis:
49 |
50 | def M(a: X, i: int) -> X:
51 |
52 | def init(prng, *shape: Tuple[int]) -> X:
53 | x = random.normal(prng, shape)
54 |
55 | return x / np.sqrt(shape[-1])
56 |
57 | M = self.param(f"M{i}", init, *a.shape, a.shape[i])
58 |
59 | batch = [*range(i), *range(i + 1, u.ndim()), u.ndim()]
60 | a = jax.lax.dot_general(a, M, ((i, i), (batch, batch)))
61 |
62 | return np.moveaxis(a, -1, i)
63 |
64 | return u.map(F.partial(F.reduce, M, range(u.ndim())))
65 |
--------------------------------------------------------------------------------
/src/model/fno/__init__.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 | from .._base import *
3 | from ...basis.fourier import *
4 |
5 | # ---------------------------------------------------------------------------- #
6 | # SOLVER #
7 | # ---------------------------------------------------------------------------- #
8 |
9 | class FNO(PINN):
10 |
11 | def __repr__(self): return f"FNOx{self.cfg['grid']}+PINN"
12 |
13 | @nn.compact
14 | def forward(self, ϕ: Basis, s: Tuple[int]) -> Tuple[X, X]:
15 |
16 | if self.cfg["cheb"]: T = self.pde.basis
17 | else: T = Fourier[self.pde.domain.ndim]
18 |
19 | from .. import utils
20 | x = utils.grid(*s)
21 |
22 | z = np.concatenate([x, y:=ϕ[s]], -1)
23 |
24 | z = nn.Dense(self.cfg["hdim"] * 4)(z)
25 | z = self.activate(z)
26 |
27 | z = nn.Dense(self.cfg["hdim"])(z)
28 | z = self.activate(z)
29 |
30 | for _ in range(self.cfg["depth"]):
31 |
32 | conv = SpectralConv(self.cfg["hdim"], self.cfg["mode"])(T.transform(z)).inv()
33 | fc = nn.Dense(self.cfg["hdim"])(z)
34 |
35 | z = conv + fc
36 | z = self.activate(z)
37 |
38 | z = nn.Dense(self.cfg["hdim"])(z)
39 | z = self.activate(z)
40 |
41 | z = nn.Dense(self.cfg["hdim"] * 4)(z)
42 | z = self.activate(z)
43 |
44 | z = nn.Dense(self.pde.odim)(z)
45 | return x, y, self.pde.mollifier(y, (x, z))
46 |
--------------------------------------------------------------------------------
/src/model/fno/spectral.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 | from .._base import *
3 | from ...basis.fourier import *
4 |
5 | # ---------------------------------------------------------------------------- #
6 | # SOLVER #
7 | # ---------------------------------------------------------------------------- #
8 |
9 | class FNO(Spectral):
10 |
11 | def __repr__(self): return "NSM"
12 |
13 | @nn.compact
14 | def forward(self, ϕ: Basis) -> Basis:
15 |
16 | if not self.cfg["fourier"]: T = self.pde.basis
17 | else: T = Fourier[self.pde.domain.ndim]
18 |
19 | u = ϕ.to(*self.cfg["mode"])
20 |
21 | bias = T.transform(u.grid(*u.mode)).coef
22 | u = u.map(lambda coef: np.concatenate([coef, bias], axis=-1))
23 |
24 | u = u.map(nn.Dense(self.cfg["hdim"] * 4))
25 | u = T.transform(self.activate(u.inv()))
26 |
27 | u = u.map(nn.Dense(self.cfg["hdim"]))
28 | u = T.transform(self.activate(u.inv()))
29 |
30 | for _ in range(self.cfg["depth"]):
31 |
32 | conv = SpectralConv(self.cfg["hdim"])(u)
33 | fc = u.map(nn.Dense(self.cfg["hdim"]))
34 |
35 | u = T.add(conv, fc)
36 | u = T.transform(self.activate(u.inv()))
37 |
38 | u = u.map(nn.Dense(self.cfg["hdim"]))
39 | u = T.transform(self.activate(u.inv()))
40 |
41 | u = u.map(nn.Dense(self.cfg["hdim"] * 4))
42 | u = T.transform(self.activate(u.inv()))
43 |
44 | u = u.map(nn.Dense(self.pde.odim))
45 | return self.pde.mollifier(ϕ, u)
46 |
--------------------------------------------------------------------------------
/src/model/sno/spectral.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 | from .._base import *
3 |
4 | # ---------------------------------------------------------------------------- #
5 | # SOLVER #
6 | # ---------------------------------------------------------------------------- #
7 |
8 | class SNO(Spectral):
9 |
10 | def __repr__(self): return "SNO"
11 |
12 | @nn.compact
13 | def forward(self, ϕ: Basis) -> Basis:
14 |
15 | u = ϕ.to(*self.cfg["mode"])
16 |
17 | bias = u.transform(u.grid(*u.mode)).coef
18 | u = u.map(lambda coef: np.concatenate([coef, bias], axis=-1))
19 |
20 | u = u.map(nn.Dense(self.cfg["hdim"] * 4))
21 | u = u.map(self.activate)
22 |
23 | u = u.map(nn.Dense(self.cfg["hdim"]))
24 | u = u.map(self.activate)
25 |
26 | for _ in range(self.cfg["depth"]):
27 |
28 | def Integral(coef: X) -> X:
29 |
30 | K = nn.DenseGeneral(u.mode, axis=range(-u.ndim(), 0))
31 | return np.moveaxis(K(np.moveaxis(coef, -1, 0)), 0, -1)
32 |
33 | u = u.map(Integral)
34 |
35 | u = u.map(nn.Dense(self.cfg["hdim"]))
36 | u = u.map(self.activate)
37 |
38 | u = u.map(nn.Dense(self.cfg["hdim"]))
39 | u = u.map(self.activate)
40 |
41 | u = u.map(nn.Dense(self.cfg["hdim"] * 4))
42 | u = u.map(self.activate)
43 |
44 | u = u.map(nn.Dense(self.pde.odim))
45 | return self.pde.mollifier(ϕ, u)
46 |
--------------------------------------------------------------------------------
/src/pde/README:
--------------------------------------------------------------------------------
1 | This module defines the interface of PDE classess. Abstract classes in `_domain`
2 | and `_params` modules defines domain and PDE parameter. Functions in `mollifier`
3 | define how each type of boundary condition can be satisfied by transforming the
4 | predicted solution. Built on these classes, the general `PDE` class defines how
5 | the residual terms are computed in both the physical and spectral domain, which
6 | will be used by the models for loss calculation.
7 |
8 | Each subdirectory contains a instantiation of these interface, together with a
9 | `generate` module for numerical solution generation. Follow these templates if
10 | you want to try a new problem.
11 |
--------------------------------------------------------------------------------
/src/pde/__init__.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 |
3 | class PDE(ABC):
4 |
5 | from ._domain import R
6 | from ._params import G
7 |
8 | odim: int # output dimension
9 |
10 | domain: R # interior domain
11 | params: G # parameter function
12 |
13 | mollifier: Fx # transformation
14 |
15 | equation: Fx # PDE (equation)
16 | boundary: Fx # PDE (boundary)
17 |
18 | from ..basis import Basis
19 | basis: Basis # basis function
20 | spectral: Fx # PDE (spectral)
21 |
22 | solution: Any
23 |
--------------------------------------------------------------------------------
/src/pde/_domain.py:
--------------------------------------------------------------------------------
1 | from . import *
2 | from ..dists import *
3 |
4 | class R(Ω):
5 |
6 | """
7 | Euclidean space
8 | """
9 |
10 | ndim: int
11 | boundary: List[Ω]
12 |
13 | # ---------------------------------------------------------------------------- #
14 | # RECT #
15 | # ---------------------------------------------------------------------------- #
16 |
17 | class Rect(Uniform, R):
18 |
19 | """
20 | N-d unit rectangle
21 | """
22 |
23 | def __init__(self, ndim: int):
24 | super().__init__(np.zeros(ndim),
25 | np.ones(ndim))
26 |
27 | class Boundary(Uniform):
28 |
29 | def __init__(self, dim: int):
30 | super().__init__(np.zeros(ndim-1),
31 | np.ones(ndim-1))
32 |
33 | self.dim = dim
34 |
35 | def sample(self, prng, shape=()) -> X:
36 |
37 | x = super().sample(prng, shape)
38 | return np.insert(x, self.dim, np.zeros(shape), axis=-1), \
39 | np.insert(x, self.dim, np.ones(shape), axis=-1)
40 |
41 | self.ndim = ndim
42 | self.boundary = list(map(Boundary, range(ndim)))
43 |
--------------------------------------------------------------------------------
/src/pde/_params.py:
--------------------------------------------------------------------------------
1 | from . import *
2 | from ..dists import *
3 |
4 | class G(ABC):
5 |
6 | """
7 | Function space
8 | """
9 |
10 | idim: int
11 | odim: int
12 |
13 | @abstractmethod
14 | def sample(self, prng) -> Fx: pass
15 |
16 | # ---------------------------------------------------------------------------- #
17 | # INTERPOLATE #
18 | # ---------------------------------------------------------------------------- #
19 |
20 | from ..basis import *
21 | class Interpolate(G):
22 |
23 | """
24 | Interpolated function
25 | """
26 |
27 | def __init__(self, dist: Ω, basis: Type[Basis]):
28 |
29 | self.dist = dist
30 | self.basis = basis
31 |
32 | self.idim = len(dist.dim)
33 | self.odim = 1
34 |
35 | def sample(self, prng, shape=()) -> Basis:
36 |
37 | x = self.dist.sample(prng, shape)[..., None]
38 | return utils.nmap(self.basis.transform, len(shape))(x)
39 |
--------------------------------------------------------------------------------
/src/pde/mollifier.py:
--------------------------------------------------------------------------------
1 | from . import *
2 | from ..basis import *
3 |
4 | SCALE = 1e-3
5 |
6 | # ---------------------------------------------------------------------------- #
7 | # PERIODIC #
8 | # ---------------------------------------------------------------------------- #
9 |
10 | def periodic(ϕ: X, u: X) -> X:
11 |
12 | if isinstance(ϕ, Basis):
13 | base = (0, ) * u.ndim()
14 | origin = np.array(base)
15 |
16 | return u.map(lambda coef: coef.at[base].add(-u(origin)) * SCALE)
17 |
18 | if isinstance(ϕ, Array):
19 | x, uofx = u
20 |
21 | return (uofx - uofx[(0,)*(uofx.ndim-1)]) * SCALE
22 |
23 | # ---------------------------------------------------------------------------- #
24 | # DIRICHLET #
25 | # ---------------------------------------------------------------------------- #
26 |
27 | def dirichlet(ϕ: X, u: X) -> X:
28 |
29 | if isinstance(ϕ, Basis):
30 | x = u.grid(*u.mode)
31 |
32 | mol = np.prod(np.sin(π*x), axis=-1, keepdims=True)
33 | return u.transform(u.inv() * mol * SCALE)
34 |
35 | if isinstance(ϕ, Array):
36 | x, uofx = u
37 |
38 | mol = np.prod(np.sin(π*x), axis=-1, keepdims=True)
39 | return uofx * mol * SCALE
40 |
41 | # ---------------------------------------------------------------------------- #
42 | # INITIAL-CONDITION #
43 | # ---------------------------------------------------------------------------- #
44 |
45 | def initial_condition(ϕ: X, u: X) -> X:
46 |
47 | """
48 | Initial condition problem. The first dimension is temporal and the rest
49 | of them have periodic boundaries.
50 | """
51 |
52 | if isinstance(ϕ, Basis):
53 |
54 | mol = u.grid(*u.mode)[..., [0]] * SCALE
55 | return u.__class__.add(u.transform(u.inv() * mol), ϕ)
56 |
57 | if isinstance(ϕ, Array):
58 | x, uofx = u
59 |
60 | mol = x[..., [0]] * SCALE
61 | return uofx * mol + ϕ
62 |
--------------------------------------------------------------------------------
/src/pde/navierstokes/__init__.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 | from ...dists import *
3 | from ...basis import *
4 |
5 | from .._domain import *
6 | from .._params import *
7 |
8 | from ...basis.fourier import *
9 | from ...basis.chebyshev import *
10 |
11 | def k(nx: int, ny: int) -> Tuple[X]:
12 |
13 | return np.tile(np.fft.fftfreq(nx)[:, None] * nx * 2*π, (1, ny)), \
14 | np.tile(np.fft.fftfreq(ny)[None, :] * ny * 2*π, (nx, 1))
15 |
16 | def velocity(what: X = None, *, w: X = None) -> X:
17 |
18 | if what is None:
19 | what = np.fft.fft2(w)
20 | kx, ky = k(*what.shape)
21 |
22 | Δ = kx ** 2 + ky ** 2
23 | Δ = Δ.at[0, 0].set(1)
24 |
25 | vx = np.fft.irfft2(what * 1j*ky / Δ, what.shape)
26 | vy = np.fft.irfft2(what *-1j*kx / Δ, what.shape)
27 |
28 | return vx, vy
29 |
30 | class Initial(Gaussian):
31 |
32 | grid = Fourier[2].grid(64, 64)
33 |
34 | def __str__(self): return f"{self.length}x{self.scaling}"
35 |
36 | def __init__(self, length: float, scaling: float = 1.0):
37 | super().__init__(Initial.grid, Gaussian.Per(length))
38 |
39 | self.length = length
40 | self.scaling = scaling
41 |
42 | def sample(self, prng, shape=()) -> X:
43 |
44 | x = super().sample(prng, shape)
45 | x-= np.mean(x, (-2, -1), keepdims=True)
46 |
47 | x = self.scaling * x[..., np.newaxis, :, :]
48 | return np.broadcast_to(x, shape + (2, *self.dim))
49 |
50 | class NavierStokes(PDE):
51 |
52 | """
53 | wt + v ∇w = nu ∆w
54 | where
55 | ∇⨉v = w
56 | ∇·v = 0
57 | """
58 |
59 | T: int # end time
60 | nu: float # viscosity
61 | l: float # length scale
62 |
63 | # forcing term?
64 | fn: Optional[Fx]
65 |
66 | def __str__(self): return f"Re={int(self.Re)}:T={self.T}:{self.F}"
67 | def __init__(self, ic: Initial, T: float, nu: float, fn: Fx = None):
68 |
69 | self.odim = 1
70 | self.ic = ic
71 |
72 | self.T = T
73 | self.nu = nu
74 | self.fn = fn
75 |
76 | self.l = (l:=ic.length)
77 | self.Re = l / nu * ic.scaling
78 |
79 | if fn is None: self.F = None
80 | else: self.F = fn.__name__
81 |
82 | self.domain = Rect(3)
83 |
84 | self.basis = series(Chebyshev, Fourier, Fourier)
85 | self.params = Interpolate(ic, self.basis)
86 |
87 | from ..mollifier import initial_condition
88 | self.mollifier = initial_condition
89 |
90 | @F.cached_property
91 | def solution(self):
92 |
93 | dir = os.path.dirname(__file__)
94 |
95 | with jax.default_device(jax.devices("cpu")[0]):
96 |
97 | # solution data are typically large
98 | # transfer to RAM in the first place
99 |
100 | w = np.load(f"{dir}/w.{self.ic}.npy")
101 | u = np.load(f"{dir}/u.{self}.npy")
102 |
103 | return jax.vmap(self.basis)(w), u.shape[1:-1], u
104 |
105 | def equation(self, x: X, w0: X, w: X) -> X:
106 | w, w1, w2 = utils.fdm(w, n=2)
107 |
108 | wt = w1[..., 0, 0]
109 | wx = w1[..., 0, 1]
110 | wy = w1[..., 0, 2]
111 | Δw = Δ(w2[..., 0, 1:, 1:])
112 |
113 | vx, vy = jax.vmap(velocity)(w=w.squeeze(-1))
114 | Dwdt = wt / self.T + (vx * wx + vy * wy)
115 |
116 | if self.fn is None: f = np.zeros_like(Dwdt)
117 | else: f = self.fn(*w[0].squeeze(-1).shape)
118 |
119 | return Dwdt - self.nu * Δw - f
120 |
121 | def boundary(self, w: List[Tuple[X]]) -> List[X]:
122 |
123 | _, (wt, wb), (wl, wr) = w
124 | return [wt - wb, wl - wr]
125 |
126 | def spectral(self, w0: Basis, w: Basis) -> Basis:
127 | w1 = w.grad(); w2 = w1.grad()
128 |
129 | wt = self.basis(w1.coef[..., 0, 0])
130 | wx = self.basis(w1.coef[..., 0, 1])
131 | wy = self.basis(w1.coef[..., 0, 2])
132 | Δw = self.basis(Δ(w2.coef[..., 0, 1:, 1:]))
133 |
134 | vx, vy = jax.vmap(velocity)(w=w.inv().squeeze(-1))
135 | Dwdt = self.basis.add(wt.map(lambda coef: coef / self.T),
136 | self.basis.transform(vx * wx.inv() + vy * wy.inv()))
137 |
138 | if self.fn is None: f = self.basis(np.zeros_like(Dwdt.coef))
139 | else: f = self.basis.transform(np.broadcast_to(self.fn(*w.mode[1:]), w.mode))
140 |
141 | return self.basis.add(Dwdt, self.basis(-self.nu * Δw.coef), f.map(np.negative))
142 |
143 | ic = Initial(0.8)
144 |
145 | # ------------------------------- UNFORCED FLOW ------------------------------ #
146 |
147 | re2 = NavierStokes(ic, T=3, nu=1e-2)
148 | re3 = NavierStokes(ic, T=3, nu=1e-3)
149 | re4 = NavierStokes(ic, T=3, nu=1e-4)
150 |
151 | # ------------------------------ TRANSIENT FLOW ------------------------------ #
152 |
153 | def transient(nx: int, ny: int) -> X:
154 |
155 | xy = utils.grid(nx, ny, mode="left").sum(-1)
156 | return 0.1*(np.sin(2*π*xy) + np.cos(2*π*xy))
157 |
158 | tf = NavierStokes(ic, T=50, nu=2e-3, fn=transient)
159 |
--------------------------------------------------------------------------------
/src/pde/navierstokes/generate.py:
--------------------------------------------------------------------------------
1 | # Modified from neuraloperator. Commit ef3de3bb1140175c69a9fe3a8b45afd1335077d9
2 | # https://github.com/neuraloperator/neuraloperator/blob/master/data_generation/navier_stokes/ns_2d.py
3 |
4 | from . import *
5 |
6 | def simulate(w0: X, nu: float, f: X) -> Fx:
7 |
8 | """
9 | Returns:
10 | u: callable what -> what' for next step
11 | advance vorticity in spectral domain
12 | """
13 |
14 | s, s = w0.shape
15 | kx, ky = k(s, s)
16 |
17 | diffuse = (Δ := kx ** 2 + ky ** 2) * nu
18 |
19 | dealias = (Δ < (2/3 * π * s) ** 2).astype(float)
20 | dealias = dealias.at[0, 0].set(0) # zero-mean
21 |
22 | def Δhat(what: X) -> X:
23 |
24 | vx, vy = velocity(what)
25 |
26 | wx = np.fft.irfft2(what * 1j * kx, (s, s))
27 | wy = np.fft.irfft2(what * 1j * ky, (s, s))
28 |
29 | vxwx = np.fft.fft2(vx * wx, (s, s))
30 | vywy = np.fft.fft2(vy * wy, (s, s))
31 |
32 | return np.fft.fft2(f) - vxwx - vywy
33 |
34 | def call(what: X, dt: float) -> X:
35 |
36 | Δhat1 = Δhat(what) # Heun's method
37 |
38 | what_tilde = what + dt * (Δhat1 - diffuse * what / 2)
39 | what_tilde/= 1 + dt * diffuse / 2
40 |
41 | Δhat2 = Δhat(what_tilde) # Cranck-Nicholson + Heun
42 |
43 | what = what + dt * ((Δhat1 + Δhat2) - diffuse * what) / 2
44 | what/= 1 + dt * diffuse / 2
45 |
46 | return what * dealias
47 |
48 | return call
49 |
50 | def solution(w0: X, T: float, nu: float, force: Fx,
51 | dt: float, nt: int) -> X:
52 |
53 | """
54 | Args:
55 | w0: initial condition
56 |
57 | T: total time
58 | nu: viscosity
59 | force: -ing term
60 |
61 | dt: advance step
62 | nt: record step
63 |
64 | Returns:
65 | u: solution recorded at each timestep
66 | inclusive of the end time
67 | shape = (nt, *w0.shape)
68 | """
69 |
70 | if not force: f = np.zeros_like(w0)
71 | else: f = force(*w0.shape)
72 |
73 | step = simulate(w0, nu, f)
74 | Δt = T / (N := nt - 1)
75 |
76 | def record(what: X, _) -> Tuple[X, X]:
77 | call = lambda _, what: step(what, dt)
78 |
79 | what = step(jax.lax.fori_loop(0., Δt // dt, call, what), Δt % dt)
80 | return what, np.fft.irfft2(what, s=w0.shape)
81 |
82 | _, w = jax.lax.scan(record, np.fft.fft2(w0), None, N)
83 | return np.concatenate([w0[np.newaxis, :], w], axis=0)
84 |
85 | # ---------------------------------------------------------------------------- #
86 | # GENERATE #
87 | # ---------------------------------------------------------------------------- #
88 |
89 | def generate(pde: NavierStokes, dt: float = 1e-3, T: int = 64, X: int = 256):
90 |
91 | params = pde.params.sample(random.PRNGKey(0), (128, ))
92 | solve = F.partial(solution, T=pde.T, nu=pde.nu, force=pde.fn, dt=dt, nt=T)
93 |
94 | w = jax.vmap(solve)(jax.vmap(lambda w: w.to(1, X, X).inv().squeeze())(params))
95 | w = np.pad(w, [(0, 0), (0, 0), (0, 1), (0, 1)], mode="wrap")[..., np.newaxis]
96 |
97 | dir = os.path.dirname(__file__)
98 |
99 | np.save(f"{dir}/w.{pde.ic}.npy", params.coef)
100 | np.save(f"{dir}/u.{pde}.npy", w)
101 |
102 | return w
103 |
104 | if __name__ == "__main__":
105 |
106 | from sys import argv
107 |
108 | if argv[1] == "ns":
109 |
110 | generate(re2)
111 | generate(re3)
112 | generate(re4)
113 |
114 | if argv[1] == "tf":
115 |
116 | generate(tf, dt=5e-3)
117 |
--------------------------------------------------------------------------------
/src/pde/poisson/__init__.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 | from ...dists import *
3 | from ...basis import *
4 |
5 | from .._domain import *
6 | from .._params import *
7 |
8 | class Poisson(PDE):
9 |
10 | """
11 | -Δ u(x) = s(x)
12 | """
13 |
14 | def __init__(self):
15 |
16 | self.odim = 1
17 |
18 | self.domain = Rect(2)
19 |
20 | def spectral(self, s: Basis, u: Basis) -> Basis:
21 |
22 | return self.basis.add(s, u.grad().grad().map(Δ))
23 |
24 | # ---------------------------------------------------------------------------- #
25 | # PERIODIC #
26 | # ---------------------------------------------------------------------------- #
27 |
28 | from ...basis.fourier import *
29 |
30 | class Periodic(Poisson):
31 |
32 | def __init__(self, res: int):
33 | super().__init__()
34 |
35 | class Source(Gaussian):
36 |
37 | def sample(self, prng, shape=()) -> X:
38 |
39 | x = super().sample(prng, shape)
40 | μ = np.mean(x, (-2, -1), keepdims=True)
41 |
42 | return x - μ
43 |
44 | source = Source(Fourier[2].grid(res, res), Gaussian.Per(0.2))
45 | self.params = Interpolate(source, Fourier[2])
46 | self.basis = Fourier[2]
47 |
48 | from ..mollifier import periodic
49 | self.mollifier = periodic
50 |
51 | def equation(self, x: X, s: X, u: X) -> X:
52 |
53 | s = s[:-1, :-1] # periodic
54 | u = u[:-1, :-1] # boundary
55 |
56 | # 5-point stencil for discrete laplacian
57 |
58 | Δ = np.roll(u, 1, 0) + np.roll(u, -1, 0) \
59 | + np.roll(u, 1, 1) + np.roll(u, -1, 1) - 4 * u
60 |
61 | return s + Δ * len(u) ** 2
62 |
63 | def boundary(self, u: List[Tuple[X]]) -> List[X]:
64 |
65 | return [ul - ur for ul, ur in u]
66 |
67 | @F.cached_property
68 | def solution(self):
69 |
70 | dir = os.path.dirname(__file__)
71 |
72 | s = np.load(f"{dir}/s.periodic.npy")
73 | u = np.load(f"{dir}/u.periodic.npy")
74 |
75 | return jax.vmap(self.basis)(s), u.shape[1:-1], u
76 |
77 | # --------------------------------- INSTANCE --------------------------------- #
78 |
79 | periodic = Periodic(16)
80 |
81 | # ---------------------------------------------------------------------------- #
82 | # DIRICHLET #
83 | # ---------------------------------------------------------------------------- #
84 |
85 | from ...basis.chebyshev import *
86 |
87 | class Dirichlet(Poisson):
88 |
89 | def __init__(self, res: int):
90 | super().__init__()
91 |
92 | source = Gaussian(Chebyshev[2].grid(res, res), Gaussian.RBF(0.2))
93 | self.params = Interpolate(source, Chebyshev[2])
94 | self.basis = Chebyshev[2]
95 |
96 | from ..mollifier import dirichlet
97 | self.mollifier = dirichlet
98 |
99 | def equation(self, x: X, s: X, u: X) -> X:
100 |
101 | return s + Δ(utils.fdm(u, 2)[2])
102 |
103 | def boundary(self, u: List[Tuple[X]]) -> List[X]:
104 |
105 | return []
106 |
107 | @F.cached_property
108 | def solution(self):
109 |
110 | dir = os.path.dirname(__file__)
111 |
112 | s = np.load(f"{dir}/s.dirichlet.npy")
113 | u = np.load(f"{dir}/u.dirichlet.npy")
114 |
115 | return jax.vmap(self.basis)(s), u.shape[1:-1], u
116 |
117 | # --------------------------------- INSTANCE --------------------------------- #
118 |
119 | dirichlet = Dirichlet(16)
120 |
--------------------------------------------------------------------------------
/src/pde/poisson/generate.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | def solution(s: Basis, res: int = 256) -> X:
4 |
5 | freq = map(lambda n: np.square(np.fft.fftfreq(n) * 2 * π * n), s.mode)
6 | u = s.map(lambda coef: coef / sum(np.meshgrid(*freq, indexing="ij"))[..., np.newaxis])
7 |
8 | u = u.map(lambda coef: coef.at[(0, ) * u.ndim()].set(0))
9 | u = u.map(lambda coef: coef.at[(0, ) * u.ndim()].add(-u(np.zeros(u.ndim()))))
10 |
11 | return u[res, res]
12 |
13 | def generate(pde: Periodic, N: int = 128, X: int = 256):
14 |
15 | params = pde.params.sample(random.PRNGKey(0), (N, ))
16 | u = jax.vmap(F.partial(solution, res=X))(params)
17 |
18 | dir = os.path.dirname(__file__)
19 |
20 | np.save(f"{dir}/s.periodic.npy", params.coef)
21 | np.save(f"{dir}/u.periodic.npy", u)
22 |
23 | return u
24 |
25 | if __name__ == "__main__":
26 |
27 | generate(periodic)
28 |
--------------------------------------------------------------------------------
/src/pde/poisson/generate.wls:
--------------------------------------------------------------------------------
1 | (* Load parameters from numpy file *)
2 | Source = Last[ExternalEvaluate["Python", {
3 | "import numpy as np",
4 | "np.load(\"s.dirichlet.npy\")"
5 | }]];
6 |
7 | (* Qurey grid points: uniformly spaced *)
8 | X = CoordinateBoundsArray[{{0, 255}, {0, 255}}] / 255;
9 |
10 | PDE = -Div[Grad[u[x, y], {x, y}], {x, y}] == s;
11 | BCs = u[0, y] == u[1, y] == u[x, 0] == u[x, 1] == 0;
12 |
13 | U = Table[
14 |
15 | {nx, ny, one} = Dimensions[S];
16 |
17 | cx = ChebyshevT[#, x] & /@ Range[0, nx-1] /. x -> 2x - 1;
18 | cy = ChebyshevT[#, y] & /@ Range[0, ny-1] /. y -> 2y - 1;
19 | s = Sum[S[[i, j, 1]] cx[[i]] cy[[j]], {i, nx}, {j, ny}];
20 |
21 | (* Solve PDE & measure cost *)
22 | { Cost, Evaluation } = Timing[
23 | Solution = NDSolveValue[{PDE, BCs}, u, {x, 0, 1}, {y, 0, 1}];
24 | ArrayReshape[Solution @@@ Flatten[X, 1], {255, 255, 1}]
25 | ];
26 |
27 | Print[Cost];
28 | Evaluation
29 |
30 | , { S, Normal[Source] }];
31 |
32 | (* Store answer to `.npy` *)
33 | ExternalEvaluate["Python", {
34 | "import numpy as np",
35 | "np.save(\"u.dirichlet.npy\", <*NumericArray[U, \"Real32\"]*>)",
36 | "print('DONE')"
37 | }];
38 |
--------------------------------------------------------------------------------
/src/pde/reaction/__init__.py:
--------------------------------------------------------------------------------
1 | from .. import *
2 | from ...dists import *
3 | from ...basis import *
4 |
5 | from .._domain import *
6 | from .._params import *
7 |
8 | from ...basis.fourier import *
9 | from ...basis.chebyshev import *
10 |
11 | class ReactionDiffusion(PDE):
12 |
13 | """
14 | ut - \nu uxx = \rho u (1 - u)
15 | """
16 |
17 | rho: float # reaction coefficient
18 | nu: float # diffusion coefficient
19 |
20 | def __str__(self): return f"rho={self.rho}:nu={self.nu:.3f}"
21 | def __init__(self, res: int, rho: float, nu: float):
22 |
23 | self.odim = 1
24 |
25 | self.rho = rho
26 | self.nu = nu
27 |
28 | self.domain = Rect(2)
29 | self.basis = series(Chebyshev, Fourier)
30 |
31 | class Initial(Gaussian):
32 |
33 | def sample(self, prng, shape=()) -> X:
34 | x = super().sample(prng, shape)
35 |
36 | inf = np.min(x, axis=-1, keepdims=True)
37 | sup = np.max(x, axis=-1, keepdims=True)
38 |
39 | x, shape = (x - inf) / (sup - inf), shape + (2, res)
40 | return np.broadcast_to(x[..., np.newaxis, :], shape)
41 |
42 | initial = Initial(Fourier.grid(res), Gaussian.Per(0.2))
43 | self.params = Interpolate(initial, self.basis)
44 |
45 | from ..mollifier import initial_condition
46 | self.mollifier = initial_condition
47 |
48 | @F.cached_property
49 | def solution(self):
50 |
51 | dir = os.path.dirname(__file__)
52 |
53 | h = np.load(f"{dir}/h.npy")
54 | u = np.load(f"{dir}/u.{self}.npy")
55 |
56 | return jax.vmap(self.basis)(h), u.shape[1:-1], u
57 |
58 | def equation(self, x: X, h: X, u: X) -> X:
59 | u, u1, u2 = utils.fdm(u, n=2)
60 |
61 | ut = u1[..., 0]
62 | uxx = u2[..., 1, 1]
63 |
64 | reaction = -self.rho * u * (1 - u)
65 | diffusion = -self.nu * uxx
66 |
67 | return ut + reaction + diffusion
68 |
69 | def boundary(self, u: List[Tuple[X]]) -> List[X]:
70 |
71 | _, (ul, ur) = u
72 | return [ul - ur]
73 |
74 | def spectral(self, h: Basis, u: Basis) -> Basis:
75 |
76 | u1 = u.map(lambda coef: coef.at[(0, 0)].add(-1))
77 |
78 | reaction = u.__class__.mul(u, u1).map(lambda uu1: uu1 * self.rho)
79 | diffusion = u.grad(2).map(lambda u2: -u2[..., 1] * self.nu)
80 |
81 | ut = u.grad().map(lambda u1: u1[..., 0])
82 | return u.__class__.add(ut, reaction, diffusion)
83 |
84 | # --------------------------------- INSTANCE --------------------------------- #
85 |
86 | nu005 = ReactionDiffusion(64, rho=5, nu=0.005)
87 | nu01 = ReactionDiffusion(64, rho=5, nu=0.01)
88 | nu05 = ReactionDiffusion(64, rho=5, nu=0.05)
89 | nu1 = ReactionDiffusion(64, rho=5, nu=0.1)
90 |
--------------------------------------------------------------------------------
/src/pde/reaction/generate.py:
--------------------------------------------------------------------------------
1 | # Modified from characterizing-pinns-failure-modes. Commit 4390d09c507c117a37e621ab1b785a43f0c32f57
2 | # https://github.com/a1k12/characterizing-pinns-failure-modes/blob/main/pbc_examples/systems_pbc.py
3 |
4 | from . import *
5 |
6 | def reaction(u: X, rho: float, dt: float) -> X:
7 |
8 | """
9 | du/dt = rho*u*(1-u)
10 | """
11 |
12 | factor_1 = u * np.exp(rho * dt)
13 | factor_2 = 1 - u
14 |
15 | return factor_1 \
16 | / (factor_1 + factor_2)
17 |
18 | def diffusion(u: X, nu: float, dt: float, IKX2: X) -> X:
19 |
20 | """
21 | du/dt = nu*d2u/dx2
22 | """
23 |
24 | factor = np.exp(nu * IKX2 * dt)
25 | u_hat = np.fft.fft(u) * factor
26 |
27 | return np.fft.ifft(u_hat).real
28 |
29 | def solution(h: Fx, nu: float, rho: float, nx=4096, nt=4097) -> X:
30 |
31 | """
32 | Computes the discrete solution of the reaction-diffusion PDE using pseudo
33 | spectral operator splitting.
34 |
35 | Args:
36 | h: initial condition
37 | nu: diffusion coefficient
38 | rho: reaction coefficient
39 | nx: number of points in the x grid
40 | nt: number of points in the t grid
41 |
42 | Returns:
43 | x: grids
44 | u: solution
45 | """
46 |
47 | L = 1
48 | T = 1
49 | dx = L/nx
50 | dt = T/(nt-1)
51 | x = np.arange(0, L, dx) # not inclusive of the last point
52 | t = np.linspace(0, T, nt) # inclusive of the end time 1
53 | u = np.zeros((nx, nt))
54 |
55 | IKX_pos = 2j * π * np.arange(0, nx/2+1, 1)
56 | IKX_neg = 2j * π * np.arange(-nx/2+1, 0, 1)
57 | IKX = np.concatenate((IKX_pos, IKX_neg))
58 | IKX2 = IKX * IKX
59 |
60 | u = [_u:=jax.vmap(h)(x)]
61 |
62 | for _ in range(nt-1):
63 |
64 | _u = reaction(_u, rho, dt)
65 | _u = diffusion(_u, nu, dt, IKX2)
66 |
67 | u.append(_u)
68 |
69 | return np.dstack(np.meshgrid(t, x, indexing="ij")), np.stack(u)
70 |
71 | # ---------------------------------------------------------------------------- #
72 | # GENERATE #
73 | # ---------------------------------------------------------------------------- #
74 |
75 | def generate(pde: ReactionDiffusion, N: int = 128):
76 |
77 | params = pde.params.sample(random.PRNGKey(0), (N, ))
78 | solve = F.partial(solution, nu=pde.nu, rho=pde.rho)
79 |
80 | u = jax.lax.map(lambda h: solve(lambda x: h(np.array([0, x])).squeeze()), params)[1]
81 | u = np.pad(u[:, ::8, ::8], [(0, 0), (0, 0), (0, 1)], mode="wrap")[..., np.newaxis]
82 |
83 | dir = os.path.dirname(__file__)
84 |
85 | np.save(f"{dir}/h.npy", params.coef)
86 | np.save(f"{dir}/u.{pde}.npy", u)
87 |
88 | return u
89 |
90 | if __name__ == "__main__":
91 |
92 | for nu in [0.005, 0.01, 0.05, 0.1]:
93 |
94 | generate(ReactionDiffusion(64, 5, nu))
95 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | from . import *
2 | from .pde import *
3 | from .model import *
4 | from .basis import *
5 |
6 | # ---------------------------------------------------------------------------- #
7 | # STEP #
8 | # ---------------------------------------------------------------------------- #
9 |
10 | def step(self: Trainer, variable: ϴ) -> Tuple[ϴ, Dict[str, X]]:
11 |
12 | params, prng = variable.get("params"), self.make_rng("sample")
13 | ϕ = self.pde.params.sample(prng, (self.cfg["bs"], ))
14 |
15 | @F.partial(jax.grad, has_aux=True)
16 | def loss(params: ϴ, ϕ: X) -> Tuple[X, Dict[str, X]]:
17 |
18 | loss = self.mod.apply(variable.copy(dict(params=params)), ϕ, method="loss")
19 | return jax.tree_util.tree_reduce(O.add, loss), loss
20 |
21 | grads, loss = jax.tree_map(F.partial(np.mean, axis=0),
22 | utils.cmap(F.partial(loss, params), self.cfg["vmap"])(ϕ))
23 |
24 | if self.cfg["clip"] is not None:
25 |
26 | loss["norm"] = np.sqrt(sum(np.sum(np.square(grad))
27 | for grad in jax.tree_util.tree_leaves(grads)))
28 |
29 | grads = jax.tree_map(F.partial(jax.lax.cond, loss["norm"] < self.cfg["clip"],
30 | lambda grad: grad, lambda grad: grad / loss["norm"] * self.cfg["clip"]), grads)
31 |
32 | updates, state = self.optimizer.update(grads, self.get_variable("optim", "state"), params)
33 |
34 | self.put_variable("optim", "state", state)
35 | return variable.copy(dict(params=optax.apply_updates(params, updates))), loss
36 |
37 | # ---------------------------------------------------------------------------- #
38 | # EVAL #
39 | # ---------------------------------------------------------------------------- #
40 |
41 | def eval(self: Trainer, variable: ϴ) -> Tuple[Dict, X]:
42 |
43 | if isinstance(self.pde.solution, Tuple):
44 | ϕ, s, u = self.pde.solution
45 |
46 | v = utils.cmap(F.partial(self.mod.apply, variable, x=s, method="u"), self.cfg["vmap"])(ϕ)
47 | with jax.default_device(cpu:=jax.devices("cpu")[0]):
48 |
49 | u, v = jax.device_put((u, v), device=cpu)
50 | return jax.tree_map(np.mean, jax.vmap(metric(self.pde, s))(ϕ, u, v)), (u, v)
51 |
52 | def metric(pde: PDE, s: Tuple[int]) -> Fx:
53 |
54 | def call(ϕ: Basis, u: X, v: X) -> Dict[str, X]:
55 |
56 | return dict(
57 | erra=np.mean(np.abs(np.ravel(u - v))),
58 | errr=np.linalg.norm(np.ravel(u - v)) / np.linalg.norm(np.ravel(u)),
59 | residual=np.mean(np.abs(pde.equation(utils.grid(*s), ϕ[s], v))),
60 | )
61 |
62 | return call
63 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | def grid(*s: int, mode: str = None, flatten: bool = False) -> X:
4 |
5 | """
6 | Return grid on [0, 1)^n. If not flatten, shape=(*s, len(s));
7 | else shape=(∏s, len(s)).
8 |
9 | Mode:
10 | - `None`: uniformly spaced
11 | - "left": exclude endpoint
12 | - "cell": centers of rects
13 | """
14 |
15 | axes = F.partial(np.linspace, 0, 1, endpoint=mode is None)
16 | grid = np.stack(np.meshgrid(*map(axes, s), indexing="ij"), -1)
17 |
18 | if mode == "cell": grid += .5 / np.array(s)
19 | if flatten: return grid.reshape(-1, len(s))
20 |
21 | return grid
22 |
23 | def nmap(f: Fx, n: int = 1, **kwargs) -> Fx:
24 |
25 | """
26 | Nested vmap. Keeps the same semantics as `jax.vmap` except that arbitrary
27 | `n` leading dimensions are vectorized. Returns the vmapped function.
28 | """
29 |
30 | if not n: return f
31 |
32 | if n > 1: f = nmap(f, n - 1)
33 | return jax.vmap(f, **kwargs)
34 |
35 | def cmap(f: Fx, n: int = None, **kwargs) -> Fx:
36 |
37 | """
38 | Chunck vmap. Keeps the same semantics as `jax.vmap` but only vectorizing
39 | over `n` items, and uses loop-based map over the chunks of that size.
40 | """
41 |
42 | f = jax.vmap(f, **kwargs)
43 | def call(*args, **kwargs):
44 |
45 | return jax.tree_map(np.concatenate, jax.lax.map(f, *jax.tree_map(into:=lambda x:
46 | x.reshape(-1, n, *x.shape[1:]), args), **jax.tree_map(into, kwargs)))
47 |
48 | if n is None: return f
49 |
50 | return call
51 |
52 | def jit(f: Fx, **options) -> Fx:
53 |
54 | """
55 | JIT function with cost analysis on the first run. Keep in mind that loops
56 | are not taken in to account correctly (which means with cfg.vmap set, the
57 | results are not reliable).
58 | """
59 |
60 | f = jax.jit(f, **options)
61 | def call(*args, **kwargs) -> Any:
62 |
63 | nonlocal f
64 |
65 | if not isinstance(f, jax.stages.Compiled):
66 |
67 | print("=" * 116)
68 |
69 | print(f"compling function {f} ......")
70 | f = f.lower(*args, **kwargs).compile()
71 |
72 | cost, = f.cost_analysis()
73 | print("flop:", cost["flops"])
74 | print("memory:", cost["bytes accessed"])
75 |
76 | print("=" * 116)
77 |
78 | return f(*args, **kwargs)
79 |
80 | return call
81 |
82 | def timeit(f: Fx, **options) -> Fx:
83 | fjit = jit(f, **options)
84 |
85 | def call(*args, **kwargs) -> float:
86 | fjit(*args, **kwargs) # compile
87 |
88 | import time
89 | iter = time.time()
90 |
91 | jax.tree_map(jax.block_until_ready,
92 | fjit(*args, **kwargs))
93 |
94 | return time.time() - iter
95 |
96 | return call
97 |
98 | # ---------------------------------------------------------------------------- #
99 | # DERIVATIVE #
100 | # ---------------------------------------------------------------------------- #
101 |
102 | def grad(f: Fx, n: int, D: Fx = jax.jacfwd) -> Fx:
103 |
104 | """
105 | Calculate up to `n`th order derivatives. Return List[Array] of length `n`
106 | where `k`th element is the `k`th derivative of shape `(..., d^k)`.
107 |
108 | Differential scheme `D :: (X -> X) -> (X -> X)` determines how to obtain
109 | the Jacobian. Default to JAX's forward mode autograd.
110 | """
111 |
112 | u = [f]
113 |
114 | for _ in range(n): u.append(f:=D(f))
115 | return lambda x: [fk(x) for fk in u]
116 |
117 | def fdm(x: X, n: int) -> List[X]:
118 |
119 | """
120 | Approximate the above derivative using finite difference. `x` is assumed
121 | to be evaluated on uniform grids on [0, 1]^d (include end-points), where
122 | dimension is taken as `x.ndim - 1`, i.e. `x` has a trailing channel dim.
123 | """
124 |
125 | u = [x]
126 | d = x.ndim - 1
127 | s = x.shape[:-1]
128 |
129 | for _ in range(n):
130 |
131 | grad = map(lambda i: np.gradient(x, axis=i), range(d))
132 | u.append(x:=np.stack(tuple(grad), axis=-1) * (np.array(s)-1))
133 |
134 | return u
135 |
136 | # ---------------------------------------------------------------------------- #
137 | # DEBUG #
138 | # ---------------------------------------------------------------------------- #
139 |
140 | def repl(local):
141 |
142 | # ---------------------------------- IMPORT ---------------------------------- #
143 |
144 | import matplotlib.pyplot as plt
145 | import matplotlib.colors as clr
146 |
147 | import scienceplots
148 | plt.style.use(["science",
149 | "no-latex"])
150 |
151 | from src.basis import Basis, series
152 | from src.basis.fourier import Fourier
153 | from src.basis.chebyshev import Chebyshev
154 |
155 | # ---------------------------------- HELPER ---------------------------------- #
156 |
157 | def save(fig=plt): fig.savefig("test.jpg", dpi=300); fig.clf()
158 | def show(img, **kw): plt.colorbar(plt.imshow(img, **kw))
159 | def gif(*imgs, fps: int = 50, **kw):
160 | fig, ax = plt.subplots()
161 |
162 | vmin = kw.pop("vmin", min(map(np.min, imgs)))
163 | vmax = kw.pop("vmax", max(map(np.max, imgs)))
164 | im = ax.imshow(imgs[0], vmin=vmin, vmax=vmax, **kw)
165 | id = ax.text(0.98, 0.02, "", transform=ax.transAxes, ha="right", va="bottom")
166 |
167 | plt.colorbar(im)
168 | def frame(index):
169 |
170 | i = len(imgs) * index//fps
171 | id.set_text(f"#{i:03}")
172 | im.set_array(imgs[i])
173 | return im, id
174 |
175 | from matplotlib import animation
176 | ani = animation.FuncAnimation(fig, frame, fps, blit=True)
177 | ani.save("test.gif", writer="pillow", fps=fps, dpi=300)
178 |
179 | import code; code.interact(local=dict(globals(), **dict(locals(), **local)))
180 |
--------------------------------------------------------------------------------