├── LICENSE.txt ├── README.md ├── docs └── README.md ├── examples ├── README.md ├── regression_diff_op_1d.ipynb ├── regression_diff_op_1d.py ├── regression_diff_op_2d.ipynb ├── regression_diff_op_2d.py └── z_doc_img │ ├── diff_op_1d_1.png │ └── diff_op_2d_1.png ├── gmlsnets_pytorch-1.0.0.tar.gz ├── misc └── overview.png └── src ├── __init__.py ├── dataset.py ├── nn.py ├── util.py └── vis.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Paul J. Atzberger 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GMLS-Nets 2 | 3 |
4 | 5 |
6 | 7 | [Examples](https://github.com/atzberg/gmls-nets/tree/master/examples) | [Documentation](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html) 8 | 9 | __GMLS-Nets (PyTorch Implementation)__ 10 | 11 | Package provides machine learning methods for learning features from scattered unstructured data sets using Generalized Moving Least Squares (GMLS). Provides techniques which can be used to generalize approaches, such as Convolutional Neural Networks (CNNs) which utilize translational and other symmetry, to unstructured data. GMLS-Nets package also provides approaches for learning differential operators, PDEs, and other features from scattered data. 12 | 13 | __Quick Start__ 14 | 15 | *Method 1:* Install for python using pip 16 | 17 | ```pip install -U gmlsnets-pytorch``` 18 | 19 | For use of the package see the [examples page](https://github.com/atzberg/gmls-nets/tree/master/examples). For getting the latest version use ```pip install --upgrade gmlsnets-pytorch```. More information on the structure of the package also can be found on the [documentation page](https://github.com/atzberg/gmls-nets/tree/master/docs). 20 | 21 | If previously installed the package, please update to the latest version using ```pip install --upgrade gmlsnets-pytorch``` 22 | 23 | __Manual Installation__ 24 | 25 | *Method 2:* Download the [gmlsnets_pytorch-1.0.0.tar.gz](https://github.com/atzberg/gmls-nets-testing/blob/master/gmlsnets_pytorch-1.0.0.tar.gz) file above, then uncompress 26 | 27 | ``tar -xvf gmlsnets_pytorch-1.0.0.tar.gz`` 28 | 29 | For local install, please be sure to edit in your codes the path location of base directory by adding 30 | 31 | ``sys.path.append('package-path-location-here');`` 32 | 33 | Note the package resides in the sub-directory ``./gmlsnets_pytorch-1.0.0/gmlsnets_pytorch/`` 34 | 35 | __Packages__ 36 | 37 | Please be sure to install [PyTorch](https://pytorch.org/) package >= 1.2.0 with Python 3 (ideally >= 3.7). Also, be sure to install the following packages: numpy>=1.16, scipy>=1.3, matplotlib>=3.0. 38 | 39 | __Use__ 40 | 41 | For examples and documentation, see 42 | 43 | [Examples](https://github.com/atzberg/gmls-nets/tree/master/examples) 44 | 45 | [Documentation](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html) 46 | 47 | __Additional Information__ 48 | 49 | If you find these codes or methods helpful for your project, please cite: 50 | 51 | *GMLS-Nets: A Framework for Learning from Unstructured Data,* 52 | N. Trask, R. G. Patel, B. J. Gross, and P. J. Atzberger, arXiv:1909.05371, (2019), [[arXiv]](https://arxiv.org/abs/1909.05371). 53 | ``` 54 | @article{trask_patel_gross_atzberger_GMLS_Nets_2019, 55 | title={GMLS-Nets: A framework for learning from unstructured data}, 56 | author={Nathaniel Trask, Ravi G. Patel, Ben J. Gross, Paul J. Atzberger}, 57 | journal={arXiv:1909.05371}, 58 | month={September} 59 | year={2019} 60 | url={https://arxiv.org/abs/1909.05371} 61 | } 62 | ``` 63 | For [TensorFlow](https://www.tensorflow.org/) implementation of GMLS-Nets, see https://github.com/rgp62/gmls-nets. 64 | 65 | __Acknowledgements__ 66 | We gratefully acknowledge support from DOE Grant ASCR PHILMS DE-SC0019246. 67 | 68 | ---- 69 | 70 | 71 | [Examples](https://github.com/atzberg/gmls-nets/tree/master/examples) | [Documentation](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html) | [Atzberger Homepage](http://atzberger.org/) 72 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | For more information on how to use this package, please see the [Documentation](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html) and look at the [Examples](https://github.com/atzberg/gmls-nets/tree/master/examples). 2 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 |

5 | 6 | For more information on how to use this package, see the example Jupyter notebooks and python scripts above. 7 | 8 | More information on the structure of the package can also be found on the 9 | [documentation page](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html) 10 | -------------------------------------------------------------------------------- /examples/regression_diff_op_1d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # ## GMLS-Nets: 1D Regression of Linear and Non-linear Operators $L[u]$. 5 | # 6 | # __Ben J. Gross__, __Paul J. Atzberger__
7 | # http://atzberger.org/ 8 | # 9 | # Examples showing how GMLS-Nets can be used to perform regression for some basic linear and non-linear differential operators in 1D. 10 | # 11 | # __Parameters:__
12 | # The key parameter terms to adjust are:
13 | # ``op_type``: The operator type.
14 | # ``flag_mlp_case``: The type of mapping unit to use.
15 | # 16 | # __Examples of Non-linear Operators ($u{u_x},u_x^2,u{u_{xx}},u_{xx}^2$) :__
17 | # To run training for a non-linear operator like ``u*ux`` using MLP for the non-linear GMLS mapping unit, you can use:
18 | # ``op_type='u*ux';``
19 | # ``flag_mlp_case = 'NonLinear1';``
20 | # You can obtain different performance by adjusting the mapping architecture and hyperparameters of the network. 21 | # 22 | # __Examples of linear Operators ($u_x,u_{xx}$):__
23 | # To run training for a linear operator like the 1d Laplacian ``uxx`` with a linear mapping unit, you can use
24 | # ``op_type='uxx';``
25 | # ``flag_mlp_case = 'Linear1';``
26 | # 27 | # These are organized for different combinations of these settings allowing for exploring the methods. The codes are easy to modify and adjust to also experiment with other operators. For example, see the dataset classes. 28 | # 29 | 30 | # ### Imports 31 | 32 | # In[1]: 33 | 34 | print("="*80); 35 | print("GMLS-Nets: 1D Regression of Linear and Non-linear Operators $L[u]$."); 36 | print("-"*80); 37 | 38 | import sys; 39 | 40 | # setup path to location of gmlsnets_pytorch (if not install system-wide) 41 | path_gmlsnets_pytorch = '../../'; 42 | sys.path.append(path_gmlsnets_pytorch); 43 | 44 | import torch; 45 | import torch.nn as nn; 46 | 47 | import numpy as np; 48 | import pickle; 49 | 50 | import matplotlib.pyplot as plt; 51 | 52 | import pdb 53 | import time 54 | 55 | import os 56 | 57 | # setup gmlsnets package 58 | import gmlsnets_pytorch as gmlsnets; 59 | import gmlsnets_pytorch.nn; 60 | import gmlsnets_pytorch.vis; 61 | import gmlsnets_pytorch.dataset; 62 | 63 | # dereference a few common items 64 | MapToPoly_Function = gmlsnets.nn.MapToPoly_Function; 65 | get_num_polys = MapToPoly_Function.get_num_polys; 66 | weight_one_minus_r = MapToPoly_Function.weight_one_minus_r; 67 | eval_poly = MapToPoly_Function.eval_poly; 68 | 69 | print("Packages:"); 70 | print("torch.__version__ = " + str(torch.__version__)); 71 | print("numpy.__version__ = " + str(np.__version__)); 72 | print("gmlsnets.__version__ = " + str(gmlsnets.__version__)); 73 | 74 | 75 | # ### Parameters and basic setup 76 | 77 | # In[2]: 78 | 79 | 80 | # Setup the parameters 81 | batch_size = int(1e2); 82 | flag_extend_periodic = False; # periodic boundaries 83 | flag_dataset = 'diffOp1'; 84 | run_name = '%s_Test1'%flag_dataset; 85 | base_dir = './output/regression_diff_op_1d/%s'%run_name; 86 | flag_print_model = False; 87 | 88 | print("Settings:"); 89 | print("flag_dataset = " + flag_dataset); 90 | print("run_name = " + run_name); 91 | print("base_dir = " + base_dir); 92 | 93 | if not os.path.exists(base_dir): 94 | os.makedirs(base_dir); 95 | 96 | # Configure devices 97 | if torch.cuda.is_available(): 98 | num_gpus = torch.cuda.device_count(); 99 | print("num_gpus = " + str(num_gpus)); 100 | if num_gpus >= 4: 101 | device = torch.device('cuda:3'); 102 | else: 103 | device = torch.device('cuda:0'); 104 | else: 105 | device = torch.device('cpu'); 106 | print("device = " + str(device)); 107 | 108 | 109 | # ### Setup GMLS-Net for regressing differential operator 110 | 111 | # In[3]: 112 | 113 | 114 | class gmlsNetRegressionDiffOp1(nn.Module): 115 | """Sets up a GMLS-Net for regression differential operator in 1D.""" 116 | 117 | def __init__(self, 118 | flag_GMLS_type=None, 119 | porder1=None,Nc=None, 120 | pts_x1=None,layer1_epsilon=None, 121 | weight_func1=None,weight_func1_params=None, 122 | mlp_q1=None,pts_x2=None, 123 | device=None,flag_verbose=0, 124 | **extra_params): 125 | 126 | super(gmlsNetRegressionDiffOp1, self).__init__(); 127 | 128 | self.layer_types = []; 129 | 130 | if device is None: 131 | device = torch.device('cpu'); # default 132 | 133 | # -- 134 | Ncp1 = mlp_q1.channels_out; # number of channels out of the MLP-Pointwise layer 135 | 136 | num_features1 = mlp_q1.channels_out; # number of channels out (16 typical) 137 | 138 | GMLS_Layer = gmlsnets.nn.GMLS_Layer; 139 | ExtractFromTuple = gmlsnets.nn.ExtractFromTuple; 140 | PermuteLayer = gmlsnets.nn.PermuteLayer; 141 | PdbSetTraceLayer = gmlsnets.nn.PdbSetTraceLayer; 142 | 143 | # --- Layer 1 144 | #flag_layer1 = 'standard_conv1'; 145 | flag_layer1 = 'gmls1d_1'; 146 | self.layer_types.append(flag_layer1); 147 | if flag_layer1 == 'standard_conv1': 148 | self.layer1 = nn.Sequential( 149 | nn.Conv1d(in_channels=Nc,out_channels=num_features1, 150 | kernel_size=5,stride=1,padding=2,bias=False), 151 | ).to(device); 152 | elif flag_layer1 == 'gmls1d_1': 153 | self.layer1 = nn.Sequential( 154 | GMLS_Layer(flag_GMLS_type, porder1, 155 | pts_x1, layer1_epsilon, 156 | weight_func1, weight_func1_params, 157 | mlp_q=mlp_q1, pts_x2=pts_x2, device=device, 158 | flag_verbose=flag_verbose), 159 | #PdbSetTraceLayer(), 160 | ExtractFromTuple(index=0), # just get the forward output associated with the mapping and not pts_x2 161 | #PdbSetTraceLayer(), 162 | PermuteLayer((0,2,1)) 163 | ).to(device); 164 | 165 | else: 166 | raise Exception('flag_layer1 type not recognized.'); 167 | 168 | def forward(self, x): 169 | out = self.layer1(x); 170 | return out; 171 | 172 | 173 | # ### Setup the Model: Neural Network 174 | 175 | # In[4]: 176 | 177 | 178 | # setup sample point locations 179 | xj = torch.linspace(0,1.0,steps=101,device=device).unsqueeze(1); 180 | xi = torch.linspace(0,1.0,steps=101,device=device).unsqueeze(1); 181 | 182 | # make a numpy copy for plotting and some other routines 183 | np_xj = xj.cpu().numpy(); np_xi = xi.cpu().numpy(); 184 | 185 | # setup parameters 186 | Nc = 1; # scalar field 187 | Nx = xj.shape[0]; num_dim = xj.shape[1]; 188 | porder = 2; num_polys = get_num_polys(porder,num_dim); 189 | 190 | weight_func1 = MapToPoly_Function.weight_one_minus_r; 191 | targ_kernel_width = 11.5; layer1_epsilon = 0.4*0.5*np.sqrt(2)*targ_kernel_width/Nx; 192 | #targ_kernel_width = 21.5; layer1_epsilon = 0.4*0.5*np.sqrt(2)*targ_kernel_width/Nx; 193 | weight_func1_params = {'epsilon': layer1_epsilon,'p':4}; 194 | 195 | color_input = (0.05,0.44,0.69); 196 | color_output = (0.44,0.30,0.60); 197 | color_predict = (0.05,0.40,0.5); 198 | color_target = (221/255,103/255,103/255); 199 | 200 | # print the current settings 201 | print("GMLS Parameters:") 202 | print("porder = " + str(porder)); 203 | print("num_dim = " + str(num_dim)); 204 | print("num_polys = " + str(num_polys)); 205 | print("layer1_epsilon = %.3e"%layer1_epsilon); 206 | print("weight_func1 = " + str(weight_func1)); 207 | print("weight_func1_params = " + str(weight_func1_params)); 208 | print("xj.shape = " + str(xj.shape)); 209 | print("xi.shape = " + str(xi.shape)); 210 | 211 | 212 | # In[5]: 213 | 214 | 215 | # create an MLP for training the non-linear part of the GMLS Net 216 | #flag_mlp_case = 'Linear1';flag_mlp_case = 'Nonlinear1' 217 | flag_mlp_case = 'Nonlinear1'; 218 | if (flag_mlp_case == 'Linear1'): 219 | layer_sizes = []; 220 | 221 | num_depth = 0; # number of internal layers 222 | num_hidden = -1; # number of hidden per layer 223 | 224 | channels_in = Nc; # number of poly channels (matches input u channel size) 225 | channels_out = 1; # number of output filters 226 | 227 | layer_sizes.append(num_polys); # input 228 | layer_sizes.append(1); # output, single channel always, for vectors, we use channels_out separate units. 229 | 230 | mlp_q1 = gmlsnets.nn.MLP_Pointwise(layer_sizes,channels_in=channels_in,channels_out=channels_out, 231 | flag_bias=False).to(device); 232 | elif (flag_mlp_case == 'Nonlinear1'): 233 | layer_sizes = []; 234 | num_input = Nc*num_polys; # number of channels*num_polys, allows for cross-channel coupling 235 | num_depth = 4; # number of internal layers 236 | num_hidden = 100; # number of hidden per layer 237 | num_out_channels = 16; # number of output filters 238 | layer_sizes.append(num_polys); 239 | for k in range(num_depth): 240 | layer_sizes.append(num_hidden); 241 | layer_sizes.append(1); # output, single channel always, for vectors, we use channels_out separate units. 242 | 243 | mlp_q1 = gmlsnets.nn.MLP_Pointwise(layer_sizes,channels_out=num_out_channels, 244 | flag_bias=True).to(device); 245 | 246 | if flag_print_model: 247 | print("mlp_q1:"); 248 | print(mlp_q1); 249 | 250 | 251 | # In[6]: 252 | 253 | 254 | # Setup the Neural Network for Regression 255 | flag_verbose = 0; 256 | flag_case = 'standard'; 257 | 258 | # Setup the model 259 | xi = xi.float(); 260 | xj = xj.float(); 261 | model = gmlsNetRegressionDiffOp1(flag_case,porder,Nc,xj,layer1_epsilon, 262 | weight_func1,weight_func1_params, 263 | mlp_q1=mlp_q1,pts_x2=xi, 264 | device=device, 265 | flag_verbose=flag_verbose); 266 | 267 | if flag_print_model: 268 | print("model:"); 269 | print(model); 270 | 271 | 272 | # ## Setup the training and test data 273 | 274 | # In[7]: 275 | 276 | 277 | ### Generate Dataset 278 | 279 | if flag_dataset == 'diffOp1': 280 | # Use the FFT to represent differential operators for training data sets. 281 | # 282 | # Setup a data set of the following: 283 | # To start let's do regression for the Laplacian (not inverse, just action of it, like finding FD) 284 | # 285 | 286 | #op_type = 'u*ux';op_type = 'ux*ux';op_type = 'uxx';op_type = 'u*uxx';op_type = 'uxx*uxx'; 287 | op_type = 'u*ux'; 288 | print("op_type = " + op_type); 289 | 290 | num_training_samples = int(5e4); 291 | nchannels = 1; 292 | nx = np_xj.shape[0]; 293 | #alpha1 = 0.05; 294 | alpha1 = 0.1; 295 | scale_factor = 1e2; 296 | train_dataset = gmlsnets.dataset.diffOp1(op_type=op_type,op_params=None, 297 | gen_mode='exp1',gen_params={'alpha1':alpha1}, 298 | num_samples=num_training_samples, 299 | nchannels=nchannels,nx=nx, 300 | noise_factor=0,scale_factor=scale_factor, 301 | flag_verbose=1); 302 | 303 | train_dataset = train_dataset.to(device); 304 | print("done."); 305 | 306 | num_test_samples = int(1e4); 307 | scale_factor = 1e2; 308 | test_dataset = gmlsnets.dataset.diffOp1(op_type=op_type,op_params=None, 309 | gen_mode='exp1',gen_params={'alpha1':alpha1}, 310 | num_samples=num_test_samples, 311 | nchannels=nchannels,nx=nx, 312 | noise_factor=0,scale_factor=scale_factor, 313 | flag_verbose=1); 314 | test_dataset = test_dataset.to(device); 315 | print("done."); 316 | 317 | # Put the data into the 318 | #train_dataset and test_dataset structures for processing 319 | 320 | else: 321 | msg = "flag_dataset not recognized."; 322 | msg += "flag_data_set = " + str(flag_data_set); 323 | raise Exception(msg); 324 | 325 | # Data loader 326 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True); 327 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False); 328 | 329 | 330 | # In[8]: 331 | 332 | 333 | #get_ipython().run_line_magic('matplotlib', 'inline') 334 | 335 | # plot sample of the training data 336 | gmlsnets.vis.plot_dataset_diffOp1(train_dataset,np_xj,np_xi,rows=4,cols=6, 337 | title="Data Samples: u, f=L[u], L = %s"%op_type); 338 | 339 | 340 | # ## Train the Model 341 | 342 | # ### Custom Functions 343 | 344 | # In[9]: 345 | 346 | 347 | def custom_loss_least_squares(val1,val2): 348 | r"""Computes the Mean-Square-Error (MSE) over the entire batch.""" 349 | diff_flat = (val1 - val2).flatten(); 350 | N = diff_flat.shape[0]; 351 | loss = torch.sum(torch.pow(diff_flat,2),-1)/N; 352 | return loss; 353 | 354 | def domain_periodic_repeat(Z): 355 | r"""Extends the input periodically.""" 356 | Z_periodic = torch.cat((Z, Z, Z), 2); 357 | return Z_periodic; 358 | 359 | def domain_periodic_extract(Z_periodic): 360 | r"""Extracts the middle unit cell portion of the extended data.""" 361 | nn = int(Z_periodic.shape[2]/3); 362 | Z = Z_periodic[:,:,nn:2*nn]; 363 | return Z; 364 | 365 | 366 | # ### Initialize 367 | 368 | # In[10]: 369 | 370 | 371 | loss_list = np.empty(0); loss_step_list = np.empty(0); 372 | save_skip = 1; step_count = 0; 373 | 374 | 375 | # ### Train the network. 376 | 377 | # In[11]: 378 | 379 | 380 | num_epochs = int(3e0); #int(1e4); 381 | learning_rate = 1e-2; 382 | 383 | print("Training the network with:"); 384 | print(""); 385 | print("model:"); 386 | print("model.layer_types = " + str(model.layer_types)); 387 | print(""); 388 | 389 | # setup the optimization method and loss function 390 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate); 391 | 392 | #loss_func = nn.CrossEntropyLoss(); 393 | #loss_func = nn.MSELoss(); 394 | loss_func = custom_loss_least_squares; 395 | 396 | print("num_epochs = %d"%num_epochs); 397 | print("batch_size = %d"%batch_size); 398 | print(" "); 399 | 400 | # Train the model 401 | flag_time_it = True; 402 | if flag_time_it: 403 | time_1 = time.time(); 404 | print("-"*80); 405 | num_steps = len(train_loader); 406 | for epoch in range(num_epochs): 407 | for i, (input,target) in enumerate(train_loader): 408 | input = input.to(device); 409 | target = target.to(device); 410 | 411 | if flag_extend_periodic: 412 | # Extend input periodically 413 | input_periodic = domain_periodic_repeat(input); 414 | 415 | # Forward pass 416 | output_periodic = model(input_periodic); 417 | output = domain_periodic_extract(output_periodic); 418 | else: 419 | output = model(input); 420 | 421 | # Compute loss 422 | loss = loss_func(output,target); 423 | 424 | # Display 425 | if step_count % save_skip == 0: 426 | np_loss = loss.cpu().detach().numpy(); 427 | loss_list = np.append(loss_list,np_loss); 428 | loss_step_list = np.append(loss_step_list,step_count); 429 | 430 | # Back-propagation for gradients and use to optimize 431 | optimizer.zero_grad(); 432 | loss.backward(); 433 | 434 | optimizer.step(); 435 | 436 | step_count += 1; 437 | 438 | if ((i + 1) % 100) == 0 or i == 0: 439 | msg = 'epoch: [%d/%d]; '%(epoch+1,num_epochs); 440 | msg += 'batch_step = [%d/%d]; '%(i + 1,num_steps); 441 | msg += 'loss_MSE: %.3e.'%(loss.item()); 442 | print(msg); 443 | 444 | if flag_time_it and i > 0: 445 | msg = 'elapsed_time = %.4e secs \n'%(time.time() - time_1); 446 | print(msg); 447 | time_1 = time.time(); 448 | 449 | 450 | print("done training.") 451 | print("-"*80); 452 | 453 | 454 | # ### Plot Loss 455 | 456 | # In[12]: 457 | 458 | 459 | #get_ipython().run_line_magic('matplotlib', 'inline') 460 | 461 | plt.figure(figsize=(8,6)); 462 | 463 | plt.plot(loss_step_list,loss_list,'b-'); 464 | plt.yscale('log'); 465 | plt.xlabel('step'); 466 | plt.ylabel('loss'); 467 | 468 | plt.title('Loss'); 469 | 470 | 471 | # ### Test the Neural Network Predictions 472 | 473 | # In[13]: 474 | 475 | 476 | print("Testing predictions of the neural network:"); 477 | 478 | flag_save_tests = True; 479 | if flag_save_tests: 480 | test_data = {}; 481 | 482 | # Save the first few to show as examples of labeling 483 | saved_test_input = []; 484 | saved_test_target = []; 485 | saved_test_output_pred = []; 486 | 487 | count_batch = 0; 488 | with torch.no_grad(): 489 | total = 0; II = 0; 490 | avg_error = 0; 491 | for input,target in test_loader: # loads data in batches and then sums up 492 | 493 | if (II >= 1000): 494 | print("tested on %d samples"%total); 495 | II = 0; 496 | 497 | input = input.to(device); target = target.to(device); 498 | 499 | # Compute model 500 | flag_extend_periodic = False; 501 | if flag_extend_periodic: 502 | # Extend input periodically 503 | input_periodic = domain_periodic_repeat(input); 504 | 505 | # Forward pass 506 | output_periodic = model(input_periodic); 507 | output = domain_periodic_extract(output_periodic); 508 | else: 509 | output = model(input); 510 | 511 | # Compute loss 512 | loss = loss_func(output,target); 513 | 514 | # Record the results 515 | avg_error += loss; 516 | 517 | total += output.shape[0]; 518 | II += output.shape[0]; 519 | count_batch += 1; 520 | 521 | NN = output.shape[0]; 522 | for k in range(min(NN,20)): # save first 10 images of each batch 523 | saved_test_input.append(input[k]); 524 | saved_test_target.append(target[k]); 525 | saved_test_output_pred.append(output[k]); 526 | 527 | print(""); 528 | print("Tested on a total of %d samples."%total); 529 | print(""); 530 | 531 | # Compute RMSD error 532 | test_accuracy = avg_error.cpu()/count_batch; 533 | test_accuracy = np.sqrt(test_accuracy); 534 | 535 | print("The neural network has RMSD error %.2e on the %d test samples."%(test_accuracy,total)); 536 | print(""); 537 | 538 | 539 | # ### Show a Sample of the Predictions 540 | 541 | # In[14]: 542 | 543 | 544 | # collect a subset of the data to show and attach named labels 545 | #get_ipython().run_line_magic('matplotlib', 'inline') 546 | 547 | num_prediction_samples = len(saved_test_input); 548 | print("num_prediction_samples = " + str(num_prediction_samples)); 549 | 550 | #II = np.random.permutation(num_samples); # compute random collection of indices @optimize 551 | II = np.arange(num_prediction_samples); 552 | 553 | if flag_dataset == 'name-here' or 0 == 0: 554 | u_list = []; f_list = []; f_pred_list = []; 555 | for I in np.arange(0,min(num_prediction_samples,16)): 556 | u_list.append(saved_test_input[II[I]].cpu()); 557 | f_list.append(saved_test_target[II[I]].cpu()); 558 | f_pred_list.append(saved_test_output_pred[II[I]].cpu()); 559 | 560 | # plot predictions against test data 561 | gmlsnets.vis.plot_samples_u_f_fp_1d(u_list,f_list,f_pred_list,np_xj,np_xi,rows=4,cols=6, 562 | title="Test Samples and Predictions: u, f=L[u], L = %s"%op_type); 563 | 564 | 565 | # ### Save Model 566 | 567 | # In[15]: 568 | 569 | 570 | model_filename = '%s/model.ckpt'%base_dir; 571 | print("model_filename = " + model_filename); 572 | torch.save(model.state_dict(), model_filename); 573 | 574 | model_filename = "%s/model_state.pickle"%base_dir; 575 | print("model_filename = " + model_filename); 576 | f = open(model_filename,'wb'); 577 | pickle.dump(model.state_dict(),f); 578 | f.close(); 579 | 580 | 581 | # ### Display the GMLS-Nets Learned Parameters 582 | 583 | # In[16]: 584 | 585 | 586 | flag_run_cell = flag_print_model; 587 | 588 | if flag_run_cell: 589 | print("-"*80) 590 | print("model.parameters():"); 591 | ll = model.parameters(); 592 | for l in ll: 593 | print(l); 594 | 595 | if flag_run_cell: 596 | print("-"*80) 597 | print("model.state_dict():"); 598 | print(model.state_dict()); 599 | print("-"*80) 600 | 601 | 602 | # ### Done 603 | 604 | # In[ ]: 605 | 606 | print("="*80); 607 | 608 | -------------------------------------------------------------------------------- /examples/regression_diff_op_2d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # ## GMLS-Nets: 2D Regression of Linear and Non-linear Operators $L[u]$. 5 | # 6 | # __Ben J. Gross__, __Paul J. Atzberger__
7 | # http://atzberger.org/ 8 | # 9 | # Examples showing how GMLS-Nets can be used to perform regression for some basic linear and non-linear differential operators in 2D. 10 | # 11 | # __Parameters:__
12 | # The key parameter terms to adjust are:
13 | # ``op_type``: The operator type.
14 | # ``flag_mlp_case``: The type of mapping unit to use.
15 | # 16 | # __Examples of Non-linear Operators ($u\Delta{u}, u\nabla{u},\nabla{u}\cdot\nabla{u}$) :__
17 | # To run training for a non-linear operator like $u\nabla{u}$ using MLP for the non-linear GMLS mapping unit, you can use:
18 | # ``op_type=r'u\grad{u}';``
19 | # ``flag_mlp_case = 'NonLinear1';``
20 | # You can obtain different performance by adjusting the mapping architecture and hyperparameters of the network. 21 | # 22 | # __Examples of linear Operators ($\nabla{u}, \Delta{u}$):__
23 | # To run training for a linear operator like the 1d Laplacian ``uxx`` with a linear mapping unit, you can use
24 | # ``op_type=r'\Delta{u}';``
25 | # ``flag_mlp_case = 'Linear1';``
26 | # 27 | # These are organized for different combinations of these settings allowing for exploring the methods. The codes are easy to modify and adjust to also experiment with regressing other operators. For example, see the dataset classes. 28 | # 29 | 30 | # In[1]: 31 | 32 | print("="*80); 33 | print("GMLS-Nets: 2D Regression of Linear and Non-linear Operators $L[u]$."); 34 | print("-"*80); 35 | 36 | import sys; 37 | 38 | # setup path to location of gmlsnets_pytorch (if not installed system-wide) 39 | path_gmlsnets_pytorch = '../../'; 40 | sys.path.append(path_gmlsnets_pytorch); 41 | 42 | import torch; 43 | import torch.nn as nn; 44 | 45 | import numpy as np; 46 | import pickle; 47 | 48 | import matplotlib.pyplot as plt; 49 | 50 | import pdb 51 | import time 52 | 53 | import os 54 | 55 | # setup gmlsnets package 56 | import gmlsnets_pytorch as gmlsnets; 57 | import gmlsnets_pytorch.nn; 58 | import gmlsnets_pytorch.vis; 59 | import gmlsnets_pytorch.dataset; 60 | import gmlsnets_pytorch.util; 61 | 62 | # dereference a few common items 63 | MapToPoly_Function = gmlsnets.nn.MapToPoly_Function; 64 | get_num_polys = MapToPoly_Function.get_num_polys; 65 | weight_one_minus_r = MapToPoly_Function.weight_one_minus_r; 66 | eval_poly = MapToPoly_Function.eval_poly; 67 | 68 | print("Packages:"); 69 | print("torch.__version__ = " + str(torch.__version__)); 70 | print("numpy.__version__ = " + str(np.__version__)); 71 | print("gmlsnets.__version__ = " + str(gmlsnets.__version__)); 72 | 73 | 74 | # ### Parameters and basic setup 75 | 76 | # In[2]: 77 | 78 | 79 | # Setup the parameters 80 | batch_size = int(1e1); 81 | flag_extend_periodic = False; # periodic boundaries 82 | flag_dataset = 'diffOp2'; 83 | run_name = '%s_Test1'%flag_dataset; 84 | base_dir = './output/regression_diff_op_2d/%s'%run_name; 85 | flag_save_figs = True; 86 | fig_base_dir = '%s/fig'%base_dir; 87 | flag_print_model = False; 88 | 89 | print("Settings:"); 90 | print("flag_dataset = " + flag_dataset); 91 | print("run_name = " + run_name); 92 | print("base_dir = " + base_dir); 93 | 94 | if not os.path.exists(base_dir): 95 | os.makedirs(base_dir); 96 | 97 | if not os.path.exists(fig_base_dir): 98 | os.makedirs(fig_base_dir); 99 | 100 | # Configure devices 101 | if torch.cuda.is_available(): 102 | num_gpus = torch.cuda.device_count(); 103 | print("num_gpus = " + str(num_gpus)); 104 | if num_gpus >= 4: 105 | device = torch.device('cuda:3'); 106 | else: 107 | device = torch.device('cuda:0'); 108 | else: 109 | device = torch.device('cpu'); 110 | print("device = " + str(device)); 111 | 112 | 113 | # ## Setup the training and test data 114 | 115 | # ### Input Points 116 | 117 | # In[3]: 118 | 119 | 120 | Nx = 21;Ny = 21; 121 | nx = Nx; ny = Ny; 122 | 123 | NNx = 3*Nx; NNy = 3*Ny; # simple periodic by tiling for now 124 | aspect_ratio = NNx/float(NNy); 125 | xx = np.linspace(-1.5,aspect_ratio*1.5,NNx); xx = xx.astype(float); 126 | yy = np.linspace(-1.5,1.5,NNy); yy = yy.astype(float); 127 | 128 | aa = np.meshgrid(xx,yy); 129 | np_xj = np.array([aa[0].flatten(), aa[1].flatten()]).T; 130 | 131 | aa = np.meshgrid(xx,yy); 132 | np_xi = np.array([aa[0].flatten(), aa[1].flatten()]).T; 133 | 134 | # make torch tensors 135 | xj = torch.from_numpy(np_xj).float().to(device); # convert to torch tensors 136 | xj.requires_grad = False; 137 | 138 | xi = torch.from_numpy(np_xi).float().to(device); # convert to torch tensors 139 | xi.requires_grad = False; 140 | 141 | 142 | # ### Generate the Dataset 143 | 144 | # In[4]: 145 | 146 | 147 | if flag_dataset == 'diffOp2': 148 | 149 | flag_verbose = 1; 150 | #op_type = r'\Delta{u}';op_type = r'\Delta{u}*\Delta{u}'; 151 | #op_type = r'u\Delta{u}';op_type = r'\grad{u}';op_type = r'u\grad{u}'; 152 | #op_type = r'\grad{u}\cdot\grad{u}'; 153 | op_type = r'u\grad{u}'; 154 | print("op_type = " + op_type); 155 | 156 | num_dim = 2; 157 | num_training_samples = int(5e3); 158 | nchannels_u = 1; 159 | Nc = nchannels_u; 160 | 161 | #alpha1 = 0.05; 162 | alpha1 = 0.3; 163 | scale_factor = 1e2; 164 | train_dataset = gmlsnets.dataset.diffOp2(op_type=op_type,op_params=None, 165 | gen_mode='exp1',gen_params={'alpha1':alpha1}, 166 | num_samples=num_training_samples, 167 | nchannels=nchannels_u,nx=nx,ny=ny, 168 | noise_factor=0,scale_factor=scale_factor, 169 | flag_verbose=1); 170 | 171 | train_dataset = train_dataset.to(device); 172 | print("done."); 173 | 174 | num_test_samples = int(1e3); 175 | scale_factor = 1e2; 176 | test_dataset = gmlsnets.dataset.diffOp2(op_type=op_type,op_params=None, 177 | gen_mode='exp1',gen_params={'alpha1':alpha1}, 178 | num_samples=num_test_samples, 179 | nchannels=nchannels_u,nx=nx,ny=ny, 180 | noise_factor=0,scale_factor=scale_factor, 181 | flag_verbose=1); 182 | test_dataset = test_dataset.to(device); 183 | print("done."); 184 | 185 | else: 186 | msg = "flag_dataset not recognized."; 187 | msg += "flag_dataset = " + str(flag_dataset); 188 | raise Exception(msg); 189 | 190 | # Data loader 191 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True); 192 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False); 193 | 194 | # Cound number of output channels in f, determines if scalar or vector valued data 195 | nchannels_f = train_dataset.samples_Y.shape[1]; 196 | 197 | 198 | # ## Show Data 199 | 200 | # In[5]: 201 | 202 | 203 | flag_run_cell = True; 204 | if flag_run_cell: 205 | # Show subset of the data 206 | img_arr = []; 207 | label_str_arr = []; 208 | 209 | numImages = len(train_dataset); 210 | #II = np.random.permutation(numImages); # compute random collection of indices @optimize 211 | II = np.arange(numImages); 212 | 213 | if flag_dataset == '' or 0 == 0: 214 | img_arr = []; 215 | channelI = 0; # for vector-valued fields, choose component. 216 | for I in np.arange(0,min(num_training_samples,16)): 217 | img_arr.append(train_dataset[II[I],channelI,:][0].cpu()); 218 | gmlsnets.vis.plot_image_array(img_arr,title=r'$u^{[i]}$ Samples',figSize=(6,6),title_yp=0.95); 219 | 220 | if flag_save_figs: 221 | fig_name = 'samples_u'; 222 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True); 223 | 224 | img_arr = []; 225 | channelI = 0; 226 | for I in np.arange(0,min(num_training_samples,16)): 227 | img_arr.append(train_dataset[II[I],channelI,:][1].cpu()); 228 | gmlsnets.vis.plot_image_array(img_arr,title=r'$f^{[i]}$ Samples',figSize=(6,6),title_yp=0.95); 229 | 230 | if flag_save_figs: 231 | fig_name = 'samples_f'; 232 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True); 233 | 234 | 235 | 236 | 237 | # ### Side-by-Side View 238 | 239 | # In[6]: 240 | 241 | 242 | #get_ipython().run_line_magic('matplotlib', 'inline') 243 | 244 | for Ic_f in range(0,nchannels_f): 245 | # plot sample of the training data 246 | gmlsnets.vis.plot_dataset_diffOp2(train_dataset,np_xj,np_xi,rows=4,cols=6,channelI_f=Ic_f, 247 | title="Data Samples: u, f=L[u], L = %s, Ic_f = %d"%(op_type,Ic_f)); 248 | 249 | if flag_save_figs: 250 | fig_name = 'samples_u_f_Ic_f_%d'%Ic_f; 251 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True); 252 | 253 | 254 | # ### Setup GMLS-Net for regressing differential operator 255 | 256 | # In[7]: 257 | 258 | 259 | class gmlsNetRegressionDiffOp2(nn.Module): 260 | """Sets up a GMLS-Net for regression differential operator in 2D.""" 261 | def __init__(self, 262 | flag_GMLS_type=None, 263 | porder1=None,Nx=None,Ny=None,Nc=None,pts_x1=None,layer1_epsilon=None, 264 | weight_func1=None,weight_func1_params=None, 265 | mlp_q1=None,pts_x2=None, 266 | device=None,flag_verbose=0, 267 | **extra_params): 268 | 269 | super(gmlsNetRegressionDiffOp2, self).__init__(); 270 | 271 | # setup the layers 272 | self.layer_types = []; 273 | 274 | if device is None: 275 | device = torch.device('cpu'); # default 276 | 277 | # -- 278 | Ncp1 = mlp_q1.channels_out; # number of channels out of the MLP-Pointwise layer 279 | 280 | num_features1 = mlp_q1.channels_out; # number of channels out 281 | 282 | GMLS_Layer = gmlsnets.nn.GMLS_Layer; 283 | ExtractFromTuple = gmlsnets.nn.ExtractFromTuple; 284 | PermuteLayer = gmlsnets.nn.PermuteLayer; 285 | ReshapeLayer = gmlsnets.nn.ReshapeLayer; 286 | PdbSetTraceLayer = gmlsnets.nn.PdbSetTraceLayer; 287 | 288 | PeriodicPad2d = gmlsnets.util.PeriodicPad2d; 289 | ExtractUnitCell2d = gmlsnets.util.ExtractUnitCell2d; 290 | 291 | # --- Layer 1 292 | #flag_layer1 = 'standard_conv1'; 293 | flag_layer1 = 'gmls2d_1'; 294 | self.layer_types.append(flag_layer1); 295 | if flag_layer1 == 'standard_conv1': 296 | self.layer1 = nn.Sequential( 297 | PeriodicPad2d(), # expands width by 3 298 | nn.Conv2d(in_channels=Nc,out_channels=num_features1, 299 | kernel_size=5,stride=1,padding=2,bias=False), 300 | ExtractUnitCell2d() 301 | ).to(device); 302 | elif flag_layer1 == 'gmls2d_1': 303 | NNx = 3*Nx; NNy = 3*Ny; # periodic extensions 304 | reshape_ucell_to_GMLS_data = (-1,Nc,NNx*NNy); # map from assumed (batch,Nc,Nx,Ny) --> (batch,Nc,Nx*Ny) 305 | permute_ucell_to_GMLS_data = None; # identity map 306 | 307 | permute_pre_GMLS_data_to_ucell = (0,2,1); # maps (batch,Nx*Ny,Ncp) --> (batch,Ncp,Nx*Ny) 308 | 309 | reshape_GMLS_data_to_ucell = (-1,Ncp1,NNx,NNy); # maps to (batch,Ncp,Nx,Ny), Ncp new channels 310 | permute_GMLS_data_to_ucell = None; # identity map 311 | 312 | self.layer1 = nn.Sequential( 313 | PeriodicPad2d(), # expands 2D unit cell perodically with width by x3 314 | ReshapeLayer(reshape_ucell_to_GMLS_data, permute_ucell_to_GMLS_data), # reshape for GMLS-Net 315 | #PdbSetTraceLayer(), 316 | GMLS_Layer(flag_GMLS_type, porder1, 317 | pts_x1, layer1_epsilon, 318 | weight_func1, weight_func1_params, 319 | mlp_q=mlp_q1, pts_x2=pts_x2, device=device, 320 | flag_verbose=flag_verbose), 321 | #PdbSetTraceLayer(), 322 | ExtractFromTuple(index=0), # just get the forward output associated with the mapping and not pts_x2 323 | PermuteLayer(permute_pre_GMLS_data_to_ucell), 324 | ReshapeLayer(reshape_GMLS_data_to_ucell, permute_GMLS_data_to_ucell), # reshape the data out of GMLS 325 | #PdbSetTraceLayer(), 326 | ExtractUnitCell2d() 327 | ).to(device); 328 | 329 | else: 330 | raise Exception('The flag_layer1 type was not recognized.'); 331 | 332 | def forward(self, x): 333 | out = self.layer1(x); 334 | return out; 335 | 336 | 337 | # ### Setup the Model: Neural Network 338 | 339 | # In[8]: 340 | 341 | 342 | # setup parameters 343 | porder = 2; num_polys = get_num_polys(porder,num_dim); 344 | 345 | weight_func1 = MapToPoly_Function.weight_one_minus_r; 346 | targ_kernel_width = 1.5; layer1_epsilon = 1.1*0.5*np.sqrt(2)*targ_kernel_width/Nx; 347 | weight_func1_params = {'epsilon': layer1_epsilon,'p':4}; 348 | 349 | #color_input = (0.05,0.44,0.69); 350 | #color_output = (0.44,0.30,0.60); 351 | #color_predict = (0.05,0.40,0.5); 352 | #color_target = (221/255,103/255,103/255); 353 | 354 | # print the current settings 355 | print("GMLS Parameters:") 356 | print("porder = " + str(porder)); 357 | print("num_polys = " + str(num_polys)); 358 | print("layer1_epsilon = %.3e"%layer1_epsilon); 359 | print("weight_func1 = " + str(weight_func1)); 360 | print("weight_func1_params = " + str(weight_func1_params)); 361 | print("num_dim = " + str(num_dim)); 362 | print("Nx = %d, Ny = %d, Nc = %d"%(Nx,Ny,Nc)); 363 | print("xj.shape = " + str(xj.shape)); 364 | print("xi.shape = " + str(xi.shape)); 365 | 366 | 367 | # In[9]: 368 | 369 | 370 | # create an MLP for training the non-linear part of the GMLS Net 371 | num_polys = get_num_polys(porder,num_dim); 372 | 373 | print("num_dim = " + str(num_dim)); 374 | print("num_polys = " + str(num_polys)); 375 | print("Nx = " + str(Nx) + " Ny = " + str(Ny)); 376 | print("NNx = " + str(NNx) + " NNy = " + str(NNy)); 377 | print("nchannels_f = " + str(nchannels_f)); 378 | 379 | print(""); 380 | 381 | #flag_mlp_case = 'Linear1';flag_mlp_case = 'Nonlinear1'; 382 | flag_mlp_case = 'Nonlinear1'; 383 | 384 | if (flag_mlp_case == 'Linear1'): 385 | 386 | # -- Layer 1 387 | layer_sizes = []; 388 | 389 | num_depth = 0; # number of internal depth 390 | num_hidden = -1; 391 | 392 | channels_in = Nc; # number of poly channels for input 393 | #channels_out = 16; # number of output filters [might want to decouple by using separate weights here] 394 | channels_out = nchannels_f; # number of output filters, (scalar=1, vector=2,3,...) 395 | Nc2 = channels_out; 396 | 397 | layer_sizes.append(num_polys); 398 | for k in range(num_depth): 399 | layer_sizes.append(num_hidden); 400 | layer_sizes.append(1); # for single unit always give scalar output, we then use channels_out units. 401 | 402 | mlp_q1 = gmlsnets.nn.MLP_Pointwise(layer_sizes, 403 | channels_in=channels_in, 404 | channels_out=channels_out, 405 | flag_bias=False); 406 | mlp_q1.to(device); 407 | 408 | elif (flag_mlp_case == 'Nonlinear1'): 409 | layer_sizes = []; 410 | num_input = Nc*num_polys; # number of channels*num_polys (cross-channel coupling allowed) 411 | #num_depth = 4; # number of internal depth 412 | num_depth = 1; # number of internal depth 413 | num_hidden = 500; 414 | num_out_channels = nchannels_f; # number of output filters, (scalar=1, vector=2,3,...) 415 | layer_sizes.append(num_polys); 416 | for k in range(num_depth): 417 | layer_sizes.append(num_hidden); 418 | layer_sizes.append(1); # for single unit always give scalar output, we then use channels_out units. 419 | #layer_sizes.append(num_out_channels); 420 | 421 | #mlp_q = gmlsnets.nn.atzGMLS_MLP1_Module(layer_sizes); 422 | mlp_q1 = gmlsnets.nn.MLP_Pointwise(layer_sizes,channels_out=num_out_channels); 423 | mlp_q1.to(device); 424 | 425 | if flag_print_model: 426 | print("mlp_q1:"); 427 | print(mlp_q1); 428 | 429 | 430 | # In[10]: 431 | 432 | 433 | # Setup the Neural Network for Regression 434 | flag_verbose = 0; 435 | flag_case = 'standard'; 436 | 437 | # Setup the model 438 | model = gmlsNetRegressionDiffOp2(flag_case,porder,Nx,Ny,Nc,xj,layer1_epsilon, 439 | weight_func1,weight_func1_params, 440 | mlp_q1=mlp_q1,pts_x2=xi, 441 | device=device, 442 | flag_verbose=flag_verbose); 443 | 444 | if flag_print_model: 445 | print("model:"); 446 | print(model); 447 | 448 | 449 | # ## Train the Model 450 | 451 | # ### Custom Functions 452 | 453 | # In[11]: 454 | 455 | 456 | def custom_loss_least_squares(val1,val2): 457 | r"""Computes the Mean-Square-Error (MSE) over the entire batch.""" 458 | diff_flat = (val1 - val2).flatten(); 459 | N = diff_flat.shape[0]; 460 | loss = torch.sum(torch.pow(diff_flat,2),-1)/N; 461 | return loss; 462 | 463 | 464 | # ### Initialize 465 | 466 | # In[12]: 467 | 468 | 469 | # setup one-time data structures 470 | loss_list = np.empty(0); loss_step_list = np.empty(0); 471 | save_skip = 1; step_count = 0; 472 | 473 | 474 | # ### Train the network. 475 | 476 | # In[13]: 477 | 478 | 479 | num_epochs = int(2e0); 480 | learning_rate = 1e-1; 481 | 482 | print("Training the model with:"); 483 | 484 | print("model:"); 485 | print("model.layer_types = " + str(model.layer_types)); 486 | print(""); 487 | 488 | # setup the optimization method and loss function 489 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate); 490 | 491 | loss_func = custom_loss_least_squares; 492 | 493 | print("num_epochs = %d"%num_epochs); 494 | print("batch_size = %d"%batch_size); 495 | print(" "); 496 | 497 | # Train the model 498 | flag_time_it = True; 499 | if flag_time_it: 500 | time_1 = time.time(); 501 | 502 | numSteps = len(train_loader); 503 | for epoch in range(num_epochs): 504 | for i, (input,target) in enumerate(train_loader): 505 | #if 1 == 1: 506 | input = input.to(device); 507 | target = target.to(device); 508 | 509 | # Compute model 510 | output = model(input); 511 | 512 | # Compute loss 513 | loss = loss_func(output,target); 514 | 515 | # Display 516 | if step_count % save_skip == 0: 517 | np_loss = loss.cpu().detach().numpy(); 518 | loss_list = np.append(loss_list,np_loss); 519 | loss_step_list = np.append(loss_step_list,step_count); 520 | 521 | # Backward and optimize 522 | optimizer.zero_grad(); 523 | loss.backward(); 524 | 525 | #for p in model.parameters(): 526 | # print(p.grad); 527 | 528 | optimizer.step(); 529 | 530 | step_count += 1; 531 | 532 | if ((i + 1) % 100) == 0 or i == 0: 533 | #if ((step_count + 1) % int(1e3)) == 0 or step_count == 0: 534 | #print("WARNING: Debug mode..."); 535 | msg = 'epoch: [%d/%d]; '%(epoch+1,num_epochs); 536 | msg += 'batch_step = [%d/%d]; '%(i + 1,numSteps); 537 | msg += 'loss_MSE: %.3e.'%(loss.item()); 538 | print(msg); 539 | 540 | if flag_time_it and i > 0: 541 | msg = 'elapsed_time = %.1e secs \n'%(time.time() - time_1); 542 | print(msg); 543 | time_1 = time.time(); 544 | 545 | print("done."); 546 | 547 | 548 | # ### Plot Loss 549 | 550 | # In[14]: 551 | 552 | 553 | #get_ipython().run_line_magic('matplotlib', 'inline') 554 | 555 | plt.figure(figsize=(8,6)); 556 | 557 | plt.plot(loss_step_list,loss_list,'b-'); 558 | plt.yscale('log'); 559 | plt.xlabel('step'); 560 | plt.ylabel('loss'); 561 | 562 | plt.title('Loss'); 563 | 564 | if flag_save_figs: 565 | fig_name = 'training_loss'; 566 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True); 567 | 568 | 569 | # ### Test the Neural Network Predictions 570 | 571 | # In[15]: 572 | 573 | 574 | print("Testing predictions of the neural network:"); 575 | 576 | flag_save_tests = True; 577 | if flag_save_tests: 578 | test_data = {}; 579 | 580 | # Save the first few to show as examples of labeling 581 | saved_test_input = []; 582 | saved_test_target = []; 583 | saved_test_output_pred = []; 584 | 585 | count_batch = 0; 586 | with torch.no_grad(): 587 | total = 0; II = 0; 588 | avg_error = 0; 589 | for input,target in test_loader: # loads data in batches and then sums up 590 | 591 | if (II >= 1000): 592 | print("tested on %d samples"%total); 593 | II = 0; 594 | 595 | input = input.to(device); target = target.to(device); 596 | 597 | # Compute model 598 | output = model(input); 599 | 600 | # Compute loss 601 | loss = loss_func(output,target); 602 | 603 | # Record the results 604 | avg_error += loss; 605 | 606 | total += output.shape[0]; 607 | II += output.shape[0]; 608 | count_batch += 1; 609 | 610 | NN = output.shape[0]; 611 | for k in range(min(NN,20)): # save first 10 images of each batch 612 | saved_test_input.append(input[k]); 613 | saved_test_target.append(target[k]); 614 | saved_test_output_pred.append(output[k]); 615 | 616 | print(""); 617 | print("Tested on a total of %d samples."%total); 618 | print(""); 619 | 620 | # Compute RMSD error 621 | test_accuracy = avg_error.cpu()/count_batch; 622 | test_accuracy = np.sqrt(test_accuracy); 623 | 624 | print("The neural network has RMSD error %.2e on the %d test samples."%(test_accuracy,total)); 625 | print(""); 626 | 627 | 628 | # ### Show a Sample of the Predictions 629 | 630 | # In[16]: 631 | 632 | 633 | # collect a subset of the data to show and attach named labels 634 | #get_ipython().run_line_magic('matplotlib', 'inline') 635 | 636 | num_prediction_samples = len(saved_test_input); 637 | print("num_prediction_samples = " + str(num_prediction_samples)); 638 | 639 | #II = np.random.permutation(num_samples); # compute random collection of indices @optimize 640 | II = np.arange(num_prediction_samples); 641 | 642 | if flag_dataset == 'name-here' or 0 == 0: 643 | u_list = []; f_target_list = []; f_pred_list = []; 644 | for I in np.arange(0,min(num_prediction_samples,16)): 645 | u_list.append(saved_test_input[II[I]].cpu()); 646 | f_target_list.append(saved_test_target[II[I]].cpu()); 647 | f_pred_list.append(saved_test_output_pred[II[I]].cpu()); 648 | 649 | # plot predictions against test data 650 | for Ic_f in range(0,nchannels_f): 651 | title = "Test Samples and Predictions: u, f=L[u], L = %s, Ic_f = %d"%(op_type,Ic_f); 652 | gmlsnets.vis.plot_samples_u_f_fp_2d(u_list,f_target_list,f_pred_list,np_xj,np_xi, 653 | channelI_f=Ic_f,rows=4,cols=6, 654 | title=title); 655 | 656 | if flag_save_figs: 657 | fig_name = 'predictions_Ic_f_%d'%Ic_f; 658 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True); 659 | 660 | 661 | # ### Save Model 662 | 663 | # In[17]: 664 | 665 | 666 | model_filename = '%s/model.ckpt'%base_dir; 667 | print("model_filename = " + model_filename); 668 | torch.save(model.state_dict(), model_filename); 669 | 670 | model_filename = "%s/model_state.pickle"%base_dir; 671 | print("model_filename = " + model_filename); 672 | f = open(model_filename,'wb'); 673 | pickle.dump(model.state_dict(),f); 674 | f.close(); 675 | 676 | 677 | # ### Display the GMLS-Nets Learned Parameters 678 | 679 | # In[18]: 680 | 681 | 682 | flag_run_cell = flag_print_model; 683 | 684 | if flag_run_cell: 685 | print("-"*80) 686 | print("model.parameters():"); 687 | ll = model.parameters(); 688 | for l in ll: 689 | print(l); 690 | 691 | if flag_run_cell: 692 | print("-"*80) 693 | print("model.state_dict():"); 694 | print(model.state_dict()); 695 | print("-"*80) 696 | 697 | 698 | # ### Done 699 | 700 | # In[ ]: 701 | 702 | 703 | print("="*80); 704 | 705 | -------------------------------------------------------------------------------- /examples/z_doc_img/diff_op_1d_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atzberg/gmls-nets/aaa036b9ea9fcdbcf7b5892fd55886dbd3b61915/examples/z_doc_img/diff_op_1d_1.png -------------------------------------------------------------------------------- /examples/z_doc_img/diff_op_2d_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atzberg/gmls-nets/aaa036b9ea9fcdbcf7b5892fd55886dbd3b61915/examples/z_doc_img/diff_op_2d_1.png -------------------------------------------------------------------------------- /gmlsnets_pytorch-1.0.0.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atzberg/gmls-nets/aaa036b9ea9fcdbcf7b5892fd55886dbd3b61915/gmlsnets_pytorch-1.0.0.tar.gz -------------------------------------------------------------------------------- /misc/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atzberg/gmls-nets/aaa036b9ea9fcdbcf7b5892fd55886dbd3b61915/misc/overview.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # global initialization 2 | name="gmlsnets_pytorch"; # package name 3 | __version__="1.0.0"; # package version 4 | 5 | 6 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collection of codes for generating some training data sets. 3 | """ 4 | 5 | # Authors: B.J. Gross and P.J. Atzberger 6 | # Website: http://atzberger.org/ 7 | 8 | import torch; 9 | import numpy as np; 10 | import pdb; 11 | 12 | class diffOp1(torch.utils.data.Dataset): 13 | r""" 14 | Generates samples of the form :math:`(u^{[i]},f^{[i]})` where :math:`f^{[i]} = L[u^{[i]}]`, 15 | where :math:`i` denotes the index of the sample. 16 | 17 | Stores data samples in the form :math:`(u,f)`. 18 | 19 | The samples of u are represented as a tensor of size [nsamples,nchannels,nx] 20 | and sample of f as a tensor of size [nsamples,nchannels,nx]. 21 | 22 | Note: 23 | For now, please use nx that is odd. In this initial implementation, we use a 24 | method based on conjugated flips with formula for the odd case which is slightly 25 | simpler than other case. 26 | 27 | """ 28 | def flipForFFT(self,u_k_part): 29 | r"""We flip as :math:`f_k = f_{N-k}`. Notice that only :math:`0,\ldots,N-1` entries 30 | stored. This is useful for constructing real-valued function representations 31 | from random coefficients. Real-valued function requires :math:`conj(f_k) = f_{N-k}`. 32 | We can use this flip to construct from random coefficients the term 33 | :math:`u_k = f_k + conj(flip(f_k))`, then above constraint is satisfied. 34 | 35 | Args: 36 | a (Tensor): 1d array to flip. 37 | 38 | Returns: 39 | Tensor: The flipped tensors symmetric under conjucation. 40 | """ 41 | nx = self.nx; 42 | uu = u_k_part[:,:,nx:0:-1]; 43 | vv = u_k_part[:,:,0]; 44 | vv = np.expand_dims(vv,2); 45 | uu_k_flip = np.concatenate([vv,uu],2); 46 | 47 | return uu_k_flip; 48 | 49 | def getComplex(self,a,b): 50 | j = np.complex(0,1); # create complex number (or use 1j). 51 | c = a + j*b; 52 | return c; 53 | 54 | def getRealImag(self,c): 55 | a = np.real(c); 56 | b = np.imag(c); 57 | return a,b; 58 | 59 | def computeLSymbol_ux(self): 60 | r"""Compute associated Fourier symbols for use under DFT for the operator L[u].""" 61 | nx = self.nx; 62 | vec_k1 = torch.zeros(nx); 63 | vec_k1_pp = torch.zeros(nx); 64 | vec_k_sq = torch.zeros(nx); 65 | L_symbol_real = torch.zeros(nx,dtype=torch.float32); 66 | L_symbol_imag = torch.zeros(nx,dtype=torch.float32); 67 | two_pi = 2.0*np.pi; 68 | #two_pi_i = two_pi*1j; # $2\pi{i}$, 1j = sqrt(-1) 69 | for i in range(0,nx): 70 | vec_k1[i] = i; 71 | if (vec_k1[i] < nx/2): 72 | vec_k1_p = vec_k1[i]; 73 | else: 74 | vec_k1_p = vec_k1[i] - nx; 75 | vec_k1_pp[i] = vec_k1_p; 76 | L_symbol_real[i] = 0.0; 77 | L_symbol_imag[i] = two_pi*vec_k1_p; 78 | 79 | L_hat = self.getComplex(L_symbol_real.numpy(),L_symbol_imag.numpy()); 80 | 81 | return L_hat, vec_k1_pp; 82 | 83 | def computeLSymbol_uxx(self): 84 | r"""Compute associated Fourier symbols for use under DFT for the operator L[u].""" 85 | nx = self.nx; 86 | vec_k1 = torch.zeros(nx); 87 | vec_k1_pp = torch.zeros(nx); 88 | vec_k_sq = torch.zeros(nx); 89 | L_symbol_real = torch.zeros(nx,dtype=torch.float32); 90 | L_symbol_imag = torch.zeros(nx,dtype=torch.float32); 91 | neg_four_pi_sq = -4.0*np.pi*np.pi; 92 | for i in range(0,nx): 93 | vec_k1[i] = i; 94 | vec_k_sq[i] = vec_k1[i]*vec_k1[i]; 95 | if (vec_k1[i] < nx/2): 96 | vec_k1_p = vec_k1[i]; 97 | else: 98 | vec_k1_p = vec_k1[i] - nx; 99 | vec_k1_pp[i] = vec_k1_p; 100 | vec_k_p_sq = vec_k1_p*vec_k1_p; 101 | L_symbol_real[i] = neg_four_pi_sq*vec_k_p_sq; 102 | L_symbol_imag[i] = 0.0; 103 | 104 | L_hat = self.getComplex(L_symbol_real.numpy(),L_symbol_imag.numpy()); 105 | 106 | return L_hat, vec_k1_pp; 107 | 108 | def computeCoeffActionL(self,u_hat,L_hat): 109 | r"""Computes the action of operator L used for data generation in Fourier space.""" 110 | u_k_real, u_k_imag = self.getRealImag(u_hat); 111 | L_symbol_real, L_symbol_imag = self.getRealImag(L_hat); 112 | 113 | f_k_real = L_symbol_real*u_k_real - L_symbol_imag*u_k_imag; #broadcast will distr over copies of u. 114 | f_k_imag = L_symbol_real*u_k_imag + L_symbol_imag*u_k_real; 115 | 116 | # Generate samples u and f using ifft 117 | f_hat = self.getComplex(f_k_real,f_k_imag); 118 | 119 | return f_hat; 120 | 121 | def computeActionL(self,u,L_hat): 122 | r"""Computes the action of operator L used for data generation.""" 123 | raise Exception('Currently this routine not debugged, need to test first.') 124 | 125 | if flag_verbose > 0: 126 | print("computeActionL(): WARNING: Not yet fully tested."); 127 | 128 | # perform FFT to get u_hat 129 | u_hat = np.fft.fft(u); 130 | 131 | # compute action of L_hat 132 | f_hat = self.computeCoeffActionL(u_hat,L_hat); 133 | 134 | # compute inverse FFT to get f 135 | f = np.fft.ifft(f_hat); 136 | 137 | return f; 138 | 139 | def __init__(self,op_type='uxx',op_params=None, 140 | gen_mode='exp1',gen_params={'alpha1':0.1}, 141 | num_samples=int(1e4),nchannels=1,nx=15, 142 | flag_verbose=0, **extra_params): 143 | r"""Setup for data generation. 144 | 145 | Args: 146 | op_type (str): The differential operator to sample. 147 | op_params (dict): The operator parameters. 148 | gen_mode (str): The mode for the data generator. 149 | gen_params (dict): The parameters for the given generator. 150 | num_samples (int): The number of samples to generate. 151 | nchannels (int): The number of channels. 152 | nx (int): The number of input sample points. 153 | flag_verbose (int): Level of reporting during calculations. 154 | extra_params (dict): Extra parameters for the sampler. 155 | 156 | For extra_params we have: 157 | noise_factor (float): The amount of noise to add to samples. 158 | scale_factor (float): A factor to scale magnitude of the samples. 159 | flagComputeL (bool): If the fourier symbol of operator should be computed. 160 | 161 | For generator modes we have: 162 | gen_mode == 'exp1': 163 | alpha1 (float): The decay rate. 164 | 165 | Note: 166 | For now, please use only nx that is odd. In this initial implementation, we use a 167 | method based on conjugated flips with formula for the odd case which is slightly 168 | simpler than other case. 169 | """ 170 | super(diffOp1, self).__init__(); 171 | 172 | if flag_verbose > 0: 173 | print("Generating the data samples which can take some time."); 174 | print("num_samples = %d"%num_samples); 175 | 176 | self.op_type=op_type; 177 | self.op_params=op_params; 178 | 179 | self.gen_mode=gen_mode; 180 | self.gen_params=gen_params; 181 | 182 | self.num_samples=num_samples; 183 | self.nchannels=nchannels; 184 | self.nx=nx; 185 | 186 | if (nx % 2 == 0): 187 | msg = "Not allowed yet to use nx that is even. "; 188 | msg += "For now, please just use nx that is odd given the flips currently used." 189 | raise Exception(msg); 190 | 191 | noise_factor=0;scale_factor=1.0;flagComputeL=False; # default values 192 | if 'noise_factor' in extra_params: 193 | noise_factor = extra_params['noise_factor']; 194 | 195 | if 'scale_factor' in extra_params: 196 | scale_factor = extra_params['scale_factor']; 197 | 198 | if 'flagComputeL' in extra_params: 199 | flagComputeL = extra_params['flagComputeL']; 200 | 201 | # Generate for the operator the Fourier symbols 202 | if self.op_type == 'ux' or self.op_type == 'u*ux' or self.op_type == 'ux*ux': 203 | L_hat, vec_k1_pp = self.computeLSymbol_ux(); 204 | elif self.op_type == 'uxx' or self.op_type == 'u*uxx' or self.op_type == 'uxx*uxx': 205 | L_hat, vec_k1_pp = self.computeLSymbol_uxx(); 206 | else: 207 | raise Exception("Unkonwn operator type."); 208 | 209 | if (flagComputeL): 210 | L_i = np.fft.ifft(L_hat); 211 | self.L_hat = L_hat; 212 | self.L_i = L_i; 213 | u = np.zeros(nx); 214 | i0 = int(nx/2); 215 | u[i0] = 1.0; 216 | self.G_i = self.computeActionL(u); 217 | 218 | # Generate random input function (want real-valued) 219 | # conj(u_k) = u_{N -k} needs to hold. 220 | u_k_real = np.random.randn(num_samples,nchannels,nx); 221 | u_k_imag = np.random.randn(num_samples,nchannels,nx); 222 | 223 | # scale modes to make smooth 224 | if gen_mode=='exp1': 225 | alpha1 = gen_params['alpha1']; 226 | factor_k = scale_factor*np.exp(-alpha1*vec_k1_pp**2); 227 | factor_k = factor_k.numpy(); 228 | else: 229 | raise Exception("Generation mode not recognized."); 230 | 231 | u_k_real = u_k_real*factor_k; # broadcast will apply over last two dimensions 232 | u_k_imag = u_k_imag*factor_k; # broadcast will apply over last two dimensions 233 | 234 | flag_debug = False; 235 | if flag_debug: 236 | if flag_verbose > 0: 237 | print("WARNING: debugging mode on."); 238 | 239 | u_k_real = 0.0*u_k_real; 240 | u_k_imag = 0.0*u_k_imag; 241 | 242 | u_k_real[0,0,1] = nx; 243 | u_k_imag[0,0,1] = 0; 244 | 245 | u_k_real[1,0,1] = 0; 246 | u_k_imag[1,0,1] = nx; 247 | 248 | u_k_real[2,0,1] = nx; 249 | u_k_imag[2,0,1] = nx; 250 | 251 | # flip modes for constructing rep of real-valued function 252 | u_k_real_flip = self.flipForFFT(u_k_real); 253 | u_k_imag_flip = self.flipForFFT(u_k_imag); 254 | 255 | u_k_real_p = 0.5*u_k_real + 0.5*u_k_real_flip; # make conjugate conj(u_k) = u_{N -k} 256 | u_k_imag_p = 0.5*u_k_imag - 0.5*u_k_imag_flip; # make conjugate conj(u_k) = u_{N -k} 257 | 258 | u_k_real_p = torch.from_numpy(u_k_real_p); 259 | u_k_imag_p = torch.from_numpy(u_k_imag_p); 260 | 261 | u_k_real_p = u_k_real_p.type(torch.float32); 262 | u_k_imag_p = u_k_imag_p.type(torch.float32); 263 | 264 | u_hat = self.getComplex(u_k_real_p.numpy(),u_k_imag_p.numpy()); 265 | 266 | f_hat = self.computeCoeffActionL(u_hat,L_hat); 267 | f_hat = f_hat; # target operator relation for PDEs later is Lu = -f, so f = -Lu. 268 | 269 | # Generate samples u and f, in 2d using ifft2. 270 | # ifft2 is broadcast over last two indices 271 | # perform inverse DFT to get u and f. 272 | u_i = np.fft.ifft(u_hat); 273 | f_i = np.fft.ifft(f_hat); 274 | 275 | if self.op_type == 'u*ux': 276 | f_i = u_i*f_i; 277 | elif self.op_type == 'ux*ux': 278 | f_i = f_i*f_i; 279 | elif self.op_type == 'u*uxx': 280 | f_i = u_i*f_i; 281 | elif self.op_type == 'uxx*uxx': 282 | f_i = f_i*f_i; 283 | 284 | self.samples_X = torch.from_numpy(np.real(u_i)).type(torch.float32); # only grab real part 285 | self.samples_Y = torch.from_numpy(np.real(f_i)).type(torch.float32); 286 | 287 | if noise_factor > 0: 288 | self.samples_Y += noise_factor*torch.randn(*self.samples_Y.shape); 289 | 290 | def __len__(self): 291 | return self.samples_X.size()[0]; 292 | 293 | def __getitem__(self,index): 294 | return self.samples_X[index],self.samples_Y[index]; 295 | 296 | def to(self,device): 297 | self.samples_X = self.samples_X.to(device); 298 | self.samples_Y = self.samples_Y.to(device); 299 | 300 | return self; 301 | 302 | class diffOp2(torch.utils.data.Dataset): 303 | r""" 304 | Generates samples of the form :math:`(u^{[i]},f^{[i]})` where :math:`f^{[i]} = L[u^{[i]}]`, 305 | where :math:`i` denotes the index of the sample. 306 | 307 | Stores data samples in the form :math:`(u,f)`. 308 | 309 | The samples of u are represented as a tensor of size [nsamples,nchannels,nx] 310 | and sample of f as a tensor of size [nsamples,nchannels,nx]. 311 | 312 | Note: 313 | For now, please use nx that is odd. In this initial implementation, we use a 314 | method based on conjugated flips with formula for the odd case which is slightly 315 | simpler than other case. 316 | 317 | """ 318 | def flipForFFT(self,u_k_part): 319 | r"""We flip as :math:`f_k = f_{N-k}`. Notice that only :math:`0,\ldots,N-1` entries 320 | stored. This is useful for constructing real-valued function representations 321 | from random coefficients. Real-valued function requires :math:`conj(f_k) = f_{N-k}`. 322 | We can use this flip to construct from random coefficients the term 323 | :math:`u_k = f_k + conj(flip(f_k))`, then above constraint is satisfied. 324 | 325 | Args: 326 | a (Tensor): 1d array to flip. 327 | 328 | Returns: 329 | Tensor: The flipped tensors symmetric under conjucation. 330 | """ 331 | nx = self.nx;ny = self.ny; 332 | 333 | u_k_part_row0 = u_k_part[:,:,0,:]; 334 | u_k_part_row0 = np.expand_dims(u_k_part_row0,2); 335 | u_k_part_ex = np.concatenate([u_k_part,u_k_part_row0],2); 336 | 337 | u_k_part_col0 = u_k_part_ex[:,:,:,0]; 338 | u_k_part_col0 = np.expand_dims(u_k_part_col0,3); 339 | u_k_part_ex = np.concatenate([u_k_part_ex,u_k_part_col0],3); 340 | 341 | u_k_part_ex_flip = np.flip(u_k_part_ex,2); 342 | u_k_part_ex_flip = np.flip(u_k_part_ex_flip,3); 343 | 344 | u_k_part_flip = np.delete(u_k_part_ex_flip,nx,2); 345 | u_k_part_flip = np.delete(u_k_part_flip,ny,3); 346 | 347 | return u_k_part_flip; 348 | 349 | def getComplex(self,a,b): 350 | j = np.complex(0,1); # create complex number (or use 1j). 351 | c = a + j*b; 352 | return c; 353 | 354 | def getRealImag(self,c): 355 | a = np.real(c); 356 | b = np.imag(c); 357 | return a,b; 358 | 359 | def computeLSymbol_laplacian_u(self): 360 | r"""Compute associated Fourier symbols for use under DFT for the operator L[u].""" 361 | num_dim = 1;nx=self.nx;ny=self.ny; 362 | vec_k1 = torch.zeros(nx,ny); 363 | vec_k2 = torch.zeros(nx,ny); 364 | vec_k1_pp = torch.zeros(nx,ny); 365 | vec_k2_pp = torch.zeros(nx,ny); 366 | vec_k_sq = torch.zeros(nx,ny); 367 | L_symbol_real = torch.zeros(nx,ny,dtype=torch.float32); 368 | L_symbol_imag = torch.zeros(nx,ny,dtype=torch.float32); 369 | neg_four_pi_sq = -4.0*np.pi*np.pi; 370 | for i in range(0,nx): 371 | for j in range(0,ny): 372 | vec_k1[i,j] = i; 373 | vec_k2[i,j] = j; 374 | vec_k_sq[i,j] = vec_k1[i,j]*vec_k1[i,j] + vec_k2[i,j]*vec_k2[i,j]; 375 | if (vec_k1[i,j] < nx/2): 376 | vec_k1_p = vec_k1[i,j]; 377 | else: 378 | vec_k1_p = vec_k1[i,j] - nx; 379 | if (vec_k2[i,j] < ny/2): 380 | vec_k2_p = vec_k2[i,j]; 381 | else: 382 | vec_k2_p = vec_k2[i,j] - ny; 383 | vec_k1_pp[i,j] = vec_k1_p; 384 | vec_k2_pp[i,j] = vec_k2_p; 385 | vec_k_p_sq = vec_k1_p*vec_k1_p + vec_k2_p*vec_k2_p; 386 | L_symbol_real[i,j] = neg_four_pi_sq*vec_k_p_sq; 387 | L_symbol_imag[i,j] = 0.0; 388 | 389 | L_hat = self.getComplex(L_symbol_real.numpy(),L_symbol_imag.numpy()); 390 | 391 | return L_hat, vec_k1_pp, vec_k2_pp; 392 | 393 | def computeLSymbol_grad_u(self): 394 | r"""Compute associated Fourier symbols for use under DFT for the operator L[u].""" 395 | num_dim = 2;nx=self.nx;ny=self.ny; 396 | vec_k1 = torch.zeros(nx,ny); 397 | vec_k2 = torch.zeros(nx,ny); 398 | vec_k1_pp = torch.zeros(nx,ny); 399 | vec_k2_pp = torch.zeros(nx,ny); 400 | vec_k_sq = torch.zeros(nx,ny); 401 | L_symbol_real = torch.zeros(num_dim,nx,ny,dtype=torch.float32); 402 | L_symbol_imag = torch.zeros(num_dim,nx,ny,dtype=torch.float32); 403 | two_pi = 2.0*np.pi; 404 | #two_pi_i = two_pi*1j; # $2\pi{i}$, 1j = sqrt(-1) 405 | for i in range(0,nx): 406 | for j in range(0,ny): 407 | vec_k1[i,j] = i; 408 | vec_k2[i,j] = j; 409 | vec_k_sq[i,j] = vec_k1[i,j]*vec_k1[i,j] + vec_k2[i,j]*vec_k2[i,j]; 410 | if (vec_k1[i,j] < nx/2): 411 | vec_k1_p = vec_k1[i,j]; 412 | else: 413 | vec_k1_p = vec_k1[i,j] - nx; 414 | if (vec_k2[i,j] < ny/2): 415 | vec_k2_p = vec_k2[i,j]; 416 | else: 417 | vec_k2_p = vec_k2[i,j] - ny; 418 | vec_k1_pp[i,j] = vec_k1_p; 419 | vec_k2_pp[i,j] = vec_k2_p; 420 | vec_k_p_sq = vec_k1_p*vec_k1_p + vec_k2_p*vec_k2_p; 421 | L_symbol_real[0,i,j] = 0.0; 422 | L_symbol_imag[0,i,j] = two_pi*vec_k1_p; 423 | L_symbol_real[1,i,j] = 0.0; 424 | L_symbol_imag[1,i,j] = two_pi*vec_k2_p; 425 | 426 | L_hat_0 = self.getComplex(L_symbol_real[0,:,:].numpy(),L_symbol_imag[0,:,:].numpy()); 427 | L_hat_1 = self.getComplex(L_symbol_real[1,:,:].numpy(),L_symbol_imag[1,:,:].numpy()); 428 | 429 | L_hat = np.stack((L_hat_0,L_hat_1)); 430 | 431 | return L_hat, vec_k1_pp, vec_k2_pp; 432 | 433 | def computeCoeffActionL(self,u_hat,L_hat): 434 | r"""Computes the action of operator L used for data generation in Fourier space.""" 435 | u_k_real, u_k_imag = self.getRealImag(u_hat); 436 | L_symbol_real, L_symbol_imag = self.getRealImag(L_hat); 437 | 438 | f_k_real = L_symbol_real*u_k_real - L_symbol_imag*u_k_imag; #broadcast will distr over copies of u. 439 | #f_k_real = -1.0*f_k_real; 440 | f_k_imag = L_symbol_real*u_k_imag + L_symbol_imag*u_k_real; 441 | #f_k_imag = -1.0*f_k_imag; 442 | 443 | # Generate samples u and f using ifft2. 444 | f_hat = self.getComplex(f_k_real,f_k_imag); 445 | 446 | return f_hat; 447 | 448 | def computeActionL(self,u,L_hat): 449 | r"""Computes the action of operator L used for data generation.""" 450 | raise Exception('Currently this routine not debugged, need to test first.') 451 | 452 | # perform FFT to get u_hat 453 | u_hat = np.fft.fft2(u); 454 | 455 | # compute action of L_hat 456 | f_hat = self.computeCoeffActionL(u_hat,L_hat); 457 | 458 | # compute inverse FFT to get f 459 | f = np.fft.ifft2(f_hat) 460 | 461 | return f; 462 | 463 | def __init__(self,op_type=r'\Delta{u}',op_params=None, 464 | gen_mode='exp1',gen_params={'alpha1':0.1}, 465 | num_samples=int(1e4),nchannels=1,nx=15,ny=15, 466 | flag_verbose=0, **extra_params): 467 | r"""Setup for data generation. 468 | 469 | Args: 470 | op_type (str): The differential operator to sample. 471 | op_params (dict): The operator parameters. 472 | gen_mode (str): The mode for the data generator. 473 | gen_params (dict): The parameters for the given generator. 474 | num_samples (int): The number of samples to generate. 475 | nchannels (int): The number of channels. 476 | nx (int): The number of input sample points in x-direction. 477 | ny (int): The number of input sample points in y- direction. 478 | flag_verbose (int): Level of reporting during calculations. 479 | extra_params (dict): Extra parameters for the sampler. 480 | 481 | For extra_params we have: 482 | noise_factor (float): The amount of noise to add to samples. 483 | scale_factor (float): A factor to scale magnitude of the samples. 484 | flagComputeL (bool): If the fourier symbol of operator should be computed. 485 | 486 | For generator modes we have: 487 | gen_mode == 'exp1': 488 | alpha1 (float): The decay rate. 489 | 490 | Note: 491 | For now, please use only nx that is odd. In this initial implementation, we use a 492 | method based on conjugated flips with formula for the odd case which is slightly 493 | simpler than other case. 494 | """ 495 | if flag_verbose > 0: 496 | print("Generating the data samples which can take some time."); 497 | print("num_samples = %d"%num_samples); 498 | 499 | self.op_type=op_type; 500 | self.op_params=op_params; 501 | 502 | self.gen_mode=gen_mode; 503 | self.gen_params=gen_params; 504 | 505 | self.num_samples=num_samples; 506 | self.nchannels=nchannels; 507 | self.nx=nx; self.ny=ny; 508 | 509 | if (nx % 2 == 0) or (ny % 2 == 0) or (nx != ny): # may be able to relax nx != ny (just for safety) 510 | msg = "Not allowed yet to use nx,ny that are even or unequal. "; 511 | msg += "For now, please just use nx,ny that is odd given the flips currently used." 512 | raise Exception(msg); 513 | 514 | noise_factor=0;scale_factor=1.0;flagComputeL=False; # default values 515 | if 'noise_factor' in extra_params: 516 | noise_factor = extra_params['noise_factor']; 517 | 518 | if 'scale_factor' in extra_params: 519 | scale_factor = extra_params['scale_factor']; 520 | 521 | if 'flagComputeL' in extra_params: 522 | flagComputeL = extra_params['flagComputeL']; 523 | 524 | # Generate for the operator the Fourier symbols 525 | flag_vv = 'null'; 526 | if self.op_type == r'\grad{u}' or self.op_type == r'u\grad{u}' or self.op_type == r'\grad{u}\cdot\grad{u}': 527 | L_hat, vec_k1_pp, vec_k2_pp = self.computeLSymbol_grad_u(); 528 | flag_vv = 'vector2'; 529 | elif self.op_type == r'\Delta{u}' or self.op_type == r'u\Delta{u}' or self.op_type == r'\Delta{u}*\Delta{u}': 530 | L_hat, vec_k1_pp, vec_k2_pp = self.computeLSymbol_laplacian_u(); 531 | flag_vv = 'scalar'; 532 | else: 533 | raise Exception("Unknown operator type."); 534 | 535 | if (flagComputeL): 536 | raise Exception("Currently not yet supported, the flagComputeL."); 537 | L_i = np.fft.ifft2(L_hat); 538 | self.L_hat = L_hat; 539 | self.L_i = L_i; 540 | u = np.zeros(nx,ny); 541 | i0 = int(nx/2); 542 | j0 = int(ny/2); 543 | u[i0,j0] = 1.0; 544 | self.G_i = self.computeActionL(u); 545 | 546 | # Generate random input function (want real-valued) 547 | # conj(u_k) = u_{N -k} needs to hold. 548 | u_k_real = np.random.randn(num_samples,nchannels,nx,ny); 549 | u_k_imag = np.random.randn(num_samples,nchannels,nx,ny); 550 | 551 | # scale modes to make smooth 552 | if gen_mode=='exp1': 553 | alpha1 = gen_params['alpha1']; 554 | factor_k = scale_factor*np.exp(-alpha1*(vec_k1_pp**2 + vec_k2_pp**2)); 555 | factor_k = factor_k.numpy(); 556 | else: 557 | raise Exception("Generation mode not recognized."); 558 | 559 | u_k_real = u_k_real*factor_k; # broadcast will apply over last two dimensions 560 | u_k_imag = u_k_imag*factor_k; # broadcast will apply over last two dimensions 561 | 562 | # flip modes for constructing rep of real-valued function 563 | u_k_real_flip = self.flipForFFT(u_k_real); 564 | u_k_imag_flip = self.flipForFFT(u_k_imag); 565 | 566 | u_k_real = 0.5*u_k_real + 0.5*u_k_real_flip; # make conjugate conj(u_k) = u_{N -k} 567 | u_k_imag = 0.5*u_k_imag - 0.5*u_k_imag_flip; # make conjugate conj(u_k) = u_{N -k} 568 | 569 | u_k_real = torch.from_numpy(u_k_real); 570 | u_k_imag = torch.from_numpy(u_k_imag); 571 | 572 | u_k_real = u_k_real.type(torch.float32); 573 | u_k_imag = u_k_imag.type(torch.float32); 574 | 575 | u_hat = self.getComplex(u_k_real.numpy(),u_k_imag.numpy()); 576 | if flag_vv == 'scalar': 577 | f_hat = self.computeCoeffActionL(u_hat,L_hat); 578 | elif flag_vv == 'vector2': 579 | f_hat_0 = self.computeCoeffActionL(u_hat,L_hat[0,:,:]); 580 | f_hat_1 = self.computeCoeffActionL(u_hat,L_hat[1,:,:]); 581 | f_hat = np.concatenate((f_hat_0,f_hat_1),-3); 582 | else: 583 | raise Exception("Unkonwn operator type."); 584 | 585 | # Generate samples u and f using ifft2. 586 | # ifft2 is broadcast over last two indices 587 | # perform inverse DFT to get u and f 588 | u_i = np.fft.ifft2(u_hat); 589 | if flag_vv == 'scalar': 590 | f_i = np.fft.ifft2(f_hat); 591 | elif flag_vv == 'vector2': 592 | f_i_0 = np.fft.ifft2(f_hat[:,0,:,:]); 593 | f_i_1 = np.fft.ifft2(f_hat[:,1,:,:]); 594 | f_i = np.stack((f_i_0,f_i_1),-3); 595 | else: 596 | raise Exception("Unkonwn operator type."); 597 | 598 | if self.op_type == r'\grad{u}': 599 | f_i = f_i; # nothing to do. 600 | elif self.op_type == r'u\grad{u}': 601 | f_i = u_i*f_i; # matches up by broadcast rules 602 | elif self.op_type == r'\grad{u}\cdot\grad{u}': 603 | f_i = np.sum(f_i**2,1); # sum on axis for channels, [batch,channel,nx,ny]. 604 | f_i = np.expand_dims(f_i,1); # keep in form [batch,1,nx,ny] 605 | elif self.op_type == r'\Delta{u}': 606 | f_i = f_i; # nothing to do. 607 | elif self.op_type == r'u\Delta{u}': 608 | f_i = u_i*f_i; 609 | elif self.op_type == r'\Delta{u}*\Delta{u}': 610 | f_i = f_i**2; 611 | else: 612 | raise Exception("Unkonwn operator type."); 613 | 614 | self.samples_X = torch.from_numpy(np.real(u_i)).type(torch.float32); # only grab real part 615 | self.samples_Y = torch.from_numpy(np.real(f_i)).type(torch.float32); 616 | 617 | if noise_factor > 0: 618 | self.samples_Y += noise_factor*torch.randn(*self.samples_Y.shape); 619 | 620 | def __len__(self): 621 | return self.samples_X.size()[0]; 622 | 623 | def __getitem__(self,index): 624 | return self.samples_X[index],self.samples_Y[index]; 625 | 626 | def to(self,device): 627 | self.samples_X = self.samples_X.to(device); 628 | self.samples_Y = self.samples_Y.to(device); 629 | return self; 630 | 631 | -------------------------------------------------------------------------------- /src/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. image:: overview.png 3 | 4 | PyTorch implementation of GMLS-Nets. Module for neural networks for 5 | processing scattered data sets using Generalized Moving Least Squares (GMLS). 6 | 7 | If you find these codes or methods helpful for your project, please cite: 8 | 9 | | @article{trask_patel_gross_atzberger_GMLS_Nets_2019, 10 | | title={GMLS-Nets: A framework for learning from unstructured data}, 11 | | author={Nathaniel Trask, Ravi G. Patel, Ben J. Gross, Paul J. Atzberger}, 12 | | journal={arXiv:1909.05371}, 13 | | month={September}, 14 | | year={2019}, 15 | | url={https://arxiv.org/abs/1909.05371} 16 | | } 17 | 18 | """ 19 | 20 | # Authors: B.J. Gross and P.J. Atzberger 21 | # Website: http://atzberger.org/ 22 | 23 | import torch; 24 | import torch.nn as nn; 25 | import torchvision; 26 | import torchvision.transforms as transforms; 27 | 28 | import numpy as np; 29 | 30 | import scipy.spatial as spatial # used for finding neighbors within distance $\delta$ 31 | 32 | from collections import OrderedDict; 33 | 34 | import pickle as p; 35 | 36 | import pdb; 37 | 38 | import time; 39 | 40 | # ==================================== 41 | # Custom Functions 42 | # ==================================== 43 | class MapToPoly_Function(torch.autograd.Function): 44 | r""" 45 | This layer processes a collection of scattered data points consisting of a collection 46 | of values :math:`u_j` at points :math:`x_j`. For a collection of target points 47 | :math:`x_i`, local least-squares problems are solved for obtaining a local representation 48 | of the data over a polynomial space. The layer outputs a collection of polynomial 49 | coefficients :math:`c(x_i)` at each point and the collection of target points :math:`x_i`. 50 | """ 51 | @staticmethod 52 | def weight_one_minus_r(z1,z2,params): 53 | r"""Weight function :math:`\omega(x_j,x_i) = \left(1 - r/\epsilon\right)^{\bar{p}}_+.` 54 | 55 | Args: 56 | z1 (Tensor): The first point. Tensor of shape [1,num_dims]. 57 | z2 (Tensor): The second point. Tensor of shape [1,num_dims]. 58 | params (dict): The parameters are 'p' for decay power and 'epsilon' for support size. 59 | 60 | Returns: 61 | Tensor: The weight evaluation over points. 62 | """ 63 | epsilon = params['epsilon']; p = params['p']; 64 | r = torch.sqrt(torch.sum(torch.pow(z1 - z2,2),1)); 65 | 66 | diff = torch.clamp(1 - (r/epsilon),min=0); 67 | eval = torch.pow(diff,p); 68 | return eval; 69 | 70 | @staticmethod 71 | def get_num_polys(porder,num_dim=None): 72 | r""" Returns the number of polynomials of given porder. """ 73 | if num_dim == 1: 74 | num_polys = porder + 1; 75 | elif num_dim == 2: 76 | num_polys = int((porder + 2)*(porder + 1)/2); 77 | elif num_dim == 3: 78 | num_polys = 0; 79 | for beta in range(0,porder + 1): 80 | num_polys += int((porder - beta + 2)*(porder - beta + 1)/2); 81 | else: 82 | raise Exception("Number of dimensions not implemented currently. \n num_dim = %d."%num_dim); 83 | 84 | return num_polys; 85 | 86 | @staticmethod 87 | def eval_poly(pts_x,pts_x2_i0,c_star_i0,porder,flag_verbose): 88 | r""" Evaluates the polynomials locally around a target point xi given coefficients c. """ 89 | # Evaluates the polynomial locally (this helps to assess the current fit). 90 | # Implemented for 1D, 2D, and 3D. 91 | # 92 | # 2D: 93 | # Computes Taylor Polynomials over x and y. 94 | # T_{k1,k2}(x1,x2) = (1.0/(k1 + k2)!)*(x1 - x01)^{k1}*(x2 - x02)^{k2}. 95 | # of terms is N = (porder + 1)*(porder + 2)/2. 96 | # 97 | # WARNING: Note the role of factorials and orthogonality here. The Taylor 98 | # expansion/polynomial formulation is not ideal and can give ill-conditioning. 99 | # It would be better to use orthogonal polynomials or other bases. 100 | # 101 | num_dim = pts_x.shape[1]; 102 | if num_dim == 1: 103 | II = 0; 104 | alpha_factorial = 1.0; 105 | eval_p = torch.zeros(pts_x.shape[0],device=c_star_i0.device); 106 | for alpha in np.arange(0,porder + 1): 107 | if alpha >= 2: 108 | alpha_factorial *= alpha; 109 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k)); 110 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials 111 | base_poly = torch.pow(pts_x[:,0] - pts_x2_i0[0],alpha); 112 | base_poly = base_poly/alpha_factorial; 113 | eval_p += c_star_i0[II]*base_poly; 114 | II += 1; 115 | elif num_dim == 2: 116 | II = 0; 117 | alpha_factorial = 1.0; 118 | eval_p = torch.zeros(pts_x.shape[0],device=c_star_i0.device); 119 | for alpha in np.arange(0,porder + 1): 120 | if alpha >= 2: 121 | alpha_factorial *= alpha; 122 | for k in np.arange(0,alpha + 1): 123 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k)); 124 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials 125 | base_poly = torch.pow(pts_x[:,0] - pts_x2_i0[0],alpha - k); 126 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials 127 | base_poly = base_poly*torch.pow(pts_x[:,1] - pts_x2_i0[1],k); 128 | base_poly = base_poly/alpha_factorial; 129 | eval_p += c_star_i0[II]*base_poly; 130 | II += 1; 131 | elif num_dim == 3: # caution, below gives initial results, but should be more fully validated 132 | II = 0; 133 | alpha_factorial = 1.0; 134 | eval_p = torch.zeros(pts_x.shape[0],device=c_star_i0.device); 135 | for beta in np.arange(0,porder + 1): 136 | base_poly = torch.pow(pts_x[:,2] - pts_x2_i0[2],beta); 137 | for alpha in np.arange(0,porder - beta + 1): 138 | if alpha >= 2: 139 | alpha_factorial *= alpha; 140 | for k in np.arange(0,alpha + 1): 141 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k)); 142 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials 143 | base_poly = base_poly*torch.pow(pts_x[:,0] - pts_x2_i0[0],alpha - k); 144 | base_poly = base_poly*torch.pow(pts_x[:,1] - pts_x2_i0[1],k); 145 | base_poly = base_poly/alpha_factorial; 146 | eval_p += c_star_i0[II]*base_poly; 147 | II += 1; 148 | else: 149 | raise Exception("Number of dimensions not implemented currently. \n num_dim = %d."%num_dim); 150 | 151 | return eval_p; 152 | 153 | @staticmethod 154 | def generate_mapping(weight_func,weight_func_params, 155 | porder,epsilon, 156 | pts_x1,pts_x2, 157 | tree_points=None,device=None, 158 | flag_verbose=0): 159 | r""" Generates for caching the data for the mapping from field values (uj,xj) :math:`\rightarrow` (ci,xi). 160 | This help optimize codes and speed up later calculations that are done repeatedly.""" 161 | if device is None: 162 | device = torch.device('cpu'); 163 | 164 | map_data = {}; 165 | 166 | num_dim = pts_x1.shape[1]; 167 | 168 | if pts_x2 is None: 169 | pts_x2 = pts_x1; 170 | 171 | pts_x1 = pts_x1.to(device); 172 | pts_x2 = pts_x2.to(device); 173 | 174 | pts_x1_numpy = None; pts_x2_numpy = None; 175 | if tree_points is None: # build kd-tree of points for neighbor listing 176 | if pts_x1_numpy is None: pts_x1_numpy = pts_x1.cpu().numpy(); 177 | tree_points = spatial.cKDTree(pts_x1_numpy); 178 | 179 | # Maps from u(x_j) on $x_j \in \mathcal{S}^1$ to a 180 | # polynomial representations in overlapping regions $\Omega_i$ at locations 181 | # around points $x_i \in \mathcal{S}^2$. 182 | # These two sample sets need not be the same allowing mappings between point locations. 183 | # Computes polynomials over x and y. 184 | # Number of terms in 2D is num_polys = (porder + 1)*(porder + 2)/2. 185 | num_pts1 = pts_x1.shape[0]; num_pts2 = pts_x2.shape[0]; 186 | num_polys = MapToPoly_Function.get_num_polys(porder,num_dim); 187 | if flag_verbose > 0: 188 | print("num_polys = " + str(num_polys)); 189 | 190 | M = torch.zeros((num_pts2,num_polys,num_polys),device=device); # assemble matrix at each grid-point 191 | M_inv = torch.zeros((num_pts2,num_polys,num_polys),device=device); # assemble matrix at each grid-point 192 | 193 | #svd_U = torch.zeros((num_pts2,num_polys,num_polys)); # assemble matrix at each grid-point 194 | #svd_S = torch.zeros((num_pts2,num_polys,num_polys)); # assemble matrix at each grid-point 195 | #svd_V = torch.zeros((num_pts2,num_polys,num_polys)); # assemble matrix at each grid-point 196 | 197 | vec_rij = torch.zeros((num_pts2,num_polys,num_pts1),device=device); # @optimize: ideally should be sparse matrix. 198 | 199 | # build up the batch of linear systems for each target point 200 | for i in np.arange(0,num_pts2): # loop over the points $x_i$ 201 | 202 | if (flag_verbose > 0) & (i % 100 == 0): print("i = " + str(i) + " : num_pts2 = " + str(num_pts2)); 203 | 204 | if pts_x2_numpy is None: pts_x2_numpy = pts_x2.cpu().numpy(); 205 | indices_xj_i = tree_points.query_ball_point(pts_x2_numpy[i,:], epsilon); # find all points with distance 206 | # less than epsilon from xi. 207 | 208 | for j in indices_xj_i: # @optimize later to use only local points, and where weights are non-zero. 209 | 210 | if flag_verbose > 1: print("j = " + str(j)); 211 | 212 | vec_p_j = torch.zeros(num_polys,device=device); 213 | w_ij = weight_func(pts_x1[j,:].unsqueeze(0), pts_x2[i,:].unsqueeze(0), weight_func_params); # can optimize for sub-lists outer-product 214 | 215 | # Computes Taylor Polynomials over x,y,z. 216 | # 217 | # 2D Case: 218 | # T_{k1,k2}(x1,x2) = (1.0/(k1 + k2)!)*(x1 - x01)^{k1}*(x2 - x02)^{k2}. 219 | # number of terms is N = (porder + 1)*(porder + 2)/2. 220 | # computes polynomials over x and y. 221 | # 222 | # WARNING: The monomial basis is non-ideal and can lead to ill-conditioned linear algebra. 223 | # This ultimately should be generalized in the future to other bases, ideally orthogonal, 224 | # which would help both with efficiency and conditioning of the linear algebra. 225 | # 226 | if num_dim == 1: 227 | # number of terms is N = porder + 1. 228 | II = 0; 229 | for alpha in np.arange(0,porder + 1): 230 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k)); 231 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials 232 | vec_p_j[II] = torch.pow(pts_x1[j,0] - pts_x2[i,0], alpha); 233 | II += 1; 234 | elif num_dim == 2: 235 | # number of terms is N = (porder + 1)*(porder + 2)/2. 236 | II = 0; 237 | for alpha in np.arange(0,porder + 1): 238 | for k in np.arange(0,alpha + 1): 239 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k)); 240 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials 241 | vec_p_j[II] = torch.pow(pts_x1[j,0] - pts_x2[i,0], alpha - k); 242 | vec_p_j[II] = vec_p_j[II]*torch.pow(pts_x1[j,1] - pts_x2[i,1], k); 243 | II += 1; 244 | elif num_dim == 3: 245 | # number of terms is N = sum_{alpha_3 = 0}^porder [(porder - alpha_3+ 1)*(porder - alpha_3 + 2)/2. 246 | II = 0; 247 | for beta in np.arange(0,porder + 1): 248 | vec_p_j[II] = torch.pow(pts_x1[j,2] - pts_x2[i,2],beta); 249 | for alpha in np.arange(0,porder - beta + 1): 250 | for k in np.arange(0,alpha + 1): 251 | if flag_verbose > 1: 252 | print("beta = " + str(beta)); print("alpha = " + str(alpha)); print("k = " + str(k)); 253 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials 254 | vec_p_j[II] = vec_p_j[II]*torch.pow(pts_x1[j,0] - pts_x2[i,0],alpha - k); 255 | vec_p_j[II] = vec_p_j[II]*torch.pow(pts_x1[j,1] - pts_x2[i,1],k); 256 | II += 1; 257 | 258 | # add contributions to the M(x_i) and r(x_i) terms 259 | # r += (w_ij*u[j])*vec_p_j; 260 | vec_rij[i,:,j] = w_ij*vec_p_j; 261 | M[i,:,:] += torch.ger(vec_p_j,vec_p_j)*w_ij; # outer-product of vectors (build match of matrices) 262 | 263 | # Compute the SVD of M for purposes of computing the pseudo-inverse (for solving least-squares problem). 264 | # Note: M is always symmetric positive semi-definite, so U and V should be transposes of each other 265 | # and sigma^2 are the eigenvalues squared. This simplifies some expressions. 266 | 267 | U,S,V = torch.svd(M[i,:,:]); # M = U*SS*V^T, note SS = diag(S) 268 | threshold_nonzero = 1e-9; # threshold for the largest singular value to consider being non-zero. 269 | I_nonzero = (S > threshold_nonzero); 270 | S_inv = 0.0*S; 271 | S_inv[I_nonzero] = 1.0/S[I_nonzero]; 272 | SS_inv = torch.diag(S_inv); 273 | M_inv[i,:,:] = torch.matmul(V,torch.matmul(SS_inv,U.t())); # pseudo-inverse of M^{-1} = V*S^{-1}*U^T 274 | 275 | # Save the linear system information for the least-squares problem at each target point $xi$. 276 | map_data['M'] = M; 277 | map_data['M_inv'] = M_inv; 278 | map_data['vec_rij'] = vec_rij; 279 | 280 | return map_data; 281 | 282 | @staticmethod 283 | def get_poly_1D_u(u, porder, weight_func, weight_func_params, 284 | pts_x1, epsilon = None, pts_x2 = None, cached_data=None, 285 | tree_points = None, device=None, flag_verbose = 0): 286 | r""" Compute the polynomial coefficients in the case of a scalar field. Would not typically call directly, used for internal purposes. """ 287 | 288 | # We assume that all inputs are pytorch tensors 289 | # Assumes: 290 | # pts_x1.size = [num_pts,num_dim] 291 | # pts_x2.size = [num_pts,num_dim] 292 | # 293 | # @optimize: Should cache the points and neighbor lists... then using torch.solve, torch.ger. 294 | # Should vectorize all of the for-loop operations via Lambdifying polynomial evals. 295 | # Should avoid numpy calculations, maybe cache numpy copy of data if needed to avoid .cpu() transfer calls. 296 | # Use batching over points to do solves, then GPU parallizable and faster. 297 | # 298 | if device is None: 299 | device = torch.device('cpu'); # default cpu device 300 | 301 | if (u.dim() > 1): 302 | print("u.dim = " + str(u.dim())); 303 | print("u.shape = " + str(u.shape)); 304 | raise Exception("Assumes input with dimension == 1."); 305 | 306 | if (cached_data is None) or ('map_data' not in cached_data) or (cached_data['map_data'] is None): 307 | generate_mapping = MapToPoly_Function.generate_mapping; 308 | 309 | if pts_x2 is None: 310 | pts_x2 = pts_x1; 311 | 312 | map_data = generate_mapping(weight_func,weight_func_params, 313 | porder,epsilon, 314 | pts_x1,pts_x2,tree_points,device); 315 | 316 | if cached_data is not None: 317 | cached_data['map_data'] = map_data; 318 | 319 | else: 320 | map_data = cached_data['map_data']; # use cached data 321 | 322 | if flag_verbose > 0: 323 | print("num_pts1 = " + str(num_pts1) + ", num_pts2 = " + str(num_pts2)); 324 | 325 | if epsilon is None: 326 | raise Exception('The epsilon ball size to use around xi must be specified.') 327 | 328 | # Maps from u(x_j) on $x_j \in \mathcal{S}^1$ to a 329 | # polynomial representations in overlapping regions $\Omega_i$ at locations 330 | # around points $x_i \in \mathcal{S}^2$. 331 | 332 | # These two sample sets need not be the same allowing mappings between point sets. 333 | # Computes polynomials over x and y. 334 | # For 2D case, number of terms is num_polys = (porder + 1)*(porder + 2)/2. 335 | 336 | #c_star[:,i] = np.linalg.solve(np_M,np_r); # "c^*(x_i) = M^{-1}*r." 337 | vec_rij = map_data['vec_rij']; 338 | M_inv = map_data['M_inv']; 339 | 340 | r_all = torch.matmul(vec_rij,u); 341 | c_star = torch.bmm(M_inv,r_all.unsqueeze(2)); # perform batch matric-vector multiplications 342 | c_star = c_star.squeeze(2); # convert to list of vectors 343 | 344 | output = c_star; 345 | output = output.float(); # Map to float type for GPU / PyTorch Module compatibilities. 346 | 347 | return output, pts_x2; 348 | 349 | @staticmethod 350 | def forward(ctx, input, porder, weight_func, weight_func_params, 351 | pts_x1, epsilon = None, pts_x2 = None, cached_data=None, 352 | tree_points = None, device = None, flag_verbose = 0): 353 | r""" 354 | 355 | For a field u specified at points xj, performs the mapping to coefficients c at points xi, (uj,xj) :math:`\rightarrow` (ci,xi). 356 | 357 | Args: 358 | input (Tensor): The input field data uj. 359 | porder (int): Order of the basis to use (polynomial degree). 360 | weight_func (function): Weight function to use. 361 | weight_func_params (dict): Weight function parameters. 362 | pts_x1 (Tensor): The collection of domain points :math:`x_j`. 363 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params). 364 | pts_x2 (Tensor): The collection of target points :math:`x_i`. 365 | cache_data (dict): Stored data to help speed up repeated calculations. 366 | tree_points (dict): Stored data to help speed up repeated calculations. 367 | device (torch.device): Device on which to perform calculations (GPU or other, default is CPU). 368 | flag_verbose (int): Level of reporting on progress during the calculations. 369 | 370 | Returns: 371 | tuple of (ci,xi): The coefficient values ci at the target points xi. The target points xi. 372 | 373 | """ 374 | if device is None: 375 | device = torch.device('cpu'); 376 | 377 | ctx.atz_name = 'MapToPoly_Function'; 378 | 379 | ctx.save_for_backward(input,pts_x1,pts_x2); 380 | 381 | ctx.atz_porder = porder; 382 | ctx.atz_weight_func = weight_func; 383 | ctx.atz_weight_func_params = weight_func_params; 384 | 385 | get_poly_1D_u = MapToPoly_Function.get_poly_1D_u; 386 | get_num_polys = MapToPoly_Function.get_num_polys; 387 | 388 | input_dim = input.dim(); 389 | if input_dim >= 1: # compute c_star in batches 390 | pts_x1_numpy = None; 391 | pts_x2_numpy = None; 392 | 393 | if pts_x2 is None: 394 | pts_x2 = pts_x1; 395 | 396 | # reshape the data to handle as a batch [batch_size, uj_data_size] 397 | # We assume u is input in the form [I,k,xj], u_I(k,xj), the index I is arbitrary. 398 | u = input; 399 | 400 | if input_dim == 2: # need to unsqueeze, so 2D we are mapping 401 | # [k,xj] --> [I,k,xj] --> [II,c] --> [I,k,xi,c] --> [k,xi,c] 402 | u = u.unsqueeze(0); # u(k,xj) assumed in our calculations here 403 | 404 | if input_dim == 1: # need to unsqueeze, so 1D we are mapping 405 | # [xj] --> [I,k,xj] --> [II,c] --> [I,k,xi,c] --> [xi,c] 406 | u = u.unsqueeze(0); # u(k,xj) assumed in our calculations here 407 | u = u.unsqueeze(0); # u(k,xj) assumed in our calculations here 408 | 409 | u_num_dim = u.dim(); 410 | size_nm1 = 1; 411 | for d in range(u_num_dim - 1): 412 | size_nm1 *= u.shape[d]; 413 | 414 | uu = u.contiguous().view((size_nm1,u.shape[-1])); 415 | 416 | # compute the sizes of c_star and number of points 417 | num_dim = pts_x1.shape[1]; 418 | num_polys = get_num_polys(porder,num_dim); 419 | num_pts2 = pts_x2.shape[0]; 420 | 421 | # output needs to be of size [batch_size, xi_data_size, num_polys] 422 | output = torch.zeros((uu.shape[0],num_pts2,num_polys),device=device); # will reshape at the end 423 | 424 | # loop over the batches and compute the c_star in each case 425 | if cached_data is None: 426 | cached_data = {}; # create empty, which can be computed first time to store data. 427 | 428 | if tree_points is None: 429 | if pts_x1_numpy is None: pts_x1_numpy = pts_x1.cpu().numpy(); 430 | tree_points = spatial.cKDTree(pts_x1_numpy); 431 | 432 | for k in range(uu.shape[0]): 433 | uuu = uu[k,:]; 434 | out, pts_x2 = get_poly_1D_u(uuu,porder,weight_func,weight_func_params, 435 | pts_x1,epsilon,pts_x2,cached_data, 436 | tree_points,flag_verbose); 437 | output[k,:,:] = out; 438 | 439 | # final output should be [*, xi_data_size, num_polys], where * is the original sizes 440 | # for indices [i1,i2,...in,k_channel,xi_data,c_poly_coeff]. 441 | output = output.view(*u.shape[0:u_num_dim-1],num_pts2,num_polys); 442 | 443 | if input_dim == 2: # 2D special case we just return k, xi, c (otherwise feed input 3D [I,k,u(xj)] I=1,k=1). 444 | output = output.squeeze(0); 445 | 446 | if input_dim == 1: # 1D special case we just return xi, c (otherwise feed input 3D [I,k,u(xj)] I=1,k=1). 447 | output = output.squeeze(0); 448 | output = output.squeeze(0); 449 | 450 | else: 451 | print("input.dim = " + str(input.dim())); 452 | print("input.shape = " + str(input.shape)); 453 | raise Exception("input tensor dimension not yet supported, only dim = 1 and dim = 3 currently."); 454 | 455 | ctx.atz_cached_data = cached_data; 456 | 457 | pts_x2_clone = pts_x2.clone(); 458 | return output, pts_x2_clone; 459 | 460 | @staticmethod 461 | def backward(ctx,grad_output,grad_pts_x2): 462 | r""" Consider a field u specified at points xj and the mapping to coefficients c at points xi, (uj,xj) --> (ci,xi). 463 | Computes the gradient of the mapping for backward propagation. 464 | """ 465 | 466 | flag_time_it = False; 467 | if flag_time_it: 468 | time_1 = time.time(); 469 | 470 | input,pts_x1,pts_x2 = ctx.saved_tensors; 471 | 472 | porder = ctx.atz_porder; 473 | weight_func = ctx.atz_weight_func; 474 | weight_func_params = ctx.atz_weight_func_params; 475 | cached_data = ctx.atz_cached_data; 476 | 477 | #grad_input = grad_weight_func = grad_weight_func_params = None; 478 | grad_uj = None; 479 | 480 | # we only compute the gradient in x_i, if it is requested (for efficiency) 481 | if ctx.needs_input_grad[0]: # derivative in uj 482 | map_data = cached_data['map_data']; # use cached data 483 | 484 | vec_rij = map_data['vec_rij']; 485 | M_inv = map_data['M_inv']; 486 | 487 | # c_i = M_{i}^{-1} r_i^T u 488 | # dF/du = dF/dc*dc/du, 489 | # 490 | # We can express this using dF/uj = sum_i dF/dci*dci/duj 491 | # 492 | # grad_output = dF/dc, grad_input = dF/du 493 | # 494 | # [grad_input]_j = sum_i dF/ci*dci/duj. 495 | # 496 | # In practice, we have both batch and channel indices so 497 | # grad_output.shape = [batchI,channelI,i,compK] 498 | # grad_output[batchI,channelI,i,compK] = F(batchI,channelI) with respect to ci[compK](batchI,channelI). 499 | # 500 | # grad_input[batchI,channelI,j] = 501 | # 502 | # We use matrix broadcasting to get this outcome in practice. 503 | # 504 | 505 | # @optimize can optimize, since uj only contributes non-zero to a few ci's... and could try to use sparse matrix multiplications. 506 | A1 = torch.bmm(M_inv,vec_rij); # dci/du, grad = grad[i,compK,j] 507 | A2 = A1.unsqueeze(0).unsqueeze(0); # match grad_output tensor rank, for grad[batchI,channelI,i,compK,j] 508 | A3 = grad_output.unsqueeze(4); # grad_output[batchI,channelI,i,compK,j] 509 | A4 = A3*A2; # elementwise multiplication 510 | A5 = torch.sum(A4,3); # contract on index compK 511 | A6 = torch.sum(A5,2); # contract on index i 512 | 513 | grad_uj = A6; 514 | 515 | else: 516 | msg_str = "Requested a currently un-implemented gradient for this map: \n"; 517 | msg_str += "ctx.needs_input_grad = \n" + str(ctx.needs_input_grad); 518 | raise Exception(msg_str); 519 | 520 | if flag_time_it: 521 | msg = 'MapToPoly_Function->backward():'; 522 | msg += 'elapsed_time = %.4e'%(time.time() - time_1); 523 | print(msg); 524 | 525 | return grad_uj,None,None,None,None,None,None,None,None,None,None; # since no trainable parts for these components of map 526 | 527 | 528 | class MaxPoolOverPoints_Function(torch.autograd.Function): 529 | r"""Applies a max-pooling operation to obtain values :math:`v_i = \max_{j \in \mathcal{N}_i(\epsilon)} \{u_j\}.` """ 530 | # @optimize: Should cache the points and neighbor lists. 531 | # Should avoid numpy calculations, maybe cache numpy copy of data if needed to avoid .cpu() transfer calls. 532 | # Use batching over points to do solves, then GPU parallizable and faster. 533 | @staticmethod 534 | def forward(ctx,input,pts_x1,epsilon=None,pts_x2=None, 535 | indices_xj_i_cache=None,tree_points=None, 536 | flag_verbose=0): 537 | r"""Compute max pool operation from values at points (uj,xj) to obtain (vi,xi). 538 | 539 | Args: 540 | input (Tensor): The uj values at the location of points xj. 541 | pts_x1 (Tensor): The collection of domain points :math:`x_j`. 542 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params). 543 | pts_x2 (Tensor): The collection of target points :math:`x_i`. 544 | tree_points (dict): Stored data to help speed up repeated calculations. 545 | flag_verbose (int): Level of reporting on progress during the calculations. 546 | 547 | Returns: 548 | tuple: The collection ui at target points (same size as uj in the non-j indices). The collection xi of target points. Tuple of form (ui,xi). 549 | 550 | Note: 551 | We assume that all inputs are pytorch tensors with pts_x1.shape = [num_pts,num_dim] and similarly for pts_x2. 552 | 553 | """ 554 | 555 | ctx.atz_name = 'MaxPoolOverPoints_Function'; 556 | 557 | ctx.save_for_backward(input,pts_x1,pts_x2); 558 | 559 | u = input.clone(); # map input values u(xj) at xj to max value in epsilon neighborhood to u(xi) at xi points. 560 | 561 | # Assumes that input is of size [k1,k2,...,kn,j], where k1,...,kn are any indices. 562 | # We perform maxing over batch over all non-indices in j. 563 | # We reshape tensor to the form [*,j] where one index in *=index(k1,...,kn). 564 | u_num_dim = u.dim(); 565 | size_nm1 = 1; 566 | for d in range(u_num_dim - 1): 567 | size_nm1 *= u.shape[d]; 568 | 569 | uj = u.contiguous().view((size_nm1,u.shape[-1])); # reshape so indices --> [I,j], I = index(k1,...,kn). 570 | 571 | # reshaped 572 | if pts_x2 is None: 573 | pts_x2 = pts_x1; 574 | 575 | pts_x1_numpy = pts_x1.cpu().numpy(); pts_x2_numpy = pts_x2.cpu().numpy(); # move to cpu to get numpy data 576 | pts_x1 = pts_x1.to(input.device); pts_x2 = pts_x2.to(input.device); # push back to GPU [@optimize later] 577 | num_pts1 = pts_x1.size()[0]; num_pts2 = pts_x2.size()[0]; 578 | if flag_verbose > 0: 579 | print("num_pts1 = " + str(num_pts1) + ", num_pts2 = " + str(num_pts2)); 580 | 581 | if epsilon is None: 582 | raise Exception('The epsilon ball size to use around xi must be specified.'); 583 | 584 | ctx.atz_epsilon = epsilon; 585 | 586 | if indices_xj_i_cache is None: 587 | flag_need_indices_xj_i = True; 588 | else: 589 | flag_need_indices_xj_i = False; 590 | 591 | if flag_need_indices_xj_i and tree_points is None: # build kd-tree of points for neighbor listing 592 | tree_points = spatial.cKDTree(pts_x1_numpy); 593 | 594 | ctx.atz_tree_points = tree_points; 595 | ctx.indices_xj_i_cache = indices_xj_i_cache; 596 | 597 | # Maps from u(x_j) on $x_j \in \mathcal{S}^1$ to a u(x_i) giving max values in epsilon neighborhoods. 598 | # @optimize by caching these data structure for re-use later 599 | ui = torch.zeros(size_nm1,num_pts2,requires_grad=False,device=input.device); 600 | ui_argmax_j = torch.zeros(size_nm1,num_pts2,dtype=torch.int64,requires_grad=False,device=input.device); 601 | # assumes array of form [*,num_pts2], will be reshaped to match uj, [*,num_pts2]. 602 | 603 | for i in np.arange(0,num_pts2): # loop over the points $x_i$ 604 | if flag_verbose > 1: print("i = " + str(i) + " : num_pts2 = " + str(num_pts2)); 605 | 606 | # find all points distance epsilon from xi 607 | if flag_need_indices_xj_i: 608 | indices_xj_i = tree_points.query_ball_point(pts_x2_numpy[i,:], epsilon); 609 | indices_xj_i = torch.Tensor(indices_xj_i).long(); 610 | indices_xj_i.to(uj.device); 611 | else: 612 | indices_xj_i = indices_xj_i_cache[i,:]; # @optimize should consider replacing with better data structures 613 | 614 | # take max over neighborhood. Assumes for now that ui is scalar. 615 | uuj = uj[:,indices_xj_i]; 616 | qq = torch.max(uuj,dim=-1,keepdim=True); 617 | ui[:,i] = qq[0].squeeze(-1); # store max value 618 | jj = qq[1].squeeze(-1); # store index of max value 619 | ui_argmax_j[:,i] = indices_xj_i[jj]; # store global index of the max value 620 | 621 | # reshape the tensor from ui[I,i] to the form uui[k1,k2,...kn,i] 622 | uui = ui.view(*u.shape[0:u_num_dim-1],num_pts2); 623 | uui_argmax_j = ui_argmax_j.view(*u.shape[0:u_num_dim-1],num_pts2); 624 | ctx.atz_uui_argmax_j = uui_argmax_j; # save for gradient calculation 625 | 626 | output = uui; # for now, we assume for now that ui is scalar array of size [num_pts2] 627 | output = output.to(input.device); 628 | 629 | return output, pts_x2.clone(); 630 | 631 | @staticmethod 632 | def backward(ctx,grad_output,grad_pts_x2): 633 | r"""Compute gradients of the max pool operations from values at points (uj,xj) --> (max_ui,xi). """ 634 | 635 | flag_time_it = False; 636 | if flag_time_it: 637 | time_11 = time.time(); 638 | 639 | # Compute df/dx from df/dy using the Chain Rule df/dx = df/dx*dy/dx. 640 | # Compute the gradient with respect to inputs, dz/dx. 641 | # 642 | # Consider z = f(g(x)), where we refer to x as the inputs and y = g(x) as outputs. 643 | # If we know dz/dy, we would like to compute dz/dx. This will follow from the chain-rule 644 | # as dz/dx = (dz/dy)*(dy/dx). We call dz/dy the gradient with respect to output and we call 645 | # dy/dx the gradient with respect to input. 646 | # 647 | # Note: the grad_output can be larger than the size of the input vector if we include in our 648 | # definition of gradient_input the derivatives with respect to weights. Should think of everything 649 | # input as tilde_x = [x,weights,bias,etc...], then grad_output = dz/dtilde_x. 650 | 651 | input,pts_x1,pts_x2 = ctx.saved_tensors; 652 | 653 | uui_argmax_j = ctx.atz_uui_argmax_j; 654 | 655 | #grad_input = grad_weight_func = grad_weight_func_params = None; 656 | grad_input = None; 657 | 658 | # We only compute the gradient in xi, if it is requested (for efficiency) 659 | # stubs for later possible use, but not needed for now 660 | if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: 661 | msg_str = "Currently requested a non-trainable gradient for this map: \n"; 662 | msg_str += "ctx.needs_input_grad = \n" + str(ctx.needs_input_grad); 663 | raise Exception(msg_str); 664 | 665 | if ctx.needs_input_grad[0]: 666 | # Compute dL/duj = (dL/dvi)*(dvi/duj), here vi = uui. 667 | # For the max-pool case, notice that dvi/duj is non-zero only when the index uj 668 | # was the maximum value in the neighborhood of vi. Notice subtle issue with 669 | # right and left derivatives being different, so max is not differentiable for ties. 670 | # We use the right derivative lim_h (q(x + h) - q(x))/h, here. 671 | 672 | # We assume that uj.size = [k1,k2,...,kn,j], ui.size = [k1,k2,...,kn,i]. 673 | # These are reshaped so that uuj.size = [I,j] and uui.size = [I,i]. 674 | input_dim = input.dim(); 675 | size_uj = input.size(); 676 | size_uj_nm1 = np.prod(size_uj[0:input_dim-1]); # exclude last index size 677 | 678 | #ss_grad_input = input.new_zeros(size_uj_nm1,size_uj[-1]); # to store dL/duj, [I,j] indexing. 679 | ss_grad_output = grad_output.contiguous().view((size_uj_nm1,grad_output.shape[-1])); # reshape so index [I,i]. 680 | ss_uui_argmax_j = uui_argmax_j.contiguous().view((size_uj_nm1,grad_output.shape[-1])); # reshape so index [I,i]. 681 | 682 | # assign the entries k_i = argmax_{j in Omega_i} uj, reshaped so [*,j] = val[*,j]. 683 | flag_method = 'method1'; 684 | if flag_method == 'method1': 685 | 686 | flag_time_it = False; 687 | if flag_time_it: 688 | time_0 = time.time(); 689 | 690 | I = torch.arange(0,size_uj_nm1,dtype=torch.int64,device=input.device); 691 | vec_ones = torch.ones(grad_output.shape[-1],dtype=torch.int64,device=input.device); 692 | II = torch.ger(I.float(),vec_ones.float()); # careful int --> float conv 693 | II = II.flatten(); 694 | JJ = ss_uui_argmax_j.flatten(); 695 | IJ_indices1 = torch.stack((II,JJ.float())).long(); 696 | 697 | i_index = torch.arange(0,grad_output.shape[-1],dtype=torch.int64,device=input.device); 698 | vec_ones = torch.ones(size_uj_nm1,dtype=torch.int64,device=input.device); 699 | KK = torch.ger(vec_ones.float(),i_index.float()); # careful int --> float conv 700 | KK = KK.flatten(); 701 | IJ_indices2 = torch.stack((II,KK)).long(); 702 | 703 | # We aim to compute dL/duj = dL/d\bar{u}_i*d\bar{u}_i/duj. 704 | # 705 | # This is done efficiently by constructing a sparse matrix using how \bar{u}_i 706 | # depends on the uj. Sometimes the same uj contributes multiple times to 707 | # a given \bar{u}_i entry, so we add together those contributions, as would 708 | # occur in an explicit multiplication of the terms above for dL/duj. 709 | # This is acheived efficiently using the .add() for sparse tensors in PyTorch. 710 | 711 | # We construct entries of the sparse matrix and coelesce them (add repeats). 712 | vals = ss_grad_output[IJ_indices2[0,:],IJ_indices2[1,:]]; # @optimize, maybe just flatten 713 | N1 = size_uj_nm1; N2 = size_uj[-1]; sz = torch.Size([N1,N2]); 714 | ss_grad_input = torch.sparse.FloatTensor(IJ_indices1,vals,sz).coalesce().to_dense(); 715 | 716 | if flag_time_it: 717 | time_1 = time.time(); 718 | 719 | print("time: backward(): compute ss_grad_input = %.4e sec"%(time_1 - time_0)); 720 | 721 | elif flag_method == 'method2': 722 | II = torch.arange(0,size_uj_nm1,dtype=torch.int64); 723 | i_index = torch.arange(0,grad_output.shape[-1],dtype=torch.int64); 724 | 725 | # @optimize by vectorizing this calculation 726 | for I in II: 727 | for j in range(0,i_index.shape[0]): 728 | ss_grad_input[I,ss_uui_argmax_j[I,j]] += ss_grad_output[I,i_index[j]]; 729 | 730 | else: 731 | raise Exception("flag_method type not recognized.\n flag_method = %s"%flag_method); 732 | 733 | # reshape 734 | grad_input = ss_grad_input.view(*size_uj[0:input_dim - 1],size_uj[-1]); 735 | 736 | if flag_time_it: 737 | msg = 'atzGMLS_MaxPool2D_Function->backward(): '; 738 | msg += 'elapsed_time = %.4e'%(time.time() - time_11); 739 | print(msg); 740 | 741 | return grad_input,None,None,None,None,None,None; # since no trainable parts for components of this map 742 | 743 | class ExtractFromTuple_Function(torch.autograd.Function): 744 | r"""Extracts from a tuple of outputs one of the components.""" 745 | 746 | @staticmethod 747 | def forward(ctx,input,index): 748 | r"""Extracts tuple entry with the specified index.""" 749 | ctx.atz_name = 'ExtractFromTuple_Function'; 750 | 751 | extracted = input[index]; 752 | output = extracted.clone(); # clone added for safety 753 | 754 | return output; 755 | 756 | @staticmethod 757 | def backward(ctx,grad_output): # number grad's needs to match outputs of forward 758 | r"""Computes gradient of the extraction.""" 759 | 760 | raise Exception('This backward is not implemented, since PyTorch automatically handled this in the past.'); 761 | return None,None; 762 | 763 | # ==================================== 764 | # Custom Modules 765 | # ==================================== 766 | class PdbSetTraceLayer(nn.Module): 767 | r"""Allows for placing break-points within the call sequence of layers using pdb.set_trace(). Helpful for debugging networks.""" 768 | 769 | def __init__(self): 770 | r"""Initialization (currently nothing to do, but call super-class).""" 771 | super(PdbSetTraceLayer, self).__init__() 772 | 773 | def forward(self, input): 774 | r"""Executes a PDB breakpoint inside of a running network to help with debugging.""" 775 | out = input.clone(); # added clone to avoid .grad_fn overwrite 776 | pdb.set_trace(); 777 | return out; 778 | 779 | class ExtractFromTuple(nn.Module): 780 | r"""Extracts from a tuple of outputs one of the components.""" 781 | 782 | def __init__(self,index=0): 783 | r"""Initializes the index to extract.""" 784 | super(ExtractFromTuple, self).__init__() 785 | self.index = index; 786 | 787 | def forward(self, input): 788 | r"""Extracts the tuple entry with the specified index.""" 789 | extracted = input[self.index]; 790 | extracted_clone = extracted.clone(); # cloned to avoid overwrite of .grad_fn 791 | return extracted_clone; 792 | 793 | class ReshapeLayer(nn.Module): 794 | r"""Performs reshaping of a tensor output within a network.""" 795 | 796 | def __init__(self,reshape,permute=None): 797 | r"""Initializes the reshaping form to use followed by the indexing permulation to apply.""" 798 | super(ReshapeLayer, self).__init__() 799 | self.reshape = reshape; 800 | self.permute = permute; 801 | 802 | def forward(self, input): 803 | r"""Reshapes the tensor followed by applying a permutation to the indexing.""" 804 | reshape = self.reshape; 805 | permute = self.permute; 806 | A = input.contiguous(); 807 | out = A.view(*reshape); 808 | if permute is not None: 809 | out = out.permute(*permute); 810 | return out; 811 | 812 | class PermuteLayer(nn.Module): 813 | r"""Performs permutation of indices of a tensor output within a network.""" 814 | 815 | def __init__(self,permute=None): 816 | r"""Initializes the indexing permuation to apply to tensors.""" 817 | super(PermuteLayer, self).__init__() 818 | self.permute = permute; 819 | 820 | def forward(self, input): 821 | r"""Applies and indexing permuation to the input tensor.""" 822 | permute = self.permute; 823 | input_clone = input.clone(); # adding clone to avoid .grad_fn overwrites 824 | out = input_clone.permute(*permute); 825 | return out; 826 | 827 | class MLP_Pointwise(nn.Module): 828 | r"""Creates a collection of multilayer perceptrons (MLPs) for each output channel. 829 | The MLPs are then applied at each target point xi. 830 | """ 831 | 832 | def create_mlp_unit(self,layer_sizes,unit_name='',flag_bias=True): 833 | r"""Creates an instance of an MLP with specified layer sizes. """ 834 | layer_dict = OrderedDict(); 835 | NN = len(layer_sizes); 836 | for i in range(NN - 1): 837 | key_str = unit_name + ':hidden_layer_%.4d'%(i + 1); 838 | layer_dict[key_str] = nn.Linear(layer_sizes[i], layer_sizes[i+1],bias=flag_bias); 839 | if i < NN - 2: # last layer should be linear 840 | key_str = unit_name + ':relu_%.4d'%(i + 1); 841 | layer_dict[key_str] = nn.ReLU(); 842 | 843 | mlp_unit = nn.Sequential(layer_dict); # uses ordered dictionary to create network 844 | 845 | return mlp_unit; 846 | 847 | def __init__(self,layer_sizes,channels_in=1,channels_out=1,flag_bias=True,flag_verbose=0): 848 | r"""Initializes the structure of the pointwise MLP module with layer sizes, number input channels, number of output channels. 849 | 850 | Args: 851 | layer_sizes (list): The number of hidden units in each layer. 852 | channels_in (int): The number of input channels. 853 | channels_out (int): The number of output channels. 854 | flag_bias (bool): If the MLP should include the additive bias b added into layers. 855 | flag_verbose (int): The level of messages generated on progress of the calculation. 856 | 857 | """ 858 | super(MLP_Pointwise, self).__init__(); 859 | 860 | self.layer_sizes = layer_sizes; 861 | self.flag_bias = flag_bias; 862 | self.depth = len(layer_sizes); 863 | 864 | self.channels_in = channels_in; 865 | self.channels_out = channels_out; 866 | 867 | # create intermediate layers 868 | mlp_list = nn.ModuleList(); 869 | layer_sizes_unit = layer_sizes.copy(); # we use inputs k*c to cross channels in practice in our unit MLPs 870 | layer_sizes_unit[0] = layer_sizes_unit[0]*channels_in; # modify the input to have proper size combined k*c 871 | for ell in range(channels_out): 872 | mlp_unit = self.create_mlp_unit(layer_sizes_unit,'unit_ell_%.4d'%ell,flag_bias=flag_bias); 873 | mlp_list.append(mlp_unit); 874 | 875 | self.mlp_list = mlp_list; 876 | 877 | def forward(self, input, params = None): 878 | r"""Applies the specified MLP pointwise to the collection of input data to produce pointwise entries of the output channels.""" 879 | # 880 | # Assumes the tensor has the form [i1,i2,...in,k,c], the last two indices are the 881 | # channel index k, and the coefficient index c, combine for ease of use, but can reshape. 882 | # We collapse input tensor with indexing [i1,i2,...in,k,c] to a [I,k*c] tensor, where 883 | # I is general index, k is channel, and c are coefficient index. 884 | # 885 | s = input.shape; 886 | num_dim = input.dim(); 887 | 888 | if (s[-2] != self.channels_in) or (s[-1] != self.layer_sizes[0]): # check correct sized inputs 889 | print("input.shape = " + str(input.shape)); 890 | raise Exception("MLP assumes an input tensor of size [*,%d,%d]"%(self.channels_in,self.layer_sizes[0])); 891 | 892 | calc_size1 = 1.0; 893 | for d in range(num_dim-2): 894 | calc_size1 *= s[d]; 895 | calc_size1 = int(calc_size1); 896 | 897 | x = input.contiguous().view(calc_size1,s[num_dim-2]*s[num_dim-1]); # shape input to have indexing [I,k*NN + c] 898 | 899 | if params is None: 900 | output = torch.zeros((self.channels_out,x.shape[0]),device=input.device); # shape [ell,*] 901 | for ell in range(self.channels_out): 902 | mlp_q = self.mlp_list[ell]; 903 | output[ell,:] = mlp_q.forward(x).squeeze(-1); # reduce from [N,1] to [N] 904 | 905 | s = input.shape; 906 | output = output.view(self.channels_out,*s[0:num_dim-2]); # shape to have index [ell,i1,i2,...,in] 907 | nn = output.dim(); 908 | p_ind = np.arange(nn) + 1; 909 | p_ind[nn-1] = 0; 910 | p_ind = tuple(p_ind); 911 | output = output.permute(p_ind); # [*,ell] indexing of final shape 912 | else: 913 | raise Exception("Not yet implemented for setting parameters."); 914 | 915 | return output; # [*,ell] indexing of final shape 916 | 917 | def to(self, device): 918 | r"""Moves data to GPU or other specified device.""" 919 | super(MLP_Pointwise, self).to(device); 920 | for ell in range(self.channels_out): 921 | mlp_q = self.mlp_list[ell]; 922 | mlp_q.to(device); 923 | return self; 924 | 925 | 926 | class MLP1(nn.Module): 927 | r"""Creates a multilayer perceptron (MLP). """ 928 | 929 | def __init__(self, layer_sizes, flag_bias = True, flag_verbose=0): 930 | r"""Initializes MLP and specified layer sizes.""" 931 | super(MLP1, self).__init__(); 932 | 933 | self.layer_sizes = layer_sizes; 934 | self.flag_bias = flag_bias; 935 | self.depth = len(layer_sizes); 936 | 937 | # create intermediate layers 938 | layer_dict = OrderedDict(); 939 | NN = len(layer_sizes); 940 | for i in range(NN - 1): 941 | key_str = 'hidden_layer_%.4d'%(i + 1); 942 | layer_dict[key_str] = nn.Linear(layer_sizes[i], layer_sizes[i+1],bias=flag_bias); 943 | if i < NN - 2: # last layer should be linear 944 | key_str = 'relu_%.4d'%(i + 1); 945 | layer_dict[key_str] = nn.ReLU(); 946 | 947 | self.layers = nn.Sequential(layer_dict); # uses ordered dictionary to create network 948 | 949 | def forward(self, input, params = None): 950 | r"""Applies the MLP to the input data. 951 | 952 | Args: 953 | input (Tensor): The coefficient channel data organized as one stacked 954 | vector of size Nc*M, where Nc is number of channels and M is number of 955 | coefficients per channel. 956 | 957 | Returns: 958 | Tensor: The evaluation of the network. Returns tensor of size [batch,1]. 959 | 960 | """ 961 | # evaluate network with specified layers 962 | if params is None: 963 | eval = self.layers.forward(input); 964 | else: 965 | raise Exception("Not yet implemented for setting parameters."); 966 | 967 | return eval; 968 | 969 | def to(self, device): 970 | r"""Moves data to GPU or other specified device.""" 971 | super(MLP1, self).to(device); 972 | self.layers = self.layers.to(device); 973 | return self; 974 | 975 | class MapToPoly(nn.Module): 976 | r""" 977 | This layer processes a collection of scattered data points consisting of a collection 978 | of values :math:`u_j` at points :math:`x_j`. For a collection of target points 979 | :math:`x_i`, local least-squares problems are solved for obtaining a local representation 980 | of the data over a polynomial space. The layer outputs a collection of polynomial 981 | coefficients :math:`c(x_i)` at each point and the collection of target points :math:`x_i`. 982 | """ 983 | def __init__(self, porder, weight_func, weight_func_params, pts_x1, 984 | epsilon = None,pts_x2 = None,tree_points = None, 985 | device = None,flag_verbose = 0,**extra_params): 986 | r"""Initializes the layer for mapping between field data uj at points xj to the 987 | local polynomial reconstruction represented by coefficients ci at target points xi. 988 | 989 | Args: 990 | porder (int): Order of the basis to use. For polynomial basis is the degree. 991 | weight_func (func): Weight function to use. 992 | weight_func_params (dict): Weight function parameters. 993 | pts_x1 (Tensor): The collection of domain points :math:`x_j`. 994 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params). 995 | pts_x2 (Tensor): The collection of target points :math:`x_i`. 996 | tree_points (dict): Stored data to help speed up repeated calculations. 997 | device: Device on which to perform calculations (GPU or other, default is CPU). 998 | flag_verbose (int): Level of reporting on progress during the calculations. 999 | **extra_params: Extra parameters allowing for specifying layer name and caching mode. 1000 | 1001 | """ 1002 | super(MapToPoly, self).__init__(); 1003 | 1004 | self.flag_verbose = flag_verbose; 1005 | 1006 | if device is None: 1007 | device = torch.device('cpu'); 1008 | 1009 | self.device = device; 1010 | 1011 | if 'name' in extra_params: 1012 | self.name = extra_params['name']; 1013 | else: 1014 | self.name = "default_name"; 1015 | 1016 | if 'flag_cache_mode' in extra_params: 1017 | flag_cache_mode = extra_params['flag_cache_mode']; 1018 | else: 1019 | flag_cache_mode = 'generate1'; 1020 | 1021 | if flag_cache_mode == 'generate1': # setup from scratch 1022 | self.porder = porder; 1023 | self.weight_func = weight_func; 1024 | self.weight_func_params = weight_func_params; 1025 | 1026 | self.pts_x1 = pts_x1; 1027 | self.pts_x2 = pts_x2; 1028 | 1029 | self.pts_x1_numpy = None; 1030 | self.pts_x2_numpy = None; 1031 | 1032 | if self.pts_x2 is None: 1033 | self.pts_x2 = pts_x1; 1034 | 1035 | self.epsilon = epsilon; 1036 | 1037 | if tree_points is None: # build kd-tree of points for neighbor listing 1038 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy(); 1039 | self.tree_points = spatial.cKDTree(self.pts_x1_numpy); 1040 | 1041 | if device is None: 1042 | device = torch.device('cpu'); 1043 | 1044 | self.device = device; 1045 | 1046 | self.cached_data = {}; # create empty cache for storing data 1047 | generate_mapping = MapToPoly_Function.generate_mapping; 1048 | self.cached_data['map_data'] = generate_mapping(self.weight_func,self.weight_func_params, 1049 | self.porder,self.epsilon, 1050 | self.pts_x1,self.pts_x2, 1051 | self.tree_points,self.device, 1052 | self.flag_verbose); 1053 | 1054 | elif flag_cache_mode == 'load_from_file': # setup by loading data from cache file 1055 | 1056 | if 'cache_filename' in extra_params: 1057 | cache_filename = extra_params['cache_filename']; 1058 | else: 1059 | raise Exception('No cache_filename specified.'); 1060 | 1061 | self.load_cache_data(cache_filename); # load data from file 1062 | 1063 | else: 1064 | print("flag_cache_mode = " + str(flag_cache_mode)); 1065 | raise Exception('flag_cache_mode is invalid.'); 1066 | 1067 | def save_cache_data(self,cache_filename): 1068 | r"""Save needed matrices and related data to .pickle for later cached use. (Warning: prototype codes here currently and not tested).""" 1069 | # collect the data to save 1070 | d = {}; 1071 | d['porder'] = self.porder; 1072 | d['epsilon'] = self.epsilon; 1073 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy(); 1074 | d['pts_x1'] = self.pts_x1_numpy; 1075 | if self.pts_x2_numpy is None: self.pts_x2_numpy = pts_x2.cpu().numpy(); 1076 | d['pts_x2'] = self.pts_x2_numpy; 1077 | d['weight_func_str'] = str(self.weight_func); 1078 | d['weight_func_params'] = self.weight_func_params; 1079 | d['version'] = __version__; # Module version 1080 | d['cached_data'] = self.cached_data; 1081 | 1082 | # write the data to disk 1083 | f = open(cache_filename,'wb'); 1084 | p.dump(d,f); # load the data from file 1085 | f.close(); 1086 | 1087 | def load_cache_data(self,cache_filename): 1088 | r"""Load the needed matrices and related data from .pickle. (Warning: prototype codes here currently and not tested).""" 1089 | f = open(cache_filename,'rb'); 1090 | d = p.load(f); # load the data from file 1091 | f.close(); 1092 | 1093 | print(d.keys()) 1094 | self.porder = d['porder']; 1095 | self.epsilon = d['epsilon']; 1096 | 1097 | self.weight_func = d['weight_func_str']; 1098 | self.weight_func_params = d['weight_func_params']; 1099 | 1100 | self.pts_x1 = torch.from_numpy(d['pts_x1']).to(device); 1101 | self.pts_x2 = torch.from_numpy(d['pts_x2']).to(device); 1102 | 1103 | self.pts_x1_numpy = d['pts_x1']; 1104 | self.pts_x2_numpy = d['pts_x2']; 1105 | 1106 | if self.pts_x2 is None: 1107 | self.pts_x2 = pts_x1; 1108 | 1109 | # build kd-tree of points for neighbor listing 1110 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy(); 1111 | self.tree_points = spatial.cKDTree(self.pts_x1_numpy); 1112 | 1113 | self.cached_data = d['cached_data']; 1114 | 1115 | def eval_poly(self,pts_x,pts_x2_i0,c_star_i0,porder=None,flag_verbose=None): 1116 | r"""Evaluates the polynomial reconstruction around a given target point pts_x2_i0.""" 1117 | if porder is None: 1118 | porder = self.porder; 1119 | if flag_verbose is None: 1120 | flag_verbose = self.flag_verbose; 1121 | MapToPoly_Function.eval_poly(pts_x,pts_x2_i0,c_star_i0,porder,flag_verbose); 1122 | 1123 | def forward(self, input): # define the action of this layer 1124 | r"""For a field u specified at points xj, performs the mapping to coefficients c at points xi, (uj,xj) :math:`\rightarrow` (ci,xi).""" 1125 | flag_time_it = False; 1126 | if flag_time_it: 1127 | time_1 = time.time(); 1128 | 1129 | # We evaluate the action of the function, backward will be called automatically when computing gradients. 1130 | uj = input; 1131 | output = MapToPoly_Function.apply(uj,self.porder, 1132 | self.weight_func,self.weight_func_params, 1133 | self.pts_x1,self.epsilon,self.pts_x2, 1134 | self.cached_data,self.tree_points,self.device, 1135 | self.flag_verbose); 1136 | if flag_time_it: 1137 | msg = 'MapToPoly->forward(): '; 1138 | msg += 'elapsed_time = %.4e'%(time.time() - time_1); 1139 | print(msg); 1140 | 1141 | return output; 1142 | 1143 | def extra_repr(self): 1144 | r"""Displays information associated with this module.""" 1145 | # Display some extra information about this layer. 1146 | return 'porder={}, weight_func={}, weight_func_params={}, pts_x1={}, pts_x2={}'.format( 1147 | self.porder, self.weight_func, self.weight_func_params, self.pts_x1.shape, self.pts_x2.shape 1148 | ); 1149 | 1150 | def to(self, device): 1151 | r"""Moves data to GPU or other specified device.""" 1152 | super(MapToPoly,self).to(device); 1153 | self.pts_x1 = self.pts_x1.to(device); 1154 | self.pts_x2 = self.pts_x2.to(device); 1155 | return self; 1156 | 1157 | class MaxPoolOverPoints(nn.Module): 1158 | r"""Applies a max-pooling operation to obtain values :math:`v_i = \max_{j \in \mathcal{N}_i(\epsilon)} \{u_j\}.` """ 1159 | def __init__(self,pts_x1,epsilon=None,pts_x2=None, 1160 | indices_xj_i_cache=None,tree_points=None, 1161 | device=None,flag_verbose=0,**extra_params): 1162 | r"""Setup of max-pooling operation. 1163 | 1164 | Args: 1165 | pts_x1 (Tensor): The collection of domain points :math:`x_j`. We assume size [num_pts,num_dim]. 1166 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params). 1167 | pts_x2 (Tensor): The collection of target points :math:`x_i`. 1168 | indices_xj_i_cache (dict): Stored data to help speed up repeated calculations. 1169 | tree_points (dict): Stored data to help speed up repeated calculations. 1170 | device: Device on which to perform calculations (GPU or other, default is CPU). 1171 | flag_verbose (int): Level of reporting on progress during the calculations. 1172 | **extra_params (dict): Extra parameters allowing for specifying layer name and caching mode. 1173 | 1174 | """ 1175 | super(MaxPoolOverPoints,self).__init__(); 1176 | 1177 | self.flag_verbose = flag_verbose; 1178 | 1179 | if device is None: 1180 | device = torch.device('cpu'); 1181 | 1182 | self.device = device; 1183 | 1184 | if 'name' in extra_params: 1185 | self.name = extra_params['name']; 1186 | else: 1187 | self.name = "default_name"; 1188 | 1189 | if 'flag_cache_mode' in extra_params: 1190 | flag_cache_mode = extra_params['flag_cache_mode']; 1191 | else: 1192 | flag_cache_mode = 'generate1'; 1193 | 1194 | if flag_cache_mode == 'generate1': # setup from scratch 1195 | self.pts_x1 = pts_x1; 1196 | self.pts_x2 = pts_x2; 1197 | 1198 | self.pts_x1_numpy = None; 1199 | self.pts_x2_numpy = None; 1200 | 1201 | if self.pts_x2 is None: 1202 | self.pts_x2 = pts_x1; 1203 | 1204 | self.epsilon = epsilon; 1205 | 1206 | if tree_points is None: # build kd-tree of points for neighbor listing 1207 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy(); 1208 | self.tree_points = spatial.cKDTree(self.pts_x1_numpy); 1209 | 1210 | if indices_xj_i_cache is None: 1211 | self.indices_xj_i_cache = None; # cache the neighbor lists around each xi 1212 | else: 1213 | self.indices_xj_i_cache = indices_xj_i_cache; 1214 | 1215 | if device is None: 1216 | device = torch.device('cpu'); 1217 | 1218 | self.device = device; 1219 | 1220 | self.cached_data = {}; # create empty cache for storing data 1221 | 1222 | elif flag_cache_mode == 'load_from_file': # setup by loading data from cache file 1223 | 1224 | if 'cache_filename' in extra_params: 1225 | cache_filename = extra_params['cache_filename']; 1226 | else: 1227 | raise Exception('No cache_filename specified.'); 1228 | 1229 | self.load_cache_data(cache_filename); # load data from file 1230 | 1231 | else: 1232 | print("flag_cache_mode = " + str(flag_cache_mode)); 1233 | raise Exception('flag_cache_mode is invalid.'); 1234 | 1235 | def save_cache_data(self,cache_filename): 1236 | r"""Save data to .pickle file for caching. (Warning: Prototype placeholder code.)""" 1237 | # collect the data to save 1238 | d = {}; 1239 | d['epsilon'] = self.epsilon; 1240 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy(); 1241 | d['pts_x1'] = self.pts_x1_numpy; 1242 | if self.pts_x2_numpy is None: self.pts_x2_numpy = pts_x2.cpu().numpy(); 1243 | d['pts_x2'] = self.pts_x2_numpy; 1244 | d['version'] = __version__; # Module version 1245 | d['cached_data'] = self.cached_data; 1246 | 1247 | # write the data to disk 1248 | f = open(cache_filename,'wb'); 1249 | p.dump(d,f); # load the data from file 1250 | f.close(); 1251 | 1252 | def load_cache_data(self,cache_filename): 1253 | r"""Load data to .pickle file for caching. (Warning: Prototype placeholder code.)""" 1254 | f = open(cache_filename,'rb'); 1255 | d = p.load(f); # load the data from file 1256 | f.close(); 1257 | 1258 | print(d.keys()) 1259 | self.epsilon = d['epsilon']; 1260 | 1261 | self.pts_x1 = torch.from_numpy(d['pts_x1']).to(device); 1262 | self.pts_x2 = torch.from_numpy(d['pts_x2']).to(device); 1263 | 1264 | self.pts_x1_numpy = d['pts_x1']; 1265 | self.pts_x2_numpy = d['pts_x2']; 1266 | 1267 | if self.pts_x2 is None: 1268 | self.pts_x2 = pts_x1; 1269 | 1270 | # build kd-tree of points for neighbor listing 1271 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy(); 1272 | self.tree_points = spatial.cKDTree(self.pts_x1_numpy); 1273 | 1274 | self.cached_data = d['cached_data']; 1275 | 1276 | def forward(self, input): # define the action of this layer 1277 | r"""Applies a max-pooling operation to obtain values :math:`v_i = \max_{j \in \mathcal{N}_i(\epsilon)} \{u_j\}.` 1278 | 1279 | Args: 1280 | input (Tensor): The collection uj of field values at the points xj. 1281 | 1282 | Returns: 1283 | Tensor: The collection of field values vi at the target points xi. 1284 | 1285 | """ 1286 | flag_time_it = False; 1287 | if flag_time_it: 1288 | time_1 = time.time(); 1289 | 1290 | uj = input; 1291 | output = MaxPoolOverPoints_Function.apply(uj,self.pts_x1,self.epsilon,self.pts_x2, 1292 | self.indices_xj_i_cache,self.tree_points, 1293 | self.flag_verbose); 1294 | 1295 | if flag_time_it: 1296 | msg = 'MaxPoolOverPoints->forward(): '; 1297 | msg += 'elapsed_time = %.4e'%(time.time() - time_1); 1298 | print(msg); 1299 | 1300 | return output; 1301 | 1302 | def extra_repr(self): 1303 | r"""Displays information associated with this module.""" 1304 | return 'pts_x1={}, pts_x2={}'.format(self.pts_x1.shape, self.pts_x2.shape); 1305 | 1306 | def to(self, device): 1307 | r"""Moves data to GPU or other specified device.""" 1308 | super(MaxPoolOverPoints, self).to(device); 1309 | self.pts_x1 = self.pts_x1.to(device); 1310 | self.pts_x2 = self.pts_x2.to(device); 1311 | return self; 1312 | 1313 | class GMLS_Layer(nn.Module): 1314 | r"""The GMLS-Layer processes scattered data by using Generalized Moving Least 1315 | Squares (GMLS) to construct a local reconstruction of the data (here polynomials). 1316 | This is represented by coefficients that are mapped to approximate the action of 1317 | linear or non-linear operators on the input field. 1318 | 1319 | As depicted above, the architecture processes a collection of input channels 1320 | into intermediate coefficient channels. The coefficient channels are 1321 | then collectively mapped to output channels. The mappings can be any unit 1322 | for which back-propagation can be performed. This includes linear 1323 | layers or non-linear maps based on multilayer perceptrons (MLPs). 1324 | 1325 | Examples: 1326 | 1327 | Here is a typical way to construct a GMLS-Layer. This is done in 1328 | the following stages. 1329 | 1330 | ``(i)`` Construct the scattered data locations xj, xi at which processing will occur. Here, we create points in 2D. 1331 | 1332 | >>> xj = torch.randn((100,2),device=device); xi = torch.randn((100,2),device=device); 1333 | 1334 | ``(ii)`` Construct the mapping unit that will be applied pointwise. Here we create an MLP 1335 | with Nc input coefficient channels and channels_out output channels. 1336 | 1337 | >>> layer_sizes = []; 1338 | >>> num_input = Nc*num_polys; # number of channels (NC) X number polynomials (num_polys) (cross-channel coupling allowed) 1339 | >>> num_depth = 4; num_hidden = 100; channels_out = 16; # depth, width, number of output filters 1340 | >>> layer_sizes.append(num_polys); 1341 | >>> for k in range(num_depth): 1342 | >>> layer_sizes.append(num_hidden); 1343 | >>> layer_sizes.append(1); # a single unit always gives scalar output, we then use channels_out units. 1344 | >>> mlp_q_map1 = gmlsnets_pytorch.nn.MLP_Pointwise(layer_sizes,channels_out=channels_out); 1345 | 1346 | ``(iii)`` Create the GMLS-Layer using these components. 1347 | 1348 | >>> weight_func1 = gmlsnets_pytorch.nn.MapToPoly_Function.weight_one_minus_r; 1349 | >>> weight_func_params = {'epsilon':1e-3,'p'=4}; 1350 | >>> gmls_layer_params = { 1351 | 'flag_case':'standard','porder':4,'Nc':3, 1352 | 'mlp_q1':mlp_q_map1, 1353 | 'pts_x1':xj,'pts_x2':xi,'epsilon':1e-3, 1354 | 'weight_func1':weight_func1,'weight_func1_params':weight_func1_params, 1355 | 'device':device,'flag_verbose':0 1356 | }; 1357 | >>> gmls_layer=gmlsnets_pytorch.nn.GMLS_Layer(**gmls_layer_params); 1358 | 1359 | Here is an example of how a GMLS-Layer and other modules in this package 1360 | can be used to process scattered data. This could be part of a larger 1361 | neural network in practice (see example codes for more information). For instance, 1362 | 1363 | >>> layer1 = nn.Sequential(gmls_layer, # produces output tuple of tensors (ci,xi) with shapes ([batch,ci,xi],[xi]). 1364 | #PdbSetTraceLayer(), 1365 | ExtractFromTuple(index=0), # from output keep only the ui part and discard the xi part. 1366 | #PdbSetTraceLayer(), 1367 | PermuteLayer((0,2,1)) # organize indexing to be [batch,xi,ci], for further processing. 1368 | ).to(device); 1369 | 1370 | You can uncomment the PdbSetTraceLayer() to get breakpoints for state information and tensor shapes during processing. 1371 | The PermuteLayer() changes the order of the indexing. Also can use ReshapeLayer() to reshape the tensors, which is 1372 | especially useful for processing data related to CNNs. 1373 | 1374 | Much of the construction can be further simplified by writing a few wrapper classes for your most common use cases. 1375 | 1376 | More information also can be found in the example codes directory. 1377 | """ 1378 | def __init__(self, flag_case, porder, pts_x1, epsilon, weight_func, weight_func_params, 1379 | mlp_q = None,pts_x2 = None, device = None, flag_verbose = 0): 1380 | r""" 1381 | Initializes the GMLS layer. 1382 | 1383 | Args: 1384 | flag_case (str): Flag for the type of architecture to use (default is 'standard'). 1385 | porder (int): Order of the basis to use (polynomial degree). 1386 | pts_x1 (Tensor): The collection of domain points :math:`x_j`. 1387 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params). 1388 | weight_func (func): Weight function to use. 1389 | weight_func_params (dict): Weight function parameters. 1390 | mlp_q (module): Mapping q unit for computing :math:`q(c)`, where c are the coefficients. 1391 | pts_x2 (Tensor): The collection of target points :math:`x_i`. 1392 | device: Device on which to perform calculations (GPU or other, default is CPU). 1393 | flag_verbose (int): Level of reporting on progress during the calculations. 1394 | """ 1395 | super(GMLS_Layer, self).__init__(); 1396 | 1397 | if flag_case is None: 1398 | self.flag_case = 'standard'; 1399 | else: 1400 | self.flag_case = flag_case; 1401 | 1402 | if device is None: 1403 | device = torch.device('cpu'); 1404 | 1405 | self.device = device; 1406 | 1407 | if self.flag_case == 'standard': 1408 | tree_points = None; 1409 | self.MapToPoly_1 = MapToPoly(porder, weight_func, weight_func_params, 1410 | pts_x1, epsilon, pts_x2, tree_points, 1411 | device, flag_verbose); 1412 | 1413 | if mlp_q is None: # if not specified then create some default custom layers 1414 | raise Exception("Need to specify the mlp_q module for mapping coefficients to output."); 1415 | else: # in this case initialized outside 1416 | self.mlp_q = mlp_q; 1417 | 1418 | else: 1419 | print("flag_case = " + str(flag_case)); 1420 | print("self.flag_case = " + str(self.flag_case)); 1421 | raise Exception('flag_case not valid.'); 1422 | 1423 | def forward(self, input): 1424 | r"""Computes GMLS-Layer processing scattered data input field uj to obtain output field vi. 1425 | 1426 | Args: 1427 | input (Tensor): Input channels uj organized in the shape [batch,xj,uj]. 1428 | 1429 | Returns: 1430 | tuple: The output channels and point locations (vi,xi). The field vi = q(ci). 1431 | 1432 | """ 1433 | if self.flag_case == 'standard': 1434 | 1435 | map_output = self.MapToPoly_1.forward(input); 1436 | c_star_i = map_output[0]; 1437 | pts_x2 = map_output[1]; 1438 | 1439 | # MLP should apply across all channels and coefficients (coeff capture spatial, like kernel) 1440 | fc_input = c_star_i.permute((0,2,1,3)); # we organize as [batchI,ptsI,channelsI,coeffI] 1441 | 1442 | # We assume MLP can process channelI*Nc + coeffI. 1443 | # We assume output of out = fc, has shape [batchI,ptsI,channelsNew] 1444 | # Outside routines can reshape that into an nD array again for structure samples or use over scattered samples. 1445 | q_of_c_star_i = self.mlp_q.forward(fc_input); 1446 | 1447 | pts_x2_p = None; # currently returns None to simplify back-prop and debugging, but could just return the pts_x2. 1448 | 1449 | return_vals = q_of_c_star_i, pts_x2_p; 1450 | 1451 | return return_vals; 1452 | 1453 | def to(self, device): 1454 | r"""Moves data to GPU or other specified device.""" 1455 | super(GMLS_Layer, self).to(device); 1456 | self.MapToPoly_1 = self.MapToPoly_1.to(device); 1457 | self.mlp_q = self.mlp_q.to(device); 1458 | return self; 1459 | 1460 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collection of utility routines. 3 | """ 4 | 5 | # Authors: B.J. Gross and P.J. Atzberger 6 | # Website: http://atzberger.org/ 7 | 8 | import torch; 9 | import torch.nn as nn; 10 | import numpy as np; 11 | 12 | import pdb 13 | import time 14 | 15 | #****************************************** 16 | # Custom Functions 17 | #****************************************** 18 | class PeriodicPad2dFunc(torch.autograd.Function): 19 | """Performs periodic padding/tiling of a 2d lattice.""" 20 | 21 | @staticmethod 22 | def forward(ctx, input, pad=None,flag_coords=False): 23 | """ 24 | Periodically tiles the input into a 2d array for 3x3 repeat pattern. 25 | This allows for a quick way for handling periodic boundary conditions 26 | on the units cell. 27 | 28 | Args: 29 | input (Tensor): A tensor of size [nbatch,nchannel,nx,ny] or [nx,ny]. 30 | pad (float): The pad value to use. 31 | flag_coord (bool): If beginning components are coordinates (x,y,z,...,u). 32 | 33 | We then adjust (x,y,u) --> (x + i, y + j, u) for image (i,j). 34 | 35 | Returns: 36 | output (Tensor): A tensor of the same size [nbatch,nchannel,3*nx,3*ny] or [3*nx,3*ny]. 37 | 38 | """ 39 | # We use alternative by concatenating the arrays together to tile. 40 | # Process Tensor inputs of the shape with [nbatch,nchannel,nx,ny]. 41 | # We also allow the case of [nx,ny] for single input. 42 | nd = input.dim(); 43 | if (nd > 4): 44 | raise Exception('Expects tensor that has number of dimensions <= 4 (dim = {}).'.format(nd)); 45 | 46 | if (nd < 2): 47 | raise Exception('Expects tensor that has at least number of dimensions >= 2 (dim = {}).'.format(nd)); 48 | 49 | a = input; 50 | if (nd == 2): 51 | # add extra dimensions so we can process all tensors the same way 52 | a = input.unsqueeze(0).unsqueeze(0); # size --> [1,1,nx,ny] 53 | 54 | w1 = w2 = a; 55 | aa = torch.cat([w1,a,w2],dim=2); 56 | h1 = h2 = aa; 57 | output = torch.cat([h1,aa,h2],dim=3); 58 | 59 | if flag_coords: # indicates (x,y,u) input, so need 60 | # to adjust x + i-1 and y + j-1, i,j = 0,1,2 to extend. 61 | coordI1 = 0; coordI2 = 1; #x and y components 62 | N1 = output.shape[2]//3; # block size 63 | N2 = output.shape[3]//3; 64 | for j in range(0,3): 65 | blockI2 = np.arange(N2*j,N2*(j + 1),dtype=int); 66 | output[:,coordI1,:,blockI2] += (j - 1); 67 | 68 | for i in range(0,3): 69 | blockI1 = np.arange(N1*i,N1*(i + 1),dtype=int); 70 | output[:,coordI2,blockI1,:] += (i - 1); 71 | 72 | if (nd == 2): # if only [nx,ny] then squeeze our extra dimensions 73 | output = output.squeeze(0).squeeze(0); 74 | 75 | ctx.pad = pad; 76 | ctx.size = input.size(); 77 | ctx.numel = input.numel(); 78 | ctx.num_tiles = 3; 79 | 80 | return output; 81 | 82 | @staticmethod 83 | def backward(ctx, grad_output): 84 | r"""Compute df/dx from df/dy using the Chain Rule df/dx = df/dx*dy/dx. 85 | For periodic padding we use df/dx = sum_{y_i ~ x} df/dy_i, where 86 | the y_i ~ x are all points y_i that are equivalent to x under the periodic extension. 87 | """ 88 | num_tiles = ctx.num_tiles; 89 | 90 | nd = input.dim(); 91 | if (nd > 4): 92 | raise Exception('Expects tensor that has number of dimensions <= 4 (dim = {}).'.format(nd)); 93 | 94 | if (nd < 2): 95 | raise Exception('Expects tensor that has at least number of dimensions >= 2 (dim = {}).'.format(nd)); 96 | 97 | b = grad_output; # short-hand for grad_output (size of output) 98 | if (nd == 2): 99 | # add extra dimensions so we can process all tensors the same way 100 | b = b.unsqueeze(0).unsqueeze(0); # size --> [1,1,nx,ny] 101 | 102 | # construct indices to contract the periodic images in each dimension 103 | ind_r = torch.zeros(ctx.size,dtype=torch.int64); 104 | torch.arange(0, ctx.size[2]*num_tiles, out=ind_r); 105 | ind_r = ind_r.fmod(ctx.size[2]); # repeat indices number of rows [0,1,..nrow-1,0,...nrow-1]. 106 | ind_r = ind_r.view(-1); 107 | 108 | ind_c = torch.zeros(ctx.size,dtype=torch.int64); 109 | torch.arange(0, ctx.size[3]*num_tiles, out=ind_c); 110 | ind_c = ind_c.fmod(ctx.size[3]); # repeat indices number of cols [0,1,..ncol-1,0,...ncol-1]. 111 | ind_c = ind_c.view(-1); 112 | 113 | c = b.new_zeros(ctx.size[0],ctx.size[1],ctx.size[2],ctx.size[3]*num_tiles); 114 | 115 | c = c.index_add(2,ind_r,grad_output); # add the rows together to start contracting periodicity 116 | 117 | d = b.new_zeros(a.transpose(2,3).size()); 118 | d = d.index_add(2,ind_c,c.t()); # add the cols together to contract periodicity 119 | grad_input = d.transpose(2,3); # Note, this includes contraction of grad_output already, so is df/dx. 120 | 121 | if (nd == 2): # if only [nx,ny] then squeeze out extra dimensions 122 | output = grad_input.squeeze(0).squeeze(0); 123 | 124 | return grad_input; 125 | 126 | #------------------ 127 | class ExtractUnitCell2dFunc(torch.autograd.Function): 128 | r"""Extracts the 2d unit cell from periodic tiling.""" 129 | 130 | @staticmethod 131 | def forward(ctx, input, pad=None): 132 | r"""Extracts the 2d unit cell from periodic tiling. 133 | 134 | Args: 135 | input (Tensor): Tensor of size [nbatch,nchannel,nx,ny] or [nx,ny]. 136 | 137 | Returns: 138 | Tensor: The Tensor of size [nbatch,nchannel,nx/3,ny/3] or [nx/3,ny/3]. 139 | 140 | """ 141 | num_tiles = 3; 142 | ctx.num_tiles = num_tiles; 143 | 144 | nd = input.dim(); 145 | if (nd > 4): 146 | raise Exception('Expects tensor that has number of dimensions <= 4 (dim = {}).'.format(nd)); 147 | 148 | if (nd < 2): 149 | raise Exception('Expects tensor that has at least number of dimensions >= 2 (dim = {}).'.format(nd)); 150 | 151 | a = input; # short-hand 152 | if (nd == 2): 153 | # add extra dimensions so we can process all tensors the same way 154 | a = a.unsqueeze(0).unsqueeze(0); # size --> [1,1,nx,ny] 155 | 156 | chunks_col = torch.chunk(a,num_tiles,dim=2); 157 | aa = torch.chunk(chunks_col[1],num_tiles,dim=3); # extract middle of middle 158 | 159 | output = aa[1]; # choose middle one (assumes num_tiles==3) 160 | 161 | if (nd == 2): # if only [nx,ny] then squeeze out extra dimensions 162 | output = output.squeeze(0).squeeze(0); 163 | 164 | return output; 165 | 166 | @staticmethod 167 | def backward(ctx, grad_output): 168 | r""" Compute df/dx from df/dy using the Chain Rule df/dx = df/dx*dy/dx. 169 | For periodic padding we use df/dx = sum_{y_i ~ x} df/dy_i, where 170 | the y_i ~ x are all points y_i that are equivalent to x under the periodic extension. 171 | 172 | For extracting from periodic tiling the unit cell, the derivatives dy/dx are zero 173 | unless x is within the unit cell. The block matrix is dy/dx = [[Z,Z,Z],[Z,W,Z],[Z,Z,Z]], 174 | where W is the derivative values in the unit cell (df/dy) and Z is the zero matrix. 175 | """ 176 | num_tiles = ctx.num_tiles; 177 | 178 | nd = grad_output.dim(); 179 | if (nd > 4): 180 | raise Exception('Expects tensor that has number of dimensions <= 4 (dim = {}).'.format(nd)); 181 | 182 | if (nd < 2): 183 | raise Exception('Expects tensor that has at least number of dimensions >= 2 (dim = {}).'.format(nd)); 184 | 185 | W = grad_output; # short-hand for grad_output (size of output) 186 | if (nd == 2): 187 | # add extra dimensions so we can process all tensors the same way 188 | W = W.unsqueeze(0).unsqueeze(0); # size --> [1,1,nx,ny] 189 | 190 | s = W.size(); 191 | Z = grad_output.new_zeros(s[0],s[1],s[2],s[3]); 192 | A1 = torch.cat([Z,Z,Z],dim=3); 193 | B1 = torch.cat([Z,W,Z],dim=3); 194 | bb = torch.cat([A1,B1,A1],dim=2); 195 | 196 | grad_input = bb; 197 | 198 | if (nd == 2): # if only [nx,ny] then squeeze out extra dimensions 199 | grad_input = grad_input.squeeze(0).squeeze(0); 200 | 201 | return grad_input; 202 | 203 | #****************************************** 204 | # Custom Modules 205 | #****************************************** 206 | class PeriodicPad2d(nn.Module): 207 | def __init__(self,flag_coords=False): 208 | r"""Setup for computing the periodic tiling.""" 209 | super(PeriodicPad2d, self).__init__() 210 | self.flag_coords = flag_coords; 211 | 212 | def forward(self, input): 213 | r"""Compute the periodic padding of the input. """ 214 | return PeriodicPad2dFunc.apply(input,None,self.flag_coords); 215 | 216 | def extra_repr(self): 217 | r"""Displays some of the information associated with the module. """ 218 | return 'PeriodicPad2d: (no internal parameters)'; 219 | 220 | class ExtractUnitCell2d(nn.Module): 221 | def __init__(self): 222 | super(ExtractUnitCell2d, self).__init__() 223 | 224 | def forward(self, input): 225 | r"""Computes the periodic padding of the input.""" 226 | return ExtractUnitCell2dFunc.apply(input); 227 | 228 | def extra_repr(self): 229 | r"""Displays some of the information associated with the module. """ 230 | return 'ExtractUnitCell2d: (no internal parameters)'; 231 | 232 | -------------------------------------------------------------------------------- /src/vis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collection of routines helpful for visualizing results and generating figures. 3 | """ 4 | 5 | # Authors: B.J. Gross and P.J. Atzberger 6 | # Website: http://atzberger.org/ 7 | 8 | import matplotlib; 9 | import matplotlib as mtpl; 10 | import matplotlib.pyplot as plt; 11 | import matplotlib.gridspec as gridspec; 12 | 13 | import numpy as np; 14 | 15 | #---------------------------- 16 | def plot_samples_u_f_1d(u_list,f_list,np_xj,np_xi,rows=4,cols=6, 17 | figsize = (20*0.9,10*0.9),title="Data Samples: u, f=L[u]", 18 | xlabel='x',ylabel='', 19 | left=0.125,bottom=0.1,right=0.9, top=0.94,wspace=0.4,hspace=0.4, 20 | fontsize=16,y=1.00,flag_draw=True,**extra_params): 21 | 22 | r"""Plots a collection of data samples in a panel.""" 23 | # -- 24 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize,sharey=False); 25 | fig.subplots_adjust(left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace); 26 | plt.suptitle(title, fontsize=fontsize,y=y); 27 | 28 | # -- 29 | I1 = 0; I2 = 0; 30 | for i1 in range(0,rows): 31 | for i2 in range(0,cols): 32 | ax = axs[i1,i2]; 33 | 34 | if i2 % 2 == 0: 35 | xx = np_xj[:,0]; yy = u_list[I1].numpy()[0,:]; 36 | #yy = yy.squeeze(0); 37 | ax.plot(xx,yy,'m.-'); 38 | if i1 == 0: 39 | ax.set_title('u'); 40 | I1 += 1; 41 | else: 42 | xx = np_xi[:,0]; yy = f_list[I2].numpy()[0,:]; 43 | #yy = yy.squeeze(0); 44 | ax.plot(xx,yy,'r.-'); 45 | if i1 == 0: 46 | ax.set_title('f'); 47 | I2 += 1; 48 | 49 | ax.set_xlabel(xlabel); 50 | ax.set_ylabel(ylabel); 51 | 52 | if flag_draw: 53 | plt.draw(); 54 | 55 | def plot_samples_u_f_fp_1d(u_list,f_target_list,f_pred_list,np_xj,np_xi,rows=4,cols=6, 56 | figsize = (20*0.9,10*0.9),title="Data Samples: u, f=L[u]", 57 | xlabel='x',ylabel='', 58 | left=0.125,bottom=0.1,right=0.9, top=0.94,wspace=0.4,hspace=0.4, 59 | fontsize=16,y=1.00,flag_draw=True,**extra_params): 60 | 61 | r"""Plots a collection of data samples and predictions in a panel.""" 62 | # -- 63 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize,sharey=False); 64 | fig.subplots_adjust(left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace); 65 | plt.suptitle(title, fontsize=fontsize,y=y); 66 | 67 | # -- 68 | I1 = 0; I2 = 0; 69 | for i1 in range(0,rows): 70 | for i2 in range(0,cols): 71 | ax = axs[i1,i2]; 72 | 73 | if i2 % 2 == 0: 74 | ax.plot(np_xj[:,0],u_list[I1].numpy()[0,:],'m.-'); 75 | if i1 == 0: 76 | ax.set_title('u'); 77 | I1 += 1; 78 | ax.set_xlabel(xlabel); 79 | ax.set_ylabel(ylabel); 80 | else: 81 | ax.plot(np_xj[:,0],f_target_list[I2].numpy()[0,:],'r.-'); 82 | ax.plot(np_xj[:,0],f_pred_list[I2].numpy()[0,:],'b.-'); 83 | if i1 == 0: 84 | ax.set_title('f'); 85 | I2 += 1; 86 | ax.set_xlabel(xlabel); 87 | ax.set_ylabel(ylabel); 88 | 89 | if flag_draw: 90 | plt.draw(); 91 | 92 | def plot_dataset_diffOp1(dataset,np_xj=None,np_xi=None,rows=4,cols=6,II=None, 93 | figsize=(20*0.9,10*0.9),title="Data Samples: u, f=L[u]", 94 | xlabel='x',ylabel='', 95 | left=0.125,bottom=0.1,right=0.9, top=0.94,wspace=0.4,hspace=0.4, 96 | fontsize=16,y=1.00,flag_draw=True,**extra_params): 97 | 98 | r"""Plots a collection of data samples in a panel.""" 99 | u_list = []; f_list = []; f_pred_list = []; 100 | num_samples = len(dataset); 101 | for I in np.arange(0,min(num_samples,int(rows*cols/2))): 102 | if II is None: 103 | u_list.append(dataset[I][0].cpu()); # just make plain lists for convenience 104 | f_list.append(dataset[I][1].cpu()); 105 | else: 106 | u_list.append(dataset[II[I]][0].cpu()); 107 | f_list.append(dataset[II[I]][1].cpu()); 108 | 109 | # -- 110 | plot_samples_u_f_1d(u_list=u_list,f_list=f_list,np_xj=np_xj,np_xi=np_xi, 111 | rows=rows,cols=cols, 112 | figsize=figsize,title=title, 113 | xlabel=xlabel,ylabel=ylabel, 114 | left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace, 115 | fontsize=fontsize,y=y,flag_draw=flag_draw,**extra_params); 116 | 117 | #---------------------------- 118 | def plot_samples_u_f_2d(u_list,f_list,np_xj,np_xi,channelI_u=0,channelI_f=0,rows=4,cols=6, 119 | figsize = (20*0.9,10*0.9),title="Data Samples: u, f=L[u]", 120 | xlabel='x',ylabel='', 121 | left=0.125, bottom=0.1, right=0.9, top=0.94, wspace=0.01, hspace=0.1, 122 | fontsize=16,y=1.00,flag_draw=True,**extra_params): 123 | 124 | r"""Plots a collection of data samples in a panel.""" 125 | # -- 126 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize,sharey=False); 127 | fig.subplots_adjust(left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace); 128 | plt.suptitle(title, fontsize=fontsize,y=y); 129 | 130 | # -- 131 | for ax in axs.flatten(): 132 | ax.axis('off'); 133 | 134 | # -- 135 | I1 = 0; I2 = 0; Ic_u = channelI_u;Ic_f = channelI_f; 136 | for i1 in range(0,rows): 137 | for i2 in range(0,cols): 138 | ax = axs[i1,i2]; 139 | 140 | if i2 % 2 == 0: 141 | uu = u_list[I1][Ic_u,:,:]; 142 | ax.imshow(uu,cmap='Blues_r'); 143 | if i1 == 0: 144 | ax.set_title('u'); 145 | I1 += 1; 146 | else: 147 | ff = f_list[I2][Ic_f,:,:]; 148 | ax.imshow(ff,cmap='Purples_r'); 149 | if i1 == 0: 150 | ax.set_title('f'); 151 | I2 += 1; 152 | 153 | if flag_draw: 154 | plt.draw(); 155 | 156 | def plot_samples_u_f_fp_2d(u_list,f_target_list,f_pred_list,np_xj,np_xi,channelI_u=0,channelI_f=0,rows=4,cols=6, 157 | figsize = (20*0.9,10*0.9),title="Data Samples: u, f=L[u]", 158 | xlabel='x',ylabel='', 159 | left=0.125, bottom=0.1, right=0.9, top=0.94, wspace=0.01, hspace=0.1, 160 | fontsize=16,y=1.00,flag_draw=True,**extra_params): 161 | 162 | r"""Plots a collection of data samples and predictions in a panel.""" 163 | # -- 164 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize,sharey=False); 165 | fig.subplots_adjust(left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace); 166 | plt.suptitle(title, fontsize=fontsize,y=y); 167 | 168 | # -- 169 | for ax in axs.flatten(): 170 | ax.axis('off'); 171 | 172 | # -- 173 | Ic_u = channelI_u; Ic_f = channelI_f; 174 | I1 = 0; I2 = 0; I3 = 0; 175 | for i1 in range(0,rows): 176 | for i2 in range(0,cols): 177 | ax = axs[i1,i2]; 178 | 179 | if i2 % 3 == 0: 180 | uu = u_list[I1][Ic_u,:,:]; 181 | ax.imshow(uu,cmap='Blues_r'); 182 | if i1 == 0: 183 | ax.set_title('u'); 184 | I1 += 1; 185 | elif i2 % 3 == 1: 186 | ff = f_pred_list[I2][Ic_f,:,:]; 187 | ax.imshow(ff,cmap='Purples_r'); 188 | if i1 == 0: 189 | ax.set_title('f:predicted'); 190 | I2 += 1; 191 | elif i2 % 3 == 2: 192 | ff = f_target_list[I3][Ic_f,:,:]; 193 | ax.imshow(ff,cmap='Purples_r'); 194 | if i1 == 0: 195 | ax.set_title('f:target'); 196 | I3 += 1; 197 | 198 | if flag_draw: 199 | plt.draw(); 200 | 201 | def plot_dataset_diffOp2(dataset,np_xj=None,np_xi=None,channelI_u=0,channelI_f=0,rows=4,cols=6,II=None, 202 | figsize=(20*0.9,10*0.9),title="Data Samples: u, f=L[u]", 203 | xlabel='x',ylabel='', 204 | left=0.125, bottom=0.1, right=0.9, top=0.94, wspace=0.01, hspace=0.1, 205 | fontsize=16,y=1.00,flag_draw=True,**extra_params): 206 | 207 | 208 | 209 | r"""Plots a collection of data samples in a panel.""" 210 | u_list = []; f_list = []; f_pred_list = []; 211 | num_samples = len(dataset); 212 | for I in np.arange(0,min(num_samples,int(rows*cols/2))): 213 | if II is None: 214 | u_list.append(dataset[I][0].cpu()); # just make plain lists for convenience 215 | f_list.append(dataset[I][1].cpu()); 216 | else: 217 | u_list.append(dataset[II[I]][0].cpu()); 218 | f_list.append(dataset[II[I]][1].cpu()); 219 | 220 | # -- 221 | plot_samples_u_f_2d(u_list=u_list,f_list=f_list,np_xj=np_xj,np_xi=np_xi, 222 | channelI_u=channelI_u,channelI_f=channelI_f, 223 | rows=rows,cols=cols, 224 | figsize=figsize,title=title, 225 | xlabel=xlabel,ylabel=ylabel, 226 | left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace, 227 | fontsize=fontsize,y=y,flag_draw=flag_draw,**extra_params); 228 | 229 | #---------------------------- 230 | def plot_images_in_array(axs,img_arr,label_arr=None,cmap=None, **extra_params): 231 | r"""Plots an array of images as a collection of panels.""" 232 | 233 | numSamples = len(img_arr); 234 | sqrtS = int(np.sqrt(numSamples)); 235 | 236 | # Default values 237 | flag_plot_rect = False; 238 | list_correct_class = None; 239 | 240 | if 'list_correct_class' in extra_params: 241 | list_correct_class = extra_params['list_correct_class']; 242 | flag_plot_rect = True; 243 | 244 | if 'flag_plot_rect' in extra_params: 245 | flag_plot_rect = extra_params['flag_plot_rect']; 246 | 247 | I = 0; 248 | for i in range(0,sqrtS): 249 | for j in range(0,sqrtS): 250 | ax = axs[i][j]; 251 | img = img_arr[I]; 252 | 253 | if len(img.shape) >= 3: 254 | if img.shape[2] == 1: # For BW case of (Nx,Ny,1) --> (Nx,Ny), RGB has (Nx,Ny,3). 255 | img = img.squeeze(2); 256 | 257 | if cmap is not None: 258 | ax.imshow(img, cmap=cmap); 259 | else: 260 | ax.imshow(img); 261 | 262 | if label_arr is not None: 263 | ax.set_title("%s"%label_arr[I]); 264 | 265 | ax.set_xticks([]); 266 | ax.set_yticks([]); 267 | 268 | if flag_plot_rect: 269 | 270 | if list_correct_class[I]: 271 | edge_color = 'g'; 272 | else: 273 | edge_color = 'r'; 274 | 275 | # draw a rectangle 276 | Nx = img.shape[0]; Ny = img.shape[1]; 277 | rectangle = mtpl.patches.Rectangle((0,0),Nx-1,Ny-1, 278 | linewidth=5,edgecolor=edge_color, 279 | facecolor='none'); 280 | ax.add_patch(rectangle); 281 | 282 | I += 1; 283 | 284 | def plot_image_array(img_arr,label_arr=None,title=None,figSize=(18,18), 285 | title_yp=0.95,cmap="gray",**extra_params): 286 | r"""Plots an array of images as a collection of panels.""" 287 | 288 | # determine number of images we need to plot 289 | numSamples = len(img_arr); 290 | sqrtS = int(np.sqrt(numSamples)); 291 | rows = sqrtS; 292 | cols = sqrtS; 293 | 294 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figSize); 295 | 296 | plot_images_in_array(axs,img_arr,label_arr,cmap=cmap,**extra_params); 297 | 298 | if title is None: 299 | plt.suptitle("Collection of Images", fontsize=18,y=title_yp); 300 | else: 301 | plt.suptitle(title, fontsize=18,y=title_yp); 302 | 303 | def plot_gridspec_image_array(outer_h,img_arr,label_arr=None,title=None,figSize=(18,18),title_yp=0.95,cmap="gray",title_x=0.0,title_y=1.0,title_fsize=14): 304 | r"""Plots an array of images as a collection of panels.""" 305 | fig = plt.gcf(); 306 | 307 | sqrtS = int(np.sqrt(len(img_arr))); 308 | 309 | inner = gridspec.GridSpecFromSubplotSpec(sqrtS, sqrtS, 310 | subplot_spec=outer_h, wspace=0.1, hspace=0.1); 311 | 312 | # collect the axes 313 | axs =[]; 314 | I = 0; 315 | for i in range(0,sqrtS): 316 | axs_r = []; 317 | for j in range(0,sqrtS): 318 | ax = plt.Subplot(fig, inner[I]); 319 | fig.add_subplot(ax); 320 | axs_r.append(ax); 321 | I += 1; 322 | 323 | axs.append(axs_r); 324 | 325 | # plot the images 326 | plot_images_in_array(axs,img_arr,cmap="gray"); 327 | 328 | if title is None: 329 | a = 1; 330 | else: 331 | axs[0][0].text(title_x,title_y,title,fontsize=title_fsize); 332 | 333 | #---------------------------- 334 | def save_fig(baseFilename,extraLabel,flag_verbose=0,dpi_set=200,flag_pdf=False): 335 | r"""Saves figures to disk.""" 336 | 337 | fig = plt.gcf(); 338 | fig.patch.set_alpha(1.0); 339 | fig.patch.set_facecolor((1.0,1.0,1.0,1.0)); 340 | 341 | if flag_pdf: 342 | saveFilename = '%s%s.pdf'%(baseFilename,extraLabel); 343 | if flag_verbose > 0: 344 | print('saveFilename = %s'%saveFilename); 345 | plt.savefig(saveFilename, format='pdf',dpi=dpi_set,facecolor=(1,1,1,1),alpha=1.0); 346 | 347 | saveFilename = '%s%s.png'%(baseFilename,extraLabel); 348 | if flag_verbose > 0: 349 | print('saveFilename = %s'%saveFilename); 350 | plt.savefig(saveFilename, format='png',dpi=dpi_set,facecolor=(1,1,1,1),alpha=1.0); 351 | 352 | 353 | --------------------------------------------------------------------------------