├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
└── siggraph_pretrained
│ └── sweep_reference.png
├── data
├── __init__.py
├── aligned_dataset.py
├── base_data_loader.py
├── base_dataset.py
├── color_dataset.py
├── image_folder.py
└── single_dataset.py
├── environment.yml
├── imgs
└── demo.gif
├── make_ilsvrc_dataset.py
├── models
├── __init__.py
├── base_model.py
├── networks.py
└── pix2pix_model.py
├── options
├── __init__.py
├── base_options.py
└── train_options.py
├── pretrained_models
└── download_siggraph_model.sh
├── requirements.txt
├── resources
├── ilsvrclin12_val_inds.npy
└── psnrs_siggraph.npy
├── scripts
├── check_all.sh
├── conda_deps.sh
├── install_deps.sh
└── train_siggraph.sh
├── test.py
├── test_sweep.py
├── train.py
└── util
├── __init__.py
├── get_data.py
├── html.py
├── image_pool.py
├── util.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | debug*
3 | datasets/
4 | checkpoints/
5 | results/
6 | build/
7 | dist/
8 | log.txt
9 | *.png
10 | torch.egg-info/
11 | */**/__pycache__
12 | torch/version.py
13 | torch/csrc/generic/TensorMethods.cpp
14 | torch/lib/*.so*
15 | torch/lib/*.dylib*
16 | torch/lib/*.h
17 | torch/lib/build
18 | torch/lib/tmp_install
19 | torch/lib/include
20 | torch/lib/torch_shm_manager
21 | torch/csrc/cudnn/cuDNN.cpp
22 | torch/csrc/nn/THNN.cwrap
23 | torch/csrc/nn/THNN.cpp
24 | torch/csrc/nn/THCUNN.cwrap
25 | torch/csrc/nn/THCUNN.cpp
26 | torch/csrc/nn/THNN_generic.cwrap
27 | torch/csrc/nn/THNN_generic.cpp
28 | torch/csrc/nn/THNN_generic.h
29 | docs/src/**/*
30 | test/data/legacy_modules.t7
31 | test/data/gpu_tensors.pt
32 | test/htmlcov
33 | test/.coverage
34 | */*.pyc
35 | */**/*.pyc
36 | */**/**/*.pyc
37 | */**/**/**/*.pyc
38 | */**/**/**/**/*.pyc
39 | */*.so*
40 | */**/*.so*
41 | */**/*.dylib*
42 | test/data/legacy_serialized.pt
43 | *~
44 | .idea
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Richard Zhang, Jun-Yan Zhu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Interactive Deep Colorization in PyTorch
2 | [Project Page](https://richzhang.github.io/ideepcolor/) | [Paper](https://arxiv.org/abs/1705.02999) | [Video](https://youtu.be/eL5ilZgM89Q) | [Talk](https://www.youtube.com/watch?v=rp5LUSbdsys) | [UI code](https://github.com/junyanz/interactive-deep-colorization/)
3 |
4 |
5 |
6 | Real-Time User-Guided Image Colorization with Learned Deep Priors.
7 | [Richard Zhang](https://richzhang.github.io/)\*, [Jun-Yan Zhu](http://people.csail.mit.edu/junyanz/)\*, [Phillip Isola](http://people.eecs.berkeley.edu/~isola/), [Xinyang Geng](http://young-geng.xyz/), Angela S. Lin, Tianhe Yu, and [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/).
8 | In ACM Transactions on Graphics (SIGGRAPH 2017).
9 |
10 | This is our PyTorch reimplementation for interactive image colorization, written by [Richard Zhang](https://github.com/richzhang) and [Jun-Yan Zhu](https://github.com/junyanz).
11 |
12 | This repository contains training usage. The original, official GitHub repo (with an interactive GUI, and originally Caffe backend) is [here](https://richzhang.github.io/ideepcolor/). The official repo has been updated to support PyTorch models on the backend, which can be trained in this repository.
13 |
14 | ## Prerequisites
15 | - Linux or macOS
16 | - Python 2 or 3
17 | - CPU or NVIDIA GPU + CUDA CuDNN
18 |
19 | ## Getting Started
20 | ### Installation
21 | - Install PyTorch 0.4+ and torchvision from http://pytorch.org and other dependencies (e.g., [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate)). You can install all the dependencies by
22 | ```bash
23 | pip install -r requirements.txt
24 | ```
25 | - Clone this repo:
26 | ```bash
27 | git clone https://github.com/richzhang/colorization-pytorch
28 | cd colorization-pytorch
29 | ```
30 |
31 | ### Dataset preparation
32 | - Download the ILSVRC 2012 dataset and run the following script to prepare data
33 | `python make_ilsvrc_dataset.py --in_path /PATH/TO/ILSVRC12`. This will make symlinks into the training set, and divide the ILSVRC validation set into validation and test splits for colorization.
34 |
35 | ### Training interactive colorization
36 | - Train a model: ```bash ./scripts/train_siggraph.sh```. This is a 2 stage training process. First, the network is trained for automatic colorization using classification loss. Results are in `./checkpoints/siggraph_class`. Then, the network is fine-tuned for interactive colorization using regression loss. Final results are in `./checkpoints/siggraph_reg2`.
37 |
38 | - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. The following values are monitored:
39 | * `G_CE` is a cross-entropy loss between predicted color distribution and ground truth color.
40 | * `G_entr` is the entropy of the predicted distribution.
41 | * `G_entr_hint` is the entropy of the predicted distribution at points where a color hint is given.
42 | * `G_L1_max` is the L1 distance between the ground truth color and argmax of the predicted color distribution.
43 | * `G_L1_mean` is the L1 distance between the ground truth color and mean of the predicted color distribution.
44 | * `G_L1_reg` is the L1 distance between the ground truth color and the predicted color.
45 | * `G_fake_real` is the L1 distance between the predicted color and the ground truth color (in locations where a hint is given).
46 | * `G_fake_hint` is the L1 distance between the predicted color and the input hint color (in locations where a hint is given). It's a measure of how much the network "trusts" the input hint.
47 | * `G_real_hint` is the L1 distance between the ground truth color and the input hint color (in locations where a hint is given).
48 |
49 |
50 | ### Testing interactive colorization
51 | - Get a model. Either:
52 | * (1) download the pretrained model by running ```bash pretrained_models/download_siggraph_model.sh```, which will give you a few models.
53 | * Original caffe weights [Recommended] `./checkpoints/siggraph_caffemodel/latest_net_G.pth` is the original caffemodel weights, converted to PyTorch. It is recommended. Be sure to set `--mask_cent 0` when running it.
54 | * Retrained model: `./checkpoints/siggraph_retrained/latest_net_G.pth`. The model achieves better PSNR but performs qualitatively differently. Note that this repository is an approximate reimplementation of the siggraph paper.
55 | * (2) train your own model (as described in the section above), which will leave a model in `./checkpoints/siggraph_reg2/latest_net_G.pth`
56 |
57 | - Test the model on validation data:
58 | * ```python test.py --name siggraph_caffemodel --mask_cent 0``` for original caffemodel weights
59 | * ```python test.py --name siggraph_retrained ``` for retrained weights.
60 | * ```python test.py --name siggraph_reg2 ``` if you retrained your own model
61 | The test results will be saved to an HTML file in `./results/[[NAME]]/latest_val/index.html`. For each image in the validation set, it will test (1) automatic colorization, (2) interactive colorization with a few random hints, and (3) interactive colorization with lots of random hints.
62 |
63 | - Test the model by making PSNR vs. the number of hints plot: ```python test_sweep.py --name [[NAME]] ```. This plot was used in Figure 6 of the [paper](https://arxiv.org/abs/1705.02999). This test randomly reveals 6x6 color hint patches to the network and sees how accurate the colorization is with respect to the ground truth.
64 |
65 |
66 |
67 | - Test the model interactively with the original official [repository](https://github.com/junyanz/interactive-deep-colorization). Follow installation instructions in that repo and run `python ideepcolor.py --backend pytorch --color_model [[PTH/TO/MODEL]] --dist_model [[PTH/TO/MODEL]]`.
68 |
69 |
70 | ### Citation
71 | If you use this code for your research, please cite our paper:
72 | ```
73 | @article{zhang2017real,
74 | title={Real-Time User-Guided Image Colorization with Learned Deep Priors},
75 | author={Zhang, Richard and Zhu, Jun-Yan and Isola, Phillip and Geng, Xinyang and Lin, Angela S and Yu, Tianhe and Efros, Alexei A},
76 | journal={ACM Transactions on Graphics (TOG)},
77 | volume={9},
78 | number={4},
79 | year={2017},
80 | publisher={ACM}
81 | }
82 | ```
83 |
84 | ## Acknowledgments
85 | This code borrows heavily from the [pytorch-CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository.
86 |
--------------------------------------------------------------------------------
/checkpoints/siggraph_pretrained/sweep_reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/checkpoints/siggraph_pretrained/sweep_reference.png
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.utils.data
3 | from data.base_data_loader import BaseDataLoader
4 | from data.base_dataset import BaseDataset
5 |
6 |
7 | def find_dataset_using_name(dataset_name):
8 | # Given the option --dataset_mode [datasetname],
9 | # the file "data/datasetname_dataset.py"
10 | # will be imported.
11 | dataset_filename = "data." + dataset_name + "_dataset"
12 | datasetlib = importlib.import_module(dataset_filename)
13 |
14 | # In the file, the class called DatasetNameDataset() will
15 | # be instantiated. It has to be a subclass of BaseDataset,
16 | # and it is case-insensitive.
17 | dataset = None
18 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
19 | for name, cls in datasetlib.__dict__.items():
20 | if name.lower() == target_dataset_name.lower() \
21 | and issubclass(cls, BaseDataset):
22 | dataset = cls
23 |
24 | if dataset is None:
25 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
26 | exit(0)
27 |
28 | return dataset
29 |
30 |
31 | def get_option_setter(dataset_name):
32 | dataset_class = find_dataset_using_name(dataset_name)
33 | return dataset_class.modify_commandline_options
34 |
35 |
36 | def create_dataset(opt):
37 | dataset = find_dataset_using_name(opt.dataset_mode)
38 | instance = dataset()
39 | instance.initialize(opt)
40 | print("dataset [%s] was created" % (instance.name()))
41 | return instance
42 |
43 |
44 | def CreateDataLoader(opt):
45 | data_loader = CustomDatasetDataLoader()
46 | data_loader.initialize(opt)
47 | return data_loader
48 |
49 |
50 | # Wrapper class of Dataset class that performs
51 | # multi-threaded data loading
52 | class CustomDatasetDataLoader(BaseDataLoader):
53 | def name(self):
54 | return 'CustomDatasetDataLoader'
55 |
56 | def initialize(self, opt):
57 | BaseDataLoader.initialize(self, opt)
58 | self.dataset = create_dataset(opt)
59 | self.dataloader = torch.utils.data.DataLoader(
60 | self.dataset,
61 | batch_size=opt.batch_size,
62 | shuffle=not opt.serial_batches,
63 | num_workers=int(opt.num_threads))
64 |
65 | def load_data(self):
66 | return self
67 |
68 | def __len__(self):
69 | return min(len(self.dataset), self.opt.max_dataset_size)
70 |
71 | def __iter__(self):
72 | for i, data in enumerate(self.dataloader):
73 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
74 | break
75 | yield data
76 |
--------------------------------------------------------------------------------
/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import torchvision.transforms as transforms
4 | import torch
5 | from data.base_dataset import BaseDataset
6 | from data.image_folder import make_dataset
7 | from PIL import Image
8 |
9 |
10 | class AlignedDataset(BaseDataset):
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | return parser
14 |
15 | def initialize(self, opt):
16 | self.opt = opt
17 | self.root = opt.dataroot
18 | self.dir_AB = os.path.join(opt.dataroot, opt.phase)
19 | self.AB_paths = sorted(make_dataset(self.dir_AB))
20 | assert(opt.resize_or_crop == 'resize_and_crop')
21 |
22 | def __getitem__(self, index):
23 | AB_path = self.AB_paths[index]
24 | AB = Image.open(AB_path).convert('RGB')
25 | w, h = AB.size
26 | w2 = int(w / 2)
27 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
28 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
29 | A = transforms.ToTensor()(A)
30 | B = transforms.ToTensor()(B)
31 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
32 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
33 |
34 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
35 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
36 |
37 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
38 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
39 |
40 | if self.opt.which_direction == 'BtoA':
41 | input_nc = self.opt.output_nc
42 | output_nc = self.opt.input_nc
43 | else:
44 | input_nc = self.opt.input_nc
45 | output_nc = self.opt.output_nc
46 |
47 | if (not self.opt.no_flip) and random.random() < 0.5:
48 | idx = [i for i in range(A.size(2) - 1, -1, -1)]
49 | idx = torch.LongTensor(idx)
50 | A = A.index_select(2, idx)
51 | B = B.index_select(2, idx)
52 |
53 | if input_nc == 1: # RGB to gray
54 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
55 | A = tmp.unsqueeze(0)
56 |
57 | if output_nc == 1: # RGB to gray
58 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
59 | B = tmp.unsqueeze(0)
60 |
61 | return {'A': A, 'B': B,
62 | 'A_paths': AB_path, 'B_paths': AB_path}
63 |
64 | def __len__(self):
65 | return len(self.AB_paths)
66 |
67 | def name(self):
68 | return 'AlignedDataset'
69 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 | class BaseDataLoader():
2 | def __init__(self):
3 | pass
4 |
5 | def initialize(self, opt):
6 | self.opt = opt
7 | pass
8 |
9 | def load_data():
10 | return None
11 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 |
5 |
6 | class BaseDataset(data.Dataset):
7 | def __init__(self):
8 | super(BaseDataset, self).__init__()
9 |
10 | def name(self):
11 | return 'BaseDataset'
12 |
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train):
15 | return parser
16 |
17 | def initialize(self, opt):
18 | pass
19 |
20 | def __len__(self):
21 | return 0
22 |
23 |
24 | def get_transform(opt):
25 | transform_list = []
26 | if opt.resize_or_crop == 'resize_and_crop':
27 | osize = [opt.loadSize, opt.loadSize]
28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC))
29 | transform_list.append(transforms.RandomCrop(opt.fineSize))
30 | elif opt.resize_or_crop == 'crop':
31 | transform_list.append(transforms.RandomCrop(opt.fineSize))
32 | elif opt.resize_or_crop == 'scale_width':
33 | transform_list.append(transforms.Lambda(
34 | lambda img: __scale_width(img, opt.fineSize)))
35 | elif opt.resize_or_crop == 'scale_width_and_crop':
36 | transform_list.append(transforms.Lambda(
37 | lambda img: __scale_width(img, opt.loadSize)))
38 | transform_list.append(transforms.RandomCrop(opt.fineSize))
39 | elif opt.resize_or_crop == 'none':
40 | transform_list.append(transforms.Lambda(
41 | lambda img: __adjust(img)))
42 | else:
43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
44 |
45 | if opt.isTrain and not opt.no_flip:
46 | transform_list.append(transforms.RandomHorizontalFlip())
47 |
48 | transform_list += [transforms.ToTensor(),
49 | transforms.Normalize((0.5, 0.5, 0.5),
50 | (0.5, 0.5, 0.5))]
51 | return transforms.Compose(transform_list)
52 |
53 | # just modify the width and height to be multiple of 4
54 |
55 |
56 | def __adjust(img):
57 | ow, oh = img.size
58 |
59 | # the size needs to be a multiple of this number,
60 | # because going through generator network may change img size
61 | # and eventually cause size mismatch error
62 | mult = 4
63 | if ow % mult == 0 and oh % mult == 0:
64 | return img
65 | w = (ow - 1) // mult
66 | w = (w + 1) * mult
67 | h = (oh - 1) // mult
68 | h = (h + 1) * mult
69 |
70 | if ow != w or oh != h:
71 | __print_size_warning(ow, oh, w, h)
72 |
73 | return img.resize((w, h), Image.BICUBIC)
74 |
75 |
76 | def __scale_width(img, target_width):
77 | ow, oh = img.size
78 |
79 | # the size needs to be a multiple of this number,
80 | # because going through generator network may change img size
81 | # and eventually cause size mismatch error
82 | mult = 4
83 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
84 | if (ow == target_width and oh % mult == 0):
85 | return img
86 | w = target_width
87 | target_height = int(target_width * oh / ow)
88 | m = (target_height - 1) // mult
89 | h = (m + 1) * mult
90 |
91 | if target_height != h:
92 | __print_size_warning(target_width, target_height, w, h)
93 |
94 | return img.resize((w, h), Image.BICUBIC)
95 |
96 |
97 | def __print_size_warning(ow, oh, w, h):
98 | if not hasattr(__print_size_warning, 'has_printed'):
99 | print("The image size needs to be a multiple of 4. "
100 | "The loaded image size was (%d, %d), so it was adjusted to "
101 | "(%d, %d). This adjustment will be done to all images "
102 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
103 | __print_size_warning.has_printed = True
104 |
--------------------------------------------------------------------------------
/data/color_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_transform
3 | from data.image_folder import make_dataset
4 | from PIL import Image
5 |
6 |
7 | class ColorDataset(BaseDataset):
8 | @staticmethod
9 | def modify_commandline_options(parser, is_train):
10 | return parser
11 |
12 | def initialize(self, opt):
13 | self.opt = opt
14 | self.root = opt.dataroot
15 | self.dir_A = os.path.join(opt.dataroot)
16 |
17 | self.A_paths = make_dataset(self.dir_A)
18 |
19 | self.A_paths = sorted(self.A_paths)
20 |
21 | self.transform = get_transform(opt)
22 |
23 | def __getitem__(self, index):
24 | A_path = self.A_paths[index]
25 | A_img = Image.open(A_path).convert('RGB')
26 | A = self.transform(A_img)
27 | if self.opt.which_direction == 'BtoA':
28 | input_nc = self.opt.output_nc
29 | else:
30 | input_nc = self.opt.input_nc
31 |
32 | # convert to Lab
33 | # rgb2lab(A_img)
34 |
35 | if input_nc == 1: # RGB to gray
36 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
37 | A = tmp.unsqueeze(0)
38 |
39 | return {'A': A, 'A_paths': A_path}
40 |
41 | def __len__(self):
42 | return len(self.A_paths)
43 |
44 | def name(self):
45 | return 'ColorImageDataset'
46 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 |
34 | return images
35 |
36 |
37 | def default_loader(path):
38 | return Image.open(path).convert('RGB')
39 |
40 |
41 | class ImageFolder(data.Dataset):
42 |
43 | def __init__(self, root, transform=None, return_paths=False,
44 | loader=default_loader):
45 | imgs = make_dataset(root)
46 | if len(imgs) == 0:
47 | raise(RuntimeError("Found 0 images in: " + root + "\n"
48 | "Supported image extensions are: " +
49 | ",".join(IMG_EXTENSIONS)))
50 |
51 | self.root = root
52 | self.imgs = imgs
53 | self.transform = transform
54 | self.return_paths = return_paths
55 | self.loader = loader
56 |
57 | def __getitem__(self, index):
58 | path = self.imgs[index]
59 | img = self.loader(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.return_paths:
63 | return img, path
64 | else:
65 | return img
66 |
67 | def __len__(self):
68 | return len(self.imgs)
69 |
--------------------------------------------------------------------------------
/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_transform
3 | from data.image_folder import make_dataset
4 | from PIL import Image
5 |
6 |
7 | class SingleDataset(BaseDataset):
8 | @staticmethod
9 | def modify_commandline_options(parser, is_train):
10 | return parser
11 |
12 | def initialize(self, opt):
13 | self.opt = opt
14 | self.root = opt.dataroot
15 | self.dir_A = os.path.join(opt.dataroot)
16 |
17 | self.A_paths = make_dataset(self.dir_A)
18 |
19 | self.A_paths = sorted(self.A_paths)
20 |
21 | self.transform = get_transform(opt)
22 |
23 | def __getitem__(self, index):
24 | A_path = self.A_paths[index]
25 | A_img = Image.open(A_path).convert('RGB')
26 | A = self.transform(A_img)
27 | if self.opt.which_direction == 'BtoA':
28 | input_nc = self.opt.output_nc
29 | else:
30 | input_nc = self.opt.input_nc
31 |
32 | if input_nc == 1: # RGB to gray
33 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
34 | A = tmp.unsqueeze(0)
35 |
36 | return {'A': A, 'A_paths': A_path}
37 |
38 | def __len__(self):
39 | return len(self.A_paths)
40 |
41 | def name(self):
42 | return 'SingleImageDataset'
43 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: colorization-pytorch
2 | channels:
3 | - peterjc123
4 | - defaults
5 | dependencies:
6 | - python=3.5.5
7 | - pytorch=0.4.1
8 | - scipy
9 | - pip:
10 | - dominate==2.3.1
11 | - git+https://github.com/pytorch/vision.git
12 | - Pillow==5.0.0
13 | - numpy==1.14.1
14 | - visdom==0.1.7
15 |
--------------------------------------------------------------------------------
/imgs/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/imgs/demo.gif
--------------------------------------------------------------------------------
/make_ilsvrc_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from util import util
4 | import numpy as np
5 | import argparse
6 |
7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
8 | parser.add_argument('--in_path', type=str, default='/data/big/dataset/ILSVRC2012')
9 | parser.add_argument('--out_path', type=str, default='./dataset/ilsvrc2012/')
10 |
11 | opt = parser.parse_args()
12 | orig_path = opt.in_path
13 | print('Copying ILSVRC from...[%s]' % orig_path)
14 |
15 | # Copy over part of training set (for initializer)
16 | trn_small_path = os.path.join(opt.out_path, 'train_small')
17 | util.mkdirs(opt.out_path)
18 | util.mkdirs(trn_small_path)
19 | train_subdirs = os.listdir(os.path.join(opt.in_path, 'train'))
20 | for train_subdir in train_subdirs[:10]:
21 | os.symlink(os.path.join(opt.in_path, 'train', train_subdir), os.path.join(trn_small_path, train_subdir))
22 | print('Making small training set in...[%s]' % trn_small_path)
23 |
24 | # Copy over whole training set
25 | trn_path = os.path.join(opt.out_path, 'train')
26 | util.mkdirs(opt.out_path)
27 | os.symlink(os.path.join(opt.in_path, 'train'), trn_path)
28 | print('Making training set in...[%s]' % trn_path)
29 |
30 | # Copy over subset of ILSVRC12 val set for colorization val set
31 | val_path = os.path.join(opt.out_path, 'val/imgs')
32 | util.mkdirs(val_path)
33 | print('Making validation set in...[%s]' % val_path)
34 | for val_ind in range(1000):
35 | os.system('ln -s %s/val/ILSVRC2012_val_%08d.JPEG %s/ILSVRC2012_val_%08d.JPEG' % (orig_path, val_ind + 1, val_path, val_ind + 1))
36 | # os.system('cp %s/val/ILSVRC2012_val_%08d.JPEG %s/ILSVRC2012_val_%08d.JPEG'%(orig_path,val_ind+1,val_path,val_ind+1))
37 |
38 | # Copy over subset of ILSVRC12 val set for colorization test set
39 | test_path = os.path.join(opt.out_path, 'test/imgs')
40 | util.mkdirs(test_path)
41 | val_inds = np.load('./resources/ilsvrclin12_val_inds.npy')
42 | print('Making test set in...[%s]' % test_path)
43 | for val_ind in val_inds:
44 | os.system('ln -s %s/val/ILSVRC2012_val_%08d.JPEG %s/ILSVRC2012_val_%08d.JPEG' % (orig_path, val_ind + 1, test_path, val_ind + 1))
45 | # os.system('cp %s/val/ILSVRC2012_val_%08d.JPEG %s/ILSVRC2012_val_%08d.JPEG'%(orig_path,val_ind+1,test_path,val_ind+1))
46 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from models.base_model import BaseModel
3 |
4 |
5 | def find_model_using_name(model_name):
6 | # Given the option --model [modelname],
7 | # the file "models/modelname_model.py"
8 | # will be imported.
9 | model_filename = "models." + model_name + "_model"
10 | modellib = importlib.import_module(model_filename)
11 |
12 | # In the file, the class called ModelNameModel() will
13 | # be instantiated. It has to be a subclass of BaseModel,
14 | # and it is case-insensitive.
15 | model = None
16 | target_model_name = model_name.replace('_', '') + 'model'
17 | for name, cls in modellib.__dict__.items():
18 | if name.lower() == target_model_name.lower() \
19 | and issubclass(cls, BaseModel):
20 | model = cls
21 |
22 | if model is None:
23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
24 | exit(0)
25 |
26 | return model
27 |
28 |
29 | def get_option_setter(model_name):
30 | model_class = find_model_using_name(model_name)
31 | return model_class.modify_commandline_options
32 |
33 |
34 | def create_model(opt):
35 | model = find_model_using_name(opt.model)
36 | instance = model()
37 | instance.initialize(opt)
38 | print("model [%s] was created" % (instance.name()))
39 | return instance
40 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from . import networks
5 |
6 |
7 | class BaseModel():
8 | # modify parser to add command line options,
9 | # and also change the default values if needed
10 | @staticmethod
11 | def modify_commandline_options(parser, is_train):
12 | return parser
13 |
14 | def name(self):
15 | return 'BaseModel'
16 |
17 | def initialize(self, opt):
18 | self.opt = opt
19 | self.gpu_ids = opt.gpu_ids
20 | self.isTrain = opt.isTrain
21 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
22 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
23 | if opt.resize_or_crop != 'scale_width':
24 | torch.backends.cudnn.benchmark = True
25 | self.loss_names = []
26 | self.model_names = []
27 | self.visual_names = []
28 | self.image_paths = []
29 |
30 | def set_input(self, input):
31 | self.input = input
32 |
33 | def forward(self):
34 | pass
35 |
36 | # load and print networks; create schedulers
37 | def setup(self, opt, parser=None):
38 | if self.isTrain:
39 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
40 |
41 | if not self.isTrain or opt.load_model:
42 | self.load_networks(opt.which_epoch)
43 | self.print_networks(opt.verbose)
44 |
45 | # make models eval mode during test time
46 | def eval(self):
47 | for name in self.model_names:
48 | if isinstance(name, str):
49 | net = getattr(self, 'net' + name)
50 | net.eval()
51 |
52 | # used in test time, wrapping `forward` in no_grad() so we don't save
53 | # intermediate steps for backprop
54 | def test(self, compute_losses=False):
55 | with torch.no_grad():
56 | self.forward()
57 | if(compute_losses):
58 | self.compute_losses_G()
59 |
60 | # get image paths
61 | def get_image_paths(self):
62 | return self.image_paths
63 |
64 | def optimize_parameters(self):
65 | pass
66 |
67 | # update learning rate (called once every epoch)
68 | def update_learning_rate(self):
69 | for scheduler in self.schedulers:
70 | scheduler.step()
71 | lr = self.optimizers[0].param_groups[0]['lr']
72 | print('learning rate = %.7f' % lr)
73 |
74 | # return visualization images. train.py will display these images, and save the images to a html
75 | def get_current_visuals(self):
76 | visual_ret = OrderedDict()
77 | for name in self.visual_names:
78 | if isinstance(name, str):
79 | visual_ret[name] = getattr(self, name)
80 | return visual_ret
81 |
82 | # return traning losses/errors. train.py will print out these errors as debugging information
83 | def get_current_losses(self):
84 | errors_ret = OrderedDict()
85 | for name in self.loss_names:
86 | if isinstance(name, str):
87 | # float(...) works for both scalar tensor and float number
88 | errors_ret[name] = float(getattr(self, 'loss_' + name))
89 | return errors_ret
90 |
91 | # save models to the disk
92 | def save_networks(self, which_epoch):
93 | for name in self.model_names:
94 | if isinstance(name, str):
95 | save_filename = '%s_net_%s.pth' % (which_epoch, name)
96 | save_path = os.path.join(self.save_dir, save_filename)
97 | net = getattr(self, 'net' + name)
98 |
99 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
100 | torch.save(net.module.cpu().state_dict(), save_path)
101 | net.cuda(self.gpu_ids[0])
102 | else:
103 | torch.save(net.cpu().state_dict(), save_path)
104 |
105 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
106 | key = keys[i]
107 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
108 | if module.__class__.__name__.startswith('InstanceNorm') and \
109 | (key == 'running_mean' or key == 'running_var'):
110 | if getattr(module, key) is None:
111 | state_dict.pop('.'.join(keys))
112 | if module.__class__.__name__.startswith('InstanceNorm') and \
113 | (key == 'num_batches_tracked'):
114 | state_dict.pop('.'.join(keys))
115 | else:
116 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
117 |
118 | # load models from the disk
119 | def load_networks(self, which_epoch):
120 | for name in self.model_names:
121 | if isinstance(name, str):
122 | load_filename = '%s_net_%s.pth' % (which_epoch, name)
123 | load_path = os.path.join(self.save_dir, load_filename)
124 | net = getattr(self, 'net' + name)
125 | if isinstance(net, torch.nn.DataParallel):
126 | net = net.module
127 | print('loading the model from %s' % load_path)
128 | # if you are using PyTorch newer than 0.4 (e.g., built from
129 | # GitHub source), you can remove str() on self.device
130 | state_dict = torch.load(load_path, map_location=str(self.device))
131 | if hasattr(state_dict, '_metadata'):
132 | del state_dict._metadata
133 |
134 | # patch InstanceNorm checkpoints prior to 0.4
135 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
136 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
137 | net.load_state_dict(state_dict)
138 |
139 | # print network information
140 | def print_networks(self, verbose):
141 | print('---------- Networks initialized -------------')
142 | for name in self.model_names:
143 | if isinstance(name, str):
144 | net = getattr(self, 'net' + name)
145 | num_params = 0
146 | for param in net.parameters():
147 | num_params += param.numel()
148 | if verbose:
149 | print(net)
150 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
151 | print('-----------------------------------------------')
152 |
153 | # set requies_grad=Fasle to avoid computation
154 | def set_requires_grad(self, nets, requires_grad=False):
155 | if not isinstance(nets, list):
156 | nets = [nets]
157 | for net in nets:
158 | if net is not None:
159 | for param in net.parameters():
160 | param.requires_grad = requires_grad
161 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 |
7 | ###############################################################################
8 | # Helper Functions
9 | ###############################################################################
10 |
11 |
12 | def get_norm_layer(norm_type='instance'):
13 | if norm_type == 'batch':
14 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
15 | elif norm_type == 'instance':
16 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
17 | elif norm_type == 'none':
18 | norm_layer = None
19 | else:
20 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
21 | return norm_layer
22 |
23 |
24 | def get_scheduler(optimizer, opt):
25 | if opt.lr_policy == 'lambda':
26 | def lambda_rule(epoch):
27 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
28 | return lr_l
29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
30 | elif opt.lr_policy == 'step':
31 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
32 | elif opt.lr_policy == 'plateau':
33 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
34 | else:
35 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
36 | return scheduler
37 |
38 |
39 | def init_weights(net, init_type='xavier', gain=0.02):
40 | def init_func(m):
41 | classname = m.__class__.__name__
42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
43 | if init_type == 'normal':
44 | init.normal_(m.weight.data, 0.0, gain)
45 | elif init_type == 'xavier':
46 | init.xavier_normal_(m.weight.data, gain=gain)
47 | elif init_type == 'kaiming':
48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
49 | elif init_type == 'orthogonal':
50 | init.orthogonal_(m.weight.data, gain=gain)
51 | else:
52 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
53 | if hasattr(m, 'bias') and m.bias is not None:
54 | init.constant_(m.bias.data, 0.0)
55 | elif classname.find('BatchNorm2d') != -1:
56 | init.normal_(m.weight.data, 1.0, gain)
57 | init.constant_(m.bias.data, 0.0)
58 |
59 | print('initialize network with %s' % init_type)
60 | net.apply(init_func)
61 |
62 |
63 | def init_net(net, init_type='xavier', gpu_ids=[]):
64 | if len(gpu_ids) > 0:
65 | assert(torch.cuda.is_available())
66 | net.to(gpu_ids[0])
67 | net = torch.nn.DataParallel(net, gpu_ids)
68 | init_weights(net, init_type)
69 | return net
70 |
71 |
72 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='xavier', gpu_ids=[], use_tanh=True, classification=True):
73 | netG = None
74 | norm_layer = get_norm_layer(norm_type=norm)
75 |
76 | if which_model_netG == 'resnet_9blocks':
77 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
78 | elif which_model_netG == 'resnet_6blocks':
79 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
80 | elif which_model_netG == 'unet_128':
81 | netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
82 | elif which_model_netG == 'unet_256':
83 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
84 | elif which_model_netG == 'siggraph':
85 | netG = SIGGRAPHGenerator(input_nc, output_nc, norm_layer=norm_layer, use_tanh=use_tanh, classification=classification)
86 | else:
87 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
88 | return init_net(netG, init_type, gpu_ids)
89 |
90 |
91 | def define_D(input_nc, ndf, which_model_netD,
92 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='xavier', gpu_ids=[]):
93 | netD = None
94 | norm_layer = get_norm_layer(norm_type=norm)
95 |
96 | if which_model_netD == 'basic':
97 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
98 | elif which_model_netD == 'n_layers':
99 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
100 | elif which_model_netD == 'pixel':
101 | netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
102 | else:
103 | raise NotImplementedError('Discriminator model name [%s] is not recognized' %
104 | which_model_netD)
105 | return init_net(netD, init_type, gpu_ids)
106 |
107 |
108 | ##############################################################################
109 | # Classes
110 | ##############################################################################
111 |
112 |
113 | class HuberLoss(nn.Module):
114 | def __init__(self, delta=.01):
115 | super(HuberLoss, self).__init__()
116 | self.delta = delta
117 |
118 | def __call__(self, in0, in1):
119 | mask = torch.zeros_like(in0)
120 | mann = torch.abs(in0 - in1)
121 | eucl = .5 * (mann**2)
122 | mask[...] = mann < self.delta
123 |
124 | # loss = eucl*mask + self.delta*(mann-.5*self.delta)*(1-mask)
125 | loss = eucl * mask / self.delta + (mann - .5 * self.delta) * (1 - mask)
126 | return torch.sum(loss, dim=1, keepdim=True)
127 |
128 |
129 | class L1Loss(nn.Module):
130 | def __init__(self):
131 | super(L1Loss, self).__init__()
132 |
133 | def __call__(self, in0, in1):
134 | return torch.sum(torch.abs(in0 - in1), dim=1, keepdim=True)
135 |
136 |
137 | class L2Loss(nn.Module):
138 | def __init__(self):
139 | super(L2Loss, self).__init__()
140 |
141 | def __call__(self, in0, in1):
142 | return torch.sum((in0 - in1)**2, dim=1, keepdim=True)
143 |
144 |
145 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
146 | # When LSGAN is used, it is basically same as MSELoss,
147 | # but it abstracts away the need to create the target label tensor
148 | # that has the same size as the input
149 | class GANLoss(nn.Module):
150 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
151 | super(GANLoss, self).__init__()
152 | self.register_buffer('real_label', torch.tensor(target_real_label))
153 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
154 | if use_lsgan:
155 | self.loss = nn.MSELoss()
156 | else:
157 | self.loss = nn.BCELoss()
158 |
159 | def get_target_tensor(self, input, target_is_real):
160 | if target_is_real:
161 | target_tensor = self.real_label
162 | else:
163 | target_tensor = self.fake_label
164 | return target_tensor.expand_as(input)
165 |
166 | def __call__(self, input, target_is_real):
167 | target_tensor = self.get_target_tensor(input, target_is_real)
168 | return self.loss(input, target_tensor)
169 |
170 |
171 | class SIGGRAPHGenerator(nn.Module):
172 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, use_tanh=True, classification=True):
173 | super(SIGGRAPHGenerator, self).__init__()
174 | self.input_nc = input_nc
175 | self.output_nc = output_nc
176 | self.classification = classification
177 | use_bias = True
178 |
179 | # Conv1
180 | # model1=[nn.ReflectionPad2d(1),]
181 | model1 = [nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
182 | # model1+=[norm_layer(64),]
183 | model1 += [nn.ReLU(True), ]
184 | # model1+=[nn.ReflectionPad2d(1),]
185 | model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
186 | model1 += [nn.ReLU(True), ]
187 | model1 += [norm_layer(64), ]
188 | # add a subsampling operation
189 |
190 | # Conv2
191 | # model2=[nn.ReflectionPad2d(1),]
192 | model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
193 | # model2+=[norm_layer(128),]
194 | model2 += [nn.ReLU(True), ]
195 | # model2+=[nn.ReflectionPad2d(1),]
196 | model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
197 | model2 += [nn.ReLU(True), ]
198 | model2 += [norm_layer(128), ]
199 | # add a subsampling layer operation
200 |
201 | # Conv3
202 | # model3=[nn.ReflectionPad2d(1),]
203 | model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
204 | # model3+=[norm_layer(256),]
205 | model3 += [nn.ReLU(True), ]
206 | # model3+=[nn.ReflectionPad2d(1),]
207 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
208 | # model3+=[norm_layer(256),]
209 | model3 += [nn.ReLU(True), ]
210 | # model3+=[nn.ReflectionPad2d(1),]
211 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
212 | model3 += [nn.ReLU(True), ]
213 | model3 += [norm_layer(256), ]
214 | # add a subsampling layer operation
215 |
216 | # Conv4
217 | # model47=[nn.ReflectionPad2d(1),]
218 | model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
219 | # model4+=[norm_layer(512),]
220 | model4 += [nn.ReLU(True), ]
221 | # model4+=[nn.ReflectionPad2d(1),]
222 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
223 | # model4+=[norm_layer(512),]
224 | model4 += [nn.ReLU(True), ]
225 | # model4+=[nn.ReflectionPad2d(1),]
226 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
227 | model4 += [nn.ReLU(True), ]
228 | model4 += [norm_layer(512), ]
229 |
230 | # Conv5
231 | # model47+=[nn.ReflectionPad2d(2),]
232 | model5 = [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
233 | # model5+=[norm_layer(512),]
234 | model5 += [nn.ReLU(True), ]
235 | # model5+=[nn.ReflectionPad2d(2),]
236 | model5 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
237 | # model5+=[norm_layer(512),]
238 | model5 += [nn.ReLU(True), ]
239 | # model5+=[nn.ReflectionPad2d(2),]
240 | model5 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
241 | model5 += [nn.ReLU(True), ]
242 | model5 += [norm_layer(512), ]
243 |
244 | # Conv6
245 | # model6+=[nn.ReflectionPad2d(2),]
246 | model6 = [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
247 | # model6+=[norm_layer(512),]
248 | model6 += [nn.ReLU(True), ]
249 | # model6+=[nn.ReflectionPad2d(2),]
250 | model6 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
251 | # model6+=[norm_layer(512),]
252 | model6 += [nn.ReLU(True), ]
253 | # model6+=[nn.ReflectionPad2d(2),]
254 | model6 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
255 | model6 += [nn.ReLU(True), ]
256 | model6 += [norm_layer(512), ]
257 |
258 | # Conv7
259 | # model47+=[nn.ReflectionPad2d(1),]
260 | model7 = [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
261 | # model7+=[norm_layer(512),]
262 | model7 += [nn.ReLU(True), ]
263 | # model7+=[nn.ReflectionPad2d(1),]
264 | model7 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
265 | # model7+=[norm_layer(512),]
266 | model7 += [nn.ReLU(True), ]
267 | # model7+=[nn.ReflectionPad2d(1),]
268 | model7 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
269 | model7 += [nn.ReLU(True), ]
270 | model7 += [norm_layer(512), ]
271 |
272 | # Conv7
273 | model8up = [nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)]
274 |
275 | # model3short8=[nn.ReflectionPad2d(1),]
276 | model3short8 = [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
277 |
278 | # model47+=[norm_layer(256),]
279 | model8 = [nn.ReLU(True), ]
280 | # model8+=[nn.ReflectionPad2d(1),]
281 | model8 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
282 | # model8+=[norm_layer(256),]
283 | model8 += [nn.ReLU(True), ]
284 | # model8+=[nn.ReflectionPad2d(1),]
285 | model8 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
286 | model8 += [nn.ReLU(True), ]
287 | model8 += [norm_layer(256), ]
288 |
289 | # Conv9
290 | model9up = [nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ]
291 |
292 | # model2short9=[nn.ReflectionPad2d(1),]
293 | model2short9 = [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
294 | # add the two feature maps above
295 |
296 | # model9=[norm_layer(128),]
297 | model9 = [nn.ReLU(True), ]
298 | # model9+=[nn.ReflectionPad2d(1),]
299 | model9 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
300 | model9 += [nn.ReLU(True), ]
301 | model9 += [norm_layer(128), ]
302 |
303 | # Conv10
304 | model10up = [nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ]
305 |
306 | # model1short10=[nn.ReflectionPad2d(1),]
307 | model1short10 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
308 | # add the two feature maps above
309 |
310 | # model10=[norm_layer(128),]
311 | model10 = [nn.ReLU(True), ]
312 | # model10+=[nn.ReflectionPad2d(1),]
313 | model10 += [nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=use_bias), ]
314 | model10 += [nn.LeakyReLU(negative_slope=.2), ]
315 |
316 | # classification output
317 | model_class = [nn.Conv2d(256, 529, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), ]
318 |
319 | # regression output
320 | model_out = [nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), ]
321 | if(use_tanh):
322 | model_out += [nn.Tanh()]
323 |
324 | self.model1 = nn.Sequential(*model1)
325 | self.model2 = nn.Sequential(*model2)
326 | self.model3 = nn.Sequential(*model3)
327 | self.model4 = nn.Sequential(*model4)
328 | self.model5 = nn.Sequential(*model5)
329 | self.model6 = nn.Sequential(*model6)
330 | self.model7 = nn.Sequential(*model7)
331 | self.model8up = nn.Sequential(*model8up)
332 | self.model8 = nn.Sequential(*model8)
333 | self.model9up = nn.Sequential(*model9up)
334 | self.model9 = nn.Sequential(*model9)
335 | self.model10up = nn.Sequential(*model10up)
336 | self.model10 = nn.Sequential(*model10)
337 | self.model3short8 = nn.Sequential(*model3short8)
338 | self.model2short9 = nn.Sequential(*model2short9)
339 | self.model1short10 = nn.Sequential(*model1short10)
340 |
341 | self.model_class = nn.Sequential(*model_class)
342 | self.model_out = nn.Sequential(*model_out)
343 |
344 | self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='nearest'), ])
345 | self.softmax = nn.Sequential(*[nn.Softmax(dim=1), ])
346 |
347 | def forward(self, input_A, input_B, mask_B):
348 | conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1))
349 | conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
350 | conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
351 | conv4_3 = self.model4(conv3_3[:, :, ::2, ::2])
352 | conv5_3 = self.model5(conv4_3)
353 | conv6_3 = self.model6(conv5_3)
354 | conv7_3 = self.model7(conv6_3)
355 | conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
356 | conv8_3 = self.model8(conv8_up)
357 |
358 | if(self.classification):
359 | out_class = self.model_class(conv8_3)
360 |
361 | conv9_up = self.model9up(conv8_3.detach()) + self.model2short9(conv2_2.detach())
362 | conv9_3 = self.model9(conv9_up)
363 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2.detach())
364 | conv10_2 = self.model10(conv10_up)
365 | out_reg = self.model_out(conv10_2)
366 | else:
367 | out_class = self.model_class(conv8_3.detach())
368 |
369 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
370 | conv9_3 = self.model9(conv9_up)
371 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
372 | conv10_2 = self.model10(conv10_up)
373 | out_reg = self.model_out(conv10_2)
374 |
375 | return (out_class, out_reg)
376 |
377 | # Defines the generator that consists of Resnet blocks between a few
378 | # downsampling/upsampling operations.
379 | # Code and idea originally from Justin Johnson's architecture.
380 | # https://github.com/jcjohnson/fast-neural-style/
381 |
382 |
383 | class ResnetGenerator(nn.Module):
384 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
385 | assert(n_blocks >= 0)
386 | super(ResnetGenerator, self).__init__()
387 | self.input_nc = input_nc
388 | self.output_nc = output_nc
389 | self.ngf = ngf
390 | if type(norm_layer) == functools.partial:
391 | use_bias = norm_layer.func == nn.InstanceNorm2d
392 | else:
393 | use_bias = norm_layer == nn.InstanceNorm2d
394 |
395 | model = [nn.ReflectionPad2d(3),
396 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
397 | bias=use_bias),
398 | norm_layer(ngf),
399 | nn.ReLU(True)]
400 |
401 | n_downsampling = 2
402 | for i in range(n_downsampling):
403 | mult = 2**i
404 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
405 | stride=2, padding=1, bias=use_bias),
406 | norm_layer(ngf * mult * 2),
407 | nn.ReLU(True)]
408 |
409 | mult = 2**n_downsampling
410 | for i in range(n_blocks):
411 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
412 |
413 | for i in range(n_downsampling):
414 | mult = 2**(n_downsampling - i)
415 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
416 | kernel_size=3, stride=2,
417 | padding=1, output_padding=1,
418 | bias=use_bias),
419 | norm_layer(int(ngf * mult / 2)),
420 | nn.ReLU(True)]
421 | model += [nn.ReflectionPad2d(3)]
422 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
423 | model += [nn.Tanh()]
424 |
425 | self.model = nn.Sequential(*model)
426 |
427 | def forward(self, input):
428 | return self.model(input)
429 |
430 |
431 | # Define a resnet block
432 | class ResnetBlock(nn.Module):
433 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
434 | super(ResnetBlock, self).__init__()
435 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
436 |
437 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
438 | conv_block = []
439 | p = 0
440 | if padding_type == 'reflect':
441 | conv_block += [nn.ReflectionPad2d(1)]
442 | elif padding_type == 'replicate':
443 | conv_block += [nn.ReplicationPad2d(1)]
444 | elif padding_type == 'zero':
445 | p = 1
446 | else:
447 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
448 |
449 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
450 | norm_layer(dim),
451 | nn.ReLU(True)]
452 | if use_dropout:
453 | conv_block += [nn.Dropout(0.5)]
454 |
455 | p = 0
456 | if padding_type == 'reflect':
457 | conv_block += [nn.ReflectionPad2d(1)]
458 | elif padding_type == 'replicate':
459 | conv_block += [nn.ReplicationPad2d(1)]
460 | elif padding_type == 'zero':
461 | p = 1
462 | else:
463 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
464 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
465 | norm_layer(dim)]
466 |
467 | return nn.Sequential(*conv_block)
468 |
469 | def forward(self, x):
470 | out = x + self.conv_block(x)
471 | return out
472 |
473 |
474 | # Defines the Unet generator.
475 | # |num_downs|: number of downsamplings in UNet. For example,
476 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
477 | # at the bottleneck
478 | class UnetGenerator(nn.Module):
479 | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
480 | norm_layer=nn.BatchNorm2d, use_dropout=False):
481 | super(UnetGenerator, self).__init__()
482 |
483 | # construct unet structure
484 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
485 | for i in range(num_downs - 5):
486 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
487 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
488 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
489 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
490 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
491 |
492 | self.model = unet_block
493 |
494 | def forward(self, input_A, input_B, mask_B):
495 | # embed()
496 | return self.model(torch.cat((input_A, input_B, mask_B), dim=1))
497 |
498 |
499 | # Defines the submodule with skip connection.
500 | # X -------------------identity---------------------- X
501 | # |-- downsampling -- |submodule| -- upsampling --|
502 | class UnetSkipConnectionBlock(nn.Module):
503 | def __init__(self, outer_nc, inner_nc, input_nc=None,
504 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
505 | super(UnetSkipConnectionBlock, self).__init__()
506 | self.outermost = outermost
507 | if type(norm_layer) == functools.partial:
508 | use_bias = norm_layer.func == nn.InstanceNorm2d
509 | else:
510 | use_bias = norm_layer == nn.InstanceNorm2d
511 | if input_nc is None:
512 | input_nc = outer_nc
513 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
514 | stride=2, padding=1, bias=use_bias)
515 | downrelu = nn.LeakyReLU(0.2, True)
516 | downnorm = norm_layer(inner_nc)
517 | uprelu = nn.ReLU(True)
518 | upnorm = norm_layer(outer_nc)
519 |
520 | if outermost:
521 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
522 | kernel_size=4, stride=2,
523 | padding=1)
524 | down = [downconv]
525 | up = [uprelu, upconv, nn.Tanh()]
526 | model = down + [submodule] + up
527 | elif innermost:
528 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
529 | kernel_size=4, stride=2,
530 | padding=1, bias=use_bias)
531 | down = [downrelu, downconv]
532 | up = [uprelu, upconv, upnorm]
533 | model = down + up
534 | else:
535 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
536 | kernel_size=4, stride=2,
537 | padding=1, bias=use_bias)
538 | down = [downrelu, downconv, downnorm]
539 | up = [uprelu, upconv, upnorm]
540 |
541 | if use_dropout:
542 | model = down + [submodule] + up + [nn.Dropout(0.5)]
543 | else:
544 | model = down + [submodule] + up
545 |
546 | self.model = nn.Sequential(*model)
547 |
548 | def forward(self, x):
549 | if self.outermost:
550 | return self.model(x)
551 | else:
552 | return torch.cat([x, self.model(x)], 1)
553 |
554 |
555 | # Defines the PatchGAN discriminator with the specified arguments.
556 | class NLayerDiscriminator(nn.Module):
557 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
558 | super(NLayerDiscriminator, self).__init__()
559 | if type(norm_layer) == functools.partial:
560 | use_bias = norm_layer.func == nn.InstanceNorm2d
561 | else:
562 | use_bias = norm_layer == nn.InstanceNorm2d
563 |
564 | kw = 4
565 | padw = 1
566 | sequence = [
567 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
568 | nn.LeakyReLU(0.2, True)
569 | ]
570 |
571 | nf_mult = 1
572 | nf_mult_prev = 1
573 | for n in range(1, n_layers):
574 | nf_mult_prev = nf_mult
575 | nf_mult = min(2**n, 8)
576 | sequence += [
577 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
578 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),
579 | norm_layer(ndf * nf_mult),
580 | nn.LeakyReLU(0.2, True)
581 | ]
582 |
583 | nf_mult_prev = nf_mult
584 | nf_mult = min(2**n_layers, 8)
585 | sequence += [
586 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
587 | kernel_size=kw, stride=1, padding=padw, bias=use_bias),
588 | norm_layer(ndf * nf_mult),
589 | nn.LeakyReLU(0.2, True)
590 | ]
591 |
592 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
593 |
594 | if use_sigmoid:
595 | sequence += [nn.Sigmoid()]
596 |
597 | self.model = nn.Sequential(*sequence)
598 |
599 | def forward(self, input):
600 | return self.model(input)
601 |
602 |
603 | class PixelDiscriminator(nn.Module):
604 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
605 | super(PixelDiscriminator, self).__init__()
606 | if type(norm_layer) == functools.partial:
607 | use_bias = norm_layer.func == nn.InstanceNorm2d
608 | else:
609 | use_bias = norm_layer == nn.InstanceNorm2d
610 |
611 | self.net = [
612 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
613 | nn.LeakyReLU(0.2, True),
614 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
615 | norm_layer(ndf * 2),
616 | nn.LeakyReLU(0.2, True),
617 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
618 |
619 | if use_sigmoid:
620 | self.net.append(nn.Sigmoid())
621 |
622 | self.net = nn.Sequential(*self.net)
623 |
624 | def forward(self, input):
625 | return self.net(input)
626 |
--------------------------------------------------------------------------------
/models/pix2pix_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 | from util.image_pool import ImagePool
4 | from util import util
5 | from .base_model import BaseModel
6 | from . import networks
7 | import numpy as np
8 |
9 |
10 | class Pix2PixModel(BaseModel):
11 | def name(self):
12 | return 'Pix2PixModel'
13 |
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train=True):
16 | return parser
17 |
18 | def initialize(self, opt):
19 | BaseModel.initialize(self, opt)
20 | self.isTrain = opt.isTrain
21 | self.half = opt.half
22 |
23 | self.use_D = self.opt.lambda_GAN > 0
24 |
25 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
26 |
27 | if(self.use_D):
28 | self.loss_names = ['G_GAN', ]
29 | else:
30 | self.loss_names = []
31 |
32 | self.loss_names += ['G_CE', 'G_entr', 'G_entr_hint', ]
33 | self.loss_names += ['G_L1_max', 'G_L1_mean', 'G_entr', 'G_L1_reg', ]
34 | self.loss_names += ['G_fake_real', 'G_fake_hint', 'G_real_hint', ]
35 | self.loss_names += ['0', ]
36 |
37 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
38 | self.visual_names = ['real_A', 'fake_B', 'real_B']
39 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
40 |
41 | if self.isTrain:
42 | if(self.use_D):
43 | self.model_names = ['G', 'D']
44 | else:
45 | self.model_names = ['G', ]
46 | else: # during test time, only load Gs
47 | self.model_names = ['G']
48 |
49 | # load/define networks
50 | num_in = opt.input_nc + opt.output_nc + 1
51 | self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf,
52 | opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
53 | use_tanh=True, classification=opt.classification)
54 |
55 | if self.isTrain:
56 | use_sigmoid = opt.no_lsgan
57 | if self.use_D:
58 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
59 | opt.which_model_netD,
60 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
61 |
62 | if self.isTrain:
63 | self.fake_AB_pool = ImagePool(opt.pool_size)
64 | # define loss functions
65 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
66 | # self.criterionL1 = torch.nn.L1Loss()
67 | self.criterionL1 = networks.L1Loss()
68 | self.criterionHuber = networks.HuberLoss(delta=1. / opt.ab_norm)
69 |
70 | # if(opt.classification):
71 | self.criterionCE = torch.nn.CrossEntropyLoss()
72 |
73 | # initialize optimizers
74 | self.optimizers = []
75 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
76 | lr=opt.lr, betas=(opt.beta1, 0.999))
77 | self.optimizers.append(self.optimizer_G)
78 |
79 | if self.use_D:
80 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
81 | lr=opt.lr, betas=(opt.beta1, 0.999))
82 | self.optimizers.append(self.optimizer_D)
83 |
84 | if self.half:
85 | for model_name in self.model_names:
86 | net = getattr(self, 'net' + model_name)
87 | net.half()
88 | for layer in net.modules():
89 | if(isinstance(layer, torch.nn.BatchNorm2d)):
90 | layer.float()
91 | print('Net %s half precision' % model_name)
92 |
93 | # initialize average loss values
94 | self.avg_losses = OrderedDict()
95 | self.avg_loss_alpha = opt.avg_loss_alpha
96 | self.error_cnt = 0
97 |
98 | # self.avg_loss_alpha = 0.9993 # half-life of 1000 iterations
99 | # self.avg_loss_alpha = 0.9965 # half-life of 200 iterations
100 | # self.avg_loss_alpha = 0.986 # half-life of 50 iterations
101 | # self.avg_loss_alpha = 0. # no averaging
102 | for loss_name in self.loss_names:
103 | self.avg_losses[loss_name] = 0
104 |
105 | def set_input(self, input):
106 | if(self.half):
107 | for key in input.keys():
108 | input[key] = input[key].half()
109 |
110 | AtoB = self.opt.which_direction == 'AtoB'
111 | self.real_A = input['A' if AtoB else 'B'].to(self.device)
112 | self.real_B = input['B' if AtoB else 'A'].to(self.device)
113 | # self.image_paths = input['A_paths' if AtoB else 'B_paths']
114 | self.hint_B = input['hint_B'].to(self.device)
115 | self.mask_B = input['mask_B'].to(self.device)
116 | self.mask_B_nc = self.mask_B + self.opt.mask_cent
117 |
118 | self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt)
119 |
120 | def forward(self):
121 | (self.fake_B_class, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, self.mask_B)
122 | # if(self.opt.classification):
123 | self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
124 | self.fake_B_distr = self.netG.module.softmax(self.fake_B_class)
125 |
126 | self.fake_B_dec_mean = self.netG.module.upsample4(util.decode_mean(self.fake_B_distr, self.opt))
127 |
128 | self.fake_B_entr = self.netG.module.upsample4(-torch.sum(self.fake_B_distr * torch.log(self.fake_B_distr + 1.e-10), dim=1, keepdim=True))
129 | # embed()
130 |
131 | def backward_D(self):
132 | # Fake
133 | # stop backprop to the generator by detaching fake_B
134 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
135 | pred_fake = self.netD(fake_AB.detach())
136 | self.loss_D_fake = self.criterionGAN(pred_fake, False)
137 | # self.loss_D_fake = 0
138 |
139 | # Real
140 | real_AB = torch.cat((self.real_A, self.real_B), 1)
141 | pred_real = self.netD(real_AB)
142 | self.loss_D_real = self.criterionGAN(pred_real, True)
143 | # self.loss_D_real = 0
144 |
145 | # Combined loss
146 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
147 |
148 | self.loss_D.backward()
149 |
150 | def compute_losses_G(self):
151 | mask_avg = torch.mean(self.mask_B_nc.type(torch.cuda.FloatTensor)) + .000001
152 |
153 | self.loss_0 = 0 # 0 for plot
154 |
155 | # classification statistics
156 | self.loss_G_CE = self.criterionCE(self.fake_B_class.type(torch.cuda.FloatTensor),
157 | self.real_B_enc[:, 0, :, :].type(torch.cuda.LongTensor)) # cross-entropy loss
158 | self.loss_G_entr = torch.mean(self.fake_B_entr.type(torch.cuda.FloatTensor)) # entropy of predicted distribution
159 | self.loss_G_entr_hint = torch.mean(self.fake_B_entr.type(torch.cuda.FloatTensor) * self.mask_B_nc.type(torch.cuda.FloatTensor)) / mask_avg # entropy of predicted distribution at hint points
160 |
161 | # regression statistics
162 | self.loss_G_L1_max = 10 * torch.mean(self.criterionL1(self.fake_B_dec_max.type(torch.cuda.FloatTensor),
163 | self.real_B.type(torch.cuda.FloatTensor)))
164 | self.loss_G_L1_mean = 10 * torch.mean(self.criterionL1(self.fake_B_dec_mean.type(torch.cuda.FloatTensor),
165 | self.real_B.type(torch.cuda.FloatTensor)))
166 | self.loss_G_L1_reg = 10 * torch.mean(self.criterionL1(self.fake_B_reg.type(torch.cuda.FloatTensor),
167 | self.real_B.type(torch.cuda.FloatTensor)))
168 |
169 | # L1 loss at given points
170 | self.loss_G_fake_real = 10 * torch.mean(self.criterionL1(self.fake_B_reg * self.mask_B_nc, self.real_B * self.mask_B_nc).type(torch.cuda.FloatTensor)) / mask_avg
171 | self.loss_G_fake_hint = 10 * torch.mean(self.criterionL1(self.fake_B_reg * self.mask_B_nc, self.hint_B * self.mask_B_nc).type(torch.cuda.FloatTensor)) / mask_avg
172 | self.loss_G_real_hint = 10 * torch.mean(self.criterionL1(self.real_B * self.mask_B_nc, self.hint_B * self.mask_B_nc).type(torch.cuda.FloatTensor)) / mask_avg
173 |
174 | # self.loss_G_L1 = torch.mean(self.criterionL1(self.fake_B, self.real_B))
175 | # self.loss_G_Huber = torch.mean(self.criterionHuber(self.fake_B, self.real_B))
176 | # self.loss_G_fake_real = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.real_B*self.mask_B_nc)) / mask_avg
177 | # self.loss_G_fake_hint = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg
178 | # self.loss_G_real_hint = torch.mean(self.criterionHuber(self.real_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg
179 |
180 | if self.use_D:
181 | fake_AB = torch.cat((self.real_A, self.fake_B), 1)
182 | pred_fake = self.netD(fake_AB)
183 | self.loss_G_GAN = self.criterionGAN(pred_fake, True)
184 | else:
185 | self.loss_G = self.loss_G_CE * self.opt.lambda_A + self.loss_G_L1_reg
186 | # self.loss_G = self.loss_G_Huber*self.opt.lambda_A
187 |
188 | def backward_G(self):
189 | self.compute_losses_G()
190 | self.loss_G.backward()
191 |
192 | def optimize_parameters(self):
193 | self.forward()
194 |
195 | if(self.use_D):
196 | # update D
197 | self.set_requires_grad(self.netD, True)
198 | self.optimizer_D.zero_grad()
199 | self.backward_D()
200 | self.optimizer_D.step()
201 |
202 | self.set_requires_grad(self.netD, False)
203 |
204 | # update G
205 | self.optimizer_G.zero_grad()
206 | self.backward_G()
207 | self.optimizer_G.step()
208 |
209 | def get_current_visuals(self):
210 | from collections import OrderedDict
211 | visual_ret = OrderedDict()
212 |
213 | visual_ret['gray'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), torch.zeros_like(self.real_B).type(torch.cuda.FloatTensor)), dim=1), self.opt)
214 | visual_ret['real'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
215 |
216 | visual_ret['fake_max'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_dec_max.type(torch.cuda.FloatTensor)), dim=1), self.opt)
217 | visual_ret['fake_mean'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_dec_mean.type(torch.cuda.FloatTensor)), dim=1), self.opt)
218 | visual_ret['fake_reg'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt)
219 |
220 | visual_ret['hint'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.hint_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
221 |
222 | visual_ret['real_ab'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.real_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
223 |
224 | visual_ret['fake_ab_max'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_dec_max.type(torch.cuda.FloatTensor)), dim=1), self.opt)
225 | visual_ret['fake_ab_mean'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_dec_mean.type(torch.cuda.FloatTensor)), dim=1), self.opt)
226 | visual_ret['fake_ab_reg'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt)
227 |
228 | visual_ret['mask'] = self.mask_B_nc.expand(-1, 3, -1, -1).type(torch.cuda.FloatTensor)
229 | visual_ret['hint_ab'] = visual_ret['mask'] * util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.hint_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
230 |
231 | C = self.fake_B_distr.shape[1]
232 | # scale to [-1, 2], then clamped to [-1, 1]
233 | visual_ret['fake_entr'] = torch.clamp(3 * self.fake_B_entr.expand(-1, 3, -1, -1) / np.log(C) - 1, -1, 1)
234 |
235 | return visual_ret
236 |
237 | # return training losses/errors. train.py will print out these errors as debugging information
238 | def get_current_losses(self):
239 | self.error_cnt += 1
240 | errors_ret = OrderedDict()
241 | for name in self.loss_names:
242 | if isinstance(name, str):
243 | # float(...) works for both scalar tensor and float number
244 | self.avg_losses[name] = float(getattr(self, 'loss_' + name)) + self.avg_loss_alpha * self.avg_losses[name]
245 | errors_ret[name] = (1 - self.avg_loss_alpha) / (1 - self.avg_loss_alpha**self.error_cnt) * self.avg_losses[name]
246 |
247 | # errors_ret['|ab|_gt'] = float(torch.mean(torch.abs(self.real_B[:,1:,:,:])).cpu())
248 | # errors_ret['|ab|_pr'] = float(torch.mean(torch.abs(self.fake_B[:,1:,:,:])).cpu())
249 |
250 | return errors_ret
251 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import models
6 | import data
7 |
8 |
9 | class BaseOptions():
10 | def __init__(self):
11 | self.initialized = False
12 |
13 | def initialize(self, parser):
14 | parser.add_argument('--batch_size', type=int, default=25, help='input batch size')
15 | parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size')
16 | parser.add_argument('--fineSize', type=int, default=176, help='then crop to this size')
17 | parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels')
18 | parser.add_argument('--output_nc', type=int, default=2, help='# of output image channels')
19 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
20 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
21 | parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
22 | parser.add_argument('--which_model_netG', type=str, default='siggraph', help='selects model to use for netG')
23 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
24 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
25 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
26 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')
27 | parser.add_argument('--model', type=str, default='pix2pix',
28 | help='chooses which model to use. cycle_gan, pix2pix, test')
29 | parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
30 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
31 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
32 | parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization')
33 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
34 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
35 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
36 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
37 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
38 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
39 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
40 | help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
41 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
42 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
43 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
44 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
45 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
46 | parser.add_argument('--ab_norm', type=float, default=110., help='colorization normalization factor')
47 | parser.add_argument('--ab_max', type=float, default=110., help='maximimum ab value')
48 | parser.add_argument('--ab_quant', type=float, default=10., help='quantization factor')
49 | parser.add_argument('--l_norm', type=float, default=100., help='colorization normalization factor')
50 | parser.add_argument('--l_cent', type=float, default=50., help='colorization centering factor')
51 | parser.add_argument('--mask_cent', type=float, default=.5, help='mask centering factor')
52 | parser.add_argument('--sample_p', type=float, default=1.0, help='sampling geometric distribution, 1.0 means no hints')
53 | parser.add_argument('--sample_Ps', type=int, nargs='+', default=[1, 2, 3, 4, 5, 6, 7, 8, 9, ], help='patch sizes')
54 |
55 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
56 | parser.add_argument('--classification', action='store_true', help='backprop trunk using classification, otherwise use regression')
57 | parser.add_argument('--phase', type=str, default='val', help='train_small, train, val, test, etc')
58 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
59 | parser.add_argument('--how_many', type=int, default=200, help='how many test images to run')
60 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
61 |
62 | parser.add_argument('--load_model', action='store_true', help='load the latest model')
63 | parser.add_argument('--half', action='store_true', help='half precision model')
64 |
65 | self.initialized = True
66 | return parser
67 |
68 | def gather_options(self):
69 | # initialize parser with basic options
70 | if not self.initialized:
71 | parser = argparse.ArgumentParser(
72 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
73 | parser = self.initialize(parser)
74 |
75 | # get the basic options
76 | opt, _ = parser.parse_known_args()
77 |
78 | # modify model-related parser options
79 | model_name = opt.model
80 | model_option_setter = models.get_option_setter(model_name)
81 | parser = model_option_setter(parser, self.isTrain)
82 | opt, _ = parser.parse_known_args() # parse again with the new defaults
83 |
84 | # modify dataset-related parser options
85 | dataset_name = opt.dataset_mode
86 | dataset_option_setter = data.get_option_setter(dataset_name)
87 | parser = dataset_option_setter(parser, self.isTrain)
88 |
89 | self.parser = parser
90 |
91 | return parser.parse_args()
92 |
93 | def print_options(self, opt):
94 | message = ''
95 | message += '----------------- Options ---------------\n'
96 | for k, v in sorted(vars(opt).items()):
97 | comment = ''
98 | default = self.parser.get_default(k)
99 | if v != default:
100 | comment = '\t[default: %s]' % str(default)
101 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
102 | message += '----------------- End -------------------'
103 | print(message)
104 |
105 | # save to the disk
106 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
107 | util.mkdirs(expr_dir)
108 | file_name = os.path.join(expr_dir, 'opt.txt')
109 | with open(file_name, 'wt') as opt_file:
110 | opt_file.write(message)
111 | opt_file.write('\n')
112 |
113 | def parse(self):
114 |
115 | opt = self.gather_options()
116 | opt.isTrain = self.isTrain # train or test
117 |
118 | # process opt.suffix
119 | if opt.suffix:
120 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
121 | opt.name = opt.name + suffix
122 |
123 | self.print_options(opt)
124 |
125 | # set gpu ids
126 | str_ids = opt.gpu_ids.split(',')
127 | opt.gpu_ids = []
128 | for str_id in str_ids:
129 | id = int(str_id)
130 | if id >= 0:
131 | opt.gpu_ids.append(id)
132 | if len(opt.gpu_ids) > 0:
133 | torch.cuda.set_device(opt.gpu_ids[0])
134 | opt.A = 2 * opt.ab_max / opt.ab_quant + 1
135 | opt.B = opt.A
136 |
137 | self.opt = opt
138 | return self.opt
139 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self, parser):
6 | BaseOptions.initialize(self, parser)
7 | parser.add_argument('--display_freq', type=int, default=10000, help='frequency of showing training results on screen')
8 | parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
9 | parser.add_argument('--update_html_freq', type=int, default=10000, help='frequency of saving training results to html')
10 | parser.add_argument('--print_freq', type=int, default=200, help='frequency of showing training results on console')
11 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
12 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
13 | parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...')
14 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
15 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
16 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
17 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
18 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
19 | parser.add_argument('--lambda_GAN', type=float, default=0., help='weight for GAN loss')
20 | parser.add_argument('--lambda_A', type=float, default=1., help='weight for cycle loss (A -> B -> A)')
21 | parser.add_argument('--lambda_B', type=float, default=1., help='weight for cycle loss (B -> A -> B)')
22 | parser.add_argument('--lambda_identity', type=float, default=0.5,
23 | help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss.'
24 | 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
25 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
26 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
27 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
28 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
29 | parser.add_argument('--avg_loss_alpha', type=float, default=.986, help='exponential averaging weight for displaying loss')
30 | self.isTrain = True
31 | return parser
32 |
--------------------------------------------------------------------------------
/pretrained_models/download_siggraph_model.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ./checkpoints/siggraph_retrained
2 | MODEL_FILE=./checkpoints/siggraph_retrained/latest_net_G.pth
3 | URL=http://colorization.eecs.berkeley.edu/siggraph/models/pytorch.pth
4 |
5 | wget -N $URL -O $MODEL_FILE
6 |
7 | mkdir -p ./checkpoints/siggraph_caffemodel
8 | MODEL_FILE=./checkpoints/siggraph_caffemodel/latest_net_G.pth
9 | URL=http://colorization.eecs.berkeley.edu/siggraph/models/caffemodel.pth
10 |
11 | wget -N $URL -O $MODEL_FILE
12 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=0.4.0
2 | torchvision>=0.2.1
3 | dominate>=2.3.1
4 | visdom>=0.1.8.3
5 |
--------------------------------------------------------------------------------
/resources/ilsvrclin12_val_inds.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/resources/ilsvrclin12_val_inds.npy
--------------------------------------------------------------------------------
/resources/psnrs_siggraph.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/resources/psnrs_siggraph.npy
--------------------------------------------------------------------------------
/scripts/check_all.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | DOWNLOAD=${1}
3 | echo 'apply a pretrained cyclegan model'
4 | if [ ${DOWNLOAD} -eq 1 ]
5 | then
6 | bash pretrained_models/download_cyclegan_model.sh horse2zebra
7 | bash ./datasets/download_cyclegan_dataset.sh horse2zebra
8 | fi
9 | python test.py --dataroot datasets/horse2zebra/testA --checkpoints_dir ./checkpoints/ --name horse2zebra_pretrained --no_dropout --model test --dataset_mode single --loadSize 256
10 |
11 | echo 'apply a pretrained pix2pix model'
12 | if [ ${DOWNLOAD} -eq 1 ]
13 | then
14 | bash pretrained_models/download_pix2pix_model.sh facades_label2photo
15 | bash ./datasets/download_pix2pix_dataset.sh facades
16 | fi
17 | python test.py --dataroot ./datasets/facades/ --which_direction BtoA --model pix2pix --name facades_label2photo_pretrained --dataset_mode aligned --which_model_netG unet_256 --norm batch
18 |
19 |
20 | echo 'cyclegan train (1 epoch) and test'
21 | if [ ${DOWNLOAD} -eq 1 ]
22 | then
23 | bash ./datasets/download_cyclegan_dataset.sh maps
24 | fi
25 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --no_dropout --niter 1 --niter_decay 0 --max_dataset_size 100 --save_latest_freq 100
26 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test --no_dropout
27 |
28 |
29 | echo 'pix2pix train (1 epoch) and test'
30 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --no_lsgan --norm batch --pool_size 0 --niter 1 --niter_decay 0 --save_latest_freq 400
31 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --norm batch
32 |
--------------------------------------------------------------------------------
/scripts/conda_deps.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
3 | conda install -c pytorch magma-cuda80 # or magma-cuda90 if CUDA 9
4 | conda install pytorch torchvision -c pytorch # install pytorch; if you want to use cuda90, add cuda90
5 | conda install -c conda-forge dominate # install dominate
6 | conda install -c conda-forge visdom # install visdom
7 |
--------------------------------------------------------------------------------
/scripts/install_deps.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | pip install visdom
3 | pip install dominate
4 |
--------------------------------------------------------------------------------
/scripts/train_siggraph.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Train classification network on small training set first
4 | python train.py --name siggraph_class_small --sample_p 1.0 --niter 100 --niter_decay 0 --classification --phase train_small
5 |
6 | # Train classification network first
7 | mkdir ./checkpoints/siggraph_class
8 | cp ./checkpoints/siggraph_class_small/latest_net_G.pth ./checkpoints/siggraph_class/
9 | python train.py --name siggraph_class --sample_p 1.0 --niter 15 --niter_decay 0 --classification --load_model --phase train
10 |
11 | # Train regression model (with color hints)
12 | mkdir ./checkpoints/siggraph_reg
13 | cp ./checkpoints/siggraph_class/latest_net_G.pth ./checkpoints/siggraph_reg/
14 | python train.py --name siggraph_reg --sample_p .125 --niter 10 --niter_decay 0 --lr 0.00001 --load_model --phase train
15 |
16 | # Turn down learning rate to 1e-6
17 | mkdir ./checkpoints/siggraph_reg2
18 | cp ./checkpoints/siggraph_reg/latest_net_G.pth ./checkpoints/siggraph_reg2/
19 | python train.py --name siggraph_reg2 --sample_p .125 --niter 5 --niter_decay 0 --lr 0.000001 --load_model --phase train
20 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from options.train_options import TrainOptions
4 | from models import create_model
5 | from util.visualizer import save_images
6 | from util import html
7 |
8 | import string
9 | import torch
10 | import torchvision
11 | import torchvision.transforms as transforms
12 |
13 | from util import util
14 | import numpy as np
15 |
16 |
17 | if __name__ == '__main__':
18 | sample_ps = [1., .125, .03125]
19 | to_visualize = ['gray', 'hint', 'hint_ab', 'fake_entr', 'real', 'fake_reg', 'real_ab', 'fake_ab_reg', ]
20 | S = len(sample_ps)
21 |
22 | opt = TrainOptions().parse()
23 | opt.load_model = True
24 | opt.num_threads = 1 # test code only supports num_threads = 1
25 | opt.batch_size = 1 # test code only supports batch_size = 1
26 | opt.display_id = -1 # no visdom display
27 | opt.phase = 'val'
28 | opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase
29 | opt.serial_batches = True
30 | opt.aspect_ratio = 1.
31 |
32 | dataset = torchvision.datasets.ImageFolder(opt.dataroot,
33 | transform=transforms.Compose([
34 | transforms.Resize((opt.loadSize, opt.loadSize)),
35 | transforms.ToTensor()]))
36 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches)
37 |
38 | model = create_model(opt)
39 | model.setup(opt)
40 | model.eval()
41 |
42 | # create website
43 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
44 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
45 |
46 | # statistics
47 | psnrs = np.zeros((opt.how_many, S))
48 | entrs = np.zeros((opt.how_many, S))
49 |
50 | for i, data_raw in enumerate(dataset_loader):
51 | data_raw[0] = data_raw[0].cuda()
52 | data_raw[0] = util.crop_mult(data_raw[0], mult=8)
53 |
54 | # with no points
55 | for (pp, sample_p) in enumerate(sample_ps):
56 | img_path = [string.replace('%08d_%.3f' % (i, sample_p), '.', 'p')]
57 | data = util.get_colorization_data(data_raw, opt, ab_thresh=0., p=sample_p)
58 |
59 | model.set_input(data)
60 | model.test(True) # True means that losses will be computed
61 | visuals = util.get_subset_dict(model.get_current_visuals(), to_visualize)
62 |
63 | psnrs[i, pp] = util.calculate_psnr_np(util.tensor2im(visuals['real']), util.tensor2im(visuals['fake_reg']))
64 | entrs[i, pp] = model.get_current_losses()['G_entr']
65 |
66 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
67 |
68 | if i % 5 == 0:
69 | print('processing (%04d)-th image... %s' % (i, img_path))
70 |
71 | if i == opt.how_many - 1:
72 | break
73 |
74 | webpage.save()
75 |
76 | # Compute and print some summary statistics
77 | psnrs_mean = np.mean(psnrs, axis=0)
78 | psnrs_std = np.std(psnrs, axis=0) / np.sqrt(opt.how_many)
79 |
80 | entrs_mean = np.mean(entrs, axis=0)
81 | entrs_std = np.std(entrs, axis=0) / np.sqrt(opt.how_many)
82 |
83 | for (pp, sample_p) in enumerate(sample_ps):
84 | print('p=%.3f: %.2f+/-%.2f' % (sample_p, psnrs_mean[pp], psnrs_std[pp]))
85 |
--------------------------------------------------------------------------------
/test_sweep.py:
--------------------------------------------------------------------------------
1 | from options.train_options import TrainOptions
2 | from models import create_model
3 |
4 | import torch
5 | import torchvision
6 | import torchvision.transforms as transforms
7 |
8 | from util import util
9 | import numpy as np
10 | import progressbar as pb
11 | import shutil
12 |
13 | import datetime as dt
14 | import matplotlib.pyplot as plt
15 |
16 | if __name__ == '__main__':
17 | opt = TrainOptions().parse()
18 | opt.load_model = True
19 | opt.num_threads = 1 # test code only supports num_threads = 1
20 | opt.batch_size = 1 # test code only supports batch_size = 1
21 | opt.display_id = -1 # no visdom display
22 | opt.phase = 'test'
23 | opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase
24 | opt.loadSize = 256
25 | opt.how_many = 1000
26 | opt.aspect_ratio = 1.0
27 | opt.sample_Ps = [6, ]
28 | opt.load_model = True
29 |
30 | # number of random points to assign
31 | num_points = np.round(10**np.arange(-.1, 2.8, .1))
32 | num_points[0] = 0
33 | num_points = np.unique(num_points.astype('int'))
34 | N = len(num_points)
35 |
36 | dataset = torchvision.datasets.ImageFolder(opt.dataroot,
37 | transform=transforms.Compose([
38 | transforms.Resize((opt.loadSize, opt.loadSize)),
39 | transforms.ToTensor()]))
40 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches)
41 |
42 | model = create_model(opt)
43 | model.setup(opt)
44 | model.eval()
45 |
46 | time = dt.datetime.now()
47 | str_now = '%02d_%02d_%02d%02d' % (time.month, time.day, time.hour, time.minute)
48 |
49 | shutil.copyfile('./checkpoints/%s/latest_net_G.pth' % opt.name, './checkpoints/%s/%s_net_G.pth' % (opt.name, str_now))
50 |
51 | psnrs = np.zeros((opt.how_many, N))
52 |
53 | bar = pb.ProgressBar(max_value=opt.how_many)
54 | for i, data_raw in enumerate(dataset_loader):
55 | data_raw[0] = data_raw[0].cuda()
56 | data_raw[0] = util.crop_mult(data_raw[0], mult=8)
57 |
58 | for nn in range(N):
59 | # embed()
60 | data = util.get_colorization_data(data_raw, opt, ab_thresh=0., num_points=num_points[nn])
61 |
62 | model.set_input(data)
63 | model.test()
64 | visuals = model.get_current_visuals()
65 |
66 | psnrs[i, nn] = util.calculate_psnr_np(util.tensor2im(visuals['real']), util.tensor2im(visuals['fake_reg']))
67 |
68 | if i == opt.how_many - 1:
69 | break
70 |
71 | bar.update(i)
72 |
73 | # Save results
74 | psnrs_mean = np.mean(psnrs, axis=0)
75 | psnrs_std = np.std(psnrs, axis=0) / np.sqrt(opt.how_many)
76 |
77 | np.save('./checkpoints/%s/psnrs_mean_%s' % (opt.name, str_now), psnrs_mean)
78 | np.save('./checkpoints/%s/psnrs_std_%s' % (opt.name, str_now), psnrs_std)
79 | np.save('./checkpoints/%s/psnrs_%s' % (opt.name, str_now), psnrs)
80 | print(', ').join(['%.2f' % psnr for psnr in psnrs_mean])
81 |
82 | old_results = np.load('./resources/psnrs_siggraph.npy')
83 | old_mean = np.mean(old_results, axis=0)
84 | old_std = np.std(old_results, axis=0) / np.sqrt(old_results.shape[0])
85 | print(', ').join(['%.2f' % psnr for psnr in old_mean])
86 |
87 | num_points_hack = 1. * num_points
88 | num_points_hack[0] = .4
89 |
90 | plt.plot(num_points_hack, psnrs_mean, 'bo-', label=str_now)
91 | plt.plot(num_points_hack, psnrs_mean + psnrs_std, 'b--')
92 | plt.plot(num_points_hack, psnrs_mean - psnrs_std, 'b--')
93 | plt.plot(num_points_hack, old_mean, 'ro-', label='siggraph17')
94 | plt.plot(num_points_hack, old_mean + old_std, 'r--')
95 | plt.plot(num_points_hack, old_mean - old_std, 'r--')
96 |
97 | plt.xscale('log')
98 | plt.xticks([.4, 1, 2, 5, 10, 20, 50, 100, 200, 500],
99 | ['Auto', '1', '2', '5', '10', '20', '50', '100', '200', '500'])
100 | plt.xlabel('Number of points')
101 | plt.ylabel('PSNR [db]')
102 | plt.legend(loc=0)
103 | plt.xlim((num_points_hack[0], num_points_hack[-1]))
104 | plt.savefig('./checkpoints/%s/sweep_%s.png' % (opt.name, str_now))
105 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from models import create_model
4 | from util.visualizer import Visualizer
5 |
6 | import torch
7 | import torchvision
8 | import torchvision.transforms as transforms
9 |
10 | from util import util
11 |
12 | if __name__ == '__main__':
13 | opt = TrainOptions().parse()
14 |
15 | opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase
16 | dataset = torchvision.datasets.ImageFolder(opt.dataroot,
17 | transform=transforms.Compose([
18 | transforms.RandomChoice([transforms.Resize(opt.loadSize, interpolation=1),
19 | transforms.Resize(opt.loadSize, interpolation=2),
20 | transforms.Resize(opt.loadSize, interpolation=3),
21 | transforms.Resize((opt.loadSize, opt.loadSize), interpolation=1),
22 | transforms.Resize((opt.loadSize, opt.loadSize), interpolation=2),
23 | transforms.Resize((opt.loadSize, opt.loadSize), interpolation=3)]),
24 | transforms.RandomChoice([transforms.RandomResizedCrop(opt.fineSize, interpolation=1),
25 | transforms.RandomResizedCrop(opt.fineSize, interpolation=2),
26 | transforms.RandomResizedCrop(opt.fineSize, interpolation=3)]),
27 | transforms.RandomHorizontalFlip(),
28 | transforms.ToTensor()]))
29 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.num_threads))
30 |
31 | dataset_size = len(dataset)
32 | print('#training images = %d' % dataset_size)
33 |
34 | model = create_model(opt)
35 | model.setup(opt)
36 | model.print_networks(True)
37 |
38 | visualizer = Visualizer(opt)
39 | total_steps = 0
40 |
41 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay):
42 | epoch_start_time = time.time()
43 | iter_data_time = time.time()
44 | epoch_iter = 0
45 |
46 | # for i, data in enumerate(dataset):
47 | for i, data_raw in enumerate(dataset_loader):
48 | data_raw[0] = data_raw[0].cuda()
49 | data = util.get_colorization_data(data_raw, opt, p=opt.sample_p)
50 | if(data is None):
51 | continue
52 |
53 | iter_start_time = time.time()
54 | if total_steps % opt.print_freq == 0:
55 | # time to load data
56 | t_data = iter_start_time - iter_data_time
57 | visualizer.reset()
58 | total_steps += opt.batch_size
59 | epoch_iter += opt.batch_size
60 | model.set_input(data)
61 | model.optimize_parameters()
62 |
63 | if total_steps % opt.display_freq == 0:
64 | save_result = total_steps % opt.update_html_freq == 0
65 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
66 |
67 | if total_steps % opt.print_freq == 0:
68 | losses = model.get_current_losses()
69 | # time to do forward&backward
70 | t = time.time() - iter_start_time
71 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
72 | if opt.display_id > 0:
73 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
74 |
75 | if total_steps % opt.save_latest_freq == 0:
76 | print('saving the latest model (epoch %d, total_steps %d)' %
77 | (epoch, total_steps))
78 | model.save_networks('latest')
79 |
80 | iter_data_time = time.time()
81 |
82 | if epoch % opt.save_epoch_freq == 0:
83 | print('saving the model at the end of epoch %d, iters %d' %
84 | (epoch, total_steps))
85 | model.save_networks('latest')
86 | model.save_networks(epoch)
87 |
88 | print('End of epoch %d / %d \t Time Taken: %d sec' %
89 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
90 | model.update_learning_rate()
91 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/__init__.py
--------------------------------------------------------------------------------
/util/get_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import tarfile
4 | import requests
5 | from warnings import warn
6 | from zipfile import ZipFile
7 | from bs4 import BeautifulSoup
8 | from os.path import abspath, isdir, join, basename
9 |
10 |
11 | class GetData(object):
12 | """
13 |
14 | Download CycleGAN or Pix2Pix Data.
15 |
16 | Args:
17 | technique : str
18 | One of: 'cyclegan' or 'pix2pix'.
19 | verbose : bool
20 | If True, print additional information.
21 |
22 | Examples:
23 | >>> from util.get_data import GetData
24 | >>> gd = GetData(technique='cyclegan')
25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
26 |
27 | """
28 |
29 | def __init__(self, technique='cyclegan', verbose=True):
30 | url_dict = {
31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
33 | }
34 | self.url = url_dict.get(technique.lower())
35 | self._verbose = verbose
36 |
37 | def _print(self, text):
38 | if self._verbose:
39 | print(text)
40 |
41 | @staticmethod
42 | def _get_options(r):
43 | soup = BeautifulSoup(r.text, 'lxml')
44 | options = [h.text for h in soup.find_all('a', href=True)
45 | if h.text.endswith(('.zip', 'tar.gz'))]
46 | return options
47 |
48 | def _present_options(self):
49 | r = requests.get(self.url)
50 | options = self._get_options(r)
51 | print('Options:\n')
52 | for i, o in enumerate(options):
53 | print("{0}: {1}".format(i, o))
54 | choice = input("\nPlease enter the number of the "
55 | "dataset above you wish to download:")
56 | return options[int(choice)]
57 |
58 | def _download_data(self, dataset_url, save_path):
59 | if not isdir(save_path):
60 | os.makedirs(save_path)
61 |
62 | base = basename(dataset_url)
63 | temp_save_path = join(save_path, base)
64 |
65 | with open(temp_save_path, "wb") as f:
66 | r = requests.get(dataset_url)
67 | f.write(r.content)
68 |
69 | if base.endswith('.tar.gz'):
70 | obj = tarfile.open(temp_save_path)
71 | elif base.endswith('.zip'):
72 | obj = ZipFile(temp_save_path, 'r')
73 | else:
74 | raise ValueError("Unknown File Type: {0}.".format(base))
75 |
76 | self._print("Unpacking Data...")
77 | obj.extractall(save_path)
78 | obj.close()
79 | os.remove(temp_save_path)
80 |
81 | def get(self, save_path, dataset=None):
82 | """
83 |
84 | Download a dataset.
85 |
86 | Args:
87 | save_path : str
88 | A directory to save the data to.
89 | dataset : str, optional
90 | A specific dataset to download.
91 | Note: this must include the file extension.
92 | If None, options will be presented for you
93 | to choose from.
94 |
95 | Returns:
96 | save_path_full : str
97 | The absolute path to the downloaded data.
98 |
99 | """
100 | if dataset is None:
101 | selected_dataset = self._present_options()
102 | else:
103 | selected_dataset = dataset
104 |
105 | save_path_full = join(save_path, selected_dataset.split('.')[0])
106 |
107 | if isdir(save_path_full):
108 | warn("\n'{0}' already exists. Voiding Download.".format(
109 | save_path_full))
110 | else:
111 | self._print('Downloading Data...')
112 | url = "{0}/{1}".format(self.url, selected_dataset)
113 | self._download_data(url, save_path=save_path)
114 |
115 | return abspath(save_path_full)
116 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, reflesh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 | # print(self.img_dir)
16 |
17 | self.doc = dominate.document(title=title)
18 | if reflesh > 0:
19 | with self.doc.head:
20 | meta(http_equiv="reflesh", content=str(reflesh))
21 |
22 | def get_image_dir(self):
23 | return self.img_dir
24 |
25 | def add_header(self, str):
26 | with self.doc:
27 | h3(str)
28 |
29 | def add_table(self, border=1):
30 | self.t = table(border=border, style="table-layout: fixed;")
31 | self.doc.add(self.t)
32 |
33 | def add_images(self, ims, txts, links, width=400):
34 | self.add_table()
35 | with self.t:
36 | with tr():
37 | for im, txt, link in zip(ims, txts, links):
38 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
39 | with p():
40 | with a(href=os.path.join('images', link)):
41 | img(style="width:%dpx" % width, src=os.path.join('images', im))
42 | br()
43 | p(txt)
44 |
45 | def save(self):
46 | html_file = '%s/index.html' % self.web_dir
47 | f = open(html_file, 'wt')
48 | f.write(self.doc.render())
49 | f.close()
50 |
51 |
52 | if __name__ == '__main__':
53 | html = HTML('web/', 'test_html')
54 | html.add_header('hello world')
55 |
56 | ims = []
57 | txts = []
58 | links = []
59 | for n in range(4):
60 | ims.append('image_%d.png' % n)
61 | txts.append('text_%d' % n)
62 | links.append('image_%d.png' % n)
63 | html.add_images(ims, txts, links)
64 | html.save()
65 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 |
5 | class ImagePool():
6 | def __init__(self, pool_size):
7 | self.pool_size = pool_size
8 | if self.pool_size > 0:
9 | self.num_imgs = 0
10 | self.images = []
11 |
12 | def query(self, images):
13 | if self.pool_size == 0:
14 | return images
15 | return_images = []
16 | for image in images:
17 | image = torch.unsqueeze(image.data, 0)
18 | if self.num_imgs < self.pool_size:
19 | self.num_imgs = self.num_imgs + 1
20 | self.images.append(image)
21 | return_images.append(image)
22 | else:
23 | p = random.uniform(0, 1)
24 | if p > 0.5:
25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
26 | tmp = self.images[random_id].clone()
27 | self.images[random_id] = image
28 | return_images.append(tmp)
29 | else:
30 | return_images.append(image)
31 | return_images = torch.cat(return_images, 0)
32 | return return_images
33 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import os
6 | from collections import OrderedDict
7 | from IPython import embed
8 |
9 | # Converts a Tensor into an image array (numpy)
10 | # |imtype|: the desired type of the converted numpy array
11 | def tensor2im(input_image, imtype=np.uint8):
12 | if isinstance(input_image, torch.Tensor):
13 | image_tensor = input_image.data
14 | else:
15 | return input_image
16 | image_numpy = image_tensor[0].cpu().float().numpy()
17 | if image_numpy.shape[0] == 1:
18 | image_numpy = np.tile(image_numpy, (3, 1, 1))
19 | image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0)) ),0, 1) * 255.0
20 | return image_numpy.astype(imtype)
21 |
22 |
23 | def diagnose_network(net, name='network'):
24 | mean = 0.0
25 | count = 0
26 | for param in net.parameters():
27 | if param.grad is not None:
28 | mean += torch.mean(torch.abs(param.grad.data))
29 | count += 1
30 | if count > 0:
31 | mean = mean / count
32 | print(name)
33 | print(mean)
34 |
35 |
36 | def save_image(image_numpy, image_path):
37 | image_pil = Image.fromarray(image_numpy)
38 | image_pil.save(image_path)
39 |
40 |
41 | def print_numpy(x, val=True, shp=False):
42 | x = x.astype(np.float64)
43 | if shp:
44 | print('shape,', x.shape)
45 | if val:
46 | x = x.flatten()
47 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
48 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
49 |
50 |
51 | def mkdirs(paths):
52 | if isinstance(paths, list) and not isinstance(paths, str):
53 | for path in paths:
54 | mkdir(path)
55 | else:
56 | mkdir(paths)
57 |
58 |
59 | def mkdir(path):
60 | if not os.path.exists(path):
61 | os.makedirs(path)
62 |
63 |
64 | def get_subset_dict(in_dict,keys):
65 | if(len(keys)):
66 | subset = OrderedDict()
67 | for key in keys:
68 | subset[key] = in_dict[key]
69 | else:
70 | subset = in_dict
71 | return subset
72 |
73 |
74 |
75 | # Color conversion code
76 | def rgb2xyz(rgb): # rgb from [0,1]
77 | # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423],
78 | # [0.212671, 0.715160, 0.072169],
79 | # [0.019334, 0.119193, 0.950227]])
80 |
81 | mask = (rgb > .04045).type(torch.FloatTensor)
82 | if(rgb.is_cuda):
83 | mask = mask.cuda()
84 |
85 | rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask)
86 |
87 | x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:]
88 | y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:]
89 | z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:]
90 | out = torch.cat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1)
91 |
92 | # if(torch.sum(torch.isnan(out))>0):
93 | # print('rgb2xyz')
94 | # embed()
95 | return out
96 |
97 | def xyz2rgb(xyz):
98 | # array([[ 3.24048134, -1.53715152, -0.49853633],
99 | # [-0.96925495, 1.87599 , 0.04155593],
100 | # [ 0.05564664, -0.20404134, 1.05731107]])
101 |
102 | r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:]
103 | g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:]
104 | b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:]
105 |
106 | rgb = torch.cat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1)
107 | rgb = torch.max(rgb,torch.zeros_like(rgb)) # sometimes reaches a small negative number, which causes NaNs
108 |
109 | mask = (rgb > .0031308).type(torch.FloatTensor)
110 | if(rgb.is_cuda):
111 | mask = mask.cuda()
112 |
113 | rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask)
114 |
115 | # if(torch.sum(torch.isnan(rgb))>0):
116 | # print('xyz2rgb')
117 | # embed()
118 | return rgb
119 |
120 | def xyz2lab(xyz):
121 | # 0.95047, 1., 1.08883 # white
122 | sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
123 | if(xyz.is_cuda):
124 | sc = sc.cuda()
125 |
126 | xyz_scale = xyz/sc
127 |
128 | mask = (xyz_scale > .008856).type(torch.FloatTensor)
129 | if(xyz_scale.is_cuda):
130 | mask = mask.cuda()
131 |
132 | xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask)
133 |
134 | L = 116.*xyz_int[:,1,:,:]-16.
135 | a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:])
136 | b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:])
137 | out = torch.cat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1)
138 |
139 | # if(torch.sum(torch.isnan(out))>0):
140 | # print('xyz2lab')
141 | # embed()
142 |
143 | return out
144 |
145 | def lab2xyz(lab):
146 | y_int = (lab[:,0,:,:]+16.)/116.
147 | x_int = (lab[:,1,:,:]/500.) + y_int
148 | z_int = y_int - (lab[:,2,:,:]/200.)
149 | if(z_int.is_cuda):
150 | z_int = torch.max(torch.Tensor((0,)).cuda(), z_int)
151 | else:
152 | z_int = torch.max(torch.Tensor((0,)), z_int)
153 |
154 | out = torch.cat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1)
155 | mask = (out > .2068966).type(torch.FloatTensor)
156 | if(out.is_cuda):
157 | mask = mask.cuda()
158 |
159 | out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask)
160 |
161 | sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
162 | sc = sc.to(out.device)
163 |
164 | out = out*sc
165 |
166 | # if(torch.sum(torch.isnan(out))>0):
167 | # print('lab2xyz')
168 | # embed()
169 |
170 | return out
171 |
172 | def rgb2lab(rgb, opt):
173 | lab = xyz2lab(rgb2xyz(rgb))
174 | l_rs = (lab[:,[0],:,:]-opt.l_cent)/opt.l_norm
175 | ab_rs = lab[:,1:,:,:]/opt.ab_norm
176 | out = torch.cat((l_rs,ab_rs),dim=1)
177 | # if(torch.sum(torch.isnan(out))>0):
178 | # print('rgb2lab')
179 | # embed()
180 | return out
181 |
182 | def lab2rgb(lab_rs, opt):
183 | l = lab_rs[:,[0],:,:]*opt.l_norm + opt.l_cent
184 | ab = lab_rs[:,1:,:,:]*opt.ab_norm
185 | lab = torch.cat((l,ab),dim=1)
186 | out = xyz2rgb(lab2xyz(lab))
187 | # if(torch.sum(torch.isnan(out))>0):
188 | # print('lab2rgb')
189 | # embed()
190 | return out
191 |
192 | def get_colorization_data(data_raw, opt, ab_thresh=5., p=.125, num_points=None):
193 | data = {}
194 |
195 | data_lab = rgb2lab(data_raw[0], opt)
196 | data['A'] = data_lab[:,[0,],:,:]
197 | data['B'] = data_lab[:,1:,:,:]
198 |
199 | if(ab_thresh > 0): # mask out grayscale images
200 | thresh = 1.*ab_thresh/opt.ab_norm
201 | mask = torch.sum(torch.abs(torch.max(torch.max(data['B'],dim=3)[0],dim=2)[0]-torch.min(torch.min(data['B'],dim=3)[0],dim=2)[0]),dim=1) >= thresh
202 | data['A'] = data['A'][mask,:,:,:]
203 | data['B'] = data['B'][mask,:,:,:]
204 | # print('Removed %i points'%torch.sum(mask==0).numpy())
205 | if(torch.sum(mask)==0):
206 | return None
207 |
208 | return add_color_patches_rand_gt(data, opt, p=p, num_points=num_points)
209 |
210 | def add_color_patches_rand_gt(data,opt,p=.125,num_points=None,use_avg=True,samp='normal'):
211 | # Add random color points sampled from ground truth based on:
212 | # Number of points
213 | # - if num_points is 0, then sample from geometric distribution, drawn from probability p
214 | # - if num_points > 0, then sample that number of points
215 | # Location of points
216 | # - if samp is 'normal', draw from N(0.5, 0.25) of image
217 | # - otherwise, draw from U[0, 1] of image
218 | N,C,H,W = data['B'].shape
219 |
220 | data['hint_B'] = torch.zeros_like(data['B'])
221 | data['mask_B'] = torch.zeros_like(data['A'])
222 |
223 | for nn in range(N):
224 | pp = 0
225 | cont_cond = True
226 | while(cont_cond):
227 | if(num_points is None): # draw from geometric
228 | # embed()
229 | cont_cond = np.random.rand() < (1-p)
230 | else: # add certain number of points
231 | cont_cond = pp < num_points
232 | if(not cont_cond): # skip out of loop if condition not met
233 | continue
234 |
235 | P = np.random.choice(opt.sample_Ps) # patch size
236 |
237 | # sample location
238 | if(samp=='normal'): # geometric distribution
239 | h = int(np.clip(np.random.normal( (H-P+1)/2., (H-P+1)/4.), 0, H-P))
240 | w = int(np.clip(np.random.normal( (W-P+1)/2., (W-P+1)/4.), 0, W-P))
241 | else: # uniform distribution
242 | h = np.random.randint(H-P+1)
243 | w = np.random.randint(W-P+1)
244 |
245 | # add color point
246 | if(use_avg):
247 | # embed()
248 | data['hint_B'][nn,:,h:h+P,w:w+P] = torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1)
249 | else:
250 | data['hint_B'][nn,:,h:h+P,w:w+P] = data['B'][nn,:,h:h+P,w:w+P]
251 |
252 | data['mask_B'][nn,:,h:h+P,w:w+P] = 1
253 |
254 | # increment counter
255 | pp+=1
256 |
257 | data['mask_B']-=opt.mask_cent
258 |
259 | return data
260 |
261 | def add_color_patch(data,mask,opt,P=1,hw=[128,128],ab=[0,0]):
262 | # Add a color patch at (h,w) with color (a,b)
263 | data[:,0,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1.*ab[0]/opt.ab_norm
264 | data[:,1,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1.*ab[1]/opt.ab_norm
265 | mask[:,:,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1-opt.mask_cent
266 |
267 | return (data,mask)
268 |
269 | def crop_mult(data,mult=16,HWmax=[800,1200]):
270 | # crop image to a multiple
271 | H,W = data.shape[2:]
272 | Hnew = int(min(H/mult*mult,HWmax[0]))
273 | Wnew = int(min(W/mult*mult,HWmax[1]))
274 | h = (H-Hnew)/2
275 | w = (W-Wnew)/2
276 |
277 | return data[:,:,h:h+Hnew,w:w+Wnew]
278 |
279 | def encode_ab_ind(data_ab, opt):
280 | # Encode ab value into an index
281 | # INPUTS
282 | # data_ab Nx2xHxW \in [-1,1]
283 | # OUTPUTS
284 | # data_q Nx1xHxW \in [0,Q)
285 |
286 | data_ab_rs = torch.round((data_ab*opt.ab_norm + opt.ab_max)/opt.ab_quant) # normalized bin number
287 | data_q = data_ab_rs[:,[0],:,:]*opt.A + data_ab_rs[:,[1],:,:]
288 | return data_q
289 |
290 | def decode_ind_ab(data_q, opt):
291 | # Decode index into ab value
292 | # INPUTS
293 | # data_q Nx1xHxW \in [0,Q)
294 | # OUTPUTS
295 | # data_ab Nx2xHxW \in [-1,1]
296 |
297 | data_a = data_q/opt.A
298 | data_b = data_q - data_a*opt.A
299 | data_ab = torch.cat((data_a,data_b),dim=1)
300 |
301 | if(data_q.is_cuda):
302 | type_out = torch.cuda.FloatTensor
303 | else:
304 | type_out = torch.FloatTensor
305 | data_ab = ((data_ab.type(type_out)*opt.ab_quant) - opt.ab_max)/opt.ab_norm
306 |
307 | return data_ab
308 |
309 | def decode_max_ab(data_ab_quant, opt):
310 | # Decode probability distribution by using bin with highest probability
311 | # INPUTS
312 | # data_ab_quant NxQxHxW \in [0,1]
313 | # OUTPUTS
314 | # data_ab Nx2xHxW \in [-1,1]
315 |
316 | data_q = torch.argmax(data_ab_quant,dim=1)[:,None,:,:]
317 | return decode_ind_ab(data_q, opt)
318 |
319 | def decode_mean(data_ab_quant, opt):
320 | # Decode probability distribution by taking mean over all bins
321 | # INPUTS
322 | # data_ab_quant NxQxHxW \in [0,1]
323 | # OUTPUTS
324 | # data_ab_inf Nx2xHxW \in [-1,1]
325 |
326 | (N,Q,H,W) = data_ab_quant.shape
327 | a_range = torch.range(-opt.ab_max, opt.ab_max, step=opt.ab_quant).to(data_ab_quant.device)[None,:,None,None]
328 | a_range = a_range.type(data_ab_quant.type())
329 |
330 | # reshape to AB space
331 | data_ab_quant = data_ab_quant.view((N,int(opt.A),int(opt.A),H,W))
332 | data_a_total = torch.sum(data_ab_quant,dim=2)
333 | data_b_total = torch.sum(data_ab_quant,dim=1)
334 |
335 | # matrix multiply
336 | data_a_inf = torch.sum(data_a_total * a_range,dim=1,keepdim=True)
337 | data_b_inf = torch.sum(data_b_total * a_range,dim=1,keepdim=True)
338 |
339 | data_ab_inf = torch.cat((data_a_inf,data_b_inf),dim=1)/opt.ab_norm
340 |
341 | return data_ab_inf
342 |
343 | def calculate_psnr_np(img1, img2):
344 | import numpy as np
345 | SE_map = (1.*img1-img2)**2
346 | cur_MSE = np.mean(SE_map)
347 | return 20*np.log10(255./np.sqrt(cur_MSE))
348 |
349 | def calculate_psnr_torch(img1, img2):
350 | SE_map = (1.*img1-img2)**2
351 | cur_MSE = torch.mean(SE_map)
352 | return 20*torch.log10(1./torch.sqrt(cur_MSE))
353 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | from . import util
6 | from . import html
7 | from scipy.misc import imresize
8 |
9 |
10 | # save image to the disk
11 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
12 | image_dir = webpage.get_image_dir()
13 | short_path = ntpath.basename(image_path[0])
14 | name = os.path.splitext(short_path)[0]
15 |
16 | webpage.add_header(name)
17 | ims, txts, links = [], [], []
18 |
19 | for label, im_data in visuals.items():
20 | im = util.tensor2im(im_data)
21 | image_name = '%s_%s.png' % (name, label)
22 | save_path = os.path.join(image_dir, image_name)
23 | h, w, _ = im.shape
24 | if aspect_ratio > 1.0:
25 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
26 | if aspect_ratio < 1.0:
27 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
28 | util.save_image(im, save_path)
29 |
30 | ims.append(image_name)
31 | txts.append(label)
32 | links.append(image_name)
33 | webpage.add_images(ims, txts, links, width=width)
34 |
35 |
36 | class Visualizer():
37 | def __init__(self, opt):
38 | self.display_id = opt.display_id
39 | self.use_html = opt.isTrain and not opt.no_html
40 | self.win_size = opt.display_winsize
41 | self.name = opt.name
42 | self.opt = opt
43 | self.saved = False
44 | if self.display_id > 0:
45 | import visdom
46 | self.ncols = opt.display_ncols
47 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port)
48 |
49 | if self.use_html:
50 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
51 | self.img_dir = os.path.join(self.web_dir, 'images')
52 | print('create web directory %s...' % self.web_dir)
53 | util.mkdirs([self.web_dir, self.img_dir])
54 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
55 | with open(self.log_name, "a") as log_file:
56 | now = time.strftime("%c")
57 | log_file.write('================ Training Loss (%s) ================\n' % now)
58 |
59 | def reset(self):
60 | self.saved = False
61 |
62 | # |visuals|: dictionary of images to display or save
63 | def display_current_results(self, visuals, epoch, save_result):
64 | if self.display_id > 0: # show images in the browser
65 | ncols = self.ncols
66 | if ncols > 0:
67 | ncols = min(ncols, len(visuals))
68 | h, w = next(iter(visuals.values())).shape[:2]
69 | table_css = """""" % (w, h)
73 | title = self.name
74 | label_html = ''
75 | label_html_row = ''
76 | images = []
77 | idx = 0
78 | for label, image in visuals.items():
79 | image_numpy = util.tensor2im(image)
80 | label_html_row += '%s | ' % label
81 | images.append(image_numpy.transpose([2, 0, 1]))
82 | idx += 1
83 | if idx % ncols == 0:
84 | label_html += '%s
' % label_html_row
85 | label_html_row = ''
86 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
87 | while idx % ncols != 0:
88 | images.append(white_image)
89 | label_html_row += ' | '
90 | idx += 1
91 | if label_html_row != '':
92 | label_html += '%s
' % label_html_row
93 | # pane col = image row
94 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
95 | padding=2, opts=dict(title=title + ' images'))
96 | label_html = '' % label_html
97 | self.vis.text(table_css + label_html, win=self.display_id + 2,
98 | opts=dict(title=title + ' labels'))
99 | else:
100 | idx = 1
101 | for label, image in visuals.items():
102 | image_numpy = util.tensor2im(image)
103 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
104 | win=self.display_id + idx)
105 | idx += 1
106 |
107 | if self.use_html and (save_result or not self.saved): # save images to a html file
108 | self.saved = True
109 | for label, image in visuals.items():
110 | image_numpy = util.tensor2im(image)
111 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
112 | util.save_image(image_numpy, img_path)
113 | # update website
114 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
115 | for n in range(epoch, 0, -1):
116 | webpage.add_header('epoch [%d]' % n)
117 | ims, txts, links = [], [], []
118 |
119 | for label, image_numpy in visuals.items():
120 | image_numpy = util.tensor2im(image)
121 | img_path = 'epoch%.3d_%s.png' % (n, label)
122 | ims.append(img_path)
123 | txts.append(label)
124 | links.append(img_path)
125 | webpage.add_images(ims, txts, links, width=self.win_size)
126 | webpage.save()
127 |
128 | # losses: dictionary of error labels and values
129 | def plot_current_losses(self, epoch, counter_ratio, opt, losses):
130 | if not hasattr(self, 'plot_data'):
131 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
132 | self.plot_data['X'].append(epoch + counter_ratio)
133 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
134 | self.vis.line(
135 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
136 | Y=np.array(self.plot_data['Y']),
137 | opts={
138 | 'title': self.name + ' loss over time',
139 | 'legend': self.plot_data['legend'],
140 | 'xlabel': 'epoch',
141 | 'ylabel': 'loss'},
142 | win=self.display_id)
143 |
144 | # losses: same format as |losses| of plot_current_losses
145 | def print_current_losses(self, epoch, i, losses, t, t_data):
146 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
147 | for k, v in losses.items():
148 | message += '%s: %.3f, ' % (k, v)
149 |
150 | print(message)
151 | with open(self.log_name, "a") as log_file:
152 | log_file.write('%s\n' % message)
153 |
--------------------------------------------------------------------------------