├── LICENSE.txt
├── README.md
├── docs
└── README.md
├── examples
├── README.md
├── regression_diff_op_1d.ipynb
├── regression_diff_op_1d.py
├── regression_diff_op_2d.ipynb
├── regression_diff_op_2d.py
└── z_doc_img
│ ├── diff_op_1d_1.png
│ └── diff_op_2d_1.png
├── gmlsnets_pytorch-1.0.0.tar.gz
├── misc
└── overview.png
└── src
├── __init__.py
├── dataset.py
├── nn.py
├── util.py
└── vis.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Paul J. Atzberger
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GMLS-Nets
2 |
3 |
4 |

5 |
6 |
7 | [Examples](https://github.com/atzberg/gmls-nets/tree/master/examples) | [Documentation](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html)
8 |
9 | __GMLS-Nets (PyTorch Implementation)__
10 |
11 | Package provides machine learning methods for learning features from scattered unstructured data sets using Generalized Moving Least Squares (GMLS). Provides techniques which can be used to generalize approaches, such as Convolutional Neural Networks (CNNs) which utilize translational and other symmetry, to unstructured data. GMLS-Nets package also provides approaches for learning differential operators, PDEs, and other features from scattered data.
12 |
13 | __Quick Start__
14 |
15 | *Method 1:* Install for python using pip
16 |
17 | ```pip install -U gmlsnets-pytorch```
18 |
19 | For use of the package see the [examples page](https://github.com/atzberg/gmls-nets/tree/master/examples). For getting the latest version use ```pip install --upgrade gmlsnets-pytorch```. More information on the structure of the package also can be found on the [documentation page](https://github.com/atzberg/gmls-nets/tree/master/docs).
20 |
21 | If previously installed the package, please update to the latest version using ```pip install --upgrade gmlsnets-pytorch```
22 |
23 | __Manual Installation__
24 |
25 | *Method 2:* Download the [gmlsnets_pytorch-1.0.0.tar.gz](https://github.com/atzberg/gmls-nets-testing/blob/master/gmlsnets_pytorch-1.0.0.tar.gz) file above, then uncompress
26 |
27 | ``tar -xvf gmlsnets_pytorch-1.0.0.tar.gz``
28 |
29 | For local install, please be sure to edit in your codes the path location of base directory by adding
30 |
31 | ``sys.path.append('package-path-location-here');``
32 |
33 | Note the package resides in the sub-directory ``./gmlsnets_pytorch-1.0.0/gmlsnets_pytorch/``
34 |
35 | __Packages__
36 |
37 | Please be sure to install [PyTorch](https://pytorch.org/) package >= 1.2.0 with Python 3 (ideally >= 3.7). Also, be sure to install the following packages: numpy>=1.16, scipy>=1.3, matplotlib>=3.0.
38 |
39 | __Use__
40 |
41 | For examples and documentation, see
42 |
43 | [Examples](https://github.com/atzberg/gmls-nets/tree/master/examples)
44 |
45 | [Documentation](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html)
46 |
47 | __Additional Information__
48 |
49 | If you find these codes or methods helpful for your project, please cite:
50 |
51 | *GMLS-Nets: A Framework for Learning from Unstructured Data,*
52 | N. Trask, R. G. Patel, B. J. Gross, and P. J. Atzberger, arXiv:1909.05371, (2019), [[arXiv]](https://arxiv.org/abs/1909.05371).
53 | ```
54 | @article{trask_patel_gross_atzberger_GMLS_Nets_2019,
55 | title={GMLS-Nets: A framework for learning from unstructured data},
56 | author={Nathaniel Trask, Ravi G. Patel, Ben J. Gross, Paul J. Atzberger},
57 | journal={arXiv:1909.05371},
58 | month={September}
59 | year={2019}
60 | url={https://arxiv.org/abs/1909.05371}
61 | }
62 | ```
63 | For [TensorFlow](https://www.tensorflow.org/) implementation of GMLS-Nets, see https://github.com/rgp62/gmls-nets.
64 |
65 | __Acknowledgements__
66 | We gratefully acknowledge support from DOE Grant ASCR PHILMS DE-SC0019246.
67 |
68 | ----
69 |
70 |
71 | [Examples](https://github.com/atzberg/gmls-nets/tree/master/examples) | [Documentation](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html) | [Atzberger Homepage](http://atzberger.org/)
72 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | For more information on how to use this package, please see the [Documentation](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html) and look at the [Examples](https://github.com/atzberg/gmls-nets/tree/master/examples).
2 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | For more information on how to use this package, see the example Jupyter notebooks and python scripts above.
7 |
8 | More information on the structure of the package can also be found on the
9 | [documentation page](http://web.math.ucsb.edu/~atzberg/gmlsnets_docs/html/index.html)
10 |
--------------------------------------------------------------------------------
/examples/regression_diff_op_1d.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # ## GMLS-Nets: 1D Regression of Linear and Non-linear Operators $L[u]$.
5 | #
6 | # __Ben J. Gross__, __Paul J. Atzberger__
7 | # http://atzberger.org/
8 | #
9 | # Examples showing how GMLS-Nets can be used to perform regression for some basic linear and non-linear differential operators in 1D.
10 | #
11 | # __Parameters:__
12 | # The key parameter terms to adjust are:
13 | # ``op_type``: The operator type.
14 | # ``flag_mlp_case``: The type of mapping unit to use.
15 | #
16 | # __Examples of Non-linear Operators ($u{u_x},u_x^2,u{u_{xx}},u_{xx}^2$) :__
17 | # To run training for a non-linear operator like ``u*ux`` using MLP for the non-linear GMLS mapping unit, you can use:
18 | # ``op_type='u*ux';``
19 | # ``flag_mlp_case = 'NonLinear1';``
20 | # You can obtain different performance by adjusting the mapping architecture and hyperparameters of the network.
21 | #
22 | # __Examples of linear Operators ($u_x,u_{xx}$):__
23 | # To run training for a linear operator like the 1d Laplacian ``uxx`` with a linear mapping unit, you can use
24 | # ``op_type='uxx';``
25 | # ``flag_mlp_case = 'Linear1';``
26 | #
27 | # These are organized for different combinations of these settings allowing for exploring the methods. The codes are easy to modify and adjust to also experiment with other operators. For example, see the dataset classes.
28 | #
29 |
30 | # ### Imports
31 |
32 | # In[1]:
33 |
34 | print("="*80);
35 | print("GMLS-Nets: 1D Regression of Linear and Non-linear Operators $L[u]$.");
36 | print("-"*80);
37 |
38 | import sys;
39 |
40 | # setup path to location of gmlsnets_pytorch (if not install system-wide)
41 | path_gmlsnets_pytorch = '../../';
42 | sys.path.append(path_gmlsnets_pytorch);
43 |
44 | import torch;
45 | import torch.nn as nn;
46 |
47 | import numpy as np;
48 | import pickle;
49 |
50 | import matplotlib.pyplot as plt;
51 |
52 | import pdb
53 | import time
54 |
55 | import os
56 |
57 | # setup gmlsnets package
58 | import gmlsnets_pytorch as gmlsnets;
59 | import gmlsnets_pytorch.nn;
60 | import gmlsnets_pytorch.vis;
61 | import gmlsnets_pytorch.dataset;
62 |
63 | # dereference a few common items
64 | MapToPoly_Function = gmlsnets.nn.MapToPoly_Function;
65 | get_num_polys = MapToPoly_Function.get_num_polys;
66 | weight_one_minus_r = MapToPoly_Function.weight_one_minus_r;
67 | eval_poly = MapToPoly_Function.eval_poly;
68 |
69 | print("Packages:");
70 | print("torch.__version__ = " + str(torch.__version__));
71 | print("numpy.__version__ = " + str(np.__version__));
72 | print("gmlsnets.__version__ = " + str(gmlsnets.__version__));
73 |
74 |
75 | # ### Parameters and basic setup
76 |
77 | # In[2]:
78 |
79 |
80 | # Setup the parameters
81 | batch_size = int(1e2);
82 | flag_extend_periodic = False; # periodic boundaries
83 | flag_dataset = 'diffOp1';
84 | run_name = '%s_Test1'%flag_dataset;
85 | base_dir = './output/regression_diff_op_1d/%s'%run_name;
86 | flag_print_model = False;
87 |
88 | print("Settings:");
89 | print("flag_dataset = " + flag_dataset);
90 | print("run_name = " + run_name);
91 | print("base_dir = " + base_dir);
92 |
93 | if not os.path.exists(base_dir):
94 | os.makedirs(base_dir);
95 |
96 | # Configure devices
97 | if torch.cuda.is_available():
98 | num_gpus = torch.cuda.device_count();
99 | print("num_gpus = " + str(num_gpus));
100 | if num_gpus >= 4:
101 | device = torch.device('cuda:3');
102 | else:
103 | device = torch.device('cuda:0');
104 | else:
105 | device = torch.device('cpu');
106 | print("device = " + str(device));
107 |
108 |
109 | # ### Setup GMLS-Net for regressing differential operator
110 |
111 | # In[3]:
112 |
113 |
114 | class gmlsNetRegressionDiffOp1(nn.Module):
115 | """Sets up a GMLS-Net for regression differential operator in 1D."""
116 |
117 | def __init__(self,
118 | flag_GMLS_type=None,
119 | porder1=None,Nc=None,
120 | pts_x1=None,layer1_epsilon=None,
121 | weight_func1=None,weight_func1_params=None,
122 | mlp_q1=None,pts_x2=None,
123 | device=None,flag_verbose=0,
124 | **extra_params):
125 |
126 | super(gmlsNetRegressionDiffOp1, self).__init__();
127 |
128 | self.layer_types = [];
129 |
130 | if device is None:
131 | device = torch.device('cpu'); # default
132 |
133 | # --
134 | Ncp1 = mlp_q1.channels_out; # number of channels out of the MLP-Pointwise layer
135 |
136 | num_features1 = mlp_q1.channels_out; # number of channels out (16 typical)
137 |
138 | GMLS_Layer = gmlsnets.nn.GMLS_Layer;
139 | ExtractFromTuple = gmlsnets.nn.ExtractFromTuple;
140 | PermuteLayer = gmlsnets.nn.PermuteLayer;
141 | PdbSetTraceLayer = gmlsnets.nn.PdbSetTraceLayer;
142 |
143 | # --- Layer 1
144 | #flag_layer1 = 'standard_conv1';
145 | flag_layer1 = 'gmls1d_1';
146 | self.layer_types.append(flag_layer1);
147 | if flag_layer1 == 'standard_conv1':
148 | self.layer1 = nn.Sequential(
149 | nn.Conv1d(in_channels=Nc,out_channels=num_features1,
150 | kernel_size=5,stride=1,padding=2,bias=False),
151 | ).to(device);
152 | elif flag_layer1 == 'gmls1d_1':
153 | self.layer1 = nn.Sequential(
154 | GMLS_Layer(flag_GMLS_type, porder1,
155 | pts_x1, layer1_epsilon,
156 | weight_func1, weight_func1_params,
157 | mlp_q=mlp_q1, pts_x2=pts_x2, device=device,
158 | flag_verbose=flag_verbose),
159 | #PdbSetTraceLayer(),
160 | ExtractFromTuple(index=0), # just get the forward output associated with the mapping and not pts_x2
161 | #PdbSetTraceLayer(),
162 | PermuteLayer((0,2,1))
163 | ).to(device);
164 |
165 | else:
166 | raise Exception('flag_layer1 type not recognized.');
167 |
168 | def forward(self, x):
169 | out = self.layer1(x);
170 | return out;
171 |
172 |
173 | # ### Setup the Model: Neural Network
174 |
175 | # In[4]:
176 |
177 |
178 | # setup sample point locations
179 | xj = torch.linspace(0,1.0,steps=101,device=device).unsqueeze(1);
180 | xi = torch.linspace(0,1.0,steps=101,device=device).unsqueeze(1);
181 |
182 | # make a numpy copy for plotting and some other routines
183 | np_xj = xj.cpu().numpy(); np_xi = xi.cpu().numpy();
184 |
185 | # setup parameters
186 | Nc = 1; # scalar field
187 | Nx = xj.shape[0]; num_dim = xj.shape[1];
188 | porder = 2; num_polys = get_num_polys(porder,num_dim);
189 |
190 | weight_func1 = MapToPoly_Function.weight_one_minus_r;
191 | targ_kernel_width = 11.5; layer1_epsilon = 0.4*0.5*np.sqrt(2)*targ_kernel_width/Nx;
192 | #targ_kernel_width = 21.5; layer1_epsilon = 0.4*0.5*np.sqrt(2)*targ_kernel_width/Nx;
193 | weight_func1_params = {'epsilon': layer1_epsilon,'p':4};
194 |
195 | color_input = (0.05,0.44,0.69);
196 | color_output = (0.44,0.30,0.60);
197 | color_predict = (0.05,0.40,0.5);
198 | color_target = (221/255,103/255,103/255);
199 |
200 | # print the current settings
201 | print("GMLS Parameters:")
202 | print("porder = " + str(porder));
203 | print("num_dim = " + str(num_dim));
204 | print("num_polys = " + str(num_polys));
205 | print("layer1_epsilon = %.3e"%layer1_epsilon);
206 | print("weight_func1 = " + str(weight_func1));
207 | print("weight_func1_params = " + str(weight_func1_params));
208 | print("xj.shape = " + str(xj.shape));
209 | print("xi.shape = " + str(xi.shape));
210 |
211 |
212 | # In[5]:
213 |
214 |
215 | # create an MLP for training the non-linear part of the GMLS Net
216 | #flag_mlp_case = 'Linear1';flag_mlp_case = 'Nonlinear1'
217 | flag_mlp_case = 'Nonlinear1';
218 | if (flag_mlp_case == 'Linear1'):
219 | layer_sizes = [];
220 |
221 | num_depth = 0; # number of internal layers
222 | num_hidden = -1; # number of hidden per layer
223 |
224 | channels_in = Nc; # number of poly channels (matches input u channel size)
225 | channels_out = 1; # number of output filters
226 |
227 | layer_sizes.append(num_polys); # input
228 | layer_sizes.append(1); # output, single channel always, for vectors, we use channels_out separate units.
229 |
230 | mlp_q1 = gmlsnets.nn.MLP_Pointwise(layer_sizes,channels_in=channels_in,channels_out=channels_out,
231 | flag_bias=False).to(device);
232 | elif (flag_mlp_case == 'Nonlinear1'):
233 | layer_sizes = [];
234 | num_input = Nc*num_polys; # number of channels*num_polys, allows for cross-channel coupling
235 | num_depth = 4; # number of internal layers
236 | num_hidden = 100; # number of hidden per layer
237 | num_out_channels = 16; # number of output filters
238 | layer_sizes.append(num_polys);
239 | for k in range(num_depth):
240 | layer_sizes.append(num_hidden);
241 | layer_sizes.append(1); # output, single channel always, for vectors, we use channels_out separate units.
242 |
243 | mlp_q1 = gmlsnets.nn.MLP_Pointwise(layer_sizes,channels_out=num_out_channels,
244 | flag_bias=True).to(device);
245 |
246 | if flag_print_model:
247 | print("mlp_q1:");
248 | print(mlp_q1);
249 |
250 |
251 | # In[6]:
252 |
253 |
254 | # Setup the Neural Network for Regression
255 | flag_verbose = 0;
256 | flag_case = 'standard';
257 |
258 | # Setup the model
259 | xi = xi.float();
260 | xj = xj.float();
261 | model = gmlsNetRegressionDiffOp1(flag_case,porder,Nc,xj,layer1_epsilon,
262 | weight_func1,weight_func1_params,
263 | mlp_q1=mlp_q1,pts_x2=xi,
264 | device=device,
265 | flag_verbose=flag_verbose);
266 |
267 | if flag_print_model:
268 | print("model:");
269 | print(model);
270 |
271 |
272 | # ## Setup the training and test data
273 |
274 | # In[7]:
275 |
276 |
277 | ### Generate Dataset
278 |
279 | if flag_dataset == 'diffOp1':
280 | # Use the FFT to represent differential operators for training data sets.
281 | #
282 | # Setup a data set of the following:
283 | # To start let's do regression for the Laplacian (not inverse, just action of it, like finding FD)
284 | #
285 |
286 | #op_type = 'u*ux';op_type = 'ux*ux';op_type = 'uxx';op_type = 'u*uxx';op_type = 'uxx*uxx';
287 | op_type = 'u*ux';
288 | print("op_type = " + op_type);
289 |
290 | num_training_samples = int(5e4);
291 | nchannels = 1;
292 | nx = np_xj.shape[0];
293 | #alpha1 = 0.05;
294 | alpha1 = 0.1;
295 | scale_factor = 1e2;
296 | train_dataset = gmlsnets.dataset.diffOp1(op_type=op_type,op_params=None,
297 | gen_mode='exp1',gen_params={'alpha1':alpha1},
298 | num_samples=num_training_samples,
299 | nchannels=nchannels,nx=nx,
300 | noise_factor=0,scale_factor=scale_factor,
301 | flag_verbose=1);
302 |
303 | train_dataset = train_dataset.to(device);
304 | print("done.");
305 |
306 | num_test_samples = int(1e4);
307 | scale_factor = 1e2;
308 | test_dataset = gmlsnets.dataset.diffOp1(op_type=op_type,op_params=None,
309 | gen_mode='exp1',gen_params={'alpha1':alpha1},
310 | num_samples=num_test_samples,
311 | nchannels=nchannels,nx=nx,
312 | noise_factor=0,scale_factor=scale_factor,
313 | flag_verbose=1);
314 | test_dataset = test_dataset.to(device);
315 | print("done.");
316 |
317 | # Put the data into the
318 | #train_dataset and test_dataset structures for processing
319 |
320 | else:
321 | msg = "flag_dataset not recognized.";
322 | msg += "flag_data_set = " + str(flag_data_set);
323 | raise Exception(msg);
324 |
325 | # Data loader
326 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True);
327 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False);
328 |
329 |
330 | # In[8]:
331 |
332 |
333 | #get_ipython().run_line_magic('matplotlib', 'inline')
334 |
335 | # plot sample of the training data
336 | gmlsnets.vis.plot_dataset_diffOp1(train_dataset,np_xj,np_xi,rows=4,cols=6,
337 | title="Data Samples: u, f=L[u], L = %s"%op_type);
338 |
339 |
340 | # ## Train the Model
341 |
342 | # ### Custom Functions
343 |
344 | # In[9]:
345 |
346 |
347 | def custom_loss_least_squares(val1,val2):
348 | r"""Computes the Mean-Square-Error (MSE) over the entire batch."""
349 | diff_flat = (val1 - val2).flatten();
350 | N = diff_flat.shape[0];
351 | loss = torch.sum(torch.pow(diff_flat,2),-1)/N;
352 | return loss;
353 |
354 | def domain_periodic_repeat(Z):
355 | r"""Extends the input periodically."""
356 | Z_periodic = torch.cat((Z, Z, Z), 2);
357 | return Z_periodic;
358 |
359 | def domain_periodic_extract(Z_periodic):
360 | r"""Extracts the middle unit cell portion of the extended data."""
361 | nn = int(Z_periodic.shape[2]/3);
362 | Z = Z_periodic[:,:,nn:2*nn];
363 | return Z;
364 |
365 |
366 | # ### Initialize
367 |
368 | # In[10]:
369 |
370 |
371 | loss_list = np.empty(0); loss_step_list = np.empty(0);
372 | save_skip = 1; step_count = 0;
373 |
374 |
375 | # ### Train the network.
376 |
377 | # In[11]:
378 |
379 |
380 | num_epochs = int(3e0); #int(1e4);
381 | learning_rate = 1e-2;
382 |
383 | print("Training the network with:");
384 | print("");
385 | print("model:");
386 | print("model.layer_types = " + str(model.layer_types));
387 | print("");
388 |
389 | # setup the optimization method and loss function
390 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate);
391 |
392 | #loss_func = nn.CrossEntropyLoss();
393 | #loss_func = nn.MSELoss();
394 | loss_func = custom_loss_least_squares;
395 |
396 | print("num_epochs = %d"%num_epochs);
397 | print("batch_size = %d"%batch_size);
398 | print(" ");
399 |
400 | # Train the model
401 | flag_time_it = True;
402 | if flag_time_it:
403 | time_1 = time.time();
404 | print("-"*80);
405 | num_steps = len(train_loader);
406 | for epoch in range(num_epochs):
407 | for i, (input,target) in enumerate(train_loader):
408 | input = input.to(device);
409 | target = target.to(device);
410 |
411 | if flag_extend_periodic:
412 | # Extend input periodically
413 | input_periodic = domain_periodic_repeat(input);
414 |
415 | # Forward pass
416 | output_periodic = model(input_periodic);
417 | output = domain_periodic_extract(output_periodic);
418 | else:
419 | output = model(input);
420 |
421 | # Compute loss
422 | loss = loss_func(output,target);
423 |
424 | # Display
425 | if step_count % save_skip == 0:
426 | np_loss = loss.cpu().detach().numpy();
427 | loss_list = np.append(loss_list,np_loss);
428 | loss_step_list = np.append(loss_step_list,step_count);
429 |
430 | # Back-propagation for gradients and use to optimize
431 | optimizer.zero_grad();
432 | loss.backward();
433 |
434 | optimizer.step();
435 |
436 | step_count += 1;
437 |
438 | if ((i + 1) % 100) == 0 or i == 0:
439 | msg = 'epoch: [%d/%d]; '%(epoch+1,num_epochs);
440 | msg += 'batch_step = [%d/%d]; '%(i + 1,num_steps);
441 | msg += 'loss_MSE: %.3e.'%(loss.item());
442 | print(msg);
443 |
444 | if flag_time_it and i > 0:
445 | msg = 'elapsed_time = %.4e secs \n'%(time.time() - time_1);
446 | print(msg);
447 | time_1 = time.time();
448 |
449 |
450 | print("done training.")
451 | print("-"*80);
452 |
453 |
454 | # ### Plot Loss
455 |
456 | # In[12]:
457 |
458 |
459 | #get_ipython().run_line_magic('matplotlib', 'inline')
460 |
461 | plt.figure(figsize=(8,6));
462 |
463 | plt.plot(loss_step_list,loss_list,'b-');
464 | plt.yscale('log');
465 | plt.xlabel('step');
466 | plt.ylabel('loss');
467 |
468 | plt.title('Loss');
469 |
470 |
471 | # ### Test the Neural Network Predictions
472 |
473 | # In[13]:
474 |
475 |
476 | print("Testing predictions of the neural network:");
477 |
478 | flag_save_tests = True;
479 | if flag_save_tests:
480 | test_data = {};
481 |
482 | # Save the first few to show as examples of labeling
483 | saved_test_input = [];
484 | saved_test_target = [];
485 | saved_test_output_pred = [];
486 |
487 | count_batch = 0;
488 | with torch.no_grad():
489 | total = 0; II = 0;
490 | avg_error = 0;
491 | for input,target in test_loader: # loads data in batches and then sums up
492 |
493 | if (II >= 1000):
494 | print("tested on %d samples"%total);
495 | II = 0;
496 |
497 | input = input.to(device); target = target.to(device);
498 |
499 | # Compute model
500 | flag_extend_periodic = False;
501 | if flag_extend_periodic:
502 | # Extend input periodically
503 | input_periodic = domain_periodic_repeat(input);
504 |
505 | # Forward pass
506 | output_periodic = model(input_periodic);
507 | output = domain_periodic_extract(output_periodic);
508 | else:
509 | output = model(input);
510 |
511 | # Compute loss
512 | loss = loss_func(output,target);
513 |
514 | # Record the results
515 | avg_error += loss;
516 |
517 | total += output.shape[0];
518 | II += output.shape[0];
519 | count_batch += 1;
520 |
521 | NN = output.shape[0];
522 | for k in range(min(NN,20)): # save first 10 images of each batch
523 | saved_test_input.append(input[k]);
524 | saved_test_target.append(target[k]);
525 | saved_test_output_pred.append(output[k]);
526 |
527 | print("");
528 | print("Tested on a total of %d samples."%total);
529 | print("");
530 |
531 | # Compute RMSD error
532 | test_accuracy = avg_error.cpu()/count_batch;
533 | test_accuracy = np.sqrt(test_accuracy);
534 |
535 | print("The neural network has RMSD error %.2e on the %d test samples."%(test_accuracy,total));
536 | print("");
537 |
538 |
539 | # ### Show a Sample of the Predictions
540 |
541 | # In[14]:
542 |
543 |
544 | # collect a subset of the data to show and attach named labels
545 | #get_ipython().run_line_magic('matplotlib', 'inline')
546 |
547 | num_prediction_samples = len(saved_test_input);
548 | print("num_prediction_samples = " + str(num_prediction_samples));
549 |
550 | #II = np.random.permutation(num_samples); # compute random collection of indices @optimize
551 | II = np.arange(num_prediction_samples);
552 |
553 | if flag_dataset == 'name-here' or 0 == 0:
554 | u_list = []; f_list = []; f_pred_list = [];
555 | for I in np.arange(0,min(num_prediction_samples,16)):
556 | u_list.append(saved_test_input[II[I]].cpu());
557 | f_list.append(saved_test_target[II[I]].cpu());
558 | f_pred_list.append(saved_test_output_pred[II[I]].cpu());
559 |
560 | # plot predictions against test data
561 | gmlsnets.vis.plot_samples_u_f_fp_1d(u_list,f_list,f_pred_list,np_xj,np_xi,rows=4,cols=6,
562 | title="Test Samples and Predictions: u, f=L[u], L = %s"%op_type);
563 |
564 |
565 | # ### Save Model
566 |
567 | # In[15]:
568 |
569 |
570 | model_filename = '%s/model.ckpt'%base_dir;
571 | print("model_filename = " + model_filename);
572 | torch.save(model.state_dict(), model_filename);
573 |
574 | model_filename = "%s/model_state.pickle"%base_dir;
575 | print("model_filename = " + model_filename);
576 | f = open(model_filename,'wb');
577 | pickle.dump(model.state_dict(),f);
578 | f.close();
579 |
580 |
581 | # ### Display the GMLS-Nets Learned Parameters
582 |
583 | # In[16]:
584 |
585 |
586 | flag_run_cell = flag_print_model;
587 |
588 | if flag_run_cell:
589 | print("-"*80)
590 | print("model.parameters():");
591 | ll = model.parameters();
592 | for l in ll:
593 | print(l);
594 |
595 | if flag_run_cell:
596 | print("-"*80)
597 | print("model.state_dict():");
598 | print(model.state_dict());
599 | print("-"*80)
600 |
601 |
602 | # ### Done
603 |
604 | # In[ ]:
605 |
606 | print("="*80);
607 |
608 |
--------------------------------------------------------------------------------
/examples/regression_diff_op_2d.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # ## GMLS-Nets: 2D Regression of Linear and Non-linear Operators $L[u]$.
5 | #
6 | # __Ben J. Gross__, __Paul J. Atzberger__
7 | # http://atzberger.org/
8 | #
9 | # Examples showing how GMLS-Nets can be used to perform regression for some basic linear and non-linear differential operators in 2D.
10 | #
11 | # __Parameters:__
12 | # The key parameter terms to adjust are:
13 | # ``op_type``: The operator type.
14 | # ``flag_mlp_case``: The type of mapping unit to use.
15 | #
16 | # __Examples of Non-linear Operators ($u\Delta{u}, u\nabla{u},\nabla{u}\cdot\nabla{u}$) :__
17 | # To run training for a non-linear operator like $u\nabla{u}$ using MLP for the non-linear GMLS mapping unit, you can use:
18 | # ``op_type=r'u\grad{u}';``
19 | # ``flag_mlp_case = 'NonLinear1';``
20 | # You can obtain different performance by adjusting the mapping architecture and hyperparameters of the network.
21 | #
22 | # __Examples of linear Operators ($\nabla{u}, \Delta{u}$):__
23 | # To run training for a linear operator like the 1d Laplacian ``uxx`` with a linear mapping unit, you can use
24 | # ``op_type=r'\Delta{u}';``
25 | # ``flag_mlp_case = 'Linear1';``
26 | #
27 | # These are organized for different combinations of these settings allowing for exploring the methods. The codes are easy to modify and adjust to also experiment with regressing other operators. For example, see the dataset classes.
28 | #
29 |
30 | # In[1]:
31 |
32 | print("="*80);
33 | print("GMLS-Nets: 2D Regression of Linear and Non-linear Operators $L[u]$.");
34 | print("-"*80);
35 |
36 | import sys;
37 |
38 | # setup path to location of gmlsnets_pytorch (if not installed system-wide)
39 | path_gmlsnets_pytorch = '../../';
40 | sys.path.append(path_gmlsnets_pytorch);
41 |
42 | import torch;
43 | import torch.nn as nn;
44 |
45 | import numpy as np;
46 | import pickle;
47 |
48 | import matplotlib.pyplot as plt;
49 |
50 | import pdb
51 | import time
52 |
53 | import os
54 |
55 | # setup gmlsnets package
56 | import gmlsnets_pytorch as gmlsnets;
57 | import gmlsnets_pytorch.nn;
58 | import gmlsnets_pytorch.vis;
59 | import gmlsnets_pytorch.dataset;
60 | import gmlsnets_pytorch.util;
61 |
62 | # dereference a few common items
63 | MapToPoly_Function = gmlsnets.nn.MapToPoly_Function;
64 | get_num_polys = MapToPoly_Function.get_num_polys;
65 | weight_one_minus_r = MapToPoly_Function.weight_one_minus_r;
66 | eval_poly = MapToPoly_Function.eval_poly;
67 |
68 | print("Packages:");
69 | print("torch.__version__ = " + str(torch.__version__));
70 | print("numpy.__version__ = " + str(np.__version__));
71 | print("gmlsnets.__version__ = " + str(gmlsnets.__version__));
72 |
73 |
74 | # ### Parameters and basic setup
75 |
76 | # In[2]:
77 |
78 |
79 | # Setup the parameters
80 | batch_size = int(1e1);
81 | flag_extend_periodic = False; # periodic boundaries
82 | flag_dataset = 'diffOp2';
83 | run_name = '%s_Test1'%flag_dataset;
84 | base_dir = './output/regression_diff_op_2d/%s'%run_name;
85 | flag_save_figs = True;
86 | fig_base_dir = '%s/fig'%base_dir;
87 | flag_print_model = False;
88 |
89 | print("Settings:");
90 | print("flag_dataset = " + flag_dataset);
91 | print("run_name = " + run_name);
92 | print("base_dir = " + base_dir);
93 |
94 | if not os.path.exists(base_dir):
95 | os.makedirs(base_dir);
96 |
97 | if not os.path.exists(fig_base_dir):
98 | os.makedirs(fig_base_dir);
99 |
100 | # Configure devices
101 | if torch.cuda.is_available():
102 | num_gpus = torch.cuda.device_count();
103 | print("num_gpus = " + str(num_gpus));
104 | if num_gpus >= 4:
105 | device = torch.device('cuda:3');
106 | else:
107 | device = torch.device('cuda:0');
108 | else:
109 | device = torch.device('cpu');
110 | print("device = " + str(device));
111 |
112 |
113 | # ## Setup the training and test data
114 |
115 | # ### Input Points
116 |
117 | # In[3]:
118 |
119 |
120 | Nx = 21;Ny = 21;
121 | nx = Nx; ny = Ny;
122 |
123 | NNx = 3*Nx; NNy = 3*Ny; # simple periodic by tiling for now
124 | aspect_ratio = NNx/float(NNy);
125 | xx = np.linspace(-1.5,aspect_ratio*1.5,NNx); xx = xx.astype(float);
126 | yy = np.linspace(-1.5,1.5,NNy); yy = yy.astype(float);
127 |
128 | aa = np.meshgrid(xx,yy);
129 | np_xj = np.array([aa[0].flatten(), aa[1].flatten()]).T;
130 |
131 | aa = np.meshgrid(xx,yy);
132 | np_xi = np.array([aa[0].flatten(), aa[1].flatten()]).T;
133 |
134 | # make torch tensors
135 | xj = torch.from_numpy(np_xj).float().to(device); # convert to torch tensors
136 | xj.requires_grad = False;
137 |
138 | xi = torch.from_numpy(np_xi).float().to(device); # convert to torch tensors
139 | xi.requires_grad = False;
140 |
141 |
142 | # ### Generate the Dataset
143 |
144 | # In[4]:
145 |
146 |
147 | if flag_dataset == 'diffOp2':
148 |
149 | flag_verbose = 1;
150 | #op_type = r'\Delta{u}';op_type = r'\Delta{u}*\Delta{u}';
151 | #op_type = r'u\Delta{u}';op_type = r'\grad{u}';op_type = r'u\grad{u}';
152 | #op_type = r'\grad{u}\cdot\grad{u}';
153 | op_type = r'u\grad{u}';
154 | print("op_type = " + op_type);
155 |
156 | num_dim = 2;
157 | num_training_samples = int(5e3);
158 | nchannels_u = 1;
159 | Nc = nchannels_u;
160 |
161 | #alpha1 = 0.05;
162 | alpha1 = 0.3;
163 | scale_factor = 1e2;
164 | train_dataset = gmlsnets.dataset.diffOp2(op_type=op_type,op_params=None,
165 | gen_mode='exp1',gen_params={'alpha1':alpha1},
166 | num_samples=num_training_samples,
167 | nchannels=nchannels_u,nx=nx,ny=ny,
168 | noise_factor=0,scale_factor=scale_factor,
169 | flag_verbose=1);
170 |
171 | train_dataset = train_dataset.to(device);
172 | print("done.");
173 |
174 | num_test_samples = int(1e3);
175 | scale_factor = 1e2;
176 | test_dataset = gmlsnets.dataset.diffOp2(op_type=op_type,op_params=None,
177 | gen_mode='exp1',gen_params={'alpha1':alpha1},
178 | num_samples=num_test_samples,
179 | nchannels=nchannels_u,nx=nx,ny=ny,
180 | noise_factor=0,scale_factor=scale_factor,
181 | flag_verbose=1);
182 | test_dataset = test_dataset.to(device);
183 | print("done.");
184 |
185 | else:
186 | msg = "flag_dataset not recognized.";
187 | msg += "flag_dataset = " + str(flag_dataset);
188 | raise Exception(msg);
189 |
190 | # Data loader
191 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True);
192 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False);
193 |
194 | # Cound number of output channels in f, determines if scalar or vector valued data
195 | nchannels_f = train_dataset.samples_Y.shape[1];
196 |
197 |
198 | # ## Show Data
199 |
200 | # In[5]:
201 |
202 |
203 | flag_run_cell = True;
204 | if flag_run_cell:
205 | # Show subset of the data
206 | img_arr = [];
207 | label_str_arr = [];
208 |
209 | numImages = len(train_dataset);
210 | #II = np.random.permutation(numImages); # compute random collection of indices @optimize
211 | II = np.arange(numImages);
212 |
213 | if flag_dataset == '' or 0 == 0:
214 | img_arr = [];
215 | channelI = 0; # for vector-valued fields, choose component.
216 | for I in np.arange(0,min(num_training_samples,16)):
217 | img_arr.append(train_dataset[II[I],channelI,:][0].cpu());
218 | gmlsnets.vis.plot_image_array(img_arr,title=r'$u^{[i]}$ Samples',figSize=(6,6),title_yp=0.95);
219 |
220 | if flag_save_figs:
221 | fig_name = 'samples_u';
222 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True);
223 |
224 | img_arr = [];
225 | channelI = 0;
226 | for I in np.arange(0,min(num_training_samples,16)):
227 | img_arr.append(train_dataset[II[I],channelI,:][1].cpu());
228 | gmlsnets.vis.plot_image_array(img_arr,title=r'$f^{[i]}$ Samples',figSize=(6,6),title_yp=0.95);
229 |
230 | if flag_save_figs:
231 | fig_name = 'samples_f';
232 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True);
233 |
234 |
235 |
236 |
237 | # ### Side-by-Side View
238 |
239 | # In[6]:
240 |
241 |
242 | #get_ipython().run_line_magic('matplotlib', 'inline')
243 |
244 | for Ic_f in range(0,nchannels_f):
245 | # plot sample of the training data
246 | gmlsnets.vis.plot_dataset_diffOp2(train_dataset,np_xj,np_xi,rows=4,cols=6,channelI_f=Ic_f,
247 | title="Data Samples: u, f=L[u], L = %s, Ic_f = %d"%(op_type,Ic_f));
248 |
249 | if flag_save_figs:
250 | fig_name = 'samples_u_f_Ic_f_%d'%Ic_f;
251 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True);
252 |
253 |
254 | # ### Setup GMLS-Net for regressing differential operator
255 |
256 | # In[7]:
257 |
258 |
259 | class gmlsNetRegressionDiffOp2(nn.Module):
260 | """Sets up a GMLS-Net for regression differential operator in 2D."""
261 | def __init__(self,
262 | flag_GMLS_type=None,
263 | porder1=None,Nx=None,Ny=None,Nc=None,pts_x1=None,layer1_epsilon=None,
264 | weight_func1=None,weight_func1_params=None,
265 | mlp_q1=None,pts_x2=None,
266 | device=None,flag_verbose=0,
267 | **extra_params):
268 |
269 | super(gmlsNetRegressionDiffOp2, self).__init__();
270 |
271 | # setup the layers
272 | self.layer_types = [];
273 |
274 | if device is None:
275 | device = torch.device('cpu'); # default
276 |
277 | # --
278 | Ncp1 = mlp_q1.channels_out; # number of channels out of the MLP-Pointwise layer
279 |
280 | num_features1 = mlp_q1.channels_out; # number of channels out
281 |
282 | GMLS_Layer = gmlsnets.nn.GMLS_Layer;
283 | ExtractFromTuple = gmlsnets.nn.ExtractFromTuple;
284 | PermuteLayer = gmlsnets.nn.PermuteLayer;
285 | ReshapeLayer = gmlsnets.nn.ReshapeLayer;
286 | PdbSetTraceLayer = gmlsnets.nn.PdbSetTraceLayer;
287 |
288 | PeriodicPad2d = gmlsnets.util.PeriodicPad2d;
289 | ExtractUnitCell2d = gmlsnets.util.ExtractUnitCell2d;
290 |
291 | # --- Layer 1
292 | #flag_layer1 = 'standard_conv1';
293 | flag_layer1 = 'gmls2d_1';
294 | self.layer_types.append(flag_layer1);
295 | if flag_layer1 == 'standard_conv1':
296 | self.layer1 = nn.Sequential(
297 | PeriodicPad2d(), # expands width by 3
298 | nn.Conv2d(in_channels=Nc,out_channels=num_features1,
299 | kernel_size=5,stride=1,padding=2,bias=False),
300 | ExtractUnitCell2d()
301 | ).to(device);
302 | elif flag_layer1 == 'gmls2d_1':
303 | NNx = 3*Nx; NNy = 3*Ny; # periodic extensions
304 | reshape_ucell_to_GMLS_data = (-1,Nc,NNx*NNy); # map from assumed (batch,Nc,Nx,Ny) --> (batch,Nc,Nx*Ny)
305 | permute_ucell_to_GMLS_data = None; # identity map
306 |
307 | permute_pre_GMLS_data_to_ucell = (0,2,1); # maps (batch,Nx*Ny,Ncp) --> (batch,Ncp,Nx*Ny)
308 |
309 | reshape_GMLS_data_to_ucell = (-1,Ncp1,NNx,NNy); # maps to (batch,Ncp,Nx,Ny), Ncp new channels
310 | permute_GMLS_data_to_ucell = None; # identity map
311 |
312 | self.layer1 = nn.Sequential(
313 | PeriodicPad2d(), # expands 2D unit cell perodically with width by x3
314 | ReshapeLayer(reshape_ucell_to_GMLS_data, permute_ucell_to_GMLS_data), # reshape for GMLS-Net
315 | #PdbSetTraceLayer(),
316 | GMLS_Layer(flag_GMLS_type, porder1,
317 | pts_x1, layer1_epsilon,
318 | weight_func1, weight_func1_params,
319 | mlp_q=mlp_q1, pts_x2=pts_x2, device=device,
320 | flag_verbose=flag_verbose),
321 | #PdbSetTraceLayer(),
322 | ExtractFromTuple(index=0), # just get the forward output associated with the mapping and not pts_x2
323 | PermuteLayer(permute_pre_GMLS_data_to_ucell),
324 | ReshapeLayer(reshape_GMLS_data_to_ucell, permute_GMLS_data_to_ucell), # reshape the data out of GMLS
325 | #PdbSetTraceLayer(),
326 | ExtractUnitCell2d()
327 | ).to(device);
328 |
329 | else:
330 | raise Exception('The flag_layer1 type was not recognized.');
331 |
332 | def forward(self, x):
333 | out = self.layer1(x);
334 | return out;
335 |
336 |
337 | # ### Setup the Model: Neural Network
338 |
339 | # In[8]:
340 |
341 |
342 | # setup parameters
343 | porder = 2; num_polys = get_num_polys(porder,num_dim);
344 |
345 | weight_func1 = MapToPoly_Function.weight_one_minus_r;
346 | targ_kernel_width = 1.5; layer1_epsilon = 1.1*0.5*np.sqrt(2)*targ_kernel_width/Nx;
347 | weight_func1_params = {'epsilon': layer1_epsilon,'p':4};
348 |
349 | #color_input = (0.05,0.44,0.69);
350 | #color_output = (0.44,0.30,0.60);
351 | #color_predict = (0.05,0.40,0.5);
352 | #color_target = (221/255,103/255,103/255);
353 |
354 | # print the current settings
355 | print("GMLS Parameters:")
356 | print("porder = " + str(porder));
357 | print("num_polys = " + str(num_polys));
358 | print("layer1_epsilon = %.3e"%layer1_epsilon);
359 | print("weight_func1 = " + str(weight_func1));
360 | print("weight_func1_params = " + str(weight_func1_params));
361 | print("num_dim = " + str(num_dim));
362 | print("Nx = %d, Ny = %d, Nc = %d"%(Nx,Ny,Nc));
363 | print("xj.shape = " + str(xj.shape));
364 | print("xi.shape = " + str(xi.shape));
365 |
366 |
367 | # In[9]:
368 |
369 |
370 | # create an MLP for training the non-linear part of the GMLS Net
371 | num_polys = get_num_polys(porder,num_dim);
372 |
373 | print("num_dim = " + str(num_dim));
374 | print("num_polys = " + str(num_polys));
375 | print("Nx = " + str(Nx) + " Ny = " + str(Ny));
376 | print("NNx = " + str(NNx) + " NNy = " + str(NNy));
377 | print("nchannels_f = " + str(nchannels_f));
378 |
379 | print("");
380 |
381 | #flag_mlp_case = 'Linear1';flag_mlp_case = 'Nonlinear1';
382 | flag_mlp_case = 'Nonlinear1';
383 |
384 | if (flag_mlp_case == 'Linear1'):
385 |
386 | # -- Layer 1
387 | layer_sizes = [];
388 |
389 | num_depth = 0; # number of internal depth
390 | num_hidden = -1;
391 |
392 | channels_in = Nc; # number of poly channels for input
393 | #channels_out = 16; # number of output filters [might want to decouple by using separate weights here]
394 | channels_out = nchannels_f; # number of output filters, (scalar=1, vector=2,3,...)
395 | Nc2 = channels_out;
396 |
397 | layer_sizes.append(num_polys);
398 | for k in range(num_depth):
399 | layer_sizes.append(num_hidden);
400 | layer_sizes.append(1); # for single unit always give scalar output, we then use channels_out units.
401 |
402 | mlp_q1 = gmlsnets.nn.MLP_Pointwise(layer_sizes,
403 | channels_in=channels_in,
404 | channels_out=channels_out,
405 | flag_bias=False);
406 | mlp_q1.to(device);
407 |
408 | elif (flag_mlp_case == 'Nonlinear1'):
409 | layer_sizes = [];
410 | num_input = Nc*num_polys; # number of channels*num_polys (cross-channel coupling allowed)
411 | #num_depth = 4; # number of internal depth
412 | num_depth = 1; # number of internal depth
413 | num_hidden = 500;
414 | num_out_channels = nchannels_f; # number of output filters, (scalar=1, vector=2,3,...)
415 | layer_sizes.append(num_polys);
416 | for k in range(num_depth):
417 | layer_sizes.append(num_hidden);
418 | layer_sizes.append(1); # for single unit always give scalar output, we then use channels_out units.
419 | #layer_sizes.append(num_out_channels);
420 |
421 | #mlp_q = gmlsnets.nn.atzGMLS_MLP1_Module(layer_sizes);
422 | mlp_q1 = gmlsnets.nn.MLP_Pointwise(layer_sizes,channels_out=num_out_channels);
423 | mlp_q1.to(device);
424 |
425 | if flag_print_model:
426 | print("mlp_q1:");
427 | print(mlp_q1);
428 |
429 |
430 | # In[10]:
431 |
432 |
433 | # Setup the Neural Network for Regression
434 | flag_verbose = 0;
435 | flag_case = 'standard';
436 |
437 | # Setup the model
438 | model = gmlsNetRegressionDiffOp2(flag_case,porder,Nx,Ny,Nc,xj,layer1_epsilon,
439 | weight_func1,weight_func1_params,
440 | mlp_q1=mlp_q1,pts_x2=xi,
441 | device=device,
442 | flag_verbose=flag_verbose);
443 |
444 | if flag_print_model:
445 | print("model:");
446 | print(model);
447 |
448 |
449 | # ## Train the Model
450 |
451 | # ### Custom Functions
452 |
453 | # In[11]:
454 |
455 |
456 | def custom_loss_least_squares(val1,val2):
457 | r"""Computes the Mean-Square-Error (MSE) over the entire batch."""
458 | diff_flat = (val1 - val2).flatten();
459 | N = diff_flat.shape[0];
460 | loss = torch.sum(torch.pow(diff_flat,2),-1)/N;
461 | return loss;
462 |
463 |
464 | # ### Initialize
465 |
466 | # In[12]:
467 |
468 |
469 | # setup one-time data structures
470 | loss_list = np.empty(0); loss_step_list = np.empty(0);
471 | save_skip = 1; step_count = 0;
472 |
473 |
474 | # ### Train the network.
475 |
476 | # In[13]:
477 |
478 |
479 | num_epochs = int(2e0);
480 | learning_rate = 1e-1;
481 |
482 | print("Training the model with:");
483 |
484 | print("model:");
485 | print("model.layer_types = " + str(model.layer_types));
486 | print("");
487 |
488 | # setup the optimization method and loss function
489 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate);
490 |
491 | loss_func = custom_loss_least_squares;
492 |
493 | print("num_epochs = %d"%num_epochs);
494 | print("batch_size = %d"%batch_size);
495 | print(" ");
496 |
497 | # Train the model
498 | flag_time_it = True;
499 | if flag_time_it:
500 | time_1 = time.time();
501 |
502 | numSteps = len(train_loader);
503 | for epoch in range(num_epochs):
504 | for i, (input,target) in enumerate(train_loader):
505 | #if 1 == 1:
506 | input = input.to(device);
507 | target = target.to(device);
508 |
509 | # Compute model
510 | output = model(input);
511 |
512 | # Compute loss
513 | loss = loss_func(output,target);
514 |
515 | # Display
516 | if step_count % save_skip == 0:
517 | np_loss = loss.cpu().detach().numpy();
518 | loss_list = np.append(loss_list,np_loss);
519 | loss_step_list = np.append(loss_step_list,step_count);
520 |
521 | # Backward and optimize
522 | optimizer.zero_grad();
523 | loss.backward();
524 |
525 | #for p in model.parameters():
526 | # print(p.grad);
527 |
528 | optimizer.step();
529 |
530 | step_count += 1;
531 |
532 | if ((i + 1) % 100) == 0 or i == 0:
533 | #if ((step_count + 1) % int(1e3)) == 0 or step_count == 0:
534 | #print("WARNING: Debug mode...");
535 | msg = 'epoch: [%d/%d]; '%(epoch+1,num_epochs);
536 | msg += 'batch_step = [%d/%d]; '%(i + 1,numSteps);
537 | msg += 'loss_MSE: %.3e.'%(loss.item());
538 | print(msg);
539 |
540 | if flag_time_it and i > 0:
541 | msg = 'elapsed_time = %.1e secs \n'%(time.time() - time_1);
542 | print(msg);
543 | time_1 = time.time();
544 |
545 | print("done.");
546 |
547 |
548 | # ### Plot Loss
549 |
550 | # In[14]:
551 |
552 |
553 | #get_ipython().run_line_magic('matplotlib', 'inline')
554 |
555 | plt.figure(figsize=(8,6));
556 |
557 | plt.plot(loss_step_list,loss_list,'b-');
558 | plt.yscale('log');
559 | plt.xlabel('step');
560 | plt.ylabel('loss');
561 |
562 | plt.title('Loss');
563 |
564 | if flag_save_figs:
565 | fig_name = 'training_loss';
566 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True);
567 |
568 |
569 | # ### Test the Neural Network Predictions
570 |
571 | # In[15]:
572 |
573 |
574 | print("Testing predictions of the neural network:");
575 |
576 | flag_save_tests = True;
577 | if flag_save_tests:
578 | test_data = {};
579 |
580 | # Save the first few to show as examples of labeling
581 | saved_test_input = [];
582 | saved_test_target = [];
583 | saved_test_output_pred = [];
584 |
585 | count_batch = 0;
586 | with torch.no_grad():
587 | total = 0; II = 0;
588 | avg_error = 0;
589 | for input,target in test_loader: # loads data in batches and then sums up
590 |
591 | if (II >= 1000):
592 | print("tested on %d samples"%total);
593 | II = 0;
594 |
595 | input = input.to(device); target = target.to(device);
596 |
597 | # Compute model
598 | output = model(input);
599 |
600 | # Compute loss
601 | loss = loss_func(output,target);
602 |
603 | # Record the results
604 | avg_error += loss;
605 |
606 | total += output.shape[0];
607 | II += output.shape[0];
608 | count_batch += 1;
609 |
610 | NN = output.shape[0];
611 | for k in range(min(NN,20)): # save first 10 images of each batch
612 | saved_test_input.append(input[k]);
613 | saved_test_target.append(target[k]);
614 | saved_test_output_pred.append(output[k]);
615 |
616 | print("");
617 | print("Tested on a total of %d samples."%total);
618 | print("");
619 |
620 | # Compute RMSD error
621 | test_accuracy = avg_error.cpu()/count_batch;
622 | test_accuracy = np.sqrt(test_accuracy);
623 |
624 | print("The neural network has RMSD error %.2e on the %d test samples."%(test_accuracy,total));
625 | print("");
626 |
627 |
628 | # ### Show a Sample of the Predictions
629 |
630 | # In[16]:
631 |
632 |
633 | # collect a subset of the data to show and attach named labels
634 | #get_ipython().run_line_magic('matplotlib', 'inline')
635 |
636 | num_prediction_samples = len(saved_test_input);
637 | print("num_prediction_samples = " + str(num_prediction_samples));
638 |
639 | #II = np.random.permutation(num_samples); # compute random collection of indices @optimize
640 | II = np.arange(num_prediction_samples);
641 |
642 | if flag_dataset == 'name-here' or 0 == 0:
643 | u_list = []; f_target_list = []; f_pred_list = [];
644 | for I in np.arange(0,min(num_prediction_samples,16)):
645 | u_list.append(saved_test_input[II[I]].cpu());
646 | f_target_list.append(saved_test_target[II[I]].cpu());
647 | f_pred_list.append(saved_test_output_pred[II[I]].cpu());
648 |
649 | # plot predictions against test data
650 | for Ic_f in range(0,nchannels_f):
651 | title = "Test Samples and Predictions: u, f=L[u], L = %s, Ic_f = %d"%(op_type,Ic_f);
652 | gmlsnets.vis.plot_samples_u_f_fp_2d(u_list,f_target_list,f_pred_list,np_xj,np_xi,
653 | channelI_f=Ic_f,rows=4,cols=6,
654 | title=title);
655 |
656 | if flag_save_figs:
657 | fig_name = 'predictions_Ic_f_%d'%Ic_f;
658 | gmlsnets.vis.save_fig('%s/%s'%(fig_base_dir,fig_name),'',flag_verbose=True,dpi_set=200,flag_pdf=True);
659 |
660 |
661 | # ### Save Model
662 |
663 | # In[17]:
664 |
665 |
666 | model_filename = '%s/model.ckpt'%base_dir;
667 | print("model_filename = " + model_filename);
668 | torch.save(model.state_dict(), model_filename);
669 |
670 | model_filename = "%s/model_state.pickle"%base_dir;
671 | print("model_filename = " + model_filename);
672 | f = open(model_filename,'wb');
673 | pickle.dump(model.state_dict(),f);
674 | f.close();
675 |
676 |
677 | # ### Display the GMLS-Nets Learned Parameters
678 |
679 | # In[18]:
680 |
681 |
682 | flag_run_cell = flag_print_model;
683 |
684 | if flag_run_cell:
685 | print("-"*80)
686 | print("model.parameters():");
687 | ll = model.parameters();
688 | for l in ll:
689 | print(l);
690 |
691 | if flag_run_cell:
692 | print("-"*80)
693 | print("model.state_dict():");
694 | print(model.state_dict());
695 | print("-"*80)
696 |
697 |
698 | # ### Done
699 |
700 | # In[ ]:
701 |
702 |
703 | print("="*80);
704 |
705 |
--------------------------------------------------------------------------------
/examples/z_doc_img/diff_op_1d_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atzberg/gmls-nets/aaa036b9ea9fcdbcf7b5892fd55886dbd3b61915/examples/z_doc_img/diff_op_1d_1.png
--------------------------------------------------------------------------------
/examples/z_doc_img/diff_op_2d_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atzberg/gmls-nets/aaa036b9ea9fcdbcf7b5892fd55886dbd3b61915/examples/z_doc_img/diff_op_2d_1.png
--------------------------------------------------------------------------------
/gmlsnets_pytorch-1.0.0.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atzberg/gmls-nets/aaa036b9ea9fcdbcf7b5892fd55886dbd3b61915/gmlsnets_pytorch-1.0.0.tar.gz
--------------------------------------------------------------------------------
/misc/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atzberg/gmls-nets/aaa036b9ea9fcdbcf7b5892fd55886dbd3b61915/misc/overview.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # global initialization
2 | name="gmlsnets_pytorch"; # package name
3 | __version__="1.0.0"; # package version
4 |
5 |
6 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Collection of codes for generating some training data sets.
3 | """
4 |
5 | # Authors: B.J. Gross and P.J. Atzberger
6 | # Website: http://atzberger.org/
7 |
8 | import torch;
9 | import numpy as np;
10 | import pdb;
11 |
12 | class diffOp1(torch.utils.data.Dataset):
13 | r"""
14 | Generates samples of the form :math:`(u^{[i]},f^{[i]})` where :math:`f^{[i]} = L[u^{[i]}]`,
15 | where :math:`i` denotes the index of the sample.
16 |
17 | Stores data samples in the form :math:`(u,f)`.
18 |
19 | The samples of u are represented as a tensor of size [nsamples,nchannels,nx]
20 | and sample of f as a tensor of size [nsamples,nchannels,nx].
21 |
22 | Note:
23 | For now, please use nx that is odd. In this initial implementation, we use a
24 | method based on conjugated flips with formula for the odd case which is slightly
25 | simpler than other case.
26 |
27 | """
28 | def flipForFFT(self,u_k_part):
29 | r"""We flip as :math:`f_k = f_{N-k}`. Notice that only :math:`0,\ldots,N-1` entries
30 | stored. This is useful for constructing real-valued function representations
31 | from random coefficients. Real-valued function requires :math:`conj(f_k) = f_{N-k}`.
32 | We can use this flip to construct from random coefficients the term
33 | :math:`u_k = f_k + conj(flip(f_k))`, then above constraint is satisfied.
34 |
35 | Args:
36 | a (Tensor): 1d array to flip.
37 |
38 | Returns:
39 | Tensor: The flipped tensors symmetric under conjucation.
40 | """
41 | nx = self.nx;
42 | uu = u_k_part[:,:,nx:0:-1];
43 | vv = u_k_part[:,:,0];
44 | vv = np.expand_dims(vv,2);
45 | uu_k_flip = np.concatenate([vv,uu],2);
46 |
47 | return uu_k_flip;
48 |
49 | def getComplex(self,a,b):
50 | j = np.complex(0,1); # create complex number (or use 1j).
51 | c = a + j*b;
52 | return c;
53 |
54 | def getRealImag(self,c):
55 | a = np.real(c);
56 | b = np.imag(c);
57 | return a,b;
58 |
59 | def computeLSymbol_ux(self):
60 | r"""Compute associated Fourier symbols for use under DFT for the operator L[u]."""
61 | nx = self.nx;
62 | vec_k1 = torch.zeros(nx);
63 | vec_k1_pp = torch.zeros(nx);
64 | vec_k_sq = torch.zeros(nx);
65 | L_symbol_real = torch.zeros(nx,dtype=torch.float32);
66 | L_symbol_imag = torch.zeros(nx,dtype=torch.float32);
67 | two_pi = 2.0*np.pi;
68 | #two_pi_i = two_pi*1j; # $2\pi{i}$, 1j = sqrt(-1)
69 | for i in range(0,nx):
70 | vec_k1[i] = i;
71 | if (vec_k1[i] < nx/2):
72 | vec_k1_p = vec_k1[i];
73 | else:
74 | vec_k1_p = vec_k1[i] - nx;
75 | vec_k1_pp[i] = vec_k1_p;
76 | L_symbol_real[i] = 0.0;
77 | L_symbol_imag[i] = two_pi*vec_k1_p;
78 |
79 | L_hat = self.getComplex(L_symbol_real.numpy(),L_symbol_imag.numpy());
80 |
81 | return L_hat, vec_k1_pp;
82 |
83 | def computeLSymbol_uxx(self):
84 | r"""Compute associated Fourier symbols for use under DFT for the operator L[u]."""
85 | nx = self.nx;
86 | vec_k1 = torch.zeros(nx);
87 | vec_k1_pp = torch.zeros(nx);
88 | vec_k_sq = torch.zeros(nx);
89 | L_symbol_real = torch.zeros(nx,dtype=torch.float32);
90 | L_symbol_imag = torch.zeros(nx,dtype=torch.float32);
91 | neg_four_pi_sq = -4.0*np.pi*np.pi;
92 | for i in range(0,nx):
93 | vec_k1[i] = i;
94 | vec_k_sq[i] = vec_k1[i]*vec_k1[i];
95 | if (vec_k1[i] < nx/2):
96 | vec_k1_p = vec_k1[i];
97 | else:
98 | vec_k1_p = vec_k1[i] - nx;
99 | vec_k1_pp[i] = vec_k1_p;
100 | vec_k_p_sq = vec_k1_p*vec_k1_p;
101 | L_symbol_real[i] = neg_four_pi_sq*vec_k_p_sq;
102 | L_symbol_imag[i] = 0.0;
103 |
104 | L_hat = self.getComplex(L_symbol_real.numpy(),L_symbol_imag.numpy());
105 |
106 | return L_hat, vec_k1_pp;
107 |
108 | def computeCoeffActionL(self,u_hat,L_hat):
109 | r"""Computes the action of operator L used for data generation in Fourier space."""
110 | u_k_real, u_k_imag = self.getRealImag(u_hat);
111 | L_symbol_real, L_symbol_imag = self.getRealImag(L_hat);
112 |
113 | f_k_real = L_symbol_real*u_k_real - L_symbol_imag*u_k_imag; #broadcast will distr over copies of u.
114 | f_k_imag = L_symbol_real*u_k_imag + L_symbol_imag*u_k_real;
115 |
116 | # Generate samples u and f using ifft
117 | f_hat = self.getComplex(f_k_real,f_k_imag);
118 |
119 | return f_hat;
120 |
121 | def computeActionL(self,u,L_hat):
122 | r"""Computes the action of operator L used for data generation."""
123 | raise Exception('Currently this routine not debugged, need to test first.')
124 |
125 | if flag_verbose > 0:
126 | print("computeActionL(): WARNING: Not yet fully tested.");
127 |
128 | # perform FFT to get u_hat
129 | u_hat = np.fft.fft(u);
130 |
131 | # compute action of L_hat
132 | f_hat = self.computeCoeffActionL(u_hat,L_hat);
133 |
134 | # compute inverse FFT to get f
135 | f = np.fft.ifft(f_hat);
136 |
137 | return f;
138 |
139 | def __init__(self,op_type='uxx',op_params=None,
140 | gen_mode='exp1',gen_params={'alpha1':0.1},
141 | num_samples=int(1e4),nchannels=1,nx=15,
142 | flag_verbose=0, **extra_params):
143 | r"""Setup for data generation.
144 |
145 | Args:
146 | op_type (str): The differential operator to sample.
147 | op_params (dict): The operator parameters.
148 | gen_mode (str): The mode for the data generator.
149 | gen_params (dict): The parameters for the given generator.
150 | num_samples (int): The number of samples to generate.
151 | nchannels (int): The number of channels.
152 | nx (int): The number of input sample points.
153 | flag_verbose (int): Level of reporting during calculations.
154 | extra_params (dict): Extra parameters for the sampler.
155 |
156 | For extra_params we have:
157 | noise_factor (float): The amount of noise to add to samples.
158 | scale_factor (float): A factor to scale magnitude of the samples.
159 | flagComputeL (bool): If the fourier symbol of operator should be computed.
160 |
161 | For generator modes we have:
162 | gen_mode == 'exp1':
163 | alpha1 (float): The decay rate.
164 |
165 | Note:
166 | For now, please use only nx that is odd. In this initial implementation, we use a
167 | method based on conjugated flips with formula for the odd case which is slightly
168 | simpler than other case.
169 | """
170 | super(diffOp1, self).__init__();
171 |
172 | if flag_verbose > 0:
173 | print("Generating the data samples which can take some time.");
174 | print("num_samples = %d"%num_samples);
175 |
176 | self.op_type=op_type;
177 | self.op_params=op_params;
178 |
179 | self.gen_mode=gen_mode;
180 | self.gen_params=gen_params;
181 |
182 | self.num_samples=num_samples;
183 | self.nchannels=nchannels;
184 | self.nx=nx;
185 |
186 | if (nx % 2 == 0):
187 | msg = "Not allowed yet to use nx that is even. ";
188 | msg += "For now, please just use nx that is odd given the flips currently used."
189 | raise Exception(msg);
190 |
191 | noise_factor=0;scale_factor=1.0;flagComputeL=False; # default values
192 | if 'noise_factor' in extra_params:
193 | noise_factor = extra_params['noise_factor'];
194 |
195 | if 'scale_factor' in extra_params:
196 | scale_factor = extra_params['scale_factor'];
197 |
198 | if 'flagComputeL' in extra_params:
199 | flagComputeL = extra_params['flagComputeL'];
200 |
201 | # Generate for the operator the Fourier symbols
202 | if self.op_type == 'ux' or self.op_type == 'u*ux' or self.op_type == 'ux*ux':
203 | L_hat, vec_k1_pp = self.computeLSymbol_ux();
204 | elif self.op_type == 'uxx' or self.op_type == 'u*uxx' or self.op_type == 'uxx*uxx':
205 | L_hat, vec_k1_pp = self.computeLSymbol_uxx();
206 | else:
207 | raise Exception("Unkonwn operator type.");
208 |
209 | if (flagComputeL):
210 | L_i = np.fft.ifft(L_hat);
211 | self.L_hat = L_hat;
212 | self.L_i = L_i;
213 | u = np.zeros(nx);
214 | i0 = int(nx/2);
215 | u[i0] = 1.0;
216 | self.G_i = self.computeActionL(u);
217 |
218 | # Generate random input function (want real-valued)
219 | # conj(u_k) = u_{N -k} needs to hold.
220 | u_k_real = np.random.randn(num_samples,nchannels,nx);
221 | u_k_imag = np.random.randn(num_samples,nchannels,nx);
222 |
223 | # scale modes to make smooth
224 | if gen_mode=='exp1':
225 | alpha1 = gen_params['alpha1'];
226 | factor_k = scale_factor*np.exp(-alpha1*vec_k1_pp**2);
227 | factor_k = factor_k.numpy();
228 | else:
229 | raise Exception("Generation mode not recognized.");
230 |
231 | u_k_real = u_k_real*factor_k; # broadcast will apply over last two dimensions
232 | u_k_imag = u_k_imag*factor_k; # broadcast will apply over last two dimensions
233 |
234 | flag_debug = False;
235 | if flag_debug:
236 | if flag_verbose > 0:
237 | print("WARNING: debugging mode on.");
238 |
239 | u_k_real = 0.0*u_k_real;
240 | u_k_imag = 0.0*u_k_imag;
241 |
242 | u_k_real[0,0,1] = nx;
243 | u_k_imag[0,0,1] = 0;
244 |
245 | u_k_real[1,0,1] = 0;
246 | u_k_imag[1,0,1] = nx;
247 |
248 | u_k_real[2,0,1] = nx;
249 | u_k_imag[2,0,1] = nx;
250 |
251 | # flip modes for constructing rep of real-valued function
252 | u_k_real_flip = self.flipForFFT(u_k_real);
253 | u_k_imag_flip = self.flipForFFT(u_k_imag);
254 |
255 | u_k_real_p = 0.5*u_k_real + 0.5*u_k_real_flip; # make conjugate conj(u_k) = u_{N -k}
256 | u_k_imag_p = 0.5*u_k_imag - 0.5*u_k_imag_flip; # make conjugate conj(u_k) = u_{N -k}
257 |
258 | u_k_real_p = torch.from_numpy(u_k_real_p);
259 | u_k_imag_p = torch.from_numpy(u_k_imag_p);
260 |
261 | u_k_real_p = u_k_real_p.type(torch.float32);
262 | u_k_imag_p = u_k_imag_p.type(torch.float32);
263 |
264 | u_hat = self.getComplex(u_k_real_p.numpy(),u_k_imag_p.numpy());
265 |
266 | f_hat = self.computeCoeffActionL(u_hat,L_hat);
267 | f_hat = f_hat; # target operator relation for PDEs later is Lu = -f, so f = -Lu.
268 |
269 | # Generate samples u and f, in 2d using ifft2.
270 | # ifft2 is broadcast over last two indices
271 | # perform inverse DFT to get u and f.
272 | u_i = np.fft.ifft(u_hat);
273 | f_i = np.fft.ifft(f_hat);
274 |
275 | if self.op_type == 'u*ux':
276 | f_i = u_i*f_i;
277 | elif self.op_type == 'ux*ux':
278 | f_i = f_i*f_i;
279 | elif self.op_type == 'u*uxx':
280 | f_i = u_i*f_i;
281 | elif self.op_type == 'uxx*uxx':
282 | f_i = f_i*f_i;
283 |
284 | self.samples_X = torch.from_numpy(np.real(u_i)).type(torch.float32); # only grab real part
285 | self.samples_Y = torch.from_numpy(np.real(f_i)).type(torch.float32);
286 |
287 | if noise_factor > 0:
288 | self.samples_Y += noise_factor*torch.randn(*self.samples_Y.shape);
289 |
290 | def __len__(self):
291 | return self.samples_X.size()[0];
292 |
293 | def __getitem__(self,index):
294 | return self.samples_X[index],self.samples_Y[index];
295 |
296 | def to(self,device):
297 | self.samples_X = self.samples_X.to(device);
298 | self.samples_Y = self.samples_Y.to(device);
299 |
300 | return self;
301 |
302 | class diffOp2(torch.utils.data.Dataset):
303 | r"""
304 | Generates samples of the form :math:`(u^{[i]},f^{[i]})` where :math:`f^{[i]} = L[u^{[i]}]`,
305 | where :math:`i` denotes the index of the sample.
306 |
307 | Stores data samples in the form :math:`(u,f)`.
308 |
309 | The samples of u are represented as a tensor of size [nsamples,nchannels,nx]
310 | and sample of f as a tensor of size [nsamples,nchannels,nx].
311 |
312 | Note:
313 | For now, please use nx that is odd. In this initial implementation, we use a
314 | method based on conjugated flips with formula for the odd case which is slightly
315 | simpler than other case.
316 |
317 | """
318 | def flipForFFT(self,u_k_part):
319 | r"""We flip as :math:`f_k = f_{N-k}`. Notice that only :math:`0,\ldots,N-1` entries
320 | stored. This is useful for constructing real-valued function representations
321 | from random coefficients. Real-valued function requires :math:`conj(f_k) = f_{N-k}`.
322 | We can use this flip to construct from random coefficients the term
323 | :math:`u_k = f_k + conj(flip(f_k))`, then above constraint is satisfied.
324 |
325 | Args:
326 | a (Tensor): 1d array to flip.
327 |
328 | Returns:
329 | Tensor: The flipped tensors symmetric under conjucation.
330 | """
331 | nx = self.nx;ny = self.ny;
332 |
333 | u_k_part_row0 = u_k_part[:,:,0,:];
334 | u_k_part_row0 = np.expand_dims(u_k_part_row0,2);
335 | u_k_part_ex = np.concatenate([u_k_part,u_k_part_row0],2);
336 |
337 | u_k_part_col0 = u_k_part_ex[:,:,:,0];
338 | u_k_part_col0 = np.expand_dims(u_k_part_col0,3);
339 | u_k_part_ex = np.concatenate([u_k_part_ex,u_k_part_col0],3);
340 |
341 | u_k_part_ex_flip = np.flip(u_k_part_ex,2);
342 | u_k_part_ex_flip = np.flip(u_k_part_ex_flip,3);
343 |
344 | u_k_part_flip = np.delete(u_k_part_ex_flip,nx,2);
345 | u_k_part_flip = np.delete(u_k_part_flip,ny,3);
346 |
347 | return u_k_part_flip;
348 |
349 | def getComplex(self,a,b):
350 | j = np.complex(0,1); # create complex number (or use 1j).
351 | c = a + j*b;
352 | return c;
353 |
354 | def getRealImag(self,c):
355 | a = np.real(c);
356 | b = np.imag(c);
357 | return a,b;
358 |
359 | def computeLSymbol_laplacian_u(self):
360 | r"""Compute associated Fourier symbols for use under DFT for the operator L[u]."""
361 | num_dim = 1;nx=self.nx;ny=self.ny;
362 | vec_k1 = torch.zeros(nx,ny);
363 | vec_k2 = torch.zeros(nx,ny);
364 | vec_k1_pp = torch.zeros(nx,ny);
365 | vec_k2_pp = torch.zeros(nx,ny);
366 | vec_k_sq = torch.zeros(nx,ny);
367 | L_symbol_real = torch.zeros(nx,ny,dtype=torch.float32);
368 | L_symbol_imag = torch.zeros(nx,ny,dtype=torch.float32);
369 | neg_four_pi_sq = -4.0*np.pi*np.pi;
370 | for i in range(0,nx):
371 | for j in range(0,ny):
372 | vec_k1[i,j] = i;
373 | vec_k2[i,j] = j;
374 | vec_k_sq[i,j] = vec_k1[i,j]*vec_k1[i,j] + vec_k2[i,j]*vec_k2[i,j];
375 | if (vec_k1[i,j] < nx/2):
376 | vec_k1_p = vec_k1[i,j];
377 | else:
378 | vec_k1_p = vec_k1[i,j] - nx;
379 | if (vec_k2[i,j] < ny/2):
380 | vec_k2_p = vec_k2[i,j];
381 | else:
382 | vec_k2_p = vec_k2[i,j] - ny;
383 | vec_k1_pp[i,j] = vec_k1_p;
384 | vec_k2_pp[i,j] = vec_k2_p;
385 | vec_k_p_sq = vec_k1_p*vec_k1_p + vec_k2_p*vec_k2_p;
386 | L_symbol_real[i,j] = neg_four_pi_sq*vec_k_p_sq;
387 | L_symbol_imag[i,j] = 0.0;
388 |
389 | L_hat = self.getComplex(L_symbol_real.numpy(),L_symbol_imag.numpy());
390 |
391 | return L_hat, vec_k1_pp, vec_k2_pp;
392 |
393 | def computeLSymbol_grad_u(self):
394 | r"""Compute associated Fourier symbols for use under DFT for the operator L[u]."""
395 | num_dim = 2;nx=self.nx;ny=self.ny;
396 | vec_k1 = torch.zeros(nx,ny);
397 | vec_k2 = torch.zeros(nx,ny);
398 | vec_k1_pp = torch.zeros(nx,ny);
399 | vec_k2_pp = torch.zeros(nx,ny);
400 | vec_k_sq = torch.zeros(nx,ny);
401 | L_symbol_real = torch.zeros(num_dim,nx,ny,dtype=torch.float32);
402 | L_symbol_imag = torch.zeros(num_dim,nx,ny,dtype=torch.float32);
403 | two_pi = 2.0*np.pi;
404 | #two_pi_i = two_pi*1j; # $2\pi{i}$, 1j = sqrt(-1)
405 | for i in range(0,nx):
406 | for j in range(0,ny):
407 | vec_k1[i,j] = i;
408 | vec_k2[i,j] = j;
409 | vec_k_sq[i,j] = vec_k1[i,j]*vec_k1[i,j] + vec_k2[i,j]*vec_k2[i,j];
410 | if (vec_k1[i,j] < nx/2):
411 | vec_k1_p = vec_k1[i,j];
412 | else:
413 | vec_k1_p = vec_k1[i,j] - nx;
414 | if (vec_k2[i,j] < ny/2):
415 | vec_k2_p = vec_k2[i,j];
416 | else:
417 | vec_k2_p = vec_k2[i,j] - ny;
418 | vec_k1_pp[i,j] = vec_k1_p;
419 | vec_k2_pp[i,j] = vec_k2_p;
420 | vec_k_p_sq = vec_k1_p*vec_k1_p + vec_k2_p*vec_k2_p;
421 | L_symbol_real[0,i,j] = 0.0;
422 | L_symbol_imag[0,i,j] = two_pi*vec_k1_p;
423 | L_symbol_real[1,i,j] = 0.0;
424 | L_symbol_imag[1,i,j] = two_pi*vec_k2_p;
425 |
426 | L_hat_0 = self.getComplex(L_symbol_real[0,:,:].numpy(),L_symbol_imag[0,:,:].numpy());
427 | L_hat_1 = self.getComplex(L_symbol_real[1,:,:].numpy(),L_symbol_imag[1,:,:].numpy());
428 |
429 | L_hat = np.stack((L_hat_0,L_hat_1));
430 |
431 | return L_hat, vec_k1_pp, vec_k2_pp;
432 |
433 | def computeCoeffActionL(self,u_hat,L_hat):
434 | r"""Computes the action of operator L used for data generation in Fourier space."""
435 | u_k_real, u_k_imag = self.getRealImag(u_hat);
436 | L_symbol_real, L_symbol_imag = self.getRealImag(L_hat);
437 |
438 | f_k_real = L_symbol_real*u_k_real - L_symbol_imag*u_k_imag; #broadcast will distr over copies of u.
439 | #f_k_real = -1.0*f_k_real;
440 | f_k_imag = L_symbol_real*u_k_imag + L_symbol_imag*u_k_real;
441 | #f_k_imag = -1.0*f_k_imag;
442 |
443 | # Generate samples u and f using ifft2.
444 | f_hat = self.getComplex(f_k_real,f_k_imag);
445 |
446 | return f_hat;
447 |
448 | def computeActionL(self,u,L_hat):
449 | r"""Computes the action of operator L used for data generation."""
450 | raise Exception('Currently this routine not debugged, need to test first.')
451 |
452 | # perform FFT to get u_hat
453 | u_hat = np.fft.fft2(u);
454 |
455 | # compute action of L_hat
456 | f_hat = self.computeCoeffActionL(u_hat,L_hat);
457 |
458 | # compute inverse FFT to get f
459 | f = np.fft.ifft2(f_hat)
460 |
461 | return f;
462 |
463 | def __init__(self,op_type=r'\Delta{u}',op_params=None,
464 | gen_mode='exp1',gen_params={'alpha1':0.1},
465 | num_samples=int(1e4),nchannels=1,nx=15,ny=15,
466 | flag_verbose=0, **extra_params):
467 | r"""Setup for data generation.
468 |
469 | Args:
470 | op_type (str): The differential operator to sample.
471 | op_params (dict): The operator parameters.
472 | gen_mode (str): The mode for the data generator.
473 | gen_params (dict): The parameters for the given generator.
474 | num_samples (int): The number of samples to generate.
475 | nchannels (int): The number of channels.
476 | nx (int): The number of input sample points in x-direction.
477 | ny (int): The number of input sample points in y- direction.
478 | flag_verbose (int): Level of reporting during calculations.
479 | extra_params (dict): Extra parameters for the sampler.
480 |
481 | For extra_params we have:
482 | noise_factor (float): The amount of noise to add to samples.
483 | scale_factor (float): A factor to scale magnitude of the samples.
484 | flagComputeL (bool): If the fourier symbol of operator should be computed.
485 |
486 | For generator modes we have:
487 | gen_mode == 'exp1':
488 | alpha1 (float): The decay rate.
489 |
490 | Note:
491 | For now, please use only nx that is odd. In this initial implementation, we use a
492 | method based on conjugated flips with formula for the odd case which is slightly
493 | simpler than other case.
494 | """
495 | if flag_verbose > 0:
496 | print("Generating the data samples which can take some time.");
497 | print("num_samples = %d"%num_samples);
498 |
499 | self.op_type=op_type;
500 | self.op_params=op_params;
501 |
502 | self.gen_mode=gen_mode;
503 | self.gen_params=gen_params;
504 |
505 | self.num_samples=num_samples;
506 | self.nchannels=nchannels;
507 | self.nx=nx; self.ny=ny;
508 |
509 | if (nx % 2 == 0) or (ny % 2 == 0) or (nx != ny): # may be able to relax nx != ny (just for safety)
510 | msg = "Not allowed yet to use nx,ny that are even or unequal. ";
511 | msg += "For now, please just use nx,ny that is odd given the flips currently used."
512 | raise Exception(msg);
513 |
514 | noise_factor=0;scale_factor=1.0;flagComputeL=False; # default values
515 | if 'noise_factor' in extra_params:
516 | noise_factor = extra_params['noise_factor'];
517 |
518 | if 'scale_factor' in extra_params:
519 | scale_factor = extra_params['scale_factor'];
520 |
521 | if 'flagComputeL' in extra_params:
522 | flagComputeL = extra_params['flagComputeL'];
523 |
524 | # Generate for the operator the Fourier symbols
525 | flag_vv = 'null';
526 | if self.op_type == r'\grad{u}' or self.op_type == r'u\grad{u}' or self.op_type == r'\grad{u}\cdot\grad{u}':
527 | L_hat, vec_k1_pp, vec_k2_pp = self.computeLSymbol_grad_u();
528 | flag_vv = 'vector2';
529 | elif self.op_type == r'\Delta{u}' or self.op_type == r'u\Delta{u}' or self.op_type == r'\Delta{u}*\Delta{u}':
530 | L_hat, vec_k1_pp, vec_k2_pp = self.computeLSymbol_laplacian_u();
531 | flag_vv = 'scalar';
532 | else:
533 | raise Exception("Unknown operator type.");
534 |
535 | if (flagComputeL):
536 | raise Exception("Currently not yet supported, the flagComputeL.");
537 | L_i = np.fft.ifft2(L_hat);
538 | self.L_hat = L_hat;
539 | self.L_i = L_i;
540 | u = np.zeros(nx,ny);
541 | i0 = int(nx/2);
542 | j0 = int(ny/2);
543 | u[i0,j0] = 1.0;
544 | self.G_i = self.computeActionL(u);
545 |
546 | # Generate random input function (want real-valued)
547 | # conj(u_k) = u_{N -k} needs to hold.
548 | u_k_real = np.random.randn(num_samples,nchannels,nx,ny);
549 | u_k_imag = np.random.randn(num_samples,nchannels,nx,ny);
550 |
551 | # scale modes to make smooth
552 | if gen_mode=='exp1':
553 | alpha1 = gen_params['alpha1'];
554 | factor_k = scale_factor*np.exp(-alpha1*(vec_k1_pp**2 + vec_k2_pp**2));
555 | factor_k = factor_k.numpy();
556 | else:
557 | raise Exception("Generation mode not recognized.");
558 |
559 | u_k_real = u_k_real*factor_k; # broadcast will apply over last two dimensions
560 | u_k_imag = u_k_imag*factor_k; # broadcast will apply over last two dimensions
561 |
562 | # flip modes for constructing rep of real-valued function
563 | u_k_real_flip = self.flipForFFT(u_k_real);
564 | u_k_imag_flip = self.flipForFFT(u_k_imag);
565 |
566 | u_k_real = 0.5*u_k_real + 0.5*u_k_real_flip; # make conjugate conj(u_k) = u_{N -k}
567 | u_k_imag = 0.5*u_k_imag - 0.5*u_k_imag_flip; # make conjugate conj(u_k) = u_{N -k}
568 |
569 | u_k_real = torch.from_numpy(u_k_real);
570 | u_k_imag = torch.from_numpy(u_k_imag);
571 |
572 | u_k_real = u_k_real.type(torch.float32);
573 | u_k_imag = u_k_imag.type(torch.float32);
574 |
575 | u_hat = self.getComplex(u_k_real.numpy(),u_k_imag.numpy());
576 | if flag_vv == 'scalar':
577 | f_hat = self.computeCoeffActionL(u_hat,L_hat);
578 | elif flag_vv == 'vector2':
579 | f_hat_0 = self.computeCoeffActionL(u_hat,L_hat[0,:,:]);
580 | f_hat_1 = self.computeCoeffActionL(u_hat,L_hat[1,:,:]);
581 | f_hat = np.concatenate((f_hat_0,f_hat_1),-3);
582 | else:
583 | raise Exception("Unkonwn operator type.");
584 |
585 | # Generate samples u and f using ifft2.
586 | # ifft2 is broadcast over last two indices
587 | # perform inverse DFT to get u and f
588 | u_i = np.fft.ifft2(u_hat);
589 | if flag_vv == 'scalar':
590 | f_i = np.fft.ifft2(f_hat);
591 | elif flag_vv == 'vector2':
592 | f_i_0 = np.fft.ifft2(f_hat[:,0,:,:]);
593 | f_i_1 = np.fft.ifft2(f_hat[:,1,:,:]);
594 | f_i = np.stack((f_i_0,f_i_1),-3);
595 | else:
596 | raise Exception("Unkonwn operator type.");
597 |
598 | if self.op_type == r'\grad{u}':
599 | f_i = f_i; # nothing to do.
600 | elif self.op_type == r'u\grad{u}':
601 | f_i = u_i*f_i; # matches up by broadcast rules
602 | elif self.op_type == r'\grad{u}\cdot\grad{u}':
603 | f_i = np.sum(f_i**2,1); # sum on axis for channels, [batch,channel,nx,ny].
604 | f_i = np.expand_dims(f_i,1); # keep in form [batch,1,nx,ny]
605 | elif self.op_type == r'\Delta{u}':
606 | f_i = f_i; # nothing to do.
607 | elif self.op_type == r'u\Delta{u}':
608 | f_i = u_i*f_i;
609 | elif self.op_type == r'\Delta{u}*\Delta{u}':
610 | f_i = f_i**2;
611 | else:
612 | raise Exception("Unkonwn operator type.");
613 |
614 | self.samples_X = torch.from_numpy(np.real(u_i)).type(torch.float32); # only grab real part
615 | self.samples_Y = torch.from_numpy(np.real(f_i)).type(torch.float32);
616 |
617 | if noise_factor > 0:
618 | self.samples_Y += noise_factor*torch.randn(*self.samples_Y.shape);
619 |
620 | def __len__(self):
621 | return self.samples_X.size()[0];
622 |
623 | def __getitem__(self,index):
624 | return self.samples_X[index],self.samples_Y[index];
625 |
626 | def to(self,device):
627 | self.samples_X = self.samples_X.to(device);
628 | self.samples_Y = self.samples_Y.to(device);
629 | return self;
630 |
631 |
--------------------------------------------------------------------------------
/src/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | .. image:: overview.png
3 |
4 | PyTorch implementation of GMLS-Nets. Module for neural networks for
5 | processing scattered data sets using Generalized Moving Least Squares (GMLS).
6 |
7 | If you find these codes or methods helpful for your project, please cite:
8 |
9 | | @article{trask_patel_gross_atzberger_GMLS_Nets_2019,
10 | | title={GMLS-Nets: A framework for learning from unstructured data},
11 | | author={Nathaniel Trask, Ravi G. Patel, Ben J. Gross, Paul J. Atzberger},
12 | | journal={arXiv:1909.05371},
13 | | month={September},
14 | | year={2019},
15 | | url={https://arxiv.org/abs/1909.05371}
16 | | }
17 |
18 | """
19 |
20 | # Authors: B.J. Gross and P.J. Atzberger
21 | # Website: http://atzberger.org/
22 |
23 | import torch;
24 | import torch.nn as nn;
25 | import torchvision;
26 | import torchvision.transforms as transforms;
27 |
28 | import numpy as np;
29 |
30 | import scipy.spatial as spatial # used for finding neighbors within distance $\delta$
31 |
32 | from collections import OrderedDict;
33 |
34 | import pickle as p;
35 |
36 | import pdb;
37 |
38 | import time;
39 |
40 | # ====================================
41 | # Custom Functions
42 | # ====================================
43 | class MapToPoly_Function(torch.autograd.Function):
44 | r"""
45 | This layer processes a collection of scattered data points consisting of a collection
46 | of values :math:`u_j` at points :math:`x_j`. For a collection of target points
47 | :math:`x_i`, local least-squares problems are solved for obtaining a local representation
48 | of the data over a polynomial space. The layer outputs a collection of polynomial
49 | coefficients :math:`c(x_i)` at each point and the collection of target points :math:`x_i`.
50 | """
51 | @staticmethod
52 | def weight_one_minus_r(z1,z2,params):
53 | r"""Weight function :math:`\omega(x_j,x_i) = \left(1 - r/\epsilon\right)^{\bar{p}}_+.`
54 |
55 | Args:
56 | z1 (Tensor): The first point. Tensor of shape [1,num_dims].
57 | z2 (Tensor): The second point. Tensor of shape [1,num_dims].
58 | params (dict): The parameters are 'p' for decay power and 'epsilon' for support size.
59 |
60 | Returns:
61 | Tensor: The weight evaluation over points.
62 | """
63 | epsilon = params['epsilon']; p = params['p'];
64 | r = torch.sqrt(torch.sum(torch.pow(z1 - z2,2),1));
65 |
66 | diff = torch.clamp(1 - (r/epsilon),min=0);
67 | eval = torch.pow(diff,p);
68 | return eval;
69 |
70 | @staticmethod
71 | def get_num_polys(porder,num_dim=None):
72 | r""" Returns the number of polynomials of given porder. """
73 | if num_dim == 1:
74 | num_polys = porder + 1;
75 | elif num_dim == 2:
76 | num_polys = int((porder + 2)*(porder + 1)/2);
77 | elif num_dim == 3:
78 | num_polys = 0;
79 | for beta in range(0,porder + 1):
80 | num_polys += int((porder - beta + 2)*(porder - beta + 1)/2);
81 | else:
82 | raise Exception("Number of dimensions not implemented currently. \n num_dim = %d."%num_dim);
83 |
84 | return num_polys;
85 |
86 | @staticmethod
87 | def eval_poly(pts_x,pts_x2_i0,c_star_i0,porder,flag_verbose):
88 | r""" Evaluates the polynomials locally around a target point xi given coefficients c. """
89 | # Evaluates the polynomial locally (this helps to assess the current fit).
90 | # Implemented for 1D, 2D, and 3D.
91 | #
92 | # 2D:
93 | # Computes Taylor Polynomials over x and y.
94 | # T_{k1,k2}(x1,x2) = (1.0/(k1 + k2)!)*(x1 - x01)^{k1}*(x2 - x02)^{k2}.
95 | # of terms is N = (porder + 1)*(porder + 2)/2.
96 | #
97 | # WARNING: Note the role of factorials and orthogonality here. The Taylor
98 | # expansion/polynomial formulation is not ideal and can give ill-conditioning.
99 | # It would be better to use orthogonal polynomials or other bases.
100 | #
101 | num_dim = pts_x.shape[1];
102 | if num_dim == 1:
103 | II = 0;
104 | alpha_factorial = 1.0;
105 | eval_p = torch.zeros(pts_x.shape[0],device=c_star_i0.device);
106 | for alpha in np.arange(0,porder + 1):
107 | if alpha >= 2:
108 | alpha_factorial *= alpha;
109 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k));
110 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials
111 | base_poly = torch.pow(pts_x[:,0] - pts_x2_i0[0],alpha);
112 | base_poly = base_poly/alpha_factorial;
113 | eval_p += c_star_i0[II]*base_poly;
114 | II += 1;
115 | elif num_dim == 2:
116 | II = 0;
117 | alpha_factorial = 1.0;
118 | eval_p = torch.zeros(pts_x.shape[0],device=c_star_i0.device);
119 | for alpha in np.arange(0,porder + 1):
120 | if alpha >= 2:
121 | alpha_factorial *= alpha;
122 | for k in np.arange(0,alpha + 1):
123 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k));
124 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials
125 | base_poly = torch.pow(pts_x[:,0] - pts_x2_i0[0],alpha - k);
126 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials
127 | base_poly = base_poly*torch.pow(pts_x[:,1] - pts_x2_i0[1],k);
128 | base_poly = base_poly/alpha_factorial;
129 | eval_p += c_star_i0[II]*base_poly;
130 | II += 1;
131 | elif num_dim == 3: # caution, below gives initial results, but should be more fully validated
132 | II = 0;
133 | alpha_factorial = 1.0;
134 | eval_p = torch.zeros(pts_x.shape[0],device=c_star_i0.device);
135 | for beta in np.arange(0,porder + 1):
136 | base_poly = torch.pow(pts_x[:,2] - pts_x2_i0[2],beta);
137 | for alpha in np.arange(0,porder - beta + 1):
138 | if alpha >= 2:
139 | alpha_factorial *= alpha;
140 | for k in np.arange(0,alpha + 1):
141 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k));
142 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials
143 | base_poly = base_poly*torch.pow(pts_x[:,0] - pts_x2_i0[0],alpha - k);
144 | base_poly = base_poly*torch.pow(pts_x[:,1] - pts_x2_i0[1],k);
145 | base_poly = base_poly/alpha_factorial;
146 | eval_p += c_star_i0[II]*base_poly;
147 | II += 1;
148 | else:
149 | raise Exception("Number of dimensions not implemented currently. \n num_dim = %d."%num_dim);
150 |
151 | return eval_p;
152 |
153 | @staticmethod
154 | def generate_mapping(weight_func,weight_func_params,
155 | porder,epsilon,
156 | pts_x1,pts_x2,
157 | tree_points=None,device=None,
158 | flag_verbose=0):
159 | r""" Generates for caching the data for the mapping from field values (uj,xj) :math:`\rightarrow` (ci,xi).
160 | This help optimize codes and speed up later calculations that are done repeatedly."""
161 | if device is None:
162 | device = torch.device('cpu');
163 |
164 | map_data = {};
165 |
166 | num_dim = pts_x1.shape[1];
167 |
168 | if pts_x2 is None:
169 | pts_x2 = pts_x1;
170 |
171 | pts_x1 = pts_x1.to(device);
172 | pts_x2 = pts_x2.to(device);
173 |
174 | pts_x1_numpy = None; pts_x2_numpy = None;
175 | if tree_points is None: # build kd-tree of points for neighbor listing
176 | if pts_x1_numpy is None: pts_x1_numpy = pts_x1.cpu().numpy();
177 | tree_points = spatial.cKDTree(pts_x1_numpy);
178 |
179 | # Maps from u(x_j) on $x_j \in \mathcal{S}^1$ to a
180 | # polynomial representations in overlapping regions $\Omega_i$ at locations
181 | # around points $x_i \in \mathcal{S}^2$.
182 | # These two sample sets need not be the same allowing mappings between point locations.
183 | # Computes polynomials over x and y.
184 | # Number of terms in 2D is num_polys = (porder + 1)*(porder + 2)/2.
185 | num_pts1 = pts_x1.shape[0]; num_pts2 = pts_x2.shape[0];
186 | num_polys = MapToPoly_Function.get_num_polys(porder,num_dim);
187 | if flag_verbose > 0:
188 | print("num_polys = " + str(num_polys));
189 |
190 | M = torch.zeros((num_pts2,num_polys,num_polys),device=device); # assemble matrix at each grid-point
191 | M_inv = torch.zeros((num_pts2,num_polys,num_polys),device=device); # assemble matrix at each grid-point
192 |
193 | #svd_U = torch.zeros((num_pts2,num_polys,num_polys)); # assemble matrix at each grid-point
194 | #svd_S = torch.zeros((num_pts2,num_polys,num_polys)); # assemble matrix at each grid-point
195 | #svd_V = torch.zeros((num_pts2,num_polys,num_polys)); # assemble matrix at each grid-point
196 |
197 | vec_rij = torch.zeros((num_pts2,num_polys,num_pts1),device=device); # @optimize: ideally should be sparse matrix.
198 |
199 | # build up the batch of linear systems for each target point
200 | for i in np.arange(0,num_pts2): # loop over the points $x_i$
201 |
202 | if (flag_verbose > 0) & (i % 100 == 0): print("i = " + str(i) + " : num_pts2 = " + str(num_pts2));
203 |
204 | if pts_x2_numpy is None: pts_x2_numpy = pts_x2.cpu().numpy();
205 | indices_xj_i = tree_points.query_ball_point(pts_x2_numpy[i,:], epsilon); # find all points with distance
206 | # less than epsilon from xi.
207 |
208 | for j in indices_xj_i: # @optimize later to use only local points, and where weights are non-zero.
209 |
210 | if flag_verbose > 1: print("j = " + str(j));
211 |
212 | vec_p_j = torch.zeros(num_polys,device=device);
213 | w_ij = weight_func(pts_x1[j,:].unsqueeze(0), pts_x2[i,:].unsqueeze(0), weight_func_params); # can optimize for sub-lists outer-product
214 |
215 | # Computes Taylor Polynomials over x,y,z.
216 | #
217 | # 2D Case:
218 | # T_{k1,k2}(x1,x2) = (1.0/(k1 + k2)!)*(x1 - x01)^{k1}*(x2 - x02)^{k2}.
219 | # number of terms is N = (porder + 1)*(porder + 2)/2.
220 | # computes polynomials over x and y.
221 | #
222 | # WARNING: The monomial basis is non-ideal and can lead to ill-conditioned linear algebra.
223 | # This ultimately should be generalized in the future to other bases, ideally orthogonal,
224 | # which would help both with efficiency and conditioning of the linear algebra.
225 | #
226 | if num_dim == 1:
227 | # number of terms is N = porder + 1.
228 | II = 0;
229 | for alpha in np.arange(0,porder + 1):
230 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k));
231 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials
232 | vec_p_j[II] = torch.pow(pts_x1[j,0] - pts_x2[i,0], alpha);
233 | II += 1;
234 | elif num_dim == 2:
235 | # number of terms is N = (porder + 1)*(porder + 2)/2.
236 | II = 0;
237 | for alpha in np.arange(0,porder + 1):
238 | for k in np.arange(0,alpha + 1):
239 | if flag_verbose > 1: print("alpha = " + str(alpha)); print("k = " + str(k));
240 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials
241 | vec_p_j[II] = torch.pow(pts_x1[j,0] - pts_x2[i,0], alpha - k);
242 | vec_p_j[II] = vec_p_j[II]*torch.pow(pts_x1[j,1] - pts_x2[i,1], k);
243 | II += 1;
244 | elif num_dim == 3:
245 | # number of terms is N = sum_{alpha_3 = 0}^porder [(porder - alpha_3+ 1)*(porder - alpha_3 + 2)/2.
246 | II = 0;
247 | for beta in np.arange(0,porder + 1):
248 | vec_p_j[II] = torch.pow(pts_x1[j,2] - pts_x2[i,2],beta);
249 | for alpha in np.arange(0,porder - beta + 1):
250 | for k in np.arange(0,alpha + 1):
251 | if flag_verbose > 1:
252 | print("beta = " + str(beta)); print("alpha = " + str(alpha)); print("k = " + str(k));
253 | # for now, (x - x_*)^k, but ideally use orthogonal polynomials
254 | vec_p_j[II] = vec_p_j[II]*torch.pow(pts_x1[j,0] - pts_x2[i,0],alpha - k);
255 | vec_p_j[II] = vec_p_j[II]*torch.pow(pts_x1[j,1] - pts_x2[i,1],k);
256 | II += 1;
257 |
258 | # add contributions to the M(x_i) and r(x_i) terms
259 | # r += (w_ij*u[j])*vec_p_j;
260 | vec_rij[i,:,j] = w_ij*vec_p_j;
261 | M[i,:,:] += torch.ger(vec_p_j,vec_p_j)*w_ij; # outer-product of vectors (build match of matrices)
262 |
263 | # Compute the SVD of M for purposes of computing the pseudo-inverse (for solving least-squares problem).
264 | # Note: M is always symmetric positive semi-definite, so U and V should be transposes of each other
265 | # and sigma^2 are the eigenvalues squared. This simplifies some expressions.
266 |
267 | U,S,V = torch.svd(M[i,:,:]); # M = U*SS*V^T, note SS = diag(S)
268 | threshold_nonzero = 1e-9; # threshold for the largest singular value to consider being non-zero.
269 | I_nonzero = (S > threshold_nonzero);
270 | S_inv = 0.0*S;
271 | S_inv[I_nonzero] = 1.0/S[I_nonzero];
272 | SS_inv = torch.diag(S_inv);
273 | M_inv[i,:,:] = torch.matmul(V,torch.matmul(SS_inv,U.t())); # pseudo-inverse of M^{-1} = V*S^{-1}*U^T
274 |
275 | # Save the linear system information for the least-squares problem at each target point $xi$.
276 | map_data['M'] = M;
277 | map_data['M_inv'] = M_inv;
278 | map_data['vec_rij'] = vec_rij;
279 |
280 | return map_data;
281 |
282 | @staticmethod
283 | def get_poly_1D_u(u, porder, weight_func, weight_func_params,
284 | pts_x1, epsilon = None, pts_x2 = None, cached_data=None,
285 | tree_points = None, device=None, flag_verbose = 0):
286 | r""" Compute the polynomial coefficients in the case of a scalar field. Would not typically call directly, used for internal purposes. """
287 |
288 | # We assume that all inputs are pytorch tensors
289 | # Assumes:
290 | # pts_x1.size = [num_pts,num_dim]
291 | # pts_x2.size = [num_pts,num_dim]
292 | #
293 | # @optimize: Should cache the points and neighbor lists... then using torch.solve, torch.ger.
294 | # Should vectorize all of the for-loop operations via Lambdifying polynomial evals.
295 | # Should avoid numpy calculations, maybe cache numpy copy of data if needed to avoid .cpu() transfer calls.
296 | # Use batching over points to do solves, then GPU parallizable and faster.
297 | #
298 | if device is None:
299 | device = torch.device('cpu'); # default cpu device
300 |
301 | if (u.dim() > 1):
302 | print("u.dim = " + str(u.dim()));
303 | print("u.shape = " + str(u.shape));
304 | raise Exception("Assumes input with dimension == 1.");
305 |
306 | if (cached_data is None) or ('map_data' not in cached_data) or (cached_data['map_data'] is None):
307 | generate_mapping = MapToPoly_Function.generate_mapping;
308 |
309 | if pts_x2 is None:
310 | pts_x2 = pts_x1;
311 |
312 | map_data = generate_mapping(weight_func,weight_func_params,
313 | porder,epsilon,
314 | pts_x1,pts_x2,tree_points,device);
315 |
316 | if cached_data is not None:
317 | cached_data['map_data'] = map_data;
318 |
319 | else:
320 | map_data = cached_data['map_data']; # use cached data
321 |
322 | if flag_verbose > 0:
323 | print("num_pts1 = " + str(num_pts1) + ", num_pts2 = " + str(num_pts2));
324 |
325 | if epsilon is None:
326 | raise Exception('The epsilon ball size to use around xi must be specified.')
327 |
328 | # Maps from u(x_j) on $x_j \in \mathcal{S}^1$ to a
329 | # polynomial representations in overlapping regions $\Omega_i$ at locations
330 | # around points $x_i \in \mathcal{S}^2$.
331 |
332 | # These two sample sets need not be the same allowing mappings between point sets.
333 | # Computes polynomials over x and y.
334 | # For 2D case, number of terms is num_polys = (porder + 1)*(porder + 2)/2.
335 |
336 | #c_star[:,i] = np.linalg.solve(np_M,np_r); # "c^*(x_i) = M^{-1}*r."
337 | vec_rij = map_data['vec_rij'];
338 | M_inv = map_data['M_inv'];
339 |
340 | r_all = torch.matmul(vec_rij,u);
341 | c_star = torch.bmm(M_inv,r_all.unsqueeze(2)); # perform batch matric-vector multiplications
342 | c_star = c_star.squeeze(2); # convert to list of vectors
343 |
344 | output = c_star;
345 | output = output.float(); # Map to float type for GPU / PyTorch Module compatibilities.
346 |
347 | return output, pts_x2;
348 |
349 | @staticmethod
350 | def forward(ctx, input, porder, weight_func, weight_func_params,
351 | pts_x1, epsilon = None, pts_x2 = None, cached_data=None,
352 | tree_points = None, device = None, flag_verbose = 0):
353 | r"""
354 |
355 | For a field u specified at points xj, performs the mapping to coefficients c at points xi, (uj,xj) :math:`\rightarrow` (ci,xi).
356 |
357 | Args:
358 | input (Tensor): The input field data uj.
359 | porder (int): Order of the basis to use (polynomial degree).
360 | weight_func (function): Weight function to use.
361 | weight_func_params (dict): Weight function parameters.
362 | pts_x1 (Tensor): The collection of domain points :math:`x_j`.
363 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params).
364 | pts_x2 (Tensor): The collection of target points :math:`x_i`.
365 | cache_data (dict): Stored data to help speed up repeated calculations.
366 | tree_points (dict): Stored data to help speed up repeated calculations.
367 | device (torch.device): Device on which to perform calculations (GPU or other, default is CPU).
368 | flag_verbose (int): Level of reporting on progress during the calculations.
369 |
370 | Returns:
371 | tuple of (ci,xi): The coefficient values ci at the target points xi. The target points xi.
372 |
373 | """
374 | if device is None:
375 | device = torch.device('cpu');
376 |
377 | ctx.atz_name = 'MapToPoly_Function';
378 |
379 | ctx.save_for_backward(input,pts_x1,pts_x2);
380 |
381 | ctx.atz_porder = porder;
382 | ctx.atz_weight_func = weight_func;
383 | ctx.atz_weight_func_params = weight_func_params;
384 |
385 | get_poly_1D_u = MapToPoly_Function.get_poly_1D_u;
386 | get_num_polys = MapToPoly_Function.get_num_polys;
387 |
388 | input_dim = input.dim();
389 | if input_dim >= 1: # compute c_star in batches
390 | pts_x1_numpy = None;
391 | pts_x2_numpy = None;
392 |
393 | if pts_x2 is None:
394 | pts_x2 = pts_x1;
395 |
396 | # reshape the data to handle as a batch [batch_size, uj_data_size]
397 | # We assume u is input in the form [I,k,xj], u_I(k,xj), the index I is arbitrary.
398 | u = input;
399 |
400 | if input_dim == 2: # need to unsqueeze, so 2D we are mapping
401 | # [k,xj] --> [I,k,xj] --> [II,c] --> [I,k,xi,c] --> [k,xi,c]
402 | u = u.unsqueeze(0); # u(k,xj) assumed in our calculations here
403 |
404 | if input_dim == 1: # need to unsqueeze, so 1D we are mapping
405 | # [xj] --> [I,k,xj] --> [II,c] --> [I,k,xi,c] --> [xi,c]
406 | u = u.unsqueeze(0); # u(k,xj) assumed in our calculations here
407 | u = u.unsqueeze(0); # u(k,xj) assumed in our calculations here
408 |
409 | u_num_dim = u.dim();
410 | size_nm1 = 1;
411 | for d in range(u_num_dim - 1):
412 | size_nm1 *= u.shape[d];
413 |
414 | uu = u.contiguous().view((size_nm1,u.shape[-1]));
415 |
416 | # compute the sizes of c_star and number of points
417 | num_dim = pts_x1.shape[1];
418 | num_polys = get_num_polys(porder,num_dim);
419 | num_pts2 = pts_x2.shape[0];
420 |
421 | # output needs to be of size [batch_size, xi_data_size, num_polys]
422 | output = torch.zeros((uu.shape[0],num_pts2,num_polys),device=device); # will reshape at the end
423 |
424 | # loop over the batches and compute the c_star in each case
425 | if cached_data is None:
426 | cached_data = {}; # create empty, which can be computed first time to store data.
427 |
428 | if tree_points is None:
429 | if pts_x1_numpy is None: pts_x1_numpy = pts_x1.cpu().numpy();
430 | tree_points = spatial.cKDTree(pts_x1_numpy);
431 |
432 | for k in range(uu.shape[0]):
433 | uuu = uu[k,:];
434 | out, pts_x2 = get_poly_1D_u(uuu,porder,weight_func,weight_func_params,
435 | pts_x1,epsilon,pts_x2,cached_data,
436 | tree_points,flag_verbose);
437 | output[k,:,:] = out;
438 |
439 | # final output should be [*, xi_data_size, num_polys], where * is the original sizes
440 | # for indices [i1,i2,...in,k_channel,xi_data,c_poly_coeff].
441 | output = output.view(*u.shape[0:u_num_dim-1],num_pts2,num_polys);
442 |
443 | if input_dim == 2: # 2D special case we just return k, xi, c (otherwise feed input 3D [I,k,u(xj)] I=1,k=1).
444 | output = output.squeeze(0);
445 |
446 | if input_dim == 1: # 1D special case we just return xi, c (otherwise feed input 3D [I,k,u(xj)] I=1,k=1).
447 | output = output.squeeze(0);
448 | output = output.squeeze(0);
449 |
450 | else:
451 | print("input.dim = " + str(input.dim()));
452 | print("input.shape = " + str(input.shape));
453 | raise Exception("input tensor dimension not yet supported, only dim = 1 and dim = 3 currently.");
454 |
455 | ctx.atz_cached_data = cached_data;
456 |
457 | pts_x2_clone = pts_x2.clone();
458 | return output, pts_x2_clone;
459 |
460 | @staticmethod
461 | def backward(ctx,grad_output,grad_pts_x2):
462 | r""" Consider a field u specified at points xj and the mapping to coefficients c at points xi, (uj,xj) --> (ci,xi).
463 | Computes the gradient of the mapping for backward propagation.
464 | """
465 |
466 | flag_time_it = False;
467 | if flag_time_it:
468 | time_1 = time.time();
469 |
470 | input,pts_x1,pts_x2 = ctx.saved_tensors;
471 |
472 | porder = ctx.atz_porder;
473 | weight_func = ctx.atz_weight_func;
474 | weight_func_params = ctx.atz_weight_func_params;
475 | cached_data = ctx.atz_cached_data;
476 |
477 | #grad_input = grad_weight_func = grad_weight_func_params = None;
478 | grad_uj = None;
479 |
480 | # we only compute the gradient in x_i, if it is requested (for efficiency)
481 | if ctx.needs_input_grad[0]: # derivative in uj
482 | map_data = cached_data['map_data']; # use cached data
483 |
484 | vec_rij = map_data['vec_rij'];
485 | M_inv = map_data['M_inv'];
486 |
487 | # c_i = M_{i}^{-1} r_i^T u
488 | # dF/du = dF/dc*dc/du,
489 | #
490 | # We can express this using dF/uj = sum_i dF/dci*dci/duj
491 | #
492 | # grad_output = dF/dc, grad_input = dF/du
493 | #
494 | # [grad_input]_j = sum_i dF/ci*dci/duj.
495 | #
496 | # In practice, we have both batch and channel indices so
497 | # grad_output.shape = [batchI,channelI,i,compK]
498 | # grad_output[batchI,channelI,i,compK] = F(batchI,channelI) with respect to ci[compK](batchI,channelI).
499 | #
500 | # grad_input[batchI,channelI,j] =
501 | #
502 | # We use matrix broadcasting to get this outcome in practice.
503 | #
504 |
505 | # @optimize can optimize, since uj only contributes non-zero to a few ci's... and could try to use sparse matrix multiplications.
506 | A1 = torch.bmm(M_inv,vec_rij); # dci/du, grad = grad[i,compK,j]
507 | A2 = A1.unsqueeze(0).unsqueeze(0); # match grad_output tensor rank, for grad[batchI,channelI,i,compK,j]
508 | A3 = grad_output.unsqueeze(4); # grad_output[batchI,channelI,i,compK,j]
509 | A4 = A3*A2; # elementwise multiplication
510 | A5 = torch.sum(A4,3); # contract on index compK
511 | A6 = torch.sum(A5,2); # contract on index i
512 |
513 | grad_uj = A6;
514 |
515 | else:
516 | msg_str = "Requested a currently un-implemented gradient for this map: \n";
517 | msg_str += "ctx.needs_input_grad = \n" + str(ctx.needs_input_grad);
518 | raise Exception(msg_str);
519 |
520 | if flag_time_it:
521 | msg = 'MapToPoly_Function->backward():';
522 | msg += 'elapsed_time = %.4e'%(time.time() - time_1);
523 | print(msg);
524 |
525 | return grad_uj,None,None,None,None,None,None,None,None,None,None; # since no trainable parts for these components of map
526 |
527 |
528 | class MaxPoolOverPoints_Function(torch.autograd.Function):
529 | r"""Applies a max-pooling operation to obtain values :math:`v_i = \max_{j \in \mathcal{N}_i(\epsilon)} \{u_j\}.` """
530 | # @optimize: Should cache the points and neighbor lists.
531 | # Should avoid numpy calculations, maybe cache numpy copy of data if needed to avoid .cpu() transfer calls.
532 | # Use batching over points to do solves, then GPU parallizable and faster.
533 | @staticmethod
534 | def forward(ctx,input,pts_x1,epsilon=None,pts_x2=None,
535 | indices_xj_i_cache=None,tree_points=None,
536 | flag_verbose=0):
537 | r"""Compute max pool operation from values at points (uj,xj) to obtain (vi,xi).
538 |
539 | Args:
540 | input (Tensor): The uj values at the location of points xj.
541 | pts_x1 (Tensor): The collection of domain points :math:`x_j`.
542 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params).
543 | pts_x2 (Tensor): The collection of target points :math:`x_i`.
544 | tree_points (dict): Stored data to help speed up repeated calculations.
545 | flag_verbose (int): Level of reporting on progress during the calculations.
546 |
547 | Returns:
548 | tuple: The collection ui at target points (same size as uj in the non-j indices). The collection xi of target points. Tuple of form (ui,xi).
549 |
550 | Note:
551 | We assume that all inputs are pytorch tensors with pts_x1.shape = [num_pts,num_dim] and similarly for pts_x2.
552 |
553 | """
554 |
555 | ctx.atz_name = 'MaxPoolOverPoints_Function';
556 |
557 | ctx.save_for_backward(input,pts_x1,pts_x2);
558 |
559 | u = input.clone(); # map input values u(xj) at xj to max value in epsilon neighborhood to u(xi) at xi points.
560 |
561 | # Assumes that input is of size [k1,k2,...,kn,j], where k1,...,kn are any indices.
562 | # We perform maxing over batch over all non-indices in j.
563 | # We reshape tensor to the form [*,j] where one index in *=index(k1,...,kn).
564 | u_num_dim = u.dim();
565 | size_nm1 = 1;
566 | for d in range(u_num_dim - 1):
567 | size_nm1 *= u.shape[d];
568 |
569 | uj = u.contiguous().view((size_nm1,u.shape[-1])); # reshape so indices --> [I,j], I = index(k1,...,kn).
570 |
571 | # reshaped
572 | if pts_x2 is None:
573 | pts_x2 = pts_x1;
574 |
575 | pts_x1_numpy = pts_x1.cpu().numpy(); pts_x2_numpy = pts_x2.cpu().numpy(); # move to cpu to get numpy data
576 | pts_x1 = pts_x1.to(input.device); pts_x2 = pts_x2.to(input.device); # push back to GPU [@optimize later]
577 | num_pts1 = pts_x1.size()[0]; num_pts2 = pts_x2.size()[0];
578 | if flag_verbose > 0:
579 | print("num_pts1 = " + str(num_pts1) + ", num_pts2 = " + str(num_pts2));
580 |
581 | if epsilon is None:
582 | raise Exception('The epsilon ball size to use around xi must be specified.');
583 |
584 | ctx.atz_epsilon = epsilon;
585 |
586 | if indices_xj_i_cache is None:
587 | flag_need_indices_xj_i = True;
588 | else:
589 | flag_need_indices_xj_i = False;
590 |
591 | if flag_need_indices_xj_i and tree_points is None: # build kd-tree of points for neighbor listing
592 | tree_points = spatial.cKDTree(pts_x1_numpy);
593 |
594 | ctx.atz_tree_points = tree_points;
595 | ctx.indices_xj_i_cache = indices_xj_i_cache;
596 |
597 | # Maps from u(x_j) on $x_j \in \mathcal{S}^1$ to a u(x_i) giving max values in epsilon neighborhoods.
598 | # @optimize by caching these data structure for re-use later
599 | ui = torch.zeros(size_nm1,num_pts2,requires_grad=False,device=input.device);
600 | ui_argmax_j = torch.zeros(size_nm1,num_pts2,dtype=torch.int64,requires_grad=False,device=input.device);
601 | # assumes array of form [*,num_pts2], will be reshaped to match uj, [*,num_pts2].
602 |
603 | for i in np.arange(0,num_pts2): # loop over the points $x_i$
604 | if flag_verbose > 1: print("i = " + str(i) + " : num_pts2 = " + str(num_pts2));
605 |
606 | # find all points distance epsilon from xi
607 | if flag_need_indices_xj_i:
608 | indices_xj_i = tree_points.query_ball_point(pts_x2_numpy[i,:], epsilon);
609 | indices_xj_i = torch.Tensor(indices_xj_i).long();
610 | indices_xj_i.to(uj.device);
611 | else:
612 | indices_xj_i = indices_xj_i_cache[i,:]; # @optimize should consider replacing with better data structures
613 |
614 | # take max over neighborhood. Assumes for now that ui is scalar.
615 | uuj = uj[:,indices_xj_i];
616 | qq = torch.max(uuj,dim=-1,keepdim=True);
617 | ui[:,i] = qq[0].squeeze(-1); # store max value
618 | jj = qq[1].squeeze(-1); # store index of max value
619 | ui_argmax_j[:,i] = indices_xj_i[jj]; # store global index of the max value
620 |
621 | # reshape the tensor from ui[I,i] to the form uui[k1,k2,...kn,i]
622 | uui = ui.view(*u.shape[0:u_num_dim-1],num_pts2);
623 | uui_argmax_j = ui_argmax_j.view(*u.shape[0:u_num_dim-1],num_pts2);
624 | ctx.atz_uui_argmax_j = uui_argmax_j; # save for gradient calculation
625 |
626 | output = uui; # for now, we assume for now that ui is scalar array of size [num_pts2]
627 | output = output.to(input.device);
628 |
629 | return output, pts_x2.clone();
630 |
631 | @staticmethod
632 | def backward(ctx,grad_output,grad_pts_x2):
633 | r"""Compute gradients of the max pool operations from values at points (uj,xj) --> (max_ui,xi). """
634 |
635 | flag_time_it = False;
636 | if flag_time_it:
637 | time_11 = time.time();
638 |
639 | # Compute df/dx from df/dy using the Chain Rule df/dx = df/dx*dy/dx.
640 | # Compute the gradient with respect to inputs, dz/dx.
641 | #
642 | # Consider z = f(g(x)), where we refer to x as the inputs and y = g(x) as outputs.
643 | # If we know dz/dy, we would like to compute dz/dx. This will follow from the chain-rule
644 | # as dz/dx = (dz/dy)*(dy/dx). We call dz/dy the gradient with respect to output and we call
645 | # dy/dx the gradient with respect to input.
646 | #
647 | # Note: the grad_output can be larger than the size of the input vector if we include in our
648 | # definition of gradient_input the derivatives with respect to weights. Should think of everything
649 | # input as tilde_x = [x,weights,bias,etc...], then grad_output = dz/dtilde_x.
650 |
651 | input,pts_x1,pts_x2 = ctx.saved_tensors;
652 |
653 | uui_argmax_j = ctx.atz_uui_argmax_j;
654 |
655 | #grad_input = grad_weight_func = grad_weight_func_params = None;
656 | grad_input = None;
657 |
658 | # We only compute the gradient in xi, if it is requested (for efficiency)
659 | # stubs for later possible use, but not needed for now
660 | if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
661 | msg_str = "Currently requested a non-trainable gradient for this map: \n";
662 | msg_str += "ctx.needs_input_grad = \n" + str(ctx.needs_input_grad);
663 | raise Exception(msg_str);
664 |
665 | if ctx.needs_input_grad[0]:
666 | # Compute dL/duj = (dL/dvi)*(dvi/duj), here vi = uui.
667 | # For the max-pool case, notice that dvi/duj is non-zero only when the index uj
668 | # was the maximum value in the neighborhood of vi. Notice subtle issue with
669 | # right and left derivatives being different, so max is not differentiable for ties.
670 | # We use the right derivative lim_h (q(x + h) - q(x))/h, here.
671 |
672 | # We assume that uj.size = [k1,k2,...,kn,j], ui.size = [k1,k2,...,kn,i].
673 | # These are reshaped so that uuj.size = [I,j] and uui.size = [I,i].
674 | input_dim = input.dim();
675 | size_uj = input.size();
676 | size_uj_nm1 = np.prod(size_uj[0:input_dim-1]); # exclude last index size
677 |
678 | #ss_grad_input = input.new_zeros(size_uj_nm1,size_uj[-1]); # to store dL/duj, [I,j] indexing.
679 | ss_grad_output = grad_output.contiguous().view((size_uj_nm1,grad_output.shape[-1])); # reshape so index [I,i].
680 | ss_uui_argmax_j = uui_argmax_j.contiguous().view((size_uj_nm1,grad_output.shape[-1])); # reshape so index [I,i].
681 |
682 | # assign the entries k_i = argmax_{j in Omega_i} uj, reshaped so [*,j] = val[*,j].
683 | flag_method = 'method1';
684 | if flag_method == 'method1':
685 |
686 | flag_time_it = False;
687 | if flag_time_it:
688 | time_0 = time.time();
689 |
690 | I = torch.arange(0,size_uj_nm1,dtype=torch.int64,device=input.device);
691 | vec_ones = torch.ones(grad_output.shape[-1],dtype=torch.int64,device=input.device);
692 | II = torch.ger(I.float(),vec_ones.float()); # careful int --> float conv
693 | II = II.flatten();
694 | JJ = ss_uui_argmax_j.flatten();
695 | IJ_indices1 = torch.stack((II,JJ.float())).long();
696 |
697 | i_index = torch.arange(0,grad_output.shape[-1],dtype=torch.int64,device=input.device);
698 | vec_ones = torch.ones(size_uj_nm1,dtype=torch.int64,device=input.device);
699 | KK = torch.ger(vec_ones.float(),i_index.float()); # careful int --> float conv
700 | KK = KK.flatten();
701 | IJ_indices2 = torch.stack((II,KK)).long();
702 |
703 | # We aim to compute dL/duj = dL/d\bar{u}_i*d\bar{u}_i/duj.
704 | #
705 | # This is done efficiently by constructing a sparse matrix using how \bar{u}_i
706 | # depends on the uj. Sometimes the same uj contributes multiple times to
707 | # a given \bar{u}_i entry, so we add together those contributions, as would
708 | # occur in an explicit multiplication of the terms above for dL/duj.
709 | # This is acheived efficiently using the .add() for sparse tensors in PyTorch.
710 |
711 | # We construct entries of the sparse matrix and coelesce them (add repeats).
712 | vals = ss_grad_output[IJ_indices2[0,:],IJ_indices2[1,:]]; # @optimize, maybe just flatten
713 | N1 = size_uj_nm1; N2 = size_uj[-1]; sz = torch.Size([N1,N2]);
714 | ss_grad_input = torch.sparse.FloatTensor(IJ_indices1,vals,sz).coalesce().to_dense();
715 |
716 | if flag_time_it:
717 | time_1 = time.time();
718 |
719 | print("time: backward(): compute ss_grad_input = %.4e sec"%(time_1 - time_0));
720 |
721 | elif flag_method == 'method2':
722 | II = torch.arange(0,size_uj_nm1,dtype=torch.int64);
723 | i_index = torch.arange(0,grad_output.shape[-1],dtype=torch.int64);
724 |
725 | # @optimize by vectorizing this calculation
726 | for I in II:
727 | for j in range(0,i_index.shape[0]):
728 | ss_grad_input[I,ss_uui_argmax_j[I,j]] += ss_grad_output[I,i_index[j]];
729 |
730 | else:
731 | raise Exception("flag_method type not recognized.\n flag_method = %s"%flag_method);
732 |
733 | # reshape
734 | grad_input = ss_grad_input.view(*size_uj[0:input_dim - 1],size_uj[-1]);
735 |
736 | if flag_time_it:
737 | msg = 'atzGMLS_MaxPool2D_Function->backward(): ';
738 | msg += 'elapsed_time = %.4e'%(time.time() - time_11);
739 | print(msg);
740 |
741 | return grad_input,None,None,None,None,None,None; # since no trainable parts for components of this map
742 |
743 | class ExtractFromTuple_Function(torch.autograd.Function):
744 | r"""Extracts from a tuple of outputs one of the components."""
745 |
746 | @staticmethod
747 | def forward(ctx,input,index):
748 | r"""Extracts tuple entry with the specified index."""
749 | ctx.atz_name = 'ExtractFromTuple_Function';
750 |
751 | extracted = input[index];
752 | output = extracted.clone(); # clone added for safety
753 |
754 | return output;
755 |
756 | @staticmethod
757 | def backward(ctx,grad_output): # number grad's needs to match outputs of forward
758 | r"""Computes gradient of the extraction."""
759 |
760 | raise Exception('This backward is not implemented, since PyTorch automatically handled this in the past.');
761 | return None,None;
762 |
763 | # ====================================
764 | # Custom Modules
765 | # ====================================
766 | class PdbSetTraceLayer(nn.Module):
767 | r"""Allows for placing break-points within the call sequence of layers using pdb.set_trace(). Helpful for debugging networks."""
768 |
769 | def __init__(self):
770 | r"""Initialization (currently nothing to do, but call super-class)."""
771 | super(PdbSetTraceLayer, self).__init__()
772 |
773 | def forward(self, input):
774 | r"""Executes a PDB breakpoint inside of a running network to help with debugging."""
775 | out = input.clone(); # added clone to avoid .grad_fn overwrite
776 | pdb.set_trace();
777 | return out;
778 |
779 | class ExtractFromTuple(nn.Module):
780 | r"""Extracts from a tuple of outputs one of the components."""
781 |
782 | def __init__(self,index=0):
783 | r"""Initializes the index to extract."""
784 | super(ExtractFromTuple, self).__init__()
785 | self.index = index;
786 |
787 | def forward(self, input):
788 | r"""Extracts the tuple entry with the specified index."""
789 | extracted = input[self.index];
790 | extracted_clone = extracted.clone(); # cloned to avoid overwrite of .grad_fn
791 | return extracted_clone;
792 |
793 | class ReshapeLayer(nn.Module):
794 | r"""Performs reshaping of a tensor output within a network."""
795 |
796 | def __init__(self,reshape,permute=None):
797 | r"""Initializes the reshaping form to use followed by the indexing permulation to apply."""
798 | super(ReshapeLayer, self).__init__()
799 | self.reshape = reshape;
800 | self.permute = permute;
801 |
802 | def forward(self, input):
803 | r"""Reshapes the tensor followed by applying a permutation to the indexing."""
804 | reshape = self.reshape;
805 | permute = self.permute;
806 | A = input.contiguous();
807 | out = A.view(*reshape);
808 | if permute is not None:
809 | out = out.permute(*permute);
810 | return out;
811 |
812 | class PermuteLayer(nn.Module):
813 | r"""Performs permutation of indices of a tensor output within a network."""
814 |
815 | def __init__(self,permute=None):
816 | r"""Initializes the indexing permuation to apply to tensors."""
817 | super(PermuteLayer, self).__init__()
818 | self.permute = permute;
819 |
820 | def forward(self, input):
821 | r"""Applies and indexing permuation to the input tensor."""
822 | permute = self.permute;
823 | input_clone = input.clone(); # adding clone to avoid .grad_fn overwrites
824 | out = input_clone.permute(*permute);
825 | return out;
826 |
827 | class MLP_Pointwise(nn.Module):
828 | r"""Creates a collection of multilayer perceptrons (MLPs) for each output channel.
829 | The MLPs are then applied at each target point xi.
830 | """
831 |
832 | def create_mlp_unit(self,layer_sizes,unit_name='',flag_bias=True):
833 | r"""Creates an instance of an MLP with specified layer sizes. """
834 | layer_dict = OrderedDict();
835 | NN = len(layer_sizes);
836 | for i in range(NN - 1):
837 | key_str = unit_name + ':hidden_layer_%.4d'%(i + 1);
838 | layer_dict[key_str] = nn.Linear(layer_sizes[i], layer_sizes[i+1],bias=flag_bias);
839 | if i < NN - 2: # last layer should be linear
840 | key_str = unit_name + ':relu_%.4d'%(i + 1);
841 | layer_dict[key_str] = nn.ReLU();
842 |
843 | mlp_unit = nn.Sequential(layer_dict); # uses ordered dictionary to create network
844 |
845 | return mlp_unit;
846 |
847 | def __init__(self,layer_sizes,channels_in=1,channels_out=1,flag_bias=True,flag_verbose=0):
848 | r"""Initializes the structure of the pointwise MLP module with layer sizes, number input channels, number of output channels.
849 |
850 | Args:
851 | layer_sizes (list): The number of hidden units in each layer.
852 | channels_in (int): The number of input channels.
853 | channels_out (int): The number of output channels.
854 | flag_bias (bool): If the MLP should include the additive bias b added into layers.
855 | flag_verbose (int): The level of messages generated on progress of the calculation.
856 |
857 | """
858 | super(MLP_Pointwise, self).__init__();
859 |
860 | self.layer_sizes = layer_sizes;
861 | self.flag_bias = flag_bias;
862 | self.depth = len(layer_sizes);
863 |
864 | self.channels_in = channels_in;
865 | self.channels_out = channels_out;
866 |
867 | # create intermediate layers
868 | mlp_list = nn.ModuleList();
869 | layer_sizes_unit = layer_sizes.copy(); # we use inputs k*c to cross channels in practice in our unit MLPs
870 | layer_sizes_unit[0] = layer_sizes_unit[0]*channels_in; # modify the input to have proper size combined k*c
871 | for ell in range(channels_out):
872 | mlp_unit = self.create_mlp_unit(layer_sizes_unit,'unit_ell_%.4d'%ell,flag_bias=flag_bias);
873 | mlp_list.append(mlp_unit);
874 |
875 | self.mlp_list = mlp_list;
876 |
877 | def forward(self, input, params = None):
878 | r"""Applies the specified MLP pointwise to the collection of input data to produce pointwise entries of the output channels."""
879 | #
880 | # Assumes the tensor has the form [i1,i2,...in,k,c], the last two indices are the
881 | # channel index k, and the coefficient index c, combine for ease of use, but can reshape.
882 | # We collapse input tensor with indexing [i1,i2,...in,k,c] to a [I,k*c] tensor, where
883 | # I is general index, k is channel, and c are coefficient index.
884 | #
885 | s = input.shape;
886 | num_dim = input.dim();
887 |
888 | if (s[-2] != self.channels_in) or (s[-1] != self.layer_sizes[0]): # check correct sized inputs
889 | print("input.shape = " + str(input.shape));
890 | raise Exception("MLP assumes an input tensor of size [*,%d,%d]"%(self.channels_in,self.layer_sizes[0]));
891 |
892 | calc_size1 = 1.0;
893 | for d in range(num_dim-2):
894 | calc_size1 *= s[d];
895 | calc_size1 = int(calc_size1);
896 |
897 | x = input.contiguous().view(calc_size1,s[num_dim-2]*s[num_dim-1]); # shape input to have indexing [I,k*NN + c]
898 |
899 | if params is None:
900 | output = torch.zeros((self.channels_out,x.shape[0]),device=input.device); # shape [ell,*]
901 | for ell in range(self.channels_out):
902 | mlp_q = self.mlp_list[ell];
903 | output[ell,:] = mlp_q.forward(x).squeeze(-1); # reduce from [N,1] to [N]
904 |
905 | s = input.shape;
906 | output = output.view(self.channels_out,*s[0:num_dim-2]); # shape to have index [ell,i1,i2,...,in]
907 | nn = output.dim();
908 | p_ind = np.arange(nn) + 1;
909 | p_ind[nn-1] = 0;
910 | p_ind = tuple(p_ind);
911 | output = output.permute(p_ind); # [*,ell] indexing of final shape
912 | else:
913 | raise Exception("Not yet implemented for setting parameters.");
914 |
915 | return output; # [*,ell] indexing of final shape
916 |
917 | def to(self, device):
918 | r"""Moves data to GPU or other specified device."""
919 | super(MLP_Pointwise, self).to(device);
920 | for ell in range(self.channels_out):
921 | mlp_q = self.mlp_list[ell];
922 | mlp_q.to(device);
923 | return self;
924 |
925 |
926 | class MLP1(nn.Module):
927 | r"""Creates a multilayer perceptron (MLP). """
928 |
929 | def __init__(self, layer_sizes, flag_bias = True, flag_verbose=0):
930 | r"""Initializes MLP and specified layer sizes."""
931 | super(MLP1, self).__init__();
932 |
933 | self.layer_sizes = layer_sizes;
934 | self.flag_bias = flag_bias;
935 | self.depth = len(layer_sizes);
936 |
937 | # create intermediate layers
938 | layer_dict = OrderedDict();
939 | NN = len(layer_sizes);
940 | for i in range(NN - 1):
941 | key_str = 'hidden_layer_%.4d'%(i + 1);
942 | layer_dict[key_str] = nn.Linear(layer_sizes[i], layer_sizes[i+1],bias=flag_bias);
943 | if i < NN - 2: # last layer should be linear
944 | key_str = 'relu_%.4d'%(i + 1);
945 | layer_dict[key_str] = nn.ReLU();
946 |
947 | self.layers = nn.Sequential(layer_dict); # uses ordered dictionary to create network
948 |
949 | def forward(self, input, params = None):
950 | r"""Applies the MLP to the input data.
951 |
952 | Args:
953 | input (Tensor): The coefficient channel data organized as one stacked
954 | vector of size Nc*M, where Nc is number of channels and M is number of
955 | coefficients per channel.
956 |
957 | Returns:
958 | Tensor: The evaluation of the network. Returns tensor of size [batch,1].
959 |
960 | """
961 | # evaluate network with specified layers
962 | if params is None:
963 | eval = self.layers.forward(input);
964 | else:
965 | raise Exception("Not yet implemented for setting parameters.");
966 |
967 | return eval;
968 |
969 | def to(self, device):
970 | r"""Moves data to GPU or other specified device."""
971 | super(MLP1, self).to(device);
972 | self.layers = self.layers.to(device);
973 | return self;
974 |
975 | class MapToPoly(nn.Module):
976 | r"""
977 | This layer processes a collection of scattered data points consisting of a collection
978 | of values :math:`u_j` at points :math:`x_j`. For a collection of target points
979 | :math:`x_i`, local least-squares problems are solved for obtaining a local representation
980 | of the data over a polynomial space. The layer outputs a collection of polynomial
981 | coefficients :math:`c(x_i)` at each point and the collection of target points :math:`x_i`.
982 | """
983 | def __init__(self, porder, weight_func, weight_func_params, pts_x1,
984 | epsilon = None,pts_x2 = None,tree_points = None,
985 | device = None,flag_verbose = 0,**extra_params):
986 | r"""Initializes the layer for mapping between field data uj at points xj to the
987 | local polynomial reconstruction represented by coefficients ci at target points xi.
988 |
989 | Args:
990 | porder (int): Order of the basis to use. For polynomial basis is the degree.
991 | weight_func (func): Weight function to use.
992 | weight_func_params (dict): Weight function parameters.
993 | pts_x1 (Tensor): The collection of domain points :math:`x_j`.
994 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params).
995 | pts_x2 (Tensor): The collection of target points :math:`x_i`.
996 | tree_points (dict): Stored data to help speed up repeated calculations.
997 | device: Device on which to perform calculations (GPU or other, default is CPU).
998 | flag_verbose (int): Level of reporting on progress during the calculations.
999 | **extra_params: Extra parameters allowing for specifying layer name and caching mode.
1000 |
1001 | """
1002 | super(MapToPoly, self).__init__();
1003 |
1004 | self.flag_verbose = flag_verbose;
1005 |
1006 | if device is None:
1007 | device = torch.device('cpu');
1008 |
1009 | self.device = device;
1010 |
1011 | if 'name' in extra_params:
1012 | self.name = extra_params['name'];
1013 | else:
1014 | self.name = "default_name";
1015 |
1016 | if 'flag_cache_mode' in extra_params:
1017 | flag_cache_mode = extra_params['flag_cache_mode'];
1018 | else:
1019 | flag_cache_mode = 'generate1';
1020 |
1021 | if flag_cache_mode == 'generate1': # setup from scratch
1022 | self.porder = porder;
1023 | self.weight_func = weight_func;
1024 | self.weight_func_params = weight_func_params;
1025 |
1026 | self.pts_x1 = pts_x1;
1027 | self.pts_x2 = pts_x2;
1028 |
1029 | self.pts_x1_numpy = None;
1030 | self.pts_x2_numpy = None;
1031 |
1032 | if self.pts_x2 is None:
1033 | self.pts_x2 = pts_x1;
1034 |
1035 | self.epsilon = epsilon;
1036 |
1037 | if tree_points is None: # build kd-tree of points for neighbor listing
1038 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy();
1039 | self.tree_points = spatial.cKDTree(self.pts_x1_numpy);
1040 |
1041 | if device is None:
1042 | device = torch.device('cpu');
1043 |
1044 | self.device = device;
1045 |
1046 | self.cached_data = {}; # create empty cache for storing data
1047 | generate_mapping = MapToPoly_Function.generate_mapping;
1048 | self.cached_data['map_data'] = generate_mapping(self.weight_func,self.weight_func_params,
1049 | self.porder,self.epsilon,
1050 | self.pts_x1,self.pts_x2,
1051 | self.tree_points,self.device,
1052 | self.flag_verbose);
1053 |
1054 | elif flag_cache_mode == 'load_from_file': # setup by loading data from cache file
1055 |
1056 | if 'cache_filename' in extra_params:
1057 | cache_filename = extra_params['cache_filename'];
1058 | else:
1059 | raise Exception('No cache_filename specified.');
1060 |
1061 | self.load_cache_data(cache_filename); # load data from file
1062 |
1063 | else:
1064 | print("flag_cache_mode = " + str(flag_cache_mode));
1065 | raise Exception('flag_cache_mode is invalid.');
1066 |
1067 | def save_cache_data(self,cache_filename):
1068 | r"""Save needed matrices and related data to .pickle for later cached use. (Warning: prototype codes here currently and not tested)."""
1069 | # collect the data to save
1070 | d = {};
1071 | d['porder'] = self.porder;
1072 | d['epsilon'] = self.epsilon;
1073 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy();
1074 | d['pts_x1'] = self.pts_x1_numpy;
1075 | if self.pts_x2_numpy is None: self.pts_x2_numpy = pts_x2.cpu().numpy();
1076 | d['pts_x2'] = self.pts_x2_numpy;
1077 | d['weight_func_str'] = str(self.weight_func);
1078 | d['weight_func_params'] = self.weight_func_params;
1079 | d['version'] = __version__; # Module version
1080 | d['cached_data'] = self.cached_data;
1081 |
1082 | # write the data to disk
1083 | f = open(cache_filename,'wb');
1084 | p.dump(d,f); # load the data from file
1085 | f.close();
1086 |
1087 | def load_cache_data(self,cache_filename):
1088 | r"""Load the needed matrices and related data from .pickle. (Warning: prototype codes here currently and not tested)."""
1089 | f = open(cache_filename,'rb');
1090 | d = p.load(f); # load the data from file
1091 | f.close();
1092 |
1093 | print(d.keys())
1094 | self.porder = d['porder'];
1095 | self.epsilon = d['epsilon'];
1096 |
1097 | self.weight_func = d['weight_func_str'];
1098 | self.weight_func_params = d['weight_func_params'];
1099 |
1100 | self.pts_x1 = torch.from_numpy(d['pts_x1']).to(device);
1101 | self.pts_x2 = torch.from_numpy(d['pts_x2']).to(device);
1102 |
1103 | self.pts_x1_numpy = d['pts_x1'];
1104 | self.pts_x2_numpy = d['pts_x2'];
1105 |
1106 | if self.pts_x2 is None:
1107 | self.pts_x2 = pts_x1;
1108 |
1109 | # build kd-tree of points for neighbor listing
1110 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy();
1111 | self.tree_points = spatial.cKDTree(self.pts_x1_numpy);
1112 |
1113 | self.cached_data = d['cached_data'];
1114 |
1115 | def eval_poly(self,pts_x,pts_x2_i0,c_star_i0,porder=None,flag_verbose=None):
1116 | r"""Evaluates the polynomial reconstruction around a given target point pts_x2_i0."""
1117 | if porder is None:
1118 | porder = self.porder;
1119 | if flag_verbose is None:
1120 | flag_verbose = self.flag_verbose;
1121 | MapToPoly_Function.eval_poly(pts_x,pts_x2_i0,c_star_i0,porder,flag_verbose);
1122 |
1123 | def forward(self, input): # define the action of this layer
1124 | r"""For a field u specified at points xj, performs the mapping to coefficients c at points xi, (uj,xj) :math:`\rightarrow` (ci,xi)."""
1125 | flag_time_it = False;
1126 | if flag_time_it:
1127 | time_1 = time.time();
1128 |
1129 | # We evaluate the action of the function, backward will be called automatically when computing gradients.
1130 | uj = input;
1131 | output = MapToPoly_Function.apply(uj,self.porder,
1132 | self.weight_func,self.weight_func_params,
1133 | self.pts_x1,self.epsilon,self.pts_x2,
1134 | self.cached_data,self.tree_points,self.device,
1135 | self.flag_verbose);
1136 | if flag_time_it:
1137 | msg = 'MapToPoly->forward(): ';
1138 | msg += 'elapsed_time = %.4e'%(time.time() - time_1);
1139 | print(msg);
1140 |
1141 | return output;
1142 |
1143 | def extra_repr(self):
1144 | r"""Displays information associated with this module."""
1145 | # Display some extra information about this layer.
1146 | return 'porder={}, weight_func={}, weight_func_params={}, pts_x1={}, pts_x2={}'.format(
1147 | self.porder, self.weight_func, self.weight_func_params, self.pts_x1.shape, self.pts_x2.shape
1148 | );
1149 |
1150 | def to(self, device):
1151 | r"""Moves data to GPU or other specified device."""
1152 | super(MapToPoly,self).to(device);
1153 | self.pts_x1 = self.pts_x1.to(device);
1154 | self.pts_x2 = self.pts_x2.to(device);
1155 | return self;
1156 |
1157 | class MaxPoolOverPoints(nn.Module):
1158 | r"""Applies a max-pooling operation to obtain values :math:`v_i = \max_{j \in \mathcal{N}_i(\epsilon)} \{u_j\}.` """
1159 | def __init__(self,pts_x1,epsilon=None,pts_x2=None,
1160 | indices_xj_i_cache=None,tree_points=None,
1161 | device=None,flag_verbose=0,**extra_params):
1162 | r"""Setup of max-pooling operation.
1163 |
1164 | Args:
1165 | pts_x1 (Tensor): The collection of domain points :math:`x_j`. We assume size [num_pts,num_dim].
1166 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params).
1167 | pts_x2 (Tensor): The collection of target points :math:`x_i`.
1168 | indices_xj_i_cache (dict): Stored data to help speed up repeated calculations.
1169 | tree_points (dict): Stored data to help speed up repeated calculations.
1170 | device: Device on which to perform calculations (GPU or other, default is CPU).
1171 | flag_verbose (int): Level of reporting on progress during the calculations.
1172 | **extra_params (dict): Extra parameters allowing for specifying layer name and caching mode.
1173 |
1174 | """
1175 | super(MaxPoolOverPoints,self).__init__();
1176 |
1177 | self.flag_verbose = flag_verbose;
1178 |
1179 | if device is None:
1180 | device = torch.device('cpu');
1181 |
1182 | self.device = device;
1183 |
1184 | if 'name' in extra_params:
1185 | self.name = extra_params['name'];
1186 | else:
1187 | self.name = "default_name";
1188 |
1189 | if 'flag_cache_mode' in extra_params:
1190 | flag_cache_mode = extra_params['flag_cache_mode'];
1191 | else:
1192 | flag_cache_mode = 'generate1';
1193 |
1194 | if flag_cache_mode == 'generate1': # setup from scratch
1195 | self.pts_x1 = pts_x1;
1196 | self.pts_x2 = pts_x2;
1197 |
1198 | self.pts_x1_numpy = None;
1199 | self.pts_x2_numpy = None;
1200 |
1201 | if self.pts_x2 is None:
1202 | self.pts_x2 = pts_x1;
1203 |
1204 | self.epsilon = epsilon;
1205 |
1206 | if tree_points is None: # build kd-tree of points for neighbor listing
1207 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy();
1208 | self.tree_points = spatial.cKDTree(self.pts_x1_numpy);
1209 |
1210 | if indices_xj_i_cache is None:
1211 | self.indices_xj_i_cache = None; # cache the neighbor lists around each xi
1212 | else:
1213 | self.indices_xj_i_cache = indices_xj_i_cache;
1214 |
1215 | if device is None:
1216 | device = torch.device('cpu');
1217 |
1218 | self.device = device;
1219 |
1220 | self.cached_data = {}; # create empty cache for storing data
1221 |
1222 | elif flag_cache_mode == 'load_from_file': # setup by loading data from cache file
1223 |
1224 | if 'cache_filename' in extra_params:
1225 | cache_filename = extra_params['cache_filename'];
1226 | else:
1227 | raise Exception('No cache_filename specified.');
1228 |
1229 | self.load_cache_data(cache_filename); # load data from file
1230 |
1231 | else:
1232 | print("flag_cache_mode = " + str(flag_cache_mode));
1233 | raise Exception('flag_cache_mode is invalid.');
1234 |
1235 | def save_cache_data(self,cache_filename):
1236 | r"""Save data to .pickle file for caching. (Warning: Prototype placeholder code.)"""
1237 | # collect the data to save
1238 | d = {};
1239 | d['epsilon'] = self.epsilon;
1240 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy();
1241 | d['pts_x1'] = self.pts_x1_numpy;
1242 | if self.pts_x2_numpy is None: self.pts_x2_numpy = pts_x2.cpu().numpy();
1243 | d['pts_x2'] = self.pts_x2_numpy;
1244 | d['version'] = __version__; # Module version
1245 | d['cached_data'] = self.cached_data;
1246 |
1247 | # write the data to disk
1248 | f = open(cache_filename,'wb');
1249 | p.dump(d,f); # load the data from file
1250 | f.close();
1251 |
1252 | def load_cache_data(self,cache_filename):
1253 | r"""Load data to .pickle file for caching. (Warning: Prototype placeholder code.)"""
1254 | f = open(cache_filename,'rb');
1255 | d = p.load(f); # load the data from file
1256 | f.close();
1257 |
1258 | print(d.keys())
1259 | self.epsilon = d['epsilon'];
1260 |
1261 | self.pts_x1 = torch.from_numpy(d['pts_x1']).to(device);
1262 | self.pts_x2 = torch.from_numpy(d['pts_x2']).to(device);
1263 |
1264 | self.pts_x1_numpy = d['pts_x1'];
1265 | self.pts_x2_numpy = d['pts_x2'];
1266 |
1267 | if self.pts_x2 is None:
1268 | self.pts_x2 = pts_x1;
1269 |
1270 | # build kd-tree of points for neighbor listing
1271 | if self.pts_x1_numpy is None: self.pts_x1_numpy = pts_x1.cpu().numpy();
1272 | self.tree_points = spatial.cKDTree(self.pts_x1_numpy);
1273 |
1274 | self.cached_data = d['cached_data'];
1275 |
1276 | def forward(self, input): # define the action of this layer
1277 | r"""Applies a max-pooling operation to obtain values :math:`v_i = \max_{j \in \mathcal{N}_i(\epsilon)} \{u_j\}.`
1278 |
1279 | Args:
1280 | input (Tensor): The collection uj of field values at the points xj.
1281 |
1282 | Returns:
1283 | Tensor: The collection of field values vi at the target points xi.
1284 |
1285 | """
1286 | flag_time_it = False;
1287 | if flag_time_it:
1288 | time_1 = time.time();
1289 |
1290 | uj = input;
1291 | output = MaxPoolOverPoints_Function.apply(uj,self.pts_x1,self.epsilon,self.pts_x2,
1292 | self.indices_xj_i_cache,self.tree_points,
1293 | self.flag_verbose);
1294 |
1295 | if flag_time_it:
1296 | msg = 'MaxPoolOverPoints->forward(): ';
1297 | msg += 'elapsed_time = %.4e'%(time.time() - time_1);
1298 | print(msg);
1299 |
1300 | return output;
1301 |
1302 | def extra_repr(self):
1303 | r"""Displays information associated with this module."""
1304 | return 'pts_x1={}, pts_x2={}'.format(self.pts_x1.shape, self.pts_x2.shape);
1305 |
1306 | def to(self, device):
1307 | r"""Moves data to GPU or other specified device."""
1308 | super(MaxPoolOverPoints, self).to(device);
1309 | self.pts_x1 = self.pts_x1.to(device);
1310 | self.pts_x2 = self.pts_x2.to(device);
1311 | return self;
1312 |
1313 | class GMLS_Layer(nn.Module):
1314 | r"""The GMLS-Layer processes scattered data by using Generalized Moving Least
1315 | Squares (GMLS) to construct a local reconstruction of the data (here polynomials).
1316 | This is represented by coefficients that are mapped to approximate the action of
1317 | linear or non-linear operators on the input field.
1318 |
1319 | As depicted above, the architecture processes a collection of input channels
1320 | into intermediate coefficient channels. The coefficient channels are
1321 | then collectively mapped to output channels. The mappings can be any unit
1322 | for which back-propagation can be performed. This includes linear
1323 | layers or non-linear maps based on multilayer perceptrons (MLPs).
1324 |
1325 | Examples:
1326 |
1327 | Here is a typical way to construct a GMLS-Layer. This is done in
1328 | the following stages.
1329 |
1330 | ``(i)`` Construct the scattered data locations xj, xi at which processing will occur. Here, we create points in 2D.
1331 |
1332 | >>> xj = torch.randn((100,2),device=device); xi = torch.randn((100,2),device=device);
1333 |
1334 | ``(ii)`` Construct the mapping unit that will be applied pointwise. Here we create an MLP
1335 | with Nc input coefficient channels and channels_out output channels.
1336 |
1337 | >>> layer_sizes = [];
1338 | >>> num_input = Nc*num_polys; # number of channels (NC) X number polynomials (num_polys) (cross-channel coupling allowed)
1339 | >>> num_depth = 4; num_hidden = 100; channels_out = 16; # depth, width, number of output filters
1340 | >>> layer_sizes.append(num_polys);
1341 | >>> for k in range(num_depth):
1342 | >>> layer_sizes.append(num_hidden);
1343 | >>> layer_sizes.append(1); # a single unit always gives scalar output, we then use channels_out units.
1344 | >>> mlp_q_map1 = gmlsnets_pytorch.nn.MLP_Pointwise(layer_sizes,channels_out=channels_out);
1345 |
1346 | ``(iii)`` Create the GMLS-Layer using these components.
1347 |
1348 | >>> weight_func1 = gmlsnets_pytorch.nn.MapToPoly_Function.weight_one_minus_r;
1349 | >>> weight_func_params = {'epsilon':1e-3,'p'=4};
1350 | >>> gmls_layer_params = {
1351 | 'flag_case':'standard','porder':4,'Nc':3,
1352 | 'mlp_q1':mlp_q_map1,
1353 | 'pts_x1':xj,'pts_x2':xi,'epsilon':1e-3,
1354 | 'weight_func1':weight_func1,'weight_func1_params':weight_func1_params,
1355 | 'device':device,'flag_verbose':0
1356 | };
1357 | >>> gmls_layer=gmlsnets_pytorch.nn.GMLS_Layer(**gmls_layer_params);
1358 |
1359 | Here is an example of how a GMLS-Layer and other modules in this package
1360 | can be used to process scattered data. This could be part of a larger
1361 | neural network in practice (see example codes for more information). For instance,
1362 |
1363 | >>> layer1 = nn.Sequential(gmls_layer, # produces output tuple of tensors (ci,xi) with shapes ([batch,ci,xi],[xi]).
1364 | #PdbSetTraceLayer(),
1365 | ExtractFromTuple(index=0), # from output keep only the ui part and discard the xi part.
1366 | #PdbSetTraceLayer(),
1367 | PermuteLayer((0,2,1)) # organize indexing to be [batch,xi,ci], for further processing.
1368 | ).to(device);
1369 |
1370 | You can uncomment the PdbSetTraceLayer() to get breakpoints for state information and tensor shapes during processing.
1371 | The PermuteLayer() changes the order of the indexing. Also can use ReshapeLayer() to reshape the tensors, which is
1372 | especially useful for processing data related to CNNs.
1373 |
1374 | Much of the construction can be further simplified by writing a few wrapper classes for your most common use cases.
1375 |
1376 | More information also can be found in the example codes directory.
1377 | """
1378 | def __init__(self, flag_case, porder, pts_x1, epsilon, weight_func, weight_func_params,
1379 | mlp_q = None,pts_x2 = None, device = None, flag_verbose = 0):
1380 | r"""
1381 | Initializes the GMLS layer.
1382 |
1383 | Args:
1384 | flag_case (str): Flag for the type of architecture to use (default is 'standard').
1385 | porder (int): Order of the basis to use (polynomial degree).
1386 | pts_x1 (Tensor): The collection of domain points :math:`x_j`.
1387 | epsilon (float): The :math:`\epsilon`-neighborhood size to use to sort points (should be compatible with choice of weight_func_params).
1388 | weight_func (func): Weight function to use.
1389 | weight_func_params (dict): Weight function parameters.
1390 | mlp_q (module): Mapping q unit for computing :math:`q(c)`, where c are the coefficients.
1391 | pts_x2 (Tensor): The collection of target points :math:`x_i`.
1392 | device: Device on which to perform calculations (GPU or other, default is CPU).
1393 | flag_verbose (int): Level of reporting on progress during the calculations.
1394 | """
1395 | super(GMLS_Layer, self).__init__();
1396 |
1397 | if flag_case is None:
1398 | self.flag_case = 'standard';
1399 | else:
1400 | self.flag_case = flag_case;
1401 |
1402 | if device is None:
1403 | device = torch.device('cpu');
1404 |
1405 | self.device = device;
1406 |
1407 | if self.flag_case == 'standard':
1408 | tree_points = None;
1409 | self.MapToPoly_1 = MapToPoly(porder, weight_func, weight_func_params,
1410 | pts_x1, epsilon, pts_x2, tree_points,
1411 | device, flag_verbose);
1412 |
1413 | if mlp_q is None: # if not specified then create some default custom layers
1414 | raise Exception("Need to specify the mlp_q module for mapping coefficients to output.");
1415 | else: # in this case initialized outside
1416 | self.mlp_q = mlp_q;
1417 |
1418 | else:
1419 | print("flag_case = " + str(flag_case));
1420 | print("self.flag_case = " + str(self.flag_case));
1421 | raise Exception('flag_case not valid.');
1422 |
1423 | def forward(self, input):
1424 | r"""Computes GMLS-Layer processing scattered data input field uj to obtain output field vi.
1425 |
1426 | Args:
1427 | input (Tensor): Input channels uj organized in the shape [batch,xj,uj].
1428 |
1429 | Returns:
1430 | tuple: The output channels and point locations (vi,xi). The field vi = q(ci).
1431 |
1432 | """
1433 | if self.flag_case == 'standard':
1434 |
1435 | map_output = self.MapToPoly_1.forward(input);
1436 | c_star_i = map_output[0];
1437 | pts_x2 = map_output[1];
1438 |
1439 | # MLP should apply across all channels and coefficients (coeff capture spatial, like kernel)
1440 | fc_input = c_star_i.permute((0,2,1,3)); # we organize as [batchI,ptsI,channelsI,coeffI]
1441 |
1442 | # We assume MLP can process channelI*Nc + coeffI.
1443 | # We assume output of out = fc, has shape [batchI,ptsI,channelsNew]
1444 | # Outside routines can reshape that into an nD array again for structure samples or use over scattered samples.
1445 | q_of_c_star_i = self.mlp_q.forward(fc_input);
1446 |
1447 | pts_x2_p = None; # currently returns None to simplify back-prop and debugging, but could just return the pts_x2.
1448 |
1449 | return_vals = q_of_c_star_i, pts_x2_p;
1450 |
1451 | return return_vals;
1452 |
1453 | def to(self, device):
1454 | r"""Moves data to GPU or other specified device."""
1455 | super(GMLS_Layer, self).to(device);
1456 | self.MapToPoly_1 = self.MapToPoly_1.to(device);
1457 | self.mlp_q = self.mlp_q.to(device);
1458 | return self;
1459 |
1460 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Collection of utility routines.
3 | """
4 |
5 | # Authors: B.J. Gross and P.J. Atzberger
6 | # Website: http://atzberger.org/
7 |
8 | import torch;
9 | import torch.nn as nn;
10 | import numpy as np;
11 |
12 | import pdb
13 | import time
14 |
15 | #******************************************
16 | # Custom Functions
17 | #******************************************
18 | class PeriodicPad2dFunc(torch.autograd.Function):
19 | """Performs periodic padding/tiling of a 2d lattice."""
20 |
21 | @staticmethod
22 | def forward(ctx, input, pad=None,flag_coords=False):
23 | """
24 | Periodically tiles the input into a 2d array for 3x3 repeat pattern.
25 | This allows for a quick way for handling periodic boundary conditions
26 | on the units cell.
27 |
28 | Args:
29 | input (Tensor): A tensor of size [nbatch,nchannel,nx,ny] or [nx,ny].
30 | pad (float): The pad value to use.
31 | flag_coord (bool): If beginning components are coordinates (x,y,z,...,u).
32 |
33 | We then adjust (x,y,u) --> (x + i, y + j, u) for image (i,j).
34 |
35 | Returns:
36 | output (Tensor): A tensor of the same size [nbatch,nchannel,3*nx,3*ny] or [3*nx,3*ny].
37 |
38 | """
39 | # We use alternative by concatenating the arrays together to tile.
40 | # Process Tensor inputs of the shape with [nbatch,nchannel,nx,ny].
41 | # We also allow the case of [nx,ny] for single input.
42 | nd = input.dim();
43 | if (nd > 4):
44 | raise Exception('Expects tensor that has number of dimensions <= 4 (dim = {}).'.format(nd));
45 |
46 | if (nd < 2):
47 | raise Exception('Expects tensor that has at least number of dimensions >= 2 (dim = {}).'.format(nd));
48 |
49 | a = input;
50 | if (nd == 2):
51 | # add extra dimensions so we can process all tensors the same way
52 | a = input.unsqueeze(0).unsqueeze(0); # size --> [1,1,nx,ny]
53 |
54 | w1 = w2 = a;
55 | aa = torch.cat([w1,a,w2],dim=2);
56 | h1 = h2 = aa;
57 | output = torch.cat([h1,aa,h2],dim=3);
58 |
59 | if flag_coords: # indicates (x,y,u) input, so need
60 | # to adjust x + i-1 and y + j-1, i,j = 0,1,2 to extend.
61 | coordI1 = 0; coordI2 = 1; #x and y components
62 | N1 = output.shape[2]//3; # block size
63 | N2 = output.shape[3]//3;
64 | for j in range(0,3):
65 | blockI2 = np.arange(N2*j,N2*(j + 1),dtype=int);
66 | output[:,coordI1,:,blockI2] += (j - 1);
67 |
68 | for i in range(0,3):
69 | blockI1 = np.arange(N1*i,N1*(i + 1),dtype=int);
70 | output[:,coordI2,blockI1,:] += (i - 1);
71 |
72 | if (nd == 2): # if only [nx,ny] then squeeze our extra dimensions
73 | output = output.squeeze(0).squeeze(0);
74 |
75 | ctx.pad = pad;
76 | ctx.size = input.size();
77 | ctx.numel = input.numel();
78 | ctx.num_tiles = 3;
79 |
80 | return output;
81 |
82 | @staticmethod
83 | def backward(ctx, grad_output):
84 | r"""Compute df/dx from df/dy using the Chain Rule df/dx = df/dx*dy/dx.
85 | For periodic padding we use df/dx = sum_{y_i ~ x} df/dy_i, where
86 | the y_i ~ x are all points y_i that are equivalent to x under the periodic extension.
87 | """
88 | num_tiles = ctx.num_tiles;
89 |
90 | nd = input.dim();
91 | if (nd > 4):
92 | raise Exception('Expects tensor that has number of dimensions <= 4 (dim = {}).'.format(nd));
93 |
94 | if (nd < 2):
95 | raise Exception('Expects tensor that has at least number of dimensions >= 2 (dim = {}).'.format(nd));
96 |
97 | b = grad_output; # short-hand for grad_output (size of output)
98 | if (nd == 2):
99 | # add extra dimensions so we can process all tensors the same way
100 | b = b.unsqueeze(0).unsqueeze(0); # size --> [1,1,nx,ny]
101 |
102 | # construct indices to contract the periodic images in each dimension
103 | ind_r = torch.zeros(ctx.size,dtype=torch.int64);
104 | torch.arange(0, ctx.size[2]*num_tiles, out=ind_r);
105 | ind_r = ind_r.fmod(ctx.size[2]); # repeat indices number of rows [0,1,..nrow-1,0,...nrow-1].
106 | ind_r = ind_r.view(-1);
107 |
108 | ind_c = torch.zeros(ctx.size,dtype=torch.int64);
109 | torch.arange(0, ctx.size[3]*num_tiles, out=ind_c);
110 | ind_c = ind_c.fmod(ctx.size[3]); # repeat indices number of cols [0,1,..ncol-1,0,...ncol-1].
111 | ind_c = ind_c.view(-1);
112 |
113 | c = b.new_zeros(ctx.size[0],ctx.size[1],ctx.size[2],ctx.size[3]*num_tiles);
114 |
115 | c = c.index_add(2,ind_r,grad_output); # add the rows together to start contracting periodicity
116 |
117 | d = b.new_zeros(a.transpose(2,3).size());
118 | d = d.index_add(2,ind_c,c.t()); # add the cols together to contract periodicity
119 | grad_input = d.transpose(2,3); # Note, this includes contraction of grad_output already, so is df/dx.
120 |
121 | if (nd == 2): # if only [nx,ny] then squeeze out extra dimensions
122 | output = grad_input.squeeze(0).squeeze(0);
123 |
124 | return grad_input;
125 |
126 | #------------------
127 | class ExtractUnitCell2dFunc(torch.autograd.Function):
128 | r"""Extracts the 2d unit cell from periodic tiling."""
129 |
130 | @staticmethod
131 | def forward(ctx, input, pad=None):
132 | r"""Extracts the 2d unit cell from periodic tiling.
133 |
134 | Args:
135 | input (Tensor): Tensor of size [nbatch,nchannel,nx,ny] or [nx,ny].
136 |
137 | Returns:
138 | Tensor: The Tensor of size [nbatch,nchannel,nx/3,ny/3] or [nx/3,ny/3].
139 |
140 | """
141 | num_tiles = 3;
142 | ctx.num_tiles = num_tiles;
143 |
144 | nd = input.dim();
145 | if (nd > 4):
146 | raise Exception('Expects tensor that has number of dimensions <= 4 (dim = {}).'.format(nd));
147 |
148 | if (nd < 2):
149 | raise Exception('Expects tensor that has at least number of dimensions >= 2 (dim = {}).'.format(nd));
150 |
151 | a = input; # short-hand
152 | if (nd == 2):
153 | # add extra dimensions so we can process all tensors the same way
154 | a = a.unsqueeze(0).unsqueeze(0); # size --> [1,1,nx,ny]
155 |
156 | chunks_col = torch.chunk(a,num_tiles,dim=2);
157 | aa = torch.chunk(chunks_col[1],num_tiles,dim=3); # extract middle of middle
158 |
159 | output = aa[1]; # choose middle one (assumes num_tiles==3)
160 |
161 | if (nd == 2): # if only [nx,ny] then squeeze out extra dimensions
162 | output = output.squeeze(0).squeeze(0);
163 |
164 | return output;
165 |
166 | @staticmethod
167 | def backward(ctx, grad_output):
168 | r""" Compute df/dx from df/dy using the Chain Rule df/dx = df/dx*dy/dx.
169 | For periodic padding we use df/dx = sum_{y_i ~ x} df/dy_i, where
170 | the y_i ~ x are all points y_i that are equivalent to x under the periodic extension.
171 |
172 | For extracting from periodic tiling the unit cell, the derivatives dy/dx are zero
173 | unless x is within the unit cell. The block matrix is dy/dx = [[Z,Z,Z],[Z,W,Z],[Z,Z,Z]],
174 | where W is the derivative values in the unit cell (df/dy) and Z is the zero matrix.
175 | """
176 | num_tiles = ctx.num_tiles;
177 |
178 | nd = grad_output.dim();
179 | if (nd > 4):
180 | raise Exception('Expects tensor that has number of dimensions <= 4 (dim = {}).'.format(nd));
181 |
182 | if (nd < 2):
183 | raise Exception('Expects tensor that has at least number of dimensions >= 2 (dim = {}).'.format(nd));
184 |
185 | W = grad_output; # short-hand for grad_output (size of output)
186 | if (nd == 2):
187 | # add extra dimensions so we can process all tensors the same way
188 | W = W.unsqueeze(0).unsqueeze(0); # size --> [1,1,nx,ny]
189 |
190 | s = W.size();
191 | Z = grad_output.new_zeros(s[0],s[1],s[2],s[3]);
192 | A1 = torch.cat([Z,Z,Z],dim=3);
193 | B1 = torch.cat([Z,W,Z],dim=3);
194 | bb = torch.cat([A1,B1,A1],dim=2);
195 |
196 | grad_input = bb;
197 |
198 | if (nd == 2): # if only [nx,ny] then squeeze out extra dimensions
199 | grad_input = grad_input.squeeze(0).squeeze(0);
200 |
201 | return grad_input;
202 |
203 | #******************************************
204 | # Custom Modules
205 | #******************************************
206 | class PeriodicPad2d(nn.Module):
207 | def __init__(self,flag_coords=False):
208 | r"""Setup for computing the periodic tiling."""
209 | super(PeriodicPad2d, self).__init__()
210 | self.flag_coords = flag_coords;
211 |
212 | def forward(self, input):
213 | r"""Compute the periodic padding of the input. """
214 | return PeriodicPad2dFunc.apply(input,None,self.flag_coords);
215 |
216 | def extra_repr(self):
217 | r"""Displays some of the information associated with the module. """
218 | return 'PeriodicPad2d: (no internal parameters)';
219 |
220 | class ExtractUnitCell2d(nn.Module):
221 | def __init__(self):
222 | super(ExtractUnitCell2d, self).__init__()
223 |
224 | def forward(self, input):
225 | r"""Computes the periodic padding of the input."""
226 | return ExtractUnitCell2dFunc.apply(input);
227 |
228 | def extra_repr(self):
229 | r"""Displays some of the information associated with the module. """
230 | return 'ExtractUnitCell2d: (no internal parameters)';
231 |
232 |
--------------------------------------------------------------------------------
/src/vis.py:
--------------------------------------------------------------------------------
1 | """
2 | Collection of routines helpful for visualizing results and generating figures.
3 | """
4 |
5 | # Authors: B.J. Gross and P.J. Atzberger
6 | # Website: http://atzberger.org/
7 |
8 | import matplotlib;
9 | import matplotlib as mtpl;
10 | import matplotlib.pyplot as plt;
11 | import matplotlib.gridspec as gridspec;
12 |
13 | import numpy as np;
14 |
15 | #----------------------------
16 | def plot_samples_u_f_1d(u_list,f_list,np_xj,np_xi,rows=4,cols=6,
17 | figsize = (20*0.9,10*0.9),title="Data Samples: u, f=L[u]",
18 | xlabel='x',ylabel='',
19 | left=0.125,bottom=0.1,right=0.9, top=0.94,wspace=0.4,hspace=0.4,
20 | fontsize=16,y=1.00,flag_draw=True,**extra_params):
21 |
22 | r"""Plots a collection of data samples in a panel."""
23 | # --
24 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize,sharey=False);
25 | fig.subplots_adjust(left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace);
26 | plt.suptitle(title, fontsize=fontsize,y=y);
27 |
28 | # --
29 | I1 = 0; I2 = 0;
30 | for i1 in range(0,rows):
31 | for i2 in range(0,cols):
32 | ax = axs[i1,i2];
33 |
34 | if i2 % 2 == 0:
35 | xx = np_xj[:,0]; yy = u_list[I1].numpy()[0,:];
36 | #yy = yy.squeeze(0);
37 | ax.plot(xx,yy,'m.-');
38 | if i1 == 0:
39 | ax.set_title('u');
40 | I1 += 1;
41 | else:
42 | xx = np_xi[:,0]; yy = f_list[I2].numpy()[0,:];
43 | #yy = yy.squeeze(0);
44 | ax.plot(xx,yy,'r.-');
45 | if i1 == 0:
46 | ax.set_title('f');
47 | I2 += 1;
48 |
49 | ax.set_xlabel(xlabel);
50 | ax.set_ylabel(ylabel);
51 |
52 | if flag_draw:
53 | plt.draw();
54 |
55 | def plot_samples_u_f_fp_1d(u_list,f_target_list,f_pred_list,np_xj,np_xi,rows=4,cols=6,
56 | figsize = (20*0.9,10*0.9),title="Data Samples: u, f=L[u]",
57 | xlabel='x',ylabel='',
58 | left=0.125,bottom=0.1,right=0.9, top=0.94,wspace=0.4,hspace=0.4,
59 | fontsize=16,y=1.00,flag_draw=True,**extra_params):
60 |
61 | r"""Plots a collection of data samples and predictions in a panel."""
62 | # --
63 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize,sharey=False);
64 | fig.subplots_adjust(left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace);
65 | plt.suptitle(title, fontsize=fontsize,y=y);
66 |
67 | # --
68 | I1 = 0; I2 = 0;
69 | for i1 in range(0,rows):
70 | for i2 in range(0,cols):
71 | ax = axs[i1,i2];
72 |
73 | if i2 % 2 == 0:
74 | ax.plot(np_xj[:,0],u_list[I1].numpy()[0,:],'m.-');
75 | if i1 == 0:
76 | ax.set_title('u');
77 | I1 += 1;
78 | ax.set_xlabel(xlabel);
79 | ax.set_ylabel(ylabel);
80 | else:
81 | ax.plot(np_xj[:,0],f_target_list[I2].numpy()[0,:],'r.-');
82 | ax.plot(np_xj[:,0],f_pred_list[I2].numpy()[0,:],'b.-');
83 | if i1 == 0:
84 | ax.set_title('f');
85 | I2 += 1;
86 | ax.set_xlabel(xlabel);
87 | ax.set_ylabel(ylabel);
88 |
89 | if flag_draw:
90 | plt.draw();
91 |
92 | def plot_dataset_diffOp1(dataset,np_xj=None,np_xi=None,rows=4,cols=6,II=None,
93 | figsize=(20*0.9,10*0.9),title="Data Samples: u, f=L[u]",
94 | xlabel='x',ylabel='',
95 | left=0.125,bottom=0.1,right=0.9, top=0.94,wspace=0.4,hspace=0.4,
96 | fontsize=16,y=1.00,flag_draw=True,**extra_params):
97 |
98 | r"""Plots a collection of data samples in a panel."""
99 | u_list = []; f_list = []; f_pred_list = [];
100 | num_samples = len(dataset);
101 | for I in np.arange(0,min(num_samples,int(rows*cols/2))):
102 | if II is None:
103 | u_list.append(dataset[I][0].cpu()); # just make plain lists for convenience
104 | f_list.append(dataset[I][1].cpu());
105 | else:
106 | u_list.append(dataset[II[I]][0].cpu());
107 | f_list.append(dataset[II[I]][1].cpu());
108 |
109 | # --
110 | plot_samples_u_f_1d(u_list=u_list,f_list=f_list,np_xj=np_xj,np_xi=np_xi,
111 | rows=rows,cols=cols,
112 | figsize=figsize,title=title,
113 | xlabel=xlabel,ylabel=ylabel,
114 | left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace,
115 | fontsize=fontsize,y=y,flag_draw=flag_draw,**extra_params);
116 |
117 | #----------------------------
118 | def plot_samples_u_f_2d(u_list,f_list,np_xj,np_xi,channelI_u=0,channelI_f=0,rows=4,cols=6,
119 | figsize = (20*0.9,10*0.9),title="Data Samples: u, f=L[u]",
120 | xlabel='x',ylabel='',
121 | left=0.125, bottom=0.1, right=0.9, top=0.94, wspace=0.01, hspace=0.1,
122 | fontsize=16,y=1.00,flag_draw=True,**extra_params):
123 |
124 | r"""Plots a collection of data samples in a panel."""
125 | # --
126 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize,sharey=False);
127 | fig.subplots_adjust(left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace);
128 | plt.suptitle(title, fontsize=fontsize,y=y);
129 |
130 | # --
131 | for ax in axs.flatten():
132 | ax.axis('off');
133 |
134 | # --
135 | I1 = 0; I2 = 0; Ic_u = channelI_u;Ic_f = channelI_f;
136 | for i1 in range(0,rows):
137 | for i2 in range(0,cols):
138 | ax = axs[i1,i2];
139 |
140 | if i2 % 2 == 0:
141 | uu = u_list[I1][Ic_u,:,:];
142 | ax.imshow(uu,cmap='Blues_r');
143 | if i1 == 0:
144 | ax.set_title('u');
145 | I1 += 1;
146 | else:
147 | ff = f_list[I2][Ic_f,:,:];
148 | ax.imshow(ff,cmap='Purples_r');
149 | if i1 == 0:
150 | ax.set_title('f');
151 | I2 += 1;
152 |
153 | if flag_draw:
154 | plt.draw();
155 |
156 | def plot_samples_u_f_fp_2d(u_list,f_target_list,f_pred_list,np_xj,np_xi,channelI_u=0,channelI_f=0,rows=4,cols=6,
157 | figsize = (20*0.9,10*0.9),title="Data Samples: u, f=L[u]",
158 | xlabel='x',ylabel='',
159 | left=0.125, bottom=0.1, right=0.9, top=0.94, wspace=0.01, hspace=0.1,
160 | fontsize=16,y=1.00,flag_draw=True,**extra_params):
161 |
162 | r"""Plots a collection of data samples and predictions in a panel."""
163 | # --
164 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize,sharey=False);
165 | fig.subplots_adjust(left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace);
166 | plt.suptitle(title, fontsize=fontsize,y=y);
167 |
168 | # --
169 | for ax in axs.flatten():
170 | ax.axis('off');
171 |
172 | # --
173 | Ic_u = channelI_u; Ic_f = channelI_f;
174 | I1 = 0; I2 = 0; I3 = 0;
175 | for i1 in range(0,rows):
176 | for i2 in range(0,cols):
177 | ax = axs[i1,i2];
178 |
179 | if i2 % 3 == 0:
180 | uu = u_list[I1][Ic_u,:,:];
181 | ax.imshow(uu,cmap='Blues_r');
182 | if i1 == 0:
183 | ax.set_title('u');
184 | I1 += 1;
185 | elif i2 % 3 == 1:
186 | ff = f_pred_list[I2][Ic_f,:,:];
187 | ax.imshow(ff,cmap='Purples_r');
188 | if i1 == 0:
189 | ax.set_title('f:predicted');
190 | I2 += 1;
191 | elif i2 % 3 == 2:
192 | ff = f_target_list[I3][Ic_f,:,:];
193 | ax.imshow(ff,cmap='Purples_r');
194 | if i1 == 0:
195 | ax.set_title('f:target');
196 | I3 += 1;
197 |
198 | if flag_draw:
199 | plt.draw();
200 |
201 | def plot_dataset_diffOp2(dataset,np_xj=None,np_xi=None,channelI_u=0,channelI_f=0,rows=4,cols=6,II=None,
202 | figsize=(20*0.9,10*0.9),title="Data Samples: u, f=L[u]",
203 | xlabel='x',ylabel='',
204 | left=0.125, bottom=0.1, right=0.9, top=0.94, wspace=0.01, hspace=0.1,
205 | fontsize=16,y=1.00,flag_draw=True,**extra_params):
206 |
207 |
208 |
209 | r"""Plots a collection of data samples in a panel."""
210 | u_list = []; f_list = []; f_pred_list = [];
211 | num_samples = len(dataset);
212 | for I in np.arange(0,min(num_samples,int(rows*cols/2))):
213 | if II is None:
214 | u_list.append(dataset[I][0].cpu()); # just make plain lists for convenience
215 | f_list.append(dataset[I][1].cpu());
216 | else:
217 | u_list.append(dataset[II[I]][0].cpu());
218 | f_list.append(dataset[II[I]][1].cpu());
219 |
220 | # --
221 | plot_samples_u_f_2d(u_list=u_list,f_list=f_list,np_xj=np_xj,np_xi=np_xi,
222 | channelI_u=channelI_u,channelI_f=channelI_f,
223 | rows=rows,cols=cols,
224 | figsize=figsize,title=title,
225 | xlabel=xlabel,ylabel=ylabel,
226 | left=left,bottom=bottom,right=right,top=top,wspace=wspace,hspace=hspace,
227 | fontsize=fontsize,y=y,flag_draw=flag_draw,**extra_params);
228 |
229 | #----------------------------
230 | def plot_images_in_array(axs,img_arr,label_arr=None,cmap=None, **extra_params):
231 | r"""Plots an array of images as a collection of panels."""
232 |
233 | numSamples = len(img_arr);
234 | sqrtS = int(np.sqrt(numSamples));
235 |
236 | # Default values
237 | flag_plot_rect = False;
238 | list_correct_class = None;
239 |
240 | if 'list_correct_class' in extra_params:
241 | list_correct_class = extra_params['list_correct_class'];
242 | flag_plot_rect = True;
243 |
244 | if 'flag_plot_rect' in extra_params:
245 | flag_plot_rect = extra_params['flag_plot_rect'];
246 |
247 | I = 0;
248 | for i in range(0,sqrtS):
249 | for j in range(0,sqrtS):
250 | ax = axs[i][j];
251 | img = img_arr[I];
252 |
253 | if len(img.shape) >= 3:
254 | if img.shape[2] == 1: # For BW case of (Nx,Ny,1) --> (Nx,Ny), RGB has (Nx,Ny,3).
255 | img = img.squeeze(2);
256 |
257 | if cmap is not None:
258 | ax.imshow(img, cmap=cmap);
259 | else:
260 | ax.imshow(img);
261 |
262 | if label_arr is not None:
263 | ax.set_title("%s"%label_arr[I]);
264 |
265 | ax.set_xticks([]);
266 | ax.set_yticks([]);
267 |
268 | if flag_plot_rect:
269 |
270 | if list_correct_class[I]:
271 | edge_color = 'g';
272 | else:
273 | edge_color = 'r';
274 |
275 | # draw a rectangle
276 | Nx = img.shape[0]; Ny = img.shape[1];
277 | rectangle = mtpl.patches.Rectangle((0,0),Nx-1,Ny-1,
278 | linewidth=5,edgecolor=edge_color,
279 | facecolor='none');
280 | ax.add_patch(rectangle);
281 |
282 | I += 1;
283 |
284 | def plot_image_array(img_arr,label_arr=None,title=None,figSize=(18,18),
285 | title_yp=0.95,cmap="gray",**extra_params):
286 | r"""Plots an array of images as a collection of panels."""
287 |
288 | # determine number of images we need to plot
289 | numSamples = len(img_arr);
290 | sqrtS = int(np.sqrt(numSamples));
291 | rows = sqrtS;
292 | cols = sqrtS;
293 |
294 | fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figSize);
295 |
296 | plot_images_in_array(axs,img_arr,label_arr,cmap=cmap,**extra_params);
297 |
298 | if title is None:
299 | plt.suptitle("Collection of Images", fontsize=18,y=title_yp);
300 | else:
301 | plt.suptitle(title, fontsize=18,y=title_yp);
302 |
303 | def plot_gridspec_image_array(outer_h,img_arr,label_arr=None,title=None,figSize=(18,18),title_yp=0.95,cmap="gray",title_x=0.0,title_y=1.0,title_fsize=14):
304 | r"""Plots an array of images as a collection of panels."""
305 | fig = plt.gcf();
306 |
307 | sqrtS = int(np.sqrt(len(img_arr)));
308 |
309 | inner = gridspec.GridSpecFromSubplotSpec(sqrtS, sqrtS,
310 | subplot_spec=outer_h, wspace=0.1, hspace=0.1);
311 |
312 | # collect the axes
313 | axs =[];
314 | I = 0;
315 | for i in range(0,sqrtS):
316 | axs_r = [];
317 | for j in range(0,sqrtS):
318 | ax = plt.Subplot(fig, inner[I]);
319 | fig.add_subplot(ax);
320 | axs_r.append(ax);
321 | I += 1;
322 |
323 | axs.append(axs_r);
324 |
325 | # plot the images
326 | plot_images_in_array(axs,img_arr,cmap="gray");
327 |
328 | if title is None:
329 | a = 1;
330 | else:
331 | axs[0][0].text(title_x,title_y,title,fontsize=title_fsize);
332 |
333 | #----------------------------
334 | def save_fig(baseFilename,extraLabel,flag_verbose=0,dpi_set=200,flag_pdf=False):
335 | r"""Saves figures to disk."""
336 |
337 | fig = plt.gcf();
338 | fig.patch.set_alpha(1.0);
339 | fig.patch.set_facecolor((1.0,1.0,1.0,1.0));
340 |
341 | if flag_pdf:
342 | saveFilename = '%s%s.pdf'%(baseFilename,extraLabel);
343 | if flag_verbose > 0:
344 | print('saveFilename = %s'%saveFilename);
345 | plt.savefig(saveFilename, format='pdf',dpi=dpi_set,facecolor=(1,1,1,1),alpha=1.0);
346 |
347 | saveFilename = '%s%s.png'%(baseFilename,extraLabel);
348 | if flag_verbose > 0:
349 | print('saveFilename = %s'%saveFilename);
350 | plt.savefig(saveFilename, format='png',dpi=dpi_set,facecolor=(1,1,1,1),alpha=1.0);
351 |
352 |
353 |
--------------------------------------------------------------------------------