├── Model.png
├── Model_details.pdf
├── outputs
└── sample
│ ├── img-2.jpg
│ ├── img-43.jpg
│ ├── img-79.jpg
│ ├── Example.png
│ ├── img-219.jpg
│ ├── img-2417.jpg
│ ├── img-2584.jpg
│ ├── img-2796.jpg
│ ├── img-3050.jpg
│ ├── img-4202.jpg
│ ├── img-4515.jpg
│ ├── img-7038.jpg
│ └── img-7159.jpg
├── README.md
├── dataloader.py
└── mymodels.py
/Model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/Model.png
--------------------------------------------------------------------------------
/Model_details.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/Model_details.pdf
--------------------------------------------------------------------------------
/outputs/sample/img-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-2.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-43.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-43.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-79.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-79.jpg
--------------------------------------------------------------------------------
/outputs/sample/Example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/Example.png
--------------------------------------------------------------------------------
/outputs/sample/img-219.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-219.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-2417.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-2417.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-2584.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-2584.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-2796.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-2796.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-3050.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-3050.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-4202.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-4202.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-4515.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-4515.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-7038.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-7038.jpg
--------------------------------------------------------------------------------
/outputs/sample/img-7159.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta6189/Anime-Sketch-Colorizer/HEAD/outputs/sample/img-7159.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Anime-Sketch-Colorizer
2 |
3 | Automatic Sketch Colorization with reference image
4 |
5 | Prerequisites
6 | ------
7 |
8 | `pytorch`
9 |
10 | `torchvision`
11 |
12 | `numpy`
13 |
14 | `openCV2`
15 |
16 | `matplotlib`
17 |
18 | Dataset
19 | ------
20 |
21 | Taebum Kim, "Anime Sketch Colorization Pair", https://www.kaggle.com/ktaebum/anime-sketch-colorization-pair
22 |
23 | Train
24 | ------
25 |
26 | Please refer `train.ipynb`
27 |
28 | Test
29 | ------
30 |
31 | Please refer `test.ipynb`
32 |
33 | * You can download pretrained checkpoint on https://drive.google.com/open?id=1pIZCjubtyOUr7AXtGQMvzcbKczJ9CtQG (449MB)
34 |
35 | Training details
36 | ------
37 |
38 | |
Parameter | Value |
39 | |:--------|:--------:|
40 | | Learning rate | 2e-4 |
41 | | Batch size | 2 |
42 | | Epoch | 25 |
43 | | Optimizer | Adam |
44 | | (beta1, beta2) | (0.5, 0.999) |
45 | | (lambda1, lambda2, lambda3) | (100, 1e-4, 1e-2) |
46 | | Data Augmentation | RandomResizedCrop(256)
RandomHorizontalFlip() |
47 | | HW | CPU : Intel i5-8400
RAM : 16G
GPU : NVIDIA GTX1060 6G |
48 | | Training Time | About 0.93s per iteration
(About 45 hours for 25 epoch) |
49 |
50 | Model
51 | ------
52 |
53 | 
54 |
55 | For more details, please refer `Model_details.pdf`
56 |
57 | Results
58 | -----
59 |
60 | Reference / Sketch / Colorization Result / Ground Truth
61 |
62 | 
63 | 
64 | 
65 | 
66 | 
67 | 
68 | 
69 | 
70 | 
71 | 
72 | 
73 | 
74 |
75 | Reference
76 | ------
77 |
78 | [1] Taebum Kim, "Anime Sketch Colorization Pair", https://www.kaggle.com/ktaebum/anime-sketch-colorization-pair, 2019., 2020.1.13.
79 |
80 | [2] Jim Bohnslav,"opencv_transforms", https://github.com/jbohnslav/opencv_transforms, 2020.1.13.
81 |
82 | [3] Takeru Miyato et al., "Spectral Normalization for Generative Adversarial Networks", ICLR 2018, 2018.2.18.
83 |
84 | [4] Ozan Oktay et al., "Attention U-Net: Learning Where to Look for the Pancreas", MIDL 2018, 2018.5.20.
85 |
86 | [5] Siyuan Qiao et al., "Weight Standardization", https://arxiv.org/abs/1903.10520, 2019. 3. 25., 2020.1.19.
87 |
88 | [6] Tero Karras, Samuli Laine, Timo Aila, "A Style-Based Generator Architecture for Generative Adversarial Networks", https://arxiv.org/abs/1812.04948, 2019.3.29., 2020.1.22.
89 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import torchvision
5 | import opencv_transforms.functional as FF
6 | from torchvision import datasets
7 | from PIL import Image
8 |
9 | def color_cluster(img, nclusters=9):
10 | """
11 | Apply K-means clustering to the input image
12 |
13 | Args:
14 | img: Numpy array which has shape of (H, W, C)
15 | nclusters: # of clusters (default = 9)
16 |
17 | Returns:
18 | color_palette: list of 3D numpy arrays which have same shape of that of input image
19 | e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4]
20 | and each component is (256, 256, 3) numpy array.
21 |
22 | Note:
23 | K-means clustering algorithm is quite computaionally intensive.
24 | Thus, before extracting dominant colors, the input images are resized to x0.25 size.
25 | """
26 | img_size = img.shape
27 | small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
28 | sample = small_img.reshape((-1, 3))
29 | sample = np.float32(sample)
30 | criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
31 | flags = cv2.KMEANS_PP_CENTERS
32 |
33 | _, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags)
34 | centers = np.uint8(centers)
35 | color_palette = []
36 |
37 | for i in range(0, nclusters):
38 | dominant_color = np.zeros(img_size, dtype='uint8')
39 | dominant_color[:,:,:] = centers[i]
40 | color_palette.append(dominant_color)
41 |
42 | return color_palette
43 |
44 | class PairImageFolder(datasets.ImageFolder):
45 | """
46 | A generic data loader where the images are arranged in this way: ::
47 |
48 | root/dog/xxx.png
49 | root/dog/xxy.png
50 | root/dog/xxz.png
51 |
52 | root/cat/123.png
53 | root/cat/nsdf3.png
54 | root/cat/asd932_.png
55 |
56 | This class works properly for paired image in form of [sketch, color_image]
57 |
58 | Args:
59 | root (string): Root directory path.
60 | transform (callable, optional): A function/transform that takes in an PIL image
61 | and returns a transformed version. E.g, ``transforms.RandomCrop``
62 | target_transform (callable, optional): A function/transform that takes in the
63 | target and transforms it.
64 | loader (callable, optional): A function to load an image given its path.
65 | is_valid_file (callable, optional): A function that takes path of an Image file
66 | and check if the file is a valid file (used to check of corrupt files)
67 | sketch_net: The network to convert color image to sketch image
68 | ncluster: Number of clusters when extracting color palette.
69 |
70 | Attributes:
71 | classes (list): List of the class names.
72 | class_to_idx (dict): Dict with items (class_name, class_index).
73 | imgs (list): List of (image path, class_index) tuples
74 |
75 | Getitem:
76 | img_edge: Edge image
77 | img: Color Image
78 | color_palette: Extracted color paltette
79 | """
80 | def __init__(self, root, transform, sketch_net, ncluster):
81 | super(PairImageFolder, self).__init__(root, transform)
82 | self.ncluster = ncluster
83 | self.sketch_net = sketch_net
84 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
85 |
86 | def __getitem__(self, index):
87 | path, label = self.imgs[index]
88 | img = self.loader(path)
89 | img = np.asarray(img)
90 | img = img[:, 0:512, :]
91 | img = self.transform(img)
92 | color_palette = color_cluster(img, nclusters=self.ncluster)
93 | img = self.make_tensor(img)
94 |
95 | with torch.no_grad():
96 | img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy()
97 | img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
98 | img_edge = FF.to_tensor(img_edge)
99 |
100 | for i in range(0, len(color_palette)):
101 | color = color_palette[i]
102 | color_palette[i] = self.make_tensor(color)
103 |
104 | return img_edge, img, color_palette
105 |
106 | def make_tensor(self, img):
107 | img = FF.to_tensor(img)
108 | img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
109 | return img
110 |
111 | class GetImageFolder(datasets.ImageFolder):
112 | """
113 | A generic data loader where the images are arranged in this way: ::
114 |
115 | root/dog/xxx.png
116 | root/dog/xxy.png
117 | root/dog/xxz.png
118 |
119 | root/cat/123.png
120 | root/cat/nsdf3.png
121 | root/cat/asd932_.png
122 |
123 | Args:
124 | root (string): Root directory path.
125 | transform (callable, optional): A function/transform that takes in an PIL image
126 | and returns a transformed version. E.g, ``transforms.RandomCrop``
127 | target_transform (callable, optional): A function/transform that takes in the
128 | target and transforms it.
129 | loader (callable, optional): A function to load an image given its path.
130 | is_valid_file (callable, optional): A function that takes path of an Image file
131 | and check if the file is a valid file (used to check of corrupt files)
132 | sketch_net: The network to convert color image to sketch image
133 | ncluster: Number of clusters when extracting color palette.
134 |
135 | Attributes:
136 | classes (list): List of the class names.
137 | class_to_idx (dict): Dict with items (class_name, class_index).
138 | imgs (list): List of (image path, class_index) tuples
139 |
140 | Getitem:
141 | img_edge: Edge image
142 | img: Color Image
143 | color_palette: Extracted color paltette
144 | """
145 | def __init__(self, root, transform, sketch_net, ncluster):
146 | super(GetImageFolder, self).__init__(root, transform)
147 | self.ncluster = ncluster
148 | self.sketch_net = sketch_net
149 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
150 |
151 | def __getitem__(self, index):
152 | path, label = self.imgs[index]
153 | img = self.loader(path)
154 | img = np.asarray(img)
155 | img = self.transform(img)
156 | color_palette = color_cluster(img, nclusters=self.ncluster)
157 | img = self.make_tensor(img)
158 |
159 | with torch.no_grad():
160 | img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy()
161 | img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
162 | img_edge = FF.to_tensor(img_edge)
163 |
164 | for i in range(0, len(color_palette)):
165 | color = color_palette[i]
166 | color_palette[i] = self.make_tensor(color)
167 |
168 | return img_edge, img, color_palette
169 |
170 | def make_tensor(self, img):
171 | img = FF.to_tensor(img)
172 | img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
173 | return img
--------------------------------------------------------------------------------
/mymodels.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 |
5 | __all__ = [
6 | 'Color2Sketch', 'Sketch2Color', 'Discriminator',
7 | ]
8 |
9 | class ApplyNoise(nn.Module):
10 | def __init__(self, channels):
11 | super().__init__()
12 | self.weight = nn.Parameter(torch.zeros(channels))
13 |
14 | def forward(self, x, noise=None):
15 | if noise is None:
16 | noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
17 | return x + self.weight.view(1, -1, 1, 1) * noise.to(x.device)
18 |
19 | class Conv2d_WS(nn.Conv2d):
20 | def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
21 | super().__init__(in_chan, out_chan, kernel_size, stride, padding, dilation, groups, bias)
22 |
23 | def forward(self, x):
24 | weight = self.weight
25 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)
26 | weight = weight - weight_mean
27 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1,1,1,1)+1e-5
28 | weight = weight / std.expand_as(weight)
29 | return torch.nn.functional.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
30 |
31 | class ResidualBlock(nn.Module):
32 | def __init__(self, in_channels, out_channels, stride=1, sample=None):
33 | super(ResidualBlock, self).__init__()
34 | self.ic = in_channels
35 | self.oc = out_channels
36 | self.conv1 = Conv2d_WS(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
37 | self.bn1 = nn.GroupNorm(32, out_channels)
38 | self.conv2 = Conv2d_WS(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
39 | self.bn2 = nn.GroupNorm(32, out_channels)
40 | self.convr = Conv2d_WS(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
41 | self.bnr = nn.GroupNorm(32, out_channels)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.sample = sample
44 | if self.sample == 'down':
45 | self.sampling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
46 | elif self.sample == 'up':
47 | self.sampling = nn.Upsample(scale_factor=2, mode='nearest')
48 |
49 | def forward(self, x):
50 | if self.ic != self.oc:
51 | residual = self.convr(x)
52 | residual = self.bnr(residual)
53 | else:
54 | residual = x
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu(out)
58 | out = self.conv2(out)
59 | out = self.bn2(out)
60 | out += residual
61 | out = self.relu(out)
62 | if self.sample is not None:
63 | out = self.sampling(out)
64 | return out
65 |
66 | class Attention_block(nn.Module):
67 | def __init__(self,F_g,F_l,F_int):
68 | super(Attention_block,self).__init__()
69 | self.W_g = nn.Sequential(
70 | Conv2d_WS(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
71 | nn.GroupNorm(32, F_int)
72 | )
73 |
74 | self.W_x = nn.Sequential(
75 | Conv2d_WS(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
76 | nn.GroupNorm(32, F_int)
77 | )
78 |
79 | self.psi = nn.Sequential(
80 | Conv2d_WS(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
81 | nn.InstanceNorm2d(1),
82 | nn.Sigmoid()
83 | )
84 |
85 | self.relu = nn.ReLU(inplace=True)
86 |
87 | def forward(self,g,x):
88 | g1 = self.W_g(g)
89 | x1 = self.W_x(x)
90 | psi = self.relu(g1+x1)
91 | psi = self.psi(psi)
92 |
93 | return x*psi
94 |
95 | class Color2Sketch(nn.Module):
96 | def __init__(self, nc=3, pretrained=False):
97 | super(Color2Sketch, self).__init__()
98 | class Encoder(nn.Module):
99 | def __init__(self):
100 | super(Encoder, self).__init__()
101 | # Build ResNet and change first conv layer to accept single-channel input
102 | self.layer1 = ResidualBlock(nc, 64, sample='down')
103 | self.layer2 = ResidualBlock(64, 128, sample='down')
104 | self.layer3 = ResidualBlock(128, 256, sample='down')
105 | self.layer4 = ResidualBlock(256, 512, sample='down')
106 | self.layer5 = ResidualBlock(512, 512, sample='down')
107 | self.layer6 = ResidualBlock(512, 512, sample='down')
108 | self.layer7 = ResidualBlock(512, 512, sample='down')
109 |
110 | def forward(self, input_image):
111 | # Pass input through ResNet-gray to extract features
112 | x0 = input_image # nc * 256 * 256
113 | x1 = self.layer1(x0) # 64 * 128 * 128
114 | x2 = self.layer2(x1) # 128 * 64 * 64
115 | x3 = self.layer3(x2) # 256 * 32 * 32
116 | x4 = self.layer4(x3) # 512 * 16 * 16
117 | x5 = self.layer5(x4) # 512 * 8 * 8
118 | x6 = self.layer6(x5) # 512 * 4 * 4
119 | x7 = self.layer7(x6) # 512 * 2 * 2
120 |
121 | return x1, x2, x3, x4, x5, x6, x7
122 |
123 | class Decoder(nn.Module):
124 | def __init__(self):
125 | super(Decoder, self).__init__()
126 | # Convolutional layers and upsampling
127 | self.noise7 = ApplyNoise(512)
128 | self.layer7_up = ResidualBlock(512, 512, sample='up')
129 |
130 | self.Att6 = Attention_block(F_g=512,F_l=512,F_int=256)
131 | self.layer6 = ResidualBlock(1024, 512, sample=None)
132 | self.noise6 = ApplyNoise(512)
133 | self.layer6_up = ResidualBlock(512, 512, sample='up')
134 |
135 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
136 | self.layer5 = ResidualBlock(1024, 512, sample=None)
137 | self.noise5 = ApplyNoise(512)
138 | self.layer5_up = ResidualBlock(512, 512, sample='up')
139 |
140 | self.Att4 = Attention_block(F_g=512,F_l=512,F_int=256)
141 | self.layer4 = ResidualBlock(1024, 512, sample=None)
142 | self.noise4 = ApplyNoise(512)
143 | self.layer4_up = ResidualBlock(512, 256, sample='up')
144 |
145 | self.Att3 = Attention_block(F_g=256,F_l=256,F_int=128)
146 | self.layer3 = ResidualBlock(512, 256, sample=None)
147 | self.noise3 = ApplyNoise(256)
148 | self.layer3_up = ResidualBlock(256, 128, sample='up')
149 |
150 | self.Att2 = Attention_block(F_g=128,F_l=128,F_int=64)
151 | self.layer2 = ResidualBlock(256, 128, sample=None)
152 | self.noise2 = ApplyNoise(128)
153 | self.layer2_up = ResidualBlock(128, 64, sample='up')
154 |
155 | self.Att1 = Attention_block(F_g=64,F_l=64,F_int=32)
156 | self.layer1 = ResidualBlock(128, 64, sample=None)
157 | self.noise1 = ApplyNoise(64)
158 | self.layer1_up = ResidualBlock(64, 32, sample='up')
159 |
160 | self.noise0 = ApplyNoise(32)
161 | self.layer0 = Conv2d_WS(32, 3, kernel_size=3, stride=1, padding=1)
162 | self.activation = nn.ReLU(inplace=True)
163 | self.tanh = nn.Tanh()
164 |
165 | def forward(self, midlevel_input): #, global_input):
166 | x1, x2, x3, x4, x5, x6, x7 = midlevel_input
167 |
168 | x = self.noise7(x7)
169 | x = self.layer7_up(x) # 512 * 4 * 4
170 |
171 | x6 = self.Att6(g=x,x=x6)
172 | x = torch.cat((x, x6), dim=1) # 1024 * 4 * 4
173 | x = self.layer6(x) # 512 * 4 * 4
174 | x = self.noise6(x)
175 | x = self.layer6_up(x) # 512 * 8 * 8
176 |
177 | x5 = self.Att5(g=x,x=x5)
178 | x = torch.cat((x, x5), dim=1) # 1024 * 8 * 8
179 | x = self.layer5(x) # 512 * 8 * 8
180 | x = self.noise5(x)
181 | x = self.layer5_up(x) # 512 * 16 * 16
182 |
183 | x4 = self.Att4(g=x,x=x4)
184 | x = torch.cat((x, x4), dim=1) # 1024 * 16 * 16
185 | x = self.layer4(x) # 512 * 16 * 16
186 | x = self.noise4(x)
187 | x = self.layer4_up(x) # 256 * 32 * 32
188 |
189 | x3 = self.Att3(g=x,x=x3)
190 | x = torch.cat((x, x3), dim=1) # 512 * 32 * 32
191 | x = self.layer3(x) # 256 * 32 * 32
192 | x = self.noise3(x)
193 | x = self.layer3_up(x) # 128 * 64 * 64
194 |
195 | x2 = self.Att2(g=x,x=x2)
196 | x = torch.cat((x, x2), dim=1) # 256 * 64 * 64
197 | x = self.layer2(x) # 128 * 64 * 64
198 | x = self.noise2(x)
199 | x = self.layer2_up(x) # 64 * 128 * 128
200 |
201 | x1 = self.Att1(g=x,x=x1)
202 | x = torch.cat((x, x1), dim=1) # 128 * 128 * 128
203 | x = self.layer1(x) # 64 * 128 * 128
204 | x = self.noise1(x)
205 | x = self.layer1_up(x) # 32 * 256 * 256
206 |
207 | x = self.noise0(x)
208 | x = self.layer0(x) # 3 * 256 * 256
209 | x = self.tanh(x)
210 |
211 | return x
212 |
213 | self.encoder = Encoder()
214 | self.decoder = Decoder()
215 | if pretrained:
216 | print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
217 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
218 | checkpoint = torch.load('./checkpoint/color2edge/ckpt.pth')
219 | self.load_state_dict(checkpoint['netG'], strict=True)
220 | print("Done!")
221 | else:
222 | self.apply(weights_init)
223 | print('Weights of {0} model are initialized'.format('Color2Sketch'))
224 |
225 | def forward(self, inputs):
226 | encode = self.encoder(inputs)
227 | output = self.decoder(encode)
228 |
229 | return output
230 |
231 | class Sketch2Color(nn.Module):
232 | def __init__(self, nc=3, pretrained=False):
233 | super(Sketch2Color, self).__init__()
234 | class Encoder(nn.Module):
235 | def __init__(self):
236 | super(Encoder, self).__init__()
237 | # Build ResNet and change first conv layer to accept single-channel input
238 | self.layer1 = ResidualBlock(nc, 64, sample='down')
239 | self.layer2 = ResidualBlock(64, 128, sample='down')
240 | self.layer3 = ResidualBlock(128, 256, sample='down')
241 | self.layer4 = ResidualBlock(256, 512, sample='down')
242 | self.layer5 = ResidualBlock(512, 512, sample='down')
243 | self.layer6 = ResidualBlock(512, 512, sample='down')
244 | self.layer7 = ResidualBlock(512, 512, sample='down')
245 |
246 | def forward(self, input_image):
247 | # Pass input through ResNet-gray to extract features
248 | x0 = input_image # nc * 256 * 256
249 | x1 = self.layer1(x0) # 64 * 128 * 128
250 | x2 = self.layer2(x1) # 128 * 64 * 64
251 | x3 = self.layer3(x2) # 256 * 32 * 32
252 | x4 = self.layer4(x3) # 512 * 16 * 16
253 | x5 = self.layer5(x4) # 512 * 8 * 8
254 | x6 = self.layer6(x5) # 512 * 4 * 4
255 | x7 = self.layer7(x6) # 512 * 2 * 2
256 |
257 | return x1, x2, x3, x4, x5, x6, x7
258 |
259 | class Decoder(nn.Module):
260 | def __init__(self):
261 | super(Decoder, self).__init__()
262 | # Convolutional layers and upsampling
263 | self.noise7 = ApplyNoise(512)
264 | self.layer7_up = ResidualBlock(512, 512, sample='up')
265 |
266 | self.Att6 = Attention_block(F_g=512,F_l=512,F_int=256)
267 | self.layer6 = ResidualBlock(1024, 512, sample=None)
268 | self.noise6 = ApplyNoise(512)
269 | self.layer6_up = ResidualBlock(512, 512, sample='up')
270 |
271 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
272 | self.layer5 = ResidualBlock(1024, 512, sample=None)
273 | self.noise5 = ApplyNoise(512)
274 | self.layer5_up = ResidualBlock(512, 512, sample='up')
275 |
276 | self.Att4 = Attention_block(F_g=512,F_l=512,F_int=256)
277 | self.layer4 = ResidualBlock(1024, 512, sample=None)
278 | self.noise4 = ApplyNoise(512)
279 | self.layer4_up = ResidualBlock(512, 256, sample='up')
280 |
281 | self.Att3 = Attention_block(F_g=256,F_l=256,F_int=128)
282 | self.layer3 = ResidualBlock(512, 256, sample=None)
283 | self.noise3 = ApplyNoise(256)
284 | self.layer3_up = ResidualBlock(256, 128, sample='up')
285 |
286 | self.Att2 = Attention_block(F_g=128,F_l=128,F_int=64)
287 | self.layer2 = ResidualBlock(256, 128, sample=None)
288 | self.noise2 = ApplyNoise(128)
289 | self.layer2_up = ResidualBlock(128, 64, sample='up')
290 |
291 | self.Att1 = Attention_block(F_g=64,F_l=64,F_int=32)
292 | self.layer1 = ResidualBlock(128, 64, sample=None)
293 | self.noise1 = ApplyNoise(64)
294 | self.layer1_up = ResidualBlock(64, 32, sample='up')
295 |
296 | self.noise0 = ApplyNoise(32)
297 | self.layer0 = Conv2d_WS(32, 3, kernel_size=3, stride=1, padding=1)
298 | self.activation = nn.ReLU(inplace=True)
299 | self.tanh = nn.Tanh()
300 |
301 | def forward(self, midlevel_input): #, global_input):
302 | x1, x2, x3, x4, x5, x6, x7 = midlevel_input
303 |
304 | x = self.noise7(x7)
305 | x = self.layer7_up(x) # 512 * 4 * 4
306 |
307 | x6 = self.Att6(g=x,x=x6)
308 | x = torch.cat((x, x6), dim=1) # 1024 * 4 * 4
309 | x = self.layer6(x) # 512 * 4 * 4
310 | x = self.noise6(x)
311 | x = self.layer6_up(x) # 512 * 8 * 8
312 |
313 | x5 = self.Att5(g=x,x=x5)
314 | x = torch.cat((x, x5), dim=1) # 1024 * 8 * 8
315 | x = self.layer5(x) # 512 * 8 * 8
316 | x = self.noise5(x)
317 | x = self.layer5_up(x) # 512 * 16 * 16
318 |
319 | x4 = self.Att4(g=x,x=x4)
320 | x = torch.cat((x, x4), dim=1) # 1024 * 16 * 16
321 | x = self.layer4(x) # 512 * 16 * 16
322 | x = self.noise4(x)
323 | x = self.layer4_up(x) # 256 * 32 * 32
324 |
325 | x3 = self.Att3(g=x,x=x3)
326 | x = torch.cat((x, x3), dim=1) # 512 * 32 * 32
327 | x = self.layer3(x) # 256 * 32 * 32
328 | x = self.noise3(x)
329 | x = self.layer3_up(x) # 128 * 64 * 64
330 |
331 | x2 = self.Att2(g=x,x=x2)
332 | x = torch.cat((x, x2), dim=1) # 256 * 64 * 64
333 | x = self.layer2(x) # 128 * 64 * 64
334 | x = self.noise2(x)
335 | x = self.layer2_up(x) # 64 * 128 * 128
336 |
337 | x1 = self.Att1(g=x,x=x1)
338 | x = torch.cat((x, x1), dim=1) # 128 * 128 * 128
339 | x = self.layer1(x) # 64 * 128 * 128
340 | x = self.noise1(x)
341 | x = self.layer1_up(x) # 32 * 256 * 256
342 |
343 | x = self.noise0(x)
344 | x = self.layer0(x) # 3 * 256 * 256
345 | x = self.tanh(x)
346 |
347 | return x
348 |
349 | self.encoder = Encoder()
350 | self.decoder = Decoder()
351 | if pretrained:
352 | print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
353 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
354 | checkpoint = torch.load('./checkpoint/edge2color/ckpt.pth')
355 | self.load_state_dict(checkpoint['netG'], strict=True)
356 | print("Done!")
357 | else:
358 | self.apply(weights_init)
359 | print('Weights of {0} model are initialized'.format('Sketch2Color'))
360 |
361 | def forward(self, inputs):
362 | encode = self.encoder(inputs)
363 | output = self.decoder(encode)
364 |
365 | return output
366 |
367 | class Discriminator(nn.Module):
368 | def __init__(self, nc=6, pretrained=False):
369 | super(Discriminator, self).__init__()
370 | self.conv1 = torch.nn.utils.spectral_norm(nn.Conv2d(nc, 64, kernel_size=4, stride=2, padding=1))
371 | self.bn1 = nn.GroupNorm(32, 64)
372 | self.conv2 = torch.nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1))
373 | self.bn2 = nn.GroupNorm(32,128)
374 | self.conv3 = torch.nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1))
375 | self.bn3 = nn.GroupNorm(32, 256)
376 | self.conv4 = torch.nn.utils.spectral_norm(nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1))
377 | self.bn4 = nn.GroupNorm(32, 512)
378 | self.conv5 = torch.nn.utils.spectral_norm(nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1))
379 | self.activation = nn.LeakyReLU(0.2, inplace=True)
380 | self.sigmoid = nn.Sigmoid()
381 |
382 | if pretrained:
383 | pass
384 | else:
385 | self.apply(weights_init)
386 | print('Weights of {0} model are initialized'.format('Discriminator'))
387 |
388 | def forward(self, base, unknown):
389 | input = torch.cat((base, unknown), dim=1)
390 | x = self.activation(self.conv1(input))
391 | x = self.activation(self.bn2(self.conv2(x)))
392 | x = self.activation(self.bn3(self.conv3(x)))
393 | x = self.activation(self.bn4(self.conv4(x)))
394 | x = self.sigmoid(self.conv5(x))
395 |
396 | return x.mean((2,3))
397 |
398 | # To initialize model weights
399 | def weights_init(model):
400 | classname = model.__class__.__name__
401 | if classname.find('Conv') != -1:
402 | nn.init.normal_(model.weight.data, 0.0, 0.02)
403 | elif classname.find('Conv2d_WS') != -1:
404 | nn.init.normal_(model.weight.data, 0.0, 0.02)
405 | elif classname.find('BatchNorm') != -1:
406 | nn.init.normal_(model.weight.data, 1.0, 0.02)
407 | nn.init.constant_(model.bias.data, 0)
408 | elif classname.find('GroupNorm') != -1:
409 | nn.init.normal_(model.weight.data, 1.0, 0.02)
410 | nn.init.constant_(model.bias.data, 0)
411 | else:
412 | pass
--------------------------------------------------------------------------------