├── OpenGAN_logo.png
├── README.md
├── utils
├── dataset_tinyimagenet.py
├── dataset_tinyimagenet_3sets.py
├── layers.py
├── eval_funcs.py
├── dataset_cifar10.py
├── dataset_cityscapes.py
├── dataset_cityscapes4OpenGAN.py
└── network_arch_tinyimagenet.py
└── demo_OpenSetSegmentation_training.ipynb
/OpenGAN_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimerykong/OpenGAN/HEAD/OpenGAN_logo.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## OpenGAN: Open-Set Recognition via Open Data Generation
2 |
3 | ICCV 2021 ([best paper honorable mention](https://www.cs.cmu.edu/~shuk/OpenGAN.html))
4 |
5 | 
6 |
7 | [[website](https://www.cs.cmu.edu/~shuk/OpenGAN.html)]
8 | [[poster](http://www.cs.cmu.edu/~shuk/img/OpenGAN_poster.pdf)]
9 | [[slides](http://www.cs.cmu.edu/~shuk/img/OpenGAN_slides.pdf)]
10 | [[oral presentation](https://youtu.be/CNYqYXyUHn0)]
11 | [[paper](https://arxiv.org/abs/2104.02939)]
12 | [[PAMI Version](https://github.com/aimerykong/aimerykong.github.io/raw/main/OpenGAN_files/PAMI_OpenGAN_accepted_version.pdf) 18MB]
13 |
14 | Real-world machine learning systems need to analyze novel testing data that differs from the training data. In K-way classification, this is crisply formulated as open-set recognition, core to which is the ability to discriminate open-set data outside the K closed-set classes. Two conceptually elegant ideas for open-set discrimination are: 1) discriminatively learning an open-vs-closed binary discriminator by exploiting some outlier data as the open-set, and 2) unsupervised learning the closed-set data distribution with a GAN and using its discriminator as the open-set likelihood function. However, the former generalizes poorly to diverse open test data due to overfitting to the training outliers, which unlikely exhaustively span the open-world. The latter does not work well, presumably due to the instable training of GANs. Motivated by the above, we propose OpenGAN, which addresses the limitation of each approach by combining them with several technical insights. First, we show that a carefully selected GAN-discriminator on some real outlier data already achieves the state-of-the-art. Second, we augment the available set of real open training examples with adversarially synthesized "fake" data.
15 | Third and most importantly, we build the discriminator over the features computed by the closed-world K-way networks.
16 | Extensive experiments show that OpenGAN significantly outperforms prior open-set methods.
17 |
18 |
19 | **keywords**: out-of-distribution detection, anomaly detection, open-set recognition, novelty detection, density estimation, generative model, discriminative model, adverserial learning, image classification, semantic segmentation.
20 |
21 |
22 | If you find our model/method/dataset useful, please cite our work ([ICCV version on arxiv](https://arxiv.org/abs/2104.02939), [PAMI version](https://github.com/aimerykong/aimerykong.github.io/raw/main/OpenGAN_files/PAMI_OpenGAN_accepted_version.pdf)):
23 |
24 | @inproceedings{OpenGAN,
25 | title={OpenGAN: Open-Set Recognition via Open Data Generation},
26 | author={Kong, Shu and Ramanan, Deva},
27 | booktitle={ICCV},
28 | year={2021}
29 | }
30 |
31 | @inproceedings{OpenGAN_PAMI,
32 | title={OpenGAN: Open-Set Recognition via Open Data Generation},
33 | author={Kong, Shu and Ramanan, Deva},
34 | booktitle={IEEE PAMI},
35 | year={2022}
36 | }
37 |
38 |
39 |
40 | last update: July, 2021
41 |
42 | Shu Kong
43 |
44 | aimerykong At g-m-a-i-l dot com
45 |
--------------------------------------------------------------------------------
/utils/dataset_tinyimagenet.py:
--------------------------------------------------------------------------------
1 | import os, random, time, copy
2 | from skimage import io, transform
3 | import numpy as np
4 | import os.path as path
5 | import scipy.io as sio
6 | from scipy import misc
7 | import matplotlib.pyplot as plt
8 | import PIL.Image
9 | import pickle
10 | import skimage.transform
11 | import csv
12 | import torch
13 | from torch.utils.data import Dataset, DataLoader
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | from torch.optim import lr_scheduler
17 | import torch.nn.functional as F
18 | from torch.autograd import Variable
19 |
20 | import torchvision
21 | from torchvision import datasets, models, transforms
22 |
23 |
24 |
25 |
26 | class TINYIMAGENET(Dataset):
27 | def __init__(self, size=(64,64), set_name='train',
28 | path_to_data='/scratch/shuk/dataset/tiny-imagenet-200',
29 | isAugment=True):
30 |
31 | self.path_to_data = path_to_data
32 | self.mapping_name2id = {}
33 | self.mapping_id2name = {}
34 | with open(path.join(self.path_to_data, 'wnids.txt')) as csv_file:
35 | csv_reader = csv.reader(csv_file, delimiter=' ')
36 | idx = 0
37 | for row in csv_reader:
38 | self.mapping_id2name[idx] = row[0]
39 | self.mapping_name2id[row[0]] = idx
40 | idx += 1
41 |
42 |
43 | if set_name=='test': set_name = 'val'
44 |
45 | self.size = size
46 | self.set_name = set_name
47 | self.path_to_data = path_to_data
48 | self.isAugment = isAugment
49 |
50 | self.imageNameList = []
51 | self.className = []
52 | self.labelList = []
53 | self.mappingLabel2Name = dict()
54 | curLabel = 0
55 |
56 |
57 | if self.set_name == 'val':
58 | with open(path.join(self.path_to_data, 'val', 'val_annotations.txt')) as csv_file:
59 | csv_reader = csv.reader(csv_file, delimiter='\t')
60 | line_count = 0
61 | for row in csv_reader:
62 | self.imageNameList += [path.join(self.path_to_data, 'val', 'images', row[0])]
63 | self.labelList += [self.mapping_name2id[row[1]]]
64 | else: # 'train'
65 | self.current_class_dir = path.join(self.path_to_data, self.set_name)
66 | for curClass in os.listdir(self.current_class_dir):
67 | if curClass[0]=='.': continue
68 |
69 | curLabel = self.mapping_name2id[curClass]
70 | for curImg in os.listdir(path.join(self.current_class_dir, curClass, 'images')):
71 | if curImg[0]=='.': continue
72 | self.labelList += [curLabel]
73 | self.imageNameList += [path.join(self.path_to_data, self.set_name, curClass, 'images', curImg)]
74 |
75 |
76 | self.current_set_len = len(self.labelList)
77 |
78 | if self.set_name=='test' or self.set_name=='val' or not self.isAugment:
79 | self.transform = transforms.Compose([
80 | transforms.ToTensor(),
81 | transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)),
82 | ]) # ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
83 | else:
84 | self.transform = transforms.Compose([
85 | transforms.RandomCrop(self.size[0], padding=4),
86 | transforms.RandomHorizontalFlip(),
87 | transforms.ToTensor(),
88 | transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)),
89 | ])
90 |
91 | def __len__(self):
92 | return self.current_set_len
93 |
94 | def __getitem__(self, idx):
95 | curLabel = np.asarray(self.labelList[idx])
96 | curImage = self.imageNameList[idx]
97 | curImage = PIL.Image.open(curImage).convert('RGB')
98 | curImage = self.transform(curImage)
99 |
100 | #print(idx, curLabel)
101 |
102 | #curLabel = torch.tensor([curLabel]).unsqueeze(0).unsqueeze(0)
103 |
104 | return curImage, curLabel
--------------------------------------------------------------------------------
/utils/dataset_tinyimagenet_3sets.py:
--------------------------------------------------------------------------------
1 | import os, random, time, copy
2 | from skimage import io, transform
3 | import numpy as np
4 | import os.path as path
5 | import scipy.io as sio
6 | from scipy import misc
7 | import matplotlib.pyplot as plt
8 | import PIL.Image
9 | import pickle
10 | import skimage.transform
11 | import csv
12 | import torch
13 | from torch.utils.data import Dataset, DataLoader
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | from torch.optim import lr_scheduler
17 | import torch.nn.functional as F
18 | from torch.autograd import Variable
19 |
20 | import torchvision
21 | from torchvision import datasets, models, transforms
22 |
23 |
24 |
25 |
26 | class TINYIMAGENET(Dataset):
27 | def __init__(self, size=(64,64), set_name='train',
28 | path_to_data='/scratch/shuk/dataset/tiny-imagenet-200',
29 | isAugment=True):
30 |
31 | self.path_to_data = path_to_data
32 | self.mapping_name2id = {}
33 | self.mapping_id2name = {}
34 | with open(path.join(self.path_to_data, 'wnids.txt')) as csv_file:
35 | csv_reader = csv.reader(csv_file, delimiter=' ')
36 | idx = 0
37 | for row in csv_reader:
38 | self.mapping_id2name[idx] = row[0]
39 | self.mapping_name2id[row[0]] = idx
40 | idx += 1
41 |
42 |
43 | #if set_name=='test': set_name = 'val'
44 |
45 | self.size = size
46 | self.set_name = set_name
47 | self.path_to_data = path_to_data
48 | self.isAugment = isAugment
49 |
50 | self.imageNameList = []
51 | self.className = []
52 | self.labelList = []
53 | self.mappingLabel2Name = dict()
54 | curLabel = 0
55 |
56 | if self.set_name == 'test':
57 | img_dir = os.path.join(self.path_to_data, 'val', 'images')
58 | for file_name in os.listdir(img_dir):
59 | if file_name[-4:] == 'JPEG':
60 | self.imageNameList += [path.join(self.path_to_data, 'val', 'images', file_name)]
61 | self.labelList += [0]
62 |
63 | elif self.set_name == 'val':
64 | with open(path.join(self.path_to_data, 'val', 'val_annotations.txt')) as csv_file:
65 | csv_reader = csv.reader(csv_file, delimiter='\t')
66 | line_count = 0
67 | for row in csv_reader:
68 | self.imageNameList += [path.join(self.path_to_data, 'val', 'images', row[0])]
69 | self.labelList += [self.mapping_name2id[row[1]]]
70 | #with open(path.join(self.path_to_data, 'val', 'val_annotations.txt')) as csv_file:
71 | # csv_reader = csv.reader(csv_file, delimiter='\t')
72 | # line_count = 0
73 | # for row in csv_reader:
74 | # self.imageNameList += [path.join(self.path_to_data, 'val', 'images', row[0])]
75 | # self.labelList += [self.mapping_name2id[row[1]]]
76 | else: # 'train'
77 | self.current_class_dir = path.join(self.path_to_data, self.set_name)
78 | for curClass in os.listdir(self.current_class_dir):
79 | if curClass[0]=='.': continue
80 |
81 | curLabel = self.mapping_name2id[curClass]
82 | for curImg in os.listdir(path.join(self.current_class_dir, curClass, 'images')):
83 | if curImg[0]=='.': continue
84 | self.labelList += [curLabel]
85 | self.imageNameList += [path.join(self.path_to_data, self.set_name, curClass, 'images', curImg)]
86 |
87 |
88 | self.current_set_len = len(self.labelList)
89 |
90 | if self.set_name=='test' or self.set_name=='val' or not self.isAugment:
91 | self.transform = transforms.Compose([
92 | transforms.ToTensor(),
93 | transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)),
94 | ]) # ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
95 | else:
96 | self.transform = transforms.Compose([
97 | transforms.RandomCrop(self.size[0], padding=4),
98 | transforms.RandomHorizontalFlip(),
99 | transforms.ToTensor(),
100 | transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)),
101 | ])
102 |
103 | def __len__(self):
104 | return self.current_set_len
105 |
106 | def __getitem__(self, idx):
107 | curLabel = np.asarray(self.labelList[idx])
108 | curImage = self.imageNameList[idx]
109 | curImage = PIL.Image.open(curImage).convert('RGB')
110 | curImage = self.transform(curImage)
111 |
112 | return curImage, curLabel
--------------------------------------------------------------------------------
/utils/layers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | from __future__ import absolute_import, division, print_function
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | def focus_to_intrinsics(opt, fxy): # 2-dim output fxy
14 | """Convert Nx2 focal length fxy into camera intrinsics
15 | for homogeneous representation
16 | Nx2 --> Nx4x4
17 | """
18 | N, d = fxy.size()
19 |
20 | dummy_leftMat = np.zeros((4, 2), dtype=np.float32)
21 | dummy_leftMat[0, 0] = 1
22 | dummy_leftMat[1, 1] = 1
23 | dummy_rightMat = np.zeros((1, 4), dtype=np.float32)
24 | dummy_rightMat[0, 0] = 1
25 | dummy_rightMat[0, 1] = 1
26 | dummy_residual = np.zeros((4, 4), dtype=np.float32)
27 | dummy_residual[0, 2] = 0.5
28 | dummy_residual[1, 2] = 0.5
29 | dummy_residual[2, 2] = 1
30 | dummy_residual[3, 3] = 1
31 | dummy_identity = torch.eye(4).unsqueeze(0).expand(N, -1, -1)
32 |
33 | dummy_leftMat = torch.from_numpy(dummy_leftMat).unsqueeze(0).expand(N, -1, -1)
34 | dummy_rightMat = torch.from_numpy(dummy_rightMat)
35 | dummy_residual = torch.from_numpy(dummy_residual).unsqueeze(0).expand(N, -1, -1)
36 |
37 | if not opt.no_cuda:
38 | dummy_identity = dummy_identity.cuda()
39 | dummy_leftMat = dummy_leftMat.cuda()
40 | dummy_rightMat = dummy_rightMat.cuda()
41 | dummy_residual = dummy_residual.cuda()
42 |
43 | #print(dummy_leftMat.shape, fxy.shape)
44 | fxy = torch.matmul(dummy_leftMat, fxy.unsqueeze(-1)) # Nxd dxp --> Nxp
45 | fxy = torch.matmul(fxy, dummy_rightMat) # Nxd dxp --> Nxp
46 |
47 | fxy = fxy * dummy_identity + dummy_residual
48 | #print(fxy.shape)
49 | return fxy
50 |
51 |
52 | def disp_to_depth(disp, min_depth, max_depth):
53 | """Convert network's sigmoid output into depth predictcion
54 | The formula for this conversion is given in the 'additional considerations'
55 | section of the paper.
56 | """
57 | min_disp = 1 / max_depth
58 | max_disp = 1 / min_depth
59 | scaled_disp = min_disp + (max_disp - min_disp) * disp
60 | depth = 1 / scaled_disp
61 | return scaled_disp, depth
62 |
63 |
64 | def transformation_from_parameters(axisangle, translation, invert=False):
65 | """Convert the network's (axisangle, translation) output into a 4x4 matrix
66 | """
67 | R = rot_from_axisangle(axisangle)
68 | t = translation.clone()
69 |
70 | if invert:
71 | R = R.transpose(1, 2)
72 | t *= -1
73 |
74 | T = get_translation_matrix(t)
75 |
76 | if invert:
77 | M = torch.matmul(R, T)
78 | else:
79 | M = torch.matmul(T, R)
80 |
81 | return M
82 |
83 |
84 | def get_translation_matrix(translation_vector):
85 | """Convert a translation vector into a 4x4 transformation matrix
86 | """
87 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
88 |
89 | t = translation_vector.contiguous().view(-1, 3, 1)
90 |
91 | T[:, 0, 0] = 1
92 | T[:, 1, 1] = 1
93 | T[:, 2, 2] = 1
94 | T[:, 3, 3] = 1
95 | T[:, :3, 3, None] = t
96 |
97 | return T
98 |
99 |
100 | def rot_from_axisangle(vec):
101 | """Convert an axisangle rotation into a 4x4 transformation matrix
102 | (adapted from https://github.com/Wallacoloo/printipi)
103 | Input 'vec' has to be Bx1x3
104 | """
105 | angle = torch.norm(vec, 2, 2, True) # p=2, dim=2, keepdim=True
106 | axis = vec / (angle + 1e-7)
107 |
108 | ca = torch.cos(angle)
109 | sa = torch.sin(angle)
110 | C = 1 - ca
111 |
112 | x = axis[..., 0].unsqueeze(1)
113 | y = axis[..., 1].unsqueeze(1)
114 | z = axis[..., 2].unsqueeze(1)
115 |
116 | xs = x * sa
117 | ys = y * sa
118 | zs = z * sa
119 | xC = x * C
120 | yC = y * C
121 | zC = z * C
122 | xyC = x * yC
123 | yzC = y * zC
124 | zxC = z * xC
125 |
126 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
127 |
128 | rot[:, 0, 0] = torch.squeeze(x * xC + ca)
129 | rot[:, 0, 1] = torch.squeeze(xyC - zs)
130 | rot[:, 0, 2] = torch.squeeze(zxC + ys)
131 | rot[:, 1, 0] = torch.squeeze(xyC + zs)
132 | rot[:, 1, 1] = torch.squeeze(y * yC + ca)
133 | rot[:, 1, 2] = torch.squeeze(yzC - xs)
134 | rot[:, 2, 0] = torch.squeeze(zxC - ys)
135 | rot[:, 2, 1] = torch.squeeze(yzC + xs)
136 | rot[:, 2, 2] = torch.squeeze(z * zC + ca)
137 | rot[:, 3, 3] = 1
138 |
139 | return rot
140 |
141 |
142 | class ConvBlock(nn.Module):
143 | """Layer to perform a convolution followed by ELU
144 | """
145 | def __init__(self, in_channels, out_channels):
146 | super(ConvBlock, self).__init__()
147 |
148 | self.conv = Conv3x3(in_channels, out_channels)
149 | self.nonlin = nn.ELU(inplace=True)
150 | #self.nonlin = nn.ReLU()
151 |
152 | def forward(self, x):
153 | out = self.conv(x)
154 | out = self.nonlin(out)
155 | return out
156 |
157 |
158 | class Conv3x3(nn.Module):
159 | """Layer to pad and convolve input
160 | """
161 | def __init__(self, in_channels, out_channels, use_refl=True):
162 | super(Conv3x3, self).__init__()
163 |
164 | if use_refl:
165 | self.pad = nn.ReflectionPad2d(1)
166 | else:
167 | self.pad = nn.ZeroPad2d(1)
168 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
169 |
170 | def forward(self, x):
171 | out = self.pad(x)
172 | out = self.conv(out)
173 | return out
174 |
175 |
176 | class BackprojectDepth(nn.Module):
177 | """Layer to transform a depth image into a point cloud
178 | """
179 | def __init__(self, batch_size, height, width):
180 | super(BackprojectDepth, self).__init__()
181 |
182 | self.batch_size = batch_size
183 | self.height = height
184 | self.width = width
185 |
186 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
187 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
188 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords))
189 |
190 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width))
191 |
192 | self.pix_coords = torch.unsqueeze(torch.stack(
193 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
194 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
195 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1))
196 |
197 | def forward(self, depth, inv_K):
198 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
199 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points
200 | cam_points = torch.cat([cam_points, self.ones], 1)
201 |
202 | return cam_points
203 |
204 |
205 | class Project3D(nn.Module):
206 | """Layer which projects 3D points into a camera with intrinsics K and at position T
207 | """
208 | def __init__(self, batch_size, height, width, eps=1e-7):
209 | super(Project3D, self).__init__()
210 |
211 | self.batch_size = batch_size
212 | self.height = height
213 | self.width = width
214 | self.eps = eps
215 |
216 | def forward(self, points, K, T):
217 | P = torch.matmul(K, T)[:, :3, :]
218 |
219 | cam_points = torch.matmul(P, points)
220 |
221 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
222 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
223 | pix_coords = pix_coords.permute(0, 2, 3, 1)
224 | pix_coords[..., 0] /= self.width - 1
225 | pix_coords[..., 1] /= self.height - 1
226 | pix_coords = (pix_coords - 0.5) * 2
227 | return pix_coords
228 |
229 |
230 | def upsample(x):
231 | """Upsample input tensor by a factor of 2
232 | """
233 | return F.interpolate(x, scale_factor=2, mode="nearest")
234 |
235 |
236 | def get_smooth_loss(disp, img):
237 | """Computes the smoothness loss for a disparity image
238 | The color image is used for edge-aware smoothness
239 | """
240 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
241 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
242 |
243 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
244 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
245 |
246 | grad_disp_x *= torch.exp(-grad_img_x)
247 | grad_disp_y *= torch.exp(-grad_img_y)
248 |
249 | return grad_disp_x.mean() + grad_disp_y.mean()
250 |
251 |
252 | class SSIM(nn.Module):
253 | """Layer to compute the SSIM loss between a pair of images
254 | """
255 | def __init__(self):
256 | super(SSIM, self).__init__()
257 | self.mu_x_pool = nn.AvgPool2d(3, 1)
258 | self.mu_y_pool = nn.AvgPool2d(3, 1)
259 | self.sig_x_pool = nn.AvgPool2d(3, 1)
260 | self.sig_y_pool = nn.AvgPool2d(3, 1)
261 | self.sig_xy_pool = nn.AvgPool2d(3, 1)
262 |
263 | self.refl = nn.ReflectionPad2d(1)
264 |
265 | self.C1 = 0.01 ** 2
266 | self.C2 = 0.03 ** 2
267 |
268 | def forward(self, x, y):
269 | x = self.refl(x)
270 | y = self.refl(y)
271 |
272 | mu_x = self.mu_x_pool(x)
273 | mu_y = self.mu_y_pool(y)
274 |
275 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
276 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
277 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
278 |
279 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
280 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
281 |
282 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
283 |
284 |
285 | def compute_depth_errors(gt, pred):
286 | """Computation of error metrics between predicted and ground truth depths
287 | """
288 | thresh = torch.max((gt / pred), (pred / gt))
289 | a1 = (thresh < 1.25).float().mean()
290 | a2 = (thresh < 1.25 ** 2).float().mean()
291 | a3 = (thresh < 1.25 ** 3).float().mean()
292 |
293 | rmse = (gt - pred) ** 2
294 | rmse = torch.sqrt(rmse.mean())
295 |
296 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
297 | rmse_log = torch.sqrt(rmse_log.mean())
298 |
299 | abs_rel = torch.mean(torch.abs(gt - pred) / gt)
300 |
301 | sq_rel = torch.mean((gt - pred) ** 2 / gt)
302 |
303 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
304 |
--------------------------------------------------------------------------------
/utils/eval_funcs.py:
--------------------------------------------------------------------------------
1 | import os, random, time, copy
2 | from skimage import io, transform
3 | import numpy as np
4 | import os.path as path
5 | import scipy.io as sio
6 | import matplotlib.pyplot as plt
7 | from PIL import Image
8 | import PIL.Image
9 | from sklearn.metrics import roc_curve, roc_auc_score, f1_score
10 | import pandas as pd
11 |
12 | import torch
13 | from torch.utils.data import Dataset, DataLoader
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | from torch.optim import lr_scheduler
17 | import torch.nn.functional as F
18 | from torch.autograd import Variable
19 |
20 | import torchvision
21 | from torchvision import models, transforms
22 |
23 | import sklearn.metrics
24 |
25 | def F_measure(preds, labels, openset=False, theta=None):
26 | if openset:
27 | # f1 score for openset evaluation
28 | true_pos = 0.
29 | false_pos = 0.
30 | false_neg = 0.
31 | for i in range(len(labels)):
32 | true_pos += 1 if preds[i] == labels[i] and labels[i] != -1 else 0
33 | false_pos += 1 if preds[i] != labels[i] and labels[i] != -1 else 0
34 | false_neg += 1 if preds[i] != labels[i] and labels[i] == -1 else 0
35 |
36 | precision = true_pos / (true_pos + false_pos)
37 | recall = true_pos / (true_pos + false_neg)
38 | return 2 * ((precision * recall) / (precision + recall + 1e-12))
39 | else: # Regular f1 score
40 | return f1_score(labels, preds, average='macro')
41 |
42 | #
43 | # ref: https://github.com/lwneal/counterfactual-open-set/blob/master/generativeopenset/evaluation.py
44 | class ClassCentroids(nn.Module):
45 | def __init__(self, num_classes=10, feat_dim=2, device='cpu'):
46 | super(ClassCentroids, self).__init__()
47 | self.num_classes = num_classes
48 | self.feat_dim = feat_dim
49 | self.centers = torch.randn(self.num_classes, self.feat_dim)
50 | #self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
51 | self.device = device
52 | if self.device!='cpu':
53 | self.centers.to(self.device)
54 |
55 | def forward(self, x, labels):
56 | batch_size = x.size(0)
57 | # ||x-y||_2 = (x-y)^2 = x^2 + y^2 - 2xy
58 | # This part of the calculation is “x^2+y^2”
59 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
60 | # This part is "x^2+y^2 - 2xy"
61 | distmat.addmm_(1, -2, x, self.centers.t())
62 |
63 | classes = torch.arange(self.num_classes).long().to(self.device)
64 | if self.device!='cpu':
65 | classes = classes.to(self.device)
66 |
67 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
68 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
69 |
70 | self.curDistMat = distmat
71 |
72 | dist = distmat * mask.float()
73 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
74 |
75 | return loss
76 |
77 | class CosCentroid(nn.Module):
78 | def __init__(self, num_classes=10, feat_dim=2, device='cpu'):
79 | super(CosCentroid, self).__init__()
80 | self.num_classes = num_classes
81 | self.feat_dim = feat_dim
82 | self.centers = torch.randn(self.num_classes, self.feat_dim)
83 | self.device = device
84 | #self.centers = F.normalize(self.centers, p=2, dim=1)
85 | if self.device!='cpu':
86 | self.centers.to(self.device)
87 |
88 | def forward(self, x, label=0):
89 | x = F.normalize(x, p=2, dim=1)
90 | distmat = torch.zeros((x.shape[0], self.centers.shape[0])).to(self.device)
91 | distmat.addmm_(0, -1, x, self.centers.t())
92 | self.curDistMat = distmat
93 | return self.curDistMat
94 |
95 |
96 |
97 | def pca(X=np.array([]), no_dims=50):
98 | """
99 | Runs PCA on the NxD array X in order to reduce its dimensionality to
100 | no_dims dimensions.
101 | """
102 |
103 | print("Preprocessing the data using PCA...")
104 | (n, d) = X.shape
105 | m = np.mean(X, 0)
106 | X = X - np.tile(m, (n, 1))
107 | (l, M) = np.linalg.eig(np.dot(X.T, X))
108 | P = M[:, 0:no_dims]
109 | Y = np.dot(X, P)
110 | return Y, m, P
111 |
112 |
113 |
114 |
115 |
116 | def FetchFromSingleImage(curImg, cropSize=64, scaleList=[64, 78, 96, 128]):
117 | imgBatchList = []
118 |
119 | for curSize in scaleList:
120 | curImg = curImg.resize((curSize, curSize))
121 | curTransform = transforms.Compose([
122 | transforms.TenCrop(cropSize, vertical_flip=False),
123 | transforms.Lambda(lambda crops: torch.stack([
124 | transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor
125 | transforms.Lambda(lambda crops: torch.stack([
126 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(crop) for crop in crops])),
127 | ])
128 | imgBatchList += list(curTransform(curImg).unsqueeze(0))
129 |
130 |
131 | curImg = curImg.resize((cropSize,cropSize))
132 | curTransform = transforms.Compose([
133 | transforms.ToTensor(),
134 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
135 | ])
136 |
137 | imgBatchList += [curTransform(curImg).unsqueeze(0).clone()]
138 | imgBatchList += [curTransform(curImg.transpose(PIL.Image.FLIP_LEFT_RIGHT)).unsqueeze(0).clone()]
139 | imgBatchList = torch.cat(imgBatchList, 0)
140 | return imgBatchList
141 |
142 |
143 |
144 | class CustomizedPoolList(nn.Module):
145 | def __init__(self, poolSizeList=[32,32,16,8,4], poolType='max'):
146 | super(CustomizedPoolList, self).__init__()
147 |
148 | self.poolSizeList = poolSizeList
149 | self.poolType = poolType
150 | #self.linearLayers = OrderedDict()
151 | self.relu = nn.ReLU()
152 | #self.mnist_clsnet = nn.ModuleList(list(self.linearLayers.values()))
153 |
154 | def forward(self, feaList):
155 | x = []
156 | if self.poolType=='max':
157 | for i in range(len(self.poolSizeList)):
158 | if self.poolSizeList[i]>0:
159 | x += [F.max_pool2d(feaList[i], self.poolSizeList[i])]
160 | elif self.poolType=='avg':
161 | for i in range(len(self.poolSizeList)):
162 | if self.poolSizeList[i]>0:
163 | x += [F.avg_pool2d(feaList[i], self.poolSizeList[i])]
164 |
165 | x = torch.cat(x, 1)
166 | x = x.view(x.shape[0], -1)
167 | return x
168 |
169 |
170 |
171 | class weightedL1Loss(nn.Module):
172 | def __init__(self, weight=1):
173 | # mean over all
174 | super(weightedL1Loss, self).__init__()
175 | self.loss = nn.L1Loss()
176 | self.weight = weight
177 |
178 | def forward(self, inputs, target):
179 | lossValue = self.weight * self.loss(inputs, target)
180 | return lossValue
181 |
182 |
183 |
184 |
185 | class MetricLoss(nn.Module):
186 | """inner-class compactness, aka Center loss.
187 |
188 | Reference:
189 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
190 |
191 | Args:
192 | num_classes (int): number of classes.
193 | feat_dim (int): feature dimension.
194 | """
195 | def __init__(self, num_classes=10, feat_dim=2,
196 | weightCompactness=0.2,
197 | weightInner=1,
198 | weightInter=1.,
199 | marginAlpha=0.2,
200 | sepMultiplier=3,
201 | device='cpu'):
202 | super(MetricLoss, self).__init__()
203 | self.num_classes = num_classes
204 | self.feat_dim = feat_dim
205 | self.weightCompactness = weightCompactness
206 | self.weightInner = weightInner
207 | self.weightInter = weightInter
208 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
209 | self.device = device
210 | if self.device!='cpu':
211 | self.centers.to(self.device)
212 | self.curDistMat = 0
213 | self.lossInner = 0
214 | self.lossInter = 0
215 | self.marginAlpha = marginAlpha
216 | self.sepMultiplier = sepMultiplier
217 | self.classes = torch.arange(self.num_classes).long().to(self.device)
218 |
219 | def forward(self, x, labels):
220 | """
221 | Args:
222 | x: feature matrix with shape (batch_size, feat_dim).
223 | labels: ground truth labels with shape (batch_size).
224 | """
225 | batch_size = x.size(0)
226 | # ||x-y||_2 = (x-y)^2 = x^2 + y^2 - 2xy
227 | # This part of the calculation is “x^2+y^2”
228 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
229 | # This part is "x^2+y^2 - 2xy"
230 | distmat.addmm_(1, -2, x, self.centers.t())
231 |
232 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
233 | mask = labels.eq(self.classes.expand(batch_size, self.num_classes))
234 |
235 | self.curDistMat = distmat
236 | #print('self.curDistMat: ', self.curDistMat.shape)
237 |
238 | # inner loss
239 | dist = distmat * mask.float()
240 | self.lossInner = (dist-self.marginAlpha).clamp(min=0)
241 | self.lossInner = self.lossInner.mean()*self.weightInner # / batch_size
242 |
243 | # compactness loss
244 | loss = dist.clamp(min=1e-12, max=1e+12).mean() / batch_size
245 |
246 | # inter loss
247 | # distance between centroids should be at least three times larger than the defined margin alpha
248 | self.lossInter = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, self.num_classes) + torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, self.num_classes).t()
249 | self.lossInter.addmm_(1, -2, self.centers, self.centers.t())
250 | tmpMask = 1-torch.eye(self.num_classes).float().to(self.device)
251 | #tmpMask = tmpMask.reshape((1, self.num_classes, self.num_classes))
252 | #tmpMask = tmpMask.repeat(batch_size, 1, 1).to(self.device)
253 | self.lossInter = (self.marginAlpha*self.sepMultiplier-self.lossInter).clamp(min=0)
254 | self.lossInter = self.lossInter*tmpMask
255 | self.lossInter = self.lossInter.sum()*self.weightInter
256 |
257 | return loss*self.weightCompactness
258 |
259 |
260 |
261 |
262 | def evaluate_openset(scores_closeset, scores_openset):
263 | y_true = np.array([0] * len(scores_closeset) + [1] * len(scores_openset))
264 | y_discriminator = np.concatenate([scores_closeset, scores_openset])
265 | auc_d, roc_to_plot = plot_roc(y_true, y_discriminator, 'Discriminator ROC')
266 | return auc_d, roc_to_plot
267 |
268 |
269 | def plot_roc(y_true, y_score, title="Receiver Operating Characteristic", **options):
270 | fpr, tpr, thresholds = roc_curve(y_true, y_score)
271 | auc_score = roc_auc_score(y_true, y_score)
272 | roc_to_plot = {'tp':tpr, 'fp':fpr, 'thresh':thresholds, 'auc_score':auc_score}
273 | #plot = plot_xy(fpr, tpr, x_axis="False Positive Rate", y_axis="True Positive Rate", title=title)
274 | #if options.get('roc_output'):
275 | # print("Saving ROC scores to file")
276 | # np.save(options['roc_output'], (fpr, tpr))
277 | #return auc_score, plot, roc_to_plot
278 | return auc_score, roc_to_plot
279 |
280 |
281 | def plot_xy(x, y, x_axis="X", y_axis="Y", title="Plot"):
282 | df = pd.DataFrame({'x': x, 'y': y})
283 | plot = df.plot(x='x', y='y')
284 |
285 | plot.grid(b=True, which='major')
286 | plot.grid(b=True, which='minor')
287 |
288 | plot.set_title(title)
289 | plot.set_ylabel(y_axis)
290 | plot.set_xlabel(x_axis)
291 | return plot
292 |
293 |
294 | def backup_Weibull():
295 | print("Weibull: computing features for all correctly-classified training data")
296 | activation_vectors = {}
297 | for images, labels in dataloader_train_closeset:
298 | images = images.to(device)
299 | labels = labels.type(torch.long).view(-1).to(device)
300 |
301 | embFeature = encoder(images)
302 | logits = clsModel(embFeature)
303 | #logits = F.softmax(logits, dim=1)
304 |
305 | correctly_labeled = (logits.data.max(1)[1] == labels)
306 | labels_np = labels.cpu().numpy()
307 | logits_np = logits.data.cpu().numpy()
308 | for i, label in enumerate(labels_np):
309 | if not correctly_labeled[i]:
310 | continue
311 | if label not in activation_vectors:
312 | activation_vectors[label] = []
313 | activation_vectors[label].append(logits_np[i])
314 |
315 | print("Computed activation_vectors for {} known classes".format(len(activation_vectors)))
316 | for class_idx in activation_vectors:
317 | print("Class {}: {} images".format(class_idx, len(activation_vectors[class_idx])))
318 |
319 | # Compute a mean activation vector for each class
320 | print("Weibull computing mean activation vectors...")
321 | mean_activation_vectors = {}
322 | for class_idx in activation_vectors:
323 | mean_activation_vectors[class_idx] = np.array(activation_vectors[class_idx]).mean(axis=0)
324 |
325 | WEIBULL_TAIL_SIZE = 20
326 | # Initialize one libMR Wiebull object for each class
327 | print("Fitting Weibull to distance distribution of each class")
328 | weibulls = {}
329 | for class_idx in activation_vectors:
330 | distances = []
331 | mav = mean_activation_vectors[class_idx]
332 | for v in activation_vectors[class_idx]:
333 | distances.append(np.linalg.norm(v - mav))
334 | mr = libmr.MR()
335 | tail_size = min(len(distances), WEIBULL_TAIL_SIZE)
336 | mr.fit_high(distances, tail_size)
337 | weibulls[class_idx] = mr
338 | print("Weibull params for class {}: {}".format(class_idx, mr.get_params()))
339 |
340 |
341 | # Apply Weibull score to every logit
342 | weibull_scores_closeset = []
343 | logits_closeset = []
344 | classes = activation_vectors.keys()
345 | for images, labels in dataloader_test_closeset:
346 | images = images.to(device)
347 | labels = labels.type(torch.long).view(-1).to(device)
348 | embFeature = encoder(images)
349 | batch_logits = clsModel(embFeature).data.cpu().numpy()
350 | batch_weibull = np.zeros(shape=batch_logits.shape)
351 | for activation_vector in batch_logits:
352 | weibull_row = np.ones(len(classes))
353 | for class_idx in classes:
354 | mav = mean_activation_vectors[class_idx]
355 | dist = np.linalg.norm(activation_vector - mav)
356 | weibull_row[class_idx] = 1 - weibulls[class_idx].w_score(dist)
357 | weibull_scores_closeset.append(weibull_row)
358 | logits_closeset.append(activation_vector)
359 |
360 | weibull_scores_closeset = np.array(weibull_scores_closeset)
361 | logits_closeset = np.array(logits_closeset)
362 | openmax_scores_closeset = -np.log(np.sum(np.exp(logits_closeset * weibull_scores_closeset), axis=1))
363 |
364 |
365 | # Apply Weibull score to every logit
366 | weibull_scores_openset = []
367 | logits_openset = []
368 | classes = activation_vectors.keys()
369 | for images, labels in dataloader_test_openset:
370 | images = images.to(device)
371 | labels = labels.type(torch.long).view(-1).to(device)
372 | embFeature = encoder(images)
373 | batch_logits = clsModel(embFeature).data.cpu().numpy()
374 | batch_weibull = np.zeros(shape=batch_logits.shape)
375 | for activation_vector in batch_logits:
376 | weibull_row = np.ones(len(classes))
377 | for class_idx in classes:
378 | mav = mean_activation_vectors[class_idx]
379 | dist = np.linalg.norm(activation_vector - mav)
380 | weibull_row[class_idx] = 1 - weibulls[class_idx].w_score(dist)
381 | weibull_scores_openset.append(weibull_row)
382 | logits_openset.append(activation_vector)
383 |
384 | weibull_scores_openset = np.array(weibull_scores_openset)
385 | logits_openset = np.array(logits_openset)
386 | openmax_scores_openset = -np.log(np.sum(np.exp(logits_openset * weibull_scores_openset), axis=1))
--------------------------------------------------------------------------------
/utils/dataset_cifar10.py:
--------------------------------------------------------------------------------
1 | import os, random, time, copy
2 | from skimage import io, transform
3 | import numpy as np
4 | import os.path as path
5 | import scipy.io as sio
6 | from scipy import misc
7 | import matplotlib.pyplot as plt
8 | import PIL.Image
9 | import pickle
10 | import skimage.transform
11 |
12 | import torch
13 | from torch.utils.data import Dataset, DataLoader
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | from torch.optim import lr_scheduler
17 | import torch.nn.functional as F
18 | from torch.autograd import Variable
19 |
20 | import torchvision
21 | from torchvision import datasets, models, transforms
22 |
23 |
24 |
25 |
26 |
27 |
28 | class CIFAR_OneClass4Train(Dataset):
29 | def __init__(self, size=(32,32), set_name='train',
30 | numKnown=6, numTotal=10, runIdx=0,
31 | classLabelIndex=0,
32 | path_to_data='/scratch/shuk/dataset/cifar10/cifar-10-batches-py', isOpenset=True,
33 | isAugment=True):
34 | self.classLabelIndex = classLabelIndex
35 | self.isAugment = isAugment
36 | self.set_name = set_name
37 | self.size = size
38 | self.numTotal = numTotal
39 | self.numKnown = numKnown
40 | self.runIdx = runIdx
41 | self.isOpenset = isOpenset
42 | self.path_to_data = path_to_data
43 |
44 | ######### get the data
45 | # train set
46 | curpath = path.join(self.path_to_data, 'data_batch_1')
47 | with open(curpath, 'rb') as fo:
48 | curpath = pickle.load(fo, encoding='bytes')
49 |
50 | self.imgList = curpath[b'data'].copy()
51 | self.labelList = curpath[b'labels'].copy()
52 |
53 | for i in range(2, 6):
54 | curpath = path.join(path_to_data, 'data_batch_{}'.format(i))
55 | with open(curpath, 'rb') as fo:
56 | curpath = pickle.load(fo, encoding='bytes')
57 | self.imgList = np.concatenate((self.imgList, curpath[b'data'].copy()))
58 | self.labelList += curpath[b'labels'].copy()
59 | del curpath
60 |
61 | ####### set pre-processing operations
62 | self.transform = transforms.Compose([
63 | transforms.ToTensor(),
64 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
65 | ])
66 | #self.transform = transforms.Compose([
67 | # transforms.RandomCrop(32, padding=4),
68 | # transforms.RandomHorizontalFlip(),
69 | # transforms.ToTensor(),
70 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
71 | #])
72 |
73 | self.imgList = np.reshape(self.imgList, (self.imgList.shape[0], 3, 32, 32))
74 | self.size = size
75 | self.labelList = np.asarray(self.labelList).astype(np.float32).reshape((-1, 1))
76 | self.current_set_len = len(self.labelList)
77 |
78 |
79 | ########### shuffle for openset train-test data
80 | random.seed(0)
81 |
82 | self.randShuffleIndexSets = []
83 | self.OpenSetSplit = [
84 | [3, 6, 7, 8],
85 | [1, 2, 4, 6],
86 | [2, 3, 4, 9],
87 | [0, 1, 2, 6],
88 | [4, 5, 6, 9],
89 | [0, 2, 4, 6, 8, 9], # tinyImageNet
90 | ]
91 | for i in range(6):
92 | tmp = list(range(10))
93 | tmpCloseset = list(set(tmp)-set(self.OpenSetSplit[i]))
94 | self.randShuffleIndexSets += [tmpCloseset+self.OpenSetSplit[i]]
95 |
96 |
97 | self.curShuffleSet = self.randShuffleIndexSets[runIdx]
98 | self.closesetActualLabels = self.curShuffleSet[:self.numKnown]
99 | self.opensetActualLabels = self.curShuffleSet[self.numKnown:]
100 | self.labelmapping = {}
101 | self.labelmapping_open = {}
102 |
103 | for i in range(len(self.closesetActualLabels)):
104 | self.labelmapping[self.closesetActualLabels[i]] = i
105 | for j in range(len(self.opensetActualLabels)):
106 | self.labelmapping_open[self.opensetActualLabels[j]] = self.numKnown + j
107 |
108 | self.validList = []
109 | self.newLabel = []
110 | for i in range(len(self.labelList)):
111 | if self.isOpenset:
112 | if self.labelList[i][0] in self.opensetActualLabels:
113 | self.validList += [i]
114 | self.newLabel += [self.labelmapping_open[self.labelList[i][0]]]
115 | else:
116 | if self.labelList[i][0] in self.closesetActualLabels:
117 | tmp_new_label = self.labelmapping[self.labelList[i][0]]
118 | if tmp_new_label==self.classLabelIndex:
119 | self.validList += [i]
120 | self.newLabel += [tmp_new_label]
121 |
122 | self.imgList = self.imgList[self.validList, :]
123 | self.labelList = np.asarray(self.newLabel).reshape((len(self.newLabel),1))
124 | self.current_set_len = len(self.labelList)
125 |
126 | def __len__(self):
127 | return self.current_set_len
128 |
129 | def __getitem__(self, idx):
130 | curImage = self.imgList[idx,:]
131 | curLabel = self.labelList[idx].astype(np.float32)
132 |
133 | curImage = PIL.Image.fromarray(curImage.transpose(1,2,0))
134 | curImage = self.transform(curImage)
135 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0)
136 |
137 | return curImage, curLabel
138 |
139 |
140 |
141 |
142 |
143 |
144 | class CIFAR_OPENSET_CLS(Dataset):
145 | def __init__(self, size=(32,32), set_name='train',
146 | numKnown=6, numTotal=10, runIdx=0,
147 | path_to_data='/scratch/shuk/dataset/cifar10/cifar-10-batches-py', isOpenset=True,
148 | isAugment=True):
149 |
150 | if set_name=='val':
151 | set_name = 'test'
152 |
153 | self.isAugment = isAugment
154 | self.set_name = set_name
155 | self.size = size
156 | self.numTotal = numTotal
157 | self.numKnown = numKnown
158 | self.runIdx = runIdx
159 | self.isOpenset = isOpenset
160 | self.path_to_data = path_to_data
161 |
162 | ######### get the data
163 | if self.set_name=='test':
164 | self.imgList = path.join(self.path_to_data, 'test_batch')
165 | with open(self.imgList, 'rb') as fo:
166 | self.imgList = pickle.load(fo, encoding='bytes')
167 | self.labelList = self.imgList[b'labels'].copy()
168 | self.imgList = self.imgList[b'data']
169 | else: # train set
170 | curpath = path.join(self.path_to_data, 'data_batch_1')
171 | with open(curpath, 'rb') as fo:
172 | curpath = pickle.load(fo, encoding='bytes')
173 |
174 | self.imgList = curpath[b'data'].copy()
175 | self.labelList = curpath[b'labels'].copy()
176 |
177 | for i in range(2, 6):
178 | curpath = path.join(path_to_data, 'data_batch_{}'.format(i))
179 | with open(curpath, 'rb') as fo:
180 | curpath = pickle.load(fo, encoding='bytes')
181 | self.imgList = np.concatenate((self.imgList, curpath[b'data'].copy()))
182 | self.labelList += curpath[b'labels'].copy()
183 | del curpath
184 |
185 |
186 | ####### set pre-processing operations
187 | if self.set_name=='test' or not self.isAugment:
188 | self.transform = transforms.Compose([
189 | transforms.ToTensor(),
190 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
191 | ])
192 | else:
193 | self.transform = transforms.Compose([
194 | transforms.RandomCrop(32, padding=4),
195 | transforms.RandomHorizontalFlip(),
196 | transforms.ToTensor(),
197 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
198 | ])
199 |
200 | self.imgList = np.reshape(self.imgList, (self.imgList.shape[0], 3, 32, 32))
201 | self.size = size
202 | self.labelList = np.asarray(self.labelList).astype(np.float32).reshape((-1, 1))
203 | self.current_set_len = len(self.labelList)
204 |
205 |
206 | ########### shuffle for openset train-test data
207 | random.seed(0)
208 |
209 | self.randShuffleIndexSets = []
210 | self.OpenSetSplit = [
211 | [3, 6, 7, 8],
212 | [1, 2, 4, 6],
213 | [2, 3, 4, 9],
214 | [0, 1, 2, 6],
215 | [4, 5, 6, 9],
216 | [0, 2, 4, 7, 8, 9], # tinyImageNet
217 | ]
218 | for i in range(6):
219 | tmp = list(range(10))
220 | tmpCloseset = list(set(tmp)-set(self.OpenSetSplit[i]))
221 | self.randShuffleIndexSets += [tmpCloseset+self.OpenSetSplit[i]]
222 |
223 | #for i in range(10):
224 | # a = list(range(10))
225 | # random.shuffle(a)
226 | # self.randShuffleIndexSets += [a]
227 |
228 |
229 | self.curShuffleSet = self.randShuffleIndexSets[runIdx]
230 | self.closesetActualLabels = self.curShuffleSet[:self.numKnown]
231 | self.opensetActualLabels = self.curShuffleSet[self.numKnown:]
232 | self.labelmapping = {}
233 | self.labelmapping_open = {}
234 |
235 | for i in range(len(self.closesetActualLabels)):
236 | self.labelmapping[self.closesetActualLabels[i]] = i
237 | for j in range(len(self.opensetActualLabels)):
238 | self.labelmapping_open[self.opensetActualLabels[j]] = self.numKnown + j
239 |
240 |
241 | #self.imgList = np.loadtxt(self.path_to_csv, delimiter=",")
242 | #self.labelList = np.asfarray(self.imgList[:, :1])
243 | #self.imgList = np.asfarray(self.imgList[:, 1:]) * self.fac + 0.01
244 |
245 | self.validList = []
246 | self.newLabel = []
247 | for i in range(len(self.labelList)):
248 | if self.isOpenset:
249 | if self.labelList[i][0] in self.opensetActualLabels:
250 | self.validList += [i]
251 | self.newLabel += [self.labelmapping_open[self.labelList[i][0]]]
252 | else:
253 | if self.labelList[i][0] in self.closesetActualLabels:
254 | self.validList += [i]
255 | self.newLabel += [self.labelmapping[self.labelList[i][0]]]
256 |
257 | self.imgList = self.imgList[self.validList, :]
258 | self.labelList = np.asarray(self.newLabel).reshape((len(self.newLabel),1))
259 | self.current_set_len = len(self.labelList)
260 |
261 | def __len__(self):
262 | return self.current_set_len
263 |
264 | def __getitem__(self, idx):
265 | curImage = self.imgList[idx,:]
266 | curLabel = self.labelList[idx].astype(np.float32)
267 |
268 | #if self.isAugment:
269 | # curImage = PIL.Image.fromarray(curImage.transpose(1,2,0))
270 | # curImage = self.transform(curImage)
271 | #else:
272 | # curImage = torch.from_numpy(curImage.astype(np.float32))
273 |
274 | curImage = PIL.Image.fromarray(curImage.transpose(1,2,0))
275 | curImage = self.transform(curImage)
276 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0)
277 |
278 | '''
279 | curImage = curImage.astype(np.float32)
280 | curLabel = curLabel.astype(np.float32)
281 |
282 | curImage = torch.from_numpy(curImage)
283 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0)
284 | '''
285 | return curImage, curLabel
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 | class CIFAR10_CLS_full_aug(Dataset):
301 | def __init__(self, size=(32,32), set_name='train', path_to_data='/scratch/shuk/dataset/cifar10/cifar-10-batches-py', isAugment=True):
302 | if set_name=='val':
303 | set_name = 'test'
304 | self.set_name = set_name
305 | self.path_to_data = path_to_data
306 | self.isAugment = isAugment
307 |
308 | if self.set_name=='test':
309 | self.imgList = path.join(self.path_to_data, 'test_batch')
310 | with open(self.imgList, 'rb') as fo:
311 | self.imgList = pickle.load(fo, encoding='bytes')
312 | self.labelList = self.imgList[b'labels'].copy()
313 | self.imgList = self.imgList[b'data']
314 | else: # train set
315 | curpath = path.join(self.path_to_data, 'data_batch_1')
316 | with open(curpath, 'rb') as fo:
317 | curpath = pickle.load(fo, encoding='bytes')
318 |
319 | self.imgList = curpath[b'data'].copy()
320 | self.labelList = curpath[b'labels'].copy()
321 |
322 | for i in range(2, 6):
323 | curpath = path.join(path_to_data, 'data_batch_{}'.format(i))
324 | with open(curpath, 'rb') as fo:
325 | curpath = pickle.load(fo, encoding='bytes')
326 | self.imgList = np.concatenate((self.imgList, curpath[b'data'].copy()))
327 | self.labelList += curpath[b'labels'].copy()
328 | del curpath
329 |
330 | if self.set_name=='test':
331 | self.transform = transforms.Compose([
332 | transforms.ToTensor(),
333 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
334 | ])
335 | else:
336 | self.transform = transforms.Compose([
337 | transforms.RandomCrop(32, padding=4),
338 | transforms.RandomHorizontalFlip(),
339 | transforms.ToTensor(),
340 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
341 | ])
342 |
343 | self.imgList = np.reshape(self.imgList, (self.imgList.shape[0], 3, 32, 32))
344 | self.size = size
345 | self.labelList = np.asarray(self.labelList).astype(np.float32)
346 | self.current_set_len = len(self.labelList)
347 |
348 |
349 | def __len__(self):
350 | return self.current_set_len
351 |
352 | def __getitem__(self, idx):
353 | curImage = self.imgList[idx]
354 | curLabel = np.asarray(self.labelList[idx])
355 |
356 | if self.isAugment:
357 | curImage = PIL.Image.fromarray(curImage.transpose(1,2,0))
358 | curImage = self.transform(curImage)
359 | else:
360 | curImage = torch.from_numpy(curImage)
361 |
362 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0)
363 |
364 | return curImage, curLabel
365 |
366 |
367 |
368 | class CIFAR10_CLS_full(Dataset):
369 | def __init__(self, size=(32,32), set_name='train', path_to_data='/scratch/shuk/dataset/cifar10/cifar-10-batches-py'):
370 | if set_name=='val':
371 | set_name = 'test'
372 | self.set_name = set_name
373 | self.path_to_data = path_to_data
374 |
375 | if self.set_name=='test':
376 | self.imgList = path.join(self.path_to_data, 'test_batch')
377 | with open(self.imgList, 'rb') as fo:
378 | self.imgList = pickle.load(fo, encoding='bytes')
379 | self.labelList = self.imgList[b'labels'].copy()
380 | self.imgList = self.imgList[b'data']
381 | else: # train set
382 | curpath = path.join(self.path_to_data, 'data_batch_1')
383 | with open(curpath, 'rb') as fo:
384 | curpath = pickle.load(fo, encoding='bytes')
385 |
386 | self.imgList = curpath[b'data'].copy()
387 | self.labelList = curpath[b'labels'].copy()
388 |
389 | for i in range(2, 6):
390 | curpath = path.join(path_to_data, 'data_batch_{}'.format(i))
391 | with open(curpath, 'rb') as fo:
392 | curpath = pickle.load(fo, encoding='bytes')
393 | self.imgList = np.concatenate((self.imgList, curpath[b'data'].copy()))
394 | self.labelList += curpath[b'labels'].copy()
395 | del curpath
396 |
397 | self.imgList = np.reshape(self.imgList, (self.imgList.shape[0], 3, 32, 32))
398 | self.size = size
399 | self.fac = 0.99 / 255
400 | self.labelList = np.asarray(self.labelList).astype(np.float32)
401 | self.imgList = self.imgList.astype(np.float32) * self.fac + 0.01
402 | self.current_set_len = len(self.labelList)
403 |
404 | def __len__(self):
405 | return self.current_set_len
406 |
407 | def __getitem__(self, idx):
408 | curImage = self.imgList[idx].astype(np.float32)
409 | curLabel = np.asarray(self.labelList[idx]).astype(np.float32)
410 |
411 | curImage = torch.from_numpy(curImage)
412 | curLabel = torch.from_numpy(curLabel).unsqueeze(0).unsqueeze(0)
413 |
414 | return curImage, curLabel
415 |
416 |
417 |
--------------------------------------------------------------------------------
/utils/dataset_cityscapes.py:
--------------------------------------------------------------------------------
1 | import os, random, time, copy
2 | from skimage import io, transform
3 | import json
4 | import numpy as np
5 | from subprocess import check_output
6 | import numpy as np
7 | import os.path as path
8 | import scipy.io as sio
9 | from scipy import misc
10 | import matplotlib.pyplot as plt
11 | from PIL import Image
12 | from tqdm import tqdm
13 | import pickle
14 | import skimage.transform
15 | import csv
16 | import torch
17 | from torch.utils.data import Dataset, DataLoader
18 | import torch.nn as nn
19 | import torch.optim as optim
20 | from torch.optim import lr_scheduler
21 | import torch.nn.functional as F
22 | from torch.autograd import Variable
23 | import torchvision
24 | from torchvision import datasets, models, transforms
25 | from collections import namedtuple
26 |
27 |
28 | class Cityscapes(Dataset):
29 | """`Cityscapes `_ Dataset.
30 |
31 | Args:
32 | root (string): Root directory of dataset where directory ``leftImg8bit``
33 | and ``gtFine`` or ``gtCoarse`` are located.
34 | split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
35 | otherwise ``train``, ``train_extra`` or ``val``
36 | mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
37 | target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
38 | or ``color``. Can also be a list to output a tuple with all specified target types.
39 | transform (callable, optional): A function/transform that takes in a PIL image
40 | and returns a transformed version. E.g, ``transforms.RandomCrop``
41 | target_transform (callable, optional): A function/transform that takes in the
42 | target and transforms it.
43 | transforms (callable, optional): A function/transform that takes input sample and its target as entry
44 | and returns a transformed version.
45 |
46 | Examples:
47 |
48 | Get semantic segmentation target
49 |
50 | .. code-block:: python
51 |
52 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
53 | target_type='semantic')
54 |
55 | img, smnt = dataset[0]
56 |
57 | Get multiple targets
58 |
59 | .. code-block:: python
60 |
61 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
62 | target_type=['instance', 'color', 'polygon'])
63 |
64 | img, (inst, col, poly) = dataset[0]
65 |
66 | Validate on the "coarse" set
67 |
68 | .. code-block:: python
69 |
70 | dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
71 | target_type='semantic')
72 |
73 | img, smnt = dataset[0]
74 | """
75 |
76 | # Based on https://github.com/mcordts/cityscapesScripts
77 | CityscapesClass = namedtuple('CityscapesClass',
78 | ['name', 'id', 'train_id', 'category', 'category_id',
79 | 'has_instances', 'ignore_in_eval', 'color'])
80 |
81 | classes = [
82 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
83 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
84 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
85 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
86 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
87 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
88 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
89 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
90 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
91 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
92 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
93 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
94 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
95 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
96 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
97 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
98 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
99 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
100 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
101 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
102 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
103 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
104 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
105 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
106 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
107 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
108 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
109 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
110 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
111 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
112 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
113 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
114 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
115 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
116 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
117 | ]
118 |
119 | def __init__(self, root='/home/skong2/restore/dataset/Cityscapes',
120 | newsize=(256, 256),
121 | split='train',
122 | mode='fine',
123 | target_type='semantic',
124 | transform=None,
125 | target_transform=None,
126 | transforms=None):
127 |
128 | #super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
129 | self.newsize = newsize
130 | self.flagResize = True
131 | if newsize[0]<0 or newsize[1]<0:
132 | self.flagResize = False
133 |
134 | self.root = root
135 | self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
136 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
137 | self.targets_dir = os.path.join(self.root, self.mode, split)
138 | self.target_type = target_type
139 | self.split = split
140 | self.images = []
141 | self.targets = []
142 | if self.split=='test':
143 | self.split = 'val'
144 | self.transform = transform
145 | self.transforms = transforms
146 | self.target_transform = target_transform
147 |
148 | #verify_str_arg(mode, "mode", ("fine", "coarse"))
149 | if mode == "fine":
150 | valid_modes = ("train", "test", "val")
151 | else:
152 | valid_modes = ("train", "train_extra", "val")
153 |
154 | msg = ("Unknown value '{}' for argument split if mode is '{}'. "
155 | "Valid values are {{{}}}.")
156 |
157 | #msg = msg.format(split, mode, iterable_to_str(valid_modes))
158 | #verify_str_arg(split, "split", valid_modes, msg)
159 |
160 | if not isinstance(target_type, list):
161 | self.target_type = [target_type]
162 |
163 | #[verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) for value in self.target_type]
164 |
165 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
166 | if split == 'train_extra':
167 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip'))
168 | else:
169 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip'))
170 |
171 | if self.mode == 'gtFine':
172 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip'))
173 | elif self.mode == 'gtCoarse':
174 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip'))
175 |
176 | if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
177 | extract_archive(from_path=image_dir_zip, to_path=self.root)
178 | extract_archive(from_path=target_dir_zip, to_path=self.root)
179 | else:
180 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
181 | ' specified "split" and "mode" are inside the "root" directory')
182 |
183 | for city in os.listdir(self.images_dir):
184 | img_dir = os.path.join(self.images_dir, city)
185 | target_dir = os.path.join(self.targets_dir, city)
186 | for file_name in os.listdir(img_dir):
187 | target_types = []
188 | for t in self.target_type:
189 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
190 | self._get_target_suffix(self.mode, t))
191 | target_types.append(os.path.join(target_dir, target_name))
192 |
193 | self.images.append(os.path.join(img_dir, file_name))
194 | self.targets.append(target_types)
195 |
196 | def __getitem__(self, index):
197 | """
198 | Args:
199 | index (int): Index
200 | Returns:
201 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
202 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
203 | """
204 |
205 | image = Image.open(self.images[index]).convert('RGB')
206 | #b, g, r = image.split()
207 | #image = Image.merge("RGB", (r, g, b))
208 |
209 | if self.flagResize:
210 | image = image.resize(self.newsize, resample=Image.BILINEAR)
211 | #print(self.targets[index], self.target_type)
212 |
213 | targets = []
214 | for i, t in enumerate(self.target_type):
215 | if t == 'polygon':
216 | target = self._load_json(self.targets[index][i])
217 | else:
218 | target = Image.open(self.targets[index][i])
219 |
220 | if self.flagResize:
221 | target = target.resize(self.newsize, resample=Image.NEAREST)
222 |
223 | targets.append(target)
224 |
225 | target = tuple(targets) if len(targets) > 1 else targets[0]
226 |
227 | image = self.transform(image)
228 | #print('image', type(image), image.shape)
229 |
230 | target = np.asarray(target).astype(np.float32)
231 | target = torch.from_numpy(target)
232 | #target = self.target_transform(target)
233 | #print('target', type(target), target.shape)
234 |
235 | if self.transforms is not None:
236 | image, target = self.transforms(image, target)
237 |
238 | return image, target
239 |
240 |
241 | def __len__(self):
242 | return len(self.images)
243 |
244 | def extra_repr(self):
245 | lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
246 | return '\n'.join(lines).format(**self.__dict__)
247 |
248 | def _load_json(self, path):
249 | with open(path, 'r') as file:
250 | data = json.load(file)
251 | return data
252 |
253 | def _get_target_suffix(self, mode, target_type):
254 | if target_type == 'instance':
255 | return '{}_instanceIds.png'.format(mode)
256 | elif target_type == 'semantic':
257 | return '{}_labelIds.png'.format(mode)
258 | elif target_type == 'color':
259 | return '{}_color.png'.format(mode)
260 | else:
261 | return '{}_polygons.json'.format(mode)
262 |
263 |
264 |
265 |
266 |
267 | '''
268 |
269 | Label = namedtuple( 'Label' , [
270 |
271 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
272 | # We use them to uniquely name a class
273 |
274 | 'id' , # An integer ID that is associated with this label.
275 | # The IDs are used to represent the label in ground truth images
276 | # An ID of -1 means that this label does not have an ID and thus
277 | # is ignored when creating ground truth images (e.g. license plate).
278 | # Do not modify these IDs, since exactly these IDs are expected by the
279 | # evaluation server.
280 |
281 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
282 | # ground truth images with train IDs, using the tools provided in the
283 | # 'preparation' folder. However, make sure to validate or submit results
284 | # to our evaluation server using the regular IDs above!
285 | # For trainIds, multiple labels might have the same ID. Then, these labels
286 | # are mapped to the same class in the ground truth images. For the inverse
287 | # mapping, we use the label that is defined first in the list below.
288 | # For example, mapping all void-type classes to the same ID in training,
289 | # might make sense for some approaches.
290 | # Max value is 255!
291 |
292 | 'category' , # The name of the category that this label belongs to
293 |
294 | 'categoryId' , # The ID of this category. Used to create ground truth images
295 | # on category level.
296 |
297 | 'hasInstances', # Whether this label distinguishes between single instances or not
298 |
299 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
300 | # during evaluations or not
301 |
302 | 'color' , # The color of this label
303 | ] )
304 |
305 |
306 |
307 | labels = [
308 | # name id trainId category catId hasInstances ignoreInEval color
309 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
310 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
311 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
312 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
313 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
314 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
315 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
316 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
317 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
318 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
319 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
320 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
321 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
322 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
323 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
324 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
325 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
326 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
327 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
328 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
329 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
330 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
331 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
332 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
333 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
334 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
335 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
336 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
337 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
338 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
339 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
340 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
341 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
342 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
343 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
344 | ]
345 |
346 |
347 | '''
--------------------------------------------------------------------------------
/utils/dataset_cityscapes4OpenGAN.py:
--------------------------------------------------------------------------------
1 | import os, random, time, copy
2 | from skimage import io, transform
3 | import json
4 | import numpy as np
5 | from subprocess import check_output
6 | import numpy as np
7 | import os.path as path
8 | import scipy.io as sio
9 | from scipy import misc
10 | import matplotlib.pyplot as plt
11 | from PIL import Image
12 | from tqdm import tqdm
13 | import pickle
14 | import skimage.transform
15 | import csv
16 | import torch
17 | from torch.utils.data import Dataset, DataLoader
18 | import torch.nn as nn
19 | import torch.optim as optim
20 | from torch.optim import lr_scheduler
21 | import torch.nn.functional as F
22 | from torch.autograd import Variable
23 | import torchvision
24 | from torchvision import datasets, models, transforms
25 | from collections import namedtuple
26 |
27 |
28 | class Cityscapes(Dataset):
29 | """`Cityscapes `_ Dataset.
30 |
31 | Args:
32 | root (string): Root directory of dataset where directory ``leftImg8bit``
33 | and ``gtFine`` or ``gtCoarse`` are located.
34 | split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
35 | otherwise ``train``, ``train_extra`` or ``val``
36 | mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
37 | target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
38 | or ``color``. Can also be a list to output a tuple with all specified target types.
39 | transform (callable, optional): A function/transform that takes in a PIL image
40 | and returns a transformed version. E.g, ``transforms.RandomCrop``
41 | target_transform (callable, optional): A function/transform that takes in the
42 | target and transforms it.
43 | transforms (callable, optional): A function/transform that takes input sample and its target as entry
44 | and returns a transformed version.
45 |
46 | Examples:
47 |
48 | Get semantic segmentation target
49 |
50 | .. code-block:: python
51 |
52 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
53 | target_type='semantic')
54 |
55 | img, smnt = dataset[0]
56 |
57 | Get multiple targets
58 |
59 | .. code-block:: python
60 |
61 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
62 | target_type=['instance', 'color', 'polygon'])
63 |
64 | img, (inst, col, poly) = dataset[0]
65 |
66 | Validate on the "coarse" set
67 |
68 | .. code-block:: python
69 |
70 | dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
71 | target_type='semantic')
72 |
73 | img, smnt = dataset[0]
74 | """
75 |
76 | # Based on https://github.com/mcordts/cityscapesScripts
77 | CityscapesClass = namedtuple('CityscapesClass',
78 | ['name', 'id', 'train_id', 'category', 'category_id',
79 | 'has_instances', 'ignore_in_eval', 'color'])
80 |
81 | classes = [
82 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
83 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
84 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
85 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
86 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
87 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
88 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
89 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
90 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
91 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
92 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
93 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
94 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
95 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
96 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
97 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
98 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
99 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
100 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
101 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
102 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
103 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
104 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
105 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
106 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
107 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
108 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
109 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
110 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
111 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
112 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
113 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
114 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
115 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
116 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
117 | ]
118 |
119 | def __init__(self, root='/home/skong2/restore/dataset/Cityscapes',
120 | newsize=(256, 256),
121 | split='train',
122 | mode='fine',
123 | trainnum=10,
124 | target_type='semantic',
125 | transform=None,
126 | target_transform=None,
127 | transforms=None):
128 |
129 | #super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
130 | self.newsize = newsize
131 | self.trainnum = trainnum
132 | self.flagResize = True
133 | if newsize[0]<0 or newsize[1]<0:
134 | self.flagResize = False
135 |
136 | self.root = root
137 | self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
138 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
139 | self.targets_dir = os.path.join(self.root, self.mode, split)
140 | self.target_type = target_type
141 | self.split = split
142 | self.images = []
143 | self.targets = []
144 | if self.split=='test':
145 | self.split = 'val'
146 | self.transform = transform
147 | self.transforms = transforms
148 | self.target_transform = target_transform
149 |
150 | #verify_str_arg(mode, "mode", ("fine", "coarse"))
151 | if mode == "fine":
152 | valid_modes = ("train", "test", "val")
153 | else:
154 | valid_modes = ("train", "train_extra", "val")
155 |
156 | msg = ("Unknown value '{}' for argument split if mode is '{}'. "
157 | "Valid values are {{{}}}.")
158 |
159 | #msg = msg.format(split, mode, iterable_to_str(valid_modes))
160 | #verify_str_arg(split, "split", valid_modes, msg)
161 |
162 | if not isinstance(target_type, list):
163 | self.target_type = [target_type]
164 |
165 | #[verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) for value in self.target_type]
166 |
167 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
168 | if split == 'train_extra':
169 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip'))
170 | else:
171 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip'))
172 |
173 | if self.mode == 'gtFine':
174 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip'))
175 | elif self.mode == 'gtCoarse':
176 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip'))
177 |
178 | if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
179 | extract_archive(from_path=image_dir_zip, to_path=self.root)
180 | extract_archive(from_path=target_dir_zip, to_path=self.root)
181 | else:
182 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
183 | ' specified "split" and "mode" are inside the "root" directory')
184 |
185 | for city in os.listdir(self.images_dir):
186 | img_dir = os.path.join(self.images_dir, city)
187 | target_dir = os.path.join(self.targets_dir, city)
188 | for file_name in os.listdir(img_dir):
189 | target_types = []
190 | for t in self.target_type:
191 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
192 | self._get_target_suffix(self.mode, t))
193 | target_types.append(os.path.join(target_dir, target_name))
194 |
195 | self.images.append(os.path.join(img_dir, file_name))
196 | self.targets.append(target_types)
197 |
198 | if self.split=='train' and self.trainnum>0:
199 | self.images = self.images[:self.trainnum]
200 | self.targets = self.targets[:self.trainnum]
201 | elif self.split=='train' and self.trainnum<-5:
202 | self.images = self.images[self.trainnum:]
203 | self.targets = self.targets[self.trainnum:]
204 | else:
205 | self.trainnum = -1
206 |
207 |
208 | def __getitem__(self, index):
209 | """
210 | Args:
211 | index (int): Index
212 | Returns:
213 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
214 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
215 | """
216 |
217 | image = Image.open(self.images[index]).convert('RGB')
218 | #b, g, r = image.split()
219 | #image = Image.merge("RGB", (r, g, b))
220 |
221 | if self.flagResize:
222 | image = image.resize(self.newsize, resample=Image.BILINEAR)
223 | #print(self.targets[index], self.target_type)
224 |
225 | targets = []
226 | for i, t in enumerate(self.target_type):
227 | if t == 'polygon':
228 | target = self._load_json(self.targets[index][i])
229 | else:
230 | target = Image.open(self.targets[index][i])
231 |
232 | if self.flagResize:
233 | target = target.resize(self.newsize, resample=Image.NEAREST)
234 |
235 | targets.append(target)
236 |
237 | target = tuple(targets) if len(targets) > 1 else targets[0]
238 |
239 | image = self.transform(image)
240 | #print('image', type(image), image.shape)
241 |
242 | target = np.asarray(target).astype(np.float32)
243 | target = torch.from_numpy(target)
244 | #target = self.target_transform(target)
245 | #print('target', type(target), target.shape)
246 |
247 | if self.transforms is not None:
248 | image, target = self.transforms(image, target)
249 |
250 | return image, target
251 |
252 |
253 | def __len__(self):
254 | return len(self.images)
255 |
256 | def extra_repr(self):
257 | lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
258 | return '\n'.join(lines).format(**self.__dict__)
259 |
260 | def _load_json(self, path):
261 | with open(path, 'r') as file:
262 | data = json.load(file)
263 | return data
264 |
265 | def _get_target_suffix(self, mode, target_type):
266 | if target_type == 'instance':
267 | return '{}_instanceIds.png'.format(mode)
268 | elif target_type == 'semantic':
269 | return '{}_labelIds.png'.format(mode)
270 | elif target_type == 'color':
271 | return '{}_color.png'.format(mode)
272 | else:
273 | return '{}_polygons.json'.format(mode)
274 |
275 |
276 |
277 |
278 |
279 | '''
280 |
281 | Label = namedtuple( 'Label' , [
282 |
283 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
284 | # We use them to uniquely name a class
285 |
286 | 'id' , # An integer ID that is associated with this label.
287 | # The IDs are used to represent the label in ground truth images
288 | # An ID of -1 means that this label does not have an ID and thus
289 | # is ignored when creating ground truth images (e.g. license plate).
290 | # Do not modify these IDs, since exactly these IDs are expected by the
291 | # evaluation server.
292 |
293 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
294 | # ground truth images with train IDs, using the tools provided in the
295 | # 'preparation' folder. However, make sure to validate or submit results
296 | # to our evaluation server using the regular IDs above!
297 | # For trainIds, multiple labels might have the same ID. Then, these labels
298 | # are mapped to the same class in the ground truth images. For the inverse
299 | # mapping, we use the label that is defined first in the list below.
300 | # For example, mapping all void-type classes to the same ID in training,
301 | # might make sense for some approaches.
302 | # Max value is 255!
303 |
304 | 'category' , # The name of the category that this label belongs to
305 |
306 | 'categoryId' , # The ID of this category. Used to create ground truth images
307 | # on category level.
308 |
309 | 'hasInstances', # Whether this label distinguishes between single instances or not
310 |
311 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
312 | # during evaluations or not
313 |
314 | 'color' , # The color of this label
315 | ] )
316 |
317 |
318 |
319 | labels = [
320 | # name id trainId category catId hasInstances ignoreInEval color
321 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
322 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
323 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
324 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
325 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
326 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
327 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
328 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
329 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
330 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
331 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
332 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
333 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
334 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
335 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
336 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
337 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
338 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
339 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
340 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
341 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
342 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
343 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
344 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
345 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
346 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
347 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
348 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
349 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
350 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
351 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
352 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
353 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
354 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
355 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
356 | ]
357 |
358 |
359 | '''
--------------------------------------------------------------------------------
/utils/network_arch_tinyimagenet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import numpy as np
3 | import torchvision
4 | from torchvision import datasets, models, transforms
5 | import torch
6 | import torch.nn as nn
7 | from collections import OrderedDict
8 | from utils.layers import *
9 | import torchvision.models as models
10 | import torch.utils.model_zoo as model_zoo
11 | import numpy as np
12 | import os, math
13 | from torch.utils.data import Dataset, DataLoader
14 | import torch.nn as nn
15 |
16 |
17 |
18 |
19 | class Discriminator80x80InstNorm(nn.Module):
20 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], frameStackNumber=3):
21 | super(Discriminator80x80InstNorm, self).__init__()
22 | self.device = device
23 | self.frameStackNumber = frameStackNumber
24 | self.patchSize = patchSize
25 | self.outputSize = [patchSize[0]/16, patchSize[1]/16]
26 |
27 | self.discriminator = nn.Sequential(
28 | # 128-->60
29 | nn.Conv2d(self.frameStackNumber, 64, kernel_size=5, padding=0, stride=2, bias=True),
30 | nn.LeakyReLU(0.2, inplace=True),
31 |
32 | # 60-->33
33 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False),
34 | nn.InstanceNorm2d(128, momentum=0.001, affine=False, track_running_stats=False),
35 | nn.LeakyReLU(0.2, inplace=True),
36 | # 33->
37 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False),
38 | nn.InstanceNorm2d(256, momentum=0.001, affine=False, track_running_stats=False),
39 | nn.LeakyReLU(0.2, inplace=True),
40 | #
41 | nn.Conv2d(256, 512, kernel_size=3, padding=0, stride=2, bias=False),
42 | nn.InstanceNorm2d(512, momentum=0.001, affine=False, track_running_stats=False),
43 | nn.LeakyReLU(0.2, inplace=True),
44 | # final classification for 'real(1) vs. fake(0)'
45 | nn.Conv2d(512, 1, kernel_size=2, padding=0, stride=2, bias=True),
46 | nn.Sigmoid()
47 | )
48 |
49 | def forward(self, X):
50 | return self.discriminator(X)
51 |
52 |
53 |
54 | class Discriminator80x80(nn.Module):
55 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], frameStackNumber=3):
56 | super(Discriminator80x80, self).__init__()
57 | self.device = device
58 | self.frameStackNumber = frameStackNumber
59 | self.patchSize = patchSize
60 | self.outputSize = [patchSize[0]/16, patchSize[1]/16]
61 |
62 | self.discriminator = nn.Sequential(
63 | # 128-->60
64 | nn.Conv2d(self.frameStackNumber, 64, kernel_size=5, padding=0, stride=2, bias=False),
65 | nn.LeakyReLU(0.2, inplace=True),
66 |
67 | # 60-->33
68 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False),
69 | nn.BatchNorm2d(128),
70 | nn.LeakyReLU(0.2, inplace=True),
71 | # 33->
72 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False),
73 | nn.BatchNorm2d(256),
74 | nn.LeakyReLU(0.2, inplace=True),
75 | #
76 | nn.Conv2d(256, 512, kernel_size=3, padding=0, stride=2, bias=False),
77 | nn.BatchNorm2d(512),
78 | nn.LeakyReLU(0.2, inplace=True),
79 | # final classification for 'real(1) vs. fake(0)'
80 | nn.Conv2d(512, 1, kernel_size=2, padding=0, stride=2, bias=True),
81 | nn.Sigmoid()
82 | )
83 |
84 | def forward(self, X):
85 | return self.discriminator(X)
86 |
87 |
88 |
89 | class Discriminator70x70(nn.Module):
90 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], frameStackNumber=3):
91 | super(Discriminator70x70, self).__init__()
92 | self.device = device
93 | self.frameStackNumber = frameStackNumber
94 | self.patchSize = patchSize
95 | self.outputSize = [patchSize[0]/16, patchSize[1]/16]
96 |
97 | self.discriminator = nn.Sequential(
98 | # 128-->60
99 | nn.Conv2d(self.frameStackNumber, 64, kernel_size=4, padding=0, stride=2, bias=False),
100 | nn.LeakyReLU(0.2, inplace=True),
101 |
102 | # 60-->33
103 | nn.Conv2d(64, 128, kernel_size=4, padding=0, stride=2, bias=False),
104 | nn.BatchNorm2d(128),
105 | nn.LeakyReLU(0.2, inplace=True),
106 | # 33->
107 | nn.Conv2d(128, 256, kernel_size=4, padding=0, stride=2, bias=False),
108 | nn.BatchNorm2d(256),
109 | nn.LeakyReLU(0.2, inplace=True),
110 | #
111 | nn.Conv2d(256, 512, kernel_size=4, padding=0, stride=2, bias=False),
112 | nn.BatchNorm2d(512),
113 | nn.LeakyReLU(0.2, inplace=True),
114 | # final classification for 'real(1) vs. fake(0)'
115 | nn.Conv2d(512, 1, kernel_size=2, padding=0, stride=2, bias=True),
116 | nn.Sigmoid()
117 | )
118 |
119 | def forward(self, X):
120 | return self.discriminator(X)
121 |
122 |
123 | class Discriminator(nn.Module):
124 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], frameStackNumber=3):
125 | super(Discriminator, self).__init__()
126 | self.device = device
127 | self.frameStackNumber = frameStackNumber
128 | self.patchSize = patchSize
129 | self.outputSize = [patchSize[0]/16, patchSize[1]/16]
130 |
131 | self.discriminator = nn.Sequential(
132 | # 128-->60
133 | nn.Conv2d(self.frameStackNumber, 64, kernel_size=9, padding=0, stride=2, bias=False),
134 | nn.LeakyReLU(0.2, inplace=True),
135 |
136 | # 60-->33
137 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False),
138 | nn.BatchNorm2d(128),
139 | nn.LeakyReLU(0.2, inplace=True),
140 | # 33->
141 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False),
142 | nn.BatchNorm2d(256),
143 | nn.LeakyReLU(0.2, inplace=True),
144 |
145 | nn.Conv2d(256, 256, kernel_size=3, padding=0, stride=2, bias=False),
146 | nn.BatchNorm2d(256),
147 | nn.LeakyReLU(0.2, inplace=True),
148 | # dropout
149 | nn.Dropout(0.7),
150 | # final classification for 'real(1) vs. fake(0)'
151 | nn.Conv2d(256, 1, kernel_size=2, padding=0, stride=2, bias=True),
152 | nn.Sigmoid()
153 | )
154 |
155 | def forward(self, X):
156 | return self.discriminator(X)
157 |
158 |
159 |
160 |
161 |
162 | class GAN_Encoder(nn.Module):
163 | def __init__(self, embDimension=512):
164 | super(self.__class__, self).__init__()
165 |
166 | self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
167 | self.conv2 = nn.Conv2d(64, 128, 1, 2, 0, bias=False)
168 | self.conv3 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
169 | self.conv4 = nn.Conv2d(128, 256, 1, 2, 0, bias=False)
170 | self.conv5 = nn.Conv2d(256, 256, 3, 1, 1, bias=False)
171 | self.conv6 = nn.Conv2d(256, 512, 1, 2, 0, bias=False)
172 | self.conv7 = nn.Conv2d(512, 512, 3, 1, 1, bias=False)
173 | self.conv8 = nn.Conv2d(512, 512, 1, 2, 0, bias=False)
174 | self.conv9 = nn.Conv2d(512, 512, 3, 1, 1, bias=False)
175 | self.conv10 = nn.Conv2d(512, 512, 1, 2, 0, bias=False)
176 | self.conv11 = nn.Conv2d(512, embDimension, 3, 1, 1, bias=False)
177 |
178 | self.bn1 = nn.BatchNorm2d(64)
179 | self.bn2 = nn.BatchNorm2d(128)
180 | self.bn3 = nn.BatchNorm2d(128)
181 | self.bn4 = nn.BatchNorm2d(256)
182 | self.bn5 = nn.BatchNorm2d(256)
183 | self.bn6 = nn.BatchNorm2d(512)
184 | self.bn7 = nn.BatchNorm2d(512)
185 | self.bn8 = nn.BatchNorm2d(512)
186 | self.bn9 = nn.BatchNorm2d(512)
187 | self.bn10 = nn.BatchNorm2d(512)
188 | self.bn11 = nn.BatchNorm2d(embDimension)
189 |
190 | self.apply(weights_init)
191 |
192 |
193 | def forward(self, x, output_scale=1):
194 | batch_size = len(x)
195 |
196 | x = self.conv1(x)
197 | x = self.bn1(x)
198 | x = nn.LeakyReLU(0.2)(x)
199 | x = self.conv2(x)
200 | x = self.bn2(x)
201 | x = nn.LeakyReLU(0.2)(x)
202 | x = self.conv3(x)
203 | x = self.bn3(x)
204 | x = nn.LeakyReLU(0.2)(x)
205 |
206 | x = self.conv4(x)
207 | x = self.bn4(x)
208 | x = nn.LeakyReLU(0.2)(x)
209 | x = self.conv5(x)
210 | x = self.bn5(x)
211 | x = nn.LeakyReLU(0.2)(x)
212 | x = self.conv6(x)
213 | x = self.bn6(x)
214 | x = nn.LeakyReLU(0.2)(x)
215 |
216 | x = self.conv7(x)
217 | x = self.bn7(x)
218 | x = nn.LeakyReLU(0.2)(x)
219 | x = self.conv8(x)
220 | x = self.bn8(x)
221 | x = nn.LeakyReLU(0.2)(x)
222 | x = self.conv9(x)
223 | x = self.bn9(x)
224 | x = nn.LeakyReLU(0.2)(x)
225 |
226 | x = self.conv10(x)
227 | x = self.bn10(x)
228 | x = nn.LeakyReLU(0.2)(x)
229 |
230 | return x
231 |
232 |
233 | class GAN_Decoder(nn.Module):
234 | def __init__(self, nz=64, ngf=64, nc=3):
235 | super(GAN_Decoder, self).__init__()
236 |
237 | # torch.nn.ConvTranspose2d(
238 | # in_channels, out_channels, kernel_size,
239 | # stride=1, padding=0, output_padding=0, groups=1,
240 | # bias=True, dilation=1, padding_mode='zeros')
241 |
242 | self.main = nn.Sequential(
243 | # input is Z, going into a convolution
244 | nn.ConvTranspose2d(nz, ngf*4, 4, 2, 1, bias=False),
245 | nn.BatchNorm2d(ngf * 4),
246 | nn.ReLU(True),
247 | # state size. (ngf*8) x 2 x 2
248 | nn.ConvTranspose2d(ngf*4, ngf*4, 4, 2, 1, bias=False),
249 | nn.BatchNorm2d(ngf * 4),
250 | nn.ReLU(True),
251 | # state size. (ngf*4) x 4 x 4
252 | nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
253 | nn.BatchNorm2d(ngf * 2),
254 | nn.ReLU(True),
255 | # state size. (ngf*2) x 8 x 8
256 | nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
257 | nn.BatchNorm2d(ngf),
258 | nn.ReLU(True),
259 | # state size. (ngf) x 16 x 16
260 | nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=True)
261 | #nn.Tanh()
262 | # state size. (nc) x 32 x 32
263 | )
264 |
265 | def forward(self, x):
266 | return self.main(x)
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 | def weights_init(m):
284 | classname = m.__class__.__name__
285 | # TODO: what about fully-connected layers?
286 | if classname.find('Conv') != -1:
287 | m.weight.data.normal_(0.0, 0.05)
288 | elif classname.find('BatchNorm') != -1:
289 | m.weight.data.normal_(1.0, 0.02)
290 | m.bias.data.fill_(0)
291 |
292 |
293 |
294 |
295 |
296 |
297 | class MyDecoder(nn.Module):
298 | def __init__(self, latent_size=512, input_scale=4, insertConv=False):
299 | super(self.__class__, self).__init__()
300 | self.latent_size = latent_size
301 | self.input_scale = input_scale
302 | self.fc1 = nn.Linear(latent_size, 512*2*2, bias=False)
303 | self.insertConv = insertConv
304 |
305 | self.conv2_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=False)
306 | self.conv2 = nn.ConvTranspose2d( 512, 512, 4, stride=2, padding=1, bias=False)
307 | self.conv2_mid = nn.Conv2d(152, 512, 3, 1, 1, bias=False)
308 |
309 | self.conv3_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=False)
310 | self.conv3 = nn.ConvTranspose2d( 512, 256, 4, stride=2, padding=1, bias=False)
311 | self.conv3_mid = nn.Conv2d(256, 256, 3, 1, 1, bias=False)
312 |
313 | self.conv4_in = nn.ConvTranspose2d(latent_size, 256, 1, stride=1, padding=0, bias=False)
314 | self.conv4 = nn.ConvTranspose2d( 256, 128, 4, stride=2, padding=1, bias=False)
315 | self.conv4_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
316 |
317 | self.conv5 = nn.ConvTranspose2d( 128, 128, 4, stride=2, padding=1, bias=False)
318 | self.conv5_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
319 |
320 |
321 | self.conv6 = nn.ConvTranspose2d( 128, 3, 4, stride=2, padding=1, bias=True)
322 |
323 |
324 | self.bn1 = nn.BatchNorm2d(512)
325 | self.bn2 = nn.BatchNorm2d(512)
326 | self.bn2_mid = nn.BatchNorm2d(512)
327 | self.bn3 = nn.BatchNorm2d(256)
328 | self.bn3_mid = nn.BatchNorm2d(256)
329 | self.bn4 = nn.BatchNorm2d(128)
330 | self.bn4_mid = nn.BatchNorm2d(128)
331 | self.bn5 = nn.BatchNorm2d(128)
332 | self.bn5_mid = nn.BatchNorm2d(128)
333 |
334 | self.apply(weights_init)
335 | self.cuda()
336 |
337 |
338 | def forward(self, x):
339 | input_scale=self.input_scale
340 | batch_size = x.shape[0]
341 |
342 | if input_scale <= 1:
343 | x = self.fc1(x)
344 | x = x.resize(batch_size, 512, 2, 2)
345 |
346 | # 512 x 2 x 2
347 | if input_scale == 2:
348 | x = x.view(batch_size, self.latent_size, 2, 2)
349 | x = self.conv2_in(x)
350 | if input_scale <= 2:
351 | x = self.conv2(x)
352 | x = nn.LeakyReLU()(x)
353 | x = self.bn2(x)
354 | if self.insertConv:
355 | x = self.conv2_mid(x)
356 | x = nn.LeakyReLU()(x)
357 | x = self.bn2_mid(x)
358 |
359 | # 512 x 4 x 4
360 | if input_scale == 4:
361 | x = x.view(batch_size, self.latent_size, 4, 4)
362 | x = self.conv3_in(x)
363 | if input_scale <= 4:
364 | x = self.conv3(x)
365 | x = nn.LeakyReLU()(x)
366 | x = self.bn3(x)
367 | if self.insertConv:
368 | x = self.conv3_mid(x)
369 | x = nn.LeakyReLU()(x)
370 | x = self.bn3_mid(x)
371 |
372 |
373 | # 256 x 8 x 8
374 | if input_scale == 8:
375 | x = x.view(batch_size, self.latent_size, 8, 8)
376 | x = self.conv4_in(x)
377 | if input_scale <= 8:
378 | x = self.conv4(x)
379 | x = nn.LeakyReLU()(x)
380 | x = self.bn4(x)
381 | if self.insertConv:
382 | x = self.conv4_mid(x)
383 | x = nn.LeakyReLU()(x)
384 | x = self.bn4_mid(x)
385 |
386 |
387 | # 128 x 16 x 16
388 | x = self.conv5(x)
389 | x = nn.LeakyReLU()(x)
390 | x = self.bn5_mid(x)
391 |
392 | # 3 x 32 x 32
393 | #x = nn.Sigmoid()(x)
394 |
395 | x = self.conv6(x)
396 | return x
397 |
398 |
399 |
400 |
401 |
402 | class MySingleBigDecoder(nn.Module):
403 | def __init__(self, latent_size=512, input_scale=4, insertConv=False, nClasses=200):
404 | super(self.__class__, self).__init__()
405 | self.latent_size = latent_size
406 | self.input_scale = input_scale
407 | self.fc1 = nn.Linear(latent_size, 512*2*2, bias=False)
408 | self.insertConv = insertConv
409 | self.nClasses = nClasses
410 |
411 | self.conv2_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=False)
412 | self.conv2 = nn.ConvTranspose2d( 512, 512, 4, stride=2, padding=1, bias=False)
413 | self.conv2_mid = nn.Conv2d(152, 512, 3, 1, 1, bias=False)
414 |
415 | self.conv3_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=False)
416 | self.conv3 = nn.ConvTranspose2d( 512, 256, 4, stride=2, padding=1, bias=False)
417 | self.conv3_mid = nn.Conv2d(256, 256, 3, 1, 1, bias=False)
418 |
419 | self.conv4_in = nn.ConvTranspose2d(latent_size, 256, 1, stride=1, padding=0, bias=False)
420 | self.conv4 = nn.ConvTranspose2d( 256, 128, 4, stride=2, padding=1, bias=False)
421 | self.conv4_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
422 |
423 | self.conv5 = nn.ConvTranspose2d( 128, 128, 4, stride=2, padding=1, bias=False)
424 | self.conv5_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
425 |
426 |
427 | self.conv6 = nn.ConvTranspose2d( 128, 3*nClasses, 4, stride=2, padding=1, bias=True)
428 |
429 |
430 | self.bn1 = nn.BatchNorm2d(512)
431 | self.bn2 = nn.BatchNorm2d(512)
432 | self.bn2_mid = nn.BatchNorm2d(512)
433 | self.bn3 = nn.BatchNorm2d(256)
434 | self.bn3_mid = nn.BatchNorm2d(256)
435 | self.bn4 = nn.BatchNorm2d(128)
436 | self.bn4_mid = nn.BatchNorm2d(128)
437 | self.bn5 = nn.BatchNorm2d(128)
438 | self.bn5_mid = nn.BatchNorm2d(128)
439 |
440 | self.apply(weights_init)
441 | self.cuda()
442 |
443 |
444 | def forward(self, x):
445 | input_scale=self.input_scale
446 | batch_size = x.shape[0]
447 |
448 | if input_scale <= 1:
449 | x = self.fc1(x)
450 | x = x.resize(batch_size, 512, 2, 2)
451 |
452 | # 512 x 2 x 2
453 | if input_scale == 2:
454 | x = x.view(batch_size, self.latent_size, 2, 2)
455 | x = self.conv2_in(x)
456 | if input_scale <= 2:
457 | x = self.conv2(x)
458 | x = nn.LeakyReLU()(x)
459 | x = self.bn2(x)
460 | if self.insertConv:
461 | x = self.conv2_mid(x)
462 | x = nn.LeakyReLU()(x)
463 | x = self.bn2_mid(x)
464 |
465 | # 512 x 4 x 4
466 | if input_scale == 4:
467 | x = x.view(batch_size, self.latent_size, 4, 4)
468 | x = self.conv3_in(x)
469 | if input_scale <= 4:
470 | x = self.conv3(x)
471 | x = nn.LeakyReLU()(x)
472 | x = self.bn3(x)
473 | if self.insertConv:
474 | x = self.conv3_mid(x)
475 | x = nn.LeakyReLU()(x)
476 | x = self.bn3_mid(x)
477 |
478 |
479 | # 256 x 8 x 8
480 | if input_scale == 8:
481 | x = x.view(batch_size, self.latent_size, 8, 8)
482 | x = self.conv4_in(x)
483 | if input_scale <= 8:
484 | x = self.conv4(x)
485 | x = nn.LeakyReLU()(x)
486 | x = self.bn4(x)
487 | if self.insertConv:
488 | x = self.conv4_mid(x)
489 | x = nn.LeakyReLU()(x)
490 | x = self.bn4_mid(x)
491 |
492 |
493 | # 128 x 16 x 16
494 | x = self.conv5(x)
495 | x = nn.LeakyReLU()(x)
496 | x = self.bn5_mid(x)
497 |
498 | # 3 x 32 x 32
499 | #x = nn.Sigmoid()(x)
500 |
501 | x = self.conv6(x)
502 | return x
503 |
504 |
505 |
506 |
507 |
508 |
509 | class MyDecoder_noBN(nn.Module):
510 | def __init__(self, latent_size=512, input_scale=4, insertConv=False):
511 | super(self.__class__, self).__init__()
512 | self.latent_size = latent_size
513 | self.input_scale = input_scale
514 | self.fc1 = nn.Linear(latent_size, 512*2*2, bias=False)
515 | self.insertConv = insertConv
516 |
517 | self.conv2_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=True)
518 | self.conv2 = nn.ConvTranspose2d( 512, 512, 4, stride=2, padding=1, bias=True)
519 | self.conv2_mid = nn.Conv2d(152, 512, 3, 1, 1, bias=True)
520 |
521 | self.conv3_in = nn.ConvTranspose2d(latent_size, 512, 1, stride=1, padding=0, bias=True)
522 | self.conv3 = nn.ConvTranspose2d( 512, 256, 4, stride=2, padding=1, bias=True)
523 | self.conv3_mid = nn.Conv2d(256, 256, 3, 1, 1, bias=True)
524 |
525 | self.conv4_in = nn.ConvTranspose2d(latent_size, 256, 1, stride=1, padding=0, bias=True)
526 | self.conv4 = nn.ConvTranspose2d( 256, 128, 4, stride=2, padding=1, bias=True)
527 | self.conv4_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=True)
528 |
529 | self.conv5 = nn.ConvTranspose2d( 128, 3, 4, stride=2, padding=1, bias=True)
530 | self.conv5_mid = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
531 |
532 | self.conv6 = nn.ConvTranspose2d( 128, 3, 4, stride=2, padding=1, bias=True)
533 |
534 | #self.bn1 = nn.BatchNorm2d(512)
535 | #self.bn2 = nn.BatchNorm2d(512)
536 | #self.bn2_mid = nn.BatchNorm2d(512)
537 | #self.bn3 = nn.BatchNorm2d(256)
538 | #self.bn3_mid = nn.BatchNorm2d(256)
539 | #self.bn4 = nn.BatchNorm2d(128)
540 | #self.bn4_mid = nn.BatchNorm2d(128)
541 |
542 | self.apply(weights_init)
543 | self.cuda()
544 |
545 | def forward(self, x):
546 | input_scale=self.input_scale
547 | batch_size = x.shape[0]
548 |
549 | if input_scale <= 1:
550 | x = self.fc1(x)
551 | x = x.resize(batch_size, 512, 2, 2)
552 |
553 | # 512 x 2 x 2
554 | if input_scale == 2:
555 | x = x.view(batch_size, self.latent_size, 2, 2)
556 | x = self.conv2_in(x)
557 | if input_scale <= 2:
558 | x = self.conv2(x)
559 | x = nn.LeakyReLU()(x)
560 | #x = self.bn2(x)
561 | if self.insertConv:
562 | x = self.conv2_mid(x)
563 | x = nn.LeakyReLU()(x)
564 | #x = self.bn2_mid(x)
565 |
566 | # 512 x 4 x 4
567 | if input_scale == 4:
568 | x = x.view(batch_size, self.latent_size, 4, 4)
569 | x = self.conv3_in(x)
570 | x = nn.LeakyReLU()(x)
571 | if input_scale <= 4:
572 | x = self.conv3(x)
573 | x = nn.LeakyReLU()(x)
574 | #x = self.bn3(x)
575 | if self.insertConv:
576 | x = self.conv3_mid(x)
577 | x = nn.LeakyReLU()(x)
578 | #x = self.bn3_mid(x)
579 |
580 |
581 | # 256 x 8 x 8
582 | if input_scale == 8:
583 | x = x.view(batch_size, self.latent_size, 8, 8)
584 | x = self.conv4_in(x)
585 | if input_scale <= 8:
586 | x = self.conv4(x)
587 | x = nn.LeakyReLU()(x)
588 | #x = self.bn4(x)
589 | if self.insertConv:
590 | x = self.conv4_mid(x)
591 | x = nn.LeakyReLU()(x)
592 | #x = self.bn4_mid(x)
593 |
594 |
595 | # 128 x 16 x 16
596 | x = self.conv5(x)
597 | x = nn.LeakyReLU()(x)
598 | #x = self.bn5_mid(x)
599 |
600 | # 3 x 32 x 32
601 | #x = nn.Sigmoid()(x)
602 |
603 | x = self.conv6(x)
604 | return x
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 | class classifier32(nn.Module):
614 | def __init__(self, latent_size=100, num_classes=2, batch_size=64, return_feat=True):
615 | super(self.__class__, self).__init__()
616 | self.return_feat = return_feat
617 |
618 | self.batch_size = batch_size
619 | self.num_classes = num_classes
620 | self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
621 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
622 | self.conv3 = nn.Conv2d(64, 128, 3, 2, 1, bias=False)
623 |
624 | self.conv4 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
625 | self.conv5 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
626 | self.conv6 = nn.Conv2d(128, 128, 3, 2, 1, bias=False)
627 |
628 | self.conv7 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
629 | self.conv8 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
630 | self.conv9 = nn.Conv2d(128, 128, 3, 2, 1, bias=False)
631 |
632 | self.bn1 = nn.BatchNorm2d(64)
633 | self.bn2 = nn.BatchNorm2d(64)
634 | self.bn3 = nn.BatchNorm2d(128)
635 |
636 | self.bn4 = nn.BatchNorm2d(128)
637 | self.bn5 = nn.BatchNorm2d(128)
638 | self.bn6 = nn.BatchNorm2d(128)
639 |
640 | self.bn7 = nn.BatchNorm2d(128)
641 | self.bn8 = nn.BatchNorm2d(128)
642 | self.bn9 = nn.BatchNorm2d(128)
643 |
644 | self.fc1 = nn.Linear(128*4*4, num_classes)
645 | self.dr1 = nn.Dropout2d(0.2)
646 | self.dr2 = nn.Dropout2d(0.2)
647 | self.dr3 = nn.Dropout2d(0.2)
648 |
649 | self.apply(weights_init)
650 | self.cuda()
651 |
652 | def forward(self, x, return_features=False):
653 | batch_size = len(x)
654 |
655 | x = self.dr1(x)
656 | x = self.conv1(x)
657 | x = self.bn1(x)
658 | x = nn.LeakyReLU(0.2)(x)
659 | x = self.conv2(x)
660 | x = self.bn2(x)
661 | x = nn.LeakyReLU(0.2)(x)
662 | x = self.conv3(x)
663 | x = self.bn3(x)
664 | x = nn.LeakyReLU(0.2)(x)
665 |
666 | x = self.dr2(x)
667 | x = self.conv4(x)
668 | x = self.bn4(x)
669 | x = nn.LeakyReLU(0.2)(x)
670 | x = self.conv5(x)
671 | x = self.bn5(x)
672 | x = nn.LeakyReLU(0.2)(x)
673 | x = self.conv6(x)
674 | x = self.bn6(x)
675 | x = nn.LeakyReLU(0.2)(x)
676 |
677 | x = self.dr3(x)
678 | x = self.conv7(x)
679 | x = self.bn7(x)
680 | x = nn.LeakyReLU(0.2)(x)
681 | x = self.conv8(x)
682 | x = self.bn8(x)
683 | x = nn.LeakyReLU(0.2)(x)
684 | x = self.conv9(x)
685 | x = self.bn9(x)
686 | x = nn.LeakyReLU(0.2)(x)
687 |
688 | x = x.view(batch_size, -1)
689 | if self.return_feat:
690 | return x
691 | x = self.fc1(x)
692 | return x
693 |
694 |
695 |
696 | class ResnetEncoder(nn.Module):
697 | """Pytorch module for a resnet encoder
698 | """
699 | def __init__(self, num_layers=18, isPretrained=False, isGrayscale=False, embDimension=128, poolSize=4):
700 | super(ResnetEncoder, self).__init__()
701 | self.path_to_model = '../models'
702 | self.num_ch_enc = np.array([64, 64, 128, 256, 512])
703 | self.isGrayscale = isGrayscale
704 | self.isPretrained = isPretrained
705 | self.embDimension = embDimension
706 | self.poolSize = poolSize
707 | self.featListName = ['layer0', 'layer1', 'layer2', 'layer3', 'layer4']
708 |
709 | resnets = {
710 | 18: models.resnet18,
711 | 34: models.resnet34,
712 | 50: models.resnet50,
713 | 101: models.resnet101,
714 | 152: models.resnet152}
715 |
716 | resnets_pretrained_path = {
717 | 18: 'resnet18-5c106cde.pth',
718 | 34: 'resnet34.pth',
719 | 50: 'resnet50.pth',
720 | 101: 'resnet101.pth',
721 | 152: 'resnet152.pth'}
722 |
723 | if num_layers not in resnets:
724 | raise ValueError("{} is not a valid number of resnet layers".format(
725 | num_layers))
726 |
727 | self.encoder = resnets[num_layers]()
728 |
729 | if self.embDimension>0:
730 | self.encoder.linear = nn.Linear(self.num_ch_enc[-1], self.embDimension)
731 |
732 | if self.isPretrained:
733 | print("using pretrained model")
734 | self.encoder.load_state_dict(
735 | torch.load(os.path.join(self.path_to_model, resnets_pretrained_path[num_layers])))
736 |
737 | #if self.isGrayscale:
738 | # self.encoder.conv1 = nn.Conv2d(
739 | # 1, 64, kernel_size=3, stride=1, padding=1, bias=False)
740 | #else:
741 | # self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
742 |
743 | if num_layers > 34:
744 | self.num_ch_enc[1:] *= 4
745 |
746 | def forward(self, input_image):
747 | self.features = []
748 |
749 | x = self.encoder.conv1(input_image)
750 | x = self.encoder.bn1(x)
751 | x = self.encoder.relu(x)
752 | self.features.append(x)
753 |
754 | #x = self.encoder.layer1(self.encoder.maxpool(x)) #
755 | x = self.encoder.layer1(x) # self.encoder.maxpool(x)
756 | self.features.append(x)
757 | #print('layer1: ', x.shape)
758 |
759 | x = self.encoder.layer2(x)
760 | self.features.append(x)
761 | #print('layer2: ', x.shape)
762 |
763 | x = self.encoder.layer3(x)
764 | self.features.append(x)
765 | #print('layer3: ', x.shape)
766 |
767 | x = self.encoder.layer4(x)
768 | self.features.append(x)
769 | #print('layer4: ', x.shape)
770 |
771 | x = F.avg_pool2d(x, self.poolSize)
772 | #print('global pool: ', x.shape)
773 |
774 | x = x.view(x.size(0), -1)
775 | #print('reshape: ', x.shape)
776 |
777 | if self.embDimension>0:
778 | x = self.encoder.linear(x)
779 | #print('final: ', x.shape)
780 | return x
781 |
782 |
783 |
784 | class TinyImageNet_ClsNet(nn.Module):
785 | def __init__(self, nClass=10, layerList=(64, 32)):
786 | super(TinyImageNet_ClsNet, self).__init__()
787 |
788 | self.nClass = nClass
789 | self.layerList = layerList
790 | self.linearLayers = OrderedDict()
791 | self.relu = nn.ReLU()
792 | i=-1
793 | for i in range(len(layerList)-1):
794 | self.linearLayers[i] = nn.Linear(self.layerList[i], self.layerList[i+1])
795 | self.linearLayers[i+1] = nn.Linear(self.layerList[-1], self.nClass)
796 | self.mnist_clsnet = nn.ModuleList(list(self.linearLayers.values()))
797 |
798 | def forward(self, x):
799 | i = -1
800 | for i in range(len(self.layerList)-1):
801 | x = self.linearLayers[i](x)
802 | x = self.relu(x)
803 | x = self.linearLayers[i+1](x)
804 | return x
805 |
806 |
807 |
808 |
809 |
810 | class TinyImageNet_Decoder(nn.Module):
811 | def __init__(self, embDimension=128, layerList=(256, 512, 3*1024*1024), imgSize=[3,32,32],
812 | isReshapeBack=True, reluFirst=False):
813 | super(TinyImageNet_Decoder, self).__init__()
814 |
815 | self.imgSize = imgSize
816 | self.embDimension = embDimension
817 | self.layerList = layerList
818 | self.linearLayers = OrderedDict()
819 | self.relu = nn.ReLU()
820 | self.isReshapeBack = isReshapeBack
821 | self.reluFirst = reluFirst
822 |
823 | self.linearLayers[0] = nn.Linear(self.embDimension, self.layerList[0])
824 | for i in range(1, len(layerList)):
825 | self.linearLayers[i] = nn.Linear(self.layerList[i-1], self.layerList[i])
826 |
827 | self.mnist_decoder = nn.ModuleList(list(self.linearLayers.values()))
828 |
829 | def forward(self, x):
830 | self.featList = []
831 |
832 | if self.reluFirst:
833 | x = self.relu(x)
834 | x = self.linearLayers[0](x)
835 | self.featList.append(x)
836 |
837 | for i in range(1, len(self.layerList)):
838 | x = self.relu(x)
839 | x = self.linearLayers[i](x)
840 | self.featList.append(x)
841 |
842 | if self.isReshapeBack:
843 | x = x.view(x.size(0), self.imgSize[0], self.imgSize[1], self.imgSize[2])
844 |
845 | return x
846 |
847 |
848 |
849 | class CondEncoder(nn.Module):
850 | def __init__(self, num_classes=200, dimension=128, device='cpu'):
851 | super(self.__class__, self).__init__()
852 | self.num_classes = num_classes
853 | self.dimension = dimension
854 |
855 | self.fc1 = nn.Linear(num_classes, num_classes)
856 | self.fc2 = nn.Linear(num_classes, dimension)
857 | self.fc3 = nn.Linear(dimension, dimension)
858 | self.device = device
859 |
860 | def forward(self, input, indicator):
861 | batch_size = len(input)
862 | x = torch.zeros(batch_size, self.num_classes).to(self.device)
863 | x[:, indicator] = 1
864 | x = x.to(self.device)
865 |
866 | x = self.fc1(x)
867 | x = nn.LeakyReLU(0.2)(x)
868 | x = self.fc2(x)
869 | x = nn.LeakyReLU(0.2)(x)
870 | x = self.fc3(x)
871 | return x
--------------------------------------------------------------------------------
/demo_OpenSetSegmentation_training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "OpenGAN: Open-Set Recognition via Open Data Generation\n",
8 | "================\n",
9 | "**Supplemental Material for ICCV2021 Submission**\n",
10 | "\n",
11 | "\n",
12 | "In this notebook is for demonstrating open-set semantic segmentation, especially for training in this task."
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {},
18 | "source": [
19 | "import packages\n",
20 | "------------------\n",
21 | "\n",
22 | "Some packages are installed automatically through Anaconda. PyTorch should be also installed."
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "name": "stdout",
32 | "output_type": "stream",
33 | "text": [
34 | "3.7.4 (default, Aug 13 2019, 20:35:49) \n",
35 | "[GCC 7.3.0]\n",
36 | "1.4.0+cu92\n"
37 | ]
38 | }
39 | ],
40 | "source": [
41 | "from __future__ import print_function, division\n",
42 | "import os, random, time, copy, scipy, pickle, sys, math, json, pickle\n",
43 | "\n",
44 | "import argparse, pprint, shutil, logging, time, timeit\n",
45 | "from pathlib import Path\n",
46 | "\n",
47 | "from skimage import io, transform\n",
48 | "import numpy as np\n",
49 | "import os.path as path\n",
50 | "import scipy.io as sio\n",
51 | "from scipy import misc\n",
52 | "from scipy import ndimage, signal\n",
53 | "import matplotlib.pyplot as plt\n",
54 | "# import PIL.Image\n",
55 | "from PIL import Image\n",
56 | "from io import BytesIO\n",
57 | "from skimage import data, img_as_float\n",
58 | "from skimage.measure import compare_ssim as ssim\n",
59 | "from skimage.measure import compare_psnr as psnr\n",
60 | "\n",
61 | "import torch, torchvision\n",
62 | "from torch.utils.data import Dataset, DataLoader\n",
63 | "import torch.nn as nn\n",
64 | "import torch.optim as optim\n",
65 | "from torch.optim import lr_scheduler \n",
66 | "import torch.nn.functional as F\n",
67 | "from torch.autograd import Variable\n",
68 | "from torchvision import datasets, models, transforms\n",
69 | "import torchvision.utils as vutils\n",
70 | "from collections import namedtuple\n",
71 | "\n",
72 | "from config_HRNet import models\n",
73 | "from config_HRNet import seg_hrnet\n",
74 | "from config_HRNet import config\n",
75 | "from config_HRNet import update_config\n",
76 | "from config_HRNet.modelsummary import *\n",
77 | "from config_HRNet.utils import *\n",
78 | "\n",
79 | "\n",
80 | "from utils.dataset_tinyimagenet import *\n",
81 | "from utils.dataset_cityscapes import *\n",
82 | "from utils.eval_funcs import *\n",
83 | "\n",
84 | "\n",
85 | "import warnings # ignore warnings\n",
86 | "warnings.filterwarnings(\"ignore\")\n",
87 | "print(sys.version)\n",
88 | "print(torch.__version__)\n",
89 | "\n",
90 | "# %load_ext autoreload\n",
91 | "# %autoreload 2"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "metadata": {},
97 | "source": [
98 | "Setup config parameters\n",
99 | " -----------------\n",
100 | " \n",
101 | " There are several things to setup, like which GPU to use, where to read images and save files, etc. Please read and understand this. By default, you should be able to run this script smoothly by changing nothing."
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 2,
107 | "metadata": {},
108 | "outputs": [
109 | {
110 | "name": "stdout",
111 | "output_type": "stream",
112 | "text": [
113 | "./exp/demo_step030_OpenGAN_num1000_w0.20\n"
114 | ]
115 | }
116 | ],
117 | "source": [
118 | "# set the random seed\n",
119 | "torch.manual_seed(0)\n",
120 | "\n",
121 | "\n",
122 | "################## set attributes for this project/experiment ##################\n",
123 | "# config result folder\n",
124 | "exp_dir = './exp' # experiment directory, used for reading the init model\n",
125 | "\n",
126 | "num_open_training_images = 1000\n",
127 | "weight_adversarialLoss = 0.2\n",
128 | "project_name = 'demo_step030_OpenGAN_num{}_w{:.2f}'.format(num_open_training_images, weight_adversarialLoss)\n",
129 | "\n",
130 | "\n",
131 | "\n",
132 | "\n",
133 | "device ='cpu'\n",
134 | "if torch.cuda.is_available(): \n",
135 | " device='cuda:3'\n",
136 | " \n",
137 | "\n",
138 | "\n",
139 | "ganBatchSize = 640\n",
140 | "batch_size = 1\n",
141 | "newsize = (-1,-1)\n",
142 | "\n",
143 | "total_epoch_num = 50 # total number of epoch in training\n",
144 | "insertConv = False \n",
145 | "embDimension = 64\n",
146 | "#isPretrained = False\n",
147 | "#encoder_num_layers = 18\n",
148 | "\n",
149 | "\n",
150 | "# Number of channels in the training images. For color images this is 3\n",
151 | "nc = 720\n",
152 | "# Size of z latent vector (i.e. size of generator input)\n",
153 | "nz = 64\n",
154 | "# Size of feature maps in generator\n",
155 | "ngf = 64\n",
156 | "# Size of feature maps in discriminator\n",
157 | "ndf = 64\n",
158 | "# Beta1 hyperparam for Adam optimizers\n",
159 | "beta1 = 0.5\n",
160 | "# Number of GPUs available. Use 0 for CPU mode.\n",
161 | "ngpu = 1\n",
162 | "\n",
163 | "\n",
164 | "\n",
165 | "save_dir = os.path.join(exp_dir, project_name)\n",
166 | "if not os.path.exists(exp_dir): os.makedirs(exp_dir)\n",
167 | "\n",
168 | "lr = 0.0001 # base learning rate\n",
169 | "\n",
170 | "num_epochs = total_epoch_num\n",
171 | "torch.cuda.device_count()\n",
172 | "torch.cuda.empty_cache()\n",
173 | "\n",
174 | "save_dir = os.path.join(exp_dir, project_name)\n",
175 | "print(save_dir) \n",
176 | "if not os.path.exists(save_dir): os.makedirs(save_dir)\n",
177 | "\n",
178 | "log_filename = os.path.join(save_dir, 'train.log')"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {},
184 | "source": [
185 | "Define model architecture\n",
186 | "---------\n",
187 | "\n",
188 | "Here is the definition of the model architecture. "
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": 3,
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "class CityscapesOpenPixelFeat4(Dataset):\n",
198 | " def __init__(self, set_name='train',\n",
199 | " numImgs=500,\n",
200 | " path_to_data='/scratch/dataset/Cityscapes_feat4'): \n",
201 | " \n",
202 | " self.imgList = []\n",
203 | " self.current_set_len = numImgs # 2975\n",
204 | " if set_name=='test': \n",
205 | " set_name = 'val'\n",
206 | " self.current_set_len = 500\n",
207 | " \n",
208 | " self.set_name = set_name\n",
209 | " self.path_to_data = path_to_data\n",
210 | " for i in range(self.current_set_len):\n",
211 | " self.imgList += ['{}_openpixel.pkl'.format(i)] \n",
212 | " \n",
213 | " def __len__(self): \n",
214 | " return self.current_set_len\n",
215 | " \n",
216 | " def __getitem__(self, idx): \n",
217 | " filename = path.join(self.path_to_data, self.set_name, self.imgList[idx])\n",
218 | " with open(filename, \"rb\") as fn:\n",
219 | " openPixFeat = pickle.load(fn)\n",
220 | " openPixFeat = openPixFeat['feat4open_percls']\n",
221 | " openPixFeat = torch.cat(openPixFeat, 0).detach()\n",
222 | " #print(openPixFeat.shape)\n",
223 | " return openPixFeat"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 4,
229 | "metadata": {},
230 | "outputs": [],
231 | "source": [
232 | "parser = argparse.ArgumentParser(description='Train segmentation network') \n",
233 | "parser.add_argument('--cfg',\n",
234 | " help='experiment configure file name',\n",
235 | " default='./config_HRNet/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',\n",
236 | " type=str)\n",
237 | "parser.add_argument('opts',\n",
238 | " help=\"Modify config options using the command-line\",\n",
239 | " default=None,\n",
240 | " nargs=argparse.REMAINDER)\n",
241 | "\n",
242 | "\n",
243 | "args = parser.parse_args(r'--cfg ./config_HRNet/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml ')\n",
244 | "args.opts = []\n",
245 | "update_config(config, args)"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 5,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "model = eval(config.MODEL.NAME + '.get_seg_model_myModel')(config)\n",
255 | "model_dict = model.state_dict()\n",
256 | "\n",
257 | "\n",
258 | "model_state_file = '../openset/models/hrnet_w48_cityscapes_cls19_1024x2048_ohem_trainset.pth'\n",
259 | "pretrained_dict = torch.load(model_state_file, map_location=lambda storage, loc: storage)\n",
260 | "\n",
261 | "\n",
262 | "suppl_dict = {}\n",
263 | "suppl_dict['last_1_conv.weight'] = pretrained_dict['model.last_layer.0.weight'].clone()\n",
264 | "suppl_dict['last_1_conv.bias'] = pretrained_dict['model.last_layer.0.bias'].clone()\n",
265 | "\n",
266 | "suppl_dict['last_2_BN.running_mean'] = pretrained_dict['model.last_layer.1.running_mean'].clone()\n",
267 | "suppl_dict['last_2_BN.running_var'] = pretrained_dict['model.last_layer.1.running_var'].clone()\n",
268 | "# suppl_dict['last_2_BN.num_batches_tracked'] = pretrained_dict['model.last_layer.1.num_batches_tracked']\n",
269 | "suppl_dict['last_2_BN.weight'] = pretrained_dict['model.last_layer.1.weight'].clone()\n",
270 | "suppl_dict['last_2_BN.bias'] = pretrained_dict['model.last_layer.1.bias'].clone()\n",
271 | "\n",
272 | "suppl_dict['last_4_conv.weight'] = pretrained_dict['model.last_layer.3.weight'].clone()\n",
273 | "suppl_dict['last_4_conv.bias'] = pretrained_dict['model.last_layer.3.bias'].clone()\n",
274 | "\n",
275 | "\n",
276 | "pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()\n",
277 | " if k[6:] in model_dict.keys()}\n",
278 | "\n",
279 | "\n",
280 | "model_dict.update(pretrained_dict)\n",
281 | "model_dict.update(suppl_dict)\n",
282 | "model.load_state_dict(model_dict)\n",
283 | "\n",
284 | "\n",
285 | "model.eval();\n",
286 | "model.to(device);"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": 6,
292 | "metadata": {},
293 | "outputs": [],
294 | "source": [
295 | "def weights_init(m):\n",
296 | " classname = m.__class__.__name__\n",
297 | " if classname.find('Conv') != -1:\n",
298 | " nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
299 | " elif classname.find('BatchNorm') != -1:\n",
300 | " nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
301 | " nn.init.constant_(m.bias.data, 0) \n",
302 | " \n",
303 | "\n",
304 | "class Generator(nn.Module):\n",
305 | " def __init__(self, ngpu=1, nz=100, ngf=64, nc=512):\n",
306 | " super(Generator, self).__init__()\n",
307 | " self.ngpu = ngpu\n",
308 | " self.nz = nz\n",
309 | " self.ngf = ngf\n",
310 | " self.nc = nc\n",
311 | " \n",
312 | " self.main = nn.Sequential(\n",
313 | " # input is Z, going into a convolution\n",
314 | " nn.Conv2d( self.nz, self.ngf * 8, 1, 1, 0, bias=True),\n",
315 | " nn.BatchNorm2d(self.ngf * 8),\n",
316 | " nn.ReLU(True),\n",
317 | " # state size. (self.ngf*8) x 4 x 4\n",
318 | " nn.Conv2d(self.ngf * 8, self.ngf * 4, 1, 1, 0, bias=True),\n",
319 | " nn.BatchNorm2d(self.ngf * 4),\n",
320 | " nn.ReLU(True),\n",
321 | " # state size. (self.ngf*4) x 8 x 8\n",
322 | " nn.Conv2d( self.ngf * 4, self.ngf * 2, 1, 1, 0, bias=True),\n",
323 | " nn.BatchNorm2d(self.ngf * 2),\n",
324 | " nn.ReLU(True),\n",
325 | " # state size. (self.ngf*2) x 16 x 16\n",
326 | " nn.Conv2d( self.ngf * 2, self.ngf*4, 1, 1, 0, bias=True),\n",
327 | " nn.BatchNorm2d(self.ngf*4),\n",
328 | " nn.ReLU(True),\n",
329 | " # state size. (self.ngf) x 32 x 32\n",
330 | " nn.Conv2d( self.ngf*4, self.nc, 1, 1, 0, bias=True),\n",
331 | " #nn.Tanh()\n",
332 | " # state size. (self.nc) x 64 x 64\n",
333 | " )\n",
334 | "\n",
335 | " def forward(self, input):\n",
336 | " return self.main(input)\n",
337 | "\n",
338 | " \n",
339 | "class Discriminator(nn.Module):\n",
340 | " def __init__(self, ngpu=1, nc=512, ndf=64):\n",
341 | " super(Discriminator, self).__init__()\n",
342 | " self.ngpu = ngpu\n",
343 | " self.nc = nc\n",
344 | " self.ndf = ndf\n",
345 | " self.main = nn.Sequential(\n",
346 | " nn.Conv2d(self.nc, self.ndf*8, 1, 1, 0, bias=True),\n",
347 | " nn.LeakyReLU(0.2, inplace=True),\n",
348 | " nn.Conv2d(self.ndf*8, self.ndf*4, 1, 1, 0, bias=True),\n",
349 | " nn.BatchNorm2d(self.ndf*4),\n",
350 | " nn.LeakyReLU(0.2, inplace=True),\n",
351 | " nn.Conv2d(self.ndf*4, self.ndf*2, 1, 1, 0, bias=True),\n",
352 | " nn.BatchNorm2d(self.ndf*2),\n",
353 | " nn.LeakyReLU(0.2, inplace=True),\n",
354 | " nn.Conv2d(self.ndf*2, self.ndf, 1, 1, 0, bias=True),\n",
355 | " nn.BatchNorm2d(self.ndf),\n",
356 | " nn.LeakyReLU(0.2, inplace=True),\n",
357 | " nn.Conv2d(self.ndf, 1, 1, 1, 0, bias=True),\n",
358 | " nn.Sigmoid()\n",
359 | " )\n",
360 | "\n",
361 | " def forward(self, input):\n",
362 | " return self.main(input)"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": 7,
368 | "metadata": {},
369 | "outputs": [
370 | {
371 | "name": "stdout",
372 | "output_type": "stream",
373 | "text": [
374 | "cuda:3\n"
375 | ]
376 | }
377 | ],
378 | "source": [
379 | "netG = Generator(ngpu=ngpu, nz=nz, ngf=ngf, nc=nc).to(device)\n",
380 | "netD = Discriminator(ngpu=ngpu, nc=nc, ndf=ndf).to(device)\n",
381 | "\n",
382 | "\n",
383 | "# Handle multi-gpu if desired\n",
384 | "if ('cuda' in device) and (ngpu > 1): \n",
385 | " netD = nn.DataParallel(netD, list(range(ngpu)))\n",
386 | "\n",
387 | "# Apply the weights_init function to randomly initialize all weights\n",
388 | "# to mean=0, stdev=0.2.\n",
389 | "netD.apply(weights_init)\n",
390 | "\n",
391 | "\n",
392 | "if ('cuda' in device) and (ngpu > 1):\n",
393 | " netG = nn.DataParallel(netG, list(range(ngpu)))\n",
394 | "netG.apply(weights_init)\n",
395 | "\n",
396 | "print(device)"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": 8,
402 | "metadata": {},
403 | "outputs": [
404 | {
405 | "name": "stdout",
406 | "output_type": "stream",
407 | "text": [
408 | "torch.Size([5, 64, 1, 1]) torch.Size([5, 720, 1, 1]) torch.Size([5, 1, 1, 1])\n"
409 | ]
410 | }
411 | ],
412 | "source": [
413 | "noise = torch.randn(batch_size*5, nz, 1, 1, device=device)\n",
414 | "# Generate fake image batch with G\n",
415 | "fake = netG(noise)\n",
416 | "predLabel = netD(fake)\n",
417 | "\n",
418 | "print(noise.shape, fake.shape, predLabel.shape)"
419 | ]
420 | },
421 | {
422 | "cell_type": "markdown",
423 | "metadata": {},
424 | "source": [
425 | "setup dataset\n",
426 | "-----------"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 9,
432 | "metadata": {},
433 | "outputs": [
434 | {
435 | "name": "stdout",
436 | "output_type": "stream",
437 | "text": [
438 | "2975 500\n"
439 | ]
440 | }
441 | ],
442 | "source": [
443 | "# torchvision.transforms.Normalize(mean, std, inplace=False)\n",
444 | "imgTransformList = transforms.Compose([\n",
445 | " transforms.ToTensor(),\n",
446 | " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
447 | "])\n",
448 | "\n",
449 | "targetTransformList = transforms.Compose([\n",
450 | " transforms.ToTensor(), \n",
451 | "])\n",
452 | "\n",
453 | "cls_datasets = {set_name: Cityscapes(root='/scratch/dataset/Cityscapes',\n",
454 | " newsize=newsize,\n",
455 | " split=set_name,\n",
456 | " mode='fine',\n",
457 | " target_type='semantic',\n",
458 | " transform=imgTransformList,\n",
459 | " target_transform=targetTransformList,\n",
460 | " transforms=None)\n",
461 | " for set_name in ['train', 'val']} # 'train', \n",
462 | "\n",
463 | "dataloaders = {set_name: DataLoader(cls_datasets[set_name],\n",
464 | " batch_size=batch_size,\n",
465 | " shuffle=set_name=='train', \n",
466 | " num_workers=4) # num_work can be set to batch_size\n",
467 | " for set_name in ['train', 'val']} # 'train',\n",
468 | "\n",
469 | "\n",
470 | "print(len(cls_datasets['train']), len(cls_datasets['val']))\n",
471 | "classDictionary = cls_datasets['val'].classes"
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "execution_count": 10,
477 | "metadata": {},
478 | "outputs": [
479 | {
480 | "name": "stdout",
481 | "output_type": "stream",
482 | "text": [
483 | "0 unlabeled\n",
484 | "1 ego vehicle\n",
485 | "2 rectification border\n",
486 | "3 out of roi\n",
487 | "4 static\n",
488 | "5 dynamic\n",
489 | "6 ground\n",
490 | "9 parking\n",
491 | "10 rail track\n",
492 | "14 guard rail\n",
493 | "15 bridge\n",
494 | "16 tunnel\n",
495 | "18 polegroup\n",
496 | "29 caravan\n",
497 | "30 trailer\n",
498 | "34 license plate\n",
499 | "total# 16\n"
500 | ]
501 | }
502 | ],
503 | "source": [
504 | "id2trainID = {}\n",
505 | "id2color = {}\n",
506 | "trainID2color = {}\n",
507 | "id2name = {}\n",
508 | "opensetIDlist = []\n",
509 | "for i in range(len(classDictionary)):\n",
510 | " id2trainID[i] = classDictionary[i][2]\n",
511 | " id2color[i] = classDictionary[i][-1]\n",
512 | " trainID2color[classDictionary[i][2]] = classDictionary[i][-1]\n",
513 | " id2name[i] = classDictionary[i][0]\n",
514 | " if classDictionary[i][-2]:\n",
515 | " opensetIDlist += [i]\n",
516 | "\n",
517 | "id2trainID_list = []\n",
518 | "for i in range(len(id2trainID)):\n",
519 | " id2trainID_list.append(id2trainID[i])\n",
520 | "id2trainID_np = np.asarray(id2trainID_list) \n",
521 | " \n",
522 | "for elm in opensetIDlist:\n",
523 | " print(elm, id2name[elm])\n",
524 | "print('total# {}'.format(len(opensetIDlist)))"
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": 11,
530 | "metadata": {},
531 | "outputs": [],
532 | "source": [
533 | "data_sampler = iter(dataloaders['train'])\n",
534 | "data = next(data_sampler)\n",
535 | "imageList, labelList = data[0], data[1]\n",
536 | "\n",
537 | "imageList = imageList.to(device)\n",
538 | "labelList = labelList.to(device)"
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": 12,
544 | "metadata": {},
545 | "outputs": [
546 | {
547 | "data": {
548 | "text/plain": [
549 | "(torch.Size([1, 3, 1024, 2048]), torch.Size([1, 1024, 2048]))"
550 | ]
551 | },
552 | "execution_count": 12,
553 | "metadata": {},
554 | "output_type": "execute_result"
555 | }
556 | ],
557 | "source": [
558 | "imageList.shape, labelList.shape"
559 | ]
560 | },
561 | {
562 | "cell_type": "markdown",
563 | "metadata": {},
564 | "source": [
565 | "setup training\n",
566 | "-----------"
567 | ]
568 | },
569 | {
570 | "cell_type": "code",
571 | "execution_count": 13,
572 | "metadata": {},
573 | "outputs": [],
574 | "source": [
575 | "# Initialize BCELoss function\n",
576 | "criterion = nn.BCELoss()\n",
577 | "\n",
578 | "# Create batch of latent vectors that we will use to visualize\n",
579 | "# the progression of the generator\n",
580 | "fixed_noise = torch.randn(64, nz, 1, 1, device=device)\n",
581 | "\n",
582 | "# Establish open and close labels\n",
583 | "close_label = 1\n",
584 | "open_label = 0\n",
585 | "\n",
586 | "# Establish convention for real and fake labels during training\n",
587 | "real_label = 1\n",
588 | "fake_label = 0\n",
589 | "\n",
590 | "\n",
591 | "# Setup Adam optimizers for both G and D\n",
592 | "optimizerD = optim.Adam(netD.parameters(), lr=lr/1.5, betas=(beta1, 0.999))\n",
593 | "optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))"
594 | ]
595 | },
596 | {
597 | "cell_type": "markdown",
598 | "metadata": {},
599 | "source": [
600 | "testing a single image\n",
601 | "-----------"
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "execution_count": 14,
607 | "metadata": {},
608 | "outputs": [],
609 | "source": [
610 | "labelList = labelList.unsqueeze(1)\n",
611 | "labelList = F.interpolate(labelList, scale_factor=0.25, mode='nearest')\n",
612 | "labelList = labelList.squeeze()\n",
613 | "H, W = labelList.squeeze().shape\n",
614 | "trainlabelList = id2trainID_np[labelList.cpu().numpy().reshape(-1,).astype(np.int32)]\n",
615 | "trainlabelList = trainlabelList.reshape((1,H,W))\n",
616 | "trainlabelList = torch.from_numpy(trainlabelList)\n",
617 | "\n",
618 | "\n",
619 | "\n",
620 | "upsampleFunc = nn.UpsamplingBilinear2d(scale_factor=4)\n",
621 | "with torch.no_grad():\n",
622 | " imageList = imageList.to(device)\n",
623 | " logitsTensor = model(imageList).detach().cpu()\n",
624 | " #logitsTensor = upsampleFunc(logitsTensor)\n",
625 | " softmaxTensor = F.softmax(logitsTensor, dim=1)\n",
626 | " \n",
627 | " feat1Tensor = model.feat1.detach()\n",
628 | " feat2Tensor = model.feat2.detach()\n",
629 | " feat3Tensor = model.feat3.detach()\n",
630 | " feat4Tensor = model.feat4.detach()\n",
631 | " feat5Tensor = model.feat5.detach()\n",
632 | " \n",
633 | " torch.cuda.empty_cache()"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": 15,
639 | "metadata": {
640 | "scrolled": false
641 | },
642 | "outputs": [
643 | {
644 | "data": {
645 | "text/plain": [
646 | "(torch.Size([1, 720, 256, 512]), torch.Size([1, 256, 512]), 131072)"
647 | ]
648 | },
649 | "execution_count": 15,
650 | "metadata": {},
651 | "output_type": "execute_result"
652 | }
653 | ],
654 | "source": [
655 | "feat4Tensor.shape, trainlabelList.shape, trainlabelList.shape[1]*trainlabelList.shape[2]"
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": 16,
661 | "metadata": {},
662 | "outputs": [],
663 | "source": [
664 | "validList = trainlabelList.reshape(-1,1)\n",
665 | "validList = ((validList>=0) & (validList<=18)).nonzero()\n",
666 | "validList = validList[:,0]\n",
667 | "validList = validList[torch.randperm(validList.size()[0])]\n",
668 | "validList = validList[:ganBatchSize]"
669 | ]
670 | },
671 | {
672 | "cell_type": "code",
673 | "execution_count": 17,
674 | "metadata": {},
675 | "outputs": [],
676 | "source": [
677 | "label = torch.full((ganBatchSize,), close_label, device=device)"
678 | ]
679 | },
680 | {
681 | "cell_type": "code",
682 | "execution_count": 18,
683 | "metadata": {},
684 | "outputs": [],
685 | "source": [
686 | "real_cpu = feat4Tensor.squeeze()\n",
687 | "real_cpu = real_cpu.reshape(real_cpu.shape[0], -1).permute(1,0)\n",
688 | "real_cpu = real_cpu[validList,:].unsqueeze(-1).unsqueeze(-1).to(device)\n",
689 | "\n",
690 | "output = netD(real_cpu).view(-1)\n",
691 | "# Calculate loss on all-real batch\n",
692 | "errD_real = criterion(output, label)"
693 | ]
694 | },
695 | {
696 | "cell_type": "code",
697 | "execution_count": 19,
698 | "metadata": {},
699 | "outputs": [],
700 | "source": [
701 | "noise = torch.randn(ganBatchSize, nz, 1, 1, device=device)\n",
702 | "# Generate fake image batch with G\n",
703 | "fake = netG(noise)\n",
704 | "label.fill_(fake_label)\n",
705 | "# Classify all fake batch with D\n",
706 | "output = netD(fake.detach()).view(-1)\n",
707 | "# Calculate D's loss on the all-fake batch\n",
708 | "errD_fake = criterion(output, label)"
709 | ]
710 | },
711 | {
712 | "cell_type": "code",
713 | "execution_count": 20,
714 | "metadata": {
715 | "scrolled": true
716 | },
717 | "outputs": [
718 | {
719 | "data": {
720 | "text/plain": [
721 | "(torch.Size([640, 64, 1, 1]), torch.Size([640]), torch.Size([640, 720, 1, 1]))"
722 | ]
723 | },
724 | "execution_count": 20,
725 | "metadata": {},
726 | "output_type": "execute_result"
727 | }
728 | ],
729 | "source": [
730 | "noise.shape, label.shape, fake.shape"
731 | ]
732 | },
733 | {
734 | "cell_type": "markdown",
735 | "metadata": {},
736 | "source": [
737 | "training GAN\n",
738 | "-----------"
739 | ]
740 | },
741 | {
742 | "cell_type": "code",
743 | "execution_count": 21,
744 | "metadata": {},
745 | "outputs": [
746 | {
747 | "name": "stdout",
748 | "output_type": "stream",
749 | "text": [
750 | "torch.Size([640, 720, 1, 1])\n"
751 | ]
752 | }
753 | ],
754 | "source": [
755 | "openPix_datasets = CityscapesOpenPixelFeat4(set_name='train', numImgs=num_open_training_images)\n",
756 | "openPix_dataloader = DataLoader(openPix_datasets, batch_size=1, shuffle=True, num_workers=4) \n",
757 | "\n",
758 | "openPix_sampler = iter(openPix_dataloader)\n",
759 | "\n",
760 | "openPixFeat = next(openPix_sampler)\n",
761 | "openPixFeat = openPixFeat.squeeze(0)\n",
762 | "\n",
763 | "openPixIdxList = torch.randperm(openPixFeat.size()[0])\n",
764 | "openPixIdxList = openPixIdxList[:ganBatchSize]\n",
765 | "openPixFeat = openPixFeat[openPixIdxList].to(device)\n",
766 | "\n",
767 | "print(openPixFeat.shape)"
768 | ]
769 | },
770 | {
771 | "cell_type": "code",
772 | "execution_count": null,
773 | "metadata": {
774 | "scrolled": true
775 | },
776 | "outputs": [
777 | {
778 | "name": "stdout",
779 | "output_type": "stream",
780 | "text": [
781 | "Starting Training Loop...\n",
782 | "[0/50][0/2975]\t\tlossG: 0.6536, lossD: 0.5840\n",
783 | "[0/50][100/2975]\t\tlossG: 0.6569, lossD: 0.4096\n",
784 | "[0/50][200/2975]\t\tlossG: 0.6236, lossD: 0.3967\n",
785 | "[0/50][300/2975]\t\tlossG: 0.5869, lossD: 0.2977\n",
786 | "[0/50][400/2975]\t\tlossG: 0.5548, lossD: 0.2332\n",
787 | "[0/50][500/2975]\t\tlossG: 0.5504, lossD: 0.3024\n"
788 | ]
789 | }
790 | ],
791 | "source": [
792 | "# Training Loop\n",
793 | "\n",
794 | "# Lists to keep track of progress\n",
795 | "lossList = []\n",
796 | "G_losses = []\n",
797 | "D_losses = []\n",
798 | "\n",
799 | "fake_BatchSize = int(ganBatchSize/2)\n",
800 | "open_BatchSize = ganBatchSize\n",
801 | "\n",
802 | "\n",
803 | "\n",
804 | "tmp_weights = torch.full((ganBatchSize+open_BatchSize+fake_BatchSize,), 1, device=device)\n",
805 | "tmp_weights[-fake_BatchSize:] *= weight_adversarialLoss\n",
806 | "criterionD = nn.BCELoss(weight=tmp_weights)\n",
807 | "\n",
808 | "\n",
809 | "\n",
810 | "print(\"Starting Training Loop...\")\n",
811 | "# For each epoch\n",
812 | "openPixImgCount = 0\n",
813 | "openPix_sampler = iter(openPix_dataloader)\n",
814 | "for epoch in range(num_epochs):\n",
815 | " # For each batch in the dataloader\n",
816 | " for i, sample in enumerate(dataloaders['train'], 0):\n",
817 | " imageList, labelList = sample\n",
818 | " imageList = imageList.to(device)\n",
819 | " labelList = labelList.to(device)\n",
820 | "\n",
821 | " labelList = labelList.unsqueeze(1)\n",
822 | " labelList = F.interpolate(labelList, scale_factor=0.25, mode='nearest')\n",
823 | " labelList = labelList.squeeze()\n",
824 | " H, W = labelList.squeeze().shape\n",
825 | " trainlabelList = id2trainID_np[labelList.cpu().numpy().reshape(-1,).astype(np.int32)]\n",
826 | " trainlabelList = trainlabelList.reshape((1,H,W))\n",
827 | " trainlabelList = torch.from_numpy(trainlabelList)\n",
828 | " \n",
829 | " \n",
830 | " #upsampleFunc = nn.UpsamplingBilinear2d(scale_factor=4)\n",
831 | " with torch.no_grad():\n",
832 | " imageList = imageList.to(device)\n",
833 | " logitsTensor = model(imageList).detach().cpu()\n",
834 | " featTensor = model.feat4.detach()\n",
835 | " \n",
836 | " validList = trainlabelList.reshape(-1,1)\n",
837 | " validList = ((validList>=0) & (validList<=18)).nonzero()\n",
838 | " validList = validList[:,0]\n",
839 | " tmp = torch.randperm(validList.size()[0]) \n",
840 | " validList = validList[tmp[:ganBatchSize]]\n",
841 | " \n",
842 | "\n",
843 | " \n",
844 | " label_closeset = torch.full((ganBatchSize,), close_label, device=device)\n",
845 | " feat_closeset = featTensor.squeeze()\n",
846 | " feat_closeset = feat_closeset.reshape(feat_closeset.shape[0], -1).permute(1,0)\n",
847 | " feat_closeset = feat_closeset[validList,:].unsqueeze(-1).unsqueeze(-1) \n",
848 | " label_open = torch.full((open_BatchSize,), open_label, device=device)\n",
849 | " \n",
850 | " openPixImgCount += 1\n",
851 | " feat_openset = next(openPix_sampler)\n",
852 | " feat_openset = feat_openset.squeeze(0)\n",
853 | " openPixIdxList = torch.randperm(feat_openset.size()[0])\n",
854 | " openPixIdxList = openPixIdxList[:open_BatchSize]\n",
855 | " feat_openset = feat_openset[openPixIdxList].to(device)\n",
856 | "\n",
857 | " if openPixImgCount==num_open_training_images:\n",
858 | " openPixImgCount = 0\n",
859 | " openPix_sampler = iter(openPix_dataloader)\n",
860 | " \n",
861 | " \n",
862 | " \n",
863 | " # generate fake images \n",
864 | " noise = torch.randn(fake_BatchSize, nz, 1, 1, device=device)\n",
865 | " # Generate fake image batch with G\n",
866 | " label_fake = torch.full((fake_BatchSize,), fake_label, device=device)\n",
867 | " feat_fakeset = netG(noise) \n",
868 | " \n",
869 | " ############################\n",
870 | " # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n",
871 | " ###########################\n",
872 | " # using close&open&fake data to update D\n",
873 | " netD.zero_grad()\n",
874 | " X = torch.cat((feat_closeset, feat_openset.to(device), feat_fakeset.detach()),0)\n",
875 | " label_total = torch.cat((label_closeset, label_open, label_fake),0)\n",
876 | " \n",
877 | " output = netD(X).view(-1)\n",
878 | " lossD = criterionD(output, label_total)\n",
879 | " lossD.backward()\n",
880 | " optimizerD.step()\n",
881 | " errD = lossD.mean().item() \n",
882 | " \n",
883 | " \n",
884 | " ############################\n",
885 | " # (2) Update G network: maximize log(D(G(z)))\n",
886 | " ###########################\n",
887 | " netG.zero_grad()\n",
888 | " label_fakeclose = torch.full((fake_BatchSize,), close_label, device=device) \n",
889 | " # Since we just updated D, perform another forward pass of all-fake batch through D\n",
890 | " output = netD(feat_fakeset).view(-1)\n",
891 | " # Calculate G's loss based on this output\n",
892 | " lossG = criterion(output, label_fakeclose)\n",
893 | " # Calculate gradients for G\n",
894 | " lossG.backward()\n",
895 | " errG = lossG.mean().item()\n",
896 | " # Update G\n",
897 | " optimizerG.step()\n",
898 | " \n",
899 | " \n",
900 | " # Save Losses for plotting later\n",
901 | " G_losses.append(errG)\n",
902 | " D_losses.append(errD)\n",
903 | " \n",
904 | " \n",
905 | " # Output training stats\n",
906 | " if i % 100 == 0:\n",
907 | " print('[%d/%d][%d/%d]\\t\\tlossG: %.4f, lossD: %.4f'\n",
908 | " % (epoch, num_epochs, i, len(dataloaders['train']), \n",
909 | " errG, errD))\n",
910 | " \n",
911 | " \n",
912 | " cur_model_wts = copy.deepcopy(netD.state_dict())\n",
913 | " path_to_save_paramOnly = os.path.join(save_dir, 'epoch-{}.classifier'.format(epoch+1))\n",
914 | " torch.save(cur_model_wts, path_to_save_paramOnly)\n",
915 | " cur_model_wts = copy.deepcopy(netG.state_dict())\n",
916 | " path_to_save_paramOnly = os.path.join(save_dir, 'epoch-{}.GNet'.format(epoch+1))\n",
917 | " torch.save(cur_model_wts, path_to_save_paramOnly)"
918 | ]
919 | },
920 | {
921 | "cell_type": "markdown",
922 | "metadata": {},
923 | "source": [
924 | "validating results\n",
925 | "-----------"
926 | ]
927 | },
928 | {
929 | "cell_type": "code",
930 | "execution_count": null,
931 | "metadata": {},
932 | "outputs": [],
933 | "source": [
934 | "plt.figure(figsize=(10,5))\n",
935 | "plt.title(\"binary cross-entropy loss in training\")\n",
936 | "plt.plot(Dopen_losses, label=\"Dopen_losses\")\n",
937 | "plt.plot(Dclose_losses, label=\"Dclose_losses\")\n",
938 | "plt.plot(Dfake_losses, label=\"Dfake_losses\")\n",
939 | "plt.plot(G_losses, label=\"G_losses\")\n",
940 | "plt.xlabel(\"iterations\")\n",
941 | "plt.ylabel(\"Loss\")\n",
942 | "plt.legend()\n",
943 | "# plt.savefig('learningCurves_{}.png'.format(modelFlag), bbox_inches='tight',transparent=True)\n",
944 | "# plt.show()"
945 | ]
946 | },
947 | {
948 | "cell_type": "code",
949 | "execution_count": null,
950 | "metadata": {},
951 | "outputs": [],
952 | "source": []
953 | }
954 | ],
955 | "metadata": {
956 | "kernelspec": {
957 | "display_name": "Python 3",
958 | "language": "python",
959 | "name": "python3"
960 | },
961 | "language_info": {
962 | "codemirror_mode": {
963 | "name": "ipython",
964 | "version": 3
965 | },
966 | "file_extension": ".py",
967 | "mimetype": "text/x-python",
968 | "name": "python",
969 | "nbconvert_exporter": "python",
970 | "pygments_lexer": "ipython3",
971 | "version": "3.7.4"
972 | }
973 | },
974 | "nbformat": 4,
975 | "nbformat_minor": 2
976 | }
977 |
--------------------------------------------------------------------------------