├── figures ├── advection.png ├── burgers.png ├── reaction.png ├── allen_cahn.png └── reaction_diffusion.png ├── pinns ├── __pycache__ │ ├── nn.cpython-311.pyc │ ├── ivps.cpython-311.pyc │ └── training.cpython-311.pyc ├── nn.py └── ivps.py ├── requirements.txt ├── spinn.py ├── pinn.py ├── README.md └── .gitignore /figures/advection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaeminoh/Siren_PINNs/HEAD/figures/advection.png -------------------------------------------------------------------------------- /figures/burgers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaeminoh/Siren_PINNs/HEAD/figures/burgers.png -------------------------------------------------------------------------------- /figures/reaction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaeminoh/Siren_PINNs/HEAD/figures/reaction.png -------------------------------------------------------------------------------- /figures/allen_cahn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaeminoh/Siren_PINNs/HEAD/figures/allen_cahn.png -------------------------------------------------------------------------------- /figures/reaction_diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaeminoh/Siren_PINNs/HEAD/figures/reaction_diffusion.png -------------------------------------------------------------------------------- /pinns/__pycache__/nn.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaeminoh/Siren_PINNs/HEAD/pinns/__pycache__/nn.cpython-311.pyc -------------------------------------------------------------------------------- /pinns/__pycache__/ivps.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaeminoh/Siren_PINNs/HEAD/pinns/__pycache__/ivps.cpython-311.pyc -------------------------------------------------------------------------------- /pinns/__pycache__/training.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaeminoh/Siren_PINNs/HEAD/pinns/__pycache__/training.cpython-311.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | chex==0.1.88 3 | contourpy==1.3.1 4 | cycler==0.12.1 5 | etils==1.11.0 6 | fonttools==4.55.5 7 | jax==0.5.0 8 | jax-cuda12-pjrt==0.5.0 9 | jax-cuda12-plugin==0.5.0 10 | jaxlib==0.5.0 11 | jaxopt==0.8.3 12 | kiwisolver==1.4.8 13 | matplotlib==3.10.0 14 | ml-dtypes==0.5.1 15 | numpy==2.2.2 16 | nvidia-cublas-cu12==12.8.3.14 17 | nvidia-cuda-cupti-cu12==12.8.57 18 | nvidia-cuda-nvcc-cu12==12.8.61 19 | nvidia-cuda-runtime-cu12==12.8.57 20 | nvidia-cudnn-cu12==9.6.0.74 21 | nvidia-cufft-cu12==11.3.3.41 22 | nvidia-cusolver-cu12==11.7.2.55 23 | nvidia-cusparse-cu12==12.5.7.53 24 | nvidia-nccl-cu12==2.24.3 25 | nvidia-nvjitlink-cu12==12.8.61 26 | opt-einsum==3.4.0 27 | optax==0.2.4 28 | packaging==24.2 29 | pillow==11.1.0 30 | pyparsing==3.2.1 31 | python-dateutil==2.9.0.post0 32 | scipy==1.15.1 33 | setuptools==75.8.0 34 | six==1.17.0 35 | toolz==1.0.0 36 | tqdm==4.67.1 37 | typing-extensions==4.12.2 38 | -------------------------------------------------------------------------------- /spinn.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.random as jr 3 | import jaxopt 4 | from optax import adam, cosine_decay_schedule 5 | 6 | from pinns.ivps import advection 7 | from pinns.nn import Siren 8 | 9 | 10 | class SeparablePINN(advection): 11 | def __init__(self, width=64, depth=4, d_out=64, w0=8.0): 12 | super().__init__() 13 | layers = [1] + [width for _ in range(depth - 1)] + [d_out] 14 | self.init, self.apply = Siren(layers, w0) 15 | 16 | def u(self, params, t, x): # (Nt, Nx) 17 | t, x = t.reshape(-1, 1), x.reshape(-1, 1) 18 | outputs = self.apply(params[0], t) @ self.apply(params[1], x).T 19 | return outputs 20 | 21 | 22 | spinn = SeparablePINN() 23 | *init_keys, train_key = jr.split(jr.key(0), 3) 24 | init_params = [spinn.init(_key) for _key in init_keys] 25 | 26 | nIter = 1 * 10**5 27 | lr = cosine_decay_schedule(1e-03, nIter) 28 | 29 | 30 | Nt, Nx = 128, 128 31 | domain_tr = ( 32 | spinn.T * jnp.linspace(0, 1, Nt), 33 | spinn.X * jnp.linspace(*spinn.x_bd, Nx), 34 | ) 35 | 36 | optimizer = jaxopt.OptaxSolver(fun=spinn.loss, opt=adam(lr),) 37 | 38 | spinn.train(optimizer, domain_tr, train_key, init_params, nIter=nIter) 39 | spinn.drawing(save=True) 40 | -------------------------------------------------------------------------------- /pinn.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.random as jr 4 | import jaxopt 5 | from optax import adam, cosine_decay_schedule 6 | 7 | from pinns.ivps import advection 8 | from pinns.nn import Siren 9 | 10 | 11 | class PINN(advection): 12 | def __init__(self, width=64, depth=5, w0=8.0): 13 | super().__init__() 14 | layers = [2] + [width for _ in range(depth - 1)] + [1] 15 | self.init, self.apply = Siren(layers, w0) 16 | # (Nt, Nx) 17 | self.u = jax.vmap(jax.vmap(self._u, (None, 0, None), 0), (None, None, 0), 1) 18 | 19 | def _u(self, params, t, x): # scalar function 20 | inputs = jnp.hstack([t, x]) 21 | output = self.apply(params, inputs).squeeze() 22 | return output 23 | 24 | 25 | pinn = PINN() 26 | init_key, train_key = jr.split(jr.key(0)) 27 | init_params = pinn.init(init_key) 28 | 29 | nIter = 1 * 10**5 30 | lr = cosine_decay_schedule(1e-03, nIter) 31 | optimizer = jaxopt.OptaxSolver(fun=pinn.loss, opt=adam(lr)) 32 | 33 | Nt, Nx = 128, 128 34 | domain_tr = ( 35 | pinn.T * jnp.linspace(0, 1, Nt), 36 | pinn.X * jnp.linspace(*pinn.x_bd, Nx), 37 | ) 38 | 39 | 40 | pinn.train(optimizer, domain_tr, train_key, init_params, nIter=nIter) 41 | pinn.drawing(save=True) 42 | -------------------------------------------------------------------------------- /pinns/nn.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.random as jr 3 | 4 | 5 | def Siren(layers, w0=4.0): 6 | """ 7 | Multi-layer Perceptron with sine activation function, known as 'SIREN'. 8 | It has the same architecture as MLP, but different initialization is employed. 9 | w0 controls the frequency of the output. 10 | """ 11 | 12 | def siren_init(key, d_in, d_out, is_first=False): 13 | if is_first: 14 | variance = 1 / d_in 15 | W = jr.uniform(key, (d_in, d_out), minval=-variance, maxval=variance) 16 | std = jnp.sqrt(1 / d_in) 17 | b = jr.uniform(key, (d_out,), minval=-std, maxval=std) 18 | return W, b 19 | else: 20 | variance = jnp.sqrt(6 / d_in) / w0 21 | W = jr.uniform(key, (d_in, d_out), minval=-variance, maxval=variance) 22 | std = jnp.sqrt(1 / d_in) 23 | b = jr.uniform(key, (d_out,), minval=-std, maxval=std) 24 | return W, b 25 | 26 | def init(rng_key): 27 | _, *keys = jr.split(rng_key, len(layers)) 28 | params = [siren_init(keys[0], layers[0], layers[1], is_first=True)] + list( 29 | map(siren_init, keys[1:], layers[1:-1], layers[2:]) 30 | ) 31 | return params 32 | 33 | def activation(x): 34 | return jnp.sin(w0 * x) 35 | 36 | def apply(params, inputs): 37 | # Forward Pass 38 | for W, b in params[:-1]: 39 | outputs = jnp.dot(inputs, W) + b 40 | inputs = activation(outputs) 41 | # Final inner product 42 | W, b = params[-1] 43 | outputs = jnp.dot(inputs, W) + b 44 | return outputs 45 | 46 | return init, apply 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Resolving failure modes of PINNs with osciliatory activation functions. 2 | 3 | [![DOI](https://zenodo.org/badge/664425702.svg)](https://doi.org/10.5281/zenodo.15207858) 4 | 5 | It is widely acknowledged that standard Multilayer Perceptrons (MLPs) have inherent limitations in effectively learning high-frequency signals. 6 | Consequently, Partial Differential Equations (PDEs) with periodic, sharp, and highly variable solutions pose a significant challenge when trained using [Physics-Informed Neural Networks (PINNs)](https://doi.org/10.1016/j.jcp.2018.10.045). 7 | 8 | To address this issue, [Krishnapriyan et al.](https://arxiv.org/abs/2109.01050) propose a "curriculum" learning approach, starting with easily learnable parameters in PINNs and gradually increasing the complexity towards more challenging cases. 9 | By initializing network parameters from the previous step, these authors have successfully tackled "convection, reaction, and reaction-diffusion" equations, which were previously difficult to handle using standard MLPs. 10 | 11 | In this repository, we take a different approach, focusing on the application of "[Siren](https://arxiv.org/abs/2006.09661)" - a widely recognized architecture in the realm of Implicit Neural Representations (INRs). 12 | Siren employs a sine activation function and a corresponding initialization scheme, enabling efficient learning of high-frequency signals with MLPs. 13 | By leveraging Siren, we aim to overcome the challenges posed by complex PDEs without resorting to the aforementioned curriculum-based methods. 14 | 15 | Further speedups can be possible with "[Separable PINN](https://arxiv.org/abs/2211.08761)". 16 | 17 | Burgers, advection, reaction, reaction-diffusion, Allen-Cahn equations are implemented in `pinns/ivps.py`. 18 | Simple example code is available in `pinn.py` and `spinn.py`. 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Scheduler related 2 | *slurm* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ -------------------------------------------------------------------------------- /pinns/ivps.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import matplotlib.pyplot as plt 7 | from jax.experimental.jet import jet 8 | from tqdm import trange 9 | 10 | 11 | class ivps: 12 | name = "ivps" 13 | T = 1.0 14 | 15 | # solution, initial condition and pde must be overrided 16 | def u(self): 17 | raise NotImplementedError 18 | 19 | def u0(self): 20 | raise NotImplementedError 21 | 22 | def pde(self): 23 | raise NotImplementedError 24 | 25 | def loss_ic(self, params, x): 26 | t = jnp.zeros((1,)) 27 | init_data = self.u0(x) 28 | init_pred = self.u(params, t, x)[0, ...] 29 | loss_ic = jnp.mean((init_pred - init_data) ** 2) 30 | return loss_ic 31 | 32 | # Default: periodic on x 33 | def loss_bc(self, params, t): 34 | x = self.X * self.x_bd 35 | u = self.u(params, t, x) 36 | loss_bc = jnp.mean((u[..., -1] - u[..., 0]) ** 2) 37 | return loss_bc 38 | 39 | @partial(jax.jit, static_argnums=(0,)) 40 | def loss(self, params, t, x): 41 | loss = ( 42 | self.pde(params, t, x).mean() 43 | + 1e3 * self.loss_ic(params, x) 44 | + self.loss_bc(params, t) 45 | ) 46 | return loss 47 | 48 | def train(self, optimizer, domain_tr, key, params, nIter=5 * 10**4): 49 | print(self.equation) 50 | T, X = self.T, self.X 51 | x_L, x_R = self.x_bd 52 | domain = [*domain_tr] 53 | Nt, Nx = domain[0].size, domain[1].size 54 | state = optimizer.init_state(params, *domain_tr) 55 | loss_log = [] 56 | 57 | @jax.jit 58 | def step(params, state, *args, **kwargs): 59 | params, state = optimizer.update(params, state, *args, **kwargs) 60 | return params, state 61 | 62 | for it in (pbar := trange(1, nIter + 1)): 63 | params, state = step(params, state, *domain) 64 | if it % 100 == 0: 65 | loss = state.value 66 | loss_log.append(loss) 67 | # domain sampling 68 | key, *subkey = jr.split(key, 3) 69 | domain[0] = T * jr.uniform(subkey[0], (Nt,)) 70 | domain[1] = X * jr.uniform(subkey[1], (Nx,), minval=x_L, maxval=x_R) 71 | pbar.set_postfix({"pinn loss": f"{loss:.3e}"}) 72 | 73 | self.opt_params, self.loss_log = params, loss_log 74 | 75 | def drawing(self, save=True): 76 | print("Drawing...") 77 | dir = f"figures/{self.name}" 78 | fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 5)) 79 | # loss log 80 | ax1.semilogy(self.loss_log, label="PINN Loss") 81 | ax1.set_xlabel("100 iterations") 82 | ax1.set_ylabel("Mean Squared Error") 83 | # Solution profile 84 | opt_params = self.opt_params 85 | domain = ( 86 | self.T * jnp.linspace(0, 1, 200), 87 | self.X * jnp.linspace(*self.x_bd, 200), 88 | ) 89 | pred = self.u(opt_params, *domain) 90 | im = ax2.imshow(pred.T, origin="lower", cmap="jet", aspect="auto") 91 | ax2.axis("off") 92 | fig.colorbar(im) 93 | if save: 94 | fig.savefig(dir) 95 | else: 96 | fig.show() 97 | print("Done!") 98 | 99 | 100 | class burgers(ivps): 101 | name = "burgers" 102 | equation = "u_t + uu_x = νu_xx/𝝅" 103 | X = 1.0 104 | x_bd = jnp.array([-1, 1]) 105 | 106 | def __init__(self, nu=1e-02): 107 | self.nu = nu 108 | 109 | def u0(self, x): 110 | return -jnp.sin(jnp.pi * x) 111 | 112 | def pde(self, params, t, x): 113 | u_t = jax.jvp(lambda t: self.u(params, t, x), (t,), (jnp.ones(t.shape),))[1] 114 | u, (u_x, u_xx) = jet( 115 | lambda x: self.u(params, t, x), 116 | (x,), 117 | ((jnp.ones(x.shape), jnp.zeros(x.shape)),), 118 | ) 119 | pde = (u_t + u * u_x - self.nu * u_xx / jnp.pi) ** 2 120 | return pde 121 | 122 | def loss_bc(self, params, t): 123 | x = self.X * self.x_bd 124 | # Dirichelt on x 125 | u = self.u(params, t, x) 126 | loss_bc = jnp.mean(u**2) 127 | return loss_bc 128 | 129 | 130 | class advection(ivps): 131 | name = "advection" 132 | equation = "u_t + βu_x = 0" 133 | T = 1.0 134 | X = 2 * jnp.pi 135 | x_bd = jnp.array([0, 1]) 136 | 137 | def __init__(self, beta=30.0): 138 | self.beta = beta 139 | 140 | def u0(self, x): 141 | return jnp.sin(x) 142 | 143 | def pde(self, params, t, x): 144 | _, u_t = jax.jvp(lambda t: self.u(params, t, x), (t,), (jnp.ones(t.shape),)) 145 | _, u_x = jax.jvp(lambda x: self.u(params, t, x), (x,), (jnp.ones(x.shape),)) 146 | pde = (u_t + self.beta * u_x) ** 2 147 | return pde 148 | 149 | 150 | class reaction(ivps): 151 | name = "reaction" 152 | equation = "u_t = ρu(1-u)" 153 | X = 2 * jnp.pi 154 | x_bd = jnp.array([0, 1]) 155 | 156 | def __init__(self, rho=5.0): 157 | self.rho = rho 158 | 159 | def u0(self, x): 160 | exponent = 4 * (x - jnp.pi) / jnp.pi 161 | return jnp.exp(-0.5 * (exponent**2)) 162 | 163 | def pde(self, params, t, x): 164 | u, u_t = jax.jvp(lambda t: self.u(params, t, x), (t,), (jnp.ones(t.shape),)) 165 | pde = (u_t - self.rho * u * (1 - u)) ** 2 166 | return pde 167 | 168 | 169 | class reaction_diffusion(ivps): 170 | name = "reaction_diffusion" 171 | equation = "u_t = νu_xx + ρu(1-u)" 172 | X = 2 * jnp.pi 173 | x_bd = jnp.array([0, 1]) 174 | 175 | def __init__(self, nu=5.0, rho=5.0): 176 | self.nu, self.rho = nu, rho 177 | 178 | def u0(self, x): 179 | exponent = 4 * (x - jnp.pi) / jnp.pi 180 | return jnp.exp(-0.5 * (exponent**2)) 181 | 182 | def pde(self, params, t, x): 183 | _, u_t = jax.jvp(lambda t: self.u(params, t, x), (t,), (jnp.ones(t.shape),)) 184 | u, (_, u_xx) = jet( 185 | lambda x: self.u(params, t, x), 186 | (x,), 187 | ((jnp.ones(x.shape), jnp.zeros(x.shape)),), 188 | ) 189 | pde = (u_t - self.nu * u_xx - self.rho * u * (1 - u)) ** 2 190 | return pde 191 | 192 | 193 | class allen_cahn(ivps): 194 | name = "allen_cahn" 195 | equation = "u_t = νu_xx + ρu(1-u^2)" 196 | X = 1.0 197 | x_bd = jnp.array([-1, 1]) 198 | 199 | def __init__(self, nu=1e-04, rho=5.0): 200 | self.nu, self.rho = nu, rho 201 | 202 | def u0(self, x): 203 | return x**2 * jnp.cos(jnp.pi * x) 204 | 205 | def pde(self, params, t, x): 206 | _, u_t = jax.jvp(lambda t: self.u(params, t, x), (t,), (jnp.ones(t.shape),)) 207 | u, (_, u_xx) = jet( 208 | lambda x: self.u(params, t, x), 209 | (x,), 210 | ((jnp.ones(x.shape), jnp.zeros(x.shape)),), 211 | ) 212 | pde = (u_t - self.nu * u_xx + self.rho * u**3 - self.rho * u) ** 2 213 | return pde 214 | --------------------------------------------------------------------------------