├── README.md
├── col1.png
├── col2.png
├── col3.png
├── data
├── .DS_Store
└── pts_in_hull.npy
├── data_process.py
├── deep_color.py
├── download.sh
├── global_hint.py
├── sampling.py
├── unet.py
└── util.py
/README.md:
--------------------------------------------------------------------------------
1 | # "Real Time User-guided Colorization with Learned Deep Priors" implemented in pytorch
2 |
3 | This is a pytorch implementation of ["Real-Time User-Guided Image Colorization with Learned Deep Priors"](https://arxiv.org/abs/1705.02999) by Zhang et.al.
4 |
5 | ## Getting Started
6 |
7 | ### Prerequisites
8 |
9 | torch==0.2.0.post4, torchvision==0.1.9
10 | The code is written with the default setting that you have gpu. Cpu mode is not recommended when using this repository.
11 |
12 | ### Installing and running the tests
13 |
14 | Make sure you have cifar10 or CelebA downloaded in ./data.
15 | You can download it through by taking a look at my "download.sh" file
16 | ```
17 | ./data/CelebA
18 | ./data/Cifar10
19 | ./data/pts_in_hull.npy
20 | ```
21 |
22 | first clone this repository
23 |
24 | ```
25 | git clone https://github.com/sjooyoo/https://github.com/sjooyoo/real-time-user-guided-colorization_pytorch.git
26 | ```
27 | then run train
28 |
29 | ```
30 | python deep_color.py
31 | ```
32 |
33 | to sample results you first need to run deep_color.py, which will automatically save models under a models folder that will be made in your root directory.
34 | I did not include pretrained models in this repository. The --model unet100.pkl below is a sample after 100 epochs. Change the command according to your model that you want to sample.
35 | ```
36 | python sampling.py --model unet100.pkl
37 | ```
38 |
39 |
40 | ### Results
41 |
42 | Input black and white image
43 |
44 |
45 |
46 | Predicted colorization output
47 |
48 |
49 |
50 | Ground truth image
51 |
52 |
53 |
54 |
55 | ### Note
56 | This is not a complete implementation. I have implemented the global hints network but have yet to incorporate it into the main network.
57 |
58 |
59 | ### Further work
60 | * global hints network
61 |
62 |
63 | ## Acknowledgments
64 | Original paper ["Real-Time User-Guided Image Colorization with Learned Deep Priors"](https://arxiv.org/abs/1705.02999)
65 |
--------------------------------------------------------------------------------
/col1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/col1.png
--------------------------------------------------------------------------------
/col2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/col2.png
--------------------------------------------------------------------------------
/col3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/col3.png
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/data/.DS_Store
--------------------------------------------------------------------------------
/data/pts_in_hull.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/data/pts_in_hull.npy
--------------------------------------------------------------------------------
/data_process.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchvision.datasets as dsets
4 | import torchvision.transforms as transforms
5 | from skimage.color import rgb2lab
6 |
7 | from global_hint import *
8 |
9 |
10 | def Color_Dataloader(dataset, batch_size):
11 | if dataset == 'cifar':
12 | transform = transforms.Compose([
13 | transforms.ToTensor()
14 | ])
15 | train_dataset = dsets.CIFAR10(root='./data/',
16 | train=True,
17 | transform=transform,
18 | download=True)
19 |
20 | val_dataset = dsets.CIFAR10(root='./data/',
21 | train=False,
22 | transform=transform)
23 | # Data Loader-> it will hand in dataset by size batch
24 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
25 | batch_size=batch_size,
26 | shuffle=True)
27 |
28 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
29 | batch_size=batch_size,
30 | shuffle=False)
31 | imsize = 32
32 |
33 | elif dataset == 'imagenet':
34 |
35 | traindir = './data/tiny-imagenet-200/train/'
36 | valdir = './data/tiny-imagenet-200/val/'
37 | transform = transforms.Compose([
38 | transforms.ToTensor()
39 | ])
40 |
41 | train_dataset = dsets.ImageFolder(traindir, transform)
42 | val_dataset = dsets.ImageFolder(valdir, transform)
43 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
44 | batch_size=batch_size,
45 | shuffle=True,
46 | num_workers=2)
47 |
48 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
49 | batch_size=batch_size,
50 | shuffle=True,
51 | num_workers=2)
52 | imsize = 64
53 |
54 |
55 | elif dataset == 'celeba':
56 |
57 | traindir = './data/CelebA/trainimages/images'
58 | valdir= './data/CelebA/valimages'
59 | transform = transforms.Compose([
60 | transforms.ToTensor()
61 | ])
62 |
63 | train_dataset = dsets.ImageFolder(traindir, transform=transform)
64 | val_dataset = dsets.ImageFolder(valdir, transform=transform)
65 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
66 | batch_size=batch_size,
67 | shuffle=True,
68 | num_workers=2)
69 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
70 | batch_size=batch_size,
71 | shuffle=True,
72 | num_workers=2)
73 | imsize = 128
74 |
75 | elif dataset == 'mscoco':
76 |
77 | traindir = './data/mscoco/trainimages_resized'
78 | valdir = './data/mscoco/valimages_resized'
79 | # Load mscoco data
80 | transform = transforms.Compose([
81 | transforms.ToTensor()
82 | ])
83 |
84 | train_dataset = dsets.ImageFolder(traindir, transform=transform)
85 | val_dataset = dsets.ImageFolder(valdir, transform=transform)
86 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
87 | batch_size=batch_size,
88 | shuffle=True,
89 | num_workers=2)
90 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
91 | batch_size=batch_size,
92 | shuffle=True,
93 | num_workers=2)
94 | imsize = 32
95 |
96 | return train_dataset, train_loader, val_loader, imsize
97 |
98 |
99 | def process_data(image_data, batch_size, imsize, islocal):
100 | input = torch.zeros(batch_size, 1, imsize, imsize)
101 | labels = torch.zeros(batch_size, 2, imsize, imsize)
102 | images_np = image_data.numpy().transpose((0, 2, 3, 1))
103 |
104 | if islocal == False:
105 | ab_for_global = torch.zeros(batch_size, 2, imsize, imsize)
106 |
107 | for k in range(batch_size):
108 | img_lab = rgb2lab(images_np[k])
109 |
110 | img_l = img_lab[:, :, 0] / 100
111 | input[k] = torch.from_numpy(np.expand_dims(img_l, 0))
112 |
113 | img_ab_scale = (img_lab[:, :, 1:3] + 100) / 200
114 | labels[k] = torch.from_numpy(img_ab_scale.transpose((2, 0, 1)))
115 |
116 | img_ab_unscale = img_lab[:, :, 1:3]
117 | ab_for_global[k] = torch.from_numpy(img_ab_unscale.transpose((2, 0, 1)))
118 |
119 | if islocal == True:
120 | for k in range(batch_size):
121 | img_lab = rgb2lab(images_np[k])
122 |
123 | img_l = img_lab[:, :, 0] / 100
124 | input[k] = torch.from_numpy(np.expand_dims(img_l, 0))
125 |
126 | img_ab_scale = (img_lab[:, :, 1:3] + 100) / 200
127 | labels[k] = torch.from_numpy(img_ab_scale.transpose((2, 0, 1)))
128 |
129 | ab_for_global = 0 # just to make the room. don't need it in local net
130 |
131 | return input, labels, ab_for_global
132 |
133 |
134 | def process_global(images, input_ab, batch_size, imsize, hist_mean, hist_std):
135 | glob_quant = Global_Quant(batch_size, imsize)
136 | X_hist = glob_quant.global_histogram(input_ab) # batch x 313 x imsize x imsize
137 | X_sat = glob_quant.global_saturation(images).unsqueeze(1) # batch x 1
138 | B_hist, B_sat = glob_quant.global_masks(batch_size) # if masks are 0, put uniform random(0~1) value in it
139 |
140 | for l in range(batch_size):
141 | if B_sat[l].numpy() == 0:
142 | X_sat[l] = torch.normal(torch.FloatTensor([hist_mean]), std=torch.FloatTensor([hist_std]))
143 | if B_hist[l].numpy() == 0:
144 | tmp = torch.rand(313)
145 | X_hist[l] = torch.div(tmp, torch.sum(tmp))
146 | global_input = torch.cat([X_hist, B_hist, X_sat, B_sat], 1).unsqueeze(2).unsqueeze(2)
147 | # batch x (q+1) = batch x 316 x 1 x 1
148 |
149 | return global_input
150 |
151 | def process_local(input_ab, batch_size, imsize):
152 | num_points = torch.zeros(batch_size).geometric_(0.125).long() # number of points to give as hints
153 | block_size = torch.zeros(batch_size, 1).uniform_(-0.5, 2.49).round().clamp(0, 2).long() # size of blocks to average
154 | local_ab = torch.zeros(batch_size, 2, imsize, imsize) # output local hint (ab channel)
155 | local_mask = torch.zeros(batch_size, 1, imsize, imsize).long() # output local hint (mask)
156 |
157 | for i in range(batch_size): # for all batches and
158 | for j in range(num_points[i]):
159 | gaussian_points = torch.zeros(2).normal_(mean=imsize/2, std=imsize/4).round().clamp(0, imsize-1).long()
160 | local_ab[i], local_mask[i] = \
161 | local_get_average_value(local_ab[i], input_ab[i], local_mask[i], gaussian_points, block_size[i], imsize)
162 |
163 | return local_ab, local_mask.float()
164 |
165 | # get average value in local_ab for random sized box at certain points.
166 | def local_get_average_value(local_ab, input_ab, local_mask, loc, p, imsize): # width 0~4
167 |
168 | low_v = loc[0]-p[0] #lower bound 0
169 | if low_v<0:
170 | low_v=0
171 | high_v = loc[0]+p[0]+1 #higher bound imsize-1
172 | if high_v>=imsize:
173 | high_v=imsize
174 | low_h = loc[1]-p[0] #lower bound 0
175 | if low_h<0:
176 | low_h=0
177 | high_h = loc[1]+p[0]+1 #higher bound imsize-1
178 | if high_h>=imsize:
179 | high_h=imsize
180 |
181 |
182 | local_mask[:, low_v:high_v, low_h:high_h] = 1
183 | local_ab = torch.mul(local_mask.repeat(2, 1, 1).float(), input_ab)
184 | local_mean_a = torch.sum(local_ab[0,:,:]) / len(torch.nonzero(local_ab[0,:,:]))
185 | local_mean_b = torch.sum(local_ab[1,:,:]) / len(torch.nonzero(local_ab[1,:,:]))
186 | local_a = local_mask.float() * local_mean_a # 1 x 32 x 32
187 | local_b = local_mask.float() * local_mean_b
188 | local_ab = torch.cat([local_a, local_b], dim=0)
189 | return local_ab, local_mask
190 |
191 |
192 |
193 | def process_global_sampling(batch_size, imsize, hist_mean, hist_std,
194 | HIST=False, SAT=False, hist_ref_idx=1, sat_ref_idx=1):
195 | glob_quant = Global_Quant(batch_size, imsize)
196 |
197 | if HIST==True:
198 | input_ab_for_hist = hist_ref(batch_size, imsize, hist_ref_idx)
199 | X_hist = glob_quant.global_histogram(input_ab_for_hist) # batch x 313 x imsize x imsize
200 | B_hist = torch.ones(batch_size, 1)
201 |
202 | else:
203 | tmp = torch.rand(batch_size, 313)
204 | X_hist = torch.div(tmp, torch.sum(tmp, dim=1).unsqueeze(1).repeat(1, 313))
205 | B_hist = torch.zeros(batch_size, 1)
206 |
207 | if SAT==True:
208 | image_for_sat = (batch_size, imsize, sat_ref_idx)
209 | X_sat = glob_quant.global_saturation(image_for_sat).unsqueeze(1) # batch x 1
210 | B_sat = torch.ones(batch_size, 1) # if masks are 0, put uniform random(0~1) value in it
211 |
212 | else:
213 | X_sat = torch.randn(batch_size, 1)
214 | for l in range(batch_size):
215 | X_sat[l] = torch.normal(torch.FloatTensor([hist_mean]), std=torch.FloatTensor([hist_std]))
216 | B_sat = torch.zeros(batch_size, 1)
217 |
218 | global_input = torch.cat([X_hist, B_hist, X_sat, B_sat], 1).unsqueeze(2).unsqueeze(2)
219 | # batch x (q+1) = batch x 316 x 1 x 1
220 |
221 | return global_input
222 |
223 | def process_local_sampling(batch_size, imsize, p):
224 |
225 | ab_input = torch.FloatTensor([0,0]).unsqueeze(0)
226 | xy_input = torch.LongTensor([0,0]).unsqueeze(0)
227 | q=0
228 | while q is not -1:
229 | ab_list = []
230 | xy_list = []
231 | x = int(input("Enter a number for x: "))
232 | y = int(input("Enter a number for y: "))
233 | a = int(input("For which color you want to apply?: (between -100 and 100)"))
234 | b = int(input("For which color you want to apply?: (between -100 and 100)"))
235 | a = ((a+100)/200)
236 | b = ((b+100)/200)
237 | xy_list.append(x)
238 | xy_list.append(y)
239 | ab_list.append(a)
240 | ab_list.append(b)
241 | xy_list = torch.LongTensor([xy_list])
242 | ab_list = torch.FloatTensor([ab_list])
243 | xy_input = torch.cat([xy_input, xy_list], dim=0) # n x 2 with 1 x 2 all zeros
244 | ab_input = torch.cat([ab_input, ab_list], dim=0) # n x 2 with 1 x 2 all zeros
245 | q = int(input("Enter -1 to finish: "))
246 |
247 | local_ab = torch.zeros(batch_size, 2, imsize, imsize) # output local hint (ab channel)
248 | local_mask = torch.zeros(batch_size, 1, imsize, imsize).long() # output local hint (mask)
249 | # print(torch.sum(local_ab))
250 | # print(torch.sum(local_mask))
251 | for i in range(batch_size): # for all batches and
252 | for j in range(ab_input.size(0)-1):
253 | # print(ab_input.size(0)-1)
254 | # print(ab_input[j+1])
255 |
256 | low_v = xy_input[j+1][0] - p # lower bound 0
257 | if low_v < 0:
258 | low_v = 0
259 | high_v = xy_input[j+1][0] + p + 1 # higher bound imsize-1
260 | if high_v >= imsize:
261 | high_v = imsize
262 | low_h = xy_input[j+1][1] - p # lower bound 0
263 | if low_h < 0:
264 | low_h = 0
265 | high_h = xy_input[j+1][1] + p + 1 # higher bound imsize-1
266 | if high_h >= imsize:
267 | high_h = imsize
268 |
269 | local_ab[i,0, low_v:high_v, low_h:high_h] = ab_input[j + 1][0]
270 | local_ab[i,1, low_v:high_v, low_h:high_h] = ab_input[j + 1][1]
271 | local_mask[i,:,low_v:high_v, low_h:high_h] = 1
272 | print(len(torch.nonzero(local_ab[i])), len(torch.nonzero(local_mask[i])))
273 |
274 | return local_ab, local_mask.float()
275 |
276 | def hist_ref(batch, imsize, idx=1):
277 | valdir = './data/sample/hist'
278 | transform = transforms.Compose([
279 | transforms.Scale((imsize,imsize)),
280 | transforms.ToTensor(),
281 |
282 | ])
283 |
284 | val_dataset = dsets.ImageFolder(valdir, transform)
285 |
286 |
287 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
288 | batch_size=1,
289 | shuffle=False,
290 | num_workers=2)
291 |
292 | for i, (image, _) in enumerate(val_loader):
293 | if i==(idx-1):
294 | ref_image = image
295 | print('%dth image chosen as reference for histogram'%(idx))
296 | break
297 |
298 | ref_image = ref_image.numpy().transpose((0, 2, 3, 1))
299 | img_lab = rgb2lab(ref_image)
300 | img_ab = img_lab[:, :, :, 1:3]
301 |
302 | pick_ref = torch.from_numpy(img_ab.transpose((0, 3, 1, 2))).repeat(batch,1,1,1).float()
303 |
304 | return pick_ref
305 |
306 | def sat_ref(batch, imsize, idx=1):
307 | valdir = './data/sample/sat'
308 | transform = transforms.Compose([
309 | transforms.Scale((imsize,imsize)),
310 | transforms.ToTensor(),
311 |
312 | ])
313 |
314 | val_dataset = dsets.ImageFolder(valdir, transform)
315 |
316 |
317 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
318 | batch_size=1,
319 | shuffle=False,
320 | num_workers=2)
321 |
322 | for i, (image, _) in enumerate(val_loader):
323 | if i==(idx-1):
324 | ref_image = image
325 | print('%dth image chosen as reference for saturation'%(idx))
326 | break
327 |
328 | print(ref_image.size())
329 | pick_ref = ref_image.repeat(batch, 1, 1, 1).float()
330 |
331 | return pick_ref
--------------------------------------------------------------------------------
/deep_color.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import argparse
5 | import numpy as np
6 | import torch.nn as nn
7 | from torch import cuda
8 | from torch.autograd import Variable
9 |
10 | from unet import *
11 | from util import *
12 | from global_hint import *
13 | from data_process import *
14 |
15 |
16 | # Hyper Parameters
17 |
18 |
19 | # arguments parsed when initiating
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--data', type=str, default='cifar', choices=['cifar', 'imagenet', 'celeba', 'mscoco'])
23 | parser.add_argument('--gpu', type=int, default=1)
24 | parser.add_argument('--model_path', type=str, default='./models')
25 | parser.add_argument('--log_path', type=str, default='./logs')
26 | parser.add_argument('--model', type=str, default='unet100.pkl')
27 | parser.add_argument('--image_save', type=str, default='./images')
28 | parser.add_argument('--learning_rate', type=int, default=0.0002)
29 | parser.add_argument('--num_epochs', type=int, default=500)
30 | parser.add_argument('--start_epoch', type=int, default=0)
31 | parser.add_argument('--batch_size', type=int, default=64)
32 | parser.add_argument('--idx', type=int, default=1)
33 | parser.add_argument('--resume', type=bool, default=False,
34 | help='path to latest checkpoint (default: none)')
35 | parser.add_argument('--islocal', type=bool, default=False)
36 |
37 | return parser.parse_args()
38 |
39 |
40 | def main(args):
41 | dataset = args.data
42 | gpu = args.gpu
43 | batch_size = args.batch_size
44 | model_path = args.model_path
45 | log_path = args.log_path
46 | num_epochs = args.num_epochs
47 | learning_rate = args.learning_rate
48 | start_epoch = args.start_epoch
49 | islocal = args.islocal
50 |
51 | # make directory for models saved when there is not.
52 | make_folder(model_path, dataset) # for sampling model
53 | make_folder(log_path, dataset) # for logpoint model
54 | make_folder(log_path, dataset +'/ckpt') # for checkpoint model
55 |
56 | # see if gpu is on
57 | print("Running on gpu : ", gpu)
58 | cuda.set_device(gpu)
59 |
60 | # set the data-loaders
61 | train_dataset, train_loader, val_loader, imsize = Color_Dataloader(dataset, batch_size)
62 |
63 | # declare unet class
64 | unet = UNet(imsize, islocal)
65 |
66 | # make the class run on gpu
67 | unet.cuda()
68 |
69 | # Loss and Optimizer
70 | optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate)
71 | criterion = torch.nn.SmoothL1Loss()
72 |
73 | # optionally resume from a checkpoint
74 | if args.resume:
75 | ckpt_path = os.path.join(log_path, dataset, 'ckpt/local/model.ckpt')
76 | if os.path.isfile(ckpt_path):
77 | print("=> loading checkpoint")
78 | checkpoint = torch.load(ckpt_path)
79 | start_epoch = checkpoint['epoch']
80 | unet.load_state_dict(checkpoint['state_dict'])
81 | optimizer.load_state_dict(checkpoint['optimizer'])
82 | print("=> Loaded checkpoint (epoch {})".format(checkpoint['epoch']))
83 | print("=> Meaning that start training from (epoch {})".format(checkpoint['epoch']+1))
84 | else:
85 | print("=> Sorry, no checkpoint found at '{}'".format(args.resume))
86 |
87 | # record time
88 | tell_time = Timer()
89 | iter = 0
90 | # Train the Model
91 | for epoch in range(start_epoch, num_epochs):
92 |
93 | unet.train()
94 | for i, (images, _) in enumerate(train_loader):
95 |
96 | batch = images.size(0)
97 | '''
98 | additional variables for later use.
99 | change the picture type from rgb to CIE Lab.
100 | def process_data, def process_global in util file
101 | '''
102 | if islocal:
103 | input, labels, _ = process_data(images, batch, imsize, islocal)
104 | local_ab, local_mask = process_local(labels, batch, imsize)
105 | side_input = torch.cat([local_ab, local_mask], 1) # concat([batch x 2 x imsize x imsize , batch x 1 x imsize x imsize], 1) = batch x 3 x imsize x imsize
106 | random_expose = random.randrange(1, 101)
107 | if random_expose == 100:
108 | print("Jackpot! expose the whole!")
109 | local_mask = torch.ones(batch_size, 1, imsize, imsize)
110 | side_input = torch.cat([labels, local_mask], 1)
111 | else: # if is local
112 | input, labels, ab_for_global = process_data(images, batch, imsize, islocal)
113 | side_input = process_global(images, ab_for_global, batch, imsize, hist_mean=0.03, hist_std=0.13)
114 |
115 |
116 | # make them all variable + gpu avialable
117 |
118 | input = Variable(input).cuda()
119 | labels = Variable(labels).cuda()
120 | side_input = Variable(side_input).cuda()
121 |
122 | # initialize gradients
123 | optimizer.zero_grad()
124 | outputs = unet(input, side_input)
125 |
126 | # make outputs and labels as a matrix for loss calculation
127 | outputs = outputs.view(batch, -1) # 100 x 32*32*3(2048)
128 | labels = labels.contiguous().view(batch, -1) # 100 x 32*32*3
129 |
130 | loss_train = criterion(outputs, labels)
131 | loss_train.backward()
132 | optimizer.step()
133 |
134 | if (i + 1) % 10 == 0:
135 | print('Epoch [%d/%d], Iter [%d/%d], Loss: %.10f, iter_time: %2.2f, aggregate_time: %6.2f'
136 | % (epoch + 1, num_epochs, i + 1, (len(train_dataset) // batch_size), loss_train.data[0],
137 | (tell_time.toc() - iter), tell_time.toc()))
138 | iter = tell_time.toc()
139 |
140 | torch.save(unet.state_dict(), os.path.join(model_path, dataset, 'unet%d.pkl' % (epoch + 1)))
141 |
142 | # start evaluation
143 | print("-------------evaluation start------------")
144 |
145 | unet.eval()
146 | loss_val_all = Variable(torch.zeros(100), volatile=True).cuda()
147 | for i, (images, _) in enumerate(val_loader):
148 |
149 | # change the picture type from rgb to CIE Lab
150 | batch = images.size(0)
151 |
152 | if islocal:
153 | input, labels, _ = process_data(images, batch, imsize, islocal)
154 | local_ab, local_mask = process_local(labels, batch, imsize)
155 | side_input = torch.cat([local_ab, local_mask], 1)
156 | random_expose = random.randrange(1, 101)
157 | if random_expose == 100:
158 | print("Jackpot! expose the whole!")
159 | local_mask = torch.ones(batch_size, 1, imsize, imsize)
160 | side_input = torch.cat([labels, local_mask], 1)
161 | else: # if is local
162 | input, labels, ab_for_global = process_data(images, batch, imsize, islocal)
163 | side_input = process_global(images, ab_for_global, batch, imsize, hist_mean=0.03, hist_std=0.13)
164 |
165 | # make them all variable + gpu avialable
166 |
167 | input = Variable(input).cuda()
168 | labels = Variable(labels).cuda()
169 | side_input = Variable(side_input).cuda()
170 |
171 | # initialize gradients
172 | optimizer.zero_grad()
173 | outputs = unet(input, side_input)
174 |
175 | # make outputs and labels as a matrix for loss calculation
176 | outputs = outputs.view(batch, -1) # 100 x 32*32*3(2048)
177 | labels = labels.contiguous().view(batch, -1) # 100 x 32*32*3
178 |
179 | loss_val = criterion(outputs, labels)
180 |
181 | logpoint = {
182 | 'epoch': epoch + 1,
183 | 'args': args,
184 | }
185 | checkpoint = {
186 | 'epoch': epoch + 1,
187 | 'args': args,
188 | 'state_dict': unet.state_dict(),
189 | 'optimizer': optimizer.state_dict(),
190 | }
191 |
192 | loss_val_all[i] = loss_val
193 |
194 | if i == 30:
195 | print('Epoch [%d/%d], Validation Loss: %.10f'
196 | % (epoch + 1, num_epochs, torch.mean(loss_val_all).data[0]))
197 | torch.save(logpoint, os.path.join(log_path, dataset, 'Model_e%d_train_%.4f_val_%.4f.pt' %
198 | (epoch + 1, torch.mean(loss_train).data[0],
199 | torch.mean(loss_val_all).data[0])))
200 | torch.save(checkpoint, os.path.join(log_path, dataset, 'ckpt/model.ckpt'))
201 | break
202 |
203 |
204 | if __name__ == '__main__':
205 | args = parse_args()
206 | main(args)
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [ $FILE == 'CelebA_FD' ]
4 | then
5 | URL=https://www.dropbox.com/s/e0ig4nf1v94hyj8/CelebA.zip?dl=0
6 | ZIP_FILE=./data/CelebA_FD.zip
7 | elif [ $FILE == 'CelebA' ]
8 | then
9 | URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0
10 | ZIP_FILE=./data/CelebA.zip
11 | elif [ $FILE == 'LSUN' ]
12 | then
13 | URL=https://www.dropbox.com/s/zt7d2hchrw7cp9p/church_outdoor_train_lmdb.zip?dl=0
14 | ZIP_FILE=./data/church_outdoor_train_lmdb.zip
15 | else
16 | echo "Available datasets are: CelebA, CelebA_FD and LSUN"
17 | exit 1
18 | fi
19 |
20 | mkdir -p ./data/
21 | wget -N $URL -O $ZIP_FILE
22 | unzip $ZIP_FILE -d ./data/
23 |
24 | if [ $FILE == 'CelebA' ]
25 | then
26 | mv ./data/CelebA_nocrop ./data/CelebA
27 | elif [ $FILE == 'CelebA_FD' ]
28 | then
29 | mv ./data/CelebA ./data/CelebA_FD
30 | fi
31 |
32 | rm $ZIP_FILE
33 |
--------------------------------------------------------------------------------
/global_hint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import sklearn.neighbors as neigh
4 | from skimage.color import rgb2hsv
5 |
6 | from unet import *
7 | import util
8 |
9 |
10 | class Global_Quant():
11 | ''' Layer which encodes ab map into Q colors
12 | '''
13 | def __init__(self, batch, imsize):
14 | self.quantization = Quantization(batch, imsize, km_filepath='./data/pts_in_hull.npy')
15 |
16 | def global_histogram(self, input):
17 | out = self.quantization.encode_nn(input) # batch x 313 x imsize x imsize
18 | out = out.type(torch.FloatTensor) # change it to tensor
19 | X_onehotsum = torch.sum(torch.sum(out, dim=3), dim=2) # sum it up to batch x 313
20 | X_hist = torch.div(X_onehotsum, util.expand(torch.sum(X_onehotsum, dim=1).unsqueeze(1), X_onehotsum)) # make 313 probability
21 | return X_hist
22 |
23 | def global_saturation(self, images): # input: tensor images batch x 3 x imsize x imsize (rgb)
24 | images_np = images.numpy().transpose((0, 2, 3, 1)) # numpy: batch x imsize x imsize x 3
25 | images_h = torch.zeros(images.size(0), 1, images.size(2),images.size(2))
26 | for k in range(images.size(0)):
27 | img_hsv = rgb2hsv(images_np[k])
28 | img_h = img_hsv[:, :, 1]
29 | images_h[k] = torch.from_numpy(img_h).unsqueeze(0) # batch x 1 x imsize x imsize
30 | avgs = torch.mean(images_h.view(images.size(0), -1),dim=1) # batch x 1
31 | return avgs
32 |
33 | def global_masks(self, batch_size): # both for histogram and saturation
34 | B_hist = torch.round(torch.rand(batch_size, 1))
35 | B_sat = torch.round(torch.rand(batch_size, 1))
36 | return B_hist, B_sat
37 |
38 | class Quantization():
39 | # Encode points as a linear combination of unordered points
40 | # using NN search and RBF kernel
41 | def __init__(self,batch, imsize, km_filepath='./data/pts_in_hull.npy' ):
42 |
43 | self.cc = torch.from_numpy(np.load(km_filepath)).type(torch.FloatTensor) # 313 x 2
44 | self.K = self.cc.shape[0]
45 | self.batch = batch
46 | self.imsize = imsize
47 |
48 | def encode_nn(self,images): # batch x imsize x imsize x 2
49 |
50 | images = images.permute(0,2,3,1) # batch x 2 x imsize x imsize -> batch x imsize x imsize x 2
51 | images_flt = images.contiguous().view(-1, 2)
52 | P = images_flt.shape[0]
53 | inds = self.nearest_inds(images_flt, self.cc).unsqueeze(1) # P x 1
54 | images_encoded = torch.zeros(P,self.K)
55 | images_encoded.scatter_(1, inds, 1)
56 | images_encoded = images_encoded.view(self.batch, self.imsize, self.imsize, 313)
57 | images_encoded = images_encoded.permute(0,3,1,2)
58 | return images_encoded
59 |
60 | def nearest_inds(self, x, y): # x= n x 2, y= 313 x 2 n x 2, 2 x 313 = n x 313
61 | inner = torch.matmul(x, y.t())
62 | normX = torch.sum(torch.mul(x, x), 1).unsqueeze(1).expand_as(inner)
63 | normY = torch.sum(torch.mul(y, y), 1).unsqueeze(1).t().expand_as(inner) # n x 313
64 | P = normX - 2 * inner + normY
65 | nearest_idx = torch.min(P, dim=1)[1]
66 | return nearest_idx
67 |
68 |
69 |
70 | # def decode_points_mtx_nd(self,pts_enc_nd,axis=1):
71 | # pts_enc_flt = util.flatten_nd_array(pts_enc_nd,axis=axis)
72 | # pts_dec_flt = np.dot(pts_enc_flt,self.cc)
73 | # pts_dec_nd = util.unflatten_2d_array(pts_dec_flt,pts_enc_nd,axis=axis)
74 | # return pts_dec_nd
75 | #
76 | # def decode_1hot_mtx_nd(self,pts_enc_nd,axis=1,returnEncode=False):
77 | # pts_1hot_nd = nd_argmax_1hot(pts_enc_nd,axis=axis)
78 | # pts_dec_nd = self.decode_points_mtx_nd(pts_1hot_nd,axis=axis)
79 | # if(returnEncode):
80 | # return (pts_dec_nd,pts_1hot_nd)
81 | # else:
82 | # return pts_dec_nd
--------------------------------------------------------------------------------
/sampling.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import torchvision
5 | import numpy as np
6 | import torch.nn as nn
7 | from torch import cuda
8 | from torch.autograd import Variable
9 | from skimage.color import rgb2lab, lab2rgb, rgb2gray
10 |
11 | from unet import *
12 | from util import *
13 | from global_hint import *
14 | from data_process import *
15 |
16 |
17 |
18 | # Hyper Parameters
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--data', type=str, default='cifar', choices=['cifar', 'imagenet', 'celeba', 'mscoco'])
23 | parser.add_argument('--gpu', type=int, default=1)
24 | parser.add_argument('--model_path', type=str, default='./models')
25 | parser.add_argument('--model', type=str, default='unet100.pkl')
26 | parser.add_argument('--image_save', type=str, default='./images')
27 | parser.add_argument('--learning_rate', type=int, default=0.001)
28 | parser.add_argument('--num_epochs', type=int, default=100)
29 | parser.add_argument('--batch_size', type=int, default=64)
30 | parser.add_argument('--idx', type=int, default=1)
31 | parser.add_argument('--global_hist', type=bool, default=False)
32 | parser.add_argument('--global_sat', type=bool, default=False)
33 | parser.add_argument('--hist_ref_idx', type=int, default=1)
34 | parser.add_argument('--sat_ref_idx', type=int, default=1)
35 | parser.add_argument('--islocal', type=bool, default=False)
36 | parser.add_argument('--nohint', type=bool, default=False)
37 |
38 |
39 | return parser.parse_args()
40 |
41 |
42 |
43 | def main(args):
44 | dataset = args.data
45 | gpu = args.gpu
46 | batch_size = args.batch_size
47 | model_path = args.model_path
48 | image_save = args.image_save
49 | model = args.model
50 | idx = args.idx
51 | global_hist = args.global_hist
52 | global_sat = args.global_sat
53 | hist_ref_idx = args.hist_ref_idx
54 | sat_ref_idx = args.hist_ref_idx
55 | islocal = args.islocal
56 | nohint = args.nohint
57 |
58 | make_folder(image_save, dataset)
59 |
60 | print("Running on gpu : ", gpu)
61 | cuda.set_device(gpu)
62 |
63 | _, _, test_loader, imsize = Color_Dataloader(dataset, batch_size)
64 |
65 | unet = UNet(imsize, islocal)
66 |
67 | unet.cuda()
68 |
69 | unet.eval()
70 | unet.load_state_dict(torch.load(os.path.join(model_path, dataset, model)))
71 |
72 |
73 | for i, (images, _) in enumerate(test_loader):
74 |
75 | batch = images.size(0)
76 | '''
77 | additional variables for later use.
78 | change the picture type from rgb to CIE Lab.
79 | def process_data, def process_global in util file
80 | '''
81 | if islocal:
82 | input, labels, _ = process_data(images, batch, imsize, islocal)
83 | local_ab, local_mask = process_local_sampling(batch_size, imsize, p=1)
84 | if nohint:
85 | local_ab = torch.zeros(batch_size, 2, imsize, imsize)
86 | local_mask = torch.zeros(batch_size, 1, imsize, imsize)
87 |
88 | side_input = torch.cat([local_ab, local_mask], 1)
89 |
90 |
91 | else:
92 | input, labels, ab_for_global = process_data(images, batch, imsize, islocal)
93 |
94 | print('global hint for histogram : ', global_hist)
95 | print('global hint for saturation : ', global_sat)
96 |
97 | side_input = process_global_sampling(batch, imsize, 0.03, 0.13,
98 | global_hist, global_sat, hist_ref_idx, sat_ref_idx)
99 |
100 | # make them all variable + gpu avialable
101 |
102 | input = Variable(input).cuda()
103 | labels = Variable(labels).cuda()
104 | side_input = Variable(side_input).cuda()
105 |
106 | outputs = unet(input, side_input)
107 |
108 | criterion = torch.nn.SmoothL1Loss()
109 | loss = criterion(outputs, labels)
110 | print('loss for test data: %2.4f'%(loss.cpu().data[0]))
111 |
112 |
113 | colored_images = torch.cat([input,outputs],1).data # 100 x 3 x 32 x 32
114 | gray_images = torch.zeros(batch_size, 3, imsize, imsize)
115 | img_gray =np.zeros((imsize, imsize,3))
116 |
117 | colored_images_np = colored_images.cpu().numpy().transpose((0,2,3,1))
118 |
119 | j = 0
120 | # make sample images back to rgb
121 | for img in colored_images_np:
122 |
123 | img[:,:,0] = img[:,:,0]*100
124 | img[:, :, 1:3] = img[:, :, 1:3] * 200 - 100
125 | img = img.astype(np.float64)
126 | img_RGB = lab2rgb(img)
127 | img_gray[:,:,0] = img[:,:,0]
128 | img_gray_RGB = lab2rgb(img_gray)
129 |
130 | colored_images[j] = torch.from_numpy(img_RGB.transpose((2,0,1)))
131 | gray_images[j] = torch.from_numpy(img_gray_RGB.transpose((2,0,1)))
132 | j+=1
133 |
134 | #
135 | torchvision.utils.save_image(images,
136 | os.path.join(image_save, dataset, '{}_real_samples.png'.format(idx)))
137 | torchvision.utils.save_image(colored_images,
138 | os.path.join(image_save, dataset, '{}_colored_samples.png'.format(idx)))
139 | torchvision.utils.save_image(gray_images,
140 | os.path.join(image_save, dataset, '{}_input_samples.png'.format(idx)))
141 |
142 |
143 | print('-----------images sampled!------------')
144 | break
145 |
146 |
147 | if __name__ == '__main__':
148 | args = parse_args()
149 | main(args)
--------------------------------------------------------------------------------
/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class UNetConvBlock1_1(nn.Module):
7 | def __init__(self, in_size, out_size, kernel_size=3):
8 | super(UNetConvBlock1_1, self).__init__()
9 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
10 |
11 | def forward(self, x):
12 | out = self.conv(x)
13 | return out
14 |
15 | class UNetConvBlock1_2(nn.Module):
16 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
17 | super(UNetConvBlock1_2, self).__init__()
18 | self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
19 | self.activation = activation
20 | self.batchnorm = nn.BatchNorm2d(out_size)
21 | self.conv3 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False)
22 | #self.conv3.weight.data.fill_(1)
23 |
24 | def forward(self, x):
25 | out = self.activation(x)
26 | out = self.activation(self.conv2(out))
27 | out = self.batchnorm(out)
28 | out = self.conv3(out)
29 | return out
30 |
31 | class UNetConvBlock1_2_2(nn.Module):
32 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
33 | super(UNetConvBlock1_2_2, self).__init__()
34 | self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
35 | self.activation = activation
36 | self.batchnorm = nn.BatchNorm2d(out_size)
37 |
38 | def forward(self, x):
39 | out = self.activation(x)
40 | out = self.activation(self.conv2(out))
41 | out = self.batchnorm(out)
42 | return out
43 |
44 | class UNetConvBlock2(nn.Module):
45 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
46 | super(UNetConvBlock2, self).__init__()
47 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
48 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
49 | self.activation = activation
50 | self.batchnorm = nn.BatchNorm2d(out_size)
51 | self.conv3 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False)
52 | #self.conv3.weight.data.fill_(1)
53 |
54 | def forward(self, x):
55 | out = self.activation(self.conv(x))
56 | out = self.activation(self.conv2(out))
57 | out = self.batchnorm(out)
58 | out = self.conv3(out)
59 | return out
60 |
61 | class UNetConvBlock2_2(nn.Module):
62 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
63 | super(UNetConvBlock2_2, self).__init__()
64 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
65 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
66 | self.activation = activation
67 | self.batchnorm = nn.BatchNorm2d(out_size)
68 |
69 | def forward(self, x):
70 | out = self.activation(self.conv(x))
71 | out = self.activation(self.conv2(out))
72 | out = self.batchnorm(out)
73 | return out
74 |
75 | class UNetConvBlock3(nn.Module):
76 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
77 | super(UNetConvBlock3, self).__init__()
78 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
79 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
80 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
81 | self.activation = activation
82 | self.batchnorm = nn.BatchNorm2d(out_size)
83 | self.conv4 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False)
84 | #self.conv4.weight.data.fill_(1)
85 |
86 | def forward(self, x):
87 | out = self.activation(self.conv(x))
88 | out = self.activation(self.conv2(out))
89 | out = self.activation(self.conv3(out))
90 | out = self.batchnorm(out)
91 | out = self.conv4(out)
92 | return out
93 |
94 | class UNetConvBlock3_2(nn.Module):
95 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
96 | super(UNetConvBlock3_2, self).__init__()
97 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
98 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
99 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
100 | self.activation = activation
101 | self.batchnorm = nn.BatchNorm2d(out_size)
102 |
103 | def forward(self, x):
104 | out = self.activation(self.conv(x))
105 | out = self.activation(self.conv2(out))
106 | out = self.activation(self.conv3(out))
107 | out = self.batchnorm(out)
108 | return out
109 |
110 | class UNetConvBlock4(nn.Module):
111 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
112 | super(UNetConvBlock4, self).__init__()
113 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1)
114 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
115 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
116 | self.activation = activation
117 | self.batchnorm = nn.BatchNorm2d(out_size)
118 |
119 | def forward(self, x):
120 | out = self.activation(self.conv(x))
121 | out = self.activation(self.conv2(out))
122 | out = self.activation(self.conv3(out))
123 | out = self.batchnorm(out)
124 | return out
125 |
126 | class UNetConvBlock5(nn.Module):
127 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
128 | super(UNetConvBlock5, self).__init__()
129 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=2, dilation=2)
130 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2)
131 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2)
132 | self.activation = activation
133 | self.batchnorm = nn.BatchNorm2d(out_size)
134 |
135 | def forward(self, x):
136 | out = self.activation(self.conv(x))
137 | out = self.activation(self.conv2(out))
138 | out = self.activation(self.conv3(out))
139 | out = self.batchnorm(out)
140 | return out
141 |
142 | class UNetConvBlock6(nn.Module):
143 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
144 | super(UNetConvBlock6, self).__init__()
145 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=2, dilation=2)
146 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2)
147 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2)
148 | self.activation = activation
149 | self.batchnorm = nn.BatchNorm2d(out_size)
150 |
151 | def forward(self, x):
152 | out = self.activation(self.conv(x))
153 | out = self.activation(self.conv2(out))
154 | out = self.activation(self.conv3(out))
155 | out = self.batchnorm(out)
156 | return out
157 |
158 | class UNetConvBlock7(nn.Module):
159 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
160 | super(UNetConvBlock7, self).__init__()
161 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1)
162 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
163 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
164 | self.activation = activation
165 | self.batchnorm = nn.BatchNorm2d(out_size)
166 |
167 | def forward(self, x):
168 | out = self.activation(self.conv(x))
169 | out = self.activation(self.conv2(out))
170 | out = self.activation(self.conv3(out))
171 | out = self.batchnorm(out)
172 | return out
173 |
174 | class UNetConvBlock8(nn.Module):
175 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
176 | super(UNetConvBlock8, self).__init__()
177 | self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1)
178 | self.bridge = nn.Conv2d(256, 256, kernel_size, padding=1)
179 | #self.bridge.weight.data.normal_(0, 0.01)
180 | #self.bridge.bias.data.fill_(1)
181 | self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
182 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
183 | self.activation = activation
184 | self.batchnorm = nn.BatchNorm2d(out_size)
185 | # def center_crop(self, layer, target_size):
186 | # batch_size, n_channels, layer_width, layer_height = layer.size()
187 | # xy1 = (layer_width - target_size) // 2
188 | # return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]
189 | def forward(self, x, bridge):
190 | up = self.up(x)
191 | out = self.activation(self.bridge(bridge) + up)
192 | out = self.activation(self.conv(out))
193 | out = self.activation(self.conv2(out))
194 | out = self.batchnorm(out)
195 | return out
196 |
197 | class UNetConvBlock9(nn.Module):
198 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
199 | super(UNetConvBlock9, self).__init__()
200 | self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1)
201 | #self.up.weight.data.normal_(0, 0.01)
202 | #self.up.bias.data.fill_(1)
203 | self.bridge = nn.Conv2d(128, 128, kernel_size, padding=1)
204 | #self.bridge.weight.data.normal_(0, 0.01)
205 | #self.bridge.bias.data.fill_(1)
206 | self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
207 | #self.conv.weight.data.normal_(0, 0.01)
208 | #self.conv.bias.data.fill_(1)
209 | self.activation = activation
210 | self.batchnorm = nn.BatchNorm2d(out_size)
211 |
212 | def forward(self, x, bridge):
213 | up = self.up(x)
214 | out = self.activation(self.bridge(bridge) + up)
215 | out = self.activation(self.conv(out))
216 | out = self.batchnorm(out)
217 |
218 | return out
219 |
220 | class UNetConvBlock10(nn.Module):
221 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
222 | super(UNetConvBlock10, self).__init__()
223 | self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1)
224 | #self.up.weight.data.normal_(0, 0.01)
225 | #self.up.bias.data.fill_(1)
226 | self.bridge = nn.Conv2d(64, 128, kernel_size, padding=1)
227 | #self.bridge.weight.data.normal_(0, 0.01)
228 | #self.bridge.bias.data.fill_(1)
229 | self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
230 | #self.conv.weight.data.normal_(0, 0.01)
231 | #self.conv.bias.data.fill_(1)
232 | self.activation = activation
233 | self.activation2 = nn.LeakyReLU(negative_slope=0.02)
234 |
235 | def forward(self, x, bridge):
236 | up = self.up(x)
237 | out = self.activation(self.bridge(bridge) + up)
238 | out = self.activation2(self.conv(out))
239 | return out
240 |
241 | class prediction(nn.Module):
242 | def __init__(self, in_size, out_size, kernel_size=1, activation=F.tanh, space_dropout=False):
243 | super(prediction, self).__init__()
244 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, dilation=1)
245 | self.activation = activation
246 |
247 | def forward(self, x):
248 | out = self.activation(self.conv(x))
249 | out = out * 100
250 | return out
251 |
252 | class convrelu(nn.Module):
253 |
254 | def __init__(self, in_size, out_size, kernel_size=1, activation=F.relu, space_dropout=False):
255 | super(convrelu, self).__init__()
256 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=0)
257 | self.activation = activation
258 |
259 | def forward(self, x):
260 | out = self.activation(self.conv(x))
261 | return out
262 |
263 | class global_network(nn.Module):
264 | def __init__(self, image_size):
265 | super(global_network, self).__init__()
266 | self.oneD = convrelu(316, 512)
267 | self.twoD = convrelu(512, 512)
268 | self.threeD = convrelu(512, 512)
269 | self.fourD = convrelu(512, 512)
270 | self.image_size = image_size
271 |
272 | def forward(self, x): # 4 conv+relu layers with 1 x 1 kernel size with 512 depth,
273 | tmp = self.oneD(x) # made into h/8 x w/8 x 512 # input: 1 x 1 x 313+3 dimension tensor
274 | tmp = self.twoD(tmp)
275 | tmp = self.threeD(tmp)
276 | out = self.fourD(tmp) # batch x 1 x 1 x 512
277 |
278 | out = out.repeat(1,1, int(self.image_size/8), int(self.image_size/8))
279 | return out
280 |
281 | class local_network(nn.Module):
282 | def __init__(self, in_size, out_size, imsize):
283 | super(local_network, self).__init__()
284 | self.imsize = imsize
285 | self.conv = nn.Conv2d(in_size, out_size, 3, padding=1)
286 |
287 | def forward(self, ab_input):
288 | out=self.conv(ab_input) # depth 64 red feed for the network
289 | return out
290 |
291 | class UNet(nn.Module):
292 | def __init__(self, imsize, islocal):
293 | super(UNet, self).__init__()
294 | self.imsize = imsize
295 | self.islocal = islocal
296 |
297 | if self.islocal==True:
298 | self.localnet = local_network(3, 64, self.imsize)
299 | else: # if local
300 | self.globalnet = global_network(self.imsize)
301 |
302 | self.convlayer1_1 = UNetConvBlock1_1(1, 64)
303 | self.convlayer1_2 = UNetConvBlock1_2(64, 64)
304 | self.convlayer1_2_2 = UNetConvBlock1_2_2(64, 64)
305 | self.convlayer2 = UNetConvBlock2(64, 128)
306 | self.convlayer2_2 = UNetConvBlock2_2(64, 128)
307 | self.convlayer3 = UNetConvBlock3(128, 256)
308 | self.convlayer3_2 = UNetConvBlock3_2(128, 256)
309 | self.convlayer4 = UNetConvBlock4(256, 512)
310 | self.convlayer5 = UNetConvBlock5(512, 512) # Dilated Convolution
311 | self.convlayer6 = UNetConvBlock6(512, 512) # Dilated Convolution
312 | self.convlayer7 = UNetConvBlock7(512, 512)
313 | self.convlayer8 = UNetConvBlock8(512, 256)
314 | self.convlayer9 = UNetConvBlock9(256, 128)
315 | self.convlayer10 = UNetConvBlock10(128, 128)
316 |
317 | self.prediction = prediction(128, 2)
318 |
319 | #self.last = nn.Conv2d(128, 2, 1)
320 |
321 | def forward(self, x, side_input):
322 | layer1_1 = self.convlayer1_1(x)
323 |
324 | if self.islocal == True:
325 | local_net = self.localnet(side_input)
326 | layer1_1 = layer1_1 + local_net
327 |
328 | layer1_2 = self.convlayer1_2(layer1_1)
329 | layer1_2_2 = self.convlayer1_2_2(layer1_1)
330 | layer2 = self.convlayer2(layer1_2)
331 | layer2_2 = self.convlayer2_2(layer1_2)
332 | layer3 = self.convlayer3(layer2)
333 | layer3_2 = self.convlayer3_2(layer2)
334 | layer4 = self.convlayer4(layer3)
335 |
336 | if self.islocal == False:
337 | global_net = self.globalnet(side_input)
338 | layer4 = layer4 + global_net
339 |
340 | layer5 = self.convlayer5(layer4)
341 | layer6 = self.convlayer6(layer5)
342 | layer7 = self.convlayer7(layer6)
343 | layer8 = self.convlayer8(layer7, layer3_2)
344 | layer9 = self.convlayer9(layer8, layer2_2)
345 | layer10 = self.convlayer10(layer9, layer1_2_2)
346 |
347 | prediction = self.prediction(layer10)
348 |
349 | return prediction
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import time
4 | import datetime
5 | import torch
6 |
7 |
8 | def check_value(inds, val):
9 | # Check to see if an array is a single element equaling a particular value
10 | # Good for pre-processing inputs in a function
11 | if (np.array(inds).size == 1):
12 | if (inds == val):
13 | return True
14 | return False
15 |
16 |
17 | def flatten_nd_array(pts_nd, axis=1):
18 | # Flatten an nd array into a 2d array with a certain axis
19 | # INPUTS
20 | # pts_nd N0xN1x...xNd array
21 | # axis integer
22 | # OUTPUTS
23 | # pts_flt prod(N \ N_axis) x N_axis array
24 | NDIM = pts_nd.ndim
25 | SHP = np.array(pts_nd.shape)
26 | nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis))) # non axis indices
27 | NPTS = np.prod(SHP[nax])
28 | axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0)
29 | pts_flt = pts_nd.transpose((axorder))
30 | pts_flt = pts_flt.reshape(NPTS, SHP[axis])
31 | return pts_flt
32 |
33 |
34 | def unflatten_2d_array(pts_flt, pts_nd, axis=1, squeeze=False):
35 | # Unflatten a 2d array with a certain axis
36 | # INPUTS
37 | # pts_flt prod(N \ N_axis) x M array
38 | # pts_nd N0xN1x...xNd array
39 | # axis integer
40 | # squeeze bool if true, M=1, squeeze it out
41 | # OUTPUTS
42 | # pts_out N0xN1x...xNd array
43 | NDIM = pts_nd.ndim
44 | SHP = np.array(pts_nd.shape)
45 | nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis))) # non axis indices
46 | NPTS = np.prod(SHP[nax])
47 |
48 | if (squeeze):
49 | axorder = nax
50 | axorder_rev = np.argsort(axorder)
51 | M = pts_flt.shape[1]
52 | NEW_SHP = SHP[nax].tolist()
53 | pts_out = pts_flt.reshape(NEW_SHP)
54 | pts_out = pts_out.transpose(axorder_rev)
55 | else:
56 | axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0)
57 | axorder_rev = np.argsort(axorder)
58 | M = pts_flt.shape[1]
59 | NEW_SHP = SHP[nax].tolist()
60 | NEW_SHP.append(M)
61 | pts_out = pts_flt.reshape(NEW_SHP)
62 | pts_out = pts_out.transpose(axorder_rev)
63 |
64 | return pts_out
65 |
66 |
67 | def na():
68 | return np.newaxis
69 |
70 |
71 | class Timer():
72 | def __init__(self):
73 | self.cur_t = time.time()
74 |
75 | def tic(self):
76 | self.cur_t = time.time()
77 |
78 | def toc(self):
79 | return time.time() - self.cur_t
80 |
81 | def tocStr(self, t=-1):
82 | if (t == -1):
83 | return str(datetime.timedelta(seconds=np.round(time.time() - self.cur_t, 3)))[:-4]
84 | else:
85 | return str(datetime.timedelta(seconds=np.round(t, 3)))[:-4]
86 |
87 | def distribution(tensor):
88 |
89 | tensor = torch.div(tensor, expand(tensor.sum(dim=1).unsqueeze(-1), tensor))
90 | if (tensor.sum(dim=1).data.cpu().numpy()==0).any():
91 | print ("")
92 | print ("")
93 | print ("division by zero")
94 | print ("")
95 | print ("")
96 | return tensor.unsqueeze(-1)
97 |
98 | def expand(tensor, target):
99 | return tensor.expand_as(target)
100 |
101 |
102 | def make_folder(path, dataset):
103 | try:
104 | os.makedirs(os.path.join(path, dataset))
105 | except OSError:
106 | pass
107 |
--------------------------------------------------------------------------------