├── OpenGAN_logo.png ├── README.md ├── utils ├── dataset_tinyimagenet.py ├── dataset_tinyimagenet_3sets.py ├── layers.py ├── eval_funcs.py ├── dataset_cifar10.py ├── dataset_cityscapes.py ├── dataset_cityscapes4OpenGAN.py └── network_arch_tinyimagenet.py └── demo_OpenSetSegmentation_training.ipynb /OpenGAN_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimerykong/OpenGAN/HEAD/OpenGAN_logo.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## OpenGAN: Open-Set Recognition via Open Data Generation 2 | 3 | ICCV 2021 ([best paper honorable mention](https://www.cs.cmu.edu/~shuk/OpenGAN.html)) 4 | 5 | ![alt text](https://github.com/aimerykong/OpenGAN/raw/main/OpenGAN_logo.png "video demo") 6 | 7 | [[website](https://www.cs.cmu.edu/~shuk/OpenGAN.html)] 8 | [[poster](http://www.cs.cmu.edu/~shuk/img/OpenGAN_poster.pdf)] 9 | [[slides](http://www.cs.cmu.edu/~shuk/img/OpenGAN_slides.pdf)] 10 | [[oral presentation](https://youtu.be/CNYqYXyUHn0)] 11 | [[paper](https://arxiv.org/abs/2104.02939)] 12 | [[PAMI Version](https://github.com/aimerykong/aimerykong.github.io/raw/main/OpenGAN_files/PAMI_OpenGAN_accepted_version.pdf) 18MB] 13 | 14 | Real-world machine learning systems need to analyze novel testing data that differs from the training data. In K-way classification, this is crisply formulated as open-set recognition, core to which is the ability to discriminate open-set data outside the K closed-set classes. Two conceptually elegant ideas for open-set discrimination are: 1) discriminatively learning an open-vs-closed binary discriminator by exploiting some outlier data as the open-set, and 2) unsupervised learning the closed-set data distribution with a GAN and using its discriminator as the open-set likelihood function. However, the former generalizes poorly to diverse open test data due to overfitting to the training outliers, which unlikely exhaustively span the open-world. The latter does not work well, presumably due to the instable training of GANs. Motivated by the above, we propose OpenGAN, which addresses the limitation of each approach by combining them with several technical insights. First, we show that a carefully selected GAN-discriminator on some real outlier data already achieves the state-of-the-art. Second, we augment the available set of real open training examples with adversarially synthesized "fake" data. 15 | Third and most importantly, we build the discriminator over the features computed by the closed-world K-way networks. 16 | Extensive experiments show that OpenGAN significantly outperforms prior open-set methods. 17 | 18 | 19 | **keywords**: out-of-distribution detection, anomaly detection, open-set recognition, novelty detection, density estimation, generative model, discriminative model, adverserial learning, image classification, semantic segmentation. 20 | 21 | 22 | If you find our model/method/dataset useful, please cite our work ([ICCV version on arxiv](https://arxiv.org/abs/2104.02939), [PAMI version](https://github.com/aimerykong/aimerykong.github.io/raw/main/OpenGAN_files/PAMI_OpenGAN_accepted_version.pdf)): 23 | 24 | @inproceedings{OpenGAN, 25 | title={OpenGAN: Open-Set Recognition via Open Data Generation}, 26 | author={Kong, Shu and Ramanan, Deva}, 27 | booktitle={ICCV}, 28 | year={2021} 29 | } 30 | 31 | @inproceedings{OpenGAN_PAMI, 32 | title={OpenGAN: Open-Set Recognition via Open Data Generation}, 33 | author={Kong, Shu and Ramanan, Deva}, 34 | booktitle={IEEE PAMI}, 35 | year={2022} 36 | } 37 | 38 | 39 | 40 | last update: July, 2021 41 | 42 | Shu Kong 43 | 44 | aimerykong At g-m-a-i-l dot com 45 | -------------------------------------------------------------------------------- /utils/dataset_tinyimagenet.py: -------------------------------------------------------------------------------- 1 | import os, random, time, copy 2 | from skimage import io, transform 3 | import numpy as np 4 | import os.path as path 5 | import scipy.io as sio 6 | from scipy import misc 7 | import matplotlib.pyplot as plt 8 | import PIL.Image 9 | import pickle 10 | import skimage.transform 11 | import csv 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.optim import lr_scheduler 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | 20 | import torchvision 21 | from torchvision import datasets, models, transforms 22 | 23 | 24 | 25 | 26 | class TINYIMAGENET(Dataset): 27 | def __init__(self, size=(64,64), set_name='train', 28 | path_to_data='/scratch/shuk/dataset/tiny-imagenet-200', 29 | isAugment=True): 30 | 31 | self.path_to_data = path_to_data 32 | self.mapping_name2id = {} 33 | self.mapping_id2name = {} 34 | with open(path.join(self.path_to_data, 'wnids.txt')) as csv_file: 35 | csv_reader = csv.reader(csv_file, delimiter=' ') 36 | idx = 0 37 | for row in csv_reader: 38 | self.mapping_id2name[idx] = row[0] 39 | self.mapping_name2id[row[0]] = idx 40 | idx += 1 41 | 42 | 43 | if set_name=='test': set_name = 'val' 44 | 45 | self.size = size 46 | self.set_name = set_name 47 | self.path_to_data = path_to_data 48 | self.isAugment = isAugment 49 | 50 | self.imageNameList = [] 51 | self.className = [] 52 | self.labelList = [] 53 | self.mappingLabel2Name = dict() 54 | curLabel = 0 55 | 56 | 57 | if self.set_name == 'val': 58 | with open(path.join(self.path_to_data, 'val', 'val_annotations.txt')) as csv_file: 59 | csv_reader = csv.reader(csv_file, delimiter='\t') 60 | line_count = 0 61 | for row in csv_reader: 62 | self.imageNameList += [path.join(self.path_to_data, 'val', 'images', row[0])] 63 | self.labelList += [self.mapping_name2id[row[1]]] 64 | else: # 'train' 65 | self.current_class_dir = path.join(self.path_to_data, self.set_name) 66 | for curClass in os.listdir(self.current_class_dir): 67 | if curClass[0]=='.': continue 68 | 69 | curLabel = self.mapping_name2id[curClass] 70 | for curImg in os.listdir(path.join(self.current_class_dir, curClass, 'images')): 71 | if curImg[0]=='.': continue 72 | self.labelList += [curLabel] 73 | self.imageNameList += [path.join(self.path_to_data, self.set_name, curClass, 'images', curImg)] 74 | 75 | 76 | self.current_set_len = len(self.labelList) 77 | 78 | if self.set_name=='test' or self.set_name=='val' or not self.isAugment: 79 | self.transform = transforms.Compose([ 80 | transforms.ToTensor(), 81 | transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)), 82 | ]) # ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 83 | else: 84 | self.transform = transforms.Compose([ 85 | transforms.RandomCrop(self.size[0], padding=4), 86 | transforms.RandomHorizontalFlip(), 87 | transforms.ToTensor(), 88 | transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)), 89 | ]) 90 | 91 | def __len__(self): 92 | return self.current_set_len 93 | 94 | def __getitem__(self, idx): 95 | curLabel = np.asarray(self.labelList[idx]) 96 | curImage = self.imageNameList[idx] 97 | curImage = PIL.Image.open(curImage).convert('RGB') 98 | curImage = self.transform(curImage) 99 | 100 | #print(idx, curLabel) 101 | 102 | #curLabel = torch.tensor([curLabel]).unsqueeze(0).unsqueeze(0) 103 | 104 | return curImage, curLabel -------------------------------------------------------------------------------- /utils/dataset_tinyimagenet_3sets.py: -------------------------------------------------------------------------------- 1 | import os, random, time, copy 2 | from skimage import io, transform 3 | import numpy as np 4 | import os.path as path 5 | import scipy.io as sio 6 | from scipy import misc 7 | import matplotlib.pyplot as plt 8 | import PIL.Image 9 | import pickle 10 | import skimage.transform 11 | import csv 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.optim import lr_scheduler 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | 20 | import torchvision 21 | from torchvision import datasets, models, transforms 22 | 23 | 24 | 25 | 26 | class TINYIMAGENET(Dataset): 27 | def __init__(self, size=(64,64), set_name='train', 28 | path_to_data='/scratch/shuk/dataset/tiny-imagenet-200', 29 | isAugment=True): 30 | 31 | self.path_to_data = path_to_data 32 | self.mapping_name2id = {} 33 | self.mapping_id2name = {} 34 | with open(path.join(self.path_to_data, 'wnids.txt')) as csv_file: 35 | csv_reader = csv.reader(csv_file, delimiter=' ') 36 | idx = 0 37 | for row in csv_reader: 38 | self.mapping_id2name[idx] = row[0] 39 | self.mapping_name2id[row[0]] = idx 40 | idx += 1 41 | 42 | 43 | #if set_name=='test': set_name = 'val' 44 | 45 | self.size = size 46 | self.set_name = set_name 47 | self.path_to_data = path_to_data 48 | self.isAugment = isAugment 49 | 50 | self.imageNameList = [] 51 | self.className = [] 52 | self.labelList = [] 53 | self.mappingLabel2Name = dict() 54 | curLabel = 0 55 | 56 | if self.set_name == 'test': 57 | img_dir = os.path.join(self.path_to_data, 'val', 'images') 58 | for file_name in os.listdir(img_dir): 59 | if file_name[-4:] == 'JPEG': 60 | self.imageNameList += [path.join(self.path_to_data, 'val', 'images', file_name)] 61 | self.labelList += [0] 62 | 63 | elif self.set_name == 'val': 64 | with open(path.join(self.path_to_data, 'val', 'val_annotations.txt')) as csv_file: 65 | csv_reader = csv.reader(csv_file, delimiter='\t') 66 | line_count = 0 67 | for row in csv_reader: 68 | self.imageNameList += [path.join(self.path_to_data, 'val', 'images', row[0])] 69 | self.labelList += [self.mapping_name2id[row[1]]] 70 | #with open(path.join(self.path_to_data, 'val', 'val_annotations.txt')) as csv_file: 71 | # csv_reader = csv.reader(csv_file, delimiter='\t') 72 | # line_count = 0 73 | # for row in csv_reader: 74 | # self.imageNameList += [path.join(self.path_to_data, 'val', 'images', row[0])] 75 | # self.labelList += [self.mapping_name2id[row[1]]] 76 | else: # 'train' 77 | self.current_class_dir = path.join(self.path_to_data, self.set_name) 78 | for curClass in os.listdir(self.current_class_dir): 79 | if curClass[0]=='.': continue 80 | 81 | curLabel = self.mapping_name2id[curClass] 82 | for curImg in os.listdir(path.join(self.current_class_dir, curClass, 'images')): 83 | if curImg[0]=='.': continue 84 | self.labelList += [curLabel] 85 | self.imageNameList += [path.join(self.path_to_data, self.set_name, curClass, 'images', curImg)] 86 | 87 | 88 | self.current_set_len = len(self.labelList) 89 | 90 | if self.set_name=='test' or self.set_name=='val' or not self.isAugment: 91 | self.transform = transforms.Compose([ 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)), 94 | ]) # ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 95 | else: 96 | self.transform = transforms.Compose([ 97 | transforms.RandomCrop(self.size[0], padding=4), 98 | transforms.RandomHorizontalFlip(), 99 | transforms.ToTensor(), 100 | transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)), 101 | ]) 102 | 103 | def __len__(self): 104 | return self.current_set_len 105 | 106 | def __getitem__(self, idx): 107 | curLabel = np.asarray(self.labelList[idx]) 108 | curImage = self.imageNameList[idx] 109 | curImage = PIL.Image.open(curImage).convert('RGB') 110 | curImage = self.transform(curImage) 111 | 112 | return curImage, curLabel -------------------------------------------------------------------------------- /utils/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | def focus_to_intrinsics(opt, fxy): # 2-dim output fxy 14 | """Convert Nx2 focal length fxy into camera intrinsics 15 | for homogeneous representation 16 | Nx2 --> Nx4x4 17 | """ 18 | N, d = fxy.size() 19 | 20 | dummy_leftMat = np.zeros((4, 2), dtype=np.float32) 21 | dummy_leftMat[0, 0] = 1 22 | dummy_leftMat[1, 1] = 1 23 | dummy_rightMat = np.zeros((1, 4), dtype=np.float32) 24 | dummy_rightMat[0, 0] = 1 25 | dummy_rightMat[0, 1] = 1 26 | dummy_residual = np.zeros((4, 4), dtype=np.float32) 27 | dummy_residual[0, 2] = 0.5 28 | dummy_residual[1, 2] = 0.5 29 | dummy_residual[2, 2] = 1 30 | dummy_residual[3, 3] = 1 31 | dummy_identity = torch.eye(4).unsqueeze(0).expand(N, -1, -1) 32 | 33 | dummy_leftMat = torch.from_numpy(dummy_leftMat).unsqueeze(0).expand(N, -1, -1) 34 | dummy_rightMat = torch.from_numpy(dummy_rightMat) 35 | dummy_residual = torch.from_numpy(dummy_residual).unsqueeze(0).expand(N, -1, -1) 36 | 37 | if not opt.no_cuda: 38 | dummy_identity = dummy_identity.cuda() 39 | dummy_leftMat = dummy_leftMat.cuda() 40 | dummy_rightMat = dummy_rightMat.cuda() 41 | dummy_residual = dummy_residual.cuda() 42 | 43 | #print(dummy_leftMat.shape, fxy.shape) 44 | fxy = torch.matmul(dummy_leftMat, fxy.unsqueeze(-1)) # Nxd dxp --> Nxp 45 | fxy = torch.matmul(fxy, dummy_rightMat) # Nxd dxp --> Nxp 46 | 47 | fxy = fxy * dummy_identity + dummy_residual 48 | #print(fxy.shape) 49 | return fxy 50 | 51 | 52 | def disp_to_depth(disp, min_depth, max_depth): 53 | """Convert network's sigmoid output into depth predictcion 54 | The formula for this conversion is given in the 'additional considerations' 55 | section of the paper. 56 | """ 57 | min_disp = 1 / max_depth 58 | max_disp = 1 / min_depth 59 | scaled_disp = min_disp + (max_disp - min_disp) * disp 60 | depth = 1 / scaled_disp 61 | return scaled_disp, depth 62 | 63 | 64 | def transformation_from_parameters(axisangle, translation, invert=False): 65 | """Convert the network's (axisangle, translation) output into a 4x4 matrix 66 | """ 67 | R = rot_from_axisangle(axisangle) 68 | t = translation.clone() 69 | 70 | if invert: 71 | R = R.transpose(1, 2) 72 | t *= -1 73 | 74 | T = get_translation_matrix(t) 75 | 76 | if invert: 77 | M = torch.matmul(R, T) 78 | else: 79 | M = torch.matmul(T, R) 80 | 81 | return M 82 | 83 | 84 | def get_translation_matrix(translation_vector): 85 | """Convert a translation vector into a 4x4 transformation matrix 86 | """ 87 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 88 | 89 | t = translation_vector.contiguous().view(-1, 3, 1) 90 | 91 | T[:, 0, 0] = 1 92 | T[:, 1, 1] = 1 93 | T[:, 2, 2] = 1 94 | T[:, 3, 3] = 1 95 | T[:, :3, 3, None] = t 96 | 97 | return T 98 | 99 | 100 | def rot_from_axisangle(vec): 101 | """Convert an axisangle rotation into a 4x4 transformation matrix 102 | (adapted from https://github.com/Wallacoloo/printipi) 103 | Input 'vec' has to be Bx1x3 104 | """ 105 | angle = torch.norm(vec, 2, 2, True) # p=2, dim=2, keepdim=True 106 | axis = vec / (angle + 1e-7) 107 | 108 | ca = torch.cos(angle) 109 | sa = torch.sin(angle) 110 | C = 1 - ca 111 | 112 | x = axis[..., 0].unsqueeze(1) 113 | y = axis[..., 1].unsqueeze(1) 114 | z = axis[..., 2].unsqueeze(1) 115 | 116 | xs = x * sa 117 | ys = y * sa 118 | zs = z * sa 119 | xC = x * C 120 | yC = y * C 121 | zC = z * C 122 | xyC = x * yC 123 | yzC = y * zC 124 | zxC = z * xC 125 | 126 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 127 | 128 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 129 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 130 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 131 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 132 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 133 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 134 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 135 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 136 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 137 | rot[:, 3, 3] = 1 138 | 139 | return rot 140 | 141 | 142 | class ConvBlock(nn.Module): 143 | """Layer to perform a convolution followed by ELU 144 | """ 145 | def __init__(self, in_channels, out_channels): 146 | super(ConvBlock, self).__init__() 147 | 148 | self.conv = Conv3x3(in_channels, out_channels) 149 | self.nonlin = nn.ELU(inplace=True) 150 | #self.nonlin = nn.ReLU() 151 | 152 | def forward(self, x): 153 | out = self.conv(x) 154 | out = self.nonlin(out) 155 | return out 156 | 157 | 158 | class Conv3x3(nn.Module): 159 | """Layer to pad and convolve input 160 | """ 161 | def __init__(self, in_channels, out_channels, use_refl=True): 162 | super(Conv3x3, self).__init__() 163 | 164 | if use_refl: 165 | self.pad = nn.ReflectionPad2d(1) 166 | else: 167 | self.pad = nn.ZeroPad2d(1) 168 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 169 | 170 | def forward(self, x): 171 | out = self.pad(x) 172 | out = self.conv(out) 173 | return out 174 | 175 | 176 | class BackprojectDepth(nn.Module): 177 | """Layer to transform a depth image into a point cloud 178 | """ 179 | def __init__(self, batch_size, height, width): 180 | super(BackprojectDepth, self).__init__() 181 | 182 | self.batch_size = batch_size 183 | self.height = height 184 | self.width = width 185 | 186 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 187 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 188 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords)) 189 | 190 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width)) 191 | 192 | self.pix_coords = torch.unsqueeze(torch.stack( 193 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 194 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 195 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1)) 196 | 197 | def forward(self, depth, inv_K): 198 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 199 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 200 | cam_points = torch.cat([cam_points, self.ones], 1) 201 | 202 | return cam_points 203 | 204 | 205 | class Project3D(nn.Module): 206 | """Layer which projects 3D points into a camera with intrinsics K and at position T 207 | """ 208 | def __init__(self, batch_size, height, width, eps=1e-7): 209 | super(Project3D, self).__init__() 210 | 211 | self.batch_size = batch_size 212 | self.height = height 213 | self.width = width 214 | self.eps = eps 215 | 216 | def forward(self, points, K, T): 217 | P = torch.matmul(K, T)[:, :3, :] 218 | 219 | cam_points = torch.matmul(P, points) 220 | 221 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 222 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) 223 | pix_coords = pix_coords.permute(0, 2, 3, 1) 224 | pix_coords[..., 0] /= self.width - 1 225 | pix_coords[..., 1] /= self.height - 1 226 | pix_coords = (pix_coords - 0.5) * 2 227 | return pix_coords 228 | 229 | 230 | def upsample(x): 231 | """Upsample input tensor by a factor of 2 232 | """ 233 | return F.interpolate(x, scale_factor=2, mode="nearest") 234 | 235 | 236 | def get_smooth_loss(disp, img): 237 | """Computes the smoothness loss for a disparity image 238 | The color image is used for edge-aware smoothness 239 | """ 240 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) 241 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) 242 | 243 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) 244 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) 245 | 246 | grad_disp_x *= torch.exp(-grad_img_x) 247 | grad_disp_y *= torch.exp(-grad_img_y) 248 | 249 | return grad_disp_x.mean() + grad_disp_y.mean() 250 | 251 | 252 | class SSIM(nn.Module): 253 | """Layer to compute the SSIM loss between a pair of images 254 | """ 255 | def __init__(self): 256 | super(SSIM, self).__init__() 257 | self.mu_x_pool = nn.AvgPool2d(3, 1) 258 | self.mu_y_pool = nn.AvgPool2d(3, 1) 259 | self.sig_x_pool = nn.AvgPool2d(3, 1) 260 | self.sig_y_pool = nn.AvgPool2d(3, 1) 261 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 262 | 263 | self.refl = nn.ReflectionPad2d(1) 264 | 265 | self.C1 = 0.01 ** 2 266 | self.C2 = 0.03 ** 2 267 | 268 | def forward(self, x, y): 269 | x = self.refl(x) 270 | y = self.refl(y) 271 | 272 | mu_x = self.mu_x_pool(x) 273 | mu_y = self.mu_y_pool(y) 274 | 275 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 276 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 277 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 278 | 279 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 280 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 281 | 282 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 283 | 284 | 285 | def compute_depth_errors(gt, pred): 286 | """Computation of error metrics between predicted and ground truth depths 287 | """ 288 | thresh = torch.max((gt / pred), (pred / gt)) 289 | a1 = (thresh < 1.25).float().mean() 290 | a2 = (thresh < 1.25 ** 2).float().mean() 291 | a3 = (thresh < 1.25 ** 3).float().mean() 292 | 293 | rmse = (gt - pred) ** 2 294 | rmse = torch.sqrt(rmse.mean()) 295 | 296 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 297 | rmse_log = torch.sqrt(rmse_log.mean()) 298 | 299 | abs_rel = torch.mean(torch.abs(gt - pred) / gt) 300 | 301 | sq_rel = torch.mean((gt - pred) ** 2 / gt) 302 | 303 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 304 | -------------------------------------------------------------------------------- /utils/eval_funcs.py: -------------------------------------------------------------------------------- 1 | import os, random, time, copy 2 | from skimage import io, transform 3 | import numpy as np 4 | import os.path as path 5 | import scipy.io as sio 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | import PIL.Image 9 | from sklearn.metrics import roc_curve, roc_auc_score, f1_score 10 | import pandas as pd 11 | 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.optim import lr_scheduler 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | 20 | import torchvision 21 | from torchvision import models, transforms 22 | 23 | import sklearn.metrics 24 | 25 | def F_measure(preds, labels, openset=False, theta=None): 26 | if openset: 27 | # f1 score for openset evaluation 28 | true_pos = 0. 29 | false_pos = 0. 30 | false_neg = 0. 31 | for i in range(len(labels)): 32 | true_pos += 1 if preds[i] == labels[i] and labels[i] != -1 else 0 33 | false_pos += 1 if preds[i] != labels[i] and labels[i] != -1 else 0 34 | false_neg += 1 if preds[i] != labels[i] and labels[i] == -1 else 0 35 | 36 | precision = true_pos / (true_pos + false_pos) 37 | recall = true_pos / (true_pos + false_neg) 38 | return 2 * ((precision * recall) / (precision + recall + 1e-12)) 39 | else: # Regular f1 score 40 | return f1_score(labels, preds, average='macro') 41 | 42 | # 43 | # ref: https://github.com/lwneal/counterfactual-open-set/blob/master/generativeopenset/evaluation.py 44 | class ClassCentroids(nn.Module): 45 | def __init__(self, num_classes=10, feat_dim=2, device='cpu'): 46 | super(ClassCentroids, self).__init__() 47 | self.num_classes = num_classes 48 | self.feat_dim = feat_dim 49 | self.centers = torch.randn(self.num_classes, self.feat_dim) 50 | #self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 51 | self.device = device 52 | if self.device!='cpu': 53 | self.centers.to(self.device) 54 | 55 | def forward(self, x, labels): 56 | batch_size = x.size(0) 57 | # ||x-y||_2 = (x-y)^2 = x^2 + y^2 - 2xy 58 | # This part of the calculation is “x^2+y^2” 59 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 60 | # This part is "x^2+y^2 - 2xy" 61 | distmat.addmm_(1, -2, x, self.centers.t()) 62 | 63 | classes = torch.arange(self.num_classes).long().to(self.device) 64 | if self.device!='cpu': 65 | classes = classes.to(self.device) 66 | 67 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 68 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 69 | 70 | self.curDistMat = distmat 71 | 72 | dist = distmat * mask.float() 73 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 74 | 75 | return loss 76 | 77 | class CosCentroid(nn.Module): 78 | def __init__(self, num_classes=10, feat_dim=2, device='cpu'): 79 | super(CosCentroid, self).__init__() 80 | self.num_classes = num_classes 81 | self.feat_dim = feat_dim 82 | self.centers = torch.randn(self.num_classes, self.feat_dim) 83 | self.device = device 84 | #self.centers = F.normalize(self.centers, p=2, dim=1) 85 | if self.device!='cpu': 86 | self.centers.to(self.device) 87 | 88 | def forward(self, x, label=0): 89 | x = F.normalize(x, p=2, dim=1) 90 | distmat = torch.zeros((x.shape[0], self.centers.shape[0])).to(self.device) 91 | distmat.addmm_(0, -1, x, self.centers.t()) 92 | self.curDistMat = distmat 93 | return self.curDistMat 94 | 95 | 96 | 97 | def pca(X=np.array([]), no_dims=50): 98 | """ 99 | Runs PCA on the NxD array X in order to reduce its dimensionality to 100 | no_dims dimensions. 101 | """ 102 | 103 | print("Preprocessing the data using PCA...") 104 | (n, d) = X.shape 105 | m = np.mean(X, 0) 106 | X = X - np.tile(m, (n, 1)) 107 | (l, M) = np.linalg.eig(np.dot(X.T, X)) 108 | P = M[:, 0:no_dims] 109 | Y = np.dot(X, P) 110 | return Y, m, P 111 | 112 | 113 | 114 | 115 | 116 | def FetchFromSingleImage(curImg, cropSize=64, scaleList=[64, 78, 96, 128]): 117 | imgBatchList = [] 118 | 119 | for curSize in scaleList: 120 | curImg = curImg.resize((curSize, curSize)) 121 | curTransform = transforms.Compose([ 122 | transforms.TenCrop(cropSize, vertical_flip=False), 123 | transforms.Lambda(lambda crops: torch.stack([ 124 | transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor 125 | transforms.Lambda(lambda crops: torch.stack([ 126 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(crop) for crop in crops])), 127 | ]) 128 | imgBatchList += list(curTransform(curImg).unsqueeze(0)) 129 | 130 | 131 | curImg = curImg.resize((cropSize,cropSize)) 132 | curTransform = transforms.Compose([ 133 | transforms.ToTensor(), 134 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 135 | ]) 136 | 137 | imgBatchList += [curTransform(curImg).unsqueeze(0).clone()] 138 | imgBatchList += [curTransform(curImg.transpose(PIL.Image.FLIP_LEFT_RIGHT)).unsqueeze(0).clone()] 139 | imgBatchList = torch.cat(imgBatchList, 0) 140 | return imgBatchList 141 | 142 | 143 | 144 | class CustomizedPoolList(nn.Module): 145 | def __init__(self, poolSizeList=[32,32,16,8,4], poolType='max'): 146 | super(CustomizedPoolList, self).__init__() 147 | 148 | self.poolSizeList = poolSizeList 149 | self.poolType = poolType 150 | #self.linearLayers = OrderedDict() 151 | self.relu = nn.ReLU() 152 | #self.mnist_clsnet = nn.ModuleList(list(self.linearLayers.values())) 153 | 154 | def forward(self, feaList): 155 | x = [] 156 | if self.poolType=='max': 157 | for i in range(len(self.poolSizeList)): 158 | if self.poolSizeList[i]>0: 159 | x += [F.max_pool2d(feaList[i], self.poolSizeList[i])] 160 | elif self.poolType=='avg': 161 | for i in range(len(self.poolSizeList)): 162 | if self.poolSizeList[i]>0: 163 | x += [F.avg_pool2d(feaList[i], self.poolSizeList[i])] 164 | 165 | x = torch.cat(x, 1) 166 | x = x.view(x.shape[0], -1) 167 | return x 168 | 169 | 170 | 171 | class weightedL1Loss(nn.Module): 172 | def __init__(self, weight=1): 173 | # mean over all 174 | super(weightedL1Loss, self).__init__() 175 | self.loss = nn.L1Loss() 176 | self.weight = weight 177 | 178 | def forward(self, inputs, target): 179 | lossValue = self.weight * self.loss(inputs, target) 180 | return lossValue 181 | 182 | 183 | 184 | 185 | class MetricLoss(nn.Module): 186 | """inner-class compactness, aka Center loss. 187 | 188 | Reference: 189 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 190 | 191 | Args: 192 | num_classes (int): number of classes. 193 | feat_dim (int): feature dimension. 194 | """ 195 | def __init__(self, num_classes=10, feat_dim=2, 196 | weightCompactness=0.2, 197 | weightInner=1, 198 | weightInter=1., 199 | marginAlpha=0.2, 200 | sepMultiplier=3, 201 | device='cpu'): 202 | super(MetricLoss, self).__init__() 203 | self.num_classes = num_classes 204 | self.feat_dim = feat_dim 205 | self.weightCompactness = weightCompactness 206 | self.weightInner = weightInner 207 | self.weightInter = weightInter 208 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 209 | self.device = device 210 | if self.device!='cpu': 211 | self.centers.to(self.device) 212 | self.curDistMat = 0 213 | self.lossInner = 0 214 | self.lossInter = 0 215 | self.marginAlpha = marginAlpha 216 | self.sepMultiplier = sepMultiplier 217 | self.classes = torch.arange(self.num_classes).long().to(self.device) 218 | 219 | def forward(self, x, labels): 220 | """ 221 | Args: 222 | x: feature matrix with shape (batch_size, feat_dim). 223 | labels: ground truth labels with shape (batch_size). 224 | """ 225 | batch_size = x.size(0) 226 | # ||x-y||_2 = (x-y)^2 = x^2 + y^2 - 2xy 227 | # This part of the calculation is “x^2+y^2” 228 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 229 | # This part is "x^2+y^2 - 2xy" 230 | distmat.addmm_(1, -2, x, self.centers.t()) 231 | 232 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 233 | mask = labels.eq(self.classes.expand(batch_size, self.num_classes)) 234 | 235 | self.curDistMat = distmat 236 | #print('self.curDistMat: ', self.curDistMat.shape) 237 | 238 | # inner loss 239 | dist = distmat * mask.float() 240 | self.lossInner = (dist-self.marginAlpha).clamp(min=0) 241 | self.lossInner = self.lossInner.mean()*self.weightInner # / batch_size 242 | 243 | # compactness loss 244 | loss = dist.clamp(min=1e-12, max=1e+12).mean() / batch_size 245 | 246 | # inter loss 247 | # distance between centroids should be at least three times larger than the defined margin alpha 248 | self.lossInter = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, self.num_classes) + torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, self.num_classes).t() 249 | self.lossInter.addmm_(1, -2, self.centers, self.centers.t()) 250 | tmpMask = 1-torch.eye(self.num_classes).float().to(self.device) 251 | #tmpMask = tmpMask.reshape((1, self.num_classes, self.num_classes)) 252 | #tmpMask = tmpMask.repeat(batch_size, 1, 1).to(self.device) 253 | self.lossInter = (self.marginAlpha*self.sepMultiplier-self.lossInter).clamp(min=0) 254 | self.lossInter = self.lossInter*tmpMask 255 | self.lossInter = self.lossInter.sum()*self.weightInter 256 | 257 | return loss*self.weightCompactness 258 | 259 | 260 | 261 | 262 | def evaluate_openset(scores_closeset, scores_openset): 263 | y_true = np.array([0] * len(scores_closeset) + [1] * len(scores_openset)) 264 | y_discriminator = np.concatenate([scores_closeset, scores_openset]) 265 | auc_d, roc_to_plot = plot_roc(y_true, y_discriminator, 'Discriminator ROC') 266 | return auc_d, roc_to_plot 267 | 268 | 269 | def plot_roc(y_true, y_score, title="Receiver Operating Characteristic", **options): 270 | fpr, tpr, thresholds = roc_curve(y_true, y_score) 271 | auc_score = roc_auc_score(y_true, y_score) 272 | roc_to_plot = {'tp':tpr, 'fp':fpr, 'thresh':thresholds, 'auc_score':auc_score} 273 | #plot = plot_xy(fpr, tpr, x_axis="False Positive Rate", y_axis="True Positive Rate", title=title) 274 | #if options.get('roc_output'): 275 | # print("Saving ROC scores to file") 276 | # np.save(options['roc_output'], (fpr, tpr)) 277 | #return auc_score, plot, roc_to_plot 278 | return auc_score, roc_to_plot 279 | 280 | 281 | def plot_xy(x, y, x_axis="X", y_axis="Y", title="Plot"): 282 | df = pd.DataFrame({'x': x, 'y': y}) 283 | plot = df.plot(x='x', y='y') 284 | 285 | plot.grid(b=True, which='major') 286 | plot.grid(b=True, which='minor') 287 | 288 | plot.set_title(title) 289 | plot.set_ylabel(y_axis) 290 | plot.set_xlabel(x_axis) 291 | return plot 292 | 293 | 294 | def backup_Weibull(): 295 | print("Weibull: computing features for all correctly-classified training data") 296 | activation_vectors = {} 297 | for images, labels in dataloader_train_closeset: 298 | images = images.to(device) 299 | labels = labels.type(torch.long).view(-1).to(device) 300 | 301 | embFeature = encoder(images) 302 | logits = clsModel(embFeature) 303 | #logits = F.softmax(logits, dim=1) 304 | 305 | correctly_labeled = (logits.data.max(1)[1] == labels) 306 | labels_np = labels.cpu().numpy() 307 | logits_np = logits.data.cpu().numpy() 308 | for i, label in enumerate(labels_np): 309 | if not correctly_labeled[i]: 310 | continue 311 | if label not in activation_vectors: 312 | activation_vectors[label] = [] 313 | activation_vectors[label].append(logits_np[i]) 314 | 315 | print("Computed activation_vectors for {} known classes".format(len(activation_vectors))) 316 | for class_idx in activation_vectors: 317 | print("Class {}: {} images".format(class_idx, len(activation_vectors[class_idx]))) 318 | 319 | # Compute a mean activation vector for each class 320 | print("Weibull computing mean activation vectors...") 321 | mean_activation_vectors = {} 322 | for class_idx in activation_vectors: 323 | mean_activation_vectors[class_idx] = np.array(activation_vectors[class_idx]).mean(axis=0) 324 | 325 | WEIBULL_TAIL_SIZE = 20 326 | # Initialize one libMR Wiebull object for each class 327 | print("Fitting Weibull to distance distribution of each class") 328 | weibulls = {} 329 | for class_idx in activation_vectors: 330 | distances = [] 331 | mav = mean_activation_vectors[class_idx] 332 | for v in activation_vectors[class_idx]: 333 | distances.append(np.linalg.norm(v - mav)) 334 | mr = libmr.MR() 335 | tail_size = min(len(distances), WEIBULL_TAIL_SIZE) 336 | mr.fit_high(distances, tail_size) 337 | weibulls[class_idx] = mr 338 | print("Weibull params for class {}: {}".format(class_idx, mr.get_params())) 339 | 340 | 341 | # Apply Weibull score to every logit 342 | weibull_scores_closeset = [] 343 | logits_closeset = [] 344 | classes = activation_vectors.keys() 345 | for images, labels in dataloader_test_closeset: 346 | images = images.to(device) 347 | labels = labels.type(torch.long).view(-1).to(device) 348 | embFeature = encoder(images) 349 | batch_logits = clsModel(embFeature).data.cpu().numpy() 350 | batch_weibull = np.zeros(shape=batch_logits.shape) 351 | for activation_vector in batch_logits: 352 | weibull_row = np.ones(len(classes)) 353 | for class_idx in classes: 354 | mav = mean_activation_vectors[class_idx] 355 | dist = np.linalg.norm(activation_vector - mav) 356 | weibull_row[class_idx] = 1 - weibulls[class_idx].w_score(dist) 357 | weibull_scores_closeset.append(weibull_row) 358 | logits_closeset.append(activation_vector) 359 | 360 | weibull_scores_closeset = np.array(weibull_scores_closeset) 361 | logits_closeset = np.array(logits_closeset) 362 | openmax_scores_closeset = -np.log(np.sum(np.exp(logits_closeset * weibull_scores_closeset), axis=1)) 363 | 364 | 365 | # Apply Weibull score to every logit 366 | weibull_scores_openset = [] 367 | logits_openset = [] 368 | classes = activation_vectors.keys() 369 | for images, labels in dataloader_test_openset: 370 | images = images.to(device) 371 | labels = labels.type(torch.long).view(-1).to(device) 372 | embFeature = encoder(images) 373 | batch_logits = clsModel(embFeature).data.cpu().numpy() 374 | batch_weibull = np.zeros(shape=batch_logits.shape) 375 | for activation_vector in batch_logits: 376 | weibull_row = np.ones(len(classes)) 377 | for class_idx in classes: 378 | mav = mean_activation_vectors[class_idx] 379 | dist = np.linalg.norm(activation_vector - mav) 380 | weibull_row[class_idx] = 1 - weibulls[class_idx].w_score(dist) 381 | weibull_scores_openset.append(weibull_row) 382 | logits_openset.append(activation_vector) 383 | 384 | weibull_scores_openset = np.array(weibull_scores_openset) 385 | logits_openset = np.array(logits_openset) 386 | openmax_scores_openset = -np.log(np.sum(np.exp(logits_openset * weibull_scores_openset), axis=1)) -------------------------------------------------------------------------------- /utils/dataset_cifar10.py: -------------------------------------------------------------------------------- 1 | import os, random, time, copy 2 | from skimage import io, transform 3 | import numpy as np 4 | import os.path as path 5 | import scipy.io as sio 6 | from scipy import misc 7 | import matplotlib.pyplot as plt 8 | import PIL.Image 9 | import pickle 10 | import skimage.transform 11 | 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.optim import lr_scheduler 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | 20 | import torchvision 21 | from torchvision import datasets, models, transforms 22 | 23 | 24 | 25 | 26 | 27 | 28 | class CIFAR_OneClass4Train(Dataset): 29 | def __init__(self, size=(32,32), set_name='train', 30 | numKnown=6, numTotal=10, runIdx=0, 31 | classLabelIndex=0, 32 | path_to_data='/scratch/shuk/dataset/cifar10/cifar-10-batches-py', isOpenset=True, 33 | isAugment=True): 34 | self.classLabelIndex = classLabelIndex 35 | self.isAugment = isAugment 36 | self.set_name = set_name 37 | self.size = size 38 | self.numTotal = numTotal 39 | self.numKnown = numKnown 40 | self.runIdx = runIdx 41 | self.isOpenset = isOpenset 42 | self.path_to_data = path_to_data 43 | 44 | ######### get the data 45 | # train set 46 | curpath = path.join(self.path_to_data, 'data_batch_1') 47 | with open(curpath, 'rb') as fo: 48 | curpath = pickle.load(fo, encoding='bytes') 49 | 50 | self.imgList = curpath[b'data'].copy() 51 | self.labelList = curpath[b'labels'].copy() 52 | 53 | for i in range(2, 6): 54 | curpath = path.join(path_to_data, 'data_batch_{}'.format(i)) 55 | with open(curpath, 'rb') as fo: 56 | curpath = pickle.load(fo, encoding='bytes') 57 | self.imgList = np.concatenate((self.imgList, curpath[b'data'].copy())) 58 | self.labelList += curpath[b'labels'].copy() 59 | del curpath 60 | 61 | ####### set pre-processing operations 62 | self.transform = transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 65 | ]) 66 | #self.transform = transforms.Compose([ 67 | # transforms.RandomCrop(32, padding=4), 68 | # transforms.RandomHorizontalFlip(), 69 | # transforms.ToTensor(), 70 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 71 | #]) 72 | 73 | self.imgList = np.reshape(self.imgList, (self.imgList.shape[0], 3, 32, 32)) 74 | self.size = size 75 | self.labelList = np.asarray(self.labelList).astype(np.float32).reshape((-1, 1)) 76 | self.current_set_len = len(self.labelList) 77 | 78 | 79 | ########### shuffle for openset train-test data 80 | random.seed(0) 81 | 82 | self.randShuffleIndexSets = [] 83 | self.OpenSetSplit = [ 84 | [3, 6, 7, 8], 85 | [1, 2, 4, 6], 86 | [2, 3, 4, 9], 87 | [0, 1, 2, 6], 88 | [4, 5, 6, 9], 89 | [0, 2, 4, 6, 8, 9], # tinyImageNet 90 | ] 91 | for i in range(6): 92 | tmp = list(range(10)) 93 | tmpCloseset = list(set(tmp)-set(self.OpenSetSplit[i])) 94 | self.randShuffleIndexSets += [tmpCloseset+self.OpenSetSplit[i]] 95 | 96 | 97 | self.curShuffleSet = self.randShuffleIndexSets[runIdx] 98 | self.closesetActualLabels = self.curShuffleSet[:self.numKnown] 99 | self.opensetActualLabels = self.curShuffleSet[self.numKnown:] 100 | self.labelmapping = {} 101 | self.labelmapping_open = {} 102 | 103 | for i in range(len(self.closesetActualLabels)): 104 | self.labelmapping[self.closesetActualLabels[i]] = i 105 | for j in range(len(self.opensetActualLabels)): 106 | self.labelmapping_open[self.opensetActualLabels[j]] = self.numKnown + j 107 | 108 | self.validList = [] 109 | self.newLabel = [] 110 | for i in range(len(self.labelList)): 111 | if self.isOpenset: 112 | if self.labelList[i][0] in self.opensetActualLabels: 113 | self.validList += [i] 114 | self.newLabel += [self.labelmapping_open[self.labelList[i][0]]] 115 | else: 116 | if self.labelList[i][0] in self.closesetActualLabels: 117 | tmp_new_label = self.labelmapping[self.labelList[i][0]] 118 | if tmp_new_label==self.classLabelIndex: 119 | self.validList += [i] 120 | self.newLabel += [tmp_new_label] 121 | 122 | self.imgList = self.imgList[self.validList, :] 123 | self.labelList = np.asarray(self.newLabel).reshape((len(self.newLabel),1)) 124 | self.current_set_len = len(self.labelList) 125 | 126 | def __len__(self): 127 | return self.current_set_len 128 | 129 | def __getitem__(self, idx): 130 | curImage = self.imgList[idx,:] 131 | curLabel = self.labelList[idx].astype(np.float32) 132 | 133 | curImage = PIL.Image.fromarray(curImage.transpose(1,2,0)) 134 | curImage = self.transform(curImage) 135 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0) 136 | 137 | return curImage, curLabel 138 | 139 | 140 | 141 | 142 | 143 | 144 | class CIFAR_OPENSET_CLS(Dataset): 145 | def __init__(self, size=(32,32), set_name='train', 146 | numKnown=6, numTotal=10, runIdx=0, 147 | path_to_data='/scratch/shuk/dataset/cifar10/cifar-10-batches-py', isOpenset=True, 148 | isAugment=True): 149 | 150 | if set_name=='val': 151 | set_name = 'test' 152 | 153 | self.isAugment = isAugment 154 | self.set_name = set_name 155 | self.size = size 156 | self.numTotal = numTotal 157 | self.numKnown = numKnown 158 | self.runIdx = runIdx 159 | self.isOpenset = isOpenset 160 | self.path_to_data = path_to_data 161 | 162 | ######### get the data 163 | if self.set_name=='test': 164 | self.imgList = path.join(self.path_to_data, 'test_batch') 165 | with open(self.imgList, 'rb') as fo: 166 | self.imgList = pickle.load(fo, encoding='bytes') 167 | self.labelList = self.imgList[b'labels'].copy() 168 | self.imgList = self.imgList[b'data'] 169 | else: # train set 170 | curpath = path.join(self.path_to_data, 'data_batch_1') 171 | with open(curpath, 'rb') as fo: 172 | curpath = pickle.load(fo, encoding='bytes') 173 | 174 | self.imgList = curpath[b'data'].copy() 175 | self.labelList = curpath[b'labels'].copy() 176 | 177 | for i in range(2, 6): 178 | curpath = path.join(path_to_data, 'data_batch_{}'.format(i)) 179 | with open(curpath, 'rb') as fo: 180 | curpath = pickle.load(fo, encoding='bytes') 181 | self.imgList = np.concatenate((self.imgList, curpath[b'data'].copy())) 182 | self.labelList += curpath[b'labels'].copy() 183 | del curpath 184 | 185 | 186 | ####### set pre-processing operations 187 | if self.set_name=='test' or not self.isAugment: 188 | self.transform = transforms.Compose([ 189 | transforms.ToTensor(), 190 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 191 | ]) 192 | else: 193 | self.transform = transforms.Compose([ 194 | transforms.RandomCrop(32, padding=4), 195 | transforms.RandomHorizontalFlip(), 196 | transforms.ToTensor(), 197 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 198 | ]) 199 | 200 | self.imgList = np.reshape(self.imgList, (self.imgList.shape[0], 3, 32, 32)) 201 | self.size = size 202 | self.labelList = np.asarray(self.labelList).astype(np.float32).reshape((-1, 1)) 203 | self.current_set_len = len(self.labelList) 204 | 205 | 206 | ########### shuffle for openset train-test data 207 | random.seed(0) 208 | 209 | self.randShuffleIndexSets = [] 210 | self.OpenSetSplit = [ 211 | [3, 6, 7, 8], 212 | [1, 2, 4, 6], 213 | [2, 3, 4, 9], 214 | [0, 1, 2, 6], 215 | [4, 5, 6, 9], 216 | [0, 2, 4, 7, 8, 9], # tinyImageNet 217 | ] 218 | for i in range(6): 219 | tmp = list(range(10)) 220 | tmpCloseset = list(set(tmp)-set(self.OpenSetSplit[i])) 221 | self.randShuffleIndexSets += [tmpCloseset+self.OpenSetSplit[i]] 222 | 223 | #for i in range(10): 224 | # a = list(range(10)) 225 | # random.shuffle(a) 226 | # self.randShuffleIndexSets += [a] 227 | 228 | 229 | self.curShuffleSet = self.randShuffleIndexSets[runIdx] 230 | self.closesetActualLabels = self.curShuffleSet[:self.numKnown] 231 | self.opensetActualLabels = self.curShuffleSet[self.numKnown:] 232 | self.labelmapping = {} 233 | self.labelmapping_open = {} 234 | 235 | for i in range(len(self.closesetActualLabels)): 236 | self.labelmapping[self.closesetActualLabels[i]] = i 237 | for j in range(len(self.opensetActualLabels)): 238 | self.labelmapping_open[self.opensetActualLabels[j]] = self.numKnown + j 239 | 240 | 241 | #self.imgList = np.loadtxt(self.path_to_csv, delimiter=",") 242 | #self.labelList = np.asfarray(self.imgList[:, :1]) 243 | #self.imgList = np.asfarray(self.imgList[:, 1:]) * self.fac + 0.01 244 | 245 | self.validList = [] 246 | self.newLabel = [] 247 | for i in range(len(self.labelList)): 248 | if self.isOpenset: 249 | if self.labelList[i][0] in self.opensetActualLabels: 250 | self.validList += [i] 251 | self.newLabel += [self.labelmapping_open[self.labelList[i][0]]] 252 | else: 253 | if self.labelList[i][0] in self.closesetActualLabels: 254 | self.validList += [i] 255 | self.newLabel += [self.labelmapping[self.labelList[i][0]]] 256 | 257 | self.imgList = self.imgList[self.validList, :] 258 | self.labelList = np.asarray(self.newLabel).reshape((len(self.newLabel),1)) 259 | self.current_set_len = len(self.labelList) 260 | 261 | def __len__(self): 262 | return self.current_set_len 263 | 264 | def __getitem__(self, idx): 265 | curImage = self.imgList[idx,:] 266 | curLabel = self.labelList[idx].astype(np.float32) 267 | 268 | #if self.isAugment: 269 | # curImage = PIL.Image.fromarray(curImage.transpose(1,2,0)) 270 | # curImage = self.transform(curImage) 271 | #else: 272 | # curImage = torch.from_numpy(curImage.astype(np.float32)) 273 | 274 | curImage = PIL.Image.fromarray(curImage.transpose(1,2,0)) 275 | curImage = self.transform(curImage) 276 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0) 277 | 278 | ''' 279 | curImage = curImage.astype(np.float32) 280 | curLabel = curLabel.astype(np.float32) 281 | 282 | curImage = torch.from_numpy(curImage) 283 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0) 284 | ''' 285 | return curImage, curLabel 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | class CIFAR10_CLS_full_aug(Dataset): 301 | def __init__(self, size=(32,32), set_name='train', path_to_data='/scratch/shuk/dataset/cifar10/cifar-10-batches-py', isAugment=True): 302 | if set_name=='val': 303 | set_name = 'test' 304 | self.set_name = set_name 305 | self.path_to_data = path_to_data 306 | self.isAugment = isAugment 307 | 308 | if self.set_name=='test': 309 | self.imgList = path.join(self.path_to_data, 'test_batch') 310 | with open(self.imgList, 'rb') as fo: 311 | self.imgList = pickle.load(fo, encoding='bytes') 312 | self.labelList = self.imgList[b'labels'].copy() 313 | self.imgList = self.imgList[b'data'] 314 | else: # train set 315 | curpath = path.join(self.path_to_data, 'data_batch_1') 316 | with open(curpath, 'rb') as fo: 317 | curpath = pickle.load(fo, encoding='bytes') 318 | 319 | self.imgList = curpath[b'data'].copy() 320 | self.labelList = curpath[b'labels'].copy() 321 | 322 | for i in range(2, 6): 323 | curpath = path.join(path_to_data, 'data_batch_{}'.format(i)) 324 | with open(curpath, 'rb') as fo: 325 | curpath = pickle.load(fo, encoding='bytes') 326 | self.imgList = np.concatenate((self.imgList, curpath[b'data'].copy())) 327 | self.labelList += curpath[b'labels'].copy() 328 | del curpath 329 | 330 | if self.set_name=='test': 331 | self.transform = transforms.Compose([ 332 | transforms.ToTensor(), 333 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 334 | ]) 335 | else: 336 | self.transform = transforms.Compose([ 337 | transforms.RandomCrop(32, padding=4), 338 | transforms.RandomHorizontalFlip(), 339 | transforms.ToTensor(), 340 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 341 | ]) 342 | 343 | self.imgList = np.reshape(self.imgList, (self.imgList.shape[0], 3, 32, 32)) 344 | self.size = size 345 | self.labelList = np.asarray(self.labelList).astype(np.float32) 346 | self.current_set_len = len(self.labelList) 347 | 348 | 349 | def __len__(self): 350 | return self.current_set_len 351 | 352 | def __getitem__(self, idx): 353 | curImage = self.imgList[idx] 354 | curLabel = np.asarray(self.labelList[idx]) 355 | 356 | if self.isAugment: 357 | curImage = PIL.Image.fromarray(curImage.transpose(1,2,0)) 358 | curImage = self.transform(curImage) 359 | else: 360 | curImage = torch.from_numpy(curImage) 361 | 362 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0) 363 | 364 | return curImage, curLabel 365 | 366 | 367 | 368 | class CIFAR10_CLS_full(Dataset): 369 | def __init__(self, size=(32,32), set_name='train', path_to_data='/scratch/shuk/dataset/cifar10/cifar-10-batches-py'): 370 | if set_name=='val': 371 | set_name = 'test' 372 | self.set_name = set_name 373 | self.path_to_data = path_to_data 374 | 375 | if self.set_name=='test': 376 | self.imgList = path.join(self.path_to_data, 'test_batch') 377 | with open(self.imgList, 'rb') as fo: 378 | self.imgList = pickle.load(fo, encoding='bytes') 379 | self.labelList = self.imgList[b'labels'].copy() 380 | self.imgList = self.imgList[b'data'] 381 | else: # train set 382 | curpath = path.join(self.path_to_data, 'data_batch_1') 383 | with open(curpath, 'rb') as fo: 384 | curpath = pickle.load(fo, encoding='bytes') 385 | 386 | self.imgList = curpath[b'data'].copy() 387 | self.labelList = curpath[b'labels'].copy() 388 | 389 | for i in range(2, 6): 390 | curpath = path.join(path_to_data, 'data_batch_{}'.format(i)) 391 | with open(curpath, 'rb') as fo: 392 | curpath = pickle.load(fo, encoding='bytes') 393 | self.imgList = np.concatenate((self.imgList, curpath[b'data'].copy())) 394 | self.labelList += curpath[b'labels'].copy() 395 | del curpath 396 | 397 | self.imgList = np.reshape(self.imgList, (self.imgList.shape[0], 3, 32, 32)) 398 | self.size = size 399 | self.fac = 0.99 / 255 400 | self.labelList = np.asarray(self.labelList).astype(np.float32) 401 | self.imgList = self.imgList.astype(np.float32) * self.fac + 0.01 402 | self.current_set_len = len(self.labelList) 403 | 404 | def __len__(self): 405 | return self.current_set_len 406 | 407 | def __getitem__(self, idx): 408 | curImage = self.imgList[idx].astype(np.float32) 409 | curLabel = np.asarray(self.labelList[idx]).astype(np.float32) 410 | 411 | curImage = torch.from_numpy(curImage) 412 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0) 413 | 414 | return curImage, curLabel 415 | 416 | 417 | -------------------------------------------------------------------------------- /utils/dataset_cityscapes.py: -------------------------------------------------------------------------------- 1 | import os, random, time, copy 2 | from skimage import io, transform 3 | import json 4 | import numpy as np 5 | from subprocess import check_output 6 | import numpy as np 7 | import os.path as path 8 | import scipy.io as sio 9 | from scipy import misc 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import pickle 14 | import skimage.transform 15 | import csv 16 | import torch 17 | from torch.utils.data import Dataset, DataLoader 18 | import torch.nn as nn 19 | import torch.optim as optim 20 | from torch.optim import lr_scheduler 21 | import torch.nn.functional as F 22 | from torch.autograd import Variable 23 | import torchvision 24 | from torchvision import datasets, models, transforms 25 | from collections import namedtuple 26 | 27 | 28 | class Cityscapes(Dataset): 29 | """`Cityscapes `_ Dataset. 30 | 31 | Args: 32 | root (string): Root directory of dataset where directory ``leftImg8bit`` 33 | and ``gtFine`` or ``gtCoarse`` are located. 34 | split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine" 35 | otherwise ``train``, ``train_extra`` or ``val`` 36 | mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse`` 37 | target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon`` 38 | or ``color``. Can also be a list to output a tuple with all specified target types. 39 | transform (callable, optional): A function/transform that takes in a PIL image 40 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 41 | target_transform (callable, optional): A function/transform that takes in the 42 | target and transforms it. 43 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 44 | and returns a transformed version. 45 | 46 | Examples: 47 | 48 | Get semantic segmentation target 49 | 50 | .. code-block:: python 51 | 52 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', 53 | target_type='semantic') 54 | 55 | img, smnt = dataset[0] 56 | 57 | Get multiple targets 58 | 59 | .. code-block:: python 60 | 61 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', 62 | target_type=['instance', 'color', 'polygon']) 63 | 64 | img, (inst, col, poly) = dataset[0] 65 | 66 | Validate on the "coarse" set 67 | 68 | .. code-block:: python 69 | 70 | dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse', 71 | target_type='semantic') 72 | 73 | img, smnt = dataset[0] 74 | """ 75 | 76 | # Based on https://github.com/mcordts/cityscapesScripts 77 | CityscapesClass = namedtuple('CityscapesClass', 78 | ['name', 'id', 'train_id', 'category', 'category_id', 79 | 'has_instances', 'ignore_in_eval', 'color']) 80 | 81 | classes = [ 82 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 83 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 84 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 85 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 86 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 87 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 88 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 89 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 90 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), 91 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 92 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 93 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), 94 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), 95 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), 96 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 97 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 98 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 99 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), 100 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 101 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), 102 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), 103 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), 104 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), 105 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), 106 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), 107 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), 108 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), 109 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), 110 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), 111 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 112 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 113 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), 114 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), 115 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), 116 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), 117 | ] 118 | 119 | def __init__(self, root='/home/skong2/restore/dataset/Cityscapes', 120 | newsize=(256, 256), 121 | split='train', 122 | mode='fine', 123 | target_type='semantic', 124 | transform=None, 125 | target_transform=None, 126 | transforms=None): 127 | 128 | #super(Cityscapes, self).__init__(root, transforms, transform, target_transform) 129 | self.newsize = newsize 130 | self.flagResize = True 131 | if newsize[0]<0 or newsize[1]<0: 132 | self.flagResize = False 133 | 134 | self.root = root 135 | self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' 136 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split) 137 | self.targets_dir = os.path.join(self.root, self.mode, split) 138 | self.target_type = target_type 139 | self.split = split 140 | self.images = [] 141 | self.targets = [] 142 | if self.split=='test': 143 | self.split = 'val' 144 | self.transform = transform 145 | self.transforms = transforms 146 | self.target_transform = target_transform 147 | 148 | #verify_str_arg(mode, "mode", ("fine", "coarse")) 149 | if mode == "fine": 150 | valid_modes = ("train", "test", "val") 151 | else: 152 | valid_modes = ("train", "train_extra", "val") 153 | 154 | msg = ("Unknown value '{}' for argument split if mode is '{}'. " 155 | "Valid values are {{{}}}.") 156 | 157 | #msg = msg.format(split, mode, iterable_to_str(valid_modes)) 158 | #verify_str_arg(split, "split", valid_modes, msg) 159 | 160 | if not isinstance(target_type, list): 161 | self.target_type = [target_type] 162 | 163 | #[verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) for value in self.target_type] 164 | 165 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): 166 | if split == 'train_extra': 167 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip')) 168 | else: 169 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip')) 170 | 171 | if self.mode == 'gtFine': 172 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip')) 173 | elif self.mode == 'gtCoarse': 174 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip')) 175 | 176 | if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): 177 | extract_archive(from_path=image_dir_zip, to_path=self.root) 178 | extract_archive(from_path=target_dir_zip, to_path=self.root) 179 | else: 180 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' 181 | ' specified "split" and "mode" are inside the "root" directory') 182 | 183 | for city in os.listdir(self.images_dir): 184 | img_dir = os.path.join(self.images_dir, city) 185 | target_dir = os.path.join(self.targets_dir, city) 186 | for file_name in os.listdir(img_dir): 187 | target_types = [] 188 | for t in self.target_type: 189 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 190 | self._get_target_suffix(self.mode, t)) 191 | target_types.append(os.path.join(target_dir, target_name)) 192 | 193 | self.images.append(os.path.join(img_dir, file_name)) 194 | self.targets.append(target_types) 195 | 196 | def __getitem__(self, index): 197 | """ 198 | Args: 199 | index (int): Index 200 | Returns: 201 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more 202 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. 203 | """ 204 | 205 | image = Image.open(self.images[index]).convert('RGB') 206 | #b, g, r = image.split() 207 | #image = Image.merge("RGB", (r, g, b)) 208 | 209 | if self.flagResize: 210 | image = image.resize(self.newsize, resample=Image.BILINEAR) 211 | #print(self.targets[index], self.target_type) 212 | 213 | targets = [] 214 | for i, t in enumerate(self.target_type): 215 | if t == 'polygon': 216 | target = self._load_json(self.targets[index][i]) 217 | else: 218 | target = Image.open(self.targets[index][i]) 219 | 220 | if self.flagResize: 221 | target = target.resize(self.newsize, resample=Image.NEAREST) 222 | 223 | targets.append(target) 224 | 225 | target = tuple(targets) if len(targets) > 1 else targets[0] 226 | 227 | image = self.transform(image) 228 | #print('image', type(image), image.shape) 229 | 230 | target = np.asarray(target).astype(np.float32) 231 | target = torch.from_numpy(target) 232 | #target = self.target_transform(target) 233 | #print('target', type(target), target.shape) 234 | 235 | if self.transforms is not None: 236 | image, target = self.transforms(image, target) 237 | 238 | return image, target 239 | 240 | 241 | def __len__(self): 242 | return len(self.images) 243 | 244 | def extra_repr(self): 245 | lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] 246 | return '\n'.join(lines).format(**self.__dict__) 247 | 248 | def _load_json(self, path): 249 | with open(path, 'r') as file: 250 | data = json.load(file) 251 | return data 252 | 253 | def _get_target_suffix(self, mode, target_type): 254 | if target_type == 'instance': 255 | return '{}_instanceIds.png'.format(mode) 256 | elif target_type == 'semantic': 257 | return '{}_labelIds.png'.format(mode) 258 | elif target_type == 'color': 259 | return '{}_color.png'.format(mode) 260 | else: 261 | return '{}_polygons.json'.format(mode) 262 | 263 | 264 | 265 | 266 | 267 | ''' 268 | 269 | Label = namedtuple( 'Label' , [ 270 | 271 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 272 | # We use them to uniquely name a class 273 | 274 | 'id' , # An integer ID that is associated with this label. 275 | # The IDs are used to represent the label in ground truth images 276 | # An ID of -1 means that this label does not have an ID and thus 277 | # is ignored when creating ground truth images (e.g. license plate). 278 | # Do not modify these IDs, since exactly these IDs are expected by the 279 | # evaluation server. 280 | 281 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 282 | # ground truth images with train IDs, using the tools provided in the 283 | # 'preparation' folder. However, make sure to validate or submit results 284 | # to our evaluation server using the regular IDs above! 285 | # For trainIds, multiple labels might have the same ID. Then, these labels 286 | # are mapped to the same class in the ground truth images. For the inverse 287 | # mapping, we use the label that is defined first in the list below. 288 | # For example, mapping all void-type classes to the same ID in training, 289 | # might make sense for some approaches. 290 | # Max value is 255! 291 | 292 | 'category' , # The name of the category that this label belongs to 293 | 294 | 'categoryId' , # The ID of this category. Used to create ground truth images 295 | # on category level. 296 | 297 | 'hasInstances', # Whether this label distinguishes between single instances or not 298 | 299 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 300 | # during evaluations or not 301 | 302 | 'color' , # The color of this label 303 | ] ) 304 | 305 | 306 | 307 | labels = [ 308 | # name id trainId category catId hasInstances ignoreInEval color 309 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 310 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 311 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 312 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 313 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 314 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 315 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 316 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 317 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 318 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 319 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 320 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 321 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 322 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 323 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 324 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 325 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 326 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 327 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 328 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 329 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 330 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 331 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 332 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 333 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 334 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 335 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 336 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 337 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 338 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 339 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 340 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 341 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 342 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 343 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 344 | ] 345 | 346 | 347 | ''' -------------------------------------------------------------------------------- /utils/dataset_cityscapes4OpenGAN.py: -------------------------------------------------------------------------------- 1 | import os, random, time, copy 2 | from skimage import io, transform 3 | import json 4 | import numpy as np 5 | from subprocess import check_output 6 | import numpy as np 7 | import os.path as path 8 | import scipy.io as sio 9 | from scipy import misc 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import pickle 14 | import skimage.transform 15 | import csv 16 | import torch 17 | from torch.utils.data import Dataset, DataLoader 18 | import torch.nn as nn 19 | import torch.optim as optim 20 | from torch.optim import lr_scheduler 21 | import torch.nn.functional as F 22 | from torch.autograd import Variable 23 | import torchvision 24 | from torchvision import datasets, models, transforms 25 | from collections import namedtuple 26 | 27 | 28 | class Cityscapes(Dataset): 29 | """`Cityscapes `_ Dataset. 30 | 31 | Args: 32 | root (string): Root directory of dataset where directory ``leftImg8bit`` 33 | and ``gtFine`` or ``gtCoarse`` are located. 34 | split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine" 35 | otherwise ``train``, ``train_extra`` or ``val`` 36 | mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse`` 37 | target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon`` 38 | or ``color``. Can also be a list to output a tuple with all specified target types. 39 | transform (callable, optional): A function/transform that takes in a PIL image 40 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 41 | target_transform (callable, optional): A function/transform that takes in the 42 | target and transforms it. 43 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 44 | and returns a transformed version. 45 | 46 | Examples: 47 | 48 | Get semantic segmentation target 49 | 50 | .. code-block:: python 51 | 52 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', 53 | target_type='semantic') 54 | 55 | img, smnt = dataset[0] 56 | 57 | Get multiple targets 58 | 59 | .. code-block:: python 60 | 61 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', 62 | target_type=['instance', 'color', 'polygon']) 63 | 64 | img, (inst, col, poly) = dataset[0] 65 | 66 | Validate on the "coarse" set 67 | 68 | .. code-block:: python 69 | 70 | dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse', 71 | target_type='semantic') 72 | 73 | img, smnt = dataset[0] 74 | """ 75 | 76 | # Based on https://github.com/mcordts/cityscapesScripts 77 | CityscapesClass = namedtuple('CityscapesClass', 78 | ['name', 'id', 'train_id', 'category', 'category_id', 79 | 'has_instances', 'ignore_in_eval', 'color']) 80 | 81 | classes = [ 82 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 83 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 84 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 85 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 86 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 87 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 88 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 89 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 90 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), 91 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 92 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 93 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), 94 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), 95 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), 96 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 97 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 98 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 99 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), 100 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 101 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), 102 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), 103 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), 104 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), 105 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), 106 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), 107 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), 108 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), 109 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), 110 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), 111 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 112 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 113 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), 114 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), 115 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), 116 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), 117 | ] 118 | 119 | def __init__(self, root='/home/skong2/restore/dataset/Cityscapes', 120 | newsize=(256, 256), 121 | split='train', 122 | mode='fine', 123 | trainnum=10, 124 | target_type='semantic', 125 | transform=None, 126 | target_transform=None, 127 | transforms=None): 128 | 129 | #super(Cityscapes, self).__init__(root, transforms, transform, target_transform) 130 | self.newsize = newsize 131 | self.trainnum = trainnum 132 | self.flagResize = True 133 | if newsize[0]<0 or newsize[1]<0: 134 | self.flagResize = False 135 | 136 | self.root = root 137 | self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' 138 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split) 139 | self.targets_dir = os.path.join(self.root, self.mode, split) 140 | self.target_type = target_type 141 | self.split = split 142 | self.images = [] 143 | self.targets = [] 144 | if self.split=='test': 145 | self.split = 'val' 146 | self.transform = transform 147 | self.transforms = transforms 148 | self.target_transform = target_transform 149 | 150 | #verify_str_arg(mode, "mode", ("fine", "coarse")) 151 | if mode == "fine": 152 | valid_modes = ("train", "test", "val") 153 | else: 154 | valid_modes = ("train", "train_extra", "val") 155 | 156 | msg = ("Unknown value '{}' for argument split if mode is '{}'. " 157 | "Valid values are {{{}}}.") 158 | 159 | #msg = msg.format(split, mode, iterable_to_str(valid_modes)) 160 | #verify_str_arg(split, "split", valid_modes, msg) 161 | 162 | if not isinstance(target_type, list): 163 | self.target_type = [target_type] 164 | 165 | #[verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) for value in self.target_type] 166 | 167 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): 168 | if split == 'train_extra': 169 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip')) 170 | else: 171 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip')) 172 | 173 | if self.mode == 'gtFine': 174 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip')) 175 | elif self.mode == 'gtCoarse': 176 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip')) 177 | 178 | if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): 179 | extract_archive(from_path=image_dir_zip, to_path=self.root) 180 | extract_archive(from_path=target_dir_zip, to_path=self.root) 181 | else: 182 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' 183 | ' specified "split" and "mode" are inside the "root" directory') 184 | 185 | for city in os.listdir(self.images_dir): 186 | img_dir = os.path.join(self.images_dir, city) 187 | target_dir = os.path.join(self.targets_dir, city) 188 | for file_name in os.listdir(img_dir): 189 | target_types = [] 190 | for t in self.target_type: 191 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 192 | self._get_target_suffix(self.mode, t)) 193 | target_types.append(os.path.join(target_dir, target_name)) 194 | 195 | self.images.append(os.path.join(img_dir, file_name)) 196 | self.targets.append(target_types) 197 | 198 | if self.split=='train' and self.trainnum>0: 199 | self.images = self.images[:self.trainnum] 200 | self.targets = self.targets[:self.trainnum] 201 | elif self.split=='train' and self.trainnum<-5: 202 | self.images = self.images[self.trainnum:] 203 | self.targets = self.targets[self.trainnum:] 204 | else: 205 | self.trainnum = -1 206 | 207 | 208 | def __getitem__(self, index): 209 | """ 210 | Args: 211 | index (int): Index 212 | Returns: 213 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more 214 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. 215 | """ 216 | 217 | image = Image.open(self.images[index]).convert('RGB') 218 | #b, g, r = image.split() 219 | #image = Image.merge("RGB", (r, g, b)) 220 | 221 | if self.flagResize: 222 | image = image.resize(self.newsize, resample=Image.BILINEAR) 223 | #print(self.targets[index], self.target_type) 224 | 225 | targets = [] 226 | for i, t in enumerate(self.target_type): 227 | if t == 'polygon': 228 | target = self._load_json(self.targets[index][i]) 229 | else: 230 | target = Image.open(self.targets[index][i]) 231 | 232 | if self.flagResize: 233 | target = target.resize(self.newsize, resample=Image.NEAREST) 234 | 235 | targets.append(target) 236 | 237 | target = tuple(targets) if len(targets) > 1 else targets[0] 238 | 239 | image = self.transform(image) 240 | #print('image', type(image), image.shape) 241 | 242 | target = np.asarray(target).astype(np.float32) 243 | target = torch.from_numpy(target) 244 | #target = self.target_transform(target) 245 | #print('target', type(target), target.shape) 246 | 247 | if self.transforms is not None: 248 | image, target = self.transforms(image, target) 249 | 250 | return image, target 251 | 252 | 253 | def __len__(self): 254 | return len(self.images) 255 | 256 | def extra_repr(self): 257 | lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] 258 | return '\n'.join(lines).format(**self.__dict__) 259 | 260 | def _load_json(self, path): 261 | with open(path, 'r') as file: 262 | data = json.load(file) 263 | return data 264 | 265 | def _get_target_suffix(self, mode, target_type): 266 | if target_type == 'instance': 267 | return '{}_instanceIds.png'.format(mode) 268 | elif target_type == 'semantic': 269 | return '{}_labelIds.png'.format(mode) 270 | elif target_type == 'color': 271 | return '{}_color.png'.format(mode) 272 | else: 273 | return '{}_polygons.json'.format(mode) 274 | 275 | 276 | 277 | 278 | 279 | ''' 280 | 281 | Label = namedtuple( 'Label' , [ 282 | 283 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 284 | # We use them to uniquely name a class 285 | 286 | 'id' , # An integer ID that is associated with this label. 287 | # The IDs are used to represent the label in ground truth images 288 | # An ID of -1 means that this label does not have an ID and thus 289 | # is ignored when creating ground truth images (e.g. license plate). 290 | # Do not modify these IDs, since exactly these IDs are expected by the 291 | # evaluation server. 292 | 293 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 294 | # ground truth images with train IDs, using the tools provided in the 295 | # 'preparation' folder. However, make sure to validate or submit results 296 | # to our evaluation server using the regular IDs above! 297 | # For trainIds, multiple labels might have the same ID. Then, these labels 298 | # are mapped to the same class in the ground truth images. For the inverse 299 | # mapping, we use the label that is defined first in the list below. 300 | # For example, mapping all void-type classes to the same ID in training, 301 | # might make sense for some approaches. 302 | # Max value is 255! 303 | 304 | 'category' , # The name of the category that this label belongs to 305 | 306 | 'categoryId' , # The ID of this category. Used to create ground truth images 307 | # on category level. 308 | 309 | 'hasInstances', # Whether this label distinguishes between single instances or not 310 | 311 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 312 | # during evaluations or not 313 | 314 | 'color' , # The color of this label 315 | ] ) 316 | 317 | 318 | 319 | labels = [ 320 | # name id trainId category catId hasInstances ignoreInEval color 321 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 322 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 323 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 324 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 325 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 326 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 327 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 328 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 329 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 330 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 331 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 332 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 333 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 334 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 335 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 336 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 337 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 338 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 339 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 340 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 341 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 342 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 343 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 344 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 345 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 346 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 347 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 348 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 349 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 350 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 351 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 352 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 353 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 354 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 355 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 356 | ] 357 | 358 | 359 | ''' -------------------------------------------------------------------------------- /utils/network_arch_tinyimagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import numpy as np 3 | import torchvision 4 | from torchvision import datasets, models, transforms 5 | import torch 6 | import torch.nn as nn 7 | from collections import OrderedDict 8 | from utils.layers import * 9 | import torchvision.models as models 10 | import torch.utils.model_zoo as model_zoo 11 | import numpy as np 12 | import os, math 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.nn as nn 15 | 16 | 17 | 18 | 19 | class Discriminator80x80InstNorm(nn.Module): 20 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], frameStackNumber=3): 21 | super(Discriminator80x80InstNorm, self).__init__() 22 | self.device = device 23 | self.frameStackNumber = frameStackNumber 24 | self.patchSize = patchSize 25 | self.outputSize = [patchSize[0]/16, patchSize[1]/16] 26 | 27 | self.discriminator = nn.Sequential( 28 | # 128-->60 29 | nn.Conv2d(self.frameStackNumber, 64, kernel_size=5, padding=0, stride=2, bias=True), 30 | nn.LeakyReLU(0.2, inplace=True), 31 | 32 | # 60-->33 33 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False), 34 | nn.InstanceNorm2d(128, momentum=0.001, affine=False, track_running_stats=False), 35 | nn.LeakyReLU(0.2, inplace=True), 36 | # 33-> 37 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False), 38 | nn.InstanceNorm2d(256, momentum=0.001, affine=False, track_running_stats=False), 39 | nn.LeakyReLU(0.2, inplace=True), 40 | # 41 | nn.Conv2d(256, 512, kernel_size=3, padding=0, stride=2, bias=False), 42 | nn.InstanceNorm2d(512, momentum=0.001, affine=False, track_running_stats=False), 43 | nn.LeakyReLU(0.2, inplace=True), 44 | # final classification for 'real(1) vs. fake(0)' 45 | nn.Conv2d(512, 1, kernel_size=2, padding=0, stride=2, bias=True), 46 | nn.Sigmoid() 47 | ) 48 | 49 | def forward(self, X): 50 | return self.discriminator(X) 51 | 52 | 53 | 54 | class Discriminator80x80(nn.Module): 55 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], frameStackNumber=3): 56 | super(Discriminator80x80, self).__init__() 57 | self.device = device 58 | self.frameStackNumber = frameStackNumber 59 | self.patchSize = patchSize 60 | self.outputSize = [patchSize[0]/16, patchSize[1]/16] 61 | 62 | self.discriminator = nn.Sequential( 63 | # 128-->60 64 | nn.Conv2d(self.frameStackNumber, 64, kernel_size=5, padding=0, stride=2, bias=False), 65 | nn.LeakyReLU(0.2, inplace=True), 66 | 67 | # 60-->33 68 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False), 69 | nn.BatchNorm2d(128), 70 | nn.LeakyReLU(0.2, inplace=True), 71 | # 33-> 72 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False), 73 | nn.BatchNorm2d(256), 74 | nn.LeakyReLU(0.2, inplace=True), 75 | # 76 | nn.Conv2d(256, 512, kernel_size=3, padding=0, stride=2, bias=False), 77 | nn.BatchNorm2d(512), 78 | nn.LeakyReLU(0.2, inplace=True), 79 | # final classification for 'real(1) vs. fake(0)' 80 | nn.Conv2d(512, 1, kernel_size=2, padding=0, stride=2, bias=True), 81 | nn.Sigmoid() 82 | ) 83 | 84 | def forward(self, X): 85 | return self.discriminator(X) 86 | 87 | 88 | 89 | class Discriminator70x70(nn.Module): 90 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], frameStackNumber=3): 91 | super(Discriminator70x70, self).__init__() 92 | self.device = device 93 | self.frameStackNumber = frameStackNumber 94 | self.patchSize = patchSize 95 | self.outputSize = [patchSize[0]/16, patchSize[1]/16] 96 | 97 | self.discriminator = nn.Sequential( 98 | # 128-->60 99 | nn.Conv2d(self.frameStackNumber, 64, kernel_size=4, padding=0, stride=2, bias=False), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | 102 | # 60-->33 103 | nn.Conv2d(64, 128, kernel_size=4, padding=0, stride=2, bias=False), 104 | nn.BatchNorm2d(128), 105 | nn.LeakyReLU(0.2, inplace=True), 106 | # 33-> 107 | nn.Conv2d(128, 256, kernel_size=4, padding=0, stride=2, bias=False), 108 | nn.BatchNorm2d(256), 109 | nn.LeakyReLU(0.2, inplace=True), 110 | # 111 | nn.Conv2d(256, 512, kernel_size=4, padding=0, stride=2, bias=False), 112 | nn.BatchNorm2d(512), 113 | nn.LeakyReLU(0.2, inplace=True), 114 | # final classification for 'real(1) vs. fake(0)' 115 | nn.Conv2d(512, 1, kernel_size=2, padding=0, stride=2, bias=True), 116 | nn.Sigmoid() 117 | ) 118 | 119 | def forward(self, X): 120 | return self.discriminator(X) 121 | 122 | 123 | class Discriminator(nn.Module): 124 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], frameStackNumber=3): 125 | super(Discriminator, self).__init__() 126 | self.device = device 127 | self.frameStackNumber = frameStackNumber 128 | self.patchSize = patchSize 129 | self.outputSize = [patchSize[0]/16, patchSize[1]/16] 130 | 131 | self.discriminator = nn.Sequential( 132 | # 128-->60 133 | nn.Conv2d(self.frameStackNumber, 64, kernel_size=9, padding=0, stride=2, bias=False), 134 | nn.LeakyReLU(0.2, inplace=True), 135 | 136 | # 60-->33 137 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False), 138 | nn.BatchNorm2d(128), 139 | nn.LeakyReLU(0.2, inplace=True), 140 | # 33-> 141 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False), 142 | nn.BatchNorm2d(256), 143 | nn.LeakyReLU(0.2, inplace=True), 144 | 145 | nn.Conv2d(256, 256, kernel_size=3, padding=0, stride=2, bias=False), 146 | nn.BatchNorm2d(256), 147 | nn.LeakyReLU(0.2, inplace=True), 148 | # dropout 149 | nn.Dropout(0.7), 150 | # final classification for 'real(1) vs. fake(0)' 151 | nn.Conv2d(256, 1, kernel_size=2, padding=0, stride=2, bias=True), 152 | nn.Sigmoid() 153 | ) 154 | 155 | def forward(self, X): 156 | return self.discriminator(X) 157 | 158 | 159 | 160 | 161 | 162 | class GAN_Encoder(nn.Module): 163 | def __init__(self, embDimension=512): 164 | super(self.__class__, self).__init__() 165 | 166 | self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 167 | self.conv2 = nn.Conv2d(64, 128, 1, 2, 0, bias=False) 168 | self.conv3 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 169 | self.conv4 = nn.Conv2d(128, 256, 1, 2, 0, bias=False) 170 | self.conv5 = nn.Conv2d(256, 256, 3, 1, 1, bias=False) 171 | self.conv6 = nn.Conv2d(256, 512, 1, 2, 0, bias=False) 172 | self.conv7 = nn.Conv2d(512, 512, 3, 1, 1, bias=False) 173 | self.conv8 = nn.Conv2d(512, 512, 1, 2, 0, bias=False) 174 | self.conv9 = nn.Conv2d(512, 512, 3, 1, 1, bias=False) 175 | self.conv10 = nn.Conv2d(512, 512, 1, 2, 0, bias=False) 176 | self.conv11 = nn.Conv2d(512, embDimension, 3, 1, 1, bias=False) 177 | 178 | self.bn1 = nn.BatchNorm2d(64) 179 | self.bn2 = nn.BatchNorm2d(128) 180 | self.bn3 = nn.BatchNorm2d(128) 181 | self.bn4 = nn.BatchNorm2d(256) 182 | self.bn5 = nn.BatchNorm2d(256) 183 | self.bn6 = nn.BatchNorm2d(512) 184 | self.bn7 = nn.BatchNorm2d(512) 185 | self.bn8 = nn.BatchNorm2d(512) 186 | self.bn9 = nn.BatchNorm2d(512) 187 | self.bn10 = nn.BatchNorm2d(512) 188 | self.bn11 = nn.BatchNorm2d(embDimension) 189 | 190 | self.apply(weights_init) 191 | 192 | 193 | def forward(self, x, output_scale=1): 194 | batch_size = len(x) 195 | 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = nn.LeakyReLU(0.2)(x) 199 | x = self.conv2(x) 200 | x = self.bn2(x) 201 | x = nn.LeakyReLU(0.2)(x) 202 | x = self.conv3(x) 203 | x = self.bn3(x) 204 | x = nn.LeakyReLU(0.2)(x) 205 | 206 | x = self.conv4(x) 207 | x = self.bn4(x) 208 | x = nn.LeakyReLU(0.2)(x) 209 | x = self.conv5(x) 210 | x = self.bn5(x) 211 | x = nn.LeakyReLU(0.2)(x) 212 | x = self.conv6(x) 213 | x = self.bn6(x) 214 | x = nn.LeakyReLU(0.2)(x) 215 | 216 | x = self.conv7(x) 217 | x = self.bn7(x) 218 | x = nn.LeakyReLU(0.2)(x) 219 | x = self.conv8(x) 220 | x = self.bn8(x) 221 | x = nn.LeakyReLU(0.2)(x) 222 | x = self.conv9(x) 223 | x = self.bn9(x) 224 | x = nn.LeakyReLU(0.2)(x) 225 | 226 | x = self.conv10(x) 227 | x = self.bn10(x) 228 | x = nn.LeakyReLU(0.2)(x) 229 | 230 | return x 231 | 232 | 233 | class GAN_Decoder(nn.Module): 234 | def __init__(self, nz=64, ngf=64, nc=3): 235 | super(GAN_Decoder, self).__init__() 236 | 237 | # torch.nn.ConvTranspose2d( 238 | # in_channels, out_channels, kernel_size, 239 | # stride=1, padding=0, output_padding=0, groups=1, 240 | # bias=True, dilation=1, padding_mode='zeros') 241 | 242 | self.main = nn.Sequential( 243 | # input is Z, going into a convolution 244 | nn.ConvTranspose2d(nz, ngf*4, 4, 2, 1, bias=False), 245 | nn.BatchNorm2d(ngf * 4), 246 | nn.ReLU(True), 247 | # state size. (ngf*8) x 2 x 2 248 | nn.ConvTranspose2d(ngf*4, ngf*4, 4, 2, 1, bias=False), 249 | nn.BatchNorm2d(ngf * 4), 250 | nn.ReLU(True), 251 | # state size. (ngf*4) x 4 x 4 252 | nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False), 253 | nn.BatchNorm2d(ngf * 2), 254 | nn.ReLU(True), 255 | # state size. (ngf*2) x 8 x 8 256 | nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False), 257 | nn.BatchNorm2d(ngf), 258 | nn.ReLU(True), 259 | # state size. (ngf) x 16 x 16 260 | nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=True) 261 | #nn.Tanh() 262 | # state size. (nc) x 32 x 32 263 | ) 264 | 265 | def forward(self, x): 266 | return self.main(x) 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | def weights_init(m): 284 | classname = m.__class__.__name__ 285 | # TODO: what about fully-connected layers? 286 | if classname.find('Conv') != -1: 287 | m.weight.data.normal_(0.0, 0.05) 288 | elif classname.find('BatchNorm') != -1: 289 | m.weight.data.normal_(1.0, 0.02) 290 | m.bias.data.fill_(0) 291 | 292 | 293 | 294 | 295 | 296 | 297 | class MyDecoder(nn.Module): 298 | def __init__(self, latent_size=512, input_scale=4, insertConv=False): 299 | super(self.__class__, self).__init__() 300 | self.latent_size = latent_size 301 | self.input_scale = input_scale 302 | self.fc1 = nn.Linear(latent_size, 512*2*2, bias=False) 303 | self.insertConv = insertConv 304 | 305 | self.conv2_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=False) 306 | self.conv2 = nn.ConvTranspose2d( 512, 512, 4, stride=2, padding=1, bias=False) 307 | self.conv2_mid = nn.Conv2d(152, 512, 3, 1, 1, bias=False) 308 | 309 | self.conv3_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=False) 310 | self.conv3 = nn.ConvTranspose2d( 512, 256, 4, stride=2, padding=1, bias=False) 311 | self.conv3_mid = nn.Conv2d(256, 256, 3, 1, 1, bias=False) 312 | 313 | self.conv4_in = nn.ConvTranspose2d(latent_size, 256, 1, stride=1, padding=0, bias=False) 314 | self.conv4 = nn.ConvTranspose2d( 256, 128, 4, stride=2, padding=1, bias=False) 315 | self.conv4_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 316 | 317 | self.conv5 = nn.ConvTranspose2d( 128, 128, 4, stride=2, padding=1, bias=False) 318 | self.conv5_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 319 | 320 | 321 | self.conv6 = nn.ConvTranspose2d( 128, 3, 4, stride=2, padding=1, bias=True) 322 | 323 | 324 | self.bn1 = nn.BatchNorm2d(512) 325 | self.bn2 = nn.BatchNorm2d(512) 326 | self.bn2_mid = nn.BatchNorm2d(512) 327 | self.bn3 = nn.BatchNorm2d(256) 328 | self.bn3_mid = nn.BatchNorm2d(256) 329 | self.bn4 = nn.BatchNorm2d(128) 330 | self.bn4_mid = nn.BatchNorm2d(128) 331 | self.bn5 = nn.BatchNorm2d(128) 332 | self.bn5_mid = nn.BatchNorm2d(128) 333 | 334 | self.apply(weights_init) 335 | self.cuda() 336 | 337 | 338 | def forward(self, x): 339 | input_scale=self.input_scale 340 | batch_size = x.shape[0] 341 | 342 | if input_scale <= 1: 343 | x = self.fc1(x) 344 | x = x.resize(batch_size, 512, 2, 2) 345 | 346 | # 512 x 2 x 2 347 | if input_scale == 2: 348 | x = x.view(batch_size, self.latent_size, 2, 2) 349 | x = self.conv2_in(x) 350 | if input_scale <= 2: 351 | x = self.conv2(x) 352 | x = nn.LeakyReLU()(x) 353 | x = self.bn2(x) 354 | if self.insertConv: 355 | x = self.conv2_mid(x) 356 | x = nn.LeakyReLU()(x) 357 | x = self.bn2_mid(x) 358 | 359 | # 512 x 4 x 4 360 | if input_scale == 4: 361 | x = x.view(batch_size, self.latent_size, 4, 4) 362 | x = self.conv3_in(x) 363 | if input_scale <= 4: 364 | x = self.conv3(x) 365 | x = nn.LeakyReLU()(x) 366 | x = self.bn3(x) 367 | if self.insertConv: 368 | x = self.conv3_mid(x) 369 | x = nn.LeakyReLU()(x) 370 | x = self.bn3_mid(x) 371 | 372 | 373 | # 256 x 8 x 8 374 | if input_scale == 8: 375 | x = x.view(batch_size, self.latent_size, 8, 8) 376 | x = self.conv4_in(x) 377 | if input_scale <= 8: 378 | x = self.conv4(x) 379 | x = nn.LeakyReLU()(x) 380 | x = self.bn4(x) 381 | if self.insertConv: 382 | x = self.conv4_mid(x) 383 | x = nn.LeakyReLU()(x) 384 | x = self.bn4_mid(x) 385 | 386 | 387 | # 128 x 16 x 16 388 | x = self.conv5(x) 389 | x = nn.LeakyReLU()(x) 390 | x = self.bn5_mid(x) 391 | 392 | # 3 x 32 x 32 393 | #x = nn.Sigmoid()(x) 394 | 395 | x = self.conv6(x) 396 | return x 397 | 398 | 399 | 400 | 401 | 402 | class MySingleBigDecoder(nn.Module): 403 | def __init__(self, latent_size=512, input_scale=4, insertConv=False, nClasses=200): 404 | super(self.__class__, self).__init__() 405 | self.latent_size = latent_size 406 | self.input_scale = input_scale 407 | self.fc1 = nn.Linear(latent_size, 512*2*2, bias=False) 408 | self.insertConv = insertConv 409 | self.nClasses = nClasses 410 | 411 | self.conv2_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=False) 412 | self.conv2 = nn.ConvTranspose2d( 512, 512, 4, stride=2, padding=1, bias=False) 413 | self.conv2_mid = nn.Conv2d(152, 512, 3, 1, 1, bias=False) 414 | 415 | self.conv3_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=False) 416 | self.conv3 = nn.ConvTranspose2d( 512, 256, 4, stride=2, padding=1, bias=False) 417 | self.conv3_mid = nn.Conv2d(256, 256, 3, 1, 1, bias=False) 418 | 419 | self.conv4_in = nn.ConvTranspose2d(latent_size, 256, 1, stride=1, padding=0, bias=False) 420 | self.conv4 = nn.ConvTranspose2d( 256, 128, 4, stride=2, padding=1, bias=False) 421 | self.conv4_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 422 | 423 | self.conv5 = nn.ConvTranspose2d( 128, 128, 4, stride=2, padding=1, bias=False) 424 | self.conv5_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 425 | 426 | 427 | self.conv6 = nn.ConvTranspose2d( 128, 3*nClasses, 4, stride=2, padding=1, bias=True) 428 | 429 | 430 | self.bn1 = nn.BatchNorm2d(512) 431 | self.bn2 = nn.BatchNorm2d(512) 432 | self.bn2_mid = nn.BatchNorm2d(512) 433 | self.bn3 = nn.BatchNorm2d(256) 434 | self.bn3_mid = nn.BatchNorm2d(256) 435 | self.bn4 = nn.BatchNorm2d(128) 436 | self.bn4_mid = nn.BatchNorm2d(128) 437 | self.bn5 = nn.BatchNorm2d(128) 438 | self.bn5_mid = nn.BatchNorm2d(128) 439 | 440 | self.apply(weights_init) 441 | self.cuda() 442 | 443 | 444 | def forward(self, x): 445 | input_scale=self.input_scale 446 | batch_size = x.shape[0] 447 | 448 | if input_scale <= 1: 449 | x = self.fc1(x) 450 | x = x.resize(batch_size, 512, 2, 2) 451 | 452 | # 512 x 2 x 2 453 | if input_scale == 2: 454 | x = x.view(batch_size, self.latent_size, 2, 2) 455 | x = self.conv2_in(x) 456 | if input_scale <= 2: 457 | x = self.conv2(x) 458 | x = nn.LeakyReLU()(x) 459 | x = self.bn2(x) 460 | if self.insertConv: 461 | x = self.conv2_mid(x) 462 | x = nn.LeakyReLU()(x) 463 | x = self.bn2_mid(x) 464 | 465 | # 512 x 4 x 4 466 | if input_scale == 4: 467 | x = x.view(batch_size, self.latent_size, 4, 4) 468 | x = self.conv3_in(x) 469 | if input_scale <= 4: 470 | x = self.conv3(x) 471 | x = nn.LeakyReLU()(x) 472 | x = self.bn3(x) 473 | if self.insertConv: 474 | x = self.conv3_mid(x) 475 | x = nn.LeakyReLU()(x) 476 | x = self.bn3_mid(x) 477 | 478 | 479 | # 256 x 8 x 8 480 | if input_scale == 8: 481 | x = x.view(batch_size, self.latent_size, 8, 8) 482 | x = self.conv4_in(x) 483 | if input_scale <= 8: 484 | x = self.conv4(x) 485 | x = nn.LeakyReLU()(x) 486 | x = self.bn4(x) 487 | if self.insertConv: 488 | x = self.conv4_mid(x) 489 | x = nn.LeakyReLU()(x) 490 | x = self.bn4_mid(x) 491 | 492 | 493 | # 128 x 16 x 16 494 | x = self.conv5(x) 495 | x = nn.LeakyReLU()(x) 496 | x = self.bn5_mid(x) 497 | 498 | # 3 x 32 x 32 499 | #x = nn.Sigmoid()(x) 500 | 501 | x = self.conv6(x) 502 | return x 503 | 504 | 505 | 506 | 507 | 508 | 509 | class MyDecoder_noBN(nn.Module): 510 | def __init__(self, latent_size=512, input_scale=4, insertConv=False): 511 | super(self.__class__, self).__init__() 512 | self.latent_size = latent_size 513 | self.input_scale = input_scale 514 | self.fc1 = nn.Linear(latent_size, 512*2*2, bias=False) 515 | self.insertConv = insertConv 516 | 517 | self.conv2_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=True) 518 | self.conv2 = nn.ConvTranspose2d( 512, 512, 4, stride=2, padding=1, bias=True) 519 | self.conv2_mid = nn.Conv2d(152, 512, 3, 1, 1, bias=True) 520 | 521 | self.conv3_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=True) 522 | self.conv3 = nn.ConvTranspose2d( 512, 256, 4, stride=2, padding=1, bias=True) 523 | self.conv3_mid = nn.Conv2d(256, 256, 3, 1, 1, bias=True) 524 | 525 | self.conv4_in = nn.ConvTranspose2d(latent_size, 256, 1, stride=1, padding=0, bias=True) 526 | self.conv4 = nn.ConvTranspose2d( 256, 128, 4, stride=2, padding=1, bias=True) 527 | self.conv4_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=True) 528 | 529 | self.conv5 = nn.ConvTranspose2d( 128, 3, 4, stride=2, padding=1, bias=True) 530 | self.conv5_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 531 | 532 | self.conv6 = nn.ConvTranspose2d( 128, 3, 4, stride=2, padding=1, bias=True) 533 | 534 | #self.bn1 = nn.BatchNorm2d(512) 535 | #self.bn2 = nn.BatchNorm2d(512) 536 | #self.bn2_mid = nn.BatchNorm2d(512) 537 | #self.bn3 = nn.BatchNorm2d(256) 538 | #self.bn3_mid = nn.BatchNorm2d(256) 539 | #self.bn4 = nn.BatchNorm2d(128) 540 | #self.bn4_mid = nn.BatchNorm2d(128) 541 | 542 | self.apply(weights_init) 543 | self.cuda() 544 | 545 | def forward(self, x): 546 | input_scale=self.input_scale 547 | batch_size = x.shape[0] 548 | 549 | if input_scale <= 1: 550 | x = self.fc1(x) 551 | x = x.resize(batch_size, 512, 2, 2) 552 | 553 | # 512 x 2 x 2 554 | if input_scale == 2: 555 | x = x.view(batch_size, self.latent_size, 2, 2) 556 | x = self.conv2_in(x) 557 | if input_scale <= 2: 558 | x = self.conv2(x) 559 | x = nn.LeakyReLU()(x) 560 | #x = self.bn2(x) 561 | if self.insertConv: 562 | x = self.conv2_mid(x) 563 | x = nn.LeakyReLU()(x) 564 | #x = self.bn2_mid(x) 565 | 566 | # 512 x 4 x 4 567 | if input_scale == 4: 568 | x = x.view(batch_size, self.latent_size, 4, 4) 569 | x = self.conv3_in(x) 570 | x = nn.LeakyReLU()(x) 571 | if input_scale <= 4: 572 | x = self.conv3(x) 573 | x = nn.LeakyReLU()(x) 574 | #x = self.bn3(x) 575 | if self.insertConv: 576 | x = self.conv3_mid(x) 577 | x = nn.LeakyReLU()(x) 578 | #x = self.bn3_mid(x) 579 | 580 | 581 | # 256 x 8 x 8 582 | if input_scale == 8: 583 | x = x.view(batch_size, self.latent_size, 8, 8) 584 | x = self.conv4_in(x) 585 | if input_scale <= 8: 586 | x = self.conv4(x) 587 | x = nn.LeakyReLU()(x) 588 | #x = self.bn4(x) 589 | if self.insertConv: 590 | x = self.conv4_mid(x) 591 | x = nn.LeakyReLU()(x) 592 | #x = self.bn4_mid(x) 593 | 594 | 595 | # 128 x 16 x 16 596 | x = self.conv5(x) 597 | x = nn.LeakyReLU()(x) 598 | #x = self.bn5_mid(x) 599 | 600 | # 3 x 32 x 32 601 | #x = nn.Sigmoid()(x) 602 | 603 | x = self.conv6(x) 604 | return x 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | class classifier32(nn.Module): 614 | def __init__(self, latent_size=100, num_classes=2, batch_size=64, return_feat=True): 615 | super(self.__class__, self).__init__() 616 | self.return_feat = return_feat 617 | 618 | self.batch_size = batch_size 619 | self.num_classes = num_classes 620 | self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 621 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False) 622 | self.conv3 = nn.Conv2d(64, 128, 3, 2, 1, bias=False) 623 | 624 | self.conv4 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 625 | self.conv5 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 626 | self.conv6 = nn.Conv2d(128, 128, 3, 2, 1, bias=False) 627 | 628 | self.conv7 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 629 | self.conv8 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 630 | self.conv9 = nn.Conv2d(128, 128, 3, 2, 1, bias=False) 631 | 632 | self.bn1 = nn.BatchNorm2d(64) 633 | self.bn2 = nn.BatchNorm2d(64) 634 | self.bn3 = nn.BatchNorm2d(128) 635 | 636 | self.bn4 = nn.BatchNorm2d(128) 637 | self.bn5 = nn.BatchNorm2d(128) 638 | self.bn6 = nn.BatchNorm2d(128) 639 | 640 | self.bn7 = nn.BatchNorm2d(128) 641 | self.bn8 = nn.BatchNorm2d(128) 642 | self.bn9 = nn.BatchNorm2d(128) 643 | 644 | self.fc1 = nn.Linear(128*4*4, num_classes) 645 | self.dr1 = nn.Dropout2d(0.2) 646 | self.dr2 = nn.Dropout2d(0.2) 647 | self.dr3 = nn.Dropout2d(0.2) 648 | 649 | self.apply(weights_init) 650 | self.cuda() 651 | 652 | def forward(self, x, return_features=False): 653 | batch_size = len(x) 654 | 655 | x = self.dr1(x) 656 | x = self.conv1(x) 657 | x = self.bn1(x) 658 | x = nn.LeakyReLU(0.2)(x) 659 | x = self.conv2(x) 660 | x = self.bn2(x) 661 | x = nn.LeakyReLU(0.2)(x) 662 | x = self.conv3(x) 663 | x = self.bn3(x) 664 | x = nn.LeakyReLU(0.2)(x) 665 | 666 | x = self.dr2(x) 667 | x = self.conv4(x) 668 | x = self.bn4(x) 669 | x = nn.LeakyReLU(0.2)(x) 670 | x = self.conv5(x) 671 | x = self.bn5(x) 672 | x = nn.LeakyReLU(0.2)(x) 673 | x = self.conv6(x) 674 | x = self.bn6(x) 675 | x = nn.LeakyReLU(0.2)(x) 676 | 677 | x = self.dr3(x) 678 | x = self.conv7(x) 679 | x = self.bn7(x) 680 | x = nn.LeakyReLU(0.2)(x) 681 | x = self.conv8(x) 682 | x = self.bn8(x) 683 | x = nn.LeakyReLU(0.2)(x) 684 | x = self.conv9(x) 685 | x = self.bn9(x) 686 | x = nn.LeakyReLU(0.2)(x) 687 | 688 | x = x.view(batch_size, -1) 689 | if self.return_feat: 690 | return x 691 | x = self.fc1(x) 692 | return x 693 | 694 | 695 | 696 | class ResnetEncoder(nn.Module): 697 | """Pytorch module for a resnet encoder 698 | """ 699 | def __init__(self, num_layers=18, isPretrained=False, isGrayscale=False, embDimension=128, poolSize=4): 700 | super(ResnetEncoder, self).__init__() 701 | self.path_to_model = '../models' 702 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 703 | self.isGrayscale = isGrayscale 704 | self.isPretrained = isPretrained 705 | self.embDimension = embDimension 706 | self.poolSize = poolSize 707 | self.featListName = ['layer0', 'layer1', 'layer2', 'layer3', 'layer4'] 708 | 709 | resnets = { 710 | 18: models.resnet18, 711 | 34: models.resnet34, 712 | 50: models.resnet50, 713 | 101: models.resnet101, 714 | 152: models.resnet152} 715 | 716 | resnets_pretrained_path = { 717 | 18: 'resnet18-5c106cde.pth', 718 | 34: 'resnet34.pth', 719 | 50: 'resnet50.pth', 720 | 101: 'resnet101.pth', 721 | 152: 'resnet152.pth'} 722 | 723 | if num_layers not in resnets: 724 | raise ValueError("{} is not a valid number of resnet layers".format( 725 | num_layers)) 726 | 727 | self.encoder = resnets[num_layers]() 728 | 729 | if self.embDimension>0: 730 | self.encoder.linear = nn.Linear(self.num_ch_enc[-1], self.embDimension) 731 | 732 | if self.isPretrained: 733 | print("using pretrained model") 734 | self.encoder.load_state_dict( 735 | torch.load(os.path.join(self.path_to_model, resnets_pretrained_path[num_layers]))) 736 | 737 | #if self.isGrayscale: 738 | # self.encoder.conv1 = nn.Conv2d( 739 | # 1, 64, kernel_size=3, stride=1, padding=1, bias=False) 740 | #else: 741 | # self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 742 | 743 | if num_layers > 34: 744 | self.num_ch_enc[1:] *= 4 745 | 746 | def forward(self, input_image): 747 | self.features = [] 748 | 749 | x = self.encoder.conv1(input_image) 750 | x = self.encoder.bn1(x) 751 | x = self.encoder.relu(x) 752 | self.features.append(x) 753 | 754 | #x = self.encoder.layer1(self.encoder.maxpool(x)) # 755 | x = self.encoder.layer1(x) # self.encoder.maxpool(x) 756 | self.features.append(x) 757 | #print('layer1: ', x.shape) 758 | 759 | x = self.encoder.layer2(x) 760 | self.features.append(x) 761 | #print('layer2: ', x.shape) 762 | 763 | x = self.encoder.layer3(x) 764 | self.features.append(x) 765 | #print('layer3: ', x.shape) 766 | 767 | x = self.encoder.layer4(x) 768 | self.features.append(x) 769 | #print('layer4: ', x.shape) 770 | 771 | x = F.avg_pool2d(x, self.poolSize) 772 | #print('global pool: ', x.shape) 773 | 774 | x = x.view(x.size(0), -1) 775 | #print('reshape: ', x.shape) 776 | 777 | if self.embDimension>0: 778 | x = self.encoder.linear(x) 779 | #print('final: ', x.shape) 780 | return x 781 | 782 | 783 | 784 | class TinyImageNet_ClsNet(nn.Module): 785 | def __init__(self, nClass=10, layerList=(64, 32)): 786 | super(TinyImageNet_ClsNet, self).__init__() 787 | 788 | self.nClass = nClass 789 | self.layerList = layerList 790 | self.linearLayers = OrderedDict() 791 | self.relu = nn.ReLU() 792 | i=-1 793 | for i in range(len(layerList)-1): 794 | self.linearLayers[i] = nn.Linear(self.layerList[i], self.layerList[i+1]) 795 | self.linearLayers[i+1] = nn.Linear(self.layerList[-1], self.nClass) 796 | self.mnist_clsnet = nn.ModuleList(list(self.linearLayers.values())) 797 | 798 | def forward(self, x): 799 | i = -1 800 | for i in range(len(self.layerList)-1): 801 | x = self.linearLayers[i](x) 802 | x = self.relu(x) 803 | x = self.linearLayers[i+1](x) 804 | return x 805 | 806 | 807 | 808 | 809 | 810 | class TinyImageNet_Decoder(nn.Module): 811 | def __init__(self, embDimension=128, layerList=(256, 512, 3*1024*1024), imgSize=[3,32,32], 812 | isReshapeBack=True, reluFirst=False): 813 | super(TinyImageNet_Decoder, self).__init__() 814 | 815 | self.imgSize = imgSize 816 | self.embDimension = embDimension 817 | self.layerList = layerList 818 | self.linearLayers = OrderedDict() 819 | self.relu = nn.ReLU() 820 | self.isReshapeBack = isReshapeBack 821 | self.reluFirst = reluFirst 822 | 823 | self.linearLayers[0] = nn.Linear(self.embDimension, self.layerList[0]) 824 | for i in range(1, len(layerList)): 825 | self.linearLayers[i] = nn.Linear(self.layerList[i-1], self.layerList[i]) 826 | 827 | self.mnist_decoder = nn.ModuleList(list(self.linearLayers.values())) 828 | 829 | def forward(self, x): 830 | self.featList = [] 831 | 832 | if self.reluFirst: 833 | x = self.relu(x) 834 | x = self.linearLayers[0](x) 835 | self.featList.append(x) 836 | 837 | for i in range(1, len(self.layerList)): 838 | x = self.relu(x) 839 | x = self.linearLayers[i](x) 840 | self.featList.append(x) 841 | 842 | if self.isReshapeBack: 843 | x = x.view(x.size(0), self.imgSize[0], self.imgSize[1], self.imgSize[2]) 844 | 845 | return x 846 | 847 | 848 | 849 | class CondEncoder(nn.Module): 850 | def __init__(self, num_classes=200, dimension=128, device='cpu'): 851 | super(self.__class__, self).__init__() 852 | self.num_classes = num_classes 853 | self.dimension = dimension 854 | 855 | self.fc1 = nn.Linear(num_classes, num_classes) 856 | self.fc2 = nn.Linear(num_classes, dimension) 857 | self.fc3 = nn.Linear(dimension, dimension) 858 | self.device = device 859 | 860 | def forward(self, input, indicator): 861 | batch_size = len(input) 862 | x = torch.zeros(batch_size, self.num_classes).to(self.device) 863 | x[:, indicator] = 1 864 | x = x.to(self.device) 865 | 866 | x = self.fc1(x) 867 | x = nn.LeakyReLU(0.2)(x) 868 | x = self.fc2(x) 869 | x = nn.LeakyReLU(0.2)(x) 870 | x = self.fc3(x) 871 | return x -------------------------------------------------------------------------------- /demo_OpenSetSegmentation_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "OpenGAN: Open-Set Recognition via Open Data Generation\n", 8 | "================\n", 9 | "**Supplemental Material for ICCV2021 Submission**\n", 10 | "\n", 11 | "\n", 12 | "In this notebook is for demonstrating open-set semantic segmentation, especially for training in this task." 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "import packages\n", 20 | "------------------\n", 21 | "\n", 22 | "Some packages are installed automatically through Anaconda. PyTorch should be also installed." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "3.7.4 (default, Aug 13 2019, 20:35:49) \n", 35 | "[GCC 7.3.0]\n", 36 | "1.4.0+cu92\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "from __future__ import print_function, division\n", 42 | "import os, random, time, copy, scipy, pickle, sys, math, json, pickle\n", 43 | "\n", 44 | "import argparse, pprint, shutil, logging, time, timeit\n", 45 | "from pathlib import Path\n", 46 | "\n", 47 | "from skimage import io, transform\n", 48 | "import numpy as np\n", 49 | "import os.path as path\n", 50 | "import scipy.io as sio\n", 51 | "from scipy import misc\n", 52 | "from scipy import ndimage, signal\n", 53 | "import matplotlib.pyplot as plt\n", 54 | "# import PIL.Image\n", 55 | "from PIL import Image\n", 56 | "from io import BytesIO\n", 57 | "from skimage import data, img_as_float\n", 58 | "from skimage.measure import compare_ssim as ssim\n", 59 | "from skimage.measure import compare_psnr as psnr\n", 60 | "\n", 61 | "import torch, torchvision\n", 62 | "from torch.utils.data import Dataset, DataLoader\n", 63 | "import torch.nn as nn\n", 64 | "import torch.optim as optim\n", 65 | "from torch.optim import lr_scheduler \n", 66 | "import torch.nn.functional as F\n", 67 | "from torch.autograd import Variable\n", 68 | "from torchvision import datasets, models, transforms\n", 69 | "import torchvision.utils as vutils\n", 70 | "from collections import namedtuple\n", 71 | "\n", 72 | "from config_HRNet import models\n", 73 | "from config_HRNet import seg_hrnet\n", 74 | "from config_HRNet import config\n", 75 | "from config_HRNet import update_config\n", 76 | "from config_HRNet.modelsummary import *\n", 77 | "from config_HRNet.utils import *\n", 78 | "\n", 79 | "\n", 80 | "from utils.dataset_tinyimagenet import *\n", 81 | "from utils.dataset_cityscapes import *\n", 82 | "from utils.eval_funcs import *\n", 83 | "\n", 84 | "\n", 85 | "import warnings # ignore warnings\n", 86 | "warnings.filterwarnings(\"ignore\")\n", 87 | "print(sys.version)\n", 88 | "print(torch.__version__)\n", 89 | "\n", 90 | "# %load_ext autoreload\n", 91 | "# %autoreload 2" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "Setup config parameters\n", 99 | " -----------------\n", 100 | " \n", 101 | " There are several things to setup, like which GPU to use, where to read images and save files, etc. Please read and understand this. By default, you should be able to run this script smoothly by changing nothing." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 2, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "./exp/demo_step030_OpenGAN_num1000_w0.20\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "# set the random seed\n", 119 | "torch.manual_seed(0)\n", 120 | "\n", 121 | "\n", 122 | "################## set attributes for this project/experiment ##################\n", 123 | "# config result folder\n", 124 | "exp_dir = './exp' # experiment directory, used for reading the init model\n", 125 | "\n", 126 | "num_open_training_images = 1000\n", 127 | "weight_adversarialLoss = 0.2\n", 128 | "project_name = 'demo_step030_OpenGAN_num{}_w{:.2f}'.format(num_open_training_images, weight_adversarialLoss)\n", 129 | "\n", 130 | "\n", 131 | "\n", 132 | "\n", 133 | "device ='cpu'\n", 134 | "if torch.cuda.is_available(): \n", 135 | " device='cuda:3'\n", 136 | " \n", 137 | "\n", 138 | "\n", 139 | "ganBatchSize = 640\n", 140 | "batch_size = 1\n", 141 | "newsize = (-1,-1)\n", 142 | "\n", 143 | "total_epoch_num = 50 # total number of epoch in training\n", 144 | "insertConv = False \n", 145 | "embDimension = 64\n", 146 | "#isPretrained = False\n", 147 | "#encoder_num_layers = 18\n", 148 | "\n", 149 | "\n", 150 | "# Number of channels in the training images. For color images this is 3\n", 151 | "nc = 720\n", 152 | "# Size of z latent vector (i.e. size of generator input)\n", 153 | "nz = 64\n", 154 | "# Size of feature maps in generator\n", 155 | "ngf = 64\n", 156 | "# Size of feature maps in discriminator\n", 157 | "ndf = 64\n", 158 | "# Beta1 hyperparam for Adam optimizers\n", 159 | "beta1 = 0.5\n", 160 | "# Number of GPUs available. Use 0 for CPU mode.\n", 161 | "ngpu = 1\n", 162 | "\n", 163 | "\n", 164 | "\n", 165 | "save_dir = os.path.join(exp_dir, project_name)\n", 166 | "if not os.path.exists(exp_dir): os.makedirs(exp_dir)\n", 167 | "\n", 168 | "lr = 0.0001 # base learning rate\n", 169 | "\n", 170 | "num_epochs = total_epoch_num\n", 171 | "torch.cuda.device_count()\n", 172 | "torch.cuda.empty_cache()\n", 173 | "\n", 174 | "save_dir = os.path.join(exp_dir, project_name)\n", 175 | "print(save_dir) \n", 176 | "if not os.path.exists(save_dir): os.makedirs(save_dir)\n", 177 | "\n", 178 | "log_filename = os.path.join(save_dir, 'train.log')" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "Define model architecture\n", 186 | "---------\n", 187 | "\n", 188 | "Here is the definition of the model architecture. " 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 3, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "class CityscapesOpenPixelFeat4(Dataset):\n", 198 | " def __init__(self, set_name='train',\n", 199 | " numImgs=500,\n", 200 | " path_to_data='/scratch/dataset/Cityscapes_feat4'): \n", 201 | " \n", 202 | " self.imgList = []\n", 203 | " self.current_set_len = numImgs # 2975\n", 204 | " if set_name=='test': \n", 205 | " set_name = 'val'\n", 206 | " self.current_set_len = 500\n", 207 | " \n", 208 | " self.set_name = set_name\n", 209 | " self.path_to_data = path_to_data\n", 210 | " for i in range(self.current_set_len):\n", 211 | " self.imgList += ['{}_openpixel.pkl'.format(i)] \n", 212 | " \n", 213 | " def __len__(self): \n", 214 | " return self.current_set_len\n", 215 | " \n", 216 | " def __getitem__(self, idx): \n", 217 | " filename = path.join(self.path_to_data, self.set_name, self.imgList[idx])\n", 218 | " with open(filename, \"rb\") as fn:\n", 219 | " openPixFeat = pickle.load(fn)\n", 220 | " openPixFeat = openPixFeat['feat4open_percls']\n", 221 | " openPixFeat = torch.cat(openPixFeat, 0).detach()\n", 222 | " #print(openPixFeat.shape)\n", 223 | " return openPixFeat" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 4, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "parser = argparse.ArgumentParser(description='Train segmentation network') \n", 233 | "parser.add_argument('--cfg',\n", 234 | " help='experiment configure file name',\n", 235 | " default='./config_HRNet/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',\n", 236 | " type=str)\n", 237 | "parser.add_argument('opts',\n", 238 | " help=\"Modify config options using the command-line\",\n", 239 | " default=None,\n", 240 | " nargs=argparse.REMAINDER)\n", 241 | "\n", 242 | "\n", 243 | "args = parser.parse_args(r'--cfg ./config_HRNet/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml ')\n", 244 | "args.opts = []\n", 245 | "update_config(config, args)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 5, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "model = eval(config.MODEL.NAME + '.get_seg_model_myModel')(config)\n", 255 | "model_dict = model.state_dict()\n", 256 | "\n", 257 | "\n", 258 | "model_state_file = '../openset/models/hrnet_w48_cityscapes_cls19_1024x2048_ohem_trainset.pth'\n", 259 | "pretrained_dict = torch.load(model_state_file, map_location=lambda storage, loc: storage)\n", 260 | "\n", 261 | "\n", 262 | "suppl_dict = {}\n", 263 | "suppl_dict['last_1_conv.weight'] = pretrained_dict['model.last_layer.0.weight'].clone()\n", 264 | "suppl_dict['last_1_conv.bias'] = pretrained_dict['model.last_layer.0.bias'].clone()\n", 265 | "\n", 266 | "suppl_dict['last_2_BN.running_mean'] = pretrained_dict['model.last_layer.1.running_mean'].clone()\n", 267 | "suppl_dict['last_2_BN.running_var'] = pretrained_dict['model.last_layer.1.running_var'].clone()\n", 268 | "# suppl_dict['last_2_BN.num_batches_tracked'] = pretrained_dict['model.last_layer.1.num_batches_tracked']\n", 269 | "suppl_dict['last_2_BN.weight'] = pretrained_dict['model.last_layer.1.weight'].clone()\n", 270 | "suppl_dict['last_2_BN.bias'] = pretrained_dict['model.last_layer.1.bias'].clone()\n", 271 | "\n", 272 | "suppl_dict['last_4_conv.weight'] = pretrained_dict['model.last_layer.3.weight'].clone()\n", 273 | "suppl_dict['last_4_conv.bias'] = pretrained_dict['model.last_layer.3.bias'].clone()\n", 274 | "\n", 275 | "\n", 276 | "pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()\n", 277 | " if k[6:] in model_dict.keys()}\n", 278 | "\n", 279 | "\n", 280 | "model_dict.update(pretrained_dict)\n", 281 | "model_dict.update(suppl_dict)\n", 282 | "model.load_state_dict(model_dict)\n", 283 | "\n", 284 | "\n", 285 | "model.eval();\n", 286 | "model.to(device);" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 6, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "def weights_init(m):\n", 296 | " classname = m.__class__.__name__\n", 297 | " if classname.find('Conv') != -1:\n", 298 | " nn.init.normal_(m.weight.data, 0.0, 0.02)\n", 299 | " elif classname.find('BatchNorm') != -1:\n", 300 | " nn.init.normal_(m.weight.data, 1.0, 0.02)\n", 301 | " nn.init.constant_(m.bias.data, 0) \n", 302 | " \n", 303 | "\n", 304 | "class Generator(nn.Module):\n", 305 | " def __init__(self, ngpu=1, nz=100, ngf=64, nc=512):\n", 306 | " super(Generator, self).__init__()\n", 307 | " self.ngpu = ngpu\n", 308 | " self.nz = nz\n", 309 | " self.ngf = ngf\n", 310 | " self.nc = nc\n", 311 | " \n", 312 | " self.main = nn.Sequential(\n", 313 | " # input is Z, going into a convolution\n", 314 | " nn.Conv2d( self.nz, self.ngf * 8, 1, 1, 0, bias=True),\n", 315 | " nn.BatchNorm2d(self.ngf * 8),\n", 316 | " nn.ReLU(True),\n", 317 | " # state size. (self.ngf*8) x 4 x 4\n", 318 | " nn.Conv2d(self.ngf * 8, self.ngf * 4, 1, 1, 0, bias=True),\n", 319 | " nn.BatchNorm2d(self.ngf * 4),\n", 320 | " nn.ReLU(True),\n", 321 | " # state size. (self.ngf*4) x 8 x 8\n", 322 | " nn.Conv2d( self.ngf * 4, self.ngf * 2, 1, 1, 0, bias=True),\n", 323 | " nn.BatchNorm2d(self.ngf * 2),\n", 324 | " nn.ReLU(True),\n", 325 | " # state size. (self.ngf*2) x 16 x 16\n", 326 | " nn.Conv2d( self.ngf * 2, self.ngf*4, 1, 1, 0, bias=True),\n", 327 | " nn.BatchNorm2d(self.ngf*4),\n", 328 | " nn.ReLU(True),\n", 329 | " # state size. (self.ngf) x 32 x 32\n", 330 | " nn.Conv2d( self.ngf*4, self.nc, 1, 1, 0, bias=True),\n", 331 | " #nn.Tanh()\n", 332 | " # state size. (self.nc) x 64 x 64\n", 333 | " )\n", 334 | "\n", 335 | " def forward(self, input):\n", 336 | " return self.main(input)\n", 337 | "\n", 338 | " \n", 339 | "class Discriminator(nn.Module):\n", 340 | " def __init__(self, ngpu=1, nc=512, ndf=64):\n", 341 | " super(Discriminator, self).__init__()\n", 342 | " self.ngpu = ngpu\n", 343 | " self.nc = nc\n", 344 | " self.ndf = ndf\n", 345 | " self.main = nn.Sequential(\n", 346 | " nn.Conv2d(self.nc, self.ndf*8, 1, 1, 0, bias=True),\n", 347 | " nn.LeakyReLU(0.2, inplace=True),\n", 348 | " nn.Conv2d(self.ndf*8, self.ndf*4, 1, 1, 0, bias=True),\n", 349 | " nn.BatchNorm2d(self.ndf*4),\n", 350 | " nn.LeakyReLU(0.2, inplace=True),\n", 351 | " nn.Conv2d(self.ndf*4, self.ndf*2, 1, 1, 0, bias=True),\n", 352 | " nn.BatchNorm2d(self.ndf*2),\n", 353 | " nn.LeakyReLU(0.2, inplace=True),\n", 354 | " nn.Conv2d(self.ndf*2, self.ndf, 1, 1, 0, bias=True),\n", 355 | " nn.BatchNorm2d(self.ndf),\n", 356 | " nn.LeakyReLU(0.2, inplace=True),\n", 357 | " nn.Conv2d(self.ndf, 1, 1, 1, 0, bias=True),\n", 358 | " nn.Sigmoid()\n", 359 | " )\n", 360 | "\n", 361 | " def forward(self, input):\n", 362 | " return self.main(input)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 7, 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "name": "stdout", 372 | "output_type": "stream", 373 | "text": [ 374 | "cuda:3\n" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "netG = Generator(ngpu=ngpu, nz=nz, ngf=ngf, nc=nc).to(device)\n", 380 | "netD = Discriminator(ngpu=ngpu, nc=nc, ndf=ndf).to(device)\n", 381 | "\n", 382 | "\n", 383 | "# Handle multi-gpu if desired\n", 384 | "if ('cuda' in device) and (ngpu > 1): \n", 385 | " netD = nn.DataParallel(netD, list(range(ngpu)))\n", 386 | "\n", 387 | "# Apply the weights_init function to randomly initialize all weights\n", 388 | "# to mean=0, stdev=0.2.\n", 389 | "netD.apply(weights_init)\n", 390 | "\n", 391 | "\n", 392 | "if ('cuda' in device) and (ngpu > 1):\n", 393 | " netG = nn.DataParallel(netG, list(range(ngpu)))\n", 394 | "netG.apply(weights_init)\n", 395 | "\n", 396 | "print(device)" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 8, 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "name": "stdout", 406 | "output_type": "stream", 407 | "text": [ 408 | "torch.Size([5, 64, 1, 1]) torch.Size([5, 720, 1, 1]) torch.Size([5, 1, 1, 1])\n" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "noise = torch.randn(batch_size*5, nz, 1, 1, device=device)\n", 414 | "# Generate fake image batch with G\n", 415 | "fake = netG(noise)\n", 416 | "predLabel = netD(fake)\n", 417 | "\n", 418 | "print(noise.shape, fake.shape, predLabel.shape)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": {}, 424 | "source": [ 425 | "setup dataset\n", 426 | "-----------" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 9, 432 | "metadata": {}, 433 | "outputs": [ 434 | { 435 | "name": "stdout", 436 | "output_type": "stream", 437 | "text": [ 438 | "2975 500\n" 439 | ] 440 | } 441 | ], 442 | "source": [ 443 | "# torchvision.transforms.Normalize(mean, std, inplace=False)\n", 444 | "imgTransformList = transforms.Compose([\n", 445 | " transforms.ToTensor(),\n", 446 | " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", 447 | "])\n", 448 | "\n", 449 | "targetTransformList = transforms.Compose([\n", 450 | " transforms.ToTensor(), \n", 451 | "])\n", 452 | "\n", 453 | "cls_datasets = {set_name: Cityscapes(root='/scratch/dataset/Cityscapes',\n", 454 | " newsize=newsize,\n", 455 | " split=set_name,\n", 456 | " mode='fine',\n", 457 | " target_type='semantic',\n", 458 | " transform=imgTransformList,\n", 459 | " target_transform=targetTransformList,\n", 460 | " transforms=None)\n", 461 | " for set_name in ['train', 'val']} # 'train', \n", 462 | "\n", 463 | "dataloaders = {set_name: DataLoader(cls_datasets[set_name],\n", 464 | " batch_size=batch_size,\n", 465 | " shuffle=set_name=='train', \n", 466 | " num_workers=4) # num_work can be set to batch_size\n", 467 | " for set_name in ['train', 'val']} # 'train',\n", 468 | "\n", 469 | "\n", 470 | "print(len(cls_datasets['train']), len(cls_datasets['val']))\n", 471 | "classDictionary = cls_datasets['val'].classes" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": 10, 477 | "metadata": {}, 478 | "outputs": [ 479 | { 480 | "name": "stdout", 481 | "output_type": "stream", 482 | "text": [ 483 | "0 unlabeled\n", 484 | "1 ego vehicle\n", 485 | "2 rectification border\n", 486 | "3 out of roi\n", 487 | "4 static\n", 488 | "5 dynamic\n", 489 | "6 ground\n", 490 | "9 parking\n", 491 | "10 rail track\n", 492 | "14 guard rail\n", 493 | "15 bridge\n", 494 | "16 tunnel\n", 495 | "18 polegroup\n", 496 | "29 caravan\n", 497 | "30 trailer\n", 498 | "34 license plate\n", 499 | "total# 16\n" 500 | ] 501 | } 502 | ], 503 | "source": [ 504 | "id2trainID = {}\n", 505 | "id2color = {}\n", 506 | "trainID2color = {}\n", 507 | "id2name = {}\n", 508 | "opensetIDlist = []\n", 509 | "for i in range(len(classDictionary)):\n", 510 | " id2trainID[i] = classDictionary[i][2]\n", 511 | " id2color[i] = classDictionary[i][-1]\n", 512 | " trainID2color[classDictionary[i][2]] = classDictionary[i][-1]\n", 513 | " id2name[i] = classDictionary[i][0]\n", 514 | " if classDictionary[i][-2]:\n", 515 | " opensetIDlist += [i]\n", 516 | "\n", 517 | "id2trainID_list = []\n", 518 | "for i in range(len(id2trainID)):\n", 519 | " id2trainID_list.append(id2trainID[i])\n", 520 | "id2trainID_np = np.asarray(id2trainID_list) \n", 521 | " \n", 522 | "for elm in opensetIDlist:\n", 523 | " print(elm, id2name[elm])\n", 524 | "print('total# {}'.format(len(opensetIDlist)))" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 11, 530 | "metadata": {}, 531 | "outputs": [], 532 | "source": [ 533 | "data_sampler = iter(dataloaders['train'])\n", 534 | "data = next(data_sampler)\n", 535 | "imageList, labelList = data[0], data[1]\n", 536 | "\n", 537 | "imageList = imageList.to(device)\n", 538 | "labelList = labelList.to(device)" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 12, 544 | "metadata": {}, 545 | "outputs": [ 546 | { 547 | "data": { 548 | "text/plain": [ 549 | "(torch.Size([1, 3, 1024, 2048]), torch.Size([1, 1024, 2048]))" 550 | ] 551 | }, 552 | "execution_count": 12, 553 | "metadata": {}, 554 | "output_type": "execute_result" 555 | } 556 | ], 557 | "source": [ 558 | "imageList.shape, labelList.shape" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "setup training\n", 566 | "-----------" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 13, 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "# Initialize BCELoss function\n", 576 | "criterion = nn.BCELoss()\n", 577 | "\n", 578 | "# Create batch of latent vectors that we will use to visualize\n", 579 | "# the progression of the generator\n", 580 | "fixed_noise = torch.randn(64, nz, 1, 1, device=device)\n", 581 | "\n", 582 | "# Establish open and close labels\n", 583 | "close_label = 1\n", 584 | "open_label = 0\n", 585 | "\n", 586 | "# Establish convention for real and fake labels during training\n", 587 | "real_label = 1\n", 588 | "fake_label = 0\n", 589 | "\n", 590 | "\n", 591 | "# Setup Adam optimizers for both G and D\n", 592 | "optimizerD = optim.Adam(netD.parameters(), lr=lr/1.5, betas=(beta1, 0.999))\n", 593 | "optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))" 594 | ] 595 | }, 596 | { 597 | "cell_type": "markdown", 598 | "metadata": {}, 599 | "source": [ 600 | "testing a single image\n", 601 | "-----------" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 14, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "labelList = labelList.unsqueeze(1)\n", 611 | "labelList = F.interpolate(labelList, scale_factor=0.25, mode='nearest')\n", 612 | "labelList = labelList.squeeze()\n", 613 | "H, W = labelList.squeeze().shape\n", 614 | "trainlabelList = id2trainID_np[labelList.cpu().numpy().reshape(-1,).astype(np.int32)]\n", 615 | "trainlabelList = trainlabelList.reshape((1,H,W))\n", 616 | "trainlabelList = torch.from_numpy(trainlabelList)\n", 617 | "\n", 618 | "\n", 619 | "\n", 620 | "upsampleFunc = nn.UpsamplingBilinear2d(scale_factor=4)\n", 621 | "with torch.no_grad():\n", 622 | " imageList = imageList.to(device)\n", 623 | " logitsTensor = model(imageList).detach().cpu()\n", 624 | " #logitsTensor = upsampleFunc(logitsTensor)\n", 625 | " softmaxTensor = F.softmax(logitsTensor, dim=1)\n", 626 | " \n", 627 | " feat1Tensor = model.feat1.detach()\n", 628 | " feat2Tensor = model.feat2.detach()\n", 629 | " feat3Tensor = model.feat3.detach()\n", 630 | " feat4Tensor = model.feat4.detach()\n", 631 | " feat5Tensor = model.feat5.detach()\n", 632 | " \n", 633 | " torch.cuda.empty_cache()" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 15, 639 | "metadata": { 640 | "scrolled": false 641 | }, 642 | "outputs": [ 643 | { 644 | "data": { 645 | "text/plain": [ 646 | "(torch.Size([1, 720, 256, 512]), torch.Size([1, 256, 512]), 131072)" 647 | ] 648 | }, 649 | "execution_count": 15, 650 | "metadata": {}, 651 | "output_type": "execute_result" 652 | } 653 | ], 654 | "source": [ 655 | "feat4Tensor.shape, trainlabelList.shape, trainlabelList.shape[1]*trainlabelList.shape[2]" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 16, 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [ 664 | "validList = trainlabelList.reshape(-1,1)\n", 665 | "validList = ((validList>=0) & (validList<=18)).nonzero()\n", 666 | "validList = validList[:,0]\n", 667 | "validList = validList[torch.randperm(validList.size()[0])]\n", 668 | "validList = validList[:ganBatchSize]" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": 17, 674 | "metadata": {}, 675 | "outputs": [], 676 | "source": [ 677 | "label = torch.full((ganBatchSize,), close_label, device=device)" 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": 18, 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "real_cpu = feat4Tensor.squeeze()\n", 687 | "real_cpu = real_cpu.reshape(real_cpu.shape[0], -1).permute(1,0)\n", 688 | "real_cpu = real_cpu[validList,:].unsqueeze(-1).unsqueeze(-1).to(device)\n", 689 | "\n", 690 | "output = netD(real_cpu).view(-1)\n", 691 | "# Calculate loss on all-real batch\n", 692 | "errD_real = criterion(output, label)" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": 19, 698 | "metadata": {}, 699 | "outputs": [], 700 | "source": [ 701 | "noise = torch.randn(ganBatchSize, nz, 1, 1, device=device)\n", 702 | "# Generate fake image batch with G\n", 703 | "fake = netG(noise)\n", 704 | "label.fill_(fake_label)\n", 705 | "# Classify all fake batch with D\n", 706 | "output = netD(fake.detach()).view(-1)\n", 707 | "# Calculate D's loss on the all-fake batch\n", 708 | "errD_fake = criterion(output, label)" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 20, 714 | "metadata": { 715 | "scrolled": true 716 | }, 717 | "outputs": [ 718 | { 719 | "data": { 720 | "text/plain": [ 721 | "(torch.Size([640, 64, 1, 1]), torch.Size([640]), torch.Size([640, 720, 1, 1]))" 722 | ] 723 | }, 724 | "execution_count": 20, 725 | "metadata": {}, 726 | "output_type": "execute_result" 727 | } 728 | ], 729 | "source": [ 730 | "noise.shape, label.shape, fake.shape" 731 | ] 732 | }, 733 | { 734 | "cell_type": "markdown", 735 | "metadata": {}, 736 | "source": [ 737 | "training GAN\n", 738 | "-----------" 739 | ] 740 | }, 741 | { 742 | "cell_type": "code", 743 | "execution_count": 21, 744 | "metadata": {}, 745 | "outputs": [ 746 | { 747 | "name": "stdout", 748 | "output_type": "stream", 749 | "text": [ 750 | "torch.Size([640, 720, 1, 1])\n" 751 | ] 752 | } 753 | ], 754 | "source": [ 755 | "openPix_datasets = CityscapesOpenPixelFeat4(set_name='train', numImgs=num_open_training_images)\n", 756 | "openPix_dataloader = DataLoader(openPix_datasets, batch_size=1, shuffle=True, num_workers=4) \n", 757 | "\n", 758 | "openPix_sampler = iter(openPix_dataloader)\n", 759 | "\n", 760 | "openPixFeat = next(openPix_sampler)\n", 761 | "openPixFeat = openPixFeat.squeeze(0)\n", 762 | "\n", 763 | "openPixIdxList = torch.randperm(openPixFeat.size()[0])\n", 764 | "openPixIdxList = openPixIdxList[:ganBatchSize]\n", 765 | "openPixFeat = openPixFeat[openPixIdxList].to(device)\n", 766 | "\n", 767 | "print(openPixFeat.shape)" 768 | ] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "execution_count": null, 773 | "metadata": { 774 | "scrolled": true 775 | }, 776 | "outputs": [ 777 | { 778 | "name": "stdout", 779 | "output_type": "stream", 780 | "text": [ 781 | "Starting Training Loop...\n", 782 | "[0/50][0/2975]\t\tlossG: 0.6536, lossD: 0.5840\n", 783 | "[0/50][100/2975]\t\tlossG: 0.6569, lossD: 0.4096\n", 784 | "[0/50][200/2975]\t\tlossG: 0.6236, lossD: 0.3967\n", 785 | "[0/50][300/2975]\t\tlossG: 0.5869, lossD: 0.2977\n", 786 | "[0/50][400/2975]\t\tlossG: 0.5548, lossD: 0.2332\n", 787 | "[0/50][500/2975]\t\tlossG: 0.5504, lossD: 0.3024\n" 788 | ] 789 | } 790 | ], 791 | "source": [ 792 | "# Training Loop\n", 793 | "\n", 794 | "# Lists to keep track of progress\n", 795 | "lossList = []\n", 796 | "G_losses = []\n", 797 | "D_losses = []\n", 798 | "\n", 799 | "fake_BatchSize = int(ganBatchSize/2)\n", 800 | "open_BatchSize = ganBatchSize\n", 801 | "\n", 802 | "\n", 803 | "\n", 804 | "tmp_weights = torch.full((ganBatchSize+open_BatchSize+fake_BatchSize,), 1, device=device)\n", 805 | "tmp_weights[-fake_BatchSize:] *= weight_adversarialLoss\n", 806 | "criterionD = nn.BCELoss(weight=tmp_weights)\n", 807 | "\n", 808 | "\n", 809 | "\n", 810 | "print(\"Starting Training Loop...\")\n", 811 | "# For each epoch\n", 812 | "openPixImgCount = 0\n", 813 | "openPix_sampler = iter(openPix_dataloader)\n", 814 | "for epoch in range(num_epochs):\n", 815 | " # For each batch in the dataloader\n", 816 | " for i, sample in enumerate(dataloaders['train'], 0):\n", 817 | " imageList, labelList = sample\n", 818 | " imageList = imageList.to(device)\n", 819 | " labelList = labelList.to(device)\n", 820 | "\n", 821 | " labelList = labelList.unsqueeze(1)\n", 822 | " labelList = F.interpolate(labelList, scale_factor=0.25, mode='nearest')\n", 823 | " labelList = labelList.squeeze()\n", 824 | " H, W = labelList.squeeze().shape\n", 825 | " trainlabelList = id2trainID_np[labelList.cpu().numpy().reshape(-1,).astype(np.int32)]\n", 826 | " trainlabelList = trainlabelList.reshape((1,H,W))\n", 827 | " trainlabelList = torch.from_numpy(trainlabelList)\n", 828 | " \n", 829 | " \n", 830 | " #upsampleFunc = nn.UpsamplingBilinear2d(scale_factor=4)\n", 831 | " with torch.no_grad():\n", 832 | " imageList = imageList.to(device)\n", 833 | " logitsTensor = model(imageList).detach().cpu()\n", 834 | " featTensor = model.feat4.detach()\n", 835 | " \n", 836 | " validList = trainlabelList.reshape(-1,1)\n", 837 | " validList = ((validList>=0) & (validList<=18)).nonzero()\n", 838 | " validList = validList[:,0]\n", 839 | " tmp = torch.randperm(validList.size()[0]) \n", 840 | " validList = validList[tmp[:ganBatchSize]]\n", 841 | " \n", 842 | "\n", 843 | " \n", 844 | " label_closeset = torch.full((ganBatchSize,), close_label, device=device)\n", 845 | " feat_closeset = featTensor.squeeze()\n", 846 | " feat_closeset = feat_closeset.reshape(feat_closeset.shape[0], -1).permute(1,0)\n", 847 | " feat_closeset = feat_closeset[validList,:].unsqueeze(-1).unsqueeze(-1) \n", 848 | " label_open = torch.full((open_BatchSize,), open_label, device=device)\n", 849 | " \n", 850 | " openPixImgCount += 1\n", 851 | " feat_openset = next(openPix_sampler)\n", 852 | " feat_openset = feat_openset.squeeze(0)\n", 853 | " openPixIdxList = torch.randperm(feat_openset.size()[0])\n", 854 | " openPixIdxList = openPixIdxList[:open_BatchSize]\n", 855 | " feat_openset = feat_openset[openPixIdxList].to(device)\n", 856 | "\n", 857 | " if openPixImgCount==num_open_training_images:\n", 858 | " openPixImgCount = 0\n", 859 | " openPix_sampler = iter(openPix_dataloader)\n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " # generate fake images \n", 864 | " noise = torch.randn(fake_BatchSize, nz, 1, 1, device=device)\n", 865 | " # Generate fake image batch with G\n", 866 | " label_fake = torch.full((fake_BatchSize,), fake_label, device=device)\n", 867 | " feat_fakeset = netG(noise) \n", 868 | " \n", 869 | " ############################\n", 870 | " # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n", 871 | " ###########################\n", 872 | " # using close&open&fake data to update D\n", 873 | " netD.zero_grad()\n", 874 | " X = torch.cat((feat_closeset, feat_openset.to(device), feat_fakeset.detach()),0)\n", 875 | " label_total = torch.cat((label_closeset, label_open, label_fake),0)\n", 876 | " \n", 877 | " output = netD(X).view(-1)\n", 878 | " lossD = criterionD(output, label_total)\n", 879 | " lossD.backward()\n", 880 | " optimizerD.step()\n", 881 | " errD = lossD.mean().item() \n", 882 | " \n", 883 | " \n", 884 | " ############################\n", 885 | " # (2) Update G network: maximize log(D(G(z)))\n", 886 | " ###########################\n", 887 | " netG.zero_grad()\n", 888 | " label_fakeclose = torch.full((fake_BatchSize,), close_label, device=device) \n", 889 | " # Since we just updated D, perform another forward pass of all-fake batch through D\n", 890 | " output = netD(feat_fakeset).view(-1)\n", 891 | " # Calculate G's loss based on this output\n", 892 | " lossG = criterion(output, label_fakeclose)\n", 893 | " # Calculate gradients for G\n", 894 | " lossG.backward()\n", 895 | " errG = lossG.mean().item()\n", 896 | " # Update G\n", 897 | " optimizerG.step()\n", 898 | " \n", 899 | " \n", 900 | " # Save Losses for plotting later\n", 901 | " G_losses.append(errG)\n", 902 | " D_losses.append(errD)\n", 903 | " \n", 904 | " \n", 905 | " # Output training stats\n", 906 | " if i % 100 == 0:\n", 907 | " print('[%d/%d][%d/%d]\\t\\tlossG: %.4f, lossD: %.4f'\n", 908 | " % (epoch, num_epochs, i, len(dataloaders['train']), \n", 909 | " errG, errD))\n", 910 | " \n", 911 | " \n", 912 | " cur_model_wts = copy.deepcopy(netD.state_dict())\n", 913 | " path_to_save_paramOnly = os.path.join(save_dir, 'epoch-{}.classifier'.format(epoch+1))\n", 914 | " torch.save(cur_model_wts, path_to_save_paramOnly)\n", 915 | " cur_model_wts = copy.deepcopy(netG.state_dict())\n", 916 | " path_to_save_paramOnly = os.path.join(save_dir, 'epoch-{}.GNet'.format(epoch+1))\n", 917 | " torch.save(cur_model_wts, path_to_save_paramOnly)" 918 | ] 919 | }, 920 | { 921 | "cell_type": "markdown", 922 | "metadata": {}, 923 | "source": [ 924 | "validating results\n", 925 | "-----------" 926 | ] 927 | }, 928 | { 929 | "cell_type": "code", 930 | "execution_count": null, 931 | "metadata": {}, 932 | "outputs": [], 933 | "source": [ 934 | "plt.figure(figsize=(10,5))\n", 935 | "plt.title(\"binary cross-entropy loss in training\")\n", 936 | "plt.plot(Dopen_losses, label=\"Dopen_losses\")\n", 937 | "plt.plot(Dclose_losses, label=\"Dclose_losses\")\n", 938 | "plt.plot(Dfake_losses, label=\"Dfake_losses\")\n", 939 | "plt.plot(G_losses, label=\"G_losses\")\n", 940 | "plt.xlabel(\"iterations\")\n", 941 | "plt.ylabel(\"Loss\")\n", 942 | "plt.legend()\n", 943 | "# plt.savefig('learningCurves_{}.png'.format(modelFlag), bbox_inches='tight',transparent=True)\n", 944 | "# plt.show()" 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "execution_count": null, 950 | "metadata": {}, 951 | "outputs": [], 952 | "source": [] 953 | } 954 | ], 955 | "metadata": { 956 | "kernelspec": { 957 | "display_name": "Python 3", 958 | "language": "python", 959 | "name": "python3" 960 | }, 961 | "language_info": { 962 | "codemirror_mode": { 963 | "name": "ipython", 964 | "version": 3 965 | }, 966 | "file_extension": ".py", 967 | "mimetype": "text/x-python", 968 | "name": "python", 969 | "nbconvert_exporter": "python", 970 | "pygments_lexer": "ipython3", 971 | "version": "3.7.4" 972 | } 973 | }, 974 | "nbformat": 4, 975 | "nbformat_minor": 2 976 | } 977 | --------------------------------------------------------------------------------