├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── siggraph_pretrained │ └── sweep_reference.png ├── data ├── __init__.py ├── aligned_dataset.py ├── base_data_loader.py ├── base_dataset.py ├── color_dataset.py ├── image_folder.py └── single_dataset.py ├── environment.yml ├── imgs └── demo.gif ├── make_ilsvrc_dataset.py ├── models ├── __init__.py ├── base_model.py ├── networks.py └── pix2pix_model.py ├── options ├── __init__.py ├── base_options.py └── train_options.py ├── pretrained_models └── download_siggraph_model.sh ├── requirements.txt ├── resources ├── ilsvrclin12_val_inds.npy └── psnrs_siggraph.npy ├── scripts ├── check_all.sh ├── conda_deps.sh ├── install_deps.sh └── train_siggraph.sh ├── test.py ├── test_sweep.py ├── train.py └── util ├── __init__.py ├── get_data.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | debug* 3 | datasets/ 4 | checkpoints/ 5 | results/ 6 | build/ 7 | dist/ 8 | log.txt 9 | *.png 10 | torch.egg-info/ 11 | */**/__pycache__ 12 | torch/version.py 13 | torch/csrc/generic/TensorMethods.cpp 14 | torch/lib/*.so* 15 | torch/lib/*.dylib* 16 | torch/lib/*.h 17 | torch/lib/build 18 | torch/lib/tmp_install 19 | torch/lib/include 20 | torch/lib/torch_shm_manager 21 | torch/csrc/cudnn/cuDNN.cpp 22 | torch/csrc/nn/THNN.cwrap 23 | torch/csrc/nn/THNN.cpp 24 | torch/csrc/nn/THCUNN.cwrap 25 | torch/csrc/nn/THCUNN.cpp 26 | torch/csrc/nn/THNN_generic.cwrap 27 | torch/csrc/nn/THNN_generic.cpp 28 | torch/csrc/nn/THNN_generic.h 29 | docs/src/**/* 30 | test/data/legacy_modules.t7 31 | test/data/gpu_tensors.pt 32 | test/htmlcov 33 | test/.coverage 34 | */*.pyc 35 | */**/*.pyc 36 | */**/**/*.pyc 37 | */**/**/**/*.pyc 38 | */**/**/**/**/*.pyc 39 | */*.so* 40 | */**/*.so* 41 | */**/*.dylib* 42 | test/data/legacy_serialized.pt 43 | *~ 44 | .idea 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Richard Zhang, Jun-Yan Zhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interactive Deep Colorization in PyTorch 2 | [Project Page](https://richzhang.github.io/ideepcolor/) | [Paper](https://arxiv.org/abs/1705.02999) | [Video](https://youtu.be/eL5ilZgM89Q) | [Talk](https://www.youtube.com/watch?v=rp5LUSbdsys) | [UI code](https://github.com/junyanz/interactive-deep-colorization/) 3 | 4 | 5 | 6 | Real-Time User-Guided Image Colorization with Learned Deep Priors. 7 | [Richard Zhang](https://richzhang.github.io/)\*, [Jun-Yan Zhu](http://people.csail.mit.edu/junyanz/)\*, [Phillip Isola](http://people.eecs.berkeley.edu/~isola/), [Xinyang Geng](http://young-geng.xyz/), Angela S. Lin, Tianhe Yu, and [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/). 8 | In ACM Transactions on Graphics (SIGGRAPH 2017). 9 | 10 | This is our PyTorch reimplementation for interactive image colorization, written by [Richard Zhang](https://github.com/richzhang) and [Jun-Yan Zhu](https://github.com/junyanz). 11 | 12 | This repository contains training usage. The original, official GitHub repo (with an interactive GUI, and originally Caffe backend) is [here](https://richzhang.github.io/ideepcolor/). The official repo has been updated to support PyTorch models on the backend, which can be trained in this repository. 13 | 14 | ## Prerequisites 15 | - Linux or macOS 16 | - Python 2 or 3 17 | - CPU or NVIDIA GPU + CUDA CuDNN 18 | 19 | ## Getting Started 20 | ### Installation 21 | - Install PyTorch 0.4+ and torchvision from http://pytorch.org and other dependencies (e.g., [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate)). You can install all the dependencies by 22 | ```bash 23 | pip install -r requirements.txt 24 | ``` 25 | - Clone this repo: 26 | ```bash 27 | git clone https://github.com/richzhang/colorization-pytorch 28 | cd colorization-pytorch 29 | ``` 30 | 31 | ### Dataset preparation 32 | - Download the ILSVRC 2012 dataset and run the following script to prepare data 33 | `python make_ilsvrc_dataset.py --in_path /PATH/TO/ILSVRC12`. This will make symlinks into the training set, and divide the ILSVRC validation set into validation and test splits for colorization. 34 | 35 | ### Training interactive colorization 36 | - Train a model: ```bash ./scripts/train_siggraph.sh```. This is a 2 stage training process. First, the network is trained for automatic colorization using classification loss. Results are in `./checkpoints/siggraph_class`. Then, the network is fine-tuned for interactive colorization using regression loss. Final results are in `./checkpoints/siggraph_reg2`. 37 | 38 | - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. The following values are monitored: 39 | * `G_CE` is a cross-entropy loss between predicted color distribution and ground truth color. 40 | * `G_entr` is the entropy of the predicted distribution. 41 | * `G_entr_hint` is the entropy of the predicted distribution at points where a color hint is given. 42 | * `G_L1_max` is the L1 distance between the ground truth color and argmax of the predicted color distribution. 43 | * `G_L1_mean` is the L1 distance between the ground truth color and mean of the predicted color distribution. 44 | * `G_L1_reg` is the L1 distance between the ground truth color and the predicted color. 45 | * `G_fake_real` is the L1 distance between the predicted color and the ground truth color (in locations where a hint is given). 46 | * `G_fake_hint` is the L1 distance between the predicted color and the input hint color (in locations where a hint is given). It's a measure of how much the network "trusts" the input hint. 47 | * `G_real_hint` is the L1 distance between the ground truth color and the input hint color (in locations where a hint is given). 48 | 49 | 50 | ### Testing interactive colorization 51 | - Get a model. Either: 52 | * (1) download the pretrained model by running ```bash pretrained_models/download_siggraph_model.sh```, which will give you a few models. 53 | * Original caffe weights [Recommended] `./checkpoints/siggraph_caffemodel/latest_net_G.pth` is the original caffemodel weights, converted to PyTorch. It is recommended. Be sure to set `--mask_cent 0` when running it. 54 | * Retrained model: `./checkpoints/siggraph_retrained/latest_net_G.pth`. The model achieves better PSNR but performs qualitatively differently. Note that this repository is an approximate reimplementation of the siggraph paper. 55 | * (2) train your own model (as described in the section above), which will leave a model in `./checkpoints/siggraph_reg2/latest_net_G.pth` 56 | 57 | - Test the model on validation data: 58 | * ```python test.py --name siggraph_caffemodel --mask_cent 0``` for original caffemodel weights 59 | * ```python test.py --name siggraph_retrained ``` for retrained weights. 60 | * ```python test.py --name siggraph_reg2 ``` if you retrained your own model 61 | The test results will be saved to an HTML file in `./results/[[NAME]]/latest_val/index.html`. For each image in the validation set, it will test (1) automatic colorization, (2) interactive colorization with a few random hints, and (3) interactive colorization with lots of random hints. 62 | 63 | - Test the model by making PSNR vs. the number of hints plot: ```python test_sweep.py --name [[NAME]] ```. This plot was used in Figure 6 of the [paper](https://arxiv.org/abs/1705.02999). This test randomly reveals 6x6 color hint patches to the network and sees how accurate the colorization is with respect to the ground truth. 64 | 65 | 66 | 67 | - Test the model interactively with the original official [repository](https://github.com/junyanz/interactive-deep-colorization). Follow installation instructions in that repo and run `python ideepcolor.py --backend pytorch --color_model [[PTH/TO/MODEL]] --dist_model [[PTH/TO/MODEL]]`. 68 | 69 | 70 | ### Citation 71 | If you use this code for your research, please cite our paper: 72 | ``` 73 | @article{zhang2017real, 74 | title={Real-Time User-Guided Image Colorization with Learned Deep Priors}, 75 | author={Zhang, Richard and Zhu, Jun-Yan and Isola, Phillip and Geng, Xinyang and Lin, Angela S and Yu, Tianhe and Efros, Alexei A}, 76 | journal={ACM Transactions on Graphics (TOG)}, 77 | volume={9}, 78 | number={4}, 79 | year={2017}, 80 | publisher={ACM} 81 | } 82 | ``` 83 | 84 | ## Acknowledgments 85 | This code borrows heavily from the [pytorch-CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository. 86 | -------------------------------------------------------------------------------- /checkpoints/siggraph_pretrained/sweep_reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/checkpoints/siggraph_pretrained/sweep_reference.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | from data.base_dataset import BaseDataset 5 | 6 | 7 | def find_dataset_using_name(dataset_name): 8 | # Given the option --dataset_mode [datasetname], 9 | # the file "data/datasetname_dataset.py" 10 | # will be imported. 11 | dataset_filename = "data." + dataset_name + "_dataset" 12 | datasetlib = importlib.import_module(dataset_filename) 13 | 14 | # In the file, the class called DatasetNameDataset() will 15 | # be instantiated. It has to be a subclass of BaseDataset, 16 | # and it is case-insensitive. 17 | dataset = None 18 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 19 | for name, cls in datasetlib.__dict__.items(): 20 | if name.lower() == target_dataset_name.lower() \ 21 | and issubclass(cls, BaseDataset): 22 | dataset = cls 23 | 24 | if dataset is None: 25 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 26 | exit(0) 27 | 28 | return dataset 29 | 30 | 31 | def get_option_setter(dataset_name): 32 | dataset_class = find_dataset_using_name(dataset_name) 33 | return dataset_class.modify_commandline_options 34 | 35 | 36 | def create_dataset(opt): 37 | dataset = find_dataset_using_name(opt.dataset_mode) 38 | instance = dataset() 39 | instance.initialize(opt) 40 | print("dataset [%s] was created" % (instance.name())) 41 | return instance 42 | 43 | 44 | def CreateDataLoader(opt): 45 | data_loader = CustomDatasetDataLoader() 46 | data_loader.initialize(opt) 47 | return data_loader 48 | 49 | 50 | # Wrapper class of Dataset class that performs 51 | # multi-threaded data loading 52 | class CustomDatasetDataLoader(BaseDataLoader): 53 | def name(self): 54 | return 'CustomDatasetDataLoader' 55 | 56 | def initialize(self, opt): 57 | BaseDataLoader.initialize(self, opt) 58 | self.dataset = create_dataset(opt) 59 | self.dataloader = torch.utils.data.DataLoader( 60 | self.dataset, 61 | batch_size=opt.batch_size, 62 | shuffle=not opt.serial_batches, 63 | num_workers=int(opt.num_threads)) 64 | 65 | def load_data(self): 66 | return self 67 | 68 | def __len__(self): 69 | return min(len(self.dataset), self.opt.max_dataset_size) 70 | 71 | def __iter__(self): 72 | for i, data in enumerate(self.dataloader): 73 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 74 | break 75 | yield data 76 | -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from data.base_dataset import BaseDataset 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | 9 | 10 | class AlignedDataset(BaseDataset): 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def initialize(self, opt): 16 | self.opt = opt 17 | self.root = opt.dataroot 18 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 19 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 20 | assert(opt.resize_or_crop == 'resize_and_crop') 21 | 22 | def __getitem__(self, index): 23 | AB_path = self.AB_paths[index] 24 | AB = Image.open(AB_path).convert('RGB') 25 | w, h = AB.size 26 | w2 = int(w / 2) 27 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 28 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 29 | A = transforms.ToTensor()(A) 30 | B = transforms.ToTensor()(B) 31 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 32 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 33 | 34 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 35 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 36 | 37 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 38 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) 39 | 40 | if self.opt.which_direction == 'BtoA': 41 | input_nc = self.opt.output_nc 42 | output_nc = self.opt.input_nc 43 | else: 44 | input_nc = self.opt.input_nc 45 | output_nc = self.opt.output_nc 46 | 47 | if (not self.opt.no_flip) and random.random() < 0.5: 48 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 49 | idx = torch.LongTensor(idx) 50 | A = A.index_select(2, idx) 51 | B = B.index_select(2, idx) 52 | 53 | if input_nc == 1: # RGB to gray 54 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 55 | A = tmp.unsqueeze(0) 56 | 57 | if output_nc == 1: # RGB to gray 58 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 59 | B = tmp.unsqueeze(0) 60 | 61 | return {'A': A, 'B': B, 62 | 'A_paths': AB_path, 'B_paths': AB_path} 63 | 64 | def __len__(self): 65 | return len(self.AB_paths) 66 | 67 | def name(self): 68 | return 'AlignedDataset' 69 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class BaseDataset(data.Dataset): 7 | def __init__(self): 8 | super(BaseDataset, self).__init__() 9 | 10 | def name(self): 11 | return 'BaseDataset' 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | pass 19 | 20 | def __len__(self): 21 | return 0 22 | 23 | 24 | def get_transform(opt): 25 | transform_list = [] 26 | if opt.resize_or_crop == 'resize_and_crop': 27 | osize = [opt.loadSize, opt.loadSize] 28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | elif opt.resize_or_crop == 'crop': 31 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 32 | elif opt.resize_or_crop == 'scale_width': 33 | transform_list.append(transforms.Lambda( 34 | lambda img: __scale_width(img, opt.fineSize))) 35 | elif opt.resize_or_crop == 'scale_width_and_crop': 36 | transform_list.append(transforms.Lambda( 37 | lambda img: __scale_width(img, opt.loadSize))) 38 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 39 | elif opt.resize_or_crop == 'none': 40 | transform_list.append(transforms.Lambda( 41 | lambda img: __adjust(img))) 42 | else: 43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) 44 | 45 | if opt.isTrain and not opt.no_flip: 46 | transform_list.append(transforms.RandomHorizontalFlip()) 47 | 48 | transform_list += [transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), 50 | (0.5, 0.5, 0.5))] 51 | return transforms.Compose(transform_list) 52 | 53 | # just modify the width and height to be multiple of 4 54 | 55 | 56 | def __adjust(img): 57 | ow, oh = img.size 58 | 59 | # the size needs to be a multiple of this number, 60 | # because going through generator network may change img size 61 | # and eventually cause size mismatch error 62 | mult = 4 63 | if ow % mult == 0 and oh % mult == 0: 64 | return img 65 | w = (ow - 1) // mult 66 | w = (w + 1) * mult 67 | h = (oh - 1) // mult 68 | h = (h + 1) * mult 69 | 70 | if ow != w or oh != h: 71 | __print_size_warning(ow, oh, w, h) 72 | 73 | return img.resize((w, h), Image.BICUBIC) 74 | 75 | 76 | def __scale_width(img, target_width): 77 | ow, oh = img.size 78 | 79 | # the size needs to be a multiple of this number, 80 | # because going through generator network may change img size 81 | # and eventually cause size mismatch error 82 | mult = 4 83 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 84 | if (ow == target_width and oh % mult == 0): 85 | return img 86 | w = target_width 87 | target_height = int(target_width * oh / ow) 88 | m = (target_height - 1) // mult 89 | h = (m + 1) * mult 90 | 91 | if target_height != h: 92 | __print_size_warning(target_width, target_height, w, h) 93 | 94 | return img.resize((w, h), Image.BICUBIC) 95 | 96 | 97 | def __print_size_warning(ow, oh, w, h): 98 | if not hasattr(__print_size_warning, 'has_printed'): 99 | print("The image size needs to be a multiple of 4. " 100 | "The loaded image size was (%d, %d), so it was adjusted to " 101 | "(%d, %d). This adjustment will be done to all images " 102 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 103 | __print_size_warning.has_printed = True 104 | -------------------------------------------------------------------------------- /data/color_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | 7 | class ColorDataset(BaseDataset): 8 | @staticmethod 9 | def modify_commandline_options(parser, is_train): 10 | return parser 11 | 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_A = os.path.join(opt.dataroot) 16 | 17 | self.A_paths = make_dataset(self.dir_A) 18 | 19 | self.A_paths = sorted(self.A_paths) 20 | 21 | self.transform = get_transform(opt) 22 | 23 | def __getitem__(self, index): 24 | A_path = self.A_paths[index] 25 | A_img = Image.open(A_path).convert('RGB') 26 | A = self.transform(A_img) 27 | if self.opt.which_direction == 'BtoA': 28 | input_nc = self.opt.output_nc 29 | else: 30 | input_nc = self.opt.input_nc 31 | 32 | # convert to Lab 33 | # rgb2lab(A_img) 34 | 35 | if input_nc == 1: # RGB to gray 36 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 37 | A = tmp.unsqueeze(0) 38 | 39 | return {'A': A, 'A_paths': A_path} 40 | 41 | def __len__(self): 42 | return len(self.A_paths) 43 | 44 | def name(self): 45 | return 'ColorImageDataset' 46 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | 7 | class SingleDataset(BaseDataset): 8 | @staticmethod 9 | def modify_commandline_options(parser, is_train): 10 | return parser 11 | 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_A = os.path.join(opt.dataroot) 16 | 17 | self.A_paths = make_dataset(self.dir_A) 18 | 19 | self.A_paths = sorted(self.A_paths) 20 | 21 | self.transform = get_transform(opt) 22 | 23 | def __getitem__(self, index): 24 | A_path = self.A_paths[index] 25 | A_img = Image.open(A_path).convert('RGB') 26 | A = self.transform(A_img) 27 | if self.opt.which_direction == 'BtoA': 28 | input_nc = self.opt.output_nc 29 | else: 30 | input_nc = self.opt.input_nc 31 | 32 | if input_nc == 1: # RGB to gray 33 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 34 | A = tmp.unsqueeze(0) 35 | 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | return len(self.A_paths) 40 | 41 | def name(self): 42 | return 'SingleImageDataset' 43 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: colorization-pytorch 2 | channels: 3 | - peterjc123 4 | - defaults 5 | dependencies: 6 | - python=3.5.5 7 | - pytorch=0.4.1 8 | - scipy 9 | - pip: 10 | - dominate==2.3.1 11 | - git+https://github.com/pytorch/vision.git 12 | - Pillow==5.0.0 13 | - numpy==1.14.1 14 | - visdom==0.1.7 15 | -------------------------------------------------------------------------------- /imgs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/imgs/demo.gif -------------------------------------------------------------------------------- /make_ilsvrc_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from util import util 4 | import numpy as np 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | parser.add_argument('--in_path', type=str, default='/data/big/dataset/ILSVRC2012') 9 | parser.add_argument('--out_path', type=str, default='./dataset/ilsvrc2012/') 10 | 11 | opt = parser.parse_args() 12 | orig_path = opt.in_path 13 | print('Copying ILSVRC from...[%s]' % orig_path) 14 | 15 | # Copy over part of training set (for initializer) 16 | trn_small_path = os.path.join(opt.out_path, 'train_small') 17 | util.mkdirs(opt.out_path) 18 | util.mkdirs(trn_small_path) 19 | train_subdirs = os.listdir(os.path.join(opt.in_path, 'train')) 20 | for train_subdir in train_subdirs[:10]: 21 | os.symlink(os.path.join(opt.in_path, 'train', train_subdir), os.path.join(trn_small_path, train_subdir)) 22 | print('Making small training set in...[%s]' % trn_small_path) 23 | 24 | # Copy over whole training set 25 | trn_path = os.path.join(opt.out_path, 'train') 26 | util.mkdirs(opt.out_path) 27 | os.symlink(os.path.join(opt.in_path, 'train'), trn_path) 28 | print('Making training set in...[%s]' % trn_path) 29 | 30 | # Copy over subset of ILSVRC12 val set for colorization val set 31 | val_path = os.path.join(opt.out_path, 'val/imgs') 32 | util.mkdirs(val_path) 33 | print('Making validation set in...[%s]' % val_path) 34 | for val_ind in range(1000): 35 | os.system('ln -s %s/val/ILSVRC2012_val_%08d.JPEG %s/ILSVRC2012_val_%08d.JPEG' % (orig_path, val_ind + 1, val_path, val_ind + 1)) 36 | # os.system('cp %s/val/ILSVRC2012_val_%08d.JPEG %s/ILSVRC2012_val_%08d.JPEG'%(orig_path,val_ind+1,val_path,val_ind+1)) 37 | 38 | # Copy over subset of ILSVRC12 val set for colorization test set 39 | test_path = os.path.join(opt.out_path, 'test/imgs') 40 | util.mkdirs(test_path) 41 | val_inds = np.load('./resources/ilsvrclin12_val_inds.npy') 42 | print('Making test set in...[%s]' % test_path) 43 | for val_ind in val_inds: 44 | os.system('ln -s %s/val/ILSVRC2012_val_%08d.JPEG %s/ILSVRC2012_val_%08d.JPEG' % (orig_path, val_ind + 1, test_path, val_ind + 1)) 45 | # os.system('cp %s/val/ILSVRC2012_val_%08d.JPEG %s/ILSVRC2012_val_%08d.JPEG'%(orig_path,val_ind+1,test_path,val_ind+1)) 46 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | model_filename = "models." + model_name + "_model" 10 | modellib = importlib.import_module(model_filename) 11 | 12 | # In the file, the class called ModelNameModel() will 13 | # be instantiated. It has to be a subclass of BaseModel, 14 | # and it is case-insensitive. 15 | model = None 16 | target_model_name = model_name.replace('_', '') + 'model' 17 | for name, cls in modellib.__dict__.items(): 18 | if name.lower() == target_model_name.lower() \ 19 | and issubclass(cls, BaseModel): 20 | model = cls 21 | 22 | if model is None: 23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 24 | exit(0) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | model = find_model_using_name(opt.model) 36 | instance = model() 37 | instance.initialize(opt) 38 | print("model [%s] was created" % (instance.name())) 39 | return instance 40 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | # modify parser to add command line options, 9 | # and also change the default values if needed 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train): 12 | return parser 13 | 14 | def name(self): 15 | return 'BaseModel' 16 | 17 | def initialize(self, opt): 18 | self.opt = opt 19 | self.gpu_ids = opt.gpu_ids 20 | self.isTrain = opt.isTrain 21 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 22 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 23 | if opt.resize_or_crop != 'scale_width': 24 | torch.backends.cudnn.benchmark = True 25 | self.loss_names = [] 26 | self.model_names = [] 27 | self.visual_names = [] 28 | self.image_paths = [] 29 | 30 | def set_input(self, input): 31 | self.input = input 32 | 33 | def forward(self): 34 | pass 35 | 36 | # load and print networks; create schedulers 37 | def setup(self, opt, parser=None): 38 | if self.isTrain: 39 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 40 | 41 | if not self.isTrain or opt.load_model: 42 | self.load_networks(opt.which_epoch) 43 | self.print_networks(opt.verbose) 44 | 45 | # make models eval mode during test time 46 | def eval(self): 47 | for name in self.model_names: 48 | if isinstance(name, str): 49 | net = getattr(self, 'net' + name) 50 | net.eval() 51 | 52 | # used in test time, wrapping `forward` in no_grad() so we don't save 53 | # intermediate steps for backprop 54 | def test(self, compute_losses=False): 55 | with torch.no_grad(): 56 | self.forward() 57 | if(compute_losses): 58 | self.compute_losses_G() 59 | 60 | # get image paths 61 | def get_image_paths(self): 62 | return self.image_paths 63 | 64 | def optimize_parameters(self): 65 | pass 66 | 67 | # update learning rate (called once every epoch) 68 | def update_learning_rate(self): 69 | for scheduler in self.schedulers: 70 | scheduler.step() 71 | lr = self.optimizers[0].param_groups[0]['lr'] 72 | print('learning rate = %.7f' % lr) 73 | 74 | # return visualization images. train.py will display these images, and save the images to a html 75 | def get_current_visuals(self): 76 | visual_ret = OrderedDict() 77 | for name in self.visual_names: 78 | if isinstance(name, str): 79 | visual_ret[name] = getattr(self, name) 80 | return visual_ret 81 | 82 | # return traning losses/errors. train.py will print out these errors as debugging information 83 | def get_current_losses(self): 84 | errors_ret = OrderedDict() 85 | for name in self.loss_names: 86 | if isinstance(name, str): 87 | # float(...) works for both scalar tensor and float number 88 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 89 | return errors_ret 90 | 91 | # save models to the disk 92 | def save_networks(self, which_epoch): 93 | for name in self.model_names: 94 | if isinstance(name, str): 95 | save_filename = '%s_net_%s.pth' % (which_epoch, name) 96 | save_path = os.path.join(self.save_dir, save_filename) 97 | net = getattr(self, 'net' + name) 98 | 99 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 100 | torch.save(net.module.cpu().state_dict(), save_path) 101 | net.cuda(self.gpu_ids[0]) 102 | else: 103 | torch.save(net.cpu().state_dict(), save_path) 104 | 105 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 106 | key = keys[i] 107 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 108 | if module.__class__.__name__.startswith('InstanceNorm') and \ 109 | (key == 'running_mean' or key == 'running_var'): 110 | if getattr(module, key) is None: 111 | state_dict.pop('.'.join(keys)) 112 | if module.__class__.__name__.startswith('InstanceNorm') and \ 113 | (key == 'num_batches_tracked'): 114 | state_dict.pop('.'.join(keys)) 115 | else: 116 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 117 | 118 | # load models from the disk 119 | def load_networks(self, which_epoch): 120 | for name in self.model_names: 121 | if isinstance(name, str): 122 | load_filename = '%s_net_%s.pth' % (which_epoch, name) 123 | load_path = os.path.join(self.save_dir, load_filename) 124 | net = getattr(self, 'net' + name) 125 | if isinstance(net, torch.nn.DataParallel): 126 | net = net.module 127 | print('loading the model from %s' % load_path) 128 | # if you are using PyTorch newer than 0.4 (e.g., built from 129 | # GitHub source), you can remove str() on self.device 130 | state_dict = torch.load(load_path, map_location=str(self.device)) 131 | if hasattr(state_dict, '_metadata'): 132 | del state_dict._metadata 133 | 134 | # patch InstanceNorm checkpoints prior to 0.4 135 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 136 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 137 | net.load_state_dict(state_dict) 138 | 139 | # print network information 140 | def print_networks(self, verbose): 141 | print('---------- Networks initialized -------------') 142 | for name in self.model_names: 143 | if isinstance(name, str): 144 | net = getattr(self, 'net' + name) 145 | num_params = 0 146 | for param in net.parameters(): 147 | num_params += param.numel() 148 | if verbose: 149 | print(net) 150 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 151 | print('-----------------------------------------------') 152 | 153 | # set requies_grad=Fasle to avoid computation 154 | def set_requires_grad(self, nets, requires_grad=False): 155 | if not isinstance(nets, list): 156 | nets = [nets] 157 | for net in nets: 158 | if net is not None: 159 | for param in net.parameters(): 160 | param.requires_grad = requires_grad 161 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | ############################################################################### 8 | # Helper Functions 9 | ############################################################################### 10 | 11 | 12 | def get_norm_layer(norm_type='instance'): 13 | if norm_type == 'batch': 14 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 15 | elif norm_type == 'instance': 16 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 17 | elif norm_type == 'none': 18 | norm_layer = None 19 | else: 20 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 21 | return norm_layer 22 | 23 | 24 | def get_scheduler(optimizer, opt): 25 | if opt.lr_policy == 'lambda': 26 | def lambda_rule(epoch): 27 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 28 | return lr_l 29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 30 | elif opt.lr_policy == 'step': 31 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 32 | elif opt.lr_policy == 'plateau': 33 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 34 | else: 35 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 36 | return scheduler 37 | 38 | 39 | def init_weights(net, init_type='xavier', gain=0.02): 40 | def init_func(m): 41 | classname = m.__class__.__name__ 42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 43 | if init_type == 'normal': 44 | init.normal_(m.weight.data, 0.0, gain) 45 | elif init_type == 'xavier': 46 | init.xavier_normal_(m.weight.data, gain=gain) 47 | elif init_type == 'kaiming': 48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 49 | elif init_type == 'orthogonal': 50 | init.orthogonal_(m.weight.data, gain=gain) 51 | else: 52 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | init.constant_(m.bias.data, 0.0) 55 | elif classname.find('BatchNorm2d') != -1: 56 | init.normal_(m.weight.data, 1.0, gain) 57 | init.constant_(m.bias.data, 0.0) 58 | 59 | print('initialize network with %s' % init_type) 60 | net.apply(init_func) 61 | 62 | 63 | def init_net(net, init_type='xavier', gpu_ids=[]): 64 | if len(gpu_ids) > 0: 65 | assert(torch.cuda.is_available()) 66 | net.to(gpu_ids[0]) 67 | net = torch.nn.DataParallel(net, gpu_ids) 68 | init_weights(net, init_type) 69 | return net 70 | 71 | 72 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='xavier', gpu_ids=[], use_tanh=True, classification=True): 73 | netG = None 74 | norm_layer = get_norm_layer(norm_type=norm) 75 | 76 | if which_model_netG == 'resnet_9blocks': 77 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 78 | elif which_model_netG == 'resnet_6blocks': 79 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 80 | elif which_model_netG == 'unet_128': 81 | netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 82 | elif which_model_netG == 'unet_256': 83 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 84 | elif which_model_netG == 'siggraph': 85 | netG = SIGGRAPHGenerator(input_nc, output_nc, norm_layer=norm_layer, use_tanh=use_tanh, classification=classification) 86 | else: 87 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) 88 | return init_net(netG, init_type, gpu_ids) 89 | 90 | 91 | def define_D(input_nc, ndf, which_model_netD, 92 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='xavier', gpu_ids=[]): 93 | netD = None 94 | norm_layer = get_norm_layer(norm_type=norm) 95 | 96 | if which_model_netD == 'basic': 97 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 98 | elif which_model_netD == 'n_layers': 99 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 100 | elif which_model_netD == 'pixel': 101 | netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 102 | else: 103 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % 104 | which_model_netD) 105 | return init_net(netD, init_type, gpu_ids) 106 | 107 | 108 | ############################################################################## 109 | # Classes 110 | ############################################################################## 111 | 112 | 113 | class HuberLoss(nn.Module): 114 | def __init__(self, delta=.01): 115 | super(HuberLoss, self).__init__() 116 | self.delta = delta 117 | 118 | def __call__(self, in0, in1): 119 | mask = torch.zeros_like(in0) 120 | mann = torch.abs(in0 - in1) 121 | eucl = .5 * (mann**2) 122 | mask[...] = mann < self.delta 123 | 124 | # loss = eucl*mask + self.delta*(mann-.5*self.delta)*(1-mask) 125 | loss = eucl * mask / self.delta + (mann - .5 * self.delta) * (1 - mask) 126 | return torch.sum(loss, dim=1, keepdim=True) 127 | 128 | 129 | class L1Loss(nn.Module): 130 | def __init__(self): 131 | super(L1Loss, self).__init__() 132 | 133 | def __call__(self, in0, in1): 134 | return torch.sum(torch.abs(in0 - in1), dim=1, keepdim=True) 135 | 136 | 137 | class L2Loss(nn.Module): 138 | def __init__(self): 139 | super(L2Loss, self).__init__() 140 | 141 | def __call__(self, in0, in1): 142 | return torch.sum((in0 - in1)**2, dim=1, keepdim=True) 143 | 144 | 145 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 146 | # When LSGAN is used, it is basically same as MSELoss, 147 | # but it abstracts away the need to create the target label tensor 148 | # that has the same size as the input 149 | class GANLoss(nn.Module): 150 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 151 | super(GANLoss, self).__init__() 152 | self.register_buffer('real_label', torch.tensor(target_real_label)) 153 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 154 | if use_lsgan: 155 | self.loss = nn.MSELoss() 156 | else: 157 | self.loss = nn.BCELoss() 158 | 159 | def get_target_tensor(self, input, target_is_real): 160 | if target_is_real: 161 | target_tensor = self.real_label 162 | else: 163 | target_tensor = self.fake_label 164 | return target_tensor.expand_as(input) 165 | 166 | def __call__(self, input, target_is_real): 167 | target_tensor = self.get_target_tensor(input, target_is_real) 168 | return self.loss(input, target_tensor) 169 | 170 | 171 | class SIGGRAPHGenerator(nn.Module): 172 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, use_tanh=True, classification=True): 173 | super(SIGGRAPHGenerator, self).__init__() 174 | self.input_nc = input_nc 175 | self.output_nc = output_nc 176 | self.classification = classification 177 | use_bias = True 178 | 179 | # Conv1 180 | # model1=[nn.ReflectionPad2d(1),] 181 | model1 = [nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 182 | # model1+=[norm_layer(64),] 183 | model1 += [nn.ReLU(True), ] 184 | # model1+=[nn.ReflectionPad2d(1),] 185 | model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 186 | model1 += [nn.ReLU(True), ] 187 | model1 += [norm_layer(64), ] 188 | # add a subsampling operation 189 | 190 | # Conv2 191 | # model2=[nn.ReflectionPad2d(1),] 192 | model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 193 | # model2+=[norm_layer(128),] 194 | model2 += [nn.ReLU(True), ] 195 | # model2+=[nn.ReflectionPad2d(1),] 196 | model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 197 | model2 += [nn.ReLU(True), ] 198 | model2 += [norm_layer(128), ] 199 | # add a subsampling layer operation 200 | 201 | # Conv3 202 | # model3=[nn.ReflectionPad2d(1),] 203 | model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 204 | # model3+=[norm_layer(256),] 205 | model3 += [nn.ReLU(True), ] 206 | # model3+=[nn.ReflectionPad2d(1),] 207 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 208 | # model3+=[norm_layer(256),] 209 | model3 += [nn.ReLU(True), ] 210 | # model3+=[nn.ReflectionPad2d(1),] 211 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 212 | model3 += [nn.ReLU(True), ] 213 | model3 += [norm_layer(256), ] 214 | # add a subsampling layer operation 215 | 216 | # Conv4 217 | # model47=[nn.ReflectionPad2d(1),] 218 | model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 219 | # model4+=[norm_layer(512),] 220 | model4 += [nn.ReLU(True), ] 221 | # model4+=[nn.ReflectionPad2d(1),] 222 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 223 | # model4+=[norm_layer(512),] 224 | model4 += [nn.ReLU(True), ] 225 | # model4+=[nn.ReflectionPad2d(1),] 226 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 227 | model4 += [nn.ReLU(True), ] 228 | model4 += [norm_layer(512), ] 229 | 230 | # Conv5 231 | # model47+=[nn.ReflectionPad2d(2),] 232 | model5 = [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ] 233 | # model5+=[norm_layer(512),] 234 | model5 += [nn.ReLU(True), ] 235 | # model5+=[nn.ReflectionPad2d(2),] 236 | model5 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ] 237 | # model5+=[norm_layer(512),] 238 | model5 += [nn.ReLU(True), ] 239 | # model5+=[nn.ReflectionPad2d(2),] 240 | model5 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ] 241 | model5 += [nn.ReLU(True), ] 242 | model5 += [norm_layer(512), ] 243 | 244 | # Conv6 245 | # model6+=[nn.ReflectionPad2d(2),] 246 | model6 = [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ] 247 | # model6+=[norm_layer(512),] 248 | model6 += [nn.ReLU(True), ] 249 | # model6+=[nn.ReflectionPad2d(2),] 250 | model6 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ] 251 | # model6+=[norm_layer(512),] 252 | model6 += [nn.ReLU(True), ] 253 | # model6+=[nn.ReflectionPad2d(2),] 254 | model6 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ] 255 | model6 += [nn.ReLU(True), ] 256 | model6 += [norm_layer(512), ] 257 | 258 | # Conv7 259 | # model47+=[nn.ReflectionPad2d(1),] 260 | model7 = [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 261 | # model7+=[norm_layer(512),] 262 | model7 += [nn.ReLU(True), ] 263 | # model7+=[nn.ReflectionPad2d(1),] 264 | model7 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 265 | # model7+=[norm_layer(512),] 266 | model7 += [nn.ReLU(True), ] 267 | # model7+=[nn.ReflectionPad2d(1),] 268 | model7 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 269 | model7 += [nn.ReLU(True), ] 270 | model7 += [norm_layer(512), ] 271 | 272 | # Conv7 273 | model8up = [nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)] 274 | 275 | # model3short8=[nn.ReflectionPad2d(1),] 276 | model3short8 = [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 277 | 278 | # model47+=[norm_layer(256),] 279 | model8 = [nn.ReLU(True), ] 280 | # model8+=[nn.ReflectionPad2d(1),] 281 | model8 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 282 | # model8+=[norm_layer(256),] 283 | model8 += [nn.ReLU(True), ] 284 | # model8+=[nn.ReflectionPad2d(1),] 285 | model8 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 286 | model8 += [nn.ReLU(True), ] 287 | model8 += [norm_layer(256), ] 288 | 289 | # Conv9 290 | model9up = [nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ] 291 | 292 | # model2short9=[nn.ReflectionPad2d(1),] 293 | model2short9 = [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 294 | # add the two feature maps above 295 | 296 | # model9=[norm_layer(128),] 297 | model9 = [nn.ReLU(True), ] 298 | # model9+=[nn.ReflectionPad2d(1),] 299 | model9 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 300 | model9 += [nn.ReLU(True), ] 301 | model9 += [norm_layer(128), ] 302 | 303 | # Conv10 304 | model10up = [nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ] 305 | 306 | # model1short10=[nn.ReflectionPad2d(1),] 307 | model1short10 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] 308 | # add the two feature maps above 309 | 310 | # model10=[norm_layer(128),] 311 | model10 = [nn.ReLU(True), ] 312 | # model10+=[nn.ReflectionPad2d(1),] 313 | model10 += [nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=use_bias), ] 314 | model10 += [nn.LeakyReLU(negative_slope=.2), ] 315 | 316 | # classification output 317 | model_class = [nn.Conv2d(256, 529, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), ] 318 | 319 | # regression output 320 | model_out = [nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), ] 321 | if(use_tanh): 322 | model_out += [nn.Tanh()] 323 | 324 | self.model1 = nn.Sequential(*model1) 325 | self.model2 = nn.Sequential(*model2) 326 | self.model3 = nn.Sequential(*model3) 327 | self.model4 = nn.Sequential(*model4) 328 | self.model5 = nn.Sequential(*model5) 329 | self.model6 = nn.Sequential(*model6) 330 | self.model7 = nn.Sequential(*model7) 331 | self.model8up = nn.Sequential(*model8up) 332 | self.model8 = nn.Sequential(*model8) 333 | self.model9up = nn.Sequential(*model9up) 334 | self.model9 = nn.Sequential(*model9) 335 | self.model10up = nn.Sequential(*model10up) 336 | self.model10 = nn.Sequential(*model10) 337 | self.model3short8 = nn.Sequential(*model3short8) 338 | self.model2short9 = nn.Sequential(*model2short9) 339 | self.model1short10 = nn.Sequential(*model1short10) 340 | 341 | self.model_class = nn.Sequential(*model_class) 342 | self.model_out = nn.Sequential(*model_out) 343 | 344 | self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='nearest'), ]) 345 | self.softmax = nn.Sequential(*[nn.Softmax(dim=1), ]) 346 | 347 | def forward(self, input_A, input_B, mask_B): 348 | conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) 349 | conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) 350 | conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) 351 | conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) 352 | conv5_3 = self.model5(conv4_3) 353 | conv6_3 = self.model6(conv5_3) 354 | conv7_3 = self.model7(conv6_3) 355 | conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) 356 | conv8_3 = self.model8(conv8_up) 357 | 358 | if(self.classification): 359 | out_class = self.model_class(conv8_3) 360 | 361 | conv9_up = self.model9up(conv8_3.detach()) + self.model2short9(conv2_2.detach()) 362 | conv9_3 = self.model9(conv9_up) 363 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2.detach()) 364 | conv10_2 = self.model10(conv10_up) 365 | out_reg = self.model_out(conv10_2) 366 | else: 367 | out_class = self.model_class(conv8_3.detach()) 368 | 369 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) 370 | conv9_3 = self.model9(conv9_up) 371 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) 372 | conv10_2 = self.model10(conv10_up) 373 | out_reg = self.model_out(conv10_2) 374 | 375 | return (out_class, out_reg) 376 | 377 | # Defines the generator that consists of Resnet blocks between a few 378 | # downsampling/upsampling operations. 379 | # Code and idea originally from Justin Johnson's architecture. 380 | # https://github.com/jcjohnson/fast-neural-style/ 381 | 382 | 383 | class ResnetGenerator(nn.Module): 384 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 385 | assert(n_blocks >= 0) 386 | super(ResnetGenerator, self).__init__() 387 | self.input_nc = input_nc 388 | self.output_nc = output_nc 389 | self.ngf = ngf 390 | if type(norm_layer) == functools.partial: 391 | use_bias = norm_layer.func == nn.InstanceNorm2d 392 | else: 393 | use_bias = norm_layer == nn.InstanceNorm2d 394 | 395 | model = [nn.ReflectionPad2d(3), 396 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 397 | bias=use_bias), 398 | norm_layer(ngf), 399 | nn.ReLU(True)] 400 | 401 | n_downsampling = 2 402 | for i in range(n_downsampling): 403 | mult = 2**i 404 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 405 | stride=2, padding=1, bias=use_bias), 406 | norm_layer(ngf * mult * 2), 407 | nn.ReLU(True)] 408 | 409 | mult = 2**n_downsampling 410 | for i in range(n_blocks): 411 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 412 | 413 | for i in range(n_downsampling): 414 | mult = 2**(n_downsampling - i) 415 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 416 | kernel_size=3, stride=2, 417 | padding=1, output_padding=1, 418 | bias=use_bias), 419 | norm_layer(int(ngf * mult / 2)), 420 | nn.ReLU(True)] 421 | model += [nn.ReflectionPad2d(3)] 422 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 423 | model += [nn.Tanh()] 424 | 425 | self.model = nn.Sequential(*model) 426 | 427 | def forward(self, input): 428 | return self.model(input) 429 | 430 | 431 | # Define a resnet block 432 | class ResnetBlock(nn.Module): 433 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 434 | super(ResnetBlock, self).__init__() 435 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 436 | 437 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 438 | conv_block = [] 439 | p = 0 440 | if padding_type == 'reflect': 441 | conv_block += [nn.ReflectionPad2d(1)] 442 | elif padding_type == 'replicate': 443 | conv_block += [nn.ReplicationPad2d(1)] 444 | elif padding_type == 'zero': 445 | p = 1 446 | else: 447 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 448 | 449 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 450 | norm_layer(dim), 451 | nn.ReLU(True)] 452 | if use_dropout: 453 | conv_block += [nn.Dropout(0.5)] 454 | 455 | p = 0 456 | if padding_type == 'reflect': 457 | conv_block += [nn.ReflectionPad2d(1)] 458 | elif padding_type == 'replicate': 459 | conv_block += [nn.ReplicationPad2d(1)] 460 | elif padding_type == 'zero': 461 | p = 1 462 | else: 463 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 464 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 465 | norm_layer(dim)] 466 | 467 | return nn.Sequential(*conv_block) 468 | 469 | def forward(self, x): 470 | out = x + self.conv_block(x) 471 | return out 472 | 473 | 474 | # Defines the Unet generator. 475 | # |num_downs|: number of downsamplings in UNet. For example, 476 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 477 | # at the bottleneck 478 | class UnetGenerator(nn.Module): 479 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 480 | norm_layer=nn.BatchNorm2d, use_dropout=False): 481 | super(UnetGenerator, self).__init__() 482 | 483 | # construct unet structure 484 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 485 | for i in range(num_downs - 5): 486 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 487 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 488 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 489 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 490 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 491 | 492 | self.model = unet_block 493 | 494 | def forward(self, input_A, input_B, mask_B): 495 | # embed() 496 | return self.model(torch.cat((input_A, input_B, mask_B), dim=1)) 497 | 498 | 499 | # Defines the submodule with skip connection. 500 | # X -------------------identity---------------------- X 501 | # |-- downsampling -- |submodule| -- upsampling --| 502 | class UnetSkipConnectionBlock(nn.Module): 503 | def __init__(self, outer_nc, inner_nc, input_nc=None, 504 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 505 | super(UnetSkipConnectionBlock, self).__init__() 506 | self.outermost = outermost 507 | if type(norm_layer) == functools.partial: 508 | use_bias = norm_layer.func == nn.InstanceNorm2d 509 | else: 510 | use_bias = norm_layer == nn.InstanceNorm2d 511 | if input_nc is None: 512 | input_nc = outer_nc 513 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 514 | stride=2, padding=1, bias=use_bias) 515 | downrelu = nn.LeakyReLU(0.2, True) 516 | downnorm = norm_layer(inner_nc) 517 | uprelu = nn.ReLU(True) 518 | upnorm = norm_layer(outer_nc) 519 | 520 | if outermost: 521 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 522 | kernel_size=4, stride=2, 523 | padding=1) 524 | down = [downconv] 525 | up = [uprelu, upconv, nn.Tanh()] 526 | model = down + [submodule] + up 527 | elif innermost: 528 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 529 | kernel_size=4, stride=2, 530 | padding=1, bias=use_bias) 531 | down = [downrelu, downconv] 532 | up = [uprelu, upconv, upnorm] 533 | model = down + up 534 | else: 535 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 536 | kernel_size=4, stride=2, 537 | padding=1, bias=use_bias) 538 | down = [downrelu, downconv, downnorm] 539 | up = [uprelu, upconv, upnorm] 540 | 541 | if use_dropout: 542 | model = down + [submodule] + up + [nn.Dropout(0.5)] 543 | else: 544 | model = down + [submodule] + up 545 | 546 | self.model = nn.Sequential(*model) 547 | 548 | def forward(self, x): 549 | if self.outermost: 550 | return self.model(x) 551 | else: 552 | return torch.cat([x, self.model(x)], 1) 553 | 554 | 555 | # Defines the PatchGAN discriminator with the specified arguments. 556 | class NLayerDiscriminator(nn.Module): 557 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 558 | super(NLayerDiscriminator, self).__init__() 559 | if type(norm_layer) == functools.partial: 560 | use_bias = norm_layer.func == nn.InstanceNorm2d 561 | else: 562 | use_bias = norm_layer == nn.InstanceNorm2d 563 | 564 | kw = 4 565 | padw = 1 566 | sequence = [ 567 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 568 | nn.LeakyReLU(0.2, True) 569 | ] 570 | 571 | nf_mult = 1 572 | nf_mult_prev = 1 573 | for n in range(1, n_layers): 574 | nf_mult_prev = nf_mult 575 | nf_mult = min(2**n, 8) 576 | sequence += [ 577 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 578 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 579 | norm_layer(ndf * nf_mult), 580 | nn.LeakyReLU(0.2, True) 581 | ] 582 | 583 | nf_mult_prev = nf_mult 584 | nf_mult = min(2**n_layers, 8) 585 | sequence += [ 586 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 587 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 588 | norm_layer(ndf * nf_mult), 589 | nn.LeakyReLU(0.2, True) 590 | ] 591 | 592 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 593 | 594 | if use_sigmoid: 595 | sequence += [nn.Sigmoid()] 596 | 597 | self.model = nn.Sequential(*sequence) 598 | 599 | def forward(self, input): 600 | return self.model(input) 601 | 602 | 603 | class PixelDiscriminator(nn.Module): 604 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 605 | super(PixelDiscriminator, self).__init__() 606 | if type(norm_layer) == functools.partial: 607 | use_bias = norm_layer.func == nn.InstanceNorm2d 608 | else: 609 | use_bias = norm_layer == nn.InstanceNorm2d 610 | 611 | self.net = [ 612 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 613 | nn.LeakyReLU(0.2, True), 614 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 615 | norm_layer(ndf * 2), 616 | nn.LeakyReLU(0.2, True), 617 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 618 | 619 | if use_sigmoid: 620 | self.net.append(nn.Sigmoid()) 621 | 622 | self.net = nn.Sequential(*self.net) 623 | 624 | def forward(self, input): 625 | return self.net(input) 626 | -------------------------------------------------------------------------------- /models/pix2pix_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from util.image_pool import ImagePool 4 | from util import util 5 | from .base_model import BaseModel 6 | from . import networks 7 | import numpy as np 8 | 9 | 10 | class Pix2PixModel(BaseModel): 11 | def name(self): 12 | return 'Pix2PixModel' 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train=True): 16 | return parser 17 | 18 | def initialize(self, opt): 19 | BaseModel.initialize(self, opt) 20 | self.isTrain = opt.isTrain 21 | self.half = opt.half 22 | 23 | self.use_D = self.opt.lambda_GAN > 0 24 | 25 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 26 | 27 | if(self.use_D): 28 | self.loss_names = ['G_GAN', ] 29 | else: 30 | self.loss_names = [] 31 | 32 | self.loss_names += ['G_CE', 'G_entr', 'G_entr_hint', ] 33 | self.loss_names += ['G_L1_max', 'G_L1_mean', 'G_entr', 'G_L1_reg', ] 34 | self.loss_names += ['G_fake_real', 'G_fake_hint', 'G_real_hint', ] 35 | self.loss_names += ['0', ] 36 | 37 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 38 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 39 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 40 | 41 | if self.isTrain: 42 | if(self.use_D): 43 | self.model_names = ['G', 'D'] 44 | else: 45 | self.model_names = ['G', ] 46 | else: # during test time, only load Gs 47 | self.model_names = ['G'] 48 | 49 | # load/define networks 50 | num_in = opt.input_nc + opt.output_nc + 1 51 | self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf, 52 | opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, 53 | use_tanh=True, classification=opt.classification) 54 | 55 | if self.isTrain: 56 | use_sigmoid = opt.no_lsgan 57 | if self.use_D: 58 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, 59 | opt.which_model_netD, 60 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) 61 | 62 | if self.isTrain: 63 | self.fake_AB_pool = ImagePool(opt.pool_size) 64 | # define loss functions 65 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) 66 | # self.criterionL1 = torch.nn.L1Loss() 67 | self.criterionL1 = networks.L1Loss() 68 | self.criterionHuber = networks.HuberLoss(delta=1. / opt.ab_norm) 69 | 70 | # if(opt.classification): 71 | self.criterionCE = torch.nn.CrossEntropyLoss() 72 | 73 | # initialize optimizers 74 | self.optimizers = [] 75 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 76 | lr=opt.lr, betas=(opt.beta1, 0.999)) 77 | self.optimizers.append(self.optimizer_G) 78 | 79 | if self.use_D: 80 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 81 | lr=opt.lr, betas=(opt.beta1, 0.999)) 82 | self.optimizers.append(self.optimizer_D) 83 | 84 | if self.half: 85 | for model_name in self.model_names: 86 | net = getattr(self, 'net' + model_name) 87 | net.half() 88 | for layer in net.modules(): 89 | if(isinstance(layer, torch.nn.BatchNorm2d)): 90 | layer.float() 91 | print('Net %s half precision' % model_name) 92 | 93 | # initialize average loss values 94 | self.avg_losses = OrderedDict() 95 | self.avg_loss_alpha = opt.avg_loss_alpha 96 | self.error_cnt = 0 97 | 98 | # self.avg_loss_alpha = 0.9993 # half-life of 1000 iterations 99 | # self.avg_loss_alpha = 0.9965 # half-life of 200 iterations 100 | # self.avg_loss_alpha = 0.986 # half-life of 50 iterations 101 | # self.avg_loss_alpha = 0. # no averaging 102 | for loss_name in self.loss_names: 103 | self.avg_losses[loss_name] = 0 104 | 105 | def set_input(self, input): 106 | if(self.half): 107 | for key in input.keys(): 108 | input[key] = input[key].half() 109 | 110 | AtoB = self.opt.which_direction == 'AtoB' 111 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 112 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 113 | # self.image_paths = input['A_paths' if AtoB else 'B_paths'] 114 | self.hint_B = input['hint_B'].to(self.device) 115 | self.mask_B = input['mask_B'].to(self.device) 116 | self.mask_B_nc = self.mask_B + self.opt.mask_cent 117 | 118 | self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt) 119 | 120 | def forward(self): 121 | (self.fake_B_class, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, self.mask_B) 122 | # if(self.opt.classification): 123 | self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt)) 124 | self.fake_B_distr = self.netG.module.softmax(self.fake_B_class) 125 | 126 | self.fake_B_dec_mean = self.netG.module.upsample4(util.decode_mean(self.fake_B_distr, self.opt)) 127 | 128 | self.fake_B_entr = self.netG.module.upsample4(-torch.sum(self.fake_B_distr * torch.log(self.fake_B_distr + 1.e-10), dim=1, keepdim=True)) 129 | # embed() 130 | 131 | def backward_D(self): 132 | # Fake 133 | # stop backprop to the generator by detaching fake_B 134 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) 135 | pred_fake = self.netD(fake_AB.detach()) 136 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 137 | # self.loss_D_fake = 0 138 | 139 | # Real 140 | real_AB = torch.cat((self.real_A, self.real_B), 1) 141 | pred_real = self.netD(real_AB) 142 | self.loss_D_real = self.criterionGAN(pred_real, True) 143 | # self.loss_D_real = 0 144 | 145 | # Combined loss 146 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 147 | 148 | self.loss_D.backward() 149 | 150 | def compute_losses_G(self): 151 | mask_avg = torch.mean(self.mask_B_nc.type(torch.cuda.FloatTensor)) + .000001 152 | 153 | self.loss_0 = 0 # 0 for plot 154 | 155 | # classification statistics 156 | self.loss_G_CE = self.criterionCE(self.fake_B_class.type(torch.cuda.FloatTensor), 157 | self.real_B_enc[:, 0, :, :].type(torch.cuda.LongTensor)) # cross-entropy loss 158 | self.loss_G_entr = torch.mean(self.fake_B_entr.type(torch.cuda.FloatTensor)) # entropy of predicted distribution 159 | self.loss_G_entr_hint = torch.mean(self.fake_B_entr.type(torch.cuda.FloatTensor) * self.mask_B_nc.type(torch.cuda.FloatTensor)) / mask_avg # entropy of predicted distribution at hint points 160 | 161 | # regression statistics 162 | self.loss_G_L1_max = 10 * torch.mean(self.criterionL1(self.fake_B_dec_max.type(torch.cuda.FloatTensor), 163 | self.real_B.type(torch.cuda.FloatTensor))) 164 | self.loss_G_L1_mean = 10 * torch.mean(self.criterionL1(self.fake_B_dec_mean.type(torch.cuda.FloatTensor), 165 | self.real_B.type(torch.cuda.FloatTensor))) 166 | self.loss_G_L1_reg = 10 * torch.mean(self.criterionL1(self.fake_B_reg.type(torch.cuda.FloatTensor), 167 | self.real_B.type(torch.cuda.FloatTensor))) 168 | 169 | # L1 loss at given points 170 | self.loss_G_fake_real = 10 * torch.mean(self.criterionL1(self.fake_B_reg * self.mask_B_nc, self.real_B * self.mask_B_nc).type(torch.cuda.FloatTensor)) / mask_avg 171 | self.loss_G_fake_hint = 10 * torch.mean(self.criterionL1(self.fake_B_reg * self.mask_B_nc, self.hint_B * self.mask_B_nc).type(torch.cuda.FloatTensor)) / mask_avg 172 | self.loss_G_real_hint = 10 * torch.mean(self.criterionL1(self.real_B * self.mask_B_nc, self.hint_B * self.mask_B_nc).type(torch.cuda.FloatTensor)) / mask_avg 173 | 174 | # self.loss_G_L1 = torch.mean(self.criterionL1(self.fake_B, self.real_B)) 175 | # self.loss_G_Huber = torch.mean(self.criterionHuber(self.fake_B, self.real_B)) 176 | # self.loss_G_fake_real = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.real_B*self.mask_B_nc)) / mask_avg 177 | # self.loss_G_fake_hint = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg 178 | # self.loss_G_real_hint = torch.mean(self.criterionHuber(self.real_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg 179 | 180 | if self.use_D: 181 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 182 | pred_fake = self.netD(fake_AB) 183 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 184 | else: 185 | self.loss_G = self.loss_G_CE * self.opt.lambda_A + self.loss_G_L1_reg 186 | # self.loss_G = self.loss_G_Huber*self.opt.lambda_A 187 | 188 | def backward_G(self): 189 | self.compute_losses_G() 190 | self.loss_G.backward() 191 | 192 | def optimize_parameters(self): 193 | self.forward() 194 | 195 | if(self.use_D): 196 | # update D 197 | self.set_requires_grad(self.netD, True) 198 | self.optimizer_D.zero_grad() 199 | self.backward_D() 200 | self.optimizer_D.step() 201 | 202 | self.set_requires_grad(self.netD, False) 203 | 204 | # update G 205 | self.optimizer_G.zero_grad() 206 | self.backward_G() 207 | self.optimizer_G.step() 208 | 209 | def get_current_visuals(self): 210 | from collections import OrderedDict 211 | visual_ret = OrderedDict() 212 | 213 | visual_ret['gray'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), torch.zeros_like(self.real_B).type(torch.cuda.FloatTensor)), dim=1), self.opt) 214 | visual_ret['real'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor)), dim=1), self.opt) 215 | 216 | visual_ret['fake_max'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_dec_max.type(torch.cuda.FloatTensor)), dim=1), self.opt) 217 | visual_ret['fake_mean'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_dec_mean.type(torch.cuda.FloatTensor)), dim=1), self.opt) 218 | visual_ret['fake_reg'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt) 219 | 220 | visual_ret['hint'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.hint_B.type(torch.cuda.FloatTensor)), dim=1), self.opt) 221 | 222 | visual_ret['real_ab'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.real_B.type(torch.cuda.FloatTensor)), dim=1), self.opt) 223 | 224 | visual_ret['fake_ab_max'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_dec_max.type(torch.cuda.FloatTensor)), dim=1), self.opt) 225 | visual_ret['fake_ab_mean'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_dec_mean.type(torch.cuda.FloatTensor)), dim=1), self.opt) 226 | visual_ret['fake_ab_reg'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt) 227 | 228 | visual_ret['mask'] = self.mask_B_nc.expand(-1, 3, -1, -1).type(torch.cuda.FloatTensor) 229 | visual_ret['hint_ab'] = visual_ret['mask'] * util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.hint_B.type(torch.cuda.FloatTensor)), dim=1), self.opt) 230 | 231 | C = self.fake_B_distr.shape[1] 232 | # scale to [-1, 2], then clamped to [-1, 1] 233 | visual_ret['fake_entr'] = torch.clamp(3 * self.fake_B_entr.expand(-1, 3, -1, -1) / np.log(C) - 1, -1, 1) 234 | 235 | return visual_ret 236 | 237 | # return training losses/errors. train.py will print out these errors as debugging information 238 | def get_current_losses(self): 239 | self.error_cnt += 1 240 | errors_ret = OrderedDict() 241 | for name in self.loss_names: 242 | if isinstance(name, str): 243 | # float(...) works for both scalar tensor and float number 244 | self.avg_losses[name] = float(getattr(self, 'loss_' + name)) + self.avg_loss_alpha * self.avg_losses[name] 245 | errors_ret[name] = (1 - self.avg_loss_alpha) / (1 - self.avg_loss_alpha**self.error_cnt) * self.avg_losses[name] 246 | 247 | # errors_ret['|ab|_gt'] = float(torch.mean(torch.abs(self.real_B[:,1:,:,:])).cpu()) 248 | # errors_ret['|ab|_pr'] = float(torch.mean(torch.abs(self.fake_B[:,1:,:,:])).cpu()) 249 | 250 | return errors_ret 251 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--batch_size', type=int, default=25, help='input batch size') 15 | parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size') 16 | parser.add_argument('--fineSize', type=int, default=176, help='then crop to this size') 17 | parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels') 18 | parser.add_argument('--output_nc', type=int, default=2, help='# of output image channels') 19 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 20 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 21 | parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD') 22 | parser.add_argument('--which_model_netG', type=str, default='siggraph', help='selects model to use for netG') 23 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 24 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 25 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 26 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') 27 | parser.add_argument('--model', type=str, default='pix2pix', 28 | help='chooses which model to use. cycle_gan, pix2pix, test') 29 | parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 30 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 31 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 32 | parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization') 33 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 34 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 35 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 36 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 37 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 38 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 39 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), 40 | help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 41 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 42 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 43 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 44 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 45 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}') 46 | parser.add_argument('--ab_norm', type=float, default=110., help='colorization normalization factor') 47 | parser.add_argument('--ab_max', type=float, default=110., help='maximimum ab value') 48 | parser.add_argument('--ab_quant', type=float, default=10., help='quantization factor') 49 | parser.add_argument('--l_norm', type=float, default=100., help='colorization normalization factor') 50 | parser.add_argument('--l_cent', type=float, default=50., help='colorization centering factor') 51 | parser.add_argument('--mask_cent', type=float, default=.5, help='mask centering factor') 52 | parser.add_argument('--sample_p', type=float, default=1.0, help='sampling geometric distribution, 1.0 means no hints') 53 | parser.add_argument('--sample_Ps', type=int, nargs='+', default=[1, 2, 3, 4, 5, 6, 7, 8, 9, ], help='patch sizes') 54 | 55 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 56 | parser.add_argument('--classification', action='store_true', help='backprop trunk using classification, otherwise use regression') 57 | parser.add_argument('--phase', type=str, default='val', help='train_small, train, val, test, etc') 58 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 59 | parser.add_argument('--how_many', type=int, default=200, help='how many test images to run') 60 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 61 | 62 | parser.add_argument('--load_model', action='store_true', help='load the latest model') 63 | parser.add_argument('--half', action='store_true', help='half precision model') 64 | 65 | self.initialized = True 66 | return parser 67 | 68 | def gather_options(self): 69 | # initialize parser with basic options 70 | if not self.initialized: 71 | parser = argparse.ArgumentParser( 72 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 73 | parser = self.initialize(parser) 74 | 75 | # get the basic options 76 | opt, _ = parser.parse_known_args() 77 | 78 | # modify model-related parser options 79 | model_name = opt.model 80 | model_option_setter = models.get_option_setter(model_name) 81 | parser = model_option_setter(parser, self.isTrain) 82 | opt, _ = parser.parse_known_args() # parse again with the new defaults 83 | 84 | # modify dataset-related parser options 85 | dataset_name = opt.dataset_mode 86 | dataset_option_setter = data.get_option_setter(dataset_name) 87 | parser = dataset_option_setter(parser, self.isTrain) 88 | 89 | self.parser = parser 90 | 91 | return parser.parse_args() 92 | 93 | def print_options(self, opt): 94 | message = '' 95 | message += '----------------- Options ---------------\n' 96 | for k, v in sorted(vars(opt).items()): 97 | comment = '' 98 | default = self.parser.get_default(k) 99 | if v != default: 100 | comment = '\t[default: %s]' % str(default) 101 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 102 | message += '----------------- End -------------------' 103 | print(message) 104 | 105 | # save to the disk 106 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 107 | util.mkdirs(expr_dir) 108 | file_name = os.path.join(expr_dir, 'opt.txt') 109 | with open(file_name, 'wt') as opt_file: 110 | opt_file.write(message) 111 | opt_file.write('\n') 112 | 113 | def parse(self): 114 | 115 | opt = self.gather_options() 116 | opt.isTrain = self.isTrain # train or test 117 | 118 | # process opt.suffix 119 | if opt.suffix: 120 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 121 | opt.name = opt.name + suffix 122 | 123 | self.print_options(opt) 124 | 125 | # set gpu ids 126 | str_ids = opt.gpu_ids.split(',') 127 | opt.gpu_ids = [] 128 | for str_id in str_ids: 129 | id = int(str_id) 130 | if id >= 0: 131 | opt.gpu_ids.append(id) 132 | if len(opt.gpu_ids) > 0: 133 | torch.cuda.set_device(opt.gpu_ids[0]) 134 | opt.A = 2 * opt.ab_max / opt.ab_quant + 1 135 | opt.B = opt.A 136 | 137 | self.opt = opt 138 | return self.opt 139 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | BaseOptions.initialize(self, parser) 7 | parser.add_argument('--display_freq', type=int, default=10000, help='frequency of showing training results on screen') 8 | parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | parser.add_argument('--update_html_freq', type=int, default=10000, help='frequency of saving training results to html') 10 | parser.add_argument('--print_freq', type=int, default=200, help='frequency of showing training results on console') 11 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 12 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 13 | parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...') 14 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 15 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 16 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam') 17 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 18 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 19 | parser.add_argument('--lambda_GAN', type=float, default=0., help='weight for GAN loss') 20 | parser.add_argument('--lambda_A', type=float, default=1., help='weight for cycle loss (A -> B -> A)') 21 | parser.add_argument('--lambda_B', type=float, default=1., help='weight for cycle loss (B -> A -> B)') 22 | parser.add_argument('--lambda_identity', type=float, default=0.5, 23 | help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss.' 24 | 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') 25 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 26 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 27 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 28 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 29 | parser.add_argument('--avg_loss_alpha', type=float, default=.986, help='exponential averaging weight for displaying loss') 30 | self.isTrain = True 31 | return parser 32 | -------------------------------------------------------------------------------- /pretrained_models/download_siggraph_model.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ./checkpoints/siggraph_retrained 2 | MODEL_FILE=./checkpoints/siggraph_retrained/latest_net_G.pth 3 | URL=http://colorization.eecs.berkeley.edu/siggraph/models/pytorch.pth 4 | 5 | wget -N $URL -O $MODEL_FILE 6 | 7 | mkdir -p ./checkpoints/siggraph_caffemodel 8 | MODEL_FILE=./checkpoints/siggraph_caffemodel/latest_net_G.pth 9 | URL=http://colorization.eecs.berkeley.edu/siggraph/models/caffemodel.pth 10 | 11 | wget -N $URL -O $MODEL_FILE 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.0 2 | torchvision>=0.2.1 3 | dominate>=2.3.1 4 | visdom>=0.1.8.3 5 | -------------------------------------------------------------------------------- /resources/ilsvrclin12_val_inds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/resources/ilsvrclin12_val_inds.npy -------------------------------------------------------------------------------- /resources/psnrs_siggraph.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/resources/psnrs_siggraph.npy -------------------------------------------------------------------------------- /scripts/check_all.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | DOWNLOAD=${1} 3 | echo 'apply a pretrained cyclegan model' 4 | if [ ${DOWNLOAD} -eq 1 ] 5 | then 6 | bash pretrained_models/download_cyclegan_model.sh horse2zebra 7 | bash ./datasets/download_cyclegan_dataset.sh horse2zebra 8 | fi 9 | python test.py --dataroot datasets/horse2zebra/testA --checkpoints_dir ./checkpoints/ --name horse2zebra_pretrained --no_dropout --model test --dataset_mode single --loadSize 256 10 | 11 | echo 'apply a pretrained pix2pix model' 12 | if [ ${DOWNLOAD} -eq 1 ] 13 | then 14 | bash pretrained_models/download_pix2pix_model.sh facades_label2photo 15 | bash ./datasets/download_pix2pix_dataset.sh facades 16 | fi 17 | python test.py --dataroot ./datasets/facades/ --which_direction BtoA --model pix2pix --name facades_label2photo_pretrained --dataset_mode aligned --which_model_netG unet_256 --norm batch 18 | 19 | 20 | echo 'cyclegan train (1 epoch) and test' 21 | if [ ${DOWNLOAD} -eq 1 ] 22 | then 23 | bash ./datasets/download_cyclegan_dataset.sh maps 24 | fi 25 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --no_dropout --niter 1 --niter_decay 0 --max_dataset_size 100 --save_latest_freq 100 26 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test --no_dropout 27 | 28 | 29 | echo 'pix2pix train (1 epoch) and test' 30 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --no_lsgan --norm batch --pool_size 0 --niter 1 --niter_decay 0 --save_latest_freq 400 31 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --norm batch 32 | -------------------------------------------------------------------------------- /scripts/conda_deps.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing 3 | conda install -c pytorch magma-cuda80 # or magma-cuda90 if CUDA 9 4 | conda install pytorch torchvision -c pytorch # install pytorch; if you want to use cuda90, add cuda90 5 | conda install -c conda-forge dominate # install dominate 6 | conda install -c conda-forge visdom # install visdom 7 | -------------------------------------------------------------------------------- /scripts/install_deps.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | pip install visdom 3 | pip install dominate 4 | -------------------------------------------------------------------------------- /scripts/train_siggraph.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Train classification network on small training set first 4 | python train.py --name siggraph_class_small --sample_p 1.0 --niter 100 --niter_decay 0 --classification --phase train_small 5 | 6 | # Train classification network first 7 | mkdir ./checkpoints/siggraph_class 8 | cp ./checkpoints/siggraph_class_small/latest_net_G.pth ./checkpoints/siggraph_class/ 9 | python train.py --name siggraph_class --sample_p 1.0 --niter 15 --niter_decay 0 --classification --load_model --phase train 10 | 11 | # Train regression model (with color hints) 12 | mkdir ./checkpoints/siggraph_reg 13 | cp ./checkpoints/siggraph_class/latest_net_G.pth ./checkpoints/siggraph_reg/ 14 | python train.py --name siggraph_reg --sample_p .125 --niter 10 --niter_decay 0 --lr 0.00001 --load_model --phase train 15 | 16 | # Turn down learning rate to 1e-6 17 | mkdir ./checkpoints/siggraph_reg2 18 | cp ./checkpoints/siggraph_reg/latest_net_G.pth ./checkpoints/siggraph_reg2/ 19 | python train.py --name siggraph_reg2 --sample_p .125 --niter 5 --niter_decay 0 --lr 0.000001 --load_model --phase train 20 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from options.train_options import TrainOptions 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from util import html 7 | 8 | import string 9 | import torch 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | from util import util 14 | import numpy as np 15 | 16 | 17 | if __name__ == '__main__': 18 | sample_ps = [1., .125, .03125] 19 | to_visualize = ['gray', 'hint', 'hint_ab', 'fake_entr', 'real', 'fake_reg', 'real_ab', 'fake_ab_reg', ] 20 | S = len(sample_ps) 21 | 22 | opt = TrainOptions().parse() 23 | opt.load_model = True 24 | opt.num_threads = 1 # test code only supports num_threads = 1 25 | opt.batch_size = 1 # test code only supports batch_size = 1 26 | opt.display_id = -1 # no visdom display 27 | opt.phase = 'val' 28 | opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase 29 | opt.serial_batches = True 30 | opt.aspect_ratio = 1. 31 | 32 | dataset = torchvision.datasets.ImageFolder(opt.dataroot, 33 | transform=transforms.Compose([ 34 | transforms.Resize((opt.loadSize, opt.loadSize)), 35 | transforms.ToTensor()])) 36 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches) 37 | 38 | model = create_model(opt) 39 | model.setup(opt) 40 | model.eval() 41 | 42 | # create website 43 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 44 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 45 | 46 | # statistics 47 | psnrs = np.zeros((opt.how_many, S)) 48 | entrs = np.zeros((opt.how_many, S)) 49 | 50 | for i, data_raw in enumerate(dataset_loader): 51 | data_raw[0] = data_raw[0].cuda() 52 | data_raw[0] = util.crop_mult(data_raw[0], mult=8) 53 | 54 | # with no points 55 | for (pp, sample_p) in enumerate(sample_ps): 56 | img_path = [string.replace('%08d_%.3f' % (i, sample_p), '.', 'p')] 57 | data = util.get_colorization_data(data_raw, opt, ab_thresh=0., p=sample_p) 58 | 59 | model.set_input(data) 60 | model.test(True) # True means that losses will be computed 61 | visuals = util.get_subset_dict(model.get_current_visuals(), to_visualize) 62 | 63 | psnrs[i, pp] = util.calculate_psnr_np(util.tensor2im(visuals['real']), util.tensor2im(visuals['fake_reg'])) 64 | entrs[i, pp] = model.get_current_losses()['G_entr'] 65 | 66 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 67 | 68 | if i % 5 == 0: 69 | print('processing (%04d)-th image... %s' % (i, img_path)) 70 | 71 | if i == opt.how_many - 1: 72 | break 73 | 74 | webpage.save() 75 | 76 | # Compute and print some summary statistics 77 | psnrs_mean = np.mean(psnrs, axis=0) 78 | psnrs_std = np.std(psnrs, axis=0) / np.sqrt(opt.how_many) 79 | 80 | entrs_mean = np.mean(entrs, axis=0) 81 | entrs_std = np.std(entrs, axis=0) / np.sqrt(opt.how_many) 82 | 83 | for (pp, sample_p) in enumerate(sample_ps): 84 | print('p=%.3f: %.2f+/-%.2f' % (sample_p, psnrs_mean[pp], psnrs_std[pp])) 85 | -------------------------------------------------------------------------------- /test_sweep.py: -------------------------------------------------------------------------------- 1 | from options.train_options import TrainOptions 2 | from models import create_model 3 | 4 | import torch 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | 8 | from util import util 9 | import numpy as np 10 | import progressbar as pb 11 | import shutil 12 | 13 | import datetime as dt 14 | import matplotlib.pyplot as plt 15 | 16 | if __name__ == '__main__': 17 | opt = TrainOptions().parse() 18 | opt.load_model = True 19 | opt.num_threads = 1 # test code only supports num_threads = 1 20 | opt.batch_size = 1 # test code only supports batch_size = 1 21 | opt.display_id = -1 # no visdom display 22 | opt.phase = 'test' 23 | opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase 24 | opt.loadSize = 256 25 | opt.how_many = 1000 26 | opt.aspect_ratio = 1.0 27 | opt.sample_Ps = [6, ] 28 | opt.load_model = True 29 | 30 | # number of random points to assign 31 | num_points = np.round(10**np.arange(-.1, 2.8, .1)) 32 | num_points[0] = 0 33 | num_points = np.unique(num_points.astype('int')) 34 | N = len(num_points) 35 | 36 | dataset = torchvision.datasets.ImageFolder(opt.dataroot, 37 | transform=transforms.Compose([ 38 | transforms.Resize((opt.loadSize, opt.loadSize)), 39 | transforms.ToTensor()])) 40 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches) 41 | 42 | model = create_model(opt) 43 | model.setup(opt) 44 | model.eval() 45 | 46 | time = dt.datetime.now() 47 | str_now = '%02d_%02d_%02d%02d' % (time.month, time.day, time.hour, time.minute) 48 | 49 | shutil.copyfile('./checkpoints/%s/latest_net_G.pth' % opt.name, './checkpoints/%s/%s_net_G.pth' % (opt.name, str_now)) 50 | 51 | psnrs = np.zeros((opt.how_many, N)) 52 | 53 | bar = pb.ProgressBar(max_value=opt.how_many) 54 | for i, data_raw in enumerate(dataset_loader): 55 | data_raw[0] = data_raw[0].cuda() 56 | data_raw[0] = util.crop_mult(data_raw[0], mult=8) 57 | 58 | for nn in range(N): 59 | # embed() 60 | data = util.get_colorization_data(data_raw, opt, ab_thresh=0., num_points=num_points[nn]) 61 | 62 | model.set_input(data) 63 | model.test() 64 | visuals = model.get_current_visuals() 65 | 66 | psnrs[i, nn] = util.calculate_psnr_np(util.tensor2im(visuals['real']), util.tensor2im(visuals['fake_reg'])) 67 | 68 | if i == opt.how_many - 1: 69 | break 70 | 71 | bar.update(i) 72 | 73 | # Save results 74 | psnrs_mean = np.mean(psnrs, axis=0) 75 | psnrs_std = np.std(psnrs, axis=0) / np.sqrt(opt.how_many) 76 | 77 | np.save('./checkpoints/%s/psnrs_mean_%s' % (opt.name, str_now), psnrs_mean) 78 | np.save('./checkpoints/%s/psnrs_std_%s' % (opt.name, str_now), psnrs_std) 79 | np.save('./checkpoints/%s/psnrs_%s' % (opt.name, str_now), psnrs) 80 | print(', ').join(['%.2f' % psnr for psnr in psnrs_mean]) 81 | 82 | old_results = np.load('./resources/psnrs_siggraph.npy') 83 | old_mean = np.mean(old_results, axis=0) 84 | old_std = np.std(old_results, axis=0) / np.sqrt(old_results.shape[0]) 85 | print(', ').join(['%.2f' % psnr for psnr in old_mean]) 86 | 87 | num_points_hack = 1. * num_points 88 | num_points_hack[0] = .4 89 | 90 | plt.plot(num_points_hack, psnrs_mean, 'bo-', label=str_now) 91 | plt.plot(num_points_hack, psnrs_mean + psnrs_std, 'b--') 92 | plt.plot(num_points_hack, psnrs_mean - psnrs_std, 'b--') 93 | plt.plot(num_points_hack, old_mean, 'ro-', label='siggraph17') 94 | plt.plot(num_points_hack, old_mean + old_std, 'r--') 95 | plt.plot(num_points_hack, old_mean - old_std, 'r--') 96 | 97 | plt.xscale('log') 98 | plt.xticks([.4, 1, 2, 5, 10, 20, 50, 100, 200, 500], 99 | ['Auto', '1', '2', '5', '10', '20', '50', '100', '200', '500']) 100 | plt.xlabel('Number of points') 101 | plt.ylabel('PSNR [db]') 102 | plt.legend(loc=0) 103 | plt.xlim((num_points_hack[0], num_points_hack[-1])) 104 | plt.savefig('./checkpoints/%s/sweep_%s.png' % (opt.name, str_now)) 105 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from models import create_model 4 | from util.visualizer import Visualizer 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | from util import util 11 | 12 | if __name__ == '__main__': 13 | opt = TrainOptions().parse() 14 | 15 | opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase 16 | dataset = torchvision.datasets.ImageFolder(opt.dataroot, 17 | transform=transforms.Compose([ 18 | transforms.RandomChoice([transforms.Resize(opt.loadSize, interpolation=1), 19 | transforms.Resize(opt.loadSize, interpolation=2), 20 | transforms.Resize(opt.loadSize, interpolation=3), 21 | transforms.Resize((opt.loadSize, opt.loadSize), interpolation=1), 22 | transforms.Resize((opt.loadSize, opt.loadSize), interpolation=2), 23 | transforms.Resize((opt.loadSize, opt.loadSize), interpolation=3)]), 24 | transforms.RandomChoice([transforms.RandomResizedCrop(opt.fineSize, interpolation=1), 25 | transforms.RandomResizedCrop(opt.fineSize, interpolation=2), 26 | transforms.RandomResizedCrop(opt.fineSize, interpolation=3)]), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.ToTensor()])) 29 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.num_threads)) 30 | 31 | dataset_size = len(dataset) 32 | print('#training images = %d' % dataset_size) 33 | 34 | model = create_model(opt) 35 | model.setup(opt) 36 | model.print_networks(True) 37 | 38 | visualizer = Visualizer(opt) 39 | total_steps = 0 40 | 41 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay): 42 | epoch_start_time = time.time() 43 | iter_data_time = time.time() 44 | epoch_iter = 0 45 | 46 | # for i, data in enumerate(dataset): 47 | for i, data_raw in enumerate(dataset_loader): 48 | data_raw[0] = data_raw[0].cuda() 49 | data = util.get_colorization_data(data_raw, opt, p=opt.sample_p) 50 | if(data is None): 51 | continue 52 | 53 | iter_start_time = time.time() 54 | if total_steps % opt.print_freq == 0: 55 | # time to load data 56 | t_data = iter_start_time - iter_data_time 57 | visualizer.reset() 58 | total_steps += opt.batch_size 59 | epoch_iter += opt.batch_size 60 | model.set_input(data) 61 | model.optimize_parameters() 62 | 63 | if total_steps % opt.display_freq == 0: 64 | save_result = total_steps % opt.update_html_freq == 0 65 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 66 | 67 | if total_steps % opt.print_freq == 0: 68 | losses = model.get_current_losses() 69 | # time to do forward&backward 70 | t = time.time() - iter_start_time 71 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 72 | if opt.display_id > 0: 73 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) 74 | 75 | if total_steps % opt.save_latest_freq == 0: 76 | print('saving the latest model (epoch %d, total_steps %d)' % 77 | (epoch, total_steps)) 78 | model.save_networks('latest') 79 | 80 | iter_data_time = time.time() 81 | 82 | if epoch % opt.save_epoch_freq == 0: 83 | print('saving the model at the end of epoch %d, iters %d' % 84 | (epoch, total_steps)) 85 | model.save_networks('latest') 86 | model.save_networks(epoch) 87 | 88 | print('End of epoch %d / %d \t Time Taken: %d sec' % 89 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 90 | model.update_learning_rate() 91 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization-pytorch/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/__init__.py -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """ 13 | 14 | Download CycleGAN or Pix2Pix Data. 15 | 16 | Args: 17 | technique : str 18 | One of: 'cyclegan' or 'pix2pix'. 19 | verbose : bool 20 | If True, print additional information. 21 | 22 | Examples: 23 | >>> from util.get_data import GetData 24 | >>> gd = GetData(technique='cyclegan') 25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 26 | 27 | """ 28 | 29 | def __init__(self, technique='cyclegan', verbose=True): 30 | url_dict = { 31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 33 | } 34 | self.url = url_dict.get(technique.lower()) 35 | self._verbose = verbose 36 | 37 | def _print(self, text): 38 | if self._verbose: 39 | print(text) 40 | 41 | @staticmethod 42 | def _get_options(r): 43 | soup = BeautifulSoup(r.text, 'lxml') 44 | options = [h.text for h in soup.find_all('a', href=True) 45 | if h.text.endswith(('.zip', 'tar.gz'))] 46 | return options 47 | 48 | def _present_options(self): 49 | r = requests.get(self.url) 50 | options = self._get_options(r) 51 | print('Options:\n') 52 | for i, o in enumerate(options): 53 | print("{0}: {1}".format(i, o)) 54 | choice = input("\nPlease enter the number of the " 55 | "dataset above you wish to download:") 56 | return options[int(choice)] 57 | 58 | def _download_data(self, dataset_url, save_path): 59 | if not isdir(save_path): 60 | os.makedirs(save_path) 61 | 62 | base = basename(dataset_url) 63 | temp_save_path = join(save_path, base) 64 | 65 | with open(temp_save_path, "wb") as f: 66 | r = requests.get(dataset_url) 67 | f.write(r.content) 68 | 69 | if base.endswith('.tar.gz'): 70 | obj = tarfile.open(temp_save_path) 71 | elif base.endswith('.zip'): 72 | obj = ZipFile(temp_save_path, 'r') 73 | else: 74 | raise ValueError("Unknown File Type: {0}.".format(base)) 75 | 76 | self._print("Unpacking Data...") 77 | obj.extractall(save_path) 78 | obj.close() 79 | os.remove(temp_save_path) 80 | 81 | def get(self, save_path, dataset=None): 82 | """ 83 | 84 | Download a dataset. 85 | 86 | Args: 87 | save_path : str 88 | A directory to save the data to. 89 | dataset : str, optional 90 | A specific dataset to download. 91 | Note: this must include the file extension. 92 | If None, options will be presented for you 93 | to choose from. 94 | 95 | Returns: 96 | save_path_full : str 97 | The absolute path to the downloaded data. 98 | 99 | """ 100 | if dataset is None: 101 | selected_dataset = self._present_options() 102 | else: 103 | selected_dataset = dataset 104 | 105 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 106 | 107 | if isdir(save_path_full): 108 | warn("\n'{0}' already exists. Voiding Download.".format( 109 | save_path_full)) 110 | else: 111 | self._print('Downloading Data...') 112 | url = "{0}/{1}".format(self.url, selected_dataset) 113 | self._download_data(url, save_path=save_path) 114 | 115 | return abspath(save_path_full) 116 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images: 17 | image = torch.unsqueeze(image.data, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = torch.cat(return_images, 0) 32 | return return_images 33 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | from collections import OrderedDict 7 | from IPython import embed 8 | 9 | # Converts a Tensor into an image array (numpy) 10 | # |imtype|: the desired type of the converted numpy array 11 | def tensor2im(input_image, imtype=np.uint8): 12 | if isinstance(input_image, torch.Tensor): 13 | image_tensor = input_image.data 14 | else: 15 | return input_image 16 | image_numpy = image_tensor[0].cpu().float().numpy() 17 | if image_numpy.shape[0] == 1: 18 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 19 | image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0)) ),0, 1) * 255.0 20 | return image_numpy.astype(imtype) 21 | 22 | 23 | def diagnose_network(net, name='network'): 24 | mean = 0.0 25 | count = 0 26 | for param in net.parameters(): 27 | if param.grad is not None: 28 | mean += torch.mean(torch.abs(param.grad.data)) 29 | count += 1 30 | if count > 0: 31 | mean = mean / count 32 | print(name) 33 | print(mean) 34 | 35 | 36 | def save_image(image_numpy, image_path): 37 | image_pil = Image.fromarray(image_numpy) 38 | image_pil.save(image_path) 39 | 40 | 41 | def print_numpy(x, val=True, shp=False): 42 | x = x.astype(np.float64) 43 | if shp: 44 | print('shape,', x.shape) 45 | if val: 46 | x = x.flatten() 47 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 48 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 49 | 50 | 51 | def mkdirs(paths): 52 | if isinstance(paths, list) and not isinstance(paths, str): 53 | for path in paths: 54 | mkdir(path) 55 | else: 56 | mkdir(paths) 57 | 58 | 59 | def mkdir(path): 60 | if not os.path.exists(path): 61 | os.makedirs(path) 62 | 63 | 64 | def get_subset_dict(in_dict,keys): 65 | if(len(keys)): 66 | subset = OrderedDict() 67 | for key in keys: 68 | subset[key] = in_dict[key] 69 | else: 70 | subset = in_dict 71 | return subset 72 | 73 | 74 | 75 | # Color conversion code 76 | def rgb2xyz(rgb): # rgb from [0,1] 77 | # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], 78 | # [0.212671, 0.715160, 0.072169], 79 | # [0.019334, 0.119193, 0.950227]]) 80 | 81 | mask = (rgb > .04045).type(torch.FloatTensor) 82 | if(rgb.is_cuda): 83 | mask = mask.cuda() 84 | 85 | rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask) 86 | 87 | x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:] 88 | y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:] 89 | z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:] 90 | out = torch.cat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1) 91 | 92 | # if(torch.sum(torch.isnan(out))>0): 93 | # print('rgb2xyz') 94 | # embed() 95 | return out 96 | 97 | def xyz2rgb(xyz): 98 | # array([[ 3.24048134, -1.53715152, -0.49853633], 99 | # [-0.96925495, 1.87599 , 0.04155593], 100 | # [ 0.05564664, -0.20404134, 1.05731107]]) 101 | 102 | r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:] 103 | g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:] 104 | b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:] 105 | 106 | rgb = torch.cat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1) 107 | rgb = torch.max(rgb,torch.zeros_like(rgb)) # sometimes reaches a small negative number, which causes NaNs 108 | 109 | mask = (rgb > .0031308).type(torch.FloatTensor) 110 | if(rgb.is_cuda): 111 | mask = mask.cuda() 112 | 113 | rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask) 114 | 115 | # if(torch.sum(torch.isnan(rgb))>0): 116 | # print('xyz2rgb') 117 | # embed() 118 | return rgb 119 | 120 | def xyz2lab(xyz): 121 | # 0.95047, 1., 1.08883 # white 122 | sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None] 123 | if(xyz.is_cuda): 124 | sc = sc.cuda() 125 | 126 | xyz_scale = xyz/sc 127 | 128 | mask = (xyz_scale > .008856).type(torch.FloatTensor) 129 | if(xyz_scale.is_cuda): 130 | mask = mask.cuda() 131 | 132 | xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask) 133 | 134 | L = 116.*xyz_int[:,1,:,:]-16. 135 | a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:]) 136 | b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:]) 137 | out = torch.cat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1) 138 | 139 | # if(torch.sum(torch.isnan(out))>0): 140 | # print('xyz2lab') 141 | # embed() 142 | 143 | return out 144 | 145 | def lab2xyz(lab): 146 | y_int = (lab[:,0,:,:]+16.)/116. 147 | x_int = (lab[:,1,:,:]/500.) + y_int 148 | z_int = y_int - (lab[:,2,:,:]/200.) 149 | if(z_int.is_cuda): 150 | z_int = torch.max(torch.Tensor((0,)).cuda(), z_int) 151 | else: 152 | z_int = torch.max(torch.Tensor((0,)), z_int) 153 | 154 | out = torch.cat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1) 155 | mask = (out > .2068966).type(torch.FloatTensor) 156 | if(out.is_cuda): 157 | mask = mask.cuda() 158 | 159 | out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask) 160 | 161 | sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None] 162 | sc = sc.to(out.device) 163 | 164 | out = out*sc 165 | 166 | # if(torch.sum(torch.isnan(out))>0): 167 | # print('lab2xyz') 168 | # embed() 169 | 170 | return out 171 | 172 | def rgb2lab(rgb, opt): 173 | lab = xyz2lab(rgb2xyz(rgb)) 174 | l_rs = (lab[:,[0],:,:]-opt.l_cent)/opt.l_norm 175 | ab_rs = lab[:,1:,:,:]/opt.ab_norm 176 | out = torch.cat((l_rs,ab_rs),dim=1) 177 | # if(torch.sum(torch.isnan(out))>0): 178 | # print('rgb2lab') 179 | # embed() 180 | return out 181 | 182 | def lab2rgb(lab_rs, opt): 183 | l = lab_rs[:,[0],:,:]*opt.l_norm + opt.l_cent 184 | ab = lab_rs[:,1:,:,:]*opt.ab_norm 185 | lab = torch.cat((l,ab),dim=1) 186 | out = xyz2rgb(lab2xyz(lab)) 187 | # if(torch.sum(torch.isnan(out))>0): 188 | # print('lab2rgb') 189 | # embed() 190 | return out 191 | 192 | def get_colorization_data(data_raw, opt, ab_thresh=5., p=.125, num_points=None): 193 | data = {} 194 | 195 | data_lab = rgb2lab(data_raw[0], opt) 196 | data['A'] = data_lab[:,[0,],:,:] 197 | data['B'] = data_lab[:,1:,:,:] 198 | 199 | if(ab_thresh > 0): # mask out grayscale images 200 | thresh = 1.*ab_thresh/opt.ab_norm 201 | mask = torch.sum(torch.abs(torch.max(torch.max(data['B'],dim=3)[0],dim=2)[0]-torch.min(torch.min(data['B'],dim=3)[0],dim=2)[0]),dim=1) >= thresh 202 | data['A'] = data['A'][mask,:,:,:] 203 | data['B'] = data['B'][mask,:,:,:] 204 | # print('Removed %i points'%torch.sum(mask==0).numpy()) 205 | if(torch.sum(mask)==0): 206 | return None 207 | 208 | return add_color_patches_rand_gt(data, opt, p=p, num_points=num_points) 209 | 210 | def add_color_patches_rand_gt(data,opt,p=.125,num_points=None,use_avg=True,samp='normal'): 211 | # Add random color points sampled from ground truth based on: 212 | # Number of points 213 | # - if num_points is 0, then sample from geometric distribution, drawn from probability p 214 | # - if num_points > 0, then sample that number of points 215 | # Location of points 216 | # - if samp is 'normal', draw from N(0.5, 0.25) of image 217 | # - otherwise, draw from U[0, 1] of image 218 | N,C,H,W = data['B'].shape 219 | 220 | data['hint_B'] = torch.zeros_like(data['B']) 221 | data['mask_B'] = torch.zeros_like(data['A']) 222 | 223 | for nn in range(N): 224 | pp = 0 225 | cont_cond = True 226 | while(cont_cond): 227 | if(num_points is None): # draw from geometric 228 | # embed() 229 | cont_cond = np.random.rand() < (1-p) 230 | else: # add certain number of points 231 | cont_cond = pp < num_points 232 | if(not cont_cond): # skip out of loop if condition not met 233 | continue 234 | 235 | P = np.random.choice(opt.sample_Ps) # patch size 236 | 237 | # sample location 238 | if(samp=='normal'): # geometric distribution 239 | h = int(np.clip(np.random.normal( (H-P+1)/2., (H-P+1)/4.), 0, H-P)) 240 | w = int(np.clip(np.random.normal( (W-P+1)/2., (W-P+1)/4.), 0, W-P)) 241 | else: # uniform distribution 242 | h = np.random.randint(H-P+1) 243 | w = np.random.randint(W-P+1) 244 | 245 | # add color point 246 | if(use_avg): 247 | # embed() 248 | data['hint_B'][nn,:,h:h+P,w:w+P] = torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1) 249 | else: 250 | data['hint_B'][nn,:,h:h+P,w:w+P] = data['B'][nn,:,h:h+P,w:w+P] 251 | 252 | data['mask_B'][nn,:,h:h+P,w:w+P] = 1 253 | 254 | # increment counter 255 | pp+=1 256 | 257 | data['mask_B']-=opt.mask_cent 258 | 259 | return data 260 | 261 | def add_color_patch(data,mask,opt,P=1,hw=[128,128],ab=[0,0]): 262 | # Add a color patch at (h,w) with color (a,b) 263 | data[:,0,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1.*ab[0]/opt.ab_norm 264 | data[:,1,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1.*ab[1]/opt.ab_norm 265 | mask[:,:,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1-opt.mask_cent 266 | 267 | return (data,mask) 268 | 269 | def crop_mult(data,mult=16,HWmax=[800,1200]): 270 | # crop image to a multiple 271 | H,W = data.shape[2:] 272 | Hnew = int(min(H/mult*mult,HWmax[0])) 273 | Wnew = int(min(W/mult*mult,HWmax[1])) 274 | h = (H-Hnew)/2 275 | w = (W-Wnew)/2 276 | 277 | return data[:,:,h:h+Hnew,w:w+Wnew] 278 | 279 | def encode_ab_ind(data_ab, opt): 280 | # Encode ab value into an index 281 | # INPUTS 282 | # data_ab Nx2xHxW \in [-1,1] 283 | # OUTPUTS 284 | # data_q Nx1xHxW \in [0,Q) 285 | 286 | data_ab_rs = torch.round((data_ab*opt.ab_norm + opt.ab_max)/opt.ab_quant) # normalized bin number 287 | data_q = data_ab_rs[:,[0],:,:]*opt.A + data_ab_rs[:,[1],:,:] 288 | return data_q 289 | 290 | def decode_ind_ab(data_q, opt): 291 | # Decode index into ab value 292 | # INPUTS 293 | # data_q Nx1xHxW \in [0,Q) 294 | # OUTPUTS 295 | # data_ab Nx2xHxW \in [-1,1] 296 | 297 | data_a = data_q/opt.A 298 | data_b = data_q - data_a*opt.A 299 | data_ab = torch.cat((data_a,data_b),dim=1) 300 | 301 | if(data_q.is_cuda): 302 | type_out = torch.cuda.FloatTensor 303 | else: 304 | type_out = torch.FloatTensor 305 | data_ab = ((data_ab.type(type_out)*opt.ab_quant) - opt.ab_max)/opt.ab_norm 306 | 307 | return data_ab 308 | 309 | def decode_max_ab(data_ab_quant, opt): 310 | # Decode probability distribution by using bin with highest probability 311 | # INPUTS 312 | # data_ab_quant NxQxHxW \in [0,1] 313 | # OUTPUTS 314 | # data_ab Nx2xHxW \in [-1,1] 315 | 316 | data_q = torch.argmax(data_ab_quant,dim=1)[:,None,:,:] 317 | return decode_ind_ab(data_q, opt) 318 | 319 | def decode_mean(data_ab_quant, opt): 320 | # Decode probability distribution by taking mean over all bins 321 | # INPUTS 322 | # data_ab_quant NxQxHxW \in [0,1] 323 | # OUTPUTS 324 | # data_ab_inf Nx2xHxW \in [-1,1] 325 | 326 | (N,Q,H,W) = data_ab_quant.shape 327 | a_range = torch.range(-opt.ab_max, opt.ab_max, step=opt.ab_quant).to(data_ab_quant.device)[None,:,None,None] 328 | a_range = a_range.type(data_ab_quant.type()) 329 | 330 | # reshape to AB space 331 | data_ab_quant = data_ab_quant.view((N,int(opt.A),int(opt.A),H,W)) 332 | data_a_total = torch.sum(data_ab_quant,dim=2) 333 | data_b_total = torch.sum(data_ab_quant,dim=1) 334 | 335 | # matrix multiply 336 | data_a_inf = torch.sum(data_a_total * a_range,dim=1,keepdim=True) 337 | data_b_inf = torch.sum(data_b_total * a_range,dim=1,keepdim=True) 338 | 339 | data_ab_inf = torch.cat((data_a_inf,data_b_inf),dim=1)/opt.ab_norm 340 | 341 | return data_ab_inf 342 | 343 | def calculate_psnr_np(img1, img2): 344 | import numpy as np 345 | SE_map = (1.*img1-img2)**2 346 | cur_MSE = np.mean(SE_map) 347 | return 20*np.log10(255./np.sqrt(cur_MSE)) 348 | 349 | def calculate_psnr_torch(img1, img2): 350 | SE_map = (1.*img1-img2)**2 351 | cur_MSE = torch.mean(SE_map) 352 | return 20*torch.log10(1./torch.sqrt(cur_MSE)) 353 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | from scipy.misc import imresize 8 | 9 | 10 | # save image to the disk 11 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 12 | image_dir = webpage.get_image_dir() 13 | short_path = ntpath.basename(image_path[0]) 14 | name = os.path.splitext(short_path)[0] 15 | 16 | webpage.add_header(name) 17 | ims, txts, links = [], [], [] 18 | 19 | for label, im_data in visuals.items(): 20 | im = util.tensor2im(im_data) 21 | image_name = '%s_%s.png' % (name, label) 22 | save_path = os.path.join(image_dir, image_name) 23 | h, w, _ = im.shape 24 | if aspect_ratio > 1.0: 25 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 26 | if aspect_ratio < 1.0: 27 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 28 | util.save_image(im, save_path) 29 | 30 | ims.append(image_name) 31 | txts.append(label) 32 | links.append(image_name) 33 | webpage.add_images(ims, txts, links, width=width) 34 | 35 | 36 | class Visualizer(): 37 | def __init__(self, opt): 38 | self.display_id = opt.display_id 39 | self.use_html = opt.isTrain and not opt.no_html 40 | self.win_size = opt.display_winsize 41 | self.name = opt.name 42 | self.opt = opt 43 | self.saved = False 44 | if self.display_id > 0: 45 | import visdom 46 | self.ncols = opt.display_ncols 47 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port) 48 | 49 | if self.use_html: 50 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 51 | self.img_dir = os.path.join(self.web_dir, 'images') 52 | print('create web directory %s...' % self.web_dir) 53 | util.mkdirs([self.web_dir, self.img_dir]) 54 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 55 | with open(self.log_name, "a") as log_file: 56 | now = time.strftime("%c") 57 | log_file.write('================ Training Loss (%s) ================\n' % now) 58 | 59 | def reset(self): 60 | self.saved = False 61 | 62 | # |visuals|: dictionary of images to display or save 63 | def display_current_results(self, visuals, epoch, save_result): 64 | if self.display_id > 0: # show images in the browser 65 | ncols = self.ncols 66 | if ncols > 0: 67 | ncols = min(ncols, len(visuals)) 68 | h, w = next(iter(visuals.values())).shape[:2] 69 | table_css = """""" % (w, h) 73 | title = self.name 74 | label_html = '' 75 | label_html_row = '' 76 | images = [] 77 | idx = 0 78 | for label, image in visuals.items(): 79 | image_numpy = util.tensor2im(image) 80 | label_html_row += '%s' % label 81 | images.append(image_numpy.transpose([2, 0, 1])) 82 | idx += 1 83 | if idx % ncols == 0: 84 | label_html += '%s' % label_html_row 85 | label_html_row = '' 86 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 87 | while idx % ncols != 0: 88 | images.append(white_image) 89 | label_html_row += '' 90 | idx += 1 91 | if label_html_row != '': 92 | label_html += '%s' % label_html_row 93 | # pane col = image row 94 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 95 | padding=2, opts=dict(title=title + ' images')) 96 | label_html = '%s
' % label_html 97 | self.vis.text(table_css + label_html, win=self.display_id + 2, 98 | opts=dict(title=title + ' labels')) 99 | else: 100 | idx = 1 101 | for label, image in visuals.items(): 102 | image_numpy = util.tensor2im(image) 103 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 104 | win=self.display_id + idx) 105 | idx += 1 106 | 107 | if self.use_html and (save_result or not self.saved): # save images to a html file 108 | self.saved = True 109 | for label, image in visuals.items(): 110 | image_numpy = util.tensor2im(image) 111 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 112 | util.save_image(image_numpy, img_path) 113 | # update website 114 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 115 | for n in range(epoch, 0, -1): 116 | webpage.add_header('epoch [%d]' % n) 117 | ims, txts, links = [], [], [] 118 | 119 | for label, image_numpy in visuals.items(): 120 | image_numpy = util.tensor2im(image) 121 | img_path = 'epoch%.3d_%s.png' % (n, label) 122 | ims.append(img_path) 123 | txts.append(label) 124 | links.append(img_path) 125 | webpage.add_images(ims, txts, links, width=self.win_size) 126 | webpage.save() 127 | 128 | # losses: dictionary of error labels and values 129 | def plot_current_losses(self, epoch, counter_ratio, opt, losses): 130 | if not hasattr(self, 'plot_data'): 131 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 132 | self.plot_data['X'].append(epoch + counter_ratio) 133 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 134 | self.vis.line( 135 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 136 | Y=np.array(self.plot_data['Y']), 137 | opts={ 138 | 'title': self.name + ' loss over time', 139 | 'legend': self.plot_data['legend'], 140 | 'xlabel': 'epoch', 141 | 'ylabel': 'loss'}, 142 | win=self.display_id) 143 | 144 | # losses: same format as |losses| of plot_current_losses 145 | def print_current_losses(self, epoch, i, losses, t, t_data): 146 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 147 | for k, v in losses.items(): 148 | message += '%s: %.3f, ' % (k, v) 149 | 150 | print(message) 151 | with open(self.log_name, "a") as log_file: 152 | log_file.write('%s\n' % message) 153 | --------------------------------------------------------------------------------