├── figures ├── scheme.png ├── te_result.png └── tm_result.png ├── examples ├── spheric_te │ ├── sample.npz │ ├── train.npz │ └── specs_maxwell.json └── spheric_tm │ ├── sample.npz │ ├── train.npz │ └── specs_maxwell.json ├── environments.yml ├── Dataset.py ├── UNet.py ├── README.md ├── solution_maxwellnet.py ├── train_maxwellnet.py └── MaxwellNet.py /figures/scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/figures/scheme.png -------------------------------------------------------------------------------- /figures/te_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/figures/te_result.png -------------------------------------------------------------------------------- /figures/tm_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/figures/tm_result.png -------------------------------------------------------------------------------- /examples/spheric_te/sample.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/examples/spheric_te/sample.npz -------------------------------------------------------------------------------- /examples/spheric_te/train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/examples/spheric_te/train.npz -------------------------------------------------------------------------------- /examples/spheric_tm/sample.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/examples/spheric_tm/sample.npz -------------------------------------------------------------------------------- /examples/spheric_tm/train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limjoowon/maxwellnet/HEAD/examples/spheric_tm/train.npz -------------------------------------------------------------------------------- /environments.yml: -------------------------------------------------------------------------------- 1 | name: maxwellnet 2 | channels: 3 | - defaults 4 | - pytorch 5 | - nvidia 6 | dependencies: 7 | - python=3.7.0 8 | - nvidia::cudatoolkit=11.0 9 | - pytorch::pytorch=1.7.1 10 | - torchvision=0.8.2 11 | - scikit-image=0.19.2 12 | - jupyter=1.0.0 13 | - matplotlib=3.5.1 14 | - tensorboard=2.0.0 15 | - pip: 16 | - plyfile==0.7.4 -------------------------------------------------------------------------------- /Dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com 2 | 3 | import torch 4 | import numpy as np 5 | import os 6 | 7 | 8 | class LensDataset(torch.utils.data.Dataset): 9 | def __init__(self, data_path, mode): 10 | self.dataset = np.load(os.path.join( 11 | data_path, mode + ".npz"))['sample'] 12 | self.n = np.load(os.path.join(data_path, mode + ".npz"))['n'] 13 | 14 | def __len__(self): 15 | return self.dataset.shape[0] 16 | 17 | def __getitem__(self, idx): 18 | sample = self.dataset[idx, :, :, :] 19 | return sample, self.n, idx 20 | -------------------------------------------------------------------------------- /examples/spheric_te/specs_maxwell.json: -------------------------------------------------------------------------------- 1 | { 2 | "Description": [ 3 | "This experiment learns to solve the Maxwell's equations", 4 | "for a spheric lens data." 5 | ], 6 | "NetworkArch": "maxwellnet", 7 | "NetworkSpecs": { 8 | "depth": 6, 9 | "filter": 16, 10 | "norm": "weight", 11 | "up_mode": "upconv" 12 | }, 13 | "PhysicalSpecs": { 14 | "wavelength": 1, 15 | "dpl": 20, 16 | "Nx": 160, 17 | "Nz": 192, 18 | "pml_thickness": 30, 19 | "symmetry_x": true, 20 | "mode": "te", 21 | "high_order": "fourth" 22 | }, 23 | "Seed": 2, 24 | "LearningRate": 0.0005, 25 | "LearningRateDecay": 0.5, 26 | "LearningRateDecayStep": 50000, 27 | "Epochs": 250000, 28 | "BatchSize": 1, 29 | "SnapshotFrequency": 50000, 30 | "TensorboardFrequency": 100 31 | } 32 | 33 | -------------------------------------------------------------------------------- /examples/spheric_tm/specs_maxwell.json: -------------------------------------------------------------------------------- 1 | { 2 | "Description": [ 3 | "This experiment learns to solve the Maxwell's equations", 4 | "for a spheric lens data." 5 | ], 6 | "NetworkArch": "maxwellnet", 7 | "NetworkSpecs": { 8 | "depth": 6, 9 | "filter": 32, 10 | "norm": "weight", 11 | "up_mode": "upconv" 12 | }, 13 | "PhysicalSpecs": { 14 | "wavelength": 1, 15 | "dpl": 20, 16 | "Nx": 160, 17 | "Nz": 192, 18 | "pml_thickness": 30, 19 | "symmetry_x": true, 20 | "mode": "tm", 21 | "high_order": "fourth" 22 | }, 23 | "Seed": 2, 24 | "LearningRate": 0.0005, 25 | "LearningRateDecay": 0.5, 26 | "LearningRateDecayStep": 50000, 27 | "Epochs": 250000, 28 | "BatchSize": 1, 29 | "SnapshotFrequency": 50000, 30 | "TensorboardFrequency": 100 31 | } 32 | 33 | -------------------------------------------------------------------------------- /UNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted and modified from https://github.com/jvanvugt/pytorch-unet 3 | 4 | Modified parts: 5 | Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com 6 | 7 | Original parts: 8 | MIT License 9 | 10 | Copyright (c) 2018 Joris 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE. 29 | """ 30 | 31 | import torch 32 | from torch import nn 33 | 34 | 35 | class UNet(nn.Module): 36 | def __init__(self, in_channels=1, out_channels=2, depth=5, wf=6, norm='weight', up_mode='upconv'): 37 | super(UNet, self).__init__() 38 | assert up_mode in ('upconv', 'upsample') 39 | self.down_path = nn.ModuleList() 40 | self.up_path = nn.ModuleList() 41 | 42 | prev_channels = int(in_channels) 43 | 44 | for i in range(depth): 45 | if i != depth - 1: 46 | if i == 0: 47 | self.down_path.append(UNetConvBlock( 48 | prev_channels, [wf * (2 ** i), wf * (2 ** i)], 3, 0, norm)) 49 | else: 50 | self.down_path.append(UNetConvBlock( 51 | prev_channels, [wf * (2 ** i), wf * (2 ** i)], 3, 0, norm)) 52 | prev_channels = int(wf * (2 ** i)) 53 | self.down_path.append(nn.AvgPool2d(2)) 54 | else: 55 | self.down_path.append(UNetConvBlock( 56 | prev_channels, [wf * (2 ** i), wf * (2 ** (i - 1))], 3, 0, norm)) 57 | prev_channels = int(wf * (2 ** (i - 1))) 58 | 59 | for i in reversed(range(depth - 1)): 60 | self.up_path.append( 61 | UNetUpBlock(prev_channels, [wf * (2 ** i), int(wf * (2 ** (i - 1)))], up_mode, 3, 0, norm)) 62 | prev_channels = int(wf * (2 ** (i - 1))) 63 | 64 | self.last_conv = nn.Conv2d( 65 | prev_channels, out_channels, kernel_size=1, padding=0, bias=False) 66 | 67 | def forward(self, scat_pot): 68 | blocks = [] 69 | x = scat_pot 70 | for i, down in enumerate(self.down_path): 71 | x = down(x) 72 | if i % 2 == 0 and i != (len(self.down_path) - 1): 73 | blocks.append(x) 74 | for i, up in enumerate(self.up_path): 75 | x = up(x, blocks[-i - 1]) 76 | 77 | return self.last_conv(x) 78 | 79 | 80 | class UNetConvBlock(nn.Module): 81 | def __init__(self, in_size, out_size, kersize, padding, norm): 82 | super(UNetConvBlock, self).__init__() 83 | block = [] 84 | if norm == 'weight': 85 | block.append(nn.ReplicationPad2d(1)) 86 | block.append(nn.utils.weight_norm((nn.Conv2d(in_size, out_size[0], kernel_size=int(kersize), 87 | padding=int(0), bias=True)), name='weight')) 88 | block.append(nn.CELU()) 89 | block.append(nn.ReplicationPad2d(1)) 90 | block.append(nn.utils.weight_norm((nn.Conv2d(out_size[0], out_size[1], kernel_size=int(kersize), 91 | padding=int(0), bias=True)), name='weight')) 92 | elif norm == 'batch': 93 | block.append(nn.ReflectionPad2d(1)) 94 | block.append(nn.Conv2d(in_size, out_size[0], kernel_size=int(kersize), 95 | padding=int(padding), bias=True)) 96 | block.append(nn.BatchNorm2d(out_size[0])) 97 | block.append(nn.CELU()) 98 | 99 | block.append(nn.ReflectionPad2d(1)) 100 | block.append(nn.Conv2d(out_size[0], out_size[1], kernel_size=int(kersize), 101 | padding=int(padding), bias=True)) 102 | block.append(nn.BatchNorm2d(out_size[1])) 103 | 104 | elif norm == 'no': 105 | block.append(nn.ReplicationPad2d(1)) 106 | block.append((nn.Conv2d(in_size, out_size[0], kernel_size=int(kersize), 107 | padding=int(0), bias=True))) 108 | block.append(nn.CELU()) 109 | block.append(nn.ReplicationPad2d(1)) 110 | block.append((nn.Conv2d(out_size[0], out_size[1], kernel_size=int(kersize), 111 | padding=int(0), bias=True))) 112 | 113 | self.block = nn.Sequential(*block) 114 | 115 | def forward(self, x): 116 | out = self.block(x) 117 | return out 118 | 119 | 120 | class UNetUpBlock(nn.Module): 121 | def __init__(self, in_size, out_size, up_mode, kersize, padding, norm): 122 | super(UNetUpBlock, self).__init__() 123 | block = [] 124 | if up_mode == 'upconv': 125 | block.append(nn.ConvTranspose2d(in_size, in_size, 126 | kernel_size=2, stride=2, bias=False)) 127 | elif up_mode == 'upsample': 128 | block.append(nn.Upsample(mode='bilinear', scale_factor=2)) 129 | block.append(nn.Conv2d(in_size, in_size, 130 | kernel_size=1, bias=False)) 131 | 132 | self.block = nn.Sequential(*block) 133 | self.conv_block = UNetConvBlock( 134 | in_size * 2, out_size, kersize, padding, norm) 135 | 136 | def forward(self, x, bridge): 137 | up = self.block(x) 138 | out = torch.cat([up, bridge], 1) 139 | out = self.conv_block(out) 140 | return out 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **MaxwellNet** 2 | This repository is the official code implementation of the paper, "MaxwellNet: Physics-driven deep neural network training based on Maxwell’s equations" by [Joowon Lim](https://www.linkedin.com/in/joowon-lim/) and [Demetri Psaltis](https://scholar.google.com/citations?&user=-CVR2h8AAAAJ). You can refer to the following materials for the details of implementation, 3 | - [Main article](https://aip.scitation.org/doi/10.1063/5.0071616 "Main article") 4 | - [Supplementary material](https://aip.scitation.org/doi/suppl/10.1063/5.0071616 "Supplementary material") 5 | 6 | Also, we had an interview on this work, 7 | - [Scilight interview](https://aip.scitation.org/doi/full/10.1063/10.0009285 "Scilight interview") 8 | 9 | ### **Overall scheme and idea** 10 | The novelty of this work is to train a deep neural network, MaxwellNet, which solves Maxwell's equations using physics-driven loss. In other words, we are using the residual of Maxwell's equations as a loss function to train MaxwellNet, therefore, it does not require ground truth solutions to train it. Furthermore, we utilized MaxwellNet in a novel inverse design scheme, and we encourage you to refer to the [main article](https://aip.scitation.org/doi/10.1063/5.0071616 "Main article") for details. 11 |
12 | 13 | ![Scheme](/figures/scheme.png) 14 | 15 |
16 | 17 | 18 | 19 | ## **Installation** 20 | Our code is based on Windows 10, pytorch 1.7.1, CUDA 11.0, and python 3.7. 21 | We recommend using conda for installation. 22 | 23 | ``` 24 | conda env create --file environment.yaml 25 | conda activate maxwellnet 26 | ``` 27 | 28 | ## **Run** 29 | 30 | ### **1. MaxwellNet Training** 31 | ``` 32 | python train_maxwellnet.py --directory 33 | ``` 34 | In , you need to have 'train.npz' which contains the training dataset and 'specs_maxwell.json' where you specify training parameters. A brief description of the parameters can be found below. I encourage you to read the [supplementary material](https://aip.scitation.org/doi/suppl/10.1063/5.0071616 "supplementary material") to understand the parameters. 35 | 36 | | NetworkSpecs | Description | 37 | | :---: | :--- | 38 | | depth [int] | Depth of UNet. | 39 | | filter [int] | Channel numbers in the first layer of UNet. | 40 | | norm [str] | Type of normalization ('weight' for weight normalization, 'batch' for batch normalization, and 'no' for no normalization). | 41 | | up_mode [str] | Upsample mode of UNet (either 'upcov' for transpose convolution or 'upsample' for upsampling). | 42 | 43 | 44 | | PhysicalSpecs | Description | 45 | | :---: | :--- | 46 | | wavelength [float] | Wavelength in [um]. | 47 | | dpl [int] | One pixel size is 'wavelength / dpl' [um]. | 48 | | Nx [int] | Pixel number along the x-axis. This is equivalent to the pixel number along the x-axis of your scattering sample.| 49 | | Nz [int] | Pixel number along the z-axis (light propagation direction). This is equivalent to the pixel number along the z-axis of your scattering sample. | 50 | | pml_thickness [int] | Perfectly-matched-layer (PML) thickness in pixel number. 'pml_thickness * wavelength / dpl' is the actual thickness of PML layer in micrometers. | 51 | | symmetry_x [bool] | If this is True, MaxwellNet will assume your input scattering sample is symmetric along the x-axis. For example, when given a sample whose Nx and Nz are 100 and 200, respectively, if this sample is symmetric along the x-axis, you can save only half of it (Nx=50, Nz=200) in your train file (train.npz) and set 'symmetry_x' as True. | 52 | | mode [str] | 'te' or 'tm' (Transverse Electric or Transverse Magnetic). | 53 | | high_order [str] | 'second' or 'fourth'. It decides which order (second or fourth order) to calculate the gradient. 'fourth' is more accurate than 'second'. | 54 | 55 | 56 | #### **Examples** 57 | *Training for a single spheric lens.* 58 | 59 | If you just want to train a model for a single lens (which would be a good exercise as it runs for a short time), you can train MaxwellNet for a single spheric lens as followings, 60 | * TE mode. 61 | ``` 62 | python train_maxwellnet.py --directory examples\spheric_te 63 | ``` 64 | * TM mode. 65 | ``` 66 | python train_maxwellnet.py --directory examples\spheric_tm 67 | ``` 68 | 69 | *Training for multiple lenses.* 70 | 71 | You can download the datasets of multiple lenses [here](https://drive.google.com/drive/folders/1ZXPKntdBQUOyMYvmKM7Ol6woCN2Rsrqj?usp=sharing). Download and place 'lens_te' and 'lens_tm' folders under 'examples' folder. 72 | * Transverse Electric (TE) mode. 73 | ``` 74 | python train_maxwellnet.py --directory examples\lens_te 75 | ``` 76 | * Transverse Magnetic (TM) mode. 77 | ``` 78 | python train_maxwellnet.py --directory examples\lens_tm 79 | ``` 80 | The above training cases take about 37 (TE mode) and 63 (TM mode) hours on V100, respectively. 81 | 82 |
83 | 84 | ### **2. MaxwellNet Solution** 85 | If you want to check the solution found by MaxwellNet, 86 | 87 | ``` 88 | python solution_maxwellnet.py --directory --model_filename --sample_filename 89 | ``` 90 | It will provide the sample ( in ) to the saved model () and return the solution found by MaxwellNet, and this output will be saved as an image in as you can see in the below examples. 91 | 92 | #### **Examples** 93 | If you want to calculate the solution found by MaxwellNet for the single spheric lenses (as trained above), 94 | * TE mode. 95 | ``` 96 | python solution_maxwellnet.py --directory examples\spheric_te --model_filename 250000_te_fourth.pt --sample_filename sample.npz 97 | ``` 98 | * TM mode. 99 | ``` 100 | python solution_maxwellnet.py --directory examples\spheric_tm --model_filename 250000_tm_fourth.pt --sample_filename sample.npz 101 | ``` 102 | 103 | | Mode | Result | 104 | | :---: | :---: | 105 | | TE mode |![Scheme](/figures/te_result.png) | 106 | | TM mode |![Scheme](/figures/tm_result.png) | 107 | You can find the solutions for the multiple lens training cases similarly. 108 | 109 | ## **Citation** 110 | 111 | If you find our work useful in your research, please consider citing our paper: 112 | ``` 113 | @article{lim2022maxwellnet, 114 | title={MaxwellNet: Physics-driven deep neural network training based on Maxwell’s equations}, 115 | author={Lim, Joowon and Psaltis, Demetri}, 116 | journal={APL Photonics}, 117 | volume={7}, 118 | number={1}, 119 | pages={011301}, 120 | year={2022}, 121 | publisher={AIP Publishing LLC} 122 | } 123 | ``` 124 | 125 | ## **Acknowledgments** 126 | We referred to the code from the following repo, [UNet](https://github.com/jvanvugt/pytorch-unet). We thank the authors for sharing their code. -------------------------------------------------------------------------------- /solution_maxwellnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com 2 | 3 | import torch 4 | from MaxwellNet import MaxwellNet 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | 10 | import os 11 | import json 12 | import argparse 13 | 14 | 15 | def main(args): 16 | 17 | specs_filename = os.path.join( 18 | os.getcwd(), args.directory, 'specs_maxwell.json') 19 | specs = json.load(open(specs_filename)) 20 | 21 | physical_specs = specs['PhysicalSpecs'] 22 | Nx = physical_specs['Nx'] 23 | Nz = physical_specs['Nz'] 24 | dpl = physical_specs['dpl'] 25 | wavelength = physical_specs['wavelength'] 26 | symmetry_x = physical_specs['symmetry_x'] 27 | mode = physical_specs['mode'] 28 | 29 | delta = wavelength / dpl 30 | Nx = Nx * (symmetry_x + 1) 31 | 32 | scat_pot_np = np.load(os.path.join( 33 | args.directory, args.sample_filename))['sample'] 34 | ri = np.load(os.path.join(args.directory, 35 | args.sample_filename))['n'] 36 | 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | model_directory = os.path.join(args.directory, 'model') 39 | model_dict = torch.load(os.path.join( 40 | model_directory, args.model_filename)) 41 | model = MaxwellNet(**specs["NetworkSpecs"], ** 42 | specs["PhysicalSpecs"]).to(device) 43 | model.load_state_dict(model_dict['state_dict']) 44 | model.eval() 45 | 46 | scat_pot_torch = torch.from_numpy(np.float32(scat_pot_np)).to(device) 47 | ri_torch = torch.tensor([np.float32(ri)]).to(device) 48 | (diff, total) = model(scat_pot_torch, ri_torch) 49 | 50 | total_np = total.cpu().detach().numpy() 51 | diff_np = diff.cpu().detach().numpy() 52 | total_np = total_np[0, 0::2, :, :] + 1j*total_np[0, 1::2, :, :] 53 | diff_np = diff_np[0, 0::2, :, :] + 1j*diff_np[0, 1::2, :, :] 54 | scat_pot_np = scat_pot_np[0, :, :, :] * (ri-1) + 1 55 | 56 | # Here, I use the following min max values to present the data just for visualization. This is approximately correct. 57 | # Please note that the output values from MaxwellNet are defined on Yee grid. 58 | # So, if you want to quantitatively compare the MaxwellNet outputs with solutions from another solver, you should compare two solutions at the same Yee grid points. 59 | x_min = -(Nx//2) * delta 60 | x_max = (Nx//2-1) * delta 61 | z_min = -(Nz//2) * delta 62 | z_max = (Nz//2-1) * delta 63 | fontsize = 20 64 | 65 | if mode == 'te': 66 | if symmetry_x == True: 67 | scat_pot_np = np.pad(np.concatenate( 68 | (np.flip(scat_pot_np[0, 1::, :], 0), scat_pot_np[0, :, :]), 0), ((1, 0), (0, 0))) 69 | total_np = np.pad(np.concatenate( 70 | (np.flip(total_np[0, 1::, :], 0), total_np[0, :, :]), 0), ((1, 0), (0, 0))) 71 | diff_np = np.pad(np.concatenate( 72 | (np.flip(diff_np[0, 1::, :], 0), diff_np[0, :, :]), 0), ((1, 0), (0, 0))) 73 | 74 | fig, axs = plt.subplots(1, 2, figsize=(8, 5)) 75 | fig.suptitle('TE mode - Sherical Lens', fontsize=fontsize) 76 | 77 | img0 = axs[0].imshow(scat_pot_np, extent=[ 78 | z_min, z_max, x_min, x_max], vmin=1, vmax=ri) 79 | axs[0].set_title('RI distribution', fontsize=fontsize) 80 | divider0 = make_axes_locatable(axs[0]) 81 | cax0 = divider0.append_axes("right", size="5%", pad=0.05) 82 | plt.colorbar(img0, cax=cax0) 83 | 84 | img1 = axs[1].imshow(np.abs(total_np), extent=[ 85 | z_min, z_max, x_min, x_max]) 86 | axs[1].set_title('Ey (envelop)', fontsize=fontsize) 87 | divider1 = make_axes_locatable(axs[1]) 88 | cax1 = divider1.append_axes("right", size="5%", pad=0.05) 89 | plt.colorbar(img1, cax=cax1) 90 | 91 | plt.tight_layout() 92 | plt.savefig(os.path.join(os.getcwd(), args.directory, 'te_result.png')) 93 | 94 | elif mode == 'tm': 95 | if symmetry_x == True: 96 | scat_pot_x_np = np.concatenate( 97 | (np.flip(scat_pot_np[0, :, :], 0), scat_pot_np[0, :, :]), 0) 98 | scat_pot_z_np = np.pad(np.concatenate( 99 | (np.flip(scat_pot_np[1, 1::, :], 0), scat_pot_np[1, :, :]), 0), ((1, 0), (0, 0))) 100 | total_x_np = np.concatenate( 101 | (np.flip(total_np[0, :, :], 0), total_np[0, :, :]), 0) 102 | total_z_np = np.pad(np.concatenate( 103 | (-np.flip(total_np[1, 1::, :], 0), total_np[1, :, :]), 0), ((1, 0), (0, 0))) 104 | 105 | fig, axs = plt.subplots(2, 2, figsize=(8, 10)) 106 | fig.suptitle('TM mode - Sherical Lens', fontsize=fontsize) 107 | print(scat_pot_x_np.shape) 108 | img00 = axs[0, 0].imshow(scat_pot_x_np, extent=[ 109 | z_min, z_max, x_min, x_max], vmin=1, vmax=ri) 110 | axs[0, 0].set_title('RI distribution', fontsize=fontsize) 111 | divider00 = make_axes_locatable(axs[0, 0]) 112 | cax00 = divider00.append_axes("right", size="5%", pad=0.05) 113 | plt.colorbar(img00, cax=cax00) 114 | 115 | img01 = axs[0, 1].imshow(np.abs(total_x_np), extent=[ 116 | z_min, z_max, x_min, x_max]) 117 | axs[0, 1].set_title('Ex (envelop)', fontsize=fontsize) 118 | divider01 = make_axes_locatable(axs[0, 1]) 119 | cax01 = divider01.append_axes("right", size="5%", pad=0.05) 120 | plt.colorbar(img01, cax=cax01) 121 | 122 | img10 = axs[1, 0].imshow(scat_pot_z_np, extent=[ 123 | z_min, z_max, x_min, x_max], vmin=1, vmax=ri) 124 | axs[1, 0].set_title('RI distribution', fontsize=fontsize) 125 | divider10 = make_axes_locatable(axs[1, 0]) 126 | cax10 = divider10.append_axes("right", size="5%", pad=0.05) 127 | plt.colorbar(img10, cax=cax10) 128 | 129 | img11 = axs[1, 1].imshow(np.abs(total_z_np), extent=[ 130 | z_min, z_max, x_min, x_max]) 131 | axs[1, 1].set_title('Ez (envelop)', fontsize=fontsize) 132 | divider11 = make_axes_locatable(axs[1, 1]) 133 | cax11 = divider11.append_axes("right", size="5%", pad=0.05) 134 | plt.colorbar(img11, cax=cax11) 135 | 136 | plt.tight_layout() 137 | plt.savefig(os.path.join(os.getcwd(), args.directory, 'tm_result.png')) 138 | 139 | else: 140 | raise KeyError("'mode' should me either 'te' or 'tm'.") 141 | 142 | 143 | if __name__ == '__main__': 144 | arg_parser = argparse.ArgumentParser(description="Train a MaxwellNet") 145 | arg_parser.add_argument( 146 | "--directory", 147 | "-d", 148 | required=True, 149 | default='examples\spheric_te', 150 | help="This directory should include " 151 | + "'sample.npz' (input to MaxwellNet) and " 152 | + "'model' folder where trained 'MaxwellNet' parameters are saved and " 153 | + "'specs_maxwell.json' used during training." 154 | ) 155 | arg_parser.add_argument( 156 | "--model_filename", 157 | required=True, 158 | help="This filename indicates the saved model file name within 'directory\model\'." 159 | ) 160 | arg_parser.add_argument( 161 | "--sample_filename", 162 | required=True, 163 | help="This filename indicates a .npz file to be provied to MaxwellNet to calculate the solution, and it should be located in 'directory\'." 164 | ) 165 | 166 | args = arg_parser.parse_args() 167 | main(args) 168 | -------------------------------------------------------------------------------- /train_maxwellnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com 2 | 3 | import torch 4 | from Dataset import LensDataset 5 | import torch.backends.cudnn as cudnn 6 | from torch.optim.lr_scheduler import StepLR 7 | from torch.utils.tensorboard import SummaryWriter 8 | from MaxwellNet import MaxwellNet 9 | 10 | import numpy as np 11 | import random 12 | import logging 13 | import argparse 14 | import os 15 | import json 16 | from datetime import datetime 17 | 18 | 19 | def main(directory, load_ckpt): 20 | logging.basicConfig(level=logging.DEBUG, 21 | format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', 22 | datefmt='%a, %d %b %Y %H:%M:%S', 23 | filename=os.path.join( 24 | os.getcwd(), directory, f"maxwellnet_{datetime.now():%Y-%m-%d %H-%M-%S}.log"), 25 | filemode='w') 26 | 27 | logging.info("training " + directory) 28 | 29 | specs_filename = os.path.join(directory, 'specs_maxwell.json') 30 | 31 | if not os.path.isfile(specs_filename): 32 | raise Exception( 33 | 'The experiment directory does not include specifications file "specs_maxwell.json"' 34 | ) 35 | 36 | specs = json.load(open(specs_filename)) 37 | 38 | seed_number = get_spec_with_default(specs, "Seed", None) 39 | if seed_number != None: 40 | fix_seed(seed_number, torch.cuda.is_available()) 41 | 42 | rank = 0 43 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 44 | 45 | logging.info("Experiment description: \n" + 46 | ' '.join([str(elem) for elem in specs["Description"]])) 47 | logging.info("Training with " + str(device)) 48 | 49 | model = MaxwellNet(**specs["NetworkSpecs"], **specs["PhysicalSpecs"]) 50 | if torch.cuda.device_count() > 1: 51 | logging.info("Multiple GPUs: " + str(torch.cuda.device_count())) 52 | if load_ckpt is not None: 53 | load_path = os.path.join(os.getcwd(), directory, 'model', load_ckpt) 54 | ckpt_dict = torch.load(load_path + '.pt') 55 | ckpt_epoch = ckpt_dict['epoch'] 56 | logging.info("Checkpoint loaded from {}-epoch".format(ckpt_epoch)) 57 | model.load_state_dict(ckpt_dict['state_dict']) 58 | 59 | model = torch.nn.DataParallel(model) 60 | model.train() 61 | model = model.to(device) 62 | 63 | logging.info("Number of network parameters: {}".format( 64 | sum(p.data.nelement() for p in model.parameters()))) 65 | logging.debug(specs["NetworkSpecs"]) 66 | logging.debug(specs["PhysicalSpecs"]) 67 | 68 | optimizer = torch.optim.Adam(model.parameters(), lr=get_spec_with_default( 69 | specs, "LearningRate", 0.0001), weight_decay=0) 70 | scheduler = StepLR(optimizer, step_size=get_spec_with_default( 71 | specs, "LearningRateDecayStep", 10000), gamma=get_spec_with_default(specs, "LearningRateDecay", 1.0)) 72 | 73 | batch_size = get_spec_with_default(specs, "BatchSize", 1) 74 | epochs = get_spec_with_default(specs, "Epochs", 1) 75 | snapshot_freq = specs["SnapshotFrequency"] 76 | physical_specs = specs["PhysicalSpecs"] 77 | symmetry_x = physical_specs['symmetry_x'] 78 | mode = physical_specs['mode'] 79 | high_order = physical_specs['high_order'] 80 | 81 | checkpoints = list(range(snapshot_freq, epochs + 1, snapshot_freq)) 82 | 83 | filename = 'maxwellnet_' + mode + '_' + high_order 84 | writer = SummaryWriter(os.path.join(directory, 'tensorboard_' + filename)) 85 | writer_freq = get_spec_with_default(specs, "TensorboardFrequency", None) 86 | 87 | train_dataset = LensDataset(directory, 'train') 88 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 89 | shuffle=True, pin_memory=True, sampler=None) 90 | logging.info("Train Dataset length: {}".format(len(train_dataset))) 91 | loss_train = torch.zeros( 92 | (int(epochs),), dtype=torch.float32, requires_grad=False) 93 | 94 | if len(train_dataset) > 1: 95 | perform_valid = True 96 | else: 97 | perform_valid = False 98 | 99 | if perform_valid == True: 100 | valid_dataset = LensDataset(directory, 'valid') 101 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, 102 | shuffle=True, pin_memory=True, sampler=None) 103 | logging.info("Valid Dataset length: {}".format(len(valid_dataset))) 104 | loss_valid = torch.zeros( 105 | (int(epochs),), dtype=torch.float32, requires_grad=False) 106 | 107 | if load_ckpt is not None: 108 | optimizer.load_state_dict(ckpt_dict['optimizer']) 109 | scheduler.load_state_dict(ckpt_dict['scheduler']) 110 | loss_train[:ckpt_epoch:] = ckpt_dict['loss_train'][:ckpt_epoch:] 111 | logging.info("Check point loaded from {}-epoch".format(ckpt_epoch)) 112 | 113 | start_epoch = ckpt_epoch 114 | else: 115 | start_epoch = 0 116 | 117 | logging.info("Training start") 118 | 119 | for epoch in range(start_epoch + 1, epochs + 1): 120 | train(train_loader, model, optimizer, epoch, loss_train, 121 | device, mode, symmetry_x, writer, writer_freq) 122 | logging.info("[Train] {} epoch. Loss: {:.5f}".format( 123 | epoch, loss_train[epoch-1].item())) if rank == 0 else None 124 | if perform_valid: 125 | valid(valid_loader, model, epoch, loss_valid, 126 | device, mode, symmetry_x, writer, writer_freq) 127 | logging.info("[Valid] {} epoch. Loss: {:.5f}".format( 128 | epoch, loss_valid[epoch-1].item())) if rank == 0 else None 129 | 130 | if epoch in checkpoints: 131 | logging.info("Checkpoint saved at {} epoch.".format( 132 | epoch)) if rank == 0 else None 133 | if rank == 0: 134 | save_checkpoint({ 135 | 'epoch': epoch, 136 | 'state_dict': model.module.state_dict(), 137 | 'optimizer': optimizer.state_dict(), 138 | 'loss_train': loss_train, 139 | 'scheduler': scheduler.state_dict(), 140 | }, directory, str(epoch) + '_' + mode + '_' + high_order) 141 | 142 | if epoch % 200 == 0: 143 | logging.info("'latest' checkpoint saved at {} epoch.".format( 144 | epoch)) if rank == 0 else None 145 | if rank == 0: 146 | save_checkpoint({ 147 | 'epoch': epoch, 148 | 'state_dict': model.module.state_dict(), 149 | 'optimizer': optimizer.state_dict(), 150 | 'loss_train': loss_train, 151 | 'scheduler': scheduler.state_dict(), 152 | }, directory, 'latest') 153 | 154 | scheduler.step() 155 | 156 | writer.close() if rank == 0 else None 157 | 158 | 159 | def train(train_loader, model, optimizer, epoch, loss_train, device, mode, symmetry, writer, writer_freq): 160 | model.train() 161 | with torch.set_grad_enabled(True): 162 | count = 0 163 | 164 | for data in train_loader: 165 | scat_pot_torch = data[0].to(device) 166 | ri_value_torch = data[1].to(device) 167 | 168 | (diff, total) = model(scat_pot_torch, ri_value_torch) 169 | 170 | l2 = diff.pow(2) 171 | loss = torch.mean(l2) 172 | optimizer.zero_grad() 173 | loss.backward() 174 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-3) 175 | optimizer.step() 176 | 177 | loss_train[epoch-1] += loss.item() * diff.size(0) 178 | count += diff.size(0) 179 | 180 | loss_train[epoch-1] = loss_train[epoch-1] / count 181 | 182 | if epoch % writer_freq == 0 and writer != None: 183 | to_tensorboard(total.permute(2, 3, 1, 0)[:, :, :, 0].clone().detach().cpu(), loss_train[epoch-1].numpy(), epoch, 184 | mode, symmetry, writer, 'train') 185 | 186 | 187 | def valid(valid_loader, model, epoch, loss_valid, device, mode, symmetry, writer, writer_freq): 188 | model.eval() 189 | with torch.set_grad_enabled(False): 190 | count = 0 191 | 192 | for data in valid_loader: 193 | scat_pot_torch = data[0].to(device) 194 | ri_value_torch = data[1].to(device) 195 | 196 | (diff, total) = model(scat_pot_torch, 197 | ri_value_torch) # [N, 1, H, W, D] 198 | 199 | l2 = diff.pow(2) 200 | loss = torch.mean(l2) 201 | 202 | loss_valid[epoch-1] += loss.item() * diff.size(0) 203 | count += diff.size(0) 204 | 205 | loss_valid[epoch-1] = loss_valid[epoch-1] / count 206 | 207 | if epoch % writer_freq == 0 and writer != None: 208 | to_tensorboard(total.permute(2, 3, 1, 0)[:, :, :, 0].clone().detach().cpu(), loss_valid[epoch-1].numpy(), epoch, 209 | mode, symmetry, writer, 'valid') 210 | 211 | 212 | def to_tensorboard(image, losses, epoch, mode, symmetry, writer, train_valid): 213 | if symmetry is True: 214 | if mode == 'te': 215 | y_pol = torch.cat((torch.flip(image, [0])[0:-1:, :, :], image), 0) 216 | elif mode == 'tm': 217 | x_pol = torch.cat((torch.flip(image, [0]), image), 0) 218 | z_pol = torch.cat((-torch.flip(image, [0])[0:-1, :, :], image), 0) 219 | 220 | if mode == 'te': 221 | polarization = ['y'] 222 | elif mode == 'tm': 223 | polarization = ['x', 'z'] 224 | 225 | for idx in range(len(polarization)): 226 | if symmetry == True: 227 | if polarization[idx] == 'y': 228 | image = y_pol 229 | elif polarization[idx] == 'x': 230 | image = x_pol 231 | elif polarization[idx] == 'z': 232 | image = z_pol 233 | 234 | amplitude = torch.sum( 235 | image[:, :, idx*2:(idx+1)*2].pow(2), 2).pow(1 / 2) 236 | amplitude = amplitude - torch.min(amplitude) 237 | amplitude = amplitude / torch.max(amplitude) 238 | writer.add_image(train_valid + '/' + mode + '/amplitude_' + 239 | polarization[idx], amplitude.unsqueeze(0), epoch) 240 | 241 | real = image[:, :, idx*2] 242 | real = real - torch.min(real) 243 | real = real / torch.max(real) 244 | writer.add_image(train_valid + '/' + mode + '/real_' + 245 | polarization[idx], real.unsqueeze(0), epoch) 246 | 247 | imaginary = image[:, :, idx*2+1] 248 | imaginary = imaginary - torch.min(imaginary) 249 | imaginary = imaginary / torch.max(imaginary) 250 | writer.add_image(train_valid + '/' + mode + '/imaginary_' + 251 | polarization[idx], imaginary.unsqueeze(0), epoch) 252 | 253 | writer.add_scalar(train_valid + '/' + mode, losses, epoch) 254 | 255 | 256 | def save_checkpoint(state, directory, filename): 257 | model_directory = os.path.join(directory, 'model') 258 | if os.path.exists(model_directory) == False: 259 | os.makedirs(model_directory) 260 | torch.save(state, os.path.join(model_directory, filename + '.pt')) 261 | 262 | 263 | def fix_seed(seed, is_cuda): 264 | random.seed(seed) 265 | np.random.seed(seed) 266 | torch.manual_seed(seed) 267 | if is_cuda: 268 | torch.cuda.manual_seed(seed) 269 | torch.cuda.manual_seed_all(seed) 270 | cudnn.benchmark = False 271 | cudnn.deterministic = True 272 | 273 | 274 | def get_spec_with_default(specs, key, default): 275 | try: 276 | return specs[key] 277 | except KeyError: 278 | return default 279 | 280 | 281 | if __name__ == '__main__': 282 | arg_parser = argparse.ArgumentParser(description="Train a MaxwellNet") 283 | arg_parser.add_argument( 284 | "--directory", 285 | "-d", 286 | required=True, 287 | default='examples\spheric_te', 288 | help="This directory should include " 289 | + "all the training and network parameters in 'specs_maxwell.json', and logging will be " 290 | + "done in this directory as well.", 291 | ) 292 | arg_parser.add_argument( 293 | "--load_ckpt", 294 | "-l", 295 | default=None, 296 | help="This should specify a filename of your checkpoint within 'directory'\model if you want to continue your training from the checkpoint.", 297 | ) 298 | 299 | args = arg_parser.parse_args() 300 | main(args.directory, args.load_ckpt) 301 | -------------------------------------------------------------------------------- /MaxwellNet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Joowon Lim, limjoowon@gmail.com 2 | 3 | import torch 4 | from torch import nn 5 | from UNet import UNet 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | 10 | 11 | class MaxwellNet(nn.Module): 12 | def __init__(self, depth=6, filter=16, norm='weight', up_mode='upconv', 13 | wavelength=1, dpl=20, Nx=256, Nz=256, pml_thickness=16, symmetry_x=False, mode='te', high_order='fourth'): 14 | 15 | super(MaxwellNet, self).__init__() 16 | self.mode = mode 17 | if mode == 'te': 18 | in_channels = 1 19 | out_channels = 2 20 | elif mode == 'tm': 21 | in_channels = 2 22 | out_channels = 4 23 | 24 | self.high_order = high_order 25 | self.model = UNet(in_channels, out_channels, 26 | depth, filter, norm, up_mode) 27 | 28 | # pixel size [um / pixel] 29 | delta = wavelength / dpl 30 | # wave-number [1 / um] 31 | k = 2 * math.pi / wavelength 32 | self.register_buffer('delta', torch.tensor( 33 | delta, dtype=torch.float32, requires_grad=False)) 34 | self.register_buffer('k', torch.tensor( 35 | k, dtype=torch.float32, requires_grad=False)) 36 | 37 | self.symmetry_x = symmetry_x 38 | 39 | if self.high_order == 'second': 40 | pad = 2 41 | self.pad = pad 42 | elif self.high_order == 'fourth': 43 | pad = 4 44 | self.pad = pad 45 | 46 | self.padding_ref = nn.Sequential(nn.ReflectionPad2d( 47 | (0, 0, pad, 0)), nn.ZeroPad2d((pad, pad, 0, pad))) 48 | self.padding_zero = nn.Sequential(nn.ZeroPad2d((pad, pad, pad, pad))) 49 | 50 | if symmetry_x == True: 51 | x = np.linspace(-pad, Nx + pad - 1, Nx + 2 * pad) * delta 52 | else: 53 | x = np.linspace(-Nx // 2 - pad, Nx // 2 + 54 | pad - 1, Nx + 2 * pad) * delta 55 | z = np.linspace(-Nz // 2 - pad, Nz // 2 + 56 | pad - 1, Nz + 2 * pad) * delta 57 | 58 | # Coordinate set-up 59 | zz, xx = np.meshgrid(z, x) 60 | self.Nx = zz.shape[0] 61 | self.Nz = zz.shape[1] 62 | 63 | # incident electric and magnetic fields definition on the Yee grid 64 | fast = np.exp(1j * (k * zz)) 65 | fast_z = np.exp(1j * (k * (zz + delta / 2))) 66 | 67 | self.register_buffer('fast', torch.zeros((1, 2, fast.shape[0], fast.shape[1]), dtype=torch.float32, 68 | requires_grad=False)) 69 | self.fast[0, 0, :, :] = torch.from_numpy(np.real(fast)) 70 | self.fast[0, 1, :, :] = torch.from_numpy(np.imag(fast)) 71 | 72 | self.register_buffer('fast_z', torch.zeros((1, 2, fast_z.shape[0], fast_z.shape[1]), dtype=torch.float32, 73 | requires_grad=False)) 74 | self.fast_z[0, 0, :, :] = torch.from_numpy(np.real(fast_z)) 75 | self.fast_z[0, 1, :, :] = torch.from_numpy(np.imag(fast_z)) 76 | 77 | # perfectly-matched-layer set up 78 | m = 4 79 | const = 5 80 | rx_p = 1 + 1j * const * (xx - x[-1] + pml_thickness * delta) ** m 81 | rx_p[0:-pml_thickness, :] = 0 82 | rx_n = 1 + 1j * const * (xx - x[0] - pml_thickness * delta) ** m 83 | rx_n[pml_thickness::, :] = 0 84 | rx = rx_p + rx_n 85 | if symmetry_x == True: 86 | rx[0:-pml_thickness:, :] = 1 87 | else: 88 | rx[pml_thickness:-pml_thickness, :] = 1 89 | 90 | rz_p = 1 + 1j * const * (zz - z[-1] + pml_thickness * delta) ** m 91 | rz_p[:, 0:-pml_thickness] = 0 92 | rz_n = 1 + 1j * const * (zz - z[0] - pml_thickness * delta) ** m 93 | rz_n[:, pml_thickness::] = 0 94 | rz = rz_p + rz_n 95 | rz[:, pml_thickness:-pml_thickness] = 1 96 | 97 | rx_inverse = 1 / rx 98 | rz_inverse = 1 / rz 99 | 100 | self.register_buffer('rx_inverse', torch.zeros((1, 2, rx_inverse.shape[0], rx_inverse.shape[1]), dtype=torch.float32, 101 | requires_grad=False)) 102 | self.rx_inverse[0, 0, :, :] = torch.from_numpy(np.real(rx_inverse)) 103 | self.rx_inverse[0, 1, :, :] = torch.from_numpy(np.imag(rx_inverse)) 104 | 105 | self.register_buffer('rz_inverse', torch.zeros((1, 2, rz_inverse.shape[0], rz_inverse.shape[1]), dtype=torch.float32, 106 | requires_grad=False)) 107 | self.rz_inverse[0, 0, :, :] = torch.from_numpy(np.real(rz_inverse)) 108 | self.rz_inverse[0, 1, :, :] = torch.from_numpy(np.imag(rz_inverse)) 109 | 110 | # Gradient and laplacian kernels set up 111 | self.register_buffer('gradient_h_z', torch.zeros( 112 | (2, 1, 1, 3), dtype=torch.float32, requires_grad=False)) 113 | self.gradient_h_z[:, :, 0, :] = torch.tensor( 114 | [-1 / delta, +1 / delta, 0]) 115 | self.register_buffer('gradient_h_x', torch.zeros( 116 | (2, 1, 3, 1), dtype=torch.float32, requires_grad=False)) 117 | self.gradient_h_x = self.gradient_h_z.permute(0, 1, 3, 2) 118 | self.register_buffer('gradient_h_z_ho', torch.zeros( 119 | (2, 1, 1, 5), dtype=torch.float32, requires_grad=False)) 120 | self.gradient_h_z_ho[:, :, 0, :] = torch.tensor( 121 | [1 / 24 / delta, -9 / 8 / delta, +9 / 8 / delta, -1 / 24 / delta, 0]) 122 | self.register_buffer('gradient_h_x_ho', torch.zeros( 123 | (2, 1, 5, 1), dtype=torch.float32, requires_grad=False)) 124 | self.gradient_h_x_ho = self.gradient_h_z_ho.permute(0, 1, 3, 2) 125 | 126 | self.register_buffer('gradient_e_z', torch.zeros( 127 | (2, 1, 1, 3), dtype=torch.float32, requires_grad=False)) 128 | self.gradient_e_z[:, :, 0, :] = torch.tensor( 129 | [0, -1 / delta, +1 / delta]) 130 | self.register_buffer('gradient_e_x', torch.zeros( 131 | (2, 1, 3, 1), dtype=torch.float32, requires_grad=False)) 132 | self.gradient_e_x = self.gradient_e_z.permute(0, 1, 3, 2) 133 | self.register_buffer('gradient_e_z_ho', torch.zeros( 134 | (2, 1, 1, 5), dtype=torch.float32, requires_grad=False)) 135 | self.gradient_e_z_ho[:, :, 0, :] = torch.tensor( 136 | [0, 1 / 24 / delta, -9 / 8 / delta, +9 / 8 / delta, -1 / 24 / delta]) 137 | self.register_buffer('gradient_e_x_ho', torch.zeros( 138 | (2, 1, 5, 1), dtype=torch.float32, requires_grad=False)) 139 | self.gradient_e_x_ho = self.gradient_e_z_ho.permute(0, 1, 3, 2) 140 | 141 | self.register_buffer('dd_z_fast', torch.zeros( 142 | (1, 2, Nx, Nz), dtype=torch.float32, requires_grad=False)) 143 | self.dd_z_fast = self.dd_z(self.fast)[:, :, self.pad:-self.pad:, :] 144 | self.register_buffer('dd_z_ho_fast', torch.zeros( 145 | (1, 2, Nx, Nz), dtype=torch.float32, requires_grad=False)) 146 | self.dd_z_ho_fast = self.dd_z_ho( 147 | self.fast)[:, :, self.pad:-self.pad:, :] 148 | 149 | def forward(self, scat_pot, ri_value): 150 | if self.mode == 'te': 151 | epsillon = scat_pot * \ 152 | (ri_value ** 2).unsqueeze(1).unsqueeze(2).unsqueeze(3) 153 | epsillon = torch.where(epsillon > 1.0, epsillon, torch.tensor( 154 | [1], dtype=torch.float32).to(ri_value.device)) 155 | 156 | x = self.model(scat_pot) 157 | total = torch.cat((x[:, 0:1, :, :] + 1, x[:, 1:2, :, :]), 1) 158 | 159 | ey = self.complex_multiplication(total[:, 0:2, :, :], 160 | self.fast[:, :, self.pad:-self.pad:, self.pad:-self.pad:]) 161 | ey_i = self.fast 162 | ey_s = ey - ey_i[:, :, self.pad:-self.pad:, self.pad:-self.pad:] 163 | 164 | if self.symmetry_x == True: 165 | ey_s = self.padding_ref(ey_s) 166 | else: 167 | ey_s = self.padding_zero(ey_s) 168 | 169 | if self.high_order == 'second': 170 | diff = self.dd_x_pml(ey_s)[:, :, :, self.pad:-self.pad] \ 171 | + self.dd_z_pml(ey_s)[:, :, self.pad:-self.pad, :] \ 172 | + self.dd_z_fast \ 173 | + self.k ** 2 * (epsillon * ey) 174 | 175 | elif self.high_order == 'fourth': 176 | diff = self.dd_x_ho_pml(ey_s)[:, :, :, self.pad:-self.pad] \ 177 | + self.dd_z_ho_pml(ey_s)[:, :, self.pad:-self.pad, :] \ 178 | + self.dd_z_ho_fast \ 179 | + self.k ** 2 * (epsillon * ey) 180 | 181 | elif self.mode == 'tm': 182 | epsillon = scat_pot * \ 183 | (ri_value ** 2).unsqueeze(1).unsqueeze(2).unsqueeze(3) 184 | epsillon_x = torch.where(epsillon[:, 0:1, :, :] > 1.0, epsillon[:, 0:1, :, :], 185 | torch.tensor([1], dtype=torch.float32).to(ri_value.device)) 186 | epsillon_z = torch.where(epsillon[:, 1:2, :, :] > 1.0, epsillon[:, 1:2, :, :], 187 | torch.tensor([1], dtype=torch.float32).to(ri_value.device)) 188 | 189 | x = self.model(scat_pot) 190 | total = torch.cat((x[:, 0:1, :, :] + 1, x[:, 1:4, :, :]), 1) 191 | 192 | ex = self.complex_multiplication( 193 | total[:, 0:2, :, :], self.fast[:, :, self.pad:-self.pad:, self.pad:-self.pad:]) 194 | ex_i = self.fast 195 | ex_s = ex - ex_i[:, :, self.pad:-self.pad:, self.pad:-self.pad:] 196 | 197 | ez_s = self.complex_multiplication( 198 | total[:, 2:4, :, :], self.fast_z[:, :, self.pad:-self.pad:, self.pad:-self.pad:]) 199 | 200 | if self.symmetry_x == True: 201 | ex_s = self.padding_zero(ex_s) 202 | ez_s = self.padding_ref(ez_s) 203 | ex_s[:, :, 0:self.pad, :] = torch.flip( 204 | ex_s[:, :, self.pad:2 * self.pad, :], [2]) 205 | ez_s[:, :, 0:self.pad, :] = -ez_s[:, :, 0:self.pad, :] 206 | else: 207 | ex_s = self.padding_zero(ex_s) 208 | ez_s = self.padding_zero(ez_s) 209 | 210 | if self.high_order == 'second': 211 | diff_x = self.dd_z_pml(ex_s)[:, :, self.pad:-self.pad:, :] \ 212 | + self.dd_z_fast \ 213 | - self.dd_zx(ez_s)[:, :, self.pad // 2:-self.pad // 2:, self.pad // 2:-self.pad // 2] \ 214 | + self.k ** 2 * (epsillon_x * ex) \ 215 | 216 | diff_z = self.dd_x_pml(ez_s)[:, :, :, self.pad:-self.pad] \ 217 | - self.dd_xz(ex_s)[:, :, self.pad // 2:-self.pad // 2:, self.pad // 2:-self.pad // 2] \ 218 | + self.k ** 2 * (epsillon_z * ez_s) \ 219 | 220 | elif self.high_order == 'fourth': 221 | diff_x = self.dd_z_ho_pml(ex_s)[:, :, self.pad:-self.pad:, :] \ 222 | + self.dd_z_ho_fast \ 223 | - self.dd_zx_ho_pml(ez_s)[:, :, self.pad//2:-self.pad//2:, self.pad//2:-self.pad//2] \ 224 | + self.k ** 2 * (epsillon_x * ex) 225 | 226 | diff_z = self.dd_x_ho_pml(ez_s)[:, :, :, self.pad:-self.pad] \ 227 | - self.dd_xz_ho_pml(ex_s)[:, :, self.pad//2:-self.pad//2:, self.pad//2:-self.pad//2] \ 228 | + self.k ** 2 * \ 229 | (epsillon_z * ez_s[:, :, self.pad:- 230 | self.pad:, self.pad:-self.pad:]) 231 | 232 | diff = torch.cat((diff_x, diff_z), 1) 233 | 234 | return diff, total 235 | 236 | def complex_multiplication(self, a, b): 237 | r_p = torch.mul(a[:, 0:1, :, :], b[:, 0:1, :, :]) - \ 238 | torch.mul(a[:, 1:2, :, :], b[:, 1:2, :, :]) 239 | i_p = torch.mul(a[:, 0:1, :, :], b[:, 1:2, :, :]) + \ 240 | torch.mul(a[:, 1:2, :, :], b[:, 0:1, :, :]) 241 | return torch.cat((r_p, i_p), 1) 242 | 243 | def complex_conjugate(self, a): 244 | return torch.cat((-a[:, 1:2, :, :], a[:, 0:1, :, :]), 1) 245 | 246 | def d_e_x(self, x): 247 | return F.conv2d(x, self.gradient_e_x, padding=0, groups=2) 248 | 249 | def d_e_x_ho(self, x): 250 | return F.conv2d(x, self.gradient_e_x_ho, padding=0, groups=2) 251 | 252 | def d_h_x(self, x): 253 | return F.conv2d(x, self.gradient_h_x, padding=0, groups=2) 254 | 255 | def d_h_x_ho(self, x): 256 | return F.conv2d(x, self.gradient_h_x_ho, padding=0, groups=2) 257 | 258 | def d_e_z(self, x): 259 | return F.conv2d(x, self.gradient_e_z, padding=0, groups=2) 260 | 261 | def d_e_z_ho(self, x): 262 | return F.conv2d(x, self.gradient_e_z_ho, padding=0, groups=2) 263 | 264 | def d_h_z(self, x): 265 | return F.conv2d(x, self.gradient_h_z, padding=0, groups=2) 266 | 267 | def d_h_z_ho(self, x): 268 | return F.conv2d(x, self.gradient_h_z_ho, padding=0, groups=2) 269 | 270 | def dd_x(self, x): 271 | return self.d_h_x(self.d_e_x(x)) 272 | 273 | def dd_x_ho(self, x): 274 | return self.d_h_x_ho(self.d_e_x_ho(x)) 275 | 276 | def dd_x_pml(self, x): 277 | return self.complex_multiplication(self.rx_inverse[:, :, 2:-2, :], self.d_h_x( 278 | self.complex_multiplication(self.rx_inverse[:, :, 1:-1, :], self.d_e_x(x)))) 279 | 280 | def dd_x_ho_pml(self, x): 281 | return self.complex_multiplication(self.rx_inverse[:, :, 4:-4, :], self.d_h_x_ho( 282 | self.complex_multiplication(self.rx_inverse[:, :, 2:-2, :], self.d_e_x_ho(x)))) 283 | 284 | def dd_z(self, x): 285 | return self.d_h_z(self.d_e_z(x)) 286 | 287 | def dd_z_ho(self, x): 288 | return self.d_h_z_ho(self.d_e_z_ho(x)) 289 | 290 | def dd_z_pml(self, x): 291 | return self.complex_multiplication(self.rz_inverse[:, :, :, 2:-2], self.d_h_z( 292 | self.complex_multiplication(self.rz_inverse[:, :, :, 1:-1], self.d_e_z(x)))) 293 | 294 | def dd_z_ho_pml(self, x): 295 | return self.complex_multiplication(self.rz_inverse[:, :, :, 4:-4], self.d_h_z_ho( 296 | self.complex_multiplication(self.rz_inverse[:, :, :, 2:-2], self.d_e_z_ho(x)))) 297 | 298 | def dd_zx(self, x): 299 | return self.d_h_z(self.d_e_x(x)) 300 | 301 | def dd_zx_ho(self, x): 302 | return self.d_h_z_ho(self.d_e_x_ho(x)) 303 | 304 | def dd_zx_pml(self, x): 305 | return self.complex_multiplication(self.rz_inverse[:, :, 1:-1, 1:-1], self.d_h_z( 306 | self.complex_multiplication(self.rx_inverse[:, :, 1:-1, :], self.d_e_x(x)))) 307 | 308 | def dd_zx_ho_pml(self, x): 309 | return self.complex_multiplication(self.rz_inverse[:, :, 2:-2, 2:-2], self.d_h_z_ho( 310 | self.complex_multiplication(self.rx_inverse[:, :, 2:-2, :], self.d_e_x_ho(x)))) 311 | 312 | def dd_xz(self, x): 313 | return self.d_h_x(self.d_e_z(x)) 314 | 315 | def dd_xz_ho(self, x): 316 | return self.d_h_x_ho(self.d_e_z_ho(x)) 317 | 318 | def dd_xz_pml(self, x): 319 | return self.complex_multiplication(self.rx_inverse[:, :, 1:-1, 1:-1], self.d_h_x( 320 | self.complex_multiplication(self.rz_inverse[:, :, :, 1:-1], self.d_e_z(x)))) 321 | 322 | def dd_xz_ho_pml(self, x): 323 | return self.complex_multiplication(self.rx_inverse[:, :, 2:-2, 2:-2], self.d_h_x_ho( 324 | self.complex_multiplication(self.rz_inverse[:, :, :, 2:-2], self.d_e_z_ho(x)))) 325 | --------------------------------------------------------------------------------