├── Model.png ├── Model_details.pdf ├── outputs └── sample │ ├── img-2.jpg │ ├── img-43.jpg │ ├── img-79.jpg │ ├── Example.png │ ├── img-219.jpg │ ├── img-2417.jpg │ ├── img-2584.jpg │ ├── img-2796.jpg │ ├── img-3050.jpg │ ├── img-4202.jpg │ ├── img-4515.jpg │ ├── img-7038.jpg │ └── img-7159.jpg ├── README.md ├── dataloader.py └── mymodels.py /Model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/Model.png -------------------------------------------------------------------------------- /Model_details.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/Model_details.pdf -------------------------------------------------------------------------------- /outputs/sample/img-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-2.jpg -------------------------------------------------------------------------------- /outputs/sample/img-43.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-43.jpg -------------------------------------------------------------------------------- /outputs/sample/img-79.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-79.jpg -------------------------------------------------------------------------------- /outputs/sample/Example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/Example.png -------------------------------------------------------------------------------- /outputs/sample/img-219.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-219.jpg -------------------------------------------------------------------------------- /outputs/sample/img-2417.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-2417.jpg -------------------------------------------------------------------------------- /outputs/sample/img-2584.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-2584.jpg -------------------------------------------------------------------------------- /outputs/sample/img-2796.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-2796.jpg -------------------------------------------------------------------------------- /outputs/sample/img-3050.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-3050.jpg -------------------------------------------------------------------------------- /outputs/sample/img-4202.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-4202.jpg -------------------------------------------------------------------------------- /outputs/sample/img-4515.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-4515.jpg -------------------------------------------------------------------------------- /outputs/sample/img-7038.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-7038.jpg -------------------------------------------------------------------------------- /outputs/sample/img-7159.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-7159.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Anime-Sketch-Colorizer 2 | 3 | Automatic Sketch Colorization with reference image 4 | 5 | Prerequisites 6 | ------ 7 | 8 | `pytorch` 9 | 10 | `torchvision` 11 | 12 | `numpy` 13 | 14 | `openCV2` 15 | 16 | `matplotlib` 17 | 18 | Dataset 19 | ------ 20 | 21 | Taebum Kim, "Anime Sketch Colorization Pair", https://www.kaggle.com/ktaebum/anime-sketch-colorization-pair 22 | 23 | Train 24 | ------ 25 | 26 | Please refer `train.ipynb` 27 | 28 | Test 29 | ------ 30 | 31 | Please refer `test.ipynb` 32 | 33 | * You can download pretrained checkpoint on https://drive.google.com/open?id=1pIZCjubtyOUr7AXtGQMvzcbKczJ9CtQG (449MB) 34 | 35 | Training details 36 | ------ 37 | 38 | |
Parameter
|
Value
| 39 | |:--------|:--------:| 40 | | Learning rate | 2e-4 | 41 | | Batch size | 2 | 42 | | Epoch | 25 | 43 | | Optimizer | Adam | 44 | | (beta1, beta2) | (0.5, 0.999) | 45 | | (lambda1, lambda2, lambda3) | (100, 1e-4, 1e-2) | 46 | | Data Augmentation | RandomResizedCrop(256)
RandomHorizontalFlip() | 47 | | HW | CPU : Intel i5-8400
RAM : 16G
GPU : NVIDIA GTX1060 6G | 48 | | Training Time | About 0.93s per iteration
(About 45 hours for 25 epoch) | 49 | 50 | Model 51 | ------ 52 | 53 | ![ex_screenshot](./Model.png) 54 | 55 | For more details, please refer `Model_details.pdf` 56 | 57 | Results 58 | ----- 59 | 60 |
Reference / Sketch / Colorization Result / Ground Truth
61 | 62 | ![ex_screenshot](./outputs/sample/img-2.jpg) 63 | ![ex_screenshot](./outputs/sample/img-43.jpg) 64 | ![ex_screenshot](./outputs/sample/img-79.jpg) 65 | ![ex_screenshot](./outputs/sample/img-219.jpg) 66 | ![ex_screenshot](./outputs/sample/img-2417.jpg) 67 | ![ex_screenshot](./outputs/sample/img-2584.jpg) 68 | ![ex_screenshot](./outputs/sample/img-2796.jpg) 69 | ![ex_screenshot](./outputs/sample/img-3050.jpg) 70 | ![ex_screenshot](./outputs/sample/img-4202.jpg) 71 | ![ex_screenshot](./outputs/sample/img-4515.jpg) 72 | ![ex_screenshot](./outputs/sample/img-7038.jpg) 73 | ![ex_screenshot](./outputs/sample/img-7159.jpg) 74 | 75 | Reference 76 | ------ 77 | 78 | [1] Taebum Kim, "Anime Sketch Colorization Pair", https://www.kaggle.com/ktaebum/anime-sketch-colorization-pair, 2019., 2020.1.13. 79 | 80 | [2] Jim Bohnslav,"opencv_transforms", https://github.com/jbohnslav/opencv_transforms, 2020.1.13. 81 | 82 | [3] Takeru Miyato et al., "Spectral Normalization for Generative Adversarial Networks", ICLR 2018, 2018.2.18. 83 | 84 | [4] Ozan Oktay et al., "Attention U-Net: Learning Where to Look for the Pancreas", MIDL 2018, 2018.5.20. 85 | 86 | [5] Siyuan Qiao et al., "Weight Standardization", https://arxiv.org/abs/1903.10520, 2019. 3. 25., 2020.1.19. 87 | 88 | [6] Tero Karras, Samuli Laine, Timo Aila, "A Style-Based Generator Architecture for Generative Adversarial Networks", https://arxiv.org/abs/1812.04948, 2019.3.29., 2020.1.22. 89 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import opencv_transforms.functional as FF 6 | from torchvision import datasets 7 | from PIL import Image 8 | 9 | def color_cluster(img, nclusters=9): 10 | """ 11 | Apply K-means clustering to the input image 12 | 13 | Args: 14 | img: Numpy array which has shape of (H, W, C) 15 | nclusters: # of clusters (default = 9) 16 | 17 | Returns: 18 | color_palette: list of 3D numpy arrays which have same shape of that of input image 19 | e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4] 20 | and each component is (256, 256, 3) numpy array. 21 | 22 | Note: 23 | K-means clustering algorithm is quite computaionally intensive. 24 | Thus, before extracting dominant colors, the input images are resized to x0.25 size. 25 | """ 26 | img_size = img.shape 27 | small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA) 28 | sample = small_img.reshape((-1, 3)) 29 | sample = np.float32(sample) 30 | criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) 31 | flags = cv2.KMEANS_PP_CENTERS 32 | 33 | _, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags) 34 | centers = np.uint8(centers) 35 | color_palette = [] 36 | 37 | for i in range(0, nclusters): 38 | dominant_color = np.zeros(img_size, dtype='uint8') 39 | dominant_color[:,:,:] = centers[i] 40 | color_palette.append(dominant_color) 41 | 42 | return color_palette 43 | 44 | class PairImageFolder(datasets.ImageFolder): 45 | """ 46 | A generic data loader where the images are arranged in this way: :: 47 | 48 | root/dog/xxx.png 49 | root/dog/xxy.png 50 | root/dog/xxz.png 51 | 52 | root/cat/123.png 53 | root/cat/nsdf3.png 54 | root/cat/asd932_.png 55 | 56 | This class works properly for paired image in form of [sketch, color_image] 57 | 58 | Args: 59 | root (string): Root directory path. 60 | transform (callable, optional): A function/transform that takes in an PIL image 61 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 62 | target_transform (callable, optional): A function/transform that takes in the 63 | target and transforms it. 64 | loader (callable, optional): A function to load an image given its path. 65 | is_valid_file (callable, optional): A function that takes path of an Image file 66 | and check if the file is a valid file (used to check of corrupt files) 67 | sketch_net: The network to convert color image to sketch image 68 | ncluster: Number of clusters when extracting color palette. 69 | 70 | Attributes: 71 | classes (list): List of the class names. 72 | class_to_idx (dict): Dict with items (class_name, class_index). 73 | imgs (list): List of (image path, class_index) tuples 74 | 75 | Getitem: 76 | img_edge: Edge image 77 | img: Color Image 78 | color_palette: Extracted color paltette 79 | """ 80 | def __init__(self, root, transform, sketch_net, ncluster): 81 | super(PairImageFolder, self).__init__(root, transform) 82 | self.ncluster = ncluster 83 | self.sketch_net = sketch_net 84 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 85 | 86 | def __getitem__(self, index): 87 | path, label = self.imgs[index] 88 | img = self.loader(path) 89 | img = np.asarray(img) 90 | img = img[:, 0:512, :] 91 | img = self.transform(img) 92 | color_palette = color_cluster(img, nclusters=self.ncluster) 93 | img = self.make_tensor(img) 94 | 95 | with torch.no_grad(): 96 | img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy() 97 | img_edge = FF.to_grayscale(img_edge, num_output_channels=3) 98 | img_edge = FF.to_tensor(img_edge) 99 | 100 | for i in range(0, len(color_palette)): 101 | color = color_palette[i] 102 | color_palette[i] = self.make_tensor(color) 103 | 104 | return img_edge, img, color_palette 105 | 106 | def make_tensor(self, img): 107 | img = FF.to_tensor(img) 108 | img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 109 | return img 110 | 111 | class GetImageFolder(datasets.ImageFolder): 112 | """ 113 | A generic data loader where the images are arranged in this way: :: 114 | 115 | root/dog/xxx.png 116 | root/dog/xxy.png 117 | root/dog/xxz.png 118 | 119 | root/cat/123.png 120 | root/cat/nsdf3.png 121 | root/cat/asd932_.png 122 | 123 | Args: 124 | root (string): Root directory path. 125 | transform (callable, optional): A function/transform that takes in an PIL image 126 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 127 | target_transform (callable, optional): A function/transform that takes in the 128 | target and transforms it. 129 | loader (callable, optional): A function to load an image given its path. 130 | is_valid_file (callable, optional): A function that takes path of an Image file 131 | and check if the file is a valid file (used to check of corrupt files) 132 | sketch_net: The network to convert color image to sketch image 133 | ncluster: Number of clusters when extracting color palette. 134 | 135 | Attributes: 136 | classes (list): List of the class names. 137 | class_to_idx (dict): Dict with items (class_name, class_index). 138 | imgs (list): List of (image path, class_index) tuples 139 | 140 | Getitem: 141 | img_edge: Edge image 142 | img: Color Image 143 | color_palette: Extracted color paltette 144 | """ 145 | def __init__(self, root, transform, sketch_net, ncluster): 146 | super(GetImageFolder, self).__init__(root, transform) 147 | self.ncluster = ncluster 148 | self.sketch_net = sketch_net 149 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 150 | 151 | def __getitem__(self, index): 152 | path, label = self.imgs[index] 153 | img = self.loader(path) 154 | img = np.asarray(img) 155 | img = self.transform(img) 156 | color_palette = color_cluster(img, nclusters=self.ncluster) 157 | img = self.make_tensor(img) 158 | 159 | with torch.no_grad(): 160 | img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy() 161 | img_edge = FF.to_grayscale(img_edge, num_output_channels=3) 162 | img_edge = FF.to_tensor(img_edge) 163 | 164 | for i in range(0, len(color_palette)): 165 | color = color_palette[i] 166 | color_palette[i] = self.make_tensor(color) 167 | 168 | return img_edge, img, color_palette 169 | 170 | def make_tensor(self, img): 171 | img = FF.to_tensor(img) 172 | img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 173 | return img -------------------------------------------------------------------------------- /mymodels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = [ 6 | 'Color2Sketch', 'Sketch2Color', 'Discriminator', 7 | ] 8 | 9 | class ApplyNoise(nn.Module): 10 | def __init__(self, channels): 11 | super().__init__() 12 | self.weight = nn.Parameter(torch.zeros(channels)) 13 | 14 | def forward(self, x, noise=None): 15 | if noise is None: 16 | noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype) 17 | return x + self.weight.view(1, -1, 1, 1) * noise.to(x.device) 18 | 19 | class Conv2d_WS(nn.Conv2d): 20 | def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 21 | super().__init__(in_chan, out_chan, kernel_size, stride, padding, dilation, groups, bias) 22 | 23 | def forward(self, x): 24 | weight = self.weight 25 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True) 26 | weight = weight - weight_mean 27 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1,1,1,1)+1e-5 28 | weight = weight / std.expand_as(weight) 29 | return torch.nn.functional.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 30 | 31 | class ResidualBlock(nn.Module): 32 | def __init__(self, in_channels, out_channels, stride=1, sample=None): 33 | super(ResidualBlock, self).__init__() 34 | self.ic = in_channels 35 | self.oc = out_channels 36 | self.conv1 = Conv2d_WS(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 37 | self.bn1 = nn.GroupNorm(32, out_channels) 38 | self.conv2 = Conv2d_WS(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 39 | self.bn2 = nn.GroupNorm(32, out_channels) 40 | self.convr = Conv2d_WS(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 41 | self.bnr = nn.GroupNorm(32, out_channels) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.sample = sample 44 | if self.sample == 'down': 45 | self.sampling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 46 | elif self.sample == 'up': 47 | self.sampling = nn.Upsample(scale_factor=2, mode='nearest') 48 | 49 | def forward(self, x): 50 | if self.ic != self.oc: 51 | residual = self.convr(x) 52 | residual = self.bnr(residual) 53 | else: 54 | residual = x 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | out += residual 61 | out = self.relu(out) 62 | if self.sample is not None: 63 | out = self.sampling(out) 64 | return out 65 | 66 | class Attention_block(nn.Module): 67 | def __init__(self,F_g,F_l,F_int): 68 | super(Attention_block,self).__init__() 69 | self.W_g = nn.Sequential( 70 | Conv2d_WS(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 71 | nn.GroupNorm(32, F_int) 72 | ) 73 | 74 | self.W_x = nn.Sequential( 75 | Conv2d_WS(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 76 | nn.GroupNorm(32, F_int) 77 | ) 78 | 79 | self.psi = nn.Sequential( 80 | Conv2d_WS(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 81 | nn.InstanceNorm2d(1), 82 | nn.Sigmoid() 83 | ) 84 | 85 | self.relu = nn.ReLU(inplace=True) 86 | 87 | def forward(self,g,x): 88 | g1 = self.W_g(g) 89 | x1 = self.W_x(x) 90 | psi = self.relu(g1+x1) 91 | psi = self.psi(psi) 92 | 93 | return x*psi 94 | 95 | class Color2Sketch(nn.Module): 96 | def __init__(self, nc=3, pretrained=False): 97 | super(Color2Sketch, self).__init__() 98 | class Encoder(nn.Module): 99 | def __init__(self): 100 | super(Encoder, self).__init__() 101 | # Build ResNet and change first conv layer to accept single-channel input 102 | self.layer1 = ResidualBlock(nc, 64, sample='down') 103 | self.layer2 = ResidualBlock(64, 128, sample='down') 104 | self.layer3 = ResidualBlock(128, 256, sample='down') 105 | self.layer4 = ResidualBlock(256, 512, sample='down') 106 | self.layer5 = ResidualBlock(512, 512, sample='down') 107 | self.layer6 = ResidualBlock(512, 512, sample='down') 108 | self.layer7 = ResidualBlock(512, 512, sample='down') 109 | 110 | def forward(self, input_image): 111 | # Pass input through ResNet-gray to extract features 112 | x0 = input_image # nc * 256 * 256 113 | x1 = self.layer1(x0) # 64 * 128 * 128 114 | x2 = self.layer2(x1) # 128 * 64 * 64 115 | x3 = self.layer3(x2) # 256 * 32 * 32 116 | x4 = self.layer4(x3) # 512 * 16 * 16 117 | x5 = self.layer5(x4) # 512 * 8 * 8 118 | x6 = self.layer6(x5) # 512 * 4 * 4 119 | x7 = self.layer7(x6) # 512 * 2 * 2 120 | 121 | return x1, x2, x3, x4, x5, x6, x7 122 | 123 | class Decoder(nn.Module): 124 | def __init__(self): 125 | super(Decoder, self).__init__() 126 | # Convolutional layers and upsampling 127 | self.noise7 = ApplyNoise(512) 128 | self.layer7_up = ResidualBlock(512, 512, sample='up') 129 | 130 | self.Att6 = Attention_block(F_g=512,F_l=512,F_int=256) 131 | self.layer6 = ResidualBlock(1024, 512, sample=None) 132 | self.noise6 = ApplyNoise(512) 133 | self.layer6_up = ResidualBlock(512, 512, sample='up') 134 | 135 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 136 | self.layer5 = ResidualBlock(1024, 512, sample=None) 137 | self.noise5 = ApplyNoise(512) 138 | self.layer5_up = ResidualBlock(512, 512, sample='up') 139 | 140 | self.Att4 = Attention_block(F_g=512,F_l=512,F_int=256) 141 | self.layer4 = ResidualBlock(1024, 512, sample=None) 142 | self.noise4 = ApplyNoise(512) 143 | self.layer4_up = ResidualBlock(512, 256, sample='up') 144 | 145 | self.Att3 = Attention_block(F_g=256,F_l=256,F_int=128) 146 | self.layer3 = ResidualBlock(512, 256, sample=None) 147 | self.noise3 = ApplyNoise(256) 148 | self.layer3_up = ResidualBlock(256, 128, sample='up') 149 | 150 | self.Att2 = Attention_block(F_g=128,F_l=128,F_int=64) 151 | self.layer2 = ResidualBlock(256, 128, sample=None) 152 | self.noise2 = ApplyNoise(128) 153 | self.layer2_up = ResidualBlock(128, 64, sample='up') 154 | 155 | self.Att1 = Attention_block(F_g=64,F_l=64,F_int=32) 156 | self.layer1 = ResidualBlock(128, 64, sample=None) 157 | self.noise1 = ApplyNoise(64) 158 | self.layer1_up = ResidualBlock(64, 32, sample='up') 159 | 160 | self.noise0 = ApplyNoise(32) 161 | self.layer0 = Conv2d_WS(32, 3, kernel_size=3, stride=1, padding=1) 162 | self.activation = nn.ReLU(inplace=True) 163 | self.tanh = nn.Tanh() 164 | 165 | def forward(self, midlevel_input): #, global_input): 166 | x1, x2, x3, x4, x5, x6, x7 = midlevel_input 167 | 168 | x = self.noise7(x7) 169 | x = self.layer7_up(x) # 512 * 4 * 4 170 | 171 | x6 = self.Att6(g=x,x=x6) 172 | x = torch.cat((x, x6), dim=1) # 1024 * 4 * 4 173 | x = self.layer6(x) # 512 * 4 * 4 174 | x = self.noise6(x) 175 | x = self.layer6_up(x) # 512 * 8 * 8 176 | 177 | x5 = self.Att5(g=x,x=x5) 178 | x = torch.cat((x, x5), dim=1) # 1024 * 8 * 8 179 | x = self.layer5(x) # 512 * 8 * 8 180 | x = self.noise5(x) 181 | x = self.layer5_up(x) # 512 * 16 * 16 182 | 183 | x4 = self.Att4(g=x,x=x4) 184 | x = torch.cat((x, x4), dim=1) # 1024 * 16 * 16 185 | x = self.layer4(x) # 512 * 16 * 16 186 | x = self.noise4(x) 187 | x = self.layer4_up(x) # 256 * 32 * 32 188 | 189 | x3 = self.Att3(g=x,x=x3) 190 | x = torch.cat((x, x3), dim=1) # 512 * 32 * 32 191 | x = self.layer3(x) # 256 * 32 * 32 192 | x = self.noise3(x) 193 | x = self.layer3_up(x) # 128 * 64 * 64 194 | 195 | x2 = self.Att2(g=x,x=x2) 196 | x = torch.cat((x, x2), dim=1) # 256 * 64 * 64 197 | x = self.layer2(x) # 128 * 64 * 64 198 | x = self.noise2(x) 199 | x = self.layer2_up(x) # 64 * 128 * 128 200 | 201 | x1 = self.Att1(g=x,x=x1) 202 | x = torch.cat((x, x1), dim=1) # 128 * 128 * 128 203 | x = self.layer1(x) # 64 * 128 * 128 204 | x = self.noise1(x) 205 | x = self.layer1_up(x) # 32 * 256 * 256 206 | 207 | x = self.noise0(x) 208 | x = self.layer0(x) # 3 * 256 * 256 209 | x = self.tanh(x) 210 | 211 | return x 212 | 213 | self.encoder = Encoder() 214 | self.decoder = Decoder() 215 | if pretrained: 216 | print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ') 217 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 218 | checkpoint = torch.load('./checkpoint/color2edge/ckpt.pth') 219 | self.load_state_dict(checkpoint['netG'], strict=True) 220 | print("Done!") 221 | else: 222 | self.apply(weights_init) 223 | print('Weights of {0} model are initialized'.format('Color2Sketch')) 224 | 225 | def forward(self, inputs): 226 | encode = self.encoder(inputs) 227 | output = self.decoder(encode) 228 | 229 | return output 230 | 231 | class Sketch2Color(nn.Module): 232 | def __init__(self, nc=3, pretrained=False): 233 | super(Sketch2Color, self).__init__() 234 | class Encoder(nn.Module): 235 | def __init__(self): 236 | super(Encoder, self).__init__() 237 | # Build ResNet and change first conv layer to accept single-channel input 238 | self.layer1 = ResidualBlock(nc, 64, sample='down') 239 | self.layer2 = ResidualBlock(64, 128, sample='down') 240 | self.layer3 = ResidualBlock(128, 256, sample='down') 241 | self.layer4 = ResidualBlock(256, 512, sample='down') 242 | self.layer5 = ResidualBlock(512, 512, sample='down') 243 | self.layer6 = ResidualBlock(512, 512, sample='down') 244 | self.layer7 = ResidualBlock(512, 512, sample='down') 245 | 246 | def forward(self, input_image): 247 | # Pass input through ResNet-gray to extract features 248 | x0 = input_image # nc * 256 * 256 249 | x1 = self.layer1(x0) # 64 * 128 * 128 250 | x2 = self.layer2(x1) # 128 * 64 * 64 251 | x3 = self.layer3(x2) # 256 * 32 * 32 252 | x4 = self.layer4(x3) # 512 * 16 * 16 253 | x5 = self.layer5(x4) # 512 * 8 * 8 254 | x6 = self.layer6(x5) # 512 * 4 * 4 255 | x7 = self.layer7(x6) # 512 * 2 * 2 256 | 257 | return x1, x2, x3, x4, x5, x6, x7 258 | 259 | class Decoder(nn.Module): 260 | def __init__(self): 261 | super(Decoder, self).__init__() 262 | # Convolutional layers and upsampling 263 | self.noise7 = ApplyNoise(512) 264 | self.layer7_up = ResidualBlock(512, 512, sample='up') 265 | 266 | self.Att6 = Attention_block(F_g=512,F_l=512,F_int=256) 267 | self.layer6 = ResidualBlock(1024, 512, sample=None) 268 | self.noise6 = ApplyNoise(512) 269 | self.layer6_up = ResidualBlock(512, 512, sample='up') 270 | 271 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 272 | self.layer5 = ResidualBlock(1024, 512, sample=None) 273 | self.noise5 = ApplyNoise(512) 274 | self.layer5_up = ResidualBlock(512, 512, sample='up') 275 | 276 | self.Att4 = Attention_block(F_g=512,F_l=512,F_int=256) 277 | self.layer4 = ResidualBlock(1024, 512, sample=None) 278 | self.noise4 = ApplyNoise(512) 279 | self.layer4_up = ResidualBlock(512, 256, sample='up') 280 | 281 | self.Att3 = Attention_block(F_g=256,F_l=256,F_int=128) 282 | self.layer3 = ResidualBlock(512, 256, sample=None) 283 | self.noise3 = ApplyNoise(256) 284 | self.layer3_up = ResidualBlock(256, 128, sample='up') 285 | 286 | self.Att2 = Attention_block(F_g=128,F_l=128,F_int=64) 287 | self.layer2 = ResidualBlock(256, 128, sample=None) 288 | self.noise2 = ApplyNoise(128) 289 | self.layer2_up = ResidualBlock(128, 64, sample='up') 290 | 291 | self.Att1 = Attention_block(F_g=64,F_l=64,F_int=32) 292 | self.layer1 = ResidualBlock(128, 64, sample=None) 293 | self.noise1 = ApplyNoise(64) 294 | self.layer1_up = ResidualBlock(64, 32, sample='up') 295 | 296 | self.noise0 = ApplyNoise(32) 297 | self.layer0 = Conv2d_WS(32, 3, kernel_size=3, stride=1, padding=1) 298 | self.activation = nn.ReLU(inplace=True) 299 | self.tanh = nn.Tanh() 300 | 301 | def forward(self, midlevel_input): #, global_input): 302 | x1, x2, x3, x4, x5, x6, x7 = midlevel_input 303 | 304 | x = self.noise7(x7) 305 | x = self.layer7_up(x) # 512 * 4 * 4 306 | 307 | x6 = self.Att6(g=x,x=x6) 308 | x = torch.cat((x, x6), dim=1) # 1024 * 4 * 4 309 | x = self.layer6(x) # 512 * 4 * 4 310 | x = self.noise6(x) 311 | x = self.layer6_up(x) # 512 * 8 * 8 312 | 313 | x5 = self.Att5(g=x,x=x5) 314 | x = torch.cat((x, x5), dim=1) # 1024 * 8 * 8 315 | x = self.layer5(x) # 512 * 8 * 8 316 | x = self.noise5(x) 317 | x = self.layer5_up(x) # 512 * 16 * 16 318 | 319 | x4 = self.Att4(g=x,x=x4) 320 | x = torch.cat((x, x4), dim=1) # 1024 * 16 * 16 321 | x = self.layer4(x) # 512 * 16 * 16 322 | x = self.noise4(x) 323 | x = self.layer4_up(x) # 256 * 32 * 32 324 | 325 | x3 = self.Att3(g=x,x=x3) 326 | x = torch.cat((x, x3), dim=1) # 512 * 32 * 32 327 | x = self.layer3(x) # 256 * 32 * 32 328 | x = self.noise3(x) 329 | x = self.layer3_up(x) # 128 * 64 * 64 330 | 331 | x2 = self.Att2(g=x,x=x2) 332 | x = torch.cat((x, x2), dim=1) # 256 * 64 * 64 333 | x = self.layer2(x) # 128 * 64 * 64 334 | x = self.noise2(x) 335 | x = self.layer2_up(x) # 64 * 128 * 128 336 | 337 | x1 = self.Att1(g=x,x=x1) 338 | x = torch.cat((x, x1), dim=1) # 128 * 128 * 128 339 | x = self.layer1(x) # 64 * 128 * 128 340 | x = self.noise1(x) 341 | x = self.layer1_up(x) # 32 * 256 * 256 342 | 343 | x = self.noise0(x) 344 | x = self.layer0(x) # 3 * 256 * 256 345 | x = self.tanh(x) 346 | 347 | return x 348 | 349 | self.encoder = Encoder() 350 | self.decoder = Decoder() 351 | if pretrained: 352 | print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ') 353 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 354 | checkpoint = torch.load('./checkpoint/edge2color/ckpt.pth') 355 | self.load_state_dict(checkpoint['netG'], strict=True) 356 | print("Done!") 357 | else: 358 | self.apply(weights_init) 359 | print('Weights of {0} model are initialized'.format('Sketch2Color')) 360 | 361 | def forward(self, inputs): 362 | encode = self.encoder(inputs) 363 | output = self.decoder(encode) 364 | 365 | return output 366 | 367 | class Discriminator(nn.Module): 368 | def __init__(self, nc=6, pretrained=False): 369 | super(Discriminator, self).__init__() 370 | self.conv1 = torch.nn.utils.spectral_norm(nn.Conv2d(nc, 64, kernel_size=4, stride=2, padding=1)) 371 | self.bn1 = nn.GroupNorm(32, 64) 372 | self.conv2 = torch.nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)) 373 | self.bn2 = nn.GroupNorm(32,128) 374 | self.conv3 = torch.nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)) 375 | self.bn3 = nn.GroupNorm(32, 256) 376 | self.conv4 = torch.nn.utils.spectral_norm(nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1)) 377 | self.bn4 = nn.GroupNorm(32, 512) 378 | self.conv5 = torch.nn.utils.spectral_norm(nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)) 379 | self.activation = nn.LeakyReLU(0.2, inplace=True) 380 | self.sigmoid = nn.Sigmoid() 381 | 382 | if pretrained: 383 | pass 384 | else: 385 | self.apply(weights_init) 386 | print('Weights of {0} model are initialized'.format('Discriminator')) 387 | 388 | def forward(self, base, unknown): 389 | input = torch.cat((base, unknown), dim=1) 390 | x = self.activation(self.conv1(input)) 391 | x = self.activation(self.bn2(self.conv2(x))) 392 | x = self.activation(self.bn3(self.conv3(x))) 393 | x = self.activation(self.bn4(self.conv4(x))) 394 | x = self.sigmoid(self.conv5(x)) 395 | 396 | return x.mean((2,3)) 397 | 398 | # To initialize model weights 399 | def weights_init(model): 400 | classname = model.__class__.__name__ 401 | if classname.find('Conv') != -1: 402 | nn.init.normal_(model.weight.data, 0.0, 0.02) 403 | elif classname.find('Conv2d_WS') != -1: 404 | nn.init.normal_(model.weight.data, 0.0, 0.02) 405 | elif classname.find('BatchNorm') != -1: 406 | nn.init.normal_(model.weight.data, 1.0, 0.02) 407 | nn.init.constant_(model.bias.data, 0) 408 | elif classname.find('GroupNorm') != -1: 409 | nn.init.normal_(model.weight.data, 1.0, 0.02) 410 | nn.init.constant_(model.bias.data, 0) 411 | else: 412 | pass --------------------------------------------------------------------------------