├── README.md ├── col1.png ├── col2.png ├── col3.png ├── data ├── .DS_Store └── pts_in_hull.npy ├── data_process.py ├── deep_color.py ├── download.sh ├── global_hint.py ├── sampling.py ├── unet.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # "Real Time User-guided Colorization with Learned Deep Priors" implemented in pytorch 2 | 3 | This is a pytorch implementation of ["Real-Time User-Guided Image Colorization with Learned Deep Priors"](https://arxiv.org/abs/1705.02999) by Zhang et.al. 4 | 5 | ## Getting Started 6 | 7 | ### Prerequisites 8 | 9 | torch==0.2.0.post4, torchvision==0.1.9 10 | The code is written with the default setting that you have gpu. Cpu mode is not recommended when using this repository. 11 | 12 | ### Installing and running the tests 13 | 14 | Make sure you have cifar10 or CelebA downloaded in ./data. 15 | You can download it through by taking a look at my "download.sh" file 16 | ``` 17 | ./data/CelebA 18 | ./data/Cifar10 19 | ./data/pts_in_hull.npy 20 | ``` 21 | 22 | first clone this repository 23 | 24 | ``` 25 | git clone https://github.com/sjooyoo/https://github.com/sjooyoo/real-time-user-guided-colorization_pytorch.git 26 | ``` 27 | then run train 28 | 29 | ``` 30 | python deep_color.py 31 | ``` 32 | 33 | to sample results you first need to run deep_color.py, which will automatically save models under a models folder that will be made in your root directory. 34 | I did not include pretrained models in this repository. The --model unet100.pkl below is a sample after 100 epochs. Change the command according to your model that you want to sample. 35 | ``` 36 | python sampling.py --model unet100.pkl 37 | ``` 38 | 39 | 40 | ### Results 41 | 42 | Input black and white image 43 | 44 | 45 | 46 | Predicted colorization output 47 | 48 | 49 | 50 | Ground truth image 51 | 52 | 53 | 54 | 55 | ### Note 56 | This is not a complete implementation. I have implemented the global hints network but have yet to incorporate it into the main network. 57 | 58 | 59 | ### Further work 60 | * global hints network 61 | 62 | 63 | ## Acknowledgments 64 | Original paper ["Real-Time User-Guided Image Colorization with Learned Deep Priors"](https://arxiv.org/abs/1705.02999) 65 | -------------------------------------------------------------------------------- /col1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/col1.png -------------------------------------------------------------------------------- /col2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/col2.png -------------------------------------------------------------------------------- /col3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/col3.png -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/data/.DS_Store -------------------------------------------------------------------------------- /data/pts_in_hull.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inkImage/real-time-user-guided-colorization_pytorch/d1fd1e0a0e31dc296989bc3bbbac4d6279e43156/data/pts_in_hull.npy -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.datasets as dsets 4 | import torchvision.transforms as transforms 5 | from skimage.color import rgb2lab 6 | 7 | from global_hint import * 8 | 9 | 10 | def Color_Dataloader(dataset, batch_size): 11 | if dataset == 'cifar': 12 | transform = transforms.Compose([ 13 | transforms.ToTensor() 14 | ]) 15 | train_dataset = dsets.CIFAR10(root='./data/', 16 | train=True, 17 | transform=transform, 18 | download=True) 19 | 20 | val_dataset = dsets.CIFAR10(root='./data/', 21 | train=False, 22 | transform=transform) 23 | # Data Loader-> it will hand in dataset by size batch 24 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 25 | batch_size=batch_size, 26 | shuffle=True) 27 | 28 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 29 | batch_size=batch_size, 30 | shuffle=False) 31 | imsize = 32 32 | 33 | elif dataset == 'imagenet': 34 | 35 | traindir = './data/tiny-imagenet-200/train/' 36 | valdir = './data/tiny-imagenet-200/val/' 37 | transform = transforms.Compose([ 38 | transforms.ToTensor() 39 | ]) 40 | 41 | train_dataset = dsets.ImageFolder(traindir, transform) 42 | val_dataset = dsets.ImageFolder(valdir, transform) 43 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 44 | batch_size=batch_size, 45 | shuffle=True, 46 | num_workers=2) 47 | 48 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 49 | batch_size=batch_size, 50 | shuffle=True, 51 | num_workers=2) 52 | imsize = 64 53 | 54 | 55 | elif dataset == 'celeba': 56 | 57 | traindir = './data/CelebA/trainimages/images' 58 | valdir= './data/CelebA/valimages' 59 | transform = transforms.Compose([ 60 | transforms.ToTensor() 61 | ]) 62 | 63 | train_dataset = dsets.ImageFolder(traindir, transform=transform) 64 | val_dataset = dsets.ImageFolder(valdir, transform=transform) 65 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 66 | batch_size=batch_size, 67 | shuffle=True, 68 | num_workers=2) 69 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 70 | batch_size=batch_size, 71 | shuffle=True, 72 | num_workers=2) 73 | imsize = 128 74 | 75 | elif dataset == 'mscoco': 76 | 77 | traindir = './data/mscoco/trainimages_resized' 78 | valdir = './data/mscoco/valimages_resized' 79 | # Load mscoco data 80 | transform = transforms.Compose([ 81 | transforms.ToTensor() 82 | ]) 83 | 84 | train_dataset = dsets.ImageFolder(traindir, transform=transform) 85 | val_dataset = dsets.ImageFolder(valdir, transform=transform) 86 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 87 | batch_size=batch_size, 88 | shuffle=True, 89 | num_workers=2) 90 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 91 | batch_size=batch_size, 92 | shuffle=True, 93 | num_workers=2) 94 | imsize = 32 95 | 96 | return train_dataset, train_loader, val_loader, imsize 97 | 98 | 99 | def process_data(image_data, batch_size, imsize, islocal): 100 | input = torch.zeros(batch_size, 1, imsize, imsize) 101 | labels = torch.zeros(batch_size, 2, imsize, imsize) 102 | images_np = image_data.numpy().transpose((0, 2, 3, 1)) 103 | 104 | if islocal == False: 105 | ab_for_global = torch.zeros(batch_size, 2, imsize, imsize) 106 | 107 | for k in range(batch_size): 108 | img_lab = rgb2lab(images_np[k]) 109 | 110 | img_l = img_lab[:, :, 0] / 100 111 | input[k] = torch.from_numpy(np.expand_dims(img_l, 0)) 112 | 113 | img_ab_scale = (img_lab[:, :, 1:3] + 100) / 200 114 | labels[k] = torch.from_numpy(img_ab_scale.transpose((2, 0, 1))) 115 | 116 | img_ab_unscale = img_lab[:, :, 1:3] 117 | ab_for_global[k] = torch.from_numpy(img_ab_unscale.transpose((2, 0, 1))) 118 | 119 | if islocal == True: 120 | for k in range(batch_size): 121 | img_lab = rgb2lab(images_np[k]) 122 | 123 | img_l = img_lab[:, :, 0] / 100 124 | input[k] = torch.from_numpy(np.expand_dims(img_l, 0)) 125 | 126 | img_ab_scale = (img_lab[:, :, 1:3] + 100) / 200 127 | labels[k] = torch.from_numpy(img_ab_scale.transpose((2, 0, 1))) 128 | 129 | ab_for_global = 0 # just to make the room. don't need it in local net 130 | 131 | return input, labels, ab_for_global 132 | 133 | 134 | def process_global(images, input_ab, batch_size, imsize, hist_mean, hist_std): 135 | glob_quant = Global_Quant(batch_size, imsize) 136 | X_hist = glob_quant.global_histogram(input_ab) # batch x 313 x imsize x imsize 137 | X_sat = glob_quant.global_saturation(images).unsqueeze(1) # batch x 1 138 | B_hist, B_sat = glob_quant.global_masks(batch_size) # if masks are 0, put uniform random(0~1) value in it 139 | 140 | for l in range(batch_size): 141 | if B_sat[l].numpy() == 0: 142 | X_sat[l] = torch.normal(torch.FloatTensor([hist_mean]), std=torch.FloatTensor([hist_std])) 143 | if B_hist[l].numpy() == 0: 144 | tmp = torch.rand(313) 145 | X_hist[l] = torch.div(tmp, torch.sum(tmp)) 146 | global_input = torch.cat([X_hist, B_hist, X_sat, B_sat], 1).unsqueeze(2).unsqueeze(2) 147 | # batch x (q+1) = batch x 316 x 1 x 1 148 | 149 | return global_input 150 | 151 | def process_local(input_ab, batch_size, imsize): 152 | num_points = torch.zeros(batch_size).geometric_(0.125).long() # number of points to give as hints 153 | block_size = torch.zeros(batch_size, 1).uniform_(-0.5, 2.49).round().clamp(0, 2).long() # size of blocks to average 154 | local_ab = torch.zeros(batch_size, 2, imsize, imsize) # output local hint (ab channel) 155 | local_mask = torch.zeros(batch_size, 1, imsize, imsize).long() # output local hint (mask) 156 | 157 | for i in range(batch_size): # for all batches and 158 | for j in range(num_points[i]): 159 | gaussian_points = torch.zeros(2).normal_(mean=imsize/2, std=imsize/4).round().clamp(0, imsize-1).long() 160 | local_ab[i], local_mask[i] = \ 161 | local_get_average_value(local_ab[i], input_ab[i], local_mask[i], gaussian_points, block_size[i], imsize) 162 | 163 | return local_ab, local_mask.float() 164 | 165 | # get average value in local_ab for random sized box at certain points. 166 | def local_get_average_value(local_ab, input_ab, local_mask, loc, p, imsize): # width 0~4 167 | 168 | low_v = loc[0]-p[0] #lower bound 0 169 | if low_v<0: 170 | low_v=0 171 | high_v = loc[0]+p[0]+1 #higher bound imsize-1 172 | if high_v>=imsize: 173 | high_v=imsize 174 | low_h = loc[1]-p[0] #lower bound 0 175 | if low_h<0: 176 | low_h=0 177 | high_h = loc[1]+p[0]+1 #higher bound imsize-1 178 | if high_h>=imsize: 179 | high_h=imsize 180 | 181 | 182 | local_mask[:, low_v:high_v, low_h:high_h] = 1 183 | local_ab = torch.mul(local_mask.repeat(2, 1, 1).float(), input_ab) 184 | local_mean_a = torch.sum(local_ab[0,:,:]) / len(torch.nonzero(local_ab[0,:,:])) 185 | local_mean_b = torch.sum(local_ab[1,:,:]) / len(torch.nonzero(local_ab[1,:,:])) 186 | local_a = local_mask.float() * local_mean_a # 1 x 32 x 32 187 | local_b = local_mask.float() * local_mean_b 188 | local_ab = torch.cat([local_a, local_b], dim=0) 189 | return local_ab, local_mask 190 | 191 | 192 | 193 | def process_global_sampling(batch_size, imsize, hist_mean, hist_std, 194 | HIST=False, SAT=False, hist_ref_idx=1, sat_ref_idx=1): 195 | glob_quant = Global_Quant(batch_size, imsize) 196 | 197 | if HIST==True: 198 | input_ab_for_hist = hist_ref(batch_size, imsize, hist_ref_idx) 199 | X_hist = glob_quant.global_histogram(input_ab_for_hist) # batch x 313 x imsize x imsize 200 | B_hist = torch.ones(batch_size, 1) 201 | 202 | else: 203 | tmp = torch.rand(batch_size, 313) 204 | X_hist = torch.div(tmp, torch.sum(tmp, dim=1).unsqueeze(1).repeat(1, 313)) 205 | B_hist = torch.zeros(batch_size, 1) 206 | 207 | if SAT==True: 208 | image_for_sat = (batch_size, imsize, sat_ref_idx) 209 | X_sat = glob_quant.global_saturation(image_for_sat).unsqueeze(1) # batch x 1 210 | B_sat = torch.ones(batch_size, 1) # if masks are 0, put uniform random(0~1) value in it 211 | 212 | else: 213 | X_sat = torch.randn(batch_size, 1) 214 | for l in range(batch_size): 215 | X_sat[l] = torch.normal(torch.FloatTensor([hist_mean]), std=torch.FloatTensor([hist_std])) 216 | B_sat = torch.zeros(batch_size, 1) 217 | 218 | global_input = torch.cat([X_hist, B_hist, X_sat, B_sat], 1).unsqueeze(2).unsqueeze(2) 219 | # batch x (q+1) = batch x 316 x 1 x 1 220 | 221 | return global_input 222 | 223 | def process_local_sampling(batch_size, imsize, p): 224 | 225 | ab_input = torch.FloatTensor([0,0]).unsqueeze(0) 226 | xy_input = torch.LongTensor([0,0]).unsqueeze(0) 227 | q=0 228 | while q is not -1: 229 | ab_list = [] 230 | xy_list = [] 231 | x = int(input("Enter a number for x: ")) 232 | y = int(input("Enter a number for y: ")) 233 | a = int(input("For which color you want to apply?: (between -100 and 100)")) 234 | b = int(input("For which color you want to apply?: (between -100 and 100)")) 235 | a = ((a+100)/200) 236 | b = ((b+100)/200) 237 | xy_list.append(x) 238 | xy_list.append(y) 239 | ab_list.append(a) 240 | ab_list.append(b) 241 | xy_list = torch.LongTensor([xy_list]) 242 | ab_list = torch.FloatTensor([ab_list]) 243 | xy_input = torch.cat([xy_input, xy_list], dim=0) # n x 2 with 1 x 2 all zeros 244 | ab_input = torch.cat([ab_input, ab_list], dim=0) # n x 2 with 1 x 2 all zeros 245 | q = int(input("Enter -1 to finish: ")) 246 | 247 | local_ab = torch.zeros(batch_size, 2, imsize, imsize) # output local hint (ab channel) 248 | local_mask = torch.zeros(batch_size, 1, imsize, imsize).long() # output local hint (mask) 249 | # print(torch.sum(local_ab)) 250 | # print(torch.sum(local_mask)) 251 | for i in range(batch_size): # for all batches and 252 | for j in range(ab_input.size(0)-1): 253 | # print(ab_input.size(0)-1) 254 | # print(ab_input[j+1]) 255 | 256 | low_v = xy_input[j+1][0] - p # lower bound 0 257 | if low_v < 0: 258 | low_v = 0 259 | high_v = xy_input[j+1][0] + p + 1 # higher bound imsize-1 260 | if high_v >= imsize: 261 | high_v = imsize 262 | low_h = xy_input[j+1][1] - p # lower bound 0 263 | if low_h < 0: 264 | low_h = 0 265 | high_h = xy_input[j+1][1] + p + 1 # higher bound imsize-1 266 | if high_h >= imsize: 267 | high_h = imsize 268 | 269 | local_ab[i,0, low_v:high_v, low_h:high_h] = ab_input[j + 1][0] 270 | local_ab[i,1, low_v:high_v, low_h:high_h] = ab_input[j + 1][1] 271 | local_mask[i,:,low_v:high_v, low_h:high_h] = 1 272 | print(len(torch.nonzero(local_ab[i])), len(torch.nonzero(local_mask[i]))) 273 | 274 | return local_ab, local_mask.float() 275 | 276 | def hist_ref(batch, imsize, idx=1): 277 | valdir = './data/sample/hist' 278 | transform = transforms.Compose([ 279 | transforms.Scale((imsize,imsize)), 280 | transforms.ToTensor(), 281 | 282 | ]) 283 | 284 | val_dataset = dsets.ImageFolder(valdir, transform) 285 | 286 | 287 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 288 | batch_size=1, 289 | shuffle=False, 290 | num_workers=2) 291 | 292 | for i, (image, _) in enumerate(val_loader): 293 | if i==(idx-1): 294 | ref_image = image 295 | print('%dth image chosen as reference for histogram'%(idx)) 296 | break 297 | 298 | ref_image = ref_image.numpy().transpose((0, 2, 3, 1)) 299 | img_lab = rgb2lab(ref_image) 300 | img_ab = img_lab[:, :, :, 1:3] 301 | 302 | pick_ref = torch.from_numpy(img_ab.transpose((0, 3, 1, 2))).repeat(batch,1,1,1).float() 303 | 304 | return pick_ref 305 | 306 | def sat_ref(batch, imsize, idx=1): 307 | valdir = './data/sample/sat' 308 | transform = transforms.Compose([ 309 | transforms.Scale((imsize,imsize)), 310 | transforms.ToTensor(), 311 | 312 | ]) 313 | 314 | val_dataset = dsets.ImageFolder(valdir, transform) 315 | 316 | 317 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 318 | batch_size=1, 319 | shuffle=False, 320 | num_workers=2) 321 | 322 | for i, (image, _) in enumerate(val_loader): 323 | if i==(idx-1): 324 | ref_image = image 325 | print('%dth image chosen as reference for saturation'%(idx)) 326 | break 327 | 328 | print(ref_image.size()) 329 | pick_ref = ref_image.repeat(batch, 1, 1, 1).float() 330 | 331 | return pick_ref -------------------------------------------------------------------------------- /deep_color.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import argparse 5 | import numpy as np 6 | import torch.nn as nn 7 | from torch import cuda 8 | from torch.autograd import Variable 9 | 10 | from unet import * 11 | from util import * 12 | from global_hint import * 13 | from data_process import * 14 | 15 | 16 | # Hyper Parameters 17 | 18 | 19 | # arguments parsed when initiating 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--data', type=str, default='cifar', choices=['cifar', 'imagenet', 'celeba', 'mscoco']) 23 | parser.add_argument('--gpu', type=int, default=1) 24 | parser.add_argument('--model_path', type=str, default='./models') 25 | parser.add_argument('--log_path', type=str, default='./logs') 26 | parser.add_argument('--model', type=str, default='unet100.pkl') 27 | parser.add_argument('--image_save', type=str, default='./images') 28 | parser.add_argument('--learning_rate', type=int, default=0.0002) 29 | parser.add_argument('--num_epochs', type=int, default=500) 30 | parser.add_argument('--start_epoch', type=int, default=0) 31 | parser.add_argument('--batch_size', type=int, default=64) 32 | parser.add_argument('--idx', type=int, default=1) 33 | parser.add_argument('--resume', type=bool, default=False, 34 | help='path to latest checkpoint (default: none)') 35 | parser.add_argument('--islocal', type=bool, default=False) 36 | 37 | return parser.parse_args() 38 | 39 | 40 | def main(args): 41 | dataset = args.data 42 | gpu = args.gpu 43 | batch_size = args.batch_size 44 | model_path = args.model_path 45 | log_path = args.log_path 46 | num_epochs = args.num_epochs 47 | learning_rate = args.learning_rate 48 | start_epoch = args.start_epoch 49 | islocal = args.islocal 50 | 51 | # make directory for models saved when there is not. 52 | make_folder(model_path, dataset) # for sampling model 53 | make_folder(log_path, dataset) # for logpoint model 54 | make_folder(log_path, dataset +'/ckpt') # for checkpoint model 55 | 56 | # see if gpu is on 57 | print("Running on gpu : ", gpu) 58 | cuda.set_device(gpu) 59 | 60 | # set the data-loaders 61 | train_dataset, train_loader, val_loader, imsize = Color_Dataloader(dataset, batch_size) 62 | 63 | # declare unet class 64 | unet = UNet(imsize, islocal) 65 | 66 | # make the class run on gpu 67 | unet.cuda() 68 | 69 | # Loss and Optimizer 70 | optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate) 71 | criterion = torch.nn.SmoothL1Loss() 72 | 73 | # optionally resume from a checkpoint 74 | if args.resume: 75 | ckpt_path = os.path.join(log_path, dataset, 'ckpt/local/model.ckpt') 76 | if os.path.isfile(ckpt_path): 77 | print("=> loading checkpoint") 78 | checkpoint = torch.load(ckpt_path) 79 | start_epoch = checkpoint['epoch'] 80 | unet.load_state_dict(checkpoint['state_dict']) 81 | optimizer.load_state_dict(checkpoint['optimizer']) 82 | print("=> Loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 83 | print("=> Meaning that start training from (epoch {})".format(checkpoint['epoch']+1)) 84 | else: 85 | print("=> Sorry, no checkpoint found at '{}'".format(args.resume)) 86 | 87 | # record time 88 | tell_time = Timer() 89 | iter = 0 90 | # Train the Model 91 | for epoch in range(start_epoch, num_epochs): 92 | 93 | unet.train() 94 | for i, (images, _) in enumerate(train_loader): 95 | 96 | batch = images.size(0) 97 | ''' 98 | additional variables for later use. 99 | change the picture type from rgb to CIE Lab. 100 | def process_data, def process_global in util file 101 | ''' 102 | if islocal: 103 | input, labels, _ = process_data(images, batch, imsize, islocal) 104 | local_ab, local_mask = process_local(labels, batch, imsize) 105 | side_input = torch.cat([local_ab, local_mask], 1) # concat([batch x 2 x imsize x imsize , batch x 1 x imsize x imsize], 1) = batch x 3 x imsize x imsize 106 | random_expose = random.randrange(1, 101) 107 | if random_expose == 100: 108 | print("Jackpot! expose the whole!") 109 | local_mask = torch.ones(batch_size, 1, imsize, imsize) 110 | side_input = torch.cat([labels, local_mask], 1) 111 | else: # if is local 112 | input, labels, ab_for_global = process_data(images, batch, imsize, islocal) 113 | side_input = process_global(images, ab_for_global, batch, imsize, hist_mean=0.03, hist_std=0.13) 114 | 115 | 116 | # make them all variable + gpu avialable 117 | 118 | input = Variable(input).cuda() 119 | labels = Variable(labels).cuda() 120 | side_input = Variable(side_input).cuda() 121 | 122 | # initialize gradients 123 | optimizer.zero_grad() 124 | outputs = unet(input, side_input) 125 | 126 | # make outputs and labels as a matrix for loss calculation 127 | outputs = outputs.view(batch, -1) # 100 x 32*32*3(2048) 128 | labels = labels.contiguous().view(batch, -1) # 100 x 32*32*3 129 | 130 | loss_train = criterion(outputs, labels) 131 | loss_train.backward() 132 | optimizer.step() 133 | 134 | if (i + 1) % 10 == 0: 135 | print('Epoch [%d/%d], Iter [%d/%d], Loss: %.10f, iter_time: %2.2f, aggregate_time: %6.2f' 136 | % (epoch + 1, num_epochs, i + 1, (len(train_dataset) // batch_size), loss_train.data[0], 137 | (tell_time.toc() - iter), tell_time.toc())) 138 | iter = tell_time.toc() 139 | 140 | torch.save(unet.state_dict(), os.path.join(model_path, dataset, 'unet%d.pkl' % (epoch + 1))) 141 | 142 | # start evaluation 143 | print("-------------evaluation start------------") 144 | 145 | unet.eval() 146 | loss_val_all = Variable(torch.zeros(100), volatile=True).cuda() 147 | for i, (images, _) in enumerate(val_loader): 148 | 149 | # change the picture type from rgb to CIE Lab 150 | batch = images.size(0) 151 | 152 | if islocal: 153 | input, labels, _ = process_data(images, batch, imsize, islocal) 154 | local_ab, local_mask = process_local(labels, batch, imsize) 155 | side_input = torch.cat([local_ab, local_mask], 1) 156 | random_expose = random.randrange(1, 101) 157 | if random_expose == 100: 158 | print("Jackpot! expose the whole!") 159 | local_mask = torch.ones(batch_size, 1, imsize, imsize) 160 | side_input = torch.cat([labels, local_mask], 1) 161 | else: # if is local 162 | input, labels, ab_for_global = process_data(images, batch, imsize, islocal) 163 | side_input = process_global(images, ab_for_global, batch, imsize, hist_mean=0.03, hist_std=0.13) 164 | 165 | # make them all variable + gpu avialable 166 | 167 | input = Variable(input).cuda() 168 | labels = Variable(labels).cuda() 169 | side_input = Variable(side_input).cuda() 170 | 171 | # initialize gradients 172 | optimizer.zero_grad() 173 | outputs = unet(input, side_input) 174 | 175 | # make outputs and labels as a matrix for loss calculation 176 | outputs = outputs.view(batch, -1) # 100 x 32*32*3(2048) 177 | labels = labels.contiguous().view(batch, -1) # 100 x 32*32*3 178 | 179 | loss_val = criterion(outputs, labels) 180 | 181 | logpoint = { 182 | 'epoch': epoch + 1, 183 | 'args': args, 184 | } 185 | checkpoint = { 186 | 'epoch': epoch + 1, 187 | 'args': args, 188 | 'state_dict': unet.state_dict(), 189 | 'optimizer': optimizer.state_dict(), 190 | } 191 | 192 | loss_val_all[i] = loss_val 193 | 194 | if i == 30: 195 | print('Epoch [%d/%d], Validation Loss: %.10f' 196 | % (epoch + 1, num_epochs, torch.mean(loss_val_all).data[0])) 197 | torch.save(logpoint, os.path.join(log_path, dataset, 'Model_e%d_train_%.4f_val_%.4f.pt' % 198 | (epoch + 1, torch.mean(loss_train).data[0], 199 | torch.mean(loss_val_all).data[0]))) 200 | torch.save(checkpoint, os.path.join(log_path, dataset, 'ckpt/model.ckpt')) 201 | break 202 | 203 | 204 | if __name__ == '__main__': 205 | args = parse_args() 206 | main(args) -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [ $FILE == 'CelebA_FD' ] 4 | then 5 | URL=https://www.dropbox.com/s/e0ig4nf1v94hyj8/CelebA.zip?dl=0 6 | ZIP_FILE=./data/CelebA_FD.zip 7 | elif [ $FILE == 'CelebA' ] 8 | then 9 | URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0 10 | ZIP_FILE=./data/CelebA.zip 11 | elif [ $FILE == 'LSUN' ] 12 | then 13 | URL=https://www.dropbox.com/s/zt7d2hchrw7cp9p/church_outdoor_train_lmdb.zip?dl=0 14 | ZIP_FILE=./data/church_outdoor_train_lmdb.zip 15 | else 16 | echo "Available datasets are: CelebA, CelebA_FD and LSUN" 17 | exit 1 18 | fi 19 | 20 | mkdir -p ./data/ 21 | wget -N $URL -O $ZIP_FILE 22 | unzip $ZIP_FILE -d ./data/ 23 | 24 | if [ $FILE == 'CelebA' ] 25 | then 26 | mv ./data/CelebA_nocrop ./data/CelebA 27 | elif [ $FILE == 'CelebA_FD' ] 28 | then 29 | mv ./data/CelebA ./data/CelebA_FD 30 | fi 31 | 32 | rm $ZIP_FILE 33 | -------------------------------------------------------------------------------- /global_hint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import sklearn.neighbors as neigh 4 | from skimage.color import rgb2hsv 5 | 6 | from unet import * 7 | import util 8 | 9 | 10 | class Global_Quant(): 11 | ''' Layer which encodes ab map into Q colors 12 | ''' 13 | def __init__(self, batch, imsize): 14 | self.quantization = Quantization(batch, imsize, km_filepath='./data/pts_in_hull.npy') 15 | 16 | def global_histogram(self, input): 17 | out = self.quantization.encode_nn(input) # batch x 313 x imsize x imsize 18 | out = out.type(torch.FloatTensor) # change it to tensor 19 | X_onehotsum = torch.sum(torch.sum(out, dim=3), dim=2) # sum it up to batch x 313 20 | X_hist = torch.div(X_onehotsum, util.expand(torch.sum(X_onehotsum, dim=1).unsqueeze(1), X_onehotsum)) # make 313 probability 21 | return X_hist 22 | 23 | def global_saturation(self, images): # input: tensor images batch x 3 x imsize x imsize (rgb) 24 | images_np = images.numpy().transpose((0, 2, 3, 1)) # numpy: batch x imsize x imsize x 3 25 | images_h = torch.zeros(images.size(0), 1, images.size(2),images.size(2)) 26 | for k in range(images.size(0)): 27 | img_hsv = rgb2hsv(images_np[k]) 28 | img_h = img_hsv[:, :, 1] 29 | images_h[k] = torch.from_numpy(img_h).unsqueeze(0) # batch x 1 x imsize x imsize 30 | avgs = torch.mean(images_h.view(images.size(0), -1),dim=1) # batch x 1 31 | return avgs 32 | 33 | def global_masks(self, batch_size): # both for histogram and saturation 34 | B_hist = torch.round(torch.rand(batch_size, 1)) 35 | B_sat = torch.round(torch.rand(batch_size, 1)) 36 | return B_hist, B_sat 37 | 38 | class Quantization(): 39 | # Encode points as a linear combination of unordered points 40 | # using NN search and RBF kernel 41 | def __init__(self,batch, imsize, km_filepath='./data/pts_in_hull.npy' ): 42 | 43 | self.cc = torch.from_numpy(np.load(km_filepath)).type(torch.FloatTensor) # 313 x 2 44 | self.K = self.cc.shape[0] 45 | self.batch = batch 46 | self.imsize = imsize 47 | 48 | def encode_nn(self,images): # batch x imsize x imsize x 2 49 | 50 | images = images.permute(0,2,3,1) # batch x 2 x imsize x imsize -> batch x imsize x imsize x 2 51 | images_flt = images.contiguous().view(-1, 2) 52 | P = images_flt.shape[0] 53 | inds = self.nearest_inds(images_flt, self.cc).unsqueeze(1) # P x 1 54 | images_encoded = torch.zeros(P,self.K) 55 | images_encoded.scatter_(1, inds, 1) 56 | images_encoded = images_encoded.view(self.batch, self.imsize, self.imsize, 313) 57 | images_encoded = images_encoded.permute(0,3,1,2) 58 | return images_encoded 59 | 60 | def nearest_inds(self, x, y): # x= n x 2, y= 313 x 2 n x 2, 2 x 313 = n x 313 61 | inner = torch.matmul(x, y.t()) 62 | normX = torch.sum(torch.mul(x, x), 1).unsqueeze(1).expand_as(inner) 63 | normY = torch.sum(torch.mul(y, y), 1).unsqueeze(1).t().expand_as(inner) # n x 313 64 | P = normX - 2 * inner + normY 65 | nearest_idx = torch.min(P, dim=1)[1] 66 | return nearest_idx 67 | 68 | 69 | 70 | # def decode_points_mtx_nd(self,pts_enc_nd,axis=1): 71 | # pts_enc_flt = util.flatten_nd_array(pts_enc_nd,axis=axis) 72 | # pts_dec_flt = np.dot(pts_enc_flt,self.cc) 73 | # pts_dec_nd = util.unflatten_2d_array(pts_dec_flt,pts_enc_nd,axis=axis) 74 | # return pts_dec_nd 75 | # 76 | # def decode_1hot_mtx_nd(self,pts_enc_nd,axis=1,returnEncode=False): 77 | # pts_1hot_nd = nd_argmax_1hot(pts_enc_nd,axis=axis) 78 | # pts_dec_nd = self.decode_points_mtx_nd(pts_1hot_nd,axis=axis) 79 | # if(returnEncode): 80 | # return (pts_dec_nd,pts_1hot_nd) 81 | # else: 82 | # return pts_dec_nd -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import torchvision 5 | import numpy as np 6 | import torch.nn as nn 7 | from torch import cuda 8 | from torch.autograd import Variable 9 | from skimage.color import rgb2lab, lab2rgb, rgb2gray 10 | 11 | from unet import * 12 | from util import * 13 | from global_hint import * 14 | from data_process import * 15 | 16 | 17 | 18 | # Hyper Parameters 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--data', type=str, default='cifar', choices=['cifar', 'imagenet', 'celeba', 'mscoco']) 23 | parser.add_argument('--gpu', type=int, default=1) 24 | parser.add_argument('--model_path', type=str, default='./models') 25 | parser.add_argument('--model', type=str, default='unet100.pkl') 26 | parser.add_argument('--image_save', type=str, default='./images') 27 | parser.add_argument('--learning_rate', type=int, default=0.001) 28 | parser.add_argument('--num_epochs', type=int, default=100) 29 | parser.add_argument('--batch_size', type=int, default=64) 30 | parser.add_argument('--idx', type=int, default=1) 31 | parser.add_argument('--global_hist', type=bool, default=False) 32 | parser.add_argument('--global_sat', type=bool, default=False) 33 | parser.add_argument('--hist_ref_idx', type=int, default=1) 34 | parser.add_argument('--sat_ref_idx', type=int, default=1) 35 | parser.add_argument('--islocal', type=bool, default=False) 36 | parser.add_argument('--nohint', type=bool, default=False) 37 | 38 | 39 | return parser.parse_args() 40 | 41 | 42 | 43 | def main(args): 44 | dataset = args.data 45 | gpu = args.gpu 46 | batch_size = args.batch_size 47 | model_path = args.model_path 48 | image_save = args.image_save 49 | model = args.model 50 | idx = args.idx 51 | global_hist = args.global_hist 52 | global_sat = args.global_sat 53 | hist_ref_idx = args.hist_ref_idx 54 | sat_ref_idx = args.hist_ref_idx 55 | islocal = args.islocal 56 | nohint = args.nohint 57 | 58 | make_folder(image_save, dataset) 59 | 60 | print("Running on gpu : ", gpu) 61 | cuda.set_device(gpu) 62 | 63 | _, _, test_loader, imsize = Color_Dataloader(dataset, batch_size) 64 | 65 | unet = UNet(imsize, islocal) 66 | 67 | unet.cuda() 68 | 69 | unet.eval() 70 | unet.load_state_dict(torch.load(os.path.join(model_path, dataset, model))) 71 | 72 | 73 | for i, (images, _) in enumerate(test_loader): 74 | 75 | batch = images.size(0) 76 | ''' 77 | additional variables for later use. 78 | change the picture type from rgb to CIE Lab. 79 | def process_data, def process_global in util file 80 | ''' 81 | if islocal: 82 | input, labels, _ = process_data(images, batch, imsize, islocal) 83 | local_ab, local_mask = process_local_sampling(batch_size, imsize, p=1) 84 | if nohint: 85 | local_ab = torch.zeros(batch_size, 2, imsize, imsize) 86 | local_mask = torch.zeros(batch_size, 1, imsize, imsize) 87 | 88 | side_input = torch.cat([local_ab, local_mask], 1) 89 | 90 | 91 | else: 92 | input, labels, ab_for_global = process_data(images, batch, imsize, islocal) 93 | 94 | print('global hint for histogram : ', global_hist) 95 | print('global hint for saturation : ', global_sat) 96 | 97 | side_input = process_global_sampling(batch, imsize, 0.03, 0.13, 98 | global_hist, global_sat, hist_ref_idx, sat_ref_idx) 99 | 100 | # make them all variable + gpu avialable 101 | 102 | input = Variable(input).cuda() 103 | labels = Variable(labels).cuda() 104 | side_input = Variable(side_input).cuda() 105 | 106 | outputs = unet(input, side_input) 107 | 108 | criterion = torch.nn.SmoothL1Loss() 109 | loss = criterion(outputs, labels) 110 | print('loss for test data: %2.4f'%(loss.cpu().data[0])) 111 | 112 | 113 | colored_images = torch.cat([input,outputs],1).data # 100 x 3 x 32 x 32 114 | gray_images = torch.zeros(batch_size, 3, imsize, imsize) 115 | img_gray =np.zeros((imsize, imsize,3)) 116 | 117 | colored_images_np = colored_images.cpu().numpy().transpose((0,2,3,1)) 118 | 119 | j = 0 120 | # make sample images back to rgb 121 | for img in colored_images_np: 122 | 123 | img[:,:,0] = img[:,:,0]*100 124 | img[:, :, 1:3] = img[:, :, 1:3] * 200 - 100 125 | img = img.astype(np.float64) 126 | img_RGB = lab2rgb(img) 127 | img_gray[:,:,0] = img[:,:,0] 128 | img_gray_RGB = lab2rgb(img_gray) 129 | 130 | colored_images[j] = torch.from_numpy(img_RGB.transpose((2,0,1))) 131 | gray_images[j] = torch.from_numpy(img_gray_RGB.transpose((2,0,1))) 132 | j+=1 133 | 134 | # 135 | torchvision.utils.save_image(images, 136 | os.path.join(image_save, dataset, '{}_real_samples.png'.format(idx))) 137 | torchvision.utils.save_image(colored_images, 138 | os.path.join(image_save, dataset, '{}_colored_samples.png'.format(idx))) 139 | torchvision.utils.save_image(gray_images, 140 | os.path.join(image_save, dataset, '{}_input_samples.png'.format(idx))) 141 | 142 | 143 | print('-----------images sampled!------------') 144 | break 145 | 146 | 147 | if __name__ == '__main__': 148 | args = parse_args() 149 | main(args) -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class UNetConvBlock1_1(nn.Module): 7 | def __init__(self, in_size, out_size, kernel_size=3): 8 | super(UNetConvBlock1_1, self).__init__() 9 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 10 | 11 | def forward(self, x): 12 | out = self.conv(x) 13 | return out 14 | 15 | class UNetConvBlock1_2(nn.Module): 16 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 17 | super(UNetConvBlock1_2, self).__init__() 18 | self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 19 | self.activation = activation 20 | self.batchnorm = nn.BatchNorm2d(out_size) 21 | self.conv3 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False) 22 | #self.conv3.weight.data.fill_(1) 23 | 24 | def forward(self, x): 25 | out = self.activation(x) 26 | out = self.activation(self.conv2(out)) 27 | out = self.batchnorm(out) 28 | out = self.conv3(out) 29 | return out 30 | 31 | class UNetConvBlock1_2_2(nn.Module): 32 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 33 | super(UNetConvBlock1_2_2, self).__init__() 34 | self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 35 | self.activation = activation 36 | self.batchnorm = nn.BatchNorm2d(out_size) 37 | 38 | def forward(self, x): 39 | out = self.activation(x) 40 | out = self.activation(self.conv2(out)) 41 | out = self.batchnorm(out) 42 | return out 43 | 44 | class UNetConvBlock2(nn.Module): 45 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 46 | super(UNetConvBlock2, self).__init__() 47 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 48 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1) 49 | self.activation = activation 50 | self.batchnorm = nn.BatchNorm2d(out_size) 51 | self.conv3 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False) 52 | #self.conv3.weight.data.fill_(1) 53 | 54 | def forward(self, x): 55 | out = self.activation(self.conv(x)) 56 | out = self.activation(self.conv2(out)) 57 | out = self.batchnorm(out) 58 | out = self.conv3(out) 59 | return out 60 | 61 | class UNetConvBlock2_2(nn.Module): 62 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 63 | super(UNetConvBlock2_2, self).__init__() 64 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 65 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1) 66 | self.activation = activation 67 | self.batchnorm = nn.BatchNorm2d(out_size) 68 | 69 | def forward(self, x): 70 | out = self.activation(self.conv(x)) 71 | out = self.activation(self.conv2(out)) 72 | out = self.batchnorm(out) 73 | return out 74 | 75 | class UNetConvBlock3(nn.Module): 76 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 77 | super(UNetConvBlock3, self).__init__() 78 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 79 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1) 80 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1) 81 | self.activation = activation 82 | self.batchnorm = nn.BatchNorm2d(out_size) 83 | self.conv4 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False) 84 | #self.conv4.weight.data.fill_(1) 85 | 86 | def forward(self, x): 87 | out = self.activation(self.conv(x)) 88 | out = self.activation(self.conv2(out)) 89 | out = self.activation(self.conv3(out)) 90 | out = self.batchnorm(out) 91 | out = self.conv4(out) 92 | return out 93 | 94 | class UNetConvBlock3_2(nn.Module): 95 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 96 | super(UNetConvBlock3_2, self).__init__() 97 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 98 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1) 99 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1) 100 | self.activation = activation 101 | self.batchnorm = nn.BatchNorm2d(out_size) 102 | 103 | def forward(self, x): 104 | out = self.activation(self.conv(x)) 105 | out = self.activation(self.conv2(out)) 106 | out = self.activation(self.conv3(out)) 107 | out = self.batchnorm(out) 108 | return out 109 | 110 | class UNetConvBlock4(nn.Module): 111 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 112 | super(UNetConvBlock4, self).__init__() 113 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1) 114 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 115 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 116 | self.activation = activation 117 | self.batchnorm = nn.BatchNorm2d(out_size) 118 | 119 | def forward(self, x): 120 | out = self.activation(self.conv(x)) 121 | out = self.activation(self.conv2(out)) 122 | out = self.activation(self.conv3(out)) 123 | out = self.batchnorm(out) 124 | return out 125 | 126 | class UNetConvBlock5(nn.Module): 127 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 128 | super(UNetConvBlock5, self).__init__() 129 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=2, dilation=2) 130 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2) 131 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2) 132 | self.activation = activation 133 | self.batchnorm = nn.BatchNorm2d(out_size) 134 | 135 | def forward(self, x): 136 | out = self.activation(self.conv(x)) 137 | out = self.activation(self.conv2(out)) 138 | out = self.activation(self.conv3(out)) 139 | out = self.batchnorm(out) 140 | return out 141 | 142 | class UNetConvBlock6(nn.Module): 143 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 144 | super(UNetConvBlock6, self).__init__() 145 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=2, dilation=2) 146 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2) 147 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2) 148 | self.activation = activation 149 | self.batchnorm = nn.BatchNorm2d(out_size) 150 | 151 | def forward(self, x): 152 | out = self.activation(self.conv(x)) 153 | out = self.activation(self.conv2(out)) 154 | out = self.activation(self.conv3(out)) 155 | out = self.batchnorm(out) 156 | return out 157 | 158 | class UNetConvBlock7(nn.Module): 159 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 160 | super(UNetConvBlock7, self).__init__() 161 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1) 162 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 163 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 164 | self.activation = activation 165 | self.batchnorm = nn.BatchNorm2d(out_size) 166 | 167 | def forward(self, x): 168 | out = self.activation(self.conv(x)) 169 | out = self.activation(self.conv2(out)) 170 | out = self.activation(self.conv3(out)) 171 | out = self.batchnorm(out) 172 | return out 173 | 174 | class UNetConvBlock8(nn.Module): 175 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False): 176 | super(UNetConvBlock8, self).__init__() 177 | self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1) 178 | self.bridge = nn.Conv2d(256, 256, kernel_size, padding=1) 179 | #self.bridge.weight.data.normal_(0, 0.01) 180 | #self.bridge.bias.data.fill_(1) 181 | self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 182 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 183 | self.activation = activation 184 | self.batchnorm = nn.BatchNorm2d(out_size) 185 | # def center_crop(self, layer, target_size): 186 | # batch_size, n_channels, layer_width, layer_height = layer.size() 187 | # xy1 = (layer_width - target_size) // 2 188 | # return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)] 189 | def forward(self, x, bridge): 190 | up = self.up(x) 191 | out = self.activation(self.bridge(bridge) + up) 192 | out = self.activation(self.conv(out)) 193 | out = self.activation(self.conv2(out)) 194 | out = self.batchnorm(out) 195 | return out 196 | 197 | class UNetConvBlock9(nn.Module): 198 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False): 199 | super(UNetConvBlock9, self).__init__() 200 | self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1) 201 | #self.up.weight.data.normal_(0, 0.01) 202 | #self.up.bias.data.fill_(1) 203 | self.bridge = nn.Conv2d(128, 128, kernel_size, padding=1) 204 | #self.bridge.weight.data.normal_(0, 0.01) 205 | #self.bridge.bias.data.fill_(1) 206 | self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 207 | #self.conv.weight.data.normal_(0, 0.01) 208 | #self.conv.bias.data.fill_(1) 209 | self.activation = activation 210 | self.batchnorm = nn.BatchNorm2d(out_size) 211 | 212 | def forward(self, x, bridge): 213 | up = self.up(x) 214 | out = self.activation(self.bridge(bridge) + up) 215 | out = self.activation(self.conv(out)) 216 | out = self.batchnorm(out) 217 | 218 | return out 219 | 220 | class UNetConvBlock10(nn.Module): 221 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False): 222 | super(UNetConvBlock10, self).__init__() 223 | self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1) 224 | #self.up.weight.data.normal_(0, 0.01) 225 | #self.up.bias.data.fill_(1) 226 | self.bridge = nn.Conv2d(64, 128, kernel_size, padding=1) 227 | #self.bridge.weight.data.normal_(0, 0.01) 228 | #self.bridge.bias.data.fill_(1) 229 | self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 230 | #self.conv.weight.data.normal_(0, 0.01) 231 | #self.conv.bias.data.fill_(1) 232 | self.activation = activation 233 | self.activation2 = nn.LeakyReLU(negative_slope=0.02) 234 | 235 | def forward(self, x, bridge): 236 | up = self.up(x) 237 | out = self.activation(self.bridge(bridge) + up) 238 | out = self.activation2(self.conv(out)) 239 | return out 240 | 241 | class prediction(nn.Module): 242 | def __init__(self, in_size, out_size, kernel_size=1, activation=F.tanh, space_dropout=False): 243 | super(prediction, self).__init__() 244 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, dilation=1) 245 | self.activation = activation 246 | 247 | def forward(self, x): 248 | out = self.activation(self.conv(x)) 249 | out = out * 100 250 | return out 251 | 252 | class convrelu(nn.Module): 253 | 254 | def __init__(self, in_size, out_size, kernel_size=1, activation=F.relu, space_dropout=False): 255 | super(convrelu, self).__init__() 256 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=0) 257 | self.activation = activation 258 | 259 | def forward(self, x): 260 | out = self.activation(self.conv(x)) 261 | return out 262 | 263 | class global_network(nn.Module): 264 | def __init__(self, image_size): 265 | super(global_network, self).__init__() 266 | self.oneD = convrelu(316, 512) 267 | self.twoD = convrelu(512, 512) 268 | self.threeD = convrelu(512, 512) 269 | self.fourD = convrelu(512, 512) 270 | self.image_size = image_size 271 | 272 | def forward(self, x): # 4 conv+relu layers with 1 x 1 kernel size with 512 depth, 273 | tmp = self.oneD(x) # made into h/8 x w/8 x 512 # input: 1 x 1 x 313+3 dimension tensor 274 | tmp = self.twoD(tmp) 275 | tmp = self.threeD(tmp) 276 | out = self.fourD(tmp) # batch x 1 x 1 x 512 277 | 278 | out = out.repeat(1,1, int(self.image_size/8), int(self.image_size/8)) 279 | return out 280 | 281 | class local_network(nn.Module): 282 | def __init__(self, in_size, out_size, imsize): 283 | super(local_network, self).__init__() 284 | self.imsize = imsize 285 | self.conv = nn.Conv2d(in_size, out_size, 3, padding=1) 286 | 287 | def forward(self, ab_input): 288 | out=self.conv(ab_input) # depth 64 red feed for the network 289 | return out 290 | 291 | class UNet(nn.Module): 292 | def __init__(self, imsize, islocal): 293 | super(UNet, self).__init__() 294 | self.imsize = imsize 295 | self.islocal = islocal 296 | 297 | if self.islocal==True: 298 | self.localnet = local_network(3, 64, self.imsize) 299 | else: # if local 300 | self.globalnet = global_network(self.imsize) 301 | 302 | self.convlayer1_1 = UNetConvBlock1_1(1, 64) 303 | self.convlayer1_2 = UNetConvBlock1_2(64, 64) 304 | self.convlayer1_2_2 = UNetConvBlock1_2_2(64, 64) 305 | self.convlayer2 = UNetConvBlock2(64, 128) 306 | self.convlayer2_2 = UNetConvBlock2_2(64, 128) 307 | self.convlayer3 = UNetConvBlock3(128, 256) 308 | self.convlayer3_2 = UNetConvBlock3_2(128, 256) 309 | self.convlayer4 = UNetConvBlock4(256, 512) 310 | self.convlayer5 = UNetConvBlock5(512, 512) # Dilated Convolution 311 | self.convlayer6 = UNetConvBlock6(512, 512) # Dilated Convolution 312 | self.convlayer7 = UNetConvBlock7(512, 512) 313 | self.convlayer8 = UNetConvBlock8(512, 256) 314 | self.convlayer9 = UNetConvBlock9(256, 128) 315 | self.convlayer10 = UNetConvBlock10(128, 128) 316 | 317 | self.prediction = prediction(128, 2) 318 | 319 | #self.last = nn.Conv2d(128, 2, 1) 320 | 321 | def forward(self, x, side_input): 322 | layer1_1 = self.convlayer1_1(x) 323 | 324 | if self.islocal == True: 325 | local_net = self.localnet(side_input) 326 | layer1_1 = layer1_1 + local_net 327 | 328 | layer1_2 = self.convlayer1_2(layer1_1) 329 | layer1_2_2 = self.convlayer1_2_2(layer1_1) 330 | layer2 = self.convlayer2(layer1_2) 331 | layer2_2 = self.convlayer2_2(layer1_2) 332 | layer3 = self.convlayer3(layer2) 333 | layer3_2 = self.convlayer3_2(layer2) 334 | layer4 = self.convlayer4(layer3) 335 | 336 | if self.islocal == False: 337 | global_net = self.globalnet(side_input) 338 | layer4 = layer4 + global_net 339 | 340 | layer5 = self.convlayer5(layer4) 341 | layer6 = self.convlayer6(layer5) 342 | layer7 = self.convlayer7(layer6) 343 | layer8 = self.convlayer8(layer7, layer3_2) 344 | layer9 = self.convlayer9(layer8, layer2_2) 345 | layer10 = self.convlayer10(layer9, layer1_2_2) 346 | 347 | prediction = self.prediction(layer10) 348 | 349 | return prediction -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import datetime 5 | import torch 6 | 7 | 8 | def check_value(inds, val): 9 | # Check to see if an array is a single element equaling a particular value 10 | # Good for pre-processing inputs in a function 11 | if (np.array(inds).size == 1): 12 | if (inds == val): 13 | return True 14 | return False 15 | 16 | 17 | def flatten_nd_array(pts_nd, axis=1): 18 | # Flatten an nd array into a 2d array with a certain axis 19 | # INPUTS 20 | # pts_nd N0xN1x...xNd array 21 | # axis integer 22 | # OUTPUTS 23 | # pts_flt prod(N \ N_axis) x N_axis array 24 | NDIM = pts_nd.ndim 25 | SHP = np.array(pts_nd.shape) 26 | nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis))) # non axis indices 27 | NPTS = np.prod(SHP[nax]) 28 | axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0) 29 | pts_flt = pts_nd.transpose((axorder)) 30 | pts_flt = pts_flt.reshape(NPTS, SHP[axis]) 31 | return pts_flt 32 | 33 | 34 | def unflatten_2d_array(pts_flt, pts_nd, axis=1, squeeze=False): 35 | # Unflatten a 2d array with a certain axis 36 | # INPUTS 37 | # pts_flt prod(N \ N_axis) x M array 38 | # pts_nd N0xN1x...xNd array 39 | # axis integer 40 | # squeeze bool if true, M=1, squeeze it out 41 | # OUTPUTS 42 | # pts_out N0xN1x...xNd array 43 | NDIM = pts_nd.ndim 44 | SHP = np.array(pts_nd.shape) 45 | nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis))) # non axis indices 46 | NPTS = np.prod(SHP[nax]) 47 | 48 | if (squeeze): 49 | axorder = nax 50 | axorder_rev = np.argsort(axorder) 51 | M = pts_flt.shape[1] 52 | NEW_SHP = SHP[nax].tolist() 53 | pts_out = pts_flt.reshape(NEW_SHP) 54 | pts_out = pts_out.transpose(axorder_rev) 55 | else: 56 | axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0) 57 | axorder_rev = np.argsort(axorder) 58 | M = pts_flt.shape[1] 59 | NEW_SHP = SHP[nax].tolist() 60 | NEW_SHP.append(M) 61 | pts_out = pts_flt.reshape(NEW_SHP) 62 | pts_out = pts_out.transpose(axorder_rev) 63 | 64 | return pts_out 65 | 66 | 67 | def na(): 68 | return np.newaxis 69 | 70 | 71 | class Timer(): 72 | def __init__(self): 73 | self.cur_t = time.time() 74 | 75 | def tic(self): 76 | self.cur_t = time.time() 77 | 78 | def toc(self): 79 | return time.time() - self.cur_t 80 | 81 | def tocStr(self, t=-1): 82 | if (t == -1): 83 | return str(datetime.timedelta(seconds=np.round(time.time() - self.cur_t, 3)))[:-4] 84 | else: 85 | return str(datetime.timedelta(seconds=np.round(t, 3)))[:-4] 86 | 87 | def distribution(tensor): 88 | 89 | tensor = torch.div(tensor, expand(tensor.sum(dim=1).unsqueeze(-1), tensor)) 90 | if (tensor.sum(dim=1).data.cpu().numpy()==0).any(): 91 | print ("") 92 | print ("") 93 | print ("division by zero") 94 | print ("") 95 | print ("") 96 | return tensor.unsqueeze(-1) 97 | 98 | def expand(tensor, target): 99 | return tensor.expand_as(target) 100 | 101 | 102 | def make_folder(path, dataset): 103 | try: 104 | os.makedirs(os.path.join(path, dataset)) 105 | except OSError: 106 | pass 107 | --------------------------------------------------------------------------------