├── Cascaded_Network_LM_256.py ├── Label256Full └── aachen_000000_000019_gtFine_color.png ├── Label256Fullval ├── frankfurt_000000_000294_gtFine_color.png ├── frankfurt_000000_000576_gtFine_color.png ├── frankfurt_000000_011810_gtFine_color.png └── frankfurt_000001_007285_gtFine_color.png ├── README.md ├── RGB256Full └── aachen_000000_000019_leftImg8bit.png ├── RGB256FullVal └── frankfurt_000000_000294_leftImg8bit.png ├── Screenshot_from.png ├── datasets └── cityscapes.json ├── helper.py └── val.jpg /Cascaded_Network_LM_256.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Dec 13 11:45:23 2017 5 | 6 | @author: soumya 7 | """ 8 | 9 | from __future__ import print_function, division 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.autograd import Variable 14 | import numpy as np 15 | from torchvision import transforms 16 | import matplotlib.pyplot as plt 17 | import os, scipy.io 18 | import torch.utils.model_zoo as model_zoo 19 | import helper 20 | from skimage import io 21 | from random import shuffle 22 | import scipy.misc 23 | 24 | 25 | plt.ion() # interactive mode 26 | 27 | 28 | __all__ = ['VGG19', 'vgg19'] 29 | 30 | 31 | model_urls = { 32 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 33 | } 34 | 35 | class LayerNorm(nn.Module): 36 | 37 | def __init__(self, num_features, eps=1e-12, affine=True): 38 | super(LayerNorm, self).__init__() 39 | self.num_features = num_features 40 | self.affine = affine 41 | self.eps = eps 42 | 43 | if self.affine: 44 | self.gamma = nn.Parameter(torch.ones(num_features)) 45 | self.beta = nn.Parameter(torch.zeros(num_features)) 46 | 47 | def forward(self, x): 48 | 49 | shape = [-1] + [1] * (x.dim() - 1) 50 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 51 | std = x.view(x.size(0), -1).std(1).view(*shape) 52 | 53 | y = (x - mean) / (std + self.eps) 54 | if self.affine: 55 | shape = [1, -1] + [1] * (x.dim() - 2) 56 | y = self.gamma.view(*shape) * y + self.beta.view(*shape) 57 | return y 58 | 59 | class cascaded_model(nn.Module): 60 | 61 | def __init__(self, D_m): 62 | super(cascaded_model, self).__init__() 63 | self.conv1=nn.Conv2d(20, D_m[1], kernel_size=3, stride=1, padding=1,bias=True) 64 | nn.init.xavier_uniform(self.conv1.weight, gain=1) 65 | 66 | nn.init.constant(self.conv1.bias, 0) 67 | self.lay1=LayerNorm(D_m[1], eps=1e-12, affine=True) 68 | 69 | self.relu1=nn.LeakyReLU(negative_slope=0.2,inplace=True) 70 | 71 | self.conv11=nn.Conv2d(D_m[1], D_m[1], kernel_size=3, stride=1, padding=1,bias=True) 72 | nn.init.xavier_uniform(self.conv11.weight, gain=1) 73 | 74 | nn.init.constant(self.conv11.bias, 0) 75 | self.lay11=LayerNorm(D_m[1], eps=1e-12, affine=True) 76 | 77 | self.relu11=nn.LeakyReLU(negative_slope=0.2,inplace=True) 78 | 79 | #Layer2 80 | 81 | self.conv2=nn.Conv2d(D_m[1]+20, D_m[2], kernel_size=3, stride=1, padding=1,bias=True) 82 | nn.init.xavier_uniform(self.conv2.weight, gain=1) 83 | # nn.init.constant(self.conv2.weight, 1) 84 | nn.init.constant(self.conv2.bias, 0) 85 | self.lay2=LayerNorm(D_m[2], eps=1e-12, affine=True) 86 | # self.lay2=nn.BatchNorm2d(D_m[2]) 87 | self.relu2=nn.LeakyReLU(negative_slope=0.2,inplace=True) 88 | 89 | self.conv22=nn.Conv2d(D_m[2], D_m[2], kernel_size=3, stride=1, padding=1,bias=True) 90 | nn.init.xavier_uniform(self.conv22.weight, gain=1) 91 | # nn.init.constant(self.conv22.weight, 1) 92 | nn.init.constant(self.conv22.bias, 0) 93 | self.lay22=LayerNorm(D_m[2], eps=1e-12, affine=True) 94 | # self.lay2=nn.BatchNorm2d(D_m[2]) 95 | self.relu22=nn.LeakyReLU(negative_slope=0.2,inplace=True) 96 | 97 | 98 | #layer 3 99 | 100 | self.conv3=nn.Conv2d(D_m[2]+20, D_m[3], kernel_size=3, stride=1, padding=1,bias=True) 101 | nn.init.xavier_uniform(self.conv3.weight, gain=1) 102 | # nn.init.constant(self.conv3.weight,1) 103 | nn.init.constant(self.conv3.bias, 0) 104 | self.lay3=LayerNorm(D_m[3], eps=1e-12, affine=True) 105 | # self.lay3=nn.BatchNorm2d(D_m[3]) 106 | self.relu3=nn.LeakyReLU(negative_slope=0.2,inplace=True) 107 | 108 | self.conv33=nn.Conv2d(D_m[3], D_m[3], kernel_size=3, stride=1, padding=1,bias=True) 109 | nn.init.xavier_uniform(self.conv33.weight,gain=1) 110 | nn.init.constant(self.conv33.bias, 0) 111 | self.lay33=LayerNorm(D_m[3], eps=1e-12, affine=True) 112 | # self.lay3=nn.BatchNorm2d(D_m[3]) 113 | self.relu33=nn.LeakyReLU(negative_slope=0.2,inplace=True) 114 | 115 | #layer4 116 | 117 | self.conv4=nn.Conv2d(D_m[3]+20, D_m[4], kernel_size=3, stride=1, padding=1,bias=True) 118 | nn.init.xavier_uniform(self.conv4.weight,gain=1) 119 | nn.init.constant(self.conv4.bias, 0) 120 | self.lay4=LayerNorm(D_m[4], eps=1e-12, affine=True) 121 | # self.lay4=nn.BatchNorm2d(D_m[4]) 122 | self.relu4=nn.LeakyReLU(negative_slope=0.2,inplace=True) 123 | 124 | self.conv44=nn.Conv2d(D_m[4], D_m[4], kernel_size=3, stride=1, padding=1,bias=True) 125 | nn.init.xavier_uniform(self.conv44.weight,gain=1) 126 | nn.init.constant(self.conv44.bias, 0) 127 | self.lay44=LayerNorm(D_m[4], eps=1e-12, affine=True) 128 | # self.lay4=nn.BatchNorm2d(D_m[4]) 129 | self.relu44=nn.LeakyReLU(negative_slope=0.2,inplace=True) 130 | 131 | #layers5 132 | 133 | self.conv5=nn.Conv2d(D_m[4]+20, D_m[5], kernel_size=3, stride=1, padding=1,bias=True) 134 | nn.init.xavier_uniform(self.conv5.weight, gain=1) 135 | nn.init.constant(self.conv5.bias, 0) 136 | self.lay5=LayerNorm(D_m[5], eps=1e-12, affine=True) 137 | # self.lay5=nn.BatchNorm2d(D_m[5]) 138 | self.relu5=nn.LeakyReLU(negative_slope=0.2,inplace=True) 139 | 140 | self.conv55=nn.Conv2d(D_m[5], D_m[5], kernel_size=3, stride=1, padding=1,bias=True) 141 | nn.init.xavier_uniform(self.conv55.weight, gain=1) 142 | nn.init.constant(self.conv55.bias, 0) 143 | self.lay55=LayerNorm(D_m[5], eps=1e-12, affine=True) 144 | # self.lay5=nn.BatchNorm2d(D_m[5]) 145 | self.relu55=nn.LeakyReLU(negative_slope=0.2,inplace=True) 146 | 147 | #layer 6 148 | 149 | self.conv6=nn.Conv2d(D_m[5]+20, D_m[6], kernel_size=3, stride=1, padding=1,bias=True) 150 | nn.init.xavier_uniform(self.conv6.weight, gain=1) 151 | nn.init.constant(self.conv6.bias, 0) 152 | self.lay6=LayerNorm(D_m[6], eps=1e-12, affine=True) 153 | # self.lay6=nn.BatchNorm2d(D_m[6]) 154 | self.relu6=nn.LeakyReLU(negative_slope=0.2,inplace=True) 155 | 156 | self.conv66=nn.Conv2d(D_m[6], D_m[6], kernel_size=3, stride=1, padding=1,bias=True) 157 | nn.init.xavier_uniform(self.conv66.weight, gain=1) 158 | nn.init.constant(self.conv66.bias, 0) 159 | self.lay66=LayerNorm(D_m[6], eps=1e-12, affine=True) 160 | # self.lay6=nn.BatchNorm2d(D_m[6]) 161 | self.relu66=nn.LeakyReLU(negative_slope=0.2,inplace=True) 162 | 163 | #layer7 164 | self.conv7=nn.Conv2d(D_m[6]+20, D_m[6], kernel_size=3, stride=1, padding=1,bias=True) 165 | nn.init.xavier_uniform(self.conv7.weight, gain=1) 166 | nn.init.constant(self.conv7.bias, 0) 167 | self.lay7=LayerNorm(D_m[6], eps=1e-12, affine=True) 168 | # self.lay6=nn.BatchNorm2d(D_m[6]) 169 | self.relu7=nn.LeakyReLU(negative_slope=0.2,inplace=True) 170 | 171 | self.conv77=nn.Conv2d(D_m[6], D_m[6], kernel_size=3, stride=1, padding=1,bias=True) 172 | nn.init.xavier_uniform(self.conv77.weight, gain=1) 173 | nn.init.constant(self.conv77.bias, 0) 174 | self.lay77=LayerNorm(D_m[6], eps=1e-12, affine=True) 175 | # self.lay6=nn.BatchNorm2d(D_m[6]) 176 | self.relu77=nn.LeakyReLU(negative_slope=0.2,inplace=True) 177 | 178 | self.conv8=nn.Conv2d(D_m[6], 27, kernel_size=1, stride=1, padding=0,bias=True) 179 | nn.init.xavier_uniform(self.conv8.weight, gain=1) 180 | nn.init.constant(self.conv8.bias, 0) 181 | def forward(self, D, label): 182 | 183 | out1= self.conv1(D[1]) 184 | L1=self.lay1(out1) 185 | out2= self.relu1(L1) 186 | 187 | out11= self.conv11(out2) 188 | L11=self.lay11(out11) 189 | out22= self.relu11(L11) 190 | 191 | m = nn.Upsample(size=(D[1].size(3),D[1].size(3)*2), mode='bilinear') 192 | 193 | img1 = torch.cat((m(out22), D[2]),1) 194 | 195 | out3= self.conv2(img1) 196 | L2=self.lay2(out3) 197 | out4= self.relu2(L2) 198 | 199 | out33= self.conv22(out4) 200 | L22=self.lay22(out33) 201 | out44= self.relu22(L22) 202 | 203 | m = nn.Upsample(size=(D[2].size(3),D[2].size(3)*2), mode='bilinear') 204 | 205 | img2 = torch.cat((m(out44), D[3]),1) 206 | 207 | out5= self.conv3(img2) 208 | L3=self.lay3(out5) 209 | out6= self.relu3(L3) 210 | 211 | out55= self.conv33(out6) 212 | L33=self.lay33(out55) 213 | out66= self.relu33(L33) 214 | 215 | m = nn.Upsample(size=(D[3].size(3),D[3].size(3)*2),mode='bilinear') 216 | 217 | img3 = torch.cat((m(out66), D[4]),1) 218 | 219 | out7= self.conv4(img3) 220 | L4=self.lay4(out7) 221 | out8= self.relu4(L4) 222 | 223 | out77= self.conv44(out8) 224 | L44=self.lay44(out77) 225 | out88= self.relu44(L44) 226 | 227 | m = nn.Upsample(size=(D[4].size(3),D[4].size(3)*2),mode='bilinear') 228 | 229 | img4 = torch.cat((m(out88), D[5]),1) 230 | 231 | out9= self.conv5(img4) 232 | L5=self.lay5(out9) 233 | out10= self.relu5(L5) 234 | 235 | out99= self.conv55(out10) 236 | L55=self.lay55(out99) 237 | out110= self.relu55(L55) 238 | # L5=self.lay5(out10) 239 | 240 | m = nn.Upsample(size=(D[5].size(3),D[5].size(3)*2),mode='bilinear') 241 | 242 | img5 = torch.cat((m(out110), D[6]),1) 243 | 244 | out11= self.conv6(img5) 245 | L6=self.lay6(out11) 246 | out12= self.relu6(L6) 247 | 248 | out111= self.conv66(out12) 249 | L66=self.lay66(out111) 250 | out112= self.relu66(L66) 251 | 252 | m = nn.Upsample(size=(D[6].size(3),D[6].size(3)*2),mode='bilinear') 253 | 254 | img6 = torch.cat((m(out112), label),1) 255 | 256 | out13= self.conv7(img6) 257 | L7=self.lay7(out13) 258 | out14= self.relu7(L7) 259 | 260 | out113= self.conv77(out14) 261 | L77=self.lay77(out113) 262 | out114= self.relu77(L77) 263 | 264 | out15= self.conv8(out114) 265 | 266 | out15=(out15+1.0)/2.0*255.0 267 | 268 | out16,out17,out18=torch.chunk(out15.permute(1,0,2,3),3,0) 269 | out=torch.cat((out16,out17,out18),1) 270 | 271 | return out 272 | 273 | 274 | 275 | class VGG19(nn.Module): 276 | 277 | def __init__(self): 278 | super(VGG19, self).__init__() 279 | self.conv1=nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True) 280 | self.relu1=nn.ReLU(inplace=True) 281 | 282 | self.conv2=nn.Conv2d(64,64, kernel_size=3, stride=1, padding=1, bias=True) 283 | self.relu2=nn.ReLU(inplace=True) 284 | self.max1=nn.AvgPool2d(kernel_size=2, stride=2) 285 | 286 | self.conv3=nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True) 287 | self.relu3=nn.ReLU(inplace=True) 288 | 289 | self.conv4=nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True) 290 | self.relu4=nn.ReLU(inplace=True) 291 | self.max2=nn.AvgPool2d(kernel_size=2, stride=2) 292 | 293 | self.conv5=nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True) 294 | self.relu5=nn.ReLU(inplace=True) 295 | 296 | self.conv6=nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True) 297 | self.relu6=nn.ReLU(inplace=True) 298 | 299 | self.conv7=nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True) 300 | self.relu7=nn.ReLU(inplace=True) 301 | 302 | self.conv8=nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True) 303 | self.relu8=nn.ReLU(inplace=True) 304 | self.max3=nn.AvgPool2d(kernel_size=2, stride=2) 305 | 306 | self.conv9=nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=True) 307 | self.relu9=nn.ReLU(inplace=True) 308 | 309 | self.conv10=nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 310 | self.relu10=nn.ReLU(inplace=True) 311 | 312 | self.conv11=nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 313 | self.relu11=nn.ReLU(inplace=True) 314 | 315 | self.conv12=nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 316 | self.relu12=nn.ReLU(inplace=True) 317 | self.max4=nn.AvgPool2d(kernel_size=2, stride=2) 318 | 319 | self.conv13=nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 320 | self.relu13=nn.ReLU(inplace=True) 321 | 322 | self.conv14=nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 323 | self.relu14=nn.ReLU(inplace=True) 324 | 325 | self.conv15=nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 326 | self.relu15=nn.ReLU(inplace=True) 327 | 328 | self.conv16=nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 329 | self.relu16=nn.ReLU(inplace=True) 330 | self.max5=nn.AvgPool2d(kernel_size=2, stride=2) 331 | 332 | def forward(self, x): 333 | 334 | out1= self.conv1(x) 335 | out2= self.relu1(out1) 336 | 337 | out3= self.conv2(out2) 338 | out4=self.relu2(out3) 339 | out5=self.max1(out4) 340 | 341 | out6=self.conv3(out5) 342 | out7=self.relu3(out6) 343 | out8=self.conv4(out7) 344 | out9=self.relu4(out8) 345 | out10=self.max2(out9) 346 | out11=self.conv5(out10) 347 | out12=self.relu5(out11) 348 | out13=self.conv6(out12) 349 | out14=self.relu6(out13) 350 | out15=self.conv7(out14) 351 | out16=self.relu7(out15) 352 | out17=self.conv8(out16) 353 | out18=self.relu8(out17) 354 | out19=self.max3(out18) 355 | out20=self.conv9(out19) 356 | out21=self.relu9(out20) 357 | out22=self.conv10(out21) 358 | out23=self.relu10(out22) 359 | out24=self.conv11(out23) 360 | out25=self.relu11(out24) 361 | out26=self.conv12(out25) 362 | out27=self.relu12(out26) 363 | out28=self.max4(out27) 364 | out29=self.conv13(out28) 365 | out30=self.relu13(out29) 366 | out31=self.conv14(out30) 367 | out32=self.relu14(out31) 368 | out33=self.conv15(out32) 369 | out34=self.relu15(out33) 370 | out35=self.conv16(out34) 371 | out36=self.relu16(out35) 372 | out37=self.max5(out36) 373 | return out4, out9, out14, out23, out32, out7 #Add appropriate outputs 374 | 375 | 376 | def vggnet(pretrained=False, model_root=None, **kwargs): 377 | model = VGG19(**kwargs) 378 | if pretrained: 379 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'], model_root)) 380 | return model 381 | 382 | Net=vggnet(pretrained=False, model_root=None) 383 | 384 | Net=Net.cuda() 385 | 386 | vgg_rawnet=scipy.io.loadmat('imagenet-vgg-verydeep-19.mat') 387 | 388 | vgg_layers=vgg_rawnet['layers'][0] 389 | 390 | #Weight initialization according to the pretrained VGG Very deep 19 network Network weights 391 | 392 | layers=[0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34] 393 | 394 | att=['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'conv6', 'conv7', 'conv8', 'conv9', 'conv10', 'conv11', 'conv12', 'conv13', 'conv14', 'conv15', 'conv16'] 395 | 396 | S=[64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 512] 397 | for L in range(16): 398 | # getattr(Net, att[L]).weight=nn.Parameter(torch.from_numpy(vgg_layers[layers[L]][0][0][2][0][0].reshape(S[L],-1,3,3))) 399 | getattr(Net, att[L]).weight=nn.Parameter(torch.from_numpy(vgg_layers[layers[L]][0][0][2][0][0]).permute(3,2,0,1).cuda()) 400 | getattr(Net, att[L]).bias=nn.Parameter(torch.from_numpy(vgg_layers[layers[L]][0][0][2][0][1]).view(S[L]).cuda()) 401 | 402 | #Till Now VGG19 pretrained network is ready 403 | 404 | #Cascaded Refinement Network will start from now 405 | global D_m 406 | global D 407 | global count 408 | D=[] 409 | D_m=[] 410 | count=0 411 | 412 | def recursive_img(label,res): #Resulution may refers to the final image output i.e. 256x512 or 512x1024 413 | dim=512 if res>=128 else 1024 414 | # #M_low will start from 4x8 to resx2*res 415 | if res == 4: 416 | downsampled = label #torch.unsqueeze(torch.from_numpy(label).float().permute(2,0,1), dim=0) 417 | else: 418 | max1=nn.AvgPool2d(kernel_size=2, padding=0, stride=2) 419 | downsampled=max1(label) 420 | img = recursive_img(downsampled, res//2) 421 | 422 | global D 423 | global count 424 | global D_m 425 | D.insert(count, downsampled) 426 | D_m.insert(count, dim) 427 | count+=1 428 | return downsampled 429 | # Loss function goes here 430 | 431 | def compute_error(R,F,label_images): 432 | E=torch.mean(torch.mean(label_images* torch.mean(torch.abs(R-F),1).unsqueeze(1),2),2) 433 | # E= torch.mean(torch.abs(R-F)) 434 | return E 435 | 436 | def loss_function(real, generator,label_images,D): 437 | 438 | aa=np.array([123.6800, 116.7790, 103.9390]).reshape((1,1,1,3)) 439 | bb=Variable(torch.from_numpy(aa).float().permute(0,3,1,2).cuda()) 440 | out3_r, out8_r, out13_r, out22_r, out33_r, out7r =Net(real-bb) 441 | out3_f, out8_f, out13_f, out22_f, out33_f, out7f =Net(generator-bb) 442 | 443 | E0=compute_error(real-bb,generator-bb,label_images) 444 | E1=compute_error(out3_r,out3_f,label_images)/1.6 445 | E2=compute_error(out8_r,out8_f,D[6])/2.3 446 | E3=compute_error(out13_r,out13_f,D[5])/1.8 447 | E4=compute_error(out22_r,out22_f,D[4])/2.8 448 | E5=compute_error(out33_r,out33_f,D[3])*10/0.8 449 | Total_loss=E0+E1+E2+E3+E4+E5 450 | aa=torch.min(Total_loss, 0) 451 | G_loss=torch.sum(aa[0])*0.999+torch.sum(torch.mean(Total_loss, 0))*0.001 452 | #G_loss=torch.sum(torch.min(Total_loss, 0))*0.999+torch.sum(torch.mean(Total_loss, 0))*0.001 453 | return G_loss 454 | 455 | def training(M): 456 | res=256 457 | 458 | label_dir='Label256Full' 459 | l=os.listdir(label_dir) 460 | 461 | for epoch in range(200): 462 | running_loss=0 463 | c_t=0 464 | for I in enumerate(l): 465 | c_t+=1 466 | global D_m 467 | global D 468 | global count 469 | D=[] 470 | D_m=[] 471 | count=0 472 | J=str.replace(I[1],'gtFine_color.png', 'leftImg8bit.png') 473 | 474 | label_images1=Variable(torch.unsqueeze(torch.from_numpy(helper.get_semantic_map('Label256Full/'+I[1])).float().permute(2,0,1), dim=0))#.cuda()#training label 475 | input_images=Variable(torch.unsqueeze(torch.from_numpy(io.imread("RGB256Full/"+J)).float(),dim=0).permute(0,3,1,2)) 476 | label_images = torch.cat((label_images1, (1-label_images1.sum(1)).unsqueeze(1)),1) 477 | input_images=input_images.cuda() 478 | label_images=label_images.cuda() 479 | G_temp=recursive_img(label_images,res) 480 | if M==0: 481 | model=cascaded_model(D_m) 482 | model=model.cuda() 483 | # model.load_state_dict(torch.load('mynet_updated.pth')) # if u want to resume training from a pretrained model then add the .pth file here 484 | optimizer=optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) 485 | 486 | optimizer.zero_grad() 487 | Generator=model(D, label_images) 488 | Loss=loss_function(input_images, Generator,label_images,D) 489 | 490 | Loss.backward() 491 | optimizer.step() 492 | M=1 493 | running_loss += Loss.data[0] 494 | print(epoch,c_t,Loss.data[0]) 495 | del D 496 | del D_m 497 | del count 498 | del Loss, label_images,G_temp, input_images 499 | shuffle(l) 500 | epoch_loss = running_loss / 2975.0 #can replace the 2975 with c_t for generalization 501 | print(epoch, epoch_loss) 502 | if epoch % 2 == 0: 503 | Generator=Generator.permute(0,2,3,1) 504 | Generator=Generator.cpu() 505 | Generator=Generator.data.numpy() 506 | output=np.minimum(np.maximum(Generator,0.0), 255.0) 507 | scipy.misc.toimage(output[0,:,:,:],cmin=0,cmax=255).save("%06d_output_real.jpg"%epoch) 508 | #epoch_acc = running_corrects / 2975.0 509 | 510 | # return Loss 511 | best_model_wts = model.state_dict() 512 | model.load_state_dict(best_model_wts) 513 | 514 | return model 515 | 516 | def testing(seman_in): 517 | label_images1=Variable(torch.unsqueeze(torch.from_numpy(helper.get_semantic_map(seman_in)).float().permute(2,0,1), dim=0)) 518 | global D_m 519 | global D 520 | global count 521 | D=[] 522 | D_m=[] 523 | count=0 524 | 525 | label_images = torch.cat((label_images1, (1-label_images1.sum(1)).unsqueeze(1)),1) 526 | label_images=label_images#.cuda() 527 | res=256 528 | G_temp=recursive_img(label_images,res) 529 | model=cascaded_model(D_m) 530 | model=model.cuda() 531 | model.load_state_dict(torch.load('mynet_200epoch_CRN.pth')) 532 | model=model.cpu().eval() 533 | G=model(D, label_images) 534 | Generator=G.permute(0,2,3,1) 535 | Generator=Generator 536 | Generator=Generator.data.numpy() 537 | output=np.minimum(np.maximum(Generator,0.0), 255.0) 538 | scipy.misc.toimage(output[2,:,:,:],cmin=0,cmax=255).save("val3.jpg") 539 | 540 | 541 | mode='test' 542 | 543 | if mode=='train': 544 | M=0 545 | model_ft=training(M) 546 | torch.save(model_ft.state_dict(),'mynet_200epoch_CRN.pth') 547 | else: 548 | file_name='/home/soumya/Documents/cascaded_code_for_cluster/Label256Fullval/frankfurt_000000_000294_gtFine_color.png' 549 | testing(file_name) 550 | 551 | 552 | 553 | 554 | 555 | -------------------------------------------------------------------------------- /Label256Full/aachen_000000_000019_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/Label256Full/aachen_000000_000019_gtFine_color.png -------------------------------------------------------------------------------- /Label256Fullval/frankfurt_000000_000294_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/Label256Fullval/frankfurt_000000_000294_gtFine_color.png -------------------------------------------------------------------------------- /Label256Fullval/frankfurt_000000_000576_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/Label256Fullval/frankfurt_000000_000576_gtFine_color.png -------------------------------------------------------------------------------- /Label256Fullval/frankfurt_000000_011810_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/Label256Fullval/frankfurt_000000_011810_gtFine_color.png -------------------------------------------------------------------------------- /Label256Fullval/frankfurt_000001_007285_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/Label256Fullval/frankfurt_000001_007285_gtFine_color.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Photographic Image Synthesis with Cascaded Refinement Networks-Pytorch (https://arxiv.org/abs/1707.09405) 2 | This is a Pytorch implementation of cascaded refinement networks to synthesize photographic images from semantic layouts. Now the pretrained model and codes for training the network from scratch are available for 256x512 resolution. Thanks to [Qifeng Chen](https://github.com/CQFIO) for his tensorflow implementation which helped a lot in developing this pytorch version. 3 | ![Output](https://github.com/Blade6570/Photographic-Image-Synthesis-with-Cascaded-Refinement-Networks--Pytorch-/blob/master/Screenshot_from.png?raw=true "Comparision with Original TensorFlow version") 4 | 5 | **Testing** 6 | 1. Download this package and keep all the subsequent mentioned files in the same folder. 7 | 2. Download the pretrained VGG19 Net from [VGG19](https://drive.google.com/open?id=1wkMhYoRdjZ7LC1OeTOIdzf5YcxNvR8vs) 8 | 3. Download the pretrained weights for the CRN network for 256x512 [CRN](https://drive.google.com/open?id=1WHPMDLkRvQMKRoHhV8-tqFhZgmOfoA3p) 9 | 4. Keep the *mode=test* and mention the semantic image name to be tested in the *Cascadaed_Network_LM_256.py* 10 | 5. The synthesized images will be saved in current folder. 11 | 12 | **Training** 13 | 1. Follow steps *1 to 3* from the testing steps. 14 | 2. Resize all the training images to 256x512. Keep the semantic segmentated training images in *Label256Full* folder and 15 | the RGB training images in *RGB256Full* (without any subfolders). 16 | 3. Set *mode=train* in *Cascadaed_Network_LM_256.py* and run it for desired epochs (default is 200). 17 | 18 | **Future Work** 19 | 1. Soon the pretrained weights for resolution *512x1024* and *1024x20148* will be available along with training scripts. 20 | 21 | **Note** 22 | 1. All the codes are written to run on GPU. Suitable changes should be done if you want to run on CPU. Also feel free to 23 | customize it according to your need. 24 | -------------------------------------------------------------------------------- /RGB256Full/aachen_000000_000019_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/RGB256Full/aachen_000000_000019_leftImg8bit.png -------------------------------------------------------------------------------- /RGB256FullVal/frankfurt_000000_000294_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/RGB256FullVal/frankfurt_000000_000294_leftImg8bit.png -------------------------------------------------------------------------------- /Screenshot_from.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/Screenshot_from.png -------------------------------------------------------------------------------- /datasets/cityscapes.json: -------------------------------------------------------------------------------- 1 | {"palette":[[128,64,128],[244,35,232],[70,70,70],[102,102,156],[190,153,153],[153,153,153],[250,170,30],[220,220,0],[107,142,35],[152,251,152],[70,130,180],[220,20,60],[255,0,0],[0,0,142],[0,0,70],[0,60,100],[0,80,100],[0,0,230],[119,11,32]],"mean":[72.39,82.91,73.16],"dilation":10, "zoom":1} -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import os,numpy as np 2 | from os.path import dirname, exists, join, splitext 3 | import json,scipy 4 | class Dataset(object): 5 | def __init__(self, dataset_name): 6 | self.work_dir = dirname(os.path.realpath('__file__')) 7 | info_path = join(self.work_dir, 'datasets', dataset_name + '.json') 8 | with open(info_path, 'r') as fp: 9 | info = json.load(fp) 10 | self.palette = np.array(info['palette'], dtype=np.uint8) 11 | 12 | 13 | def get_semantic_map(path): 14 | dataset=Dataset('cityscapes') 15 | semantic=scipy.misc.imread(path) 16 | tmp=np.zeros((semantic.shape[0],semantic.shape[1],dataset.palette.shape[0]),dtype=np.float32) 17 | for k in range(dataset.palette.shape[0]): 18 | tmp[:,:,k]=np.float32((semantic[:,:,0]==dataset.palette[k,0])&(semantic[:,:,1]==dataset.palette[k,1])&(semantic[:,:,2]==dataset.palette[k,2])) 19 | # return tmp.reshape((1,)+tmp.shape) 20 | return tmp 21 | def print_semantic_map(semantic,path): 22 | dataset=Dataset('cityscapes') 23 | semantic=semantic.transpose([1,2,3,0]) 24 | prediction=np.argmax(semantic,axis=2) 25 | color_image=dataset.palette[prediction.ravel()].reshape((prediction.shape[0],prediction.shape[1],3)) 26 | row,col,dump=np.where(np.sum(semantic,axis=2)==0) 27 | color_image[row,col,:]=0 28 | scipy.misc.imsave(path,color_image) 29 | -------------------------------------------------------------------------------- /val.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blade6570/PhotographicImageSynthesiswithCascadedRefinementNetworks-Pytorch/fb42106a69d461ed9b2ed943a44fd687d86ed6f5/val.jpg --------------------------------------------------------------------------------