├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── pit.py ├── supplementary_data ├── data_burgers.mat ├── data_sod.mat └── shape_coords.npy ├── tensorflow ├── 1_InviscidBurgers │ ├── train.py │ └── utils.py ├── 2_ShockTube │ ├── train.py │ └── utils.py ├── 3_Darcy2D │ ├── evaluate.py │ ├── train.py │ └── utils.py ├── 4_Vorticity │ ├── evaluate.py │ ├── train.py │ └── utils.py ├── 5_Elasticity │ ├── evaluate.py │ ├── train.py │ └── utils.py ├── 6_NACA │ ├── evaluate.py │ ├── train.py │ └── utils.py ├── LICENSE ├── README.md ├── figures │ ├── Darcy2D.png │ ├── Elasticity.png │ ├── InviscidBurgers.png │ ├── NACA.png │ ├── ShockTube.png │ ├── err_t20.png │ ├── pred_t20.png │ └── true_t20.png └── requirements.txt ├── train_burgers.py ├── train_cylinder.py ├── train_darcy.py ├── train_elasticity.py ├── train_naca.py ├── train_sod.py ├── train_vorticity.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.mat filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | *.pdf 3 | *.pth 4 | *.mat 5 | *.log 6 | *.csv 7 | __pycache__ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 junfeng-chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Position-induced Transformer 2 | 3 | The code in this repository presents the implementation of **Position-induced Transformer (PiT)** and the numerical experiments of using PiT for learing operators in partial differential equations. It is built upon the position-attention mechanism, proposed in the paper *Positional Knowledge is All You Need: Position-induced Transformer (PiT) for Operator Learning*. The paper can be accessed here. 4 | 5 | ## Updates May 2025 6 | - PiT is now part of **DUE**, our open-source toolkit for data-driven equation modeling with modern deep learning methods. For installation and examples, visit the GitHub repository. 7 | - Added a numerical example showcasing PiT for learning the unsteady flow past a cylinder. 8 | ## Contents 9 | - `train_burgers`: One-dimensional inviscid Burgers' equation. 10 | - `train_sod`: One-dimensional compressible Euler equations. 11 | - `train_darcy`: Two-dimensional Darcy flow problem. 12 | - `train_vorticity`: Two-dimensional incompressible Navier–Stokes equations with periodic boundary conditions. 13 | - `train_elasticity`: Two-dimensional hyper-elastic problem. 14 | - `train_naca`: Two-dimensional transonic flow over airfoils. 15 | - `train_cylinder`: Two-dimensional flow past cylinder at Reynolds number equal to 100. 16 | 17 | ## Datasets 18 | The raw data required to reproduce the main results can be obtained from some of the baseline methods selected in our paper. 19 | - For InviscidBurgers and ShockTube, datasets are provided in Lanthaler et al. They can be downloaded here. 20 | - For Darcy2D and Vorticity, datasets are provided by Li et al. They can be downloaded here. 21 | - For Elasticity and NACA, datasets are provided by Li et al. They can be downloaded here. 22 | - For Cylinder, the dataset is generated using FEniCS. It can be downloaded here. 23 | 24 | ## Requirements 25 | - This code is primarily based on PyTorch. We have observed significant improvements in PiT's training speed with PyTorch 2.x, especially when using `torch.compile`. Therefore, we highly recommend using PyTorch 2.x with `torch.compile` enabled for optimal performance. 26 | - If any issues arise with `torch.compile` that cannot be resolved, the code is also compatible with recent versions of PyTorch 1.x. In such cases, simply comment out the line `model = torch.compile(model)` in the scripts. 27 | - Matplotlib and Scipy are also required. 28 | 29 | ## Citations 30 | ``` 31 | @inproceedings{chen2024positional, 32 | title={Positional Knowledge is All You Need: Position-induced Transformer (PiT) for Operator Learning}, 33 | author={Chen, Junfeng and Wu, Kailiang}, 34 | booktitle={International Conference on Machine Learning}, 35 | pages={7526--7552}, 36 | year={2024}, 37 | organization={PMLR} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /pit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_float32_matmul_precision('high') 3 | torch.manual_seed(0) 4 | torch.cuda.manual_seed(0) 5 | torch.backends.cudnn.benchmark = True 6 | torch.backends.cudnn.deterministic = True 7 | import torch.nn as nn 8 | from torch.nn.functional import gelu 9 | import numpy as np 10 | np.random.seed(0) 11 | from math import pi 12 | 13 | class kaiming_mlp(nn.Module): 14 | def __init__(self, n_filters0, n_filters1, n_filters2): 15 | super(kaiming_mlp, self).__init__() 16 | self.mlp1 = torch.nn.Linear(n_filters0, n_filters1) 17 | self.mlp2 = torch.nn.Linear(n_filters1, n_filters2) 18 | nn.init.kaiming_normal_(self.mlp1.weight) 19 | nn.init.kaiming_normal_(self.mlp2.weight) 20 | 21 | def forward(self, x): 22 | # print(x.shape) 23 | x = self.mlp1(x) 24 | x = gelu(x) 25 | x = self.mlp2(x) 26 | return x 27 | 28 | class posatt(nn.Module): 29 | def __init__(self, n_head, in_dim, locality): 30 | super(posatt, self).__init__() 31 | 32 | self.locality = locality 33 | self.n_head = n_head 34 | self.in_dim = in_dim 35 | self.lmda = torch.nn.Parameter( torch.rand(n_head, 1, 1) ) 36 | 37 | def forward(self, mesh, inputs): 38 | """ 39 | mesh: (batch, L, 2) 40 | inputs: (batch, L_in, in_dim) 41 | """ 42 | att = self.dist2att(mesh, mesh, self.lmda, self.locality) # (n_head, L, L) 43 | convoluted = self.convolution(att, inputs) 44 | return torch.cat((inputs, convoluted), dim=-1) 45 | 46 | def dist2att(self, mesh_out, mesh_in, scale, locality): 47 | m_dist = torch.sum((mesh_out.unsqueeze(-2) - mesh_in.unsqueeze(-3))**2, dim=-1) # (batch_size, L_out, L_in) 48 | scaled_dist = m_dist.unsqueeze(1) * torch.tan(0.25*pi*(1-1e-7)*(1.0+torch.sin(scale))) # (batch_size, n_head, L_out, L_in) 49 | mask = torch.quantile(scaled_dist, locality, dim=-1, keepdim=True) 50 | scaled_dist = torch.where(scaled_dist <= mask, scaled_dist, torch.tensor(torch.finfo(torch.float32).max, device=scaled_dist.device)) 51 | scaled_dist = -scaled_dist 52 | return torch.nn.Softmax(dim=-1)(scaled_dist) 53 | 54 | def convolution(self, A, U): 55 | convoluted = torch.einsum("bhnj,bjd->bnhd", A, U) # (batch, L_out, n_head * in_dim) 56 | convoluted = convoluted.reshape(U.shape[0], -1, self.n_head * U.shape[-1]) # (batch, L_out, n_head*hid_dim) 57 | return convoluted 58 | 59 | class posatt_cross(posatt): 60 | def __init__(self, n_head, in_dim, locality): 61 | super(posatt_cross, self).__init__(n_head, in_dim, locality) 62 | 63 | def forward(self, mesh_out, mesh_in, inputs): 64 | """ 65 | mesh_out: (batch, L_out, 2) 66 | mesh_in: (batch, L_in, 2) 67 | inputs: (batch, L_in, in_dim) 68 | """ 69 | att = self.dist2att(mesh_out, mesh_in, self.lmda, self.locality) 70 | convoluted = self.convolution(att, inputs) 71 | return convoluted 72 | 73 | class pit(nn.Module): 74 | def __init__(self, 75 | space_dim, 76 | in_dim, 77 | out_dim, 78 | hid_dim, 79 | n_head, 80 | n_blocks, 81 | mesh_ltt, 82 | en_loc, 83 | de_loc): 84 | 85 | super(pit, self).__init__() 86 | self.space_dim= space_dim 87 | self.in_dim = in_dim 88 | self.out_dim = out_dim 89 | self.hid_dim = hid_dim 90 | self.n_head = n_head 91 | self.n_blocks = n_blocks 92 | if mesh_ltt != None: 93 | self.mesh_ltt = mesh_ltt.reshape(-1, self.space_dim) 94 | else: 95 | self.mesh_ltt = mesh_ltt 96 | self.en_local = en_loc 97 | self.de_local = de_loc 98 | 99 | self.down = posatt_cross(self.n_head, self.in_dim, self.en_local) 100 | self.en_layer = kaiming_mlp(self.n_head * (self.in_dim+self.space_dim), self.hid_dim, self.hid_dim) 101 | 102 | self.conv = torch.nn.ModuleList([posatt(self.n_head, self.hid_dim, 1.0) for _ in range(self.n_blocks)]) 103 | self.mlp = torch.nn.ModuleList([kaiming_mlp((1 + self.n_head) * self.hid_dim, self.hid_dim, self.hid_dim) for _ in range(self.n_blocks)]) # with residual in the convolutions 104 | 105 | self.up = posatt_cross(self.n_head, self.hid_dim, self.de_local) 106 | self.de = kaiming_mlp(self.n_head * self.hid_dim, self.hid_dim, self.out_dim) 107 | 108 | def encoder(self, mesh_in, func_in, mesh_ltt): 109 | func_ltt = self.down(mesh_ltt, mesh_in, func_in) 110 | func_ltt = self.en_layer(func_ltt) 111 | func_ltt = gelu(func_ltt ) 112 | return func_ltt 113 | 114 | def processor(self, func_ltt, mesh_ltt): 115 | for a, w in zip(self.conv, self.mlp): 116 | ''' 117 | U = AUW 118 | ''' 119 | func_ltt = a(mesh_ltt, func_ltt) 120 | func_ltt = w(func_ltt) 121 | func_ltt = gelu(func_ltt) 122 | return func_ltt 123 | 124 | def decoder(self, mesh_ltt, func_ltt, mesh_out): 125 | func_out = self.up(mesh_out, mesh_ltt, func_ltt) 126 | func_out = self.de(func_out) 127 | return func_out 128 | 129 | class posatt_fixed(posatt): 130 | def __init__(self, n_head, in_dim, locality): 131 | super(posatt_fixed, self).__init__(n_head, in_dim, locality) 132 | 133 | def dist2att(self, mesh_out, mesh_in, scale, locality): 134 | m_dist = torch.sum((mesh_out.unsqueeze(-2) - mesh_in.unsqueeze(-3))**2, dim=-1) # (L_out, L_in) 135 | scaled_dist = m_dist * torch.tan(0.25*pi*(1-1e-7)*(1.0+torch.sin(scale))) # (n_head, L_out, L_in) 136 | mask = torch.quantile(scaled_dist, locality, dim=-1, keepdim=True) 137 | scaled_dist = torch.where(scaled_dist <= mask, scaled_dist, torch.tensor(torch.finfo(torch.float32).max, device=scaled_dist.device)) 138 | scaled_dist = -scaled_dist 139 | return torch.nn.Softmax(dim=-1)(scaled_dist) 140 | 141 | def convolution(self, A, U): 142 | convoluted = torch.einsum("hnj,bjd->bnhd", A, U) # (batch, L_out, n_head * in_dim) 143 | convoluted = convoluted.reshape(U.shape[0], -1, self.n_head * U.shape[-1]) # (batch, L_out, n_head*hid_dim) 144 | return convoluted 145 | 146 | class posatt_cross_fixed(posatt_fixed): 147 | 148 | def __init__(self, n_head, in_dim, locality): 149 | super(posatt_cross_fixed, self).__init__(n_head, in_dim, locality) 150 | 151 | def forward(self, mesh_out, mesh_in, inputs): 152 | """ 153 | mesh_out: (L_out, 2) 154 | mesh_in: (L_in, 2) 155 | inputs: (batch, L_in, in_dim) 156 | """ 157 | att = self.dist2att(mesh_out, mesh_in, self.lmda, self.locality) 158 | convoluted = self.convolution(att, inputs) 159 | return convoluted 160 | 161 | class pit_fixed(pit): 162 | def __init__(self, 163 | space_dim, 164 | in_dim, 165 | out_dim, 166 | hid_dim, 167 | n_head, 168 | n_blocks, 169 | mesh_ltt, 170 | en_loc, 171 | de_loc): 172 | 173 | super(pit_fixed, self).__init__(space_dim, 174 | in_dim, 175 | out_dim, 176 | hid_dim, 177 | n_head, 178 | n_blocks, 179 | mesh_ltt, 180 | en_loc, 181 | de_loc) 182 | self.down = posatt_cross_fixed(self.n_head, self.in_dim, self.en_local) 183 | self.conv = torch.nn.ModuleList([posatt_fixed(self.n_head, self.hid_dim, 1.0) for _ in range(self.n_blocks)]) 184 | self.up = posatt_cross_fixed(self.n_head, self.hid_dim, self.de_local) 185 | ############################## 186 | class posatt_periodic1d(posatt_fixed): 187 | def __init__(self, n_head, in_dim, locality): 188 | super(posatt_periodic1d, self).__init__(n_head, in_dim, locality) 189 | 190 | def dist2att(self, mesh_out, mesh_in, scale, locality): 191 | dx = torch.abs(mesh_in[1,0] - mesh_in[0,0]) 192 | l = dx * mesh_in.shape[0] 193 | m_diff = abs(mesh_out.unsqueeze(-2) - mesh_in.unsqueeze(-3)) 194 | m_diff = torch.minimum(m_diff, l-m_diff) 195 | m_dist = m_diff[...,0]**2 196 | scaled_dist = m_dist * torch.tan(0.25*pi*(1-1e-7)*(1.0+torch.sin(scale))) # (n_head, L_out, L_in) 197 | mask = torch.quantile(scaled_dist, locality, dim=-1, keepdim=True) 198 | scaled_dist = torch.where(scaled_dist <= mask, scaled_dist, torch.tensor(torch.finfo(torch.float32).max, device=scaled_dist.device)) 199 | scaled_dist = -scaled_dist 200 | return torch.nn.Softmax(dim=-1)(scaled_dist) 201 | 202 | class posatt_cross_periodic1d(posatt_periodic1d): 203 | 204 | def __init__(self, n_head, in_dim, locality): 205 | super(posatt_cross_periodic1d, self).__init__(n_head, in_dim, locality) 206 | 207 | def forward(self, mesh_out, mesh_in, inputs): 208 | """ 209 | mesh_out: (L_out, 2) 210 | mesh_in: (L_in, 2) 211 | inputs: (batch, L_in, in_dim) 212 | """ 213 | att = self.dist2att(mesh_out, mesh_in, self.lmda, self.locality) 214 | convoluted = self.convolution(att, inputs) 215 | return convoluted 216 | 217 | class pit_periodic1d(pit): 218 | def __init__(self, 219 | space_dim, 220 | in_dim, 221 | out_dim, 222 | hid_dim, 223 | n_head, 224 | n_blocks, 225 | mesh_ltt, 226 | en_loc, 227 | de_loc): 228 | 229 | super(pit_periodic1d, self).__init__(space_dim, 230 | in_dim, 231 | out_dim, 232 | hid_dim, 233 | n_head, 234 | n_blocks, 235 | mesh_ltt, 236 | en_loc, 237 | de_loc) 238 | self.down = posatt_cross_periodic1d(self.n_head, self.in_dim, self.en_local) 239 | self.conv = torch.nn.ModuleList([posatt_periodic1d(self.n_head, self.hid_dim, 1.0) for _ in range(self.n_blocks)]) 240 | self.up = posatt_cross_periodic1d(self.n_head, self.hid_dim, self.de_local) 241 | ################ 242 | 243 | class posatt_periodic2d(posatt_fixed): 244 | def __init__(self, n_head, in_dim, locality): 245 | super(posatt_periodic2d, self).__init__(n_head, in_dim, locality) 246 | 247 | def dist2att(self, mesh_out, mesh_in, scale, locality): 248 | res = int(mesh_in.shape[0]**0.5) 249 | dx =( torch.max(mesh_in[:,0]) - torch.min(mesh_in[:,0]) ) / (res - 1) 250 | l = dx * res 251 | m_diff = abs(mesh_out.unsqueeze(-2) - mesh_in.unsqueeze(-3)) 252 | m_diff = torch.minimum(m_diff, l-m_diff) 253 | m_dist = torch.sum(m_diff**2, dim=-1) 254 | scaled_dist = m_dist * torch.tan(0.25*pi*(1-1e-7)*(1.0+torch.sin(scale))) # (n_head, L_out, L_in) 255 | mask = torch.quantile(scaled_dist, locality, dim=-1, keepdim=True) 256 | scaled_dist = torch.where(scaled_dist <= mask, scaled_dist, torch.tensor(torch.finfo(torch.float32).max, device=scaled_dist.device)) 257 | scaled_dist = -scaled_dist 258 | return torch.nn.Softmax(dim=-1)(scaled_dist) 259 | 260 | class posatt_cross_periodic2d(posatt_periodic2d): 261 | 262 | def __init__(self, n_head, in_dim, locality): 263 | super(posatt_cross_periodic2d, self).__init__(n_head, in_dim, locality) 264 | 265 | def forward(self, mesh_out, mesh_in, inputs): 266 | """ 267 | mesh_out: (L_out, 2) 268 | mesh_in: (L_in, 2) 269 | inputs: (batch, L_in, in_dim) 270 | """ 271 | att = self.dist2att(mesh_out, mesh_in, self.lmda, self.locality) 272 | convoluted = self.convolution(att, inputs) 273 | return convoluted 274 | 275 | class pit_periodic2d(pit): 276 | def __init__(self, 277 | space_dim, 278 | in_dim, 279 | out_dim, 280 | hid_dim, 281 | n_head, 282 | n_blocks, 283 | mesh_ltt, 284 | en_loc, 285 | de_loc): 286 | 287 | super(pit_periodic2d, self).__init__(space_dim, 288 | in_dim, 289 | out_dim, 290 | hid_dim, 291 | n_head, 292 | n_blocks, 293 | mesh_ltt, 294 | en_loc, 295 | de_loc) 296 | self.down = posatt_cross_periodic2d(self.n_head, self.in_dim, self.en_local) 297 | self.conv = torch.nn.ModuleList([posatt_periodic2d(self.n_head, self.hid_dim, 1.0) for _ in range(self.n_blocks)]) 298 | self.up = posatt_cross_periodic2d(self.n_head, self.hid_dim, self.de_local) 299 | -------------------------------------------------------------------------------- /supplementary_data/data_burgers.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7dbf1534c7d8bde94f43c3bee886c6bd278debd9dacd93ea5747915f361ba8a0 3 | size 18874608 4 | -------------------------------------------------------------------------------- /supplementary_data/data_sod.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c57e134c73cdba4aa247ae36b498cc08bef65a2337fc99b6397b13f20e872443 3 | size 125829376 4 | -------------------------------------------------------------------------------- /supplementary_data/shape_coords.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junfeng-chen/position_induced_transformer/05d1c2db041aa665bc04f252b31a213e15972a2d/supplementary_data/shape_coords.npy -------------------------------------------------------------------------------- /tensorflow/1_InviscidBurgers/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | from scipy.io import savemat 8 | import matplotlib.pyplot as plt 9 | from time import time 10 | 11 | ### custom imports 12 | from utils import * 13 | 14 | # params ##################################33 15 | n_epochs = 500 16 | lr = 0.001 17 | batch_size = 5 18 | encode_dim = 64 19 | out_dim = 1 20 | n_head = 2 21 | n_train = 950 22 | n_test = 128 23 | qry_res = 1024 24 | ltt_res = 1024 25 | en_loc = 1 # locality paramerter in Encoder 26 | de_local = 8 # locality paramerter in Encoder 27 | save_path = './model/' 28 | 29 | 30 | ### load dataset 31 | trainX, trainY, testX, testY= load_data("./1_InviscidBurgers.mat", n_train, n_test) 32 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 33 | 34 | ### define a model 35 | m_query = pairwise_dist(qry_res, qry_res) # pairwise distance matrix for Encoder and Decoder 36 | m_cross = pairwise_dist(qry_res, ltt_res) # pairwise distance matrix for Encoder and Decoder 37 | m_latent = pairwise_dist(ltt_res, ltt_res) # pairwise distance matrix for Processor 38 | network = PiT(m_query, m_cross, m_latent, out_dim, encode_dim, n_head, 1, 8) 39 | #network = LiteTransformer(m_query, m_cross, out_dim, encode_dim, n_head, 1, 8) 40 | #network = Transformer(qry_res, out_dim, encode_dim, n_head) 41 | 42 | inputs = tf.keras.Input(shape=(qry_res,trainX.shape[-1])) 43 | outputs = network(inputs) 44 | model = tf.keras.Model(inputs, outputs) 45 | network.summary() 46 | 47 | ### compile model 48 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 49 | loss=rel_norm(), 50 | jit_compile=True) 51 | 52 | ### fit model 53 | start = time() 54 | train_history = model.fit(trainX, trainY, 55 | batch_size, n_epochs, verbose=1, # verbose to 0 when nohup, to 1 when run in shell 56 | validation_data=(testX, testY), 57 | validation_batch_size=128) 58 | end = time() 59 | print(' ') 60 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 61 | print(' ') 62 | 63 | ### save model 64 | model.save_weights(save_path + 'checkpoints/my_checkpoint') 65 | 66 | ### plot training history 67 | loss_hist = train_history.history 68 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 69 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 70 | savetxt(save_path + 'training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 71 | 72 | plt.plot(train_loss, label='train') 73 | plt.plot(test_loss, label='test') 74 | plt.legend() 75 | plt.yscale('log', base=10) 76 | plt.savefig(save_path + 'training_history.png') 77 | plt.close() 78 | 79 | ### do some tests 80 | pred = model.predict(testX) 81 | savemat(save_path + 'pred.mat', mdict={'pred':pred, 'X':testX, 'Y':testY}) 82 | print(rel_l1_median(testY, pred)) 83 | -------------------------------------------------------------------------------- /tensorflow/2_ShockTube/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | from scipy.io import savemat 8 | import matplotlib.pyplot as plt 9 | from time import time 10 | 11 | ### custom imports 12 | from utils import * 13 | 14 | # params ##################################33 15 | n_epochs = 500 16 | lr = 0.001 17 | batch_size = 8 18 | encode_dim = 64 19 | out_dim = 1 20 | n_head = 2 21 | n_train = 1024 22 | n_test = 128 23 | qry_res = 2048 24 | ltt_res = 1024 25 | en_loc = 4 # locality paramerter in Encoder 26 | de_loc = 2 # locality paramerter in Encoder 27 | save_path = './model/' 28 | # load dataset 29 | trainX, trainY, testX, testY= load_data("./2_ShockTube.mat", n_train, n_test) 30 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 31 | 32 | ### define a model 33 | m_query = pairwise_dist(qry_res, qry_res) # pairwise distance matrix for Encoder and Decoder 34 | m_cross = pairwise_dist(qry_res, ltt_res) # pairwise distance matrix for Encoder and Decoder 35 | m_latent = pairwise_dist(ltt_res, ltt_res) # pairwise distance matrix for Processor 36 | network = PiT(m_query, m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc) 37 | #network = LiteTransformer(m_query, m_cross, out_dim, encode_dim, n_head, en_loc, de_loc) 38 | #network = Transformer(qry_res, out_dim, encode_dim, n_head) 39 | 40 | inputs = tf.keras.Input(shape=(qry_res,trainX.shape[-1])) 41 | outputs = network(inputs) 42 | model = tf.keras.Model(inputs, outputs) 43 | network.summary() 44 | 45 | ### compile model 46 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 47 | loss=rel_norm(), 48 | jit_compile=True) # jit_compile always on, for faster training 49 | 50 | ### fit model 51 | start = time() 52 | train_history = model.fit(trainX, trainY, 53 | batch_size, n_epochs, verbose=1, # verbose to 0 when nohup, to 1 when run in shell 54 | validation_data=(testX, testY), 55 | validation_batch_size=128) 56 | end = time() 57 | print(' ') 58 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 59 | 60 | print(' ') 61 | ### save model 62 | model.save_weights(save_path + 'checkpoints/my_checkpoint') 63 | 64 | ### plot training history 65 | loss_hist = train_history.history 66 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 67 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 68 | savetxt(save_path + 'training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 69 | 70 | plt.plot(train_loss, label='train') 71 | plt.plot(test_loss, label='test') 72 | plt.legend() 73 | plt.yscale('log', base=10) 74 | plt.savefig(save_path + 'training_history.png') 75 | plt.close() 76 | 77 | 78 | pred = model.predict(testX) 79 | savemat(save_path + 'pred.mat', mdict={'pred':pred, 'X':testX, 'Y':testY}) 80 | print(rel_l1_median(testY, pred)) 81 | -------------------------------------------------------------------------------- /tensorflow/2_ShockTube/utils.py: -------------------------------------------------------------------------------- 1 | from numpy import newaxis 2 | from scipy.io import loadmat 3 | import tensorflow as tf 4 | tf.keras.utils.set_random_seed(0) 5 | tf.config.experimental.enable_op_determinism() 6 | tf.random.set_seed(0) 7 | physical_devices = tf.config.list_physical_devices('GPU') 8 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 9 | import tensorflow_probability as tfp 10 | 11 | class rel_norm(tf.keras.losses.Loss): 12 | ''' 13 | Compute the average relative l1 loss between a batch of targets and predictions 14 | ''' 15 | def __init__(self): 16 | super().__init__() 17 | def call(self, true, pred): 18 | ''' 19 | true: (batch_size, L, d). 20 | pred: (batch_size, L, d). 21 | number of variables d=1 22 | ''' 23 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1,))(true-pred), ord=1, axis=1), tf.norm(tf.keras.layers.Reshape((-1,))(true), ord=1, axis=1)) 24 | return tf.math.reduce_mean(rel_error) 25 | 26 | def rel_l1_median(true, pred): 27 | ''' 28 | Compute the 25%, 50%, and 75% quantile of the relative l1 loss between a batch of targets and predictions 29 | ''' 30 | rel_error = tf.norm(true[...,0]-pred[...,0], ord=1, axis=1) / tf.norm(true[...,0], ord=1, axis=1) 31 | return tfp.stats.percentile(rel_error, 25, interpolation="linear").numpy(), tfp.stats.percentile(rel_error, 50, interpolation="linear").numpy(), tfp.stats.percentile(rel_error, 75, interpolation="linear").numpy() 32 | 33 | def pairwise_dist(res1, res2): 34 | 35 | grid1 = tf.reshape(tf.linspace(0, 1, res1+1)[:-1], (-1,1)) 36 | grid1 = tf.tile(grid1, [1,res2]) 37 | grid2 = tf.reshape(tf.linspace(0, 1, res2+1)[:-1], (1,-1)) 38 | grid2 = tf.tile(grid2, [res1,1]) 39 | print(grid1.shape, grid2.shape) 40 | 41 | dist2 = (grid1-grid2)**2 42 | print(dist2.shape, tf.math.reduce_max(dist2)) 43 | 44 | return tf.cast(dist2, 'float32') 45 | 46 | def load_data(path_data, ntrain = 1024, ntest=128): 47 | 48 | data = loadmat(path_data) 49 | X_data = data["x"].astype('float32') 50 | Y_data = data["y"].astype('float32') 51 | X_train = X_data[:ntrain,:] 52 | Y_train = Y_data[:ntrain,:, newaxis] 53 | 54 | X_test = X_data[-ntest:,:] 55 | Y_test = Y_data[-ntest:,:, newaxis] 56 | 57 | return X_train, Y_train, X_test, Y_test 58 | 59 | class mlp(tf.keras.layers.Layer): 60 | ''' 61 | A two-layer MLP with GELU activation. 62 | ''' 63 | def __init__(self, n_filters1, n_filters2): 64 | super(mlp, self).__init__() 65 | 66 | self.width1 = n_filters1 67 | self.width2 = n_filters2 68 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 69 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 70 | 71 | def call(self, inputs): 72 | x = self.mlp1(inputs) 73 | x = self.mlp2(x) 74 | return x 75 | 76 | def get_config(self): 77 | config = { 78 | 'n_filters1': self.width1, 79 | 'n_filters2': self.width2, 80 | } 81 | return config 82 | 83 | class MultiHeadPosAtt(tf.keras.layers.Layer): 84 | ''' 85 | Global, local and cross variants of the multi-head position-attention mechanism. 86 | ''' 87 | def __init__(self, m_dist, n_head, hid_dim, locality): 88 | super(MultiHeadPosAtt, self).__init__() 89 | ''' 90 | m_dist: distance matrix 91 | n_head: number of attention heads 92 | hid_dim: encoding dimension 93 | locality: quantile parameter to customize receptive field in position-attention 94 | ''' 95 | self.dist = m_dist 96 | self.locality = locality 97 | self.hid_dim = hid_dim 98 | self.n_head = n_head 99 | self.v_dim = round(self.hid_dim/self.n_head) 100 | 101 | def build(self, input_shape): 102 | 103 | self.r = self.add_weight( 104 | shape=(self.n_head, 1, 1), 105 | trainable=True, 106 | name="band_width", 107 | ) 108 | 109 | self.weight = self.add_weight( 110 | shape=(self.n_head, input_shape[-1], self.v_dim), 111 | initializer="he_normal", 112 | trainable=True, 113 | name="weight", 114 | ) 115 | self.built = True 116 | 117 | def call(self, inputs): 118 | scaled_dist = self.dist * self.r**2 119 | if self.locality <= 100: 120 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 121 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 122 | else: 123 | pass 124 | scaled_dist = - scaled_dist # (n_head, L, L) 125 | att = tf.nn.softmax(scaled_dist, axis=2) # (n_heads, L, L) 126 | 127 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) # (batch_size, n_head, L, v_dim) 128 | 129 | concat = tf.einsum("hnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 130 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 131 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 132 | return tf.keras.activations.gelu(concat) 133 | 134 | def get_config(self): 135 | config = { 136 | 'm_dist': self.dist, 137 | 'hid_dim': self.hid_dim, 138 | 'n_head': self.n_head, 139 | 'locality': self.locality 140 | } 141 | return config 142 | 143 | class PiT(tf.keras.Model): 144 | ''' 145 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 146 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 147 | ''' 148 | def __init__(self, m_qry, m_cross, m_ltt, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 149 | super(PiT, self).__init__() 150 | ''' 151 | m_qry: distance matrix between X_query and X_query; (L_qry,L_qry) 152 | m_cross: distance matrix between X_query and X_latent; (L_qry,L_ltt) 153 | m_ltt: distance matrix between X_latent and X_latent; (L_ltt,L_ltt) 154 | out_dim: number of variables 155 | hid_dim: encoding dimension (network width) 156 | n_head: number of heads in multi-head attention modules 157 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive filed 158 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive filed 159 | ''' 160 | self.m_qry = m_qry 161 | self.res = m_qry.shape[0] 162 | self.m_cross = m_cross 163 | self.m_ltt = m_ltt 164 | self.out_dim = out_dim 165 | self.hid_dim = hid_dim 166 | self.n_head = n_head 167 | self.en_local = locality_encoder 168 | self.de_local = locality_decoder 169 | self.n_blocks = 4 # number of position-attention modules in the Processor 170 | 171 | # Encoder 172 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 173 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 174 | 175 | # Processor 176 | self.MHPA = [MultiHeadPosAtt(self.m_ltt, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 177 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 178 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 179 | 180 | # Decoder 181 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 182 | self.up2 = MultiHeadPosAtt(self.m_qry, self.n_head, self.hid_dim, locality=self.de_local) 183 | self.mlp = mlp(self.hid_dim, self.hid_dim) 184 | self.w = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 185 | self.de_layer = mlp(self.hid_dim, self.out_dim) 186 | 187 | def call(self, inputs): 188 | 189 | # Encoder 190 | grid = self.get_mesh(inputs) # (batch_size, L_qry, 1) 191 | en = tf.concat([grid, inputs], axis=-1) # (batch_size, L_qry, input_dim+1) 192 | en = self.en_layer(en) # (batch_size, L_qry, hid_dim) 193 | x = self.down(en) # (batch_size, L_ltt, hid_dim) 194 | 195 | # Processor 196 | for i in range(self.n_blocks): 197 | x = self.MLP[i](self.MHPA[i](x)) + self.W[i](x) # (batch_size, L_ltt, hid_dim) 198 | x = tf.keras.activations.gelu(x) # (batch_size, L_ltt, hid_dim) 199 | 200 | # Decoder 201 | de = self.up(x) # (batch_size, L_qry, hid_dim) 202 | de = self.mlp(self.up2(de)) + self.w(de) # (batch_size, L_ltt, hid_dim) 203 | de = tf.keras.activations.gelu(de) # (batch_size, L_ltt, hid_dim) 204 | de = self.de_layer(de) # (batch_size, L_ltt, out_dim) 205 | return de 206 | 207 | def get_mesh(self, inputs): 208 | grid = tf.reshape(tf.linspace(0, 1, self.res+1)[:-1], (1,-1,1)) 209 | grid = tf.repeat(grid, tf.shape(inputs)[0], 0) 210 | return tf.cast(grid, dtype="float32") 211 | 212 | def get_config(self): 213 | config = { 214 | 'm_qry': self.m_qry, 215 | 'm_cross': self.m_cross, 216 | 'm_ltt': self.m_ltt, 217 | 'out_dim': self.out_dim, 218 | 'hid_dim': self.hid_dim, 219 | 'n_head': self.n_head, 220 | 'locality_encoder': self.en_local, 221 | 'locality_decoder': self.de_local, 222 | } 223 | return config 224 | 225 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 226 | ''' 227 | Scaled dot-product multi-head self-attention 228 | ''' 229 | def __init__(self, n_head, hid_dim): 230 | super(MultiHeadSelfAtt, self).__init__() 231 | 232 | self.hid_dim = hid_dim 233 | self.n_head = n_head 234 | self.v_dim = round(self.hid_dim/self.n_head) 235 | 236 | def build(self, input_shape): 237 | 238 | self.q = self.add_weight( 239 | shape=(self.n_head, input_shape[-1], self.v_dim), 240 | initializer="he_normal", 241 | trainable=True, 242 | name="query", 243 | ) 244 | 245 | self.k = self.add_weight( 246 | shape=(self.n_head, input_shape[-1], self.v_dim), 247 | initializer="he_normal", 248 | trainable=True, 249 | name="key", 250 | ) 251 | 252 | self.v = self.add_weight( 253 | shape=(self.n_head, input_shape[-1], self.v_dim), 254 | initializer="he_normal", 255 | trainable=True, 256 | name="value", 257 | ) 258 | self.built = True 259 | 260 | def call(self, inputs): 261 | 262 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q) # (batch_size, n_head, L, v_dim) 263 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k) # (batch_size, n_head, L, v_dim) 264 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1) # (batch_size, n_heads, L, L) 265 | 266 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v)#(batch_size, n_head, L, v_dim) 267 | 268 | concat = tf.einsum("...nj,...jd->...nd", att, value) # (batch_size, n_head, L, v_dim) 269 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 270 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 271 | return tf.keras.activations.gelu(concat) 272 | 273 | def get_config(self): 274 | config = { 275 | 'n_head': self.n_head, 276 | 'hid_dim': self.hid_dim 277 | } 278 | return config 279 | 280 | class LiteTransformer(tf.keras.Model): 281 | ''' 282 | Replace position-attention of the Processor in a PiT with self-attention 283 | ''' 284 | def __init__(self, m_qry, m_cross, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 285 | super(LiteTransformer, self).__init__() 286 | 287 | self.m_qry = m_qry 288 | self.res = m_qry.shape[0] 289 | self.m_cross = m_cross 290 | self.out_dim = out_dim 291 | self.hid_dim = hid_dim 292 | self.n_head = n_head 293 | self.en_local = locality_encoder 294 | self.de_local = locality_decoder 295 | self.n_blocks = 4 296 | 297 | # Encoder 298 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 299 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 300 | 301 | # Processor 302 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 303 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 304 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 305 | 306 | # Decoder 307 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 308 | self.up2 = MultiHeadPosAtt(self.m_qry, self.n_head, self.hid_dim, locality=self.de_local) 309 | self.mlp = mlp(self.hid_dim, self.hid_dim) 310 | self.w = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 311 | self.de_layer = mlp(self.hid_dim, self.out_dim) 312 | 313 | def call(self, inputs): 314 | 315 | # Encoder 316 | grid = self.get_mesh(inputs) 317 | en = tf.concat([grid, inputs], axis=-1) 318 | en = self.en_layer(en) 319 | x = self.down(en) 320 | 321 | # Processor 322 | for i in range(self.n_blocks): 323 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 324 | x = tf.keras.activations.gelu(x) 325 | 326 | # Decoder 327 | de = self.up(x) 328 | de = self.mlp(self.up2(de)) + self.w(de) 329 | de = tf.keras.activations.gelu(de) 330 | de = self.de_layer(de) 331 | return de 332 | 333 | def get_mesh(self, inputs): 334 | grid = tf.reshape(tf.linspace(0, 1, self.res+1)[:-1], (1,-1,1)) 335 | grid = tf.repeat(grid, tf.shape(inputs)[0], 0) 336 | return tf.cast(grid, dtype="float32") 337 | 338 | def get_config(self): 339 | config = { 340 | 'm_qry':self.m_qry, 341 | 'm_cross':self.m_cross, 342 | 'out_dim': self.out_dim, 343 | 'hid_dim': self.hid_dim, 344 | 'n_head': self.n_head, 345 | 'locality_encoder': self.en_local, 346 | 'locality_decoder': self.de_local 347 | } 348 | return config 349 | 350 | class Transformer(tf.keras.Model): 351 | ''' 352 | Replace position-attention of a PiT with self-attention. 353 | ''' 354 | def __init__(self, res, out_dim, hid_dim, n_head): 355 | super(Transformer, self).__init__() 356 | 357 | self.res = res 358 | self.out_dim = out_dim 359 | self.hid_dim = hid_dim 360 | self.n_head = n_head 361 | self.n_blocks = 4 362 | 363 | # Encoder 364 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 365 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 366 | 367 | # Processor 368 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 369 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 370 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 371 | 372 | # Decoder 373 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 374 | self.up2 = MultiHeadSelfAtt(self.n_head, self.hid_dim) 375 | self.mlp = mlp(self.hid_dim, self.hid_dim) 376 | self.w = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 377 | self.de_layer = mlp(self.hid_dim, self.out_dim) 378 | 379 | def call(self, inputs): 380 | 381 | # Encoder 382 | grid = self.get_mesh(inputs) 383 | en = tf.concat([grid, inputs], axis=-1) 384 | en = self.en_layer(en) 385 | x = self.down(en) 386 | 387 | # Processor 388 | for i in range(self.n_blocks): 389 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 390 | x = tf.keras.activations.gelu(x) 391 | 392 | # Decoder 393 | de = self.up(x) 394 | de = self.mlp(self.up2(de)) + self.w(de) 395 | de = tf.keras.activations.gelu(de) 396 | de = self.de_layer(de) 397 | return de 398 | 399 | def get_mesh(self, inputs): 400 | grid = tf.reshape(tf.linspace(0, 1, self.res+1)[:-1], (1,-1,1)) 401 | grid = tf.repeat(grid, tf.shape(inputs)[0], 0) 402 | return tf.cast(grid, dtype="float32") 403 | 404 | def get_config(self): 405 | config = { 406 | 'res':self.res, 407 | 'out_dim': self.out_dim, 408 | 'hid_dim': self.hid_dim, 409 | 'n_head': self.n_head, 410 | } 411 | return config 412 | 413 | class SelfMultiHeadPosAtt(tf.keras.layers.Layer): 414 | ''' 415 | Combine self-attention with position-attention: A = QK^T/sqrt(d) - lambda D 416 | ''' 417 | def __init__(self, m_dist, n_head, hid_dim, locality): 418 | super(SelfMultiHeadPosAtt, self).__init__() 419 | 420 | self.dist = m_dist 421 | self.locality = locality 422 | self.hid_dim = hid_dim 423 | self.n_head = n_head 424 | self.v_dim = round(self.hid_dim/self.n_head) 425 | 426 | def build(self, input_shape): 427 | 428 | 429 | self.r = self.add_weight( 430 | shape=(self.n_head, 1, 1), 431 | trainable=True, 432 | constraint=tf.keras.constraints.NonNeg(), 433 | name="band_width", 434 | ) 435 | 436 | self.query = self.add_weight( 437 | shape=(self.n_head, input_shape[-1], self.v_dim), 438 | trainable=True, 439 | name="query", 440 | ) 441 | 442 | self.key = self.add_weight( 443 | shape=(self.n_head, input_shape[-1], self.v_dim), 444 | trainable=True, 445 | name="key", 446 | ) 447 | 448 | self.weight = self.add_weight( 449 | shape=(self.n_head, input_shape[-1], self.v_dim), 450 | initializer="he_normal", 451 | trainable=True, 452 | name="weight", 453 | ) 454 | self.built = True 455 | 456 | def call(self, inputs): 457 | 458 | scaled_dist = self.dist * self.r**2 459 | if self.locality <= 100: 460 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 461 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 462 | else: 463 | pass 464 | scaled_dist = - scaled_dist 465 | 466 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.query) 467 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.key) 468 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) 469 | product = tf.einsum("...mi,...ni->...mn", query, key) 470 | att = product / self.v_dim**0.5 + scaled_dist[tf.newaxis,...] 471 | att = tf.nn.softmax(att, axis=-1) 472 | concat = tf.einsum("...nj,...jd->...nd", att, value) 473 | concat = tf.transpose(concat, (0,2,1,3)) 474 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) 475 | 476 | return tf.keras.activations.gelu(concat) 477 | 478 | class SelfPiT(tf.keras.Model): 479 | ''' 480 | Replace position-attention of a PiT by SelfMultiHeadPosAtt 481 | ''' 482 | def __init__(self, m_qry, m_cross, m_ltt, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 483 | super(SelfPiT, self).__init__() 484 | 485 | self.m_qry = m_qry 486 | self.res = m_qry.shape[0] 487 | self.m_cross = m_cross 488 | self.m_ltt = m_ltt 489 | self.out_dim = out_dim 490 | self.hid_dim = hid_dim 491 | self.n_head = n_head 492 | self.en_local = locality_encoder 493 | self.de_local = locality_decoder 494 | self.n_blocks = 4 495 | 496 | # Encoder 497 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 498 | self.down = SelfMultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 499 | 500 | # Processor 501 | self.MHPA = [SelfMultiHeadPosAtt(self.m_ltt, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 502 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 503 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 504 | 505 | # Decoder 506 | self.up = SelfMultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 507 | self.up2 = SelfMultiHeadPosAtt(self.m_qry, self.n_head, self.hid_dim, locality=self.de_local) 508 | self.mlp = mlp(self.hid_dim, self.hid_dim) 509 | self.w = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 510 | self.de_layer = mlp(self.hid_dim, self.out_dim) 511 | 512 | def call(self, inputs): 513 | 514 | # Encoder 515 | grid = self.get_mesh(inputs) 516 | en = tf.concat([grid, inputs], axis=-1) 517 | en = self.en_layer(en) 518 | x = self.down(en) 519 | 520 | # Processor 521 | for i in range(self.n_blocks): 522 | x = self.MLP[i](self.MHPA[i](x)) + self.W[i](x) 523 | x = tf.keras.activations.gelu(x) 524 | 525 | # Decoder 526 | de = self.up(x) 527 | de = self.mlp(self.up2(de)) + self.w(de) 528 | de = tf.keras.activations.gelu(de) 529 | de = self.de_layer(de) 530 | return de 531 | 532 | def get_mesh(self, inputs): 533 | grid = tf.reshape(tf.linspace(0, 1, self.res+1)[:-1], (1,-1,1)) 534 | grid = tf.repeat(grid, tf.shape(inputs)[0], 0) 535 | return tf.cast(grid, dtype="float32") 536 | 537 | def get_config(self): 538 | config = { 539 | 'm_qry': self.m_qry, 540 | 'm_cross': self.m_cross, 541 | 'm_ltt': self.m_ltt, 542 | 'out_dim': self.out_dim, 543 | 'hid_dim': self.hid_dim, 544 | 'n_head': self.n_head, 545 | 'locality_encoder': self.en_local, 546 | 'locality_decoder': self.de_local, 547 | } 548 | return config 549 | 550 | 551 | -------------------------------------------------------------------------------- /tensorflow/3_Darcy2D/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | from scipy.io import savemat 6 | from utils import *#load_data, pairwise_dist, PixelWiseNormalization 7 | from numpy import zeros 8 | import matplotlib.pyplot as plt 9 | downsampling = 10 10 | qry_res = int((421-1)/downsampling+1) 11 | ltt_res = 32 12 | en_loc = 2 13 | de_loc = 5 14 | encode_dim = 128 15 | out_dim = 1 16 | n_head = 2 17 | n_train = 1024 18 | n_test = 100 19 | 20 | 21 | ## Load training data and construct normalizer 22 | trainX, trainY, testX, testY= load_data("./piececonst_r421_N1024_smooth1.mat", "./piececonst_r421_N1024_smooth2.mat", downsampling, n_train, n_test) 23 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 24 | XNormalizer = PixelWiseNormalization(trainX) 25 | YNormalizer = PixelWiseNormalization(trainY) 26 | 27 | ## creat model, load trained weights 28 | m_cross = pairwise_dist(qry_res, qry_res, ltt_res, ltt_res) # pairwise distance matrix for encoder and decoder 29 | m_latent = pairwise_dist(ltt_res, ltt_res, ltt_res, ltt_res) # pairwise distance matrix for processor 30 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc, YNormalizer) 31 | inputs = tf.keras.Input(shape=(qry_res,qry_res,trainX.shape[-1])) 32 | outputs = network(inputs) 33 | model = tf.keras.Model(inputs, outputs) 34 | model.load_weights('./results_43/checkpoints/my_checkpoint').expect_partial() 35 | 36 | #testX = XNormalizer.normalize(testX) 37 | #pred = tf.convert_to_tensor(model.predict(testX), "float32") 38 | #rel_err = rel_norm()(testY,pred) 39 | #print(rel_err) 40 | #pred = tf.convert_to_tensor(network(testX), "float32") 41 | #rel_err = rel_norm()(testY,pred) 42 | #print(rel_err) 43 | 44 | #network.summary() 45 | ################## 46 | #zeor-shot super-resolution 47 | downsampling = 1 48 | qry_res = int((421-1)/downsampling+1) 49 | trainX, trainY, testX, testY= load_data("./piececonst_r421_N1024_smooth1.mat", "./piececonst_r421_N1024_smooth2.mat", downsampling, n_train, n_test) 50 | m_cross = pairwise_dist(qry_res, qry_res, ltt_res, ltt_res) # pairwise distance matrix for encoder and decoder 51 | network2 = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, 2, 5, YNormalizer) 52 | inputs2 = tf.keras.Input(shape=(qry_res,qry_res,trainX.shape[-1])) 53 | outputs2 = network2(inputs2) 54 | model2 = tf.keras.Model(inputs2, outputs2) 55 | model2.set_weights(model.get_weights()) 56 | 57 | testX = XNormalizer.normalize(testX) 58 | pred = tf.convert_to_tensor(model2.predict(testX), "float32") 59 | rel_err = rel_norm()(testY,pred) 60 | print(rel_err) 61 | 62 | ######### plots 63 | index = 43 64 | a = XNormalizer.denormalize(testX)[index,:,:,0] 65 | u = testY[index,:,:,0] 66 | pred = pred[index,:,:,0] 67 | 68 | abs_err = tf.math.abs(u-pred) * 10000 69 | emax = tf.math.reduce_max(abs_err) 70 | emin = tf.math.reduce_min(abs_err) 71 | print(emax, emin) 72 | amax = tf.math.reduce_max(a) 73 | amin = tf.math.reduce_min(a) 74 | umax = tf.math.reduce_max(u) 75 | umin = tf.math.reduce_min(u) 76 | print(amax, amin, umax, umin) 77 | 78 | # plot the contours 79 | plt.figure(figsize=(17,4),dpi=300) 80 | plt.subplot(141) 81 | plt.imshow(a, vmax=12, vmin=3, interpolation='spline16', cmap='jet') 82 | plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[3, 12], format='%1.0f') 83 | plt.axis('off') 84 | plt.axis("equal") 85 | plt.title('Permeability') 86 | 87 | plt.subplot(142) 88 | plt.imshow(u, vmax=umax, vmin=umin, interpolation='spline16', cmap='jet') 89 | plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[umin, umax], format='%1.3f') 90 | plt.axis('off') 91 | plt.axis("equal") 92 | plt.title('Reference') 93 | 94 | plt.subplot(143) 95 | plt.imshow(pred, vmax=umax, vmin=umin, interpolation='spline16', cmap='jet') 96 | plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[umin, umax], format='%1.3f') 97 | plt.axis('off') 98 | plt.axis("equal") 99 | plt.title('Prediction') 100 | 101 | plt.subplot(144) 102 | plt.imshow(abs_err, vmax=emax, vmin=0, interpolation='spline16', cmap='jet') 103 | plt.colorbar(location="bottom", fraction=0.046, pad=0.04, ticks=[0, emax], format='%1.3f') 104 | plt.axis('off') 105 | plt.axis("equal") 106 | plt.title('Absolute error ('+r"$\times 10^{-4}$"+")") 107 | plt.savefig('./prediction.pdf') 108 | plt.close() 109 | 110 | 111 | -------------------------------------------------------------------------------- /tensorflow/3_Darcy2D/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | from scipy.io import savemat 8 | import matplotlib.pyplot as plt 9 | from time import time 10 | 11 | ### custom imports 12 | from utils import * 13 | 14 | # params ##################################33 15 | n_epochs = 500 16 | lr = 0.001 17 | batch_size = 8 18 | encode_dim = 128 19 | out_dim = 1 20 | n_head = 2 21 | n_train = 1024 22 | n_test = 100 23 | downsampling = 2 24 | qry_res = int((421-1)/downsampling+1) 25 | ltt_res = 32 26 | en_loc = 2 27 | de_loc = 5 28 | # load dataset 29 | trainX, trainY, testX, testY= load_data("./piececonst_r421_N1024_smooth1.mat", "./piececonst_r421_N1024_smooth2.mat", downsampling, n_train, n_test) 30 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 31 | 32 | ### normalize input data 33 | XNormalizer = PixelWiseNormalization(trainX) 34 | trainX = XNormalizer.normalize(trainX) 35 | testX = XNormalizer.normalize(testX) 36 | YNormalizer = PixelWiseNormalization(trainY) 37 | 38 | ### define a model 39 | m_cross = pairwise_dist(qry_res, qry_res, ltt_res, ltt_res) # pairwise distance matrix for encoder and decoder 40 | m_latent = pairwise_dist(ltt_res, ltt_res, ltt_res, ltt_res) # pairwise distance matrix for processor 41 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc, YNormalizer) 42 | inputs = tf.keras.Input(shape=(qry_res,qry_res,trainX.shape[-1])) 43 | outputs = network(inputs) 44 | model = tf.keras.Model(inputs, outputs) 45 | network.summary() 46 | ### compile model 47 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 48 | loss=rel_norm(), 49 | jit_compile=True) 50 | 51 | ### fit model 52 | start = time() 53 | train_history = model.fit(trainX, trainY, 54 | batch_size, n_epochs, verbose=1, 55 | validation_data=(testX, testY), 56 | validation_batch_size=50) 57 | end = time() 58 | print(' ') 59 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 60 | 61 | print(' ') 62 | ### save model 63 | model.save_weights('./checkpoints/my_checkpoint') 64 | ### training history 65 | loss_hist = train_history.history 66 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 67 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 68 | savetxt('./training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 69 | 70 | plt.plot(train_loss, label='train') 71 | plt.plot(test_loss, label='test') 72 | plt.legend() 73 | plt.yscale('log', base=10) 74 | plt.savefig('training_history.png') 75 | plt.close() 76 | -------------------------------------------------------------------------------- /tensorflow/3_Darcy2D/utils.py: -------------------------------------------------------------------------------- 1 | from numpy import mean, std, concatenate, newaxis, zeros 2 | from scipy.io import loadmat 3 | import tensorflow as tf 4 | tf.keras.utils.set_random_seed(0) 5 | tf.config.experimental.enable_op_determinism() 6 | tf.random.set_seed(0) 7 | physical_devices = tf.config.list_physical_devices('GPU') 8 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 9 | import tensorflow_probability as tfp 10 | from math import pi 11 | 12 | class PixelWiseNormalization(): 13 | def __init__(self, x, eps=1e-5): 14 | 15 | self.mean = tf.math.reduce_mean(x, axis=0, keepdims=True) #(1,h,w,1) 16 | self.std = tf.math.reduce_std(x, axis=0, keepdims=True) #(1,h,w,1) 17 | self.eps = eps 18 | 19 | def normalize(self, x): 20 | try: 21 | x = (x - self.mean) / (self.std + self.eps) 22 | except:#do upsampling 23 | h = x.shape[1] 24 | w = x.shape[2] 25 | mean = tf.image.resize(self.mean, (h,w), method='bilinear') 26 | std = tf.image.resize(self.std, (h,w), method='bilinear') 27 | x = (x - mean) / (std + self.eps) 28 | return x 29 | 30 | def denormalize(self, x): 31 | 32 | try: 33 | x = x * (self.std + self.eps) + self.mean 34 | except:#do upsampling 35 | h = x.shape[1] 36 | w = x.shape[2] 37 | mean = tf.image.resize(self.mean, (h,w), method='bilinear') 38 | std = tf.image.resize(self.std, (h,w), method='bilinear') 39 | x = x * (std + self.eps) + mean 40 | return x 41 | 42 | class rel_norm(tf.keras.losses.Loss): 43 | ''' 44 | Compute the average relative l2 loss between a batch of targets and predictions 45 | ''' 46 | def __init__(self): 47 | super().__init__() 48 | def call(self, true, pred): 49 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1,))(true-pred), axis=1), tf.norm(tf.keras.layers.Reshape((-1,))(true), axis=1)) 50 | return tf.math.reduce_mean(rel_error) 51 | 52 | 53 | def pairwise_dist(res1x, res1y, res2x, res2y): 54 | 55 | gridx = tf.reshape(tf.linspace(0, 1, res1x+1)[:-1], (1,-1,1)) 56 | gridx = tf.tile(gridx, [res1y,1,1]) 57 | gridy = tf.reshape(tf.linspace(0, 1, res1y+1)[:-1], (-1,1,1)) 58 | gridy = tf.tile(gridy, [1,res1x,1]) 59 | grid1 = tf.concat([gridx, gridy], axis=-1) 60 | grid1 = tf.reshape(grid1, (res1x*res1y,2)) #(res1*res1,2) 61 | 62 | gridx = tf.reshape(tf.linspace(0, 1, res2x+1)[:-1], (1,-1,1)) 63 | gridx = tf.tile(gridx, [res2y,1,1]) 64 | gridy = tf.reshape(tf.linspace(0, 1, res2y+1)[:-1], (-1,1,1)) 65 | gridy = tf.tile(gridy, [1,res2x,1]) 66 | grid2 = tf.concat([gridx, gridy], axis=-1) 67 | grid2 = tf.reshape(grid2, (res2x*res2y,2)) #(res2*res2,2) 68 | 69 | print(grid1.shape, grid2.shape) 70 | grid1 = tf.tile(tf.expand_dims(grid1, 1), [1,grid2.shape[0],1]) 71 | grid2 = tf.tile(tf.expand_dims(grid2, 0), [grid1.shape[0],1,1]) 72 | 73 | dist = tf.norm(grid1-grid2, axis=-1) 74 | print(dist.shape, tf.math.reduce_max(dist)) 75 | dist2 = tf.cast(dist**2, 'float32') 76 | return dist2/2.0 77 | 78 | def load_data(train_path, test_path, downsampling, ntrain, ntest): 79 | 80 | s = int(((421 - 1)/downsampling) + 1) 81 | 82 | train = loadmat(train_path) 83 | a = train["coeff"].astype('float32') 84 | u = train["sol"].astype('float32') 85 | 86 | trainX = a[:ntrain,::downsampling,::downsampling][:,:s,:s] 87 | trainY = u[:ntrain,::downsampling,::downsampling][:,:s,:s] 88 | 89 | test = loadmat(test_path) 90 | a = test["coeff"].astype('float32') 91 | u = test["sol"].astype('float32') 92 | testX = a[:ntest,::downsampling,::downsampling][:,:s,:s] 93 | testY = u[:ntest,::downsampling,::downsampling][:,:s,:s] 94 | return tf.convert_to_tensor(trainX[...,newaxis]), tf.convert_to_tensor(trainY[...,newaxis]), tf.convert_to_tensor(testX[...,newaxis]), tf.convert_to_tensor(testY[...,newaxis]) 95 | 96 | class mlp(tf.keras.layers.Layer): 97 | ''' 98 | A two-layer MLP with GELU activation. 99 | ''' 100 | def __init__(self, n_filters1, n_filters2): 101 | super(mlp, self).__init__() 102 | 103 | self.width1 = n_filters1 104 | self.width2 = n_filters2 105 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 106 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 107 | 108 | def call(self, inputs): 109 | x = self.mlp1(inputs) 110 | x = self.mlp2(x) 111 | return x 112 | 113 | def get_config(self): 114 | config = { 115 | 'n_filters1': self.width1, 116 | 'n_filters2': self.width2, 117 | } 118 | return config 119 | 120 | class MultiHeadPosAtt(tf.keras.layers.Layer): 121 | ''' 122 | Global, local and cross variants of the multi-head position-attention mechanism. 123 | ''' 124 | def __init__(self, m_dist, n_head, hid_dim, locality): 125 | super(MultiHeadPosAtt, self).__init__() 126 | ''' 127 | m_dist: distance matrix 128 | n_head: number of attention heads 129 | hid_dim: encoding dimension 130 | locality: quantile parameter to customize receptive field in position-attention 131 | ''' 132 | self.dist = m_dist 133 | self.locality = locality 134 | self.hid_dim = hid_dim 135 | self.n_head = n_head 136 | self.v_dim = round(self.hid_dim/self.n_head) 137 | 138 | def build(self, input_shape): 139 | 140 | self.r = self.add_weight( 141 | shape=(self.n_head, 1, 1), 142 | trainable=True, 143 | # constraint=tf.keras.constraints.NonNeg(), 144 | name="band_width", 145 | ) 146 | 147 | self.weight = self.add_weight( 148 | shape=(self.n_head, input_shape[-1], self.v_dim), 149 | initializer="he_normal", 150 | trainable=True, 151 | name="weight", 152 | ) 153 | self.built = True 154 | 155 | def call(self, inputs): 156 | scaled_dist = self.dist * tf.math.tan(0.25*pi*(1-1e-7)*(1.0+tf.math.sin(self.r)))# tan(0.25*pi*(1+sin(r))) leads to higher accuracy than using r^2 and tan(r) # (n_head, L, L) 157 | if self.locality <= 100: 158 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 159 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 160 | else: 161 | pass 162 | scaled_dist = - scaled_dist 163 | att = tf.nn.softmax(scaled_dist, axis=2) #(n_heads, L, L) 164 | 165 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) # (batch_size, n_head, L, v_dim) 166 | 167 | concat = tf.einsum("hnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 168 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 169 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 170 | return tf.keras.activations.gelu(concat) 171 | 172 | def get_config(self): 173 | config = { 174 | 'm_dist': self.dist, 175 | 'hid_dim': self.hid_dim, 176 | 'n_head': self.n_head, 177 | 'locality': self.locality 178 | } 179 | return config 180 | 181 | class PiT(tf.keras.Model): 182 | ''' 183 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 184 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 185 | ''' 186 | def __init__(self, m_cross, m_small, out_dim, hid_dim, n_head, locality_encoder, locality_decoder, YNormalizer): 187 | super(PiT, self).__init__() 188 | ''' 189 | m_cross: distance matrix between X_query and X_latent; (L_qry,L_ltt) 190 | m_small: distance matrix between X_latent and X_latent; (L_ltt,L_ltt) 191 | out_dim: number of variables 192 | hid_dim: encoding dimension (network width) 193 | n_head: number of heads in multi-head attention modules 194 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive field 195 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive field 196 | YNormalizer: PixelWiseNormalization object with mean and std of the training data 197 | ''' 198 | self.m_cross = m_cross 199 | self.m_small = m_small 200 | self.res = int(m_cross.shape[0]**0.5) 201 | self.out_dim = out_dim 202 | self.hid_dim = hid_dim 203 | self.n_head = n_head 204 | self.en_local = locality_encoder 205 | self.de_local = locality_decoder 206 | self.n_blocks = 4 207 | self.YNormalizer = YNormalizer 208 | 209 | # Encoder 210 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 211 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 212 | 213 | # Processor 214 | self.PA = [MultiHeadPosAtt(self.m_small, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 215 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 216 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 217 | 218 | # Decoder 219 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 220 | self.de_layer = mlp(self.hid_dim, self.out_dim) 221 | 222 | def call(self, inputs): 223 | 224 | # Encoder 225 | grid = self.get_mesh(inputs) #(batch_size, res_qry, res_qry, 2) 226 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) # (batch_size, res_qry, res_qry, input_dim+2) 227 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) # (batch_size, L_qry, hid_dim) 228 | en = self.en_layer(en) # (batch_size, L_qry, hid_dim) 229 | x = self.down(en) # (batch_size, L_ltt, hid_dim) 230 | 231 | # Processor 232 | for i in range(self.n_blocks): 233 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) # (batch_size, L_ltt, hid_dim) 234 | x = tf.keras.activations.gelu(x) # (batch_size, L_ltt, hid_dim) 235 | 236 | # Decoder 237 | de = self.up(x) # (batch_size, L_qry, hid_dim) 238 | de = self.de_layer(de) # (batch_size, L_qry, hid_dim) 239 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) # (batch_size, res_qry, res_qry, out_dim) 240 | return self.YNormalizer.denormalize(de) 241 | 242 | def get_mesh(self, inputs): 243 | size_x = size_y = self.res 244 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 245 | gridx = tf.tile(gridx, [1,1,size_y,1]) 246 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 247 | gridy = tf.tile(gridy, [1,size_x,1,1]) 248 | grid = tf.concat([gridx, gridy], axis=-1) 249 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 250 | 251 | def get_config(self): 252 | config = { 253 | 'm_cross': self.m_cross, 254 | 'm_small': self.m_small, 255 | 'out_dim': self.out_dim, 256 | 'hid_dim': self.hid_dim, 257 | 'n_head': self.n_head, 258 | 'locality_encoder': self.en_local, 259 | 'locality_decoder': self.de_local, 260 | 'YNormalizer': self.YNormalizer 261 | } 262 | return config 263 | 264 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 265 | ''' 266 | Scaled dot-product multi-head self-attention 267 | ''' 268 | def __init__(self, n_head, hid_dim): 269 | super(MultiHeadSelfAtt, self).__init__() 270 | 271 | self.hid_dim = hid_dim 272 | self.n_head = n_head 273 | self.v_dim = round(self.hid_dim/self.n_head) 274 | 275 | def build(self, input_shape): 276 | 277 | self.q = self.add_weight( 278 | shape=(self.n_head, input_shape[-1], self.v_dim), 279 | initializer="he_normal", 280 | trainable=True, 281 | name="query", 282 | ) 283 | 284 | self.k = self.add_weight( 285 | shape=(self.n_head, input_shape[-1], self.v_dim), 286 | initializer="he_normal", 287 | trainable=True, 288 | name="key", 289 | ) 290 | 291 | self.v = self.add_weight( 292 | shape=(self.n_head, input_shape[-1], self.v_dim), 293 | initializer="he_normal", 294 | trainable=True, 295 | name="value", 296 | ) 297 | self.built = True 298 | 299 | def call(self, inputs): 300 | 301 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q) # (batch_size, n_head, L, v_dim) 302 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k) # (batch_size, n_head, L, v_dim) 303 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1) # (batch_size, n_heads, L, L) 304 | 305 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v) #(batch_size, n_head, L, v_dim) 306 | 307 | concat = tf.einsum("...nj,...jd->...nd", att, value) #(batch_size, n_head, L, v_dim) 308 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 309 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 310 | return tf.keras.activations.gelu(concat) 311 | 312 | def get_config(self): 313 | config = { 314 | 'n_head': self.n_head, 315 | 'hid_dim': self.hid_dim 316 | } 317 | return config 318 | 319 | class LiteTransformer(tf.keras.Model): 320 | ''' 321 | Replace position-attention of the Processor in a PiT with self-attention 322 | ''' 323 | def __init__(self, m_cross, res, out_dim, hid_dim, n_head, en_local, de_local, YNormalizer): 324 | super(LiteTransformer, self).__init__() 325 | 326 | self.m_cross = m_cross 327 | self.res = res 328 | self.out_dim = out_dim 329 | self.hid_dim = hid_dim 330 | self.n_head = n_head 331 | self.en_local = en_local 332 | self.de_local = de_local 333 | self.YNormalizer = YNormalizer 334 | self.n_blocks = 4 335 | 336 | # Encoder 337 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 338 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 339 | 340 | # Processor 341 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 342 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 343 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 344 | 345 | # Decoder 346 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 347 | self.de_layer = mlp(self.hid_dim, self.out_dim) 348 | 349 | def call(self, inputs): 350 | 351 | # Encoder 352 | grid = self.get_mesh(inputs) 353 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 354 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 355 | en = self.en_layer(en) 356 | x = self.down(en) 357 | 358 | # Processor 359 | for i in range(self.n_blocks): 360 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 361 | x = tf.keras.activations.gelu(x) 362 | 363 | # Decoder 364 | de = self.up(x) 365 | de = self.de_layer(de) 366 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 367 | return self.YNormalizer.denormalize(de) 368 | 369 | def get_mesh(self, inputs): 370 | size_x = size_y = self.res 371 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 372 | gridx = tf.tile(gridx, [1,1,size_y,1]) 373 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 374 | gridy = tf.tile(gridy, [1,size_x,1,1]) 375 | grid = tf.concat([gridx, gridy], axis=-1) 376 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 377 | 378 | def get_config(self): 379 | config = { 380 | 'm_cross': self.m_cross, 381 | 'res': self.res, 382 | 'out_dim': self.out_dim, 383 | 'hid_dim': self.hid_dim, 384 | 'n_head': self.n_head, 385 | 'locality_encoder': self.en_local, 386 | 'locality_decoder': self.de_local, 387 | 'YNormalizer': self.YNormalizer 388 | } 389 | return config 390 | 391 | class Transformer(tf.keras.Model): 392 | ''' 393 | Replace position-attention of a PiT with self-attention. 394 | ''' 395 | def __init__(self, res, out_dim, hid_dim, n_head, YNormalizer): 396 | super(Transformer, self).__init__() 397 | 398 | self.res = res 399 | self.out_dim = out_dim 400 | self.hid_dim = hid_dim 401 | self.n_head = n_head 402 | self.n_blocks = 4 403 | self.YNormalizer = YNormalizer 404 | 405 | # Encoder 406 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 407 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 408 | 409 | # Processor 410 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 411 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 412 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 413 | 414 | # Decoder 415 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 416 | self.de_layer = mlp(self.hid_dim, self.out_dim) 417 | 418 | def call(self, inputs): 419 | 420 | # Encoder 421 | grid = self.get_mesh(inputs) 422 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 423 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 424 | en = self.en_layer(en) 425 | x = self.down(en) 426 | 427 | # Processor 428 | for i in range(self.n_blocks): 429 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 430 | x = tf.keras.activations.gelu(x) 431 | 432 | # Decoder 433 | de = self.up(x) 434 | de = self.de_layer(de) 435 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 436 | return self.YNormalizer.denormalize(de) 437 | 438 | def get_mesh(self, inputs): 439 | size_x = size_y = self.res 440 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 441 | gridx = tf.tile(gridx, [1,1,size_y,1]) 442 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 443 | gridy = tf.tile(gridy, [1,size_x,1,1]) 444 | grid = tf.concat([gridx, gridy], axis=-1) 445 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 446 | 447 | def get_config(self): 448 | config = { 449 | 'res':self.res, 450 | 'out_dim': self.out_dim, 451 | 'hid_dim': self.hid_dim, 452 | 'n_head': self.n_head, 453 | 'YNormalizer': self.YNormalizer 454 | } 455 | return config 456 | 457 | 458 | -------------------------------------------------------------------------------- /tensorflow/4_Vorticity/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | from scipy.io import loadmat 6 | from utils import *#load_data, pairwise_dist, PixelWiseNormalization 7 | from numpy import zeros, mean, savetxt 8 | from numpy.linalg import norm 9 | import matplotlib.pyplot as plt 10 | 11 | # params ##################################33 12 | 13 | # load dataset 14 | data = loadmat('./pred.mat') 15 | testY = data['true'] 16 | pred = data['pred'] 17 | rel_err = rel_norm_traj(testY,pred) 18 | print(rel_err.numpy()) 19 | err = testY - pred 20 | T = 20 21 | rel_err = norm(err.reshape(200,-1,T), ord=2, axis=1) / norm(testY.reshape(200,-1,T), ord=2, axis=1) 22 | rel_err = mean(rel_err, axis=0) 23 | plt.figure(figsize=(16,9),dpi=100) 24 | plt.plot(tf.range(11,31), rel_err) 25 | plt.xlabel('t') 26 | plt.savefig('./rel_err.pdf') 27 | savetxt('./rel_err.csv', rel_err) 28 | ############## plots 29 | index = 0 30 | abs_err = abs(err[index,...]) 31 | emax = abs_err.max() 32 | emin = abs_err.min() 33 | print(emax, emin) 34 | omega = testY[index,...] 35 | vmax = omega.max() 36 | vmin = omega.min() 37 | print(vmax, vmin) 38 | omega_p = pred[index,...] 39 | 40 | for i in range(T): 41 | # plot the contours 42 | plt.figure(figsize=(4,4),dpi=300) 43 | plt.axes([0,0,1,1]) 44 | plt.imshow(omega[...,i], vmax=vmax, vmin=vmin, interpolation='spline16', cmap='jet') 45 | plt.axis('off') 46 | plt.axis('equal') 47 | plt.savefig('./testY_{}.pdf'.format(i+1)) 48 | plt.close() 49 | 50 | plt.figure(figsize=(4,4),dpi=300) 51 | plt.axes([0,0,1,1]) 52 | plt.imshow(omega_p[...,i], vmax=vmax, vmin=vmin, interpolation='spline16', cmap='jet') 53 | plt.axis('off') 54 | plt.axis('equal') 55 | plt.savefig('./pred_{}.pdf'.format(i+1)) 56 | plt.close() 57 | 58 | plt.figure(figsize=(4,4),dpi=300) 59 | plt.axes([0,0,1,1]) 60 | plt.imshow(abs_err[...,i], vmax=emax, vmin=emin, interpolation='spline16', cmap='jet') 61 | plt.axis('off') 62 | plt.axis('equal') 63 | plt.savefig('./err_{}.pdf'.format(i+1)) 64 | plt.close() 65 | 66 | 67 | -------------------------------------------------------------------------------- /tensorflow/4_Vorticity/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | import matplotlib.pyplot as plt 8 | from time import time 9 | 10 | ### custom imports 11 | from utils import * 12 | 13 | # params ##################################33 14 | n_epochs = 500 15 | lr = 0.001 16 | batch_size = 8 17 | encode_dim = 256 18 | out_dim = 1 19 | n_head = 1 20 | n_train = 1000 21 | n_test = 200 22 | steps = 20 23 | en_loc = 1 24 | de_loc = 8 25 | # load dataset 26 | trainX, trainY, testX, testY= load_data("./NavierStokes_V1e-4_N1200_T30.mat", n_train, n_test) 27 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 28 | 29 | # define a model 30 | m_cross = pairwise_dist(64, 64, 16, 16) # pairwise distance matrix for encoder and decoder 31 | m_latent = pairwise_dist(16, 16, 16, 16) # pairwise distance matrix for processor 32 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc) 33 | r_network = reccurent_PiT(network, steps) 34 | inputs = tf.keras.Input(shape=(64,64,trainX.shape[-1])) 35 | outputs = r_network(inputs) 36 | model = tf.keras.Model(inputs, outputs) 37 | r_network.summary() 38 | ### compile model 39 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 40 | loss=rel_norm_step(steps), 41 | jit_compile=True) 42 | 43 | ### fit model 44 | start = time() 45 | train_history = model.fit(trainX, trainY, 46 | batch_size, n_epochs, verbose=1, 47 | validation_data=(testX, testY), 48 | validation_batch_size=20) 49 | end = time() 50 | print(' ') 51 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 52 | 53 | print(' ') 54 | ### save model 55 | model.save_weights('./checkpoints/my_checkpoint') 56 | ### training history 57 | loss_hist = train_history.history 58 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 59 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 60 | savetxt('./training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 61 | 62 | plt.plot(train_loss, label='train') 63 | plt.plot(test_loss, label='test') 64 | plt.legend() 65 | plt.yscale('log', base=10) 66 | plt.savefig('training_history.png') 67 | plt.close() 68 | 69 | ### evaluation, visualization 70 | from scipy.io import savemat 71 | pred = model.predict(testX) 72 | print(rel_norm_traj(testY,pred)) 73 | savemat("pred.mat", mdict={"pred":pred, "true":testY}) 74 | -------------------------------------------------------------------------------- /tensorflow/4_Vorticity/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | import tensorflow as tf 3 | tf.keras.utils.set_random_seed(0) 4 | tf.config.experimental.enable_op_determinism() 5 | tf.random.set_seed(0) 6 | physical_devices = tf.config.list_physical_devices('GPU') 7 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 8 | import tensorflow_probability as tfp 9 | from math import pi 10 | 11 | class rel_norm_step(tf.keras.losses.Loss): 12 | ''' 13 | Compute the average relative l2 loss between a batch of targets and predictions, step wise 14 | ''' 15 | def __init__(self, steps): 16 | super(rel_norm_step, self).__init__() 17 | self.steps = steps 18 | 19 | def call(self, true, pred): 20 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1, self.steps))(true-pred), axis=1), tf.norm(tf.keras.layers.Reshape((-1, self.steps))(true), axis=1)) # step-wise relative l2 loss# 21 | return tf.math.reduce_mean(rel_error) 22 | 23 | def get_config(self): 24 | config = { 25 | 'steps': self.steps 26 | } 27 | return config 28 | 29 | def rel_norm_traj(true, pred): 30 | ''' 31 | Compute the average relative l2 loss between a batch of targets and predictions, full trajectory 32 | ''' 33 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1,))(true-pred), axis=1), tf.norm(tf.keras.layers.Reshape((-1,))(true), axis=1)) 34 | return tf.math.reduce_mean(rel_error) 35 | 36 | 37 | def pairwise_dist(res1x, res1y, res2x, res2y): 38 | 39 | gridx = tf.reshape(tf.linspace(0, 1, res1x+1)[:-1], (1,-1,1)) 40 | gridx = tf.tile(gridx, [res1y,1,1]) 41 | gridy = tf.reshape(tf.linspace(0, 1, res1y+1)[:-1], (-1,1,1)) 42 | gridy = tf.tile(gridy, [1,res1x,1]) 43 | grid1 = tf.concat([gridx, gridy], axis=-1) 44 | grid1 = tf.reshape(grid1, (res1x*res1y,2)) #(res1*res1,2) 45 | 46 | gridx = tf.reshape(tf.linspace(0, 1, res2x+1)[:-1], (1,-1,1)) 47 | gridx = tf.tile(gridx, [res2y,1,1]) 48 | gridy = tf.reshape(tf.linspace(0, 1, res2y+1)[:-1], (-1,1,1)) 49 | gridy = tf.tile(gridy, [1,res2x,1]) 50 | grid2 = tf.concat([gridx, gridy], axis=-1) 51 | grid2 = tf.reshape(grid2, (res2x*res2y,2)) #(res2*res2,2) 52 | 53 | print(grid1.shape, grid2.shape) 54 | grid1 = tf.tile(tf.expand_dims(grid1, 1), [1,grid2.shape[0],1]) 55 | grid2 = tf.tile(tf.expand_dims(grid2, 0), [grid1.shape[0],1,1]) 56 | 57 | dist = tf.math.minimum(tf.norm(grid1-grid2, axis=-1), tf.norm(grid1+tf.constant([[1,0]], dtype='float64')-grid2, axis=-1)) 58 | dist = tf.math.minimum(dist, tf.norm(grid1+tf.constant([[-1,0]], dtype='float64')-grid2, axis=-1)) 59 | dist = tf.math.minimum(dist, tf.norm(grid1+tf.constant([[0,1]], dtype='float64')-grid2, axis=-1)) 60 | dist = tf.math.minimum(dist, tf.norm(grid1+tf.constant([[0,-1]], dtype='float64')-grid2, axis=-1)) 61 | dist2 = tf.cast(dist**2, 'float32') 62 | return dist2 63 | 64 | def load_data(file_path, ntrain, ntest): 65 | 66 | try: 67 | data = loadmat(file_path) 68 | except: 69 | import mat73 70 | data = mat73.loadmat(file_path) 71 | flow = data['u'].astype('float32') 72 | print(flow.shape) 73 | del data 74 | 75 | trainX = flow[:ntrain,:,:,:10] # ntrain 1000 76 | trainY = flow[:ntrain,:,:,10:30] 77 | testX = flow[-ntest:,:,:,:10] 78 | testY = flow[-ntest:,:,:,10:30] 79 | 80 | del flow 81 | return trainX, trainY, testX, testY 82 | 83 | 84 | class mlp(tf.keras.layers.Layer): 85 | ''' 86 | A two-layer MLP with GELU activation. 87 | ''' 88 | def __init__(self, n_filters1, n_filters2): 89 | super(mlp, self).__init__() 90 | 91 | self.width1 = n_filters1 92 | self.width2 = n_filters2 93 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 94 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 95 | 96 | def call(self, inputs): 97 | x = self.mlp1(inputs) 98 | x = self.mlp2(x) 99 | return x 100 | 101 | def get_config(self): 102 | config = { 103 | 'n_filters1': self.width1, 104 | 'n_filters2': self.width2 105 | } 106 | return config 107 | 108 | class reccurent_PiT(tf.keras.Model): 109 | def __init__(self, network, steps): 110 | super(reccurent_PiT, self).__init__() 111 | self.PiT = network 112 | self.steps = steps 113 | 114 | def call(self, inputs): 115 | x = tf.identity(inputs) 116 | pred = x[...,-1:] 117 | for t in range(self.steps): 118 | y = self.PiT(x) 119 | pred = tf.concat([pred,y], axis=-1) 120 | x = tf.concat([x[...,1:],y], axis=-1) 121 | return pred[...,1:] 122 | def get_config(self): 123 | config = super(reccurent_PiT, self).get_config() 124 | config.update({ 125 | 'network': self.PiT.get_config(), # Store the config of the network instead of the network itself 126 | 'steps': self.steps, 127 | }) 128 | return config 129 | 130 | @classmethod 131 | def from_config(cls, config): 132 | network_config = config.pop('network') 133 | network = tf.keras.Model.from_config(network_config) 134 | return cls(network=network, **config) 135 | 136 | class MultiHeadPosAtt(tf.keras.layers.Layer): 137 | ''' 138 | Global, local and cross variants of the multi-head position-attention mechanism. 139 | ''' 140 | def __init__(self, m_dist, n_head, hid_dim, locality): 141 | super(MultiHeadPosAtt, self).__init__() 142 | ''' 143 | m_dist: distance matrix 144 | n_head: number of attention heads 145 | hid_dim: encoding dimension 146 | locality: quantile parameter to customize receptive field in position-attention 147 | ''' 148 | self.dist = m_dist 149 | self.locality = locality 150 | self.hid_dim = hid_dim 151 | self.n_head = n_head 152 | self.v_dim = round(self.hid_dim/self.n_head) 153 | 154 | def build(self, input_shape): 155 | 156 | self.r = self.add_weight( 157 | shape=(self.n_head, 1, 1), 158 | trainable=True, 159 | name="band_width", 160 | ) 161 | 162 | self.weight = self.add_weight( 163 | shape=(self.n_head, input_shape[-1], self.v_dim), 164 | initializer="he_normal", 165 | trainable=True, 166 | name="weight", 167 | ) 168 | self.built = True 169 | 170 | def call(self, inputs): 171 | scaled_dist = self.dist * tf.math.tan(0.25*pi*(1-1e-7)*(1.0+tf.math.sin(self.r)))# tan(0.25*pi*(1+sin(r))) leads to higher accuracy than using r^2 and tan(r) # (n_head, L, L) 172 | if self.locality <= 100: 173 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 174 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 175 | else: 176 | pass 177 | scaled_dist = - scaled_dist # (n_head, L, L) 178 | att = tf.nn.softmax(scaled_dist, axis=2) # (n_heads, L, L) 179 | 180 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) # (batch_size, n_head, L, v_dim) 181 | 182 | concat = tf.einsum("hnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 183 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 184 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 185 | return tf.keras.activations.gelu(concat) 186 | 187 | def get_config(self): 188 | config = { 189 | 'm_dist': self.dist, 190 | 'hid_dim': self.hid_dim, 191 | 'n_head': self.n_head, 192 | 'locality': self.locality 193 | } 194 | return config 195 | 196 | class PiT(tf.keras.layers.Layer): 197 | ''' 198 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 199 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 200 | ''' 201 | def __init__(self, m_cross, m_ltt, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 202 | super(PiT, self).__init__() 203 | ''' 204 | m_cross: distance matrix between X_query and X_latent; (L_qry,L_ltt) 205 | m_ltt: distance matrix between X_latent and X_latent; (L_ltt,L_ltt) 206 | out_dim: number of variables 207 | hid_dim: encoding dimension (network width) 208 | n_head: number of heads in multi-head attention modules 209 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive filed 210 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive filed 211 | ''' 212 | self.m_cross = m_cross 213 | self.m_ltt = m_ltt 214 | self.res = int(m_cross.shape[0]**0.5) 215 | self.out_dim = out_dim 216 | self.hid_dim = hid_dim 217 | self.n_head = n_head 218 | self.en_local = locality_encoder 219 | self.de_local = locality_decoder 220 | self.n_blocks = 4 # number of position-attention modules in the Processor 221 | 222 | # Encoder 223 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 224 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 225 | 226 | # Processor 227 | self.MHPA = [MultiHeadPosAtt(self.m_ltt, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 228 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 229 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 230 | 231 | # Decoder 232 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 233 | self.de_layer = mlp(self.hid_dim, self.out_dim) 234 | 235 | def call(self, inputs): 236 | 237 | # Encoder 238 | grid = self.get_mesh(inputs) #(batch_size, res_qry, res_qry, 2) 239 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) # (batch_size, res_qry, res_qry, input_dim+2) 240 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) # (batch_size, L_qry, hid_dim) 241 | en = self.en_layer(en) # (batch_size, L_qry, hid_dim) 242 | x = self.down(en) # (batch_size, L_ltt, hid_dim) 243 | 244 | # Processor 245 | for i in range(self.n_blocks): 246 | x = self.MLP[i](self.MHPA[i](x)) + self.W[i](x) # (batch_size, L_ltt, hid_dim) 247 | x = tf.keras.activations.gelu(x) # (batch_size, L_ltt, hid_dim) 248 | 249 | # Decoder 250 | de = self.up(x) # (batch_size, L_qry, hid_dim) 251 | de = self.de_layer(de) # (batch_size, L_qry, out_dim) 252 | de = tf.keras.layers.Reshape((self.res, self.res, self.out_dim))(de) # (batch_size, res_qry, res_qry, out_dim) 253 | return de 254 | 255 | def get_mesh(self, inputs): 256 | size_x = size_y = self.res 257 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 258 | gridx = tf.tile(gridx, [1,1,size_y,1]) 259 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 260 | gridy = tf.tile(gridy, [1,size_x,1,1]) 261 | grid = tf.concat([gridx, gridy], axis=-1) 262 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 263 | 264 | def get_config(self): 265 | config = { 266 | 'm_cross': self.m_cross, 267 | 'm_ltt': self.m_ltt, 268 | 'out_dim': self.out_dim, 269 | 'hid_dim': self.hid_dim, 270 | 'n_head': self.n_head, 271 | 'locality_encoder': self.en_local, 272 | 'locality_decoder': self.de_local 273 | } 274 | return config 275 | 276 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 277 | ''' 278 | Scaled dot-product multi-head self-attention 279 | ''' 280 | def __init__(self, n_head, hid_dim): 281 | super(MultiHeadSelfAtt, self).__init__() 282 | 283 | self.hid_dim = hid_dim 284 | self.n_head = n_head 285 | self.v_dim = round(self.hid_dim/self.n_head) 286 | 287 | def build(self, input_shape): 288 | 289 | self.q = self.add_weight( 290 | shape=(self.n_head, input_shape[-1], self.v_dim), 291 | initializer="he_normal", 292 | trainable=True, 293 | name="query", 294 | ) 295 | 296 | self.k = self.add_weight( 297 | shape=(self.n_head, input_shape[-1], self.v_dim), 298 | initializer="he_normal", 299 | trainable=True, 300 | name="key", 301 | ) 302 | 303 | self.v = self.add_weight( 304 | shape=(self.n_head, input_shape[-1], self.v_dim), 305 | initializer="he_normal", 306 | trainable=True, 307 | name="value", 308 | ) 309 | self.built = True 310 | 311 | def call(self, inputs): 312 | 313 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q) # (batch_size, n_head, L, v_dim) 314 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k) # (batch_size, n_head, L, v_dim) 315 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1) # (batch_size, n_heads, L, L) 316 | 317 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v) #(batch_size, n_head, L, v_dim) 318 | 319 | concat = tf.einsum("...nj,...jd->...nd", att, value) #(batch_size, n_head, L, v_dim) 320 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 321 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 322 | return tf.keras.activations.gelu(concat) 323 | 324 | def get_config(self): 325 | config = { 326 | 'n_head': self.n_head, 327 | 'hid_dim': self.hid_dim 328 | } 329 | return config 330 | 331 | class LiteTransformer(tf.keras.Model): 332 | ''' 333 | Replace position-attention of the Processor in a PiT with self-attention 334 | ''' 335 | def __init__(self, m_cross, res, out_dim, hid_dim, n_head, en_local, de_local): 336 | super(LiteTransformer, self).__init__() 337 | 338 | self.m_cross = m_cross 339 | self.res = res 340 | self.out_dim = out_dim 341 | self.hid_dim = hid_dim 342 | self.n_head = n_head 343 | self.en_local = en_local 344 | self.de_local = de_local 345 | self.n_blocks = 4 346 | 347 | # Encoder 348 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 349 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 350 | 351 | # Processor 352 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 353 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 354 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 355 | 356 | # Decoder 357 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 358 | self.de_layer = mlp(self.hid_dim, self.out_dim) 359 | 360 | def call(self, inputs): 361 | 362 | # Encoder 363 | grid = self.get_mesh(inputs) #(b, s1, s2, 2) 364 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 365 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 366 | en = self.en_layer(en) 367 | x = self.down(en) 368 | 369 | # Processor 370 | for i in range(self.n_blocks): 371 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 372 | x = tf.keras.activations.gelu(x) 373 | 374 | # Decoder 375 | de = self.up(x) 376 | de = self.de_layer(de) 377 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 378 | return de 379 | 380 | def get_mesh(self, inputs): 381 | size_x = size_y = self.res 382 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 383 | gridx = tf.tile(gridx, [1,1,size_y,1]) 384 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 385 | gridy = tf.tile(gridy, [1,size_x,1,1]) 386 | grid = tf.concat([gridx, gridy], axis=-1) 387 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 388 | 389 | def get_config(self): 390 | config = { 391 | 'm_cross': self.m_cross, 392 | 'res': self.res, 393 | 'out_dim': self.out_dim, 394 | 'hid_dim': self.hid_dim, 395 | 'n_head': self.n_head, 396 | 'locality_encoder': self.en_local, 397 | 'locality_decoder': self.de_local, 398 | } 399 | return config 400 | 401 | class Transformer(tf.keras.Model): 402 | ''' 403 | Replace position-attention of a PiT with self-attention. 404 | ''' 405 | def __init__(self, res, out_dim, hid_dim, n_head): 406 | super(Transformer, self).__init__() 407 | 408 | self.res = res 409 | self.out_dim = out_dim 410 | self.hid_dim = hid_dim 411 | self.n_head = n_head 412 | self.n_blocks = 4 413 | 414 | # Encoder 415 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 416 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 417 | 418 | # Processor 419 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 420 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 421 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 422 | 423 | # Decoder 424 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 425 | self.de_layer = mlp(self.hid_dim, self.out_dim) 426 | 427 | def call(self, inputs): 428 | 429 | # Encoder 430 | grid = self.get_mesh(inputs) #(b, s1, s2, 2) 431 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 432 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 433 | en = self.en_layer(en) 434 | x = self.down(en) 435 | 436 | # Processor 437 | for i in range(self.n_blocks): 438 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 439 | x = tf.keras.activations.gelu(x) 440 | 441 | # Decoder 442 | de = self.up(x) 443 | de = self.de_layer(de) 444 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 445 | return de 446 | 447 | def get_mesh(self, inputs): 448 | size_x = size_y = self.res 449 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 450 | gridx = tf.tile(gridx, [1,1,size_y,1]) 451 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 452 | gridy = tf.tile(gridy, [1,size_x,1,1]) 453 | grid = tf.concat([gridx, gridy], axis=-1) 454 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 455 | 456 | def get_config(self): 457 | config = { 458 | 'res':self.res, 459 | 'out_dim': self.out_dim, 460 | 'hid_dim': self.hid_dim, 461 | 'n_head': self.n_head 462 | } 463 | return config 464 | 465 | 466 | 467 | -------------------------------------------------------------------------------- /tensorflow/5_Elasticity/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | import matplotlib.pyplot as plt 6 | from time import time 7 | from utils import * 8 | 9 | # params ##################################33 10 | encode_dim = 512 11 | out_dim = 1 12 | n_head = 8 13 | n_train = 1000 14 | n_test = 200 15 | 16 | # load dataset 17 | trainX, trainY, testX, testY= load_data("./", n_train, n_test) 18 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 19 | 20 | # define a model 21 | network = PiT(out_dim, encode_dim, n_head, 2, 2) 22 | inputs = tf.keras.Input(shape=trainX.shape[1:]) 23 | outputs = network(inputs) 24 | model = tf.keras.Model(inputs, outputs) 25 | network.summary() 26 | 27 | 28 | ### load model 29 | model.load_weights('./512/checkpoints/my_checkpoint').expect_partial() 30 | 31 | ### evaluation, visualization 32 | index = 89 33 | pred = model.predict(testX) 34 | print(rel_norm()(testY, pred)) 35 | pred = pred[index:index+1,...] 36 | true = testY[index:index+1,...] 37 | 38 | err = abs(true-pred) 39 | emax = err.max() 40 | emin = err.min() 41 | vmax = true.max() 42 | vmin = true.min() 43 | print(vmax, vmin, emax, emin) 44 | 45 | plt.figure(figsize=(6,6),dpi=300) 46 | plt.axes([0,0,1,1]) 47 | x = testX[index,:,:1] 48 | y = testX[index,:,1:2] 49 | plt.scatter(x, y, c=pred, cmap="jet", s=160, vmin=vmin, vmax=vmax) 50 | plt.axis("off") 51 | plt.axis("equal") 52 | plt.savefig("pred.pdf") 53 | plt.close() 54 | 55 | plt.figure(figsize=(6,6),dpi=300) 56 | plt.axes([0,0,1,1]) 57 | plt.scatter(x, y, c=true, cmap="jet", s=160, vmin=vmin, vmax=vmax) 58 | plt.axis("off") 59 | plt.axis("equal") 60 | plt.savefig("true.pdf") 61 | plt.close() 62 | 63 | plt.figure(figsize=(6,6),dpi=300) 64 | plt.axes([0,0,1,1]) 65 | plt.scatter(x, y, c=err, cmap="jet", s=160, vmin=emin, vmax=emax) 66 | plt.axis("off") 67 | plt.axis("equal") 68 | plt.savefig("error.pdf") 69 | plt.close() 70 | -------------------------------------------------------------------------------- /tensorflow/5_Elasticity/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' 6 | import matplotlib.pyplot as plt 7 | from time import time 8 | from utils import * 9 | 10 | # params ##################################33 11 | n_epochs = 500 12 | lr = 0.001 13 | batch_size = 10 14 | encode_dim = 512 15 | out_dim = 1 16 | n_head = 8 17 | n_train = 1000 18 | n_test = 200 19 | 20 | # load dataset 21 | trainX, trainY, testX, testY= load_data("./", n_train, n_test) 22 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 23 | 24 | # define a model 25 | network = PiT(out_dim, encode_dim, n_head, 2, 2) 26 | inputs = tf.keras.Input(shape=trainX.shape[1:]) 27 | outputs = network(inputs) 28 | model = tf.keras.Model(inputs, outputs) 29 | network.summary() 30 | 31 | ### compile model 32 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 33 | loss=rel_norm(), 34 | jit_compile=True) 35 | 36 | ### fit model 37 | start = time() 38 | train_history = model.fit(trainX, trainY, 39 | batch_size, n_epochs, verbose=1, 40 | validation_data=(testX, testY), 41 | validation_batch_size=20) 42 | end = time() 43 | print(' ') 44 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 45 | print(' ') 46 | 47 | ### save model 48 | model.save_weights('./checkpoints/my_checkpoint') 49 | 50 | ### training history 51 | loss_hist = train_history.history 52 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 53 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 54 | np.savetxt('./training_history.csv', tf.concat([train_loss, test_loss], axis=-1), delimiter=',', header='train,test', fmt='%1.16f', comments='') 55 | 56 | plt.plot(train_loss, label='train') 57 | plt.plot(test_loss, label='test') 58 | plt.legend() 59 | plt.yscale('log', base=10) 60 | plt.savefig('training_history.png') 61 | plt.close() 62 | 63 | ### evaluation, visualization 64 | index = 2 65 | pred = model.predict(testX) 66 | print(rel_norm()(testY, pred)) 67 | pred = pred[index:index+1,...] 68 | true = testY[index:index+1,...] 69 | 70 | err = abs(true-pred) 71 | emax = err.max() 72 | emin = err.min() 73 | vmax = true.max() 74 | vmin = true.min() 75 | print(vmax, vmin, emax, emin) 76 | 77 | plt.figure(figsize=(6,6),dpi=300) 78 | plt.axes([0,0,1,1]) 79 | x = testX[index,:,:1] 80 | y = testX[index,:,1:2] 81 | plt.scatter(x, y, c=pred, cmap="jet", s=160, vmin=vmin, vmax=vmax) 82 | plt.axis("off") 83 | plt.axis("equal") 84 | plt.savefig("pred.pdf") 85 | plt.close() 86 | 87 | plt.figure(figsize=(6,6),dpi=300) 88 | plt.axes([0,0,1,1]) 89 | plt.scatter(x, y, c=true, cmap="jet", s=160, vmin=vmin, vmax=vmax) 90 | plt.axis("off") 91 | plt.axis("equal") 92 | plt.savefig("true.pdf") 93 | plt.close() 94 | 95 | plt.figure(figsize=(6,6),dpi=300) 96 | plt.axes([0,0,1,1]) 97 | plt.scatter(x, y, c=err, cmap="jet", s=160, vmin=emin, vmax=emax) 98 | plt.axis("off") 99 | plt.axis("equal") 100 | plt.savefig("error.pdf") 101 | plt.close() 102 | -------------------------------------------------------------------------------- /tensorflow/5_Elasticity/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | import numpy as np 3 | import tensorflow as tf 4 | tf.keras.utils.set_random_seed(0) 5 | tf.config.experimental.enable_op_determinism() 6 | tf.random.set_seed(0) 7 | physical_devices = tf.config.list_physical_devices('GPU') 8 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 9 | import tensorflow_probability as tfp 10 | 11 | 12 | class rel_norm(tf.keras.losses.Loss): 13 | ''' 14 | Compute the average relative l2 loss between a batch of targets and predictions 15 | ''' 16 | def __init__(self): 17 | super().__init__() 18 | def call(self, true, pred): 19 | ''' 20 | true: (batch_size, L, d). 21 | pred: (batch_size, L, d). 22 | number of variables d=1 23 | ''' 24 | rel_error = tf.math.divide(tf.norm(tf.keras.layers.Reshape((-1,))(true-pred), axis=1), tf.norm(tf.keras.layers.Reshape((-1,))(true), axis=1)) 25 | return tf.math.reduce_mean(rel_error) 26 | 27 | def load_data(path, ntrain, ntest): 28 | 29 | R = np.transpose(np.load(path + "Random_UnitCell_rr_10.npy"), (1,0))[:,np.newaxis,:] #(2000,1,42) 30 | X = np.transpose(np.load(path + "Random_UnitCell_XY_10.npy"), (2,0,1)) #(2000,972,2) 31 | R = np.repeat(5*R-1, X.shape[1], 1) #(2000,972,42) 32 | X = np.concatenate((X,R), axis=-1) #(2000,972,46) 33 | Y = np.transpose(np.load(path + "Random_UnitCell_sigma_10.npy"), (1,0))[...,np.newaxis] 34 | 35 | return X[:ntrain,...].astype("float32"), Y[:ntrain,...].astype("float32"), X[-ntest:,...].astype("float32"), Y[-ntest:,...].astype("float32") 36 | 37 | class mlp(tf.keras.layers.Layer): 38 | ''' 39 | A two-layer MLP with GELU activation. 40 | ''' 41 | def __init__(self, n_filters1, n_filters2): 42 | super(mlp, self).__init__() 43 | 44 | self.width1 = n_filters1 45 | self.width2 = n_filters2 46 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 47 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 48 | 49 | def call(self, inputs): 50 | x = self.mlp1(inputs) 51 | x = self.mlp2(x) 52 | return x 53 | 54 | def get_config(self): 55 | config = { 56 | 'n_filters1': self.width1, 57 | 'n_filters2': self.width2 58 | } 59 | return config 60 | 61 | class MultiHeadPosAtt(tf.keras.layers.Layer): 62 | def __init__(self, n_head, hid_dim, locality): 63 | super(MultiHeadPosAtt, self).__init__() 64 | 65 | self.locality = locality 66 | self.hid_dim = hid_dim 67 | self.n_head = n_head 68 | self.v_dim = round(self.hid_dim/self.n_head) 69 | 70 | def build(self, input_shape): 71 | 72 | self.r = self.add_weight( 73 | shape=(1, self.n_head, 1, 1), 74 | trainable=True, 75 | name="dist", 76 | ) 77 | 78 | self.weight = self.add_weight( 79 | shape=(self.n_head, self.hid_dim, self.v_dim), 80 | initializer="he_normal", 81 | trainable=True, 82 | name="weight", 83 | ) 84 | self.built = True 85 | 86 | def call(self, m_dist, inputs): 87 | """ 88 | m_dist: (batch, N, N) 89 | """ 90 | scaled_dist = tf.expand_dims(m_dist, 1) * self.r**2 #(batch_size, n_heads ,L, L) 91 | if self.locality <= 100: 92 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 93 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 94 | else: 95 | scaled_dist = scaled_dist 96 | scaled_dist = -scaled_dist #(batch_size, n_heads ,L, L) 97 | att = tf.nn.softmax(scaled_dist, axis=-1) #(batch_size, n_heads ,L, L) 98 | 99 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight)#(batch_size, n_head, L, v_dim) 100 | concat = tf.einsum("bhnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 101 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 102 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 103 | return tf.keras.activations.gelu(concat) 104 | 105 | 106 | def get_config(self): 107 | config = { 108 | 'hid_dim': self.hid_dim, 109 | 'n_head': self.n_head, 110 | 'locality': self.locality 111 | } 112 | return config 113 | 114 | class PiT(tf.keras.Model): 115 | ''' 116 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 117 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 118 | ''' 119 | def __init__(self, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 120 | super(PiT, self).__init__() 121 | ''' 122 | out_dim: number of variables 123 | hid_dim: encoding dimension (network width) 124 | n_head: number of heads in multi-head attention modules 125 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive filed 126 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive filed 127 | ''' 128 | self.out_dim = out_dim 129 | self.hid_dim = hid_dim 130 | self.n_head = n_head 131 | self.en_local = locality_encoder 132 | self.de_local = locality_decoder 133 | self.n_blocks = 4 # number of position-attention modules in the Processor 134 | 135 | # Encoder 136 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 137 | self.down = MultiHeadPosAtt(self.n_head, self.hid_dim, locality=self.en_local) 138 | self.mlp1 = mlp(self.hid_dim, self.hid_dim) 139 | self.w1 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 140 | 141 | # Processor 142 | self.PA = [MultiHeadPosAtt(self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 143 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 144 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 145 | 146 | # Decoder 147 | self.up = MultiHeadPosAtt(self.n_head, self.hid_dim, locality=self.de_local) 148 | self.mlp2 = mlp(self.hid_dim, self.hid_dim) 149 | self.w2 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 150 | self.de_layer = mlp(self.hid_dim, self.out_dim) 151 | 152 | def call(self, inputs): 153 | 154 | m_dist = self.pairwise_dist(inputs[...,:2]) 155 | 156 | # Encoder 157 | en = self.en_layer(inputs) # (batch_size, L_qry, hid_dim) 158 | x = self.mlp1(self.down(m_dist, en)) + self.w1(en) # (batch_size, L_qry, hid_dim) 159 | x = tf.keras.activations.gelu(x) # (batch_size, L_qry, hid_dim) 160 | 161 | # Processor 162 | for i in range(self.n_blocks): 163 | x = self.MLP[i](self.PA[i](m_dist, x)) + self.W[i](x) # (batch_size, L_qry, hid_dim) 164 | x = tf.keras.activations.gelu(x) # (batch_size, L_qry, hid_dim) 165 | 166 | # Decoder 167 | de = self.mlp2(self.up(m_dist, x)) + self.w2(x) # (batch_size, L_qry, hid_dim) 168 | de = tf.keras.activations.gelu(de) # (batch_size, L_qry, hid_dim) 169 | de = self.de_layer(de) # (batch_size, L_qry, out_dim) 170 | return de 171 | 172 | def pairwise_dist(self, mesh): 173 | 174 | mesh = tf.expand_dims(mesh, axis=1) 175 | pairwise_diff = mesh - tf.transpose(mesh, (0,2,1,3)) 176 | pairwise_dist = tf.norm(pairwise_diff, ord=2, axis=-1) 177 | return pairwise_dist**2 / 2.0 178 | 179 | def get_config(self): 180 | config = { 181 | 'out_dim': self.out_dim, 182 | 'hid_dim': self.hid_dim, 183 | 'n_head': self.n_head, 184 | 'locality_encoder': self.en_local, 185 | 'locality_decoder': self.de_local 186 | } 187 | return config 188 | 189 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 190 | ''' 191 | Scaled dot-product multi-head self-attention 192 | ''' 193 | def __init__(self, n_head, hid_dim): 194 | super(MultiHeadSelfAtt, self).__init__() 195 | 196 | self.hid_dim = hid_dim 197 | self.n_head = n_head 198 | self.v_dim = round(self.hid_dim/self.n_head) 199 | 200 | def build(self, input_shape): 201 | 202 | self.q = self.add_weight( 203 | shape=(self.n_head, input_shape[-1], self.v_dim), 204 | initializer="he_normal", 205 | trainable=True, 206 | name="query", 207 | ) 208 | 209 | self.k = self.add_weight( 210 | shape=(self.n_head, input_shape[-1], self.v_dim), 211 | initializer="he_normal", 212 | trainable=True, 213 | name="key", 214 | ) 215 | 216 | self.v = self.add_weight( 217 | shape=(self.n_head, input_shape[-1], self.v_dim), 218 | initializer="he_normal", 219 | trainable=True, 220 | name="value", 221 | ) 222 | self.built = True 223 | 224 | def call(self, inputs): 225 | 226 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q)#(batch, n_head, L, v_dim) 227 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k)#(batch, n_head, L, v_dim) 228 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1)#(batch, n_heads, L, L) 229 | 230 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v)#(batch, n_head, L, v_dim) 231 | 232 | concat = tf.einsum("...nj,...jd->...nd", att, value)#(batch, n_head, L, v_dim) 233 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 234 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 235 | return tf.keras.activations.gelu(concat) 236 | 237 | def get_config(self): 238 | config = { 239 | 'n_head': self.n_head, 240 | 'hid_dim': self.hid_dim 241 | } 242 | return config 243 | 244 | class LiteTransformer(tf.keras.Model): 245 | ''' 246 | Replace position-attention of the Processor in a PiT with self-attention 247 | ''' 248 | def __init__(self, out_dim, hid_dim, n_head, en_local, de_local): 249 | super(LiteTransformer, self).__init__() 250 | 251 | self.out_dim = out_dim 252 | self.hid_dim = hid_dim 253 | self.n_head = n_head 254 | self.en_local = en_local 255 | self.de_local = de_local 256 | self.n_blocks = 4 257 | 258 | # Encoder 259 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 260 | self.down = MultiHeadPosAtt(self.n_head, self.hid_dim, self.en_local) 261 | self.mlp1 = mlp(self.hid_dim, self.hid_dim) 262 | self.w1 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 263 | 264 | # Processor 265 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 266 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 267 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 268 | 269 | # Decoder 270 | self.up = MultiHeadPosAtt(self.n_head, self.hid_dim, self.de_local) 271 | self.mlp2 = mlp(self.hid_dim, self.hid_dim) 272 | self.w2 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 273 | self.de_layer = mlp(self.hid_dim, self.out_dim) 274 | 275 | def call(self, m_dist, inputs): 276 | 277 | m_dist = self.pairwise_dist(inputs[...,:2]) 278 | 279 | # Encoder 280 | en = self.en_layer(inputs) 281 | x = self.mlp1(self.down(m_dist, en)) + self.w1(en) 282 | x = tf.keras.activations.gelu(x) 283 | 284 | # Processor 285 | for i in range(self.n_blocks): 286 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 287 | x = tf.keras.activations.gelu(x) 288 | 289 | # Decoder 290 | de = self.mlp2(self.up(m_dist, x)) + self.w2(x) 291 | de = tf.keras.activations.gelu(de) 292 | de = self.de_layer(de) 293 | 294 | return de 295 | 296 | def pairwise_dist(self, mesh): 297 | 298 | mesh = tf.expand_dims(mesh, axis=1) 299 | pairwise_diff = mesh - tf.transpose(mesh, (0,2,1,3)) 300 | pairwise_dist = tf.norm(pairwise_diff, ord=2, axis=-1) 301 | return pairwise_dist**2 / 2.0 302 | 303 | def get_config(self): 304 | config = { 305 | 'out_dim': self.out_dim, 306 | 'hid_dim': self.hid_dim, 307 | 'n_head': self.n_head, 308 | 'en_local':self.en_local, 309 | 'de_local':self.de_local 310 | } 311 | return config 312 | 313 | class Transformer(tf.keras.Model): 314 | ''' 315 | Replace position-attention of a PiT with self-attention. 316 | ''' 317 | def __init__(self, out_dim, hid_dim, n_head): 318 | super(Transformer, self).__init__() 319 | 320 | self.out_dim = out_dim 321 | self.hid_dim = hid_dim 322 | self.n_head = n_head 323 | self.n_blocks = 4 324 | 325 | # Encoder 326 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 327 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 328 | self.mlp1 = mlp(self.hid_dim, self.hid_dim) 329 | self.w1 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 330 | 331 | # Processor 332 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 333 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 334 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 335 | 336 | # Decoder 337 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 338 | self.mlp2 = mlp(self.hid_dim, self.hid_dim) 339 | self.w2 = tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") 340 | self.de_layer = mlp(self.hid_dim, self.out_dim) 341 | 342 | def call(self, inputs): 343 | 344 | # Encoder 345 | en = self.en_layer(inputs) 346 | x = self.mlp1(self.down(en)) + self.w1(en) 347 | x = tf.keras.activations.gelu(x) 348 | 349 | # Processor 350 | for i in range(self.n_blocks): 351 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 352 | x = tf.keras.activations.gelu(x) 353 | 354 | # Decoder 355 | de = self.mlp2(self.up(x)) + self.w2(x) 356 | de = tf.keras.activations.gelu(de) 357 | de = self.de_layer(de) 358 | 359 | return de 360 | 361 | def get_config(self): 362 | config = { 363 | 'out_dim': self.out_dim, 364 | 'hid_dim': self.hid_dim, 365 | 'n_head': self.n_head, 366 | } 367 | return config 368 | 369 | -------------------------------------------------------------------------------- /tensorflow/6_NACA/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | import matplotlib.pyplot as plt 8 | from time import time 9 | 10 | ### custom imports 11 | from utils import * 12 | 13 | # params ##################################33 14 | encode_dim = 256 15 | out_dim = 1 16 | n_head = 2 17 | n_train = 1000 18 | n_test = 200 19 | 20 | # load dataset 21 | trainX, trainY, testX, testY= load_data("./", ntrain=n_train, ntest=n_test) 22 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 23 | 24 | # define a model 25 | m_cross = pairwise_dist(51, 221, 26, 111) # pairwise distance matrix for encoder and decoder 26 | m_latent = pairwise_dist(26, 111, 26, 111) # pairwise distance matrix for processor 27 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, 0.5, 2) 28 | inputs = tf.keras.Input(shape=(221,51,trainX.shape[-1])) 29 | outputs = network(inputs) 30 | model = tf.keras.Model(inputs, outputs) 31 | network.summary() 32 | print(' ') 33 | ### load model 34 | model.load_weights('./256_0.5_2_selected/checkpoints/my_checkpoint').expect_partial() 35 | 36 | ### evaluation, visualization 37 | pred = model.predict(testX) 38 | 39 | index = -67 40 | true = testY[index,40:-40,:20,0].reshape(-1,1) 41 | pred = model(testX)[index,40:-40,:20,0].numpy().reshape(-1,1) 42 | err = abs(true-pred) 43 | emax = err.max() 44 | emin = err.min() 45 | vmax = true.max() 46 | vmin = true.min() 47 | print(vmax, vmin, emax, emin) 48 | 49 | x = testX[index,40:-40,:20,0].reshape(-1,1) 50 | y = testX[index,40:-40,:20,1].reshape(-1,1) 51 | print(x.max(), x.min(), y.max(), y.min()) 52 | 53 | plt.figure(figsize=(12,12),dpi=100) 54 | plt.scatter(x, y, c=pred, cmap="jet", s=160) 55 | plt.ylim(-0.5,0.5) 56 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 57 | plt.tight_layout(pad=0) 58 | plt.savefig("pred.pdf") 59 | plt.close() 60 | 61 | plt.figure(figsize=(12,12),dpi=100) 62 | plt.scatter(x, y, c=true, cmap="jet", s=160) 63 | plt.ylim(-0.5,0.5) 64 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 65 | plt.tight_layout(pad=0) 66 | plt.savefig("true.pdf") 67 | plt.close() 68 | 69 | plt.figure(figsize=(12,12),dpi=100) 70 | plt.scatter(x, y, c=err, cmap="jet", s=160) 71 | plt.ylim(-0.5,0.5) 72 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 73 | plt.tight_layout(pad=0) 74 | plt.savefig("error.pdf") 75 | plt.close() 76 | 77 | plt.figure(figsize=(12,12),dpi=100) 78 | plt.scatter(x, y, color="black", s=160) 79 | plt.ylim(-0.5,0.5) 80 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 81 | plt.tight_layout(pad=0) 82 | plt.savefig("points.pdf") 83 | plt.close() 84 | -------------------------------------------------------------------------------- /tensorflow/6_NACA/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 4 | os.environ['PYTHONHASHSEED'] = '1' 5 | 6 | from numpy import savetxt 7 | import matplotlib.pyplot as plt 8 | from time import time 9 | 10 | ### custom imports 11 | from utils import * 12 | 13 | # params ##################################33 14 | n_epochs = 500 15 | lr = 0.001 16 | batch_size = 8 17 | encode_dim = 256 18 | out_dim = 1 19 | n_head = 2 20 | n_train = 1000 21 | n_test = 200 22 | en_loc = 0.5 23 | de_loc = 2 24 | 25 | # load dataset 26 | trainX, trainY, testX, testY= load_data("./", ntrain=n_train, ntest=n_test) 27 | print(trainX.shape, trainY.shape, testX.shape, testY.shape) 28 | 29 | # define a model 30 | m_cross = pairwise_dist(51, 221, 26, 111) # pairwise distance matrix for encoder and decoder 31 | m_latent = pairwise_dist(26, 111, 26, 111) # pairwise distance matrix for processor 32 | network = PiT(m_cross, m_latent, out_dim, encode_dim, n_head, en_loc, de_loc) 33 | 34 | inputs = tf.keras.Input(shape=(221,51,trainX.shape[-1])) 35 | outputs = network(inputs) 36 | model = tf.keras.Model(inputs, outputs) 37 | model.summary() 38 | ### compile model 39 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=tf.keras.optimizers.schedules.CosineDecay(lr, n_epochs*(n_train//batch_size))), 40 | loss=rel_norm(), 41 | jit_compile=True) 42 | 43 | ### fit model 44 | start = time() 45 | train_history = model.fit(trainX, trainY, 46 | batch_size, n_epochs, verbose=1, 47 | validation_data=(testX, testY), 48 | validation_batch_size=100) 49 | end = time() 50 | print(' ') 51 | print('Training cost is {} seconds per epoch !'.format((end-start)/n_epochs)) 52 | 53 | print(' ') 54 | ### save model 55 | model.save_weights('./checkpoints/my_checkpoint') 56 | ### training history 57 | loss_hist = train_history.history 58 | train_loss = tf.reshape(tf.convert_to_tensor(loss_hist["loss"]), (-1,1)) 59 | test_loss = tf.reshape(tf.convert_to_tensor(loss_hist["val_loss"]), (-1,1)) 60 | savetxt('./training_history.csv', tf.concat([train_loss, test_loss], axis=-1).numpy(), delimiter=',', header='train,test', fmt='%1.16f', comments='') 61 | 62 | plt.plot(train_loss, label='train') 63 | plt.plot(test_loss, label='test') 64 | plt.legend() 65 | plt.yscale('log', base=10) 66 | plt.savefig('training_history.png') 67 | plt.close() 68 | 69 | ### evaluation, visualization 70 | pred = model.predict(testX) 71 | 72 | index = -2 73 | true = testY[index,40:-40,:20,0].reshape(-1,1) 74 | pred = model(testX)[index,40:-40,:20,0].numpy().reshape(-1,1) 75 | err = abs(true-pred) 76 | emax = err.max() 77 | emin = err.min() 78 | vmax = true.max() 79 | vmin = true.min() 80 | print(vmax, vmin, emax, emin) 81 | 82 | x = testX[index,40:-40,:20,0].reshape(-1,1) 83 | y = testX[index,40:-40,:20,1].reshape(-1,1) 84 | print(x.max(), x.min(), y.max(), y.min()) 85 | 86 | plt.figure(figsize=(12,12),dpi=100) 87 | plt.scatter(x, y, c=pred, cmap="jet", s=160) 88 | plt.ylim(-0.5,0.5) 89 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 90 | plt.tight_layout(pad=0) 91 | plt.savefig("pred.pdf") 92 | plt.close() 93 | 94 | plt.figure(figsize=(12,12),dpi=100) 95 | plt.scatter(x, y, c=true, cmap="jet", s=160) 96 | plt.ylim(-0.5,0.5) 97 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 98 | plt.tight_layout(pad=0) 99 | plt.savefig("true.pdf") 100 | plt.close() 101 | 102 | plt.figure(figsize=(12,12),dpi=100) 103 | plt.scatter(x, y, c=err, cmap="jet", s=160) 104 | plt.ylim(-0.5,0.5) 105 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 106 | plt.tight_layout(pad=0) 107 | plt.savefig("error.pdf") 108 | plt.close() 109 | 110 | plt.figure(figsize=(12,12),dpi=100) 111 | plt.scatter(x, y, color="black", s=160) 112 | plt.ylim(-0.5,0.5) 113 | plt.tick_params(axis="both", which="both", bottom=False, left=False, labelleft=False, labelbottom=False) 114 | plt.tight_layout(pad=0) 115 | plt.savefig("points.pdf") 116 | plt.close() 117 | -------------------------------------------------------------------------------- /tensorflow/6_NACA/utils.py: -------------------------------------------------------------------------------- 1 | from numpy import load, concatenate, newaxis, zeros 2 | import tensorflow as tf 3 | tf.keras.utils.set_random_seed(0) 4 | tf.config.experimental.enable_op_determinism() 5 | tf.random.set_seed(0) 6 | physical_devices = tf.config.list_physical_devices('GPU') 7 | tf.config.set_visible_devices(physical_devices[0:],'GPU') 8 | import tensorflow_probability as tfp 9 | 10 | class rel_norm(tf.keras.losses.Loss): 11 | ''' 12 | Compute the average relative l2 loss between a batch of targets and predictions 13 | ''' 14 | def __init__(self): 15 | super().__init__() 16 | def call(self, true, pred): 17 | batch_size = tf.shape(true)[0] 18 | rel_error = tf.math.divide(tf.norm(tf.reshape(true-pred, (batch_size,-1)), axis=1), tf.norm(tf.reshape(true, (batch_size,-1)), axis=1)) 19 | return tf.math.reduce_mean(rel_error) 20 | 21 | def pairwise_dist(res1x, res1y, res2x, res2y): 22 | 23 | gridx = tf.reshape(tf.linspace(0, 1, res1x+1)[:-1], (1,-1,1)) 24 | gridx = tf.tile(gridx, [res1y,1,1]) 25 | gridy = tf.reshape(tf.linspace(0, 1, res1y+1)[:-1], (-1,1,1)) 26 | gridy = tf.tile(gridy, [1,res1x,1]) 27 | grid1 = tf.concat([gridx, gridy], axis=-1) 28 | grid1 = tf.reshape(grid1, (res1x*res1y,2)) #(res1*res1,2) 29 | 30 | gridx = tf.reshape(tf.linspace(0, 1, res2x+1)[:-1], (1,-1,1)) 31 | gridx = tf.tile(gridx, [res2y,1,1]) 32 | gridy = tf.reshape(tf.linspace(0, 1, res2y+1)[:-1], (-1,1,1)) 33 | gridy = tf.tile(gridy, [1,res2x,1]) 34 | grid2 = tf.concat([gridx, gridy], axis=-1) 35 | grid2 = tf.reshape(grid2, (res2x*res2y,2)) #(res2*res2,2) 36 | 37 | print(grid1.shape, grid2.shape) 38 | grid1 = tf.tile(tf.expand_dims(grid1, 1), [1,grid2.shape[0],1]) 39 | grid2 = tf.tile(tf.expand_dims(grid2, 0), [grid1.shape[0],1,1]) 40 | 41 | dist = tf.norm(grid1-grid2, axis=-1) 42 | dist2 = tf.cast(dist**2, 'float32') 43 | return dist2/2 44 | 45 | def load_data(path, ntrain, ntest): 46 | vertices_x = load(path + "NACA_Cylinder_X.npy")[...,newaxis] 47 | vertices_y = load(path + "NACA_Cylinder_Y.npy")[...,newaxis] 48 | mach = load(path + "NACA_Cylinder_Q.npy")[:,4,...][...,newaxis] 49 | 50 | X = concatenate((vertices_x, vertices_y), -1).astype("float32") 51 | Y = mach.astype("float32") 52 | return X[:ntrain,...], Y[:ntrain,...], X[ntrain:ntrain+ntest,...], Y[ntrain:ntrain+ntest,...] 53 | 54 | 55 | class mlp(tf.keras.layers.Layer): 56 | ''' 57 | A two-layer MLP with GELU activation. 58 | ''' 59 | def __init__(self, n_filters1, n_filters2): 60 | super(mlp, self).__init__() 61 | 62 | self.width1 = n_filters1 63 | self.width2 = n_filters2 64 | self.mlp1 = tf.keras.layers.Dense(self.width1, activation='gelu', kernel_initializer="he_normal") 65 | self.mlp2 = tf.keras.layers.Dense(self.width2, kernel_initializer="he_normal") 66 | 67 | def call(self, inputs): 68 | x = self.mlp1(inputs) 69 | x = self.mlp2(x) 70 | return x 71 | 72 | def get_config(self): 73 | config = { 74 | 'n_filters1': self.width1, 75 | 'n_filters2': self.width2, 76 | } 77 | return config 78 | 79 | class MultiHeadPosAtt(tf.keras.layers.Layer): 80 | ''' 81 | Global, local and cross variants of the multi-head position-attention mechanism. 82 | ''' 83 | def __init__(self, m_dist, n_head, hid_dim, locality): 84 | super(MultiHeadPosAtt, self).__init__() 85 | ''' 86 | m_dist: distance matrix 87 | n_head: number of attention heads 88 | hid_dim: encoding dimension 89 | locality: quantile parameter to customize receptive field in position-attention 90 | ''' 91 | self.dist = m_dist 92 | self.locality = locality 93 | self.hid_dim = hid_dim 94 | self.n_head = n_head 95 | self.v_dim = round(self.hid_dim/self.n_head) 96 | 97 | def build(self, input_shape): 98 | 99 | self.r = self.add_weight( 100 | shape=(self.n_head, 1, 1), 101 | trainable=True, 102 | constraint=tf.keras.constraints.NonNeg(), 103 | name="band_width", 104 | ) 105 | 106 | self.weight = self.add_weight( 107 | shape=(self.n_head, input_shape[-1], self.v_dim), 108 | initializer="he_normal", 109 | trainable=True, 110 | name="weight", 111 | ) 112 | self.built = True 113 | 114 | def call(self, inputs): 115 | scaled_dist = self.dist * tf.math.tan(self.r) 116 | if self.locality <= 100: 117 | mask = tfp.stats.percentile(scaled_dist, self.locality, interpolation="linear", axis=-1, keepdims=True) 118 | scaled_dist = tf.where(scaled_dist<=mask, scaled_dist, tf.float32.max) 119 | else: 120 | pass 121 | scaled_dist = - scaled_dist 122 | att = tf.nn.softmax(scaled_dist, axis=2)#(n_heads, L, L) 123 | 124 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.weight) # (batch_size, n_head, L, v_dim) 125 | 126 | concat = tf.einsum("hnj,bhjd->bhnd", att, value) # (batch_size, n_head, L, v_dim) 127 | concat = tf.transpose(concat, (0,2,1,3)) # (batch_size, L, n_head, v_dim) 128 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) # (batch_size, L, hid_dim) 129 | return tf.keras.activations.gelu(concat) 130 | 131 | def get_config(self): 132 | config = { 133 | 'm_dist': self.dist, 134 | 'hid_dim': self.hid_dim, 135 | 'n_head': self.n_head, 136 | 'locality': self.locality 137 | } 138 | return config 139 | 140 | class PiT(tf.keras.Model): 141 | ''' 142 | Position-induced Transfomer, built upon the multi-head position-attention mechanism. 143 | PiT can be trained to decompose and learn the global and local dependcencies of operators in partial differential equations. 144 | ''' 145 | def __init__(self, m_cross, m_small, out_dim, hid_dim, n_head, locality_encoder, locality_decoder): 146 | super(PiT, self).__init__() 147 | ''' 148 | m_cross: distance matrix between X_query and X_latent; (L_qry,L_ltt) 149 | m_small: distance matrix between X_latent and X_latent; (L_ltt,L_ltt) 150 | out_dim: number of variables 151 | hid_dim: encoding dimension (network width) 152 | n_head: number of heads in multi-head attention modules 153 | locality_encoder: quantile parameter of local position-attention in the Encoder, allowing to customize the size of receptive field 154 | locality_decoder: quantile parameter of local position-attention in the Decoder, allowing to customize the size of receptive field 155 | ''' 156 | self.m_cross = m_cross 157 | self.m_small = m_small 158 | self.out_dim = out_dim 159 | self.hid_dim = hid_dim 160 | self.n_head = n_head 161 | self.en_local = locality_encoder 162 | self.de_local = locality_decoder 163 | self.n_blocks = 4 164 | 165 | # Encoder 166 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 167 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 168 | 169 | # Processor 170 | self.PA = [MultiHeadPosAtt(self.m_small, self.n_head, self.hid_dim, locality=200) for i in range(self.n_blocks)] 171 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 172 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 173 | 174 | # Decoder 175 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 176 | self.de_layer = mlp(self.hid_dim, self.out_dim) 177 | 178 | def call(self, inputs): 179 | 180 | # Encoder 181 | grid = self.get_mesh(inputs) #(batch_size, res_qry1, res_qry2, 2) 182 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) # (batch_size, res_qry, res_qry, input_dim+2) 183 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) # (batch_size, L_qry, hid_dim) 184 | en = self.en_layer(en) # (batch_size, L_qry, hid_dim) 185 | x = self.down(en) # (batch_size, L_ltt, hid_dim) 186 | 187 | # Processor 188 | for i in range(self.n_blocks): 189 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) # (batch_size, L_ltt, hid_dim) 190 | x = tf.keras.activations.gelu(x) # (batch_size, L_ltt, hid_dim) 191 | 192 | # Decoder 193 | de = self.up(x) # (batch_size, L_qry, hid_dim) 194 | de = self.de_layer(de) # (batch_size, L_qry, hid_dim) 195 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) # (batch_size, res_qry1, res_qry2, out_dim) 196 | return de 197 | 198 | def get_mesh(self, inputs): 199 | size_x, size_y = tf.shape(inputs)[1], tf.shape(inputs)[2] 200 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 201 | gridx = tf.tile(gridx, [1,1,size_y,1]) 202 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 203 | gridy = tf.tile(gridy, [1,size_x,1,1]) 204 | grid = tf.concat([gridx, gridy], axis=-1) 205 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 206 | 207 | def get_config(self): 208 | config = { 209 | 'm_cross': self.m_cross, 210 | 'm_small': self.m_small, 211 | 'out_dim': self.out_dim, 212 | 'hid_dim': self.hid_dim, 213 | 'n_head': self.n_head, 214 | 'locality_encoder': self.en_local, 215 | 'locality_decoder': self.de_local 216 | } 217 | return config 218 | 219 | class MultiHeadSelfAtt(tf.keras.layers.Layer): 220 | ''' 221 | Scaled dot-product multi-head self-attention 222 | ''' 223 | def __init__(self, n_head, hid_dim): 224 | super(MultiHeadSelfAtt, self).__init__() 225 | 226 | self.hid_dim = hid_dim 227 | self.n_head = n_head 228 | self.v_dim = round(self.hid_dim/self.n_head) 229 | 230 | def build(self, input_shape): 231 | 232 | self.q = self.add_weight( 233 | shape=(self.n_head, input_shape[-1], self.v_dim), 234 | initializer="he_normal", 235 | trainable=True, 236 | name="query", 237 | ) 238 | 239 | self.k = self.add_weight( 240 | shape=(self.n_head, input_shape[-1], self.v_dim), 241 | initializer="he_normal", 242 | trainable=True, 243 | name="key", 244 | ) 245 | 246 | self.v = self.add_weight( 247 | shape=(self.n_head, input_shape[-1], self.v_dim), 248 | initializer="he_normal", 249 | trainable=True, 250 | name="value", 251 | ) 252 | self.built = True 253 | 254 | def call(self, inputs): 255 | 256 | query = tf.einsum("bnj,hjk->bhnk", inputs, self.q)#(batch_size, n_head, L, v_dim) 257 | key = tf.einsum("bnj,hjk->bhnk", inputs, self.k)#(batch_size, n_head, L ,v_dim) 258 | att = tf.nn.softmax(tf.einsum("...ij,...kj->...ik", query, key)/self.v_dim**0.5, axis=-1)#(batch_size, n_heads, L, L) 259 | 260 | value = tf.einsum("bnj,hjk->bhnk", inputs, self.v)#(batch_size, n_head, L, v_dim) 261 | 262 | concat = tf.einsum("...nj,...jd->...nd", att, value)#(batch_size, n_head, L, v_dim) 263 | concat = tf.transpose(concat, (0,2,1,3)) 264 | concat = tf.keras.layers.Reshape((-1,self.hid_dim))(concat) 265 | return tf.keras.activations.gelu(concat) 266 | 267 | def get_config(self): 268 | config = { 269 | 'n_head': self.n_head, 270 | 'hid_dim': self.hid_dim 271 | } 272 | return config 273 | 274 | class LiteTransformer(tf.keras.Model): 275 | ''' 276 | Replace position-attention of the Processor in a PiT with self-attention 277 | ''' 278 | def __init__(self, m_cross, out_dim, hid_dim, n_head, en_local, de_local): 279 | super(LiteTransformer, self).__init__() 280 | 281 | self.m_cross = m_cross 282 | self.out_dim = out_dim 283 | self.hid_dim = hid_dim 284 | self.n_head = n_head 285 | self.en_local = en_local 286 | self.de_local = de_local 287 | self.n_blocks = 4 288 | 289 | # Encoder 290 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 291 | self.down = MultiHeadPosAtt(tf.transpose(self.m_cross), self.n_head, self.hid_dim, locality=self.en_local) 292 | 293 | # Processor 294 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 295 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 296 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 297 | 298 | # Decoder 299 | self.up = MultiHeadPosAtt(self.m_cross, self.n_head, self.hid_dim, locality=self.de_local) 300 | self.de_layer = mlp(self.hid_dim, self.out_dim) 301 | 302 | def call(self, inputs): 303 | 304 | # Encoder 305 | grid = self.get_mesh(inputs) 306 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 307 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 308 | en = self.en_layer(en) 309 | x = self.down(en) 310 | 311 | # Processor 312 | for i in range(self.n_blocks): 313 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 314 | x = tf.keras.activations.gelu(x) 315 | 316 | # Decoder 317 | de = self.up(x) 318 | de = self.de_layer(de) 319 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 320 | return de 321 | 322 | def get_mesh(self, inputs): 323 | size_x, size_y = tf.shape(inputs)[1], tf.shape(inputs)[2] 324 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 325 | gridx = tf.tile(gridx, [1,1,size_y,1]) 326 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 327 | gridy = tf.tile(gridy, [1,size_x,1,1]) 328 | grid = tf.concat([gridx, gridy], axis=-1) 329 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 330 | 331 | def get_config(self): 332 | config = { 333 | 'm_cross': self.m_cross, 334 | 'out_dim': self.out_dim, 335 | 'hid_dim': self.hid_dim, 336 | 'n_head': self.n_head, 337 | 'locality_encoder': self.en_local, 338 | 'locality_decoder': self.de_local, 339 | } 340 | return config 341 | 342 | class Transformer(tf.keras.Model): 343 | ''' 344 | Replace position-attention of a PiT with self-attention. 345 | ''' 346 | def __init__(self, out_dim, hid_dim, n_head): 347 | super(Transformer, self).__init__() 348 | 349 | self.out_dim = out_dim 350 | self.hid_dim = hid_dim 351 | self.n_head = n_head 352 | self.n_blocks = 4 353 | 354 | # Encoder 355 | self.en_layer = tf.keras.layers.Dense(self.hid_dim, activation="gelu", kernel_initializer="he_normal") 356 | self.down = MultiHeadSelfAtt(self.n_head, self.hid_dim) 357 | 358 | # Processor 359 | self.PA = [MultiHeadSelfAtt(self.n_head, self.hid_dim) for i in range(self.n_blocks)] 360 | self.MLP = [mlp(self.hid_dim, self.hid_dim) for i in range(self.n_blocks)] 361 | self.W = [tf.keras.layers.Dense(self.hid_dim, kernel_initializer="he_normal") for i in range(self.n_blocks)] 362 | 363 | # Decoder 364 | self.up = MultiHeadSelfAtt(self.n_head, self.hid_dim) 365 | self.de_layer = mlp(self.hid_dim, self.out_dim) 366 | 367 | def call(self, inputs): 368 | 369 | # Encoder 370 | grid = self.get_mesh(inputs) 371 | en = tf.concat([tf.cast(grid, dtype="float32"), inputs], axis=-1) 372 | en = tf.keras.layers.Reshape((-1, en.shape[3]))(en) 373 | en = self.en_layer(en) 374 | x = self.down(en) 375 | 376 | # Processor 377 | for i in range(self.n_blocks): 378 | x = self.MLP[i](self.PA[i](x)) + self.W[i](x) 379 | x = tf.keras.activations.gelu(x) 380 | 381 | # Decoder 382 | de = self.up(x) 383 | de = self.de_layer(de) 384 | de = tf.keras.layers.Reshape((inputs.shape[1], inputs.shape[2], self.out_dim))(de) 385 | return de 386 | 387 | def get_mesh(self, inputs): 388 | size_x, size_y = tf.shape(inputs)[1], tf.shape(inputs)[2] 389 | gridx = tf.reshape(tf.linspace(0, 1, size_x+1)[:-1], (1,-1,1,1)) 390 | gridx = tf.tile(gridx, [1,1,size_y,1]) 391 | gridy = tf.reshape(tf.linspace(0, 1, size_y+1)[:-1], (1,1,-1,1)) 392 | gridy = tf.tile(gridy, [1,size_x,1,1]) 393 | grid = tf.concat([gridx, gridy], axis=-1) 394 | return tf.tile(grid, [tf.shape(inputs)[0],1,1,1]) 395 | 396 | def get_config(self): 397 | config = { 398 | 'res':self.res, 399 | 'out_dim': self.out_dim, 400 | 'hid_dim': self.hid_dim 401 | } 402 | return config 403 | 404 | 405 | 406 | 407 | -------------------------------------------------------------------------------- /tensorflow/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 junfeng-chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tensorflow/README.md: -------------------------------------------------------------------------------- 1 | # Position-induced Transformer 2 | 3 | The code in this repository presents six numerical experiments of using Position-induced Transformer (PiT) for learing operators in partial differential equations. PiT is built upon the position-attention mechanism, proposed in the paper *Positional Knowledge is All You Need: Position-induced Transformer (PiT) for Operator Learning*. The paper can be downloaded here. 4 | 5 | PiT is discretization convergent, giving consistent and convergent predictions as the input data is refined. On the Darcy2D benchmark, a PiT model trained with data at 43x43 resolution can produce accurate predictions given input data at 421x421 resolution. 6 | 7 |
8 |
9 |
14 |
15 |
20 |
21 |
22 |
27 |
28 |