├── README.md ├── dataset.py ├── display_result.py ├── layer.py ├── main.py ├── model.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Noise2Noise 2 | 3 | #### Title 4 | [Noise2Noise: Learning Image Restoration without Clean Data](https://arxiv.org/abs/1803.04189) 5 | 6 | #### Abstract 7 | We apply basic statistical reasoning to signal reconstruction by machine learning -- learning to map corrupted observations to clean signals -- with a simple and powerful conclusion: it is possible to learn to restore images by only looking at corrupted examples, at performance at and sometimes exceeding training using clean data, without explicit image priors or likelihood models of the corruption. In practice, we show that a single model learns photographic noise removal, denoising synthetic Monte Carlo images, and reconstruction of undersampled MRI scans -- all corrupted by different processes -- based on noisy data only. 8 | 9 | ## Train 10 | $ python main.py --mode train \ 11 | --scope [scope name] \ 12 | --name_data [data name] \ 13 | --dir_data [data directory] \ 14 | --dir_log [log directory] \ 15 | --dir_checkpoint [checkpoint directory] 16 | --gpu_ids [gpu id; '-1': no gpu, '0, 1, ..., N-1': gpus] 17 | --- 18 | $ python main.py --mode train \ 19 | --scope resnet \ 20 | --name_data bsd500 \ 21 | --dir_data ./datasets \ 22 | --dir_log ./log \ 23 | --dir_checkpoint ./checkpoint 24 | --gpu_ids 0 25 | 26 | * Set **[scope name]** uniquely. 27 | * To understand hierarchy of directories based on their arguments, see **directories structure** below. 28 | * Hyperparameters were written to **arg.txt** under the **[log directory]**. 29 | 30 | 31 | ## Test 32 | $ python main.py --mode test \ 33 | --scope [scope name] \ 34 | --name_data [data name] \ 35 | --dir_data [data directory] \ 36 | --dir_log [log directory] \ 37 | --dir_checkpoint [checkpoint directory] \ 38 | --dir_result [result directory] 39 | --gpu_ids [gpu id; '-1': no gpu, '0, 1, ..., N-1': gpus] 40 | --- 41 | $ python main.py --mode test \ 42 | --scope resnet \ 43 | --name_data bsd500 \ 44 | --dir_data ./datasets \ 45 | --dir_log ./log \ 46 | --dir_checkpoint ./checkpoints \ 47 | --dir_result ./results 48 | --gpu_ids 0 49 | 50 | * To test using trained network, set **[scope name]** defined in the **train** phase. 51 | * Generated images are saved in the **images** subfolder along with **[result directory]** folder. 52 | * **index.html** is also generated to display the generated images. 53 | 54 | 55 | ## Tensorboard 56 | $ tensorboard --logdir [log directory]/[scope name]/[data name] \ 57 | --port [(optional) 4 digit port number] 58 | --- 59 | $ tensorboard --logdir ./log/resnet/bsd500 \ 60 | --port 6006 61 | 62 | After the above comment executes, go **http://localhost:6006** 63 | 64 | * You can change **[(optional) 4 digit port number]**. 65 | * Default 4 digit port number is **6006**. 66 | 67 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import skimage 4 | from skimage import transform 5 | import matplotlib.pyplot as plt 6 | import os 7 | 8 | 9 | class Dataset(torch.utils.data.Dataset): 10 | """ 11 | dataset of image files of the form 12 | stuff_trans.pt 13 | stuff_density.pt 14 | """ 15 | 16 | def __init__(self, data_dir, data_type='float32', transform=None, sgm=(25, 25)): 17 | self.data_dir = data_dir 18 | self.transform = transform 19 | self.data_type = data_type 20 | 21 | self.sgm_label = sgm[0] 22 | self.sgm_input = sgm[1] 23 | 24 | lst_data = os.listdir(data_dir) 25 | 26 | # lst_input = [f for f in lst_data if f.startswith('input')] 27 | # lst_label = [f for f in lst_data if f.startswith('label')] 28 | # 29 | # lst_input.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 30 | # lst_label.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 31 | # 32 | # self.lst_input = lst_input 33 | # self.lst_label = lst_label 34 | 35 | lst_data.sort(key=lambda f: (''.join(filter(str.isdigit, f)))) 36 | # lst_data.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 37 | 38 | self.lst_data = lst_data 39 | 40 | def __getitem__(self, index): 41 | # label = np.load(os.path.join(self.data_dir, self.lst_label[index])) 42 | # input = np.load(os.path.join(self.data_dir, self.lst_input[index])) 43 | # 44 | # if label.dtype == np.uint8: 45 | # label = label / 255.0 46 | # if input.dtype == np.uint8: 47 | # input = input / 255.0 48 | # 49 | # if label.ndim == 2: 50 | # label = np.expand_dims(label, axis=2) 51 | # if input.ndim == 2: 52 | # input = np.expand_dims(input, axis=2) 53 | # 54 | # if self.ny != label.shape[0]: 55 | # label = label.transpose((1, 0, 2)) 56 | # if self.ny != input.shape[0]: 57 | # input = input.transpose((1, 0, 2)) 58 | # 59 | # data = {'input': input, 'label': label} 60 | 61 | data = plt.imread(os.path.join(self.data_dir, self.lst_data[index])) 62 | 63 | if data.dtype == np.uint8: 64 | data = data / 255.0 65 | 66 | if data.ndim == 2: 67 | data = np.expand_dims(data, axis=2) 68 | 69 | if data.shape[0] > data.shape[1]: 70 | data = data.transpose((1, 0, 2)) 71 | 72 | sz = data.shape 73 | 74 | label = data + self.sgm_label / 255 * np.random.randn(sz[0], sz[1], sz[2]) 75 | input = data + self.sgm_input/255 * np.random.randn(sz[0], sz[1], sz[2]) 76 | 77 | data = {'input': input, 'label': label} 78 | 79 | if self.transform: 80 | data = self.transform(data) 81 | 82 | return data 83 | 84 | def __len__(self): 85 | return len(self.lst_data) 86 | 87 | 88 | class ToTensor(object): 89 | """Convert ndarrays in sample to Tensors.""" 90 | 91 | def __call__(self, data): 92 | # Swap color axis because numpy image: H x W x C 93 | # torch image: C x H x W 94 | 95 | # for key, value in data: 96 | # data[key] = torch.from_numpy(value.transpose((2, 0, 1))) 97 | # 98 | # return data 99 | 100 | input, label = data['input'], data['label'] 101 | 102 | input = input.transpose((2, 0, 1)).astype(np.float32) 103 | label = label.transpose((2, 0, 1)).astype(np.float32) 104 | return {'input': torch.from_numpy(input), 'label': torch.from_numpy(label)} 105 | 106 | 107 | class Normalize(object): 108 | def __init__(self, mean=0.5, std=0.5): 109 | self.mean = mean 110 | self.std = std 111 | 112 | def __call__(self, data): 113 | input, label = data['input'], data['label'] 114 | 115 | input = (input - self.mean) / self.std 116 | label = (label - self.mean) / self.std 117 | 118 | data = {'input': input, 'label': label} 119 | return data 120 | 121 | 122 | class RandomFlip(object): 123 | def __call__(self, data): 124 | # Random Left or Right Flip 125 | 126 | # for key, value in data: 127 | # data[key] = 2 * (value / 255) - 1 128 | # 129 | # return data 130 | input, label = data['input'], data['label'] 131 | 132 | if np.random.rand() > 0.5: 133 | input = np.fliplr(input) 134 | label = np.fliplr(label) 135 | 136 | if np.random.rand() > 0.5: 137 | input = np.flipud(input) 138 | label = np.flipud(label) 139 | 140 | return {'input': input, 'label': label} 141 | 142 | 143 | class Rescale(object): 144 | """Rescale the image in a sample to a given size 145 | 146 | Args: 147 | output_size (tuple or int): Desired output size. 148 | If tuple, output is matched to output_size. 149 | If int, smaller of image edges is matched 150 | to output_size keeping aspect ratio the same. 151 | """ 152 | 153 | def __init__(self, output_size): 154 | assert isinstance(output_size, (int, tuple)) 155 | self.output_size = output_size 156 | 157 | def __call__(self, data): 158 | input, label = data['input'], data['label'] 159 | 160 | h, w = input.shape[:2] 161 | 162 | if isinstance(self.output_size, int): 163 | if h > w: 164 | new_h, new_w = self.output_size * h / w, self.output_size 165 | else: 166 | new_h, new_w = self.output_size, self.output_size * w / h 167 | else: 168 | new_h, new_w = self.output_size 169 | 170 | new_h, new_w = int(new_h), int(new_w) 171 | 172 | input = transform.resize(input, (new_h, new_w)) 173 | label = transform.resize(label, (new_h, new_w)) 174 | 175 | return {'input': input, 'label': label} 176 | 177 | 178 | class RandomCrop(object): 179 | """Crop randomly the image in a sample 180 | 181 | Args: 182 | output_size (tuple or int): Desired output size. 183 | If int, square crop is made. 184 | """ 185 | 186 | def __init__(self, output_size): 187 | assert isinstance(output_size, (int, tuple)) 188 | if isinstance(output_size, int): 189 | self.output_size = (output_size, output_size) 190 | else: 191 | assert len(output_size) == 2 192 | self.output_size = output_size 193 | 194 | def __call__(self, data): 195 | input, label = data['input'], data['label'] 196 | 197 | h, w = input.shape[:2] 198 | new_h, new_w = self.output_size 199 | 200 | top = np.random.randint(0, h - new_h) 201 | left = np.random.randint(0, w - new_w) 202 | 203 | id_y = np.arange(top, top + new_h, 1)[:, np.newaxis].astype(np.int32) 204 | id_x = np.arange(left, left + new_w, 1).astype(np.int32) 205 | 206 | # input = input[top: top + new_h, left: left + new_w] 207 | # label = label[top: top + new_h, left: left + new_w] 208 | 209 | input = input[id_y, id_x] 210 | label = label[id_y, id_x] 211 | 212 | return {'input': input, 'label': label} 213 | 214 | 215 | class UnifromSample(object): 216 | """Crop randomly the image in a sample 217 | 218 | Args: 219 | output_size (tuple or int): Desired output size. 220 | If int, square crop is made. 221 | """ 222 | 223 | def __init__(self, stride): 224 | assert isinstance(stride, (int, tuple)) 225 | if isinstance(stride, int): 226 | self.stride = (stride, stride) 227 | else: 228 | assert len(stride) == 2 229 | self.stride = stride 230 | 231 | def __call__(self, data): 232 | input, label = data['input'], data['label'] 233 | 234 | h, w = input.shape[:2] 235 | stride_h, stride_w = self.stride 236 | new_h = h//stride_h 237 | new_w = w//stride_w 238 | 239 | top = np.random.randint(0, stride_h + (h - new_h * stride_h)) 240 | left = np.random.randint(0, stride_w + (w - new_w * stride_w)) 241 | 242 | id_h = np.arange(top, h, stride_h)[:, np.newaxis] 243 | id_w = np.arange(left, w, stride_w) 244 | 245 | input = input[id_h, id_w] 246 | label = label[id_h, id_w] 247 | 248 | return {'input': input, 'label': label} 249 | 250 | 251 | class ZeroPad(object): 252 | """Rescale the image in a sample to a given size 253 | 254 | Args: 255 | output_size (tuple or int): Desired output size. 256 | If tuple, output is matched to output_size. 257 | If int, smaller of image edges is matched 258 | to output_size keeping aspect ratio the same. 259 | """ 260 | 261 | def __init__(self, output_size): 262 | assert isinstance(output_size, (int, tuple)) 263 | self.output_size = output_size 264 | 265 | def __call__(self, data): 266 | input, label = data['input'], data['label'] 267 | 268 | h, w = input.shape[:2] 269 | 270 | if isinstance(self.output_size, int): 271 | if h > w: 272 | new_h, new_w = self.output_size * h / w, self.output_size 273 | else: 274 | new_h, new_w = self.output_size, self.output_size * w / h 275 | else: 276 | new_h, new_w = self.output_size 277 | 278 | new_h, new_w = int(new_h), int(new_w) 279 | 280 | l = (new_w - w)//2 281 | r = (new_w - w) - l 282 | 283 | u = (new_h - h)//2 284 | b = (new_h - h) - u 285 | 286 | input = np.pad(input, pad_width=((u, b), (l, r), (0, 0))) 287 | label = np.pad(label, pad_width=((u, b), (l, r), (0, 0))) 288 | 289 | return {'input': input, 'label': label} 290 | 291 | class ToNumpy(object): 292 | """Convert ndarrays in sample to Tensors.""" 293 | 294 | def __call__(self, data): 295 | # Swap color axis because numpy image: H x W x C 296 | # torch image: C x H x W 297 | 298 | # for key, value in data: 299 | # data[key] = value.transpose((2, 0, 1)).numpy() 300 | # 301 | # return data 302 | 303 | return data.to('cpu').detach().numpy().transpose(0, 2, 3, 1) 304 | 305 | # input, label = data['input'], data['label'] 306 | # input = input.transpose((2, 0, 1)) 307 | # label = label.transpose((2, 0, 1)) 308 | # return {'input': input.detach().numpy(), 'label': label.detach().numpy()} 309 | 310 | 311 | class Denormalize(object): 312 | def __init__(self, mean=0.5, std=0.5): 313 | self.mean = mean 314 | self.std = std 315 | 316 | def __call__(self, data): 317 | data = self.std * data + self.mean 318 | return data 319 | -------------------------------------------------------------------------------- /display_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision.utils as vutils 5 | import matplotlib.pyplot as plt 6 | 7 | dir_result = './results/unet-bnorm/em/images' 8 | lst_result = os.listdir(dir_result) 9 | 10 | lst_input = [f for f in lst_result if f.endswith('input.png')] 11 | lst_label = [f for f in lst_result if f.endswith('label.png')] 12 | lst_output = [f for f in lst_result if f.endswith('output.png')] 13 | 14 | nx = 512 15 | ny = 512 16 | nch = 1 17 | 18 | n = 3 19 | m = 5 20 | 21 | inputs = torch.zeros((m, ny, nx, nch)) 22 | labels = torch.zeros((m, ny, nx, nch)) 23 | outputs = torch.zeros((m, ny, nx, nch)) 24 | 25 | for i in range(m): 26 | inputs[i, :, :, :] = torch.from_numpy(plt.imread(os.path.join(dir_result, lst_input[i]))[:, :, :nch]) 27 | labels[i, :, :, :] = torch.from_numpy(plt.imread(os.path.join(dir_result, lst_label[i]))[:, :, :nch]) 28 | outputs[i, :, :, :] = torch.from_numpy(plt.imread(os.path.join(dir_result, lst_output[i]))[:, :, :nch]) 29 | 30 | inputs = inputs.permute((0, 3, 1, 2)) 31 | labels = labels.permute((0, 3, 1, 2)) 32 | outputs = outputs.permute((0, 3, 1, 2)) 33 | outputs = 1.0*(outputs > 0.5) 34 | 35 | images = torch.cat([inputs, labels, outputs], axis=2) 36 | 37 | plt.figure(figsize=(n, m)) 38 | plt.axis("off") 39 | # plt.title("Generated Images") 40 | plt.imshow(np.transpose(vutils.make_grid(images, padding=2, normalize=False), (1, 2, 0))) 41 | 42 | plt.show() 43 | 44 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CNR2d(nn.Module): 7 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]): 8 | super().__init__() 9 | 10 | if bias == []: 11 | if norm == 'bnorm': 12 | bias = False 13 | else: 14 | bias = True 15 | 16 | layers = [] 17 | layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] 18 | 19 | if norm != []: 20 | layers += [Norm2d(nch_out, norm)] 21 | 22 | if relu != []: 23 | layers += [ReLU(relu)] 24 | 25 | if drop != []: 26 | layers += [nn.Dropout2d(drop)] 27 | 28 | self.cbr = nn.Sequential(*layers) 29 | 30 | def forward(self, x): 31 | return self.cbr(x) 32 | 33 | 34 | class DECNR2d(nn.Module): 35 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]): 36 | super().__init__() 37 | 38 | if bias == []: 39 | if norm == 'bnorm': 40 | bias = False 41 | else: 42 | bias = True 43 | 44 | layers = [] 45 | layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)] 46 | 47 | if norm != []: 48 | layers += [Norm2d(nch_out, norm)] 49 | 50 | if relu != []: 51 | layers += [ReLU(relu)] 52 | 53 | if drop != []: 54 | layers += [nn.Dropout2d(drop)] 55 | 56 | self.decbr = nn.Sequential(*layers) 57 | 58 | def forward(self, x): 59 | return self.decbr(x) 60 | 61 | 62 | class ResBlock(nn.Module): 63 | def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): 64 | super().__init__() 65 | 66 | if bias == []: 67 | if norm == 'bnorm': 68 | bias = False 69 | else: 70 | bias = True 71 | 72 | layers = [] 73 | 74 | # 1st conv 75 | layers += [Padding(padding, padding_mode=padding_mode)] 76 | layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] 77 | 78 | if drop != []: 79 | layers += [nn.Dropout2d(drop)] 80 | 81 | # 2nd conv 82 | layers += [Padding(padding, padding_mode=padding_mode)] 83 | layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] 84 | 85 | self.resblk = nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | return x + self.resblk(x) 89 | 90 | 91 | class CNR1d(nn.Module): 92 | def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]): 93 | super().__init__() 94 | 95 | if norm == 'bnorm': 96 | bias = False 97 | else: 98 | bias = True 99 | 100 | layers = [] 101 | layers += [nn.Linear(nch_in, nch_out, bias=bias)] 102 | 103 | if norm != []: 104 | layers += [Norm2d(nch_out, norm)] 105 | 106 | if relu != []: 107 | layers += [ReLU(relu)] 108 | 109 | if drop != []: 110 | layers += [nn.Dropout2d(drop)] 111 | 112 | self.cbr = nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | return self.cbr(x) 116 | 117 | 118 | class Conv2d(nn.Module): 119 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True): 120 | super(Conv2d, self).__init__() 121 | self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) 122 | 123 | def forward(self, x): 124 | return self.conv(x) 125 | 126 | 127 | class Deconv2d(nn.Module): 128 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True): 129 | super(Deconv2d, self).__init__() 130 | self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias) 131 | 132 | # layers = [nn.Upsample(scale_factor=2, mode='bilinear'), 133 | # nn.ReflectionPad2d(1), 134 | # nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)] 135 | # 136 | # self.deconv = nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | return self.deconv(x) 140 | 141 | 142 | class Linear(nn.Module): 143 | def __init__(self, nch_in, nch_out): 144 | super(Linear, self).__init__() 145 | self.linear = nn.Linear(nch_in, nch_out) 146 | 147 | def forward(self, x): 148 | return self.linear(x) 149 | 150 | 151 | class Norm2d(nn.Module): 152 | def __init__(self, nch, norm_mode): 153 | super(Norm2d, self).__init__() 154 | if norm_mode == 'bnorm': 155 | self.norm = nn.BatchNorm2d(nch) 156 | elif norm_mode == 'inorm': 157 | self.norm = nn.InstanceNorm2d(nch) 158 | 159 | def forward(self, x): 160 | return self.norm(x) 161 | 162 | 163 | class ReLU(nn.Module): 164 | def __init__(self, relu): 165 | super(ReLU, self).__init__() 166 | if relu > 0: 167 | self.relu = nn.LeakyReLU(relu, True) 168 | elif relu == 0: 169 | self.relu = nn.ReLU(True) 170 | 171 | def forward(self, x): 172 | return self.relu(x) 173 | 174 | 175 | class Padding(nn.Module): 176 | def __init__(self, padding, padding_mode='zeros', value=0): 177 | super(Padding, self).__init__() 178 | if padding_mode == 'reflection': 179 | self. padding = nn.ReflectionPad2d(padding) 180 | elif padding_mode == 'replication': 181 | self.padding = nn.ReplicationPad2d(padding) 182 | elif padding_mode == 'constant': 183 | self.padding = nn.ConstantPad2d(padding, value) 184 | elif padding_mode == 'zeros': 185 | self.padding = nn.ZeroPad2d(padding) 186 | 187 | def forward(self, x): 188 | return self.padding(x) 189 | 190 | 191 | class Pooling2d(nn.Module): 192 | def __init__(self, nch=[], pool=2, type='avg'): 193 | super().__init__() 194 | 195 | if type == 'avg': 196 | self.pooling = nn.AvgPool2d(pool) 197 | elif type == 'max': 198 | self.pooling = nn.MaxPool2d(pool) 199 | elif type == 'conv': 200 | self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool) 201 | 202 | def forward(self, x): 203 | return self.pooling(x) 204 | 205 | 206 | class UnPooling2d(nn.Module): 207 | def __init__(self, nch=[], pool=2, type='nearest'): 208 | super().__init__() 209 | 210 | if type == 'nearest': 211 | self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest') 212 | elif type == 'bilinear': 213 | self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True) 214 | elif type == 'conv': 215 | self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool) 216 | 217 | def forward(self, x): 218 | return self.unpooling(x) 219 | 220 | 221 | class Concat(nn.Module): 222 | def __init__(self): 223 | super().__init__() 224 | 225 | def forward(self, x1, x2): 226 | diffy = x2.size()[2] - x1.size()[2] 227 | diffx = x2.size()[3] - x1.size()[3] 228 | 229 | x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, 230 | diffy // 2, diffy - diffy // 2]) 231 | 232 | return torch.cat([x2, x1], dim=1) 233 | 234 | 235 | class TV1dLoss(nn.Module): 236 | def __init__(self): 237 | super(TV1dLoss, self).__init__() 238 | 239 | def forward(self, input): 240 | # loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ 241 | # torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) 242 | loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:])) 243 | 244 | return loss 245 | 246 | 247 | class TV2dLoss(nn.Module): 248 | def __init__(self): 249 | super(TV2dLoss, self).__init__() 250 | 251 | def forward(self, input): 252 | loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ 253 | torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) 254 | return loss 255 | 256 | 257 | class SSIM2dLoss(nn.Module): 258 | def __init__(self): 259 | super(SSIM2dLoss, self).__init__() 260 | 261 | def forward(self, input, targer): 262 | loss = 0 263 | return loss 264 | 265 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.backends.cudnn as cudnn 4 | from train import * 5 | from utils import * 6 | 7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 8 | 9 | cudnn.benchmark = True 10 | cudnn.fastest = True 11 | 12 | ## setup parse 13 | parser = argparse.ArgumentParser(description='Train the unet network', 14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | 16 | parser.add_argument('--gpu_ids', default='-1', dest='gpu_ids') 17 | 18 | parser.add_argument('--mode', default='train', choices=['train', 'test'], dest='mode') 19 | parser.add_argument('--train_continue', default='off', choices=['on', 'off'], dest='train_continue') 20 | 21 | parser.add_argument('--scope', default='resnet', dest='scope') 22 | parser.add_argument('--norm', type=str, default='inorm', dest='norm') 23 | 24 | parser.add_argument('--dir_checkpoint', default='./checkpoints', dest='dir_checkpoint') 25 | parser.add_argument('--dir_log', default='./log', dest='dir_log') 26 | 27 | parser.add_argument('--name_data', type=str, default='bsd500', dest='name_data') 28 | parser.add_argument('--dir_data', default='../datasets', dest='dir_data') 29 | parser.add_argument('--dir_result', default='./results', dest='dir_result') 30 | 31 | parser.add_argument('--num_epoch', type=int, default=300, dest='num_epoch') 32 | parser.add_argument('--batch_size', type=int, default=4, dest='batch_size') 33 | 34 | parser.add_argument('--lr_G', type=float, default=1e-4, dest='lr_G') 35 | 36 | parser.add_argument('--optim', default='adam', choices=['sgd', 'adam', 'rmsprop'], dest='optim') 37 | parser.add_argument('--beta1', default=0.5, dest='beta1') 38 | 39 | parser.add_argument('--ny_in', type=int, default=321, dest='ny_in') 40 | parser.add_argument('--nx_in', type=int, default=481, dest='nx_in') 41 | parser.add_argument('--nch_in', type=int, default=3, dest='nch_in') 42 | 43 | parser.add_argument('--ny_load', type=int, default=256, dest='ny_load') 44 | parser.add_argument('--nx_load', type=int, default=256, dest='nx_load') 45 | parser.add_argument('--nch_load', type=int, default=3, dest='nch_load') 46 | 47 | parser.add_argument('--ny_out', type=int, default=256, dest='ny_out') 48 | parser.add_argument('--nx_out', type=int, default=256, dest='nx_out') 49 | parser.add_argument('--nch_out', type=int, default=3, dest='nch_out') 50 | 51 | parser.add_argument('--nch_ker', type=int, default=64, dest='nch_ker') 52 | 53 | parser.add_argument('--data_type', default='float32', dest='data_type') 54 | 55 | parser.add_argument('--num_freq_disp', type=int, default=1, dest='num_freq_disp') 56 | parser.add_argument('--num_freq_save', type=int, default=1, dest='num_freq_save') 57 | 58 | PARSER = Parser(parser) 59 | 60 | def main(): 61 | ARGS = PARSER.get_arguments() 62 | PARSER.write_args() 63 | PARSER.print_args() 64 | 65 | TRAINER = Train(ARGS) 66 | 67 | if ARGS.mode == 'train': 68 | TRAINER.train() 69 | elif ARGS.mode == 'test': 70 | TRAINER.test() 71 | 72 | if __name__ == '__main__': 73 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from layer import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.optim import lr_scheduler 7 | 8 | # U-Net: Convolutional Networks for Biomedical Image Segmentation 9 | # https://arxiv.org/abs/1505.04597 10 | class UNet(nn.Module): 11 | def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm'): 12 | super(UNet, self).__init__() 13 | 14 | self.nch_in = nch_in 15 | self.nch_out = nch_out 16 | self.nch_ker = nch_ker 17 | self.norm = norm 18 | 19 | if norm == 'bnorm': 20 | self.bias = False 21 | else: 22 | self.bias = True 23 | 24 | """ 25 | Encoder part 26 | """ 27 | 28 | self.enc1_1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 29 | self.enc1_2 = CNR2d(1 * self.nch_ker, 1 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 30 | 31 | self.pool1 = Pooling2d(pool=2, type='avg') 32 | 33 | self.enc2_1 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 34 | self.enc2_2 = CNR2d(2 * self.nch_ker, 2 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 35 | 36 | self.pool2 = Pooling2d(pool=2, type='avg') 37 | 38 | self.enc3_1 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 39 | self.enc3_2 = CNR2d(4 * self.nch_ker, 4 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 40 | 41 | self.pool3 = Pooling2d(pool=2, type='avg') 42 | 43 | self.enc4_1 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 44 | self.enc4_2 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 45 | 46 | self.pool4 = Pooling2d(pool=2, type='avg') 47 | 48 | self.enc5_1 = CNR2d(8 * self.nch_ker, 2 * 8 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 49 | 50 | """ 51 | Decoder part 52 | """ 53 | 54 | self.dec5_1 = DECNR2d(2 * 8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 55 | 56 | self.unpool4 = UnPooling2d(pool=2, type='nearest') 57 | 58 | self.dec4_2 = DECNR2d(2 * 8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 59 | self.dec4_1 = DECNR2d(8 * self.nch_ker, 4 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 60 | 61 | self.unpool3 = UnPooling2d(pool=2, type='nearest') 62 | 63 | self.dec3_2 = DECNR2d(2 * 4 * self.nch_ker, 4 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 64 | self.dec3_1 = DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 65 | 66 | self.unpool2 = UnPooling2d(pool=2, type='nearest') 67 | 68 | self.dec2_2 = DECNR2d(2 * 2 * self.nch_ker, 2 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 69 | self.dec2_1 = DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 70 | 71 | self.unpool1 = UnPooling2d(pool=2, type='nearest') 72 | 73 | self.dec1_2 = DECNR2d(2 * 1 * self.nch_ker, 1 * self.nch_ker, kernel_size=3, stride=1, norm=self.norm, relu=0.0, drop=[]) 74 | self.dec1_1 = DECNR2d(1 * self.nch_ker, 1 * self.nch_out, kernel_size=3, stride=1, norm=[], relu=[], drop=[], bias=False) 75 | 76 | def forward(self, x): 77 | 78 | """ 79 | Encoder part 80 | """ 81 | 82 | enc1 = self.enc1_2(self.enc1_1(x)) 83 | pool1 = self.pool1(enc1) 84 | 85 | enc2 = self.enc2_2(self.enc2_1(pool1)) 86 | pool2 = self.pool2(enc2) 87 | 88 | enc3 = self.enc3_2(self.enc3_1(pool2)) 89 | pool3 = self.pool3(enc3) 90 | 91 | enc4 = self.enc4_2(self.enc4_1(pool3)) 92 | pool4 = self.pool4(enc4) 93 | 94 | enc5 = self.enc5_1(pool4) 95 | 96 | """ 97 | Encoder part 98 | """ 99 | dec5 = self.dec5_1(enc5) 100 | 101 | unpool4 = self.unpool4(dec5) 102 | cat4 = torch.cat([enc4, unpool4], dim=1) 103 | dec4 = self.dec4_1(self.dec4_2(cat4)) 104 | 105 | unpool3 = self.unpool3(dec4) 106 | cat3 = torch.cat([enc3, unpool3], dim=1) 107 | dec3 = self.dec3_1(self.dec3_2(cat3)) 108 | 109 | unpool2 = self.unpool2(dec3) 110 | cat2 = torch.cat([enc2, unpool2], dim=1) 111 | dec2 = self.dec2_1(self.dec2_2(cat2)) 112 | 113 | unpool1 = self.unpool1(dec2) 114 | cat1 = torch.cat([enc1, unpool1], dim=1) 115 | dec1 = self.dec1_1(self.dec1_2(cat1)) 116 | 117 | x = dec1 118 | 119 | return x 120 | 121 | # Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 122 | # https://arxiv.org/abs/1609.04802 123 | class ResNet(nn.Module): 124 | def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm', nblk=16): 125 | super(ResNet, self).__init__() 126 | 127 | self.nch_in = nch_in 128 | self.nch_out = nch_out 129 | self.nch_ker = nch_ker 130 | self.norm = norm 131 | self.nblk = nblk 132 | 133 | if norm == 'bnorm': 134 | self.bias = False 135 | else: 136 | self.bias = True 137 | 138 | self.enc1 = CNR2d(self.nch_in, self.nch_ker, kernel_size=3, stride=1, padding=1, norm=[], relu=0.0) 139 | 140 | res = [] 141 | for i in range(self.nblk): 142 | res += [ResBlock(self.nch_ker, self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')] 143 | self.res = nn.Sequential(*res) 144 | 145 | self.dec1 = CNR2d(self.nch_ker, self.nch_ker, kernel_size=3, stride=1, padding=1, norm=norm, relu=[]) 146 | 147 | self.conv1 = Conv2d(self.nch_ker, self.nch_out, kernel_size=3, stride=1, padding=1) 148 | 149 | def forward(self, x): 150 | x = self.enc1(x) 151 | x0 = x 152 | 153 | x = self.res(x) 154 | 155 | x = self.dec1(x) 156 | x = x + x0 157 | 158 | x = self.conv1(x) 159 | 160 | return x 161 | 162 | 163 | class Discriminator(nn.Module): 164 | def __init__(self, nch_in, nch_ker=64, norm='bnorm'): 165 | super(Discriminator, self).__init__() 166 | 167 | self.nch_in = nch_in 168 | self.nch_ker = nch_ker 169 | self.norm = norm 170 | 171 | if norm == 'bnorm': 172 | self.bias = False 173 | else: 174 | self.bias = True 175 | 176 | # dsc1 : 256 x 256 x 3 -> 128 x 128 x 64 177 | # dsc2 : 128 x 128 x 64 -> 64 x 64 x 128 178 | # dsc3 : 64 x 64 x 128 -> 32 x 32 x 256 179 | # dsc4 : 32 x 32 x 256 -> 16 x 16 x 512 180 | # dsc5 : 16 x 16 x 512 -> 16 x 16 x 1 181 | 182 | self.dsc1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2) 183 | self.dsc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2) 184 | self.dsc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2) 185 | self.dsc4 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2) 186 | self.dsc5 = CNR2d(8 * self.nch_ker, 1, kernel_size=4, stride=1, padding=1, norm=[], relu=[], bias=False) 187 | 188 | # self.dsc1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2) 189 | # self.dsc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2) 190 | # self.dsc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2) 191 | # self.dsc4 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=1, padding=1, norm=[], relu=0.2) 192 | # self.dsc5 = CNR2d(8 * self.nch_ker, 1, kernel_size=4, stride=1, padding=1, norm=[], relu=[], bias=False) 193 | 194 | def forward(self, x): 195 | 196 | x = self.dsc1(x) 197 | x = self.dsc2(x) 198 | x = self.dsc3(x) 199 | x = self.dsc4(x) 200 | x = self.dsc5(x) 201 | 202 | # x = torch.sigmoid(x) 203 | 204 | return x 205 | 206 | 207 | def init_weights(net, init_type='normal', init_gain=0.02): 208 | """Initialize network weights. 209 | 210 | Parameters: 211 | net (network) -- network to be initialized 212 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 213 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 214 | 215 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 216 | work better for some applications. Feel free to try yourself. 217 | """ 218 | def init_func(m): # define the initialization function 219 | classname = m.__class__.__name__ 220 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 221 | if init_type == 'normal': 222 | init.normal_(m.weight.data, 0.0, init_gain) 223 | elif init_type == 'xavier': 224 | init.xavier_normal_(m.weight.data, gain=init_gain) 225 | elif init_type == 'kaiming': 226 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 227 | elif init_type == 'orthogonal': 228 | init.orthogonal_(m.weight.data, gain=init_gain) 229 | else: 230 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 231 | if hasattr(m, 'bias') and m.bias is not None: 232 | init.constant_(m.bias.data, 0.0) 233 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 234 | init.normal_(m.weight.data, 1.0, init_gain) 235 | init.constant_(m.bias.data, 0.0) 236 | 237 | print('initialize network with %s' % init_type) 238 | net.apply(init_func) # apply the initialization function 239 | 240 | 241 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 242 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 243 | Parameters: 244 | net (network) -- the network to be initialized 245 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 246 | gain (float) -- scaling factor for normal, xavier and orthogonal. 247 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 248 | 249 | Return an initialized network. 250 | """ 251 | if gpu_ids: 252 | assert(torch.cuda.is_available()) 253 | net.to(gpu_ids[0]) 254 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 255 | init_weights(net, init_type, init_gain=init_gain) 256 | return net 257 | 258 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from dataset import * 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchvision import transforms 8 | from torch.utils.tensorboard import SummaryWriter 9 | import matplotlib.pyplot as plt 10 | 11 | ## 12 | class Train: 13 | def __init__(self, args): 14 | self.mode = args.mode 15 | self.train_continue = args.train_continue 16 | 17 | self.scope = args.scope 18 | self.norm = args.norm 19 | 20 | self.dir_checkpoint = args.dir_checkpoint 21 | self.dir_log = args.dir_log 22 | 23 | self.name_data = args.name_data 24 | self.dir_data = args.dir_data 25 | self.dir_result = args.dir_result 26 | 27 | self.num_epoch = args.num_epoch 28 | self.batch_size = args.batch_size 29 | 30 | self.lr_G = args.lr_G 31 | 32 | self.optim = args.optim 33 | self.beta1 = args.beta1 34 | 35 | self.ny_in = args.ny_in 36 | self.nx_in = args.nx_in 37 | self.nch_in = args.nch_in 38 | 39 | self.ny_load = args.ny_load 40 | self.nx_load = args.nx_load 41 | self.nch_load = args.nch_load 42 | 43 | self.ny_out = args.ny_out 44 | self.nx_out = args.nx_out 45 | self.nch_out = args.nch_out 46 | 47 | self.nch_ker = args.nch_ker 48 | 49 | self.data_type = args.data_type 50 | 51 | self.num_freq_disp = args.num_freq_disp 52 | self.num_freq_save = args.num_freq_save 53 | 54 | self.gpu_ids = args.gpu_ids 55 | 56 | if self.gpu_ids and torch.cuda.is_available(): 57 | self.device = torch.device("cuda:%d" % self.gpu_ids[0]) 58 | torch.cuda.set_device(self.gpu_ids[0]) 59 | else: 60 | self.device = torch.device("cpu") 61 | 62 | def save(self, dir_chck, netG, optimG, epoch): 63 | if not os.path.exists(dir_chck): 64 | os.makedirs(dir_chck) 65 | 66 | torch.save({'netG': netG.state_dict(), 67 | 'optimG': optimG.state_dict()}, 68 | '%s/model_epoch%04d.pth' % (dir_chck, epoch)) 69 | 70 | def load(self, dir_chck, netG, optimG=[], epoch=[], mode='train'): 71 | 72 | if not os.path.exists(dir_chck): 73 | epoch = 0 74 | if mode == 'train': 75 | return netG, optimG, epoch 76 | elif mode == 'test': 77 | return netG, epoch 78 | 79 | if not epoch: 80 | ckpt = os.listdir(dir_chck) 81 | ckpt.sort() 82 | epoch = int(ckpt[-1].split('epoch')[1].split('.pth')[0]) 83 | 84 | dict_net = torch.load('%s/model_epoch%04d.pth' % (dir_chck, epoch)) 85 | 86 | print('Loaded %dth network' % epoch) 87 | 88 | if mode == 'train': 89 | netG.load_state_dict(dict_net['netG']) 90 | optimG.load_state_dict(dict_net['optimG']) 91 | 92 | return netG, optimG, epoch 93 | 94 | elif mode == 'test': 95 | netG.load_state_dict(dict_net['netG']) 96 | 97 | return netG, epoch 98 | 99 | def train(self): 100 | mode = self.mode 101 | 102 | train_continue = self.train_continue 103 | num_epoch = self.num_epoch 104 | 105 | lr_G = self.lr_G 106 | 107 | batch_size = self.batch_size 108 | device = self.device 109 | 110 | gpu_ids = self.gpu_ids 111 | 112 | nch_in = self.nch_in 113 | nch_out = self.nch_out 114 | nch_ker = self.nch_ker 115 | 116 | norm = self.norm 117 | name_data = self.name_data 118 | 119 | num_freq_disp = self.num_freq_disp 120 | num_freq_save = self.num_freq_save 121 | 122 | ## setup dataset 123 | dir_chck = os.path.join(self.dir_checkpoint, self.scope, name_data) 124 | 125 | dir_data_train = os.path.join(self.dir_data, name_data, 'train') 126 | dir_data_val = os.path.join(self.dir_data, name_data, 'val') 127 | 128 | dir_log_train = os.path.join(self.dir_log, self.scope, name_data, 'train') 129 | dir_log_val = os.path.join(self.dir_log, self.scope, name_data, 'val') 130 | 131 | dir_result_train = os.path.join(self.dir_result, self.scope, name_data, 'train') 132 | dir_result_val = os.path.join(self.dir_result, self.scope, name_data, 'val') 133 | if not os.path.exists(os.path.join(dir_result_train, 'images')): 134 | os.makedirs(os.path.join(dir_result_train, 'images')) 135 | if not os.path.exists(os.path.join(dir_result_val, 'images')): 136 | os.makedirs(os.path.join(dir_result_val, 'images')) 137 | 138 | transform_train = transforms.Compose([Normalize(mean=0.5, std=0.5), RandomFlip(), RandomCrop((self.ny_load, self.nx_load)), ToTensor()]) 139 | transform_val = transforms.Compose([Normalize(mean=0.5, std=0.5), RandomFlip(), RandomCrop((self.ny_load, self.nx_load)), ToTensor()]) 140 | 141 | transform_inv = transforms.Compose([ToNumpy(), Denormalize(mean=0.5, std=0.5)]) 142 | 143 | dataset_train = Dataset(dir_data_train, data_type=self.data_type, transform=transform_train, sgm=(25, 25)) 144 | dataset_val = Dataset(dir_data_val, data_type=self.data_type, transform=transform_val, sgm=(25, 25)) 145 | 146 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8) 147 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=True, num_workers=8) 148 | 149 | num_train = len(dataset_train) 150 | num_val = len(dataset_val) 151 | 152 | num_batch_train = int((num_train / batch_size) + ((num_train % batch_size) != 0)) 153 | num_batch_val = int((num_val / batch_size) + ((num_val % batch_size) != 0)) 154 | 155 | if nch_out == 1: 156 | cmap = 'gray' 157 | else: 158 | cmap = None 159 | 160 | ## setup network 161 | # netG = UNet(nch_in, nch_out, nch_ker, norm) 162 | netG = ResNet(nch_in, nch_out, nch_ker, norm) 163 | 164 | init_net(netG, init_type='normal', init_gain=0.02, gpu_ids=gpu_ids) 165 | 166 | ## setup loss & optimization 167 | fn_REG = nn.L1Loss().to(device) # Regression loss: L1 168 | # fn_REG = nn.MSELoss().to(device) # Regression loss: L2 169 | 170 | paramsG = netG.parameters() 171 | 172 | optimG = torch.optim.Adam(paramsG, lr=lr_G, betas=(self.beta1, 0.999)) 173 | 174 | ## load from checkpoints 175 | st_epoch = 0 176 | 177 | if train_continue == 'on': 178 | netG, optimG, st_epoch = self.load(dir_chck, netG, optimG, mode=mode) 179 | 180 | ## setup tensorboard 181 | writer_train = SummaryWriter(log_dir=dir_log_train) 182 | writer_val = SummaryWriter(log_dir=dir_log_val) 183 | 184 | for epoch in range(st_epoch + 1, num_epoch + 1): 185 | ## training phase 186 | netG.train() 187 | 188 | loss_G_train = [] 189 | 190 | for batch, data in enumerate(loader_train, 1): 191 | def should(freq): 192 | return freq > 0 and (batch % freq == 0 or batch == num_batch_train) 193 | 194 | label = data['label'].to(device) 195 | input = data['input'].to(device) 196 | 197 | # forward netG 198 | output = netG(input) 199 | 200 | # backward netG 201 | optimG.zero_grad() 202 | 203 | loss_G = fn_REG(output, label) 204 | 205 | loss_G.backward() 206 | optimG.step() 207 | 208 | # get losses 209 | loss_G_train += [loss_G.item()] 210 | 211 | print('TRAIN: EPOCH %d: BATCH %04d/%04d: LOSS: %.4f' 212 | % (epoch, batch, num_batch_train, np.mean(loss_G_train))) 213 | 214 | if should(num_freq_disp): 215 | ## show output 216 | input = transform_inv(input) 217 | label = transform_inv(label) 218 | output = transform_inv(output) 219 | 220 | input = np.clip(input, 0, 1) 221 | label = np.clip(label, 0, 1) 222 | output = np.clip(output, 0, 1) 223 | 224 | writer_train.add_images('input', input, num_batch_train * (epoch - 1) + batch, dataformats='NHWC') 225 | writer_train.add_images('output', output, num_batch_train * (epoch - 1) + batch, dataformats='NHWC') 226 | writer_train.add_images('label', label, num_batch_train * (epoch - 1) + batch, dataformats='NHWC') 227 | 228 | for j in range(label.shape[0]): 229 | # name = num_train * (epoch - 1) + num_batch_train * (batch - 1) + j 230 | name = num_batch_train * (batch - 1) + j 231 | fileset = {'name': name, 232 | 'input': "%04d-input.png" % name, 233 | 'output': "%04d-output.png" % name, 234 | 'label': "%04d-label.png" % name} 235 | 236 | plt.imsave(os.path.join(dir_result_train, 'images', fileset['input']), input[j, :, :, :].squeeze(), cmap=cmap) 237 | plt.imsave(os.path.join(dir_result_train, 'images', fileset['output']), output[j, :, :, :].squeeze(), cmap=cmap) 238 | plt.imsave(os.path.join(dir_result_train, 'images', fileset['label']), label[j, :, :, :].squeeze(), cmap=cmap) 239 | 240 | append_index(dir_result_train, fileset) 241 | 242 | writer_train.add_scalar('loss_G', np.mean(loss_G_train), epoch) 243 | 244 | ## validation phase 245 | with torch.no_grad(): 246 | netG.eval() 247 | 248 | loss_G_val = [] 249 | 250 | for batch, data in enumerate(loader_val, 1): 251 | def should(freq): 252 | return freq > 0 and (batch % freq == 0 or batch == num_batch_val) 253 | 254 | input = data['input'].to(device) 255 | label = data['label'].to(device) 256 | 257 | # forward netG 258 | output = netG(input) 259 | 260 | loss_G = fn_REG(output, label) 261 | 262 | loss_G_val += [loss_G.item()] 263 | 264 | print('VALID: EPOCH %d: BATCH %04d/%04d: LOSS: %.4f' 265 | % (epoch, batch, num_batch_val, np.mean(loss_G_val))) 266 | 267 | if should(num_freq_disp): 268 | ## show output 269 | input = transform_inv(input) 270 | label = transform_inv(label) 271 | output = transform_inv(output) 272 | 273 | input = np.clip(input, 0, 1) 274 | label = np.clip(label, 0, 1) 275 | output = np.clip(output, 0, 1) 276 | 277 | writer_val.add_images('input', input, num_batch_val * (epoch - 1) + batch, dataformats='NHWC') 278 | writer_val.add_images('output', output, num_batch_val * (epoch - 1) + batch, dataformats='NHWC') 279 | writer_val.add_images('label', label, num_batch_val * (epoch - 1) + batch, dataformats='NHWC') 280 | 281 | for j in range(label.shape[0]): 282 | # name = num_train * (epoch - 1) + num_batch_train * (batch - 1) + j 283 | name = num_batch_train * (batch - 1) + j 284 | fileset = {'name': name, 285 | 'input': "%04d-input.png" % name, 286 | 'output': "%04d-output.png" % name, 287 | 'label': "%04d-label.png" % name} 288 | 289 | plt.imsave(os.path.join(dir_result_val, 'images', fileset['input']), input[j, :, :, :].squeeze(), cmap=cmap) 290 | plt.imsave(os.path.join(dir_result_val, 'images', fileset['output']), output[j, :, :, :].squeeze(), cmap=cmap) 291 | plt.imsave(os.path.join(dir_result_val, 'images', fileset['label']), label[j, :, :, :].squeeze(), cmap=cmap) 292 | 293 | append_index(dir_result_val, fileset) 294 | 295 | writer_val.add_scalar('loss_G', np.mean(loss_G_val), epoch) 296 | 297 | # update schduler 298 | # schedG.step() 299 | # schedD.step() 300 | 301 | ## save 302 | if (epoch % num_freq_save) == 0: 303 | self.save(dir_chck, netG, optimG, epoch) 304 | 305 | writer_train.close() 306 | writer_val.close() 307 | 308 | def test(self): 309 | mode = self.mode 310 | 311 | batch_size = self.batch_size 312 | device = self.device 313 | gpu_ids = self.gpu_ids 314 | 315 | nch_in = self.nch_in 316 | nch_out = self.nch_out 317 | nch_ker = self.nch_ker 318 | 319 | norm = self.norm 320 | 321 | name_data = self.name_data 322 | 323 | if nch_out == 1: 324 | cmap = 'gray' 325 | else: 326 | cmap = None 327 | 328 | ## setup dataset 329 | dir_chck = os.path.join(self.dir_checkpoint, self.scope, name_data) 330 | 331 | dir_result_test = os.path.join(self.dir_result, self.scope, name_data, 'test') 332 | if not os.path.exists(os.path.join(dir_result_test, 'images')): 333 | os.makedirs(os.path.join(dir_result_test, 'images')) 334 | 335 | dir_data_test = os.path.join(self.dir_data, name_data, 'test') 336 | 337 | transform_test = transforms.Compose([Normalize(mean=0.5, std=0.5), ToTensor()]) 338 | transform_inv = transforms.Compose([ToNumpy(), Denormalize(mean=0.5, std=0.5)]) 339 | transform_ts2np = ToNumpy() 340 | 341 | dataset_test = Dataset(dir_data_test, data_type=self.data_type, transform=transform_test, sgm=(0, 25)) 342 | 343 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=0) 344 | 345 | num_test = len(dataset_test) 346 | 347 | num_batch_test = int((num_test / batch_size) + ((num_test % batch_size) != 0)) 348 | 349 | ## setup network 350 | # netG = UNet(nch_in, nch_out, nch_ker, norm) 351 | netG = ResNet(nch_in, nch_out, nch_ker, norm) 352 | init_net(netG, init_type='normal', init_gain=0.02, gpu_ids=gpu_ids) 353 | 354 | ## setup loss & optimization 355 | fn_REG = nn.L1Loss().to(device) # L1 356 | # fn_REG = nn.MSELoss().to(device) # L1 357 | 358 | ## load from checkpoints 359 | st_epoch = 0 360 | 361 | netG, st_epoch = self.load(dir_chck, netG, mode=mode) 362 | 363 | ## test phase 364 | with torch.no_grad(): 365 | netG.eval() 366 | # netG.train() 367 | 368 | loss_G_test = [] 369 | 370 | for i, data in enumerate(loader_test, 1): 371 | input = data['input'].to(device) 372 | label = data['label'].to(device) 373 | 374 | output = netG(input) 375 | 376 | loss_G = fn_REG(output, label) 377 | 378 | loss_G_test += [loss_G.item()] 379 | 380 | input = transform_inv(input) 381 | label = transform_inv(label) 382 | output = transform_inv(output) 383 | 384 | input = np.clip(input, 0, 1) 385 | label = np.clip(label, 0, 1) 386 | output = np.clip(output, 0, 1) 387 | 388 | for j in range(label.shape[0]): 389 | name = batch_size * (i - 1) + j 390 | fileset = {'name': name, 391 | 'input': "%04d-input.png" % name, 392 | 'output': "%04d-output.png" % name, 393 | 'label': "%04d-label.png" % name} 394 | 395 | plt.imsave(os.path.join(dir_result_test, 'images', fileset['input']), input[j, :, :, :].squeeze(), cmap=cmap) 396 | plt.imsave(os.path.join(dir_result_test, 'images', fileset['output']), output[j, :, :, :].squeeze(), cmap=cmap) 397 | plt.imsave(os.path.join(dir_result_test, 'images', fileset['label']), label[j, :, :, :].squeeze(), cmap=cmap) 398 | 399 | append_index(dir_result_test, fileset) 400 | 401 | print('TEST: %d/%d: LOSS: %.6f' % (i, num_batch_test, loss_G.item())) 402 | print('TEST: AVERAGE LOSS: %.6f' % (np.mean(loss_G_test))) 403 | 404 | 405 | def set_requires_grad(nets, requires_grad=False): 406 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 407 | Parameters: 408 | nets (network list) -- a list of networks 409 | requires_grad (bool) -- whether the networks require gradients or not 410 | """ 411 | if not isinstance(nets, list): 412 | nets = [nets] 413 | for net in nets: 414 | if net is not None: 415 | for param in net.parameters(): 416 | param.requires_grad = requires_grad 417 | 418 | 419 | def get_scheduler(optimizer, opt): 420 | """Return a learning rate scheduler 421 | 422 | Parameters: 423 | optimizer -- the optimizer of the network 424 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  425 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 426 | 427 | For 'linear', we keep the same learning rate for the first epochs 428 | and linearly decay the rate to zero over the next epochs. 429 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 430 | See https://pytorch.org/docs/stable/optim.html for more details. 431 | """ 432 | if opt.lr_policy == 'linear': 433 | def lambda_rule(epoch): 434 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 435 | return lr_l 436 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 437 | elif opt.lr_policy == 'step': 438 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 439 | elif opt.lr_policy == 'plateau': 440 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 441 | elif opt.lr_policy == 'cosine': 442 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 443 | else: 444 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 445 | return scheduler 446 | 447 | 448 | def append_index(dir_result, fileset, step=False): 449 | index_path = os.path.join(dir_result, "index.html") 450 | if os.path.exists(index_path): 451 | index = open(index_path, "a") 452 | else: 453 | index = open(index_path, "w") 454 | index.write("") 455 | if step: 456 | index.write("") 457 | for key, value in fileset.items(): 458 | index.write("" % key) 459 | index.write('') 460 | 461 | # for fileset in filesets: 462 | index.write("") 463 | 464 | if step: 465 | index.write("" % fileset["step"]) 466 | index.write("" % fileset["name"]) 467 | 468 | del fileset['name'] 469 | 470 | for key, value in fileset.items(): 471 | index.write("" % value) 472 | 473 | index.write("") 474 | return index_path 475 | 476 | 477 | def add_plot(output, label, writer, epoch=[], ylabel='Density', xlabel='Radius', namescope=[]): 478 | fig, ax = plt.subplots() 479 | 480 | ax.plot(output.transpose(1, 0).detach().numpy(), '-') 481 | ax.plot(label.transpose(1, 0).detach().numpy(), '--') 482 | 483 | ax.set_xlim(0, 400) 484 | 485 | ax.grid(True) 486 | ax.set_ylabel(ylabel) 487 | ax.set_xlabel(xlabel) 488 | 489 | writer.add_figure(namescope, fig, epoch) 490 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import logging 5 | import torch 6 | # import argparse 7 | 8 | '''' 9 | class Logger: 10 | class Parser: 11 | ''' 12 | class Parser: 13 | def __init__(self, parser): 14 | self.__parser = parser 15 | self.__args = parser.parse_args() 16 | 17 | # set gpu ids 18 | str_ids = self.__args.gpu_ids.split(',') 19 | self.__args.gpu_ids = [] 20 | for str_id in str_ids: 21 | id = int(str_id) 22 | if id >= 0: 23 | self.__args.gpu_ids.append(id) 24 | # if len(self.__args.gpu_ids) > 0: 25 | # torch.cuda.set_device(self.__args.gpu_ids[0]) 26 | 27 | def get_parser(self): 28 | return self.__parser 29 | 30 | def get_arguments(self): 31 | return self.__args 32 | 33 | def write_args(self): 34 | params_dict = vars(self.__args) 35 | 36 | log_dir = os.path.join(params_dict['dir_log'], params_dict['scope'], params_dict['name_data']) 37 | args_name = os.path.join(log_dir, 'args.txt') 38 | 39 | if not os.path.exists(log_dir): 40 | os.makedirs(log_dir) 41 | 42 | with open(args_name, 'wt') as args_fid: 43 | args_fid.write('----' * 10 + '\n') 44 | args_fid.write('{0:^40}'.format('PARAMETER TABLES') + '\n') 45 | args_fid.write('----' * 10 + '\n') 46 | for k, v in sorted(params_dict.items()): 47 | args_fid.write('{}'.format(str(k)) + ' : ' + ('{0:>%d}' % (35 - len(str(k)))).format(str(v)) + '\n') 48 | args_fid.write('----' * 10 + '\n') 49 | 50 | def print_args(self, name='PARAMETER TABLES'): 51 | params_dict = vars(self.__args) 52 | 53 | print('----' * 10) 54 | print('{0:^40}'.format(name)) 55 | print('----' * 10) 56 | for k, v in sorted(params_dict.items()): 57 | if '__' not in str(k): 58 | print('{}'.format(str(k)) + ' : ' + ('{0:>%d}' % (35 - len(str(k)))).format(str(v))) 59 | print('----' * 10) 60 | 61 | 62 | class Logger: 63 | def __init__(self, info=logging.INFO, name=__name__): 64 | logger = logging.getLogger(name) 65 | logger.setLevel(info) 66 | 67 | self.__logger = logger 68 | 69 | def get_logger(self, handler_type='stream_handler'): 70 | if handler_type == 'stream_handler': 71 | handler = logging.StreamHandler() 72 | log_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 73 | handler.setFormatter(log_format) 74 | else: 75 | handler = logging.FileHandler('utils.log') 76 | 77 | self.__logger.addHandler(handler) 78 | 79 | return self.__logger 80 | --------------------------------------------------------------------------------
step%s
%d%s