├── .idea
├── colorNet-pytorch.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── colornet.py
├── myimgfolder.py
├── pt1
├── README.md
├── colornet.py
├── dataset.py
├── test.py
└── train.py
├── readme images
├── bad-result.png
├── good-result.png
└── model.png
├── removegray.py
├── train.py
└── val.py
/.idea/colorNet-pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 | 1547950770328
273 |
274 |
275 | 1547950770328
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # colorNet-pytorch
2 | A Neural Network For Automatic Image Colorization
3 |
4 | This project is a PyTorch version of the [ColorNet](http://hi.cs.waseda.ac.jp/~iizuka/projects/colorization/en/) issued on SIGGRAPH 2016. Please check out the original website for the details.
5 |
6 | ## Overview
7 | * Net model
8 | 
9 |
10 | * DataSet
11 | [MIT Places205](http://places.csail.mit.edu/user/index.php)
12 | > Hint: For there are grayscale images in the dataset, I write a script to remove these images
13 |
14 | * Development Environment
15 | Python 3.5.1
16 | CUDA 8.0
17 |
18 | ## Result
19 | I just train this model for 3 epochs while 11 epochs in the paper, so I think it will work better if train it more.
20 |
21 | * Good results
22 | 
23 | * Bad results
24 | 
25 | For this network is trained by landscape image database, it's work well for scenery pictures. So if you use this network to color images of other types, maybe you can't get a satisfying output.
26 |
27 | ## Pretrained model
28 | You can download the model from https://drive.google.com/file/d/0B6WuMuYfgb4XblE4c3N2RUJQcFU/view?usp=sharing
29 |
30 | ## Todo
31 | Implement with PyTorch1.0
32 |
33 |
--------------------------------------------------------------------------------
/colornet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | class LowLevelFeatNet(nn.Module):
7 | def __init__(self):
8 | super(LowLevelFeatNet, self).__init__()
9 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
10 | self.bn1 = nn.BatchNorm2d(64)
11 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
12 | self.bn2 = nn.BatchNorm2d(128)
13 | self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
14 | self.bn3 = nn.BatchNorm2d(128)
15 | self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
16 | self.bn4 = nn.BatchNorm2d(256)
17 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
18 | self.bn5 = nn.BatchNorm2d(256)
19 | self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
20 | self.bn6 = nn.BatchNorm2d(512)
21 |
22 | def forward(self, x1, x2):
23 | x1 = F.relu(self.bn1(self.conv1(x1)))
24 | x1 = F.relu(self.bn2(self.conv2(x1)))
25 | x1 = F.relu(self.bn3(self.conv3(x1)))
26 | x1 = F.relu(self.bn4(self.conv4(x1)))
27 | x1 = F.relu(self.bn5(self.conv5(x1)))
28 | x1 = F.relu(self.bn6(self.conv6(x1)))
29 | if self.training:
30 | x2 = x1.clone()
31 | else:
32 | x2 = F.relu(self.bn1(self.conv1(x2)))
33 | x2 = F.relu(self.bn2(self.conv2(x2)))
34 | x2 = F.relu(self.bn3(self.conv3(x2)))
35 | x2 = F.relu(self.bn4(self.conv4(x2)))
36 | x2 = F.relu(self.bn5(self.conv5(x2)))
37 | x2 = F.relu(self.bn6(self.conv6(x2)))
38 | return x1, x2
39 |
40 |
41 | class MidLevelFeatNet(nn.Module):
42 | def __init__(self):
43 | super(MidLevelFeatNet, self).__init__()
44 | self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
45 | self.bn1 = nn.BatchNorm2d(512)
46 | self.conv2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
47 | self.bn2 = nn.BatchNorm2d(256)
48 |
49 | def forward(self, x):
50 | x = F.relu(self.bn1(self.conv1(x)))
51 | x = F.relu(self.bn2(self.conv2(x)))
52 | return x
53 |
54 |
55 | class GlobalFeatNet(nn.Module):
56 | def __init__(self):
57 | super(GlobalFeatNet, self).__init__()
58 | self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
59 | self.bn1 = nn.BatchNorm2d(512)
60 | self.conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
61 | self.bn2 = nn.BatchNorm2d(512)
62 | self.conv3 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
63 | self.bn3 = nn.BatchNorm2d(512)
64 | self.conv4 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
65 | self.bn4 = nn.BatchNorm2d(512)
66 | self.fc1 = nn.Linear(25088, 1024)
67 | self.bn5 = nn.BatchNorm1d(1024)
68 | self.fc2 = nn.Linear(1024, 512)
69 | self.bn6 = nn.BatchNorm1d(512)
70 | self.fc3 = nn.Linear(512, 256)
71 | self.bn7 = nn.BatchNorm1d(256)
72 |
73 | def forward(self, x):
74 | x = F.relu(self.bn1(self.conv1(x)))
75 | x = F.relu(self.bn2(self.conv2(x)))
76 | x = F.relu(self.bn3(self.conv3(x)))
77 | x = F.relu(self.bn4(self.conv4(x)))
78 | x = x.view(-1, 25088)
79 | x = F.relu(self.bn5(self.fc1(x)))
80 | output_512 = F.relu(self.bn6(self.fc2(x)))
81 | output_256 = F.relu(self.bn7(self.fc3(output_512)))
82 | return output_512, output_256
83 |
84 |
85 | class ClassificationNet(nn.Module):
86 | def __init__(self):
87 | super(ClassificationNet, self).__init__()
88 | self.fc1 = nn.Linear(512, 256)
89 | self.bn1 = nn.BatchNorm1d(256)
90 | self.fc2 = nn.Linear(256, 205)
91 | self.bn2 = nn.BatchNorm1d(205)
92 |
93 | def forward(self, x):
94 | x = F.relu(self.bn1(self.fc1(x)))
95 | x = F.log_softmax(self.bn2(self.fc2(x)))
96 | return x
97 |
98 |
99 | class ColorizationNet(nn.Module):
100 | def __init__(self):
101 | super(ColorizationNet, self).__init__()
102 | self.fc1 = nn.Linear(512, 256)
103 | self.bn1 = nn.BatchNorm1d(256)
104 | self.conv1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
105 | self.bn2 = nn.BatchNorm2d(128)
106 | self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
107 | self.bn3 = nn.BatchNorm2d(64)
108 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
109 | self.bn4 = nn.BatchNorm2d(64)
110 | self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
111 | self.bn5 = nn.BatchNorm2d(32)
112 | self.conv5 = nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
113 | self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
114 |
115 | def forward(self, mid_input, global_input):
116 | w = mid_input.size()[2]
117 | h = mid_input.size()[3]
118 | global_input = global_input.unsqueeze(2).unsqueeze(2).expand_as(mid_input)
119 | fusion_layer = torch.cat((mid_input, global_input), 1)
120 | fusion_layer = fusion_layer.permute(2, 3, 0, 1).contiguous()
121 | fusion_layer = fusion_layer.view(-1, 512)
122 | fusion_layer = self.bn1(self.fc1(fusion_layer))
123 | fusion_layer = fusion_layer.view(w, h, -1, 256)
124 |
125 | x = fusion_layer.permute(2, 3, 0, 1).contiguous()
126 | x = F.relu(self.bn2(self.conv1(x)))
127 | x = self.upsample(x)
128 | x = F.relu(self.bn3(self.conv2(x)))
129 | x = F.relu(self.bn4(self.conv3(x)))
130 | x = self.upsample(x)
131 | x = F.sigmoid(self.bn5(self.conv4(x)))
132 | x = self.upsample(self.conv5(x))
133 | return x
134 |
135 |
136 | class ColorNet(nn.Module):
137 | def __init__(self):
138 | super(ColorNet, self).__init__()
139 | self.low_lv_feat_net = LowLevelFeatNet()
140 | self.mid_lv_feat_net = MidLevelFeatNet()
141 | self.global_feat_net = GlobalFeatNet()
142 | self.class_net = ClassificationNet()
143 | self.upsample_col_net = ColorizationNet()
144 |
145 | def forward(self, x1, x2):
146 | x1, x2 = self.low_lv_feat_net(x1, x2)
147 | #print('after low_lv, mid_input is:{}, global_input is:{}'.format(x1.size(), x2.size()))
148 | x1 = self.mid_lv_feat_net(x1)
149 | #print('after mid_lv, mid2fusion_input is:{}'.format(x1.size()))
150 | class_input, x2 = self.global_feat_net(x2)
151 | #print('after global_lv, class_input is:{}, global2fusion_input is:{}'.format(class_input.size(), x2.size()))
152 | class_output = self.class_net(class_input)
153 | #print('after class_lv, class_output is:{}'.format(class_output.size()))
154 | output = self.upsample_col_net(x1, x2)
155 | #print('after upsample_lv, output is:{}'.format(output.size()))
156 | return class_output, output
--------------------------------------------------------------------------------
/myimgfolder.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets, transforms
2 | from skimage.color import rgb2lab, rgb2gray
3 | import torch
4 | import numpy as np
5 | #import matplotlib.pyplot as plt
6 |
7 | scale_transform = transforms.Compose([
8 | transforms.Scale(256),
9 | transforms.RandomCrop(224),
10 | #transforms.ToTensor()
11 | ])
12 |
13 |
14 | class TrainImageFolder(datasets.ImageFolder):
15 | def __getitem__(self, index):
16 | path, target = self.imgs[index]
17 | img = self.loader(path)
18 | if self.transform is not None:
19 | img_original = self.transform(img)
20 | img_original = np.asarray(img_original)
21 |
22 | img_lab = rgb2lab(img_original)
23 | img_lab = (img_lab + 128) / 255
24 | img_ab = img_lab[:, :, 1:3]
25 | img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1)))
26 | img_original = rgb2gray(img_original)
27 | img_original = torch.from_numpy(img_original)
28 | if self.target_transform is not None:
29 | target = self.target_transform(target)
30 | return (img_original, img_ab), target
31 |
32 |
33 | class ValImageFolder(datasets.ImageFolder):
34 | def __getitem__(self, index):
35 | path, target = self.imgs[index]
36 | img = self.loader(path)
37 |
38 | img_scale = img.copy()
39 | img_original = img
40 | img_scale = scale_transform(img_scale)
41 |
42 | img_scale = np.asarray(img_scale)
43 | img_original = np.asarray(img_original)
44 |
45 | img_scale = rgb2gray(img_scale)
46 | img_scale = torch.from_numpy(img_scale)
47 | img_original = rgb2gray(img_original)
48 | img_original = torch.from_numpy(img_original)
49 | return (img_original, img_scale), target
50 |
--------------------------------------------------------------------------------
/pt1/README.md:
--------------------------------------------------------------------------------
1 | # colorNet-pytorch V1.0
2 |
3 | ## training data organization
4 | you have to put all the training data on a single folder
5 | and then you have prepare a train.txt under the same folder
6 | >the format of train.txt should be:
7 | imagename1 label1
8 | imagename2 label2
9 | .....
10 |
11 | ## test it
12 | if you want to test this model, you also have to prepare a test.txt under the folder mentioned above
13 | the format is same as train.txt
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/pt1/colornet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | class LowLevelFeatNet(nn.Module):
7 | def __init__(self):
8 | super(LowLevelFeatNet, self).__init__()
9 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
10 | self.bn1 = nn.BatchNorm2d(64)
11 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
12 | self.bn2 = nn.BatchNorm2d(128)
13 | self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
14 | self.bn3 = nn.BatchNorm2d(128)
15 | self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
16 | self.bn4 = nn.BatchNorm2d(256)
17 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
18 | self.bn5 = nn.BatchNorm2d(256)
19 | self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
20 | self.bn6 = nn.BatchNorm2d(512)
21 |
22 | def forward(self, x1, x2):
23 | x1 = F.relu(self.bn1(self.conv1(x1)))
24 | x1 = F.relu(self.bn2(self.conv2(x1)))
25 | x1 = F.relu(self.bn3(self.conv3(x1)))
26 | x1 = F.relu(self.bn4(self.conv4(x1)))
27 | x1 = F.relu(self.bn5(self.conv5(x1)))
28 | x1 = F.relu(self.bn6(self.conv6(x1)))
29 | if self.training:
30 | x2 = x1.clone()
31 | else:
32 | x2 = F.relu(self.bn1(self.conv1(x2)))
33 | x2 = F.relu(self.bn2(self.conv2(x2)))
34 | x2 = F.relu(self.bn3(self.conv3(x2)))
35 | x2 = F.relu(self.bn4(self.conv4(x2)))
36 | x2 = F.relu(self.bn5(self.conv5(x2)))
37 | x2 = F.relu(self.bn6(self.conv6(x2)))
38 | return x1, x2
39 |
40 |
41 | class MidLevelFeatNet(nn.Module):
42 | def __init__(self):
43 | super(MidLevelFeatNet, self).__init__()
44 | self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
45 | self.bn1 = nn.BatchNorm2d(512)
46 | self.conv2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
47 | self.bn2 = nn.BatchNorm2d(256)
48 |
49 | def forward(self, x):
50 | x = F.relu(self.bn1(self.conv1(x)))
51 | x = F.relu(self.bn2(self.conv2(x)))
52 | return x
53 |
54 |
55 | class GlobalFeatNet(nn.Module):
56 | def __init__(self):
57 | super(GlobalFeatNet, self).__init__()
58 | self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
59 | self.bn1 = nn.BatchNorm2d(512)
60 | self.conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
61 | self.bn2 = nn.BatchNorm2d(512)
62 | self.conv3 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
63 | self.bn3 = nn.BatchNorm2d(512)
64 | self.conv4 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
65 | self.bn4 = nn.BatchNorm2d(512)
66 | self.fc1 = nn.Linear(25088, 1024)
67 | self.bn5 = nn.BatchNorm1d(1024)
68 | self.fc2 = nn.Linear(1024, 512)
69 | self.bn6 = nn.BatchNorm1d(512)
70 | self.fc3 = nn.Linear(512, 256)
71 | self.bn7 = nn.BatchNorm1d(256)
72 |
73 | def forward(self, x):
74 | x = F.relu(self.bn1(self.conv1(x)))
75 | x = F.relu(self.bn2(self.conv2(x)))
76 | x = F.relu(self.bn3(self.conv3(x)))
77 | x = F.relu(self.bn4(self.conv4(x)))
78 | x = x.view(-1, 25088)
79 | x = F.relu(self.bn5(self.fc1(x)))
80 | output_512 = F.relu(self.bn6(self.fc2(x)))
81 | output_256 = F.relu(self.bn7(self.fc3(output_512)))
82 | return output_512, output_256
83 |
84 |
85 | class ClassificationNet(nn.Module):
86 | def __init__(self):
87 | super(ClassificationNet, self).__init__()
88 | self.fc1 = nn.Linear(512, 256)
89 | self.bn1 = nn.BatchNorm1d(256)
90 | self.fc2 = nn.Linear(256, 205)
91 | self.bn2 = nn.BatchNorm1d(205)
92 |
93 | def forward(self, x):
94 | x = F.relu(self.bn1(self.fc1(x)))
95 | x = F.softmax(self.bn2(self.fc2(x)),dim=1)
96 | return x
97 |
98 |
99 | class ColorizationNet(nn.Module):
100 | def __init__(self):
101 | super(ColorizationNet, self).__init__()
102 | self.fc1 = nn.Linear(512, 256)
103 | self.bn1 = nn.BatchNorm1d(256)
104 | self.conv1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
105 | self.bn2 = nn.BatchNorm2d(128)
106 | self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
107 | self.bn3 = nn.BatchNorm2d(64)
108 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
109 | self.bn4 = nn.BatchNorm2d(64)
110 | self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
111 | self.bn5 = nn.BatchNorm2d(32)
112 | self.conv5 = nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
113 | self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
114 |
115 | def forward(self, mid_input, global_input):
116 | w = mid_input.size()[2]
117 | h = mid_input.size()[3]
118 | global_input = global_input.unsqueeze(2).unsqueeze(2).expand_as(mid_input)
119 | fusion_layer = torch.cat((mid_input, global_input), 1)
120 | fusion_layer = fusion_layer.permute(2, 3, 0, 1).contiguous()
121 | fusion_layer = fusion_layer.view(-1, 512)
122 | fusion_layer = self.bn1(self.fc1(fusion_layer))
123 | fusion_layer = fusion_layer.view(w, h, -1, 256)
124 |
125 | x = fusion_layer.permute(2, 3, 0, 1).contiguous()
126 | x = F.relu(self.bn2(self.conv1(x)))
127 | x = self.upsample(x)
128 | x = F.relu(self.bn3(self.conv2(x)))
129 | x = F.relu(self.bn4(self.conv3(x)))
130 | x = self.upsample(x)
131 | x = F.sigmoid(self.bn5(self.conv4(x)))
132 | x = self.upsample(self.conv5(x))
133 | return x
134 |
135 |
136 | class ColorNet(nn.Module):
137 | def __init__(self):
138 | super(ColorNet, self).__init__()
139 | self.low_lv_feat_net = LowLevelFeatNet()
140 | self.mid_lv_feat_net = MidLevelFeatNet()
141 | self.global_feat_net = GlobalFeatNet()
142 | self.class_net = ClassificationNet()
143 | self.upsample_col_net = ColorizationNet()
144 |
145 | def forward(self, x1, x2):
146 | x1, x2 = self.low_lv_feat_net(x1, x2)
147 | #print('after low_lv, mid_input is:{}, global_input is:{}'.format(x1.size(), x2.size()))
148 | x1 = self.mid_lv_feat_net(x1)
149 | #print('after mid_lv, mid2fusion_input is:{}'.format(x1.size()))
150 | class_input, x2 = self.global_feat_net(x2)
151 | #print('after global_lv, class_input is:{}, global2fusion_input is:{}'.format(class_input.size(), x2.size()))
152 | class_output = self.class_net(class_input)
153 | #print('after class_lv, class_output is:{}'.format(class_output.size()))
154 | output = self.upsample_col_net(x1, x2)
155 | #print('after upsample_lv, output is:{}'.format(output.size()))
156 | return class_output, output
--------------------------------------------------------------------------------
/pt1/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 | from torch.utils.data import Dataset
5 | import torchvision.transforms as transforms
6 | from skimage.color import rgb2lab, rgb2gray
7 |
8 | data_augmentation = transforms.Compose([
9 | transforms.RandomHorizontalFlip(),
10 | transforms.RandomCrop(256),
11 | transforms.Resize(224)
12 | ])
13 |
14 |
15 | class ColorDataset(Dataset):
16 | def __init__(self, phase):
17 | assert (phase in ['train', 'val', 'test'])
18 | self.phase = phase
19 | self.root_dir = '/home/wsf/Pictures/Wallpapers'
20 | self.samples = None
21 | with open('{}/labels/{}.txt'.format(self.root_dir, phase), 'r') as f:
22 | self.samples = f.readlines()[:3]
23 |
24 | print('[+] dataset `{}` loaded {} images'.format(self.phase, len(self.samples)))
25 |
26 | def __getitem__(self, idx):
27 | if self.phase == 'train' or self.phase == 'val':
28 | image_path, label = self.samples[idx].strip().split()
29 | label = np.array(int(label))
30 | image = Image.open('{}/images/{}'.format(self.root_dir, image_path)).convert('RGB')
31 | image = data_augmentation(image)
32 | image = np.asarray(image)
33 | img_lab = rgb2lab(image)
34 | img_lab = (img_lab + 128) / 255
35 | img_ab = img_lab[:, :, 1:3].astype('float32')
36 | img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1)))
37 | img_gray = rgb2gray(image).astype('float32')
38 | img_gray = img_gray[np.newaxis, :]
39 | img_gray = torch.from_numpy(img_gray)
40 | return (img_gray, img_ab), label
41 |
42 | else:
43 | image_path = self.samples[idx].strip()
44 | img_gray = Image.open('{}/images/{}'.format(self.root_dir, image_path)).convert('L')
45 | img_gray_scale = img_gray.copy()
46 | img_gray_scale = img_gray_scale.resize((224, 224))
47 | img_gray = transforms.ToTensor()(img_gray)
48 | img_gray_scale = transforms.ToTensor()(img_gray_scale)
49 | return img_gray, img_gray_scale
50 |
51 | def __len__(self):
52 | return len(self.samples)
53 |
54 |
55 | if __name__ == '__main__':
56 | pass
57 |
--------------------------------------------------------------------------------
/pt1/test.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.autograd import Variable
5 | from torchvision.utils import make_grid, save_image
6 | from skimage.color import lab2rgb
7 | from skimage import io
8 | from colornet import ColorNet
9 | from myimgfolder import ValImageFolder
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | from colornet import ColorNet
13 | from pt1.dataset import ColorDataset
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'gpu')
16 | BZ=1
17 | test_set = ColorDataset('test')
18 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=BZ, shuffle=False, num_workers=4)
19 | color_model = ColorNet()
20 | color_model.load_state_dict(torch.load('/home/wsf/colornet_params.pkl'))
21 | color_model.to(device)
22 |
23 |
24 | def test():
25 | color_model.eval()
26 |
27 | for idx, (imgs, imgs_scale) in enumerate(test_loader):
28 | imgs = imgs.to(device)
29 | imgs_scale = imgs_scale.to(device)
30 | gray_name = test_set.samples[idx].strip().split('/')[-1]
31 | for img in imgs:
32 | pic = img.cpu().squeeze().numpy()
33 | pic = pic.astype(np.float64)
34 | plt.imsave('./{}/{}'.format('grayimg',gray_name), pic, cmap='gray')
35 | w = imgs.size(2)
36 | h = imgs.size(3)
37 |
38 | _, output = color_model(imgs, imgs_scale)
39 | color_img = torch.cat((imgs, output[:, :, 0:w, 0:h]), 1)
40 | color_img = color_img.data.cpu().numpy().transpose((0, 2, 3, 1))
41 | for img in color_img:
42 | img[:, :, 0:1] = img[:, :, 0:1] * 100
43 | img[:, :, 1:3] = img[:, :, 1:3] * 255 - 128
44 | img = img.astype(np.float64)
45 | img = lab2rgb(img)
46 | color_name = './colorimg/{}'.format(gray_name)
47 | plt.imsave(color_name, img)
48 | # use the follow method can't get the right image but I don't know why
49 | # color_img = torch.from_numpy(color_img.transpose((0, 3, 1, 2)))
50 | # sprite_img = make_grid(color_img)
51 | # color_name = './colorimg/'+str(i)+'.jpg'
52 | # save_image(sprite_img, color_name)
53 | # i += 1
54 |
55 | if __name__ == '__main__':
56 | test()
57 |
--------------------------------------------------------------------------------
/pt1/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import traceback
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from pt1.dataset import ColorDataset
8 | import numpy as np
9 | from torch.utils.data import DataLoader
10 | from colornet import ColorNet
11 |
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'gpu')
14 |
15 | train_set = ColorDataset('train')
16 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
17 | color_model = ColorNet()
18 | if os.path.exists('/home/wsf/colornet_params.pkl'):
19 | color_model.load_state_dict(torch.load('/home/wsf/colornet_params.pkl'))
20 | color_model.to(device)
21 | optimizer = optim.Adadelta(color_model.parameters())
22 |
23 | def train():
24 | color_model.train()
25 | try:
26 | for epoch in range(20):
27 | for batch_idx, (data, label) in enumerate(train_loader):
28 | messagefile = open('./message.txt', 'a')
29 | img_gray,img_ab = data
30 | img_gray = img_gray.to(device)
31 | img_ab = img_ab.to(device)
32 | label = label.to(device)
33 | optimizer.zero_grad()
34 | class_output, output = color_model(img_gray, img_gray)
35 | ems_loss = torch.pow((img_ab - output), 2).sum() / torch.from_numpy(np.array(list(output.size()))).prod()
36 | cross_entropy_loss = 1/300 * F.cross_entropy(class_output, label)
37 | loss = ems_loss + cross_entropy_loss
38 | lossmsg = 'loss: %.9f\n' % (loss.item())
39 | messagefile.write(lossmsg)
40 | loss.backward()
41 | # ems_loss.backward(retain_variables=True)
42 | # cross_entropy_loss.backward()
43 | optimizer.step()
44 | if batch_idx % 500 == 0:
45 | message = 'Train Epoch:%d\tPercent:[%d/%d (%.0f%%)]\tLoss:%.9f\n' % (
46 | epoch, batch_idx * len(data), len(train_loader.dataset),
47 | 100. * batch_idx / len(train_loader), loss.item())
48 | messagefile.write(message)
49 | torch.save(color_model.state_dict(), 'colornet_params.pkl')
50 | messagefile.close()
51 | # print('Train Epoch: {}[{}/{}({:.0f}%)]\tLoss: {:.9f}\n'.format(
52 | # epoch, batch_idx * len(data), len(train_loader.dataset),
53 | # 100. * batch_idx / len(train_loader), loss.data[0]))
54 | except Exception:
55 | logfile = open('log.txt', 'w')
56 | logfile.write(traceback.format_exc())
57 | logfile.close()
58 | finally:
59 | torch.save(color_model.state_dict(), 'colornet_params.pkl')
60 |
61 |
62 | if __name__ == '__main__':
63 | train()
64 |
--------------------------------------------------------------------------------
/readme images/bad-result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shufanwu/colorNet-pytorch/6139a7f9de1a3627d4aa02992f477f78f0335068/readme images/bad-result.png
--------------------------------------------------------------------------------
/readme images/good-result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shufanwu/colorNet-pytorch/6139a7f9de1a3627d4aa02992f477f78f0335068/readme images/good-result.png
--------------------------------------------------------------------------------
/readme images/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shufanwu/colorNet-pytorch/6139a7f9de1a3627d4aa02992f477f78f0335068/readme images/model.png
--------------------------------------------------------------------------------
/removegray.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import shutil
3 |
4 | import skimage.io as io
5 | import numpy as np
6 | from skimage.color import rgb2lab
7 |
8 | rootdir = '../images256/' # 指明被遍历的文件夹
9 | newrootdir = '../datadelete/'
10 |
11 | for parent, dirnames, filenames in os.walk(rootdir): # 三个参数:分别返回1.父目录 2.所有文件夹名字(不含路径) 3.所有文件名字
12 | # for dirname in dirnames: #输出文件夹信息
13 | # print("parent is:" + parent)
14 | # print("dirname is:" + dirname)
15 | messagefile = open('./deletemessage.txt', 'a')
16 | messagefile.write(os.path.split(parent)[1])
17 | messagefile.close()
18 |
19 | for filename in filenames: # 输出文件信息
20 | # print("parent is:" + parent)
21 | # print("filename is:" + filename)
22 | # print("the full name of the file is:" + os.path.join(parent,filename)) #输出文件路径信息
23 | path = os.path.join(parent, filename)
24 | img = io.imread(path)
25 |
26 | # remove gray image of 1 channel
27 | if len(img.shape) == 2:
28 | newpath = os.path.join(newrootdir, os.path.split(parent)[1])
29 | if not os.path.exists(newpath):
30 | os.makedirs(newpath)
31 | shutil.move(path, newpath)
32 |
33 | else:
34 | r = img[:, :, 0]
35 | g = img[:, :, 1]
36 | b = img[:, :, 2]
37 | num = r.size
38 | num1 = img.size
39 |
40 | # remove gray image of 3 channel
41 | if np.sum(abs(r - g) < 30) / num > 0.9 and np.sum(abs(r - b) < 30) / num > 0.9:
42 | newpath = os.path.join(newrootdir, os.path.split(parent)[1])
43 | if not os.path.exists(newpath):
44 | os.makedirs(newpath)
45 | shutil.move(path, newpath)
46 |
47 | else:
48 | # remove image with small color variance
49 | img = rgb2lab(img)
50 | ab = [img[:, :, 1], img[:, :, 2]]
51 | variance = np.sqrt(np.sum(np.power(ab[0] - np.sum(ab[0]) / num, 2)) / num) + \
52 | np.sqrt(np.sum(np.power(ab[1] - np.sum(ab[1]) / num, 2)) / num)
53 | if variance < 6:
54 | newpath = os.path.join(newrootdir, os.path.split(parent)[1])
55 | if not os.path.exists(newpath):
56 | os.makedirs(newpath)
57 | shutil.move(path, newpath)
58 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import traceback
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torch.autograd import Variable
8 | from torchvision import transforms
9 | import numpy as np
10 |
11 | from myimgfolder import TrainImageFolder
12 | from colornet import ColorNet
13 |
14 | original_transform = transforms.Compose([
15 | transforms.Scale(256),
16 | transforms.RandomCrop(224),
17 | transforms.RandomHorizontalFlip(),
18 | #transforms.ToTensor()
19 | ])
20 |
21 | have_cuda = torch.cuda.is_available()
22 | epochs = 1
23 |
24 | data_dir = "../images256/"
25 | train_set = TrainImageFolder(data_dir, original_transform)
26 | train_set_size = len(train_set)
27 | train_set_classes = train_set.classes
28 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
29 | color_model = ColorNet()
30 | if os.path.exists('./colornet_params.pkl'):
31 | color_model.load_state_dict(torch.load('colornet_params.pkl'))
32 | if have_cuda:
33 | color_model.cuda()
34 | optimizer = optim.Adadelta(color_model.parameters())
35 |
36 |
37 | def train(epoch):
38 | color_model.train()
39 |
40 | try:
41 | for batch_idx, (data, classes) in enumerate(train_loader):
42 | messagefile = open('./message.txt', 'a')
43 | original_img = data[0].unsqueeze(1).float()
44 | img_ab = data[1].float()
45 | if have_cuda:
46 | original_img = original_img.cuda()
47 | img_ab = img_ab.cuda()
48 | classes = classes.cuda()
49 | original_img = Variable(original_img)
50 | img_ab = Variable(img_ab)
51 | classes = Variable(classes)
52 | optimizer.zero_grad()
53 | class_output, output = color_model(original_img, original_img)
54 | ems_loss = torch.pow((img_ab - output), 2).sum() / torch.from_numpy(np.array(list(output.size()))).prod()
55 | cross_entropy_loss = 1/300 * F.cross_entropy(class_output, classes)
56 | loss = ems_loss + cross_entropy_loss
57 | lossmsg = 'loss: %.9f\n' % (loss.data[0])
58 | messagefile.write(lossmsg)
59 | ems_loss.backward(retain_variables=True)
60 | cross_entropy_loss.backward()
61 | optimizer.step()
62 | if batch_idx % 500 == 0:
63 | message = 'Train Epoch:%d\tPercent:[%d/%d (%.0f%%)]\tLoss:%.9f\n' % (
64 | epoch, batch_idx * len(data), len(train_loader.dataset),
65 | 100. * batch_idx / len(train_loader), loss.data[0])
66 | messagefile.write(message)
67 | torch.save(color_model.state_dict(), 'colornet_params.pkl')
68 | messagefile.close()
69 | # print('Train Epoch: {}[{}/{}({:.0f}%)]\tLoss: {:.9f}\n'.format(
70 | # epoch, batch_idx * len(data), len(train_loader.dataset),
71 | # 100. * batch_idx / len(train_loader), loss.data[0]))
72 | except Exception:
73 | logfile = open('log.txt', 'w')
74 | logfile.write(traceback.format_exc())
75 | logfile.close()
76 | finally:
77 | torch.save(color_model.state_dict(), 'colornet_params.pkl')
78 |
79 |
80 | for epoch in range(1, epochs + 1):
81 | train(epoch)
82 |
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.autograd import Variable
5 | from torchvision.utils import make_grid, save_image
6 | from skimage.color import lab2rgb
7 | from skimage import io
8 | from colornet import ColorNet
9 | from myimgfolder import ValImageFolder
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | data_dir = "../places205"
15 | have_cuda = torch.cuda.is_available()
16 |
17 | val_set = ValImageFolder(data_dir)
18 | val_set_size = len(val_set)
19 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)
20 |
21 | color_model = ColorNet()
22 | color_model.load_state_dict(torch.load('colornet_params.pkl'))
23 | if have_cuda:
24 | color_model.cuda()
25 |
26 |
27 | def val():
28 | color_model.eval()
29 |
30 | i = 0
31 | for data, _ in val_loader:
32 | original_img = data[0].unsqueeze(1).float()
33 | gray_name = './gray/' + str(i) + '.jpg'
34 | for img in original_img:
35 | pic = img.squeeze().numpy()
36 | pic = pic.astype(np.float64)
37 | plt.imsave(gray_name, pic, cmap='gray')
38 | w = original_img.size()[2]
39 | h = original_img.size()[3]
40 | scale_img = data[1].unsqueeze(1).float()
41 | if have_cuda:
42 | original_img, scale_img = original_img.cuda(), scale_img.cuda()
43 |
44 | original_img, scale_img = Variable(original_img, volatile=True), Variable(scale_img)
45 | _, output = color_model(original_img, scale_img)
46 | color_img = torch.cat((original_img, output[:, :, 0:w, 0:h]), 1)
47 | color_img = color_img.data.cpu().numpy().transpose((0, 2, 3, 1))
48 | for img in color_img:
49 | img[:, :, 0:1] = img[:, :, 0:1] * 100
50 | img[:, :, 1:3] = img[:, :, 1:3] * 255 - 128
51 | img = img.astype(np.float64)
52 | img = lab2rgb(img)
53 | color_name = './colorimg/' + str(i) + '.jpg'
54 | plt.imsave(color_name, img)
55 | i += 1
56 | # use the follow method can't get the right image but I don't know why
57 | # color_img = torch.from_numpy(color_img.transpose((0, 3, 1, 2)))
58 | # sprite_img = make_grid(color_img)
59 | # color_name = './colorimg/'+str(i)+'.jpg'
60 | # save_image(sprite_img, color_name)
61 | # i += 1
62 |
63 | val()
--------------------------------------------------------------------------------