├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── pit.py ├── supplementary_data ├── data_burgers.mat ├── data_sod.mat └── shape_coords.npy ├── tensorflow ├── 1_InviscidBurgers │ ├── train.py │ └── utils.py ├── 2_ShockTube │ ├── train.py │ └── utils.py ├── 3_Darcy2D │ ├── evaluate.py │ ├── train.py │ └── utils.py ├── 4_Vorticity │ ├── evaluate.py │ ├── train.py │ └── utils.py ├── 5_Elasticity │ ├── evaluate.py │ ├── train.py │ └── utils.py ├── 6_NACA │ ├── evaluate.py │ ├── train.py │ └── utils.py ├── LICENSE ├── README.md ├── figures │ ├── Darcy2D.png │ ├── Elasticity.png │ ├── InviscidBurgers.png │ ├── NACA.png │ ├── ShockTube.png │ ├── err_t20.png │ ├── pred_t20.png │ └── true_t20.png └── requirements.txt ├── train_burgers.py ├── train_cylinder.py ├── train_darcy.py ├── train_elasticity.py ├── train_naca.py ├── train_sod.py ├── train_vorticity.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.mat filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | *.pdf 3 | *.pth 4 | *.mat 5 | *.log 6 | *.csv 7 | __pycache__ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 junfeng-chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Position-induced Transformer 2 | 3 | The code in this repository presents the implementation of **Position-induced Transformer (PiT)** and the numerical experiments of using PiT for learing operators in partial differential equations. It is built upon the position-attention mechanism, proposed in the paper *Positional Knowledge is All You Need: Position-induced Transformer (PiT) for Operator Learning*. The paper can be accessed here. 4 | 5 | ## Updates May 2025 6 | - PiT is now part of **DUE**, our open-source toolkit for data-driven equation modeling with modern deep learning methods. For installation and examples, visit the GitHub repository. 7 | - Added a numerical example showcasing PiT for learning the unsteady flow past a cylinder. 8 | ## Contents 9 | - `train_burgers`: One-dimensional inviscid Burgers' equation. 10 | - `train_sod`: One-dimensional compressible Euler equations. 11 | - `train_darcy`: Two-dimensional Darcy flow problem. 12 | - `train_vorticity`: Two-dimensional incompressible Navier–Stokes equations with periodic boundary conditions. 13 | - `train_elasticity`: Two-dimensional hyper-elastic problem. 14 | - `train_naca`: Two-dimensional transonic flow over airfoils. 15 | - `train_cylinder`: Two-dimensional flow past cylinder at Reynolds number equal to 100. 16 | 17 | ## Datasets 18 | The raw data required to reproduce the main results can be obtained from some of the baseline methods selected in our paper. 19 | - For InviscidBurgers and ShockTube, datasets are provided in Lanthaler et al. They can be downloaded here. 20 | - For Darcy2D and Vorticity, datasets are provided by Li et al. They can be downloaded here. 21 | - For Elasticity and NACA, datasets are provided by Li et al. They can be downloaded here. 22 | - For Cylinder, the dataset is generated using FEniCS. It can be downloaded here. 23 | 24 | ## Requirements 25 | - This code is primarily based on PyTorch. We have observed significant improvements in PiT's training speed with PyTorch 2.x, especially when using `torch.compile`. Therefore, we highly recommend using PyTorch 2.x with `torch.compile` enabled for optimal performance. 26 | - If any issues arise with `torch.compile` that cannot be resolved, the code is also compatible with recent versions of PyTorch 1.x. In such cases, simply comment out the line `model = torch.compile(model)` in the scripts. 27 | - Matplotlib and Scipy are also required. 28 | 29 | ## Citations 30 | ``` 31 | @inproceedings{chen2024positional, 32 | title={Positional Knowledge is All You Need: Position-induced Transformer (PiT) for Operator Learning}, 33 | author={Chen, Junfeng and Wu, Kailiang}, 34 | booktitle={International Conference on Machine Learning}, 35 | pages={7526--7552}, 36 | year={2024}, 37 | organization={PMLR} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /pit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_float32_matmul_precision('high') 3 | torch.manual_seed(0) 4 | torch.cuda.manual_seed(0) 5 | torch.backends.cudnn.benchmark = True 6 | torch.backends.cudnn.deterministic = True 7 | import torch.nn as nn 8 | from torch.nn.functional import gelu 9 | import numpy as np 10 | np.random.seed(0) 11 | from math import pi 12 | 13 | class kaiming_mlp(nn.Module): 14 | def __init__(self, n_filters0, n_filters1, n_filters2): 15 | super(kaiming_mlp, self).__init__() 16 | self.mlp1 = torch.nn.Linear(n_filters0, n_filters1) 17 | self.mlp2 = torch.nn.Linear(n_filters1, n_filters2) 18 | nn.init.kaiming_normal_(self.mlp1.weight) 19 | nn.init.kaiming_normal_(self.mlp2.weight) 20 | 21 | def forward(self, x): 22 | # print(x.shape) 23 | x = self.mlp1(x) 24 | x = gelu(x) 25 | x = self.mlp2(x) 26 | return x 27 | 28 | class posatt(nn.Module): 29 | def __init__(self, n_head, in_dim, locality): 30 | super(posatt, self).__init__() 31 | 32 | self.locality = locality 33 | self.n_head = n_head 34 | self.in_dim = in_dim 35 | self.lmda = torch.nn.Parameter( torch.rand(n_head, 1, 1) ) 36 | 37 | def forward(self, mesh, inputs): 38 | """ 39 | mesh: (batch, L, 2) 40 | inputs: (batch, L_in, in_dim) 41 | """ 42 | att = self.dist2att(mesh, mesh, self.lmda, self.locality) # (n_head, L, L) 43 | convoluted = self.convolution(att, inputs) 44 | return torch.cat((inputs, convoluted), dim=-1) 45 | 46 | def dist2att(self, mesh_out, mesh_in, scale, locality): 47 | m_dist = torch.sum((mesh_out.unsqueeze(-2) - mesh_in.unsqueeze(-3))**2, dim=-1) # (batch_size, L_out, L_in) 48 | scaled_dist = m_dist.unsqueeze(1) * torch.tan(0.25*pi*(1-1e-7)*(1.0+torch.sin(scale))) # (batch_size, n_head, L_out, L_in) 49 | mask = torch.quantile(scaled_dist, locality, dim=-1, keepdim=True) 50 | scaled_dist = torch.where(scaled_dist <= mask, scaled_dist, torch.tensor(torch.finfo(torch.float32).max, device=scaled_dist.device)) 51 | scaled_dist = -scaled_dist 52 | return torch.nn.Softmax(dim=-1)(scaled_dist) 53 | 54 | def convolution(self, A, U): 55 | convoluted = torch.einsum("bhnj,bjd->bnhd", A, U) # (batch, L_out, n_head * in_dim) 56 | convoluted = convoluted.reshape(U.shape[0], -1, self.n_head * U.shape[-1]) # (batch, L_out, n_head*hid_dim) 57 | return convoluted 58 | 59 | class posatt_cross(posatt): 60 | def __init__(self, n_head, in_dim, locality): 61 | super(posatt_cross, self).__init__(n_head, in_dim, locality) 62 | 63 | def forward(self, mesh_out, mesh_in, inputs): 64 | """ 65 | mesh_out: (batch, L_out, 2) 66 | mesh_in: (batch, L_in, 2) 67 | inputs: (batch, L_in, in_dim) 68 | """ 69 | att = self.dist2att(mesh_out, mesh_in, self.lmda, self.locality) 70 | convoluted = self.convolution(att, inputs) 71 | return convoluted 72 | 73 | class pit(nn.Module): 74 | def __init__(self, 75 | space_dim, 76 | in_dim, 77 | out_dim, 78 | hid_dim, 79 | n_head, 80 | n_blocks, 81 | mesh_ltt, 82 | en_loc, 83 | de_loc): 84 | 85 | super(pit, self).__init__() 86 | self.space_dim= space_dim 87 | self.in_dim = in_dim 88 | self.out_dim = out_dim 89 | self.hid_dim = hid_dim 90 | self.n_head = n_head 91 | self.n_blocks = n_blocks 92 | if mesh_ltt != None: 93 | self.mesh_ltt = mesh_ltt.reshape(-1, self.space_dim) 94 | else: 95 | self.mesh_ltt = mesh_ltt 96 | self.en_local = en_loc 97 | self.de_local = de_loc 98 | 99 | self.down = posatt_cross(self.n_head, self.in_dim, self.en_local) 100 | self.en_layer = kaiming_mlp(self.n_head * (self.in_dim+self.space_dim), self.hid_dim, self.hid_dim) 101 | 102 | self.conv = torch.nn.ModuleList([posatt(self.n_head, self.hid_dim, 1.0) for _ in range(self.n_blocks)]) 103 | self.mlp = torch.nn.ModuleList([kaiming_mlp((1 + self.n_head) * self.hid_dim, self.hid_dim, self.hid_dim) for _ in range(self.n_blocks)]) # with residual in the convolutions 104 | 105 | self.up = posatt_cross(self.n_head, self.hid_dim, self.de_local) 106 | self.de = kaiming_mlp(self.n_head * self.hid_dim, self.hid_dim, self.out_dim) 107 | 108 | def encoder(self, mesh_in, func_in, mesh_ltt): 109 | func_ltt = self.down(mesh_ltt, mesh_in, func_in) 110 | func_ltt = self.en_layer(func_ltt) 111 | func_ltt = gelu(func_ltt ) 112 | return func_ltt 113 | 114 | def processor(self, func_ltt, mesh_ltt): 115 | for a, w in zip(self.conv, self.mlp): 116 | ''' 117 | U = AUW 118 | ''' 119 | func_ltt = a(mesh_ltt, func_ltt) 120 | func_ltt = w(func_ltt) 121 | func_ltt = gelu(func_ltt) 122 | return func_ltt 123 | 124 | def decoder(self, mesh_ltt, func_ltt, mesh_out): 125 | func_out = self.up(mesh_out, mesh_ltt, func_ltt) 126 | func_out = self.de(func_out) 127 | return func_out 128 | 129 | class posatt_fixed(posatt): 130 | def __init__(self, n_head, in_dim, locality): 131 | super(posatt_fixed, self).__init__(n_head, in_dim, locality) 132 | 133 | def dist2att(self, mesh_out, mesh_in, scale, locality): 134 | m_dist = torch.sum((mesh_out.unsqueeze(-2) - mesh_in.unsqueeze(-3))**2, dim=-1) # (L_out, L_in) 135 | scaled_dist = m_dist * torch.tan(0.25*pi*(1-1e-7)*(1.0+torch.sin(scale))) # (n_head, L_out, L_in) 136 | mask = torch.quantile(scaled_dist, locality, dim=-1, keepdim=True) 137 | scaled_dist = torch.where(scaled_dist <= mask, scaled_dist, torch.tensor(torch.finfo(torch.float32).max, device=scaled_dist.device)) 138 | scaled_dist = -scaled_dist 139 | return torch.nn.Softmax(dim=-1)(scaled_dist) 140 | 141 | def convolution(self, A, U): 142 | convoluted = torch.einsum("hnj,bjd->bnhd", A, U) # (batch, L_out, n_head * in_dim) 143 | convoluted = convoluted.reshape(U.shape[0], -1, self.n_head * U.shape[-1]) # (batch, L_out, n_head*hid_dim) 144 | return convoluted 145 | 146 | class posatt_cross_fixed(posatt_fixed): 147 | 148 | def __init__(self, n_head, in_dim, locality): 149 | super(posatt_cross_fixed, self).__init__(n_head, in_dim, locality) 150 | 151 | def forward(self, mesh_out, mesh_in, inputs): 152 | """ 153 | mesh_out: (L_out, 2) 154 | mesh_in: (L_in, 2) 155 | inputs: (batch, L_in, in_dim) 156 | """ 157 | att = self.dist2att(mesh_out, mesh_in, self.lmda, self.locality) 158 | convoluted = self.convolution(att, inputs) 159 | return convoluted 160 | 161 | class pit_fixed(pit): 162 | def __init__(self, 163 | space_dim, 164 | in_dim, 165 | out_dim, 166 | hid_dim, 167 | n_head, 168 | n_blocks, 169 | mesh_ltt, 170 | en_loc, 171 | de_loc): 172 | 173 | super(pit_fixed, self).__init__(space_dim, 174 | in_dim, 175 | out_dim, 176 | hid_dim, 177 | n_head, 178 | n_blocks, 179 | mesh_ltt, 180 | en_loc, 181 | de_loc) 182 | self.down = posatt_cross_fixed(self.n_head, self.in_dim, self.en_local) 183 | self.conv = torch.nn.ModuleList([posatt_fixed(self.n_head, self.hid_dim, 1.0) for _ in range(self.n_blocks)]) 184 | self.up = posatt_cross_fixed(self.n_head, self.hid_dim, self.de_local) 185 | ############################## 186 | class posatt_periodic1d(posatt_fixed): 187 | def __init__(self, n_head, in_dim, locality): 188 | super(posatt_periodic1d, self).__init__(n_head, in_dim, locality) 189 | 190 | def dist2att(self, mesh_out, mesh_in, scale, locality): 191 | dx = torch.abs(mesh_in[1,0] - mesh_in[0,0]) 192 | l = dx * mesh_in.shape[0] 193 | m_diff = abs(mesh_out.unsqueeze(-2) - mesh_in.unsqueeze(-3)) 194 | m_diff = torch.minimum(m_diff, l-m_diff) 195 | m_dist = m_diff[...,0]**2 196 | scaled_dist = m_dist * torch.tan(0.25*pi*(1-1e-7)*(1.0+torch.sin(scale))) # (n_head, L_out, L_in) 197 | mask = torch.quantile(scaled_dist, locality, dim=-1, keepdim=True) 198 | scaled_dist = torch.where(scaled_dist <= mask, scaled_dist, torch.tensor(torch.finfo(torch.float32).max, device=scaled_dist.device)) 199 | scaled_dist = -scaled_dist 200 | return torch.nn.Softmax(dim=-1)(scaled_dist) 201 | 202 | class posatt_cross_periodic1d(posatt_periodic1d): 203 | 204 | def __init__(self, n_head, in_dim, locality): 205 | super(posatt_cross_periodic1d, self).__init__(n_head, in_dim, locality) 206 | 207 | def forward(self, mesh_out, mesh_in, inputs): 208 | """ 209 | mesh_out: (L_out, 2) 210 | mesh_in: (L_in, 2) 211 | inputs: (batch, L_in, in_dim) 212 | """ 213 | att = self.dist2att(mesh_out, mesh_in, self.lmda, self.locality) 214 | convoluted = self.convolution(att, inputs) 215 | return convoluted 216 | 217 | class pit_periodic1d(pit): 218 | def __init__(self, 219 | space_dim, 220 | in_dim, 221 | out_dim, 222 | hid_dim, 223 | n_head, 224 | n_blocks, 225 | mesh_ltt, 226 | en_loc, 227 | de_loc): 228 | 229 | super(pit_periodic1d, self).__init__(space_dim, 230 | in_dim, 231 | out_dim, 232 | hid_dim, 233 | n_head, 234 | n_blocks, 235 | mesh_ltt, 236 | en_loc, 237 | de_loc) 238 | self.down = posatt_cross_periodic1d(self.n_head, self.in_dim, self.en_local) 239 | self.conv = torch.nn.ModuleList([posatt_periodic1d(self.n_head, self.hid_dim, 1.0) for _ in range(self.n_blocks)]) 240 | self.up = posatt_cross_periodic1d(self.n_head, self.hid_dim, self.de_local) 241 | ################ 242 | 243 | class posatt_periodic2d(posatt_fixed): 244 | def __init__(self, n_head, in_dim, locality): 245 | super(posatt_periodic2d, self).__init__(n_head, in_dim, locality) 246 | 247 | def dist2att(self, mesh_out, mesh_in, scale, locality): 248 | res = int(mesh_in.shape[0]**0.5) 249 | dx =( torch.max(mesh_in[:,0]) - torch.min(mesh_in[:,0]) ) / (res - 1) 250 | l = dx * res 251 | m_diff = abs(mesh_out.unsqueeze(-2) - mesh_in.unsqueeze(-3)) 252 | m_diff = torch.minimum(m_diff, l-m_diff) 253 | m_dist = torch.sum(m_diff**2, dim=-1) 254 | scaled_dist = m_dist * torch.tan(0.25*pi*(1-1e-7)*(1.0+torch.sin(scale))) # (n_head, L_out, L_in) 255 | mask = torch.quantile(scaled_dist, locality, dim=-1, keepdim=True) 256 | scaled_dist = torch.where(scaled_dist <= mask, scaled_dist, torch.tensor(torch.finfo(torch.float32).max, device=scaled_dist.device)) 257 | scaled_dist = -scaled_dist 258 | return torch.nn.Softmax(dim=-1)(scaled_dist) 259 | 260 | class posatt_cross_periodic2d(posatt_periodic2d): 261 | 262 | def __init__(self, n_head, in_dim, locality): 263 | super(posatt_cross_periodic2d, self).__init__(n_head, in_dim, locality) 264 | 265 | def forward(self, mesh_out, mesh_in, inputs): 266 | """ 267 | mesh_out: (L_out, 2) 268 | mesh_in: (L_in, 2) 269 | inputs: (batch, L_in, in_dim) 270 | """ 271 | att = self.dist2att(mesh_out, mesh_in, self.lmda, self.locality) 272 | convoluted = self.convolution(att, inputs) 273 | return convoluted 274 | 275 | class pit_periodic2d(pit): 276 | def __init__(self, 277 | space_dim, 278 | in_dim, 279 | out_dim, 280 | hid_dim, 281 | n_head, 282 | n_blocks, 283 | mesh_ltt, 284 | en_loc, 285 | de_loc): 286 | 287 | super(pit_periodic2d, self).__init__(space_dim, 288 | in_dim, 289 | out_dim, 290 | hid_dim, 291 | n_head, 292 | n_blocks, 293 | mesh_ltt, 294 | en_loc, 295 | de_loc) 296 | self.down = posatt_cross_periodic2d(self.n_head, self.in_dim, self.en_local) 297 | self.conv = torch.nn.ModuleList([posatt_periodic2d(self.n_head, self.hid_dim, 1.0) for _ in range(self.n_blocks)]) 298 | self.up = posatt_cross_periodic2d(self.n_head, self.hid_dim, self.de_local) 299 | -------------------------------------------------------------------------------- /supplementary_data/data_burgers.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7dbf1534c7d8bde94f43c3bee886c6bd278debd9dacd93ea5747915f361ba8a0 3 | size 18874608 4 | -------------------------------------------------------------------------------- /supplementary_data/data_sod.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c57e134c73cdba4aa247ae36b498cc08bef65a2337fc99b6397b13f20e872443 3 | size 125829376 4 | -------------------------------------------------------------------------------- /supplementary_data/shape_coords.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/supplementary_data/shape_coords.npy -------------------------------------------------------------------------------- /tensorflow/1_InviscidBurgers/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | from scipy.io import savemat 8 | import matplotlib.pyplot as plt 9 | from time import time 10 | 11 | ### custom imports 12 | from utils import * 13 | 14 | # params ##################################33 15 | n_epochs = 500 16 | lr = 0.001 17 | batch_size = 5 18 | encode_dim = 64 19 | out_dim = 1 20 | n_head = 2 21 | n_train = 950 22 | n_test = 128 23 | qry_res = 1024 24 | ltt_res = 1024 25 | en_loc = 1 # locality paramerter in Encoder 26 | de_local = 8 # locality paramerter in Encoder 27 | save_path = './model/' 28 | 29 | 30 | ### load dataset 31 | trainX, trainY, testX, testY= load_data("./1_InviscidBurgers.mat", n_train, n_test) 32 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 33 | 34 | ### define a model 35 | m_query = pairwise_dist(qry_res, qry_res) # pairwise distance matrix for Encoder and Decoder 36 | m_cross = pairwise_dist(qry_res, ltt_res) # pairwise distance matrix for Encoder and Decoder 37 | m_latent = pairwise_dist(ltt_res, ltt_res) # pairwise distance matrix for Processor 38 | network = PiT(m_query, m_cross, m_latent, out_dim, encode_dim, n_head, 1, 8) 39 | #network = LiteTransformer(m_query, m_cross, out_dim, encode_dim, n_head, 1, 8) 40 | #network = Transformer(qry_res, out_dim, encode_dim, n_head) 41 | 42 | inputs = tf.keras.Input(shape=(qry_res,trainX.shape[-1])) 43 | outputs = network(inputs) 44 | model = tf.keras.Model(inputs, outputs) 45 | network.summary() 46 | 47 | ### compile model 48 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 49 | loss=rel_norm(), 50 | jit_compile=True) 51 | 52 | ### fit model 53 | start = time() 54 | train_history = model.fit(trainX, trainY, 55 | batch_size, n_epochs, verbose=1, # verbose to 0 when nohup, to 1 when run in shell 56 | validation_data=(testX, testY), 57 | validation_batch_size=128) 58 | end = time() 59 | print(' ') 60 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 61 | print(' ') 62 | 63 | ### save model 64 | model.save_weights(save_path + 'checkpoints/my_checkpoint') 65 | 66 | ### plot training history 67 | loss_hist = train_history.history 68 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 69 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 70 | savetxt(save_path + 'training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 71 | 72 | plt.plot(train_loss, label='train') 73 | plt.plot(test_loss, label='test') 74 | plt.legend() 75 | plt.yscale('log', base=10) 76 | plt.savefig(save_path + 'training_history.png') 77 | plt.close() 78 | 79 | ### do some tests 80 | pred = model.predict(testX) 81 | savemat(save_path + 'pred.mat', mdict={'pred':pred, 'X':testX, 'Y':testY}) 82 | print(rel_l1_median(testY, pred)) 83 | -------------------------------------------------------------------------------- /tensorflow/2_ShockTube/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | from scipy.io import savemat 8 | import matplotlib.pyplot as plt 9 | from time import time 10 | 11 | ### custom imports 12 | from utils import * 13 | 14 | # params ##################################33 15 | n_epochs = 500 16 | lr = 0.001 17 | batch_size = 8 18 | encode_dim = 64 19 | out_dim = 1 20 | n_head = 2 21 | n_train = 1024 22 | n_test = 128 23 | qry_res = 2048 24 | ltt_res = 1024 25 | en_loc = 4 # locality paramerter in Encoder 26 | de_loc = 2 # locality paramerter in Encoder 27 | save_path = './model/' 28 | # load dataset 29 | trainX, trainY, testX, testY= load_data("./2_ShockTube.mat", n_train, n_test) 30 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 31 | 32 | ### define a model 33 | m_query = pairwise_dist(qry_res, qry_res) # pairwise distance matrix for Encoder and Decoder 34 | m_cross = pairwise_dist(qry_res, ltt_res) # pairwise distance matrix for Encoder and Decoder 35 | m_latent = pairwise_dist(ltt_res, ltt_res) # pairwise distance matrix for Processor 36 | network = PiT(m_query, m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc) 37 | #network = LiteTransformer(m_query, m_cross, out_dim, encode_dim, n_head, en_loc, de_loc) 38 | #network = Transformer(qry_res, out_dim, encode_dim, n_head) 39 | 40 | inputs = tf.keras.Input(shape=(qry_res,trainX.shape[-1])) 41 | outputs = network(inputs) 42 | model = tf.keras.Model(inputs, outputs) 43 | network.summary() 44 | 45 | ### compile model 46 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 47 | loss=rel_norm(), 48 | jit_compile=True) # jit_compile always on, for faster training 49 | 50 | ### fit model 51 | start = time() 52 | train_history = model.fit(trainX, trainY, 53 | batch_size, n_epochs, verbose=1, # verbose to 0 when nohup, to 1 when run in shell 54 | validation_data=(testX, testY), 55 | validation_batch_size=128) 56 | end = time() 57 | print(' ') 58 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 59 | 60 | print(' ') 61 | ### save model 62 | model.save_weights(save_path + 'checkpoints/my_checkpoint') 63 | 64 | ### plot training history 65 | loss_hist = train_history.history 66 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 67 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 68 | savetxt(save_path + 'training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 69 | 70 | plt.plot(train_loss, label='train') 71 | plt.plot(test_loss, label='test') 72 | plt.legend() 73 | plt.yscale('log', base=10) 74 | plt.savefig(save_path + 'training_history.png') 75 | plt.close() 76 | 77 | 78 | pred = model.predict(testX) 79 | savemat(save_path + 'pred.mat', mdict={'pred':pred, 'X':testX, 'Y':testY}) 80 | print(rel_l1_median(testY, pred)) 81 | -------------------------------------------------------------------------------- /tensorflow/2_ShockTube/utils.py: -------------------------------------------------------------------------------- 1 | from numpy import newaxis 2 | from scipy.io import loadmat 3 | import tensorflow as tf 4 | tf.keras.utils.set_random_seed(0) 5 | tf.config.experimental.enable_op_determinism() 6 | tf.random.set_seed(0) 7 | physical_devices = tf.config.list_physical_devices('GPU') 8 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 9 | import tensorflow_probability as tfp 10 | 11 | class rel_norm(tf.keras.losses.Loss): 12 | ''' 13 | Compute the average relative l1 loss between a batch of targets and predictions 14 | ''' 15 | def __init__(self): 16 | super().__init__() 17 | def call(self, true, pred): 18 | ''' 19 | true: (batch_size, L, d). 20 | pred: (batch_size, L, d). 21 | number of variables d=1 22 | ''' 23 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1,))(true-pred), ord=1, axis=1), tf.norm(tf.keras.layers.Reshape((-1,))(true), ord=1, axis=1)) 24 | return tf.math.reduce_mean(rel_error) 25 | 26 | def rel_l1_median(true, pred): 27 | ''' 28 | Compute the 25%, 50%, and 75% quantile of the relative l1 loss between a batch of targets and predictions 29 | ''' 30 | rel_error = tf.norm(true[...,0]-pred[...,0], ord=1, axis=1) / tf.norm(true[...,0], ord=1, axis=1) 31 | return tfp.stats.percentile(rel_error, 25, interpolation="linear").numpy(), tfp.stats.percentile(rel_error, 50, interpolation="linear").numpy(), tfp.stats.percentile(rel_error, 75, interpolation="linear").numpy() 32 | 33 | def pairwise_dist(res1, res2): 34 | 35 | grid1 = tf.reshape(tf.linspace(0, 1, res1+1)[:-1], (-1,1)) 36 | grid1 = tf.tile(grid1, [1,res2]) 37 | grid2 = tf.reshape(tf.linspace(0, 1, res2+1)[:-1], (1,-1)) 38 | grid2 = tf.tile(grid2, [res1,1]) 39 | print(grid1.shape, grid2.shape) 40 | 41 | dist2 = (grid1-grid2)**2 42 | print(dist2.shape, tf.math.reduce_max(dist2)) 43 | 44 | return tf.cast(dist2, 'float32') 45 | 46 | def load_data(path_data, ntrain = 1024, ntest=128): 47 | 48 | data = loadmat(path_data) 49 | X_data = data["x"].astype('float32') 50 | Y_data = data["y"].astype('float32') 51 | X_train = X_data[:ntrain,:] 52 | Y_train = Y_data[:ntrain,:, newaxis] 53 | 54 | X_test = X_data[-ntest:,:] 55 | Y_test = Y_data[-ntest:,:, newaxis] 56 | 57 | return X_train, Y_train, X_test, Y_test 58 | 59 | class mlp(tf.keras.layers.Layer): 60 | ''' 61 | A two-layer MLP with GELU activation. 62 | ''' 63 | def __init__(self, n_filters1, n_filters2): 64 | super(mlp, self).__init__() 65 | 66 | self.width1 = n_filters1 67 | self.width2 = n_filters2 68 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 69 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 70 | 71 | def call(self, inputs): 72 | x = self.mlp1(inputs) 73 | x = self.mlp2(x) 74 | return x 75 | 76 | def get_config(self): 77 | config = { 78 | 'n_filters1': self.width1, 79 | 'n_filters2': self.width2, 80 | } 81 | return config 82 | 83 | class MultiHeadPosAtt(tf.keras.layers.Layer): 84 | ''' 85 | Global, local and cross variants of the multi-head position-attention mechanism. 86 | ''' 87 | def __init__(self, m_dist, n_head, hid_dim, locality): 88 | super(MultiHeadPosAtt, self).__init__() 89 | ''' 90 | m_dist: distance matrix 91 | n_head: number of attention heads 92 | hid_dim: encoding dimension 93 | locality: quantile parameter to customize receptive field in position-attention 94 | ''' 95 | self.dist = m_dist 96 | self.locality = locality 97 | self.hid_dim = hid_dim 98 | self.n_head = n_head 99 | self.v_dim = round(self.hid_dim/self.n_head) 100 | 101 | def build(self, input_shape): 102 | 103 | self.r = self.add_weight( 104 | shape=(self.n_head, 1, 1), 105 | trainable=True, 106 | name="band_width", 107 | ) 108 | 109 | self.weight = self.add_weight( 110 | shape=(self.n_head, input_shape[-1], self.v_dim), 111 | initializer="he_normal", 112 | trainable=True, 113 | name="weight", 114 | ) 115 | self.built = True 116 | 117 | def call(self, inputs): 118 | scaled_dist = self.dist * self.r**2 119 | if self.locality <= 100: 120 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 121 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 122 | else: 123 | pass 124 | scaled_dist = - scaled_dist # (n_head, L, L) 125 | att = tf.nn.softmax(scaled_dist, axis=2) # (n_heads, L, L) 126 | 127 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) # (batch_size, n_head, L, v_dim) 128 | 129 | concat = tf.einsum("hnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 130 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 131 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 132 | return tf.keras.activations.gelu(concat) 133 | 134 | def get_config(self): 135 | config = { 136 | 'm_dist': self.dist, 137 | 'hid_dim': self.hid_dim, 138 | 'n_head': self.n_head, 139 | 'locality': self.locality 140 | } 141 | return config 142 | 143 | class PiT(tf.keras.Model): 144 | ''' 145 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 146 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 147 | ''' 148 | def __init__(self, m_qry, m_cross, m_ltt, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 149 | super(PiT, self).__init__() 150 | ''' 151 | m_qry: distance matrix between X_query and X_query; (L_qry,L_qry) 152 | m_cross: distance matrix between X_query and X_latent; (L_qry,L_ltt) 153 | m_ltt: distance matrix between X_latent and X_latent; (L_ltt,L_ltt) 154 | out_dim: number of variables 155 | hid_dim: encoding dimension (network width) 156 | n_head: number of heads in multi-head attention modules 157 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive filed 158 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive filed 159 | ''' 160 | self.m_qry = m_qry 161 | self.res = m_qry.shape[0] 162 | self.m_cross = m_cross 163 | self.m_ltt = m_ltt 164 | self.out_dim = out_dim 165 | self.hid_dim = hid_dim 166 | self.n_head = n_head 167 | self.en_local = locality_encoder 168 | self.de_local = locality_decoder 169 | self.n_blocks = 4 # number of position-attention modules in the Processor 170 | 171 | # Encoder 172 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 173 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 174 | 175 | # Processor 176 | self.MHPA = [MultiHeadPosAtt(self.m_ltt, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 177 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 178 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 179 | 180 | # Decoder 181 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 182 | self.up2 = MultiHeadPosAtt(self.m_qry, self.n_head, self.hid_dim, locality=self.de_local) 183 | self.mlp = mlp(self.hid_dim, self.hid_dim) 184 | self.w = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 185 | self.de_layer = mlp(self.hid_dim, self.out_dim) 186 | 187 | def call(self, inputs): 188 | 189 | # Encoder 190 | grid = self.get_mesh(inputs) # (batch_size, L_qry, 1) 191 | en = tf.concat([grid, inputs], axis=-1) # (batch_size, L_qry, input_dim+1) 192 | en = self.en_layer(en) # (batch_size, L_qry, hid_dim) 193 | x = self.down(en) # (batch_size, L_ltt, hid_dim) 194 | 195 | # Processor 196 | for i in range(self.n_blocks): 197 | x = self.MLP[i](self.MHPA[i](x)) + self.W[i](x) # (batch_size, L_ltt, hid_dim) 198 | x = tf.keras.activations.gelu(x) # (batch_size, L_ltt, hid_dim) 199 | 200 | # Decoder 201 | de = self.up(x) # (batch_size, L_qry, hid_dim) 202 | de = self.mlp(self.up2(de)) + self.w(de) # (batch_size, L_ltt, hid_dim) 203 | de = tf.keras.activations.gelu(de) # (batch_size, L_ltt, hid_dim) 204 | de = self.de_layer(de) # (batch_size, L_ltt, out_dim) 205 | return de 206 | 207 | def get_mesh(self, inputs): 208 | grid = tf.reshape(tf.linspace(0, 1, self.res+1)[:-1], (1,-1,1)) 209 | grid = tf.repeat(grid, tf.shape(inputs)[0], 0) 210 | return tf.cast(grid, dtype="float32") 211 | 212 | def get_config(self): 213 | config = { 214 | 'm_qry': self.m_qry, 215 | 'm_cross': self.m_cross, 216 | 'm_ltt': self.m_ltt, 217 | 'out_dim': self.out_dim, 218 | 'hid_dim': self.hid_dim, 219 | 'n_head': self.n_head, 220 | 'locality_encoder': self.en_local, 221 | 'locality_decoder': self.de_local, 222 | } 223 | return config 224 | 225 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 226 | ''' 227 | Scaled dot-product multi-head self-attention 228 | ''' 229 | def __init__(self, n_head, hid_dim): 230 | super(MultiHeadSelfAtt, self).__init__() 231 | 232 | self.hid_dim = hid_dim 233 | self.n_head = n_head 234 | self.v_dim = round(self.hid_dim/self.n_head) 235 | 236 | def build(self, input_shape): 237 | 238 | self.q = self.add_weight( 239 | shape=(self.n_head, input_shape[-1], self.v_dim), 240 | initializer="he_normal", 241 | trainable=True, 242 | name="query", 243 | ) 244 | 245 | self.k = self.add_weight( 246 | shape=(self.n_head, input_shape[-1], self.v_dim), 247 | initializer="he_normal", 248 | trainable=True, 249 | name="key", 250 | ) 251 | 252 | self.v = self.add_weight( 253 | shape=(self.n_head, input_shape[-1], self.v_dim), 254 | initializer="he_normal", 255 | trainable=True, 256 | name="value", 257 | ) 258 | self.built = True 259 | 260 | def call(self, inputs): 261 | 262 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q) # (batch_size, n_head, L, v_dim) 263 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k) # (batch_size, n_head, L, v_dim) 264 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1) # (batch_size, n_heads, L, L) 265 | 266 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v)#(batch_size, n_head, L, v_dim) 267 | 268 | concat = tf.einsum("...nj,...jd->...nd", att, value) # (batch_size, n_head, L, v_dim) 269 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 270 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 271 | return tf.keras.activations.gelu(concat) 272 | 273 | def get_config(self): 274 | config = { 275 | 'n_head': self.n_head, 276 | 'hid_dim': self.hid_dim 277 | } 278 | return config 279 | 280 | class LiteTransformer(tf.keras.Model): 281 | ''' 282 | Replace position-attention of the Processor in a PiT with self-attention 283 | ''' 284 | def __init__(self, m_qry, m_cross, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 285 | super(LiteTransformer, self).__init__() 286 | 287 | self.m_qry = m_qry 288 | self.res = m_qry.shape[0] 289 | self.m_cross = m_cross 290 | self.out_dim = out_dim 291 | self.hid_dim = hid_dim 292 | self.n_head = n_head 293 | self.en_local = locality_encoder 294 | self.de_local = locality_decoder 295 | self.n_blocks = 4 296 | 297 | # Encoder 298 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 299 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 300 | 301 | # Processor 302 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 303 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 304 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 305 | 306 | # Decoder 307 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 308 | self.up2 = MultiHeadPosAtt(self.m_qry, self.n_head, self.hid_dim, locality=self.de_local) 309 | self.mlp = mlp(self.hid_dim, self.hid_dim) 310 | self.w = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 311 | self.de_layer = mlp(self.hid_dim, self.out_dim) 312 | 313 | def call(self, inputs): 314 | 315 | # Encoder 316 | grid = self.get_mesh(inputs) 317 | en = tf.concat([grid, inputs], axis=-1) 318 | en = self.en_layer(en) 319 | x = self.down(en) 320 | 321 | # Processor 322 | for i in range(self.n_blocks): 323 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 324 | x = tf.keras.activations.gelu(x) 325 | 326 | # Decoder 327 | de = self.up(x) 328 | de = self.mlp(self.up2(de)) + self.w(de) 329 | de = tf.keras.activations.gelu(de) 330 | de = self.de_layer(de) 331 | return de 332 | 333 | def get_mesh(self, inputs): 334 | grid = tf.reshape(tf.linspace(0, 1, self.res+1)[:-1], (1,-1,1)) 335 | grid = tf.repeat(grid, tf.shape(inputs)[0], 0) 336 | return tf.cast(grid, dtype="float32") 337 | 338 | def get_config(self): 339 | config = { 340 | 'm_qry':self.m_qry, 341 | 'm_cross':self.m_cross, 342 | 'out_dim': self.out_dim, 343 | 'hid_dim': self.hid_dim, 344 | 'n_head': self.n_head, 345 | 'locality_encoder': self.en_local, 346 | 'locality_decoder': self.de_local 347 | } 348 | return config 349 | 350 | class Transformer(tf.keras.Model): 351 | ''' 352 | Replace position-attention of a PiT with self-attention. 353 | ''' 354 | def __init__(self, res, out_dim, hid_dim, n_head): 355 | super(Transformer, self).__init__() 356 | 357 | self.res = res 358 | self.out_dim = out_dim 359 | self.hid_dim = hid_dim 360 | self.n_head = n_head 361 | self.n_blocks = 4 362 | 363 | # Encoder 364 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 365 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 366 | 367 | # Processor 368 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 369 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 370 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 371 | 372 | # Decoder 373 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 374 | self.up2 = MultiHeadSelfAtt(self.n_head, self.hid_dim) 375 | self.mlp = mlp(self.hid_dim, self.hid_dim) 376 | self.w = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 377 | self.de_layer = mlp(self.hid_dim, self.out_dim) 378 | 379 | def call(self, inputs): 380 | 381 | # Encoder 382 | grid = self.get_mesh(inputs) 383 | en = tf.concat([grid, inputs], axis=-1) 384 | en = self.en_layer(en) 385 | x = self.down(en) 386 | 387 | # Processor 388 | for i in range(self.n_blocks): 389 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 390 | x = tf.keras.activations.gelu(x) 391 | 392 | # Decoder 393 | de = self.up(x) 394 | de = self.mlp(self.up2(de)) + self.w(de) 395 | de = tf.keras.activations.gelu(de) 396 | de = self.de_layer(de) 397 | return de 398 | 399 | def get_mesh(self, inputs): 400 | grid = tf.reshape(tf.linspace(0, 1, self.res+1)[:-1], (1,-1,1)) 401 | grid = tf.repeat(grid, tf.shape(inputs)[0], 0) 402 | return tf.cast(grid, dtype="float32") 403 | 404 | def get_config(self): 405 | config = { 406 | 'res':self.res, 407 | 'out_dim': self.out_dim, 408 | 'hid_dim': self.hid_dim, 409 | 'n_head': self.n_head, 410 | } 411 | return config 412 | 413 | class SelfMultiHeadPosAtt(tf.keras.layers.Layer): 414 | ''' 415 | Combine self-attention with position-attention: A = QK^T/sqrt(d) - lambda D 416 | ''' 417 | def __init__(self, m_dist, n_head, hid_dim, locality): 418 | super(SelfMultiHeadPosAtt, self).__init__() 419 | 420 | self.dist = m_dist 421 | self.locality = locality 422 | self.hid_dim = hid_dim 423 | self.n_head = n_head 424 | self.v_dim = round(self.hid_dim/self.n_head) 425 | 426 | def build(self, input_shape): 427 | 428 | 429 | self.r = self.add_weight( 430 | shape=(self.n_head, 1, 1), 431 | trainable=True, 432 | constraint=tf.keras.constraints.NonNeg(), 433 | name="band_width", 434 | ) 435 | 436 | self.query = self.add_weight( 437 | shape=(self.n_head, input_shape[-1], self.v_dim), 438 | trainable=True, 439 | name="query", 440 | ) 441 | 442 | self.key = self.add_weight( 443 | shape=(self.n_head, input_shape[-1], self.v_dim), 444 | trainable=True, 445 | name="key", 446 | ) 447 | 448 | self.weight = self.add_weight( 449 | shape=(self.n_head, input_shape[-1], self.v_dim), 450 | initializer="he_normal", 451 | trainable=True, 452 | name="weight", 453 | ) 454 | self.built = True 455 | 456 | def call(self, inputs): 457 | 458 | scaled_dist = self.dist * self.r**2 459 | if self.locality <= 100: 460 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 461 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 462 | else: 463 | pass 464 | scaled_dist = - scaled_dist 465 | 466 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.query) 467 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.key) 468 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) 469 | product = tf.einsum("...mi,...ni->...mn", query, key) 470 | att = product / self.v_dim**0.5 + scaled_dist[tf.newaxis,...] 471 | att = tf.nn.softmax(att, axis=-1) 472 | concat = tf.einsum("...nj,...jd->...nd", att, value) 473 | concat = tf.transpose(concat, (0,2,1,3)) 474 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) 475 | 476 | return tf.keras.activations.gelu(concat) 477 | 478 | class SelfPiT(tf.keras.Model): 479 | ''' 480 | Replace position-attention of a PiT by SelfMultiHeadPosAtt 481 | ''' 482 | def __init__(self, m_qry, m_cross, m_ltt, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 483 | super(SelfPiT, self).__init__() 484 | 485 | self.m_qry = m_qry 486 | self.res = m_qry.shape[0] 487 | self.m_cross = m_cross 488 | self.m_ltt = m_ltt 489 | self.out_dim = out_dim 490 | self.hid_dim = hid_dim 491 | self.n_head = n_head 492 | self.en_local = locality_encoder 493 | self.de_local = locality_decoder 494 | self.n_blocks = 4 495 | 496 | # Encoder 497 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 498 | self.down = SelfMultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 499 | 500 | # Processor 501 | self.MHPA = [SelfMultiHeadPosAtt(self.m_ltt, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 502 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 503 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 504 | 505 | # Decoder 506 | self.up = SelfMultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 507 | self.up2 = SelfMultiHeadPosAtt(self.m_qry, self.n_head, self.hid_dim, locality=self.de_local) 508 | self.mlp = mlp(self.hid_dim, self.hid_dim) 509 | self.w = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 510 | self.de_layer = mlp(self.hid_dim, self.out_dim) 511 | 512 | def call(self, inputs): 513 | 514 | # Encoder 515 | grid = self.get_mesh(inputs) 516 | en = tf.concat([grid, inputs], axis=-1) 517 | en = self.en_layer(en) 518 | x = self.down(en) 519 | 520 | # Processor 521 | for i in range(self.n_blocks): 522 | x = self.MLP[i](self.MHPA[i](x)) + self.W[i](x) 523 | x = tf.keras.activations.gelu(x) 524 | 525 | # Decoder 526 | de = self.up(x) 527 | de = self.mlp(self.up2(de)) + self.w(de) 528 | de = tf.keras.activations.gelu(de) 529 | de = self.de_layer(de) 530 | return de 531 | 532 | def get_mesh(self, inputs): 533 | grid = tf.reshape(tf.linspace(0, 1, self.res+1)[:-1], (1,-1,1)) 534 | grid = tf.repeat(grid, tf.shape(inputs)[0], 0) 535 | return tf.cast(grid, dtype="float32") 536 | 537 | def get_config(self): 538 | config = { 539 | 'm_qry': self.m_qry, 540 | 'm_cross': self.m_cross, 541 | 'm_ltt': self.m_ltt, 542 | 'out_dim': self.out_dim, 543 | 'hid_dim': self.hid_dim, 544 | 'n_head': self.n_head, 545 | 'locality_encoder': self.en_local, 546 | 'locality_decoder': self.de_local, 547 | } 548 | return config 549 | 550 | 551 | -------------------------------------------------------------------------------- /tensorflow/3_Darcy2D/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | from scipy.io import savemat 6 | from utils import *#load_data, pairwise_dist, PixelWiseNormalization 7 | from numpy import zeros 8 | import matplotlib.pyplot as plt 9 | downsampling = 10 10 | qry_res = int((421-1)/downsampling+1) 11 | ltt_res = 32 12 | en_loc = 2 13 | de_loc = 5 14 | encode_dim = 128 15 | out_dim = 1 16 | n_head = 2 17 | n_train = 1024 18 | n_test = 100 19 | 20 | 21 | ## Load training data and construct normalizer 22 | trainX, trainY, testX, testY= load_data("./piececonst_r421_N1024_smooth1.mat", "./piececonst_r421_N1024_smooth2.mat", downsampling, n_train, n_test) 23 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 24 | XNormalizer = PixelWiseNormalization(trainX) 25 | YNormalizer = PixelWiseNormalization(trainY) 26 | 27 | ## creat model, load trained weights 28 | m_cross = pairwise_dist(qry_res, qry_res, ltt_res, ltt_res) # pairwise distance matrix for encoder and decoder 29 | m_latent = pairwise_dist(ltt_res, ltt_res, ltt_res, ltt_res) # pairwise distance matrix for processor 30 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc, YNormalizer) 31 | inputs = tf.keras.Input(shape=(qry_res,qry_res,trainX.shape[-1])) 32 | outputs = network(inputs) 33 | model = tf.keras.Model(inputs, outputs) 34 | model.load_weights('./results_43/checkpoints/my_checkpoint').expect_partial() 35 | 36 | #testX = XNormalizer.normalize(testX) 37 | #pred = tf.convert_to_tensor(model.predict(testX), "float32") 38 | #rel_err = rel_norm()(testY,pred) 39 | #print(rel_err) 40 | #pred = tf.convert_to_tensor(network(testX), "float32") 41 | #rel_err = rel_norm()(testY,pred) 42 | #print(rel_err) 43 | 44 | #network.summary() 45 | ################## 46 | #zeor-shot super-resolution 47 | downsampling = 1 48 | qry_res = int((421-1)/downsampling+1) 49 | trainX, trainY, testX, testY= load_data("./piececonst_r421_N1024_smooth1.mat", "./piececonst_r421_N1024_smooth2.mat", downsampling, n_train, n_test) 50 | m_cross = pairwise_dist(qry_res, qry_res, ltt_res, ltt_res) # pairwise distance matrix for encoder and decoder 51 | network2 = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, 2, 5, YNormalizer) 52 | inputs2 = tf.keras.Input(shape=(qry_res,qry_res,trainX.shape[-1])) 53 | outputs2 = network2(inputs2) 54 | model2 = tf.keras.Model(inputs2, outputs2) 55 | model2.set_weights(model.get_weights()) 56 | 57 | testX = XNormalizer.normalize(testX) 58 | pred = tf.convert_to_tensor(model2.predict(testX), "float32") 59 | rel_err = rel_norm()(testY,pred) 60 | print(rel_err) 61 | 62 | ######### plots 63 | index = 43 64 | a = XNormalizer.denormalize(testX)[index,:,:,0] 65 | u = testY[index,:,:,0] 66 | pred = pred[index,:,:,0] 67 | 68 | abs_err = tf.math.abs(u-pred) * 10000 69 | emax = tf.math.reduce_max(abs_err) 70 | emin = tf.math.reduce_min(abs_err) 71 | print(emax, emin) 72 | amax = tf.math.reduce_max(a) 73 | amin = tf.math.reduce_min(a) 74 | umax = tf.math.reduce_max(u) 75 | umin = tf.math.reduce_min(u) 76 | print(amax, amin, umax, umin) 77 | 78 | # plot the contours 79 | plt.figure(figsize=(17,4),dpi=300) 80 | plt.subplot(141) 81 | plt.imshow(a, vmax=12, vmin=3, interpolation='spline16', cmap='jet') 82 | plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[3, 12], format='%1.0f') 83 | plt.axis('off') 84 | plt.axis("equal") 85 | plt.title('Permeability') 86 | 87 | plt.subplot(142) 88 | plt.imshow(u, vmax=umax, vmin=umin, interpolation='spline16', cmap='jet') 89 | plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[umin, umax], format='%1.3f') 90 | plt.axis('off') 91 | plt.axis("equal") 92 | plt.title('Reference') 93 | 94 | plt.subplot(143) 95 | plt.imshow(pred, vmax=umax, vmin=umin, interpolation='spline16', cmap='jet') 96 | plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[umin, umax], format='%1.3f') 97 | plt.axis('off') 98 | plt.axis("equal") 99 | plt.title('Prediction') 100 | 101 | plt.subplot(144) 102 | plt.imshow(abs_err, vmax=emax, vmin=0, interpolation='spline16', cmap='jet') 103 | plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[0, emax], format='%1.3f') 104 | plt.axis('off') 105 | plt.axis("equal") 106 | plt.title('Absolute error ('+r"$\times 10^{-4}$"+")") 107 | plt.savefig('./prediction.pdf') 108 | plt.close() 109 | 110 | 111 | -------------------------------------------------------------------------------- /tensorflow/3_Darcy2D/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | from scipy.io import savemat 8 | import matplotlib.pyplot as plt 9 | from time import time 10 | 11 | ### custom imports 12 | from utils import * 13 | 14 | # params ##################################33 15 | n_epochs = 500 16 | lr = 0.001 17 | batch_size = 8 18 | encode_dim = 128 19 | out_dim = 1 20 | n_head = 2 21 | n_train = 1024 22 | n_test = 100 23 | downsampling = 2 24 | qry_res = int((421-1)/downsampling+1) 25 | ltt_res = 32 26 | en_loc = 2 27 | de_loc = 5 28 | # load dataset 29 | trainX, trainY, testX, testY= load_data("./piececonst_r421_N1024_smooth1.mat", "./piececonst_r421_N1024_smooth2.mat", downsampling, n_train, n_test) 30 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 31 | 32 | ### normalize input data 33 | XNormalizer = PixelWiseNormalization(trainX) 34 | trainX = XNormalizer.normalize(trainX) 35 | testX = XNormalizer.normalize(testX) 36 | YNormalizer = PixelWiseNormalization(trainY) 37 | 38 | ### define a model 39 | m_cross = pairwise_dist(qry_res, qry_res, ltt_res, ltt_res) # pairwise distance matrix for encoder and decoder 40 | m_latent = pairwise_dist(ltt_res, ltt_res, ltt_res, ltt_res) # pairwise distance matrix for processor 41 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc, YNormalizer) 42 | inputs = tf.keras.Input(shape=(qry_res,qry_res,trainX.shape[-1])) 43 | outputs = network(inputs) 44 | model = tf.keras.Model(inputs, outputs) 45 | network.summary() 46 | ### compile model 47 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 48 | loss=rel_norm(), 49 | jit_compile=True) 50 | 51 | ### fit model 52 | start = time() 53 | train_history = model.fit(trainX, trainY, 54 | batch_size, n_epochs, verbose=1, 55 | validation_data=(testX, testY), 56 | validation_batch_size=50) 57 | end = time() 58 | print(' ') 59 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 60 | 61 | print(' ') 62 | ### save model 63 | model.save_weights('./checkpoints/my_checkpoint') 64 | ### training history 65 | loss_hist = train_history.history 66 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 67 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 68 | savetxt('./training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 69 | 70 | plt.plot(train_loss, label='train') 71 | plt.plot(test_loss, label='test') 72 | plt.legend() 73 | plt.yscale('log', base=10) 74 | plt.savefig('training_history.png') 75 | plt.close() 76 | -------------------------------------------------------------------------------- /tensorflow/3_Darcy2D/utils.py: -------------------------------------------------------------------------------- 1 | from numpy import mean, std, concatenate, newaxis, zeros 2 | from scipy.io import loadmat 3 | import tensorflow as tf 4 | tf.keras.utils.set_random_seed(0) 5 | tf.config.experimental.enable_op_determinism() 6 | tf.random.set_seed(0) 7 | physical_devices = tf.config.list_physical_devices('GPU') 8 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 9 | import tensorflow_probability as tfp 10 | from math import pi 11 | 12 | class PixelWiseNormalization(): 13 | def __init__(self, x, eps=1e-5): 14 | 15 | self.mean = tf.math.reduce_mean(x, axis=0, keepdims=True) #(1,h,w,1) 16 | self.std = tf.math.reduce_std(x, axis=0, keepdims=True) #(1,h,w,1) 17 | self.eps = eps 18 | 19 | def normalize(self, x): 20 | try: 21 | x = (x - self.mean) / (self.std + self.eps) 22 | except:#do upsampling 23 | h = x.shape[1] 24 | w = x.shape[2] 25 | mean = tf.image.resize(self.mean, (h,w), method='bilinear') 26 | std = tf.image.resize(self.std, (h,w), method='bilinear') 27 | x = (x - mean) / (std + self.eps) 28 | return x 29 | 30 | def denormalize(self, x): 31 | 32 | try: 33 | x = x * (self.std + self.eps) + self.mean 34 | except:#do upsampling 35 | h = x.shape[1] 36 | w = x.shape[2] 37 | mean = tf.image.resize(self.mean, (h,w), method='bilinear') 38 | std = tf.image.resize(self.std, (h,w), method='bilinear') 39 | x = x * (std + self.eps) + mean 40 | return x 41 | 42 | class rel_norm(tf.keras.losses.Loss): 43 | ''' 44 | Compute the average relative l2 loss between a batch of targets and predictions 45 | ''' 46 | def __init__(self): 47 | super().__init__() 48 | def call(self, true, pred): 49 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1,))(true-pred), axis=1), tf.norm(tf.keras.layers.Reshape((-1,))(true), axis=1)) 50 | return tf.math.reduce_mean(rel_error) 51 | 52 | 53 | def pairwise_dist(res1x, res1y, res2x, res2y): 54 | 55 | gridx = tf.reshape(tf.linspace(0, 1, res1x+1)[:-1], (1,-1,1)) 56 | gridx = tf.tile(gridx, [res1y,1,1]) 57 | gridy = tf.reshape(tf.linspace(0, 1, res1y+1)[:-1], (-1,1,1)) 58 | gridy = tf.tile(gridy, [1,res1x,1]) 59 | grid1 = tf.concat([gridx, gridy], axis=-1) 60 | grid1 = tf.reshape(grid1, (res1x*res1y,2)) #(res1*res1,2) 61 | 62 | gridx = tf.reshape(tf.linspace(0, 1, res2x+1)[:-1], (1,-1,1)) 63 | gridx = tf.tile(gridx, [res2y,1,1]) 64 | gridy = tf.reshape(tf.linspace(0, 1, res2y+1)[:-1], (-1,1,1)) 65 | gridy = tf.tile(gridy, [1,res2x,1]) 66 | grid2 = tf.concat([gridx, gridy], axis=-1) 67 | grid2 = tf.reshape(grid2, (res2x*res2y,2)) #(res2*res2,2) 68 | 69 | print(grid1.shape, grid2.shape) 70 | grid1 = tf.tile(tf.expand_dims(grid1, 1), [1,grid2.shape[0],1]) 71 | grid2 = tf.tile(tf.expand_dims(grid2, 0), [grid1.shape[0],1,1]) 72 | 73 | dist = tf.norm(grid1-grid2, axis=-1) 74 | print(dist.shape, tf.math.reduce_max(dist)) 75 | dist2 = tf.cast(dist**2, 'float32') 76 | return dist2/2.0 77 | 78 | def load_data(train_path, test_path, downsampling, ntrain, ntest): 79 | 80 | s = int(((421 - 1)/downsampling) + 1) 81 | 82 | train = loadmat(train_path) 83 | a = train["coeff"].astype('float32') 84 | u = train["sol"].astype('float32') 85 | 86 | trainX = a[:ntrain,::downsampling,::downsampling][:,:s,:s] 87 | trainY = u[:ntrain,::downsampling,::downsampling][:,:s,:s] 88 | 89 | test = loadmat(test_path) 90 | a = test["coeff"].astype('float32') 91 | u = test["sol"].astype('float32') 92 | testX = a[:ntest,::downsampling,::downsampling][:,:s,:s] 93 | testY = u[:ntest,::downsampling,::downsampling][:,:s,:s] 94 | return tf.convert_to_tensor(trainX[...,newaxis]), tf.convert_to_tensor(trainY[...,newaxis]), tf.convert_to_tensor(testX[...,newaxis]), tf.convert_to_tensor(testY[...,newaxis]) 95 | 96 | class mlp(tf.keras.layers.Layer): 97 | ''' 98 | A two-layer MLP with GELU activation. 99 | ''' 100 | def __init__(self, n_filters1, n_filters2): 101 | super(mlp, self).__init__() 102 | 103 | self.width1 = n_filters1 104 | self.width2 = n_filters2 105 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 106 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 107 | 108 | def call(self, inputs): 109 | x = self.mlp1(inputs) 110 | x = self.mlp2(x) 111 | return x 112 | 113 | def get_config(self): 114 | config = { 115 | 'n_filters1': self.width1, 116 | 'n_filters2': self.width2, 117 | } 118 | return config 119 | 120 | class MultiHeadPosAtt(tf.keras.layers.Layer): 121 | ''' 122 | Global, local and cross variants of the multi-head position-attention mechanism. 123 | ''' 124 | def __init__(self, m_dist, n_head, hid_dim, locality): 125 | super(MultiHeadPosAtt, self).__init__() 126 | ''' 127 | m_dist: distance matrix 128 | n_head: number of attention heads 129 | hid_dim: encoding dimension 130 | locality: quantile parameter to customize receptive field in position-attention 131 | ''' 132 | self.dist = m_dist 133 | self.locality = locality 134 | self.hid_dim = hid_dim 135 | self.n_head = n_head 136 | self.v_dim = round(self.hid_dim/self.n_head) 137 | 138 | def build(self, input_shape): 139 | 140 | self.r = self.add_weight( 141 | shape=(self.n_head, 1, 1), 142 | trainable=True, 143 | # constraint=tf.keras.constraints.NonNeg(), 144 | name="band_width", 145 | ) 146 | 147 | self.weight = self.add_weight( 148 | shape=(self.n_head, input_shape[-1], self.v_dim), 149 | initializer="he_normal", 150 | trainable=True, 151 | name="weight", 152 | ) 153 | self.built = True 154 | 155 | def call(self, inputs): 156 | scaled_dist = self.dist * tf.math.tan(0.25*pi*(1-1e-7)*(1.0+tf.math.sin(self.r)))# tan(0.25*pi*(1+sin(r))) leads to higher accuracy than using r^2 and tan(r) # (n_head, L, L) 157 | if self.locality <= 100: 158 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 159 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 160 | else: 161 | pass 162 | scaled_dist = - scaled_dist 163 | att = tf.nn.softmax(scaled_dist, axis=2) #(n_heads, L, L) 164 | 165 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) # (batch_size, n_head, L, v_dim) 166 | 167 | concat = tf.einsum("hnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 168 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 169 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 170 | return tf.keras.activations.gelu(concat) 171 | 172 | def get_config(self): 173 | config = { 174 | 'm_dist': self.dist, 175 | 'hid_dim': self.hid_dim, 176 | 'n_head': self.n_head, 177 | 'locality': self.locality 178 | } 179 | return config 180 | 181 | class PiT(tf.keras.Model): 182 | ''' 183 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 184 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 185 | ''' 186 | def __init__(self, m_cross, m_small, out_dim, hid_dim, n_head, locality_encoder, locality_decoder, YNormalizer): 187 | super(PiT, self).__init__() 188 | ''' 189 | m_cross: distance matrix between X_query and X_latent; (L_qry,L_ltt) 190 | m_small: distance matrix between X_latent and X_latent; (L_ltt,L_ltt) 191 | out_dim: number of variables 192 | hid_dim: encoding dimension (network width) 193 | n_head: number of heads in multi-head attention modules 194 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive field 195 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive field 196 | YNormalizer: PixelWiseNormalization object with mean and std of the training data 197 | ''' 198 | self.m_cross = m_cross 199 | self.m_small = m_small 200 | self.res = int(m_cross.shape[0]**0.5) 201 | self.out_dim = out_dim 202 | self.hid_dim = hid_dim 203 | self.n_head = n_head 204 | self.en_local = locality_encoder 205 | self.de_local = locality_decoder 206 | self.n_blocks = 4 207 | self.YNormalizer = YNormalizer 208 | 209 | # Encoder 210 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 211 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 212 | 213 | # Processor 214 | self.PA = [MultiHeadPosAtt(self.m_small, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 215 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 216 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 217 | 218 | # Decoder 219 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 220 | self.de_layer = mlp(self.hid_dim, self.out_dim) 221 | 222 | def call(self, inputs): 223 | 224 | # Encoder 225 | grid = self.get_mesh(inputs) #(batch_size, res_qry, res_qry, 2) 226 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) # (batch_size, res_qry, res_qry, input_dim+2) 227 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) # (batch_size, L_qry, hid_dim) 228 | en = self.en_layer(en) # (batch_size, L_qry, hid_dim) 229 | x = self.down(en) # (batch_size, L_ltt, hid_dim) 230 | 231 | # Processor 232 | for i in range(self.n_blocks): 233 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) # (batch_size, L_ltt, hid_dim) 234 | x = tf.keras.activations.gelu(x) # (batch_size, L_ltt, hid_dim) 235 | 236 | # Decoder 237 | de = self.up(x) # (batch_size, L_qry, hid_dim) 238 | de = self.de_layer(de) # (batch_size, L_qry, hid_dim) 239 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) # (batch_size, res_qry, res_qry, out_dim) 240 | return self.YNormalizer.denormalize(de) 241 | 242 | def get_mesh(self, inputs): 243 | size_x = size_y = self.res 244 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 245 | gridx = tf.tile(gridx, [1,1,size_y,1]) 246 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 247 | gridy = tf.tile(gridy, [1,size_x,1,1]) 248 | grid = tf.concat([gridx, gridy], axis=-1) 249 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 250 | 251 | def get_config(self): 252 | config = { 253 | 'm_cross': self.m_cross, 254 | 'm_small': self.m_small, 255 | 'out_dim': self.out_dim, 256 | 'hid_dim': self.hid_dim, 257 | 'n_head': self.n_head, 258 | 'locality_encoder': self.en_local, 259 | 'locality_decoder': self.de_local, 260 | 'YNormalizer': self.YNormalizer 261 | } 262 | return config 263 | 264 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 265 | ''' 266 | Scaled dot-product multi-head self-attention 267 | ''' 268 | def __init__(self, n_head, hid_dim): 269 | super(MultiHeadSelfAtt, self).__init__() 270 | 271 | self.hid_dim = hid_dim 272 | self.n_head = n_head 273 | self.v_dim = round(self.hid_dim/self.n_head) 274 | 275 | def build(self, input_shape): 276 | 277 | self.q = self.add_weight( 278 | shape=(self.n_head, input_shape[-1], self.v_dim), 279 | initializer="he_normal", 280 | trainable=True, 281 | name="query", 282 | ) 283 | 284 | self.k = self.add_weight( 285 | shape=(self.n_head, input_shape[-1], self.v_dim), 286 | initializer="he_normal", 287 | trainable=True, 288 | name="key", 289 | ) 290 | 291 | self.v = self.add_weight( 292 | shape=(self.n_head, input_shape[-1], self.v_dim), 293 | initializer="he_normal", 294 | trainable=True, 295 | name="value", 296 | ) 297 | self.built = True 298 | 299 | def call(self, inputs): 300 | 301 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q) # (batch_size, n_head, L, v_dim) 302 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k) # (batch_size, n_head, L, v_dim) 303 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1) # (batch_size, n_heads, L, L) 304 | 305 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v) #(batch_size, n_head, L, v_dim) 306 | 307 | concat = tf.einsum("...nj,...jd->...nd", att, value) #(batch_size, n_head, L, v_dim) 308 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 309 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 310 | return tf.keras.activations.gelu(concat) 311 | 312 | def get_config(self): 313 | config = { 314 | 'n_head': self.n_head, 315 | 'hid_dim': self.hid_dim 316 | } 317 | return config 318 | 319 | class LiteTransformer(tf.keras.Model): 320 | ''' 321 | Replace position-attention of the Processor in a PiT with self-attention 322 | ''' 323 | def __init__(self, m_cross, res, out_dim, hid_dim, n_head, en_local, de_local, YNormalizer): 324 | super(LiteTransformer, self).__init__() 325 | 326 | self.m_cross = m_cross 327 | self.res = res 328 | self.out_dim = out_dim 329 | self.hid_dim = hid_dim 330 | self.n_head = n_head 331 | self.en_local = en_local 332 | self.de_local = de_local 333 | self.YNormalizer = YNormalizer 334 | self.n_blocks = 4 335 | 336 | # Encoder 337 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 338 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 339 | 340 | # Processor 341 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 342 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 343 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 344 | 345 | # Decoder 346 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 347 | self.de_layer = mlp(self.hid_dim, self.out_dim) 348 | 349 | def call(self, inputs): 350 | 351 | # Encoder 352 | grid = self.get_mesh(inputs) 353 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 354 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 355 | en = self.en_layer(en) 356 | x = self.down(en) 357 | 358 | # Processor 359 | for i in range(self.n_blocks): 360 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 361 | x = tf.keras.activations.gelu(x) 362 | 363 | # Decoder 364 | de = self.up(x) 365 | de = self.de_layer(de) 366 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 367 | return self.YNormalizer.denormalize(de) 368 | 369 | def get_mesh(self, inputs): 370 | size_x = size_y = self.res 371 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 372 | gridx = tf.tile(gridx, [1,1,size_y,1]) 373 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 374 | gridy = tf.tile(gridy, [1,size_x,1,1]) 375 | grid = tf.concat([gridx, gridy], axis=-1) 376 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 377 | 378 | def get_config(self): 379 | config = { 380 | 'm_cross': self.m_cross, 381 | 'res': self.res, 382 | 'out_dim': self.out_dim, 383 | 'hid_dim': self.hid_dim, 384 | 'n_head': self.n_head, 385 | 'locality_encoder': self.en_local, 386 | 'locality_decoder': self.de_local, 387 | 'YNormalizer': self.YNormalizer 388 | } 389 | return config 390 | 391 | class Transformer(tf.keras.Model): 392 | ''' 393 | Replace position-attention of a PiT with self-attention. 394 | ''' 395 | def __init__(self, res, out_dim, hid_dim, n_head, YNormalizer): 396 | super(Transformer, self).__init__() 397 | 398 | self.res = res 399 | self.out_dim = out_dim 400 | self.hid_dim = hid_dim 401 | self.n_head = n_head 402 | self.n_blocks = 4 403 | self.YNormalizer = YNormalizer 404 | 405 | # Encoder 406 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 407 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 408 | 409 | # Processor 410 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 411 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 412 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 413 | 414 | # Decoder 415 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 416 | self.de_layer = mlp(self.hid_dim, self.out_dim) 417 | 418 | def call(self, inputs): 419 | 420 | # Encoder 421 | grid = self.get_mesh(inputs) 422 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 423 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 424 | en = self.en_layer(en) 425 | x = self.down(en) 426 | 427 | # Processor 428 | for i in range(self.n_blocks): 429 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 430 | x = tf.keras.activations.gelu(x) 431 | 432 | # Decoder 433 | de = self.up(x) 434 | de = self.de_layer(de) 435 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 436 | return self.YNormalizer.denormalize(de) 437 | 438 | def get_mesh(self, inputs): 439 | size_x = size_y = self.res 440 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 441 | gridx = tf.tile(gridx, [1,1,size_y,1]) 442 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 443 | gridy = tf.tile(gridy, [1,size_x,1,1]) 444 | grid = tf.concat([gridx, gridy], axis=-1) 445 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 446 | 447 | def get_config(self): 448 | config = { 449 | 'res':self.res, 450 | 'out_dim': self.out_dim, 451 | 'hid_dim': self.hid_dim, 452 | 'n_head': self.n_head, 453 | 'YNormalizer': self.YNormalizer 454 | } 455 | return config 456 | 457 | 458 | -------------------------------------------------------------------------------- /tensorflow/4_Vorticity/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | from scipy.io import loadmat 6 | from utils import *#load_data, pairwise_dist, PixelWiseNormalization 7 | from numpy import zeros, mean, savetxt 8 | from numpy.linalg import norm 9 | import matplotlib.pyplot as plt 10 | 11 | # params ##################################33 12 | 13 | # load dataset 14 | data = loadmat('./pred.mat') 15 | testY = data['true'] 16 | pred = data['pred'] 17 | rel_err = rel_norm_traj(testY,pred) 18 | print(rel_err.numpy()) 19 | err = testY - pred 20 | T = 20 21 | rel_err = norm(err.reshape(200,-1,T), ord=2, axis=1) / norm(testY.reshape(200,-1,T), ord=2, axis=1) 22 | rel_err = mean(rel_err, axis=0) 23 | plt.figure(figsize=(16,9),dpi=100) 24 | plt.plot(tf.range(11,31), rel_err) 25 | plt.xlabel('t') 26 | plt.savefig('./rel_err.pdf') 27 | savetxt('./rel_err.csv', rel_err) 28 | ############## plots 29 | index = 0 30 | abs_err = abs(err[index,...]) 31 | emax = abs_err.max() 32 | emin = abs_err.min() 33 | print(emax, emin) 34 | omega = testY[index,...] 35 | vmax = omega.max() 36 | vmin = omega.min() 37 | print(vmax, vmin) 38 | omega_p = pred[index,...] 39 | 40 | for i in range(T): 41 | # plot the contours 42 | plt.figure(figsize=(4,4),dpi=300) 43 | plt.axes([0,0,1,1]) 44 | plt.imshow(omega[...,i], vmax=vmax, vmin=vmin, interpolation='spline16', cmap='jet') 45 | plt.axis('off') 46 | plt.axis('equal') 47 | plt.savefig('./testY_{}.pdf'.format(i+1)) 48 | plt.close() 49 | 50 | plt.figure(figsize=(4,4),dpi=300) 51 | plt.axes([0,0,1,1]) 52 | plt.imshow(omega_p[...,i], vmax=vmax, vmin=vmin, interpolation='spline16', cmap='jet') 53 | plt.axis('off') 54 | plt.axis('equal') 55 | plt.savefig('./pred_{}.pdf'.format(i+1)) 56 | plt.close() 57 | 58 | plt.figure(figsize=(4,4),dpi=300) 59 | plt.axes([0,0,1,1]) 60 | plt.imshow(abs_err[...,i], vmax=emax, vmin=emin, interpolation='spline16', cmap='jet') 61 | plt.axis('off') 62 | plt.axis('equal') 63 | plt.savefig('./err_{}.pdf'.format(i+1)) 64 | plt.close() 65 | 66 | 67 | -------------------------------------------------------------------------------- /tensorflow/4_Vorticity/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | import matplotlib.pyplot as plt 8 | from time import time 9 | 10 | ### custom imports 11 | from utils import * 12 | 13 | # params ##################################33 14 | n_epochs = 500 15 | lr = 0.001 16 | batch_size = 8 17 | encode_dim = 256 18 | out_dim = 1 19 | n_head = 1 20 | n_train = 1000 21 | n_test = 200 22 | steps = 20 23 | en_loc = 1 24 | de_loc = 8 25 | # load dataset 26 | trainX, trainY, testX, testY= load_data("./NavierStokes_V1e-4_N1200_T30.mat", n_train, n_test) 27 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 28 | 29 | # define a model 30 | m_cross = pairwise_dist(64, 64, 16, 16) # pairwise distance matrix for encoder and decoder 31 | m_latent = pairwise_dist(16, 16, 16, 16) # pairwise distance matrix for processor 32 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc) 33 | r_network = reccurent_PiT(network, steps) 34 | inputs = tf.keras.Input(shape=(64,64,trainX.shape[-1])) 35 | outputs = r_network(inputs) 36 | model = tf.keras.Model(inputs, outputs) 37 | r_network.summary() 38 | ### compile model 39 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 40 | loss=rel_norm_step(steps), 41 | jit_compile=True) 42 | 43 | ### fit model 44 | start = time() 45 | train_history = model.fit(trainX, trainY, 46 | batch_size, n_epochs, verbose=1, 47 | validation_data=(testX, testY), 48 | validation_batch_size=20) 49 | end = time() 50 | print(' ') 51 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 52 | 53 | print(' ') 54 | ### save model 55 | model.save_weights('./checkpoints/my_checkpoint') 56 | ### training history 57 | loss_hist = train_history.history 58 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 59 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 60 | savetxt('./training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 61 | 62 | plt.plot(train_loss, label='train') 63 | plt.plot(test_loss, label='test') 64 | plt.legend() 65 | plt.yscale('log', base=10) 66 | plt.savefig('training_history.png') 67 | plt.close() 68 | 69 | ### evaluation, visualization 70 | from scipy.io import savemat 71 | pred = model.predict(testX) 72 | print(rel_norm_traj(testY,pred)) 73 | savemat("pred.mat", mdict={"pred":pred, "true":testY}) 74 | -------------------------------------------------------------------------------- /tensorflow/4_Vorticity/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | import tensorflow as tf 3 | tf.keras.utils.set_random_seed(0) 4 | tf.config.experimental.enable_op_determinism() 5 | tf.random.set_seed(0) 6 | physical_devices = tf.config.list_physical_devices('GPU') 7 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 8 | import tensorflow_probability as tfp 9 | from math import pi 10 | 11 | class rel_norm_step(tf.keras.losses.Loss): 12 | ''' 13 | Compute the average relative l2 loss between a batch of targets and predictions, step wise 14 | ''' 15 | def __init__(self, steps): 16 | super(rel_norm_step, self).__init__() 17 | self.steps = steps 18 | 19 | def call(self, true, pred): 20 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1, self.steps))(true-pred), axis=1), tf.norm(tf.keras.layers.Reshape((-1, self.steps))(true), axis=1)) # step-wise relative l2 loss# 21 | return tf.math.reduce_mean(rel_error) 22 | 23 | def get_config(self): 24 | config = { 25 | 'steps': self.steps 26 | } 27 | return config 28 | 29 | def rel_norm_traj(true, pred): 30 | ''' 31 | Compute the average relative l2 loss between a batch of targets and predictions, full trajectory 32 | ''' 33 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1,))(true-pred), axis=1), tf.norm(tf.keras.layers.Reshape((-1,))(true), axis=1)) 34 | return tf.math.reduce_mean(rel_error) 35 | 36 | 37 | def pairwise_dist(res1x, res1y, res2x, res2y): 38 | 39 | gridx = tf.reshape(tf.linspace(0, 1, res1x+1)[:-1], (1,-1,1)) 40 | gridx = tf.tile(gridx, [res1y,1,1]) 41 | gridy = tf.reshape(tf.linspace(0, 1, res1y+1)[:-1], (-1,1,1)) 42 | gridy = tf.tile(gridy, [1,res1x,1]) 43 | grid1 = tf.concat([gridx, gridy], axis=-1) 44 | grid1 = tf.reshape(grid1, (res1x*res1y,2)) #(res1*res1,2) 45 | 46 | gridx = tf.reshape(tf.linspace(0, 1, res2x+1)[:-1], (1,-1,1)) 47 | gridx = tf.tile(gridx, [res2y,1,1]) 48 | gridy = tf.reshape(tf.linspace(0, 1, res2y+1)[:-1], (-1,1,1)) 49 | gridy = tf.tile(gridy, [1,res2x,1]) 50 | grid2 = tf.concat([gridx, gridy], axis=-1) 51 | grid2 = tf.reshape(grid2, (res2x*res2y,2)) #(res2*res2,2) 52 | 53 | print(grid1.shape, grid2.shape) 54 | grid1 = tf.tile(tf.expand_dims(grid1, 1), [1,grid2.shape[0],1]) 55 | grid2 = tf.tile(tf.expand_dims(grid2, 0), [grid1.shape[0],1,1]) 56 | 57 | dist = tf.math.minimum(tf.norm(grid1-grid2, axis=-1), tf.norm(grid1+tf.constant([[1,0]], dtype='float64')-grid2, axis=-1)) 58 | dist = tf.math.minimum(dist, tf.norm(grid1+tf.constant([[-1,0]], dtype='float64')-grid2, axis=-1)) 59 | dist = tf.math.minimum(dist, tf.norm(grid1+tf.constant([[0,1]], dtype='float64')-grid2, axis=-1)) 60 | dist = tf.math.minimum(dist, tf.norm(grid1+tf.constant([[0,-1]], dtype='float64')-grid2, axis=-1)) 61 | dist2 = tf.cast(dist**2, 'float32') 62 | return dist2 63 | 64 | def load_data(file_path, ntrain, ntest): 65 | 66 | try: 67 | data = loadmat(file_path) 68 | except: 69 | import mat73 70 | data = mat73.loadmat(file_path) 71 | flow = data['u'].astype('float32') 72 | print(flow.shape) 73 | del data 74 | 75 | trainX = flow[:ntrain,:,:,:10] # ntrain 1000 76 | trainY = flow[:ntrain,:,:,10:30] 77 | testX = flow[-ntest:,:,:,:10] 78 | testY = flow[-ntest:,:,:,10:30] 79 | 80 | del flow 81 | return trainX, trainY, testX, testY 82 | 83 | 84 | class mlp(tf.keras.layers.Layer): 85 | ''' 86 | A two-layer MLP with GELU activation. 87 | ''' 88 | def __init__(self, n_filters1, n_filters2): 89 | super(mlp, self).__init__() 90 | 91 | self.width1 = n_filters1 92 | self.width2 = n_filters2 93 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 94 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 95 | 96 | def call(self, inputs): 97 | x = self.mlp1(inputs) 98 | x = self.mlp2(x) 99 | return x 100 | 101 | def get_config(self): 102 | config = { 103 | 'n_filters1': self.width1, 104 | 'n_filters2': self.width2 105 | } 106 | return config 107 | 108 | class reccurent_PiT(tf.keras.Model): 109 | def __init__(self, network, steps): 110 | super(reccurent_PiT, self).__init__() 111 | self.PiT = network 112 | self.steps = steps 113 | 114 | def call(self, inputs): 115 | x = tf.identity(inputs) 116 | pred = x[...,-1:] 117 | for t in range(self.steps): 118 | y = self.PiT(x) 119 | pred = tf.concat([pred,y], axis=-1) 120 | x = tf.concat([x[...,1:],y], axis=-1) 121 | return pred[...,1:] 122 | def get_config(self): 123 | config = super(reccurent_PiT, self).get_config() 124 | config.update({ 125 | 'network': self.PiT.get_config(), # Store the config of the network instead of the network itself 126 | 'steps': self.steps, 127 | }) 128 | return config 129 | 130 | @classmethod 131 | def from_config(cls, config): 132 | network_config = config.pop('network') 133 | network = tf.keras.Model.from_config(network_config) 134 | return cls(network=network, **config) 135 | 136 | class MultiHeadPosAtt(tf.keras.layers.Layer): 137 | ''' 138 | Global, local and cross variants of the multi-head position-attention mechanism. 139 | ''' 140 | def __init__(self, m_dist, n_head, hid_dim, locality): 141 | super(MultiHeadPosAtt, self).__init__() 142 | ''' 143 | m_dist: distance matrix 144 | n_head: number of attention heads 145 | hid_dim: encoding dimension 146 | locality: quantile parameter to customize receptive field in position-attention 147 | ''' 148 | self.dist = m_dist 149 | self.locality = locality 150 | self.hid_dim = hid_dim 151 | self.n_head = n_head 152 | self.v_dim = round(self.hid_dim/self.n_head) 153 | 154 | def build(self, input_shape): 155 | 156 | self.r = self.add_weight( 157 | shape=(self.n_head, 1, 1), 158 | trainable=True, 159 | name="band_width", 160 | ) 161 | 162 | self.weight = self.add_weight( 163 | shape=(self.n_head, input_shape[-1], self.v_dim), 164 | initializer="he_normal", 165 | trainable=True, 166 | name="weight", 167 | ) 168 | self.built = True 169 | 170 | def call(self, inputs): 171 | scaled_dist = self.dist * tf.math.tan(0.25*pi*(1-1e-7)*(1.0+tf.math.sin(self.r)))# tan(0.25*pi*(1+sin(r))) leads to higher accuracy than using r^2 and tan(r) # (n_head, L, L) 172 | if self.locality <= 100: 173 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 174 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 175 | else: 176 | pass 177 | scaled_dist = - scaled_dist # (n_head, L, L) 178 | att = tf.nn.softmax(scaled_dist, axis=2) # (n_heads, L, L) 179 | 180 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) # (batch_size, n_head, L, v_dim) 181 | 182 | concat = tf.einsum("hnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 183 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 184 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 185 | return tf.keras.activations.gelu(concat) 186 | 187 | def get_config(self): 188 | config = { 189 | 'm_dist': self.dist, 190 | 'hid_dim': self.hid_dim, 191 | 'n_head': self.n_head, 192 | 'locality': self.locality 193 | } 194 | return config 195 | 196 | class PiT(tf.keras.layers.Layer): 197 | ''' 198 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 199 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 200 | ''' 201 | def __init__(self, m_cross, m_ltt, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 202 | super(PiT, self).__init__() 203 | ''' 204 | m_cross: distance matrix between X_query and X_latent; (L_qry,L_ltt) 205 | m_ltt: distance matrix between X_latent and X_latent; (L_ltt,L_ltt) 206 | out_dim: number of variables 207 | hid_dim: encoding dimension (network width) 208 | n_head: number of heads in multi-head attention modules 209 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive filed 210 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive filed 211 | ''' 212 | self.m_cross = m_cross 213 | self.m_ltt = m_ltt 214 | self.res = int(m_cross.shape[0]**0.5) 215 | self.out_dim = out_dim 216 | self.hid_dim = hid_dim 217 | self.n_head = n_head 218 | self.en_local = locality_encoder 219 | self.de_local = locality_decoder 220 | self.n_blocks = 4 # number of position-attention modules in the Processor 221 | 222 | # Encoder 223 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 224 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 225 | 226 | # Processor 227 | self.MHPA = [MultiHeadPosAtt(self.m_ltt, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 228 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 229 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 230 | 231 | # Decoder 232 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 233 | self.de_layer = mlp(self.hid_dim, self.out_dim) 234 | 235 | def call(self, inputs): 236 | 237 | # Encoder 238 | grid = self.get_mesh(inputs) #(batch_size, res_qry, res_qry, 2) 239 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) # (batch_size, res_qry, res_qry, input_dim+2) 240 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) # (batch_size, L_qry, hid_dim) 241 | en = self.en_layer(en) # (batch_size, L_qry, hid_dim) 242 | x = self.down(en) # (batch_size, L_ltt, hid_dim) 243 | 244 | # Processor 245 | for i in range(self.n_blocks): 246 | x = self.MLP[i](self.MHPA[i](x)) + self.W[i](x) # (batch_size, L_ltt, hid_dim) 247 | x = tf.keras.activations.gelu(x) # (batch_size, L_ltt, hid_dim) 248 | 249 | # Decoder 250 | de = self.up(x) # (batch_size, L_qry, hid_dim) 251 | de = self.de_layer(de) # (batch_size, L_qry, out_dim) 252 | de = tf.keras.layers.Reshape((self.res, self.res, self.out_dim))(de) # (batch_size, res_qry, res_qry, out_dim) 253 | return de 254 | 255 | def get_mesh(self, inputs): 256 | size_x = size_y = self.res 257 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 258 | gridx = tf.tile(gridx, [1,1,size_y,1]) 259 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 260 | gridy = tf.tile(gridy, [1,size_x,1,1]) 261 | grid = tf.concat([gridx, gridy], axis=-1) 262 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 263 | 264 | def get_config(self): 265 | config = { 266 | 'm_cross': self.m_cross, 267 | 'm_ltt': self.m_ltt, 268 | 'out_dim': self.out_dim, 269 | 'hid_dim': self.hid_dim, 270 | 'n_head': self.n_head, 271 | 'locality_encoder': self.en_local, 272 | 'locality_decoder': self.de_local 273 | } 274 | return config 275 | 276 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 277 | ''' 278 | Scaled dot-product multi-head self-attention 279 | ''' 280 | def __init__(self, n_head, hid_dim): 281 | super(MultiHeadSelfAtt, self).__init__() 282 | 283 | self.hid_dim = hid_dim 284 | self.n_head = n_head 285 | self.v_dim = round(self.hid_dim/self.n_head) 286 | 287 | def build(self, input_shape): 288 | 289 | self.q = self.add_weight( 290 | shape=(self.n_head, input_shape[-1], self.v_dim), 291 | initializer="he_normal", 292 | trainable=True, 293 | name="query", 294 | ) 295 | 296 | self.k = self.add_weight( 297 | shape=(self.n_head, input_shape[-1], self.v_dim), 298 | initializer="he_normal", 299 | trainable=True, 300 | name="key", 301 | ) 302 | 303 | self.v = self.add_weight( 304 | shape=(self.n_head, input_shape[-1], self.v_dim), 305 | initializer="he_normal", 306 | trainable=True, 307 | name="value", 308 | ) 309 | self.built = True 310 | 311 | def call(self, inputs): 312 | 313 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q) # (batch_size, n_head, L, v_dim) 314 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k) # (batch_size, n_head, L, v_dim) 315 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1) # (batch_size, n_heads, L, L) 316 | 317 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v) #(batch_size, n_head, L, v_dim) 318 | 319 | concat = tf.einsum("...nj,...jd->...nd", att, value) #(batch_size, n_head, L, v_dim) 320 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 321 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 322 | return tf.keras.activations.gelu(concat) 323 | 324 | def get_config(self): 325 | config = { 326 | 'n_head': self.n_head, 327 | 'hid_dim': self.hid_dim 328 | } 329 | return config 330 | 331 | class LiteTransformer(tf.keras.Model): 332 | ''' 333 | Replace position-attention of the Processor in a PiT with self-attention 334 | ''' 335 | def __init__(self, m_cross, res, out_dim, hid_dim, n_head, en_local, de_local): 336 | super(LiteTransformer, self).__init__() 337 | 338 | self.m_cross = m_cross 339 | self.res = res 340 | self.out_dim = out_dim 341 | self.hid_dim = hid_dim 342 | self.n_head = n_head 343 | self.en_local = en_local 344 | self.de_local = de_local 345 | self.n_blocks = 4 346 | 347 | # Encoder 348 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 349 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 350 | 351 | # Processor 352 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 353 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 354 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 355 | 356 | # Decoder 357 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 358 | self.de_layer = mlp(self.hid_dim, self.out_dim) 359 | 360 | def call(self, inputs): 361 | 362 | # Encoder 363 | grid = self.get_mesh(inputs) #(b, s1, s2, 2) 364 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 365 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 366 | en = self.en_layer(en) 367 | x = self.down(en) 368 | 369 | # Processor 370 | for i in range(self.n_blocks): 371 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 372 | x = tf.keras.activations.gelu(x) 373 | 374 | # Decoder 375 | de = self.up(x) 376 | de = self.de_layer(de) 377 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 378 | return de 379 | 380 | def get_mesh(self, inputs): 381 | size_x = size_y = self.res 382 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 383 | gridx = tf.tile(gridx, [1,1,size_y,1]) 384 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 385 | gridy = tf.tile(gridy, [1,size_x,1,1]) 386 | grid = tf.concat([gridx, gridy], axis=-1) 387 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 388 | 389 | def get_config(self): 390 | config = { 391 | 'm_cross': self.m_cross, 392 | 'res': self.res, 393 | 'out_dim': self.out_dim, 394 | 'hid_dim': self.hid_dim, 395 | 'n_head': self.n_head, 396 | 'locality_encoder': self.en_local, 397 | 'locality_decoder': self.de_local, 398 | } 399 | return config 400 | 401 | class Transformer(tf.keras.Model): 402 | ''' 403 | Replace position-attention of a PiT with self-attention. 404 | ''' 405 | def __init__(self, res, out_dim, hid_dim, n_head): 406 | super(Transformer, self).__init__() 407 | 408 | self.res = res 409 | self.out_dim = out_dim 410 | self.hid_dim = hid_dim 411 | self.n_head = n_head 412 | self.n_blocks = 4 413 | 414 | # Encoder 415 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 416 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 417 | 418 | # Processor 419 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 420 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 421 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 422 | 423 | # Decoder 424 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 425 | self.de_layer = mlp(self.hid_dim, self.out_dim) 426 | 427 | def call(self, inputs): 428 | 429 | # Encoder 430 | grid = self.get_mesh(inputs) #(b, s1, s2, 2) 431 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 432 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 433 | en = self.en_layer(en) 434 | x = self.down(en) 435 | 436 | # Processor 437 | for i in range(self.n_blocks): 438 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 439 | x = tf.keras.activations.gelu(x) 440 | 441 | # Decoder 442 | de = self.up(x) 443 | de = self.de_layer(de) 444 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 445 | return de 446 | 447 | def get_mesh(self, inputs): 448 | size_x = size_y = self.res 449 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 450 | gridx = tf.tile(gridx, [1,1,size_y,1]) 451 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 452 | gridy = tf.tile(gridy, [1,size_x,1,1]) 453 | grid = tf.concat([gridx, gridy], axis=-1) 454 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 455 | 456 | def get_config(self): 457 | config = { 458 | 'res':self.res, 459 | 'out_dim': self.out_dim, 460 | 'hid_dim': self.hid_dim, 461 | 'n_head': self.n_head 462 | } 463 | return config 464 | 465 | 466 | 467 | -------------------------------------------------------------------------------- /tensorflow/5_Elasticity/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | import matplotlib.pyplot as plt 6 | from time import time 7 | from utils import * 8 | 9 | # params ##################################33 10 | encode_dim = 512 11 | out_dim = 1 12 | n_head = 8 13 | n_train = 1000 14 | n_test = 200 15 | 16 | # load dataset 17 | trainX, trainY, testX, testY= load_data("./", n_train, n_test) 18 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 19 | 20 | # define a model 21 | network = PiT(out_dim, encode_dim, n_head, 2, 2) 22 | inputs = tf.keras.Input(shape=trainX.shape[1:]) 23 | outputs = network(inputs) 24 | model = tf.keras.Model(inputs, outputs) 25 | network.summary() 26 | 27 | 28 | ### load model 29 | model.load_weights('./512/checkpoints/my_checkpoint').expect_partial() 30 | 31 | ### evaluation, visualization 32 | index = 89 33 | pred = model.predict(testX) 34 | print(rel_norm()(testY, pred)) 35 | pred = pred[index:index+1,...] 36 | true = testY[index:index+1,...] 37 | 38 | err = abs(true-pred) 39 | emax = err.max() 40 | emin = err.min() 41 | vmax = true.max() 42 | vmin = true.min() 43 | print(vmax, vmin, emax, emin) 44 | 45 | plt.figure(figsize=(6,6),dpi=300) 46 | plt.axes([0,0,1,1]) 47 | x = testX[index,:,:1] 48 | y = testX[index,:,1:2] 49 | plt.scatter(x, y, c=pred, cmap="jet", s=160, vmin=vmin, vmax=vmax) 50 | plt.axis("off") 51 | plt.axis("equal") 52 | plt.savefig("pred.pdf") 53 | plt.close() 54 | 55 | plt.figure(figsize=(6,6),dpi=300) 56 | plt.axes([0,0,1,1]) 57 | plt.scatter(x, y, c=true, cmap="jet", s=160, vmin=vmin, vmax=vmax) 58 | plt.axis("off") 59 | plt.axis("equal") 60 | plt.savefig("true.pdf") 61 | plt.close() 62 | 63 | plt.figure(figsize=(6,6),dpi=300) 64 | plt.axes([0,0,1,1]) 65 | plt.scatter(x, y, c=err, cmap="jet", s=160, vmin=emin, vmax=emax) 66 | plt.axis("off") 67 | plt.axis("equal") 68 | plt.savefig("error.pdf") 69 | plt.close() 70 | -------------------------------------------------------------------------------- /tensorflow/5_Elasticity/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' 6 | import matplotlib.pyplot as plt 7 | from time import time 8 | from utils import * 9 | 10 | # params ##################################33 11 | n_epochs = 500 12 | lr = 0.001 13 | batch_size = 10 14 | encode_dim = 512 15 | out_dim = 1 16 | n_head = 8 17 | n_train = 1000 18 | n_test = 200 19 | 20 | # load dataset 21 | trainX, trainY, testX, testY= load_data("./", n_train, n_test) 22 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 23 | 24 | # define a model 25 | network = PiT(out_dim, encode_dim, n_head, 2, 2) 26 | inputs = tf.keras.Input(shape=trainX.shape[1:]) 27 | outputs = network(inputs) 28 | model = tf.keras.Model(inputs, outputs) 29 | network.summary() 30 | 31 | ### compile model 32 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 33 | loss=rel_norm(), 34 | jit_compile=True) 35 | 36 | ### fit model 37 | start = time() 38 | train_history = model.fit(trainX, trainY, 39 | batch_size, n_epochs, verbose=1, 40 | validation_data=(testX, testY), 41 | validation_batch_size=20) 42 | end = time() 43 | print(' ') 44 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 45 | print(' ') 46 | 47 | ### save model 48 | model.save_weights('./checkpoints/my_checkpoint') 49 | 50 | ### training history 51 | loss_hist = train_history.history 52 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 53 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 54 | np.savetxt('./training_history.csv', tf.concat([train_loss, test_loss], axis=-1), delimiter=',', header='train,test', fmt='%1.16f', comments='') 55 | 56 | plt.plot(train_loss, label='train') 57 | plt.plot(test_loss, label='test') 58 | plt.legend() 59 | plt.yscale('log', base=10) 60 | plt.savefig('training_history.png') 61 | plt.close() 62 | 63 | ### evaluation, visualization 64 | index = 2 65 | pred = model.predict(testX) 66 | print(rel_norm()(testY, pred)) 67 | pred = pred[index:index+1,...] 68 | true = testY[index:index+1,...] 69 | 70 | err = abs(true-pred) 71 | emax = err.max() 72 | emin = err.min() 73 | vmax = true.max() 74 | vmin = true.min() 75 | print(vmax, vmin, emax, emin) 76 | 77 | plt.figure(figsize=(6,6),dpi=300) 78 | plt.axes([0,0,1,1]) 79 | x = testX[index,:,:1] 80 | y = testX[index,:,1:2] 81 | plt.scatter(x, y, c=pred, cmap="jet", s=160, vmin=vmin, vmax=vmax) 82 | plt.axis("off") 83 | plt.axis("equal") 84 | plt.savefig("pred.pdf") 85 | plt.close() 86 | 87 | plt.figure(figsize=(6,6),dpi=300) 88 | plt.axes([0,0,1,1]) 89 | plt.scatter(x, y, c=true, cmap="jet", s=160, vmin=vmin, vmax=vmax) 90 | plt.axis("off") 91 | plt.axis("equal") 92 | plt.savefig("true.pdf") 93 | plt.close() 94 | 95 | plt.figure(figsize=(6,6),dpi=300) 96 | plt.axes([0,0,1,1]) 97 | plt.scatter(x, y, c=err, cmap="jet", s=160, vmin=emin, vmax=emax) 98 | plt.axis("off") 99 | plt.axis("equal") 100 | plt.savefig("error.pdf") 101 | plt.close() 102 | -------------------------------------------------------------------------------- /tensorflow/5_Elasticity/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | import numpy as np 3 | import tensorflow as tf 4 | tf.keras.utils.set_random_seed(0) 5 | tf.config.experimental.enable_op_determinism() 6 | tf.random.set_seed(0) 7 | physical_devices = tf.config.list_physical_devices('GPU') 8 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 9 | import tensorflow_probability as tfp 10 | 11 | 12 | class rel_norm(tf.keras.losses.Loss): 13 | ''' 14 | Compute the average relative l2 loss between a batch of targets and predictions 15 | ''' 16 | def __init__(self): 17 | super().__init__() 18 | def call(self, true, pred): 19 | ''' 20 | true: (batch_size, L, d). 21 | pred: (batch_size, L, d). 22 | number of variables d=1 23 | ''' 24 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1,))(true-pred), axis=1), tf.norm(tf.keras.layers.Reshape((-1,))(true), axis=1)) 25 | return tf.math.reduce_mean(rel_error) 26 | 27 | def load_data(path, ntrain, ntest): 28 | 29 | R = np.transpose(np.load(path + "Random_UnitCell_rr_10.npy"), (1,0))[:,np.newaxis,:] #(2000,1,42) 30 | X = np.transpose(np.load(path + "Random_UnitCell_XY_10.npy"), (2,0,1)) #(2000,972,2) 31 | R = np.repeat(5*R-1, X.shape[1], 1) #(2000,972,42) 32 | X = np.concatenate((X,R), axis=-1) #(2000,972,46) 33 | Y = np.transpose(np.load(path + "Random_UnitCell_sigma_10.npy"), (1,0))[...,np.newaxis] 34 | 35 | return X[:ntrain,...].astype("float32"), Y[:ntrain,...].astype("float32"), X[-ntest:,...].astype("float32"), Y[-ntest:,...].astype("float32") 36 | 37 | class mlp(tf.keras.layers.Layer): 38 | ''' 39 | A two-layer MLP with GELU activation. 40 | ''' 41 | def __init__(self, n_filters1, n_filters2): 42 | super(mlp, self).__init__() 43 | 44 | self.width1 = n_filters1 45 | self.width2 = n_filters2 46 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 47 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 48 | 49 | def call(self, inputs): 50 | x = self.mlp1(inputs) 51 | x = self.mlp2(x) 52 | return x 53 | 54 | def get_config(self): 55 | config = { 56 | 'n_filters1': self.width1, 57 | 'n_filters2': self.width2 58 | } 59 | return config 60 | 61 | class MultiHeadPosAtt(tf.keras.layers.Layer): 62 | def __init__(self, n_head, hid_dim, locality): 63 | super(MultiHeadPosAtt, self).__init__() 64 | 65 | self.locality = locality 66 | self.hid_dim = hid_dim 67 | self.n_head = n_head 68 | self.v_dim = round(self.hid_dim/self.n_head) 69 | 70 | def build(self, input_shape): 71 | 72 | self.r = self.add_weight( 73 | shape=(1, self.n_head, 1, 1), 74 | trainable=True, 75 | name="dist", 76 | ) 77 | 78 | self.weight = self.add_weight( 79 | shape=(self.n_head, self.hid_dim, self.v_dim), 80 | initializer="he_normal", 81 | trainable=True, 82 | name="weight", 83 | ) 84 | self.built = True 85 | 86 | def call(self, m_dist, inputs): 87 | """ 88 | m_dist: (batch, N, N) 89 | """ 90 | scaled_dist = tf.expand_dims(m_dist, 1) * self.r**2 #(batch_size, n_heads ,L, L) 91 | if self.locality <= 100: 92 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 93 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 94 | else: 95 | scaled_dist = scaled_dist 96 | scaled_dist = -scaled_dist #(batch_size, n_heads ,L, L) 97 | att = tf.nn.softmax(scaled_dist, axis=-1) #(batch_size, n_heads ,L, L) 98 | 99 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight)#(batch_size, n_head, L, v_dim) 100 | concat = tf.einsum("bhnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 101 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 102 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 103 | return tf.keras.activations.gelu(concat) 104 | 105 | 106 | def get_config(self): 107 | config = { 108 | 'hid_dim': self.hid_dim, 109 | 'n_head': self.n_head, 110 | 'locality': self.locality 111 | } 112 | return config 113 | 114 | class PiT(tf.keras.Model): 115 | ''' 116 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 117 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 118 | ''' 119 | def __init__(self, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 120 | super(PiT, self).__init__() 121 | ''' 122 | out_dim: number of variables 123 | hid_dim: encoding dimension (network width) 124 | n_head: number of heads in multi-head attention modules 125 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive filed 126 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive filed 127 | ''' 128 | self.out_dim = out_dim 129 | self.hid_dim = hid_dim 130 | self.n_head = n_head 131 | self.en_local = locality_encoder 132 | self.de_local = locality_decoder 133 | self.n_blocks = 4 # number of position-attention modules in the Processor 134 | 135 | # Encoder 136 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 137 | self.down = MultiHeadPosAtt(self.n_head, self.hid_dim, locality=self.en_local) 138 | self.mlp1 = mlp(self.hid_dim, self.hid_dim) 139 | self.w1 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 140 | 141 | # Processor 142 | self.PA = [MultiHeadPosAtt(self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 143 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 144 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 145 | 146 | # Decoder 147 | self.up = MultiHeadPosAtt(self.n_head, self.hid_dim, locality=self.de_local) 148 | self.mlp2 = mlp(self.hid_dim, self.hid_dim) 149 | self.w2 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 150 | self.de_layer = mlp(self.hid_dim, self.out_dim) 151 | 152 | def call(self, inputs): 153 | 154 | m_dist = self.pairwise_dist(inputs[...,:2]) 155 | 156 | # Encoder 157 | en = self.en_layer(inputs) # (batch_size, L_qry, hid_dim) 158 | x = self.mlp1(self.down(m_dist, en)) + self.w1(en) # (batch_size, L_qry, hid_dim) 159 | x = tf.keras.activations.gelu(x) # (batch_size, L_qry, hid_dim) 160 | 161 | # Processor 162 | for i in range(self.n_blocks): 163 | x = self.MLP[i](self.PA[i](m_dist, x)) + self.W[i](x) # (batch_size, L_qry, hid_dim) 164 | x = tf.keras.activations.gelu(x) # (batch_size, L_qry, hid_dim) 165 | 166 | # Decoder 167 | de = self.mlp2(self.up(m_dist, x)) + self.w2(x) # (batch_size, L_qry, hid_dim) 168 | de = tf.keras.activations.gelu(de) # (batch_size, L_qry, hid_dim) 169 | de = self.de_layer(de) # (batch_size, L_qry, out_dim) 170 | return de 171 | 172 | def pairwise_dist(self, mesh): 173 | 174 | mesh = tf.expand_dims(mesh, axis=1) 175 | pairwise_diff = mesh - tf.transpose(mesh, (0,2,1,3)) 176 | pairwise_dist = tf.norm(pairwise_diff, ord=2, axis=-1) 177 | return pairwise_dist**2 / 2.0 178 | 179 | def get_config(self): 180 | config = { 181 | 'out_dim': self.out_dim, 182 | 'hid_dim': self.hid_dim, 183 | 'n_head': self.n_head, 184 | 'locality_encoder': self.en_local, 185 | 'locality_decoder': self.de_local 186 | } 187 | return config 188 | 189 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 190 | ''' 191 | Scaled dot-product multi-head self-attention 192 | ''' 193 | def __init__(self, n_head, hid_dim): 194 | super(MultiHeadSelfAtt, self).__init__() 195 | 196 | self.hid_dim = hid_dim 197 | self.n_head = n_head 198 | self.v_dim = round(self.hid_dim/self.n_head) 199 | 200 | def build(self, input_shape): 201 | 202 | self.q = self.add_weight( 203 | shape=(self.n_head, input_shape[-1], self.v_dim), 204 | initializer="he_normal", 205 | trainable=True, 206 | name="query", 207 | ) 208 | 209 | self.k = self.add_weight( 210 | shape=(self.n_head, input_shape[-1], self.v_dim), 211 | initializer="he_normal", 212 | trainable=True, 213 | name="key", 214 | ) 215 | 216 | self.v = self.add_weight( 217 | shape=(self.n_head, input_shape[-1], self.v_dim), 218 | initializer="he_normal", 219 | trainable=True, 220 | name="value", 221 | ) 222 | self.built = True 223 | 224 | def call(self, inputs): 225 | 226 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q)#(batch, n_head, L, v_dim) 227 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k)#(batch, n_head, L, v_dim) 228 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1)#(batch, n_heads, L, L) 229 | 230 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v)#(batch, n_head, L, v_dim) 231 | 232 | concat = tf.einsum("...nj,...jd->...nd", att, value)#(batch, n_head, L, v_dim) 233 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 234 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 235 | return tf.keras.activations.gelu(concat) 236 | 237 | def get_config(self): 238 | config = { 239 | 'n_head': self.n_head, 240 | 'hid_dim': self.hid_dim 241 | } 242 | return config 243 | 244 | class LiteTransformer(tf.keras.Model): 245 | ''' 246 | Replace position-attention of the Processor in a PiT with self-attention 247 | ''' 248 | def __init__(self, out_dim, hid_dim, n_head, en_local, de_local): 249 | super(LiteTransformer, self).__init__() 250 | 251 | self.out_dim = out_dim 252 | self.hid_dim = hid_dim 253 | self.n_head = n_head 254 | self.en_local = en_local 255 | self.de_local = de_local 256 | self.n_blocks = 4 257 | 258 | # Encoder 259 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 260 | self.down = MultiHeadPosAtt(self.n_head, self.hid_dim, self.en_local) 261 | self.mlp1 = mlp(self.hid_dim, self.hid_dim) 262 | self.w1 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 263 | 264 | # Processor 265 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 266 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 267 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 268 | 269 | # Decoder 270 | self.up = MultiHeadPosAtt(self.n_head, self.hid_dim, self.de_local) 271 | self.mlp2 = mlp(self.hid_dim, self.hid_dim) 272 | self.w2 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 273 | self.de_layer = mlp(self.hid_dim, self.out_dim) 274 | 275 | def call(self, m_dist, inputs): 276 | 277 | m_dist = self.pairwise_dist(inputs[...,:2]) 278 | 279 | # Encoder 280 | en = self.en_layer(inputs) 281 | x = self.mlp1(self.down(m_dist, en)) + self.w1(en) 282 | x = tf.keras.activations.gelu(x) 283 | 284 | # Processor 285 | for i in range(self.n_blocks): 286 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 287 | x = tf.keras.activations.gelu(x) 288 | 289 | # Decoder 290 | de = self.mlp2(self.up(m_dist, x)) + self.w2(x) 291 | de = tf.keras.activations.gelu(de) 292 | de = self.de_layer(de) 293 | 294 | return de 295 | 296 | def pairwise_dist(self, mesh): 297 | 298 | mesh = tf.expand_dims(mesh, axis=1) 299 | pairwise_diff = mesh - tf.transpose(mesh, (0,2,1,3)) 300 | pairwise_dist = tf.norm(pairwise_diff, ord=2, axis=-1) 301 | return pairwise_dist**2 / 2.0 302 | 303 | def get_config(self): 304 | config = { 305 | 'out_dim': self.out_dim, 306 | 'hid_dim': self.hid_dim, 307 | 'n_head': self.n_head, 308 | 'en_local':self.en_local, 309 | 'de_local':self.de_local 310 | } 311 | return config 312 | 313 | class Transformer(tf.keras.Model): 314 | ''' 315 | Replace position-attention of a PiT with self-attention. 316 | ''' 317 | def __init__(self, out_dim, hid_dim, n_head): 318 | super(Transformer, self).__init__() 319 | 320 | self.out_dim = out_dim 321 | self.hid_dim = hid_dim 322 | self.n_head = n_head 323 | self.n_blocks = 4 324 | 325 | # Encoder 326 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 327 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 328 | self.mlp1 = mlp(self.hid_dim, self.hid_dim) 329 | self.w1 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 330 | 331 | # Processor 332 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 333 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 334 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 335 | 336 | # Decoder 337 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 338 | self.mlp2 = mlp(self.hid_dim, self.hid_dim) 339 | self.w2 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 340 | self.de_layer = mlp(self.hid_dim, self.out_dim) 341 | 342 | def call(self, inputs): 343 | 344 | # Encoder 345 | en = self.en_layer(inputs) 346 | x = self.mlp1(self.down(en)) + self.w1(en) 347 | x = tf.keras.activations.gelu(x) 348 | 349 | # Processor 350 | for i in range(self.n_blocks): 351 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 352 | x = tf.keras.activations.gelu(x) 353 | 354 | # Decoder 355 | de = self.mlp2(self.up(x)) + self.w2(x) 356 | de = tf.keras.activations.gelu(de) 357 | de = self.de_layer(de) 358 | 359 | return de 360 | 361 | def get_config(self): 362 | config = { 363 | 'out_dim': self.out_dim, 364 | 'hid_dim': self.hid_dim, 365 | 'n_head': self.n_head, 366 | } 367 | return config 368 | 369 | -------------------------------------------------------------------------------- /tensorflow/6_NACA/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | import matplotlib.pyplot as plt 8 | from time import time 9 | 10 | ### custom imports 11 | from utils import * 12 | 13 | # params ##################################33 14 | encode_dim = 256 15 | out_dim = 1 16 | n_head = 2 17 | n_train = 1000 18 | n_test = 200 19 | 20 | # load dataset 21 | trainX, trainY, testX, testY= load_data("./", ntrain=n_train, ntest=n_test) 22 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 23 | 24 | # define a model 25 | m_cross = pairwise_dist(51, 221, 26, 111) # pairwise distance matrix for encoder and decoder 26 | m_latent = pairwise_dist(26, 111, 26, 111) # pairwise distance matrix for processor 27 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, 0.5, 2) 28 | inputs = tf.keras.Input(shape=(221,51,trainX.shape[-1])) 29 | outputs = network(inputs) 30 | model = tf.keras.Model(inputs, outputs) 31 | network.summary() 32 | print(' ') 33 | ### load model 34 | model.load_weights('./256_0.5_2_selected/checkpoints/my_checkpoint').expect_partial() 35 | 36 | ### evaluation, visualization 37 | pred = model.predict(testX) 38 | 39 | index = -67 40 | true = testY[index,40:-40,:20,0].reshape(-1,1) 41 | pred = model(testX)[index,40:-40,:20,0].numpy().reshape(-1,1) 42 | err = abs(true-pred) 43 | emax = err.max() 44 | emin = err.min() 45 | vmax = true.max() 46 | vmin = true.min() 47 | print(vmax, vmin, emax, emin) 48 | 49 | x = testX[index,40:-40,:20,0].reshape(-1,1) 50 | y = testX[index,40:-40,:20,1].reshape(-1,1) 51 | print(x.max(), x.min(), y.max(), y.min()) 52 | 53 | plt.figure(figsize=(12,12),dpi=100) 54 | plt.scatter(x, y, c=pred, cmap="jet", s=160) 55 | plt.ylim(-0.5,0.5) 56 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 57 | plt.tight_layout(pad=0) 58 | plt.savefig("pred.pdf") 59 | plt.close() 60 | 61 | plt.figure(figsize=(12,12),dpi=100) 62 | plt.scatter(x, y, c=true, cmap="jet", s=160) 63 | plt.ylim(-0.5,0.5) 64 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 65 | plt.tight_layout(pad=0) 66 | plt.savefig("true.pdf") 67 | plt.close() 68 | 69 | plt.figure(figsize=(12,12),dpi=100) 70 | plt.scatter(x, y, c=err, cmap="jet", s=160) 71 | plt.ylim(-0.5,0.5) 72 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 73 | plt.tight_layout(pad=0) 74 | plt.savefig("error.pdf") 75 | plt.close() 76 | 77 | plt.figure(figsize=(12,12),dpi=100) 78 | plt.scatter(x, y, color="black", s=160) 79 | plt.ylim(-0.5,0.5) 80 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 81 | plt.tight_layout(pad=0) 82 | plt.savefig("points.pdf") 83 | plt.close() 84 | -------------------------------------------------------------------------------- /tensorflow/6_NACA/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | import matplotlib.pyplot as plt 8 | from time import time 9 | 10 | ### custom imports 11 | from utils import * 12 | 13 | # params ##################################33 14 | n_epochs = 500 15 | lr = 0.001 16 | batch_size = 8 17 | encode_dim = 256 18 | out_dim = 1 19 | n_head = 2 20 | n_train = 1000 21 | n_test = 200 22 | en_loc = 0.5 23 | de_loc = 2 24 | 25 | # load dataset 26 | trainX, trainY, testX, testY= load_data("./", ntrain=n_train, ntest=n_test) 27 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 28 | 29 | # define a model 30 | m_cross = pairwise_dist(51, 221, 26, 111) # pairwise distance matrix for encoder and decoder 31 | m_latent = pairwise_dist(26, 111, 26, 111) # pairwise distance matrix for processor 32 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc) 33 | 34 | inputs = tf.keras.Input(shape=(221,51,trainX.shape[-1])) 35 | outputs = network(inputs) 36 | model = tf.keras.Model(inputs, outputs) 37 | model.summary() 38 | ### compile model 39 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 40 | loss=rel_norm(), 41 | jit_compile=True) 42 | 43 | ### fit model 44 | start = time() 45 | train_history = model.fit(trainX, trainY, 46 | batch_size, n_epochs, verbose=1, 47 | validation_data=(testX, testY), 48 | validation_batch_size=100) 49 | end = time() 50 | print(' ') 51 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 52 | 53 | print(' ') 54 | ### save model 55 | model.save_weights('./checkpoints/my_checkpoint') 56 | ### training history 57 | loss_hist = train_history.history 58 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 59 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 60 | savetxt('./training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 61 | 62 | plt.plot(train_loss, label='train') 63 | plt.plot(test_loss, label='test') 64 | plt.legend() 65 | plt.yscale('log', base=10) 66 | plt.savefig('training_history.png') 67 | plt.close() 68 | 69 | ### evaluation, visualization 70 | pred = model.predict(testX) 71 | 72 | index = -2 73 | true = testY[index,40:-40,:20,0].reshape(-1,1) 74 | pred = model(testX)[index,40:-40,:20,0].numpy().reshape(-1,1) 75 | err = abs(true-pred) 76 | emax = err.max() 77 | emin = err.min() 78 | vmax = true.max() 79 | vmin = true.min() 80 | print(vmax, vmin, emax, emin) 81 | 82 | x = testX[index,40:-40,:20,0].reshape(-1,1) 83 | y = testX[index,40:-40,:20,1].reshape(-1,1) 84 | print(x.max(), x.min(), y.max(), y.min()) 85 | 86 | plt.figure(figsize=(12,12),dpi=100) 87 | plt.scatter(x, y, c=pred, cmap="jet", s=160) 88 | plt.ylim(-0.5,0.5) 89 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 90 | plt.tight_layout(pad=0) 91 | plt.savefig("pred.pdf") 92 | plt.close() 93 | 94 | plt.figure(figsize=(12,12),dpi=100) 95 | plt.scatter(x, y, c=true, cmap="jet", s=160) 96 | plt.ylim(-0.5,0.5) 97 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 98 | plt.tight_layout(pad=0) 99 | plt.savefig("true.pdf") 100 | plt.close() 101 | 102 | plt.figure(figsize=(12,12),dpi=100) 103 | plt.scatter(x, y, c=err, cmap="jet", s=160) 104 | plt.ylim(-0.5,0.5) 105 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 106 | plt.tight_layout(pad=0) 107 | plt.savefig("error.pdf") 108 | plt.close() 109 | 110 | plt.figure(figsize=(12,12),dpi=100) 111 | plt.scatter(x, y, color="black", s=160) 112 | plt.ylim(-0.5,0.5) 113 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 114 | plt.tight_layout(pad=0) 115 | plt.savefig("points.pdf") 116 | plt.close() 117 | -------------------------------------------------------------------------------- /tensorflow/6_NACA/utils.py: -------------------------------------------------------------------------------- 1 | from numpy import load, concatenate, newaxis, zeros 2 | import tensorflow as tf 3 | tf.keras.utils.set_random_seed(0) 4 | tf.config.experimental.enable_op_determinism() 5 | tf.random.set_seed(0) 6 | physical_devices = tf.config.list_physical_devices('GPU') 7 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 8 | import tensorflow_probability as tfp 9 | 10 | class rel_norm(tf.keras.losses.Loss): 11 | ''' 12 | Compute the average relative l2 loss between a batch of targets and predictions 13 | ''' 14 | def __init__(self): 15 | super().__init__() 16 | def call(self, true, pred): 17 | batch_size = tf.shape(true)[0] 18 | rel_error = tf.math.divide(tf.norm(tf.reshape(true-pred, (batch_size,-1)), axis=1), tf.norm(tf.reshape(true, (batch_size,-1)), axis=1)) 19 | return tf.math.reduce_mean(rel_error) 20 | 21 | def pairwise_dist(res1x, res1y, res2x, res2y): 22 | 23 | gridx = tf.reshape(tf.linspace(0, 1, res1x+1)[:-1], (1,-1,1)) 24 | gridx = tf.tile(gridx, [res1y,1,1]) 25 | gridy = tf.reshape(tf.linspace(0, 1, res1y+1)[:-1], (-1,1,1)) 26 | gridy = tf.tile(gridy, [1,res1x,1]) 27 | grid1 = tf.concat([gridx, gridy], axis=-1) 28 | grid1 = tf.reshape(grid1, (res1x*res1y,2)) #(res1*res1,2) 29 | 30 | gridx = tf.reshape(tf.linspace(0, 1, res2x+1)[:-1], (1,-1,1)) 31 | gridx = tf.tile(gridx, [res2y,1,1]) 32 | gridy = tf.reshape(tf.linspace(0, 1, res2y+1)[:-1], (-1,1,1)) 33 | gridy = tf.tile(gridy, [1,res2x,1]) 34 | grid2 = tf.concat([gridx, gridy], axis=-1) 35 | grid2 = tf.reshape(grid2, (res2x*res2y,2)) #(res2*res2,2) 36 | 37 | print(grid1.shape, grid2.shape) 38 | grid1 = tf.tile(tf.expand_dims(grid1, 1), [1,grid2.shape[0],1]) 39 | grid2 = tf.tile(tf.expand_dims(grid2, 0), [grid1.shape[0],1,1]) 40 | 41 | dist = tf.norm(grid1-grid2, axis=-1) 42 | dist2 = tf.cast(dist**2, 'float32') 43 | return dist2/2 44 | 45 | def load_data(path, ntrain, ntest): 46 | vertices_x = load(path + "NACA_Cylinder_X.npy")[...,newaxis] 47 | vertices_y = load(path + "NACA_Cylinder_Y.npy")[...,newaxis] 48 | mach = load(path + "NACA_Cylinder_Q.npy")[:,4,...][...,newaxis] 49 | 50 | X = concatenate((vertices_x, vertices_y), -1).astype("float32") 51 | Y = mach.astype("float32") 52 | return X[:ntrain,...], Y[:ntrain,...], X[ntrain:ntrain+ntest,...], Y[ntrain:ntrain+ntest,...] 53 | 54 | 55 | class mlp(tf.keras.layers.Layer): 56 | ''' 57 | A two-layer MLP with GELU activation. 58 | ''' 59 | def __init__(self, n_filters1, n_filters2): 60 | super(mlp, self).__init__() 61 | 62 | self.width1 = n_filters1 63 | self.width2 = n_filters2 64 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 65 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 66 | 67 | def call(self, inputs): 68 | x = self.mlp1(inputs) 69 | x = self.mlp2(x) 70 | return x 71 | 72 | def get_config(self): 73 | config = { 74 | 'n_filters1': self.width1, 75 | 'n_filters2': self.width2, 76 | } 77 | return config 78 | 79 | class MultiHeadPosAtt(tf.keras.layers.Layer): 80 | ''' 81 | Global, local and cross variants of the multi-head position-attention mechanism. 82 | ''' 83 | def __init__(self, m_dist, n_head, hid_dim, locality): 84 | super(MultiHeadPosAtt, self).__init__() 85 | ''' 86 | m_dist: distance matrix 87 | n_head: number of attention heads 88 | hid_dim: encoding dimension 89 | locality: quantile parameter to customize receptive field in position-attention 90 | ''' 91 | self.dist = m_dist 92 | self.locality = locality 93 | self.hid_dim = hid_dim 94 | self.n_head = n_head 95 | self.v_dim = round(self.hid_dim/self.n_head) 96 | 97 | def build(self, input_shape): 98 | 99 | self.r = self.add_weight( 100 | shape=(self.n_head, 1, 1), 101 | trainable=True, 102 | constraint=tf.keras.constraints.NonNeg(), 103 | name="band_width", 104 | ) 105 | 106 | self.weight = self.add_weight( 107 | shape=(self.n_head, input_shape[-1], self.v_dim), 108 | initializer="he_normal", 109 | trainable=True, 110 | name="weight", 111 | ) 112 | self.built = True 113 | 114 | def call(self, inputs): 115 | scaled_dist = self.dist * tf.math.tan(self.r) 116 | if self.locality <= 100: 117 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 118 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 119 | else: 120 | pass 121 | scaled_dist = - scaled_dist 122 | att = tf.nn.softmax(scaled_dist, axis=2)#(n_heads, L, L) 123 | 124 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) # (batch_size, n_head, L, v_dim) 125 | 126 | concat = tf.einsum("hnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 127 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 128 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 129 | return tf.keras.activations.gelu(concat) 130 | 131 | def get_config(self): 132 | config = { 133 | 'm_dist': self.dist, 134 | 'hid_dim': self.hid_dim, 135 | 'n_head': self.n_head, 136 | 'locality': self.locality 137 | } 138 | return config 139 | 140 | class PiT(tf.keras.Model): 141 | ''' 142 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 143 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 144 | ''' 145 | def __init__(self, m_cross, m_small, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 146 | super(PiT, self).__init__() 147 | ''' 148 | m_cross: distance matrix between X_query and X_latent; (L_qry,L_ltt) 149 | m_small: distance matrix between X_latent and X_latent; (L_ltt,L_ltt) 150 | out_dim: number of variables 151 | hid_dim: encoding dimension (network width) 152 | n_head: number of heads in multi-head attention modules 153 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive field 154 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive field 155 | ''' 156 | self.m_cross = m_cross 157 | self.m_small = m_small 158 | self.out_dim = out_dim 159 | self.hid_dim = hid_dim 160 | self.n_head = n_head 161 | self.en_local = locality_encoder 162 | self.de_local = locality_decoder 163 | self.n_blocks = 4 164 | 165 | # Encoder 166 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 167 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 168 | 169 | # Processor 170 | self.PA = [MultiHeadPosAtt(self.m_small, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 171 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 172 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 173 | 174 | # Decoder 175 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 176 | self.de_layer = mlp(self.hid_dim, self.out_dim) 177 | 178 | def call(self, inputs): 179 | 180 | # Encoder 181 | grid = self.get_mesh(inputs) #(batch_size, res_qry1, res_qry2, 2) 182 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) # (batch_size, res_qry, res_qry, input_dim+2) 183 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) # (batch_size, L_qry, hid_dim) 184 | en = self.en_layer(en) # (batch_size, L_qry, hid_dim) 185 | x = self.down(en) # (batch_size, L_ltt, hid_dim) 186 | 187 | # Processor 188 | for i in range(self.n_blocks): 189 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) # (batch_size, L_ltt, hid_dim) 190 | x = tf.keras.activations.gelu(x) # (batch_size, L_ltt, hid_dim) 191 | 192 | # Decoder 193 | de = self.up(x) # (batch_size, L_qry, hid_dim) 194 | de = self.de_layer(de) # (batch_size, L_qry, hid_dim) 195 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) # (batch_size, res_qry1, res_qry2, out_dim) 196 | return de 197 | 198 | def get_mesh(self, inputs): 199 | size_x, size_y = tf.shape(inputs)[1], tf.shape(inputs)[2] 200 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 201 | gridx = tf.tile(gridx, [1,1,size_y,1]) 202 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 203 | gridy = tf.tile(gridy, [1,size_x,1,1]) 204 | grid = tf.concat([gridx, gridy], axis=-1) 205 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 206 | 207 | def get_config(self): 208 | config = { 209 | 'm_cross': self.m_cross, 210 | 'm_small': self.m_small, 211 | 'out_dim': self.out_dim, 212 | 'hid_dim': self.hid_dim, 213 | 'n_head': self.n_head, 214 | 'locality_encoder': self.en_local, 215 | 'locality_decoder': self.de_local 216 | } 217 | return config 218 | 219 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 220 | ''' 221 | Scaled dot-product multi-head self-attention 222 | ''' 223 | def __init__(self, n_head, hid_dim): 224 | super(MultiHeadSelfAtt, self).__init__() 225 | 226 | self.hid_dim = hid_dim 227 | self.n_head = n_head 228 | self.v_dim = round(self.hid_dim/self.n_head) 229 | 230 | def build(self, input_shape): 231 | 232 | self.q = self.add_weight( 233 | shape=(self.n_head, input_shape[-1], self.v_dim), 234 | initializer="he_normal", 235 | trainable=True, 236 | name="query", 237 | ) 238 | 239 | self.k = self.add_weight( 240 | shape=(self.n_head, input_shape[-1], self.v_dim), 241 | initializer="he_normal", 242 | trainable=True, 243 | name="key", 244 | ) 245 | 246 | self.v = self.add_weight( 247 | shape=(self.n_head, input_shape[-1], self.v_dim), 248 | initializer="he_normal", 249 | trainable=True, 250 | name="value", 251 | ) 252 | self.built = True 253 | 254 | def call(self, inputs): 255 | 256 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q)#(batch_size, n_head, L, v_dim) 257 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k)#(batch_size, n_head, L ,v_dim) 258 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1)#(batch_size, n_heads, L, L) 259 | 260 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v)#(batch_size, n_head, L, v_dim) 261 | 262 | concat = tf.einsum("...nj,...jd->...nd", att, value)#(batch_size, n_head, L, v_dim) 263 | concat = tf.transpose(concat, (0,2,1,3)) 264 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) 265 | return tf.keras.activations.gelu(concat) 266 | 267 | def get_config(self): 268 | config = { 269 | 'n_head': self.n_head, 270 | 'hid_dim': self.hid_dim 271 | } 272 | return config 273 | 274 | class LiteTransformer(tf.keras.Model): 275 | ''' 276 | Replace position-attention of the Processor in a PiT with self-attention 277 | ''' 278 | def __init__(self, m_cross, out_dim, hid_dim, n_head, en_local, de_local): 279 | super(LiteTransformer, self).__init__() 280 | 281 | self.m_cross = m_cross 282 | self.out_dim = out_dim 283 | self.hid_dim = hid_dim 284 | self.n_head = n_head 285 | self.en_local = en_local 286 | self.de_local = de_local 287 | self.n_blocks = 4 288 | 289 | # Encoder 290 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 291 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 292 | 293 | # Processor 294 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 295 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 296 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 297 | 298 | # Decoder 299 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 300 | self.de_layer = mlp(self.hid_dim, self.out_dim) 301 | 302 | def call(self, inputs): 303 | 304 | # Encoder 305 | grid = self.get_mesh(inputs) 306 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 307 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 308 | en = self.en_layer(en) 309 | x = self.down(en) 310 | 311 | # Processor 312 | for i in range(self.n_blocks): 313 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 314 | x = tf.keras.activations.gelu(x) 315 | 316 | # Decoder 317 | de = self.up(x) 318 | de = self.de_layer(de) 319 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 320 | return de 321 | 322 | def get_mesh(self, inputs): 323 | size_x, size_y = tf.shape(inputs)[1], tf.shape(inputs)[2] 324 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 325 | gridx = tf.tile(gridx, [1,1,size_y,1]) 326 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 327 | gridy = tf.tile(gridy, [1,size_x,1,1]) 328 | grid = tf.concat([gridx, gridy], axis=-1) 329 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 330 | 331 | def get_config(self): 332 | config = { 333 | 'm_cross': self.m_cross, 334 | 'out_dim': self.out_dim, 335 | 'hid_dim': self.hid_dim, 336 | 'n_head': self.n_head, 337 | 'locality_encoder': self.en_local, 338 | 'locality_decoder': self.de_local, 339 | } 340 | return config 341 | 342 | class Transformer(tf.keras.Model): 343 | ''' 344 | Replace position-attention of a PiT with self-attention. 345 | ''' 346 | def __init__(self, out_dim, hid_dim, n_head): 347 | super(Transformer, self).__init__() 348 | 349 | self.out_dim = out_dim 350 | self.hid_dim = hid_dim 351 | self.n_head = n_head 352 | self.n_blocks = 4 353 | 354 | # Encoder 355 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 356 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 357 | 358 | # Processor 359 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 360 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 361 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 362 | 363 | # Decoder 364 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 365 | self.de_layer = mlp(self.hid_dim, self.out_dim) 366 | 367 | def call(self, inputs): 368 | 369 | # Encoder 370 | grid = self.get_mesh(inputs) 371 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 372 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 373 | en = self.en_layer(en) 374 | x = self.down(en) 375 | 376 | # Processor 377 | for i in range(self.n_blocks): 378 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 379 | x = tf.keras.activations.gelu(x) 380 | 381 | # Decoder 382 | de = self.up(x) 383 | de = self.de_layer(de) 384 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 385 | return de 386 | 387 | def get_mesh(self, inputs): 388 | size_x, size_y = tf.shape(inputs)[1], tf.shape(inputs)[2] 389 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 390 | gridx = tf.tile(gridx, [1,1,size_y,1]) 391 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 392 | gridy = tf.tile(gridy, [1,size_x,1,1]) 393 | grid = tf.concat([gridx, gridy], axis=-1) 394 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 395 | 396 | def get_config(self): 397 | config = { 398 | 'res':self.res, 399 | 'out_dim': self.out_dim, 400 | 'hid_dim': self.hid_dim 401 | } 402 | return config 403 | 404 | 405 | 406 | 407 | -------------------------------------------------------------------------------- /tensorflow/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 junfeng-chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tensorflow/README.md: -------------------------------------------------------------------------------- 1 | # Position-induced Transformer 2 | 3 | The code in this repository presents six numerical experiments of using Position-induced Transformer (PiT) for learing operators in partial differential equations. PiT is built upon the position-attention mechanism, proposed in the paper *Positional Knowledge is All You Need: Position-induced Transformer (PiT) for Operator Learning*. The paper can be downloaded here. 4 | 5 | PiT is discretization convergent, giving consistent and convergent predictions as the input data is refined. On the Darcy2D benchmark, a PiT model trained with data at 43x43 resolution can produce accurate predictions given input data at 421x421 resolution. 6 | 7 |

