├── .DS_Store ├── .gitignore ├── README.md ├── images └── model_arch.png ├── pix_dataloader.py ├── pix_network_1.py ├── pix_network_2.py ├── pix_train.py └── pix_util.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjooyoo/pix-color_pixel-cnn/f0d077fe62ba3bae5b641549bad2d18c16b10e2f/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PixColor Pytorch implementation 2 | 3 | paper link: [here](https://arxiv.org/abs/1705.07208) 4 | 5 | PixColor is a state-of-the-art colorization method. It is able to produce multiple versions of colored images when given a single black and white image input. 6 | The two main networks require separate training. As you can already infer from the image below, a slight drawback can be that the model is a bit heavy and is trained with the aid of 8(!) GPUs. 7 | 8 | ***Note 9 | This is not a complete implementation. The coloring network needs to be added. 10 | 11 | ![network architecture](images/model_arch.png) 12 | 13 | * There are four main networks included in the architecture 14 | 15 | **pix_network_1.py** 16 | 1. Conditioning Network: 17 | Pretrain conditioning network on COCO image segmentation 18 | 19 | 2. Adaptation Network: 20 | Conditioning and adaptation network turn brightness channel Y into a set of features that are used for conditioning the PixelCNN. 21 | 22 | 3. Coloring Network(pixelCNN): 23 | pixelCNN is optimized alongside conditioning and adaptation network. It predicts a low resolution chrominance of the image 24 | 25 | 26 | **pix_network_2.py** 27 | 28 | 4. Refinement Network: 29 | The low resolution color image made from the previous network is fed into the refinement network, which then produces a full resolution colorization 30 | 31 | 32 | -------------------------------------------------------------------------------- /images/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjooyoo/pix-color_pixel-cnn/f0d077fe62ba3bae5b641549bad2d18c16b10e2f/images/model_arch.png -------------------------------------------------------------------------------- /pix_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | '''Loads data as numpy data form''' 4 | 5 | import torch 6 | import torchvision.datasets as dsets 7 | from torchvision import transforms 8 | from torch.utils import data 9 | import os 10 | import numpy as np 11 | from PIL import Image 12 | 13 | 14 | class ImageFolder(data.Dataset): 15 | """Custom Dataset compatible with prebuilt DataLoader. 16 | 17 | This is just for tutorial. You can use the prebuilt torchvision.datasets.ImageFolder. 18 | """ 19 | 20 | def __init__(self, root, transform=None): 21 | """Initializes image paths and preprocessing module.""" 22 | self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) 23 | self.transform = transform 24 | 25 | def __getitem__(self, index): 26 | """Reads an image from a file and preprocesses it and returns.""" 27 | image_path = self.image_paths[index] 28 | image = Image.open(image_path).convert('YCbCr') 29 | #grayscale_image = image.split()[0] 30 | #if self.transform is not None: 31 | # image = self.transform(image) 32 | return image 33 | 34 | def __len__(self): 35 | """Returns the total number of image files.""" 36 | return len(self.image_paths) 37 | 38 | 39 | def data_loader(dataset, batch_size, num_workers=2): 40 | """Builds and returns Dataloader.""" 41 | if dataset == 'imagenet': 42 | image_size = 64 43 | 44 | transform = transforms.Compose([ 45 | # transforms.Scale(image_size), 46 | transforms.ToTensor() 47 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | # std=[0.229, 0.224, 0.225]) 49 | ]) 50 | 51 | traindir = './data/tiny-imagenet-200/train/' 52 | valdir = './data/tiny-imagenet-200/val/' 53 | 54 | train_dataset = dsets.ImageFolder(traindir, transform=transform) 55 | val_dataset = dsets.ImageFolder(valdir, transform=transform) 56 | train_loader = data.DataLoader(dataset=train_dataset, 57 | batch_size=batch_size, 58 | shuffle=True, 59 | num_workers=num_workers) 60 | val_loader = data.DataLoader(dataset=val_dataset, 61 | batch_size=batch_size, 62 | shuffle=False, 63 | num_workers=num_workers) 64 | 65 | 66 | 67 | # shuffle=False, True when training. 68 | '''elif dataset == 'cifar': 69 | image_size =32 70 | 71 | transform = transforms.Compose([ 72 | transforms.Scale(image_size), 73 | transforms.ToTensor() 74 | ]) 75 | 76 | train_dataset = dsets.CIFAR10(root='./data/', 77 | train=True, 78 | transform=transform, 79 | download=True) 80 | val_dataset = dsets.CIFAR10(root='./data/', 81 | train=False, 82 | transform=transform) 83 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 84 | batch_size=batch_size, 85 | shuffle=True, 86 | num_workers=num_workers) 87 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 88 | batch_size=batch_size, 89 | shuffle=False, 90 | num_workers=num_workers)''' 91 | 92 | return train_loader, val_loader, image_size 93 | -------------------------------------------------------------------------------- /pix_network_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torchvision.transforms as transforms 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, downsample=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = nn.BatchNorm2d(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | 22 | def forward(self, x): 23 | residual = x 24 | 25 | out = self.conv1(x) 26 | out = self.bn1(out) 27 | out = self.relu(out) 28 | 29 | out = self.conv2(out) 30 | out = self.bn2(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv3(out) 34 | out = self.bn3(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class ResNet(nn.Module): 46 | 47 | def __init__(self, block, layers, num_classes=1000): 48 | self.inplanes = 64 49 | super(ResNet, self).__init__() 50 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 51 | bias=False) 52 | self.bn1 = nn.BatchNorm2d(64) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 55 | self.layer1 = self._make_layer(block, 64, layers[0]) 56 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 57 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 58 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 59 | self.avgpool = nn.AvgPool2d(7) 60 | self.fc = nn.Linear(512 * block.expansion, num_classes) 61 | 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 65 | m.weight.data.normal_(0, math.sqrt(2. / n)) 66 | elif isinstance(m, nn.BatchNorm2d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | def _make_layer(self, block, planes, blocks, stride=1): 71 | downsample = None 72 | if stride != 1 or self.inplanes != planes * block.expansion: 73 | downsample = nn.Sequential( 74 | nn.Conv2d(self.inplanes, planes * block.expansion, 75 | kernel_size=1, stride=stride, bias=False), 76 | nn.BatchNorm2d(planes * block.expansion), 77 | ) 78 | 79 | layers = [] 80 | layers.append(block(self.inplanes, planes, stride, downsample)) 81 | self.inplanes = planes * block.expansion 82 | for i in range(1, blocks): 83 | layers.append(block(self.inplanes, planes)) 84 | 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | x = self.conv1(x) 89 | x = self.bn1(x) 90 | x = self.relu(x) 91 | x = self.maxpool(x) 92 | 93 | x = self.layer1(x) 94 | x = self.layer2(x) 95 | x = self.layer3(x) 96 | 97 | return x 98 | 99 | 100 | def conditioning_network(pretrained=False, **kwargs): 101 | """Constructs a ResNet-101 model. 102 | 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | """ 106 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 107 | #if pretrained: 108 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 109 | return model 110 | 111 | 112 | class adaptation_network(nn.Module): 113 | 114 | def __init__(self, in_size, out_size, kernel_size=1): 115 | super(adaptation_network, self).__init__() 116 | self.conv1 = nn.Conv2d(in_size, 512, kernel_size, bias=False) 117 | self.conv2 = nn.Conv2d(512, 256, kernel_size, bias=False) 118 | self.conv3 = nn.Conv2d(256, out_size, kernel_size, bias=False) 119 | 120 | def forward(self, x): 121 | out = self.conv1(x) 122 | out = self.conv2(out) 123 | out = self.conv3(out) 124 | 125 | return out 126 | 127 | 128 | class PixColor1(nn.Module): 129 | 130 | def __init__(self, downsize): 131 | super(PixColor1, self).__init__() 132 | self.conditioning_network = conditioning_network() 133 | self.adaptation_network = adaptation_network(1025, 2) 134 | self.pixelCNN = 135 | '''transform = transforms.Compose([ 136 | #transforms.ToPILImage(), 137 | transforms.Scale(downsize, interpolation=2), 138 | transforms.ToTensor() 139 | ])''' 140 | self.resize = transforms.Scale(downsize, interpolation=2) 141 | 142 | def forward(self, x): 143 | x_conditioned = self.conditioning_network(x) # 64 x 1024 x 4 X 4 (4 if input imsize is 64) 144 | # print(x_conditioned.size()) 145 | x_resized = self.resize(x) # 64 x 1 x 4 x 4 146 | x_concat = torch.cat((x_resized, x_conditioned), 1) # 64 x 1025 x 4 x 4 147 | 148 | out = self.adaptation_network(x_concat) 149 | out = self.pixelCNN 150 | 151 | return out -------------------------------------------------------------------------------- /pix_network_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import transforms 5 | 6 | 7 | #input = 224 x 224 x 3 8 | 9 | class UNetConvBlock1(nn.Module): 10 | # feature map 64 11 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 12 | super(UNetConvBlock1, self).__init__() 13 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, stride=2, padding=1) 14 | self.activation = activation 15 | self.batchnorm = nn.BatchNorm2d(out_size) 16 | 17 | def forward(self, x): 18 | out = self.activation(x) 19 | out = self.activation(self.conv(out)) 20 | out = self.batchnorm(out) 21 | return out 22 | 23 | class UNetConvBlock2(nn.Module): 24 | # feature map 128 25 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 26 | super(UNetConvBlock2, self).__init__() 27 | self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 28 | self.activation = activation 29 | self.batchnorm = nn.BatchNorm2d(out_size) 30 | self.conv3 = nn.Conv2d(out_size, out_size, 3, stride=2, groups=out_size, bias=False) 31 | #self.conv3.weight.data.fill_(1) 32 | 33 | def forward(self, x): 34 | out = self.activation(x) 35 | out = self.activation(self.conv2(out)) 36 | out = self.batchnorm(out) 37 | out = self.conv3(out) 38 | return out 39 | 40 | class UNetConvBlock3(nn.Module): 41 | # feature map 256 42 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 43 | super(UNetConvBlock3, self).__init__() 44 | self.conv1 = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 45 | self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, stride=2, padding=1) 46 | self.activation = activation 47 | self.batchnorm = nn.BatchNorm2d(out_size) 48 | 49 | def forward(self, x): 50 | out = self.activation(x) 51 | out = self.activation(self.conv1(out)) 52 | out = self.activation(self.conv2(out)) 53 | out = self.batchnorm(out) 54 | return out 55 | 56 | class UNetConvBlock4(nn.Module): 57 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 58 | super(UNetConvBlock4, self).__init__() 59 | self.conv1 = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 60 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1) 61 | self.activation = activation 62 | self.batchnorm = nn.BatchNorm2d(out_size) 63 | #self.conv3.weight.data.fill_(1) 64 | 65 | def forward(self, x): 66 | out = self.activation(x) 67 | out = self.activation(self.conv1(out)) 68 | out = self.activation(self.conv2(out)) 69 | out = self.batchnorm(out) 70 | return out 71 | 72 | class UNetConvBlock5(nn.Module): 73 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 74 | super(UNetConvBlock5, self).__init__() 75 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 76 | self.activation = activation 77 | self.batchnorm = nn.BatchNorm2d(out_size) 78 | 79 | def forward(self, x): 80 | out = self.activation(x) 81 | out = self.activation(self.conv(out)) 82 | out = self.batchnorm(out) 83 | return out 84 | 85 | class UNetConvBlock6(nn.Module): 86 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 87 | # feature maps 512 88 | super(UNetConvBlock6, self).__init__() 89 | self.conv1 = nn.Conv2d(in_size, out_size, kernel_size, stride=2, padding=1) 90 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1) 91 | self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, stride=2, padding=1) 92 | self.conv4 = nn.Conv2d(out_size, out_size, kernel_size, groups=out_size, bias=False) 93 | self.activation = activation 94 | self.batchnorm = nn.BatchNorm2d(out_size) 95 | #self.conv4.weight.data.fill_(1) 96 | 97 | def forward(self, x): 98 | out = self.activation(x) 99 | out = self.activation(self.conv1(out)) 100 | out = self.activation(self.conv2(out)) 101 | out = self.activation(self.conv3(out)) 102 | out = self.batchnorm(out) 103 | return out 104 | 105 | class UNetConvBlock7(nn.Module): 106 | def __init__(self, in_size, out_size, kernel_size=5, activation=F.relu): 107 | #feature maps 1024 108 | super(UNetConvBlock7, self).__init__() 109 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1) 110 | self.activation = activation 111 | self.batchnorm = nn.BatchNorm2d(out_size) 112 | 113 | def forward(self, x): 114 | out = self.activation(x) 115 | out = self.activation(self.conv(out)) 116 | out = self.batchnorm(out) 117 | return out 118 | 119 | class UNetConvBlock8(nn.Module): 120 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 121 | # feature map 512 122 | super(UNetConvBlock8, self).__init__() 123 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1) 124 | self.activation = activation 125 | self.batchnorm = nn.BatchNorm2d(out_size) 126 | 127 | def forward(self, x): 128 | out = self.activation(x) 129 | out = self.activation(self.conv(out)) 130 | out = self.batchnorm(out) 131 | return out 132 | 133 | class UNetConvBlock9(nn.Module): 134 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 135 | # feature map 128 136 | super(UNetConvBlock9, self).__init__() 137 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1) 138 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 139 | self.activation = activation 140 | self.batchnorm = nn.BatchNorm2d(out_size) 141 | 142 | def forward(self, x): 143 | out = self.activation(x) 144 | out = self.activation(self.conv(out)) 145 | out = self.activation(self.conv2(out)) 146 | out = self.batchnorm(out) 147 | return out 148 | 149 | 150 | # bilinear upsampling 151 | class UNetConvBlock10(nn.Module): 152 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 153 | # feature map 64 154 | super(UNetConvBlock10, self).__init__() 155 | self.up = nn.UpsamplingBilinear2d(in_size, out_size, kernel_size, padding=1, dilation=1) 156 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1) 157 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 158 | self.activation = activation 159 | self.batchnorm = nn.BatchNorm2d(out_size) 160 | 161 | def forward(self, x): 162 | out = self.up(x) 163 | out = self.activation(self.conv(out)) 164 | out = self.activation(self.conv2(out)) 165 | out = self.batchnorm(out) 166 | 167 | return out 168 | 169 | 170 | # bilinear upsampling 171 | class UNetConvBlock11(nn.Module): 172 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 173 | # feature map 32 174 | super(UNetConvBlock11, self).__init__() 175 | self.up = nn.UpsamplingBilinear2d(in_size, out_size, kernel_size, padding=1, dilation=1) 176 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1) 177 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1) 178 | self.activation = activation 179 | self.batchnorm = nn.BatchNorm2d(out_size) 180 | 181 | def forward(self, x): 182 | out = self.up(x) 183 | out = self.activation(self.conv(out)) 184 | out = self.activation(self.conv2(out)) 185 | out = self.batchnorm(out) 186 | return out 187 | 188 | class UNetConvBlock12(nn.Module): 189 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False): 190 | super(UNetConvBlock12, self).__init__() 191 | self.conv = nn.Conv2d(out_size, out_size, kernel_size=1, padding=1, dilation=1) 192 | self.activation = activation 193 | self.batchnorm = nn.BatchNorm2d(out_size) 194 | def forward(self, x, bridge): 195 | out = self.activation(x) 196 | out = self.activation(self.conv(out)) 197 | out = self.batchnorm(out) 198 | return out 199 | 200 | 201 | class refinement_network(nn.Module): 202 | def __init__(self, imsize): 203 | super(refinement_network, self).__init__() 204 | self.imsize = imsize 205 | 206 | self.convlayer1 = UNetConvBlock1(1, 64) 207 | self.convlayer2 = UNetConvBlock2(64, 128) 208 | self.convlayer3 = UNetConvBlock3(128, 256) 209 | self.convlayer4 = UNetConvBlock4(256, 512) 210 | self.convlayer5 = UNetConvBlock5(512, 256) 211 | self.convlayer6 = UNetConvBlock6(256, 512) 212 | self.convlayer7 = UNetConvBlock7(512, 1024) 213 | self.convlayer8 = UNetConvBlock8(1024, 512) 214 | self.convlayer9 = UNetConvBlock9(512, 128) 215 | self.convlayer10 = UNetConvBlock10(128, 64) 216 | self.convlayer11 = UNetConvBlock11(64, 32) 217 | self.convlayer12 = UNetConvBlock12(32, 2) 218 | self.resize = transforms.Scale(imsize, interpolation=2) 219 | 220 | 221 | 222 | def forward(self, x1, x2): 223 | x2 = self.resize(x2) 224 | out = torch.cat(x1, x2) 225 | layer1 = self.convlayer1(out) 226 | layer2 = self.convlayer2(layer1) 227 | layer3 = self.convlayer3(layer2) 228 | layer4 = self.convlayer4(layer3) 229 | layer5 = self.convlayer5(layer4) 230 | layer6 = self.convlayer6(layer5) 231 | layer7 = self.convlayer7(layer6) 232 | layer8 = self.convlayer8(layer7) 233 | layer9 = self.convlayer9(layer8) 234 | layer10 = self.convlayer10(layer9) 235 | layer11 = self.convlayer11(layer10) 236 | layer12 = self.convlayer12(layer11) 237 | out = layer12 238 | 239 | return out 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | -------------------------------------------------------------------------------- /pix_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from torch import cuda 4 | from torch.autograd import Variable 5 | 6 | from pix_dataloader import * 7 | from pix_networks import * 8 | from pix_network_2 import * 9 | from pix_util import * 10 | 11 | ## don't forget to delete refinement network 12 | 13 | # arguments parsed when initiating 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data', type=str, default='imagenet', choices=['cifar', 'imagenet', 'celeba', 'mscoco']) 17 | parser.add_argument('--gpu', type=int, default=1) 18 | parser.add_argument('--model_path', type=str, default='./models_pix') 19 | parser.add_argument('--log_path', type=str, default='./logs_pix') 20 | parser.add_argument('--model', type=str, default='pixcolor100.pkl') 21 | parser.add_argument('--image_save', type=str, default='./images_pix') 22 | parser.add_argument('--learning_rate', type=int, default=0.0003) 23 | parser.add_argument('--num_epochs', type=int, default=500) 24 | parser.add_argument('--start_epoch', type=int, default=0) 25 | parser.add_argument('--batch_size', type=int, default=64) 26 | parser.add_argument('--idx', type=int, default=1) 27 | parser.add_argument('--resume', type=bool, default=False, 28 | help='path to latest checkpoint (default: none)') 29 | 30 | return parser.parse_args() 31 | 32 | 33 | def main(args): 34 | dataset = args.data 35 | gpu = args.gpu 36 | batch_size = args.batch_size 37 | model_path = args.model_path 38 | log_path = args.log_path 39 | num_epochs = args.num_epochs 40 | learning_rate = args.learning_rate 41 | start_epoch = args.start_epoch 42 | 43 | # make directory for models saved when there is not. 44 | make_folder(model_path, dataset) # for sampling model 45 | make_folder(log_path, dataset) # for logpoint model 46 | make_folder(log_path, dataset +'/ckpt') # for checkpoint model 47 | 48 | # see if gpu is on 49 | print("Running on gpu : ", gpu) 50 | cuda.set_device(gpu) 51 | 52 | # set the data-loaders 53 | train_loader, val_loader, imsize = data_loader(dataset, batch_size) 54 | 55 | # declare class 56 | RefNet = UNet(imsize) 57 | 58 | # make the class run on gpu 59 | RefNet.cuda() 60 | 61 | # Loss and Optimizer 62 | optimizer = torch.optim.Adam(RefNet.parameters(), lr=learning_rate) 63 | criterion = torch.nn.CrossEntropyLoss() 64 | # (int, torch.FloatTensor, torch.LongTensor, torch.FloatTensor, bool) is input haha... 65 | 66 | # optionally resume from a checkpoint 67 | if args.resume: 68 | ckpt_path = os.path.join(log_path, dataset, 'ckpt/model.ckpt') 69 | if os.path.isfile(ckpt_path): 70 | print("=> loading checkpoint") 71 | checkpoint = torch.load(ckpt_path) 72 | start_epoch = checkpoint['epoch'] 73 | RefNet.load_state_dict(checkpoint['state_dict']) 74 | optimizer.load_state_dict(checkpoint['optimizer']) 75 | print("=> Loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 76 | print("=> Meaning that start training from (epoch {})".format(checkpoint['epoch']+1)) 77 | else: 78 | print("=> Sorry, no checkpoint found at '{}'".format(args.resume)) 79 | 80 | # record time 81 | tell_time = Timer() 82 | iter = 0 83 | # Train the Model 84 | 85 | for epoch in range(start_epoch, num_epochs): 86 | RefNet.train() 87 | for i, (images, labels) in enumerate(train_loader): 88 | batch = images.size(0) 89 | images = Variable(images) 90 | labels = Variable(labels) 91 | 92 | # Forward + Backward + Optimize 93 | # make outputs and labels as a matrix for loss calculation 94 | outputs = images.view(batch, -1) # 100 x 32*32*3(2048) 95 | #outputs = RefNet(images) 96 | labels = labels.contiguous().view(batch, -1) # 100 x 32*32*3 97 | loss = criterion(outputs, labels) 98 | loss.backward() 99 | optimizer.zero_grad() 100 | 101 | if (i + 1) % 10 == 0: 102 | print('Epoch [%d/%d], Iter [%d/%d], Loss: %.10f, iter_time: %2.2f, aggregate_time: %6.2f' 103 | % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0], 104 | (tell_time.toc() - iter), tell_time.toc())) 105 | torch.save(RefNet.state_dict(), os.path.join(model_path, dataset, 'RefNet%d.pkl' % (epoch + 1))) 106 | 107 | 108 | 109 | 110 | # start evaluation 111 | print("-------------evaluation start------------") 112 | 113 | RefNet.eval() 114 | loss_val_all = Variable(torch.zeros(100), volatile=True).cuda() 115 | for i, (images, _) in enumerate(val_loader): 116 | 117 | # change the picture type from rgb to CIE Lab 118 | batch = images.size(0) 119 | 120 | # make them all variable + gpu avialable 121 | 122 | images = Variable(images) 123 | labels = Variable(labels) 124 | 125 | # initialize gradients 126 | optimizer.zero_grad() 127 | 128 | # make outputs and labels as a matrix for loss calculation 129 | outputs = images.view(batch, -1) 130 | outputs = outputs.view(batch, -1) # 100 x 32*32*3(2048) 131 | labels = labels.contiguous().view(batch, -1) # 100 x 32*32*3 igon aniji 132 | 133 | loss_val = criterion(outputs, labels) 134 | 135 | logpoint = { 136 | 'epoch': epoch + 1, 137 | 'args': args, 138 | } 139 | checkpoint = { 140 | 'epoch': epoch + 1, 141 | 'args': args, 142 | 'state_dict': RefNet.state_dict(), 143 | 'optimizer': optimizer.state_dict(), 144 | } 145 | 146 | loss_val_all[i] = loss_val 147 | 148 | if i == 30: 149 | print('Epoch [%d/%d], Validation Loss: %.10f' 150 | % (epoch + 1, num_epochs, torch.mean(loss_val_all).data[0])) 151 | torch.save(checkpoint, os.path.join(log_path, dataset, 'ckpt/model.ckpt')) 152 | break 153 | 154 | 155 | if __name__ == '__main__': 156 | args = parse_args() 157 | main(args) 158 | 159 | def make_folder(path, dataset): 160 | try: 161 | os.makedirs(os.path.join(path, dataset)) 162 | except OSError: 163 | pass -------------------------------------------------------------------------------- /pix_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import datetime 5 | import torch 6 | import skimage.color as sc 7 | 8 | 9 | class Timer(): 10 | def __init__(self): 11 | self.cur_t = time.time() 12 | 13 | def tic(self): 14 | self.cur_t = time.time() 15 | 16 | def toc(self): 17 | return time.time() - self.cur_t 18 | 19 | def tocStr(self, t=-1): 20 | if (t == -1): 21 | return str(datetime.timedelta(seconds=np.round(time.time() - self.cur_t, 3)))[:-4] 22 | else: 23 | return str(datetime.timedelta(seconds=np.round(t, 3)))[:-4] 24 | 25 | 26 | def make_folder(path, dataset): 27 | try: 28 | os.makedirs(os.path.join(path, dataset)) 29 | except OSError: 30 | pass 31 | 32 | 33 | def data_process(image_data, batch_size, imsize): 34 | 35 | images_numpy = image_data.numpy() 36 | input = torch.zeros(batch_size, 1, imsize, imsize) 37 | labels = torch.zeros(batch_size, 2, imsize, imsize) 38 | for k in range(batch_size): 39 | rgb = images_numpy[k].transpose(1, 2, 0) 40 | yCbCr = sc.rgb2ycbcr(rgb) / 255 41 | img_y = yCbCr[:, :, 0] 42 | input[k] = torch.from_numpy(np.expand_dims(img_y,0)) 43 | img_CbCr = yCbCr[:, :, 1:3] 44 | labels[k] = torch.from_numpy(img_CbCr) 45 | 46 | return input, labels 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | --------------------------------------------------------------------------------