├── 06d0c0f6a101f2ace438291a3745974e.webp ├── 16dfa3999cf665921cfffcc0d311b0ad.webp ├── 48e0e364bee5feb07b5b09727e811912.webp ├── 5d2f5c3c4e09d388427855efbe849af3.webp ├── 6d53f4d4053639eb90d46eeb6c99d212.webp ├── 70250ca6fc0100fe018ad486281c32c5.webp ├── 7e1423a84f48a0a0209934843023cb6c.webp ├── README.md ├── a8bfc27f0f734d6fa35534be40302f66.webp ├── aligned_dataset.py ├── b5ded66e02a178aa35e5812cf1844cea.webp ├── base_dataset.py ├── cae60fe4035e7646068a5891faccd13d.webp ├── cd7d4fb60475341e71ee2c1fb681eb52.webp ├── f26fa8698155e8becf1158db712af304.webp ├── model.py ├── test.py └── train.py /06d0c0f6a101f2ace438291a3745974e.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/06d0c0f6a101f2ace438291a3745974e.webp -------------------------------------------------------------------------------- /16dfa3999cf665921cfffcc0d311b0ad.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/16dfa3999cf665921cfffcc0d311b0ad.webp -------------------------------------------------------------------------------- /48e0e364bee5feb07b5b09727e811912.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/48e0e364bee5feb07b5b09727e811912.webp -------------------------------------------------------------------------------- /5d2f5c3c4e09d388427855efbe849af3.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/5d2f5c3c4e09d388427855efbe849af3.webp -------------------------------------------------------------------------------- /6d53f4d4053639eb90d46eeb6c99d212.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/6d53f4d4053639eb90d46eeb6c99d212.webp -------------------------------------------------------------------------------- /70250ca6fc0100fe018ad486281c32c5.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/70250ca6fc0100fe018ad486281c32c5.webp -------------------------------------------------------------------------------- /7e1423a84f48a0a0209934843023cb6c.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/7e1423a84f48a0a0209934843023cb6c.webp -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |  2 | # 1.研究背景与意义 3 | 4 | 5 | 6 | 随着计算机技术的不断发展,图像处理领域也取得了长足的进步。其中,图像上色是一个重要的研究方向,它可以将黑白图像转化为彩色图像,使得图像更加真实、生动。图像上色在许多领域都有广泛的应用,比如电影、游戏、艺术创作等。然而,传统的图像上色方法往往需要人工干预,耗时且效果不佳。因此,研究基于改进CycleGAN&pix2pix的黑白图像上色系统具有重要的意义。 7 | 8 | 首先,基于改进CycleGAN&pix2pix的黑白图像上色系统可以提高图像上色的自动化程度。传统的图像上色方法通常需要人工干预,例如手动选择颜色、调整参数等。而基于改进CycleGAN&pix2pix的系统可以通过学习大量的彩色图像和对应的黑白图像,自动学习到图像的颜色分布和特征,从而实现自动上色。这不仅可以节省人力成本,还可以提高上色的效率和准确性。 9 | 10 | 其次,基于改进CycleGAN&pix2pix的黑白图像上色系统可以提供更好的上色效果。传统的图像上色方法往往存在一些问题,比如颜色失真、边缘模糊等。而基于改进CycleGAN&pix2pix的系统可以通过深度学习的方法,学习到更准确的颜色分布和特征,从而提供更真实、生动的上色效果。这对于电影、游戏等领域的图像处理来说,具有重要的实际应用价值。 11 | 12 | 此外,基于改进CycleGAN&pix2pix的黑白图像上色系统还可以为艺术创作提供更多的可能性。艺术家们常常使用黑白图像作为创作的基础,然后通过手工上色来增加图像的表现力和情感。而基于改进CycleGAN&pix2pix的系统可以为艺术家们提供一个自动上色的工具,帮助他们更好地表达自己的创意和想法。这对于艺术创作的发展来说,具有重要的推动作用。 13 | 14 | 综上所述,基于改进CycleGAN&pix2pix的黑白图像上色系统在图像处理领域具有重要的研究意义和实际应用价值。它可以提高图像上色的自动化程度,提供更好的上色效果,同时为艺术创作提供更多的可能性。相信随着深度学习技术的不断发展,基于改进CycleGAN&pix2pix的系统将在未来取得更加广泛的应用和进一步的研究突破。 15 | 16 | # 2.图片演示 17 | ![2.png](06d0c0f6a101f2ace438291a3745974e.webp) 18 | 19 | ![3.png](b5ded66e02a178aa35e5812cf1844cea.webp) 20 | 21 | ![4.png](cae60fe4035e7646068a5891faccd13d.webp) 22 | 23 | # 3.视频演示 24 | [基于改进CycleGAN&pix2pix的黑白图像上色系统_哔哩哔哩_bilibili](https://www.bilibili.com/video/BV19w411q72T/?spm_id_from=333.999.0.0&vd_source=ff015de2d29cbe2a9cdbfa7064407a08) 25 | 26 | # 4.模型框架 27 | 模型框架如图所示,采用双网络结构,包含记忆增强网络和着色网络。 28 | ![image.png](cd7d4fb60475341e71ee2c1fb681eb52.webp) 29 | 图中,G为着色网络生成器;D为着色网络判别器;ResNet为预训练好的残差网络,用于提取图像信息。在训练过程中,记忆增强网络一方面存储训练集中彩色图像的空间特征、颜色特征等,另一方面学习如何高效地进行记忆查询,即快速检索与查询图像最匹配的颜色特征。着色网络生成器G则学习如何高效地将颜色特征注入灰度图像中,判别器D学习如何快速区分真实图像与生成图像。 30 | ![image.png](16dfa3999cf665921cfffcc0d311b0ad.webp) 31 | 32 | 33 | 34 | # 5.核心代码讲解 35 | 36 | #### 5.1 model.py 37 | 38 | 封装的类为ImageColorizationGAN,包含了生成器(Generator)、判别器(Discriminator)和组合模型(combined)的构建方法,以及生成彩色图像的方法(generate_colorized_images)。 39 | 40 | ```python 41 | 42 | class Generator(nn.Module): 43 | def __init__(self, input_shape): 44 | super(Generator, self).__init__() 45 | self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=4, stride=2, padding=1) 46 | self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) 47 | self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1) 48 | # Add more layers as needed for your U-Net generator architecture 49 | 50 | def forward(self, x): 51 | x = self.conv1(x) 52 | x = nn.functional.relu(x) 53 | x = self.conv2(x) 54 | x = nn.functional.relu(x) 55 | x = self.conv3(x) 56 | x = nn.functional.relu(x) 57 | # Add forward pass for the rest of your generator layers 58 | return x 59 | 60 | class Discriminator(nn.Module): 61 | def __init__(self, input_shape): 62 | super(Discriminator, self).__init__() 63 | self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=4, stride=2, padding=1) 64 | self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) 65 | self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1) 66 | # Add more layers as needed for your PatchGAN discriminator architecture 67 | self.fc = nn.Linear(final_size, 1) # Define final_size correctly 68 | 69 | def forward(self, x): 70 | x = self.conv1(x) 71 | x = nn.functional.leaky_relu(x, negative_slope=0.2) 72 | x = self.conv2(x) 73 | x = nn.functional.leaky_relu(x, negative_slope=0.2) 74 | x = self.conv3(x) 75 | x = nn.functional.leaky_relu(x, negative_slope=0.2) 76 | # Add forward pass for the rest of your discriminator layers 77 | x = x.view(x.size(0), -1) 78 | x = self.fc(x) 79 | return x 80 | 81 | class ImageColorizationGAN: 82 | def __init__(self, image_shape): 83 | self.image_shape = image_shape 84 | self.generator = self.build_generator() 85 | self.discriminator = self.build_discriminator() 86 | self.combined = self.build_combined() 87 | 88 | def build_generator(self): 89 | return Generator(self.image_shape) 90 | 91 | def build_discriminator(self): 92 | return Discriminator(self.image_shape) 93 | 94 | def build_combined(self): 95 | self.discriminator.trainable = False 96 | combined_model = nn.Sequential(self.generator, self.discriminator) 97 | return combined_model 98 | 99 | ...... 100 | 101 | ``` 102 | 103 | 该程序文件名为model.py,主要包含了三个类:Generator、Discriminator和ImageColorizationGAN。 104 | 105 | Generator类是一个继承自nn.Module的模型,用于生成彩色图像。它包含了三个卷积层,分别是conv1、conv2和conv3,用于提取图像特征。在forward方法中,通过对输入数据进行卷积和激活函数操作,得到生成的图像。 106 | 107 | Discriminator类也是一个继承自nn.Module的模型,用于判别图像的真实性。它也包含了三个卷积层,与Generator类相同,用于提取图像特征。在forward方法中,通过对输入数据进行卷积和LeakyReLU激活函数操作,得到判别结果。 108 | 109 | ImageColorizationGAN类是一个用于图像着色的生成对抗网络。它包含了一个Generator和一个Discriminator,并且通过build_generator、build_discriminator和build_combined方法来构建这两个模型。在generate_colorized_images方法中,通过调用Generator的forward方法来生成彩色图像。 110 | 111 | 整个程序文件的目的是构建一个用于图像着色的生成对抗网络模型,其中Generator用于生成彩色图像,Discriminator用于判别图像的真实性,ImageColorizationGAN用于整合这两个模型,并提供生成彩色图像的功能。 112 | 113 | #### 5.2 test.py 114 | 115 | ```python 116 | 117 | 118 | class ImageTranslator: 119 | def __init__(self): 120 | self.opt = TestOptions().parse() # get test options 121 | # hard-code some parameters for test 122 | self.opt.num_threads = 0 # test code only supports num_threads = 0 123 | self.opt.batch_size = 1 # test code only supports batch_size = 1 124 | self.opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 125 | self.opt.no_flip = True # no flip; comment this line if results on flipped images are needed. 126 | self.opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 127 | self.dataset = create_dataset(self.opt) # create a dataset given opt.dataset_mode and other options 128 | self.model = create_model(self.opt) # create a model given opt.model and other options 129 | self.model.setup(self.opt) # regular setup: load and print networks; create schedulers 130 | 131 | def translate_images(self): 132 | # create a website 133 | web_dir = os.path.join(self.opt.results_dir, self.opt.name, '{}_{}'.format(self.opt.phase, self.opt.epoch)) # define the website directory 134 | if self.opt.load_iter > 0: # load_iter is 0 by default 135 | web_dir = '{:s}_iter{:d}'.format(web_dir, self.opt.load_iter) 136 | print('creating web directory', web_dir) 137 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (self.opt.name, self.opt.phase, self.opt.epoch)) 138 | # test with eval mode. This only affects layers like batchnorm and dropout. 139 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 140 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 141 | if self.opt.eval: 142 | self.model.eval() 143 | for i, data in enumerate(self.dataset): 144 | if i >= self.opt.num_test: # only apply our model to opt.num_test images. 145 | break 146 | self.model.set_input(data) # unpack data from data loader 147 | self.model.test() # run inference 148 | visuals = self.model.get_current_visuals() # get image results 149 | img_path = self.model.get_image_paths() # get image paths 150 | if i % 5 == 0: # save images to an HTML file 151 | print('processing (%04d)-th image... %s' % (i, img_path)) 152 | save_images(webpage, visuals, img_path, aspect_ratio=self.opt.aspect_ratio, width=self.opt.display_winsize, use_wandb=self.opt.use_wandb) 153 | webpage.save() # save the HTML 154 | 155 | if __name__ == '__main__': 156 | translator = ImageTranslator() 157 | translator.translate_images() 158 | ``` 159 | 160 | 这是一个用于图像到图像转换的通用测试脚本。该脚本用于加载已保存的模型并将结果保存到指定目录。它首先根据选项创建模型和数据集,然后运行推理并将结果保存到一个HTML文件中。 161 | 162 | 该脚本的用法示例包括: 163 | - 测试CycleGAN模型(双向转换):`python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan` 164 | - 测试CycleGAN模型(单向转换):`python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout` 165 | - 测试pix2pix模型:`python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA` 166 | 167 | 该脚本还支持其他测试选项,可以参考`options/base_options.py`和`options/test_options.py`文件获取更多信息。 168 | 169 | 该脚本依赖于一些其他模块和库,如`options.test_options`、`data`、`models`、`util.visualizer`和`util.html`。如果缺少`wandb`库,会输出警告信息。 170 | 171 | 在`__main__`函数中,首先解析测试选项,然后硬编码一些测试参数。接着创建数据集和模型,并进行初始化设置。如果使用了`wandb`库,还会初始化日志记录器。然后创建一个网页目录,并根据选项设置模型的评估模式。最后,遍历数据集中的图像,对每张图像进行推理,并保存结果到HTML文件中。 172 | 173 | 总之,这个脚本用于加载模型并对图像进行测试,将结果保存到HTML文件中。 174 | 175 | #### 5.3 train.py 176 | 177 | ```python 178 | 179 | 180 | class ImageToImageTranslationTrainer: 181 | def __init__(self): 182 | self.opt = TrainOptions().parse() 183 | self.dataset = create_dataset(self.opt) 184 | self.dataset_size = len(self.dataset) 185 | self.model = create_model(self.opt) 186 | self.visualizer = Visualizer(self.opt) 187 | self.total_iters = 0 188 | 189 | def train(self): 190 | for epoch in range(self.opt.epoch_count, self.opt.n_epochs + self.opt.n_epochs_decay + 1): 191 | epoch_start_time = time.time() 192 | iter_data_time = time.time() 193 | epoch_iter = 0 194 | self.visualizer.reset() 195 | self.model.update_learning_rate() 196 | for i, data in enumerate(self.dataset): 197 | iter_start_time = time.time() 198 | if self.total_iters % self.opt.print_freq == 0: 199 | t_data = iter_start_time - iter_data_time 200 | 201 | self.total_iters += self.opt.batch_size 202 | epoch_iter += self.opt.batch_size 203 | self.model.set_input(data) 204 | self.model.optimize_parameters() 205 | 206 | if self.total_iters % self.opt.display_freq == 0: 207 | save_result = self.total_iters % self.opt.update_html_freq == 0 208 | self.model.compute_visuals() 209 | self.visualizer.display_current_results(self.model.get_current_visuals(), epoch, save_result) 210 | 211 | if self.total_iters % self.opt.print_freq == 0: 212 | losses = self.model.get_current_losses() 213 | t_comp = (time.time() - iter_start_time) / self.opt.batch_size 214 | self.visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 215 | if self.opt.display_id > 0: 216 | self.visualizer.plot_current_losses(epoch, float(epoch_iter) / self.dataset_size, losses) 217 | 218 | if self.total_iters % self.opt.save_latest_freq == 0: 219 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, self.total_iters)) 220 | save_suffix = 'iter_%d' % self.total_iters if self.opt.save_by_iter else 'latest' 221 | self.model.save_networks(save_suffix) 222 | 223 | iter_data_time = time.time() 224 | if epoch % self.opt.save_epoch_freq == 0: 225 | print('saving the model at the end of epoch %d, iters %d' % (epoch, self.total_iters)) 226 | self.model.save_networks('latest') 227 | self.model.save_networks(epoch) 228 | 229 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, self.opt.n_epochs + self.opt.n_epochs_decay, time.time() - epoch_start_time)) 230 | ...... 231 | ``` 232 | 233 | 这个程序文件是一个通用的用于图像到图像转换的训练脚本。它适用于不同的模型(使用选项'--model',例如pix2pix、cyclegan、colorization)和不同的数据集(使用选项'--dataset_mode',例如aligned、unaligned、single、colorization)。你需要指定数据集('--dataroot')、实验名称('--name')和模型('--model')。 234 | 235 | 它首先根据选项创建模型、数据集和可视化器。然后进行标准的网络训练。在训练过程中,它还会可视化/保存图像,打印/保存损失图表,并保存模型。该脚本支持继续/恢复训练。使用'--continue_train'来恢复之前的训练。 236 | 237 | 238 | #### 5.4 data\aligned_dataset.py 239 | 240 | ```python 241 | 242 | class AlignedDataset(BaseDataset): 243 | def __init__(self, opt): 244 | BaseDataset.__init__(self, opt) 245 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 246 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) 247 | assert(self.opt.load_size >= self.opt.crop_size) 248 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 249 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc 250 | 251 | def __getitem__(self, index): 252 | AB_path = self.AB_paths[index] 253 | AB = Image.open(AB_path).convert('RGB') 254 | w, h = AB.size 255 | w2 = int(w / 2) 256 | A = AB.crop((0, 0, w2, h)) 257 | B = AB.crop((w2, 0, w, h)) 258 | transform_params = get_params(self.opt, A.size) 259 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 260 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 261 | A = A_transform(A) 262 | B = B_transform(B) 263 | return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} 264 | 265 | def __len__(self): 266 | return len(self.AB_paths) 267 | ``` 268 | 269 | 这个程序文件是一个用于处理配对图像数据集的数据集类。它假设目录'/path/to/data/train'中包含以{A,B}形式的图像对。在测试时,需要准备一个目录'/path/to/data/test'。 270 | 271 | 这个类继承自BaseDataset类,并重写了其中的一些方法。在初始化方法中,它通过opt参数获取数据集的根目录和阶段,并将图像目录设置为根目录和阶段的组合。然后,它使用make_dataset函数获取图像路径,并对路径进行排序。接下来,它根据opt中的一些参数设置输入通道数和输出通道数。 272 | 273 | 在getitem方法中,它根据给定的索引读取一个图像,并将其分割为A和B两部分。然后,它应用相同的变换方法对A和B进行处理,并返回一个包含A、B、A_paths和B_paths的字典。 274 | 275 | 在len方法中,它返回数据集中图像的总数。 276 | 277 | 这个程序文件依赖于其他几个模块和类,如os、BaseDataset、get_params、get_transform、make_dataset和Image。 278 | 279 | #### 5.5 data\base_dataset.py 280 | 281 | ```python 282 | 283 | 284 | class BaseDataset(data.Dataset, ABC): 285 | def __init__(self, opt): 286 | self.opt = opt 287 | self.root = opt.dataroot 288 | 289 | @staticmethod 290 | def modify_commandline_options(parser, is_train): 291 | return parser 292 | 293 | @abstractmethod 294 | def __len__(self): 295 | return 0 296 | 297 | @abstractmethod 298 | def __getitem__(self, index): 299 | pass 300 | 301 | 302 | def get_params(opt, size): 303 | w, h = size 304 | new_h = h 305 | new_w = w 306 | if opt.preprocess == 'resize_and_crop': 307 | new_h = new_w = opt.load_size 308 | elif opt.preprocess == 'scale_width_and_crop': 309 | new_w = opt.load_size 310 | new_h = opt.load_size * h // w 311 | 312 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 313 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 314 | 315 | flip = random.random() > 0.5 316 | 317 | return {'crop_pos': (x, y), 'flip': flip} 318 | 319 | 320 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 321 | transform_list = [] 322 | if grayscale: 323 | transform_list.append(transforms.Grayscale(1)) 324 | if 'resize' in opt.preprocess: 325 | osize = [opt.load_size, opt.load_size] 326 | transform_list.append(transforms.Resize(osize, method)) 327 | elif 'scale_width' in opt.preprocess: 328 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 329 | 330 | if 'crop' in opt.preprocess: 331 | if params is None: 332 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 333 | else: 334 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 335 | 336 | if opt.preprocess == 'none': 337 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 338 | 339 | if not opt.no_flip: 340 | if params is None: 341 | transform_list.append(transforms.RandomHorizontalFlip()) 342 | elif params['flip']: 343 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 344 | 345 | if convert: 346 | transform_list += [transforms.ToTensor()] 347 | if grayscale: 348 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 349 | else: 350 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 351 | return transforms.Compose(transform_list) 352 | ``` 353 | 354 | 这个程序文件是一个抽象基类(ABC)'BaseDataset'的实现,用于数据集。 355 | 356 | 它还包括常见的转换函数(例如,get_transform,__scale_width),可以在子类中使用。 357 | 358 | 这个类有以下几个函数需要实现: 359 | - <__init__>: 初始化类,首先调用BaseDataset.__init__(self, opt)。 360 | - <__len__>: 返回数据集的大小。 361 | - <__getitem__>: 获取一个数据点。 362 | - : (可选)添加特定于数据集的选项并设置默认选项。 363 | 364 | 这个类还定义了一些静态方法和辅助函数,用于数据预处理和转换。 365 | 366 | # 6.系统整体结构 367 | 368 | 整体功能和构架概述: 369 | 370 | 该项目是一个基于改进CycleGAN&pix2pix的黑白图像上色系统。它使用生成对抗网络(GAN)来实现将灰度图像转换为彩色图像的功能。项目包含了训练、测试和应用程序部分。 371 | 372 | 在训练部分,train.py文件用于训练模型。它根据选项指定的数据集和模型类型,创建相应的数据集对象和模型对象,并进行训练。训练过程中,会进行图像处理、损失计算、模型更新等操作,并保存训练好的模型。 373 | 374 | 在测试部分,test.py文件用于加载已保存的模型,并对图像进行转换。它根据选项指定的数据集和模型类型,创建相应的数据集对象和模型对象,并进行图像转换。转换结果会保存到HTML文件中。 375 | 376 | 在应用程序部分,ui.py文件实现了一个基于PyQt5的图形界面,用于用户交互。用户可以选择图像文件,然后点击按钮进行图像处理。处理结果会显示在界面上,并计算与原始图像之间的指标。 377 | 378 | 下面是每个文件的功能整理: 379 | 380 | | 文件路径 | 功能 | 381 | | -------- | ---- | 382 | | model.py | 定义了生成器和判别器的模型类 | 383 | | test.py | 测试脚本,加载模型并对图像进行转换 | 384 | | train.py | 训练脚本,用于训练模型 | 385 | | ui.py | 图形界面应用程序,用于用户交互 | 386 | | data/aligned_dataset.py | 处理配对图像数据集的数据集类 | 387 | | data/base_dataset.py | 数据集的抽象基类 | 388 | | data/colorization_dataset.py | 处理着色图像数据集的数据集类 | 389 | | data/image_folder.py | 处理图像文件夹数据集的数据集类 | 390 | | data/single_dataset.py | 处理单个图像数据集的数据集类 | 391 | | data/template_dataset.py | 数据集的模板类 | 392 | | data/unaligned_dataset.py | 处理不配对图像数据集的数据集类 | 393 | | data/__init__.py | 数据集模块的初始化文件 | 394 | | datasets/combine_A_and_B.py | 将两个数据集合并为一个的脚本 | 395 | | datasets/make_dataset_aligned.py | 创建配对图像数据集的脚本 | 396 | | datasets/resize_pic.py | 调整图像大小的脚本 | 397 | | models/base_model.py | 模型的抽象基类 | 398 | | models/colorization_model.py | 着色模型的实现 | 399 | | models/cycle_gan_model.py | CycleGAN模型的实现 | 400 | | models/networks.py | 网络模型的实现 | 401 | | models/pix2pix_model.py | pix2pix模型的实现 | 402 | | models/template_model.py | 模型的模板类 | 403 | | models/test_model.py | 测试模型的实现 | 404 | | models/__init__.py | 模型模块的初始化文件 | 405 | | options/base_options.py | 基础选项类,定义了通用的训练和测试选项 | 406 | | options/test_options.py | 测试选项类,定义了测试时的选项 | 407 | | options/train_options.py | 训练选项类,定义了训练时的选项 | 408 | | options/__init__.py | 选项模块的初始化文件 | 409 | | scripts/test_before_push.py | 在推送之前进行测试的脚本 | 410 | | scripts/edges/batch_hed.py | 边缘检测的脚本 | 411 | | scripts/eval_cityscapes/cityscapes.py | Cityscapes数据集的评估脚本 | 412 | | scripts/eval_cityscapes/evaluate.py | Cityscapes数据集的评估脚本 | 413 | | scripts/eval_cityscapes/util.py | Cityscapes数据集评估的辅助函数 | 414 | | util/get_data.py | 获取数据的辅助函数 | 415 | | util/html.py | HTML文件处理的辅助函数 | 416 | | util/image_pool.py | 图像缓存池的实现 | 417 | | util/util.py | 通用的辅助函数 | 418 | | util/visualizer.py | 可视化工具类 | 419 | | util/__init__.py | 工具模块的初始化文件 | 420 | 421 | # 7.CycleGAN简介 422 | Cycle-GAN是一个2017年推出的直击产业痛点的模型。众所周知,在一系列视觉问题上是很难以找到匹配的高质量图像作为target来供模型学习的,比如在超分辨领域内对于一个低分辨率的物体图像,未必能找到同样场景的高分辨率图像,这使得一系列深度学习模型的适应性有限。上述的困难总结起来就是:由于模型训练时必须依赖匹配的图像,而除非有目的的去产生这样的图像否则无法训练,并且很容易造成数据有偏。 423 | 424 | Cycle-GAN训练的目的则避开了上述困难;该模型的思路是旨在形成一个由数据域A到数据域B的普适性映射,学习的目标是数据域A和B的风格之间的变换而非具体的数据a和b之间的一一映射关系。从这样的思路出发Cycle-GAN对于数据一一匹配的依赖性就不存在了,可以解决一系列问题,因此该模型的设计思路与具体做法十分值得学习。 425 | 426 | 总的来说,基于Cycle-GAN的模型具有较强的适应性,能够适应一系列的视觉问题场合,比如超分辨,风格变换,图像增强等等场合。 427 | 428 | 下面附一张匹配和非匹配图像的说明 429 | ![image.png](7e1423a84f48a0a0209934843023cb6c.webp) 430 | 431 | 通常的GAN的设计思路从信息流的角度出发是单向的,如下图所示:使用Generator从a产生一个假的b,然后使用Determinator判断这个假的b是否属于B集合,并将这个信息反馈至Generator,然后通过逐次分别提高Generator与Discriminator的能力以期达到使Generator能以假乱真的能力,这样的设计思路在一般有匹配图像的情况下是合理的。 432 | ![image.png](a8bfc27f0f734d6fa35534be40302f66.webp) 433 | 434 | # 8.pix2pix简介 435 | 给定一个输入数据和噪声数据生成目标图像,在pix2pix中判别器的输入是生成图像和源图像,而生成器的输入是源图像和随机噪声(使生成模型具有一定的随机性),pix2pix是通过在生成器的模型层加入Dropout来引入随机噪声,但是其带来输出内容的随机性并没有很大。同时在损失函数的使用上采用的是L1正则而非CGAN使用的L2正则用来使图像更清晰。 436 | 437 | 条件生成对抗网络为基础,用于图像翻译的通用模型框架。(图像翻译:将一个物体的图像表征转化为该物体的另一个表征,即找到两不同域的对应关系,从而实现图像的跨域转化) 438 | 439 | (条件生成对抗网络:相较于传统GAN的生成内容仅由生成器参数和噪音来决定,CGAN中向生成器和判别器添加了一个条件信息y) 440 | 441 | 采用CNN卷积+BN+ReLU的模型结构 442 | 443 | #### 生成器 444 | 445 | 以U-Net作为基础结构增加跳跃连接(下降通道256->64)压缩路径中每次为4*4的same卷积+BN+ReLU,根据是否降采样来控制卷积的步长。同时压缩路径和扩张路径使用的是拼接操作进行特征融合。 446 | ![image.png](6d53f4d4053639eb90d46eeb6c99d212.webp) 447 | 448 | #### 判别器 449 | 450 | 传统GAN生成图像比较模糊(由于采用整图作为判别输入,pix2pix则分成N*N的Patch【大概将256的图N=7效果最好,但是N越大生成的图像质量越高1*1的被称为PixelGAN,不过一般自己调整感受野选择参数】) 451 | 452 | # 9.改进CycleGAN 453 | 454 | #### 批量归一化 BN 455 | 文中算法在CycleGAN基础上进行改进,原始的CycleGAN为了能够加快训练网络时的收敛速度以及归纳统一样本的统计分布性,采用批量归一化 BN (batchnormalization) 它能将数据强行归一化为均值为О、方差为1的正态分布上,这样有利于数据分布一致,也避免了梯度消失。BN的缺点也很明显,首先,它对样本的数量的大小比较敏感,因为每次计算的方差和均值都在一个批量上,所以如果批量太小,这会导致方差和均值不足以代表整个数据分布。其次BN只对固定深度的前向神经网络很方便。 456 | 基于此,文中采用了实例归一化IN(instance nor-malization)[1]替换 BN,这是因为在无监督上色中,输出图像主要依赖于某个图像实例,IN是对一个批次中单张图片进行归一化,而不是像BN对整批图片进行归一化然后提取平均值,所以IN更适合于无监督上色,提高输出图片上色的合理性和更好地保留底层信息,IN的计算公式如下: 457 | ![image.png](70250ca6fc0100fe018ad486281c32c5.webp) 458 | 其次,该算法还在生成器中引入了自注意力Self-Attention[3]机制,传统的生成对抗网络的问题表现在卷积核的大小选取上,小的卷积核很难发现图像中的依赖关系,但是大的卷积核又会降低卷积网络的计算效率。 459 | 为了提升上色效果,该算法将Self-Attention加载到生成器网络中,具体网络如图所示,即在前一层的特征图上加入 Self-Attention机制,使得生成对抗网络在生成时能区分不同的特征图。 460 | ![image.png](f26fa8698155e8becf1158db712af304.webp) 461 | 462 | #### 网络结构 463 | 在对图像彩色化的问题上,为了保持原图像的底层轮廓信息不变和上色的合理性,一般是采用 U-Net结构,如图所示,其好处是不需要所有的信息都通过生成器的所有层,这样就能使得输入和输出图像的底层信息和突出边缘位置信息能够被共享。 U-Net 型网络对于提升输出图像的细节起到了良好的效果。 464 | ![image.png](5d2f5c3c4e09d388427855efbe849af3.webp) 465 | 466 | # 10.系统整合 467 | 468 | 下图[完整源码&数据集&环境部署视频教程&自定义UI界面](https://s.xiaocichang.com/s/6b559d) 469 | ![1.png](48e0e364bee5feb07b5b09727e811912.webp) 470 | 471 | 472 | 参考博客[《基于改进CycleGAN&pix2pix的黑白图像上色系统》](https://mbd.pub/o/qunshan/work) 473 | 474 | # 11.参考文献 475 | --- 476 | [1][商露兮](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%95%86%E9%9C%B2%E5%85%AE%22),[方建安](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%96%B9%E5%BB%BA%E5%AE%89%22),[谷小婧](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E8%B0%B7%E5%B0%8F%E5%A9%A7%22),等.[夜视图像自动彩色化源图选择算法](https://d.wanfangdata.com.cn/periodical/jgyhw200902029)[J].[激光与红外](https://sns.wanfangdata.com.cn/perio/jgyhw).2009,(2).DOI:10.3969/j.issn.1001-5078.2009.02.029 . 477 | 478 | [2][孟敏](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%AD%9F%E6%95%8F%22),[刘利刚](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%88%98%E5%88%A9%E5%88%9A%22).[勾画式局部颜色迁移](https://d.wanfangdata.com.cn/periodical/jsjfzsjytxxxb200807003)[J].[计算机辅助设计与图形学学报](https://sns.wanfangdata.com.cn/perio/jsjfzsjytxxxb).2008,(7). 479 | 480 | [3][李苏梅](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%8E%E8%8B%8F%E6%A2%85%22),[韩国强](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E9%9F%A9%E5%9B%BD%E5%BC%BA%22).[基于K-均值聚类算法的图像区域分割方法](https://d.wanfangdata.com.cn/periodical/jsjgcyyy200816050)[J].[计算机工程与应用](https://sns.wanfangdata.com.cn/perio/jsjgcyyy).2008,(16).DOI:10.3778/j.issn.1002-8331.2008.16.050 . 481 | 482 | [4][孙吉贵](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%AD%99%E5%90%89%E8%B4%B5%22),[刘杰](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%88%98%E6%9D%B0%22),[赵连宇](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E8%B5%B5%E8%BF%9E%E5%AE%87%22).[聚类算法研究](https://d.wanfangdata.com.cn/periodical/rjxb200801006)[J].[软件学报](https://sns.wanfangdata.com.cn/perio/rjxb).2008,(1).DOI:10.3724/SP.J.1001.2008.00048 . 483 | 484 | [5][李建明](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%8E%E5%BB%BA%E6%98%8E%22),[叶飞](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%8F%B6%E9%A3%9E%22),[于守秋](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E4%BA%8E%E5%AE%88%E7%A7%8B%22),等.[一种快速灰度图像彩色化算法](https://d.wanfangdata.com.cn/periodical/zgtxtxxb-a200703026)[J].[中国图象图形学报](https://sns.wanfangdata.com.cn/perio/zgtxtxxb-a).2007,(3).DOI:10.3969/j.issn.1006-8961.2007.03.026 . 485 | 486 | [6][段立娟](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%AE%B5%E7%AB%8B%E5%A8%9F%22).[形状特征的编码描述研究综述](https://d.wanfangdata.com.cn/periodical/jsjkx200708059)[J].[计算机科学](https://sns.wanfangdata.com.cn/perio/jsjkx).2007,(8).DOI:10.3969/j.issn.1002-137X.2007.08.059 . 487 | 488 | [7][王常亮](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E7%8E%8B%E5%B8%B8%E4%BA%AE%22).[基于聚类的自动颜色传输](https://d.wanfangdata.com.cn/periodical/jsjgcyyy200725020)[J].[计算机工程与应用](https://sns.wanfangdata.com.cn/perio/jsjgcyyy).2007,(25).DOI:10.3321/j.issn:1002-8331.2007.25.020 . 489 | 490 | [8][朱为](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9C%B1%E4%B8%BA%22),[李国辉](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%8E%E5%9B%BD%E8%BE%89%22),[涂丹](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%B6%82%E4%B8%B9%22).[纹理合成技术在旧照片修补中的应用](https://d.wanfangdata.com.cn/periodical/jsjgcyyy200728067)[J].[计算机工程与应用](https://sns.wanfangdata.com.cn/perio/jsjgcyyy).2007,(28).DOI:10.3321/j.issn:1002-8331.2007.28.067 . 491 | 492 | [9][赵源萌](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E8%B5%B5%E6%BA%90%E8%90%8C%22),[王岭雪](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E7%8E%8B%E5%B2%AD%E9%9B%AA%22),[金伟其](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E9%87%91%E4%BC%9F%E5%85%B6%22),等.[基于色彩传递的生物医学图像彩色化](https://d.wanfangdata.com.cn/periodical/smkxyq200711006)[J].[生命科学仪器](https://sns.wanfangdata.com.cn/perio/smkxyq).2007,(11).DOI:10.3969/j.issn.1671-7929.2007.11.006 . 493 | 494 | [10][杨春玲](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%A8%E6%98%A5%E7%8E%B2%22),[旷开智](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%97%B7%E5%BC%80%E6%99%BA%22),[陈冠豪](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E9%99%88%E5%86%A0%E8%B1%AA%22),等.[基于梯度的结构相似度的图像质量评价方法](https://d.wanfangdata.com.cn/periodical/hnlgdxxb200609005)[J].[华南理工大学学报(自然科学版)](https://sns.wanfangdata.com.cn/perio/hnlgdxxb).2006,(9).DOI:10.3321/j.issn:1000-565X.2006.09.005 . 495 | 496 | 497 | 498 | --- 499 | #### 如果您需要更详细的【源码和环境部署教程】,除了通过【系统整合】小节的链接获取之外,还可以通过邮箱以下途径获取: 500 | #### 1.请先在GitHub上为该项目点赞(Star),编辑一封邮件,附上点赞的截图、项目的中文描述概述(About)以及您的用途需求,发送到我们的邮箱 501 | #### sharecode@yeah.net 502 | #### 2.我们收到邮件后会定期根据邮件的接收顺序将【完整源码和环境部署教程】发送到您的邮箱。 503 | #### 【免责声明】本文来源于用户投稿,如果侵犯任何第三方的合法权益,可通过邮箱联系删除。 -------------------------------------------------------------------------------- /a8bfc27f0f734d6fa35534be40302f66.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/a8bfc27f0f734d6fa35534be40302f66.webp -------------------------------------------------------------------------------- /aligned_dataset.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | class AlignedDataset(BaseDataset): 4 | def __init__(self, opt): 5 | BaseDataset.__init__(self, opt) 6 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 7 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) 8 | assert(self.opt.load_size >= self.opt.crop_size) 9 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 10 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc 11 | 12 | def __getitem__(self, index): 13 | AB_path = self.AB_paths[index] 14 | AB = Image.open(AB_path).convert('RGB') 15 | w, h = AB.size 16 | w2 = int(w / 2) 17 | A = AB.crop((0, 0, w2, h)) 18 | B = AB.crop((w2, 0, w, h)) 19 | transform_params = get_params(self.opt, A.size) 20 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 21 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 22 | A = A_transform(A) 23 | B = B_transform(B) 24 | return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} 25 | 26 | def __len__(self): 27 | return len(self.AB_paths) 28 | -------------------------------------------------------------------------------- /b5ded66e02a178aa35e5812cf1844cea.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/b5ded66e02a178aa35e5812cf1844cea.webp -------------------------------------------------------------------------------- /base_dataset.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | 4 | class BaseDataset(data.Dataset, ABC): 5 | def __init__(self, opt): 6 | self.opt = opt 7 | self.root = opt.dataroot 8 | 9 | @staticmethod 10 | def modify_commandline_options(parser, is_train): 11 | return parser 12 | 13 | @abstractmethod 14 | def __len__(self): 15 | return 0 16 | 17 | @abstractmethod 18 | def __getitem__(self, index): 19 | pass 20 | 21 | 22 | def get_params(opt, size): 23 | w, h = size 24 | new_h = h 25 | new_w = w 26 | if opt.preprocess == 'resize_and_crop': 27 | new_h = new_w = opt.load_size 28 | elif opt.preprocess == 'scale_width_and_crop': 29 | new_w = opt.load_size 30 | new_h = opt.load_size * h // w 31 | 32 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 33 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 34 | 35 | flip = random.random() > 0.5 36 | 37 | return {'crop_pos': (x, y), 'flip': flip} 38 | 39 | 40 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 41 | transform_list = [] 42 | if grayscale: 43 | transform_list.append(transforms.Grayscale(1)) 44 | if 'resize' in opt.preprocess: 45 | osize = [opt.load_size, opt.load_size] 46 | transform_list.append(transforms.Resize(osize, method)) 47 | elif 'scale_width' in opt.preprocess: 48 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 49 | 50 | if 'crop' in opt.preprocess: 51 | if params is None: 52 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 53 | else: 54 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 55 | 56 | if opt.preprocess == 'none': 57 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 58 | 59 | if not opt.no_flip: 60 | if params is None: 61 | transform_list.append(transforms.RandomHorizontalFlip()) 62 | elif params['flip']: 63 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 64 | 65 | if convert: 66 | transform_list += [transforms.ToTensor()] 67 | if grayscale: 68 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 69 | else: 70 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 71 | return transforms.Compose(transform_list) 72 | -------------------------------------------------------------------------------- /cae60fe4035e7646068a5891faccd13d.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/cae60fe4035e7646068a5891faccd13d.webp -------------------------------------------------------------------------------- /cd7d4fb60475341e71ee2c1fb681eb52.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/cd7d4fb60475341e71ee2c1fb681eb52.webp -------------------------------------------------------------------------------- /f26fa8698155e8becf1158db712af304.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-CycleGAN-pix2pix-Monochrome-Colorization/83f44f801b8e1ffc42db9040bef9ac2a3772ea5b/f26fa8698155e8becf1158db712af304.webp -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | class Generator(nn.Module): 4 | def __init__(self, input_shape): 5 | super(Generator, self).__init__() 6 | self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=4, stride=2, padding=1) 7 | self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) 8 | self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1) 9 | # Add more layers as needed for your U-Net generator architecture 10 | 11 | def forward(self, x): 12 | x = self.conv1(x) 13 | x = nn.functional.relu(x) 14 | x = self.conv2(x) 15 | x = nn.functional.relu(x) 16 | x = self.conv3(x) 17 | x = nn.functional.relu(x) 18 | # Add forward pass for the rest of your generator layers 19 | return x 20 | 21 | class Discriminator(nn.Module): 22 | def __init__(self, input_shape): 23 | super(Discriminator, self).__init__() 24 | self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=4, stride=2, padding=1) 25 | self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) 26 | self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1) 27 | # Add more layers as needed for your PatchGAN discriminator architecture 28 | self.fc = nn.Linear(final_size, 1) # Define final_size correctly 29 | 30 | def forward(self, x): 31 | x = self.conv1(x) 32 | x = nn.functional.leaky_relu(x, negative_slope=0.2) 33 | x = self.conv2(x) 34 | x = nn.functional.leaky_relu(x, negative_slope=0.2) 35 | x = self.conv3(x) 36 | x = nn.functional.leaky_relu(x, negative_slope=0.2) 37 | # Add forward pass for the rest of your discriminator layers 38 | x = x.view(x.size(0), -1) 39 | x = self.fc(x) 40 | return x 41 | 42 | class ImageColorizationGAN: 43 | def __init__(self, image_shape): 44 | self.image_shape = image_shape 45 | self.generator = self.build_generator() 46 | self.discriminator = self.build_discriminator() 47 | self.combined = self.build_combined() 48 | 49 | def build_generator(self): 50 | return Generator(self.image_shape) 51 | 52 | def build_discriminator(self): 53 | return Discriminator(self.image_shape) 54 | 55 | def build_combined(self): 56 | self.discriminator.trainable = False 57 | combined_model = nn.Sequential(self.generator, self.discriminator) 58 | return combined_model 59 | 60 | ...... 61 | 62 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | 4 | class ImageTranslator: 5 | def __init__(self): 6 | self.opt = TestOptions().parse() # get test options 7 | # hard-code some parameters for test 8 | self.opt.num_threads = 0 # test code only supports num_threads = 0 9 | self.opt.batch_size = 1 # test code only supports batch_size = 1 10 | self.opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 11 | self.opt.no_flip = True # no flip; comment this line if results on flipped images are needed. 12 | self.opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 13 | self.dataset = create_dataset(self.opt) # create a dataset given opt.dataset_mode and other options 14 | self.model = create_model(self.opt) # create a model given opt.model and other options 15 | self.model.setup(self.opt) # regular setup: load and print networks; create schedulers 16 | 17 | def translate_images(self): 18 | # create a website 19 | web_dir = os.path.join(self.opt.results_dir, self.opt.name, '{}_{}'.format(self.opt.phase, self.opt.epoch)) # define the website directory 20 | if self.opt.load_iter > 0: # load_iter is 0 by default 21 | web_dir = '{:s}_iter{:d}'.format(web_dir, self.opt.load_iter) 22 | print('creating web directory', web_dir) 23 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (self.opt.name, self.opt.phase, self.opt.epoch)) 24 | # test with eval mode. This only affects layers like batchnorm and dropout. 25 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 26 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 27 | if self.opt.eval: 28 | self.model.eval() 29 | for i, data in enumerate(self.dataset): 30 | if i >= self.opt.num_test: # only apply our model to opt.num_test images. 31 | break 32 | self.model.set_input(data) # unpack data from data loader 33 | self.model.test() # run inference 34 | visuals = self.model.get_current_visuals() # get image results 35 | img_path = self.model.get_image_paths() # get image paths 36 | if i % 5 == 0: # save images to an HTML file 37 | print('processing (%04d)-th image... %s' % (i, img_path)) 38 | save_images(webpage, visuals, img_path, aspect_ratio=self.opt.aspect_ratio, width=self.opt.display_winsize, use_wandb=self.opt.use_wandb) 39 | webpage.save() # save the HTML 40 | 41 | if __name__ == '__main__': 42 | translator = ImageTranslator() 43 | translator.translate_images() 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | 4 | class ImageToImageTranslationTrainer: 5 | def __init__(self): 6 | self.opt = TrainOptions().parse() 7 | self.dataset = create_dataset(self.opt) 8 | self.dataset_size = len(self.dataset) 9 | self.model = create_model(self.opt) 10 | self.visualizer = Visualizer(self.opt) 11 | self.total_iters = 0 12 | 13 | def train(self): 14 | for epoch in range(self.opt.epoch_count, self.opt.n_epochs + self.opt.n_epochs_decay + 1): 15 | epoch_start_time = time.time() 16 | iter_data_time = time.time() 17 | epoch_iter = 0 18 | self.visualizer.reset() 19 | self.model.update_learning_rate() 20 | for i, data in enumerate(self.dataset): 21 | iter_start_time = time.time() 22 | if self.total_iters % self.opt.print_freq == 0: 23 | t_data = iter_start_time - iter_data_time 24 | 25 | self.total_iters += self.opt.batch_size 26 | epoch_iter += self.opt.batch_size 27 | self.model.set_input(data) 28 | self.model.optimize_parameters() 29 | 30 | if self.total_iters % self.opt.display_freq == 0: 31 | save_result = self.total_iters % self.opt.update_html_freq == 0 32 | self.model.compute_visuals() 33 | self.visualizer.display_current_results(self.model.get_current_visuals(), epoch, save_result) 34 | 35 | if self.total_iters % self.opt.print_freq == 0: 36 | losses = self.model.get_current_losses() 37 | t_comp = (time.time() - iter_start_time) / self.opt.batch_size 38 | self.visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 39 | if self.opt.display_id > 0: 40 | self.visualizer.plot_current_losses(epoch, float(epoch_iter) / self.dataset_size, losses) 41 | 42 | if self.total_iters % self.opt.save_latest_freq == 0: 43 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, self.total_iters)) 44 | save_suffix = 'iter_%d' % self.total_iters if self.opt.save_by_iter else 'latest' 45 | self.model.save_networks(save_suffix) 46 | 47 | iter_data_time = time.time() 48 | if epoch % self.opt.save_epoch_freq == 0: 49 | print('saving the model at the end of epoch %d, iters %d' % (epoch, self.total_iters)) 50 | self.model.save_networks('latest') 51 | self.model.save_networks(epoch) 52 | 53 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, self.opt.n_epochs + self.opt.n_epochs_decay, time.time() - epoch_start_time)) 54 | ...... 55 | --------------------------------------------------------------------------------