├── pics ├── 1 ├── burgers mesh8.pdf └── burgers mesh8.png ├── README.md ├── PDEs.py ├── interpolate.py ├── gnn_2d.py ├── env.yml ├── models_cnn.py ├── mesh ├── dmm.py ├── dmm_model.py └── dmm_utils.py ├── train_helper_2d.py ├── data_creator_2d.py └── mmpde.py /pics/1: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pics/burgers mesh8.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Peiyannn/MM-PDE/HEAD/pics/burgers mesh8.pdf -------------------------------------------------------------------------------- /pics/burgers mesh8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Peiyannn/MM-PDE/HEAD/pics/burgers mesh8.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MM-PDE: Better Neural PDE Solvers Through Data-Free Mesh Movers 2 | 3 | [Link to the paper](https://openreview.net/pdf?id=hj9ZuNimRl) (ICLR 2024) 4 | 5 | This paper introduces a neural-network-based mesh adapter called **Data-free Mesh Mover (DMM)**, which is trained in a physics-informed data-free way. The DMM can be embedded into the neural PDE solver through proper architectural design, called **MM-PDE**. 6 | 7 | 8 | 9 | ## Environment 10 | 11 | Install the environment using [conda](https://docs.conda.io/en/latest/miniconda.html) with attached environment file as follows. 12 | 13 | ```code 14 | conda env create -f env.yml 15 | ``` 16 | 17 | ## Dataset 18 | 19 | Download the datasets into the "mesh/data/" folder in the local repo via [this link](https://drive.google.com/drive/folders/1TI2xHsOqAIFNu7EBS6IrkNI7ivZtGXrX?usp=sharing). 20 | 21 | ## Training of Data-free Mesh Mover (DMM) 22 | 23 | - Burgers' equation: 24 | ```code 25 | cd mesh 26 | python dmm.py 27 | ``` 28 | - Flow around a cylinder: 29 | ```code 30 | cd mesh 31 | python dmm.py --experiment cy --train_sample_grid 1500 --branch_layers 4,3 --trunk_layers 16,512 32 | ``` 33 | 34 | ## Training of MM-PDE 35 | 36 | - Burgers' equation: 37 | ```code 38 | python mmpde.py --lr 6e-4 39 | ``` 40 | - Flow around a cylinder: 41 | ```code 42 | python mmpde.py --experiment cy --base_resolution 30,2521 43 | ``` 44 | 45 | ## Training of GNN 46 | 47 | - Burgers' equation: 48 | ```code 49 | python mmpde.py --lr 6e-4 --moving_mesh False 50 | ``` 51 | - Flow around a cylinder: 52 | ```code 53 | python mmpde.py --experiment cy --base_resolution 30,2521 --moving_mesh False 54 | ``` 55 | 56 | ## Citation 57 | 58 | If you find our work and/or our code useful, please cite us via: 59 | 60 | ```bibtex 61 | @inproceedings{ 62 | hu2024better, 63 | title={Better Neural {PDE} Solvers Through Data-Free Mesh Movers}, 64 | author={Peiyan Hu and Yue Wang and Zhi-Ming Ma}, 65 | booktitle={The Twelfth International Conference on Learning Representations}, 66 | year={2024} 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /PDEs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class PDE(nn.Module): 10 | """Generic PDE template""" 11 | def __init__(self): 12 | # Data params for grid and initial conditions 13 | super().__init__() 14 | pass 15 | 16 | def __repr__(self): 17 | return "PDE" 18 | 19 | 20 | class burgers(PDE): 21 | def __init__(self, 22 | tmin: float=None, 23 | tmax: float=None, 24 | grid_size: list=None, 25 | L: float=None, 26 | flux_splitting: str=None, 27 | device: torch.cuda.device = "cpu") -> None: 28 | 29 | # Data params for grid 30 | super().__init__() 31 | # Start and end time of the trajectory 32 | self.tmin = 0 if tmin is None else tmin 33 | self.tmax = 30 if tmax is None else tmax 34 | # Length of the spatial domain 35 | self.Lx = 1 if L is None else L 36 | self.Ly = 1 if L is None else L 37 | self.grid_size = (31, 96, 96) if grid_size is None else grid_size 38 | self.movingmesh_grid_size = (31, 96, 96) 39 | self.ori_grid_size = (31, 96, 96) 40 | self.dt = self.tmax / (self.grid_size[0]-1) 41 | self.device = device 42 | 43 | 44 | class cy(PDE): 45 | def __init__(self, 46 | tmin: float=None, 47 | tmax: float=None, 48 | grid_size: list=None, 49 | ori_grid: torch.Tensor=None, 50 | L: float=None, 51 | flux_splitting: str=None, 52 | device: torch.cuda.device = "cpu") -> None: 53 | 54 | # Data params for grid 55 | super().__init__() 56 | # Start and end time of the trajectory 57 | self.tmin = 0 if tmin is None else tmin 58 | self.tmax = 2.9 if tmax is None else tmax 59 | # Length of the spatial domain 60 | self.Lx = 1 if L is None else L 61 | self.Ly = 1 if L is None else L 62 | self.grid_size = (30, 2521) if grid_size is None else grid_size 63 | self.ori_grid_size = (30, 2521) if grid_size is None else grid_size 64 | self.movingmesh_grid_size = (30, 2521) if grid_size is None else grid_size 65 | self.ori_grid = ori_grid 66 | self.dt = self.tmax / (self.grid_size[0]-1) 67 | self.device = device 68 | 69 | -------------------------------------------------------------------------------- /interpolate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ItpNet(nn.Module): 6 | def __init__(self, ori_nx, ori_ny, layers1, layers2, layers3, normalize=False): 7 | super(ItpNet, self).__init__() 8 | self.n = 30 9 | 10 | self.layers1_node = [self.n * 2 + 2] + layers1 11 | self.layers1_node.append(self.n) 12 | self.n_layers1 = len(self.layers1_node) - 1 13 | assert self.n_layers1 >= 1 14 | self.layers = nn.ModuleList() 15 | self.act = torch.tanh 16 | 17 | for j in range(self.n_layers1): 18 | self.layers.append(nn.Linear(self.layers1_node[j], self.layers1_node[j+1])) 19 | 20 | if j != self.n_layers1 - 1: 21 | if normalize: 22 | self.layers.append(nn.BatchNorm1d(self.layers1_node[j+1])) 23 | 24 | self.layers2_node = [self.n * 2 + 2] + layers2 25 | self.layers2_node.append(self.n) 26 | self.n_layers2 = len(self.layers2_node) - 1 27 | self.layers2 = nn.ModuleList() 28 | self.act2 = torch.tanh 29 | 30 | for j in range(self.n_layers2): 31 | self.layers2.append(nn.Linear(self.layers2_node[j], self.layers2_node[j+1])) 32 | 33 | if j != self.n_layers2 - 1: 34 | if normalize: 35 | self.layers2.append(nn.BatchNorm1d(self.layers2_node[j+1])) 36 | 37 | if ori_ny != None: 38 | self.layers3_node = [ori_nx * ori_ny] + layers3 39 | self.layers3_node.append(ori_nx * ori_ny) 40 | else: 41 | self.layers3_node = [ori_nx] + layers3 42 | self.layers3_node.append(ori_nx) 43 | self.n_layers3 = len(self.layers3_node) - 1 44 | self.layers3 = nn.ModuleList() 45 | self.act3 = torch.tanh 46 | 47 | for j in range(self.n_layers3): 48 | self.layers3.append(nn.Linear(self.layers3_node[j], self.layers3_node[j+1])) 49 | 50 | if j != self.n_layers3 - 1: 51 | if normalize: 52 | self.layers3.append(nn.BatchNorm1d(self.layers3_node[j+1])) 53 | 54 | if ori_ny != None: 55 | self.down = nn.Sequential( 56 | nn.Conv2d(layers3[0], layers3[1], 5, padding=2), 57 | nn.Tanh(), 58 | nn.Conv2d(layers3[1], layers3[2], 5, padding=2), 59 | nn.Tanh(), 60 | nn.Conv2d(layers3[2], layers3[3], 5, padding=2), 61 | nn.Tanh(), 62 | nn.Conv2d(layers3[3], layers3[4], 5, padding=2), 63 | nn.Tanh(), 64 | ) 65 | else: 66 | self.down = nn.Sequential( 67 | nn.Linear(ori_nx, 2048), 68 | nn.Tanh(), 69 | nn.Linear(2048, 512), 70 | nn.Tanh(), 71 | nn.Linear(512, 2048), 72 | nn.Tanh(), 73 | nn.Linear(2048, ori_nx), 74 | ) 75 | 76 | 77 | def forward(self, neighbors, query_points, mode, data = None): 78 | 79 | if mode == '1': 80 | data = torch.cat((neighbors, query_points), dim=-2).reshape(neighbors.shape[0], neighbors.shape[1], -1) 81 | for _, l in enumerate(self.layers): 82 | if _ != self.n_layers1 - 1: 83 | data = self.act(l(data)) 84 | else: 85 | data = l(data) 86 | 87 | elif mode == '2': 88 | data = torch.cat((neighbors, query_points), dim=-2).reshape(neighbors.shape[0], neighbors.shape[1], -1) 89 | for _, l in enumerate(self.layers2): 90 | if _ != self.n_layers2 - 1: 91 | data = self.act2(l(data)) 92 | else: 93 | data = l(data) 94 | 95 | elif mode == 'res_cut': 96 | for _, l in enumerate(self.down): 97 | data = l(data) 98 | 99 | return data -------------------------------------------------------------------------------- /gnn_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch_geometric.nn import MessagePassing, global_mean_pool, InstanceNorm, avg_pool_x, BatchNorm 6 | # from einops import rearrange 7 | 8 | from IPython import embed 9 | 10 | class Swish(nn.Module): 11 | def __init__(self, beta=1): 12 | super(Swish, self).__init__() 13 | self.beta = beta 14 | 15 | def forward(self, x): 16 | return x * torch.sigmoid(self.beta*x) 17 | 18 | 19 | class GNN_Layer_FS_2D(MessagePassing): 20 | """ 21 | Parameters 22 | ---------- 23 | in_features : int 24 | Dimensionality of input features. 25 | out_features : int 26 | Dimensionality of output features. 27 | hidden_features : int 28 | Dimensionality of hidden features. 29 | """ 30 | def __init__(self, 31 | in_features, 32 | out_features, 33 | hidden_features, 34 | time_window, 35 | n_variables): 36 | super(GNN_Layer_FS_2D, self).__init__(node_dim=-2, aggr='mean') 37 | 38 | self.message_net_1 = nn.Sequential(nn.Linear(2 * in_features + time_window + 2 + n_variables, hidden_features), 39 | nn.ReLU() 40 | ) 41 | self.message_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features), 42 | nn.ReLU() 43 | ) 44 | self.update_net_1 = nn.Sequential(nn.Linear(in_features + hidden_features + n_variables, hidden_features), 45 | nn.ReLU() 46 | ) 47 | self.update_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features), 48 | nn.ReLU() 49 | ) 50 | 51 | self.norm = BatchNorm(hidden_features) 52 | 53 | def forward(self, x, u, pos_x, pos_y, variables, edge_index, batch): 54 | """ Propagate messages along edges """ 55 | x = self.propagate(edge_index, x=x, u=u, pos_x=pos_x, pos_y=pos_y, variables=variables) 56 | x = self.norm(x) 57 | return x 58 | 59 | def message(self, x_i, x_j, u_i, u_j, pos_x_i, pos_x_j, pos_y_i, pos_y_j, variables_i): 60 | """ Message update """ 61 | message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_x_i - pos_x_j, pos_y_i - pos_y_j, variables_i), dim=-1)) 62 | message = self.message_net_2(message) 63 | return message 64 | 65 | def update(self, message, x, variables): 66 | """ Node update """ 67 | update = self.update_net_1(torch.cat((x, message, variables), dim=-1)) 68 | update = self.update_net_2(update) 69 | return x + update 70 | 71 | 72 | class MP_PDE_Solver_2D(torch.nn.Module): 73 | def __init__( 74 | self, 75 | pde, 76 | time_window=1, 77 | hidden_features=128, 78 | hidden_layer=6, 79 | eq_variables={} 80 | ): 81 | 82 | super(MP_PDE_Solver_2D, self).__init__() 83 | self.pde = pde 84 | self.out_features = time_window 85 | self.hidden_features = hidden_features 86 | self.hidden_layer = hidden_layer 87 | self.time_window = time_window 88 | self.eq_variables = eq_variables 89 | 90 | # in_features have to be of the same size as out_features for the time being 91 | self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer_FS_2D( 92 | in_features=self.hidden_features, 93 | hidden_features=self.hidden_features, 94 | out_features=self.hidden_features, 95 | time_window=self.time_window, 96 | n_variables=len(self.eq_variables) + 1 # variables = eq_variables + time 97 | ) for _ in range(self.hidden_layer))) 98 | 99 | self.embedding_mlp = nn.Sequential( 100 | nn.Linear(self.time_window + 3 + len(self.eq_variables), self.hidden_features), 101 | nn.BatchNorm1d(self.hidden_features), 102 | nn.ReLU(), 103 | nn.Linear(self.hidden_features, self.hidden_features), 104 | nn.BatchNorm1d(self.hidden_features) 105 | #Swish() 106 | ) 107 | 108 | self.output_mlp = nn.Sequential(nn.Conv1d(1, 4, 16, stride=3), 109 | # nn.BatchNorm1d(8), 110 | nn.ReLU(), 111 | nn.Conv1d(4, 8, 12, stride=3), 112 | nn.ReLU(), 113 | nn.Conv1d(8, 1, 8, stride=2) 114 | ) 115 | 116 | def __repr__(self): 117 | return f'GNN' 118 | 119 | def forward(self, data): 120 | u = data.x 121 | pos = data.pos 122 | pos_x = pos[:, 1][:, None]/self.pde.Lx 123 | pos_y = pos[:, 2][:, None]/self.pde.Ly 124 | pos_t = pos[:, 0][:, None]/self.pde.tmax 125 | edge_index = data.edge_index 126 | batch = data.batch 127 | 128 | variables = pos_t # we put the time as equation variable 129 | 130 | node_input = torch.cat((u, pos_x, pos_y, variables), -1) 131 | h = self.embedding_mlp(node_input) 132 | for i in range(self.hidden_layer): 133 | h = self.gnn_layers[i](h, u, pos_x, pos_y, variables, edge_index, batch) 134 | 135 | 136 | diff = self.output_mlp(h[:, None]).squeeze(1) 137 | dt = (torch.ones(1, self.time_window) * self.pde.dt * 0.1).to(h.device) 138 | dt = torch.cumsum(dt, dim=1) 139 | out = dt * diff 140 | 141 | return out -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: mmpde 2 | channels: 3 | - rusty1s 4 | - pytorch 5 | - conda-forge 6 | - anaconda 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - _sysroot_linux-64_curr_repodata_hack=3=haa98f57_10 12 | - binutils_impl_linux-64=2.38=h2a08ee3_1 13 | - binutils_linux-64=2.38.0=hc2dff05_0 14 | - blas=1.0=mkl 15 | - bottleneck=1.3.5=py38h7deecbd_0 16 | - brotli=1.0.9=h5eee18b_7 17 | - brotli-bin=1.0.9=h5eee18b_7 18 | - brotli-python=1.0.9=py38h6a678d5_7 19 | - c-ares=1.19.1=h5eee18b_0 20 | - ca-certificates=2023.08.22=h06a4308_0 21 | - certifi=2023.7.22=py38h06a4308_0 22 | - cffi=1.15.1=py38h5eee18b_3 23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 24 | - colorama=0.4.6=pyhd8ed1ab_0 25 | - contourpy=1.0.5=py38hdb19cb5_0 26 | - cryptography=41.0.3=py38hdda0065_0 27 | - cudatoolkit=11.1.1=ha002fc5_10 28 | - cycler=0.11.0=pyhd3eb1b0_0 29 | - cyrus-sasl=2.1.28=h52b45da_1 30 | - dbus=1.13.18=hb2f20db_0 31 | - expat=2.5.0=h6a678d5_0 32 | - fontconfig=2.14.1=h4c34cd2_2 33 | - fonttools=4.25.0=pyhd3eb1b0_0 34 | - freetype=2.12.1=h4a9f257_0 35 | - gcc_impl_linux-64=11.2.0=h1234567_1 36 | - gcc_linux-64=11.2.0=h5c386dc_0 37 | - giflib=5.2.1=h5eee18b_3 38 | - glib=2.69.1=he621ea3_2 39 | - googledrivedownloader=0.4=pyhd3deb0d_1 40 | - gst-plugins-base=1.14.1=h6a678d5_1 41 | - gstreamer=1.14.1=h5eee18b_1 42 | - h5py=3.9.0=py38he06866b_0 43 | - hdf5=1.12.1=h2b7332f_3 44 | - icu=73.1=h6a678d5_0 45 | - idna=3.4=py38h06a4308_0 46 | - importlib_resources=6.1.0=py38h06a4308_0 47 | - intel-openmp=2023.1.0=hdb19cb5_46306 48 | - jinja2=3.1.2=pyhd8ed1ab_1 49 | - joblib=1.2.0=py38h06a4308_0 50 | - jpeg=9e=h5eee18b_1 51 | - kernel-headers_linux-64=3.10.0=h57e8cba_10 52 | - kiwisolver=1.4.4=py38h6a678d5_0 53 | - krb5=1.20.1=h143b758_1 54 | - lcms2=2.12=h3be6417_0 55 | - ld_impl_linux-64=2.38=h1181459_1 56 | - lerc=3.0=h295c915_0 57 | - libbrotlicommon=1.0.9=h5eee18b_7 58 | - libbrotlidec=1.0.9=h5eee18b_7 59 | - libbrotlienc=1.0.9=h5eee18b_7 60 | - libclang=14.0.6=default_hc6dbbc7_1 61 | - libclang13=14.0.6=default_he11475f_1 62 | - libcups=2.4.2=h2d74bed_1 63 | - libcurl=7.88.1=h251f7ec_2 64 | - libdeflate=1.17=h5eee18b_1 65 | - libedit=3.1.20221030=h5eee18b_0 66 | - libev=4.33=h7f8727e_1 67 | - libffi=3.4.4=h6a678d5_0 68 | - libgcc-devel_linux-64=11.2.0=h1234567_1 69 | - libgcc-ng=11.2.0=h1234567_1 70 | - libgfortran-ng=11.2.0=h00389a5_1 71 | - libgfortran5=11.2.0=h1234567_1 72 | - libgomp=11.2.0=h1234567_1 73 | - libllvm14=14.0.6=hdb19cb5_3 74 | - libnghttp2=1.57.0=h2d74bed_0 75 | - libpng=1.6.39=h5eee18b_0 76 | - libpq=12.15=hdbd6064_1 77 | - libssh2=1.10.0=hdbd6064_2 78 | - libstdcxx-ng=11.2.0=h1234567_1 79 | - libtiff=4.5.1=h6a678d5_0 80 | - libuuid=1.41.5=h5eee18b_0 81 | - libuv=1.43.0=h7f98852_0 82 | - libwebp=1.3.2=h11a3e52_0 83 | - libwebp-base=1.3.2=h5eee18b_0 84 | - libxcb=1.15=h7f8727e_0 85 | - libxkbcommon=1.0.1=h5eee18b_1 86 | - libxml2=2.10.4=hf1b16e4_1 87 | - lz4-c=1.9.4=h6a678d5_0 88 | - markupsafe=2.1.1=py38h7f8727e_0 89 | - matplotlib=3.7.2=py38h06a4308_0 90 | - matplotlib-base=3.7.2=py38h1128e8f_0 91 | - mkl=2023.1.0=h213fc3f_46344 92 | - mkl-service=2.4.0=py38h5eee18b_1 93 | - mkl_fft=1.3.8=py38h5eee18b_0 94 | - mkl_random=1.2.4=py38hdb19cb5_0 95 | - munkres=1.1.4=py_0 96 | - mysql=5.7.24=h721c034_2 97 | - ncurses=6.4=h6a678d5_0 98 | - networkx=2.8.8=pyhd8ed1ab_0 99 | - ninja=1.11.0=h924138e_0 100 | - numexpr=2.8.4=py38hc78ab66_1 101 | - numpy=1.24.3=py38hf6e8229_1 102 | - numpy-base=1.24.3=py38h060ed82_1 103 | - openjpeg=2.4.0=h3ad879b_0 104 | - openssl=3.0.12=h7f8727e_0 105 | - packaging=23.1=py38h06a4308_0 106 | - pandas=2.0.3=py38h1128e8f_0 107 | - pcre=8.45=h295c915_0 108 | - pillow=10.0.1=py38ha6cbd5a_0 109 | - pip=23.3=py38h06a4308_0 110 | - platformdirs=3.10.0=py38h06a4308_0 111 | - ply=3.11=py38_0 112 | - pooch=1.7.0=py38h06a4308_0 113 | - pycparser=2.21=pyhd3eb1b0_0 114 | - pyopenssl=23.2.0=py38h06a4308_0 115 | - pyparsing=3.0.9=py38h06a4308_0 116 | - pyqt=5.15.10=py38h6a678d5_0 117 | - pyqt5-sip=12.13.0=py38h5eee18b_0 118 | - pysocks=1.7.1=py38h06a4308_0 119 | - python=3.8.18=h955ad1f_0 120 | - python-dateutil=2.8.2=pyhd3eb1b0_0 121 | - python-louvain=0.16=pyhd8ed1ab_0 122 | - python-tzdata=2023.3=pyhd8ed1ab_0 123 | - python_abi=3.8=2_cp38 124 | - pytorch=1.9.1=py3.8_cuda11.1_cudnn8.0.5_0 125 | - pytorch-cluster=1.5.9=py38_torch_1.9.0_cu111 126 | - pytorch-geometric=2.0.3=py38_torch_1.9.0_cu111 127 | - pytorch-scatter=2.0.9=py38_torch_1.9.0_cu111 128 | - pytorch-sparse=0.6.12=py38_torch_1.9.0_cu111 129 | - pytorch-spline-conv=1.2.1=py38_torch_1.9.0_cu111 130 | - pytz=2023.3.post1=pyhd8ed1ab_0 131 | - pyyaml=6.0=py38h0a891b7_4 132 | - qt-main=5.15.2=h53bd1ea_10 133 | - readline=8.2=h5eee18b_0 134 | - requests=2.31.0=py38h06a4308_0 135 | - scikit-learn=1.3.0=py38h1128e8f_0 136 | - scipy=1.10.1=py38hf6e8229_1 137 | - setuptools=68.0.0=py38h06a4308_0 138 | - sip=6.7.12=py38h6a678d5_0 139 | - six=1.16.0=pyhd3eb1b0_1 140 | - sqlite=3.41.2=h5eee18b_0 141 | - sysroot_linux-64=2.17=h57e8cba_10 142 | - tbb=2021.8.0=hdb19cb5_0 143 | - threadpoolctl=2.2.0=pyh0d69192_0 144 | - tk=8.6.12=h1ccaba5_0 145 | - tomli=2.0.1=py38h06a4308_0 146 | - torchaudio=0.9.1=py38 147 | - torchvision=0.15.2=cpu_py38h83e0c9b_0 148 | - tornado=6.3.3=py38h5eee18b_0 149 | - tqdm=4.66.1=pyhd8ed1ab_0 150 | - typing_extensions=4.8.0=pyha770c72_0 151 | - urllib3=1.26.18=py38h06a4308_0 152 | - wheel=0.41.2=py38h06a4308_0 153 | - xz=5.4.2=h5eee18b_0 154 | - yacs=0.1.8=pyhd8ed1ab_0 155 | - yaml=0.2.5=h7f98852_2 156 | - zipp=3.11.0=py38h06a4308_0 157 | - zlib=1.2.13=h5eee18b_0 158 | - zstd=1.5.5=hc292b87_0 159 | - pip: 160 | - absl-py==2.0.0 161 | - asttokens==2.4.1 162 | - backcall==0.2.0 163 | - cachetools==5.3.2 164 | - cftime==1.6.3 165 | - comm==0.2.0 166 | - common==0.1.2 167 | - debugpy==1.8.0 168 | - decorator==5.1.1 169 | - executing==2.0.1 170 | - ffmpeg==1.4 171 | - google-auth==2.23.4 172 | - google-auth-oauthlib==1.0.0 173 | - grpcio==1.59.2 174 | - importlib-metadata==6.8.0 175 | - ipykernel==6.26.0 176 | - ipython==8.12.3 177 | - jedi==0.19.1 178 | - jupyter-client==8.6.0 179 | - jupyter-core==5.5.0 180 | - kaleido==0.2.1 181 | - llvmlite==0.41.1 182 | - markdown==3.5.1 183 | - matplotlib-inline==0.1.6 184 | - mpmath==1.3.0 185 | - nest-asyncio==1.5.8 186 | - netcdf4==1.6.5 187 | - normalization==0.4 188 | - numba==0.58.1 189 | - oauthlib==3.2.2 190 | - parso==0.8.3 191 | - pexpect==4.8.0 192 | - phiflow==2.5.1 193 | - phiml==1.2.1 194 | - pickleshare==0.7.5 195 | - plotly==5.18.0 196 | - prompt-toolkit==3.0.40 197 | - protobuf==4.25.0 198 | - psutil==5.9.6 199 | - ptyprocess==0.7.0 200 | - pure-eval==0.2.2 201 | - py-pde==0.33.1 202 | - pyasn1==0.5.0 203 | - pyasn1-modules==0.3.0 204 | - pygments==2.16.1 205 | - pytorch-minimize==0.0.2 206 | - pyzmq==25.1.1 207 | - requests-oauthlib==1.3.1 208 | - rsa==4.9 209 | - seaborn==0.13.1 210 | - stack-data==0.6.3 211 | - sympy==1.12 212 | - tenacity==8.2.3 213 | - tensorboard==2.14.0 214 | - tensorboard-data-server==0.7.2 215 | - torchdiffeq==0.2.3 216 | - traitlets==5.13.0 217 | - wcwidth==0.2.9 218 | - werkzeug==3.0.1 219 | prefix: /anaconda/envs/mmpde 220 | -------------------------------------------------------------------------------- /models_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from PDEs import PDE 5 | 6 | from IPython import embed 7 | 8 | class BaseCNN(nn.Module): 9 | ''' 10 | A simple baseline 2d Res CNN approach, the time dimension is stacked in the channels 11 | ''' 12 | def __init__(self, 13 | pde: PDE, 14 | time_window: int = 25, 15 | hidden_channels: int = 40, 16 | padding_mode: str = f'circular') -> None: 17 | """ 18 | Initialize the simple CNN architecture. It contains 8 2d CNN-layers with skip connections 19 | and increasing receptive field. 20 | The input to the forward pass has the shape [batch, time_window, x]. 21 | The output has the shape [batch, time_window, x]. 22 | Args: 23 | pde (PDE): the PDE at hand 24 | time_window (int): input/output timesteps of the trajectory 25 | hidden_channels: hidden channel dimension 26 | padding_mode (str): circular mode as default for periodic boundary problems 27 | Returns: 28 | None 29 | """ 30 | super().__init__() 31 | self.pde = pde 32 | self.time_window = time_window 33 | self.hidden_channels = hidden_channels 34 | self.padding_mode = padding_mode 35 | 36 | self.conv1 = nn.Conv2d(in_channels=self.time_window, out_channels=self.hidden_channels, kernel_size=3, padding=1, 37 | padding_mode=self.padding_mode, bias=True) 38 | self.conv2 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=5, padding=2, 39 | padding_mode=self.padding_mode, bias=True) 40 | self.conv3 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=5, padding=2, 41 | padding_mode=self.padding_mode, bias=True) 42 | self.conv4 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=5, padding=2, 43 | padding_mode=self.padding_mode, bias=True) 44 | self.conv5 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=7, padding=3, 45 | padding_mode=self.padding_mode, bias=True) 46 | self.conv6 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=7, padding=3, 47 | padding_mode=self.padding_mode, bias=True) 48 | self.conv7 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=7, padding=3, 49 | padding_mode=self.padding_mode, bias=True) 50 | self.conv8 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.time_window, kernel_size=9, padding=4, 51 | padding_mode=self.padding_mode, bias=True) 52 | 53 | nn.init.xavier_uniform_(self.conv1.weight) 54 | nn.init.xavier_uniform_(self.conv2.weight) 55 | nn.init.xavier_uniform_(self.conv3.weight) 56 | nn.init.xavier_uniform_(self.conv4.weight) 57 | nn.init.xavier_uniform_(self.conv5.weight) 58 | nn.init.xavier_uniform_(self.conv6.weight) 59 | nn.init.xavier_uniform_(self.conv7.weight) 60 | nn.init.xavier_uniform_(self.conv8.weight) 61 | 62 | 63 | 64 | def __repr__(self): 65 | return f'BaseCNN' 66 | 67 | def forward(self, u): 68 | """Forward pass of solver 69 | """ 70 | 71 | x = F.elu(self.conv1(u)) 72 | x = x + F.elu(self.conv2(x)) 73 | x = x + F.elu(self.conv3(x)) 74 | x = x + F.elu(self.conv4(x)) 75 | x = x + F.elu(self.conv5(x)) 76 | x = x + F.elu(self.conv6(x)) 77 | x = x + F.elu(self.conv7(x)) 78 | x = self.conv8(x) 79 | 80 | dt = (torch.ones(1, self.time_window) * self.pde.dt).to(x.device) 81 | dt = torch.cumsum(dt, dim=1)[None, :, :, None, None] 82 | out = u[:, -1, :, :][:, None, None, :, :].repeat(1, 1, self.time_window, 1, 1) + dt * x[:, None, :, :, :] 83 | return out.squeeze() 84 | 85 | 86 | class BaseCNN3d(nn.Module): 87 | ''' 88 | A simple baseline 2d Res CNN approach, the time dimension is stacked in the channels 89 | ''' 90 | def __init__(self, 91 | pde: PDE, 92 | time_window: int = 25, 93 | hidden_channels: int = 40, 94 | padding_mode: str = f'circular') -> None: 95 | """ 96 | Initialize the simple CNN architecture. It contains 8 2d CNN-layers with skip connections 97 | and increasing receptive field. 98 | The input to the forward pass has the shape [batch, time_window, x]. 99 | The output has the shape [batch, time_window, x]. 100 | Args: 101 | pde (PDE): the PDE at hand 102 | time_window (int): input/output timesteps of the trajectory 103 | hidden_channels: hidden channel dimension 104 | padding_mode (str): circular mode as default for periodic boundary problems 105 | Returns: 106 | None 107 | """ 108 | super().__init__() 109 | self.pde = pde 110 | self.time_window = time_window 111 | self.hidden_channels = hidden_channels 112 | self.padding_mode = padding_mode 113 | 114 | self.conv1 = nn.Conv3d(in_channels=self.time_window, out_channels=self.hidden_channels, kernel_size=3, padding=1, 115 | padding_mode=self.padding_mode, bias=True) 116 | self.conv2 = nn.Conv3d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, padding=1, 117 | padding_mode=self.padding_mode, bias=True) 118 | self.conv3 = nn.Conv3d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, padding=1, 119 | padding_mode=self.padding_mode, bias=True) 120 | self.conv4 = nn.Conv3d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, padding=1, 121 | padding_mode=self.padding_mode, bias=True) 122 | self.conv5 = nn.Conv3d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, padding=1, 123 | padding_mode=self.padding_mode, bias=True) 124 | # self.conv6 = nn.Conv3d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, padding=1, 125 | self.conv6 = nn.Conv3d(in_channels=self.hidden_channels, out_channels=self.time_window, kernel_size=3, padding=1, 126 | padding_mode=self.padding_mode, bias=True) 127 | self.conv7 = nn.Conv3d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, padding=1, 128 | padding_mode=self.padding_mode, bias=True) 129 | self.conv8 = nn.Conv3d(in_channels=self.hidden_channels, out_channels=self.time_window, kernel_size=9, padding=4, 130 | padding_mode=self.padding_mode, bias=True) 131 | 132 | nn.init.xavier_uniform_(self.conv1.weight) 133 | nn.init.xavier_uniform_(self.conv2.weight) 134 | nn.init.xavier_uniform_(self.conv3.weight) 135 | nn.init.xavier_uniform_(self.conv4.weight) 136 | nn.init.xavier_uniform_(self.conv5.weight) 137 | nn.init.xavier_uniform_(self.conv6.weight) 138 | nn.init.xavier_uniform_(self.conv7.weight) 139 | nn.init.xavier_uniform_(self.conv8.weight) 140 | 141 | 142 | 143 | def __repr__(self): 144 | return f'BaseCNN' 145 | 146 | def forward(self, u): 147 | """Forward pass of solver 148 | """ 149 | 150 | x = F.elu(self.conv1(u)) 151 | x = x + F.elu(self.conv2(x)) 152 | x = x + F.elu(self.conv3(x)) 153 | x = x + F.elu(self.conv4(x)) 154 | x = x + F.elu(self.conv5(x)) 155 | x = x + F.elu(self.conv6(x)) 156 | x = self.conv6(x) 157 | # x = x + F.elu(self.conv7(x)) 158 | # x = self.conv8(x) 159 | 160 | dt = (torch.ones(1, self.time_window) * self.pde.dt).to(x.device) 161 | dt = torch.cumsum(dt, dim=1)[None, :, :, None, None, None] 162 | out = u[:, -1, :, :, :][:, None, None, :, :, :].repeat(1, 1, self.time_window, 1, 1, 1) + dt * x[:, None, :, :, :, :] 163 | return out.squeeze() -------------------------------------------------------------------------------- /mesh/dmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import matplotlib.pyplot as plt 5 | import argparse 6 | import datetime 7 | 8 | from dmm_model import DMM 9 | from dmm_utils import * 10 | 11 | 12 | def mkdir(path): 13 | folder = os.path.exists(path) 14 | if not folder: 15 | os.makedirs(path) 16 | 17 | 18 | def get_args(argv=None): 19 | parser = argparse.ArgumentParser(description = 'Put your hyperparameters') 20 | 21 | parser.add_argument('--experiment', default='burgers', type=str, help='experiment: burgers | cy') 22 | parser.add_argument('--seed', default=0, type=int, help='random seed') 23 | parser.add_argument('--device', type=str, default='cuda:0', help='used device') 24 | parser.add_argument('--sub_u', default=4, type=int, help='subsample number when sampling') 25 | parser.add_argument('--train_sample_grid', default=5000, type=int, help='number of training grids per u') # 5000, 1500 26 | parser.add_argument('--test_grid_size', default=[6, 10, 20, 40], type=int, help='grid size for plotting') 27 | parser.add_argument('--branch_layers', type=lambda s: [int(item) for item in s.split(',')], default=7, metavar='N',\ 28 | help='number of hidden nodes of branch network') # 7, [4, 3] 29 | parser.add_argument('--trunk_layers', type=lambda s: [int(item) for item in s.split(',')], default=[32, 512], metavar='N',\ 30 | help='number of hidden nodes of trunk network') # [32, 512], [16, 512] 31 | parser.add_argument('--out_layers', type=lambda s: [int(item) for item in s.split(',')], default=[1024, 512, 1], metavar='N',\ 32 | help='number of hidden nodes of decoder network') 33 | parser.add_argument('--bound_constraint', default='soft', type=str, help='constraint of boundary condition: soft | hard') 34 | parser.add_argument('--batch_size_x_adam', default=120, type=int, help='batch size of training grids per u') # 120 35 | parser.add_argument('--batch_size_u_adam', default=160, type=int, help='batch size of u (should be divisible by sub_u)') # 160 36 | parser.add_argument('--batch_size_x_lbfgs', default=100, type=int, help='batch size') # 100 37 | parser.add_argument('--batch_size_u_lbfgs', default=120, type=int, help='batch size') # 120 38 | 39 | parser.add_argument('--rf', default=True, type=eval, help='random feature: True | False') 40 | parser.add_argument('--rf_opt_alg', default='BFGS', type=str, help='optimization algorithm of random feature method: BFGS | Newton') 41 | parser.add_argument('--convex_rel', default=0.00, type=float, help='hyperparameter of convex relaxation') 42 | parser.add_argument('--batch_size_x_rf', default=16, type=int, help='batch size') # 100 43 | parser.add_argument('--batch_size_u_rf', default=20, type=int, help='batch size') # 120 44 | parser.add_argument('--loss_bound_rf', default=True, type=eval, help='bound constraint of random feature method: True | False') 45 | parser.add_argument('--max_iter', default=300, type=int, help='max iteration of rf algorithm') 46 | parser.add_argument('--epochs_adam', default=150, type=int, help='number of epochs of Adam optimizer') # 200 47 | parser.add_argument('--epochs_lbfgs', default=0, type=int, help='number of epochs of LBFGS optimizer') # 25, 0 48 | parser.add_argument('--epochs_rf', default=5, type=int, help='number of epochs of random feature') 49 | parser.add_argument('--lr_adam', default=2e-4, type=float, help='learning rate') # 2e-4 50 | parser.add_argument('--lr_lbfgs', default=1e-3, type=float, help='learning rate') 51 | parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') 52 | parser.add_argument('--gamma_adam', default=0.2, type=float, help='gamma of Adam optimizer') 53 | parser.add_argument('--gamma_lbfgs', default=0.2, type=float, help='gamma of LBFGS optimizer') 54 | parser.add_argument('--loss_weight0', default=1, type=float, help='weight of loss_in') 55 | parser.add_argument('--loss_weight1', default=1000, type=float, help='weight of loss_bound') 56 | parser.add_argument('--loss_weight2', default=1, type=float, help='weight of loss_convex') 57 | parser.add_argument('--loss_convex', default=True, type=eval, help='convex constraint: True | False') 58 | 59 | return parser.parse_args(argv) 60 | 61 | 62 | if __name__ == "__main__": 63 | args = get_args() 64 | print(args) 65 | 66 | torch.manual_seed(args.seed) 67 | np.random.seed(args.seed) 68 | 69 | device = args.device 70 | 71 | if args.experiment == 'burgers': 72 | ori_u = torch.tensor(np.load('data/burgers_192.npy'), dtype=torch.float).to(device).reshape(-1, 192, 192) 73 | u = torch.tensor(np.load('data/burgers_192.npy'), dtype=torch.float)[:80, :, ::args.sub_u, ::args.sub_u].reshape(-1, int(192/args.sub_u), int(192/args.sub_u)) 74 | test_u = torch.tensor(np.load('data/burgers_192.npy'), dtype=torch.float)[80:, :, ::args.sub_u, ::args.sub_u].reshape(-1, int(192/args.sub_u), int(192/args.sub_u)) 75 | elif args.experiment == 'cy': 76 | ori_u = torch.load('data/cylinder_rot_tri') 77 | u = torch.load('data/cylinder_rot_tri')[:80, 10:].reshape(-1, ori_u.shape[-2], 5) 78 | # scale to a 1*1 square 79 | u[:, :, :2] *= 2 80 | test_u = torch.load('data/cylinder_rot_tri')[80:, 10:].reshape(-1, ori_u.shape[-2], 5) 81 | test_u[:, :, :2] *= 2 82 | 83 | if args.experiment == 'burgers': 84 | mkdir('burgers') 85 | model = DMM(s=u.shape[-1], mode='array', branch_layer = args.branch_layers, trunk_layer = [2] + args.trunk_layers, out_layer = args.out_layers).to(device) 86 | elif args.experiment == 'cy': 87 | mkdir('cy') 88 | model = DMM(mode='graph', grid = u[0, :, :2].to(device), branch_layer = args.branch_layers, trunk_layer = [2] + args.trunk_layers, out_layer = args.out_layers).to(device) 89 | 90 | print('Train moving mesh operator:') 91 | model, loss_in, loss_bound, loss_convex, test_equ_loss, test_equ_max, test_equ_min, test_equ_mid,\ 92 | train_std_list, train_minmax_list, test_std_list, test_minmax_list, itp_list1, itp_list2, logs_txt\ 93 | = train_MA_res(ori_u, u, test_u, args, model, init_mesh=False, n_epoch_adam=args.epochs_adam, n_epoch_lbfgs=args.epochs_lbfgs, device=device) 94 | print('Finish!') 95 | 96 | 97 | # plot mesh 98 | if args.experiment == 'burgers': 99 | for s in args.test_grid_size: 100 | fig, axes = plt.subplots(1, 5, figsize=(20, 3), dpi=500) 101 | fig, axes = plot_mesh_res(s, u, model, fig, axes, args, device) 102 | save_path = "{}/{}_{}_bound{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.png"\ 103 | .format(args.experiment, datetime.now(), args.rf, args.loss_bound_rf, args.epochs_rf, args.max_iter, args.sub_u,\ 104 | args.epochs_adam, args.batch_size_u_adam, args.batch_size_x_adam, args.loss_weight1, args.train_sample_grid, args.branch_layers, args.lr_adam, args.trunk_layers, s) 105 | plt.savefig(save_path) 106 | print(save_path) 107 | elif args.experiment == 'cy': 108 | for s in args.test_grid_size: 109 | fig, axes = plt.subplots(1, 5, figsize=(20, 3), dpi=500) 110 | fig, axes = plot_mesh_res_tri_s(s, u, model, fig, axes, args, device) 111 | save_path = "{}/{}_{}_bound{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.png"\ 112 | .format(args.experiment, datetime.now(), args.rf, args.loss_bound_rf, args.epochs_rf, args.max_iter, args.sub_u,\ 113 | args.epochs_adam, args.batch_size_u_adam, args.batch_size_x_adam, args.loss_weight1, args.train_sample_grid, args.branch_layers, args.lr_adam, args.trunk_layers, s) 114 | plt.savefig(save_path) 115 | print(save_path) 116 | fig, axes = plt.subplots(1, 5, figsize=(20, 3), dpi=500) 117 | fig, axes = plot_mesh_res_tri(u, model, fig, axes, args, device) 118 | savepath = "{}/{}_{}_bound{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.png"\ 119 | .format(args.experiment, datetime.now(), args.rf, args.loss_bound_rf, args.epochs_rf, args.max_iter, args.sub_u, args.epochs_adam, \ 120 | args.batch_size_u_adam, args.batch_size_x_adam, args.loss_weight1, args.train_sample_grid, args.branch_layers, args.lr_adam, args.trunk_layers) 121 | plt.savefig(savepath) 122 | print(savepath) 123 | 124 | with open("{}/{}_{}_bound_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.txt".format(args.experiment, datetime.now(), args.rf, args.loss_bound_rf, args.epochs_rf, args.max_iter, args.sub_u, args.epochs_adam, args.batch_size_u_adam, args.batch_size_x_adam, args.loss_weight1, args.train_sample_grid, args.branch_layers, args.lr_adam, args.trunk_layers),"w") as f: 125 | f.write('\n'.join(logs_txt)) 126 | -------------------------------------------------------------------------------- /train_helper_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torch import nn, optim 4 | from torch.utils.data import DataLoader 5 | from data_creator_2d import GraphCreator_FS_2D 6 | # from PDEs import * 7 | 8 | 9 | def training_itp(itp_model: torch.nn.Module, 10 | mesh_model: torch.nn.Module, 11 | unrolling: list, 12 | batch_size: int, 13 | optimizer: torch.optim, 14 | optimizer2: torch.optim, 15 | loader: DataLoader, 16 | graph_creator: GraphCreator_FS_2D, 17 | criterion: torch.nn.modules.loss, 18 | device: torch.cuda.device="cpu") -> torch.Tensor: 19 | """ 20 | One training epoch with random starting points for every trajectory 21 | Args: 22 | mesh_model (torch.nn.Module): moving mesh operator 23 | unrolling (list): list of different unrolling steps for each batch entry 24 | batch_size (int): batch size 25 | optimizer (torch.optim): optimizer used for training 26 | loader (DataLoader): training dataloader 27 | graph_creator (GraphCreator_FS_2D): helper object to handle graph data 28 | criterion (torch.nn.modules.loss): criterion for training 29 | device (torch.cuda.device): device (cpu/gpu) 30 | Returns: 31 | torch.Tensor: training losses 32 | """ 33 | 34 | losses = [] 35 | for (u_base, u_super) in loader: 36 | optimizer.zero_grad() 37 | if optimizer2 != None: 38 | optimizer2.zero_grad() 39 | # Randomly choose number of unrollings 40 | unrolled_graphs = random.choice(unrolling) 41 | steps = [t for t in range(graph_creator.tw, 42 | graph_creator.t_res - graph_creator.tw - (graph_creator.tw * unrolled_graphs) + 1)] 43 | # Randomly choose starting (time) point at the PDE solution manifold 44 | random_steps = random.choices(steps, k=batch_size) 45 | data, labels = graph_creator.create_data(u_super, random_steps) 46 | 47 | graph = graph_creator.create_graph(itp_model, data, labels, random_steps, device, mesh_model) 48 | 49 | itp_u = graph.x 50 | u_uni = graph_creator.interpolate_pred(itp_model, itp_u, graph, data, device) 51 | # data_uni = graph_creator.interpolate_label(data, device) 52 | data = data.to(device) 53 | loss = criterion(u_uni, data.reshape(-1, 1)) 54 | 55 | loss.backward() 56 | losses.append(loss.detach() / 2) 57 | optimizer.step() 58 | if optimizer2 != None: 59 | optimizer2.step() 60 | 61 | losses = torch.stack(losses) 62 | return losses 63 | 64 | 65 | def training_loop_branch(model: torch.nn.Module, 66 | model_b: torch.nn.Module, 67 | itp_model: torch.nn.Module, 68 | mesh_model: torch.nn.Module, 69 | unrolling: list, 70 | batch_size: int, 71 | optimizer: torch.optim, 72 | optimizer2: torch.optim, 73 | loader: DataLoader, 74 | graph_creator: GraphCreator_FS_2D, 75 | criterion: torch.nn.modules.loss, 76 | device: torch.cuda.device="cpu") -> torch.Tensor: 77 | """ 78 | One training epoch with random starting points for every trajectory 79 | Args: 80 | model (torch.nn.Module): neural network PDE solver 81 | model_b (torch.nn.Module): branch neural network PDE solver 82 | mesh_model (torch.nn.Module): moving mesh operator 83 | unrolling (list): list of different unrolling steps for each batch entry 84 | batc-h_size (int): batch size 85 | optimizer (torch.optim): optimizer used for training 86 | loader (DataLoader): training dataloader 87 | graph_creator (GraphCreator_FS_2D): helper object to handle graph data 88 | criterion (torch.nn.modules.loss): criterion for training 89 | device (torch.cuda.device): device (cpu/gpu) 90 | Returns: 91 | torch.Tensor: training losses 92 | """ 93 | 94 | losses = [] 95 | for idx, (u_base, u_super) in enumerate(loader): 96 | optimizer.zero_grad() 97 | if optimizer2 != None: 98 | optimizer2.zero_grad() 99 | # Randomly choose number of unrollings 100 | unrolled_graphs = random.choice(unrolling) 101 | steps = [t for t in range(graph_creator.tw, 102 | graph_creator.t_res - graph_creator.tw - (graph_creator.tw * unrolled_graphs) + 1)] 103 | # Randomly choose starting (time) point at the PDE solution manifold 104 | random_steps = random.choices(steps, k=batch_size) 105 | data, labels = graph_creator.create_data(u_super, random_steps) 106 | 107 | if f'{model}' == 'GNN': 108 | graph = graph_creator.create_graph(itp_model, data, labels, random_steps, device, mesh_model) 109 | graph_uni = graph_creator.create_graph(itp_model, data, labels, random_steps, device, None) 110 | else: 111 | data, labels = data.to(device), labels.to(device) 112 | 113 | # Unrolling of the equation which serves as input at the current step 114 | if f'{model}' == 'GNN': 115 | if mesh_model != None: 116 | pred = graph_creator.interpolate_pred(itp_model, model_b(graph), graph, data, device) + model(graph_uni) 117 | else: 118 | pred = model(graph_uni) 119 | # labels_uni = graph_creator.interpolate_label(labels, device) 120 | labels = labels.to(device) 121 | loss = criterion(pred, labels.reshape(-1, 1)) 122 | else: 123 | pred = model(data) 124 | loss = criterion(pred, labels.squeeze()) 125 | 126 | loss.backward() 127 | losses.append(loss.detach()) 128 | optimizer.step() 129 | if optimizer2 != None: 130 | if idx % 1 == 0: 131 | optimizer2.step() 132 | 133 | losses = torch.stack(losses) 134 | return losses 135 | 136 | 137 | def test_timestep_losses(model: torch.nn.Module, 138 | model_b: torch.nn.Module, 139 | itp_model: torch.nn.Module, 140 | mesh_model: torch.nn.Module, 141 | steps: list, 142 | batch_size: int, 143 | loader: DataLoader, 144 | graph_creator: GraphCreator_FS_2D, 145 | criterion: torch.nn.modules.loss, 146 | device: torch.cuda.device = "cpu") -> None: 147 | """ 148 | Loss for one neural network forward pass at certain timepoints on the validation/test datasets 149 | Args: 150 | model (torch.nn.Module): neural network PDE solver 151 | model_b (torch.nn.Module): branch neural network PDE solver 152 | mesh_model (torch.nn.Module): moving mesh operator 153 | steps (list): input list of possible starting (time) points 154 | batch_size (int): batch size 155 | loader (DataLoader): dataloader [valid, test] 156 | graph_creator (GraphCreator_FS_2D): helper object to handle graph data 157 | criterion (torch.nn.modules.loss): criterion for training 158 | device (torch.cuda.device): device (cpu/gpu) 159 | Returns: 160 | None 161 | """ 162 | 163 | losses_t = [] 164 | losses_uni_t = [] 165 | for step in steps: 166 | 167 | if (step != graph_creator.tw and step % graph_creator.tw != 0): 168 | continue 169 | 170 | losses = [] 171 | for (u_base, u_super) in loader: 172 | same_steps = [step]*batch_size 173 | data, labels = graph_creator.create_data(u_super, same_steps) 174 | if f'{model}' == 'GNN': 175 | if mesh_model != None: 176 | graph = graph_creator.create_graph(itp_model, data, labels, same_steps, device, mesh_model) 177 | graph_uni = graph_creator.create_graph(itp_model, data, labels, same_steps, device, None) 178 | with torch.no_grad(): 179 | if f'{model}' == 'GNN': 180 | if mesh_model != None: 181 | pred = graph_creator.interpolate_pred(itp_model, model_b(graph), graph, data, device) + model(graph_uni) 182 | else: 183 | pred = model(graph_uni) 184 | labels = labels.to(device) 185 | loss = criterion(pred, labels.reshape(-1, 1)) 186 | else: 187 | data, labels = data.to(device), labels.to(device) 188 | pred = model(data) 189 | loss = criterion(pred, labels.squeeze()) 190 | losses.append(loss) 191 | 192 | losses = torch.stack(losses) 193 | losses_t.append(torch.mean(losses)) 194 | if step % 2 == 1: 195 | print(f'Step {step}, time step loss {torch.mean(losses)}') 196 | 197 | losses_t = torch.stack(losses_t) 198 | print(f'Mean Timestep Test Error: {torch.mean(losses_t)}') 199 | 200 | return torch.mean(losses_t) 201 | 202 | -------------------------------------------------------------------------------- /mesh/dmm_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch_geometric.nn import MessagePassing, global_mean_pool, InstanceNorm, avg_pool_x, BatchNorm 5 | from torch_cluster import radius_graph, knn_graph 6 | from torch_geometric.data import Data 7 | 8 | 9 | class DenseNet(nn.Module): 10 | def __init__(self, layers, width=32, normalize=False): 11 | super(DenseNet, self).__init__() 12 | 13 | self.n_layers = len(layers) - 1 14 | assert self.n_layers >= 1 15 | self.layers = nn.ModuleList() 16 | self.act = torch.tanh 17 | self.normalize = normalize 18 | 19 | for j in range(self.n_layers): 20 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 21 | 22 | if j != self.n_layers - 1: 23 | if normalize: 24 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 25 | 26 | self.width = width 27 | self.center = torch.tensor([0.5,0.5], device="cuda").reshape(1,2) 28 | self.B = np.pi*torch.pow(2, torch.arange(0, self.width//4, dtype=torch.float, device="cuda")).reshape(1,1,1,self.width//4) 29 | self.fc0 = nn.Linear(4, self.width) 30 | 31 | def forward(self, x): 32 | if self.normalize: 33 | for _, l in enumerate(self.layers): 34 | if _ != 2 * self.n_layers - 2 and _%2 != 1: 35 | x = self.act(l(x)) 36 | else: 37 | out = l(x) 38 | else: 39 | for _, l in enumerate(self.layers): 40 | if _ != self.n_layers - 1: 41 | x = self.act(l(x)) 42 | else: 43 | out = l(x) 44 | 45 | return out, x 46 | 47 | 48 | class ConvNet(nn.Module): 49 | def __init__(self, s, layers): 50 | super().__init__() 51 | self.layers = nn.ModuleList() 52 | 53 | if layers == 7: 54 | self.layers.append(nn.Conv2d(1, 8, 5, stride=2, padding=2)) 55 | self.layers.append(nn.Conv2d(8, 16, 5, padding=2)) 56 | self.layers.append(nn.Conv2d(16, 8, 5, padding=2)) 57 | self.layers.append(nn.Conv2d(8, 1, 5, stride=2, padding=2)) 58 | self.fc1 = None 59 | self.fc2 = nn.Linear(int(((s + 1) / 2 + 1) / 2)**2, 1024) 60 | self.fc3 = nn.Linear(1024, 512) # burgers: 1024,512 61 | 62 | self.act = torch.tanh 63 | 64 | 65 | def forward(self, x): 66 | for i, l in enumerate(self.layers): 67 | # x = self.pool(self.act(l(x))) 68 | if i != len(self.layers) - 2: 69 | x = self.act(l(x)) 70 | if i == 0: 71 | ori_x = x 72 | if i == len(self.layers) - 2: 73 | x = self.act(ori_x + l(x)) 74 | x = torch.flatten(x, 1) # flatten all dimensions except batch 75 | # x = self.bn(x) 76 | if self.fc1 != None: 77 | x = self.act(self.fc1(x)) 78 | if self.fc2 != None: 79 | x = self.act(self.fc2(x)) 80 | x = self.fc3(x) 81 | return x 82 | 83 | 84 | 85 | class Swish(nn.Module): 86 | def __init__(self, beta=1): 87 | super(Swish, self).__init__() 88 | self.beta = beta 89 | 90 | def forward(self, x): 91 | return x * torch.sigmoid(self.beta*x) 92 | 93 | 94 | class GNN_Layer_FS_2D(MessagePassing): 95 | """ 96 | Parameters 97 | ---------- 98 | in_features : int 99 | Dimensionality of input features. 100 | out_features : int 101 | Dimensionality of output features. 102 | hidden_features : int 103 | Dimensionality of hidden features. 104 | """ 105 | def __init__(self, 106 | in_features, 107 | out_features, 108 | hidden_features,): 109 | super(GNN_Layer_FS_2D, self).__init__(node_dim=-2, aggr='mean') 110 | 111 | self.message_net_1 = nn.Sequential(nn.Linear(2 * in_features + 3, hidden_features), 112 | nn.Tanh() 113 | ) 114 | self.message_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features), 115 | nn.Tanh() 116 | ) 117 | self.update_net_1 = nn.Sequential(nn.Linear(in_features + hidden_features, hidden_features), 118 | nn.Tanh() 119 | ) 120 | self.update_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features), 121 | nn.Tanh() 122 | ) 123 | 124 | self.norm = BatchNorm(hidden_features) 125 | 126 | def forward(self, x, u, pos_x, pos_y, edge_index, batch): 127 | """ Propagate messages along edges """ 128 | x = self.propagate(edge_index, x=x, u=u, pos_x=pos_x, pos_y=pos_y) 129 | x = self.norm(x) 130 | return x 131 | 132 | def message(self, x_i, x_j, u_i, u_j, pos_x_i, pos_x_j, pos_y_i, pos_y_j): 133 | """ Message update """ 134 | message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_x_i - pos_x_j, pos_y_i - pos_y_j), dim=-1)) 135 | message = self.message_net_2(message) 136 | return message 137 | 138 | def update(self, message, x): 139 | """ Node update """ 140 | update = self.update_net_1(torch.cat((x, message), dim=-1)) 141 | update = self.update_net_2(update) 142 | return x + update 143 | 144 | 145 | class DMM(nn.Module): 146 | def __init__(self, branch_layer, trunk_layer, grid=None, out_layer=None, s=None, mode='array'): 147 | super(DMM, self).__init__() 148 | self.mode = mode 149 | self.ori_grid = grid 150 | if mode == 'array': 151 | self.branch = ConvNet(s, branch_layer) 152 | self.trunk = DenseNet(trunk_layer) 153 | self.out_nn = DenseNet(out_layer) 154 | elif mode == 'graph': 155 | self.hidden_features = branch_layer[0] 156 | self.hidden_layer = branch_layer[1] 157 | 158 | # in_features have to be of the same size as out_features for the time being 159 | self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer_FS_2D( 160 | in_features=self.hidden_features, 161 | hidden_features=self.hidden_features, 162 | out_features=self.hidden_features, 163 | ) for _ in range(self.hidden_layer))) 164 | 165 | self.embedding_mlp = nn.Sequential( 166 | nn.Linear(3, self.hidden_features), 167 | nn.BatchNorm1d(self.hidden_features), 168 | nn.Tanh(), 169 | nn.Linear(self.hidden_features, self.hidden_features), 170 | nn.BatchNorm1d(self.hidden_features) 171 | #Swish() 172 | ) 173 | self.decoding_mlp = DenseNet([self.hidden_features, 128, 1]) 174 | 175 | self.output_mlp = nn.Sequential( 176 | nn.Linear(grid.shape[0], 512), 177 | nn.Tanh(), 178 | nn.Linear(512, 256), 179 | nn.Tanh(), 180 | nn.Linear(256, trunk_layer[-1]) 181 | ) 182 | self.trunk = DenseNet(trunk_layer) 183 | self.out_nn = DenseNet(out_layer) 184 | 185 | def forward(self, u, grid, rf = False): 186 | if self.mode == 'array': 187 | branch = self.branch(u.unsqueeze(1)).unsqueeze(1).repeat(1, int(grid.shape[0]/u.shape[0]), 1) 188 | # (batchsize, S, S) -> (batchsize, latent) -> (batchsize, grid_per_u, latent) 189 | trunk, second_out = self.trunk(grid) 190 | out, second_out = self.out_nn(torch.cat((branch.reshape(-1, branch.shape[-1]), trunk.reshape(-1, branch.shape[-1])), dim=-1)) 191 | if rf == False: 192 | return out # (batchsize) 193 | else: 194 | return out, second_out, torch.ones_like(second_out).type_as(trunk).reshape(-1, 1) 195 | 196 | elif self.mode == 'graph': 197 | data = create_graph(u, self.ori_grid, device=u.device) 198 | x = data.x 199 | pos = data.pos 200 | pos_x = pos[:, 0][:, None] 201 | pos_y = pos[:, 1][:, None] 202 | edge_index = data.edge_index 203 | batch = data.batch 204 | 205 | node_input = torch.cat((x, pos_x, pos_y), -1) 206 | h = self.embedding_mlp(node_input) 207 | for i in range(self.hidden_layer): 208 | h = self.gnn_layers[i](h, x, pos_x, pos_y, edge_index, batch) 209 | h, _ = self.decoding_mlp(h) 210 | branch = self.output_mlp(h.reshape(u.shape[0], 1, -1)).repeat(1, int(grid.shape[0]/u.shape[0]), 1) 211 | trunk, _ = self.trunk(grid) 212 | 213 | out, second_out = self.out_nn(torch.cat((branch.reshape(-1, branch.shape[-1]), trunk.reshape(-1, branch.shape[-1])), dim=-1)) 214 | if rf == False: 215 | return out # (batchsize) 216 | # return trunk # (batchsize) 217 | else: 218 | return out, second_out, torch.ones_like(second_out).type_as(trunk).reshape(-1, 1) 219 | # return trunk, second_out, torch.ones((trunk.shape[0], 1)).type_as(trunk).reshape(-1, 1) 220 | 221 | 222 | def create_graph(data, grid, device, n=35): 223 | """ 224 | getting graph structure out of data sample 225 | """ 226 | 227 | batch = torch.arange(0, data.shape[0], 1).to(device)[:, None].repeat(1, grid.shape[0]).reshape(-1) 228 | edge_index = knn_graph(grid[None].repeat(data.shape[0], 1, 1).reshape(-1, 2), n, batch=batch.long(), loop=False) 229 | 230 | graph = Data(x=data.reshape(-1, 1), edge_index=edge_index) 231 | graph.pos = grid[None].repeat(data.shape[0], 1, 1).reshape(-1, 2) 232 | graph.batch = batch.long() 233 | 234 | return graph.to(device) 235 | -------------------------------------------------------------------------------- /data_creator_2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy as np 5 | import torch 6 | import sys 7 | 8 | from torch.utils.data import Dataset 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch_geometric.data import Data 12 | from torch_cluster import radius_graph, knn_graph 13 | from sklearn.neighbors import NearestNeighbors 14 | from interpolate import ItpNet 15 | # from einops import rearrange 16 | 17 | 18 | class GraphCreator_FS_2D(nn.Module): 19 | """ 20 | Helper class to construct graph datasets 21 | params: 22 | neighbors: now many neighbors the graph has in each direction 23 | time_window: how many time steps are used for PDE prediction 24 | time_ratio: time ratio between base and super resolution 25 | space_ratio: space ratio between base and super resolution 26 | """ 27 | 28 | def __init__(self, 29 | pde, 30 | neighbors: int=2, 31 | connect_edge: str='knn', 32 | time_window: int=10, 33 | t_resolution: int=100, 34 | ): 35 | super().__init__() 36 | self.pde = pde 37 | self.n = neighbors 38 | self.e = connect_edge 39 | self.tw = time_window 40 | self.t_res = t_resolution 41 | 42 | assert isinstance(self.n, int) 43 | assert isinstance(self.tw, int) 44 | 45 | 46 | def interpolate(self, itp_model, u, init_x, init_y, x, y, mode): 47 | """ 48 | u: (nu,nx,ny) 49 | init_x: (nu*nx*ny,1) 50 | init_y: (nu*nx*ny,1) 51 | x: (nu*nx'*ny',1) 52 | y: (nu*nx'*ny',1) 53 | return: interpolated: (nu*nx'*ny') 54 | """ 55 | nu = u.shape[0] 56 | nx = u.shape[-2] 57 | ny = u.shape[-1] 58 | output_res = int(x.shape[0] / nu) 59 | 60 | all_points = torch.cat((init_x, init_y), -1).reshape(nu, -1, 2) 61 | all_query_points = torch.cat((x, y), dim=-1).reshape(nu, -1, 2) # (8, 2304, 2) 62 | if mode == '1': 63 | n_neighbors = 30 64 | elif mode == '2': 65 | n_neighbors = 30 66 | knn = NearestNeighbors(n_neighbors=n_neighbors) 67 | weights = [] 68 | neighbors = [] 69 | neighbor_labels = [] 70 | for k in range(nu): 71 | labels = u[k].reshape(-1) 72 | points = all_points[k] 73 | query_points = all_query_points[k] 74 | 75 | knn.fit(points.detach().cpu().numpy()) 76 | distances, indices = knn.kneighbors(query_points.detach().cpu().numpy()) 77 | neighbors.append(points[indices].to(u.device)) 78 | neighbor_labels.append(labels[indices]) 79 | 80 | neighbors = torch.stack(neighbors) # [8, 2304, n, 2] 81 | neighbor_labels = torch.stack(neighbor_labels) # [8, 2304, n] 82 | weights = itp_model(neighbors, all_query_points.unsqueeze(-2), mode) 83 | interpolated = torch.sum(weights * neighbor_labels, dim=-1).reshape(-1) 84 | 85 | return interpolated 86 | 87 | 88 | def moving_mesh(self, u, mesh_model, n_grid_x, n_grid_y): 89 | """ 90 | getting the moved mesh 91 | u: (nu,nx,ny) 92 | return: x1, x2: (nu*nx*ny, 1) 93 | """ 94 | grid_x = np.linspace(0, self.pde.Lx, n_grid_x) 95 | grid_y = np.linspace(0, self.pde.Ly, n_grid_y) 96 | grid = torch.tensor(np.array(np.meshgrid(grid_x, grid_y)), dtype=torch.float).reshape(2, -1).permute(1, 0).to(u.device) 97 | xi1, xi2 = grid[:, [0]].unsqueeze(0).repeat(u.shape[0], 1, 1).reshape(-1, 1), grid[:, [1]].unsqueeze(0).repeat(u.shape[0], 1, 1).reshape(-1, 1) 98 | xi1.requires_grad = True 99 | xi2.requires_grad = True 100 | xi = torch.cat((xi1, xi2), dim=-1) 101 | 102 | if self.pde.movingmesh_grid_size[-2] != n_grid_x or self.pde.movingmesh_grid_size[-1] != n_grid_y: 103 | u = F.interpolate(u.reshape(-1, 1, u.shape[-2], u.shape[-1]), size=(self.pde.movingmesh_grid_size[-2], self.pde.movingmesh_grid_size[-1]), mode='bilinear', align_corners=True).squeeze(1) 104 | phi = mesh_model(u, xi) 105 | w = torch.ones(phi.shape).to(u.device) 106 | x1 = (torch.autograd.grad(phi, xi1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi1) 107 | x2 = (torch.autograd.grad(phi, xi2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi2) 108 | 109 | alpha = 1 110 | x1 = alpha * x1 + (1 - alpha) * xi1 111 | x2 = alpha * x2 + (1 - alpha) * xi2 112 | 113 | return x1, x2 114 | 115 | def moving_mesh_tri(self, u, mesh_model, grid_x, grid_y): 116 | """ 117 | getting the moved mesh 118 | u: (nu,n) 119 | grid_x: (n) 120 | grid_y: (n) 121 | return: x1, x2: (nu*n, 1) 122 | """ 123 | xi1, xi2 = grid_x.reshape(-1, 1), grid_y.reshape(-1, 1) 124 | xi1.requires_grad = True 125 | xi2.requires_grad = True 126 | xi = torch.cat((xi1, xi2), dim=-1) 127 | 128 | phi = mesh_model(u, xi) 129 | w = torch.ones(phi.shape).to(u.device) 130 | x1 = (torch.autograd.grad(phi, xi1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi1) 131 | x2 = (torch.autograd.grad(phi, xi2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi2) 132 | 133 | alpha = 1 134 | x1 = alpha * x1 + (1 - alpha) * xi1 135 | x2 = alpha * x2 + (1 - alpha) * xi2 136 | 137 | return x1, x2 138 | 139 | def create_data(self, datapoints, steps): 140 | """ 141 | getting data out of PDEs 142 | """ 143 | data = torch.Tensor() 144 | labels = torch.Tensor() 145 | 146 | for (dp, step) in zip(datapoints, steps): 147 | # d = dp[step - self.tw*2:step] 148 | d = dp[step - self.tw:step] 149 | l = dp[step:self.tw + step] 150 | 151 | data = torch.cat((data, d[None, :]), 0) 152 | labels = torch.cat((labels, l[None, :]), 0) 153 | 154 | return data, labels 155 | 156 | 157 | def create_graph(self, itp_model, data, labels, steps, device, mesh_model=None): 158 | """ 159 | getting moved mesh and interpolate data 160 | getting graph structure out of data sample 161 | previous timesteps are combined in one node 162 | """ 163 | data = data.to(device) 164 | labels = labels.to(device) 165 | 166 | if len(self.pde.grid_size) == 3: 167 | # h = 2 168 | ori_nx = data.shape[-2] 169 | ori_ny = data.shape[-1] 170 | ori_x = torch.linspace(0, self.pde.Lx, ori_nx).to(device) 171 | ori_y = torch.linspace(0, self.pde.Ly, ori_ny).to(device) 172 | ori_grid_x, ori_grid_y = torch.meshgrid(ori_x, ori_y) 173 | 174 | # h = 2 175 | mm_nx = self.pde.movingmesh_grid_size[-2] 176 | mm_ny = self.pde.movingmesh_grid_size[-1] 177 | mm_x = torch.linspace(0, self.pde.Lx, mm_nx).to(device) 178 | mm_y = torch.linspace(0, self.pde.Ly, mm_ny).to(device) 179 | mm_grid_x, mm_grid_y = torch.meshgrid(mm_x, mm_y) 180 | 181 | nt = self.pde.grid_size[0] 182 | nx = self.pde.grid_size[1] 183 | ny = self.pde.grid_size[2] 184 | n = nx * ny 185 | t = torch.linspace(self.pde.tmin, self.pde.tmax, nt).to(device) 186 | dt = t[1] - t[0] 187 | x = torch.linspace(0, self.pde.Lx, nx).to(device) 188 | dx = x[1]-x[0] 189 | y = torch.linspace(0, self.pde.Ly, ny).to(device) 190 | dy = y[1]-y[0] 191 | 192 | grid_x, grid_y = torch.meshgrid(x, y) 193 | grid = torch.stack((grid_x, grid_y), 2).float() 194 | grid = grid.view(-1, 2)[None].repeat(data.shape[0], 1, 1) 195 | radius = self.n * torch.sqrt(dx**2 + dy**2) + 0.0001 196 | 197 | if mesh_model != None: 198 | mesh_x, mesh_y = self.moving_mesh(data.reshape(-1, ori_nx, ori_ny)[:, ::int(ori_nx / mm_nx), ::int(ori_ny / mm_ny)], mesh_model, nx, ny) 199 | mesh = torch.cat((mesh_x, mesh_y), dim=-1).reshape(-1, nx*ny, 2) 200 | 201 | else: 202 | mesh_x, mesh_y = grid[:, :, 0].reshape(-1, 1), grid[:, :, 1].reshape(-1, 1) 203 | mesh = grid 204 | 205 | if mesh_model != None: 206 | data = self.interpolate(itp_model, data.reshape(-1, ori_nx, ori_ny), ori_grid_x[None].repeat(data.shape[0], 1, 1).reshape(-1, 1), ori_grid_y[None].repeat(data.shape[0], 1, 1).reshape(-1, 1),\ 207 | mesh_x, mesh_y, mode='1').reshape(-1, self.tw, nx, ny) 208 | labels = self.interpolate(itp_model, labels.reshape(-1, ori_nx, ori_ny), ori_grid_x[None].repeat(data.shape[0], 1, 1).reshape(-1, 1), ori_grid_y[None].repeat(data.shape[0], 1, 1).reshape(-1, 1),\ 209 | mesh_x, mesh_y, mode='1').reshape(-1, self.tw, nx, ny) 210 | 211 | elif len(self.pde.grid_size) == 2: 212 | # h = 2 213 | n = self.pde.ori_grid_size[1] 214 | grid = self.pde.ori_grid[None].repeat(data.shape[0], 1, 1).to(device) 215 | grid_x, grid_y = grid[:, :, 0], grid[:, :, 1] 216 | 217 | nt = self.pde.grid_size[0] 218 | nx = int(np.sqrt(self.pde.grid_size[1])) 219 | ny = int(np.sqrt(self.pde.grid_size[1])) 220 | t = torch.linspace(self.pde.tmin, self.pde.tmax, nt).to(device) 221 | dt = t[1] - t[0] 222 | x = torch.linspace(0, self.pde.Lx, nx).to(device) 223 | dx = x[1]-x[0] 224 | y = torch.linspace(0, self.pde.Ly, ny).to(device) 225 | dy = y[1]-y[0] 226 | radius = self.n * torch.sqrt(dx**2 + dy**2) + 0.0001 227 | 228 | if mesh_model != None: 229 | mesh_x, mesh_y = self.moving_mesh_tri(data.reshape(-1, n), mesh_model, grid_x, grid_y) 230 | mesh = torch.cat((mesh_x, mesh_y), dim=-1).reshape(-1, n, 2) 231 | 232 | else: 233 | mesh_x, mesh_y = grid_x.reshape(-1, 1), grid_y.reshape(-1, 1) 234 | mesh = grid 235 | 236 | u_new = torch.Tensor().to(device) 237 | x_new = torch.Tensor().to(device) 238 | grid_new = torch.Tensor().to(device) 239 | t_new = torch.Tensor().to(device) 240 | y_new = torch.Tensor().to(device) 241 | batch = torch.Tensor().to(device) 242 | for b, (data_batch, mesh_batch, grid_batch, labels_batch, step) in enumerate(zip(data, mesh, grid, labels, steps)): 243 | u_tmp = torch.transpose(torch.cat([d.reshape(-1, n) for d in data_batch]), 0, 1) 244 | y_tmp = torch.transpose(torch.cat([l.reshape(-1, n) for l in labels_batch]), 0, 1) 245 | 246 | u_new = torch.cat((u_new, u_tmp), ) 247 | x_new = torch.cat((x_new, mesh_batch), ) 248 | grid_new = torch.cat((grid_new, grid_batch), ) 249 | y_new = torch.cat((y_new, y_tmp), ) 250 | b_new = torch.ones(n).to(device)*b 251 | t_tmp = torch.ones(n).to(device)*t[step] 252 | 253 | t_new = torch.cat((t_new, t_tmp), ) 254 | batch = torch.cat((batch, b_new), ) 255 | 256 | # calculating the edge_index 257 | if self.e == 'radius': 258 | edge_index = radius_graph(x_new, r=radius, batch=batch.long(), loop=False) 259 | elif self.e == 'knn': 260 | edge_index = knn_graph(x_new, k=self.n, batch=batch.long(), loop=False) 261 | 262 | graph = Data(x=u_new, edge_index=edge_index) 263 | graph.y = y_new 264 | graph.pos = torch.cat((t_new[:, None], x_new), 1) 265 | graph.batch = batch.long() 266 | 267 | return graph.to(device) 268 | 269 | 270 | def interpolate_pred(self, itp_model, pred, graph, data, device): 271 | """ 272 | interpolating the prediction to the uniform mesh 273 | """ 274 | data = data.to(device) 275 | 276 | if len(self.pde.grid_size) == 3: 277 | ori_nx = self.pde.ori_grid_size[1] 278 | ori_ny = self.pde.ori_grid_size[2] 279 | ori_x = torch.linspace(0, self.pde.Lx, ori_nx).to(device) 280 | ori_y = torch.linspace(0, self.pde.Ly, ori_ny).to(device) 281 | ori_grid_x, ori_grid_y = torch.meshgrid(ori_x, ori_y) 282 | 283 | nx = self.pde.grid_size[1] 284 | ny = self.pde.grid_size[2] 285 | nu = int(pred.shape[0] / (nx * ny)) 286 | x = torch.linspace(0, self.pde.Lx, nx).to(device) 287 | y = torch.linspace(0, self.pde.Ly, ny).to(device) 288 | grid_x, grid_y = torch.meshgrid(x, y) 289 | 290 | pred_grid = self.interpolate(itp_model, pred.reshape(-1, nx, ny), graph.pos[:, [1]], graph.pos[:, [2]], ori_grid_x[None].repeat(nu, 1, 1).reshape(-1, 1),\ 291 | ori_grid_y[None].repeat(nu, 1, 1).reshape(-1, 1), mode='2').reshape(-1, 1, ori_nx, ori_ny) 292 | 293 | out = itp_model(None, None, mode = 'res_cut', data = data).reshape(-1, 1, ori_nx, ori_ny) + pred_grid 294 | 295 | elif len(self.pde.grid_size) == 2: 296 | n = self.pde.ori_grid_size[1] 297 | nu = int(pred.shape[0] / n) 298 | grid_x, grid_y = self.pde.ori_grid[:, 0].to(device), self.pde.ori_grid[:, 1].to(device) 299 | 300 | pred_grid = self.interpolate(itp_model, pred.reshape(-1, n), graph.pos[:, [1]], graph.pos[:, [2]], grid_x[None].repeat(nu, 1).reshape(-1, 1),\ 301 | grid_y[None].repeat(nu, 1).reshape(-1, 1), mode='2').reshape(-1, n) 302 | 303 | out = itp_model(None, None, mode = 'res_cut', data = data.reshape(-1, n)).reshape(-1, n) + pred_grid 304 | 305 | return out.reshape(-1, 1) 306 | -------------------------------------------------------------------------------- /mmpde.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import copy 4 | import sys 5 | import time 6 | from datetime import datetime 7 | import torch 8 | import random 9 | import numpy as np 10 | from matplotlib import pyplot as plt 11 | from torch import nn, optim 12 | from torch.nn import functional as F 13 | from torch.utils.data import DataLoader, TensorDataset 14 | from data_creator_2d import GraphCreator_FS_2D 15 | from gnn_2d import MP_PDE_Solver_2D 16 | from models_cnn import BaseCNN 17 | from train_helper_2d import * 18 | from PDEs import * 19 | from mesh.dmm_model import DMM 20 | from interpolate import ItpNet 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | 24 | def check_directory() -> None: 25 | """ 26 | Check if log directory exists within experiments 27 | """ 28 | if not os.path.exists(f'logs'): 29 | os.mkdir(f'logs') 30 | if not os.path.exists(f'models'): 31 | os.mkdir(f'models') 32 | 33 | def criterion(x, y): 34 | mse = torch.nn.MSELoss() 35 | # return torch.sqrt(mse(x, y)) / torch.sqrt(mse(x, torch.zeros_like(x).to(x.device))) 36 | return mse(x, y) 37 | 38 | def train(args: argparse, 39 | pde: PDE, 40 | epoch: int, 41 | model: torch.nn.Module, 42 | model_b: torch.nn.Module, 43 | itp_model: torch.nn.Module, 44 | mesh_model: torch.nn.Module, 45 | optimizer: torch.optim, 46 | optimizer2: torch.optim, 47 | loader: DataLoader, 48 | graph_creator: GraphCreator_FS_2D, 49 | criterion: torch.nn.modules.loss, 50 | device: torch.cuda.device="cpu") -> None: 51 | """ 52 | Training loop. 53 | Loop is over the mini-batches and for every batch we pick a random timestep. 54 | This is done for the number of timesteps in our training sample, which covers a whole episode. 55 | Args: 56 | args (argparse): command line inputs 57 | pde (PDE): PDE at hand [CE, WE, ...] 58 | model (torch.nn.Module): neural network PDE solver 59 | model_b (torch.nn.Module): branch neural network PDE solver 60 | itp_model (torch.nn.Module): neural network for interpolation 61 | mesh_model (torch.nn.Module): moving mesh operator 62 | optimizer (torch.optim): optimizer used for training 63 | loader (DataLoader): training dataloader 64 | graph_creator (GraphCreator_FS_2D): helper object to handle graph data 65 | criterion (torch.nn.modules.loss): criterion for training 66 | device (torch.cuda.device): device (cpu/gpu) 67 | Returns: 68 | None 69 | """ 70 | print(f'Starting epoch {epoch}...') 71 | model.train() 72 | if model_b != None: 73 | model_b.train() 74 | 75 | # Sample number of unrolling steps during training (pushforward trick) 76 | # Default is to unroll zero steps in the first epoch and then increase the max amount of unrolling steps per additional epoch. 77 | max_unrolling = epoch if epoch <= args.unrolling else args.unrolling 78 | unrolling = [r for r in range(max_unrolling + 1)] 79 | 80 | # Loop over every epoch as often as the number of timesteps in one trajectory. 81 | # Since the starting point is randomly drawn, this in expectation has every possible starting point/sample combination of the training data. 82 | # Therefore in expectation the whole available training information is covered. 83 | 84 | itp_losses = [] 85 | if mesh_model != None: 86 | itp_model.train() 87 | if epoch == 0: 88 | for i in range(graph_creator.t_res): 89 | losses = training_itp(itp_model, mesh_model, unrolling, 128 * args.batch_size, optimizer, optimizer2, loader, graph_creator, criterion, device) 90 | if(i % args.print_interval == 0): 91 | print(f'Training ItpNet Loss (progress: {i / graph_creator.t_res:.2f}): {torch.mean(losses)}') 92 | itp_losses.append(torch.mean(losses)) 93 | train_losses = [] 94 | for i in range(graph_creator.t_res): 95 | losses = training_loop_branch(model, model_b, itp_model, mesh_model, unrolling, args.batch_size, optimizer, optimizer2, loader, graph_creator, criterion, device) 96 | if(i % args.print_interval == 0): 97 | print(f'Training Loss (progress: {i / graph_creator.t_res:.2f}): {torch.mean(losses)}') 98 | train_losses.append(torch.mean(losses)) 99 | 100 | return train_losses, itp_losses 101 | 102 | def test(args: argparse, 103 | pde: PDE, 104 | model: torch.nn.Module, 105 | model_b: torch.nn.Module, 106 | itp_model: torch.nn.Module, 107 | mesh_model: torch.nn.Module, 108 | loader: DataLoader, 109 | graph_creator: GraphCreator_FS_2D, 110 | criterion: torch.nn.modules.loss, 111 | device: torch.cuda.device="cpu") -> torch.Tensor: 112 | """ 113 | Test routine 114 | Both step wise and unrolled forward losses are computed 115 | and compared against low resolution solvers 116 | step wise = loss for one neural network forward pass at certain timepoints 117 | unrolled forward loss = unrolling of the whole trajectory 118 | Args: 119 | args (argparse): command line inputs 120 | pde (PDE): PDE at hand [CE, WE, ...] 121 | model (torch.nn.Module): neural network PDE solver 122 | model_b (torch.nn.Module): branch neural network PDE solver 123 | itp_model (torch.nn.Module): neural network for interpolation 124 | mesh_model (torch.nn.Module): moving mesh operator 125 | loader (DataLoader): dataloader [valid, test] 126 | graph_creator (GraphCreator_FS_2D): helper object to handle graph data 127 | criterion (torch.nn.modules.loss): criterion for training 128 | device (torch.cuda.device): device (cpu/gpu) 129 | Returns: 130 | torch.Tensor: unrolled forward loss 131 | """ 132 | model.eval() 133 | if model_b != None: 134 | model_b.eval() 135 | if itp_model != None: 136 | itp_model.eval() 137 | 138 | # first we check the losses for different timesteps (one forward prediction array!) 139 | steps = [t for t in range(graph_creator.tw, graph_creator.t_res-graph_creator.tw + 1)] 140 | timestep_loss = test_timestep_losses(model=model, 141 | model_b=model_b, 142 | itp_model=itp_model, 143 | mesh_model=mesh_model, 144 | steps=steps, 145 | batch_size=args.batch_size, 146 | loader=loader, 147 | graph_creator=graph_creator, 148 | criterion=criterion, 149 | device=device) 150 | 151 | return timestep_loss 152 | 153 | 154 | def main(args: argparse): 155 | torch.manual_seed(args.seed) 156 | np.random.seed(args.seed) 157 | 158 | device = args.device 159 | check_directory() 160 | 161 | base_resolution = args.base_resolution 162 | if args.experiment == 'cy': 163 | data = torch.load('mesh/data/cylinder_rot_tri') 164 | data[:, :, :, :2] *= 2 165 | pde = cy(ori_grid=data[0, 0, :, :2], device=device) 166 | u = data[:, 10:, :, 2] 167 | u_train = u[:80] 168 | u_test = u[80:] 169 | elif args.experiment == 'burgers': 170 | pde = burgers(device=device) 171 | u = torch.tensor(np.load('mesh/data/burgers_192.npy'), dtype=torch.float)[:, :, ::int(192/base_resolution[1]), ::int(192/base_resolution[2])] 172 | u_train = u[:80] 173 | u_test = u[80:] 174 | else: 175 | raise Exception("Wrong experiment") 176 | 177 | 178 | # Equation specific parameters 179 | pde.grid_size = base_resolution 180 | pde.movingmesh_grid_size = base_resolution 181 | pde.ori_grid_size = base_resolution 182 | 183 | if args.model == 'BaseCNN': 184 | args.moving_mesh = False 185 | if args.moving_mesh == False: 186 | itp_model = None 187 | mesh_model = None 188 | else: 189 | if args.experiment == 'cy': 190 | itp_model = ItpNet(pde.ori_grid_size[1], None, args.itpnet_node1, args.itpnet_node2, args.res_cut_node).to(device) 191 | checkpoint = torch.load('cy_checkpoint', \ 192 | map_location=lambda storage, loc: storage) 193 | mesh_model = DMM(mode='graph', grid = data[0, 0, :, :2].to(device), branch_layer = checkpoint['args'].branch_layers, trunk_layer = [2] + checkpoint['args'].trunk_layers,\ 194 | out_layer = checkpoint['args'].out_layers).to(device) 195 | elif args.experiment == 'burgers': 196 | itp_model = ItpNet(pde.ori_grid_size[-2], pde.ori_grid_size[-1], args.itpnet_node1, args.itpnet_node2, args.res_cut_node).to(device) 197 | if base_resolution[1] == 48: 198 | checkpoint = torch.load('burgers_checkpoint', map_location=lambda storage, loc: storage) 199 | mesh_model = DMM(s=pde.movingmesh_grid_size[-1], mode='array', branch_layer = checkpoint['args'].branch_layers, trunk_layer = [2] + checkpoint['args'].trunk_layers, out_layer = checkpoint['args'].out_layers).to(device) 200 | mesh_model.load_state_dict(checkpoint['model_state_dict']) 201 | mesh_model.eval() 202 | 203 | try: 204 | train_dataset = TensorDataset(u_train, u_train) 205 | train_loader = DataLoader(train_dataset, 206 | batch_size=args.batch_size, 207 | shuffle=True, 208 | num_workers=4) 209 | test_dataset = TensorDataset(u_test, u_test) 210 | test_loader = DataLoader(test_dataset, 211 | batch_size=args.batch_size, 212 | shuffle=False, 213 | num_workers=4) 214 | except: 215 | raise Exception("Datasets could not be loaded properly") 216 | 217 | dateTimeObj = datetime.now() 218 | timestring = f'{dateTimeObj.date().month}-{dateTimeObj.date().day}-{dateTimeObj.time().hour}-{dateTimeObj.time().minute}-{dateTimeObj.time().second}' 219 | 220 | save_path = f'{args.experiment}_{args.model}_{args.batch_size}_mesh{args.moving_mesh}_xresolution{args.base_resolution[0]}-{args.base_resolution[1]}_lr{args.lr}_n{args.neighbors}_{args.connect_edge}_tw{args.time_window}_unrolling{args.unrolling}_time{datetime.now()}' 221 | log_dir = os.path.join("logs", save_path) 222 | writer = SummaryWriter(log_dir) 223 | 224 | save_path = f'models/{args.model}_{pde}_{args.experiment}_mesh{args.moving_mesh}_xresolution{args.base_resolution[0]}-{args.base_resolution[1]}_n{args.neighbors}_{args.connect_edge}_tw{args.time_window}_unrolling{args.unrolling}_time{timestring}.pt' 225 | print(f'Training on dataset of {args.experiment}') 226 | print(device) 227 | print(save_path) 228 | 229 | # Equation specific input variables 230 | eq_variables = {} 231 | graph_creator = GraphCreator_FS_2D(pde=pde, 232 | neighbors=args.neighbors, 233 | connect_edge=args.connect_edge, 234 | time_window=args.time_window, 235 | t_resolution=args.base_resolution[0], 236 | ).to(device) 237 | 238 | if args.model == 'GNN': 239 | model = MP_PDE_Solver_2D(pde=pde, 240 | time_window=graph_creator.tw, 241 | eq_variables=eq_variables, 242 | ).to(device) 243 | if args.moving_mesh == True: 244 | model_b = MP_PDE_Solver_2D(pde=pde, 245 | time_window=graph_creator.tw, 246 | eq_variables=eq_variables).to(device) 247 | else: 248 | model_b = None 249 | elif args.model == 'BaseCNN': 250 | model = BaseCNN(pde=pde, 251 | hidden_channels=args.hidden_channels, 252 | time_window=args.time_window).to(device) 253 | model_b = None 254 | else: 255 | raise Exception("Wrong model specified") 256 | 257 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 258 | params = sum([np.prod(p.size()) for p in model_parameters]) 259 | if mesh_model != None: 260 | model_parameters = filter(lambda p: p.requires_grad, model_b.parameters()) 261 | params2 = sum([np.prod(p.size()) for p in model_parameters]) 262 | model_parameters = filter(lambda p: p.requires_grad, itp_model.parameters()) 263 | params3 = sum([np.prod(p.size()) for p in model_parameters]) 264 | params = params + params2 + params3 265 | print(f'Number of parameters: {params}') 266 | 267 | # Optimizer 268 | if mesh_model != None: 269 | optimizer = optim.AdamW([{'params': model.parameters()}, 270 | {'params': model_b.parameters()}, 271 | {'params': itp_model.parameters()}], lr=args.lr) 272 | else: 273 | optimizer = optim.AdamW(model.parameters(), lr=args.lr) 274 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.unrolling, 30, 50, 70], gamma=args.lr_decay) 275 | optimizer2 = None 276 | 277 | # Training loop 278 | test_loss = 10e30 279 | train_losses = [] 280 | itp_losses = [] 281 | test_timestep_losses = [] 282 | for epoch in range(args.num_epochs): 283 | print(f"Epoch {epoch}") 284 | train_loss, itp_loss = train(args, pde, epoch, model, model_b, itp_model, mesh_model, optimizer, optimizer2, train_loader, graph_creator, criterion, device=device) 285 | train_losses.append(train_loss) 286 | itp_losses.append(itp_loss) 287 | print("Testing:") 288 | timestep_loss = test(args, pde, model, model_b, itp_model, mesh_model, test_loader, graph_creator, criterion, device=device) 289 | test_timestep_losses.append(timestep_loss) 290 | 291 | # Save model 292 | if args.moving_mesh == True: 293 | torch.save({ 294 | 'model_state_dict': model.state_dict(), 295 | 'model_b_state_dict': model_b.state_dict(), 296 | 'mesh_model_state_dict': mesh_model.state_dict(), 297 | 'itp_model_state_dict': itp_model.state_dict(), 298 | 'args': args, 299 | 'train_losses': train_losses, 300 | 'itp_losses': itp_losses, 301 | 'test_timestep_losses': test_timestep_losses, 302 | }, save_path) 303 | else: 304 | torch.save({ 305 | 'model_state_dict': model.state_dict(), 306 | 'args': args, 307 | 'train_losses': train_losses, 308 | 'itp_losses': itp_losses, 309 | 'test_timestep_losses': test_timestep_losses, 310 | }, save_path) 311 | print(f"Saved model at {save_path}\n") 312 | 313 | scheduler.step() 314 | 315 | for k, l in enumerate(train_loss): 316 | writer.add_scalar("train loss", l.item(), k+epoch*len(train_loss)) 317 | writer.add_scalar("test loss", timestep_loss.item(), epoch) 318 | 319 | print(f"Test loss: {test_loss}") 320 | 321 | 322 | if __name__ == "__main__": 323 | parser = argparse.ArgumentParser(description='Train a PDE solver') 324 | 325 | parser.add_argument('--seed', default=1, type=int, help='random seed') 326 | parser.add_argument('--device', type=str, default='cuda:0', 327 | help='Used device') 328 | # PDE 329 | parser.add_argument('--experiment', type=str, default='burgers', 330 | help='Experiment for PDE solver should be trained: [burgers, cy]') 331 | 332 | # Model 333 | parser.add_argument('--model', type=str, default='GNN', 334 | help='Model used as PDE solver: [GNN, BaseCNN]') 335 | parser.add_argument('--moving_mesh', type=eval, default=True, 336 | help='Use moving mesh method') 337 | 338 | # Model parameters 339 | parser.add_argument('--itpnet_node1', type=lambda s: [int(item) for item in s.split(',')], 340 | default=[128, 64], help="nodes of ItpNet1") 341 | parser.add_argument('--itpnet_node2', type=lambda s: [int(item) for item in s.split(',')], 342 | default=[128, 64], help="nodes of ItpNet2") 343 | parser.add_argument('--res_cut_node', type=lambda s: [int(item) for item in s.split(',')], 344 | default=[1, 4, 16, 4, 1], help="nodes of residual cut network") 345 | parser.add_argument('--hidden_channels', type=int, 346 | default=40, help="number of hidden channels of CNN") 347 | parser.add_argument('--batch_size', type=int, default=6, 348 | help='Number of samples in each minibatch') 349 | parser.add_argument('--num_epochs', type=int, default=80, 350 | help='Number of training epochs') 351 | parser.add_argument('--lr', type=float, default=2e-3, 352 | help='Learning rate') 353 | parser.add_argument('--lr_decay', type=float, 354 | default=0.4, help='multistep lr decay') 355 | 356 | # Base resolution and super resolution 357 | parser.add_argument('--base_resolution', type=lambda s: [int(item) for item in s.split(',')], 358 | default=[31, 48, 48], help="PDE base resolution on which network is applied") 359 | parser.add_argument('--neighbors', type=int, 360 | default=35, help="Neighbors to be considered in GNN solver") 361 | parser.add_argument('--connect_edge', type=str, default='knn', 362 | help='The way to connect edge: [knn, radius]') 363 | parser.add_argument('--time_window', type=int, 364 | default=1, help="Time steps to be considered in GNN solver") 365 | parser.add_argument('--unrolling', type=int, 366 | default=0, help="Unrolling which proceeds with each epoch") 367 | 368 | # Misc 369 | parser.add_argument('--print_interval', type=int, default=2, 370 | help='Interval between print statements') 371 | parser.add_argument('--log', type=eval, default=True, 372 | help='pip the output to log file') 373 | 374 | args = parser.parse_args() 375 | print(args) 376 | main(args) 377 | -------------------------------------------------------------------------------- /mesh/dmm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | import scipy 5 | from scipy.integrate import dblquad 6 | from scipy.spatial import Delaunay 7 | import functools 8 | import sympy as sp 9 | from sklearn.neighbors import NearestNeighbors 10 | from torch.utils.data import DataLoader,TensorDataset 11 | import os 12 | import torch.nn.functional as F 13 | import matplotlib 14 | import matplotlib.pyplot as plt 15 | import matplotlib.tri as tri 16 | import matplotlib.cm as cm 17 | from datetime import datetime 18 | from functools import reduce 19 | import operator 20 | from torchmin import minimize # pip install pytorch-minimize 21 | 22 | def count_params(model): 23 | c = 0 24 | for p in list(model.parameters()): 25 | c += reduce(operator.mul, list(p.size())) 26 | return c 27 | 28 | 29 | def sample_train_data(u, nx, nu, device): 30 | grid = torch.tensor(np.random.uniform(0, 1, (nu, 40 * nx, 2)), dtype=torch.float).to(device) 31 | u_idx = np.random.choice(a=u.shape[0], size=nu, replace=True) 32 | u = u[u_idx].to(device) 33 | ux = diff_x(u) * (u.shape[-1] - 1) 34 | uy = diff_y(u) * (u.shape[-1] - 1) 35 | alpha = torch.sum((torch.abs(ux)**2 + torch.abs(uy)**2)**(1/2), dim=(-2, -1)) / (u.shape[-1]-1)**2 36 | m = monitor(alpha.unsqueeze(-1).unsqueeze(-1).repeat(1, ux.shape[-1], ux.shape[-1]), ux, uy) 37 | RHS = torch.sum(m, dim=(-2, -1)) / (u.shape[-1]-1)**2 38 | m_normalized = m 39 | 40 | all_p = [] 41 | sub_nu = 4 42 | N = int(nu / sub_nu) 43 | for i in range(sub_nu): 44 | all_p.append(interpolate(m_normalized[i * N : (i+1) * N].unsqueeze(1).repeat(1, grid.shape[1], 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]),\ 45 | grid[i * N : (i+1) * N, :, [0]].reshape(-1, 1), grid[i * N : (i+1) * N, :, [1]].reshape(-1, 1))[:, 0].cpu().numpy()) 46 | all_p = np.array(all_p).flatten().reshape(nu, grid.shape[1]) 47 | 48 | grid_choosed = torch.zeros(nu, nx, 2).to(device) 49 | for i in range(nu): 50 | p = all_p[i] / np.sum(all_p[i]) 51 | idx = np.random.choice(a=grid.shape[1], size=nx, replace=False, p=p) 52 | grid_choosed[i] = grid[i, idx] 53 | 54 | return u, ux, uy, alpha, m, RHS, grid_choosed.reshape(-1, 2) 55 | 56 | def sample_train_data_bound(u, nx, nu, device): 57 | u_idx = np.random.choice(a=u.shape[0], size=4*nu, replace=True) 58 | u = u[u_idx].to(device) 59 | ux = diff_x(u) * (u.shape[-1] - 1) 60 | uy = diff_y(u) * (u.shape[-1] - 1) 61 | alpha = torch.sum((torch.abs(ux)**2 + torch.abs(uy)**2)**(1/2), dim=(-2, -1)) / (u.shape[-1]-1)**2 62 | m = monitor(alpha.unsqueeze(-1).unsqueeze(-1).repeat(1, ux.shape[-1], ux.shape[-1]), ux, uy) 63 | RHS = torch.sum(m, dim=(-2, -1)) / (u.shape[-1]-1)**2 64 | 65 | n = int(nx/4) 66 | # n = nx 67 | bound1 = [] 68 | bound2 = [] 69 | bound3 = [] 70 | bound4 = [] 71 | X1 = np.linspace(0, 1, n) 72 | for i in X1: 73 | data1 = [0, i] 74 | bound1.append(data1) 75 | X2 = np.linspace(0, 1, n) 76 | for i in X2: 77 | data2 = [1, i] 78 | bound2.append(data2) 79 | X3 = np.linspace(0, 1, n) 80 | for i in X3: 81 | data3 = [i, 0] 82 | bound3.append(data3) 83 | X4 = np.linspace(0, 1, n) 84 | for i in X4: 85 | data4 = [i, 1] 86 | bound4.append(data4) 87 | bound1 = torch.tensor(bound1, dtype=torch.float).to(device) 88 | bound2 = torch.tensor(bound2, dtype=torch.float).to(device) 89 | bound3 = torch.tensor(bound3, dtype=torch.float).to(device) 90 | bound4 = torch.tensor(bound4, dtype=torch.float).to(device) 91 | 92 | bound1_u = u[:nu] 93 | bound2_u = u[nu:2*nu] 94 | bound3_u = u[2*nu:3*nu] 95 | bound4_u = u[3*nu:4*nu] 96 | 97 | bound1_m = m[:nu] 98 | bound2_m = m[nu:2*nu] 99 | bound3_m = m[2*nu:3*nu] 100 | bound4_m = m[3*nu:4*nu] 101 | 102 | return bound1.repeat(nu, 1, 1).reshape(-1, 2), bound2.repeat(nu, 1, 1).reshape(-1, 2), bound3.repeat(nu, 1, 1).reshape(-1, 2),\ 103 | bound4.repeat(nu, 1, 1).reshape(-1, 2), bound1_u, bound2_u, bound3_u, bound4_u, bound1_m, bound2_m, bound3_m, bound4_m 104 | 105 | 106 | def sample_train_data_tri(all_u, nx, nu, device): 107 | u = all_u[:, :, 2].to(device) 108 | grid = torch.tensor(np.random.uniform(0, 1, (nu, 40 * nx, 2)), dtype=torch.float).to(device) 109 | u_idx = np.random.choice(a=u.shape[0], size=nu, replace=True) 110 | u = u[u_idx].to(device) 111 | ori_mesh_x = all_u[u_idx, :, 0].unsqueeze(-1).to(device) 112 | ori_mesh_y = all_u[u_idx, :, 1].unsqueeze(-1).to(device) 113 | 114 | n = int(np.sqrt(u.shape[-1])) 115 | uni_grid = torch.tensor(np.array(np.meshgrid(np.linspace(0, 1, n), np.linspace(0, 1, n))), dtype=torch.float)\ 116 | .reshape(1, 2, -1).repeat(nu, 1, 1).permute(0, 2, 1).to(device) 117 | 118 | all_p = [] 119 | uni_ux = torch.Tensor().to(device) 120 | uni_uy = torch.Tensor().to(device) 121 | alpha = torch.Tensor().to(device) 122 | uni_m = torch.Tensor().to(device) 123 | RHS = torch.Tensor().to(device) 124 | sub_nu = 10 125 | N = int(nu / sub_nu) 126 | for i in range(sub_nu): 127 | 128 | # alpha 129 | x1_ = uni_grid[i * N : (i+1) * N, :, [0]].reshape(-1, 1) 130 | x2_ = uni_grid[i * N : (i+1) * N, :, [1]].reshape(-1, 1) 131 | x1_.requires_grad = True 132 | x2_.requires_grad = True 133 | u_ = interpolate_tri(u[i*N : (i+1)*N].unsqueeze(1).repeat(1, n**2, 1).reshape(-1, u.shape[-1]), \ 134 | ori_mesh_x[i*N : (i+1)*N].unsqueeze(1).repeat(1, n**2, 1, 1).reshape(-1, u.shape[-1], 1), \ 135 | ori_mesh_y[i*N : (i+1)*N].unsqueeze(1).repeat(1, n**2, 1, 1).reshape(-1, u.shape[-1], 1), \ 136 | x1_.unsqueeze(1).repeat(1, u.shape[-1], 1), x2_.unsqueeze(1).repeat(1, u.shape[-1], 1)) 137 | w = torch.ones(u_.shape).to(device) 138 | uni_ux_ = torch.autograd.grad(u_, x1_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0].reshape(N, n, n) 139 | uni_uy_ = torch.autograd.grad(u_, x2_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0].reshape(N, n, n) 140 | alpha_ = torch.sum((torch.abs(uni_ux_)**2 + torch.abs(uni_uy_)**2)**(1/2), dim=(-2, -1)) / (n-1)**2 141 | uni_m_ = monitor(alpha_.unsqueeze(-1).unsqueeze(-1).repeat(1, n, n), uni_ux_, uni_uy_) 142 | RHS_ = torch.sum(uni_m_, dim=(-2, -1)) / (n-1)**2 143 | 144 | uni_ux = torch.cat((uni_ux, uni_ux_), ) 145 | uni_uy = torch.cat((uni_uy, uni_uy_), ) 146 | alpha = torch.cat((alpha, alpha_), ) 147 | uni_m = torch.cat((uni_m, uni_m_), ) 148 | RHS = torch.cat((RHS, RHS_), ) 149 | 150 | # ux, uy 151 | x1_ = grid[i * N : (i+1) * N, :, [0]].reshape(-1, 1) 152 | x2_ = grid[i * N : (i+1) * N, :, [1]].reshape(-1, 1) 153 | ux_ = interpolate(uni_ux_.unsqueeze(1).repeat(1, grid.shape[1], 1, 1).reshape(-1, n, n), x1_, x2_).reshape(N, grid.shape[1]) 154 | uy_ = interpolate(uni_uy_.unsqueeze(1).repeat(1, grid.shape[1], 1, 1).reshape(-1, n, n), x1_, x2_).reshape(N, grid.shape[1]) 155 | m = monitor(alpha_.unsqueeze(-1).repeat(1, grid.shape[1]), ux_, uy_) 156 | 157 | all_p.append(m.detach().cpu().numpy()) 158 | all_p = np.array(all_p).flatten().reshape(nu, grid.shape[1]) 159 | 160 | grid_choosed = torch.zeros(nu, nx, 2).to(device) 161 | for i in range(nu): 162 | p = all_p[i] / np.sum(all_p[i]) 163 | idx = np.random.choice(a=grid.shape[1], size=nx, replace=False, p=p) 164 | grid_choosed[i] = grid[i, idx] 165 | # plot_grid(grid_choosed[0], uni_m[0]) 166 | 167 | return u, uni_ux, uni_uy, alpha, uni_m, RHS, grid_choosed.reshape(-1, 2) 168 | 169 | def sample_train_data_bound_tri(u, nx, nu, device): 170 | u_idx = np.random.choice(a=u.shape[0], size=4*nu, replace=True) 171 | u = u[u_idx, :, 2].to(device) 172 | 173 | n = int(nx/4) 174 | # n = nx 175 | bound1 = [] 176 | bound2 = [] 177 | bound3 = [] 178 | bound4 = [] 179 | X1 = np.linspace(0, 1, n) 180 | for i in X1: 181 | data1 = [0, i] 182 | bound1.append(data1) 183 | X2 = np.linspace(0, 1, n) 184 | for i in X2: 185 | data2 = [1, i] 186 | bound2.append(data2) 187 | X3 = np.linspace(0, 1, n) 188 | for i in X3: 189 | data3 = [i, 0] 190 | bound3.append(data3) 191 | X4 = np.linspace(0, 1, n) 192 | for i in X4: 193 | data4 = [i, 1] 194 | bound4.append(data4) 195 | bound1 = torch.tensor(bound1, dtype=torch.float).to(device) 196 | bound2 = torch.tensor(bound2, dtype=torch.float).to(device) 197 | bound3 = torch.tensor(bound3, dtype=torch.float).to(device) 198 | bound4 = torch.tensor(bound4, dtype=torch.float).to(device) 199 | 200 | bound1_u = u[:nu] 201 | bound2_u = u[nu:2*nu] 202 | bound3_u = u[2*nu:3*nu] 203 | bound4_u = u[3*nu:4*nu] 204 | 205 | return bound1.repeat(nu, 1, 1).reshape(-1, 2), bound2.repeat(nu, 1, 1).reshape(-1, 2), bound3.repeat(nu, 1, 1).reshape(-1, 2),\ 206 | bound4.repeat(nu, 1, 1).reshape(-1, 2), bound1_u, bound2_u, bound3_u, bound4_u 207 | 208 | 209 | def monitor(alpha, ux, uy): 210 | return (1 + (torch.abs(ux)**2 + torch.abs(uy)**2) ** (1/2) / (0.01*alpha)) 211 | 212 | def monitor_np(alpha, ux, uy): 213 | return (1 + (np.abs(ux)**2 + np.abs(uy)**2) ** (1/2) / (0.01*alpha)) 214 | 215 | def diff_x(u): 216 | ux = torch.zeros_like(u) 217 | ux[:,:-1,:] = torch.diff(u, dim=-2) 218 | ux[:,-1,:] = ux[:,-2,:] 219 | return ux 220 | 221 | def diff_y(u): 222 | uy = torch.zeros_like(u) 223 | uy[:,:,:-1] = torch.diff(u, dim=-1) 224 | uy[:,:,-1] = uy[:,:,-2] 225 | return uy 226 | 227 | def init_weights(t): 228 | with torch.no_grad(): 229 | if type(t) == torch.nn.Linear: 230 | t.weight.normal_(0, 0.02) 231 | t.bias.normal_(0, 0.02) 232 | 233 | def interpolate(u, x, y, n_neighbors = 50): 234 | """ 235 | u: b*n*n 236 | x: b*1 237 | y: b*1 238 | """ 239 | 240 | n = u.shape[-1] 241 | grid_x = np.linspace(0, 1, n) 242 | grid_y = np.linspace(0, 1, n) 243 | grid = torch.tensor(np.array(np.meshgrid(grid_x, grid_y)), dtype=torch.float).reshape(1, 2, -1).permute(0, 2, 1).to(u.device) 244 | d = -torch.norm(grid.repeat(x.shape[0], 1, 1) - torch.cat((x, y), dim=-1).unsqueeze(1).repeat(1, n*n, 1), dim=-1) * n 245 | normalize = nn.Softmax(dim=-1) 246 | weight = normalize(d) 247 | interpolated = torch.sum(u.reshape(-1, n**2) * weight, dim=-1).unsqueeze(-1) 248 | 249 | return interpolated # b*1 250 | 251 | def interpolate_tri(u, ori_x, ori_y, x, y): 252 | """ 253 | u: b*n 254 | ori_x: b*n*1 255 | ori_y: b*n*1 256 | x: b*n*1 257 | y: b*n*1 258 | """ 259 | n = u.shape[-1] 260 | grid = torch.cat((ori_x, ori_y), dim=-1) 261 | d = -torch.norm(grid - torch.cat((x, y), dim=-1), dim=-1) * np.sqrt(n) 262 | normalize = nn.Softmax(dim=-1) 263 | weight = normalize(d) 264 | # weight = softmax(d, dim=-1) 265 | interpolated = torch.sum(u * weight, dim=-1).unsqueeze(-1) 266 | 267 | return interpolated # b*1 268 | 269 | def softmax(x, dim): 270 | exp_x = torch.exp(x) 271 | sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True) 272 | softmax_output = exp_x / sum_exp_x 273 | 274 | return softmax_output 275 | 276 | 277 | def train_data_loader_p(n, m): 278 | grid = torch.tensor(np.random.uniform(0, 1, (50 * n, 2)), dtype=torch.float).cuda() 279 | p = [] 280 | N = int(grid.shape[0] / 5) 281 | for i in range(5): 282 | p.append(interpolate(m.reshape(-1, 1).repeat(N, 1).reshape(N, 50, 50), grid[i * N : (i+1) * N, [0]], grid[i * N : (i+1) * N, [1]])[:, 0].cpu().numpy()) 283 | p = np.array(p).flatten() 284 | p = p / np.sum(p) 285 | idx = np.random.choice(a=grid.shape[0], size=n, replace=False, p=p) 286 | # dataset = TensorDataset(grid[idx].cpu()) 287 | return grid[idx].cpu() 288 | 289 | 290 | def random_feature_torch(weight, convex_rel, second_out, second_out_bound1, second_out_bound2, second_out_bound3, second_out_bound4, branch, branch_bound1, branch_bound2, branch_bound3, branch_bound4, args, m_xi,\ 291 | alpha, x1, x2, ux, uy, so_x_bound1, so_x_bound2, so_y_bound3, so_y_bound4, so_x, so_y, so_xx, so_yy, so_xy, so_yx, RHS): 292 | criterion = nn.MSELoss() 293 | device = second_out.device 294 | weight = weight.reshape(branch.shape[1], second_out.shape[1]) 295 | # weight_bias = params['weight_bias'].value 296 | 297 | # loss of boundary condition 298 | branch_bound1 = branch_bound1.reshape(-1, branch.shape[-1]) 299 | # so_x_one_bound1 = torch.cat((so_x_bound1, torch.ones((second_out_bound1.shape[0], 1)).to(device)), dim=1) 300 | trunkx_bound1 = torch.matmul(so_x_bound1, weight.T) 301 | phix_bound1 = torch.sum(branch_bound1 * trunkx_bound1, dim=1).reshape(-1, 1) 302 | loss_bound1 = criterion(phix_bound1, torch.zeros_like(phix_bound1).to(device)) 303 | 304 | branch_bound2 = branch_bound2.reshape(-1, branch.shape[-1]) 305 | # so_x_one_bound2 = torch.cat((so_x_bound2, torch.ones((second_out_bound2.shape[0], 1)).to(device)), dim=1) 306 | trunkx_bound2 = torch.matmul(so_x_bound2, weight.T) 307 | phix_bound2 = torch.sum(branch_bound2 * trunkx_bound2, dim=1).reshape(-1, 1) 308 | loss_bound2 = criterion(phix_bound2, torch.zeros_like(phix_bound2).to(device)) 309 | 310 | branch_bound3 = branch_bound3.reshape(-1, branch.shape[-1]) 311 | # so_y_one_bound3 = torch.cat((so_y_bound3, torch.ones((second_out_bound3.shape[0], 1)).to(device)), dim=1) 312 | trunky_bound3 = torch.matmul(so_y_bound3, weight.T) 313 | phiy_bound3 = torch.sum(branch_bound3 * trunky_bound3, dim=1).reshape(-1, 1) 314 | loss_bound3 = criterion(phiy_bound3, torch.zeros_like(phiy_bound3).to(device)) 315 | 316 | branch_bound4 = branch_bound4.reshape(-1, branch.shape[-1]) 317 | # so_y_one_bound4 = torch.cat((so_y_bound4, torch.ones((second_out_bound4.shape[0], 1)).to(device)), dim=1) 318 | trunky_bound4 = torch.matmul(so_y_bound4, weight.T) 319 | phiy_bound4 = torch.sum(branch_bound4 * trunky_bound4, dim=1).reshape(-1, 1) 320 | loss_bound4 = criterion(phiy_bound4, torch.zeros_like(phiy_bound4).to(device)) 321 | 322 | loss_bound = (loss_bound1 + loss_bound2 + loss_bound3 + loss_bound4) / 4 323 | 324 | # loss in 325 | branch = branch.reshape(-1, branch.shape[-1]) 326 | trunkx = torch.matmul(so_x, weight.T) 327 | trunky = torch.matmul(so_y, weight.T) 328 | trunkxx = torch.matmul(so_xx, weight.T) 329 | trunkxy = torch.matmul(so_xy, weight.T) 330 | trunkyx = torch.matmul(so_yx, weight.T) 331 | trunkyy = torch.matmul(so_yy, weight.T) 332 | phix = torch.sum(branch * trunkx, dim=1).reshape(-1, 1) 333 | phiy = torch.sum(branch * trunky, dim=1).reshape(-1, 1) 334 | phixx = torch.sum(branch * trunkxx, dim=1).reshape(-1, 1) 335 | phixy = torch.sum(branch * trunkxy, dim=1).reshape(-1, 1) 336 | phiyx = torch.sum(branch * trunkyx, dim=1).reshape(-1, 1) 337 | phiyy = torch.sum(branch * trunkyy, dim=1).reshape(-1, 1) 338 | ux_ = interpolate(ux.unsqueeze(1).repeat(1, args.batch_size_x_rf, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 339 | uy_ = interpolate(uy.unsqueeze(1).repeat(1, args.batch_size_x_rf, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 340 | u_xi_x = ux_ * (1 + phixx) + uy_ * phiyx 341 | u_xi_y = ux_ * phixy + uy_ * (1 + phiyy) 342 | m_xi = monitor(alpha.unsqueeze(1).repeat(1, args.batch_size_x_rf).reshape(-1, 1), u_xi_x, u_xi_y) 343 | LHS = m_xi * ((1 + phixx) * (1 + phiyy) - phixy * phiyx) 344 | 345 | loss_in = criterion(LHS / RHS.unsqueeze(1).repeat(1, args.batch_size_x_rf).reshape(-1, 1), torch.ones_like(LHS)) 346 | loss_convex = torch.mean(torch.min(torch.tensor(0).type_as(phixx).to(device), 1 + phixx)**2 + torch.min(torch.tensor(0).type_as(phiyy).to(device), 1 + phiyy)**2) 347 | # print(loss_in.item(), loss_bound.item()) 348 | return convex_rel * (torch.sum(weight ** 2)) ** 2 + args.loss_weight1 * loss_bound + args.loss_weight0 * loss_in + args.loss_weight2 * loss_convex 349 | 350 | 351 | def random_feature_torch2(weight, convex_rel, second_out, second_out_bound1, second_out_bound2, second_out_bound3, second_out_bound4, args,\ 352 | alpha, x1, x2, ux, uy, so_x_bound1, so_x_bound2, so_y_bound3, so_y_bound4, so_x, so_y, so_xx, so_yy, so_xy, so_yx, RHS): 353 | criterion = nn.MSELoss() 354 | device = second_out.device 355 | weight = weight.reshape(1, second_out.shape[1]) 356 | 357 | # loss of boundary condition 358 | phix_bound1 = torch.matmul(so_x_bound1, weight.T).reshape(-1, 1) 359 | loss_bound1 = criterion(phix_bound1, torch.zeros_like(phix_bound1).to(device)) 360 | 361 | phix_bound2 = torch.matmul(so_x_bound2, weight.T).reshape(-1, 1) 362 | loss_bound2 = criterion(phix_bound2, torch.zeros_like(phix_bound2).to(device)) 363 | 364 | phiy_bound3 = torch.matmul(so_y_bound3, weight.T).reshape(-1, 1) 365 | loss_bound3 = criterion(phiy_bound3, torch.zeros_like(phiy_bound3).to(device)) 366 | 367 | phiy_bound4 = torch.matmul(so_y_bound4, weight.T).reshape(-1, 1) 368 | loss_bound4 = criterion(phiy_bound4, torch.zeros_like(phiy_bound4).to(device)) 369 | 370 | loss_bound = (loss_bound1 + loss_bound2 + loss_bound3 + loss_bound4) / 4 371 | 372 | # loss in 373 | phix = torch.matmul(so_x, weight.T).reshape(-1, 1) 374 | phiy = torch.matmul(so_y, weight.T).reshape(-1, 1) 375 | phixx = torch.matmul(so_xx, weight.T).reshape(-1, 1) 376 | phixy = torch.matmul(so_xy, weight.T).reshape(-1, 1) 377 | phiyx = torch.matmul(so_yx, weight.T).reshape(-1, 1) 378 | phiyy = torch.matmul(so_yy, weight.T).reshape(-1, 1) 379 | ux_ = interpolate(ux.unsqueeze(1).repeat(1, args.batch_size_x_rf, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 380 | uy_ = interpolate(uy.unsqueeze(1).repeat(1, args.batch_size_x_rf, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 381 | u_xi_x = ux_ * (1 + phixx) + uy_ * phiyx 382 | u_xi_y = ux_ * phixy + uy_ * (1 + phiyy) 383 | m_xi = monitor(alpha.unsqueeze(1).repeat(1, args.batch_size_x_rf).reshape(-1, 1), u_xi_x, u_xi_y) 384 | LHS = m_xi * ((1 + phixx) * (1 + phiyy) - phixy * phiyx) 385 | 386 | loss_in = criterion(LHS / RHS.unsqueeze(1).repeat(1, args.batch_size_x_rf).reshape(-1, 1), torch.ones_like(LHS)) 387 | loss_convex = torch.mean(torch.min(torch.tensor(0).type_as(phixx).to(device), 1 + phixx)**2 + torch.min(torch.tensor(0).type_as(phiyy).to(device), 1 + phiyy)**2) 388 | return convex_rel * (torch.sum(weight ** 2)) ** 2 + args.loss_weight1 * loss_bound + args.loss_weight0 * loss_in + args.loss_weight2 * loss_convex 389 | 390 | 391 | def train_MA_res(ori_u, all_u, test_u, args, model, init_mesh, n_epoch_adam, n_epoch_lbfgs, device): 392 | # writer = SummaryWriter(logdir='runs') 393 | logs_txt = [] 394 | logs_txt.append(str(args)) 395 | 396 | optimizer_adam = torch.optim.Adam(model.parameters(), lr=args.lr_adam, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) 397 | scheduler_adam = torch.optim.lr_scheduler.MultiStepLR(optimizer_adam, milestones=[100, 150], gamma=args.gamma_adam) 398 | optimizer_lbfgs = torch.optim.LBFGS(model.parameters(), lr=args.lr_lbfgs, tolerance_grad=-1, tolerance_change=-1) 399 | scheduler_lbfgs = torch.optim.lr_scheduler.MultiStepLR(optimizer_lbfgs, milestones=[75, 125], gamma=args.gamma_lbfgs) 400 | criterion = nn.MSELoss() 401 | 402 | loss_in_list = [] 403 | test_loss_in_list = [] 404 | loss_bound_list = [] 405 | loss_convex_list = [] 406 | test_equ_loss_list = [] 407 | test_equ_max_list = [] 408 | test_equ_min_list = [] 409 | test_equ_mid_list = [] 410 | LHS_list = [] 411 | RHS_list = [] 412 | itp_list1 = [] 413 | itp_list2 = [] 414 | train_std_list = [] 415 | train_minmax_list = [] 416 | test_std_list = [] 417 | test_minmax_list = [] 418 | 419 | log_count1 = [] 420 | log_count2 = [] 421 | log_count1.append(0) 422 | log_count2.append(0) 423 | 424 | epoch = 0 425 | for epoch in range(1, n_epoch_adam + n_epoch_lbfgs + 1): 426 | start = datetime.now() 427 | # Adam 428 | if epoch < n_epoch_adam + 1: 429 | 430 | for i in range(np.max((1, int(args.train_sample_grid * all_u.shape[0] / (args.batch_size_x_adam * args.batch_size_u_adam))))): 431 | # sample points 432 | if args.experiment == 'burgers': 433 | u, ux, uy, alpha, m, RHS, x = sample_train_data(all_u, args.batch_size_x_adam, args.batch_size_u_adam, device) # b 434 | bound1, bound2, bound3, bound4, bound1_u, bound2_u, bound3_u, bound4_u, bound1_m, bound2_m, bound3_m, bound4_m = sample_train_data_bound(all_u, args.batch_size_x_adam, args.batch_size_u_adam, device) # 4 * (b//4) 435 | elif args.experiment == 'cy': 436 | u, ux, uy, alpha, m, RHS, x = sample_train_data_tri(all_u, args.batch_size_x_adam, args.batch_size_u_adam, device) # b 437 | bound1, bound2, bound3, bound4, bound1_u, bound2_u, bound3_u, bound4_u = sample_train_data_bound_tri(all_u, args.batch_size_x_adam, args.batch_size_u_adam, device) # 4 * (b//4) 438 | 439 | optimizer_adam.zero_grad() 440 | 441 | if args.bound_constraint == 'soft': 442 | # loss of boundary condition 443 | if len(bound1) == 0: 444 | loss_bound1 = torch.zeros(1).to(device) 445 | else: 446 | bound11 = bound1[:, 0].view(-1, 1) 447 | bound12 = bound1[:, 1].view(-1, 1) 448 | # bound1t = bound1[:, 2].view(-1, 1) 449 | bound11.requires_grad = True 450 | bound12.requires_grad = True 451 | # X1 = torch.cat((bound11, bound12, bound1t), dim=1) 452 | X1 = torch.cat((bound11, bound12), dim=1) 453 | output_bound1 = model(bound1_u, X1) 454 | v1 = torch.ones(output_bound1.shape).to(device) 455 | bound1_x = torch.autograd.grad(output_bound1, bound11, grad_outputs=v1, retain_graph = True, create_graph=True, allow_unused=True)[0] 456 | loss_bound1 = criterion(bound1_x, torch.zeros_like(bound1_x)) 457 | 458 | if len(bound2) == 0: 459 | loss_bound2 = torch.zeros(1).to(device) 460 | else: 461 | bound21 = bound2[:, 0].view(-1, 1) 462 | bound22 = bound2[:, 1].view(-1, 1) 463 | # bound2t = bound2[:, 2].view(-1, 1) 464 | bound21.requires_grad = True 465 | bound22.requires_grad = True 466 | # X2 = torch.cat((bound21, bound22, bound2t), dim=1) 467 | X2 = torch.cat((bound21, bound22), dim=1) 468 | output_bound2 = model(bound2_u, X2) 469 | v2 = torch.ones(output_bound2.shape).to(device) 470 | bound2_x = torch.autograd.grad(output_bound2, bound21, grad_outputs=v2, retain_graph = True, create_graph=True, allow_unused=True)[0] 471 | loss_bound2 = criterion(bound2_x, torch.zeros_like(bound2_x)) 472 | 473 | if len(bound3) == 0: 474 | loss_bound3 = torch.zeros(1).to(device) 475 | else: 476 | bound31 = bound3[:, 0].view(-1, 1) 477 | bound32 = bound3[:, 1].view(-1, 1) 478 | # bound3t = bound3[:, 2].view(-1, 1) 479 | bound31.requires_grad = True 480 | bound32.requires_grad = True 481 | # X3 = torch.cat((bound31, bound32, bound3t), dim=1) 482 | X3 = torch.cat((bound31, bound32), dim=1) 483 | output_bound3 = model(bound3_u, X3) 484 | v3 = torch.ones(output_bound3.shape).to(device) 485 | bound3_y = torch.autograd.grad(output_bound3, bound32, grad_outputs=v3, retain_graph = True, create_graph=True, allow_unused=True)[0] 486 | loss_bound3 = criterion(bound3_y, torch.zeros_like(bound3_y)) 487 | 488 | if len(bound4) == 0: 489 | loss_bound4 = torch.zeros(1).to(device) 490 | else: 491 | bound41 = bound4[:, 0].view(-1, 1) 492 | bound42 = bound4[:, 1].view(-1, 1) 493 | # bound4t = bound4[:, 2].view(-1, 1) 494 | bound41.requires_grad = True 495 | bound42.requires_grad = True 496 | # X4 = torch.cat((bound41, bound42, bound4t), dim=1) 497 | X4 = torch.cat((bound41, bound42), dim=1) 498 | output_bound4 = model(bound4_u, X4) 499 | v4 = torch.ones(output_bound4.shape).to(device) 500 | bound4_y = torch.autograd.grad(output_bound4, bound42, grad_outputs=v4, retain_graph = True, create_graph=True, allow_unused=True)[0] 501 | loss_bound4 = criterion(bound4_y, torch.zeros_like(bound4_y)) 502 | 503 | loss_bound = (loss_bound1 + loss_bound2 + loss_bound3 + loss_bound4) / 4 504 | else: 505 | loss_bound = torch.tensor(0).to(device) 506 | 507 | # loss inside 508 | x1 = x[:, 0].view(x.shape[0], 1) 509 | x2 = x[:, 1].view(x.shape[0], 1) 510 | # xt = x[:, 2].view(x.shape[0], 1) 511 | x1.requires_grad = True 512 | x2.requires_grad = True 513 | # x_ = torch.cat((x1, x2, xt), dim=1) 514 | x_ = torch.cat((x1, x2), dim=1) 515 | if args.bound_constraint == 'soft': 516 | output = model(u, x_) # nu*nx 517 | else: 518 | output = ((x1**2) * (x2**2) * ((x1-1)**2) * ((x2-1)**2)) * model(u, x_) + (1/2) * (x1**2) + (1/2) * (x2**2) 519 | w = torch.ones(output.shape).to(device) 520 | phix = torch.autograd.grad(output, x1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 521 | phiy = torch.autograd.grad(output, x2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 522 | if init_mesh == True: 523 | loss_in = (criterion(x1 + phix, x1) + criterion(x2 + phiy, x2)) / 2 524 | loss = args.loss_weight1 * loss_bound + args.loss_weight0 * loss_in 525 | loss.backward() 526 | 527 | else: 528 | w2 = torch.ones(phix.shape).to(device) 529 | phixy = torch.autograd.grad(phix, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 530 | phixx = torch.autograd.grad(phix, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 531 | phiyx = torch.autograd.grad(phiy, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 532 | phiyy = torch.autograd.grad(phiy, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 533 | 534 | if args.experiment == 'burgers': 535 | ux_ = interpolate(ux.unsqueeze(1).repeat(1, args.batch_size_x_adam, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 536 | uy_ = interpolate(uy.unsqueeze(1).repeat(1, args.batch_size_x_adam, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 537 | elif args.experiment == 'cy': 538 | ux_ = interpolate(ux.unsqueeze(1).repeat(1, args.batch_size_x_adam, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 539 | uy_ = interpolate(uy.unsqueeze(1).repeat(1, args.batch_size_x_adam, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 540 | u_xi_x = ux_ * (1 + phixx) + uy_ * phiyx 541 | u_xi_y = ux_ * phixy + uy_ * (1 + phiyy) 542 | m_xi = monitor(alpha.unsqueeze(1).repeat(1, args.batch_size_x_adam).reshape(-1, 1), u_xi_x, u_xi_y) 543 | LHS = m_xi * ((1 + phixx) * (1 + phiyy) - phixy * phiyx) 544 | 545 | loss_in = criterion(LHS / RHS.unsqueeze(1).repeat(1, args.batch_size_x_adam).reshape(-1, 1), torch.ones_like(LHS)) 546 | loss_convex = torch.mean(torch.min(torch.tensor(0).type_as(phixx), 1 + phixx)**2 + torch.min(torch.tensor(0).type_as(phiyy), 1 + phiyy)**2) 547 | 548 | if args.loss_convex == True: 549 | loss = args.loss_weight1 * loss_bound + args.loss_weight0 * loss_in + args.loss_weight2 * loss_convex 550 | else: 551 | loss = args.loss_weight1 * loss_bound + args.loss_weight0 * loss_in 552 | loss.backward() 553 | 554 | if log_count1[0] % 200 == 0: 555 | loss_in_list.append(loss_in.item()) 556 | loss_convex_list.append(loss_convex.item()) 557 | loss_bound_list.append(loss_bound.item()) 558 | LHS_list.append(LHS.detach()) 559 | RHS_list.append(RHS) 560 | log_count1[0] += 1 561 | 562 | optimizer_adam.step() 563 | 564 | # LBFGS 565 | else: 566 | for i in range(np.max((1, int(args.train_sample_grid * all_u.shape[0] / (args.batch_size_x_lbfgs * args.batch_size_u_lbfgs))))): 567 | def closure(): 568 | if args.experiment == 'burgers': 569 | u, ux, uy, alpha, m, RHS, x = sample_train_data(all_u, args.batch_size_x_lbfgs, args.batch_size_u_lbfgs, device) # b 570 | bound1, bound2, bound3, bound4, bound1_u, bound2_u, bound3_u, bound4_u, bound1_m, bound2_m, bound3_m, bound4_m = sample_train_data_bound(all_u, args.batch_size_x_lbfgs, args.batch_size_u_lbfgs, device) # 4 * (b//4) 571 | elif args.experiment == 'cy': 572 | u, ux, uy, alpha, m, RHS, x = sample_train_data_tri(all_u, args.batch_size_x_lbfgs, args.batch_size_u_lbfgs, device) # b 573 | bound1, bound2, bound3, bound4, bound1_u, bound2_u, bound3_u, bound4_u = sample_train_data_bound_tri(all_u, args.batch_size_x_lbfgs, args.batch_size_u_lbfgs, device) # 4 * (b//4) 574 | weight = torch.ones_like(RHS).unsqueeze(-1) 575 | 576 | optimizer_lbfgs.zero_grad() 577 | 578 | if args.bound_constraint == 'soft': 579 | # loss of boundary condition 580 | if len(bound1) == 0: 581 | loss_bound1 = torch.zeros(1).to(device) 582 | else: 583 | bound11 = bound1[:, 0].view(-1, 1) 584 | bound12 = bound1[:, 1].view(-1, 1) 585 | # bound1t = bound1[:, 2].view(-1, 1) 586 | bound11.requires_grad = True 587 | bound12.requires_grad = True 588 | # X1 = torch.cat((bound11, bound12, bound1t), dim=1) 589 | X1 = torch.cat((bound11, bound12), dim=1) 590 | output_bound1 = model(bound1_u, X1) 591 | v1 = torch.ones(output_bound1.shape).to(device) 592 | bound1_x = torch.autograd.grad(output_bound1, bound11, grad_outputs=v1, retain_graph = True, create_graph=True, allow_unused=True)[0] 593 | loss_bound1 = criterion(bound1_x, torch.zeros_like(bound1_x)) 594 | 595 | if len(bound2) == 0: 596 | loss_bound2 = torch.zeros(1).to(device) 597 | else: 598 | bound21 = bound2[:, 0].view(-1, 1) 599 | bound22 = bound2[:, 1].view(-1, 1) 600 | # bound2t = bound2[:, 2].view(-1, 1) 601 | bound21.requires_grad = True 602 | bound22.requires_grad = True 603 | # X2 = torch.cat((bound21, bound22, bound2t), dim=1) 604 | X2 = torch.cat((bound21, bound22), dim=1) 605 | output_bound2 = model(bound2_u, X2) 606 | v2 = torch.ones(output_bound2.shape).to(device) 607 | bound2_x = torch.autograd.grad(output_bound2, bound21, grad_outputs=v2, retain_graph = True, create_graph=True, allow_unused=True)[0] 608 | loss_bound2 = criterion(bound2_x, torch.zeros_like(bound2_x)) 609 | 610 | if len(bound3) == 0: 611 | loss_bound3 = torch.zeros(1).to(device) 612 | else: 613 | bound31 = bound3[:, 0].view(-1, 1) 614 | bound32 = bound3[:, 1].view(-1, 1) 615 | # bound3t = bound3[:, 2].view(-1, 1) 616 | bound31.requires_grad = True 617 | bound32.requires_grad = True 618 | # X3 = torch.cat((bound31, bound32, bound3t), dim=1) 619 | X3 = torch.cat((bound31, bound32), dim=1) 620 | output_bound3 = model(bound3_u, X3) 621 | v3 = torch.ones(output_bound3.shape).to(device) 622 | bound3_y = torch.autograd.grad(output_bound3, bound32, grad_outputs=v3, retain_graph = True, create_graph=True, allow_unused=True)[0] 623 | loss_bound3 = criterion(bound3_y, torch.zeros_like(bound3_y)) 624 | 625 | if len(bound4) == 0: 626 | loss_bound4 = torch.zeros(1).to(device) 627 | else: 628 | bound41 = bound4[:, 0].view(-1, 1) 629 | bound42 = bound4[:, 1].view(-1, 1) 630 | # bound4t = bound4[:, 2].view(-1, 1) 631 | bound41.requires_grad = True 632 | bound42.requires_grad = True 633 | # X4 = torch.cat((bound41, bound42, bound4t), dim=1) 634 | X4 = torch.cat((bound41, bound42), dim=1) 635 | output_bound4 = model(bound4_u, X4) 636 | v4 = torch.ones(output_bound4.shape).to(device) 637 | bound4_y = torch.autograd.grad(output_bound4, bound42, grad_outputs=v4, retain_graph = True, create_graph=True, allow_unused=True)[0] 638 | loss_bound4 = criterion(bound4_y, torch.zeros_like(bound4_y)) 639 | 640 | loss_bound = (loss_bound1 + loss_bound2 + loss_bound3 + loss_bound4) / 4 641 | else: 642 | loss_bound = torch.tensor(0).to(device) 643 | 644 | # loss inside 645 | x1 = x[:, 0].view(x.shape[0], 1) 646 | x2 = x[:, 1].view(x.shape[0], 1) 647 | # xt = x[:, 2].view(x.shape[0], 1) 648 | x1.requires_grad = True 649 | x2.requires_grad = True 650 | # x_ = torch.cat((x1, x2, xt), dim=1) 651 | x_ = torch.cat((x1, x2), dim=1) 652 | if args.bound_constraint == 'soft': 653 | output = model(u, x_) # nu*nx 654 | else: 655 | output = ((x1**2) * (x2**2) * ((x1-1)**2) * ((x2-1)**2)) * model(u, x_) + (1/2) * (x1**2) + (1/2) * (x2**2) 656 | w = torch.ones(output.shape).to(device) 657 | phix = torch.autograd.grad(output, x1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 658 | phiy = torch.autograd.grad(output, x2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 659 | if init_mesh == True: 660 | loss_in = (criterion(x1 + phix, x1) + criterion(x2 + phiy, x2)) / 2 661 | loss = args.loss_weight1 * loss_bound + args.loss_weight0 * loss_in 662 | 663 | else: 664 | w2 = torch.ones(phix.shape).to(device) 665 | phixy = torch.autograd.grad(phix, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 666 | phixx = torch.autograd.grad(phix, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 667 | phiyx = torch.autograd.grad(phiy, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 668 | phiyy = torch.autograd.grad(phiy, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 669 | 670 | if args.experiment == 'burgers': 671 | ux_ = interpolate(ux.unsqueeze(1).repeat(1, args.batch_size_x_lbfgs, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 672 | uy_ = interpolate(uy.unsqueeze(1).repeat(1, args.batch_size_x_lbfgs, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 673 | elif args.experiment == 'cy': 674 | ux_ = interpolate(ux.unsqueeze(1).repeat(1, args.batch_size_x_lbfgs, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 675 | uy_ = interpolate(uy.unsqueeze(1).repeat(1, args.batch_size_x_lbfgs, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 676 | u_xi_x = ux_ * (1 + phixx) + uy_ * phiyx 677 | u_xi_y = ux_ * phixy + uy_ * (1 + phiyy) 678 | m_xi = monitor(alpha.unsqueeze(1).repeat(1, args.batch_size_x_lbfgs).reshape(-1, 1), u_xi_x, u_xi_y) 679 | LHS = m_xi * ((1 + phixx) * (1 + phiyy) - phixy * phiyx) 680 | 681 | loss_in = criterion(LHS / RHS.unsqueeze(1).repeat(1, args.batch_size_x_lbfgs).reshape(-1, 1), torch.ones_like(LHS)) 682 | loss_convex = torch.mean(torch.min(torch.tensor(0).type_as(phixx), 1 + phixx)**2 + torch.min(torch.tensor(0).type_as(phiyy), 1 + phiyy)**2) 683 | 684 | if args.loss_convex == True: 685 | loss = args.loss_weight1 * loss_bound + args.loss_weight0 * loss_in + args.loss_weight2 * loss_convex 686 | else: 687 | loss = args.loss_weight1 * loss_bound + args.loss_weight0 * loss_in 688 | 689 | loss.backward() 690 | 691 | if log_count2[0] % 200 == 0: 692 | loss_in_list.append(loss_in.item()) 693 | loss_convex_list.append(loss_convex.item()) 694 | loss_bound_list.append(loss_bound.item()) 695 | LHS_list.append(LHS.detach()) 696 | RHS_list.append(RHS) 697 | log_count2[0] += 1 698 | 699 | return loss 700 | 701 | optimizer_lbfgs.step(closure) 702 | 703 | test_equ = LHS_list[-1] / RHS_list[-1] - torch.tensor(1).to(device) 704 | test_equ_max = torch.max(test_equ) 705 | test_equ_min = torch.min(test_equ) 706 | test_equ_min = torch.min(test_equ) 707 | test_equ_mid = torch.median(test_equ) 708 | test_equ_loss = torch.mean(torch.abs(test_equ)) 709 | test_equ_loss_list.append(test_equ_loss.item()) 710 | test_equ_max_list.append(test_equ_max.item()) 711 | test_equ_min_list.append(test_equ_min.item()) 712 | test_equ_mid_list.append(test_equ_mid.item()) 713 | 714 | torch.cuda.empty_cache() 715 | end = datetime.now() 716 | 717 | # print & evaluate 718 | if epoch < n_epoch_adam + 1: 719 | scheduler_adam.step() 720 | 721 | if epoch % 1 == 0: 722 | print(end - start) 723 | print('Epoch: {} | Loss in: {} | Loss bound: {} | Loss convex: {} | Test equ loss: {:1.4f}'\ 724 | .format(epoch, loss_in_list[-1], loss_bound_list[-1], loss_convex_list[-1], test_equ_loss_list[-1])) 725 | logs_txt.append('Epoch: {} | Loss in: {} | Loss bound: {} | Loss convex: {} | Test equ loss: {:1.4f}'\ 726 | .format(epoch, loss_in_list[-1], loss_bound_list[-1], loss_convex_list[-1], test_equ_loss_list[-1])) 727 | if epoch % 1 == 0 or epoch == n_epoch_adam: 728 | if args.experiment == 'burgers': 729 | train_mean, train_std, train_minmax = evaluate(model, all_u, device, epoch) 730 | test_mean, test_std, test_minmax = evaluate(model, test_u, device, epoch) 731 | elif args.experiment =='cy': 732 | train_mean, train_std, train_minmax = evaluate_tri(model, all_u[:, :, 2], all_u[0, :, :2], device, epoch) 733 | test_mean, test_std, test_minmax = evaluate_tri(model, test_u[:, :, 2], all_u[0, :, :2], device, epoch) 734 | train_std_list.append(train_std) 735 | train_minmax_list.append(train_minmax) 736 | test_std_list.append(test_std) 737 | test_minmax_list.append(test_minmax) 738 | print('Train mean: {:1.6f} | Train std: {:1.6f} | Train minmax: {:1.6f} | Test mean: {:1.6f} | Test std: {:1.6f} | Test minmax: {:1.6f}'\ 739 | .format(train_mean, train_std, train_minmax, test_mean, test_std, test_minmax)) 740 | logs_txt.append('Train mean: {:1.6f} | Train std: {:1.6f} | Train minmax: {:1.6f} | Test mean: {:1.6f} | Test std: {:1.6f} | Test minmax: {:1.6f}'\ 741 | .format(train_mean, train_std, train_minmax, test_mean, test_std, test_minmax)) 742 | 743 | else: 744 | scheduler_lbfgs.step() 745 | 746 | if epoch % 1 == 0: 747 | print(end - start) 748 | print('Epoch: {} | Loss in: {} | Loss bound: {} | Loss convex: {} | Test equ loss: {:1.4f}'\ 749 | .format(epoch, loss_in_list[-1], loss_bound_list[-1], loss_convex_list[-1], test_equ_loss_list[-1])) 750 | logs_txt.append('Epoch: {} | Loss in: {} | Loss bound: {} | Loss convex: {} | Test equ loss: {:1.4f}'\ 751 | .format(epoch, loss_in_list[-1], loss_bound_list[-1], loss_convex_list[-1], test_equ_loss_list[-1])) 752 | if epoch % 1 == 0: 753 | if args.experiment == 'burgers': 754 | train_mean, train_std, train_minmax = evaluate(model, all_u, device, epoch) 755 | test_mean, test_std, test_minmax = evaluate(model, test_u, device, epoch) 756 | elif args.experiment =='cy': 757 | train_mean, train_std, train_minmax = evaluate_tri(model, all_u[:, :, 2], all_u[0, :, :2], device, epoch) 758 | test_mean, test_std, test_minmax = evaluate_tri(model, test_u[:, :, 2], all_u[0, :, :2], device, epoch) 759 | train_std_list.append(train_std) 760 | train_minmax_list.append(train_minmax) 761 | test_std_list.append(test_std) 762 | test_minmax_list.append(test_minmax) 763 | print('Train mean: {:1.6f} | Train std: {:1.6f} | Train minmax: {:1.6f} | Test mean: {:1.6f} | Test std: {:1.6f} | Test minmax: {:1.6f}'\ 764 | .format(train_mean, train_std, train_minmax, test_mean, test_std, test_minmax)) 765 | logs_txt.append('Train mean: {:1.6f} | Train std: {:1.6f} | Train minmax: {:1.6f} | Test mean: {:1.6f} | Test std: {:1.6f} | Test minmax: {:1.6f}'\ 766 | .format(train_mean, train_std, train_minmax, test_mean, test_std, test_minmax)) 767 | 768 | save_path = '{}/{}_{}_bound{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'\ 769 | .format(args.experiment, datetime.now(), args.rf, args.loss_bound_rf, args.epochs_rf, args.max_iter, args.sub_u, args.epochs_lbfgs,\ 770 | args.batch_size_u_adam, args.batch_size_x_adam, args.loss_weight1, args.train_sample_grid, args.branch_layers, args.lr_adam, args.trunk_layers, args.gamma_adam) 771 | 772 | torch.save({ 773 | 'model_state_dict': model.state_dict(), 774 | 'loss_in': loss_in_list, 775 | 'loss_bound': loss_bound_list, 776 | 'loss_convex': loss_convex_list, 777 | 'args': args, 778 | 'train_std': train_std_list, 779 | 'train_minmax': train_minmax_list, 780 | 'test_std': test_std_list, 781 | 'test_minmax': test_minmax_list, 782 | }, save_path) 783 | 784 | # random feature method 785 | if args.rf == True: 786 | c = 1 787 | for i in range(args.epochs_rf): 788 | start = datetime.now() 789 | # print("time start: ", start) 790 | print('random feature method epoch No.', i) 791 | if args.experiment == 'burgers': 792 | u, ux, uy, alpha, m, RHS, x = sample_train_data(all_u, args.batch_size_x_rf, args.batch_size_u_rf, device) # b 793 | bound1, bound2, bound3, bound4, bound1_u, bound2_u, bound3_u, bound4_u, bound1_m, bound2_m, bound3_m, bound4_m = sample_train_data_bound(all_u, args.batch_size_x_rf, args.batch_size_u_rf, device) # 4 * (b//4) 794 | elif args.experiment == 'cy': 795 | u, ux, uy, alpha, m, RHS, x = sample_train_data_tri(all_u, args.batch_size_x_rf, args.batch_size_u_rf, device) # b 796 | bound1, bound2, bound3, bound4, bound1_u, bound2_u, bound3_u, bound4_u = sample_train_data_bound_tri(all_u, args.batch_size_x_rf, args.batch_size_u_rf, device) # 4 * (b//4) 797 | # loss of boundary condition 798 | bound11 = bound1[:, 0].view(-1, 1) 799 | bound12 = bound1[:, 1].view(-1, 1) 800 | # bound1t = bound1[:, 2].view(-1, 1) 801 | bound11.requires_grad = True 802 | bound12.requires_grad = True 803 | # X1 = torch.cat((bound11, bound12, bound1t), dim=1) 804 | X1 = torch.cat((bound11, bound12), dim=1) 805 | output_bound1, second_out_bound1, branch_bound1 = model(bound1_u, X1, rf = True) 806 | so_x_bound1, so_y_bound1 = [], [] 807 | for k in range(int(second_out_bound1.shape[-1])): 808 | second_out_bound1_ = second_out_bound1[:, k] 809 | w = torch.ones_like(second_out_bound1_).to(device) 810 | so_x_bound1_ = torch.autograd.grad(second_out_bound1_, bound11, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 811 | so_y_bound1_ = torch.autograd.grad(second_out_bound1_, bound12, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 812 | so_x_bound1.append(so_x_bound1_) 813 | so_y_bound1.append(so_y_bound1_) 814 | so_x_bound1 = torch.stack(so_x_bound1)[:, :, 0].permute(1, 0) 815 | so_y_bound1 = torch.stack(so_y_bound1)[:, :, 0].permute(1, 0) 816 | 817 | bound21 = bound2[:, 0].view(-1, 1) 818 | bound22 = bound2[:, 1].view(-1, 1) 819 | # bound2t = bound2[:, 2].view(-1, 1) 820 | bound21.requires_grad = True 821 | bound22.requires_grad = True 822 | # X2 = torch.cat((bound21, bound22, bound2t), dim=1) 823 | X2 = torch.cat((bound21, bound22), dim=1) 824 | output_bound2, second_out_bound2, branch_bound2 = model(bound2_u, X2, rf = True) 825 | so_x_bound2, so_y_bound2 = [], [] 826 | for k in range(int(second_out_bound2.shape[-1])): 827 | second_out_bound2_ = second_out_bound2[:, k] 828 | w = torch.ones_like(second_out_bound2_).to(device) 829 | so_x_bound2_ = torch.autograd.grad(second_out_bound2_, bound21, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 830 | so_y_bound2_ = torch.autograd.grad(second_out_bound2_, bound22, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 831 | so_x_bound2.append(so_x_bound2_) 832 | so_y_bound2.append(so_y_bound2_) 833 | so_x_bound2 = torch.stack(so_x_bound2)[:, :, 0].permute(1, 0) 834 | so_y_bound2 = torch.stack(so_y_bound2)[:, :, 0].permute(1, 0) 835 | 836 | bound31 = bound3[:, 0].view(-1, 1) 837 | bound32 = bound3[:, 1].view(-1, 1) 838 | # bound3t = bound3[:, 2].view(-1, 1) 839 | bound31.requires_grad = True 840 | bound32.requires_grad = True 841 | # X3 = torch.cat((bound31, bound32, bound3t), dim=1) 842 | X3 = torch.cat((bound31, bound32), dim=1) 843 | output_bound3, second_out_bound3, branch_bound3 = model(bound3_u, X3, rf = True) 844 | so_x_bound3, so_y_bound3 = [], [] 845 | for k in range(int(second_out_bound3.shape[-1])): 846 | second_out_bound3_ = second_out_bound3[:, k] 847 | w = torch.ones_like(second_out_bound3_).to(device) 848 | so_x_bound3_ = torch.autograd.grad(second_out_bound3_, bound31, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 849 | so_y_bound3_ = torch.autograd.grad(second_out_bound3_, bound32, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 850 | so_x_bound3.append(so_x_bound3_) 851 | so_y_bound3.append(so_y_bound3_) 852 | so_x_bound3 = torch.stack(so_x_bound3)[:, :, 0].permute(1, 0) 853 | so_y_bound3 = torch.stack(so_y_bound3)[:, :, 0].permute(1, 0) 854 | 855 | bound41 = bound4[:, 0].view(-1, 1) 856 | bound42 = bound4[:, 1].view(-1, 1) 857 | # bound4t = bound4[:, 2].view(-1, 1) 858 | bound41.requires_grad = True 859 | bound42.requires_grad = True 860 | # X4 = torch.cat((bound41, bound42, bound4t), dim=1) 861 | X4 = torch.cat((bound41, bound42), dim=1) 862 | output_bound4, second_out_bound4, branch_bound4 = model(bound4_u, X4, rf = True) 863 | so_x_bound4, so_y_bound4 = [], [] 864 | for k in range(int(second_out_bound4.shape[-1])): 865 | second_out_bound4_ = second_out_bound4[:, k] 866 | w = torch.ones_like(second_out_bound4_).to(device) 867 | so_x_bound4_ = torch.autograd.grad(second_out_bound4_, bound41, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 868 | so_y_bound4_ = torch.autograd.grad(second_out_bound4_, bound42, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 869 | so_x_bound4.append(so_x_bound4_) 870 | so_y_bound4.append(so_y_bound4_) 871 | so_x_bound4 = torch.stack(so_x_bound4)[:, :, 0].permute(1, 0) 872 | so_y_bound4 = torch.stack(so_y_bound4)[:, :, 0].permute(1, 0) 873 | 874 | # loss in 875 | x1 = x[:, 0].view(x.shape[0], 1) 876 | x2 = x[:, 1].view(x.shape[0], 1) 877 | # xt = x[:, 2].view(x.shape[0], 1) 878 | x1.requires_grad = True 879 | x2.requires_grad = True 880 | # x_ = torch.cat((x1, x2, xt), dim=1) 881 | x_ = torch.cat((x1, x2), dim=1) 882 | output, second_out, branch = model(u, x_, rf = True) 883 | so_x, so_y, so_xx, so_xy, so_yx, so_yy = [], [], [], [], [], [] 884 | for k in range(int(second_out.shape[-1])): 885 | second_out_ = second_out[:, k] 886 | w = torch.ones_like(second_out_).to(device) 887 | so_x_ = torch.autograd.grad(second_out_, x1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 888 | so_y_ = torch.autograd.grad(second_out_, x2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 889 | w2 = torch.ones_like(so_x_).to(device) 890 | so_xy_ = torch.autograd.grad(so_x_, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 891 | so_xx_ = torch.autograd.grad(so_x_, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 892 | so_yx_ = torch.autograd.grad(so_y_, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 893 | so_yy_ = torch.autograd.grad(so_y_, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 894 | so_x.append(so_x_) 895 | so_y.append(so_y_) 896 | so_xx.append(so_xx_) 897 | so_xy.append(so_xy_) 898 | so_yx.append(so_yx_) 899 | so_yy.append(so_yy_) 900 | so_x = torch.stack(so_x)[:, :, 0].permute(1, 0) 901 | so_y = torch.stack(so_y)[:, :, 0].permute(1, 0) 902 | so_xx = torch.stack(so_xx)[:, :, 0].permute(1, 0) 903 | so_xy = torch.stack(so_xy)[:, :, 0].permute(1, 0) 904 | so_yx = torch.stack(so_yx)[:, :, 0].permute(1, 0) 905 | so_yy = torch.stack(so_yy)[:, :, 0].permute(1, 0) 906 | 907 | w = torch.ones(output.shape).to(device) 908 | phix = torch.autograd.grad(output, x1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 909 | phiy = torch.autograd.grad(output, x2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 910 | w2 = torch.ones(phix.shape).to(device) 911 | phixy = torch.autograd.grad(phix, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 912 | phixx = torch.autograd.grad(phix, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 913 | phiyx = torch.autograd.grad(phiy, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 914 | phiyy = torch.autograd.grad(phiy, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 915 | 916 | ux_ = interpolate(ux.unsqueeze(1).repeat(1, args.batch_size_x_rf, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 917 | uy_ = interpolate(uy.unsqueeze(1).repeat(1, args.batch_size_x_rf, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 918 | u_xi_x = ux_ * (1 + phixx) + uy_ * phiyx 919 | u_xi_y = ux_ * phixy + uy_ * (1 + phiyy) 920 | m_xi = monitor(alpha.unsqueeze(1).repeat(1, args.batch_size_x_rf).reshape(-1, 1), u_xi_x, u_xi_y) 921 | 922 | init = model.out_nn.layers[-1].weight.data.reshape(-1, 1) 923 | if args.rf_opt_alg == 'BFGS': 924 | desired_weights = minimize(lambda x: random_feature_torch2(x, args.convex_rel, second_out,\ 925 | second_out_bound1, second_out_bound2, second_out_bound3, second_out_bound4,\ 926 | args, alpha, x1, x2, ux, uy, so_x_bound1, so_x_bound2, so_y_bound3, so_y_bound4,\ 927 | so_x, so_y, so_xx,so_yy, so_xy, so_yx, RHS), 928 | init, 929 | method='bfgs', 930 | options=dict(line_search='strong-wolfe'), 931 | max_iter=args.max_iter, 932 | disp=0, 933 | tol=0) 934 | elif args.rf_opt_alg == 'Newton': 935 | desired_weights = minimize(lambda x: random_feature_torch2(x, args.convex_rel, second_out,\ 936 | second_out_bound1, second_out_bound2, second_out_bound3, second_out_bound4,\ 937 | args, alpha, x1, x2, ux, uy, so_x_bound1, so_x_bound2, so_y_bound3, so_y_bound4,\ 938 | so_x, so_y, so_xx, so_yy, so_xy, so_yx, RHS), 939 | init, 940 | method='newton-cg', 941 | options=dict(line_search='strong-wolfe'), 942 | max_iter=args.max_iter, 943 | disp=0, 944 | tol=0) 945 | model.out_nn.layers[-1].weight.data = desired_weights.x.reshape(1, second_out.shape[1]) 946 | end = datetime.now() 947 | print("time per epoch of random feature method: ", end - start) 948 | c = c + 1 949 | 950 | # test for random feature method 951 | if args.experiment == 'burgers': 952 | u, ux, uy, alpha, m, RHS, x = sample_train_data(all_u, args.batch_size_x_rf, args.batch_size_u_rf, device) # b 953 | bound1, bound2, bound3, bound4, bound1_u, bound2_u, bound3_u, bound4_u, bound1_m, bound2_m, bound3_m, bound4_m = sample_train_data_bound(all_u, args.batch_size_x_rf, args.batch_size_u_rf, device) # 4 * (b//4) 954 | elif args.experiment == 'cy': 955 | u, ux, uy, alpha, m, RHS, x = sample_train_data_tri(all_u, args.batch_size_x_rf, args.batch_size_u_rf, device) # b 956 | bound1, bound2, bound3, bound4, bound1_u, bound2_u, bound3_u, bound4_u = sample_train_data_bound_tri(all_u, args.batch_size_x_rf, args.batch_size_u_rf, device) # 4 * (b//4) 957 | 958 | # loss of boundary condition 959 | if len(bound1) == 0: 960 | loss_bound1 = torch.zeros(1).to(device) 961 | else: 962 | bound11 = bound1[:, 0].view(-1, 1) 963 | bound12 = bound1[:, 1].view(-1, 1) 964 | # bound1t = bound1[:, 2].view(-1, 1) 965 | bound11.requires_grad = True 966 | bound12.requires_grad = True 967 | # X1 = torch.cat((bound11, bound12, bound1t), dim=1) 968 | X1 = torch.cat((bound11, bound12), dim=1) 969 | output_bound1 = model(bound1_u, X1) 970 | v1 = torch.ones(output_bound1.shape).to(device) 971 | bound1_x = torch.autograd.grad(output_bound1, bound11, grad_outputs=v1, retain_graph = True, create_graph=True, allow_unused=True)[0] 972 | loss_bound1 = criterion(bound1_x, torch.zeros_like(bound1_x)) 973 | 974 | if len(bound2) == 0: 975 | loss_bound2 = torch.zeros(1).to(device) 976 | else: 977 | bound21 = bound2[:, 0].view(-1, 1) 978 | bound22 = bound2[:, 1].view(-1, 1) 979 | # bound2t = bound2[:, 2].view(-1, 1) 980 | bound21.requires_grad = True 981 | bound22.requires_grad = True 982 | # X2 = torch.cat((bound21, bound22, bound2t), dim=1) 983 | X2 = torch.cat((bound21, bound22), dim=1) 984 | output_bound2 = model(bound2_u, X2) 985 | v2 = torch.ones(output_bound2.shape).to(device) 986 | bound2_x = torch.autograd.grad(output_bound2, bound21, grad_outputs=v2, retain_graph = True, create_graph=True, allow_unused=True)[0] 987 | loss_bound2 = criterion(bound2_x, torch.zeros_like(bound2_x)) 988 | 989 | if len(bound3) == 0: 990 | loss_bound3 = torch.zeros(1).to(device) 991 | else: 992 | bound31 = bound3[:, 0].view(-1, 1) 993 | bound32 = bound3[:, 1].view(-1, 1) 994 | # bound3t = bound3[:, 2].view(-1, 1) 995 | bound31.requires_grad = True 996 | bound32.requires_grad = True 997 | # X3 = torch.cat((bound31, bound32, bound3t), dim=1) 998 | X3 = torch.cat((bound31, bound32), dim=1) 999 | output_bound3 = model(bound3_u, X3) 1000 | v3 = torch.ones(output_bound3.shape).to(device) 1001 | bound3_y = torch.autograd.grad(output_bound3, bound32, grad_outputs=v3, retain_graph = True, create_graph=True, allow_unused=True)[0] 1002 | loss_bound3 = criterion(bound3_y, torch.zeros_like(bound3_y)) 1003 | 1004 | if len(bound4) == 0: 1005 | loss_bound4 = torch.zeros(1).to(device) 1006 | else: 1007 | bound41 = bound4[:, 0].view(-1, 1) 1008 | bound42 = bound4[:, 1].view(-1, 1) 1009 | # bound4t = bound4[:, 2].view(-1, 1) 1010 | bound41.requires_grad = True 1011 | bound42.requires_grad = True 1012 | # X4 = torch.cat((bound41, bound42, bound4t), dim=1) 1013 | X4 = torch.cat((bound41, bound42), dim=1) 1014 | output_bound4, second_out_bound4, branch = model(bound4_u, X4, rf=True) 1015 | v4 = torch.ones(output_bound4.shape).to(device) 1016 | bound4_y = torch.autograd.grad(output_bound4, bound42, grad_outputs=v4, retain_graph = True, create_graph=True, allow_unused=True)[0] 1017 | loss_bound4 = criterion(bound4_y, torch.zeros_like(bound4_y)) 1018 | loss_bound = (loss_bound1 + loss_bound2 + loss_bound3 + loss_bound4) / 4 1019 | 1020 | # loss inside 1021 | x1 = x[:, 0].view(x.shape[0], 1) 1022 | x2 = x[:, 1].view(x.shape[0], 1) 1023 | # xt = x[:, 2].view(x.shape[0], 1) 1024 | x1.requires_grad = True 1025 | x2.requires_grad = True 1026 | # x_ = torch.cat((x1, x2, xt), dim=1) 1027 | x_ = torch.cat((x1, x2), dim=1) 1028 | if args.bound_constraint == 'soft': 1029 | output = model(u, x_) # nu*nx 1030 | else: 1031 | output = ((x1**2) * (x2**2) * ((x1-1)**2) * ((x2-1)**2)) * model(u, x_) + (1/2) * (x1**2) + (1/2) * (x2**2) 1032 | w = torch.ones(output.shape).to(device) 1033 | phix = torch.autograd.grad(output, x1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 1034 | phiy = torch.autograd.grad(output, x2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] 1035 | 1036 | w2 = torch.ones(phix.shape).to(device) 1037 | phixy = torch.autograd.grad(phix, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 1038 | phixx = torch.autograd.grad(phix, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 1039 | phiyx = torch.autograd.grad(phiy, x1, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 1040 | phiyy = torch.autograd.grad(phiy, x2, grad_outputs=w2, retain_graph = True, create_graph=True, allow_unused=True)[0] 1041 | 1042 | ux_ = interpolate(ux.unsqueeze(1).repeat(1, args.batch_size_x_rf, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 1043 | uy_ = interpolate(uy.unsqueeze(1).repeat(1, args.batch_size_x_rf, 1, 1).reshape(-1, ux.shape[-1], ux.shape[-1]), x1 + phix, x2 + phiy) 1044 | u_xi_x = ux_ * (1 + phixx) + uy_ * phiyx 1045 | u_xi_y = ux_ * phixy + uy_ * (1 + phiyy) 1046 | m_xi = monitor(alpha.unsqueeze(1).repeat(1, args.batch_size_x_rf).reshape(-1, 1), u_xi_x, u_xi_y) 1047 | LHS = m_xi * ((1 + phixx) * (1 + phiyy) - phixy * phiyx) 1048 | 1049 | loss_in = criterion(LHS / RHS.unsqueeze(1).repeat(1, args.batch_size_x_rf).reshape(-1, 1), torch.ones_like(LHS)) 1050 | loss_convex = torch.mean(torch.min(torch.tensor(0).type_as(phixx), 1 + phixx)**2 + torch.min(torch.tensor(0).type_as(phiyy), 1 + phiyy)**2) 1051 | 1052 | test_equ = LHS / RHS - torch.tensor(1).to(device) 1053 | test_equ_loss = torch.mean(torch.abs(test_equ)) 1054 | test_equ_loss_list.append(test_equ_loss.item()) 1055 | 1056 | loss_in_list.append(loss_in.item()) 1057 | loss_bound_list.append(loss_bound.item()) 1058 | print('Epoch: {} | Loss in: {} | Loss bound: {} | Loss convex: {} | Test equ loss: {:1.4f}'\ 1059 | .format(epoch + c - 1, loss_in_list[-1], loss_bound_list[-1], 0, test_equ_loss_list[-1])) 1060 | logs_txt.append('Epoch: {} | Loss in: {} | Loss bound: {} | Loss convex: {} | Test equ loss: {:1.4f}'\ 1061 | .format(epoch + c - 1, loss_in_list[-1], loss_bound_list[-1], 0, test_equ_loss_list[-1])) 1062 | 1063 | if args.experiment == 'burgers': 1064 | train_mean, train_std, train_minmax = evaluate(model, all_u, device, epoch) 1065 | test_mean, test_std, test_minmax = evaluate(model, test_u, device, epoch) 1066 | elif args.experiment =='cy': 1067 | train_mean, train_std, train_minmax = evaluate_tri(model, all_u[:, :, 2], all_u[0, :, :2], device, epoch) 1068 | test_mean, test_std, test_minmax = evaluate_tri(model, test_u[:, :, 2], all_u[0, :, :2], device, epoch) 1069 | train_std_list.append(train_std) 1070 | train_minmax_list.append(train_minmax) 1071 | test_std_list.append(test_std) 1072 | test_minmax_list.append(test_minmax) 1073 | print('Train mean: {:1.6f} | Train std: {:1.6f} | Train minmax: {:1.6f} | Test mean: {:1.6f} | Test std: {:1.6f} | Test minmax: {:1.6f}'\ 1074 | .format(train_mean, train_std, train_minmax, test_mean, test_std, test_minmax)) 1075 | logs_txt.append('Train mean: {:1.6f} | Train std: {:1.6f} | Train minmax: {:1.6f} | Test mean: {:1.6f} | Test std: {:1.6f} | Test minmax: {:1.6f}'\ 1076 | .format(train_mean, train_std, train_minmax, test_mean, test_std, test_minmax)) 1077 | 1078 | save_path = '{}/{}_{}_bound{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'\ 1079 | .format(args.experiment, datetime.now(), args.rf, args.loss_bound_rf, args.epochs_rf, args.max_iter, args.sub_u, args.epochs_lbfgs,\ 1080 | args.batch_size_u_adam, args.batch_size_x_adam, args.loss_weight1, args.train_sample_grid, args.branch_layers, args.lr_adam, args.trunk_layers, args.gamma_adam) 1081 | torch.save({ 1082 | 'model_state_dict': model.state_dict(), 1083 | 'loss_in': loss_in_list, 1084 | 'loss_bound': loss_bound_list, 1085 | 'loss_convex': loss_convex_list, 1086 | 'args': args, 1087 | 'train_std': train_std_list, 1088 | 'train_minmax': train_minmax_list, 1089 | 'test_std': test_std_list, 1090 | 'test_minmax': test_minmax_list, 1091 | }, save_path) 1092 | print(save_path) 1093 | 1094 | return model, loss_in_list, loss_bound_list, loss_convex_list, test_equ_loss_list, test_equ_max_list, test_equ_min_list, test_equ_mid_list,\ 1095 | train_std_list, train_minmax_list, test_std_list, test_minmax_list, itp_list1, itp_list2, logs_txt 1096 | 1097 | 1098 | #####evaluate & plot##### 1099 | 1100 | def interpolate3(u, init_x, init_y, x, y, n): 1101 | d = -torch.norm(torch.cat((init_x, init_y), dim=-1).unsqueeze(0).repeat(x.shape[0], 1, 1) - torch.cat((x, y), dim=-1).unsqueeze(1).repeat(1, init_x.shape[0], 1), dim=-1) * n 1102 | normalize = nn.Softmax(dim=-1) 1103 | weight = normalize(d) 1104 | interpolated = torch.sum(u.reshape(1, init_x.shape[0]).repeat(x.shape[0], 1) * weight, dim=-1) 1105 | 1106 | return interpolated 1107 | 1108 | 1109 | def itp_error(u, model, device, plot): 1110 | ori_nx = u.shape[-2] 1111 | ori_ny = u.shape[-1] 1112 | ori_x = torch.linspace(0, 1, ori_nx).to(device) 1113 | ori_y = torch.linspace(0, 1, ori_ny).to(device) 1114 | ori_grid_x, ori_grid_y = torch.meshgrid(ori_x, ori_y) 1115 | 1116 | nx, ny = int(ori_nx / 4), int(ori_nx / 4) 1117 | grid1 = np.linspace(0, 1, nx) 1118 | grid2 = np.linspace(0, 1, ny) 1119 | grid = torch.tensor(np.array(np.meshgrid(grid1, grid2)), dtype=torch.float).reshape(2, -1).permute(1, 0).to(device) 1120 | xi1, xi2 = grid[:, [0]], grid[:, [1]] 1121 | xi1.requires_grad = True 1122 | xi2.requires_grad = True 1123 | xi = torch.cat((xi1, xi2), dim=-1) 1124 | phi = model(u[[0]], xi) 1125 | w = torch.ones(phi.shape).to(device) 1126 | mesh_x1 = (torch.autograd.grad(phi, xi1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi1) 1127 | mesh_y1 = (torch.autograd.grad(phi, xi2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi2) 1128 | 1129 | mesh_x2, mesh_y2 = xi1, xi2 1130 | 1131 | mesh_u1 = interpolate3(u.reshape(-1, ori_nx, ori_ny), ori_grid_x.reshape(-1, 1), ori_grid_y.reshape(-1, 1), mesh_x1, mesh_y1, ori_nx).reshape(-1, nx, ny) 1132 | mesh_u2 = interpolate3(u.reshape(-1, ori_nx, ori_ny), ori_grid_x.reshape(-1, 1), ori_grid_y.reshape(-1, 1), mesh_x2, mesh_y2, ori_nx).reshape(-1, nx, ny) 1133 | interpolated_uni_u1 = interpolate3(mesh_u1, mesh_x1, mesh_y1, ori_grid_x.reshape(-1, 1), ori_grid_y.reshape(-1, 1), ori_nx).reshape(-1, ori_nx, ori_ny) 1134 | interpolated_uni_u2 = interpolate3(mesh_u2, mesh_x2, mesh_y2, ori_grid_x.reshape(-1, 1), ori_grid_y.reshape(-1, 1), ori_nx).reshape(-1, ori_nx, ori_ny) 1135 | 1136 | itp_error1 = torch.norm((interpolated_uni_u1-u)).item()/torch.norm(u).item() 1137 | itp_error2 = torch.norm((interpolated_uni_u2-u)).item()/torch.norm(u).item() 1138 | 1139 | if plot == True: 1140 | norm = matplotlib.colors.Normalize(vmin=(torch.abs(interpolated_uni_u1-u)).cpu().min(), vmax=(torch.abs(interpolated_uni_u1-u)).cpu().max()) 1141 | plt.colorbar(cm.ScalarMappable(norm=norm, cmap=plt.cm.binary), format='%.2f') 1142 | plt.contourf((torch.abs(interpolated_uni_u1-u))[0].detach().cpu().numpy(), 50, cmap=plt.cm.binary, norm = norm) 1143 | plt.savefig("itp/{}.png".format(datetime.now())) 1144 | plt.clf() 1145 | 1146 | return itp_error1, itp_error2 1147 | 1148 | 1149 | def triangle_area_and_centroid(v1, v2, v3): 1150 | x1, y1 = v1 1151 | x2, y2 = v2 1152 | x3, y3 = v3 1153 | 1154 | area = 0.5 * abs((x1 * (y2 - y3)) + (x2 * (y3 - y1)) + (x3 * (y1 - y2))) 1155 | 1156 | centroid_x = (x1 + x2 + x3) / 3 1157 | centroid_y = (y1 + y2 + y3) / 3 1158 | 1159 | return area, (centroid_x, centroid_y) 1160 | 1161 | 1162 | def evaluate_tri(model, u, grid, device, epoch): 1163 | u = u.to(device) 1164 | grid = grid.to(device) 1165 | xi1, xi2 = grid[:, [0]], grid[:, [1]] 1166 | xi1.requires_grad = True 1167 | xi2.requires_grad = True 1168 | xi = torch.cat((xi1, xi2), dim=-1) 1169 | 1170 | n = int(np.sqrt(u.shape[-1])) 1171 | uni_grid = torch.tensor(np.array(np.meshgrid(np.linspace(0, 1, n), np.linspace(0, 1, n))), dtype=torch.float)\ 1172 | .reshape(2, -1).permute(1, 0).to(device) 1173 | 1174 | x1 = xi1[:, 0].cpu().detach().numpy() 1175 | x2 = xi2[:, 0].cpu().detach().numpy() 1176 | points = np.column_stack((x1, x2)) 1177 | tri = Delaunay(points) 1178 | triangles_indices = tri.simplices 1179 | 1180 | mean = [] 1181 | std = [] 1182 | minmax = [] 1183 | # idx = np.random.choice(u.shape[0], 30, replace=False) 1184 | idx = np.random.choice(u.shape[0], min(150, u.shape[0]), replace=False) 1185 | for t in idx: 1186 | phi = model(u[[t]], xi) 1187 | w = torch.ones(phi.shape).to(device) 1188 | x1 = (torch.autograd.grad(phi, xi1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi1)[:, 0].detach().cpu().numpy() 1189 | x2 = (torch.autograd.grad(phi, xi2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi2)[:, 0].detach().cpu().numpy() 1190 | points = np.column_stack((x1, x2)) 1191 | 1192 | areas = [] 1193 | centroids = [] 1194 | for triangle_indices in triangles_indices: 1195 | v1, v2, v3 = points[triangle_indices] 1196 | area, centroid = triangle_area_and_centroid(v1, v2, v3) 1197 | areas.append(area) 1198 | centroids.append(centroid) 1199 | areas = torch.tensor(np.array(areas), device=device) 1200 | centroids = torch.tensor(np.array(centroids), device=device) 1201 | center1 = centroids[:, [0]] 1202 | center2 = centroids[:, [1]] 1203 | 1204 | x1_ = uni_grid[:, [0]] 1205 | x2_ = uni_grid[:, [1]] 1206 | x1_.requires_grad = True 1207 | x2_.requires_grad = True 1208 | u_ = interpolate_tri(u[[t]].unsqueeze(1).repeat(1, n**2, 1).reshape(-1, u.shape[-1]), \ 1209 | xi1.unsqueeze(0).unsqueeze(0).repeat(1, n**2, 1, 1).reshape(-1, u.shape[-1], 1), \ 1210 | xi2.unsqueeze(0).unsqueeze(0).repeat(1, n**2, 1, 1).reshape(-1, u.shape[-1], 1), \ 1211 | x1_.unsqueeze(1).repeat(1, u.shape[-1], 1), x2_.unsqueeze(1).repeat(1, u.shape[-1], 1)) 1212 | w = torch.ones(u_.shape).to(device) 1213 | uni_ux = torch.autograd.grad(u_, x1_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0].reshape(1, n, n) 1214 | uni_uy = torch.autograd.grad(u_, x2_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0].reshape(1, n, n) 1215 | alpha = torch.sum((torch.abs(uni_ux)**2 + torch.abs(uni_uy)**2)**(1/2), dim=(-2, -1)) / (n-1)**2 1216 | m = monitor(alpha.unsqueeze(-1).unsqueeze(-1).repeat(1, n, n), uni_ux, uni_uy).reshape(1, -1) 1217 | RHS = torch.sum(m, dim=(-2, -1)) / (n-1)**2 1218 | 1219 | m_center = [] 1220 | N = int(center1.shape[0] / 2) 1221 | m_center.append(interpolate_tri(m.repeat(N, 1), x1_[None].repeat(N, 1, 1), x2_[None].repeat(N, 1, 1),\ 1222 | center1[:N].unsqueeze(1).repeat(1, n**2, 1), center2[:N].unsqueeze(1).repeat(1, n**2, 1))) 1223 | m_center.append(interpolate_tri(m.repeat(center1.shape[0] - N, 1), x1_[None].repeat(center1.shape[0] - N, 1, 1),\ 1224 | x2_[None].repeat(center1.shape[0] - N, 1, 1), center1[N:].unsqueeze(1).repeat(1, n**2, 1),\ 1225 | center2[N:].unsqueeze(1).repeat(1, n**2, 1))) 1226 | m_center = torch.cat(m_center).reshape(-1) 1227 | m_per_grid = m_center * areas 1228 | mean.append(torch.mean(m_per_grid).cpu().detach().numpy()) 1229 | std.append(torch.std(m_per_grid).cpu().detach().numpy()) 1230 | minmax.append((torch.max(m_per_grid) - torch.min(m_per_grid)).cpu().detach().numpy()) 1231 | 1232 | return np.mean(mean), np.mean(std), np.mean(minmax) 1233 | 1234 | 1235 | def evaluate(model, u, device, epoch): 1236 | # computational mesh 1237 | s = u.shape[-1] 1238 | grid1 = np.linspace(0, 1, s) 1239 | grid2 = np.linspace(0, 1, s) 1240 | grid = torch.tensor(np.array(np.meshgrid(grid1, grid2)), dtype=torch.float).reshape(2, -1).permute(1, 0).to(device) 1241 | xi1, xi2 = grid[:, [0]], grid[:, [1]] 1242 | xi1.requires_grad = True 1243 | xi2.requires_grad = True 1244 | xi = torch.cat((xi1, xi2), dim=-1) 1245 | 1246 | # monitor function 1247 | u = u.to(device) 1248 | ux = diff_x(u) * (u.shape[-1] - 1) 1249 | uy = diff_y(u) * (u.shape[-1] - 1) 1250 | alpha = torch.sum((torch.abs(ux)**2 + torch.abs(uy)**2)**(1/2), dim=(-2, -1)) / (u.shape[-1]-1)**2 1251 | m = monitor(alpha.unsqueeze(-1).unsqueeze(-1).repeat(1, ux.shape[-1], ux.shape[-1]), ux, uy) 1252 | ideal_m_per_grid = ((torch.sum(m, dim=(-2, -1)) / (u.shape[-1]-1)**2) / (s - 1)**2).cpu().numpy() 1253 | 1254 | mean = [] 1255 | std = [] 1256 | minmax = [] 1257 | # idx = np.random.choice(u.shape[0], 30, replace=False) 1258 | idx = np.random.choice(u.shape[0], u.shape[0], replace=False) 1259 | for t in idx: 1260 | phi = model(u[[t]], xi) 1261 | w = torch.ones(phi.shape).to(device) 1262 | x1 = (torch.autograd.grad(phi, xi1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi1).reshape(s, s) 1263 | x2 = (torch.autograd.grad(phi, xi2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi2).reshape(s, s) 1264 | bottom_left1, bottom_left2 = x1[:(s-1), :(s-1)], x2[:(s-1), :(s-1)] 1265 | bottom_right1, bottom_right2 = x1[1:s, :(s-1)], x2[1:s, :(s-1)] 1266 | top_left1, top_left2 = x1[:(s-1), 1:s], x2[:(s-1), 1:s] 1267 | top_right1, top_right2 = x1[1:s, 1:s], x2[1:s, 1:s] 1268 | # diagonal 1269 | d1 = ((bottom_left1 - top_right1) ** 2 + (bottom_left2 - top_right2) ** 2) ** 0.5 1270 | d2 = ((bottom_right1 - top_left1) ** 2 + (bottom_right2 - top_left2) ** 2) ** 0.5 1271 | # area of the quadrilateral 1272 | area = d1 * d2 / 2 1273 | center1, center2 = (bottom_left1 + bottom_right1 + top_left1 + top_right1) / 4, (bottom_left2 + bottom_right2 + top_left2 + top_right2) / 4 1274 | m_center = [] 1275 | N = int((s-1)**2 / 2) 1276 | m_center.append(interpolate(m[[t]].repeat(N, 1, 1), center1.reshape(-1, 1)[:N], center2.reshape(-1, 1)[:N])) 1277 | m_center.append(interpolate(m[[t]].repeat((s-1)**2 - N, 1, 1), center1.reshape(-1, 1)[N:], center2.reshape(-1, 1)[N:])) 1278 | m_center = torch.cat(m_center).reshape(s-1, s-1) 1279 | m_per_grid = m_center * area 1280 | mean.append(torch.mean(m_per_grid).cpu().detach().numpy()) 1281 | std.append(torch.std(m_per_grid).cpu().detach().numpy()) 1282 | minmax.append((torch.max(m_per_grid) - torch.min(m_per_grid)).cpu().detach().numpy()) 1283 | 1284 | return np.mean(mean), np.mean(std), np.mean(minmax) 1285 | 1286 | 1287 | 1288 | def plot_mesh_res_tri_s(s, u, model, fig, axes, args, device): 1289 | u = u[:, :, 2].to(device) 1290 | 1291 | # mesh 1292 | grid = model.ori_grid 1293 | xi1, xi2 = grid[:, [0]], grid[:, [1]] 1294 | xi1.requires_grad = True 1295 | xi2.requires_grad = True 1296 | xi = torch.cat((xi1, xi2), dim=-1) 1297 | 1298 | grid1 = np.linspace(0, 1, s) 1299 | grid2 = np.linspace(0, 1, s) 1300 | grid = torch.tensor(np.array(np.meshgrid(grid1, grid2)), dtype=torch.float).reshape(2, -1).permute(1, 0).to(device) 1301 | xi1_, xi2_ = grid[:, [0]], grid[:, [1]] 1302 | xi1_.requires_grad = True 1303 | xi2_.requires_grad = True 1304 | xi_ = torch.cat((xi1_, xi2_), dim=-1) 1305 | 1306 | n = int(np.sqrt(u.shape[-1])) 1307 | uni_grid = torch.tensor(np.array(np.meshgrid(np.linspace(0, 1, n), np.linspace(0, 1, n))), dtype=torch.float)\ 1308 | .reshape(2, -1).permute(1, 0).to(device) 1309 | 1310 | plt.xticks([0, int(u.shape[-1]/2-1), n-1],['0.0','0.5','1.0']) 1311 | plt.yticks([0, int(u.shape[-1]/2-1), n-1],['0.0','0.5','1.0']) 1312 | 1313 | for i in range(5): 1314 | # t = 0 1315 | t = 6*i + 5 1316 | plt.subplot(1, 5, i+1) 1317 | plt.title('t={}'.format(t), fontsize=18) 1318 | ax = axes[i] 1319 | 1320 | x1_ = uni_grid[:, [0]] 1321 | x2_ = uni_grid[:, [1]] 1322 | x1_.requires_grad = True 1323 | x2_.requires_grad = True 1324 | u_ = interpolate_tri(u[[t]].unsqueeze(1).repeat(1, n**2, 1).reshape(-1, u.shape[-1]), \ 1325 | xi1.unsqueeze(0).unsqueeze(0).repeat(1, n**2, 1, 1).reshape(-1, u.shape[-1], 1), \ 1326 | xi2.unsqueeze(0).unsqueeze(0).repeat(1, n**2, 1, 1).reshape(-1, u.shape[-1], 1), \ 1327 | x1_.unsqueeze(1).repeat(1, u.shape[-1], 1), x2_.unsqueeze(1).repeat(1, u.shape[-1], 1)) 1328 | w = torch.ones(u_.shape).to(device) 1329 | uni_ux = torch.autograd.grad(u_, x1_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0].reshape(1, n, n) 1330 | uni_uy = torch.autograd.grad(u_, x2_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0].reshape(1, n, n) 1331 | alpha = torch.sum((torch.abs(uni_ux)**2 + torch.abs(uni_uy)**2)**(1/2), dim=(-2, -1)) / (n-1)**2 1332 | m = monitor(alpha.unsqueeze(-1).unsqueeze(-1).repeat(1, n, n), uni_ux, uni_uy)[0] 1333 | 1334 | norm = matplotlib.colors.Normalize(vmin=m.cpu().min(), vmax=m.cpu().max()) 1335 | im = ax.contourf(m.detach().cpu().numpy(), 50, cmap=plt.cm.binary, norm = norm) 1336 | plt.colorbar(cm.ScalarMappable(norm=norm, cmap=plt.cm.binary), ax=ax, format='%.2f') 1337 | 1338 | # xi_ = torch.cat((xi, t * torch.ones_like(xi1)), dim=1) 1339 | phi = model(u[[t]], xi_) 1340 | w = torch.ones(phi.shape).to(device) 1341 | x1 = (torch.autograd.grad(phi, xi1_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi1_).cpu().detach().numpy() * (n-1) 1342 | x2 = (torch.autograd.grad(phi, xi2_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi2_).cpu().detach().numpy() * (n-1) 1343 | for j in range(s): 1344 | for i in range(s - 1): 1345 | plt.plot(np.concatenate((x1[i + j*s], x1[i + j*s + 1]), axis=0),\ 1346 | np.concatenate((x2[i + j*s], x2[i + j*s + 1]), axis=0), lw=0.2, color='green') 1347 | plt.plot(np.concatenate((x1[j + i*s], x1[j + (i+1)*s]), axis=0),\ 1348 | np.concatenate((x2[j + i*s], x2[j + (i+1)*s]), axis=0), lw=0.2, color='green') 1349 | 1350 | return fig, axes 1351 | 1352 | 1353 | def plot_mesh_res_tri(u, model, fig, axes, args, device): 1354 | u = u[:, :, 2].to(device) 1355 | 1356 | # mesh 1357 | grid = model.ori_grid 1358 | xi1, xi2 = grid[:, [0]], grid[:, [1]] 1359 | xi1.requires_grad = True 1360 | xi2.requires_grad = True 1361 | xi = torch.cat((xi1, xi2), dim=-1) 1362 | 1363 | n = int(np.sqrt(u.shape[-1])) 1364 | uni_grid = torch.tensor(np.array(np.meshgrid(np.linspace(0, 1, n), np.linspace(0, 1, n))), dtype=torch.float)\ 1365 | .reshape(2, -1).permute(1, 0).to(device) 1366 | 1367 | plt.xticks([0, int(u.shape[-1]/2-1), n-1],['0.0','0.5','1.0']) 1368 | plt.yticks([0, int(u.shape[-1]/2-1), n-1],['0.0','0.5','1.0']) 1369 | 1370 | for i in range(5): 1371 | # t = 0 1372 | t = 6*i + 5 1373 | plt.subplot(1, 5, i+1) 1374 | plt.title('t={}'.format(t), fontsize=18) 1375 | ax = axes[i] 1376 | 1377 | x1_ = uni_grid[:, [0]] 1378 | x2_ = uni_grid[:, [1]] 1379 | x1_.requires_grad = True 1380 | x2_.requires_grad = True 1381 | u_ = interpolate_tri(u[[t]].unsqueeze(1).repeat(1, n**2, 1).reshape(-1, u.shape[-1]), \ 1382 | xi1.unsqueeze(0).unsqueeze(0).repeat(1, n**2, 1, 1).reshape(-1, u.shape[-1], 1), \ 1383 | xi2.unsqueeze(0).unsqueeze(0).repeat(1, n**2, 1, 1).reshape(-1, u.shape[-1], 1), \ 1384 | x1_.unsqueeze(1).repeat(1, u.shape[-1], 1), x2_.unsqueeze(1).repeat(1, u.shape[-1], 1)) 1385 | w = torch.ones(u_.shape).to(device) 1386 | uni_ux = torch.autograd.grad(u_, x1_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0].reshape(1, n, n) 1387 | uni_uy = torch.autograd.grad(u_, x2_, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0].reshape(1, n, n) 1388 | alpha = torch.sum((torch.abs(uni_ux)**2 + torch.abs(uni_uy)**2)**(1/2), dim=(-2, -1)) / (n-1)**2 1389 | m = monitor(alpha.unsqueeze(-1).unsqueeze(-1).repeat(1, n, n), uni_ux, uni_uy)[0] 1390 | 1391 | norm = matplotlib.colors.Normalize(vmin=m.cpu().min(), vmax=m.cpu().max()) 1392 | im = ax.contourf(m.detach().cpu().numpy(), 50, cmap=plt.cm.binary, norm = norm) 1393 | plt.colorbar(cm.ScalarMappable(norm=norm, cmap=plt.cm.binary), ax=ax, format='%.2f') 1394 | 1395 | # xi_ = torch.cat((xi, t * torch.ones_like(xi1)), dim=1) 1396 | if args.bound_constraint == 'soft': 1397 | phi = model(u[[t]], xi) 1398 | else: 1399 | phi = ((xi1**2) * (xi2**2) * ((xi1-1)**2) * ((xi2-1)**2)) * model(u, xi) + (1/2) * (xi1**2) + (1/2) * (xi2**2) 1400 | w = torch.ones(phi.shape).to(device) 1401 | x1 = xi1[:, 0].cpu().detach().numpy() * (n-1) 1402 | x2 = xi2[:, 0].cpu().detach().numpy() * (n-1) 1403 | # triangulation = tri.Triangulation(x1, x2) 1404 | # plt.triplot(triangulation, '-', linewidth=0.1, c='b') 1405 | points = np.column_stack((x1, x2)) 1406 | tri = Delaunay(points) 1407 | triangles_indices = tri.simplices 1408 | # plt.triplot(x1, x2, tri.simplices, '-', linewidth=0.1, c='b') 1409 | x1 = (torch.autograd.grad(phi, xi1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi1)[:, 0].cpu().detach().numpy() * (n-1) 1410 | x2 = (torch.autograd.grad(phi, xi2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi2)[:, 0].cpu().detach().numpy() * (n-1) 1411 | # triangulation = tri.Triangulation(x1, x2) 1412 | # plt.triplot(triangulation, '-', linewidth=0.1, c='g') 1413 | plt.triplot(x1, x2, tri.simplices, '-', linewidth=0.1, c='g') 1414 | 1415 | return fig, axes 1416 | 1417 | 1418 | def plot_mesh_res(s, u, model, fig, axes, args, device): 1419 | 1420 | # mesh 1421 | grid1 = np.linspace(0, 1, s) 1422 | grid2 = np.linspace(0, 1, s) 1423 | grid = torch.tensor(np.array(np.meshgrid(grid1, grid2)), dtype=torch.float).reshape(2, -1).permute(1, 0).to(device) 1424 | xi1, xi2 = grid[:, [0]], grid[:, [1]] 1425 | xi1.requires_grad = True 1426 | xi2.requires_grad = True 1427 | xi = torch.cat((xi1, xi2), dim=-1) 1428 | 1429 | # monitor function 1430 | u = u.to(device) 1431 | ux = diff_x(u) * (u.shape[-1] - 1) 1432 | uy = diff_y(u) * (u.shape[-1] - 1) 1433 | alpha = torch.sum((torch.abs(ux)**2 + torch.abs(uy)**2)**(1/2), dim=(-2, -1)) / (u.shape[-1]-1)**2 1434 | 1435 | plt.xticks([0, int(u.shape[-1]/2-1), u.shape[-1]-1],['0.0','0.5','1.0']) 1436 | plt.yticks([0, int(u.shape[-1]/2-1), u.shape[-1]-1],['0.0','0.5','1.0']) 1437 | m = monitor(alpha.unsqueeze(-1).unsqueeze(-1).repeat(1, ux.shape[-1], ux.shape[-1]), ux, uy) 1438 | norm = matplotlib.colors.Normalize(vmin=m.cpu().min(), vmax=m.cpu().max()) 1439 | 1440 | for i in range(5): 1441 | t = 22*i + 22 1442 | plt.subplot(1, 5, i+1) 1443 | plt.title('t={}'.format(t), fontsize=18) 1444 | ax = axes[i] 1445 | im = ax.contourf(m[t].cpu().numpy(), 50, cmap=plt.cm.binary, norm = norm) 1446 | plt.colorbar(cm.ScalarMappable(norm=norm, cmap=plt.cm.binary), ax=ax, format='%.2f') 1447 | 1448 | # xi_ = torch.cat((xi, t * torch.ones_like(xi1)), dim=1) 1449 | if args.bound_constraint == 'soft': 1450 | phi = model(u[[t]], xi) 1451 | else: 1452 | phi = ((xi1**2) * (xi2**2) * ((xi1-1)**2) * ((xi2-1)**2)) * model(u, xi) + (1/2) * (xi1**2) + (1/2) * (xi2**2) 1453 | w = torch.ones(phi.shape).to(device) 1454 | x1 = (torch.autograd.grad(phi, xi1, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi1).cpu().detach().numpy() * (u.shape[-1]-1) 1455 | x2 = (torch.autograd.grad(phi, xi2, grad_outputs=w, retain_graph = True, create_graph=True, allow_unused=True)[0] + xi2).cpu().detach().numpy() * (u.shape[-1]-1) 1456 | for j in range(s): 1457 | for i in range(s - 1): 1458 | plt.plot(np.concatenate((x1[i + j*s], x1[i + j*s + 1]), axis=0),\ 1459 | np.concatenate((x2[i + j*s], x2[i + j*s + 1]), axis=0), lw=0.2, color='black') 1460 | plt.plot(np.concatenate((x1[j + i*s], x1[j + (i+1)*s]), axis=0),\ 1461 | np.concatenate((x2[j + i*s], x2[j + (i+1)*s]), axis=0), lw=0.2, color='black') 1462 | 1463 | return fig, axes 1464 | 1465 | --------------------------------------------------------------------------------