├── figures
├── scheme.png
├── te_result.png
└── tm_result.png
├── examples
├── spheric_te
│ ├── sample.npz
│ ├── train.npz
│ └── specs_maxwell.json
└── spheric_tm
│ ├── sample.npz
│ ├── train.npz
│ └── specs_maxwell.json
├── environments.yml
├── Dataset.py
├── UNet.py
├── README.md
├── solution_maxwellnet.py
├── train_maxwellnet.py
└── MaxwellNet.py
/figures/scheme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/figures/scheme.png
--------------------------------------------------------------------------------
/figures/te_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/figures/te_result.png
--------------------------------------------------------------------------------
/figures/tm_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/figures/tm_result.png
--------------------------------------------------------------------------------
/examples/spheric_te/sample.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/examples/spheric_te/sample.npz
--------------------------------------------------------------------------------
/examples/spheric_te/train.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/examples/spheric_te/train.npz
--------------------------------------------------------------------------------
/examples/spheric_tm/sample.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/examples/spheric_tm/sample.npz
--------------------------------------------------------------------------------
/examples/spheric_tm/train.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/examples/spheric_tm/train.npz
--------------------------------------------------------------------------------
/environments.yml:
--------------------------------------------------------------------------------
1 | name: maxwellnet
2 | channels:
3 | - defaults
4 | - pytorch
5 | - nvidia
6 | dependencies:
7 | - python=3.7.0
8 | - nvidia::cudatoolkit=11.0
9 | - pytorch::pytorch=1.7.1
10 | - torchvision=0.8.2
11 | - scikit-image=0.19.2
12 | - jupyter=1.0.0
13 | - matplotlib=3.5.1
14 | - tensorboard=2.0.0
15 | - pip:
16 | - plyfile==0.7.4
--------------------------------------------------------------------------------
/Dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com
2 |
3 | import torch
4 | import numpy as np
5 | import os
6 |
7 |
8 | class LensDataset(torch.utils.data.Dataset):
9 | def __init__(self, data_path, mode):
10 | self.dataset = np.load(os.path.join(
11 | data_path, mode + ".npz"))['sample']
12 | self.n = np.load(os.path.join(data_path, mode + ".npz"))['n']
13 |
14 | def __len__(self):
15 | return self.dataset.shape[0]
16 |
17 | def __getitem__(self, idx):
18 | sample = self.dataset[idx, :, :, :]
19 | return sample, self.n, idx
20 |
--------------------------------------------------------------------------------
/examples/spheric_te/specs_maxwell.json:
--------------------------------------------------------------------------------
1 | {
2 | "Description": [
3 | "This experiment learns to solve the Maxwell's equations",
4 | "for a spheric lens data."
5 | ],
6 | "NetworkArch": "maxwellnet",
7 | "NetworkSpecs": {
8 | "depth": 6,
9 | "filter": 16,
10 | "norm": "weight",
11 | "up_mode": "upconv"
12 | },
13 | "PhysicalSpecs": {
14 | "wavelength": 1,
15 | "dpl": 20,
16 | "Nx": 160,
17 | "Nz": 192,
18 | "pml_thickness": 30,
19 | "symmetry_x": true,
20 | "mode": "te",
21 | "high_order": "fourth"
22 | },
23 | "Seed": 2,
24 | "LearningRate": 0.0005,
25 | "LearningRateDecay": 0.5,
26 | "LearningRateDecayStep": 50000,
27 | "Epochs": 250000,
28 | "BatchSize": 1,
29 | "SnapshotFrequency": 50000,
30 | "TensorboardFrequency": 100
31 | }
32 |
33 |
--------------------------------------------------------------------------------
/examples/spheric_tm/specs_maxwell.json:
--------------------------------------------------------------------------------
1 | {
2 | "Description": [
3 | "This experiment learns to solve the Maxwell's equations",
4 | "for a spheric lens data."
5 | ],
6 | "NetworkArch": "maxwellnet",
7 | "NetworkSpecs": {
8 | "depth": 6,
9 | "filter": 32,
10 | "norm": "weight",
11 | "up_mode": "upconv"
12 | },
13 | "PhysicalSpecs": {
14 | "wavelength": 1,
15 | "dpl": 20,
16 | "Nx": 160,
17 | "Nz": 192,
18 | "pml_thickness": 30,
19 | "symmetry_x": true,
20 | "mode": "tm",
21 | "high_order": "fourth"
22 | },
23 | "Seed": 2,
24 | "LearningRate": 0.0005,
25 | "LearningRateDecay": 0.5,
26 | "LearningRateDecayStep": 50000,
27 | "Epochs": 250000,
28 | "BatchSize": 1,
29 | "SnapshotFrequency": 50000,
30 | "TensorboardFrequency": 100
31 | }
32 |
33 |
--------------------------------------------------------------------------------
/UNet.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted and modified from https://github.com/jvanvugt/pytorch-unet
3 |
4 | Modified parts:
5 | Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com
6 |
7 | Original parts:
8 | MIT License
9 |
10 | Copyright (c) 2018 Joris
11 |
12 | Permission is hereby granted, free of charge, to any person obtaining a copy
13 | of this software and associated documentation files (the "Software"), to deal
14 | in the Software without restriction, including without limitation the rights
15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 | copies of the Software, and to permit persons to whom the Software is
17 | furnished to do so, subject to the following conditions:
18 |
19 | The above copyright notice and this permission notice shall be included in all
20 | copies or substantial portions of the Software.
21 |
22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 | SOFTWARE.
29 | """
30 |
31 | import torch
32 | from torch import nn
33 |
34 |
35 | class UNet(nn.Module):
36 | def __init__(self, in_channels=1, out_channels=2, depth=5, wf=6, norm='weight', up_mode='upconv'):
37 | super(UNet, self).__init__()
38 | assert up_mode in ('upconv', 'upsample')
39 | self.down_path = nn.ModuleList()
40 | self.up_path = nn.ModuleList()
41 |
42 | prev_channels = int(in_channels)
43 |
44 | for i in range(depth):
45 | if i != depth - 1:
46 | if i == 0:
47 | self.down_path.append(UNetConvBlock(
48 | prev_channels, [wf * (2 ** i), wf * (2 ** i)], 3, 0, norm))
49 | else:
50 | self.down_path.append(UNetConvBlock(
51 | prev_channels, [wf * (2 ** i), wf * (2 ** i)], 3, 0, norm))
52 | prev_channels = int(wf * (2 ** i))
53 | self.down_path.append(nn.AvgPool2d(2))
54 | else:
55 | self.down_path.append(UNetConvBlock(
56 | prev_channels, [wf * (2 ** i), wf * (2 ** (i - 1))], 3, 0, norm))
57 | prev_channels = int(wf * (2 ** (i - 1)))
58 |
59 | for i in reversed(range(depth - 1)):
60 | self.up_path.append(
61 | UNetUpBlock(prev_channels, [wf * (2 ** i), int(wf * (2 ** (i - 1)))], up_mode, 3, 0, norm))
62 | prev_channels = int(wf * (2 ** (i - 1)))
63 |
64 | self.last_conv = nn.Conv2d(
65 | prev_channels, out_channels, kernel_size=1, padding=0, bias=False)
66 |
67 | def forward(self, scat_pot):
68 | blocks = []
69 | x = scat_pot
70 | for i, down in enumerate(self.down_path):
71 | x = down(x)
72 | if i % 2 == 0 and i != (len(self.down_path) - 1):
73 | blocks.append(x)
74 | for i, up in enumerate(self.up_path):
75 | x = up(x, blocks[-i - 1])
76 |
77 | return self.last_conv(x)
78 |
79 |
80 | class UNetConvBlock(nn.Module):
81 | def __init__(self, in_size, out_size, kersize, padding, norm):
82 | super(UNetConvBlock, self).__init__()
83 | block = []
84 | if norm == 'weight':
85 | block.append(nn.ReplicationPad2d(1))
86 | block.append(nn.utils.weight_norm((nn.Conv2d(in_size, out_size[0], kernel_size=int(kersize),
87 | padding=int(0), bias=True)), name='weight'))
88 | block.append(nn.CELU())
89 | block.append(nn.ReplicationPad2d(1))
90 | block.append(nn.utils.weight_norm((nn.Conv2d(out_size[0], out_size[1], kernel_size=int(kersize),
91 | padding=int(0), bias=True)), name='weight'))
92 | elif norm == 'batch':
93 | block.append(nn.ReflectionPad2d(1))
94 | block.append(nn.Conv2d(in_size, out_size[0], kernel_size=int(kersize),
95 | padding=int(padding), bias=True))
96 | block.append(nn.BatchNorm2d(out_size[0]))
97 | block.append(nn.CELU())
98 |
99 | block.append(nn.ReflectionPad2d(1))
100 | block.append(nn.Conv2d(out_size[0], out_size[1], kernel_size=int(kersize),
101 | padding=int(padding), bias=True))
102 | block.append(nn.BatchNorm2d(out_size[1]))
103 |
104 | elif norm == 'no':
105 | block.append(nn.ReplicationPad2d(1))
106 | block.append((nn.Conv2d(in_size, out_size[0], kernel_size=int(kersize),
107 | padding=int(0), bias=True)))
108 | block.append(nn.CELU())
109 | block.append(nn.ReplicationPad2d(1))
110 | block.append((nn.Conv2d(out_size[0], out_size[1], kernel_size=int(kersize),
111 | padding=int(0), bias=True)))
112 |
113 | self.block = nn.Sequential(*block)
114 |
115 | def forward(self, x):
116 | out = self.block(x)
117 | return out
118 |
119 |
120 | class UNetUpBlock(nn.Module):
121 | def __init__(self, in_size, out_size, up_mode, kersize, padding, norm):
122 | super(UNetUpBlock, self).__init__()
123 | block = []
124 | if up_mode == 'upconv':
125 | block.append(nn.ConvTranspose2d(in_size, in_size,
126 | kernel_size=2, stride=2, bias=False))
127 | elif up_mode == 'upsample':
128 | block.append(nn.Upsample(mode='bilinear', scale_factor=2))
129 | block.append(nn.Conv2d(in_size, in_size,
130 | kernel_size=1, bias=False))
131 |
132 | self.block = nn.Sequential(*block)
133 | self.conv_block = UNetConvBlock(
134 | in_size * 2, out_size, kersize, padding, norm)
135 |
136 | def forward(self, x, bridge):
137 | up = self.block(x)
138 | out = torch.cat([up, bridge], 1)
139 | out = self.conv_block(out)
140 | return out
141 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **MaxwellNet**
2 | This repository is the official code implementation of the paper, "MaxwellNet: Physics-driven deep neural network training based on Maxwell’s equations" by [Joowon Lim](https://www.linkedin.com/in/joowon-lim/) and [Demetri Psaltis](https://scholar.google.com/citations?&user=-CVR2h8AAAAJ). You can refer to the following materials for the details of implementation,
3 | - [Main article](https://aip.scitation.org/doi/10.1063/5.0071616 "Main article")
4 | - [Supplementary material](https://aip.scitation.org/doi/suppl/10.1063/5.0071616 "Supplementary material")
5 |
6 | Also, we had an interview on this work,
7 | - [Scilight interview](https://aip.scitation.org/doi/full/10.1063/10.0009285 "Scilight interview")
8 |
9 | ### **Overall scheme and idea**
10 | The novelty of this work is to train a deep neural network, MaxwellNet, which solves Maxwell's equations using physics-driven loss. In other words, we are using the residual of Maxwell's equations as a loss function to train MaxwellNet, therefore, it does not require ground truth solutions to train it. Furthermore, we utilized MaxwellNet in a novel inverse design scheme, and we encourage you to refer to the [main article](https://aip.scitation.org/doi/10.1063/5.0071616 "Main article") for details.
11 |
12 |
13 | 
14 |
15 |
16 |
17 |
18 |
19 | ## **Installation**
20 | Our code is based on Windows 10, pytorch 1.7.1, CUDA 11.0, and python 3.7.
21 | We recommend using conda for installation.
22 |
23 | ```
24 | conda env create --file environment.yaml
25 | conda activate maxwellnet
26 | ```
27 |
28 | ## **Run**
29 |
30 | ### **1. MaxwellNet Training**
31 | ```
32 | python train_maxwellnet.py --directory
33 | ```
34 | In , you need to have 'train.npz' which contains the training dataset and 'specs_maxwell.json' where you specify training parameters. A brief description of the parameters can be found below. I encourage you to read the [supplementary material](https://aip.scitation.org/doi/suppl/10.1063/5.0071616 "supplementary material") to understand the parameters.
35 |
36 | | NetworkSpecs | Description |
37 | | :---: | :--- |
38 | | depth [int] | Depth of UNet. |
39 | | filter [int] | Channel numbers in the first layer of UNet. |
40 | | norm [str] | Type of normalization ('weight' for weight normalization, 'batch' for batch normalization, and 'no' for no normalization). |
41 | | up_mode [str] | Upsample mode of UNet (either 'upcov' for transpose convolution or 'upsample' for upsampling). |
42 |
43 |
44 | | PhysicalSpecs | Description |
45 | | :---: | :--- |
46 | | wavelength [float] | Wavelength in [um]. |
47 | | dpl [int] | One pixel size is 'wavelength / dpl' [um]. |
48 | | Nx [int] | Pixel number along the x-axis. This is equivalent to the pixel number along the x-axis of your scattering sample.|
49 | | Nz [int] | Pixel number along the z-axis (light propagation direction). This is equivalent to the pixel number along the z-axis of your scattering sample. |
50 | | pml_thickness [int] | Perfectly-matched-layer (PML) thickness in pixel number. 'pml_thickness * wavelength / dpl' is the actual thickness of PML layer in micrometers. |
51 | | symmetry_x [bool] | If this is True, MaxwellNet will assume your input scattering sample is symmetric along the x-axis. For example, when given a sample whose Nx and Nz are 100 and 200, respectively, if this sample is symmetric along the x-axis, you can save only half of it (Nx=50, Nz=200) in your train file (train.npz) and set 'symmetry_x' as True. |
52 | | mode [str] | 'te' or 'tm' (Transverse Electric or Transverse Magnetic). |
53 | | high_order [str] | 'second' or 'fourth'. It decides which order (second or fourth order) to calculate the gradient. 'fourth' is more accurate than 'second'. |
54 |
55 |
56 | #### **Examples**
57 | *Training for a single spheric lens.*
58 |
59 | If you just want to train a model for a single lens (which would be a good exercise as it runs for a short time), you can train MaxwellNet for a single spheric lens as followings,
60 | * TE mode.
61 | ```
62 | python train_maxwellnet.py --directory examples\spheric_te
63 | ```
64 | * TM mode.
65 | ```
66 | python train_maxwellnet.py --directory examples\spheric_tm
67 | ```
68 |
69 | *Training for multiple lenses.*
70 |
71 | You can download the datasets of multiple lenses [here](https://drive.google.com/drive/folders/1ZXPKntdBQUOyMYvmKM7Ol6woCN2Rsrqj?usp=sharing). Download and place 'lens_te' and 'lens_tm' folders under 'examples' folder.
72 | * Transverse Electric (TE) mode.
73 | ```
74 | python train_maxwellnet.py --directory examples\lens_te
75 | ```
76 | * Transverse Magnetic (TM) mode.
77 | ```
78 | python train_maxwellnet.py --directory examples\lens_tm
79 | ```
80 | The above training cases take about 37 (TE mode) and 63 (TM mode) hours on V100, respectively.
81 |
82 |
83 |
84 | ### **2. MaxwellNet Solution**
85 | If you want to check the solution found by MaxwellNet,
86 |
87 | ```
88 | python solution_maxwellnet.py --directory --model_filename --sample_filename
89 | ```
90 | It will provide the sample ( in ) to the saved model () and return the solution found by MaxwellNet, and this output will be saved as an image in as you can see in the below examples.
91 |
92 | #### **Examples**
93 | If you want to calculate the solution found by MaxwellNet for the single spheric lenses (as trained above),
94 | * TE mode.
95 | ```
96 | python solution_maxwellnet.py --directory examples\spheric_te --model_filename 250000_te_fourth.pt --sample_filename sample.npz
97 | ```
98 | * TM mode.
99 | ```
100 | python solution_maxwellnet.py --directory examples\spheric_tm --model_filename 250000_tm_fourth.pt --sample_filename sample.npz
101 | ```
102 |
103 | | Mode | Result |
104 | | :---: | :---: |
105 | | TE mode | |
106 | | TM mode | |
107 | You can find the solutions for the multiple lens training cases similarly.
108 |
109 | ## **Citation**
110 |
111 | If you find our work useful in your research, please consider citing our paper:
112 | ```
113 | @article{lim2022maxwellnet,
114 | title={MaxwellNet: Physics-driven deep neural network training based on Maxwell’s equations},
115 | author={Lim, Joowon and Psaltis, Demetri},
116 | journal={APL Photonics},
117 | volume={7},
118 | number={1},
119 | pages={011301},
120 | year={2022},
121 | publisher={AIP Publishing LLC}
122 | }
123 | ```
124 |
125 | ## **Acknowledgments**
126 | We referred to the code from the following repo, [UNet](https://github.com/jvanvugt/pytorch-unet). We thank the authors for sharing their code.
--------------------------------------------------------------------------------
/solution_maxwellnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com
2 |
3 | import torch
4 | from MaxwellNet import MaxwellNet
5 |
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from mpl_toolkits.axes_grid1 import make_axes_locatable
9 |
10 | import os
11 | import json
12 | import argparse
13 |
14 |
15 | def main(args):
16 |
17 | specs_filename = os.path.join(
18 | os.getcwd(), args.directory, 'specs_maxwell.json')
19 | specs = json.load(open(specs_filename))
20 |
21 | physical_specs = specs['PhysicalSpecs']
22 | Nx = physical_specs['Nx']
23 | Nz = physical_specs['Nz']
24 | dpl = physical_specs['dpl']
25 | wavelength = physical_specs['wavelength']
26 | symmetry_x = physical_specs['symmetry_x']
27 | mode = physical_specs['mode']
28 |
29 | delta = wavelength / dpl
30 | Nx = Nx * (symmetry_x + 1)
31 |
32 | scat_pot_np = np.load(os.path.join(
33 | args.directory, args.sample_filename))['sample']
34 | ri = np.load(os.path.join(args.directory,
35 | args.sample_filename))['n']
36 |
37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38 | model_directory = os.path.join(args.directory, 'model')
39 | model_dict = torch.load(os.path.join(
40 | model_directory, args.model_filename))
41 | model = MaxwellNet(**specs["NetworkSpecs"], **
42 | specs["PhysicalSpecs"]).to(device)
43 | model.load_state_dict(model_dict['state_dict'])
44 | model.eval()
45 |
46 | scat_pot_torch = torch.from_numpy(np.float32(scat_pot_np)).to(device)
47 | ri_torch = torch.tensor([np.float32(ri)]).to(device)
48 | (diff, total) = model(scat_pot_torch, ri_torch)
49 |
50 | total_np = total.cpu().detach().numpy()
51 | diff_np = diff.cpu().detach().numpy()
52 | total_np = total_np[0, 0::2, :, :] + 1j*total_np[0, 1::2, :, :]
53 | diff_np = diff_np[0, 0::2, :, :] + 1j*diff_np[0, 1::2, :, :]
54 | scat_pot_np = scat_pot_np[0, :, :, :] * (ri-1) + 1
55 |
56 | # Here, I use the following min max values to present the data just for visualization. This is approximately correct.
57 | # Please note that the output values from MaxwellNet are defined on Yee grid.
58 | # So, if you want to quantitatively compare the MaxwellNet outputs with solutions from another solver, you should compare two solutions at the same Yee grid points.
59 | x_min = -(Nx//2) * delta
60 | x_max = (Nx//2-1) * delta
61 | z_min = -(Nz//2) * delta
62 | z_max = (Nz//2-1) * delta
63 | fontsize = 20
64 |
65 | if mode == 'te':
66 | if symmetry_x == True:
67 | scat_pot_np = np.pad(np.concatenate(
68 | (np.flip(scat_pot_np[0, 1::, :], 0), scat_pot_np[0, :, :]), 0), ((1, 0), (0, 0)))
69 | total_np = np.pad(np.concatenate(
70 | (np.flip(total_np[0, 1::, :], 0), total_np[0, :, :]), 0), ((1, 0), (0, 0)))
71 | diff_np = np.pad(np.concatenate(
72 | (np.flip(diff_np[0, 1::, :], 0), diff_np[0, :, :]), 0), ((1, 0), (0, 0)))
73 |
74 | fig, axs = plt.subplots(1, 2, figsize=(8, 5))
75 | fig.suptitle('TE mode - Sherical Lens', fontsize=fontsize)
76 |
77 | img0 = axs[0].imshow(scat_pot_np, extent=[
78 | z_min, z_max, x_min, x_max], vmin=1, vmax=ri)
79 | axs[0].set_title('RI distribution', fontsize=fontsize)
80 | divider0 = make_axes_locatable(axs[0])
81 | cax0 = divider0.append_axes("right", size="5%", pad=0.05)
82 | plt.colorbar(img0, cax=cax0)
83 |
84 | img1 = axs[1].imshow(np.abs(total_np), extent=[
85 | z_min, z_max, x_min, x_max])
86 | axs[1].set_title('Ey (envelop)', fontsize=fontsize)
87 | divider1 = make_axes_locatable(axs[1])
88 | cax1 = divider1.append_axes("right", size="5%", pad=0.05)
89 | plt.colorbar(img1, cax=cax1)
90 |
91 | plt.tight_layout()
92 | plt.savefig(os.path.join(os.getcwd(), args.directory, 'te_result.png'))
93 |
94 | elif mode == 'tm':
95 | if symmetry_x == True:
96 | scat_pot_x_np = np.concatenate(
97 | (np.flip(scat_pot_np[0, :, :], 0), scat_pot_np[0, :, :]), 0)
98 | scat_pot_z_np = np.pad(np.concatenate(
99 | (np.flip(scat_pot_np[1, 1::, :], 0), scat_pot_np[1, :, :]), 0), ((1, 0), (0, 0)))
100 | total_x_np = np.concatenate(
101 | (np.flip(total_np[0, :, :], 0), total_np[0, :, :]), 0)
102 | total_z_np = np.pad(np.concatenate(
103 | (-np.flip(total_np[1, 1::, :], 0), total_np[1, :, :]), 0), ((1, 0), (0, 0)))
104 |
105 | fig, axs = plt.subplots(2, 2, figsize=(8, 10))
106 | fig.suptitle('TM mode - Sherical Lens', fontsize=fontsize)
107 | print(scat_pot_x_np.shape)
108 | img00 = axs[0, 0].imshow(scat_pot_x_np, extent=[
109 | z_min, z_max, x_min, x_max], vmin=1, vmax=ri)
110 | axs[0, 0].set_title('RI distribution', fontsize=fontsize)
111 | divider00 = make_axes_locatable(axs[0, 0])
112 | cax00 = divider00.append_axes("right", size="5%", pad=0.05)
113 | plt.colorbar(img00, cax=cax00)
114 |
115 | img01 = axs[0, 1].imshow(np.abs(total_x_np), extent=[
116 | z_min, z_max, x_min, x_max])
117 | axs[0, 1].set_title('Ex (envelop)', fontsize=fontsize)
118 | divider01 = make_axes_locatable(axs[0, 1])
119 | cax01 = divider01.append_axes("right", size="5%", pad=0.05)
120 | plt.colorbar(img01, cax=cax01)
121 |
122 | img10 = axs[1, 0].imshow(scat_pot_z_np, extent=[
123 | z_min, z_max, x_min, x_max], vmin=1, vmax=ri)
124 | axs[1, 0].set_title('RI distribution', fontsize=fontsize)
125 | divider10 = make_axes_locatable(axs[1, 0])
126 | cax10 = divider10.append_axes("right", size="5%", pad=0.05)
127 | plt.colorbar(img10, cax=cax10)
128 |
129 | img11 = axs[1, 1].imshow(np.abs(total_z_np), extent=[
130 | z_min, z_max, x_min, x_max])
131 | axs[1, 1].set_title('Ez (envelop)', fontsize=fontsize)
132 | divider11 = make_axes_locatable(axs[1, 1])
133 | cax11 = divider11.append_axes("right", size="5%", pad=0.05)
134 | plt.colorbar(img11, cax=cax11)
135 |
136 | plt.tight_layout()
137 | plt.savefig(os.path.join(os.getcwd(), args.directory, 'tm_result.png'))
138 |
139 | else:
140 | raise KeyError("'mode' should me either 'te' or 'tm'.")
141 |
142 |
143 | if __name__ == '__main__':
144 | arg_parser = argparse.ArgumentParser(description="Train a MaxwellNet")
145 | arg_parser.add_argument(
146 | "--directory",
147 | "-d",
148 | required=True,
149 | default='examples\spheric_te',
150 | help="This directory should include "
151 | + "'sample.npz' (input to MaxwellNet) and "
152 | + "'model' folder where trained 'MaxwellNet' parameters are saved and "
153 | + "'specs_maxwell.json' used during training."
154 | )
155 | arg_parser.add_argument(
156 | "--model_filename",
157 | required=True,
158 | help="This filename indicates the saved model file name within 'directory\model\'."
159 | )
160 | arg_parser.add_argument(
161 | "--sample_filename",
162 | required=True,
163 | help="This filename indicates a .npz file to be provied to MaxwellNet to calculate the solution, and it should be located in 'directory\'."
164 | )
165 |
166 | args = arg_parser.parse_args()
167 | main(args)
168 |
--------------------------------------------------------------------------------
/train_maxwellnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com
2 |
3 | import torch
4 | from Dataset import LensDataset
5 | import torch.backends.cudnn as cudnn
6 | from torch.optim.lr_scheduler import StepLR
7 | from torch.utils.tensorboard import SummaryWriter
8 | from MaxwellNet import MaxwellNet
9 |
10 | import numpy as np
11 | import random
12 | import logging
13 | import argparse
14 | import os
15 | import json
16 | from datetime import datetime
17 |
18 |
19 | def main(directory, load_ckpt):
20 | logging.basicConfig(level=logging.DEBUG,
21 | format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
22 | datefmt='%a, %d %b %Y %H:%M:%S',
23 | filename=os.path.join(
24 | os.getcwd(), directory, f"maxwellnet_{datetime.now():%Y-%m-%d %H-%M-%S}.log"),
25 | filemode='w')
26 |
27 | logging.info("training " + directory)
28 |
29 | specs_filename = os.path.join(directory, 'specs_maxwell.json')
30 |
31 | if not os.path.isfile(specs_filename):
32 | raise Exception(
33 | 'The experiment directory does not include specifications file "specs_maxwell.json"'
34 | )
35 |
36 | specs = json.load(open(specs_filename))
37 |
38 | seed_number = get_spec_with_default(specs, "Seed", None)
39 | if seed_number != None:
40 | fix_seed(seed_number, torch.cuda.is_available())
41 |
42 | rank = 0
43 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44 |
45 | logging.info("Experiment description: \n" +
46 | ' '.join([str(elem) for elem in specs["Description"]]))
47 | logging.info("Training with " + str(device))
48 |
49 | model = MaxwellNet(**specs["NetworkSpecs"], **specs["PhysicalSpecs"])
50 | if torch.cuda.device_count() > 1:
51 | logging.info("Multiple GPUs: " + str(torch.cuda.device_count()))
52 | if load_ckpt is not None:
53 | load_path = os.path.join(os.getcwd(), directory, 'model', load_ckpt)
54 | ckpt_dict = torch.load(load_path + '.pt')
55 | ckpt_epoch = ckpt_dict['epoch']
56 | logging.info("Checkpoint loaded from {}-epoch".format(ckpt_epoch))
57 | model.load_state_dict(ckpt_dict['state_dict'])
58 |
59 | model = torch.nn.DataParallel(model)
60 | model.train()
61 | model = model.to(device)
62 |
63 | logging.info("Number of network parameters: {}".format(
64 | sum(p.data.nelement() for p in model.parameters())))
65 | logging.debug(specs["NetworkSpecs"])
66 | logging.debug(specs["PhysicalSpecs"])
67 |
68 | optimizer = torch.optim.Adam(model.parameters(), lr=get_spec_with_default(
69 | specs, "LearningRate", 0.0001), weight_decay=0)
70 | scheduler = StepLR(optimizer, step_size=get_spec_with_default(
71 | specs, "LearningRateDecayStep", 10000), gamma=get_spec_with_default(specs, "LearningRateDecay", 1.0))
72 |
73 | batch_size = get_spec_with_default(specs, "BatchSize", 1)
74 | epochs = get_spec_with_default(specs, "Epochs", 1)
75 | snapshot_freq = specs["SnapshotFrequency"]
76 | physical_specs = specs["PhysicalSpecs"]
77 | symmetry_x = physical_specs['symmetry_x']
78 | mode = physical_specs['mode']
79 | high_order = physical_specs['high_order']
80 |
81 | checkpoints = list(range(snapshot_freq, epochs + 1, snapshot_freq))
82 |
83 | filename = 'maxwellnet_' + mode + '_' + high_order
84 | writer = SummaryWriter(os.path.join(directory, 'tensorboard_' + filename))
85 | writer_freq = get_spec_with_default(specs, "TensorboardFrequency", None)
86 |
87 | train_dataset = LensDataset(directory, 'train')
88 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
89 | shuffle=True, pin_memory=True, sampler=None)
90 | logging.info("Train Dataset length: {}".format(len(train_dataset)))
91 | loss_train = torch.zeros(
92 | (int(epochs),), dtype=torch.float32, requires_grad=False)
93 |
94 | if len(train_dataset) > 1:
95 | perform_valid = True
96 | else:
97 | perform_valid = False
98 |
99 | if perform_valid == True:
100 | valid_dataset = LensDataset(directory, 'valid')
101 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size,
102 | shuffle=True, pin_memory=True, sampler=None)
103 | logging.info("Valid Dataset length: {}".format(len(valid_dataset)))
104 | loss_valid = torch.zeros(
105 | (int(epochs),), dtype=torch.float32, requires_grad=False)
106 |
107 | if load_ckpt is not None:
108 | optimizer.load_state_dict(ckpt_dict['optimizer'])
109 | scheduler.load_state_dict(ckpt_dict['scheduler'])
110 | loss_train[:ckpt_epoch:] = ckpt_dict['loss_train'][:ckpt_epoch:]
111 | logging.info("Check point loaded from {}-epoch".format(ckpt_epoch))
112 |
113 | start_epoch = ckpt_epoch
114 | else:
115 | start_epoch = 0
116 |
117 | logging.info("Training start")
118 |
119 | for epoch in range(start_epoch + 1, epochs + 1):
120 | train(train_loader, model, optimizer, epoch, loss_train,
121 | device, mode, symmetry_x, writer, writer_freq)
122 | logging.info("[Train] {} epoch. Loss: {:.5f}".format(
123 | epoch, loss_train[epoch-1].item())) if rank == 0 else None
124 | if perform_valid:
125 | valid(valid_loader, model, epoch, loss_valid,
126 | device, mode, symmetry_x, writer, writer_freq)
127 | logging.info("[Valid] {} epoch. Loss: {:.5f}".format(
128 | epoch, loss_valid[epoch-1].item())) if rank == 0 else None
129 |
130 | if epoch in checkpoints:
131 | logging.info("Checkpoint saved at {} epoch.".format(
132 | epoch)) if rank == 0 else None
133 | if rank == 0:
134 | save_checkpoint({
135 | 'epoch': epoch,
136 | 'state_dict': model.module.state_dict(),
137 | 'optimizer': optimizer.state_dict(),
138 | 'loss_train': loss_train,
139 | 'scheduler': scheduler.state_dict(),
140 | }, directory, str(epoch) + '_' + mode + '_' + high_order)
141 |
142 | if epoch % 200 == 0:
143 | logging.info("'latest' checkpoint saved at {} epoch.".format(
144 | epoch)) if rank == 0 else None
145 | if rank == 0:
146 | save_checkpoint({
147 | 'epoch': epoch,
148 | 'state_dict': model.module.state_dict(),
149 | 'optimizer': optimizer.state_dict(),
150 | 'loss_train': loss_train,
151 | 'scheduler': scheduler.state_dict(),
152 | }, directory, 'latest')
153 |
154 | scheduler.step()
155 |
156 | writer.close() if rank == 0 else None
157 |
158 |
159 | def train(train_loader, model, optimizer, epoch, loss_train, device, mode, symmetry, writer, writer_freq):
160 | model.train()
161 | with torch.set_grad_enabled(True):
162 | count = 0
163 |
164 | for data in train_loader:
165 | scat_pot_torch = data[0].to(device)
166 | ri_value_torch = data[1].to(device)
167 |
168 | (diff, total) = model(scat_pot_torch, ri_value_torch)
169 |
170 | l2 = diff.pow(2)
171 | loss = torch.mean(l2)
172 | optimizer.zero_grad()
173 | loss.backward()
174 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-3)
175 | optimizer.step()
176 |
177 | loss_train[epoch-1] += loss.item() * diff.size(0)
178 | count += diff.size(0)
179 |
180 | loss_train[epoch-1] = loss_train[epoch-1] / count
181 |
182 | if epoch % writer_freq == 0 and writer != None:
183 | to_tensorboard(total.permute(2, 3, 1, 0)[:, :, :, 0].clone().detach().cpu(), loss_train[epoch-1].numpy(), epoch,
184 | mode, symmetry, writer, 'train')
185 |
186 |
187 | def valid(valid_loader, model, epoch, loss_valid, device, mode, symmetry, writer, writer_freq):
188 | model.eval()
189 | with torch.set_grad_enabled(False):
190 | count = 0
191 |
192 | for data in valid_loader:
193 | scat_pot_torch = data[0].to(device)
194 | ri_value_torch = data[1].to(device)
195 |
196 | (diff, total) = model(scat_pot_torch,
197 | ri_value_torch) # [N, 1, H, W, D]
198 |
199 | l2 = diff.pow(2)
200 | loss = torch.mean(l2)
201 |
202 | loss_valid[epoch-1] += loss.item() * diff.size(0)
203 | count += diff.size(0)
204 |
205 | loss_valid[epoch-1] = loss_valid[epoch-1] / count
206 |
207 | if epoch % writer_freq == 0 and writer != None:
208 | to_tensorboard(total.permute(2, 3, 1, 0)[:, :, :, 0].clone().detach().cpu(), loss_valid[epoch-1].numpy(), epoch,
209 | mode, symmetry, writer, 'valid')
210 |
211 |
212 | def to_tensorboard(image, losses, epoch, mode, symmetry, writer, train_valid):
213 | if symmetry is True:
214 | if mode == 'te':
215 | y_pol = torch.cat((torch.flip(image, [0])[0:-1:, :, :], image), 0)
216 | elif mode == 'tm':
217 | x_pol = torch.cat((torch.flip(image, [0]), image), 0)
218 | z_pol = torch.cat((-torch.flip(image, [0])[0:-1, :, :], image), 0)
219 |
220 | if mode == 'te':
221 | polarization = ['y']
222 | elif mode == 'tm':
223 | polarization = ['x', 'z']
224 |
225 | for idx in range(len(polarization)):
226 | if symmetry == True:
227 | if polarization[idx] == 'y':
228 | image = y_pol
229 | elif polarization[idx] == 'x':
230 | image = x_pol
231 | elif polarization[idx] == 'z':
232 | image = z_pol
233 |
234 | amplitude = torch.sum(
235 | image[:, :, idx*2:(idx+1)*2].pow(2), 2).pow(1 / 2)
236 | amplitude = amplitude - torch.min(amplitude)
237 | amplitude = amplitude / torch.max(amplitude)
238 | writer.add_image(train_valid + '/' + mode + '/amplitude_' +
239 | polarization[idx], amplitude.unsqueeze(0), epoch)
240 |
241 | real = image[:, :, idx*2]
242 | real = real - torch.min(real)
243 | real = real / torch.max(real)
244 | writer.add_image(train_valid + '/' + mode + '/real_' +
245 | polarization[idx], real.unsqueeze(0), epoch)
246 |
247 | imaginary = image[:, :, idx*2+1]
248 | imaginary = imaginary - torch.min(imaginary)
249 | imaginary = imaginary / torch.max(imaginary)
250 | writer.add_image(train_valid + '/' + mode + '/imaginary_' +
251 | polarization[idx], imaginary.unsqueeze(0), epoch)
252 |
253 | writer.add_scalar(train_valid + '/' + mode, losses, epoch)
254 |
255 |
256 | def save_checkpoint(state, directory, filename):
257 | model_directory = os.path.join(directory, 'model')
258 | if os.path.exists(model_directory) == False:
259 | os.makedirs(model_directory)
260 | torch.save(state, os.path.join(model_directory, filename + '.pt'))
261 |
262 |
263 | def fix_seed(seed, is_cuda):
264 | random.seed(seed)
265 | np.random.seed(seed)
266 | torch.manual_seed(seed)
267 | if is_cuda:
268 | torch.cuda.manual_seed(seed)
269 | torch.cuda.manual_seed_all(seed)
270 | cudnn.benchmark = False
271 | cudnn.deterministic = True
272 |
273 |
274 | def get_spec_with_default(specs, key, default):
275 | try:
276 | return specs[key]
277 | except KeyError:
278 | return default
279 |
280 |
281 | if __name__ == '__main__':
282 | arg_parser = argparse.ArgumentParser(description="Train a MaxwellNet")
283 | arg_parser.add_argument(
284 | "--directory",
285 | "-d",
286 | required=True,
287 | default='examples\spheric_te',
288 | help="This directory should include "
289 | + "all the training and network parameters in 'specs_maxwell.json', and logging will be "
290 | + "done in this directory as well.",
291 | )
292 | arg_parser.add_argument(
293 | "--load_ckpt",
294 | "-l",
295 | default=None,
296 | help="This should specify a filename of your checkpoint within 'directory'\model if you want to continue your training from the checkpoint.",
297 | )
298 |
299 | args = arg_parser.parse_args()
300 | main(args.directory, args.load_ckpt)
301 |
--------------------------------------------------------------------------------
/MaxwellNet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com
2 |
3 | import torch
4 | from torch import nn
5 | from UNet import UNet
6 | import torch.nn.functional as F
7 | import math
8 | import numpy as np
9 |
10 |
11 | class MaxwellNet(nn.Module):
12 | def __init__(self, depth=6, filter=16, norm='weight', up_mode='upconv',
13 | wavelength=1, dpl=20, Nx=256, Nz=256, pml_thickness=16, symmetry_x=False, mode='te', high_order='fourth'):
14 |
15 | super(MaxwellNet, self).__init__()
16 | self.mode = mode
17 | if mode == 'te':
18 | in_channels = 1
19 | out_channels = 2
20 | elif mode == 'tm':
21 | in_channels = 2
22 | out_channels = 4
23 |
24 | self.high_order = high_order
25 | self.model = UNet(in_channels, out_channels,
26 | depth, filter, norm, up_mode)
27 |
28 | # pixel size [um / pixel]
29 | delta = wavelength / dpl
30 | # wave-number [1 / um]
31 | k = 2 * math.pi / wavelength
32 | self.register_buffer('delta', torch.tensor(
33 | delta, dtype=torch.float32, requires_grad=False))
34 | self.register_buffer('k', torch.tensor(
35 | k, dtype=torch.float32, requires_grad=False))
36 |
37 | self.symmetry_x = symmetry_x
38 |
39 | if self.high_order == 'second':
40 | pad = 2
41 | self.pad = pad
42 | elif self.high_order == 'fourth':
43 | pad = 4
44 | self.pad = pad
45 |
46 | self.padding_ref = nn.Sequential(nn.ReflectionPad2d(
47 | (0, 0, pad, 0)), nn.ZeroPad2d((pad, pad, 0, pad)))
48 | self.padding_zero = nn.Sequential(nn.ZeroPad2d((pad, pad, pad, pad)))
49 |
50 | if symmetry_x == True:
51 | x = np.linspace(-pad, Nx + pad - 1, Nx + 2 * pad) * delta
52 | else:
53 | x = np.linspace(-Nx // 2 - pad, Nx // 2 +
54 | pad - 1, Nx + 2 * pad) * delta
55 | z = np.linspace(-Nz // 2 - pad, Nz // 2 +
56 | pad - 1, Nz + 2 * pad) * delta
57 |
58 | # Coordinate set-up
59 | zz, xx = np.meshgrid(z, x)
60 | self.Nx = zz.shape[0]
61 | self.Nz = zz.shape[1]
62 |
63 | # incident electric and magnetic fields definition on the Yee grid
64 | fast = np.exp(1j * (k * zz))
65 | fast_z = np.exp(1j * (k * (zz + delta / 2)))
66 |
67 | self.register_buffer('fast', torch.zeros((1, 2, fast.shape[0], fast.shape[1]), dtype=torch.float32,
68 | requires_grad=False))
69 | self.fast[0, 0, :, :] = torch.from_numpy(np.real(fast))
70 | self.fast[0, 1, :, :] = torch.from_numpy(np.imag(fast))
71 |
72 | self.register_buffer('fast_z', torch.zeros((1, 2, fast_z.shape[0], fast_z.shape[1]), dtype=torch.float32,
73 | requires_grad=False))
74 | self.fast_z[0, 0, :, :] = torch.from_numpy(np.real(fast_z))
75 | self.fast_z[0, 1, :, :] = torch.from_numpy(np.imag(fast_z))
76 |
77 | # perfectly-matched-layer set up
78 | m = 4
79 | const = 5
80 | rx_p = 1 + 1j * const * (xx - x[-1] + pml_thickness * delta) ** m
81 | rx_p[0:-pml_thickness, :] = 0
82 | rx_n = 1 + 1j * const * (xx - x[0] - pml_thickness * delta) ** m
83 | rx_n[pml_thickness::, :] = 0
84 | rx = rx_p + rx_n
85 | if symmetry_x == True:
86 | rx[0:-pml_thickness:, :] = 1
87 | else:
88 | rx[pml_thickness:-pml_thickness, :] = 1
89 |
90 | rz_p = 1 + 1j * const * (zz - z[-1] + pml_thickness * delta) ** m
91 | rz_p[:, 0:-pml_thickness] = 0
92 | rz_n = 1 + 1j * const * (zz - z[0] - pml_thickness * delta) ** m
93 | rz_n[:, pml_thickness::] = 0
94 | rz = rz_p + rz_n
95 | rz[:, pml_thickness:-pml_thickness] = 1
96 |
97 | rx_inverse = 1 / rx
98 | rz_inverse = 1 / rz
99 |
100 | self.register_buffer('rx_inverse', torch.zeros((1, 2, rx_inverse.shape[0], rx_inverse.shape[1]), dtype=torch.float32,
101 | requires_grad=False))
102 | self.rx_inverse[0, 0, :, :] = torch.from_numpy(np.real(rx_inverse))
103 | self.rx_inverse[0, 1, :, :] = torch.from_numpy(np.imag(rx_inverse))
104 |
105 | self.register_buffer('rz_inverse', torch.zeros((1, 2, rz_inverse.shape[0], rz_inverse.shape[1]), dtype=torch.float32,
106 | requires_grad=False))
107 | self.rz_inverse[0, 0, :, :] = torch.from_numpy(np.real(rz_inverse))
108 | self.rz_inverse[0, 1, :, :] = torch.from_numpy(np.imag(rz_inverse))
109 |
110 | # Gradient and laplacian kernels set up
111 | self.register_buffer('gradient_h_z', torch.zeros(
112 | (2, 1, 1, 3), dtype=torch.float32, requires_grad=False))
113 | self.gradient_h_z[:, :, 0, :] = torch.tensor(
114 | [-1 / delta, +1 / delta, 0])
115 | self.register_buffer('gradient_h_x', torch.zeros(
116 | (2, 1, 3, 1), dtype=torch.float32, requires_grad=False))
117 | self.gradient_h_x = self.gradient_h_z.permute(0, 1, 3, 2)
118 | self.register_buffer('gradient_h_z_ho', torch.zeros(
119 | (2, 1, 1, 5), dtype=torch.float32, requires_grad=False))
120 | self.gradient_h_z_ho[:, :, 0, :] = torch.tensor(
121 | [1 / 24 / delta, -9 / 8 / delta, +9 / 8 / delta, -1 / 24 / delta, 0])
122 | self.register_buffer('gradient_h_x_ho', torch.zeros(
123 | (2, 1, 5, 1), dtype=torch.float32, requires_grad=False))
124 | self.gradient_h_x_ho = self.gradient_h_z_ho.permute(0, 1, 3, 2)
125 |
126 | self.register_buffer('gradient_e_z', torch.zeros(
127 | (2, 1, 1, 3), dtype=torch.float32, requires_grad=False))
128 | self.gradient_e_z[:, :, 0, :] = torch.tensor(
129 | [0, -1 / delta, +1 / delta])
130 | self.register_buffer('gradient_e_x', torch.zeros(
131 | (2, 1, 3, 1), dtype=torch.float32, requires_grad=False))
132 | self.gradient_e_x = self.gradient_e_z.permute(0, 1, 3, 2)
133 | self.register_buffer('gradient_e_z_ho', torch.zeros(
134 | (2, 1, 1, 5), dtype=torch.float32, requires_grad=False))
135 | self.gradient_e_z_ho[:, :, 0, :] = torch.tensor(
136 | [0, 1 / 24 / delta, -9 / 8 / delta, +9 / 8 / delta, -1 / 24 / delta])
137 | self.register_buffer('gradient_e_x_ho', torch.zeros(
138 | (2, 1, 5, 1), dtype=torch.float32, requires_grad=False))
139 | self.gradient_e_x_ho = self.gradient_e_z_ho.permute(0, 1, 3, 2)
140 |
141 | self.register_buffer('dd_z_fast', torch.zeros(
142 | (1, 2, Nx, Nz), dtype=torch.float32, requires_grad=False))
143 | self.dd_z_fast = self.dd_z(self.fast)[:, :, self.pad:-self.pad:, :]
144 | self.register_buffer('dd_z_ho_fast', torch.zeros(
145 | (1, 2, Nx, Nz), dtype=torch.float32, requires_grad=False))
146 | self.dd_z_ho_fast = self.dd_z_ho(
147 | self.fast)[:, :, self.pad:-self.pad:, :]
148 |
149 | def forward(self, scat_pot, ri_value):
150 | if self.mode == 'te':
151 | epsillon = scat_pot * \
152 | (ri_value ** 2).unsqueeze(1).unsqueeze(2).unsqueeze(3)
153 | epsillon = torch.where(epsillon > 1.0, epsillon, torch.tensor(
154 | [1], dtype=torch.float32).to(ri_value.device))
155 |
156 | x = self.model(scat_pot)
157 | total = torch.cat((x[:, 0:1, :, :] + 1, x[:, 1:2, :, :]), 1)
158 |
159 | ey = self.complex_multiplication(total[:, 0:2, :, :],
160 | self.fast[:, :, self.pad:-self.pad:, self.pad:-self.pad:])
161 | ey_i = self.fast
162 | ey_s = ey - ey_i[:, :, self.pad:-self.pad:, self.pad:-self.pad:]
163 |
164 | if self.symmetry_x == True:
165 | ey_s = self.padding_ref(ey_s)
166 | else:
167 | ey_s = self.padding_zero(ey_s)
168 |
169 | if self.high_order == 'second':
170 | diff = self.dd_x_pml(ey_s)[:, :, :, self.pad:-self.pad] \
171 | + self.dd_z_pml(ey_s)[:, :, self.pad:-self.pad, :] \
172 | + self.dd_z_fast \
173 | + self.k ** 2 * (epsillon * ey)
174 |
175 | elif self.high_order == 'fourth':
176 | diff = self.dd_x_ho_pml(ey_s)[:, :, :, self.pad:-self.pad] \
177 | + self.dd_z_ho_pml(ey_s)[:, :, self.pad:-self.pad, :] \
178 | + self.dd_z_ho_fast \
179 | + self.k ** 2 * (epsillon * ey)
180 |
181 | elif self.mode == 'tm':
182 | epsillon = scat_pot * \
183 | (ri_value ** 2).unsqueeze(1).unsqueeze(2).unsqueeze(3)
184 | epsillon_x = torch.where(epsillon[:, 0:1, :, :] > 1.0, epsillon[:, 0:1, :, :],
185 | torch.tensor([1], dtype=torch.float32).to(ri_value.device))
186 | epsillon_z = torch.where(epsillon[:, 1:2, :, :] > 1.0, epsillon[:, 1:2, :, :],
187 | torch.tensor([1], dtype=torch.float32).to(ri_value.device))
188 |
189 | x = self.model(scat_pot)
190 | total = torch.cat((x[:, 0:1, :, :] + 1, x[:, 1:4, :, :]), 1)
191 |
192 | ex = self.complex_multiplication(
193 | total[:, 0:2, :, :], self.fast[:, :, self.pad:-self.pad:, self.pad:-self.pad:])
194 | ex_i = self.fast
195 | ex_s = ex - ex_i[:, :, self.pad:-self.pad:, self.pad:-self.pad:]
196 |
197 | ez_s = self.complex_multiplication(
198 | total[:, 2:4, :, :], self.fast_z[:, :, self.pad:-self.pad:, self.pad:-self.pad:])
199 |
200 | if self.symmetry_x == True:
201 | ex_s = self.padding_zero(ex_s)
202 | ez_s = self.padding_ref(ez_s)
203 | ex_s[:, :, 0:self.pad, :] = torch.flip(
204 | ex_s[:, :, self.pad:2 * self.pad, :], [2])
205 | ez_s[:, :, 0:self.pad, :] = -ez_s[:, :, 0:self.pad, :]
206 | else:
207 | ex_s = self.padding_zero(ex_s)
208 | ez_s = self.padding_zero(ez_s)
209 |
210 | if self.high_order == 'second':
211 | diff_x = self.dd_z_pml(ex_s)[:, :, self.pad:-self.pad:, :] \
212 | + self.dd_z_fast \
213 | - self.dd_zx(ez_s)[:, :, self.pad // 2:-self.pad // 2:, self.pad // 2:-self.pad // 2] \
214 | + self.k ** 2 * (epsillon_x * ex) \
215 |
216 | diff_z = self.dd_x_pml(ez_s)[:, :, :, self.pad:-self.pad] \
217 | - self.dd_xz(ex_s)[:, :, self.pad // 2:-self.pad // 2:, self.pad // 2:-self.pad // 2] \
218 | + self.k ** 2 * (epsillon_z * ez_s) \
219 |
220 | elif self.high_order == 'fourth':
221 | diff_x = self.dd_z_ho_pml(ex_s)[:, :, self.pad:-self.pad:, :] \
222 | + self.dd_z_ho_fast \
223 | - self.dd_zx_ho_pml(ez_s)[:, :, self.pad//2:-self.pad//2:, self.pad//2:-self.pad//2] \
224 | + self.k ** 2 * (epsillon_x * ex)
225 |
226 | diff_z = self.dd_x_ho_pml(ez_s)[:, :, :, self.pad:-self.pad] \
227 | - self.dd_xz_ho_pml(ex_s)[:, :, self.pad//2:-self.pad//2:, self.pad//2:-self.pad//2] \
228 | + self.k ** 2 * \
229 | (epsillon_z * ez_s[:, :, self.pad:-
230 | self.pad:, self.pad:-self.pad:])
231 |
232 | diff = torch.cat((diff_x, diff_z), 1)
233 |
234 | return diff, total
235 |
236 | def complex_multiplication(self, a, b):
237 | r_p = torch.mul(a[:, 0:1, :, :], b[:, 0:1, :, :]) - \
238 | torch.mul(a[:, 1:2, :, :], b[:, 1:2, :, :])
239 | i_p = torch.mul(a[:, 0:1, :, :], b[:, 1:2, :, :]) + \
240 | torch.mul(a[:, 1:2, :, :], b[:, 0:1, :, :])
241 | return torch.cat((r_p, i_p), 1)
242 |
243 | def complex_conjugate(self, a):
244 | return torch.cat((-a[:, 1:2, :, :], a[:, 0:1, :, :]), 1)
245 |
246 | def d_e_x(self, x):
247 | return F.conv2d(x, self.gradient_e_x, padding=0, groups=2)
248 |
249 | def d_e_x_ho(self, x):
250 | return F.conv2d(x, self.gradient_e_x_ho, padding=0, groups=2)
251 |
252 | def d_h_x(self, x):
253 | return F.conv2d(x, self.gradient_h_x, padding=0, groups=2)
254 |
255 | def d_h_x_ho(self, x):
256 | return F.conv2d(x, self.gradient_h_x_ho, padding=0, groups=2)
257 |
258 | def d_e_z(self, x):
259 | return F.conv2d(x, self.gradient_e_z, padding=0, groups=2)
260 |
261 | def d_e_z_ho(self, x):
262 | return F.conv2d(x, self.gradient_e_z_ho, padding=0, groups=2)
263 |
264 | def d_h_z(self, x):
265 | return F.conv2d(x, self.gradient_h_z, padding=0, groups=2)
266 |
267 | def d_h_z_ho(self, x):
268 | return F.conv2d(x, self.gradient_h_z_ho, padding=0, groups=2)
269 |
270 | def dd_x(self, x):
271 | return self.d_h_x(self.d_e_x(x))
272 |
273 | def dd_x_ho(self, x):
274 | return self.d_h_x_ho(self.d_e_x_ho(x))
275 |
276 | def dd_x_pml(self, x):
277 | return self.complex_multiplication(self.rx_inverse[:, :, 2:-2, :], self.d_h_x(
278 | self.complex_multiplication(self.rx_inverse[:, :, 1:-1, :], self.d_e_x(x))))
279 |
280 | def dd_x_ho_pml(self, x):
281 | return self.complex_multiplication(self.rx_inverse[:, :, 4:-4, :], self.d_h_x_ho(
282 | self.complex_multiplication(self.rx_inverse[:, :, 2:-2, :], self.d_e_x_ho(x))))
283 |
284 | def dd_z(self, x):
285 | return self.d_h_z(self.d_e_z(x))
286 |
287 | def dd_z_ho(self, x):
288 | return self.d_h_z_ho(self.d_e_z_ho(x))
289 |
290 | def dd_z_pml(self, x):
291 | return self.complex_multiplication(self.rz_inverse[:, :, :, 2:-2], self.d_h_z(
292 | self.complex_multiplication(self.rz_inverse[:, :, :, 1:-1], self.d_e_z(x))))
293 |
294 | def dd_z_ho_pml(self, x):
295 | return self.complex_multiplication(self.rz_inverse[:, :, :, 4:-4], self.d_h_z_ho(
296 | self.complex_multiplication(self.rz_inverse[:, :, :, 2:-2], self.d_e_z_ho(x))))
297 |
298 | def dd_zx(self, x):
299 | return self.d_h_z(self.d_e_x(x))
300 |
301 | def dd_zx_ho(self, x):
302 | return self.d_h_z_ho(self.d_e_x_ho(x))
303 |
304 | def dd_zx_pml(self, x):
305 | return self.complex_multiplication(self.rz_inverse[:, :, 1:-1, 1:-1], self.d_h_z(
306 | self.complex_multiplication(self.rx_inverse[:, :, 1:-1, :], self.d_e_x(x))))
307 |
308 | def dd_zx_ho_pml(self, x):
309 | return self.complex_multiplication(self.rz_inverse[:, :, 2:-2, 2:-2], self.d_h_z_ho(
310 | self.complex_multiplication(self.rx_inverse[:, :, 2:-2, :], self.d_e_x_ho(x))))
311 |
312 | def dd_xz(self, x):
313 | return self.d_h_x(self.d_e_z(x))
314 |
315 | def dd_xz_ho(self, x):
316 | return self.d_h_x_ho(self.d_e_z_ho(x))
317 |
318 | def dd_xz_pml(self, x):
319 | return self.complex_multiplication(self.rx_inverse[:, :, 1:-1, 1:-1], self.d_h_x(
320 | self.complex_multiplication(self.rz_inverse[:, :, :, 1:-1], self.d_e_z(x))))
321 |
322 | def dd_xz_ho_pml(self, x):
323 | return self.complex_multiplication(self.rx_inverse[:, :, 2:-2, 2:-2], self.d_h_x_ho(
324 | self.complex_multiplication(self.rz_inverse[:, :, :, 2:-2], self.d_e_z_ho(x))))
325 |
--------------------------------------------------------------------------------