8 | 9 |

10 | 11 | PiT can learn the dynamics governed by the incompressible Navier–Stokes equations, being potential for surrogate modeling of fluids motion. Left: reference vorticity at t=20. Right: predicted vorticity at t=20. 12 | 13 |

14 |      15 |

16 | 17 | PiT is able to approximate highly nonlinear operators. A PiT model can capture discontinuities displayed in the solutions of hyperbolic PDEs. Top left: one-dimensional inviscid Burgers' equation. Top right: one-dimensional compressible Euler equations. Bottom: two-dimensional compressible Euler equations. 18 | 19 |

20 | 21 | 22 |

23 | 24 | PiT can handle irregular point clouds and effectively treat arbitrary boundary conditions. 25 | 26 |

27 | 28 |

29 | 30 | ## Contents 31 | - The numerical experiment on the one-dimensional inviscid Burgers' equation. 32 | - The numerical experiment on the one-dimensional compressible Euler equations. 33 | - The numerical experiment on the two-dimensional Darcy flow problem. 34 | - The numerical experiment on the two-dimensional incompressible Navier–Stokes equations. 35 | - The numerical experiment on the two-dimensional hyper-elastic problem. 36 | - The numerical experiment on the two-dimensional compressible Euler equations. 37 | 38 | ## Data sets 39 | We provide the preprocessed data sets. The raw data required to reproduce the main results can be obtained from some of the baseline methods selected in our paper. 40 | - For InviscidBurgers and ShockTube, data sets are provided in Lanthaler et al. They can be downloaded here. 41 | - For Darcy2D and Vorticity, data sets are provided by Li et al. They can be downloaded here. 42 | - For Elasticity and NACA, data sets are provided by Li et al. They can be downloaded here. 43 | 44 | ## Requirements 45 | - We have run the experiments on a linux OS, with `python==3.10.0`, `CUDA==11.8`, `tensorflow==2.10.0`, and `tensorflow_probability==0.18`. `matplotlib` and `scipy` are also needed for plotting and data loading. 46 | - Since the version of `2.15`, Tensorflow supports installing its NVIDIA CUDA library dependencies through pip. This, for people who know both Tensorflow and PyTorch, represents a great improvement for Tensorflow. We provide the `requirements.txt` file 47 | 48 | --extra-index-url https://pypi.nvidia.com 49 | tensorflow[and-cuda]==2.15 50 | tensorflow_probability==0.23 51 | matplotlib 52 | scipy 53 | 54 | to facilitate running our code with `tensorflow==2.15` and `tensorflow_probability==0.23`. As long as the NVIDIA driver is up to date, simply run 55 | 56 | `python -m pip install -r requirements.txt` 57 | 58 | in an environment with `python=3.9-3.11`. This will set up all the necessary dependencies, with no more need for manually installing CUDA stuff. 59 | 60 | ## Citations 61 | ``` 62 | @inproceedings{chen2024positional, 63 | title={Positional Knowledge is All You Need: Position-induced Transformer (PiT) for Operator Learning}, 64 | author={Junfeng Chen and Kailiang Wu}, 65 | booktitle={International conference on machine learning}, 66 | year={2024}, 67 | organization={PMLR} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /tensorflow/figures/Darcy2D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/tensorflow/figures/Darcy2D.png -------------------------------------------------------------------------------- /tensorflow/figures/Elasticity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/tensorflow/figures/Elasticity.png -------------------------------------------------------------------------------- /tensorflow/figures/InviscidBurgers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/tensorflow/figures/InviscidBurgers.png -------------------------------------------------------------------------------- /tensorflow/figures/NACA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/tensorflow/figures/NACA.png -------------------------------------------------------------------------------- /tensorflow/figures/ShockTube.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/tensorflow/figures/ShockTube.png -------------------------------------------------------------------------------- /tensorflow/figures/err_t20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/tensorflow/figures/err_t20.png -------------------------------------------------------------------------------- /tensorflow/figures/pred_t20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/tensorflow/figures/pred_t20.png -------------------------------------------------------------------------------- /tensorflow/figures/true_t20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/tensorflow/figures/true_t20.png -------------------------------------------------------------------------------- /tensorflow/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://pypi.nvidia.com 2 | tensorflow[and-cuda]==2.15 3 | tensorflow_probability==0.23 4 | matplotlib 5 | scipy 6 | -------------------------------------------------------------------------------- /train_burgers.py: -------------------------------------------------------------------------------- 1 | from pit import * 2 | from timeit import default_timer 3 | import matplotlib.pyplot as plt 4 | from scipy.io import savemat, loadmat 5 | from utils import * 6 | 7 | def load_data(path_data, ntrain = 1024, ntest=128): 8 | 9 | data = loadmat(path_data) 10 | 11 | X_data = data["x"].astype('float32') 12 | Y_data = data["y"].astype('float32') 13 | X_train = X_data[:ntrain,:] 14 | Y_train = Y_data[:ntrain,:] 15 | X_test = X_data[-ntest:,:] 16 | Y_test = Y_data[-ntest:,:] 17 | return torch.from_numpy(X_train[...,np.newaxis]), torch.from_numpy(Y_train[...,np.newaxis]), torch.from_numpy(X_test[...,np.newaxis]), torch.from_numpy(Y_test[...,np.newaxis]) 18 | 19 | class pit_burgers(pit_periodic1d): 20 | def __init__(self, 21 | space_dim, 22 | in_dim, 23 | out_dim, 24 | hid_dim, 25 | n_head, 26 | n_blocks, 27 | mesh_ltt, 28 | en_loc, 29 | de_loc): 30 | super(pit_burgers, self).__init__(space_dim, 31 | in_dim, 32 | out_dim, 33 | hid_dim, 34 | n_head, 35 | n_blocks, 36 | mesh_ltt, 37 | en_loc, 38 | de_loc) 39 | 40 | def forward(self, mesh_in, func_in, mesh_out): 41 | ''' 42 | func_in: (batch_size, L, self.in_dim) 43 | ext: (batch_size, L, self.space_dim) 44 | ''' 45 | func_in = torch.cat((torch.tile(mesh_in.unsqueeze(0), [func_in.shape[0],1,1]), func_in),-1) 46 | func_ltt = self.encoder(mesh_in, func_in, self.mesh_ltt) 47 | func_ltt = self.processor(func_ltt, self.mesh_ltt) 48 | func_out = self.decoder(self.mesh_ltt, func_ltt, mesh_out) 49 | return func_out 50 | 51 | ntrain = 1024 52 | ntest = 128 53 | batch_size = 8 54 | learning_rate = 0.001 55 | epochs = 500 56 | iterations = epochs*(ntrain//batch_size) 57 | 58 | x_train, y_train, x_test, y_test = load_data('./supplementary_data/data_burgers.mat', ntrain, ntest) 59 | mesh = torch.linspace(0,1,x_train.shape[1]+1)[:-1].reshape(-1,1).cuda() 60 | mesh_ltt = torch.linspace(0,1,256+1)[:-1].reshape(-1,1).cuda() 61 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 62 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 63 | 64 | model = pit_burgers(space_dim=1, 65 | in_dim=1, 66 | out_dim=1, 67 | hid_dim=64, 68 | n_head=2, 69 | n_blocks=5, 70 | mesh_ltt=mesh_ltt, 71 | en_loc=0.02, 72 | de_loc=0.02).cuda() 73 | model = torch.compile(model) 74 | print(count_params(model)) 75 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 76 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 77 | 78 | myloss = RelLpNorm(out_dim=1, p=1) 79 | l2err = RelLpNorm(out_dim=1, p=2) 80 | maxerr = RelMaxNorm(out_dim=1) 81 | 82 | for ep in range(epochs): 83 | model.train() 84 | t1 = default_timer() 85 | train_loss = 0 86 | for x, y in train_loader: 87 | x, y = x.cuda(), y.cuda() 88 | optimizer.zero_grad() 89 | out = model(mesh, x, mesh) 90 | loss = myloss(y, out) 91 | loss.backward() 92 | optimizer.step() 93 | scheduler.step() 94 | train_loss += loss.item() 95 | 96 | model.eval() 97 | test_loss = 0.0 98 | test_l2 = 0.0 99 | test_max = 0.0 100 | with torch.no_grad(): 101 | for x, y in test_loader: 102 | x, y = x.cuda(), y.cuda() 103 | out = model(mesh, x, mesh) 104 | test_loss += myloss(y, out).item() 105 | test_l2 += l2err(y,out).item() 106 | test_max += maxerr(y,out).item() 107 | 108 | train_loss /= ntrain 109 | test_loss /= ntest 110 | test_l2 /= ntest 111 | test_max /= ntest 112 | 113 | t2 = default_timer() 114 | print(ep, t2-t1, train_loss, test_loss, test_l2, test_max) 115 | 116 | torch.save({'model_state': model.state_dict()}, 'model.pth') 117 | ################################### 118 | model.eval() 119 | pred = np.zeros_like(y_test.numpy()) 120 | count = 0 121 | with torch.no_grad(): 122 | for x, y in test_loader: 123 | x, y = x.cuda(), y.cuda() 124 | out = model(mesh, x, mesh) 125 | pred[count*batch_size:(count+1)*batch_size,...] = out.detach().cpu().numpy() 126 | count += 1 127 | 128 | y_test = y_test.numpy() 129 | print("relative l1 error", (np.linalg.norm(y_test-pred, axis=1, ord=1) / np.linalg.norm(y_test, axis=1, ord=1)).mean()) 130 | print("relative l2 error", (np.linalg.norm(y_test-pred, axis=1, ord=2) / np.linalg.norm(y_test, axis=1, ord=2)).mean()) 131 | print("relative l_inf error", (abs(y_test-pred).max(axis=1) / abs(y_test).max(axis=1)).mean() ) 132 | savemat("pred.mat", mdict={'pred':pred, 'trueX':x_test, 'trueY':y_test}) 133 | 134 | 135 | index = -1 136 | true = y_test[index,...].reshape(-1,) 137 | pred = pred[index,...].reshape(-1,) 138 | mesh = mesh.detach().cpu().numpy().reshape(-1,) 139 | plt.figure(figsize=(12,12),dpi=100) 140 | plt.plot(mesh, true, label='true') 141 | plt.plot(mesh, pred, label='pred') 142 | plt.savefig("{}_pred.pdf".format(index)) 143 | plt.close() -------------------------------------------------------------------------------- /train_cylinder.py: -------------------------------------------------------------------------------- 1 | from pit import * 2 | from timeit import default_timer 3 | import matplotlib.pyplot as plt 4 | import matplotlib.tri as tri 5 | from scipy.io import loadmat, savemat 6 | from utils import * 7 | 8 | def load_data(file_path_train, file_path_test, ntrain, ntest): 9 | data_train = loadmat(file_path_train)['trajectories'].astype('float32')# (1000, 4390, 3, 11) 10 | data_test = loadmat(file_path_test)['trajectories'].astype('float32')# (100, 4390, 3, 11) 11 | trainX = data_train[:ntrain,:,:,:-1].transpose(0,3,1,2).reshape(-1, 4390, 3) # ntrain 1000 12 | trainY = data_train[:ntrain,:,:,1:].transpose(0,3,1,2).reshape(-1, 4390, 3) 13 | testX = data_test[:ntest,:,:,:-1].transpose(0,3,1,2).reshape(-1, 4390, 3) 14 | testY = data_test[:ntest,:,:,1:].transpose(0,3,1,2).reshape(-1, 4390, 3) 15 | 16 | return torch.from_numpy(trainX), torch.from_numpy(trainY), torch.from_numpy(testX), torch.from_numpy(testY) 17 | 18 | class pit_cylinder(pit_fixed): 19 | def __init__(self, 20 | space_dim, 21 | in_dim, 22 | out_dim, 23 | hid_dim, 24 | n_head, 25 | n_blocks, 26 | mesh_ltt, 27 | en_loc, 28 | de_loc): 29 | super(pit_cylinder, self).__init__(space_dim, 30 | in_dim, 31 | out_dim, 32 | hid_dim, 33 | n_head, 34 | n_blocks, 35 | mesh_ltt, 36 | en_loc, 37 | de_loc) 38 | def forward(self, mesh_in, func_in, mesh_out): 39 | ''' 40 | func_in: (batch_size, L, self.in_dim) 41 | mesh_out: (L, self.space_dim) 42 | ''' 43 | x = func_in 44 | size = mesh_out.shape[:-1] 45 | mesh_in = mesh_in.reshape(-1, self.space_dim) 46 | mesh_out = mesh_out.reshape(-1, self.space_dim) 47 | 48 | func_in = torch.cat((torch.tile(mesh_in.unsqueeze(0), [func_in.shape[0],1,1]), func_in),-1) 49 | func_ltt = self.encoder(mesh_in, func_in, self.mesh_ltt) 50 | func_ltt = self.processor(func_ltt, self.mesh_ltt) 51 | func_out = self.decoder(self.mesh_ltt, func_ltt, mesh_out) 52 | return func_out + x 53 | 54 | ################################################################ 55 | ntrain = 1000 56 | ntest = 100 57 | batch_size = 200 58 | learning_rate = 0.001 59 | epochs = 500 60 | iterations = epochs*(ntrain*10//batch_size) 61 | ################################################################ 62 | # load data and data normalization 63 | ################################################################ 64 | x_train, y_train, x_test, y_test = load_data("./WakeCylinder_train.mat", "./WakeCylinder_test.mat", ntrain, ntest) 65 | mesh = torch.from_numpy(np.genfromtxt("vertices.csv", delimiter=",").astype("float32")).to('cuda')# (4390, 2) 66 | mesh_ltt = torch.from_numpy(np.genfromtxt("vertices_small.csv", delimiter=",").astype("float32")).to('cuda')# (896, 2) 67 | elements = np.genfromtxt("elements.csv", delimiter=",").astype("int32")-1# (8552, 3) 68 | steps = y_train.shape[-1] 69 | print(x_train.shape, y_train.shape, x_test.shape, y_test.shape, mesh.shape, mesh_ltt.shape, elements.shape) 70 | ###### 71 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 72 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 73 | ################################################################ 74 | # training and evaluation 75 | ################################################################ 76 | model = pit_cylinder(space_dim=2, 77 | in_dim=3, 78 | out_dim=3, 79 | hid_dim=256, 80 | n_head=1, 81 | n_blocks=4, 82 | mesh_ltt=mesh_ltt, 83 | en_loc=0.01, 84 | de_loc=0.01).cuda() 85 | model = torch.compile(model) 86 | print(count_params(model)) 87 | 88 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 89 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 90 | 91 | myloss = RelLpNorm(out_dim=3, p=2) 92 | for ep in range(epochs): 93 | model.train() 94 | t1 = default_timer() 95 | train_l2 = 0 96 | for x, y in train_loader: 97 | loss = 0. 98 | x, y = x.cuda(), y.cuda() 99 | optimizer.zero_grad() 100 | out = model(mesh, x, mesh) 101 | loss += myloss(out, y) 102 | # for t in range(steps): 103 | # out = model(mesh, x, mesh) 104 | # loss += myloss(out, y[..., t]) 105 | # x = out 106 | loss.backward() 107 | optimizer.step() 108 | train_l2 += loss.item() 109 | scheduler.step() 110 | 111 | model.eval() 112 | test_l2 = 0. 113 | with torch.no_grad(): 114 | for x, y in test_loader: 115 | loss = 0. 116 | x, y = x.cuda(), y.cuda() 117 | out = model(mesh, x, mesh) 118 | loss += myloss(out, y) 119 | # for t in range(steps): 120 | # out = model(mesh, x, mesh) 121 | # loss += myloss(out, y[..., t]) 122 | # x = out 123 | test_l2 += loss.item() 124 | 125 | # train_l2/= (ntrain*steps) 126 | # test_l2 /= (ntest*steps) 127 | train_l2/= (10*ntrain) 128 | test_l2 /= (10*ntest) 129 | t2 = default_timer() 130 | print(ep, t2-t1, train_l2, test_l2) 131 | torch.save({'model_state': model.state_dict()}, 'model.pth') 132 | # checkpoint = torch.load('model.pth', weights_only=True) 133 | # model = torch.compile(model) 134 | # model.load_state_dict(checkpoint['model_state']) 135 | # model = torch._dynamo.disable(model) 136 | ############################## do evaluation 137 | y_test = loadmat("./WakeCylinder_test.mat")['trajectories'].astype('float32') 138 | y_test = torch.from_numpy(y_test) 139 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(y_test[...,0], y_test[...,1:]), batch_size=100, shuffle=False) 140 | steps = y_test.shape[-1] - 1 141 | pred = torch.zeros_like(y_test).cuda() 142 | pred[...,0] = y_test[...,0] 143 | i = 0 144 | with torch.no_grad(): 145 | for x, y in test_loader: 146 | loss = 0. 147 | x, y = x.cuda(), y.cuda() 148 | print(i) 149 | for t in range(steps): 150 | out = model(mesh, pred[batch_size*i:batch_size*(i+1),:,:,t], mesh) 151 | loss += myloss(out, y[...,t]) 152 | pred[batch_size*i:batch_size*(i+1),:,:,t+1] = out 153 | i += 1 154 | 155 | savemat('pred.mat', mdict={'true':y_test.detach().cpu().numpy(), 'pred':pred.detach().cpu().numpy()}) 156 | rel_err = loss / (steps*y_test.shape[0]) 157 | print(rel_err) 158 | 159 | ######### plots 160 | index = 89 161 | y_test = y_test.detach().cpu().numpy() 162 | pred = pred.detach().cpu().numpy() 163 | err = y_test - pred 164 | abs_err = abs(err[index,...]) 165 | emax = abs_err.max(axis=(0,2)) 166 | emin = abs_err.min(axis=(0,2)) 167 | print("error range", emax, emin) 168 | y_test = y_test[index,...] 169 | vmax = y_test.max(axis=(0,2)) 170 | vmin = y_test.min(axis=(0,2)) 171 | print("vorticity range", vmax, vmin) 172 | pred = pred[index,...] 173 | 174 | save_path="." 175 | triangulation = tri.Triangulation(mesh[:,0].cpu(), mesh[:,1].cpu(), elements) 176 | for d in range(3): 177 | print("Plot variable {}.".format(d+1)) 178 | for t in range(steps+1): 179 | print(" Plot time {}.".format(t)) 180 | plt.figure(figsize=(8,4),dpi=100) 181 | plt.axes([0,0,1,1]) 182 | plt.tricontourf(triangulation, y_test[:,d,t], vmax=vmax[d], vmin=vmin[d], levels=512, cmap='plasma') 183 | plt.axis('off') 184 | plt.axis('equal') 185 | plt.savefig(save_path+'/true_variable{}_time{}.pdf'.format(d+1,t)) 186 | plt.close() 187 | 188 | plt.figure(figsize=(8,4),dpi=100) 189 | plt.axes([0,0,1,1]) 190 | plt.tricontourf(triangulation, pred[:,d,t], vmax=vmax[d], vmin=vmin[d], levels=512, cmap='plasma') 191 | plt.axis('off') 192 | plt.axis('equal') 193 | plt.savefig(save_path+'/pred_variable{}_time{}.pdf'.format(d+1,t)) 194 | plt.close() 195 | 196 | plt.figure(figsize=(8,4),dpi=100) 197 | plt.axes([0,0,1,1]) 198 | plt.tricontourf(triangulation, abs_err[:,d,t], vmax=emax[d], vmin=emin[d], levels=512, cmap='plasma') 199 | plt.axis('off') 200 | plt.axis('equal') 201 | plt.savefig(save_path+'/err_variable{}_time{}.pdf'.format(d+1,t)) 202 | plt.close() -------------------------------------------------------------------------------- /train_darcy.py: -------------------------------------------------------------------------------- 1 | from pit import * 2 | from timeit import default_timer 3 | import matplotlib.pyplot as plt 4 | from scipy.io import loadmat, savemat 5 | from utils import * 6 | 7 | def load_data(train_path, test_path, downsampling, ntrain, ntest): 8 | 9 | s = int(((421 - 1)/downsampling) + 1) 10 | 11 | train = loadmat(train_path) 12 | a = train["coeff"].astype('float32') 13 | u = train["sol"].astype('float32') 14 | 15 | trainX = a[:ntrain,::downsampling,::downsampling][:,:s,:s] 16 | trainY = u[:ntrain,::downsampling,::downsampling][:,:s,:s] 17 | 18 | test = loadmat(test_path) 19 | a = test["coeff"].astype('float32') 20 | u = test["sol"].astype('float32') 21 | testX = a[:ntest,::downsampling,::downsampling][:,:s,:s] 22 | testY = u[:ntest,::downsampling,::downsampling][:,:s,:s] 23 | return torch.from_numpy(trainX[...,np.newaxis]), torch.from_numpy(trainY[...,np.newaxis]), torch.from_numpy(testX[...,np.newaxis]), torch.from_numpy(testY[...,np.newaxis]) 24 | 25 | class pit_darcy(pit_fixed): 26 | def __init__(self, 27 | space_dim, 28 | in_dim, 29 | out_dim, 30 | hid_dim, 31 | n_head, 32 | n_blocks, 33 | mesh_ltt, 34 | en_loc, 35 | de_loc): 36 | super(pit_darcy, self).__init__(space_dim, 37 | in_dim, 38 | out_dim, 39 | hid_dim, 40 | n_head, 41 | n_blocks, 42 | mesh_ltt, 43 | en_loc, 44 | de_loc) 45 | 46 | def forward(self, mesh_in, func_in, mesh_out): 47 | ''' 48 | func_in: (batch_size, L, self.in_dim) 49 | ext: (batch_size, h, w, self.space_dim) 50 | ''' 51 | size = mesh_out.shape[:-1] 52 | mesh_in = mesh_in.reshape(-1, self.space_dim) 53 | func_in = func_in.reshape(func_in.shape[0], -1, self.in_dim) 54 | mesh_out = mesh_out.reshape(-1, self.space_dim) 55 | func_in = torch.cat((torch.tile(mesh_in.unsqueeze(0), [func_in.shape[0],1,1]), func_in),-1) 56 | func_ltt = self.encoder(mesh_in, func_in, self.mesh_ltt) 57 | func_ltt = self.processor(func_ltt, self.mesh_ltt) 58 | func_out = self.decoder(self.mesh_ltt, func_ltt, mesh_out) 59 | return func_out.reshape(func_in.shape[0], *size, self.out_dim) 60 | 61 | ################################################################ 62 | TRAIN_PATH = './piececonst_r421_N1024_smooth1.mat' 63 | TEST_PATH = './piececonst_r421_N1024_smooth2.mat' 64 | ntrain = 1024 65 | ntest = 100 66 | batch_size = 8 67 | learning_rate = 0.001 68 | epochs = 30 69 | iterations = epochs*(ntrain//batch_size) 70 | r = 10 71 | s = int(((421 - 1)/r) + 1) 72 | ################################################################ 73 | # load data and data normalization 74 | ################################################################ 75 | x_train, y_train, x_test, y_test = load_data(TRAIN_PATH, TEST_PATH, r, ntrain, ntest) 76 | x_normalizer = PixelWiseNormalization(x_train) 77 | x_train = x_normalizer.normalize(x_train) 78 | x_test = x_normalizer.normalize(x_test) 79 | y_normalizer = PixelWiseNormalization(y_train) 80 | 81 | 82 | ### This part of code for mesh generation is adapted from Li et al. (Fourier Neural Operator for Parametric Partial Differential Equations) 83 | mesh = [] 84 | mesh.append(np.linspace(0, 1, s)) 85 | mesh.append(np.linspace(0, 1, s)) 86 | mesh = np.vstack([xx.ravel() for xx in np.meshgrid(*mesh)]).T 87 | mesh = mesh.reshape(s,s,2) 88 | mesh = torch.tensor(mesh, dtype=torch.float).cuda() 89 | 90 | s_ltt = 16 91 | mesh_ltt = [] 92 | mesh_ltt.append(np.linspace(0, 1, s_ltt)) 93 | mesh_ltt.append(np.linspace(0, 1, s_ltt)) 94 | mesh_ltt = np.vstack([xx.ravel() for xx in np.meshgrid(*mesh_ltt)]).T 95 | mesh_ltt = mesh_ltt.reshape(s_ltt,s_ltt,2) 96 | mesh_ltt = torch.tensor(mesh_ltt, dtype=torch.float).cuda() 97 | ####### 98 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 99 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=10, shuffle=False) 100 | ################################################################ 101 | # training and evaluation 102 | ################################################################ 103 | model = pit_darcy(space_dim=2, 104 | in_dim=1, 105 | out_dim=1, 106 | hid_dim=64, 107 | n_head=2, 108 | n_blocks=4, 109 | mesh_ltt=mesh_ltt, 110 | en_loc=0.02, 111 | de_loc=0.02).cuda() 112 | model = torch.compile(model) 113 | print(count_params(model)) 114 | 115 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 116 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 117 | 118 | myloss = RelLpNorm(out_dim=1, p=2) # LpLoss(size_average=False) 119 | y_normalizer.cuda() 120 | for ep in range(epochs): 121 | model.train() 122 | t1 = default_timer() 123 | train_l2 = 0 124 | for x, y in train_loader: 125 | 126 | x, y = x.cuda(), y.cuda() 127 | optimizer.zero_grad() 128 | out = model(mesh, x, mesh) 129 | out = y_normalizer.denormalize(out) 130 | loss = myloss(y, out) 131 | loss.backward() 132 | optimizer.step() 133 | train_l2 += loss.item() 134 | scheduler.step() 135 | 136 | model.eval() 137 | test_l2 = 0.0 138 | with torch.no_grad(): 139 | for x, y in test_loader: 140 | 141 | x, y = x.cuda(), y.cuda() 142 | out = model(mesh, x, mesh) 143 | out = y_normalizer.denormalize(out) 144 | test_l2 += myloss(y, out).item() 145 | 146 | train_l2/= ntrain 147 | test_l2 /= ntest 148 | t2 = default_timer() 149 | print(ep, t2-t1, train_l2, test_l2) 150 | torch.save({'model_state': model.state_dict()}, 'model.pth') 151 | ############################## do zero-shot super-resolution evaluation 152 | model = torch._dynamo.disable(model) 153 | r = 1 154 | s = 421 155 | x_train, y_train, x_test, y_test = load_data(TRAIN_PATH, TEST_PATH, r, ntrain, ntest) 156 | x_test = x_normalizer.normalize(x_test) 157 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=10, shuffle=False) 158 | del x_train, y_train 159 | ###### 160 | 161 | mesh = [] 162 | mesh.append(np.linspace(0, 1, s)) 163 | mesh.append(np.linspace(0, 1, s)) 164 | mesh = np.vstack([xx.ravel() for xx in np.meshgrid(*mesh)]).T 165 | mesh = mesh.reshape(s,s,2) 166 | mesh = torch.tensor(mesh, dtype=torch.float).cuda() 167 | pred = torch.zeros_like(y_test) 168 | batch_size = 10 169 | i = 0 170 | with torch.no_grad(): 171 | for x, y in test_loader: 172 | x, y = x.cuda(), y.cuda() 173 | print(i) 174 | out = model(mesh, x, mesh) 175 | out = y_normalizer.denormalize(out) 176 | pred[batch_size*i:batch_size*(i+1),...] = out 177 | i += 1 178 | zssr_err = myloss(y_test, pred) / y_test.shape[0] 179 | savemat('zssr.mat', mdict={'true':y_test.detach().cpu().numpy(), 'pred':pred.detach().cpu().numpy()}) 180 | print(zssr_err) 181 | 182 | ######### plots 183 | index = 89 184 | a = x_normalizer.denormalize(x_test)[index,:,:,0].detach().cpu().numpy() 185 | u = y_test[index,:,:,0].detach().cpu().numpy() 186 | pred = pred[index,:,:,0].detach().cpu().numpy() 187 | 188 | abs_err = abs(u-pred) * 10000 189 | emax = np.max(abs_err) 190 | emin = np.min(abs_err) 191 | print("Maximum and minimum error", emax, emin) 192 | amax = np.max(a) 193 | amin = np.min(a) 194 | umax = np.max(u) 195 | umin = np.min(u) 196 | print(amax, amin, umax, umin) 197 | 198 | # plot the contours 199 | plt.figure(figsize=(14,4),dpi=300) 200 | plt.subplot(141) 201 | plt.imshow(a, vmax=12, vmin=3, interpolation='spline16', cmap='plasma') 202 | cbar = plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[3, 12], format='%1.0f') 203 | cbar.ax.tick_params(labelsize=14) # Enlarge colorbar ticks 204 | plt.axis('off') 205 | plt.axis("equal") 206 | plt.title('Permeability', fontsize=16, fontweight='bold') 207 | 208 | plt.subplot(142) 209 | plt.imshow(u, vmax=umax, vmin=umin, interpolation='spline16', cmap='plasma') 210 | cbar = plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[umin, umax], format='%1.3f') 211 | cbar.ax.tick_params(labelsize=14) # Enlarge colorbar ticks 212 | plt.axis('off') 213 | plt.axis("equal") 214 | plt.title('Reference', fontsize=16, fontweight='bold') 215 | 216 | plt.subplot(143) 217 | plt.imshow(pred, vmax=umax, vmin=umin, interpolation='spline16', cmap='plasma') 218 | cbar = plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[umin, umax], format='%1.3f') 219 | cbar.ax.tick_params(labelsize=14) # Enlarge colorbar ticks 220 | plt.axis('off') 221 | plt.axis("equal") 222 | plt.title('Prediction', fontsize=16, fontweight='bold') 223 | 224 | plt.subplot(144) 225 | plt.imshow(abs_err, vmax=emax, vmin=0, interpolation='spline16', cmap='plasma') 226 | cbar = plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[0, emax], format='%1.3f') 227 | cbar.ax.tick_params(labelsize=14) # Enlarge colorbar ticks 228 | plt.axis('off') 229 | plt.axis("equal") 230 | plt.title('Absolute error ('+r"$\times 10^{-4}$"+")", fontsize=16, fontweight='bold') 231 | 232 | plt.subplots_adjust(left=0.0, right=1.0, wspace=0.005) 233 | plt.savefig('./prediction.pdf') 234 | plt.close() -------------------------------------------------------------------------------- /train_elasticity.py: -------------------------------------------------------------------------------- 1 | from pit import * 2 | from timeit import default_timer 3 | import matplotlib.pyplot as plt 4 | from scipy.io import savemat 5 | from utils import * 6 | 7 | def load_data(path, ntrain, ntest): 8 | 9 | R = np.transpose(np.load(path + "Random_UnitCell_rr_10.npy"), (1,0))[:,np.newaxis,:] #(2000,1,42) 10 | X = np.transpose(np.load(path + "Random_UnitCell_XY_10.npy"), (2,0,1)) #(2000,972,2) 11 | ext = X 12 | R = np.repeat(5*R-1, X.shape[1], 1) #(2000,972,42) 13 | X = np.concatenate((X,R), axis=-1) #(2000,972,44) 14 | Y = np.transpose(np.load(path + "Random_UnitCell_sigma_10.npy"), (1,0))[...,np.newaxis] 15 | 16 | return torch.from_numpy(X[:ntrain,...].astype("float32")), torch.from_numpy(ext[:ntrain,...].astype("float32")), torch.from_numpy(Y[:ntrain,...].astype("float32")), torch.from_numpy(X[-ntest:,...].astype("float32")), torch.from_numpy(ext[-ntest:,...].astype("float32")), torch.from_numpy(Y[-ntest:,...].astype("float32")) 17 | 18 | class pit_elasticity(pit): 19 | def __init__(self, 20 | space_dim, 21 | in_dim, 22 | out_dim, 23 | hid_dim, 24 | n_head, 25 | n_blocks, 26 | mesh_ltt, 27 | en_loc, 28 | de_loc): 29 | super(pit_elasticity, self).__init__(space_dim, 30 | in_dim, 31 | out_dim, 32 | hid_dim, 33 | n_head, 34 | n_blocks, 35 | mesh_ltt, 36 | en_loc, 37 | de_loc) 38 | 39 | self.en_layer = kaiming_mlp(self.n_head * self.in_dim, self.hid_dim, self.hid_dim) 40 | 41 | def forward(self, mesh_in, func_in, mesh_out): 42 | ''' 43 | func_in: (batch_size, L, self.space_dim) 44 | ext: (batch_size, h, w, self.space_dim) 45 | ''' 46 | mesh_ltt = mesh_out.clone() 47 | size = mesh_out.shape[:-1] 48 | ## Encoder 49 | func_ltt = self.encoder(mesh_in, func_in, mesh_ltt) 50 | # Processor 51 | func_ltt = self.processor(func_ltt, mesh_ltt) 52 | # Decoder 53 | func_out = self.decoder(mesh_ltt, func_ltt, mesh_out) 54 | return func_out.reshape(*size, self.out_dim) 55 | 56 | ntrain = 1000 57 | ntest = 200 58 | batch_size = 10 59 | learning_rate = 0.001 60 | epochs = 500 61 | iterations = epochs*(ntrain//batch_size) 62 | 63 | x_train, ext_train, y_train, x_test, ext_test, y_test = load_data('./', ntrain, ntest) 64 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, ext_train, y_train), batch_size=batch_size, shuffle=True) 65 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, ext_test, y_test), batch_size=batch_size, shuffle=False) 66 | 67 | model = pit_elasticity(space_dim=2, 68 | in_dim=44, 69 | out_dim=1, 70 | hid_dim=256, 71 | n_head=2, 72 | n_blocks=4, 73 | mesh_ltt=None, 74 | en_loc=0.02, 75 | de_loc=0.02).cuda() 76 | model = torch.compile(model) 77 | print(count_params(model)) 78 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 79 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 80 | 81 | myloss = RelLpNorm(out_dim=1, p=2) 82 | for ep in range(epochs): 83 | model.train() 84 | t1 = default_timer() 85 | train_l2 = 0 86 | for x, ext, y in train_loader: 87 | x, ext, y = x.cuda(), ext.cuda(), y.cuda() 88 | optimizer.zero_grad() 89 | out = model(ext, x, ext) 90 | loss = myloss(y, out) 91 | loss.backward() 92 | 93 | optimizer.step() 94 | scheduler.step() 95 | train_l2 += loss.item() 96 | 97 | model.eval() 98 | test_l2 = 0.0 99 | with torch.no_grad(): 100 | for x, ext, y in test_loader: 101 | x, ext, y = x.cuda(), ext.cuda(), y.cuda() 102 | out = model(ext, x, ext) 103 | test_l2 += myloss(y, out).item() 104 | 105 | train_l2/= ntrain 106 | test_l2 /= ntest 107 | 108 | t2 = default_timer() 109 | print(ep, t2-t1, train_l2, test_l2) 110 | torch.save({'model_state': model.state_dict()}, 'model.pth') 111 | ################################################ 112 | 113 | # checkpoint = torch.load('model_and_optimizer.pth', weights_only=True) 114 | # model.load_state_dict(checkpoint['model_state']) 115 | ######### 116 | rel1err = RelLpNorm(out_dim=1, p=1) 117 | rel2err = RelLpNorm(out_dim=1, p=2) 118 | relMaxerr = RelMaxNorm(out_dim=1) 119 | pred = torch.zeros_like(y_test, device='cpu') 120 | count = 0 121 | 122 | model.eval() 123 | with torch.no_grad(): 124 | for x, ext, y in test_loader: 125 | x, ext, y = x.cuda(), ext.cuda(), y.cuda() 126 | 127 | out = model(ext, x, ext) 128 | pred[count*batch_size:(count+1)*batch_size,...] = out.detach().cpu() 129 | count += 1 130 | print("relative l1 error", rel1err(y_test, pred) / ntest) 131 | print("relative l2 error", rel2err(y_test, pred) / ntest) 132 | print("relative l_inf error", relMaxerr(y_test, pred) / ntest) 133 | savemat("pred.mat", mdict={'pred':pred.numpy(), 'trueX':x_test.numpy(), 'ext':ext_test.numpy(), 'trueY':y_test.numpy()}) 134 | ############################## 135 | index = -1 136 | nvariables = 1 137 | true = y_test.numpy()[index,...].reshape(-1,nvariables) 138 | pred = pred.numpy()[index,...].reshape(-1,nvariables) # 139 | err = abs(true-pred) 140 | emax = err.max(axis=0) 141 | emin = err.min(axis=0) 142 | vmax = true.max(axis=0) 143 | vmin = true.min(axis=0) 144 | print(vmax, vmin, emax, emin) 145 | 146 | x = ext_test.numpy()[index,:,0].reshape(-1,1) 147 | y = ext_test.numpy()[index,:,1].reshape(-1,1) 148 | print(x.max(), x.min(), y.max(), y.min()) 149 | 150 | for i in range(nvariables): 151 | plt.figure(figsize=(12,12),dpi=100) 152 | plt.scatter(x, y, c=pred[:,i], cmap="plasma", s=160) 153 | plt.ylim(0,1.0) 154 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 155 | plt.tight_layout(pad=0) 156 | plt.savefig("{}_pred_{}.pdf".format(index, i+1)) 157 | plt.close() 158 | 159 | plt.figure(figsize=(12,12),dpi=100) 160 | plt.scatter(x, y, c=true[:,i], cmap="plasma", s=160) 161 | plt.ylim(0,1.0) 162 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 163 | plt.tight_layout(pad=0) 164 | plt.savefig("{}_true_{}.pdf".format(index, i+1)) 165 | plt.close() 166 | 167 | plt.figure(figsize=(12,12),dpi=100) 168 | plt.scatter(x, y, c=err[:,i], cmap="plasma", s=160) 169 | plt.ylim(0,1.0) 170 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 171 | plt.tight_layout(pad=0) 172 | plt.savefig("{}_error_{}.pdf".format(index, i+1)) 173 | plt.close() 174 | -------------------------------------------------------------------------------- /train_naca.py: -------------------------------------------------------------------------------- 1 | from pit import * 2 | from timeit import default_timer 3 | import matplotlib.pyplot as plt 4 | from scipy.io import savemat 5 | from utils import * 6 | 7 | def load_data(path, ntrain, ntest): 8 | coords = np.load(path + "shape_coords.npy").astype("float32")#(N,120,2) 9 | 10 | vertices_x = np.load(path + "NACA_Cylinder_X.npy")[...,np.newaxis] 11 | vertices_y = np.load(path + "NACA_Cylinder_Y.npy")[...,np.newaxis] 12 | X = np.concatenate((vertices_x, vertices_y), -1).astype("float32") 13 | Y = np.load(path + "NACA_Cylinder_Q.npy")[:,:4,...].transpose(0,2,3,1).astype("float32") 14 | 15 | return torch.from_numpy(coords[:ntrain,...]), torch.from_numpy(X[:ntrain,...]), torch.from_numpy(Y[:ntrain,...]), torch.from_numpy(coords[-ntest:,...]), torch.from_numpy(X[-ntest:,...]), torch.from_numpy(Y[-ntest:,...]) 16 | 17 | class pit_naca(pit): 18 | def __init__(self, 19 | space_dim, 20 | in_dim, 21 | out_dim, 22 | hid_dim, 23 | n_head, 24 | n_blocks, 25 | mesh_ltt, 26 | x_downsample, 27 | y_downsample, 28 | en_loc, 29 | de_loc): 30 | super(pit_naca, self).__init__(space_dim, 31 | in_dim, 32 | out_dim, 33 | hid_dim, 34 | n_head, 35 | n_blocks, 36 | mesh_ltt, 37 | en_loc, 38 | de_loc) 39 | 40 | self.x_down = x_downsample 41 | self.y_down = y_downsample 42 | self.x_res = int(220/x_downsample) + 1 43 | self.y_res = int(50/y_downsample) + 1 44 | 45 | self.en_layer = kaiming_mlp(self.n_head * self.in_dim, self.hid_dim, self.hid_dim) 46 | 47 | def forward(self, mesh_in, func_in, mesh_out): 48 | ''' 49 | func_in: (batch_size, L, self.space_dim) 50 | ext: (batch_size, h, w, self.space_dim) 51 | ''' 52 | size = mesh_out.shape[:-1] 53 | mesh_ltt, mesh_out = self.ltt_mesh(mesh_out) # (batch_size, l_latent, self.space_dim), (batch_size, l_out, self.space_dim) 54 | ## Encoder 55 | func_ltt = self.encoder(mesh_in, func_in, mesh_ltt) 56 | # Processor 57 | func_ltt = self.processor(func_ltt, mesh_ltt) 58 | # Decoder 59 | func_out = self.decoder(mesh_ltt, func_ltt, mesh_out) 60 | return func_out.reshape(*size, self.out_dim) 61 | 62 | def ltt_mesh(self, mesh_out): 63 | mesh_ltt = mesh_out[:, ::self.x_down, ::self.y_down, :][:, :self.x_res, :self.y_res, :].reshape(mesh_out.shape[0], -1, self.space_dim) 64 | mesh_out = mesh_out.reshape(mesh_out.shape[0], -1, self.space_dim) 65 | return mesh_ltt, mesh_out 66 | 67 | 68 | ntrain = 1000 69 | ntest = 200 70 | batch_size = 20 71 | learning_rate = 0.001 72 | epochs = 500 73 | iterations = epochs*(ntrain//batch_size) 74 | 75 | x_train, ext_train, y_train, x_test, ext_test, y_test = load_data('./', ntrain, ntest) 76 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, ext_train, y_train), batch_size=batch_size, shuffle=True) 77 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, ext_test, y_test), batch_size=batch_size, shuffle=False) 78 | 79 | model = pit_naca(space_dim=2, 80 | in_dim=2, 81 | out_dim=4, 82 | hid_dim=128, 83 | n_head=1, 84 | n_blocks=4, 85 | mesh_ltt=None, 86 | x_downsample=4, 87 | y_downsample=4, 88 | en_loc=0.02, 89 | de_loc=0.02).cuda() 90 | model = torch.compile(model) 91 | print(count_params(model)) 92 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 93 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 94 | 95 | myloss = RelLpNorm(out_dim=4, p=2) 96 | for ep in range(epochs): 97 | model.train() 98 | t1 = default_timer() 99 | train_l2 = 0 100 | for x, ext, y in train_loader: 101 | x, ext, y = x.cuda(), ext.cuda(), y.cuda() 102 | optimizer.zero_grad() 103 | out = model(x, x, ext) 104 | loss = myloss(y, out) 105 | loss.backward() 106 | 107 | optimizer.step() 108 | scheduler.step() 109 | train_l2 += loss.item() 110 | 111 | model.eval() 112 | test_l2 = 0.0 113 | with torch.no_grad(): 114 | for x, ext, y in test_loader: 115 | x, ext, y = x.cuda(), ext.cuda(), y.cuda() 116 | out = model(x, x, ext) 117 | test_l2 += myloss(y, out).item() 118 | 119 | train_l2/= ntrain 120 | test_l2 /= ntest 121 | 122 | t2 = default_timer() 123 | print(ep, t2-t1, train_l2, test_l2) 124 | torch.save({'model_state': model.state_dict()}, 'model.pth') 125 | ################################################ 126 | 127 | # checkpoint = torch.load('model_and_optimizer.pth', weights_only=True) 128 | # model.load_state_dict(checkpoint['model_state']) 129 | ######### 130 | rel1err = RelLpNorm(out_dim=4, p=1) 131 | rel2err = RelLpNorm(out_dim=4, p=2) 132 | relMaxerr = RelMaxNorm(out_dim=4) 133 | pred = torch.zeros_like(y_test, device='cpu') 134 | count = 0 135 | 136 | model.eval() 137 | with torch.no_grad(): 138 | for x, ext, y in test_loader: 139 | x, ext, y = x.cuda(), ext.cuda(), y.cuda() 140 | 141 | out = model(x, x, ext) 142 | pred[count*batch_size:(count+1)*batch_size,...] = out.detach().cpu() 143 | count += 1 144 | print("relative l1 error", rel1err(y_test, pred) / ntest) 145 | print("relative l2 error", rel2err(y_test, pred) / ntest) 146 | print("relative l_inf error", relMaxerr(y_test, pred) / ntest) 147 | savemat("pred.mat", mdict={'pred':pred.numpy(), 'trueX':x_test.numpy(), 'ext':ext_test.numpy(), 'trueY':y_test.numpy()}) 148 | ############################## 149 | index = -1 150 | nvariables = 4 151 | true = y_test.numpy()[index,40:-40,:20,:].reshape(-1,nvariables) 152 | pred = pred.numpy()[index,40:-40,:20,:].reshape(-1,nvariables) # 153 | err = abs(true-pred) 154 | emax = err.max(axis=0) 155 | emin = err.min(axis=0) 156 | vmax = true.max(axis=0) 157 | vmin = true.min(axis=0) 158 | print(vmax, vmin, emax, emin) 159 | 160 | x = ext_test.numpy()[index,40:-40,:20,0].reshape(-1,1) 161 | y = ext_test.numpy()[index,40:-40,:20,1].reshape(-1,1) 162 | print(x.max(), x.min(), y.max(), y.min()) 163 | 164 | for i in range(nvariables): 165 | plt.figure(figsize=(12,12),dpi=100) 166 | plt.scatter(x, y, c=pred[:,i], cmap="plasma", s=160) 167 | plt.ylim(-0.5,0.5) 168 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 169 | plt.tight_layout(pad=0) 170 | plt.savefig("{}_pred_{}.pdf".format(index, i+1)) 171 | plt.close() 172 | 173 | plt.figure(figsize=(12,12),dpi=100) 174 | plt.scatter(x, y, c=true[:,i], cmap="plasma", s=160) 175 | plt.ylim(-0.5,0.5) 176 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 177 | plt.tight_layout(pad=0) 178 | plt.savefig("{}_true_{}.pdf".format(index, i+1)) 179 | plt.close() 180 | 181 | plt.figure(figsize=(12,12),dpi=100) 182 | plt.scatter(x, y, c=err[:,i], cmap="plasma", s=160) 183 | plt.ylim(-0.5,0.5) 184 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 185 | plt.tight_layout(pad=0) 186 | plt.savefig("{}_error_{}.pdf".format(index, i+1)) 187 | plt.close() 188 | -------------------------------------------------------------------------------- /train_sod.py: -------------------------------------------------------------------------------- 1 | from pit import * 2 | from timeit import default_timer 3 | import matplotlib.pyplot as plt 4 | from scipy.io import savemat, loadmat 5 | from utils import * 6 | 7 | def load_data(path_data, ntrain = 1024, ntest=128): 8 | 9 | data = loadmat(path_data) 10 | 11 | X_data = data["x"].astype('float32') 12 | X_data[...,2] = (X_data[...,2]-0.5*X_data[...,1]**2/X_data[...,0])*(1.4-1) # the primitive variable: pressure 13 | X_data[...,1] = X_data[...,1]/X_data[...,0] # the primitive variable: velocity 14 | Y_data = data["y"].astype('float32') 15 | Y_data[...,2] = (Y_data[...,2]-0.5*Y_data[...,1]**2/Y_data[...,0])*(1.4-1) 16 | Y_data[...,1] = Y_data[...,1]/Y_data[...,0] 17 | X_train = X_data[:ntrain,:] 18 | Y_train = Y_data[:ntrain,:] 19 | X_test = X_data[-ntest:,:] 20 | Y_test = Y_data[-ntest:,:] 21 | return torch.from_numpy(X_train), torch.from_numpy(Y_train), torch.from_numpy(X_test), torch.from_numpy(Y_test) 22 | 23 | class pit_sod(pit_fixed): 24 | def __init__(self, 25 | space_dim, 26 | in_dim, 27 | out_dim, 28 | hid_dim, 29 | n_head, 30 | n_blocks, 31 | mesh_ltt, 32 | en_loc, 33 | de_loc): 34 | super(pit_sod, self).__init__(space_dim, 35 | in_dim, 36 | out_dim, 37 | hid_dim, 38 | n_head, 39 | n_blocks, 40 | mesh_ltt, 41 | en_loc, 42 | de_loc) 43 | 44 | def forward(self, mesh_in, func_in, mesh_out): 45 | ''' 46 | func_in: (batch_size, L, self.in_dim) 47 | ext: (batch_size, L, self.space_dim) 48 | ''' 49 | func_in = torch.cat((torch.tile(mesh_in.unsqueeze(0), [func_in.shape[0],1,1]), func_in),-1) 50 | func_ltt = self.encoder(mesh_in, func_in, self.mesh_ltt) 51 | func_ltt = self.processor(func_ltt, self.mesh_ltt) 52 | func_out = self.decoder(self.mesh_ltt, func_ltt, mesh_out) 53 | return func_out 54 | 55 | ntrain = 1024 56 | ntest = 128 57 | batch_size = 8 58 | learning_rate = 0.001 59 | epochs = 500 60 | iterations = epochs*(ntrain//batch_size) 61 | 62 | x_train, y_train, x_test, y_test = load_data('./supplementary_data/data_sod.mat', ntrain, ntest) 63 | mesh = torch.linspace(-5,5,x_train.shape[1]+1)[:-1].reshape(-1,1).cuda() 64 | mesh_ltt = torch.linspace(-5,5,256+1)[:-1].reshape(-1,1).cuda() 65 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 66 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 67 | 68 | model = pit_sod(space_dim=1, 69 | in_dim=3, 70 | out_dim=3, 71 | hid_dim=32, 72 | n_head=1, 73 | n_blocks=2, 74 | mesh_ltt=mesh_ltt, 75 | en_loc=0.02, 76 | de_loc=0.02).cuda() 77 | model = torch.compile(model) 78 | print(count_params(model)) 79 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 80 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 81 | 82 | myloss = RelLpNorm(out_dim=3, p=1) 83 | l2err = RelLpNorm(out_dim=3, p=2) 84 | maxerr = RelMaxNorm(out_dim=3) 85 | 86 | for ep in range(epochs): 87 | model.train() 88 | t1 = default_timer() 89 | train_loss = 0 90 | for x, y in train_loader: 91 | x, y = x.cuda(), y.cuda() 92 | optimizer.zero_grad() 93 | out = model(mesh, x, mesh) 94 | loss = myloss(y, out) 95 | loss.backward() 96 | optimizer.step() 97 | scheduler.step() 98 | train_loss += loss.item() 99 | 100 | model.eval() 101 | test_loss = 0.0 102 | test_l2 = 0.0 103 | test_max = 0.0 104 | with torch.no_grad(): 105 | for x, y in test_loader: 106 | x, y = x.cuda(), y.cuda() 107 | out = model(mesh, x, mesh) 108 | test_loss += myloss(y, out).item() 109 | # test_l1 += l1err(y, out).item() 110 | test_l2 += l2err(y,out).item() 111 | test_max += maxerr(y,out).item() 112 | 113 | train_loss /= ntrain 114 | test_loss /= ntest 115 | # test_l1 /= ntest 116 | test_l2 /= ntest 117 | test_max /= ntest 118 | 119 | t2 = default_timer() 120 | print(ep, t2-t1, train_loss, test_loss, test_l2, test_max) 121 | 122 | torch.save({'model_state': model.state_dict()}, 'model.pth') 123 | ################################### 124 | model.eval() 125 | pred = np.zeros_like(y_test.numpy()) 126 | count = 0 127 | with torch.no_grad(): 128 | for x, y in test_loader: 129 | x, y = x.cuda(), y.cuda() 130 | out = model(mesh, x, mesh) 131 | pred[count*batch_size:(count+1)*batch_size,...] = out.detach().cpu().numpy() 132 | count += 1 133 | 134 | y_test = y_test.numpy() 135 | print("relative l1 error", (np.linalg.norm(y_test-pred, axis=1, ord=1) / np.linalg.norm(y_test, axis=1, ord=1)).mean()) 136 | print("relative l2 error", (np.linalg.norm(y_test-pred, axis=1, ord=2) / np.linalg.norm(y_test, axis=1, ord=2)).mean()) 137 | print("relative l_inf error", (abs(y_test-pred).max(axis=1) / abs(y_test).max(axis=1)).mean() ) 138 | savemat("pred.mat", mdict={'pred':pred, 'trueX':x_test, 'trueY':y_test}) 139 | 140 | 141 | index = -1 142 | true = y_test[index,...].reshape(-1,3) 143 | pred = pred[index,...].reshape(-1,3) 144 | mesh = mesh.detach().cpu().numpy().reshape(-1,) 145 | for i in range(3): 146 | plt.figure(figsize=(12,12),dpi=100) 147 | plt.plot(mesh, true[:,i], label='true') 148 | plt.plot(mesh, pred[:,i], label='pred') 149 | plt.savefig("{}_pred_{}.pdf".format(index, i+1)) 150 | plt.close() -------------------------------------------------------------------------------- /train_vorticity.py: -------------------------------------------------------------------------------- 1 | from pit import * 2 | from timeit import default_timer 3 | import matplotlib.pyplot as plt 4 | from scipy.io import loadmat, savemat 5 | from utils import * 6 | 7 | def load_data(file_path, ntrain, ntest, memory, steps): 8 | try: 9 | data = loadmat(file_path) 10 | except: 11 | import mat73 12 | data = mat73.loadmat(file_path) 13 | flow = data['u'].astype('float32') 14 | del data 15 | trainX = flow[:ntrain,:,:,:memory] # ntrain 1000 16 | trainY = flow[:ntrain,:,:,memory:memory+steps] 17 | testX = flow[-ntest:,:,:,:memory] 18 | testY = flow[-ntest:,:,:,memory:memory+steps] 19 | 20 | del flow 21 | return torch.from_numpy(trainX), torch.from_numpy(trainY), torch.from_numpy(testX), torch.from_numpy(testY) 22 | 23 | class pit_vorticity(pit_periodic2d): 24 | def __init__(self, 25 | space_dim, 26 | in_dim, 27 | out_dim, 28 | hid_dim, 29 | n_head, 30 | n_blocks, 31 | mesh_ltt, 32 | en_loc, 33 | de_loc): 34 | super(pit_vorticity, self).__init__(space_dim, 35 | in_dim, 36 | out_dim, 37 | hid_dim, 38 | n_head, 39 | n_blocks, 40 | mesh_ltt, 41 | en_loc, 42 | de_loc) 43 | self.norm = nn.InstanceNorm1d(hid_dim) 44 | def forward(self, mesh_in, func_in, mesh_out): 45 | ''' 46 | func_in: (batch_size, L, self.in_dim) 47 | ext: (batch_size, h, w, self.space_dim) 48 | ''' 49 | size = mesh_out.shape[:-1] 50 | mesh_in = mesh_in.reshape(-1, self.space_dim) 51 | func_in = func_in.reshape(func_in.shape[0], -1, self.in_dim) 52 | mesh_out = mesh_out.reshape(-1, self.space_dim) 53 | 54 | func_in = torch.cat((torch.tile(mesh_in.unsqueeze(0), [func_in.shape[0],1,1]), func_in),-1) 55 | func_ltt = self.encoder(mesh_in, func_in, self.mesh_ltt) 56 | func_ltt = self.norm(func_ltt.permute(0,2,1)).permute(0,2,1) 57 | 58 | func_ltt = self.processor(func_ltt, self.mesh_ltt) 59 | func_ltt = self.norm(func_ltt.permute(0,2,1)).permute(0,2,1) 60 | 61 | func_out = self.decoder(self.mesh_ltt, func_ltt, mesh_out) 62 | return func_out.reshape(func_in.shape[0], *size, self.out_dim) 63 | 64 | ################################################################ 65 | ntrain = 1000 66 | ntest = 200 67 | batch_size = 20 68 | learning_rate = 0.001 69 | epochs = 500 70 | iterations = epochs*(ntrain//batch_size) 71 | memory = 10 72 | steps = 20 73 | ################################################################ 74 | # load data and data normalization 75 | ################################################################ 76 | x_train, y_train, x_test, y_test = load_data("./NavierStokes_V1e-4_N1200_T30.mat", ntrain, ntest, memory, steps) 77 | s = 64 78 | mesh = [] 79 | mesh.append(np.linspace(0, 1, s+1)[:-1]) 80 | mesh.append(np.linspace(0, 1, s+1)[:-1]) 81 | mesh = np.vstack([xx.ravel() for xx in np.meshgrid(*mesh)]).T 82 | mesh = mesh.reshape(s,s,2) 83 | mesh = torch.tensor(mesh, dtype=torch.float).cuda() 84 | 85 | s_ltt = 16 86 | mesh_ltt = [] 87 | mesh_ltt.append(np.linspace(0, 1, s_ltt+1)[:-1]) 88 | mesh_ltt.append(np.linspace(0, 1, s_ltt+1)[:-1]) 89 | mesh_ltt = np.vstack([xx.ravel() for xx in np.meshgrid(*mesh_ltt)]).T 90 | mesh_ltt = mesh_ltt.reshape(s_ltt,s_ltt,2) 91 | mesh_ltt = torch.tensor(mesh_ltt, dtype=torch.float).cuda() 92 | ####### 93 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) 94 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False) 95 | ################################################################ 96 | # training and evaluation 97 | ################################################################ 98 | model = pit_vorticity(space_dim=2, 99 | in_dim=10, 100 | out_dim=1, 101 | hid_dim=256, 102 | n_head=2, 103 | n_blocks=4, 104 | mesh_ltt=mesh_ltt, 105 | en_loc=0.02, 106 | de_loc=0.02).cuda() 107 | model = torch.compile(model) 108 | print(count_params(model)) 109 | 110 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 111 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations) 112 | 113 | myloss = RelLpNorm(out_dim=1, p=2) 114 | for ep in range(epochs): 115 | model.train() 116 | t1 = default_timer() 117 | train_l2 = 0 118 | for x, y in train_loader: 119 | loss = 0. 120 | x, y = x.cuda(), y.cuda() 121 | optimizer.zero_grad() 122 | for t in range(steps): 123 | out = model(mesh, x, mesh) 124 | loss += myloss(out, y[..., t:t+1]) 125 | x = torch.cat((x[..., 1:], out), dim=-1) 126 | loss.backward() 127 | optimizer.step() 128 | train_l2 += loss.item() 129 | scheduler.step() 130 | 131 | model.eval() 132 | test_l2 = 0. 133 | with torch.no_grad(): 134 | for x, y in test_loader: 135 | loss = 0. 136 | x, y = x.cuda(), y.cuda() 137 | for t in range(steps): 138 | out = model(mesh, x, mesh) 139 | loss += myloss(out, y[..., t:t+1]) 140 | x = torch.cat((x[..., 1:], out), dim=-1) 141 | test_l2 += loss.item() 142 | 143 | train_l2/= (ntrain*steps) 144 | test_l2 /= (ntest*steps) 145 | t2 = default_timer() 146 | print(ep, t2-t1, train_l2, test_l2) 147 | torch.save({'model_state': model.state_dict()}, 'model.pth') 148 | ############################## do evaluation 149 | pred = torch.zeros_like(y_test) 150 | i = 0 151 | with torch.no_grad(): 152 | for x, y in test_loader: 153 | loss = 0. 154 | x, y = x.cuda(), y.cuda() 155 | print(i) 156 | for t in range(steps): 157 | out = model(mesh, x, mesh) 158 | loss += myloss(out, y[..., t:t+1]) 159 | x = torch.cat((x[..., 1:], out), dim=-1) 160 | pred[batch_size*i:batch_size*(i+1),...] = yy 161 | i += 1 162 | 163 | savemat('pred.mat', mdict={'true':y_test.detach().cpu().numpy(), 'pred':pred.detach().cpu().numpy()}) 164 | print(rel_err) 165 | 166 | ######### plots 167 | index = 89 168 | y_test = y_test.detach().cpu().numpy() 169 | pred = pred.detach().cpu().numpy() 170 | err = y_test - pred 171 | abs_err = abs(err[index,...]) 172 | emax = abs_err.max() 173 | emin = abs_err.min() 174 | print("error range", emax, emin) 175 | omega = y_test[index,...] 176 | vmax = omega.max() 177 | vmin = omega.min() 178 | print("vorticity range", vmax, vmin) 179 | omega_p = pred[index,...] 180 | 181 | directory="." 182 | for i in range(steps): 183 | # plot the contours 184 | plt.figure(figsize=(4,4),dpi=300) 185 | plt.axes([0,0,1,1]) 186 | plt.imshow(omega[...,i], vmax=vmax, vmin=vmin, cmap="plasma", interpolation='spline16') 187 | plt.axis('off') 188 | plt.axis('equal') 189 | plt.savefig(directory + '/reference_{}.pdf'.format(i+1)) 190 | plt.close() 191 | 192 | plt.figure(figsize=(4,4),dpi=300) 193 | plt.axes([0,0,1,1]) 194 | plt.imshow(omega_p[...,i], vmax=vmax, vmin=vmin, cmap="plasma", interpolation='spline16') 195 | plt.axis('off') 196 | plt.axis('equal') 197 | plt.savefig(directory + '/pred_{}.pdf'.format(i+1)) 198 | plt.close() 199 | 200 | plt.figure(figsize=(4,4),dpi=300) 201 | plt.axes([0,0,1,1]) 202 | plt.imshow(abs_err[...,i], vmax=emax, vmin=emin, cmap="plasma", interpolation='spline16') 203 | plt.axis('off') 204 | plt.axis('equal') 205 | plt.savefig(directory + '/err_{}.pdf'.format(i+1)) 206 | plt.close() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import operator 2 | from functools import reduce 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | class PixelWiseNormalization(): 7 | def __init__(self, x, eps=1e-5): 8 | 9 | self.mean = torch.mean(x, dim=0, keepdim=True) # (1, h, w, 1) 10 | self.std = torch.std(x, dim=0, keepdim=True) #(1,h,w,1) 11 | self.eps = eps 12 | 13 | def normalize(self, x): 14 | try: 15 | x = (x - self.mean) / (self.std + self.eps) 16 | except:#do upsampling 17 | h = x.shape[1] 18 | w = x.shape[2] 19 | mean = F.interpolate(self.mean.permute(0,3,1,2), size=(h, w), mode='bilinear', align_corners=False).permute(0,2,3,1) 20 | std = F.interpolate(self.std.permute(0,3,1,2), size=(h, w), mode='bilinear', align_corners=False).permute(0,2,3,1) 21 | x = (x - mean) / (std + self.eps) 22 | return x 23 | 24 | def denormalize(self, x): 25 | 26 | try: 27 | x = x * (self.std + self.eps) + self.mean 28 | except:#do upsampling 29 | h = x.shape[1] 30 | w = x.shape[2] 31 | mean = F.interpolate(self.mean.permute(0,3,1,2), size=(h, w), mode='bilinear', align_corners=False).permute(0,2,3,1) 32 | std = F.interpolate(self.std.permute(0,3,1,2), size=(h, w), mode='bilinear', align_corners=False).permute(0,2,3,1) 33 | x = x * (std + self.eps) + mean 34 | return x 35 | 36 | def to(self, device): 37 | if device == 'cpu': 38 | self.mean = self.mean.cpu() 39 | self.std = self.std.cpu() 40 | elif device == 'cuda': 41 | self.mean = self.mean.cuda() 42 | self.std = self.std.cuda() 43 | 44 | def cuda(self): 45 | self.mean = self.mean.cuda() 46 | self.std = self.std.cuda() 47 | 48 | def cpu(self): 49 | self.mean = self.mean.cpu() 50 | self.std = self.std.cpu() 51 | 52 | def count_params(model): 53 | c = 0 54 | for p in list(model.parameters()): 55 | c += reduce(operator.mul, 56 | list(p.size())) 57 | return c 58 | 59 | class RelMaxNorm(object): 60 | def __init__(self, out_dim): 61 | super(RelMaxNorm, self).__init__() 62 | self._out_dim = out_dim 63 | 64 | def __call__(self, true, pred): 65 | # Reshape true and pred 66 | true_reshaped = true.view(true.size(0), -1, self._out_dim) # (batch_size, L, out_dim) 67 | pred_reshaped = pred.view(pred.size(0), -1, self._out_dim) # (batch_size, L, out_dim) 68 | 69 | # Compute the L_inf norm along the second dimension (L) 70 | true_norm = torch.max(torch.abs(true_reshaped), dim=1)[0] # (batch_size, out_dim) 71 | pred_diff_norm = torch.max(torch.abs(true_reshaped - pred_reshaped), dim=1)[0] # (batch_size, out_dim) 72 | 73 | # Compute the relative error 74 | rel_error = pred_diff_norm / true_norm # (batch_size, out_dim) 75 | 76 | # Average across batch and out_dim 77 | return torch.sum(torch.mean(rel_error, dim=-1)) # average over variables, sum over the batch 78 | 79 | #loss function with relative Lp loss 80 | class RelLpNorm(object): 81 | def __init__(self, out_dim, p): 82 | super(RelLpNorm, self).__init__() 83 | self._out_dim = out_dim 84 | self._ord = p 85 | 86 | def __call__(self, true, pred): 87 | # Reshape true and pred 88 | true_reshaped = true.reshape(true.size(0), -1, self._out_dim) # (batch_size, L, out_dim) 89 | pred_reshaped = pred.reshape(pred.size(0), -1, self._out_dim) # (batch_size, L, out_dim) 90 | # Compute the L2 norm along the second dimension (L) 91 | true_norm = torch.norm(true_reshaped, p=self._ord, dim=1) # (batch_size, out_dim) 92 | pred_diff_norm = torch.norm(true_reshaped - pred_reshaped, p=self._ord, dim=1) # (batch_size, out_dim) 93 | 94 | # Compute the relative error 95 | rel_error = pred_diff_norm / true_norm # (batch_size, out_dim) 96 | 97 | # Average across batch and out_dim 98 | return torch.sum(torch.mean(rel_error, dim=-1)) # average over variables, sum over the batch 99 | 100 | --------------------------------------------------------------------------------