├── config.yml ├── requirements.txt ├── parameters.py ├── dataset_generation.py ├── train.py ├── README.md ├── DLPU.py └── main.py /config.yml: -------------------------------------------------------------------------------- 1 | name: 'Super_model' 2 | batch_size: 15 3 | total_epochs: 100 4 | lr: 1e-3 5 | loss_freq: 1 6 | metric_freq: 10 7 | lr_freq: 40 8 | save_freq: 10 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | GPUtil==1.4.0 2 | torch==1.8.1 3 | numpy==1.19.5 4 | matplotlib==3.2.2 5 | opencv-python==4.1.2.30 6 | scipy==1.6.3 7 | scikit-learn==0.22.2.post1 8 | hydra-core==1.0 9 | omegaconf 10 | prettytable==2.1.0 -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | from DLPU import * 3 | 4 | 5 | def count_parameters(model): 6 | table = PrettyTable(["Modules", "Parameters"]) 7 | total_params = 0 8 | for name, parameter in model.named_parameters(): 9 | if not parameter.requires_grad: continue 10 | param = parameter.numel() 11 | table.add_row([name, param]) 12 | total_params += param 13 | print(table) 14 | print(f"Total Trainable Params: {total_params}") 15 | return total_params 16 | 17 | 18 | if __name__ == "__main__": 19 | model = DLPU() 20 | count_parameters(model) 21 | -------------------------------------------------------------------------------- /dataset_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import cv2 4 | import torch 5 | 6 | 7 | def create_dataset_element(base_size, end_size, magnitude_min, magnituge_max): 8 | array = np.random.rand(base_size, base_size) 9 | coef = np.random.permutation(np.arange(magnitude_min, magnituge_max, 0.1))[0] 10 | element = cv2.resize(array, dsize=(end_size, end_size), interpolation=cv2.INTER_CUBIC) 11 | element = element * coef 12 | if np.min(element) >= 0: 13 | min_value = np.min(element) 14 | element = element - min_value 15 | else: 16 | min_value = np.min(element) 17 | element = element + abs(min_value) 18 | return element 19 | 20 | 21 | def make_gaussian(number_of_gaussians, sigma_min, sigma_max, shift_max, magnitude_min, magnitude_max): 22 | element = np.zeros([256, 256]) 23 | x = np.arange(-3.14, 3.14, 0.0246) 24 | y = np.arange(-3.14, 3.14, 0.0246) 25 | xx, yy = np.meshgrid(x, y); 26 | 27 | for i in range(number_of_gaussians): 28 | sigma = np.random.permutation(np.arange(sigma_min, sigma_max, .5))[0] 29 | shift_x = np.random.permutation(np.arange(-shift_max, shift_max, .1))[0] 30 | shift_y = np.random.permutation(np.arange(-shift_max, shift_max, .1))[0] 31 | magnitude = np.random.permutation(np.arange(magnitude_min, magnitude_max, .5))[0] 32 | 33 | d = np.sqrt((xx - shift_x) ** 2 + (yy - shift_y) ** 2) 34 | element += np.exp(-((d) ** 2 / (2.0 * sigma ** 2))) * (1 / sigma * np.sqrt(6.28)) 35 | 36 | element = element / np.max(element) 37 | element = element * magnitude 38 | 39 | if np.min(element) >= 0: 40 | min_value = np.min(element) 41 | element = element - min_value 42 | else: 43 | min_value = np.min(element) 44 | element = element + abs(min_value) 45 | 46 | return element 47 | 48 | 49 | def wraptopi(input_image): 50 | pi = 3.1415926535897932384626433 51 | output = input_image - 2 * pi * np.floor((input_image + pi) / (2 * pi)) 52 | return output 53 | 54 | 55 | if __name__ == "__main__": 56 | n = 5 57 | dataset = np.empty([n, 256, 256]) 58 | for i in range(n): 59 | if i % 2 == 0: 60 | size = np.random.permutation(np.arange(2, 15, 1))[0] 61 | dataset[i] = create_dataset_element(size, 256, 4, 20) 62 | else: 63 | num_gauss = np.random.permutation(np.arange(1, 7, 1))[0] 64 | dataset[i] = make_gaussian( 65 | num_gauss, 66 | sigma_min=1, 67 | sigma_max=4, 68 | shift_max=4, 69 | magnitude_min=2, 70 | magnitude_max=20) 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | random.seed(0) 8 | np.random.seed(0) 9 | torch.manual_seed(0) 10 | torch.cuda.manual_seed(0) 11 | torch.backends.cudnn.deterministic = True # работает медленнее, но зато воспроизводимость! 12 | 13 | 14 | def au_and_bem_torch(nn_output, ground_truth, calc_bem: bool): 15 | """ 16 | difference from "au_and_bem' is converting to np.ndarray and abs() 17 | 18 | calculates Binary Error Map (BEM) and Accuracy of Unwrapping (AU) 19 | for batch [batch_images,0,width,heidth] and returns mean AU of a batch 20 | with list of AU for every image and may be with BEM (optionally) 21 | 22 | function returns: 23 | au_mean - float, mean AU for batch 24 | au_list - list, info about AU for every image in batch 25 | bem - 3d boolean tensor, shows BEM in format [images_in_batch,width,height] 26 | 27 | args: 28 | nn_output - ndarray or torch.tensor - tensor that goes forward the net 29 | ground_truth - ndarray or tensor - ground truth image (original phase) 30 | calc_bem - boolean, if needed, will calculate BEM 31 | 32 | 33 | with input as np.ndarray runs 10 times faster 34 | """ 35 | nn_output = nn_output.numpy() 36 | ground_truth = ground_truth.numpy() 37 | 38 | au_list = [] 39 | bem = np.empty([ 40 | len(nn_output[:, 0, 0, 0]), 41 | len(nn_output[0, 0, :, 0]), 42 | len(nn_output[0, 0, 0, :]) 43 | ]) 44 | 45 | for k in range(len(nn_output[:, 0, 0, 0])): 46 | min_height = 0 47 | cnt = 0 48 | for i in range(len(nn_output[0, 0, :, 0])): 49 | for j in range(len(nn_output[0, 0, 0, :])): 50 | x = abs(nn_output[k, 0, i, j] - ground_truth[k, 0, i, j]) 51 | 52 | if calc_bem: 53 | 54 | if x <= (ground_truth[k, 0, i, j] - min_height) * 0.05: 55 | bem[k, i, j] = 1 56 | cnt += 1 57 | else: 58 | bem[k, i, j] = 0 59 | 60 | else: 61 | if x <= (ground_truth[k, 0, i, j] - min_height) * 0.05: 62 | cnt += 1 63 | 64 | au = cnt / (len(nn_output[0, 0, :, 0]) * len(nn_output[0, 0, 0, :])) 65 | # print(k,'au:',au) 66 | au_list.append(au) 67 | 68 | au_mean = sum(au_list) / len(au_list) 69 | 70 | if calc_bem: 71 | return au_mean, au_list, bem 72 | else: 73 | return au_mean, au_list 74 | 75 | 76 | def update_lr(optimizer, lr): 77 | for param_group in optimizer.param_groups: 78 | param_group['lr'] = lr 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DLPU 2 | PyTorch model DLPU for phase unwrapping 3 | 4 | This is a PyTorch realisation of deep convolutional Unet-like network, described in arcticle [1]. 5 | 6 | Original network was designed in TensorFlow framework, and this is the PyTorch version of it. 7 | 8 | # Usage 9 | 10 | `` 11 | 12 | # Changes 13 | I've added following moments to the structure: 14 | 15 | 1. Replication padding mode in conv3x3 blocks, because experiments have shown that it's important at the edges of phase maps, 16 | otherwise unwrapping quality will be low 17 | 2. In article there are some unclear moments: neural net structure contains of "five repeated uses of two 3×3 convolution operations (each followed by a BN and a ReLU), a residual block between the two convolution operations,..." 18 | So I made residual connections only for contracting path 19 | So, according to the article it should be CONV3x3->BN->ReLU -> Residual Block(???) -> CONV3x3->BN->ReLU and it's not clear. In contracting path (down) it's possible to make "good" residual connection, as shown below 20 | 21 | 22 | 23 | But autors write, that in expansive path (up) there is similar structure CONV3x3->BN->ReLU -> Residual Block(???) -> CONV3x3->BN->ReLU and it's impossible to use residual connection below (figure from article) because first CONV3x3 reduces channels by two, and second CONV3x3 reduces again channels by two, and that makes no sence (and possibility, because numbers of channels don't match) to use residual connection here like in contracting path. 24 | 25 | 26 | But i've tried to make following residual connection. 27 | 28 | 29 | 30 | 31 | # Dataset 32 | Dataset was generated synthetically according to articles [1,2] 33 | 34 | So, dataset data was generated using two methods (in equal proportions): 35 | 36 | 1. Interpolation of squared matrixes (with uniformly distributed elements) of different sizes (2x2 to 15x15) to 256x256 and multiplying by random value, so the magnitude is between 0 and 22 rad 37 | 2. Randomly generated Gaussians on 256x256 field with random quantity of functions, means, STD, and multiplying by random value, so the magnitude is between 2 and 20 rad 38 | 39 | ![Example1](https://user-images.githubusercontent.com/73649419/116145971-9fe1db00-a6e6-11eb-9ff3-7afc4982f8a3.png) 40 | ![Example2](https://user-images.githubusercontent.com/73649419/116145975-a1130800-a6e6-11eb-8b57-5cbf2e168ac9.png) 41 | 42 | # Model 43 | Model can be shown as following: 44 | In original paper there is unclear moment: "residual block (see Ref. 45 | 20 for details) between the two convolution operations" 46 | 47 | 48 | 49 | # Metrics 50 | I've implemented BEM (Binary Error Map), described in [3] with threshold 5%, according to formula 51 | 52 | ![render](https://user-images.githubusercontent.com/73649419/116073854-a5650400-a699-11eb-9dbd-30510f355bb6.png) 53 | 54 | # Training info 55 | In original paper authors describe train hyperparameters as follows: 56 | 57 | loss: pixelwise MSE 58 | 59 | optimizer: Adam 60 | 61 | learning rate: 10e-3 62 | 63 | My training: 64 | Training with MSE converges ~10x times faster than with MAE 65 | SGD with momentum=0.9 converges ~10x times faster than with adam (in both variations learning rate was 10e-4) 66 | 67 | 68 | (!) Succeed train to zero cost (0.025) SGR m=0.9, lr=0.0001 (really slow) 69 | 70 | # Parameters counting 71 | 72 | 73 | Total Trainable Params: 1824937 74 | 75 | # Todo's 76 | 77 | 1. code refactoring 78 | 79 | # References 80 | 1. K. Wang, Y. Li, K. Qian, J. Di, and J. Zhao, “One-step robust deep 81 | learning phase unwrapping,” Opt. Express 27, 15100–15115 (2019). 82 | 2. Spoorthi, G. E. et al. “PhaseNet 2.0: Phase Unwrapping of Noisy Data Based on Deep Learning Approach.” IEEE Transactions on Image Processing 29 (2020): 4862-4872. 83 | 3. Qin, Y., Wan, S., Wan, Y., Weng, J., Liu, W., & Gong, Q. (2020). Direct and accurate phase unwrapping with deep neural network. Applied optics, 59 24, 7258-7267 . 84 | -------------------------------------------------------------------------------- /DLPU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | random.seed(0) 8 | np.random.seed(0) 9 | torch.manual_seed(0) 10 | torch.cuda.manual_seed(0) 11 | torch.backends.cudnn.deterministic = True # работает медленнее, но зато воспроизводимость! 12 | 13 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 14 | 15 | 16 | def conv3x3(in_channels, out_channels, stride=1): 17 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 18 | stride=stride, padding=1, bias=False, padding_mode='replicate') 19 | 20 | 21 | class ResidualBlock(nn.Module): 22 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 23 | super(ResidualBlock, self).__init__() 24 | 25 | self.conv1 = conv3x3(in_channels, out_channels, stride) 26 | self.bn1 = nn.BatchNorm2d(out_channels) 27 | self.relu = nn.ReLU(inplace=False) 28 | self.conv2 = conv3x3(out_channels, out_channels) 29 | self.bn2 = nn.BatchNorm2d(out_channels) 30 | 31 | # self.downsample = downsample 32 | 33 | def forward(self, x): 34 | out = self.conv1(x) 35 | residual = out 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | out = out + residual 41 | out = self.relu(out) 42 | return out 43 | 44 | 45 | class ResidualBlockUp(nn.Module): 46 | 47 | def __init__(self, in_channels, out_channels, stride=1): 48 | super(ResidualBlockUp, self).__init__() 49 | 50 | self.conv1 = conv3x3(in_channels, 2 * out_channels, stride) 51 | self.bn1 = nn.BatchNorm2d(2 * out_channels) 52 | self.relu = nn.ReLU(inplace=False) 53 | self.conv2 = conv3x3(2 * out_channels, out_channels) 54 | self.bn2 = nn.BatchNorm2d(out_channels) 55 | 56 | def forward(self, x): 57 | out = self.conv1(x) 58 | residual = out 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | out += residual 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | out = self.relu(out) 65 | return out 66 | 67 | 68 | class DLPU(torch.nn.Module): 69 | 70 | def __init__(self): 71 | super(DLPU, self).__init__() 72 | 73 | self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2) 74 | 75 | self.block1 = ResidualBlock(1, 8) 76 | self.block2 = ResidualBlock(8, 16) 77 | self.block3 = ResidualBlock(16, 32) 78 | self.block4 = ResidualBlock(32, 64) 79 | self.block5 = ResidualBlock(64, 128) 80 | self.block6 = ResidualBlock(128, 256) 81 | 82 | self.block_up1 = ResidualBlockUp(256, 64) 83 | self.block_up2 = ResidualBlockUp(128, 32) 84 | self.block_up3 = ResidualBlockUp(64, 16) 85 | self.block_up4 = ResidualBlockUp(32, 8) 86 | self.block_up5 = ResidualBlockUp(16, 1) 87 | 88 | self.up_trans_1 = nn.ConvTranspose2d( 89 | in_channels=256, 90 | out_channels=128, 91 | kernel_size=2, 92 | stride=2) 93 | 94 | self.up_trans_2 = nn.ConvTranspose2d( 95 | in_channels=64, 96 | out_channels=64, 97 | kernel_size=2, 98 | stride=2) 99 | 100 | self.up_trans_3 = nn.ConvTranspose2d( 101 | in_channels=32, 102 | out_channels=32, 103 | kernel_size=2, 104 | stride=2) 105 | 106 | self.up_trans_4 = nn.ConvTranspose2d( 107 | in_channels=16, 108 | out_channels=16, 109 | kernel_size=2, 110 | stride=2) 111 | 112 | self.up_trans_5 = nn.ConvTranspose2d( 113 | in_channels=8, 114 | out_channels=8, 115 | kernel_size=2, 116 | stride=2) 117 | 118 | self.out = nn.Conv2d( 119 | in_channels=64, 120 | out_channels=1, 121 | kernel_size=1 122 | ) 123 | 124 | def forward(self, image): 125 | # encoder 126 | x1 = self.block1(image) 127 | x2 = self.max_pool_2x2(x1) 128 | 129 | x3 = self.block2(x2) 130 | x4 = self.max_pool_2x2(x3) 131 | 132 | x5 = self.block3(x4) 133 | x6 = self.max_pool_2x2(x5) 134 | 135 | x7 = self.block4(x6) 136 | x8 = self.max_pool_2x2(x7) 137 | 138 | x9 = self.block5(x8) 139 | x10 = self.max_pool_2x2(x9) 140 | 141 | # нижняя часть 142 | x11 = self.block6(x10) 143 | 144 | # decoder 145 | x = self.up_trans_1(x11) 146 | x = torch.cat([x, x9], 1) 147 | x = self.block_up1(x) 148 | 149 | x = self.up_trans_2(x) 150 | x = torch.cat([x, x7], 1) 151 | x = self.block_up2(x) 152 | 153 | x = self.up_trans_3(x) 154 | x = torch.cat([x, x5], 1) 155 | x = self.block_up3(x) 156 | 157 | x = self.up_trans_4(x) 158 | x = torch.cat([x, x3], 1) 159 | x = self.block_up4(x) 160 | 161 | x = self.up_trans_5(x) 162 | x = torch.cat([x, x1], 1) 163 | x = self.block_up5(x) 164 | return x 165 | 166 | # print(x.size(),'мой вывод после "линии"') 167 | 168 | 169 | if __name__ == "__main__": 170 | image = torch.rand((1, 1, 256, 256)) 171 | model = DLPU() 172 | print(model(image).size()) 173 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from DLPU import * 2 | from dataset_generation import * 3 | from train import * 4 | import csv 5 | import os 6 | import hydra 7 | from omegaconf import DictConfig 8 | 9 | n = 5000 10 | dataset = np.empty([n, 256, 256]) 11 | for i in range(n): 12 | if i % 2 == 0: 13 | size = np.random.permutation(np.arange(2, 15, 1))[0] 14 | dataset[i] = create_dataset_element(size, 256, 4, 20) 15 | else: 16 | num_gauss = np.random.permutation(np.arange(1, 7, 1))[0] 17 | dataset[i] = make_gaussian( 18 | num_gauss, 19 | sigma_min=1, 20 | sigma_max=4, 21 | shift_max=4, 22 | magnitude_min=2, 23 | magnitude_max=20) 24 | 25 | dataset_torch = torch.from_numpy(dataset) 26 | dataset_unsqueezed = dataset_torch.unsqueeze(1).float() 27 | X = wraptopi(dataset_unsqueezed); 28 | 29 | from sklearn.model_selection import train_test_split 30 | 31 | X_train, X_test, Y_train, Y_test = train_test_split( 32 | X[:, :, :, :], 33 | dataset_unsqueezed[:, :, :, :], 34 | test_size=0.3, 35 | shuffle=True) 36 | 37 | print(X_train.shape, 'Размерность тренировочных картинок "wrapped phase"') 38 | print(X_test.shape, 'Размерность тестовых картинок "wrapped phase"') 39 | print(Y_train.shape, 'Размерность тренировочных картинок ground truth') 40 | print(Y_test.shape, 'Размерность тестовых картинок ground truth') 41 | 42 | print(X_test.shape) 43 | 44 | model_DLPU = DLPU() 45 | 46 | 47 | def model_train( 48 | model, 49 | name, 50 | batch_size, 51 | total_epochs, 52 | learning_rate, 53 | loss_freq, 54 | metric_freq, 55 | lr_freq, 56 | save_freq): 57 | """ 58 | That function makes train process easier, only optimizer hyperparameters 59 | should be defined in function manually 60 | 61 | function returns: 62 | 1. trained model 63 | 2. list of metric history for every "metric_freq" epoch 64 | 3. list of losses history for every "loss_freq" epoch 65 | 4. list of train losses history for every "loss_freq" epoch 66 | 67 | args: 68 | model - torch.nn.Module object - defined model 69 | name - string, model checkpoints will be saved with this name 70 | batch size - integer, defines number of images in one batch 71 | total epoch - integer, defines number of epochs for learning 72 | learning rate - float, learning rate of an optimizer 73 | loss_freq - integer, loss function will be computed every "loss_freq" epochs 74 | metric_freq - integer, metric (AU) -||- 75 | lr_freq - integer, learning rate will be decreased -||- 76 | save_freq - integer, model checkpoints for train and validation will 77 | be saved -||- 78 | 79 | *time computing supports only GPU calculations 80 | """ 81 | 82 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 83 | 84 | if device.type == 'cuda': 85 | start = torch.cuda.Event(enable_timing=True) 86 | end = torch.cuda.Event(enable_timing=True) 87 | 88 | model = model.to(device) 89 | print('[INFO] Model will be learned on {}'.format(device)) 90 | 91 | metric_history = [] 92 | test_loss_history = [] 93 | train_loss_history = [] 94 | train_loss_epoch = 0 95 | 96 | loss = torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean') 97 | # loss = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean') 98 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 99 | # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9) 100 | 101 | if device.type == 'cuda': 102 | start.record() 103 | 104 | for epoch in np.arange(0, total_epochs, 1): 105 | 106 | print('>> Epoch: {}/{} Learning rate: {}'.format(epoch, total_epochs, learning_rate)) 107 | 108 | order = np.random.permutation(len(X_train)) 109 | 110 | for start_index in range(0, len(X_train), batch_size): 111 | optimizer.zero_grad() 112 | model.train() 113 | batch_indexes = order[start_index:start_index + batch_size] 114 | 115 | X_batch = X_train[batch_indexes].to(device) 116 | Y_batch = Y_train[batch_indexes].to(device) 117 | 118 | preds = model.forward(X_batch) 119 | 120 | loss_value = loss(preds, Y_batch) 121 | loss_value.backward() 122 | 123 | train_loss_epoch += loss_value.item() 124 | 125 | optimizer.step() 126 | ##### memory optimization start ##### 127 | # GPUtil.showUtilization() 128 | 129 | del X_batch, Y_batch 130 | torch.cuda.empty_cache() 131 | 132 | # GPUtil.showUtilization() 133 | ##### memory optimization end ##### 134 | 135 | train_loss_history.append(train_loss_epoch) 136 | print('[LOSS TRAIN] mean value of MSE {:.4f} on train set at epoch number {}'.format(train_loss_epoch, epoch)) 137 | train_loss_epoch = 0 138 | 139 | if epoch % loss_freq == 0: 140 | test_per_batch = [] 141 | print('[INFO] beginning to calculate loss') 142 | model.eval() 143 | order_test = np.random.permutation(len(X_test)) 144 | 145 | for start_index_test in range(0, len(X_test), batch_size): 146 | test_per_batch = [] 147 | 148 | batch_indexes_test = order_test[start_index_test:start_index_test + batch_size] 149 | 150 | with torch.no_grad(): 151 | X_batch_test = X_test[batch_indexes_test].to(device) 152 | Y_batch_test = Y_train[batch_indexes_test].to(device) 153 | 154 | test_preds = model.forward(X_batch_test) 155 | metric_loss = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean') 156 | 157 | test_loss = metric_loss(test_preds, Y_batch_test) 158 | test_per_batch.append(test_loss.data.cpu()) 159 | 160 | ##### memory optimization start ##### 161 | del X_batch_test, Y_batch_test 162 | torch.cuda.empty_cache() 163 | ##### memory optimization end ##### 164 | 165 | test_loss_epoch = sum(test_per_batch) / len(test_per_batch) 166 | test_loss_history.append(test_loss_epoch.tolist()) 167 | 168 | print('[LOSS TEST] mean value of MSE {:.4f} on test set at epoch number {}'.format(test_loss_epoch, epoch)) 169 | 170 | if epoch % metric_freq == 0: 171 | model.eval() 172 | 173 | order_metric = np.random.permutation(len(X_test)) 174 | 175 | for start_index_metric in range(0, len(X_test), batch_size): 176 | metric_per_batch = [] 177 | 178 | batch_indexes_metric = order_metric[start_index_metric:start_index_metric + batch_size] 179 | 180 | with torch.no_grad(): 181 | X_batch_metric = X_test[batch_indexes_metric].to(device) 182 | 183 | Y_batch_metric = Y_test[batch_indexes_metric] 184 | 185 | metric_preds = model.forward(X_batch_metric) 186 | 187 | # mean_au,_ = au_and_bem_torch(Y_batch_metric,metric_preds.detach().to('cpu'),calc_bem=False) 188 | mean_au_batch, _ = au_and_bem_torch(metric_preds.detach().to('cpu'), Y_batch_metric, calc_bem=False) 189 | 190 | metric_per_batch.append(mean_au_batch) 191 | # metric_per_batch.append(mean_au_batch.data.cpu()) 192 | 193 | ##### memory optimization start ##### 194 | # GPUtil.showUtilization() 195 | del X_batch_metric, Y_batch_metric, metric_preds 196 | torch.cuda.empty_cache() 197 | # GPUtil.showUtilization() 198 | ##### memory optimization end ##### 199 | 200 | test_metric_epoch = sum(metric_per_batch) / len(metric_per_batch) 201 | metric_history.append(test_metric_epoch) 202 | print('[METRIC] Accuracy of unwrapping on test images is {:.4f} %,'.format(test_metric_epoch * 100)) 203 | 204 | if epoch % save_freq == 0: 205 | torch.save({ 206 | 'epoch': epoch, 207 | 'model_state_dict': model.state_dict(), 208 | 'optimizer_state_dict': optimizer.state_dict(), 209 | 'loss': loss 210 | }, '{}/{}_checkpoint_{}'.format(path, name, epoch)) 211 | print('[SAVE] {}/{}_checkpoint_{} was saved'.format(path, name, epoch), ) 212 | 213 | if (epoch + 1) % lr_freq == 0: 214 | learning_rate /= 2 215 | update_lr(optimizer, learning_rate) 216 | print('[lr]New learning rate: {}'.format(learning_rate)) 217 | 218 | print('[END]Learning is done') 219 | torch.save({ 220 | 'epoch': epoch, 221 | 'model_state_dict': model.state_dict(), 222 | 'optimizer_state_dict': optimizer.state_dict(), 223 | 'loss': loss 224 | # ,'lr': learning_rate 225 | }, '{}/{}_checkpoint_end'.format(path, name)) 226 | print('[END]{}/{}_checkpoint_end was saved'.format(path, name)) 227 | 228 | if device.type == 'cuda': 229 | end.record() 230 | torch.cuda.synchronize() 231 | print('Learning time is {:.1f} min'.format(start.elapsed_time(end) / (1000 * 60))) 232 | 233 | with open('{}/metric_{}.csv'.format(path, name), 'w', newline='') as myfile: 234 | wr = csv.writer(myfile, quoting=csv.QUOTE_NONE) 235 | wr.writerow(metric_history) 236 | print('Metric was saved') 237 | 238 | with open('{}/test_loss_{}.csv'.format(path, name), 'w', newline='') as myfile: 239 | wr = csv.writer(myfile, quoting=csv.QUOTE_NONE) 240 | wr.writerow(test_loss_history) 241 | print('Test loss was saved') 242 | 243 | with open('{}/train_loss_{}.csv'.format(path, name), 'w', newline='') as myfile: 244 | wr = csv.writer(myfile, quoting=csv.QUOTE_NONE) 245 | wr.writerow(train_loss_history) 246 | print('Train loss was saved') 247 | 248 | return model, metric_history, test_loss_history, train_loss_history 249 | 250 | 251 | working_dir = os.getcwd() 252 | path = os.path.join(working_dir, "model") 253 | 254 | try: 255 | os.mkdir(path) 256 | except OSError as error: 257 | print('directory exists') 258 | 259 | print(f"The current base directory is {working_dir}") 260 | 261 | 262 | @hydra.main(config_path=os.path.join(working_dir, "config.yml")) 263 | def train(cfg: DictConfig): 264 | working_dir = os.getcwd() 265 | print(f"The current working directory is {working_dir}") 266 | 267 | # To access elements of the config 268 | print(f"The batch size is {cfg.batch_size}") 269 | print(f"The learning rate is {cfg.lr}") 270 | print(f"Total epochs: {cfg['total_epochs']}") 271 | 272 | trained_model, list_metric, list_test_loss, list_train_loss = model_train( 273 | model=model_DLPU, 274 | name=cfg.name, 275 | batch_size=cfg.batch_size, 276 | total_epochs=cfg.total_epochs, 277 | learning_rate=cfg.lr, 278 | loss_freq=cfg.loss_freq, 279 | metric_freq=cfg.metric_freq, 280 | lr_freq=cfg.lr_freq, 281 | save_freq=cfg.save_freq) 282 | 283 | 284 | if __name__ == "__main__": 285 | train() 286 | --------------------------------------------------------------------------------