├── .idea └── vcs.xml ├── LICENSE ├── README.md ├── graph-neural-operator ├── README.md ├── UAI1_full_resolution.py ├── UAI2_full_equation.py ├── UAI3_resolution.py ├── UAI4_equation_sample.py ├── UAI5_sample_generalize.py ├── UAI6_sample_radius.py ├── UAI7_evaluate.py ├── UAI7_evaluate2.py ├── UAI8_kernel.py ├── model │ ├── grain_new_r64_s64testm100 │ └── grain_torus_r64_radius0.4testm100 ├── nn_conv.py └── utilities.py └── multipole-graph-neural-operator ├── MGKN_general_darcy2d.py ├── MGKN_orthogonal_burgers1d.py ├── README.md ├── neurips1_GKN.py ├── neurips1_MGKN.py ├── neurips2_MGKN.py ├── neurips3_MGKN.py ├── neurips4_GCN.py ├── neurips5_GKN.py └── utilities.py /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zongyi Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📢 DEPRECATION NOTICE 📢 2 | ---------------------------- 3 | 4 | 🚨 **This repository is no longer maintained.** 🚨 The code in this repository is **deprecated** and may not work with newer dependencies or frameworks. 5 | For the most up-to-date implementation and continued development, please visit: 6 | 7 | ## ➡️ **[NeuralOperator](https://github.com/neuraloperator/neuraloperator)** ⬅️ 8 | 9 | 🔴 We strongly recommend using the latest version to ensure compatibility, performance, and support.🔴 10 | 11 | ---------------------------- 12 | 13 | # Graph based neural operators 14 | This repository contains the code for the two following papers: 15 | - [(GKN) Neural Operator: Graph Kernel Network for Partial Differential Equations](https://arxiv.org/abs/2003.03485) 16 | - [(MGKN) Multipole Graph Neural Operator for Parametric Partial Differential Equations](https://arxiv.org/abs/2006.09535) 17 | 18 | ## Graph Kernel Network (GKN) 19 | We propose to use graph neural networks for learning the solution operator for partial differential equations. The key innovation in our work is that a single set of network parameters, within a carefully designed network architecture, may be used to describe mappings between infinite-dimensional spaces and between different finite-dimensional approximations of those spaces. 20 | 21 | ## Multipole Graph Kernel Network (MGKN) 22 | Inspired by the classical multipole methods, we propose a multi-level graph neural network framework that captures interaction at all ranges with only linear complexity. Our multi-level formulation is equivalent to recursively adding inducing points to the kernel matrix, unifying GNNs with multi-resolution matrix factorization of the kernel. Experiments confirm our multi-graph network learns discretization-invariant solution operators to PDEs and can be evaluated in linear time. 23 | 24 | ## Requirements 25 | - [PyTorch](https://pytorch.org/) 26 | - [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) 27 | 28 | 29 | ## Files 30 | The code is in the form of simple scripts. Each script shall be stand-alone and directly runnable. 31 | 32 | ## Datasets 33 | We provide the Burgers equation and Darcy flow datasets we used in the paper. The data generation can be found in the paper. 34 | The data are given in the form of matlab file. They can be loaded with the scripts provided in utilities.py. 35 | 36 | - [PDE datasets](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-?usp=sharing) 37 | 38 | -------------------------------------------------------------------------------- /graph-neural-operator/README.md: -------------------------------------------------------------------------------- 1 | # Graph-PDE 2 | This repository contains the code for the paper: 3 | [Neural Operator: Graph Kernel Network for Partial Differential Equations](https://arxiv.org/abs/2003.03485) 4 | 5 | It depends on Pytorch and torch-geometric. 6 | 7 | 8 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI1_full_resolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import Data, DataLoader 7 | import matplotlib.pyplot as plt 8 | from utilities import * 9 | from nn_conv import NNConv_old 10 | 11 | from timeit import default_timer 12 | 13 | 14 | class KernelNN(torch.nn.Module): 15 | def __init__(self, width, ker_width, depth, ker_in, in_width=1, out_width=1): 16 | super(KernelNN, self).__init__() 17 | self.depth = depth 18 | 19 | self.fc1 = torch.nn.Linear(in_width, width) 20 | 21 | kernel = DenseNet([ker_in, ker_width, ker_width, width**2], torch.nn.ReLU) 22 | self.conv1 = NNConv_old(width, width, kernel, aggr='mean') 23 | 24 | self.fc2 = torch.nn.Linear(width, 1) 25 | 26 | def forward(self, data): 27 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 28 | x = self.fc1(x) 29 | for k in range(self.depth): 30 | x = F.relu(self.conv1(x, edge_index, edge_attr)) 31 | 32 | x = self.fc2(x) 33 | return x 34 | 35 | 36 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 37 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 38 | 39 | r = 4 40 | s = int(((241 - 1)/r) + 1) 41 | n = s**2 42 | m = 100 43 | k = 1 44 | 45 | radius_train = 0.1 46 | radius_test = 0.1 47 | 48 | print('resolution', s) 49 | 50 | 51 | ntrain = 100 52 | ntest = 40 53 | 54 | batch_size = 1 55 | batch_size2 = 2 56 | width = 64 57 | ker_width = 1024 58 | depth = 6 59 | edge_features = 6 60 | node_features = 6 61 | 62 | epochs = 200 63 | learning_rate = 0.0001 64 | scheduler_step = 50 65 | scheduler_gamma = 0.8 66 | 67 | path = 'UAI1_r'+str(s)+'_n'+ str(ntrain) 68 | path_model = 'model/'+path+'' 69 | path_train_err = 'results/'+path+'train.txt' 70 | path_test_err = 'results/'+path+'test.txt' 71 | path_image = 'image/'+path+'' 72 | path_train_err = 'results/'+path+'train' 73 | path_test_err16 = 'results/'+path+'test16' 74 | path_test_err31 = 'results/'+path+'test31' 75 | path_test_err61 = 'results/'+path+'test61' 76 | path_image_train = 'image/'+path+'train' 77 | path_image_test16 = 'image/'+path+'test16' 78 | path_image_test31 = 'image/'+path+'test31' 79 | path_image_test61 = 'image/'+path+'test61' 80 | 81 | t1 = default_timer() 82 | 83 | 84 | reader = MatReader(TRAIN_PATH) 85 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 86 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 87 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 88 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 89 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 90 | train_u64 = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 91 | 92 | reader.load_file(TEST_PATH) 93 | test_a = reader.read_field('coeff')[:ntest,::4,::4].reshape(ntest,-1) 94 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::4,::4].reshape(ntest,-1) 95 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::4,::4].reshape(ntest,-1) 96 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::4,::4].reshape(ntest,-1) 97 | test_u = reader.read_field('sol')[:ntest,::4,::4].reshape(ntest,-1) 98 | 99 | 100 | a_normalizer = GaussianNormalizer(train_a) 101 | train_a = a_normalizer.encode(train_a) 102 | test_a = a_normalizer.encode(test_a) 103 | as_normalizer = GaussianNormalizer(train_a_smooth) 104 | train_a_smooth = as_normalizer.encode(train_a_smooth) 105 | test_a_smooth = as_normalizer.encode(test_a_smooth) 106 | agx_normalizer = GaussianNormalizer(train_a_gradx) 107 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 108 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 109 | agy_normalizer = GaussianNormalizer(train_a_grady) 110 | train_a_grady = agy_normalizer.encode(train_a_grady) 111 | test_a_grady = agy_normalizer.encode(test_a_grady) 112 | 113 | 114 | test_a = test_a.reshape(ntest,61,61) 115 | test_a_smooth = test_a_smooth.reshape(ntest,61,61) 116 | test_a_gradx = test_a_gradx.reshape(ntest,61,61) 117 | test_a_grady = test_a_grady.reshape(ntest,61,61) 118 | test_u = test_u.reshape(ntest,61,61) 119 | 120 | test_a16 =test_a[:ntest,::4,::4].reshape(ntest,-1) 121 | test_a_smooth16 = test_a_smooth[:ntest,::4,::4].reshape(ntest,-1) 122 | test_a_gradx16 = test_a_gradx[:ntest,::4,::4].reshape(ntest,-1) 123 | test_a_grady16 = test_a_grady[:ntest,::4,::4].reshape(ntest,-1) 124 | test_u16 = test_u[:ntest,::4,::4].reshape(ntest,-1) 125 | test_a31 =test_a[:ntest,::2,::2].reshape(ntest,-1) 126 | test_a_smooth31 = test_a_smooth[:ntest,::2,::2].reshape(ntest,-1) 127 | test_a_gradx31 = test_a_gradx[:ntest,::2,::2].reshape(ntest,-1) 128 | test_a_grady31 = test_a_grady[:ntest,::2,::2].reshape(ntest,-1) 129 | test_u31 = test_u[:ntest,::2,::2].reshape(ntest,-1) 130 | test_a =test_a.reshape(ntest,-1) 131 | test_a_smooth = test_a_smooth.reshape(ntest,-1) 132 | test_a_gradx = test_a_gradx.reshape(ntest,-1) 133 | test_a_grady = test_a_grady.reshape(ntest,-1) 134 | test_u = test_u.reshape(ntest,-1) 135 | 136 | 137 | u_normalizer = GaussianNormalizer(train_u) 138 | train_u = u_normalizer.encode(train_u) 139 | # test_u = y_normalizer.encode(test_u) 140 | 141 | 142 | 143 | meshgenerator = SquareMeshGenerator([[0,1],[0,1]],[s,s]) 144 | edge_index = meshgenerator.ball_connectivity(radius_train) 145 | grid = meshgenerator.get_grid() 146 | # meshgenerator.get_boundary() 147 | # edge_index_boundary = meshgenerator.boundary_connectivity2d(stride = stride) 148 | 149 | data_train = [] 150 | for j in range(ntrain): 151 | edge_attr = meshgenerator.attributes(theta=train_a[j,:]) 152 | # edge_attr_boundary = meshgenerator.attributes_boundary(theta=train_u[j,:]) 153 | data_train.append(Data(x=torch.cat([grid, train_a[j,:].reshape(-1, 1), 154 | train_a_smooth[j,:].reshape(-1, 1), train_a_gradx[j,:].reshape(-1, 1), train_a_grady[j,:].reshape(-1, 1) 155 | ], dim=1), 156 | y=train_u[j,:], coeff=train_a[j,:], 157 | edge_index=edge_index, edge_attr=edge_attr, 158 | # edge_index_boundary=edge_index_boundary, edge_attr_boundary= edge_attr_boundary 159 | )) 160 | 161 | print('train grid', grid.shape, 'edge_index', edge_index.shape, 'edge_attr', edge_attr.shape) 162 | 163 | meshgenerator = SquareMeshGenerator([[0,1],[0,1]],[16,16]) 164 | edge_index = meshgenerator.ball_connectivity(radius_test) 165 | grid = meshgenerator.get_grid() 166 | # meshgenerator.get_boundary() 167 | # edge_index_boundary = meshgenerator.boundary_connectivity2d(stride = stride) 168 | data_test16 = [] 169 | for j in range(ntest): 170 | edge_attr = meshgenerator.attributes(theta=test_a16[j,:]) 171 | # edge_attr_boundary = meshgenerator.attributes_boundary(theta=test_a[j, :]) 172 | data_test16.append(Data(x=torch.cat([grid, test_a16[j,:].reshape(-1, 1), 173 | test_a_smooth16[j,:].reshape(-1, 1), test_a_gradx16[j,:].reshape(-1, 1), test_a_grady16[j,:].reshape(-1, 1) 174 | ], dim=1), 175 | y=test_u16[j, :], coeff=test_a16[j,:], 176 | edge_index=edge_index, edge_attr=edge_attr, 177 | # edge_index_boundary=edge_index_boundary, edge_attr_boundary=edge_attr_boundary 178 | )) 179 | 180 | print('16 grid', grid.shape, 'edge_index', edge_index.shape, 'edge_attr', edge_attr.shape) 181 | # print('edge_index_boundary', edge_index_boundary.shape, 'edge_attr', edge_attr_boundary.shape) 182 | 183 | meshgenerator = SquareMeshGenerator([[0,1],[0,1]],[31,31]) 184 | edge_index = meshgenerator.ball_connectivity(radius_test) 185 | grid = meshgenerator.get_grid() 186 | # meshgenerator.get_boundary() 187 | # edge_index_boundary = meshgenerator.boundary_connectivity2d(stride = stride) 188 | data_test31 = [] 189 | for j in range(ntest): 190 | edge_attr = meshgenerator.attributes(theta=test_a31[j,:]) 191 | # edge_attr_boundary = meshgenerator.attributes_boundary(theta=test_a[j, :]) 192 | data_test31.append(Data(x=torch.cat([grid, test_a31[j,:].reshape(-1, 1), 193 | test_a_smooth31[j,:].reshape(-1, 1), test_a_gradx31[j,:].reshape(-1, 1), test_a_grady31[j,:].reshape(-1, 1) 194 | ], dim=1), 195 | y=test_u31[j, :], coeff=test_a31[j,:], 196 | edge_index=edge_index, edge_attr=edge_attr, 197 | # edge_index_boundary=edge_index_boundary, edge_attr_boundary=edge_attr_boundary 198 | )) 199 | 200 | print('31 grid', grid.shape, 'edge_index', edge_index.shape, 'edge_attr', edge_attr.shape) 201 | # print('edge_index_boundary', edge_index_boundary.shape, 'edge_attr', edge_attr_boundary.shape) 202 | 203 | meshgenerator = SquareMeshGenerator([[0,1],[0,1]],[61,61]) 204 | edge_index = meshgenerator.ball_connectivity(radius_test) 205 | grid = meshgenerator.get_grid() 206 | # meshgenerator.get_boundary() 207 | # edge_index_boundary = meshgenerator.boundary_connectivity2d(stride = stride) 208 | data_test61 = [] 209 | for j in range(ntest): 210 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 211 | # edge_attr_boundary = meshgenerator.attributes_boundary(theta=test_a[j, :]) 212 | data_test61.append(Data(x=torch.cat([grid, test_a[j,:].reshape(-1, 1), 213 | test_a_smooth[j,:].reshape(-1, 1), test_a_gradx[j,:].reshape(-1, 1), test_a_grady[j,:].reshape(-1, 1) 214 | ], dim=1), 215 | y=test_u[j, :], coeff=test_a[j,:], 216 | edge_index=edge_index, edge_attr=edge_attr, 217 | # edge_index_boundary=edge_index_boundary, edge_attr_boundary=edge_attr_boundary 218 | )) 219 | 220 | print('61 grid', grid.shape, 'edge_index', edge_index.shape, 'edge_attr', edge_attr.shape) 221 | # print('edge_index_boundary', edge_index_boundary.shape, 'edge_attr', edge_attr_boundary.shape) 222 | 223 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 224 | test_loader16 = DataLoader(data_test16, batch_size=batch_size2, shuffle=False) 225 | test_loader31 = DataLoader(data_test31, batch_size=batch_size2, shuffle=False) 226 | test_loader61 = DataLoader(data_test61, batch_size=batch_size2, shuffle=False) 227 | 228 | 229 | 230 | ################################################################################################## 231 | 232 | ### training 233 | 234 | ################################################################################################## 235 | t2 = default_timer() 236 | 237 | print('preprocessing finished, time used:', t2-t1) 238 | device = torch.device('cuda') 239 | 240 | model = KernelNN(width,ker_width,depth,edge_features,node_features).cuda() 241 | 242 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 243 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 244 | 245 | myloss = LpLoss(size_average=False) 246 | u_normalizer.cuda() 247 | 248 | model.train() 249 | ttrain = np.zeros((epochs, )) 250 | ttest16 = np.zeros((epochs,)) 251 | ttest31 = np.zeros((epochs,)) 252 | ttest61 = np.zeros((epochs,)) 253 | 254 | for ep in range(epochs): 255 | t1 = default_timer() 256 | train_mse = 0.0 257 | train_l2 = 0.0 258 | for batch in train_loader: 259 | batch = batch.to(device) 260 | 261 | optimizer.zero_grad() 262 | out = model(batch) 263 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 264 | # mse.backward() 265 | loss = torch.norm(out.view(-1) - batch.y.view(-1),1) 266 | loss.backward() 267 | 268 | l2 = myloss(u_normalizer.decode(out.view(batch_size,-1)), u_normalizer.decode(batch.y.view(batch_size, -1))) 269 | # l2.backward() 270 | 271 | optimizer.step() 272 | train_mse += mse.item() 273 | train_l2 += l2.item() 274 | 275 | scheduler.step() 276 | t2 = default_timer() 277 | 278 | model.eval() 279 | 280 | 281 | ttrain[ep] = train_l2/(ntrain * k) 282 | 283 | print(ep, ' time:', t2-t1, ' train_mse:', train_mse/len(train_loader)) 284 | 285 | t1 = default_timer() 286 | u_normalizer.cpu() 287 | model = model.cpu() 288 | test_l2_16 = 0.0 289 | test_l2_31 = 0.0 290 | test_l2_61 = 0.0 291 | with torch.no_grad(): 292 | for batch in test_loader16: 293 | out = model(batch) 294 | test_l2_16 += myloss(u_normalizer.decode(out.view(batch_size2,-1)), 295 | batch.y.view(batch_size2, -1)) 296 | for batch in test_loader31: 297 | out = model(batch) 298 | test_l2_31 += myloss(u_normalizer.decode(out.view(batch_size2, -1)), 299 | batch.y.view(batch_size2, -1)) 300 | for batch in test_loader61: 301 | out = model(batch) 302 | test_l2_61 += myloss(u_normalizer.decode(out.view(batch_size2, -1)), 303 | batch.y.view(batch_size2, -1)) 304 | 305 | ttest16[ep] = test_l2_16 / ntest 306 | ttest31[ep] = test_l2_31 / ntest 307 | ttest61[ep] = test_l2_61 / ntest 308 | t2 = default_timer() 309 | 310 | print(' time:', t2-t1, ' train_mse:', train_mse/len(train_loader), 311 | ' test16:', test_l2_16/ntest, ' test31:', test_l2_31/ntest, ' test61:', test_l2_61/ntest) 312 | np.savetxt(path_train_err + '.txt', ttrain) 313 | np.savetxt(path_test_err16 + '.txt', ttest16) 314 | np.savetxt(path_test_err31 + '.txt', ttest31) 315 | np.savetxt(path_test_err61 + '.txt', ttest61) 316 | 317 | torch.save(model, path_model) 318 | 319 | ################################################################################################## 320 | 321 | ### Ploting 322 | 323 | ################################################################################################## 324 | 325 | 326 | 327 | resolution = s 328 | data = train_loader.dataset[0] 329 | coeff = data.coeff.numpy().reshape((resolution, resolution)) 330 | truth = u_normalizer.decode(data.y.reshape(1,-1)).numpy().reshape((resolution, resolution)) 331 | approx = u_normalizer.decode(model(data).reshape(1,-1)).detach().numpy().reshape((resolution, resolution)) 332 | _min = np.min(np.min(truth)) 333 | _max = np.max(np.max(truth)) 334 | 335 | plt.figure() 336 | plt.subplot(1, 3, 1) 337 | plt.imshow(truth, vmin = _min, vmax=_max) 338 | plt.xticks([], []) 339 | plt.yticks([], []) 340 | plt.colorbar(fraction=0.046, pad=0.04) 341 | plt.title('Ground Truth') 342 | 343 | plt.subplot(1, 3, 2) 344 | plt.imshow(approx, vmin = _min, vmax=_max) 345 | plt.xticks([], []) 346 | plt.yticks([], []) 347 | plt.colorbar(fraction=0.046, pad=0.04) 348 | plt.title('Approximation') 349 | 350 | plt.subplot(1, 3, 3) 351 | plt.imshow((approx - truth) ** 2) 352 | plt.xticks([], []) 353 | plt.yticks([], []) 354 | plt.colorbar(fraction=0.046, pad=0.04) 355 | plt.title('Error') 356 | 357 | plt.subplots_adjust(wspace=0.5, hspace=0.5) 358 | plt.savefig(path_image_train + '.png') 359 | 360 | 361 | resolution = 16 362 | data = test_loader16.dataset[0] 363 | coeff = data.coeff.numpy().reshape((resolution, resolution)) 364 | truth = data.y.numpy().reshape((resolution, resolution)) 365 | approx = u_normalizer.decode(model(data).reshape(1,-1)).detach().numpy().reshape((resolution, resolution)) 366 | _min = np.min(np.min(truth)) 367 | _max = np.max(np.max(truth)) 368 | 369 | plt.figure() 370 | plt.subplot(1, 3, 1) 371 | plt.imshow(truth, vmin = _min, vmax=_max) 372 | plt.xticks([], []) 373 | plt.yticks([], []) 374 | plt.colorbar(fraction=0.046, pad=0.04) 375 | plt.title('Ground Truth') 376 | 377 | plt.subplot(1, 3, 2) 378 | plt.imshow(approx, vmin = _min, vmax=_max) 379 | plt.xticks([], []) 380 | plt.yticks([], []) 381 | plt.colorbar(fraction=0.046, pad=0.04) 382 | plt.title('Approximation') 383 | 384 | plt.subplot(1, 3, 3) 385 | plt.imshow((approx - truth) ** 2) 386 | plt.xticks([], []) 387 | plt.yticks([], []) 388 | plt.colorbar(fraction=0.046, pad=0.04) 389 | plt.title('Error') 390 | 391 | plt.subplots_adjust(wspace=0.5, hspace=0.5) 392 | plt.savefig(path_image_test16 + '.png') 393 | 394 | resolution = 31 395 | data = test_loader31.dataset[0] 396 | coeff = data.coeff.numpy().reshape((resolution, resolution)) 397 | truth = data.y.numpy().reshape((resolution, resolution)) 398 | approx = u_normalizer.decode(model(data).reshape(1,-1)).detach().numpy().reshape((resolution, resolution)) 399 | _min = np.min(np.min(truth)) 400 | _max = np.max(np.max(truth)) 401 | 402 | # plt.figure() 403 | plt.figure() 404 | plt.subplot(1, 3, 1) 405 | plt.imshow(truth, vmin = _min, vmax=_max) 406 | plt.xticks([], []) 407 | plt.yticks([], []) 408 | plt.colorbar(fraction=0.046, pad=0.04) 409 | plt.title('Ground Truth') 410 | 411 | plt.subplot(1, 3, 2) 412 | plt.imshow(approx, vmin = _min, vmax=_max) 413 | plt.xticks([], []) 414 | plt.yticks([], []) 415 | plt.colorbar(fraction=0.046, pad=0.04) 416 | plt.title('Approximation') 417 | 418 | plt.subplot(1, 3, 3) 419 | plt.imshow((approx - truth) ** 2) 420 | plt.xticks([], []) 421 | plt.yticks([], []) 422 | plt.colorbar(fraction=0.046, pad=0.04) 423 | plt.title('Error') 424 | 425 | plt.subplots_adjust(wspace=0.5, hspace=0.5) 426 | plt.savefig(path_image_test31 + '.png') 427 | 428 | 429 | resolution = 61 430 | data = test_loader61.dataset[0] 431 | coeff = data.coeff.numpy().reshape((resolution, resolution)) 432 | truth = data.y.numpy().reshape((resolution, resolution)) 433 | approx = u_normalizer.decode(model(data).reshape(1,-1)).detach().numpy().reshape((resolution, resolution)) 434 | _min = np.min(np.min(truth)) 435 | _max = np.max(np.max(truth)) 436 | 437 | # plt.figure() 438 | plt.figure() 439 | plt.subplot(1, 3, 1) 440 | plt.imshow(truth, vmin = _min, vmax=_max) 441 | plt.xticks([], []) 442 | plt.yticks([], []) 443 | plt.colorbar(fraction=0.046, pad=0.04) 444 | plt.title('Ground Truth') 445 | 446 | plt.subplot(1, 3, 2) 447 | plt.imshow(approx, vmin = _min, vmax=_max) 448 | plt.xticks([], []) 449 | plt.yticks([], []) 450 | plt.colorbar(fraction=0.046, pad=0.04) 451 | plt.title('Approximation') 452 | 453 | plt.subplot(1, 3, 3) 454 | plt.imshow((approx - truth) ** 2) 455 | plt.xticks([], []) 456 | plt.yticks([], []) 457 | plt.colorbar(fraction=0.046, pad=0.04) 458 | plt.title('Error') 459 | 460 | plt.subplots_adjust(wspace=0.5, hspace=0.5) 461 | plt.savefig(path_image_test61 + '.png') 462 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI2_full_equation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import Data, DataLoader 7 | import matplotlib.pyplot as plt 8 | from utilities import * 9 | from nn_conv import NNConv_old 10 | 11 | from timeit import default_timer 12 | 13 | 14 | class KernelNN(torch.nn.Module): 15 | def __init__(self, width, ker_width, depth, ker_in, in_width=1, out_width=1): 16 | super(KernelNN, self).__init__() 17 | self.depth = depth 18 | 19 | self.fc1 = torch.nn.Linear(in_width, width) 20 | 21 | kernel = DenseNet([ker_in, ker_width, ker_width, width**2], torch.nn.ReLU) 22 | self.conv1 = NNConv_old(width, width, kernel, aggr='mean') 23 | 24 | self.fc2 = torch.nn.Linear(width, 1) 25 | 26 | def forward(self, data): 27 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 28 | x = self.fc1(x) 29 | for k in range(self.depth): 30 | x = F.relu(self.conv1(x, edge_index, edge_attr)) 31 | 32 | x = self.fc2(x) 33 | return x 34 | 35 | 36 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 37 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 38 | 39 | r = 8 40 | s = int(((241 - 1)/r) + 1) 41 | n = s**2 42 | 43 | radius_train = 0.10 44 | radius_test = 0.10 45 | 46 | print('resolution', s) 47 | 48 | 49 | ntrain = 10 50 | ntest = 100 51 | 52 | batch_size = 2 53 | batch_size2 = 2 54 | width = 64 55 | ker_width = 1024 56 | depth = 6 57 | edge_features = 6 58 | node_features = 6 59 | 60 | epochs = 5000 61 | learning_rate = 0.0001 62 | scheduler_step = 50 63 | scheduler_gamma = 0.5 64 | 65 | path = 'UAI2_new2_r'+str(s)+'_n'+ str(ntrain) 66 | path_model = 'model/'+path+'' 67 | path_train_err = 'results/'+path+'train.txt' 68 | path_test_err = 'results/'+path+'test.txt' 69 | path_image = 'image/'+path+'' 70 | 71 | 72 | t1 = default_timer() 73 | 74 | 75 | reader = MatReader(TRAIN_PATH) 76 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 77 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 78 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 79 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 80 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 81 | 82 | reader.load_file(TEST_PATH) 83 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 84 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 85 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 86 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 87 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 88 | 89 | 90 | a_normalizer = GaussianNormalizer(train_a) 91 | train_a = a_normalizer.encode(train_a) 92 | test_a = a_normalizer.encode(test_a) 93 | as_normalizer = GaussianNormalizer(train_a_smooth) 94 | train_a_smooth = as_normalizer.encode(train_a_smooth) 95 | test_a_smooth = as_normalizer.encode(test_a_smooth) 96 | agx_normalizer = GaussianNormalizer(train_a_gradx) 97 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 98 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 99 | agy_normalizer = GaussianNormalizer(train_a_grady) 100 | train_a_grady = agy_normalizer.encode(train_a_grady) 101 | test_a_grady = agy_normalizer.encode(test_a_grady) 102 | 103 | u_normalizer = UnitGaussianNormalizer(train_u) 104 | train_u = u_normalizer.encode(train_u) 105 | # test_u = y_normalizer.encode(test_u) 106 | 107 | 108 | meshgenerator = SquareMeshGenerator([[0,1],[0,1]],[s,s]) 109 | edge_index = meshgenerator.ball_connectivity(radius_train) 110 | grid = meshgenerator.get_grid() 111 | # meshgenerator.get_boundary() 112 | # edge_index_boundary = meshgenerator.boundary_connectivity2d(stride = stride) 113 | 114 | data_train = [] 115 | for j in range(ntrain): 116 | edge_attr = meshgenerator.attributes(theta=train_a[j,:]) 117 | # edge_attr_boundary = meshgenerator.attributes_boundary(theta=train_u[j,:]) 118 | data_train.append(Data(x=torch.cat([grid, train_a[j,:].reshape(-1, 1), 119 | train_a_smooth[j,:].reshape(-1, 1), train_a_gradx[j,:].reshape(-1, 1), train_a_grady[j,:].reshape(-1, 1) 120 | ], dim=1), 121 | y=train_u[j,:], coeff=train_a[j,:], 122 | edge_index=edge_index, edge_attr=edge_attr, 123 | # edge_index_boundary=edge_index_boundary, edge_attr_boundary= edge_attr_boundary 124 | )) 125 | 126 | meshgenerator = SquareMeshGenerator([[0,1],[0,1]],[s,s]) 127 | edge_index = meshgenerator.ball_connectivity(radius_test) 128 | grid = meshgenerator.get_grid() 129 | # meshgenerator.get_boundary() 130 | # edge_index_boundary = meshgenerator.boundary_connectivity2d(stride = stride) 131 | data_test = [] 132 | for j in range(ntest): 133 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 134 | # edge_attr_boundary = meshgenerator.attributes_boundary(theta=test_a[j, :]) 135 | data_test.append(Data(x=torch.cat([grid, test_a[j,:].reshape(-1, 1), 136 | test_a_smooth[j,:].reshape(-1, 1), test_a_gradx[j,:].reshape(-1, 1), test_a_grady[j,:].reshape(-1, 1) 137 | ], dim=1), 138 | y=test_u[j, :], coeff=test_a[j,:], 139 | edge_index=edge_index, edge_attr=edge_attr, 140 | # edge_index_boundary=edge_index_boundary, edge_attr_boundary=edge_attr_boundary 141 | )) 142 | 143 | print('grid', grid.shape, 'edge_index', edge_index.shape, 'edge_attr', edge_attr.shape) 144 | # print('edge_index_boundary', edge_index_boundary.shape, 'edge_attr', edge_attr_boundary.shape) 145 | 146 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 147 | test_loader = DataLoader(data_test, batch_size=batch_size2, shuffle=False) 148 | 149 | 150 | 151 | ################################################################################################## 152 | 153 | ### training 154 | 155 | ################################################################################################## 156 | t2 = default_timer() 157 | 158 | print('preprocessing finished, time used:', t2-t1) 159 | device = torch.device('cuda') 160 | 161 | model = KernelNN(width,ker_width,depth,edge_features,node_features).cuda() 162 | 163 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 164 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 165 | 166 | myloss = LpLoss(size_average=False) 167 | u_normalizer.cuda() 168 | 169 | model.train() 170 | ttrain = np.zeros((epochs, )) 171 | ttest = np.zeros((epochs,)) 172 | for ep in range(epochs): 173 | t1 = default_timer() 174 | train_mse = 0.0 175 | train_l2 = 0.0 176 | for batch in train_loader: 177 | batch = batch.to(device) 178 | 179 | optimizer.zero_grad() 180 | out = model(batch) 181 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 182 | # mse.backward() 183 | loss = torch.norm(out.view(-1) - batch.y.view(-1),1) 184 | loss.backward() 185 | 186 | l2 = myloss(u_normalizer.decode(out.view(batch_size,-1)), u_normalizer.decode(batch.y.view(batch_size, -1))) 187 | # l2.backward() 188 | 189 | optimizer.step() 190 | train_mse += mse.item() 191 | train_l2 += l2.item() 192 | 193 | scheduler.step() 194 | t2 = default_timer() 195 | 196 | model.eval() 197 | test_l2 = 0.0 198 | if ep%100==99: 199 | with torch.no_grad(): 200 | for batch in test_loader: 201 | batch = batch.to(device) 202 | out = model(batch) 203 | test_l2 += myloss(u_normalizer.decode(out.view(batch_size2,-1)), batch.y.view(batch_size2, -1)).item() 204 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 205 | 206 | ttrain[ep] = train_l2/(ntrain) 207 | ttest[ep] = test_l2/ntest 208 | 209 | print(ep, t2-t1, train_mse/len(train_loader), train_l2/(ntrain), test_l2/ntest) 210 | 211 | np.savetxt(path_train_err, ttrain) 212 | np.savetxt(path_test_err, ttest) 213 | 214 | torch.save(model, path_model) 215 | ################################################################################################## 216 | 217 | ### Ploting 218 | 219 | ################################################################################################## 220 | 221 | 222 | plt.figure() 223 | # plt.plot(ttrain, label='train loss') 224 | plt.plot(ttest, label='test loss') 225 | plt.legend(loc='upper right') 226 | plt.show() 227 | 228 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI3_resolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import Data, DataLoader 7 | from utilities import * 8 | from nn_conv import NNConv_old 9 | 10 | from timeit import default_timer 11 | 12 | 13 | class KernelNN3(torch.nn.Module): 14 | def __init__(self, width_node, width_kernel, depth, ker_in, in_width=1, out_width=1): 15 | super(KernelNN3, self).__init__() 16 | self.depth = depth 17 | 18 | self.fc1 = torch.nn.Linear(in_width, width_node) 19 | 20 | kernel = DenseNet([ker_in, width_kernel // 2, width_kernel, width_node**2], torch.nn.ReLU) 21 | self.conv1 = NNConv_old(width_node, width_node, kernel, aggr='mean') 22 | 23 | self.fc2 = torch.nn.Linear(width_node, 1) 24 | 25 | def forward(self, data): 26 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 27 | x = self.fc1(x) 28 | for k in range(self.depth): 29 | x = F.relu(self.conv1(x, edge_index, edge_attr)) 30 | 31 | x = self.fc2(x) 32 | return x 33 | 34 | 35 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 36 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 37 | 38 | for r in (1,2,4,8,16): 39 | 40 | # r = 2 41 | s = int(((241 - 1)/r) + 1) 42 | n = s**2 43 | m = 200 44 | k = 2 45 | 46 | radius_train = 0.25 47 | radius_test = 0.25 48 | print('resolution', s) 49 | 50 | 51 | ntrain = 100 52 | ntest = 100 53 | 54 | 55 | batch_size = 10 56 | batch_size2 = 10 57 | width = 64 58 | ker_width = 1000 59 | depth = 6 60 | edge_features = 6 61 | node_features = 6 62 | 63 | epochs = 200 64 | learning_rate = 0.0001 65 | scheduler_step = 50 66 | scheduler_gamma = 0.5 67 | 68 | 69 | path = 'UAI3_s'+str(s) 70 | path_model = 'model/' + path 71 | path_train_err = 'results/'+path+'train.txt' 72 | path_test_err1 = 'results/'+path+'test61.txt' 73 | path_test_err2 = 'results/'+path+'test121.txt' 74 | path_test_err3 = 'results/'+path+'test241.txt' 75 | 76 | t1 = default_timer() 77 | 78 | 79 | reader = MatReader(TRAIN_PATH) 80 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 81 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 82 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 83 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 84 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 85 | 86 | reader.load_file(TEST_PATH) 87 | test_a = reader.read_field('coeff')[:ntest,:,:] 88 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,:,:] 89 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,:,:] 90 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,:,:] 91 | test_u = reader.read_field('sol')[:ntest,:,:] 92 | 93 | a_normalizer = GaussianNormalizer(train_a) 94 | train_a = a_normalizer.encode(train_a) 95 | test_a = a_normalizer.encode(test_a) 96 | as_normalizer = GaussianNormalizer(train_a_smooth) 97 | train_a_smooth = as_normalizer.encode(train_a_smooth) 98 | test_a_smooth = as_normalizer.encode(test_a_smooth) 99 | agx_normalizer = GaussianNormalizer(train_a_gradx) 100 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 101 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 102 | agy_normalizer = GaussianNormalizer(train_a_grady) 103 | train_a_grady = agy_normalizer.encode(train_a_grady) 104 | test_a_grady = agy_normalizer.encode(test_a_grady) 105 | 106 | 107 | test_a61 = test_a[:ntest, ::4, ::4].reshape(ntest, -1) 108 | test_a_smooth61 = test_a_smooth[:ntest, ::4, ::4].reshape(ntest, -1) 109 | test_a_gradx61 = test_a_gradx[:ntest, ::4, ::4].reshape(ntest, -1) 110 | test_a_grady61 = test_a_grady[:ntest, ::4, ::4].reshape(ntest, -1) 111 | test_u61 = test_u[:ntest, ::4, ::4].reshape(ntest, -1) 112 | 113 | test_a121 = test_a[:ntest, ::2, ::2].reshape(ntest, -1) 114 | test_a_smooth121 = test_a_smooth[:ntest, ::2, ::2].reshape(ntest, -1) 115 | test_a_gradx121 = test_a_gradx[:ntest, ::2, ::2].reshape(ntest, -1) 116 | test_a_grady121 = test_a_grady[:ntest, ::2, ::2].reshape(ntest, -1) 117 | test_u121 = test_u[:ntest, ::2, ::2].reshape(ntest, -1) 118 | 119 | test_a241 = test_a.reshape(ntest, -1) 120 | test_a_smooth241 = test_a_smooth.reshape(ntest, -1) 121 | test_a_gradx241 = test_a_gradx.reshape(ntest, -1) 122 | test_a_grady241 = test_a_grady.reshape(ntest, -1) 123 | test_u241 = test_u.reshape(ntest, -1) 124 | 125 | u_normalizer = GaussianNormalizer(train_u) 126 | train_u = u_normalizer.encode(train_u) 127 | # test_u = y_normalizer.encode(test_u) 128 | 129 | 130 | 131 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 132 | data_train = [] 133 | for j in range(ntrain): 134 | for i in range(k): 135 | idx = meshgenerator.sample() 136 | grid = meshgenerator.get_grid() 137 | edge_index = meshgenerator.ball_connectivity(radius_train) 138 | edge_attr = meshgenerator.attributes(theta=train_a[j,:]) 139 | #data_train.append(Data(x=init_point.clone().view(-1,1), y=train_y[j,:], edge_index=edge_index, edge_attr=edge_attr)) 140 | data_train.append(Data(x=torch.cat([grid, train_a[j, idx].reshape(-1, 1), 141 | train_a_smooth[j, idx].reshape(-1, 1), train_a_gradx[j, idx].reshape(-1, 1), 142 | train_a_grady[j, idx].reshape(-1, 1) 143 | ], dim=1), 144 | y=train_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 145 | )) 146 | 147 | 148 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[61,61], sample_size=m) 149 | data_test1 = [] 150 | for j in range(ntest): 151 | idx = meshgenerator.sample() 152 | grid = meshgenerator.get_grid() 153 | edge_index = meshgenerator.ball_connectivity(radius_test) 154 | edge_attr = meshgenerator.attributes(theta=test_a61[j,:]) 155 | data_test1.append(Data(x=torch.cat([grid, test_a61[j, idx].reshape(-1, 1), 156 | test_a_smooth61[j, idx].reshape(-1, 1), test_a_gradx61[j, idx].reshape(-1, 1), 157 | test_a_grady61[j, idx].reshape(-1, 1) 158 | ], dim=1), 159 | y=test_u61[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 160 | )) 161 | # 162 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[121,121], sample_size=m) 163 | data_test2 = [] 164 | for j in range(ntest): 165 | idx = meshgenerator.sample() 166 | grid = meshgenerator.get_grid() 167 | edge_index = meshgenerator.ball_connectivity(radius_test) 168 | edge_attr = meshgenerator.attributes(theta=test_a121[j,:]) 169 | data_test2.append(Data(x=torch.cat([grid, test_a121[j, idx].reshape(-1, 1), 170 | test_a_smooth121[j, idx].reshape(-1, 1), test_a_gradx121[j, idx].reshape(-1, 1), 171 | test_a_grady121[j, idx].reshape(-1, 1) 172 | ], dim=1), 173 | y=test_u121[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 174 | )) 175 | # 176 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[241,241], sample_size=m) 177 | data_test3 = [] 178 | for j in range(ntest): 179 | idx = meshgenerator.sample() 180 | grid = meshgenerator.get_grid() 181 | edge_index = meshgenerator.ball_connectivity(radius_test) 182 | edge_attr = meshgenerator.attributes(theta=test_a241[j,:]) 183 | data_test3.append(Data(x=torch.cat([grid, test_a241[j, idx].reshape(-1, 1), 184 | test_a_smooth241[j, idx].reshape(-1, 1), test_a_gradx241[j, idx].reshape(-1, 1), 185 | test_a_grady241[j, idx].reshape(-1, 1) 186 | ], dim=1), 187 | y=test_u241[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 188 | )) 189 | # 190 | # 191 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 192 | test_loader1 = DataLoader(data_test1, batch_size=batch_size2, shuffle=False) 193 | test_loader2 = DataLoader(data_test2, batch_size=batch_size2, shuffle=False) 194 | test_loader3 = DataLoader(data_test3, batch_size=batch_size2, shuffle=False) 195 | 196 | 197 | t2 = default_timer() 198 | 199 | print('preprocessing finished, time used:', t2-t1) 200 | device = torch.device('cuda') 201 | 202 | model = KernelNN3(width, ker_width,depth,edge_features,in_width=node_features).cuda() 203 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 204 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 205 | 206 | myloss = LpLoss(size_average=False) 207 | u_normalizer.cuda() 208 | ttrain = np.zeros((epochs, )) 209 | ttest1 = np.zeros((epochs,)) 210 | ttest2 = np.zeros((epochs,)) 211 | ttest3 = np.zeros((epochs,)) 212 | model.train() 213 | for ep in range(epochs): 214 | t1 = default_timer() 215 | train_mse = 0.0 216 | train_l2 = 0.0 217 | for batch in train_loader: 218 | batch = batch.to(device) 219 | 220 | optimizer.zero_grad() 221 | out = model(batch) 222 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 223 | mse.backward() 224 | 225 | l2 = myloss(u_normalizer.decode(out.view(batch_size,-1), sample_idx=batch.sample_idx.view(batch_size,-1)), 226 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size,-1))) 227 | 228 | optimizer.step() 229 | train_mse += mse.item() 230 | train_l2 += l2.item() 231 | 232 | scheduler.step() 233 | t2 = default_timer() 234 | 235 | model.eval() 236 | test1_l2 = 0.0 237 | test2_l2 = 0.0 238 | test3_l2 = 0.0 239 | 240 | with torch.no_grad(): 241 | for batch in test_loader1: 242 | batch = batch.to(device) 243 | out = model(batch) 244 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 245 | test1_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 246 | for batch in test_loader2: 247 | batch = batch.to(device) 248 | out = model(batch) 249 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 250 | test2_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 251 | for batch in test_loader3: 252 | batch = batch.to(device) 253 | out = model(batch) 254 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 255 | test3_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 256 | 257 | 258 | ttrain[ep] = train_l2/(ntrain * k) 259 | ttest1[ep] = test1_l2 / ntest 260 | ttest2[ep] = test2_l2 / ntest 261 | ttest3[ep] = test3_l2 / ntest 262 | 263 | 264 | print(s, ep, t2-t1, train_mse/len(train_loader), train_l2/(ntrain * k)) 265 | print(test1_l2/ntest, test2_l2/ntest, test3_l2/ntest) 266 | 267 | np.savetxt(path_train_err, ttrain) 268 | np.savetxt(path_test_err1, ttest1) 269 | np.savetxt(path_test_err2, ttest2) 270 | np.savetxt(path_test_err3, ttest3) 271 | torch.save(model, path_model) 272 | 273 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI4_equation_sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import Data, DataLoader 7 | import matplotlib.pyplot as plt 8 | from utilities import * 9 | from nn_conv import NNConv_old 10 | 11 | from timeit import default_timer 12 | 13 | 14 | class KernelNN3(torch.nn.Module): 15 | def __init__(self, width_node, width_kernel, depth, ker_in, in_width=1, out_width=1): 16 | super(KernelNN3, self).__init__() 17 | self.depth = depth 18 | 19 | self.fc1 = torch.nn.Linear(in_width, width_node) 20 | 21 | kernel = DenseNet([ker_in, width_kernel // 2, width_kernel, width_node**2], torch.nn.ReLU) 22 | self.conv1 = NNConv_old(width_node, width_node, kernel, aggr='mean') 23 | 24 | self.fc2 = torch.nn.Linear(width_node, 1) 25 | 26 | def forward(self, data): 27 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 28 | x = self.fc1(x) 29 | for k in range(self.depth): 30 | x = self.conv1(x, edge_index, edge_attr) 31 | if k != self.depth - 1: 32 | x = F.relu(x) 33 | 34 | x = self.fc2(x) 35 | return x 36 | 37 | 38 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 39 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 40 | 41 | for k in (2,): 42 | for ntrain in (100,): 43 | r = 1 44 | s = int(((241 - 1)/r) + 1) 45 | n = s**2 46 | m = 200 47 | # k = 5 48 | 49 | radius_train = 0.25 50 | radius_test = 0.25 51 | print('resolution', s) 52 | 53 | 54 | # ntrain = 100 55 | ntest = 100 56 | 57 | if ntrain <= 50: 58 | batch_size = 5 59 | batch_size2 = 5 60 | else: 61 | batch_size = 20 62 | batch_size2 = 20 63 | width = 64 64 | ker_width = 64 65 | depth = 6 66 | edge_features = 6 67 | node_features = 6 68 | 69 | epochs = 200 70 | learning_rate = 0.0001 71 | scheduler_step = 50 72 | scheduler_gamma = 0.5 73 | 74 | 75 | path = 'UAI4_s'+str(s)+'_n'+ str(ntrain)+'_k'+ str(k) 76 | path_model = 'model/' + path 77 | path_train_err = 'results/' + path + 'train.txt' 78 | path_test_err = 'results/' + path + 'test.txt' 79 | path_image = 'results/' + path 80 | 81 | 82 | t1 = default_timer() 83 | 84 | 85 | reader = MatReader(TRAIN_PATH) 86 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 87 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 88 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 89 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 90 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 91 | 92 | reader.load_file(TEST_PATH) 93 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 94 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 95 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 96 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 97 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 98 | 99 | 100 | a_normalizer = GaussianNormalizer(train_a) 101 | train_a = a_normalizer.encode(train_a) 102 | test_a = a_normalizer.encode(test_a) 103 | as_normalizer = GaussianNormalizer(train_a_smooth) 104 | train_a_smooth = as_normalizer.encode(train_a_smooth) 105 | test_a_smooth = as_normalizer.encode(test_a_smooth) 106 | agx_normalizer = GaussianNormalizer(train_a_gradx) 107 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 108 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 109 | agy_normalizer = GaussianNormalizer(train_a_grady) 110 | train_a_grady = agy_normalizer.encode(train_a_grady) 111 | test_a_grady = agy_normalizer.encode(test_a_grady) 112 | 113 | u_normalizer = UnitGaussianNormalizer(train_u) 114 | train_u = u_normalizer.encode(train_u) 115 | # test_u = y_normalizer.encode(test_u) 116 | 117 | 118 | 119 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 120 | data_train = [] 121 | for j in range(ntrain): 122 | for i in range(k): 123 | idx = meshgenerator.sample() 124 | grid = meshgenerator.get_grid() 125 | edge_index = meshgenerator.ball_connectivity(radius_train) 126 | edge_attr = meshgenerator.attributes(theta=train_a[j,:]) 127 | #data_train.append(Data(x=init_point.clone().view(-1,1), y=train_y[j,:], edge_index=edge_index, edge_attr=edge_attr)) 128 | data_train.append(Data(x=torch.cat([grid, train_a[j, idx].reshape(-1, 1), 129 | train_a_smooth[j, idx].reshape(-1, 1), train_a_gradx[j, idx].reshape(-1, 1), 130 | train_a_grady[j, idx].reshape(-1, 1) 131 | ], dim=1), 132 | y=train_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 133 | )) 134 | 135 | 136 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 137 | data_test = [] 138 | for j in range(ntest): 139 | idx = meshgenerator.sample() 140 | grid = meshgenerator.get_grid() 141 | edge_index = meshgenerator.ball_connectivity(radius_test) 142 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 143 | data_test.append(Data(x=torch.cat([grid, test_a[j, idx].reshape(-1, 1), 144 | test_a_smooth[j, idx].reshape(-1, 1), test_a_gradx[j, idx].reshape(-1, 1), 145 | test_a_grady[j, idx].reshape(-1, 1) 146 | ], dim=1), 147 | y=test_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 148 | )) 149 | # 150 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 151 | test_loader = DataLoader(data_test, batch_size=batch_size2, shuffle=False) 152 | 153 | t2 = default_timer() 154 | 155 | print('preprocessing finished, time used:', t2-t1) 156 | device = torch.device('cuda') 157 | 158 | model = KernelNN3(width, ker_width,depth,edge_features,in_width=node_features).cuda() 159 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 160 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 161 | 162 | myloss = LpLoss(size_average=False) 163 | u_normalizer.cuda() 164 | ttrain = np.zeros((epochs, )) 165 | ttest = np.zeros((epochs,)) 166 | model.train() 167 | for ep in range(epochs): 168 | t1 = default_timer() 169 | train_mse = 0.0 170 | train_l2 = 0.0 171 | for batch in train_loader: 172 | batch = batch.to(device) 173 | 174 | optimizer.zero_grad() 175 | out = model(batch) 176 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 177 | mse.backward() 178 | 179 | l2 = myloss( 180 | u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 181 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1))) 182 | optimizer.step() 183 | train_mse += mse.item() 184 | train_l2 += l2.item() 185 | 186 | scheduler.step() 187 | t2 = default_timer() 188 | 189 | model.eval() 190 | test_l2 = 0.0 191 | with torch.no_grad(): 192 | for batch in test_loader: 193 | batch = batch.to(device) 194 | out = model(batch) 195 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 196 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 197 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 198 | 199 | ttrain[ep] = train_l2/(ntrain * k) 200 | ttest[ep] = test_l2/ntest 201 | 202 | print(k, ntrain, ep, t2-t1, train_mse/len(train_loader), train_l2/(ntrain * k), test_l2/ntest) 203 | 204 | np.savetxt(path_train_err, ttrain) 205 | np.savetxt(path_test_err, ttest) 206 | torch.save(model, path_model) 207 | 208 | plt.figure() 209 | # plt.plot(ttrain, label='train loss') 210 | plt.plot(ttest, label='test loss') 211 | plt.legend(loc='upper right') 212 | plt.show() 213 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI5_sample_generalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import Data, DataLoader 7 | from utilities import * 8 | from nn_conv import NNConv_old 9 | 10 | from timeit import default_timer 11 | 12 | 13 | class KernelNN3(torch.nn.Module): 14 | def __init__(self, width_node, width_kernel, depth, ker_in, in_width=1, out_width=1): 15 | super(KernelNN3, self).__init__() 16 | self.depth = depth 17 | 18 | self.fc1 = torch.nn.Linear(in_width, width_node) 19 | 20 | kernel = DenseNet([ker_in, width_kernel // 2, width_kernel, width_node**2], torch.nn.ReLU) 21 | self.conv1 = NNConv_old(width_node, width_node, kernel, aggr='mean') 22 | 23 | self.fc2 = torch.nn.Linear(width_node, 1) 24 | 25 | def forward(self, data): 26 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 27 | x = self.fc1(x) 28 | for k in range(self.depth): 29 | x = F.relu(self.conv1(x, edge_index, edge_attr)) 30 | 31 | x = self.fc2(x) 32 | return x 33 | 34 | 35 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 36 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 37 | 38 | for m in (100,200,400,800): 39 | mtest1 = 100 40 | mtest2 = 200 41 | mtest3 = 400 42 | mtest4 = 800 43 | 44 | r = 2 45 | s = int(((241 - 1)/r) + 1) 46 | n = s**2 47 | # m = 200 48 | k = 5 49 | 50 | radius_train = 0.15 51 | radius_test = 0.15 52 | print('resolution', s) 53 | 54 | 55 | ntrain = 100 56 | ntest = 100 57 | 58 | 59 | batch_size = 10 60 | batch_size2 = 10 61 | width = 64 62 | ker_width = 1000 63 | depth = 6 64 | edge_features = 6 65 | node_features = 6 66 | 67 | epochs = 200 68 | learning_rate = 0.0001 69 | scheduler_step = 50 70 | scheduler_gamma = 0.5 71 | 72 | if m==800: 73 | batch_size = 2 74 | epochs = 100 75 | 76 | path_model = 'model/UAI5_s'+str(s)+'_m'+ str(m) 77 | path_train_err = 'results/UAI5_s'+str(s)+'_m'+ str(m) + 'train.txt' 78 | path_test_err1 = 'results/UAI5_s'+str(s)+'_m'+ str(m)+'_mtest'+ str(mtest1) + 'test.txt' 79 | path_test_err2 = 'results/UAI5_s' + str(s) + '_m' + str(m) + '_mtest' + str(mtest2) + 'test.txt' 80 | path_test_err3 = 'results/UAI5_s' + str(s) + '_m' + str(m) + '_mtest' + str(mtest3) + 'test.txt' 81 | path_test_err4 = 'results/UAI5_s' + str(s) + '_m' + str(m) + '_mtest' + str(mtest4) + 'test.txt' 82 | 83 | 84 | t1 = default_timer() 85 | 86 | 87 | reader = MatReader(TRAIN_PATH) 88 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 89 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 90 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 91 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 92 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 93 | 94 | reader.load_file(TEST_PATH) 95 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 96 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 97 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 98 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 99 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 100 | 101 | 102 | a_normalizer = GaussianNormalizer(train_a) 103 | train_a = a_normalizer.encode(train_a) 104 | test_a = a_normalizer.encode(test_a) 105 | as_normalizer = GaussianNormalizer(train_a_smooth) 106 | train_a_smooth = as_normalizer.encode(train_a_smooth) 107 | test_a_smooth = as_normalizer.encode(test_a_smooth) 108 | agx_normalizer = GaussianNormalizer(train_a_gradx) 109 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 110 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 111 | agy_normalizer = GaussianNormalizer(train_a_grady) 112 | train_a_grady = agy_normalizer.encode(train_a_grady) 113 | test_a_grady = agy_normalizer.encode(test_a_grady) 114 | 115 | u_normalizer = UnitGaussianNormalizer(train_u) 116 | train_u = u_normalizer.encode(train_u) 117 | # test_u = y_normalizer.encode(test_u) 118 | 119 | 120 | 121 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 122 | data_train = [] 123 | for j in range(ntrain): 124 | for i in range(k): 125 | idx = meshgenerator.sample() 126 | grid = meshgenerator.get_grid() 127 | edge_index = meshgenerator.ball_connectivity(radius_train) 128 | edge_attr = meshgenerator.attributes(theta=train_a[j,:]) 129 | #data_train.append(Data(x=init_point.clone().view(-1,1), y=train_y[j,:], edge_index=edge_index, edge_attr=edge_attr)) 130 | data_train.append(Data(x=torch.cat([grid, train_a[j, idx].reshape(-1, 1), 131 | train_a_smooth[j, idx].reshape(-1, 1), train_a_gradx[j, idx].reshape(-1, 1), 132 | train_a_grady[j, idx].reshape(-1, 1) 133 | ], dim=1), 134 | y=train_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 135 | )) 136 | 137 | 138 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=mtest1) 139 | data_test1 = [] 140 | for j in range(ntest): 141 | idx = meshgenerator.sample() 142 | grid = meshgenerator.get_grid() 143 | edge_index = meshgenerator.ball_connectivity(radius_test) 144 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 145 | data_test1.append(Data(x=torch.cat([grid, test_a[j, idx].reshape(-1, 1), 146 | test_a_smooth[j, idx].reshape(-1, 1), test_a_gradx[j, idx].reshape(-1, 1), 147 | test_a_grady[j, idx].reshape(-1, 1) 148 | ], dim=1), 149 | y=test_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 150 | )) 151 | # 152 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=mtest2) 153 | data_test2 = [] 154 | for j in range(ntest): 155 | idx = meshgenerator.sample() 156 | grid = meshgenerator.get_grid() 157 | edge_index = meshgenerator.ball_connectivity(radius_test) 158 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 159 | data_test2.append(Data(x=torch.cat([grid, test_a[j, idx].reshape(-1, 1), 160 | test_a_smooth[j, idx].reshape(-1, 1), test_a_gradx[j, idx].reshape(-1, 1), 161 | test_a_grady[j, idx].reshape(-1, 1) 162 | ], dim=1), 163 | y=test_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 164 | )) 165 | # 166 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=mtest3) 167 | data_test3 = [] 168 | for j in range(ntest): 169 | idx = meshgenerator.sample() 170 | grid = meshgenerator.get_grid() 171 | edge_index = meshgenerator.ball_connectivity(radius_test) 172 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 173 | data_test3.append(Data(x=torch.cat([grid, test_a[j, idx].reshape(-1, 1), 174 | test_a_smooth[j, idx].reshape(-1, 1), test_a_gradx[j, idx].reshape(-1, 1), 175 | test_a_grady[j, idx].reshape(-1, 1) 176 | ], dim=1), 177 | y=test_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 178 | )) 179 | # 180 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=mtest4) 181 | data_test4 = [] 182 | for j in range(ntest): 183 | idx = meshgenerator.sample() 184 | grid = meshgenerator.get_grid() 185 | edge_index = meshgenerator.ball_connectivity(radius_test) 186 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 187 | data_test4.append(Data(x=torch.cat([grid, test_a[j, idx].reshape(-1, 1), 188 | test_a_smooth[j, idx].reshape(-1, 1), test_a_gradx[j, idx].reshape(-1, 1), 189 | test_a_grady[j, idx].reshape(-1, 1) 190 | ], dim=1), 191 | y=test_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 192 | )) 193 | # 194 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 195 | test_loader1 = DataLoader(data_test1, batch_size=batch_size2, shuffle=False) 196 | test_loader2 = DataLoader(data_test2, batch_size=batch_size2, shuffle=False) 197 | test_loader3 = DataLoader(data_test3, batch_size=batch_size2, shuffle=False) 198 | test_loader4 = DataLoader(data_test4, batch_size=2, shuffle=False) 199 | 200 | t2 = default_timer() 201 | 202 | print('preprocessing finished, time used:', t2-t1) 203 | device = torch.device('cuda') 204 | 205 | model = KernelNN3(width, ker_width,depth,edge_features,in_width=node_features).cuda() 206 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 207 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 208 | 209 | myloss = LpLoss(size_average=False) 210 | u_normalizer.cuda() 211 | ttrain = np.zeros((epochs, )) 212 | ttest1 = np.zeros((epochs,)) 213 | ttest2 = np.zeros((epochs,)) 214 | ttest3 = np.zeros((epochs,)) 215 | ttest4 = np.zeros((epochs,)) 216 | model.train() 217 | for ep in range(epochs): 218 | t1 = default_timer() 219 | train_mse = 0.0 220 | train_l2 = 0.0 221 | for batch in train_loader: 222 | batch = batch.to(device) 223 | 224 | optimizer.zero_grad() 225 | out = model(batch) 226 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 227 | mse.backward() 228 | 229 | l2 = myloss(u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 230 | u_normalizer.decode(batch.y.view(batch_size, -1), 231 | sample_idx=batch.sample_idx.view(batch_size, -1))) 232 | optimizer.step() 233 | train_mse += mse.item() 234 | train_l2 += l2.item() 235 | 236 | scheduler.step() 237 | t2 = default_timer() 238 | 239 | model.eval() 240 | test1_l2 = 0.0 241 | test2_l2 = 0.0 242 | test3_l2 = 0.0 243 | test4_l2 = 0.0 244 | with torch.no_grad(): 245 | for batch in test_loader1: 246 | batch = batch.to(device) 247 | out = model(batch) 248 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 249 | test1_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 250 | for batch in test_loader2: 251 | batch = batch.to(device) 252 | out = model(batch) 253 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 254 | test2_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 255 | for batch in test_loader3: 256 | batch = batch.to(device) 257 | out = model(batch) 258 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 259 | test3_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 260 | for batch in test_loader4: 261 | batch = batch.to(device) 262 | out = model(batch) 263 | out = u_normalizer.decode(out.view(2,-1), sample_idx=batch.sample_idx.view(2,-1)) 264 | test4_l2 += myloss(out, batch.y.view(2, -1)).item() 265 | 266 | ttrain[ep] = train_l2/(ntrain * k) 267 | ttest1[ep] = test1_l2/ntest 268 | ttest2[ep] = test2_l2 / ntest 269 | ttest3[ep] = test3_l2 / ntest 270 | ttest4[ep] = test4_l2 / ntest 271 | 272 | print(m, t2-t1, train_mse/len(train_loader), train_l2/(ntrain * k)) 273 | print(test1_l2/ntest, test2_l2/ntest, test3_l2/ntest, test4_l2/ntest) 274 | 275 | np.savetxt(path_train_err, ttrain) 276 | np.savetxt(path_test_err1, ttest1) 277 | np.savetxt(path_test_err2, ttest2) 278 | np.savetxt(path_test_err3, ttest3) 279 | np.savetxt(path_test_err4, ttest4) 280 | torch.save(model, path_model) 281 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI6_sample_radius.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import Data, DataLoader 7 | import matplotlib.pyplot as plt 8 | from utilities import * 9 | from nn_conv import NNConv_old 10 | 11 | from timeit import default_timer 12 | 13 | 14 | class KernelNN3(torch.nn.Module): 15 | def __init__(self, width_node, width_kernel, depth, ker_in, in_width=1, out_width=1): 16 | super(KernelNN3, self).__init__() 17 | self.depth = depth 18 | 19 | self.fc1 = torch.nn.Linear(in_width, width_node) 20 | 21 | kernel = DenseNet([ker_in, width_kernel // 2, width_kernel, width_node**2], torch.nn.ReLU) 22 | self.conv1 = NNConv_old(width_node, width_node, kernel, aggr='mean') 23 | 24 | self.fc2 = torch.nn.Linear(width_node, 1) 25 | 26 | def forward(self, data): 27 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 28 | x = self.fc1(x) 29 | for k in range(self.depth): 30 | x = F.relu(self.conv1(x, edge_index, edge_attr)) 31 | 32 | x = self.fc2(x) 33 | return x 34 | 35 | 36 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 37 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 38 | 39 | for m in (100, 200, 400): 40 | for radius_train in (0.05, 0.15, 0.4): 41 | r = 2 42 | s = int(((241 - 1)/r) + 1) 43 | n = s**2 44 | # m = 200 45 | k = 5 46 | 47 | # radius_train = 0.15 48 | radius_test = radius_train 49 | print('resolution', s) 50 | 51 | 52 | ntrain = 100 53 | ntest = 100 54 | 55 | batch_size = 10 56 | batch_size2 = 10 57 | 58 | if radius_train == 0.4 and m==400: 59 | batch_size = 2 60 | batch_size2 = 2 61 | if radius_train == 0.4 and m == 200: 62 | batch_size = 5 63 | batch_size2 = 5 64 | # else: 65 | 66 | width = 64 67 | ker_width = 1000 68 | depth = 6 69 | edge_features = 6 70 | node_features = 6 71 | 72 | epochs = 200 73 | learning_rate = 0.0001 74 | scheduler_step = 50 75 | scheduler_gamma = 0.5 76 | 77 | path = 'UAI6_s'+str(s)+'_m'+ str(m)+'_radius'+ str(radius_train) 78 | path_model = 'model/'+ path 79 | path_train_err = 'results/'+ path + 'train.txt' 80 | path_test_err = 'results/'+ path + 'test.txt' 81 | path_image = 'results/'+ path 82 | 83 | 84 | t1 = default_timer() 85 | 86 | 87 | reader = MatReader(TRAIN_PATH) 88 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 89 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 90 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 91 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 92 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 93 | 94 | reader.load_file(TEST_PATH) 95 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 96 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 97 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 98 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 99 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 100 | 101 | 102 | a_normalizer = GaussianNormalizer(train_a) 103 | train_a = a_normalizer.encode(train_a) 104 | test_a = a_normalizer.encode(test_a) 105 | as_normalizer = GaussianNormalizer(train_a_smooth) 106 | train_a_smooth = as_normalizer.encode(train_a_smooth) 107 | test_a_smooth = as_normalizer.encode(test_a_smooth) 108 | agx_normalizer = GaussianNormalizer(train_a_gradx) 109 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 110 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 111 | agy_normalizer = GaussianNormalizer(train_a_grady) 112 | train_a_grady = agy_normalizer.encode(train_a_grady) 113 | test_a_grady = agy_normalizer.encode(test_a_grady) 114 | 115 | u_normalizer = UnitGaussianNormalizer(train_u) 116 | train_u = u_normalizer.encode(train_u) 117 | # test_u = y_normalizer.encode(test_u) 118 | 119 | 120 | 121 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 122 | data_train = [] 123 | for j in range(ntrain): 124 | for i in range(k): 125 | idx = meshgenerator.sample() 126 | grid = meshgenerator.get_grid() 127 | edge_index = meshgenerator.ball_connectivity(radius_train) 128 | edge_attr = meshgenerator.attributes(theta=train_a[j,:]) 129 | #data_train.append(Data(x=init_point.clone().view(-1,1), y=train_y[j,:], edge_index=edge_index, edge_attr=edge_attr)) 130 | data_train.append(Data(x=torch.cat([grid, train_a[j, idx].reshape(-1, 1), 131 | train_a_smooth[j, idx].reshape(-1, 1), train_a_gradx[j, idx].reshape(-1, 1), 132 | train_a_grady[j, idx].reshape(-1, 1) 133 | ], dim=1), 134 | y=train_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 135 | )) 136 | 137 | 138 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 139 | data_test = [] 140 | for j in range(ntest): 141 | idx = meshgenerator.sample() 142 | grid = meshgenerator.get_grid() 143 | edge_index = meshgenerator.ball_connectivity(radius_test) 144 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 145 | data_test.append(Data(x=torch.cat([grid, test_a[j, idx].reshape(-1, 1), 146 | test_a_smooth[j, idx].reshape(-1, 1), test_a_gradx[j, idx].reshape(-1, 1), 147 | test_a_grady[j, idx].reshape(-1, 1) 148 | ], dim=1), 149 | y=test_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 150 | )) 151 | # 152 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 153 | test_loader = DataLoader(data_test, batch_size=batch_size2, shuffle=False) 154 | 155 | t2 = default_timer() 156 | 157 | print('preprocessing finished, time used:', t2-t1) 158 | device = torch.device('cuda') 159 | 160 | model = KernelNN3(width, ker_width,depth,edge_features,in_width=node_features).cuda() 161 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 162 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 163 | 164 | myloss = LpLoss(size_average=False) 165 | u_normalizer.cuda() 166 | ttrain = np.zeros((epochs, )) 167 | ttest = np.zeros((epochs,)) 168 | model.train() 169 | for ep in range(epochs): 170 | t1 = default_timer() 171 | train_mse = 0.0 172 | train_l2 = 0.0 173 | for batch in train_loader: 174 | batch = batch.to(device) 175 | 176 | optimizer.zero_grad() 177 | out = model(batch) 178 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 179 | mse.backward() 180 | 181 | l2 = myloss( 182 | u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 183 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1))) 184 | optimizer.step() 185 | train_mse += mse.item() 186 | train_l2 += l2.item() 187 | 188 | scheduler.step() 189 | t2 = default_timer() 190 | 191 | model.eval() 192 | test_l2 = 0.0 193 | with torch.no_grad(): 194 | for batch in test_loader: 195 | batch = batch.to(device) 196 | out = model(batch) 197 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 198 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 199 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 200 | 201 | ttrain[ep] = train_l2/(ntrain * k) 202 | ttest[ep] = test_l2/ntest 203 | 204 | print(m, radius_train, ep, t2-t1, train_mse/len(train_loader), train_l2/(ntrain * k), test_l2/ntest) 205 | 206 | np.savetxt(path_train_err, ttrain) 207 | np.savetxt(path_test_err, ttest) 208 | torch.save(model, path_model) 209 | 210 | plt.figure() 211 | # plt.plot(ttrain, label='train loss') 212 | plt.plot(ttest, label='test loss') 213 | plt.legend(loc='upper right') 214 | plt.show() 215 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI7_evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import DataLoader 7 | import matplotlib.pyplot as plt 8 | from utilities import * 9 | from nn_conv import NNConv_old 10 | 11 | from timeit import default_timer 12 | 13 | 14 | class KernelNN(torch.nn.Module): 15 | def __init__(self, width, ker_width, depth, ker_in, in_width=1, out_width=1): 16 | super(KernelNN, self).__init__() 17 | self.depth = depth 18 | 19 | self.fc1 = torch.nn.Linear(in_width, width) 20 | 21 | kernel = DenseNet([ker_in, ker_width//2, ker_width, width**2], torch.nn.ReLU) 22 | self.conv1 = NNConv_old(width, width, kernel, aggr='mean') 23 | 24 | self.fc2 = torch.nn.Linear(width, 1) 25 | 26 | def forward(self, data): 27 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 28 | x = self.fc1(x) 29 | for k in range(self.depth): 30 | x = self.conv1(x, edge_index, edge_attr) 31 | if k != self.depth-1: 32 | x = F.relu(x) 33 | 34 | x = self.fc2(x) 35 | return x 36 | 37 | # torch.cuda.set_device('cuda:3') 38 | s0 = 421 39 | 40 | TRAIN_PATH = 'data/piececonst_r'+str(s0)+'_N1024_smooth1.mat' 41 | TEST_PATH = 'data/piececonst_r'+str(s0)+'_N1024_smooth2.mat' 42 | 43 | ntrain = 10 44 | ntest = 1 45 | 46 | 47 | 48 | r = 1 49 | s = int(((s0 - 1)/r) + 1) 50 | n = s**2 51 | m = 421 52 | k = 2 53 | trainm = m 54 | train_split = 30 55 | assert ((s0 - 1)/r) % train_split == 0 # the split must divide s-1 56 | 57 | testr1 = r 58 | tests1 = int(((s0 - 1)/testr1) + 1) 59 | test_split = train_split 60 | testn1 = s**2 61 | testm = trainm 62 | 63 | radius_train = 0.2 64 | radius_test = 0.2 65 | # rbf_sigma = 0.2 66 | 67 | print('resolution', s) 68 | 69 | 70 | batch_size = 2 # factor of ntrain * k 71 | batch_size2 = 2 # factor of test_split 72 | assert test_split%batch_size2 == 0 # the batchsize must divide the split 73 | 74 | width = 64 75 | ker_width = 1024 76 | depth = 6 77 | edge_features = 6 78 | node_features = 6 79 | 80 | epochs = 20 81 | learning_rate = 0.0001 82 | scheduler_step = 50 83 | scheduler_gamma = 0.5 84 | 85 | 86 | path = 'UAI7_new_r'+str(s)+'_s'+ str(tests1)+'testm'+str(testm) 87 | path_model = 'model/'+path 88 | path_train_err = 'results/'+path+'train.txt' 89 | path_test_err = 'results/'+path+'test.txt' 90 | path_image = 'image/'+path 91 | 92 | 93 | t1 = default_timer() 94 | 95 | 96 | reader = MatReader(TRAIN_PATH) 97 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 98 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 99 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 100 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 101 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 102 | 103 | reader.load_file(TEST_PATH) 104 | test_a = reader.read_field('coeff')[:ntest,::testr1,::testr1].reshape(ntest,-1) 105 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::testr1,::testr1].reshape(ntest,-1) 106 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::testr1,::testr1].reshape(ntest,-1) 107 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::testr1,::testr1].reshape(ntest,-1) 108 | test_u = reader.read_field('sol')[:ntest,::testr1,::testr1].reshape(ntest,-1) 109 | 110 | 111 | a_normalizer = GaussianNormalizer(train_a) 112 | train_a = a_normalizer.encode(train_a) 113 | test_a = a_normalizer.encode(test_a) 114 | as_normalizer = GaussianNormalizer(train_a_smooth) 115 | train_a_smooth = as_normalizer.encode(train_a_smooth) 116 | test_a_smooth = as_normalizer.encode(test_a_smooth) 117 | agx_normalizer = GaussianNormalizer(train_a_gradx) 118 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 119 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 120 | agy_normalizer = GaussianNormalizer(train_a_grady) 121 | train_a_grady = agy_normalizer.encode(train_a_grady) 122 | test_a_grady = agy_normalizer.encode(test_a_grady) 123 | 124 | u_normalizer = UnitGaussianNormalizer(train_u) 125 | train_u = u_normalizer.encode(train_u) 126 | # test_u = y_normalizer.encode(test_u) 127 | 128 | 129 | meshgenerator = SquareMeshGenerator([[0, 1], [0, 1]], [s, s]) 130 | grid = meshgenerator.get_grid() 131 | gridsplitter = DownsampleGridSplitter(grid, resolution=s, r=train_split, m=trainm, radius=radius_test) 132 | data_train = [] 133 | for j in range(ntrain): 134 | for i in range(k): 135 | theta = torch.cat([train_a[j, :].reshape(-1, 1), 136 | train_a_smooth[j, :].reshape(-1, 1), train_a_gradx[j, :].reshape(-1, 1), 137 | train_a_grady[j, :].reshape(-1, 1) 138 | ], dim=1) 139 | y = train_u[j,:].reshape(-1, 1) 140 | data_train.append(gridsplitter.sample(theta, y)) 141 | 142 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 143 | # print('grid', grid.shape, 'edge_index', edge_index.shape, 'edge_attr', edge_attr.shape) 144 | # print('edge_index_boundary', edge_index_boundary.shape, 'edge_attr', edge_attr_boundary.shape) 145 | 146 | 147 | meshgenerator = SquareMeshGenerator([[0,1],[0,1]],[tests1,tests1]) 148 | grid = meshgenerator.get_grid() 149 | gridsplitter = DownsampleGridSplitter(grid, resolution=tests1, r=test_split, m=testm, radius=radius_test) 150 | 151 | data_test = [] 152 | for j in range(ntest): 153 | theta =torch.cat([test_a[j,:].reshape(-1, 1), 154 | test_a_smooth[j,:].reshape(-1, 1), test_a_gradx[j,:].reshape(-1, 1), test_a_grady[j,:].reshape(-1, 1) 155 | ], dim=1) 156 | data_equation = gridsplitter.get_data(theta) 157 | equation_loader = DataLoader(data_equation, batch_size=batch_size2, shuffle=False) 158 | data_test.append(equation_loader) 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | ################################################################################################## 169 | 170 | ### training 171 | 172 | ################################################################################################## 173 | t2 = default_timer() 174 | 175 | print('preprocessing finished, time used:', t2-t1) 176 | device = torch.device('cuda') 177 | 178 | model = KernelNN(width,ker_width,depth,edge_features,node_features).cuda() 179 | 180 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 181 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 182 | 183 | myloss = LpLoss(size_average=False) 184 | u_normalizer.cuda() 185 | # gridsplitter.cuda() 186 | 187 | model.train() 188 | ttrain = np.zeros((epochs, )) 189 | ttest = np.zeros((epochs,)) 190 | for ep in range(epochs): 191 | t1 = default_timer() 192 | train_mse = 0.0 193 | for batch in train_loader: 194 | batch = batch.to(device) 195 | 196 | optimizer.zero_grad() 197 | out = model(batch) 198 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 199 | # mse.backward() 200 | loss = torch.norm(out.view(-1) - batch.y.view(-1),1) 201 | loss.backward() 202 | 203 | 204 | optimizer.step() 205 | train_mse += mse.item() 206 | 207 | ttrain[ep] = train_mse / len(train_loader) 208 | scheduler.step() 209 | t2 = default_timer() 210 | 211 | 212 | print(ep, t2-t1, train_mse/len(train_loader)) 213 | 214 | model.eval() 215 | test_l2 = 0.0 216 | u_normalizer.cpu() 217 | with torch.no_grad(): 218 | for i, equation_loader in enumerate(data_test): 219 | pred = [] 220 | split_idx = [] 221 | for batch in equation_loader: 222 | batch = batch.to(device) 223 | out = model(batch) 224 | pred.append(out) 225 | split_idx.append(batch.split_idx.tolist()) 226 | 227 | out = gridsplitter.assemble(pred, split_idx, batch_size2, sigma=1) 228 | y = test_u[i] 229 | test_l2 += myloss(u_normalizer.decode(out.view(1, -1)), y.view(1, -1)) 230 | 231 | if i <= 5: 232 | resolution = tests1 233 | truth = test_u[i].numpy().reshape((resolution, resolution)) 234 | approx = u_normalizer.decode(out.view(1, -1)).detach().numpy().reshape((resolution, resolution)) 235 | _min = np.min(np.min(truth)) 236 | _max = np.max(np.max(truth)) 237 | 238 | plt.figure() 239 | plt.subplot(1, 3, 1) 240 | plt.imshow(truth, vmin=_min, vmax=_max) 241 | plt.xticks([], []) 242 | plt.yticks([], []) 243 | plt.colorbar(fraction=0.046, pad=0.04) 244 | plt.title('Ground Truth') 245 | 246 | plt.subplot(1, 3, 2) 247 | plt.imshow(approx, vmin=_min, vmax=_max) 248 | plt.xticks([], []) 249 | plt.yticks([], []) 250 | plt.colorbar(fraction=0.046, pad=0.04) 251 | plt.title('Approximation') 252 | 253 | plt.subplot(1, 3, 3) 254 | plt.imshow((approx - truth) ** 2) 255 | plt.xticks([], []) 256 | plt.yticks([], []) 257 | plt.colorbar(fraction=0.046, pad=0.04) 258 | plt.title('Error') 259 | 260 | plt.subplots_adjust(wspace=0.5, hspace=0.5) 261 | # plt.savefig(path_image + str(i) + '.png') 262 | plt.savefig(path_image + str(i) + '.eps', format = 'eps', bbox_inches="tight") 263 | # plt.show() 264 | 265 | 266 | t3 = default_timer() 267 | print(ep, t3-t2, train_mse/len(train_loader), test_l2/ntest) 268 | 269 | ttest[ep] = test_l2 / ntest 270 | 271 | np.savetxt(path_train_err, ttrain) 272 | np.savetxt(path_test_err, ttest) 273 | torch.save(model, path_model) 274 | ################################################################################################## 275 | 276 | ### Ploting 277 | 278 | ################################################################################################## 279 | 280 | 281 | 282 | plt.figure() 283 | # plt.plot(ttrain, label='train loss') 284 | plt.plot(ttest, label='test loss') 285 | plt.legend(loc='upper right') 286 | plt.show() 287 | 288 | 289 | 290 | 291 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI7_evaluate2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import Data, DataLoader 7 | import matplotlib.pyplot as plt 8 | from utilities import * 9 | from nn_conv import NNConv_old 10 | 11 | from timeit import default_timer 12 | 13 | 14 | class KernelNN(torch.nn.Module): 15 | def __init__(self, width, ker_width, depth, ker_in, in_width=1, out_width=1): 16 | super(KernelNN, self).__init__() 17 | self.depth = depth 18 | 19 | self.fc1 = torch.nn.Linear(in_width, width) 20 | 21 | kernel = DenseNet([ker_in, ker_width//2, ker_width, width**2], torch.nn.ReLU) 22 | self.conv1 = NNConv_old(width, width, kernel, aggr='mean') 23 | 24 | self.fc2 = torch.nn.Linear(width, 1) 25 | 26 | def forward(self, data): 27 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 28 | x = self.fc1(x) 29 | for k in range(self.depth): 30 | x = self.conv1(x, edge_index, edge_attr) 31 | if k != self.depth-1: 32 | x = F.relu(x) 33 | 34 | x = self.fc2(x) 35 | return x 36 | 37 | # torch.cuda.set_device('cuda:3') 38 | s0 = 421 39 | 40 | TRAIN_PATH = 'data/piececonst_r'+str(s0)+'_N1024_smooth1.mat' 41 | TEST_PATH = 'data/piececonst_r'+str(s0)+'_N1024_smooth2.mat' 42 | 43 | ntrain = 10 44 | ntest = 1 45 | 46 | 47 | 48 | r = 1 49 | s = int(((s0 - 1)/r) + 1) 50 | n = s**2 51 | k = 2 52 | trainm = 421 53 | assert n % trainm == 0 54 | train_split = 30 55 | assert ((s0 - 1)/r) % train_split == 0 # the split must divide s-1 56 | 57 | testr1 = r 58 | tests1 = int(((s0 - 1)/testr1) + 1) 59 | test_split = train_split 60 | testn1 = s**2 61 | testm = trainm 62 | 63 | radius_train = 0.2 64 | radius_test = 0.2 65 | # rbf_sigma = 0.2 66 | 67 | print('resolution', s) 68 | 69 | 70 | batch_size = 2 # factor of ntrain * k 71 | batch_size2 = 2 # factor of test_split 72 | assert test_split%batch_size2 == 0 # the batchsize must divide the split 73 | 74 | width = 64 75 | ker_width = 1024 76 | depth = 6 77 | edge_features = 6 78 | node_features = 6 79 | 80 | epochs = 20 81 | learning_rate = 0.0001 82 | scheduler_step = 50 83 | scheduler_gamma = 0.5 84 | 85 | 86 | path = 'UAI7_new_r'+str(s)+'_s'+ str(tests1)+'testm'+str(testm) 87 | path_model = 'model/'+path 88 | path_train_err = 'results/'+path+'train.txt' 89 | path_test_err = 'results/'+path+'test.txt' 90 | path_image = 'image/'+path 91 | 92 | 93 | t1 = default_timer() 94 | 95 | 96 | reader = MatReader(TRAIN_PATH) 97 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 98 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 99 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 100 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 101 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 102 | 103 | reader.load_file(TEST_PATH) 104 | test_a = reader.read_field('coeff')[:ntest,::testr1,::testr1].reshape(ntest,-1) 105 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::testr1,::testr1].reshape(ntest,-1) 106 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::testr1,::testr1].reshape(ntest,-1) 107 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::testr1,::testr1].reshape(ntest,-1) 108 | test_u = reader.read_field('sol')[:ntest,::testr1,::testr1].reshape(ntest,-1) 109 | 110 | 111 | a_normalizer = GaussianNormalizer(train_a) 112 | train_a = a_normalizer.encode(train_a) 113 | test_a = a_normalizer.encode(test_a) 114 | as_normalizer = GaussianNormalizer(train_a_smooth) 115 | train_a_smooth = as_normalizer.encode(train_a_smooth) 116 | test_a_smooth = as_normalizer.encode(test_a_smooth) 117 | agx_normalizer = GaussianNormalizer(train_a_gradx) 118 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 119 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 120 | agy_normalizer = GaussianNormalizer(train_a_grady) 121 | train_a_grady = agy_normalizer.encode(train_a_grady) 122 | test_a_grady = agy_normalizer.encode(test_a_grady) 123 | 124 | u_normalizer = UnitGaussianNormalizer(train_u) 125 | train_u = u_normalizer.encode(train_u) 126 | # test_u = y_normalizer.encode(test_u) 127 | 128 | 129 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=trainm) 130 | data_train = [] 131 | for j in range(ntrain): 132 | for i in range(k): 133 | idx = meshgenerator.sample() 134 | grid = meshgenerator.get_grid() 135 | edge_index = meshgenerator.ball_connectivity(radius_train) 136 | edge_attr = meshgenerator.attributes(theta=train_a[j, :]) 137 | # data_train.append(Data(x=init_point.clone().view(-1,1), y=train_y[j,:], edge_index=edge_index, edge_attr=edge_attr)) 138 | data_train.append(Data(x=torch.cat([grid, train_a[j, idx].reshape(-1, 1), 139 | train_a_smooth[j, idx].reshape(-1, 1), train_a_gradx[j, idx].reshape(-1, 1), 140 | train_a_grady[j, idx].reshape(-1, 1) 141 | ], dim=1), 142 | y=train_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 143 | )) 144 | 145 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 146 | # print('grid', grid.shape, 'edge_index', edge_index.shape, 'edge_attr', edge_attr.shape) 147 | # print('edge_index_boundary', edge_index_boundary.shape, 'edge_attr', edge_attr_boundary.shape) 148 | 149 | 150 | meshgenerator = SquareMeshGenerator([[0,1],[0,1]],[tests1,tests1]) 151 | grid = meshgenerator.get_grid() 152 | gridsplitter = RandomGridSplitter(grid, resolution=tests1, l=2, m=testm, radius=radius_test) 153 | 154 | data_test = [] 155 | for j in range(ntest): 156 | theta =torch.cat([test_a[j,:].reshape(-1, 1), 157 | test_a_smooth[j,:].reshape(-1, 1), test_a_gradx[j,:].reshape(-1, 1), test_a_grady[j,:].reshape(-1, 1) 158 | ], dim=1) 159 | data_equation = gridsplitter.get_data(theta) 160 | equation_loader = DataLoader(data_equation, batch_size=batch_size2, shuffle=False) 161 | data_test.append(equation_loader) 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | ################################################################################################## 172 | 173 | ### training 174 | 175 | ################################################################################################## 176 | t2 = default_timer() 177 | 178 | print('preprocessing finished, time used:', t2-t1) 179 | device = torch.device('cuda') 180 | 181 | model = KernelNN(width,ker_width,depth,edge_features,node_features).cuda() 182 | 183 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 184 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 185 | 186 | myloss = LpLoss(size_average=False) 187 | u_normalizer.cuda() 188 | # gridsplitter.cuda() 189 | 190 | model.train() 191 | ttrain = np.zeros((epochs, )) 192 | ttest = np.zeros((epochs,)) 193 | for ep in range(epochs): 194 | t1 = default_timer() 195 | train_mse = 0.0 196 | for batch in train_loader: 197 | batch = batch.to(device) 198 | 199 | optimizer.zero_grad() 200 | out = model(batch) 201 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 202 | # mse.backward() 203 | loss = torch.norm(out.view(-1) - batch.y.view(-1),1) 204 | loss.backward() 205 | 206 | 207 | optimizer.step() 208 | train_mse += mse.item() 209 | 210 | ttrain[ep] = train_mse / len(train_loader) 211 | scheduler.step() 212 | t2 = default_timer() 213 | 214 | 215 | print(ep, t2-t1, train_mse/len(train_loader)) 216 | 217 | model.eval() 218 | test_l2 = 0.0 219 | u_normalizer.cpu() 220 | with torch.no_grad(): 221 | for i, equation_loader in enumerate(data_test): 222 | pred = [] 223 | split_idx = [] 224 | for batch in equation_loader: 225 | batch = batch.to(device) 226 | out = model(batch).detach().cpu() 227 | pred.append(out) 228 | split_idx.append(batch.split_idx) 229 | 230 | out = gridsplitter.assemble(pred, split_idx, batch_size2, sigma=1) 231 | y = test_u[i] 232 | test_l2 += myloss(u_normalizer.decode(out.view(1, -1)), y.view(1, -1)) 233 | 234 | if i <= 5: 235 | resolution = tests1 236 | truth = test_u[i].numpy().reshape((resolution, resolution)) 237 | approx = u_normalizer.decode(out.view(1, -1)).detach().numpy().reshape((resolution, resolution)) 238 | _min = np.min(np.min(truth)) 239 | _max = np.max(np.max(truth)) 240 | 241 | plt.figure() 242 | plt.subplot(1, 3, 1) 243 | plt.imshow(truth, vmin=_min, vmax=_max) 244 | plt.xticks([], []) 245 | plt.yticks([], []) 246 | plt.colorbar(fraction=0.046, pad=0.04) 247 | plt.title('Ground Truth') 248 | 249 | plt.subplot(1, 3, 2) 250 | plt.imshow(approx, vmin=_min, vmax=_max) 251 | plt.xticks([], []) 252 | plt.yticks([], []) 253 | plt.colorbar(fraction=0.046, pad=0.04) 254 | plt.title('Approximation') 255 | 256 | plt.subplot(1, 3, 3) 257 | plt.imshow((approx - truth) ** 2) 258 | plt.xticks([], []) 259 | plt.yticks([], []) 260 | plt.colorbar(fraction=0.046, pad=0.04) 261 | plt.title('Error') 262 | 263 | plt.subplots_adjust(wspace=0.5, hspace=0.5) 264 | # plt.savefig(path_image + str(i) + '.png') 265 | plt.savefig(path_image + str(i) + '.eps', format = 'eps', bbox_inches="tight") 266 | # plt.show() 267 | 268 | 269 | t3 = default_timer() 270 | print(ep, t3-t2, train_mse/len(train_loader), test_l2/ntest) 271 | 272 | ttest[ep] = test_l2 / ntest 273 | 274 | np.savetxt(path_train_err, ttrain) 275 | np.savetxt(path_test_err, ttest) 276 | torch.save(model, path_model) 277 | ################################################################################################## 278 | 279 | ### Ploting 280 | 281 | ################################################################################################## 282 | 283 | 284 | 285 | plt.figure() 286 | # plt.plot(ttrain, label='train loss') 287 | plt.plot(ttest, label='test loss') 288 | plt.legend(loc='upper right') 289 | plt.show() 290 | 291 | 292 | 293 | 294 | -------------------------------------------------------------------------------- /graph-neural-operator/UAI8_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch_geometric.data import Data, DataLoader 7 | import matplotlib.pyplot as plt 8 | from utilities import * 9 | from nn_conv import NNConv_old 10 | 11 | from timeit import default_timer 12 | 13 | 14 | class KernelNN3(torch.nn.Module): 15 | def __init__(self, width_node, width_kernel, depth, ker_in, in_width=1, out_width=1): 16 | super(KernelNN3, self).__init__() 17 | self.depth = depth 18 | 19 | self.fc1 = torch.nn.Linear(in_width, width_node) 20 | 21 | kernel = DenseNet([ker_in, width_kernel // 4, width_kernel // 2, width_kernel, width_kernel, width_node ** 2], torch.nn.ReLU) 22 | # kernel = DenseNet([ker_in, width_kernel // 2, width_kernel, width_node**2], torch.nn.ReLU) 23 | # kernel = DenseNet([ker_in, width_kernel, width_node**2], torch.nn.ReLU) 24 | 25 | self.conv1 = NNConv_old(width_node, width_node, kernel, aggr='mean') 26 | 27 | self.fc2 = torch.nn.Linear(width_node, 1) 28 | 29 | def forward(self, data): 30 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 31 | x = self.fc1(x) 32 | for k in range(self.depth): 33 | x = self.conv1(x, edge_index, edge_attr) 34 | if k != self.depth-1: 35 | x = F.relu(x) 36 | 37 | x = self.fc2(x) 38 | return x 39 | 40 | 41 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 42 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 43 | 44 | for ker_width in (256,): 45 | r = 1 46 | s = int(((241 - 1)/r) + 1) 47 | n = s**2 48 | m = 200 49 | k = 2 50 | 51 | radius_train = 0.25 52 | radius_test = 0.25 53 | print('resolution', s) 54 | 55 | 56 | ntrain = 100 57 | ntest = 100 58 | 59 | batch_size = 5 60 | batch_size2 = 5 61 | width = 64 62 | # ker_width = 1000 63 | depth = 6 64 | edge_features = 6 65 | node_features = 6 66 | 67 | epochs = 200 68 | learning_rate = 0.0001 69 | scheduler_step = 50 70 | scheduler_gamma = 0.5 71 | 72 | path = 'UAI8_s'+str(s)+'_ker_width'+ str(ker_width)+'_depth5_' 73 | path_model = 'model/'+path 74 | path_train_err = 'results/'+path+'train.txt' 75 | path_test_err = 'results/'+path+'test.txt' 76 | path_image = 'results/'+path 77 | 78 | 79 | t1 = default_timer() 80 | 81 | 82 | reader = MatReader(TRAIN_PATH) 83 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 84 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 85 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 86 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 87 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 88 | 89 | reader.load_file(TEST_PATH) 90 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 91 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 92 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 93 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 94 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 95 | 96 | 97 | a_normalizer = GaussianNormalizer(train_a) 98 | train_a = a_normalizer.encode(train_a) 99 | test_a = a_normalizer.encode(test_a) 100 | as_normalizer = GaussianNormalizer(train_a_smooth) 101 | train_a_smooth = as_normalizer.encode(train_a_smooth) 102 | test_a_smooth = as_normalizer.encode(test_a_smooth) 103 | agx_normalizer = GaussianNormalizer(train_a_gradx) 104 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 105 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 106 | agy_normalizer = GaussianNormalizer(train_a_grady) 107 | train_a_grady = agy_normalizer.encode(train_a_grady) 108 | test_a_grady = agy_normalizer.encode(test_a_grady) 109 | 110 | u_normalizer = UnitGaussianNormalizer(train_u) 111 | train_u = u_normalizer.encode(train_u) 112 | # test_u = y_normalizer.encode(test_u) 113 | 114 | 115 | 116 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 117 | data_train = [] 118 | for j in range(ntrain): 119 | for i in range(k): 120 | idx = meshgenerator.sample() 121 | grid = meshgenerator.get_grid() 122 | edge_index = meshgenerator.ball_connectivity(radius_train) 123 | edge_attr = meshgenerator.attributes(theta=train_a[j,:]) 124 | #data_train.append(Data(x=init_point.clone().view(-1,1), y=train_y[j,:], edge_index=edge_index, edge_attr=edge_attr)) 125 | data_train.append(Data(x=torch.cat([grid, train_a[j, idx].reshape(-1, 1), 126 | train_a_smooth[j, idx].reshape(-1, 1), train_a_gradx[j, idx].reshape(-1, 1), 127 | train_a_grady[j, idx].reshape(-1, 1) 128 | ], dim=1), 129 | y=train_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 130 | )) 131 | 132 | 133 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 134 | data_test = [] 135 | for j in range(ntest): 136 | idx = meshgenerator.sample() 137 | grid = meshgenerator.get_grid() 138 | edge_index = meshgenerator.ball_connectivity(radius_test) 139 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 140 | data_test.append(Data(x=torch.cat([grid, test_a[j, idx].reshape(-1, 1), 141 | test_a_smooth[j, idx].reshape(-1, 1), test_a_gradx[j, idx].reshape(-1, 1), 142 | test_a_grady[j, idx].reshape(-1, 1) 143 | ], dim=1), 144 | y=test_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 145 | )) 146 | # 147 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 148 | test_loader = DataLoader(data_test, batch_size=batch_size2, shuffle=False) 149 | 150 | t2 = default_timer() 151 | 152 | print('preprocessing finished, time used:', t2-t1) 153 | device = torch.device('cuda') 154 | 155 | model = KernelNN3(width, ker_width,depth,edge_features,in_width=node_features).cuda() 156 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 157 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 158 | 159 | myloss = LpLoss(size_average=False) 160 | u_normalizer.cuda() 161 | ttrain = np.zeros((epochs, )) 162 | ttest = np.zeros((epochs,)) 163 | model.train() 164 | for ep in range(epochs): 165 | t1 = default_timer() 166 | train_mse = 0.0 167 | train_l2 = 0.0 168 | for batch in train_loader: 169 | batch = batch.to(device) 170 | 171 | optimizer.zero_grad() 172 | out = model(batch) 173 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 174 | mse.backward() 175 | 176 | l2 = myloss(u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 177 | u_normalizer.decode(batch.y.view(batch_size, -1), 178 | sample_idx=batch.sample_idx.view(batch_size, -1))) 179 | optimizer.step() 180 | train_mse += mse.item() 181 | train_l2 += l2.item() 182 | 183 | scheduler.step() 184 | t2 = default_timer() 185 | 186 | model.eval() 187 | test_l2 = 0.0 188 | with torch.no_grad(): 189 | for batch in test_loader: 190 | batch = batch.to(device) 191 | out = model(batch) 192 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 193 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 194 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 195 | 196 | ttrain[ep] = train_l2/(ntrain * k) 197 | ttest[ep] = test_l2/ntest 198 | 199 | print(ker_width, ep, t2-t1, train_mse/len(train_loader), train_l2/(ntrain * k), test_l2/ntest) 200 | 201 | np.savetxt(path_train_err, ttrain) 202 | np.savetxt(path_test_err, ttest) 203 | torch.save(model, path_model) 204 | 205 | plt.figure() 206 | # plt.plot(ttrain, label='train loss') 207 | plt.plot(ttest, label='test loss') 208 | plt.legend(loc='upper right') 209 | plt.show() 210 | -------------------------------------------------------------------------------- /graph-neural-operator/model/grain_new_r64_s64testm100: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/graph-pde/c28220a6558554a193303975adb60d8857d48c0c/graph-neural-operator/model/grain_new_r64_s64testm100 -------------------------------------------------------------------------------- /graph-neural-operator/model/grain_torus_r64_radius0.4testm100: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/graph-pde/c28220a6558554a193303975adb60d8857d48c0c/graph-neural-operator/model/grain_torus_r64_radius0.4testm100 -------------------------------------------------------------------------------- /graph-neural-operator/nn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.nn.conv import MessagePassing 4 | from torch_geometric.nn.inits import reset, uniform 5 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | class NNConv(MessagePassing): 9 | r"""The continuous kernel-based convolutional operator from the 10 | `"Neural Message Passing for Quantum Chemistry" 11 | `_ paper. 12 | This convolution is also known as the edge-conditioned convolution from the 13 | `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on 14 | Graphs" `_ paper (see 15 | :class:`torch_geometric.nn.conv.ECConv` for an alias): 16 | 17 | .. math:: 18 | \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + 19 | \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot 20 | h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}), 21 | 22 | where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* 23 | a MLP. 24 | 25 | Args: 26 | in_channels (int): Size of each input sample. 27 | out_channels (int): Size of each output sample. 28 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that 29 | maps edge features :obj:`edge_attr` of shape :obj:`[-1, 30 | num_edge_features]` to shape 31 | :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by 32 | :class:`torch.nn.Sequential`. 33 | aggr (string, optional): The aggregation scheme to use 34 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). 35 | (default: :obj:`"add"`) 36 | root_weight (bool, optional): If set to :obj:`False`, the layer will 37 | not add the transformed root node features to the output. 38 | (default: :obj:`True`) 39 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 40 | an additive bias. (default: :obj:`True`) 41 | **kwargs (optional): Additional arguments of 42 | :class:`torch_geometric.nn.conv.MessagePassing`. 43 | """ 44 | 45 | def __init__(self, 46 | in_channels, 47 | out_channels, 48 | nn, 49 | aggr='add', 50 | root_weight=True, 51 | bias=True, 52 | **kwargs): 53 | super(NNConv, self).__init__(aggr=aggr, **kwargs) 54 | 55 | self.in_channels = in_channels 56 | self.out_channels = out_channels 57 | self.nn = nn 58 | self.aggr = aggr 59 | 60 | if root_weight: 61 | self.root = Parameter(torch.Tensor(in_channels, out_channels)) 62 | else: 63 | self.register_parameter('root', None) 64 | 65 | if bias: 66 | self.bias = Parameter(torch.Tensor(out_channels)) 67 | else: 68 | self.register_parameter('bias', None) 69 | 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | reset(self.nn) 74 | uniform(self.in_channels, self.root) 75 | uniform(self.in_channels, self.bias) 76 | 77 | def forward(self, x, edge_index, edge_attr): 78 | """""" 79 | x = x.unsqueeze(-1) if x.dim() == 1 else x 80 | pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr 81 | return self.propagate(edge_index, x=x, pseudo=pseudo) 82 | 83 | def message(self, x_j, pseudo): 84 | weight_diag = torch.diag_embed(self.nn(pseudo)).view(-1, self.in_channels, self.out_channels) 85 | return torch.matmul(x_j.unsqueeze(1), weight_diag).squeeze(1) 86 | 87 | def update(self, aggr_out, x): 88 | if self.root is not None: 89 | aggr_out = aggr_out + torch.mm(x, self.root) 90 | if self.bias is not None: 91 | aggr_out = aggr_out + self.bias 92 | return aggr_out 93 | 94 | def __repr__(self): 95 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 96 | self.out_channels) 97 | 98 | 99 | class NNConv_Gaussian(MessagePassing): 100 | r"""The continuous kernel-based convolutional operator from the 101 | `"Neural Message Passing for Quantum Chemistry" 102 | `_ paper. 103 | This convolution is also known as the edge-conditioned convolution from the 104 | `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on 105 | Graphs" `_ paper (see 106 | :class:`torch_geometric.nn.conv.ECConv` for an alias): 107 | 108 | .. math:: 109 | \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + 110 | \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot 111 | h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}), 112 | 113 | where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* 114 | a MLP. 115 | 116 | Args: 117 | in_channels (int): Size of each input sample. 118 | out_channels (int): Size of each output sample. 119 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that 120 | maps edge features :obj:`edge_attr` of shape :obj:`[-1, 121 | num_edge_features]` to shape 122 | :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by 123 | :class:`torch.nn.Sequential`. 124 | aggr (string, optional): The aggregation scheme to use 125 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). 126 | (default: :obj:`"add"`) 127 | root_weight (bool, optional): If set to :obj:`False`, the layer will 128 | not add the transformed root node features to the output. 129 | (default: :obj:`True`) 130 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 131 | an additive bias. (default: :obj:`True`) 132 | **kwargs (optional): Additional arguments of 133 | :class:`torch_geometric.nn.conv.MessagePassing`. 134 | """ 135 | 136 | def __init__(self, 137 | in_channels, 138 | out_channels, 139 | nn, 140 | aggr='add', 141 | root_weight=True, 142 | bias=True, 143 | **kwargs): 144 | super(NNConv_Gaussian, self).__init__(aggr=aggr, **kwargs) 145 | 146 | self.in_channels = in_channels 147 | self.out_channels = out_channels 148 | self.nn = nn 149 | self.aggr = aggr 150 | 151 | if root_weight: 152 | self.root = Parameter(torch.Tensor(in_channels, out_channels)) 153 | else: 154 | self.register_parameter('root', None) 155 | 156 | if bias: 157 | self.bias = Parameter(torch.Tensor(out_channels)) 158 | else: 159 | self.register_parameter('bias', None) 160 | 161 | self.reset_parameters() 162 | 163 | def reset_parameters(self): 164 | reset(self.nn) 165 | uniform(self.in_channels, self.root) 166 | uniform(self.in_channels, self.bias) 167 | 168 | def forward(self, x, edge_index, edge_attr): 169 | """""" 170 | x = x.unsqueeze(-1) if x.dim() == 1 else x 171 | pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr 172 | return self.propagate(edge_index, x=x, pseudo=pseudo) 173 | 174 | def message(self, x_j, pseudo): 175 | one = torch.ones(1).to(device) 176 | a = 1 / torch.sqrt(torch.abs(pseudo[:,1] * pseudo[:,2])) 177 | # print('a',torch.isnan(a)) 178 | b = torch.exp(-1 * (pseudo[:, 0] ** 2).view(-1, 1) / (self.nn(one) ** 2).view(1, -1)) 179 | # print('b',torch.isnan(b)) 180 | weight_guass = a.reshape(-1,1).repeat(1,64) * b 181 | # print('w',torch.isnan(weight_guass)) 182 | weight_guass = torch.diag_embed(weight_guass).view(-1, self.in_channels, self.out_channels) 183 | return torch.matmul(x_j.unsqueeze(1), weight_guass).squeeze(1) 184 | 185 | def update(self, aggr_out, x): 186 | if self.root is not None: 187 | aggr_out = aggr_out + torch.mm(x, self.root) 188 | if self.bias is not None: 189 | aggr_out = aggr_out + self.bias 190 | return aggr_out 191 | 192 | def __repr__(self): 193 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 194 | self.out_channels) 195 | 196 | 197 | class NNConv_old(MessagePassing): 198 | r"""The continuous kernel-based convolutional operator from the 199 | `"Neural Message Passing for Quantum Chemistry" 200 | `_ paper. 201 | This convolution is also known as the edge-conditioned convolution from the 202 | `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on 203 | Graphs" `_ paper (see 204 | :class:`torch_geometric.nn.conv.ECConv` for an alias): 205 | 206 | .. math:: 207 | \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + 208 | \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot 209 | h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}), 210 | 211 | where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* 212 | a MLP. 213 | 214 | Args: 215 | in_channels (int): Size of each input sample. 216 | out_channels (int): Size of each output sample. 217 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that 218 | maps edge features :obj:`edge_attr` of shape :obj:`[-1, 219 | num_edge_features]` to shape 220 | :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by 221 | :class:`torch.nn.Sequential`. 222 | aggr (string, optional): The aggregation scheme to use 223 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). 224 | (default: :obj:`"add"`) 225 | root_weight (bool, optional): If set to :obj:`False`, the layer will 226 | not add the transformed root node features to the output. 227 | (default: :obj:`True`) 228 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 229 | an additive bias. (default: :obj:`True`) 230 | **kwargs (optional): Additional arguments of 231 | :class:`torch_geometric.nn.conv.MessagePassing`. 232 | """ 233 | 234 | def __init__(self, 235 | in_channels, 236 | out_channels, 237 | nn, 238 | aggr='add', 239 | root_weight=True, 240 | bias=True, 241 | **kwargs): 242 | super(NNConv_old, self).__init__(aggr=aggr, **kwargs) 243 | 244 | self.in_channels = in_channels 245 | self.out_channels = out_channels 246 | self.nn = nn 247 | self.aggr = aggr 248 | 249 | if root_weight: 250 | self.root = Parameter(torch.Tensor(in_channels, out_channels)) 251 | else: 252 | self.register_parameter('root', None) 253 | 254 | if bias: 255 | self.bias = Parameter(torch.Tensor(out_channels)) 256 | else: 257 | self.register_parameter('bias', None) 258 | 259 | self.reset_parameters() 260 | 261 | def reset_parameters(self): 262 | reset(self.nn) 263 | size = self.in_channels 264 | uniform(size, self.root) 265 | uniform(size, self.bias) 266 | 267 | def forward(self, x, edge_index, edge_attr): 268 | """""" 269 | x = x.unsqueeze(-1) if x.dim() == 1 else x 270 | pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr 271 | return self.propagate(edge_index, x=x, pseudo=pseudo) 272 | 273 | def message(self, x_j, pseudo): 274 | weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels) 275 | return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1) 276 | 277 | def update(self, aggr_out, x): 278 | if self.root is not None: 279 | aggr_out = aggr_out + torch.mm(x, self.root) 280 | if self.bias is not None: 281 | aggr_out = aggr_out + self.bias 282 | return aggr_out 283 | 284 | def __repr__(self): 285 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 286 | self.out_channels) 287 | 288 | ECConv = NNConv 289 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/MGKN_general_darcy2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import matplotlib.pyplot as plt 6 | from utilities import * 7 | from torch_geometric.data import Data, DataLoader 8 | from torch_geometric.nn import NNConv 9 | from timeit import default_timer 10 | 11 | torch.manual_seed(0) 12 | np.random.seed(0) 13 | 14 | 15 | ######################################################################## 16 | # 17 | # The neural networks architecture 18 | # 19 | ######################################################################## 20 | 21 | class MKGN(torch.nn.Module): 22 | def __init__(self, width, ker_width, depth, ker_in, points, level, in_width=1, out_width=1): 23 | super(MKGN, self).__init__() 24 | self.depth = depth 25 | self.width = width 26 | self.level = level 27 | 28 | index = 0 29 | self.points = [0] 30 | for point in points: 31 | index = index + point 32 | self.points.append(index) 33 | print(level, self.points) 34 | 35 | self.points_total = np.sum(points) 36 | 37 | # in (P) 38 | self.fc_in = torch.nn.Linear(in_width, width) 39 | 40 | # K12 K23 K34 ... 41 | self.conv_down_list = [] 42 | for l in range(1, level): 43 | ker_width_l = ker_width // (2 ** l) 44 | kernel_l = DenseNet([ker_in, ker_width_l, width ** 2], torch.nn.ReLU) 45 | self.conv_down_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 46 | self.conv_down_list = torch.nn.ModuleList(self.conv_down_list) 47 | 48 | # K11 K22 K33 49 | self.conv_list = [] 50 | for l in range(level): 51 | ker_width_l = ker_width // (2 ** l) 52 | kernel_l = DenseNet([ker_in, ker_width_l, ker_width_l, width ** 2], torch.nn.ReLU) 53 | self.conv_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=True, bias=False)) 54 | self.conv_list = torch.nn.ModuleList(self.conv_list) 55 | 56 | # K21 K32 K43 57 | self.conv_up_list = [] 58 | for l in range(1, level): 59 | ker_width_l = ker_width // (2 ** l) 60 | kernel_l = DenseNet([ker_in, ker_width_l, width ** 2], torch.nn.ReLU) 61 | self.conv_up_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 62 | self.conv_up_list = torch.nn.ModuleList(self.conv_up_list) 63 | 64 | # out (Q) 65 | self.fc_out1 = torch.nn.Linear(width, ker_width) 66 | self.fc_out2 = torch.nn.Linear(ker_width, 1) 67 | 68 | 69 | def forward(self, data): 70 | edge_index_down, edge_attr_down, range_down = data.edge_index_down, data.edge_attr_down, data.edge_index_down_range 71 | edge_index_mid, edge_attr_mid, range_mid = data.edge_index_mid, data.edge_attr_mid, data.edge_index_range 72 | edge_index_up, edge_attr_up, range_up = data.edge_index_up, data.edge_attr_up, data.edge_index_up_range 73 | 74 | x = self.fc_in(data.x) 75 | 76 | for t in range(self.depth): 77 | #downward 78 | for l in range(self.level-1): 79 | x = x + self.conv_down_list[l](x, edge_index_down[:,range_down[l,0]:range_down[l,1]], edge_attr_down[range_down[l,0]:range_down[l,1],:]) 80 | x = F.relu(x) 81 | 82 | #upward 83 | for l in reversed(range(self.level)): 84 | x[self.points[l]:self.points[l+1]] = self.conv_list[l](x[self.points[l]:self.points[l+1]].clone(), 85 | edge_index_mid[:,range_mid[l,0]:range_mid[l,1]]-self.points[l], 86 | edge_attr_mid[range_mid[l,0]:range_mid[l,1],:]) 87 | 88 | if l > 0: 89 | x = x + self.conv_up_list[l-1](x, edge_index_up[:,range_up[l-1,0]:range_up[l-1,1]], edge_attr_up[range_up[l-1,0]:range_up[l-1,1],:]) 90 | x = F.relu(x) 91 | 92 | x = F.relu(self.fc_out1(x[:self.points[1]])) 93 | x = self.fc_out2(x) 94 | return x 95 | 96 | 97 | ######################################################################## 98 | # 99 | # Hyperparameters 100 | # 101 | ######################################################################## 102 | 103 | 104 | s0 = 421 #the grid size 105 | 106 | r = 5 #downsample 107 | s = int(((s0 - 1)/r) + 1) #grid size after downsample 108 | n = s**2 # number of nodes 109 | k = 1 # graph sampled per training pairs 110 | 111 | 112 | m = [400, 100, 25] # number of nodes sampled for each layers 113 | radius_inner = [0.25, 0.5, 1] # r_{l,l} 114 | radius_inter = [0.125, 0.25] # r_{l,l+1} = r_{l+1,l} 115 | 116 | level = len(m) # number of levels L 117 | print('resolution', s) 118 | 119 | splits = n // m[0] 120 | if splits * m[0] < n: 121 | splits = splits + 1 122 | 123 | ntrain = 1024 # number of training pairs N 124 | ntest = 100 # number of testing pairs 125 | 126 | # don't change this 127 | batch_size = 1 #train 128 | batch_size2 = 1 #test 129 | 130 | width = 64 #d_v 131 | ker_width = 256 #1024 132 | depth = 5 #T 133 | edge_features = 6 134 | node_features = 6 135 | 136 | epochs = 200 137 | learning_rate = 0.0001 138 | scheduler_step = 20 139 | scheduler_gamma = 0.80 140 | 141 | 142 | TRAIN_PATH = 'data/piececonst_r'+str(s0)+'_N1024_smooth1.mat' 143 | TEST_PATH = 'data/piececonst_r'+str(s0)+'_N1024_smooth2.mat' 144 | 145 | path = 'multigraph_fullnew_s'+str(s)+'_ntrain'+str(ntrain)+'_kerwidth'+str(ker_width) 146 | path_model = 'model/' + path 147 | path_train_err = 'results/' + path + 'train.txt' 148 | path_test_err = 'results/' + path + 'test.txt' 149 | path_image = 'image/' + path 150 | 151 | 152 | ######################################################################## 153 | # 154 | # Read the data 155 | # 156 | ######################################################################## 157 | 158 | t1 = default_timer() 159 | 160 | reader = MatReader(TRAIN_PATH) 161 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 162 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 163 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 164 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 165 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 166 | 167 | reader.load_file(TEST_PATH) 168 | nstart = 0 169 | test_a = reader.read_field('coeff')[nstart:nstart+ntest,::r,::r].reshape(ntest,-1) 170 | test_a_smooth = reader.read_field('Kcoeff')[nstart:nstart+ntest,::r,::r].reshape(ntest,-1) 171 | test_a_gradx = reader.read_field('Kcoeff_x')[nstart:nstart+ntest,::r,::r].reshape(ntest,-1) 172 | test_a_grady = reader.read_field('Kcoeff_y')[nstart:nstart+ntest,::r,::r].reshape(ntest,-1) 173 | test_u = reader.read_field('sol')[nstart:nstart+ntest,::r,::r].reshape(ntest,-1) 174 | 175 | # normalize the data 176 | a_normalizer = GaussianNormalizer(train_a) 177 | train_a = a_normalizer.encode(train_a) 178 | test_a = a_normalizer.encode(test_a) 179 | as_normalizer = GaussianNormalizer(train_a_smooth) 180 | train_a_smooth = as_normalizer.encode(train_a_smooth) 181 | test_a_smooth = as_normalizer.encode(test_a_smooth) 182 | agx_normalizer = GaussianNormalizer(train_a_gradx) 183 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 184 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 185 | agy_normalizer = GaussianNormalizer(train_a_grady) 186 | train_a_grady = agy_normalizer.encode(train_a_grady) 187 | test_a_grady = agy_normalizer.encode(test_a_grady) 188 | 189 | u_normalizer = UnitGaussianNormalizer(train_u) 190 | train_u = u_normalizer.encode(train_u) 191 | # test_u = y_normalizer.encode(test_u) 192 | 193 | 194 | 195 | ######################################################################## 196 | # 197 | # Construct Graphs 198 | # 199 | ######################################################################## 200 | 201 | meshgenerator = RandomMultiMeshGenerator([[0,1],[0,1]],[s,s], level=level, sample_sizes=m) 202 | data_train = [] 203 | for j in range(ntrain): 204 | for i in range(k): 205 | idx, idx_all = meshgenerator.sample() 206 | grid, grid_all = meshgenerator.get_grid() 207 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, radius_inter) 208 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 209 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=train_a[j,:]) 210 | x = torch.cat([grid_all, train_a[j, idx_all].reshape(-1, 1), 211 | train_a_smooth[j, idx_all].reshape(-1, 1), 212 | train_a_gradx[j, idx_all].reshape(-1, 1), 213 | train_a_grady[j, idx_all].reshape(-1, 1) 214 | ], dim=1) 215 | data_train.append(Data(x=x, y=train_u[j, idx[0]], 216 | edge_index_mid=edge_index, edge_index_down=edge_index_down, edge_index_up=edge_index_up, 217 | edge_index_range=edge_index_range, edge_index_down_range=edge_index_down_range, edge_index_up_range=edge_index_up_range, 218 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, edge_attr_up=edge_attr_up, 219 | sample_idx=idx[0])) 220 | 221 | print(x.shape, edge_index.shape, edge_index_down.shape, edge_index_range.shape, edge_attr.shape, edge_attr_down.shape) 222 | 223 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 224 | 225 | 226 | meshgenerator = RandomMultiMeshSplitter([[0,1],[0,1]],[s,s], level=level, sample_sizes=m) 227 | data_test = [] 228 | test_theta = torch.stack([test_a, test_a_smooth, test_a_gradx, test_a_grady], dim=2) 229 | for j in range(ntest): 230 | data = meshgenerator.splitter(radius_inner, radius_inter, test_a[j,:], test_theta[j,:,:]) 231 | test_loader = DataLoader(data, batch_size=batch_size2, shuffle=False) 232 | data_test.append(test_loader) 233 | 234 | 235 | t2 = default_timer() 236 | 237 | ######################################################################## 238 | # 239 | # Training 240 | # 241 | ######################################################################## 242 | 243 | 244 | print('preprocessing finished, time used:', t2-t1) 245 | device = torch.device('cuda') 246 | 247 | # print('use pre-train model') 248 | # model = torch.load('model/multigraph_fullnew_s'+str(s)+'_ntrain1024_kerwidth256_150') 249 | 250 | model = MKGN(width=width, ker_width=ker_width, depth=depth, ker_in=edge_features, 251 | points=m, level=level, in_width=node_features, out_width=1).cuda() 252 | 253 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 254 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 255 | 256 | myloss = LpLoss(size_average=False) 257 | ttrain = np.zeros((epochs, )) 258 | ttest = np.zeros((epochs,)) 259 | 260 | for ep in range(epochs): 261 | t1 = default_timer() 262 | train_mse = 0.0 263 | train_l2 = 0.0 264 | model.train() 265 | u_normalizer.cuda() 266 | for batch in train_loader: 267 | batch = batch.to(device) 268 | 269 | optimizer.zero_grad() 270 | out = model(batch) 271 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 272 | # mse.backward() 273 | 274 | loss = torch.norm(out.view(-1) - batch.y.view(-1),1) 275 | # loss.backward() 276 | 277 | l2 = myloss( 278 | u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 279 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1))) 280 | l2.backward() 281 | 282 | optimizer.step() 283 | train_mse += mse.item() 284 | train_l2 += l2.item() 285 | 286 | scheduler.step() 287 | t2 = default_timer() 288 | ttrain[ep] = train_l2 / (ntrain * k) 289 | 290 | print(ep, t2 - t1, train_mse / len(train_loader), train_l2 / (ntrain * k)) 291 | 292 | torch.save(model, path_model) 293 | 294 | ######################################################################## 295 | # 296 | # Testing 297 | # 298 | ######################################################################## 299 | 300 | 301 | ep = epochs - 1 302 | model.eval() 303 | test_l2_all = 0.0 304 | test_l2_split = 0.0 305 | u_normalizer.cpu() 306 | with torch.no_grad(): 307 | t1 = default_timer() 308 | for i, test_loader in enumerate(data_test): 309 | out_list = [] 310 | sample_idx_list = [] 311 | test_l2_split = 0.0 312 | for data in test_loader: 313 | data = data.to(device) 314 | out_split = model(data).cpu().detach() 315 | assert len(out_split) == len(data.sample_idx) 316 | out_split = u_normalizer.decode(out_split.view(batch_size2, -1), 317 | sample_idx=data.sample_idx.view(batch_size2, -1)) 318 | test_l2_split += myloss(out_split, test_u[i, data.sample_idx].view(batch_size2, -1)).item() 319 | 320 | out_list.append(out_split) 321 | sample_idx_list.append(data.sample_idx) 322 | 323 | 324 | 325 | out = meshgenerator.assembler(out_list, sample_idx_list) 326 | l2 = myloss(out.view(1, -1), test_u[i].view(batch_size2, -1)).item() 327 | test_l2_all += l2 328 | print('test i =',i, l2, test_l2_split / len(test_loader)) 329 | 330 | 331 | t2 = default_timer() 332 | print(ep, t2 - t1, test_l2_all / ntest) 333 | ttest[ep] = test_l2_all / ntest 334 | 335 | 336 | np.savetxt(path_train_err, ttrain) 337 | np.savetxt(path_test_err, ttest) 338 | 339 | 340 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/MGKN_orthogonal_burgers1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import matplotlib.pyplot as plt 6 | from utilities import * 7 | from torch_geometric.data import Data, DataLoader 8 | from torch_geometric.nn import NNConv 9 | from timeit import default_timer 10 | 11 | torch.manual_seed(0) 12 | np.random.seed(0) 13 | 14 | ######################################################################## 15 | # 16 | # The neural networks architecture 17 | # 18 | ######################################################################## 19 | 20 | 21 | class MGKN(torch.nn.Module): 22 | def __init__(self, width, ker_width, depth, ker_in, in_width, s): 23 | super(MGKN, self).__init__() 24 | self.depth = depth 25 | self.width = width 26 | self.s = s 27 | self.level = int(np.log2(s)-1) 28 | 29 | # P 30 | self.fc1 = torch.nn.Linear(in_width, width) 31 | 32 | # K_ll 33 | self.conv_list = [] 34 | for l in range(self.level + 1): 35 | ker_width_l = max( ker_width // (2**l), 16) 36 | kernel_l = DenseNet([ker_in, ker_width_l, ker_width_l, width ** 2], torch.nn.ReLU) 37 | self.conv_list.append(NNConv(width, width, kernel_l, aggr='mean')) 38 | self.conv_list = torch.nn.ModuleList(self.conv_list) 39 | 40 | # Q 41 | self.fc2 = torch.nn.Linear(width, ker_width) 42 | self.fc3 = torch.nn.Linear(ker_width, 1) 43 | 44 | 45 | # K_{l,l+1} 46 | def Upsample(self, x, channels, scale, s): 47 | x = x.transpose(0, 1).view(1,channels,s) # (K,width) to (1, width, s) 48 | x = F.upsample(x, scale_factor=scale, mode='nearest') # (1, width, s) to (1, width, s*2) 49 | x = x.view(channels, -1).transpose(0, 1) # (1, width, s*2, s*2) to (K*4, width) 50 | return x 51 | 52 | # K_{l+1,l} 53 | def Downsample(self, x, channels, scale, s): 54 | x = x.transpose(0, 1).view(1,channels,s) # (K,width) to (1, width, s) 55 | x = F.avg_pool1d(x, kernel_size=scale) 56 | x = x.view(channels, -1).transpose(0, 1) # (1, width, s/2, s/2) to (K/4, width) 57 | return x 58 | 59 | def forward(self, data): 60 | X_list,_, edge_index_list, edge_attr_list = data 61 | level = len(X_list) 62 | x = X_list[0] 63 | x = self.fc1(x) 64 | phi = [None] * level # list of x, len=level 65 | for k in range(self.depth): 66 | # downward 67 | for l in range(level): 68 | phi[l] = x 69 | if (l != level - 1): 70 | # downsample 71 | x = self.Downsample(x, channels=self.width, scale=2, s=self.s // (2 ** l) ) 72 | 73 | # upward 74 | x = F.relu(x + self.conv_list[-1](phi[-1], edge_index_list[-1], edge_attr_list[-1])) 75 | for l in reversed(range(level)): 76 | if (l != 0): 77 | # upsample 78 | x = self.Upsample(x, channels=self.width, scale=2, s=self.s // (2 ** l)) 79 | # interactive neighbors 80 | x = F.relu(x + self.conv_list[l](phi[l-1], edge_index_list[l], edge_attr_list[l])) 81 | else: 82 | x = F.relu(x + self.conv_list[0](phi[0], edge_index_list[0], edge_attr_list[0])) 83 | 84 | x = F.relu(self.fc2(x)) 85 | x = self.fc3(x) 86 | return x 87 | 88 | 89 | ######################################################################## 90 | # 91 | # Hyperparameters 92 | # 93 | ######################################################################## 94 | 95 | 96 | r = 8 #downsample 97 | s = 2**13//r #grid size after downsample 98 | 99 | ntrain = 1024 # number of training pairs N 100 | ntest = 100 # number of testing pairs 101 | 102 | batch_size = 1 #train 103 | batch_size2 = 1 #test 104 | width = 64 #d_v 105 | ker_width = 1024 #1024 106 | depth = 4 #T 107 | edge_features = 4 108 | theta_d = 1 109 | node_features = 1 + theta_d 110 | 111 | 112 | epochs = 200 113 | learning_rate = 0.00001 114 | scheduler_step = 10 115 | scheduler_gamma = 0.80 116 | 117 | 118 | TRAIN_PATH = 'data/burgers_data_R10.mat' 119 | TEST_PATH = 'data/burgers_data_R10.mat' 120 | 121 | path = 'multipole_burgersR10_s'+str(s)+'_ntrain'+str(ntrain)+'_kerwidth'+str(ker_width) 122 | path_model = 'model/'+path 123 | path_train_err = 'results/'+path+'train.txt' 124 | path_test_err = 'results/'+path+'test.txt' 125 | path_image = 'image/'+path 126 | 127 | 128 | 129 | ######################################################################## 130 | # 131 | # Read the data 132 | # 133 | ######################################################################## 134 | 135 | 136 | reader = MatReader(TRAIN_PATH) 137 | train_a = reader.read_field('a')[:ntrain,::r].reshape(ntrain,-1) 138 | train_u = reader.read_field('u')[:ntrain,::r].reshape(ntrain,-1) 139 | 140 | reader.load_file(TEST_PATH) 141 | test_a = reader.read_field('a')[-ntest:,::r].reshape(ntest,-1) 142 | test_u = reader.read_field('u')[-ntest:,::r].reshape(ntest,-1) 143 | 144 | 145 | a_normalizer = GaussianNormalizer(train_a) 146 | train_a = a_normalizer.encode(train_a) 147 | test_a = a_normalizer.encode(test_a) 148 | 149 | u_normalizer = UnitGaussianNormalizer(train_u) 150 | train_u = u_normalizer.encode(train_u) 151 | # test_u = y_normalizer.encode(test_u) 152 | 153 | train_theta = train_a.reshape(ntrain,s,1) 154 | test_theta = test_a.reshape(ntest,s,1) 155 | 156 | 157 | ######################################################################## 158 | # 159 | # Construct Graphs 160 | # 161 | ######################################################################## 162 | 163 | 164 | 165 | grid_list, train_theta_list, edge_index_list, edge_index_list_cuda = multi_pole_grid1d(theta = train_theta, theta_d=theta_d,s=s, N=ntrain, is_periodic=True) 166 | grid_list, test_theta_list, edge_index_list, edge_index_list_cuda = multi_pole_grid1d(theta = test_theta, theta_d=theta_d,s=s, N=ntest, is_periodic=True) 167 | 168 | data_train = [] 169 | for j in range(ntrain): 170 | X_list = [] 171 | edge_attr_list = [] 172 | for l in range(len(grid_list)): 173 | X_l = torch.cat([grid_list[l].reshape(-1, 1), train_theta_list[l][j].reshape(-1, theta_d)], dim=1).cuda() 174 | X_list.append(X_l) 175 | for i in range(len(edge_index_list)): 176 | if i==0: 177 | l = 0 178 | else: 179 | l = i-1 180 | edge_attr_l = get_edge_attr(grid_list[l], train_theta_list[l][j,:,0], edge_index_list[i]).cuda() 181 | edge_attr_list.append(edge_attr_l) 182 | 183 | data_train.append((X_list, train_u[j].cuda(), edge_index_list_cuda, edge_attr_list)) 184 | 185 | # train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 186 | 187 | data_test = [] 188 | for j in range(ntest): 189 | X_list = [] 190 | edge_attr_list = [] 191 | for l in range(len(grid_list)): 192 | X_l = torch.cat([grid_list[l].reshape(-1, 1), test_theta_list[l][j].reshape(-1, theta_d)], dim=1).cuda() 193 | X_list.append(X_l) 194 | for i in range(len(edge_index_list)): 195 | if i==0: 196 | l = 0 197 | else: 198 | l = i-1 199 | edge_attr_l = get_edge_attr(grid_list[l], test_theta_list[l][j,:,0], edge_index_list[i]).cuda() 200 | edge_attr_list.append(edge_attr_l) 201 | 202 | data_test.append((X_list, test_u[j].cuda(), edge_index_list_cuda, edge_attr_list)) 203 | 204 | 205 | ######################################################################## 206 | # 207 | # Training 208 | # 209 | ######################################################################## 210 | 211 | # print('use pre-train model') 212 | # model = torch.load('model/multipole_burgersR10_s8192_ntrain1024_kerwidth1024') 213 | 214 | model = MGKN(width, ker_width, depth, edge_features, in_width=node_features, s=s).cuda() 215 | 216 | 217 | 218 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 219 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 220 | 221 | myloss = LpLoss(size_average=False) 222 | u_normalizer.cuda() 223 | ttrain = np.zeros((epochs,)) 224 | ttest = np.zeros((epochs,)) 225 | model.train() 226 | for ep in range(epochs): 227 | t1 = default_timer() 228 | train_mse = 0.0 229 | train_l2 = 0.0 230 | for data in data_train: 231 | X_list, y, edge_index_list, edge_attr_list = data 232 | 233 | optimizer.zero_grad() 234 | out = model(data) 235 | mse = F.mse_loss(out.view(-1, 1), y.view(-1, 1)) 236 | # mse.backward() 237 | l2_loss = myloss(u_normalizer.decode(out.view(1, -1)), 238 | u_normalizer.decode(y.view(1, -1))) 239 | l2_loss.backward() 240 | train_l2 += l2_loss.item() 241 | 242 | optimizer.step() 243 | train_mse += mse.item() 244 | 245 | scheduler.step() 246 | t2 = default_timer() 247 | print(ep, t2 - t1, train_mse / len(data_train), train_l2 / len(data_train)) 248 | ttrain[ep] = train_l2 / len(data_train) 249 | 250 | torch.save(model, path_model) 251 | 252 | ######################################################################## 253 | # 254 | # Testing 255 | # 256 | ######################################################################## 257 | 258 | 259 | model.eval() 260 | test_l2 = 0.0 261 | with torch.no_grad(): 262 | t1 = default_timer() 263 | for i, data in enumerate(data_test): 264 | X_list, y, edge_index_list, edge_attr_list = data 265 | out = model(data) 266 | out = u_normalizer.decode(out.view(1, -1)) 267 | loss = myloss(out, y.view(1, -1)).item() 268 | test_l2 += loss 269 | print(i, loss) 270 | 271 | # resolution = s 272 | # coeff = test_a[i] 273 | # truth = y.detach().cpu().numpy() 274 | # approx = out.detach().cpu().numpy() 275 | # 276 | # np.savetxt('results/coeff'+str(i)+'.txt', coeff) 277 | # np.savetxt('results/truth' + str(i) + '.txt', truth) 278 | # np.savetxt('results/approx' + str(i) + '.txt', approx) 279 | 280 | 281 | t2 = default_timer() 282 | 283 | print(epochs, t2 - t1, test_l2 / ntest) 284 | ttest[0] = test_l2 / ntest 285 | 286 | 287 | np.savetxt(path_train_err, ttrain) 288 | np.savetxt(path_test_err, ttest) 289 | 290 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/README.md: -------------------------------------------------------------------------------- 1 | # Multipole Graph Kernel Network (MGKN) 2 | 3 | The code for "Multipole Graph Neural Operator for Parametric Partial Differential Equation". 4 | Inspired by the classical multipole methods, 5 | we propose a novel multi-level graph neural network framework 6 | that captures interaction at all ranges with only linear complexity. 7 | Our multi-level formulation is equivalent 8 | to recursively adding inducing points to the kernel matrix, 9 | unifying GNNs with multi-resolution matrix factorization of the kernel. 10 | Experiments confirm our multi-graph network 11 | learns discretization-invariant solution operators to PDEs 12 | and can be evaluated in linear time. 13 | 14 | ## Requirements 15 | - [PyTorch](https://pytorch.org/) 16 | - [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) 17 | 18 | 19 | ### Files 20 | - multigraph1.py corresonds to section 4.1 21 | - multigraph2.py corresonds to section 4.2 22 | - utilities.py contains helper functions. 23 | 24 | ### Usage 25 | ``` 26 | python multigraph1.py 27 | ``` 28 | ``` 29 | python multigraph2.py 30 | ``` 31 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/neurips1_GKN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | from torch_geometric.data import Data, DataLoader 8 | import matplotlib.pyplot as plt 9 | from utilities import * 10 | from nn_conv import NNConv, NNConv_old 11 | 12 | from timeit import default_timer 13 | import scipy.io 14 | 15 | 16 | class KernelNN3(torch.nn.Module): 17 | def __init__(self, width_node, width_kernel, depth, ker_in, in_width=1, out_width=1): 18 | super(KernelNN3, self).__init__() 19 | self.depth = depth 20 | 21 | self.fc1 = torch.nn.Linear(in_width, width_node) 22 | 23 | kernel = DenseNet([ker_in, width_kernel // 2, width_kernel, width_node**2], torch.nn.ReLU) 24 | self.conv1 = NNConv_old(width_node, width_node, kernel, aggr='mean') 25 | 26 | self.fc2 = torch.nn.Linear(width_node, 1) 27 | 28 | def forward(self, data): 29 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 30 | x = self.fc1(x) 31 | for k in range(self.depth): 32 | x = self.conv1(x, edge_index, edge_attr) 33 | if k != self.depth - 1: 34 | x = F.relu(x) 35 | 36 | x = self.fc2(x) 37 | return x 38 | 39 | 40 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 41 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 42 | 43 | ms = [200] 44 | for case in range(1): 45 | r = 1 46 | s = int(((241 - 1)/r) + 1) 47 | n = s**2 48 | m = ms[case] 49 | k = 1 50 | 51 | radius_train = 0.2 52 | radius_test = 0.2 53 | print('resolution', s) 54 | 55 | 56 | ntrain = 100 57 | ntest = 100 58 | 59 | 60 | batch_size = 1 61 | batch_size2 = 1 62 | width = 64 63 | ker_width = 256 64 | depth = 4 65 | edge_features = 6 66 | node_features = 6 67 | 68 | epochs = 100 69 | learning_rate = 0.0001 70 | scheduler_step = 50 71 | scheduler_gamma = 0.5 72 | 73 | 74 | path = 'neurips1_GKN_s'+str(s)+'_ntrain'+str(ntrain)+'_kerwidth'+str(ker_width) + '_m0' + str(m) 75 | path_model = 'model/' + path 76 | path_train_err = 'results/' + path + 'train.txt' 77 | path_test_err = 'results/' + path + 'test.txt' 78 | path_runtime = 'results/' + path + 'time.txt' 79 | path_image = 'results/' + path 80 | 81 | runtime = np.zeros(2, ) 82 | t1 = default_timer() 83 | 84 | 85 | reader = MatReader(TRAIN_PATH) 86 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 87 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 88 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 89 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 90 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 91 | 92 | reader.load_file(TEST_PATH) 93 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 94 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 95 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 96 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 97 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 98 | 99 | 100 | a_normalizer = GaussianNormalizer(train_a) 101 | train_a = a_normalizer.encode(train_a) 102 | test_a = a_normalizer.encode(test_a) 103 | as_normalizer = GaussianNormalizer(train_a_smooth) 104 | train_a_smooth = as_normalizer.encode(train_a_smooth) 105 | test_a_smooth = as_normalizer.encode(test_a_smooth) 106 | agx_normalizer = GaussianNormalizer(train_a_gradx) 107 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 108 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 109 | agy_normalizer = GaussianNormalizer(train_a_grady) 110 | train_a_grady = agy_normalizer.encode(train_a_grady) 111 | test_a_grady = agy_normalizer.encode(test_a_grady) 112 | 113 | u_normalizer = UnitGaussianNormalizer(train_u) 114 | train_u = u_normalizer.encode(train_u) 115 | # test_u = y_normalizer.encode(test_u) 116 | 117 | 118 | 119 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 120 | data_train = [] 121 | for j in range(ntrain): 122 | for i in range(k): 123 | idx = meshgenerator.sample() 124 | grid = meshgenerator.get_grid() 125 | edge_index = meshgenerator.ball_connectivity(radius_train) 126 | edge_attr = meshgenerator.attributes(theta=train_a[j,:]) 127 | #data_train.append(Data(x=init_point.clone().view(-1,1), y=train_y[j,:], edge_index=edge_index, edge_attr=edge_attr)) 128 | data_train.append(Data(x=torch.cat([grid, train_a[j, idx].reshape(-1, 1), 129 | train_a_smooth[j, idx].reshape(-1, 1), train_a_gradx[j, idx].reshape(-1, 1), 130 | train_a_grady[j, idx].reshape(-1, 1) 131 | ], dim=1), 132 | y=train_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 133 | )) 134 | 135 | 136 | meshgenerator = RandomMeshGenerator([[0,1],[0,1]],[s,s], sample_size=m) 137 | data_test = [] 138 | for j in range(ntest): 139 | idx = meshgenerator.sample() 140 | grid = meshgenerator.get_grid() 141 | edge_index = meshgenerator.ball_connectivity(radius_test) 142 | edge_attr = meshgenerator.attributes(theta=test_a[j,:]) 143 | data_test.append(Data(x=torch.cat([grid, test_a[j, idx].reshape(-1, 1), 144 | test_a_smooth[j, idx].reshape(-1, 1), test_a_gradx[j, idx].reshape(-1, 1), 145 | test_a_grady[j, idx].reshape(-1, 1) 146 | ], dim=1), 147 | y=test_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 148 | )) 149 | # 150 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 151 | test_loader = DataLoader(data_test, batch_size=batch_size2, shuffle=False) 152 | 153 | t2 = default_timer() 154 | 155 | print('preprocessing finished, time used:', t2-t1) 156 | device = torch.device('cuda') 157 | 158 | model = KernelNN3(width, ker_width,depth,edge_features,in_width=node_features).cuda() 159 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 160 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 161 | 162 | myloss = LpLoss(size_average=False) 163 | u_normalizer.cuda() 164 | ttrain = np.zeros((epochs, )) 165 | ttest = np.zeros((epochs,)) 166 | model.train() 167 | for ep in range(epochs): 168 | t1 = default_timer() 169 | train_mse = 0.0 170 | train_l2 = 0.0 171 | for batch in train_loader: 172 | batch = batch.to(device) 173 | 174 | optimizer.zero_grad() 175 | out = model(batch) 176 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 177 | mse.backward() 178 | 179 | l2 = myloss( 180 | u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 181 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1))) 182 | optimizer.step() 183 | train_mse += mse.item() 184 | train_l2 += l2.item() 185 | 186 | scheduler.step() 187 | t2 = default_timer() 188 | 189 | model.eval() 190 | test_l2 = 0.0 191 | with torch.no_grad(): 192 | for batch in test_loader: 193 | batch = batch.to(device) 194 | out = model(batch) 195 | out = u_normalizer.decode(out.view(batch_size2,-1), sample_idx=batch.sample_idx.view(batch_size2,-1)) 196 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 197 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 198 | 199 | t3 = default_timer() 200 | ttrain[ep] = train_l2/(ntrain * k) 201 | ttest[ep] = test_l2/ntest 202 | 203 | print(k, ntrain, ep, t2-t1, t3-t2, train_mse/len(train_loader), train_l2/(ntrain * k), test_l2/ntest) 204 | 205 | runtime[0] = t2-t1 206 | runtime[1] = t3-t2 207 | np.savetxt(path_train_err, ttrain) 208 | np.savetxt(path_test_err, ttest) 209 | np.savetxt(path_runtime, runtime) 210 | torch.save(model, path_model) 211 | 212 | 213 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/neurips1_MGKN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | from torch_geometric.data import Data, DataLoader 8 | import matplotlib.pyplot as plt 9 | from utilities import * 10 | from torch_geometric.nn import GCNConv, NNConv 11 | 12 | from timeit import default_timer 13 | import scipy.io 14 | 15 | torch.manual_seed(0) 16 | np.random.seed(0) 17 | 18 | 19 | 20 | class KernelInduced(torch.nn.Module): 21 | def __init__(self, width, ker_width, depth, ker_in, points, level, in_width=1, out_width=1): 22 | super(KernelInduced, self).__init__() 23 | self.depth = depth 24 | self.width = width 25 | self.level = level 26 | self.points = points 27 | self.points_total = np.sum(points) 28 | 29 | # in 30 | self.fc_in = torch.nn.Linear(in_width, width) 31 | # self.fc_in_list = [] 32 | # for l in range(level): 33 | # self.fc_in_list.append(torch.nn.Linear(in_width, width)) 34 | # self.fc_in_list = torch.nn.ModuleList(self.fc_in_list) 35 | 36 | # K12 K23 K34 ... 37 | self.conv_down_list = [] 38 | for l in range(1, level): 39 | ker_width_l = ker_width // (2 ** l) 40 | kernel_l = DenseNet([ker_in, ker_width_l, width ** 2], torch.nn.ReLU) 41 | self.conv_down_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 42 | self.conv_down_list = torch.nn.ModuleList(self.conv_down_list) 43 | 44 | # K11 K22 K33 45 | self.conv_list = [] 46 | for l in range(level): 47 | ker_width_l = ker_width // (2 ** l) 48 | kernel_l = DenseNet([ker_in, ker_width_l, ker_width_l, width ** 2], torch.nn.ReLU) 49 | self.conv_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 50 | self.conv_list = torch.nn.ModuleList(self.conv_list) 51 | 52 | # K21 K32 K43 53 | self.conv_up_list = [] 54 | for l in range(1, level): 55 | ker_width_l = ker_width // (2 ** l) 56 | kernel_l = DenseNet([ker_in, ker_width_l, width ** 2], torch.nn.ReLU) 57 | self.conv_up_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 58 | self.conv_up_list = torch.nn.ModuleList(self.conv_up_list) 59 | 60 | # out 61 | self.fc_out1 = torch.nn.Linear(width, ker_width) 62 | self.fc_out2 = torch.nn.Linear(ker_width, 1) 63 | 64 | 65 | def forward(self, data): 66 | edge_index_down, edge_attr_down, range_down = data.edge_index_down, data.edge_attr_down, data.edge_index_down_range 67 | edge_index_mid, edge_attr_mid, range_mid = data.edge_index_mid, data.edge_attr_mid, data.edge_index_range 68 | edge_index_up, edge_attr_up, range_up = data.edge_index_up, data.edge_attr_up, data.edge_index_up_range 69 | 70 | x = self.fc_in(data.x) 71 | 72 | for t in range(self.depth): 73 | #downward 74 | for l in range(self.level-1): 75 | x = x + self.conv_down_list[l](x, edge_index_down[:,range_down[l,0]:range_down[l,1]], edge_attr_down[range_down[l,0]:range_down[l,1],:]) 76 | x = F.relu(x) 77 | 78 | #upward 79 | for l in reversed(range(self.level)): 80 | x = x + self.conv_list[l](x, edge_index_mid[:,range_mid[l,0]:range_mid[l,1]], edge_attr_mid[range_mid[l,0]:range_mid[l,1],:]) 81 | x = F.relu(x) 82 | if l > 0: 83 | x = x + self.conv_up_list[l-1](x, edge_index_up[:,range_up[l-1,0]:range_up[l-1,1]], edge_attr_up[range_up[l-1,0]:range_up[l-1,1],:]) 84 | x = F.relu(x) 85 | 86 | 87 | x = F.relu(self.fc_out1(x[:self.points[0]])) 88 | x = self.fc_out2(x) 89 | return x 90 | 91 | 92 | 93 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 94 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 95 | 96 | 97 | r = 1 98 | s = int(((241 - 1)/r) + 1) 99 | n = s**2 100 | k = 1 101 | 102 | # this is too large 103 | # m = [6400, 1600, 400, 100, 25] 104 | # radius_inner = [0.5/16, 0.5/8, 0.5/4, 0.5/2, 0.5] 105 | # radius_inter = [0.5/16 * 1.41, 0.5/8* 1.41, 0.5/4* 1.41, 0.5/2* 1.41] 106 | 107 | for case in range(1): 108 | 109 | print('!!!!!!!!!!!!!! case ', case, ' !!!!!!!!!!!!!!!!!!!!!!!!') 110 | 111 | if case == 0: 112 | m = [2400, 1600, 400, 100, 25] 113 | radius_inner = [0.5/8 * 1.41, 0.5/8, 0.5/4, 0.5/2, 0.5] 114 | radius_inter = [0.5/8 * 1.1 , 0.5/8* 1.41, 0.5/4* 1.41, 0.5/2* 1.41] 115 | 116 | # if case == 0: 117 | # m = [1600, 400, 100, 25] 118 | # radius_inner = [ 0.5/8, 0.5/4, 0.5/2, 0.5] 119 | # radius_inter = [0.5/8* 1.41, 0.5/4* 1.41, 0.5/2* 1.41] 120 | 121 | if case == 1: 122 | m = [400, 100, 25] 123 | radius_inner = [0.5/4, 0.5/2, 0.5] 124 | radius_inter = [0.5/4* 1.41, 0.5/2* 1.41] 125 | 126 | if case == 2: 127 | m = [ 100, 25] 128 | radius_inner = [0.5/2, 0.5] 129 | radius_inter = [0.5/2* 1.41] 130 | 131 | level = len(m) 132 | print('resolution', s) 133 | 134 | ntrain = 100 135 | ntest = 100 136 | 137 | # don't change this 138 | batch_size = 1 139 | batch_size2 = 1 140 | 141 | width = 64 142 | ker_width = 256 143 | depth = 4 144 | edge_features = 6 145 | node_features = 6 146 | 147 | epochs = 200 148 | learning_rate = 0.1 / ntrain 149 | scheduler_step = 10 150 | scheduler_gamma = 0.8 151 | 152 | 153 | 154 | path = 'neurips1_multigraph_s'+str(s)+'_ntrain'+str(ntrain)+'_kerwidth'+str(ker_width) + '_m0' + str(m[0]) 155 | path_model = 'model/' + path 156 | path_train_err = 'results/' + path + 'train.txt' 157 | path_test_err = 'results/' + path + 'test.txt' 158 | path_runtime = 'results/' + path + 'time.txt' 159 | path_image = 'results/' + path 160 | 161 | runtime = np.zeros(2,) 162 | 163 | t1 = default_timer() 164 | 165 | 166 | reader = MatReader(TRAIN_PATH) 167 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 168 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 169 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 170 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 171 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 172 | 173 | reader.load_file(TEST_PATH) 174 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 175 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 176 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 177 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 178 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 179 | 180 | 181 | a_normalizer = GaussianNormalizer(train_a) 182 | train_a = a_normalizer.encode(train_a) 183 | test_a = a_normalizer.encode(test_a) 184 | as_normalizer = GaussianNormalizer(train_a_smooth) 185 | train_a_smooth = as_normalizer.encode(train_a_smooth) 186 | test_a_smooth = as_normalizer.encode(test_a_smooth) 187 | agx_normalizer = GaussianNormalizer(train_a_gradx) 188 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 189 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 190 | agy_normalizer = GaussianNormalizer(train_a_grady) 191 | train_a_grady = agy_normalizer.encode(train_a_grady) 192 | test_a_grady = agy_normalizer.encode(test_a_grady) 193 | 194 | u_normalizer = UnitGaussianNormalizer(train_u) 195 | train_u = u_normalizer.encode(train_u) 196 | # test_u = y_normalizer.encode(test_u) 197 | 198 | meshgenerator = RandomMultiMeshGenerator([[0, 1], [0, 1]], [s, s], level=level, sample_sizes=m) 199 | data_train = [] 200 | for j in range(ntrain): 201 | for i in range(k): 202 | idx, idx_all = meshgenerator.sample() 203 | grid, grid_all = meshgenerator.get_grid() 204 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, radius_inter) 205 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 206 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=train_a[j, :]) 207 | x = torch.cat([grid_all, train_a[j, idx_all].reshape(-1, 1), 208 | train_a_smooth[j, idx_all].reshape(-1, 1), 209 | train_a_gradx[j, idx_all].reshape(-1, 1), 210 | train_a_grady[j, idx_all].reshape(-1, 1) 211 | ], dim=1) 212 | data_train.append(Data(x=x, y=train_u[j, idx[0]], 213 | edge_index_mid=edge_index, edge_index_down=edge_index_down, 214 | edge_index_up=edge_index_up, 215 | edge_index_range=edge_index_range, edge_index_down_range=edge_index_down_range, 216 | edge_index_up_range=edge_index_up_range, 217 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, edge_attr_up=edge_attr_up, 218 | sample_idx=idx[0])) 219 | 220 | print(x.shape) 221 | print(edge_index_range) 222 | print(edge_index_down_range) 223 | print(edge_index_up_range) 224 | 225 | print(edge_index.shape, edge_attr.shape) 226 | print(edge_index_down.shape, edge_attr_down.shape) 227 | print(edge_index_up.shape, edge_attr_up.shape) 228 | 229 | meshgenerator = RandomMultiMeshGenerator([[0, 1], [0, 1]], [s, s], level=level, sample_sizes=m) 230 | data_test = [] 231 | for j in range(ntest): 232 | for i in range(k): 233 | idx, idx_all = meshgenerator.sample() 234 | grid, grid_all = meshgenerator.get_grid() 235 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, radius_inter) 236 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 237 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=test_a[j, :]) 238 | x = torch.cat([grid_all, test_a[j, idx_all].reshape(-1, 1), 239 | test_a_smooth[j, idx_all].reshape(-1, 1), 240 | test_a_gradx[j, idx_all].reshape(-1, 1), 241 | test_a_grady[j, idx_all].reshape(-1, 1) 242 | ], dim=1) 243 | data_test.append(Data(x=x, y=test_u[j, idx[0]], 244 | edge_index_mid=edge_index, edge_index_down=edge_index_down, 245 | edge_index_up=edge_index_up, 246 | edge_index_range=edge_index_range, edge_index_down_range=edge_index_down_range, 247 | edge_index_up_range=edge_index_up_range, 248 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, edge_attr_up=edge_attr_up, 249 | sample_idx=idx[0])) 250 | # 251 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 252 | test_loader = DataLoader(data_test, batch_size=batch_size2, shuffle=False) 253 | 254 | t2 = default_timer() 255 | 256 | print('preprocessing finished, time used:', t2-t1) 257 | device = torch.device('cuda') 258 | 259 | # print('use pre-train model') 260 | # model = torch.load('model/multigraph2241_n100') 261 | 262 | model = KernelInduced(width=width, ker_width=ker_width, depth=depth, ker_in=edge_features, 263 | points=m, level=level, in_width=node_features, out_width=1).cuda() 264 | 265 | # model = KernelInduced_SUM(width=width, ker_width=ker_width, depth=depth, ker_in=edge_features, 266 | # points=m, level=level, in_width=node_features, out_width=1).cuda() 267 | 268 | 269 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 270 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 271 | 272 | myloss = LpLoss(size_average=False) 273 | u_normalizer.cuda() 274 | ttrain = np.zeros((epochs, )) 275 | ttest = np.zeros((epochs,)) 276 | model.train() 277 | for ep in range(epochs): 278 | t1 = default_timer() 279 | train_mse = 0.0 280 | train_l2 = 0.0 281 | for batch in train_loader: 282 | batch = batch.to(device) 283 | 284 | optimizer.zero_grad() 285 | out = model(batch) 286 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 287 | # mse.backward() 288 | 289 | l2 = myloss( 290 | u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 291 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1))) 292 | l2.backward() 293 | 294 | optimizer.step() 295 | train_mse += mse.item() 296 | train_l2 += l2.item() 297 | 298 | scheduler.step() 299 | t2 = default_timer() 300 | ttrain[ep] = train_l2 / (ntrain * k) 301 | 302 | print(ep, t2 - t1, train_mse / len(train_loader), train_l2 / (ntrain * k)) 303 | 304 | runtime[0] = t2 - t1 305 | 306 | t1 = default_timer() 307 | model.eval() 308 | test_l2 = 0.0 309 | with torch.no_grad(): 310 | for batch in test_loader: 311 | batch = batch.to(device) 312 | out = model(batch) 313 | out = u_normalizer.decode(out.view(batch_size2, -1), sample_idx=batch.sample_idx.view(batch_size2, -1)) 314 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 315 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 316 | 317 | ttest[ep] = test_l2 / ntest 318 | t2 = default_timer() 319 | print(ep, t2 - t1, test_l2 / ntest) 320 | 321 | runtime[1] = t2 - t1 322 | 323 | np.savetxt(path_train_err, ttrain) 324 | np.savetxt(path_test_err, ttest) 325 | np.savetxt(path_runtime, runtime) 326 | torch.save(model, path_model) 327 | 328 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/neurips2_MGKN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | from torch_geometric.data import Data, DataLoader 8 | import matplotlib.pyplot as plt 9 | from utilities import * 10 | from torch_geometric.nn import GCNConv, NNConv 11 | 12 | from timeit import default_timer 13 | import scipy.io 14 | 15 | torch.manual_seed(0) 16 | np.random.seed(0) 17 | 18 | 19 | 20 | class KernelInduced(torch.nn.Module): 21 | def __init__(self, width, ker_width, depth, ker_in, points, level, in_width=1, out_width=1): 22 | super(KernelInduced, self).__init__() 23 | self.depth = depth 24 | self.width = width 25 | self.level = level 26 | self.points = points 27 | self.points_total = np.sum(points) 28 | 29 | # in 30 | self.fc_in = torch.nn.Linear(in_width, width) 31 | # self.fc_in_list = [] 32 | # for l in range(level): 33 | # self.fc_in_list.append(torch.nn.Linear(in_width, width)) 34 | # self.fc_in_list = torch.nn.ModuleList(self.fc_in_list) 35 | 36 | # K12 K23 K34 ... 37 | self.conv_down_list = [] 38 | for l in range(1, level): 39 | ker_width_l = ker_width // (2 ** l) 40 | kernel_l = DenseNet([ker_in, ker_width_l, width ** 2], torch.nn.ReLU) 41 | self.conv_down_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 42 | self.conv_down_list = torch.nn.ModuleList(self.conv_down_list) 43 | 44 | # K11 K22 K33 45 | self.conv_list = [] 46 | for l in range(level): 47 | ker_width_l = ker_width // (2 ** l) 48 | kernel_l = DenseNet([ker_in, ker_width_l, ker_width_l, width ** 2], torch.nn.ReLU) 49 | self.conv_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 50 | self.conv_list = torch.nn.ModuleList(self.conv_list) 51 | 52 | # K21 K32 K43 53 | self.conv_up_list = [] 54 | for l in range(1, level): 55 | ker_width_l = ker_width // (2 ** l) 56 | kernel_l = DenseNet([ker_in, ker_width_l, width ** 2], torch.nn.ReLU) 57 | self.conv_up_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 58 | self.conv_up_list = torch.nn.ModuleList(self.conv_up_list) 59 | 60 | # out 61 | self.fc_out1 = torch.nn.Linear(width, ker_width) 62 | self.fc_out2 = torch.nn.Linear(ker_width, 1) 63 | 64 | 65 | def forward(self, data): 66 | edge_index_down, edge_attr_down, range_down = data.edge_index_down, data.edge_attr_down, data.edge_index_down_range 67 | edge_index_mid, edge_attr_mid, range_mid = data.edge_index_mid, data.edge_attr_mid, data.edge_index_range 68 | edge_index_up, edge_attr_up, range_up = data.edge_index_up, data.edge_attr_up, data.edge_index_up_range 69 | 70 | x = self.fc_in(data.x) 71 | 72 | for t in range(self.depth): 73 | # if single graph 74 | l = 0 75 | x = x + self.conv_list[l](x, edge_index_mid[:, range_mid[l, 0]:range_mid[l, 1]], 76 | edge_attr_mid[range_mid[l, 0]:range_mid[l, 1], :]) 77 | x = F.relu(x) 78 | # #downward 79 | # for l in range(self.level-1): 80 | # x = x + self.conv_down_list[l](x, edge_index_down[:,range_down[l,0]:range_down[l,1]], edge_attr_down[range_down[l,0]:range_down[l,1],:]) 81 | # x = F.relu(x) 82 | # 83 | # #upward 84 | # for l in reversed(range(self.level)): 85 | # x = x + self.conv_list[l](x, edge_index_mid[:,range_mid[l,0]:range_mid[l,1]], edge_attr_mid[range_mid[l,0]:range_mid[l,1],:]) 86 | # x = F.relu(x) 87 | # if l > 0: 88 | # x = x + self.conv_up_list[l-1](x, edge_index_up[:,range_up[l-1,0]:range_up[l-1,1]], edge_attr_up[range_up[l-1,0]:range_up[l-1,1],:]) 89 | # x = F.relu(x) 90 | 91 | 92 | x = F.relu(self.fc_out1(x[:self.points[0]])) 93 | x = self.fc_out2(x) 94 | return x 95 | 96 | 97 | 98 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 99 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 100 | 101 | 102 | r = 1 103 | s = int(((241 - 1)/r) + 1) 104 | n = s**2 105 | k = 1 106 | 107 | # this is too large 108 | # m = [6400, 1600, 400, 100, 25] 109 | # radius_inner = [0.5/16, 0.5/8, 0.5/4, 0.5/2, 0.5] 110 | # radius_inter = [0.5/16 * 1.41, 0.5/8* 1.41, 0.5/4* 1.41, 0.5/2* 1.41] 111 | 112 | for case in range(1): 113 | 114 | # this is done in experiment 1 115 | # if case == 0: 116 | # m = [1600, 400, 100, 25] 117 | # radius_inner = [ 0.5/8, 0.5/4, 0.5/2, 0.5] 118 | # radius_inter = [0.5/8* 1.41, 0.5/4* 1.41, 0.5/2* 1.41] 119 | 120 | if case == 1: 121 | m = [1600, 400, 100] 122 | radius_inner = [ 0.5/8, 0.5/4, 0.5/2] 123 | radius_inter = [0.5/8* 1.41, 0.5/4* 1.41] 124 | 125 | # if case == 0: 126 | # m = [1600, 400] 127 | # radius_inner = [ 0.5/8, 0.5/4] 128 | # radius_inter = [0.5/8* 1.41] 129 | 130 | if case == 0: 131 | m = [25, 25] 132 | radius_inner = [ 0.5, 0.5/4] 133 | radius_inter = [0.5/8* 1.41] 134 | 135 | level = len(m) 136 | print('resolution', s) 137 | 138 | ntrain = 100 139 | ntest = 100 140 | 141 | # don't change this 142 | batch_size = 1 143 | batch_size2 = 1 144 | 145 | width = 64 146 | ker_width = 256 147 | depth = 4 148 | edge_features = 6 149 | node_features = 6 150 | 151 | epochs = 200 152 | learning_rate = 0.1 / ntrain 153 | scheduler_step = 10 154 | scheduler_gamma = 0.8 155 | 156 | 157 | 158 | path = 'neurips1_multigraph_s'+str(s)+'_ntrain'+str(ntrain)+'_kerwidth'+str(ker_width) + '_lenm1' #+ str(len(m)) 159 | path_model = 'model/' + path 160 | path_train_err = 'results/' + path + 'train.txt' 161 | path_test_err = 'results/' + path + 'test.txt' 162 | path_runtime = 'results/' + path + 'time.txt' 163 | path_image = 'results/' + path 164 | 165 | runtime = np.zeros(2,) 166 | 167 | t1 = default_timer() 168 | 169 | 170 | reader = MatReader(TRAIN_PATH) 171 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 172 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 173 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 174 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 175 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 176 | 177 | reader.load_file(TEST_PATH) 178 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 179 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 180 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 181 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 182 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 183 | 184 | 185 | a_normalizer = GaussianNormalizer(train_a) 186 | train_a = a_normalizer.encode(train_a) 187 | test_a = a_normalizer.encode(test_a) 188 | as_normalizer = GaussianNormalizer(train_a_smooth) 189 | train_a_smooth = as_normalizer.encode(train_a_smooth) 190 | test_a_smooth = as_normalizer.encode(test_a_smooth) 191 | agx_normalizer = GaussianNormalizer(train_a_gradx) 192 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 193 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 194 | agy_normalizer = GaussianNormalizer(train_a_grady) 195 | train_a_grady = agy_normalizer.encode(train_a_grady) 196 | test_a_grady = agy_normalizer.encode(test_a_grady) 197 | 198 | u_normalizer = UnitGaussianNormalizer(train_u) 199 | train_u = u_normalizer.encode(train_u) 200 | # test_u = y_normalizer.encode(test_u) 201 | 202 | meshgenerator = RandomMultiMeshGenerator([[0, 1], [0, 1]], [s, s], level=level, sample_sizes=m) 203 | data_train = [] 204 | for j in range(ntrain): 205 | for i in range(k): 206 | idx, idx_all = meshgenerator.sample() 207 | grid, grid_all = meshgenerator.get_grid() 208 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, radius_inter) 209 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 210 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=train_a[j, :]) 211 | x = torch.cat([grid_all, train_a[j, idx_all].reshape(-1, 1), 212 | train_a_smooth[j, idx_all].reshape(-1, 1), 213 | train_a_gradx[j, idx_all].reshape(-1, 1), 214 | train_a_grady[j, idx_all].reshape(-1, 1) 215 | ], dim=1) 216 | data_train.append(Data(x=x, y=train_u[j, idx[0]], 217 | edge_index_mid=edge_index, edge_index_down=edge_index_down, 218 | edge_index_up=edge_index_up, 219 | edge_index_range=edge_index_range, edge_index_down_range=edge_index_down_range, 220 | edge_index_up_range=edge_index_up_range, 221 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, edge_attr_up=edge_attr_up, 222 | sample_idx=idx[0])) 223 | 224 | print(x.shape) 225 | print(edge_index_range) 226 | print(edge_index_down_range) 227 | print(edge_index_up_range) 228 | 229 | print(edge_index.shape, edge_attr.shape) 230 | print(edge_index_down.shape, edge_attr_down.shape) 231 | print(edge_index_up.shape, edge_attr_up.shape) 232 | 233 | meshgenerator = RandomMultiMeshGenerator([[0, 1], [0, 1]], [s, s], level=level, sample_sizes=m) 234 | data_test = [] 235 | for j in range(ntest): 236 | for i in range(k): 237 | idx, idx_all = meshgenerator.sample() 238 | grid, grid_all = meshgenerator.get_grid() 239 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, radius_inter) 240 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 241 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=test_a[j, :]) 242 | x = torch.cat([grid_all, test_a[j, idx_all].reshape(-1, 1), 243 | test_a_smooth[j, idx_all].reshape(-1, 1), 244 | test_a_gradx[j, idx_all].reshape(-1, 1), 245 | test_a_grady[j, idx_all].reshape(-1, 1) 246 | ], dim=1) 247 | data_test.append(Data(x=x, y=test_u[j, idx[0]], 248 | edge_index_mid=edge_index, edge_index_down=edge_index_down, 249 | edge_index_up=edge_index_up, 250 | edge_index_range=edge_index_range, edge_index_down_range=edge_index_down_range, 251 | edge_index_up_range=edge_index_up_range, 252 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, edge_attr_up=edge_attr_up, 253 | sample_idx=idx[0])) 254 | # 255 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 256 | test_loader = DataLoader(data_test, batch_size=batch_size2, shuffle=False) 257 | 258 | t2 = default_timer() 259 | 260 | print('preprocessing finished, time used:', t2-t1) 261 | device = torch.device('cuda') 262 | 263 | # print('use pre-train model') 264 | # model = torch.load('model/multigraph2241_n100') 265 | 266 | model = KernelInduced(width=width, ker_width=ker_width, depth=depth, ker_in=edge_features, 267 | points=m, level=level, in_width=node_features, out_width=1).cuda() 268 | 269 | # model = KernelInduced_SUM(width=width, ker_width=ker_width, depth=depth, ker_in=edge_features, 270 | # points=m, level=level, in_width=node_features, out_width=1).cuda() 271 | 272 | 273 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 274 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 275 | 276 | myloss = LpLoss(size_average=False) 277 | u_normalizer.cuda() 278 | ttrain = np.zeros((epochs, )) 279 | ttest = np.zeros((epochs,)) 280 | model.train() 281 | for ep in range(epochs): 282 | t1 = default_timer() 283 | train_mse = 0.0 284 | train_l2 = 0.0 285 | for batch in train_loader: 286 | batch = batch.to(device) 287 | 288 | optimizer.zero_grad() 289 | out = model(batch) 290 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 291 | # mse.backward() 292 | 293 | l2 = myloss( 294 | u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 295 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1))) 296 | l2.backward() 297 | 298 | optimizer.step() 299 | train_mse += mse.item() 300 | train_l2 += l2.item() 301 | 302 | scheduler.step() 303 | t2 = default_timer() 304 | ttrain[ep] = train_l2 / (ntrain * k) 305 | 306 | print(ep, t2 - t1, train_mse / len(train_loader), train_l2 / (ntrain * k)) 307 | 308 | runtime[0] = t2 - t1 309 | 310 | t1 = default_timer() 311 | 312 | model.eval() 313 | test_l2 = 0.0 314 | with torch.no_grad(): 315 | for batch in test_loader: 316 | batch = batch.to(device) 317 | out = model(batch) 318 | out = u_normalizer.decode(out.view(batch_size2, -1), sample_idx=batch.sample_idx.view(batch_size2, -1)) 319 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 320 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 321 | 322 | ttest[ep] = test_l2 / ntest 323 | t2 = default_timer() 324 | print(ep, t2 - t1, test_l2 / ntest) 325 | 326 | runtime[1] = t2 - t1 327 | 328 | np.savetxt(path_train_err, ttrain) 329 | np.savetxt(path_test_err, ttest) 330 | np.savetxt(path_runtime, runtime) 331 | torch.save(model, path_model) 332 | 333 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/neurips3_MGKN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | from torch_geometric.data import Data, DataLoader 8 | import matplotlib.pyplot as plt 9 | from utilities import * 10 | from torch_geometric.nn import GCNConv, NNConv 11 | 12 | from timeit import default_timer 13 | import scipy.io 14 | 15 | torch.manual_seed(0) 16 | np.random.seed(0) 17 | 18 | 19 | 20 | class KernelInduced(torch.nn.Module): 21 | def __init__(self, width, ker_width, depth, ker_in, points, level, in_width=1, out_width=1): 22 | super(KernelInduced, self).__init__() 23 | self.depth = depth 24 | self.width = width 25 | self.level = level 26 | self.points = points 27 | self.points_total = np.sum(points) 28 | 29 | # in 30 | self.fc_in = torch.nn.Linear(in_width, width) 31 | # self.fc_in_list = [] 32 | # for l in range(level): 33 | # self.fc_in_list.append(torch.nn.Linear(in_width, width)) 34 | # self.fc_in_list = torch.nn.ModuleList(self.fc_in_list) 35 | 36 | # K12 K23 K34 ... 37 | self.conv_down_list = [] 38 | for l in range(1, level): 39 | ker_width_l = ker_width // (2 ** l) 40 | kernel_l = DenseNet([ker_in, ker_width_l, width ** 2], torch.nn.ReLU) 41 | self.conv_down_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 42 | self.conv_down_list = torch.nn.ModuleList(self.conv_down_list) 43 | 44 | # K11 K22 K33 45 | self.conv_list = [] 46 | for l in range(level): 47 | ker_width_l = ker_width // (2 ** l) 48 | kernel_l = DenseNet([ker_in, ker_width_l, ker_width_l, width ** 2], torch.nn.ReLU) 49 | self.conv_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 50 | self.conv_list = torch.nn.ModuleList(self.conv_list) 51 | 52 | # K21 K32 K43 53 | self.conv_up_list = [] 54 | for l in range(1, level): 55 | ker_width_l = ker_width // (2 ** l) 56 | kernel_l = DenseNet([ker_in, ker_width_l, width ** 2], torch.nn.ReLU) 57 | self.conv_up_list.append(NNConv(width, width, kernel_l, aggr='mean', root_weight=False, bias=False)) 58 | self.conv_up_list = torch.nn.ModuleList(self.conv_up_list) 59 | 60 | # out 61 | self.fc_out1 = torch.nn.Linear(width, ker_width) 62 | self.fc_out2 = torch.nn.Linear(ker_width, 1) 63 | 64 | 65 | def forward(self, data): 66 | edge_index_down, edge_attr_down, range_down = data.edge_index_down, data.edge_attr_down, data.edge_index_down_range 67 | edge_index_mid, edge_attr_mid, range_mid = data.edge_index_mid, data.edge_attr_mid, data.edge_index_range 68 | edge_index_up, edge_attr_up, range_up = data.edge_index_up, data.edge_attr_up, data.edge_index_up_range 69 | 70 | x = self.fc_in(data.x) 71 | 72 | for t in range(self.depth): 73 | #downward 74 | for l in range(self.level-1): 75 | x = x + self.conv_down_list[l](x, edge_index_down[:,range_down[l,0]:range_down[l,1]], edge_attr_down[range_down[l,0]:range_down[l,1],:]) 76 | x = F.relu(x) 77 | 78 | #upward 79 | for l in reversed(range(self.level)): 80 | x = x + self.conv_list[l](x, edge_index_mid[:,range_mid[l,0]:range_mid[l,1]], edge_attr_mid[range_mid[l,0]:range_mid[l,1],:]) 81 | x = F.relu(x) 82 | if l > 0: 83 | x = x + self.conv_up_list[l-1](x, edge_index_up[:,range_up[l-1,0]:range_up[l-1,1]], edge_attr_up[range_up[l-1,0]:range_up[l-1,1],:]) 84 | x = F.relu(x) 85 | 86 | 87 | x = F.relu(self.fc_out1(x[:self.points[0]])) 88 | x = self.fc_out2(x) 89 | return x 90 | 91 | 92 | 93 | TRAIN_PATH = 'data/piececonst_r241_N1024_smooth1.mat' 94 | TEST_PATH = 'data/piececonst_r241_N1024_smooth2.mat' 95 | 96 | 97 | for r in [8,6,4,2,1]: 98 | 99 | s = int(((241 - 1) / r) + 1) 100 | n = s ** 2 101 | k = 1 102 | 103 | print('!!!!!!!!!!!!!! s ', s, ' !!!!!!!!!!!!!!!!!!!!!!!!') 104 | 105 | m = [400, 100, 25] 106 | radius_inner = [0.5 / 4, 0.5 / 2, 0.5] 107 | radius_inter = [0.5 / 4 * 1.41, 0.5 / 2 * 1.41] 108 | 109 | 110 | level = len(m) 111 | print('resolution', s) 112 | 113 | ntrain = 100 114 | ntest = 100 115 | 116 | # don't change this 117 | batch_size = 1 118 | batch_size2 = 1 119 | 120 | width = 64 121 | ker_width = 256 122 | depth = 4 123 | edge_features = 6 124 | node_features = 6 125 | 126 | epochs = 200 127 | learning_rate = 0.1 / ntrain 128 | scheduler_step = 10 129 | scheduler_gamma = 0.8 130 | 131 | 132 | 133 | path = 'neurips3_multigraph_s'+str(s)+'_ntrain'+str(ntrain)+'_kerwidth'+str(ker_width) + 'r' + str(r) 134 | path_model = 'model/' + path 135 | path_train_err = 'results/' + path + 'train.txt' 136 | path_test_err61 = 'results/'+path+'test61.txt' 137 | path_test_err121 = 'results/'+path+'test121.txt' 138 | path_test_err241 = 'results/'+path+'test241.txt' 139 | path_runtime = 'results/' + path + 'time.txt' 140 | path_image = 'results/' + path 141 | 142 | runtime = np.zeros(2,) 143 | 144 | t1 = default_timer() 145 | 146 | 147 | reader = MatReader(TRAIN_PATH) 148 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 149 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 150 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 151 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 152 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 153 | 154 | reader.load_file(TEST_PATH) 155 | test_a = reader.read_field('coeff')[:ntest,:,:] 156 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,:,:] 157 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,:,:] 158 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,:,:] 159 | test_u = reader.read_field('sol')[:ntest,:,:] 160 | 161 | 162 | a_normalizer = GaussianNormalizer(train_a) 163 | train_a = a_normalizer.encode(train_a) 164 | test_a = a_normalizer.encode(test_a) 165 | as_normalizer = GaussianNormalizer(train_a_smooth) 166 | train_a_smooth = as_normalizer.encode(train_a_smooth) 167 | test_a_smooth = as_normalizer.encode(test_a_smooth) 168 | agx_normalizer = GaussianNormalizer(train_a_gradx) 169 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 170 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 171 | agy_normalizer = GaussianNormalizer(train_a_grady) 172 | train_a_grady = agy_normalizer.encode(train_a_grady) 173 | test_a_grady = agy_normalizer.encode(test_a_grady) 174 | 175 | test_a61 = test_a[:ntest, ::4, ::4].reshape(ntest, -1) 176 | test_a_smooth61 = test_a_smooth[:ntest, ::4, ::4].reshape(ntest, -1) 177 | test_a_gradx61 = test_a_gradx[:ntest, ::4, ::4].reshape(ntest, -1) 178 | test_a_grady61 = test_a_grady[:ntest, ::4, ::4].reshape(ntest, -1) 179 | test_u61 = test_u[:ntest, ::4, ::4].reshape(ntest, -1) 180 | 181 | test_a121 = test_a[:ntest, ::2, ::2].reshape(ntest, -1) 182 | test_a_smooth121 = test_a_smooth[:ntest, ::2, ::2].reshape(ntest, -1) 183 | test_a_gradx121 = test_a_gradx[:ntest, ::2, ::2].reshape(ntest, -1) 184 | test_a_grady121 = test_a_grady[:ntest, ::2, ::2].reshape(ntest, -1) 185 | test_u121 = test_u[:ntest, ::2, ::2].reshape(ntest, -1) 186 | 187 | test_a241 = test_a.reshape(ntest, -1) 188 | test_a_smooth241 = test_a_smooth.reshape(ntest, -1) 189 | test_a_gradx241 = test_a_gradx.reshape(ntest, -1) 190 | test_a_grady241 = test_a_grady.reshape(ntest, -1) 191 | test_u241 = test_u.reshape(ntest, -1) 192 | 193 | u_normalizer = GaussianNormalizer(train_u) 194 | train_u = u_normalizer.encode(train_u) 195 | # test_u = y_normalizer.encode(test_u) 196 | 197 | meshgenerator = RandomMultiMeshGenerator([[0, 1], [0, 1]], [s, s], level=level, sample_sizes=m) 198 | data_train = [] 199 | for j in range(ntrain): 200 | for i in range(k): 201 | idx, idx_all = meshgenerator.sample() 202 | grid, grid_all = meshgenerator.get_grid() 203 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, radius_inter) 204 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 205 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=train_a[j, :]) 206 | x = torch.cat([grid_all, train_a[j, idx_all].reshape(-1, 1), 207 | train_a_smooth[j, idx_all].reshape(-1, 1), 208 | train_a_gradx[j, idx_all].reshape(-1, 1), 209 | train_a_grady[j, idx_all].reshape(-1, 1) 210 | ], dim=1) 211 | data_train.append(Data(x=x, y=train_u[j, idx[0]], 212 | edge_index_mid=edge_index, edge_index_down=edge_index_down, 213 | edge_index_up=edge_index_up, 214 | edge_index_range=edge_index_range, edge_index_down_range=edge_index_down_range, 215 | edge_index_up_range=edge_index_up_range, 216 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, edge_attr_up=edge_attr_up, 217 | sample_idx=idx[0])) 218 | 219 | print(x.shape) 220 | print(edge_index_range) 221 | print(edge_index_down_range) 222 | print(edge_index_up_range) 223 | 224 | print(edge_index.shape, edge_attr.shape) 225 | print(edge_index_down.shape, edge_attr_down.shape) 226 | print(edge_index_up.shape, edge_attr_up.shape) 227 | 228 | meshgenerator = RandomMultiMeshGenerator([[0, 1], [0, 1]], [241, 241], level=level, sample_sizes=m) 229 | data_test241 = [] 230 | for j in range(ntest): 231 | for i in range(k): 232 | idx, idx_all = meshgenerator.sample() 233 | grid, grid_all = meshgenerator.get_grid() 234 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, radius_inter) 235 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 236 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=test_a241[j, :]) 237 | x = torch.cat([grid_all, test_a241[j, idx_all].reshape(-1, 1), 238 | test_a_smooth241[j, idx_all].reshape(-1, 1), 239 | test_a_gradx241[j, idx_all].reshape(-1, 1), 240 | test_a_grady241[j, idx_all].reshape(-1, 1) 241 | ], dim=1) 242 | data_test241.append(Data(x=x, y=test_u241[j, idx[0]], 243 | edge_index_mid=edge_index, edge_index_down=edge_index_down, 244 | edge_index_up=edge_index_up, 245 | edge_index_range=edge_index_range, edge_index_down_range=edge_index_down_range, 246 | edge_index_up_range=edge_index_up_range, 247 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, edge_attr_up=edge_attr_up, 248 | sample_idx=idx[0])) 249 | 250 | meshgenerator = RandomMultiMeshGenerator([[0, 1], [0, 1]], [121, 121], level=level, sample_sizes=m) 251 | data_test121 = [] 252 | for j in range(ntest): 253 | for i in range(k): 254 | idx, idx_all = meshgenerator.sample() 255 | grid, grid_all = meshgenerator.get_grid() 256 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, radius_inter) 257 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 258 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=test_a121[j, :]) 259 | x = torch.cat([grid_all, test_a121[j, idx_all].reshape(-1, 1), 260 | test_a_smooth121[j, idx_all].reshape(-1, 1), 261 | test_a_gradx121[j, idx_all].reshape(-1, 1), 262 | test_a_grady121[j, idx_all].reshape(-1, 1) 263 | ], dim=1) 264 | data_test121.append(Data(x=x, y=test_u121[j, idx[0]], 265 | edge_index_mid=edge_index, edge_index_down=edge_index_down, 266 | edge_index_up=edge_index_up, 267 | edge_index_range=edge_index_range, edge_index_down_range=edge_index_down_range, 268 | edge_index_up_range=edge_index_up_range, 269 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, edge_attr_up=edge_attr_up, 270 | sample_idx=idx[0])) 271 | 272 | meshgenerator = RandomMultiMeshGenerator([[0, 1], [0, 1]], [61, 61], level=level, sample_sizes=m) 273 | data_test61 = [] 274 | for j in range(ntest): 275 | for i in range(k): 276 | idx, idx_all = meshgenerator.sample() 277 | grid, grid_all = meshgenerator.get_grid() 278 | edge_index, edge_index_down, edge_index_up = meshgenerator.ball_connectivity(radius_inner, 279 | radius_inter) 280 | edge_index_range, edge_index_down_range, edge_index_up_range = meshgenerator.get_edge_index_range() 281 | edge_attr, edge_attr_down, edge_attr_up = meshgenerator.attributes(theta=test_a61[j, :]) 282 | x = torch.cat([grid_all, test_a61[j, idx_all].reshape(-1, 1), 283 | test_a_smooth61[j, idx_all].reshape(-1, 1), 284 | test_a_gradx61[j, idx_all].reshape(-1, 1), 285 | test_a_grady61[j, idx_all].reshape(-1, 1) 286 | ], dim=1) 287 | data_test61.append(Data(x=x, y=test_u61[j, idx[0]], 288 | edge_index_mid=edge_index, edge_index_down=edge_index_down, 289 | edge_index_up=edge_index_up, 290 | edge_index_range=edge_index_range, 291 | edge_index_down_range=edge_index_down_range, 292 | edge_index_up_range=edge_index_up_range, 293 | edge_attr_mid=edge_attr, edge_attr_down=edge_attr_down, 294 | edge_attr_up=edge_attr_up, 295 | sample_idx=idx[0])) 296 | 297 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 298 | test_loader241 = DataLoader(data_test241, batch_size=batch_size2, shuffle=False) 299 | test_loader121 = DataLoader(data_test121, batch_size=batch_size2, shuffle=False) 300 | test_loader61 = DataLoader(data_test61, batch_size=batch_size2, shuffle=False) 301 | 302 | t2 = default_timer() 303 | 304 | print('preprocessing finished, time used:', t2-t1) 305 | device = torch.device('cuda') 306 | 307 | # print('use pre-train model') 308 | # model = torch.load('model/multigraph2241_n100') 309 | 310 | model = KernelInduced(width=width, ker_width=ker_width, depth=depth, ker_in=edge_features, 311 | points=m, level=level, in_width=node_features, out_width=1).cuda() 312 | 313 | # model = KernelInduced_SUM(width=width, ker_width=ker_width, depth=depth, ker_in=edge_features, 314 | # points=m, level=level, in_width=node_features, out_width=1).cuda() 315 | 316 | 317 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 318 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 319 | 320 | myloss = LpLoss(size_average=False) 321 | u_normalizer.cuda() 322 | ttrain = np.zeros((epochs, )) 323 | ttest241 = np.zeros((epochs,)) 324 | ttest121 = np.zeros((epochs,)) 325 | ttest61 = np.zeros((epochs,)) 326 | model.train() 327 | for ep in range(epochs): 328 | t1 = default_timer() 329 | train_mse = 0.0 330 | train_l2 = 0.0 331 | for batch in train_loader: 332 | batch = batch.to(device) 333 | 334 | optimizer.zero_grad() 335 | out = model(batch) 336 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 337 | # mse.backward() 338 | 339 | l2 = myloss( 340 | u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 341 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1))) 342 | l2.backward() 343 | 344 | optimizer.step() 345 | train_mse += mse.item() 346 | train_l2 += l2.item() 347 | 348 | scheduler.step() 349 | t2 = default_timer() 350 | ttrain[ep] = train_l2 / (ntrain * k) 351 | 352 | print(ep, t2 - t1, train_mse / len(train_loader), train_l2 / (ntrain * k)) 353 | 354 | runtime[0] = t2 - t1 355 | 356 | 357 | model.eval() 358 | test_l2 = 0.0 359 | with torch.no_grad(): 360 | for batch in test_loader241: 361 | batch = batch.to(device) 362 | out = model(batch) 363 | out = u_normalizer.decode(out.view(batch_size2, -1), sample_idx=batch.sample_idx.view(batch_size2, -1)) 364 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 365 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 366 | ttest241[ep] = test_l2 / ntest 367 | print(ep, '241', t2 - t1, test_l2 / ntest) 368 | test_l2 = 0.0 369 | with torch.no_grad(): 370 | for batch in test_loader121: 371 | batch = batch.to(device) 372 | out = model(batch) 373 | out = u_normalizer.decode(out.view(batch_size2, -1), sample_idx=batch.sample_idx.view(batch_size2, -1)) 374 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 375 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 376 | ttest121[ep] = test_l2 / ntest 377 | print(ep, '121', t2 - t1, test_l2 / ntest) 378 | test_l2 = 0.0 379 | with torch.no_grad(): 380 | for batch in test_loader61: 381 | batch = batch.to(device) 382 | out = model(batch) 383 | out = u_normalizer.decode(out.view(batch_size2, -1), sample_idx=batch.sample_idx.view(batch_size2, -1)) 384 | test_l2 += myloss(out, batch.y.view(batch_size2, -1)).item() 385 | # test_l2 += myloss(out.view(batch_size2,-1), y_normalizer.encode(batch.y.view(batch_size2, -1))).item() 386 | ttest61[ep] = test_l2 / ntest 387 | print(ep, '61', t2 - t1, test_l2 / ntest) 388 | 389 | 390 | np.savetxt(path_train_err, ttrain) 391 | np.savetxt(path_test_err61, ttest61) 392 | np.savetxt(path_test_err121, ttest121) 393 | np.savetxt(path_test_err241, ttest241) 394 | torch.save(model, path_model) 395 | 396 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/neurips4_GCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | from torch_geometric.data import Data, DataLoader 8 | import matplotlib.pyplot as plt 9 | from utilities import * 10 | from torch_geometric.nn import GCNConv, NNConv 11 | 12 | from timeit import default_timer 13 | import scipy.io 14 | 15 | torch.manual_seed(0) 16 | np.random.seed(0) 17 | 18 | 19 | 20 | class GCN_Net(torch.nn.Module): 21 | def __init__(self, width, ker_width, depth, in_width=1, out_width=1): 22 | super(GCN_Net, self).__init__() 23 | self.depth = depth 24 | self.width = width 25 | 26 | self.fc_in = torch.nn.Linear(in_width, width) 27 | 28 | self.conv1 = GCNConv(width, width) 29 | self.conv2 = GCNConv(width, width) 30 | self.conv3 = GCNConv(width, width) 31 | self.conv4 = GCNConv(width, width) 32 | 33 | 34 | self.fc_out1 = torch.nn.Linear(width, ker_width) 35 | self.fc_out2 = torch.nn.Linear(ker_width, 1) 36 | 37 | 38 | def forward(self, data): 39 | x, edge_index = data.x, data.edge_index 40 | x = self.fc_in(data.x) 41 | 42 | for t in range(self.depth): 43 | x = self.conv1(x, edge_index) 44 | x = F.relu(x) 45 | x = self.conv2(x, edge_index) 46 | x = F.relu(x) 47 | x = self.conv3(x, edge_index) 48 | x = F.relu(x) 49 | x = self.conv4(x, edge_index) 50 | x = F.relu(x) 51 | 52 | x = F.relu(self.fc_out1(x)) 53 | x = self.fc_out2(x) 54 | return x 55 | 56 | 57 | 58 | TRAIN_PATH = 'data/piececonst_r421_N1024_smooth1.mat' 59 | TEST_PATH = 'data/piececonst_r421_N1024_smooth2.mat' 60 | 61 | 62 | r = 1 63 | s = int(((421 - 1)/r) + 1) 64 | n = s**2 65 | k = 1 66 | 67 | 68 | print('resolution', s) 69 | 70 | ntrain = 1024 71 | ntest = 100 72 | 73 | batch_size = 1 74 | batch_size2 = 1 75 | 76 | width = 128 77 | ker_width = 1024 78 | depth = 4 79 | 80 | node_features = 6 81 | 82 | epochs = 51 83 | learning_rate = 0.0001 84 | scheduler_step = 10 85 | scheduler_gamma = 0.85 86 | 87 | 88 | 89 | path = 'neurips4_GCN_s'+str(s)+'_ntrain'+str(ntrain)+'_kerwidth'+str(ker_width) 90 | path_model = 'model/' + path 91 | path_train_err = 'results/' + path + 'train.txt' 92 | path_test_err = 'results/' + path + 'test.txt' 93 | path_image = 'results/' + path 94 | 95 | 96 | t1 = default_timer() 97 | 98 | 99 | reader = MatReader(TRAIN_PATH) 100 | train_a = reader.read_field('coeff')[:ntrain,::r,::r].reshape(ntrain,-1) 101 | train_a_smooth = reader.read_field('Kcoeff')[:ntrain,::r,::r].reshape(ntrain,-1) 102 | train_a_gradx = reader.read_field('Kcoeff_x')[:ntrain,::r,::r].reshape(ntrain,-1) 103 | train_a_grady = reader.read_field('Kcoeff_y')[:ntrain,::r,::r].reshape(ntrain,-1) 104 | train_u = reader.read_field('sol')[:ntrain,::r,::r].reshape(ntrain,-1) 105 | 106 | reader.load_file(TEST_PATH) 107 | test_a = reader.read_field('coeff')[:ntest,::r,::r].reshape(ntest,-1) 108 | test_a_smooth = reader.read_field('Kcoeff')[:ntest,::r,::r].reshape(ntest,-1) 109 | test_a_gradx = reader.read_field('Kcoeff_x')[:ntest,::r,::r].reshape(ntest,-1) 110 | test_a_grady = reader.read_field('Kcoeff_y')[:ntest,::r,::r].reshape(ntest,-1) 111 | test_u = reader.read_field('sol')[:ntest,::r,::r].reshape(ntest,-1) 112 | 113 | 114 | a_normalizer = GaussianNormalizer(train_a) 115 | train_a = a_normalizer.encode(train_a) 116 | test_a = a_normalizer.encode(test_a) 117 | as_normalizer = GaussianNormalizer(train_a_smooth) 118 | train_a_smooth = as_normalizer.encode(train_a_smooth) 119 | test_a_smooth = as_normalizer.encode(test_a_smooth) 120 | agx_normalizer = GaussianNormalizer(train_a_gradx) 121 | train_a_gradx = agx_normalizer.encode(train_a_gradx) 122 | test_a_gradx = agx_normalizer.encode(test_a_gradx) 123 | agy_normalizer = GaussianNormalizer(train_a_grady) 124 | train_a_grady = agy_normalizer.encode(train_a_grady) 125 | test_a_grady = agy_normalizer.encode(test_a_grady) 126 | 127 | u_normalizer = UnitGaussianNormalizer(train_u) 128 | train_u = u_normalizer.encode(train_u) 129 | # test_u = y_normalizer.encode(test_u) 130 | 131 | X, edge_index, _ = grid_edge(s, s) 132 | 133 | data_train = [] 134 | for j in range(ntrain): 135 | for i in range(k): 136 | x = torch.cat([X, train_a[j].reshape(-1, 1), 137 | train_a_smooth[j].reshape(-1, 1), 138 | train_a_gradx[j].reshape(-1, 1), 139 | train_a_grady[j].reshape(-1, 1) 140 | ], dim=1) 141 | data_train.append(Data(x=x, y=train_u[j], edge_index=edge_index)) 142 | 143 | print(x.shape) 144 | print(edge_index.shape) 145 | 146 | 147 | 148 | data_test = [] 149 | for j in range(ntest): 150 | x = torch.cat([X, test_a[j].reshape(-1, 1), 151 | test_a_smooth[j].reshape(-1, 1), 152 | test_a_gradx[j].reshape(-1, 1), 153 | test_a_grady[j].reshape(-1, 1) 154 | ], dim=1) 155 | data_test.append(Data(x=x, y=test_u[j], edge_index=edge_index)) 156 | 157 | print(x.shape) 158 | print(edge_index.shape) 159 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 160 | test_loader = DataLoader(data_test, batch_size=batch_size2, shuffle=False) 161 | t2 = default_timer() 162 | 163 | print('preprocessing finished, time used:', t2-t1) 164 | device = torch.device('cuda') 165 | 166 | # print('use pre-train model') 167 | # model = torch.load('model/multigraph_full_s141_ntrain1000_kerwidth1024') 168 | 169 | model = GCN_Net(width=width, ker_width=ker_width, depth=depth, in_width=node_features, out_width=1).cuda() 170 | 171 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 172 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 173 | 174 | myloss = LpLoss(size_average=False) 175 | ttrain = np.zeros((epochs, )) 176 | ttest = np.zeros((epochs,)) 177 | 178 | for ep in range(epochs): 179 | t1 = default_timer() 180 | train_mse = 0.0 181 | train_l2 = 0.0 182 | model.train() 183 | u_normalizer.cuda() 184 | for batch in train_loader: 185 | batch = batch.to(device) 186 | 187 | optimizer.zero_grad() 188 | out = model(batch) 189 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 190 | # mse.backward() 191 | 192 | l2 = myloss( 193 | u_normalizer.decode(out.view(batch_size, -1)), 194 | u_normalizer.decode(batch.y.view(batch_size, -1))) 195 | l2.backward() 196 | 197 | optimizer.step() 198 | train_mse += mse.item() 199 | train_l2 += l2.item() 200 | 201 | scheduler.step() 202 | t2 = default_timer() 203 | ttrain[ep] = train_l2 / (ntrain * k) 204 | 205 | print(ep, t2 - t1, train_mse / len(train_loader), train_l2 / (ntrain * k)) 206 | 207 | if ep % 10 == 0: 208 | model.eval() 209 | test_l2 = 0.0 210 | with torch.no_grad(): 211 | t1 = default_timer() 212 | for batch in test_loader: 213 | batch = batch.to(device) 214 | out = model(batch) 215 | l2 = myloss(u_normalizer.decode(out.view(batch_size2, -1)), batch.y.view(batch_size2, -1)) 216 | test_l2 += l2.item() 217 | 218 | 219 | ttest[ep] = test_l2 / ntest 220 | t2 = default_timer() 221 | print(ep, t2 - t1, test_l2 / (ntest*k) ) 222 | # torch.save(model, path_model+str(ep)) 223 | 224 | 225 | 226 | np.savetxt(path_train_err, ttrain) 227 | np.savetxt(path_test_err, ttest) 228 | 229 | 230 | -------------------------------------------------------------------------------- /multipole-graph-neural-operator/neurips5_GKN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | from torch_geometric.data import Data, DataLoader 8 | import matplotlib.pyplot as plt 9 | from utilities import * 10 | from nn_conv import NNConv, NNConv_old 11 | 12 | from timeit import default_timer 13 | import scipy.io 14 | 15 | class KernelNN(torch.nn.Module): 16 | def __init__(self, width, ker_width, depth, ker_in, in_width=1, out_width=1): 17 | super(KernelNN, self).__init__() 18 | self.depth = depth 19 | 20 | self.fc1 = torch.nn.Linear(in_width, width) 21 | 22 | kernel = DenseNet([ker_in, ker_width//2, ker_width, width**2], torch.nn.ReLU) 23 | self.conv1 = NNConv_old(width, width, kernel, aggr='mean') 24 | 25 | self.fc2 = torch.nn.Linear(width, ker_width) 26 | self.fc3 = torch.nn.Linear(ker_width, 1) 27 | 28 | def forward(self, data): 29 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 30 | x = self.fc1(x) 31 | for k in range(self.depth): 32 | x = self.conv1(x, edge_index, edge_attr) 33 | if k != self.depth-1: 34 | x = F.relu(x) 35 | 36 | x = self.fc2(x) 37 | x = F.relu(x) 38 | x = self.fc3(x) 39 | return x 40 | 41 | # torch.cuda.set_device('cuda:3') 42 | TRAIN_PATH = 'data/burgers_data_R10.mat' 43 | TEST_PATH = 'data/burgers_data_R10.mat' 44 | # TRAIN_PATH = 'data/burgers1d_small.mat' 45 | # TEST_PATH = 'data/burgers1d_small.mat' 46 | 47 | r = 8 48 | s = 2**13//r 49 | K = s 50 | 51 | ntrain = 32 52 | ntest = 32 53 | 54 | n = s 55 | k = 2 56 | trainm = 128 57 | assert n % trainm == 0 58 | train_split = s // trainm 59 | 60 | testr1 = r 61 | tests1 = 2**13 // testr1 62 | test_split = train_split 63 | testn1 = s 64 | testm = trainm 65 | 66 | batch_size = 4 # factor of ntrain * k 67 | batch_size2 = 4 # factor of test_split 68 | 69 | radius_train = 0.20 70 | radius_test = 0.20 71 | # rbf_sigma = 0.2 72 | 73 | print('resolution', s) 74 | 75 | 76 | 77 | 78 | assert test_split%batch_size2 == 0 # the batchsize must divide the split 79 | 80 | width = 64 81 | ker_width = 1024 82 | depth = 6 83 | edge_features = 4 84 | node_features = 2 85 | 86 | epochs = 101 87 | learning_rate = 0.0001 88 | scheduler_step = 10 89 | scheduler_gamma = 0.85 90 | 91 | 92 | path = 'neurips5_GKN_r'+str(s)+'_s'+ str(tests1)+'testm'+str(testm) 93 | path_model = 'model/'+path 94 | path_train_err = 'results/'+path+'train.txt' 95 | path_test_err = 'results/'+path+'test.txt' 96 | path_image = 'image/'+path 97 | 98 | 99 | t1 = default_timer() 100 | 101 | 102 | reader = MatReader(TRAIN_PATH) 103 | train_a = reader.read_field('a')[:ntrain,::r].reshape(ntrain,-1) 104 | train_u = reader.read_field('u')[:ntrain,::r].reshape(ntrain,-1) 105 | 106 | reader.load_file(TEST_PATH) 107 | test_a = reader.read_field('a')[-ntest:,::r].reshape(ntest,-1) 108 | test_u = reader.read_field('u')[-ntest:,::r].reshape(ntest,-1) 109 | 110 | 111 | a_normalizer = GaussianNormalizer(train_a) 112 | train_a = a_normalizer.encode(train_a) 113 | test_a = a_normalizer.encode(test_a) 114 | 115 | u_normalizer = UnitGaussianNormalizer(train_u) 116 | train_u = u_normalizer.encode(train_u) 117 | # test_u = y_normalizer.encode(test_u) 118 | 119 | 120 | meshgenerator = RandomMeshGenerator([[0,1]],[s], sample_size=trainm) 121 | data_train = [] 122 | for j in range(ntrain): 123 | for i in range(k): 124 | idx = meshgenerator.sample() 125 | grid = meshgenerator.get_grid() 126 | edge_index = meshgenerator.ball_connectivity(radius_train) 127 | edge_attr = meshgenerator.attributes(theta=train_a[j, :]) 128 | # data_train.append(Data(x=init_point.clone().view(-1,1), y=train_y[j,:], edge_index=edge_index, edge_attr=edge_attr)) 129 | data_train.append(Data(x=torch.cat([grid.reshape(-1, 1), train_a[j, idx].reshape(-1, 1)], dim=1), 130 | y=train_u[j, idx], edge_index=edge_index, edge_attr=edge_attr, sample_idx=idx 131 | )) 132 | 133 | train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True) 134 | # print('grid', grid.shape, 'edge_index', edge_index.shape, 'edge_attr', edge_attr.shape) 135 | # print('edge_index_boundary', edge_index_boundary.shape, 'edge_attr', edge_attr_boundary.shape) 136 | 137 | 138 | meshgenerator = SquareMeshGenerator([[0,1]],[tests1]) 139 | grid = meshgenerator.get_grid() 140 | gridsplitter = RandomGridSplitter(grid, resolution=tests1, d=1, l=1, m=testm, radius=radius_test) 141 | 142 | data_test = [] 143 | for j in range(ntest): 144 | theta =test_a[j,:].reshape(-1, 1) 145 | data_equation = gridsplitter.get_data(theta) 146 | equation_loader = DataLoader(data_equation, batch_size=batch_size2, shuffle=False) 147 | data_test.append(equation_loader) 148 | 149 | 150 | 151 | 152 | ################################################################################################## 153 | 154 | ### training 155 | 156 | ################################################################################################## 157 | t2 = default_timer() 158 | 159 | print('preprocessing finished, time used:', t2-t1) 160 | device = torch.device('cuda') 161 | 162 | model = KernelNN(width,ker_width,depth,edge_features,node_features).cuda() 163 | 164 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 165 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 166 | 167 | myloss = LpLoss(size_average=False) 168 | 169 | # gridsplitter.cuda() 170 | 171 | 172 | ttrain = np.zeros((epochs, )) 173 | ttest = np.zeros((epochs,)) 174 | for ep in range(epochs): 175 | u_normalizer.cuda() 176 | model.train() 177 | t1 = default_timer() 178 | train_mse = 0.0 179 | train_l2 = 0.0 180 | for batch in train_loader: 181 | batch = batch.to(device) 182 | 183 | optimizer.zero_grad() 184 | out = model(batch) 185 | mse = F.mse_loss(out.view(-1, 1), batch.y.view(-1,1)) 186 | # mse.backward() 187 | loss = torch.norm(out.view(-1) - batch.y.view(-1),1) 188 | loss.backward() 189 | 190 | l2 = myloss( 191 | u_normalizer.decode(out.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1)), 192 | u_normalizer.decode(batch.y.view(batch_size, -1), sample_idx=batch.sample_idx.view(batch_size, -1))) 193 | # l2.backward() 194 | 195 | optimizer.step() 196 | train_mse += mse.item() 197 | train_l2 += l2.item() 198 | 199 | ttrain[ep] = train_l2 / (ntrain * k) 200 | scheduler.step() 201 | t2 = default_timer() 202 | 203 | 204 | print(ep, t2-t1, train_mse/len(train_loader), train_l2 / (ntrain * k)) 205 | 206 | if ep % 20 == 0: 207 | model.eval() 208 | test_l2 = 0.0 209 | test_l2_split = 0.0 210 | u_normalizer.cpu() 211 | with torch.no_grad(): 212 | for i, equation_loader in enumerate(data_test): 213 | pred = [] 214 | split_idx = [] 215 | for batch in equation_loader: 216 | batch = batch.to(device) 217 | out = model(batch).detach().cpu() 218 | pred.append(out) 219 | split_idx.append(batch.split_idx) 220 | 221 | out_split = u_normalizer.decode(out.view(batch_size2, -1), 222 | sample_idx=batch.split_idx.view(batch_size2, -1)) 223 | test_l2_split += myloss(out_split, test_u[i, batch.split_idx]).item() 224 | 225 | 226 | out = gridsplitter.assemble(pred, split_idx, batch_size2, sigma=1) 227 | out = u_normalizer.decode(out.view(1, -1)) 228 | test_l2 += myloss(out, test_u[i].view(1, -1)).item() 229 | 230 | ttest[ep] = test_l2 / ntest 231 | t3 = default_timer() 232 | print(ep, t3-t2, test_l2/ntest, test_l2_split/(ntest*test_split)) 233 | 234 | 235 | 236 | 237 | 238 | np.savetxt(path_train_err, ttrain) 239 | np.savetxt(path_test_err, ttest) 240 | torch.save(model, path_model) 241 | ################################################################################################## 242 | 243 | ### Ploting 244 | 245 | ################################################################################################## 246 | 247 | 248 | 249 | plt.figure() 250 | # plt.plot(ttrain, label='train loss') 251 | plt.plot(ttest, label='test loss') 252 | plt.legend(loc='upper right') 253 | plt.show() 254 | 255 | 256 | 257 | 258 | --------------------------------------------------------------------------------