├── checkpoints └── imagenet │ └── README.md ├── imgs ├── img1.png ├── img2.png ├── img3.png ├── img4.png ├── img5.png ├── img6.png ├── img7.png └── img8.png ├── requirements.txt ├── util ├── __init__.py ├── __pycache__ │ ├── html.cpython-36.pyc │ ├── html.cpython-37.pyc │ ├── util.cpython-36.pyc │ ├── util.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── visualizer.cpython-36.pyc │ └── visualizer.cpython-37.pyc ├── html.py ├── util.py └── visualizer.py ├── data ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── colorization_dataset.py ├── __init__.py └── base_dataset.py ├── models ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_model.cpython-36.pyc │ ├── base_model.cpython-37.pyc │ ├── networks.cpython-36.pyc │ ├── networks.cpython-37.pyc │ ├── colorization_model.cpython-36.pyc │ └── colorization_model.cpython-37.pyc ├── main_model.py ├── colorization_model.py ├── __init__.py ├── base_model.py └── networks.py ├── options ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_options.cpython-36.pyc │ ├── base_options.cpython-37.pyc │ ├── test_options.cpython-36.pyc │ └── test_options.cpython-37.pyc ├── __init__.py ├── test_options.py └── base_options.py ├── test.sh ├── test.py └── README.md /checkpoints/imagenet/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/imgs/img1.png -------------------------------------------------------------------------------- /imgs/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/imgs/img2.png -------------------------------------------------------------------------------- /imgs/img3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/imgs/img3.png -------------------------------------------------------------------------------- /imgs/img4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/imgs/img4.png -------------------------------------------------------------------------------- /imgs/img5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/imgs/img5.png -------------------------------------------------------------------------------- /imgs/img6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/imgs/img6.png -------------------------------------------------------------------------------- /imgs/img7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/imgs/img7.png -------------------------------------------------------------------------------- /imgs/img8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/imgs/img8.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.1 2 | torchvision>=0.2.1 3 | dominate>=2.3.1 4 | visdom>=0.1.8.3 5 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/util/__pycache__/html.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/util/__pycache__/html.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/util/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/util/__pycache__/visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/options/__pycache__/test_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /models/__pycache__/colorization_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/models/__pycache__/colorization_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/colorization_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-xueba/Gray2ColorNet/master/models/__pycache__/colorization_model.cpython-37.pyc -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | if [ -f "./test.out" ] 4 | then 5 | rm "./test.out" 6 | fi 7 | nohup python3 -u test.py --dataroot ./datasets/imagenet \ 8 | --use_D \ 9 | --preprocess none \ 10 | --gpu_id 0 >test.out 2>&1 & 11 | 12 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | # rewrite devalue values 20 | parser.set_defaults(model='test') 21 | # To avoid cropping, the load_size should be the same as crop_size 22 | parser.set_defaults(load_size=parser.get_default('crop_size')) 23 | self.isTrain = False 24 | return parser 25 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import create_dataset 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from util import html 7 | import numpy as np 8 | 9 | 10 | if __name__ == '__main__': 11 | opt = TestOptions().parse() 12 | opt.num_threads = 0 13 | opt.batch_size = 1 14 | opt.serial_batches = True 15 | opt.no_flip = True 16 | opt.display_id = -1 17 | dataset = create_dataset(opt) 18 | model = create_model(opt) 19 | model.setup(opt) 20 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) 21 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 22 | scores = [] 23 | if opt.eval: 24 | model.eval() 25 | for i, data in enumerate(dataset): 26 | model.set_input(data) 27 | model.test() 28 | visuals = model.get_current_visuals() 29 | img_path = model.get_image_paths() 30 | metrics = model.compute_scores() 31 | scores.extend(metrics) 32 | if i % 5 == 0: 33 | print('processing (%04d)-th image... %s' % (i, img_path)) 34 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 35 | webpage.save() 36 | print('Histogram Intersection: %.4f' % np.mean(scores)) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Gray2ColorNet 2 | 3 | This is a release of [Gray2ColorNet: Transfer More Colors from Reference Image](https://dl.acm.org/doi/10.1145/3394171.3413594). 4 | 5 | ## Pretrained Models 6 | 7 | You can download the pretrained model from [https://drive.google.com/file/d/1fgBjqrWLEGiAV60BqMV-LrGKVv_N2uqW/view?usp=sharing](https://drive.google.com/file/d/1fgBjqrWLEGiAV60BqMV-LrGKVv_N2uqW/view?usp=sharing). 8 | 9 | Unzip the model weight files and move them to ./checkpoints/imagenet/ 10 | 11 | ### How do I cite Gray2ColorNet? 12 | ``` 13 | @inproceedings{lu2020gray2colornet, 14 | title={Gray2ColorNet: Transfer More Colors from Reference Image}, 15 | author={Lu, Peng and Yu, Jinbei and Peng, Xujun and Zhao, Zhaoran and Wang, Xiaojie}, 16 | booktitle={Proceedings of the 28th ACM International Conference on Multimedia}, 17 | pages={3210--3218}, 18 | year={2020} 19 | } 20 | ``` 21 | 22 | If we submit the paper to a conference or journal, we will update the BibTeX. 23 | 24 | ### Contact information 25 | 26 | For help or issues using Gray2ColorNet, please submit a GitHub issue. 27 | 28 | For personal communication related to Gray2ColorNet, please contact Peng Lu (lupeng@bupt.edu.cn). 29 | 30 | ### Paper picture 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /models/main_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | from util import util 4 | 5 | 6 | class MainModel(BaseModel): 7 | 8 | @staticmethod 9 | def modify_commandline_options(parser, is_train=True): 10 | parser.set_defaults(norm='instance', dataset_mode='aligned') 11 | return parser 12 | 13 | def __init__(self, opt): 14 | BaseModel.__init__(self, opt) 15 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 16 | self.model_names = ['G'] 17 | self.netG = networks.define_G(opt.input_nc, opt.bias_input_nc, opt.output_nc, opt.norm, 18 | opt.init_type, opt.init_gain, self.gpu_ids) 19 | self.convert = util.Convert(self.device) 20 | 21 | def set_input(self, input): 22 | self.image_paths = input['A_paths'] 23 | self.labels = input['labels'].squeeze(-1).to(self.device) 24 | 25 | self.real_A_l = [] 26 | self.real_A_ab = [] 27 | self.real_R_l = [] 28 | self.real_R_ab = [] 29 | self.real_R_histogram = [] 30 | 31 | for i in range(3): 32 | self.real_A_l += [input['A_l'][i].to(self.device)] 33 | self.real_A_ab += [input['A_ab'][i].to(self.device)] 34 | self.real_R_l += [input['R_l'][i].to(self.device)] 35 | self.real_R_ab += [input['R_ab'][i].to(self.device)] 36 | self.real_R_histogram += [util.calc_hist(input['hist_ab'][i].to(self.device), self.device)] 37 | 38 | def forward(self): 39 | self.fake_imgs, self.predicts, self.mask = self.netG(self.real_A_l[-1], self.real_R_histogram[-1], self.real_R_l[-1], self.real_R_ab[-1]) # G(A) 40 | self.fake_R_histogram = [] 41 | for i in range(3): 42 | self.fake_R_histogram += [util.calc_hist(self.fake_imgs[i], self.device)] 43 | -------------------------------------------------------------------------------- /models/colorization_model.py: -------------------------------------------------------------------------------- 1 | from .main_model import MainModel 2 | import torch 3 | from skimage import color # used for lab2rgb 4 | import numpy as np 5 | import cv2 6 | 7 | 8 | class ColorizationModel(MainModel): 9 | 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train=True): 12 | MainModel.modify_commandline_options(parser, is_train) 13 | parser.set_defaults(dataset_mode='colorization') 14 | return parser 15 | 16 | def __init__(self, opt): 17 | MainModel.__init__(self, opt) 18 | self.visual_names = ['real_A_l_0', 'real_A_rgb', 'real_R_rgb', 'fake_R_rgb', 'mask_gray'] 19 | 20 | def lab2rgb(self, L, AB): 21 | AB2 = AB * 110.0 22 | L2 = (L + 1.0) * 50.0 23 | Lab = torch.cat([L2, AB2], dim=1) 24 | Lab = Lab[0].data.cpu().float().numpy() 25 | Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) 26 | rgb = color.lab2rgb(Lab) * 255 27 | return rgb 28 | 29 | def tensor2gray(self, im): 30 | im = im[0].data.cpu().float().numpy() 31 | im = np.transpose(im.astype(np.float64), (1, 2, 0)) 32 | im = np.repeat(im, 3, axis=-1) * 255 33 | return im 34 | 35 | def compute_visuals(self): 36 | self.real_A_l_0 = self.real_A_l[-1] 37 | self.real_A_rgb = self.lab2rgb(self.real_A_l[-1], self.real_A_ab[-1]) 38 | self.real_R_rgb = self.lab2rgb(self.real_R_l[-1], self.real_R_ab[-1]) 39 | self.real_R_rgb = cv2.resize(self.real_R_rgb, (self.real_A_rgb.shape[1], self.real_A_rgb.shape[0])) 40 | self.fake_R_rgb = [] 41 | for i in range(3): 42 | self.fake_R_rgb += [self.lab2rgb(self.real_A_l[i], self.fake_imgs[i])] 43 | if i != 2: 44 | self.fake_R_rgb[i] = cv2.resize(self.fake_R_rgb[i], (self.real_A_rgb.shape[1], self.real_A_rgb.shape[0])) 45 | self.mask_gray = cv2.resize(self.tensor2gray(self.mask), (self.real_A_rgb.shape[1], self.real_A_rgb.shape[0]), interpolation=cv2.INTER_LINEAR) 46 | 47 | def compute_scores(self): 48 | metrics = [] 49 | hr = self.real_R_histogram[-1].data.cpu().float().numpy().flatten() 50 | hg = self.fake_R_histogram[-1].data.cpu().float().numpy().flatten() 51 | intersect = cv2.compareHist(hr, hg, cv2.HISTCMP_INTERSECT) 52 | metrics.append(intersect) 53 | 54 | return metrics 55 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name('colorization') 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /data/colorization_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform 3 | from skimage import color # require skimage 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import torch 8 | import cv2 9 | 10 | 11 | class ColorizationDataset(BaseDataset): 12 | """This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space.""" 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | """Add new dataset-specific options, and rewrite default values for existing options. 16 | 17 | Parameters: 18 | parser -- original option parser 19 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 20 | 21 | Returns: 22 | the modified parser. 23 | 24 | By default, the number of channels for input image is 1 (L) and 25 | the nubmer of channels for output image is 2 (ab). 26 | """ 27 | parser.set_defaults(input_nc=1, output_nc=2) 28 | return parser 29 | 30 | def __init__(self, opt): 31 | """Initialize this dataset class. 32 | 33 | Parameters: 34 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 35 | """ 36 | BaseDataset.__init__(self, opt) 37 | self.dir = os.path.join(opt.dataroot, opt.phase) 38 | self.AB_paths = sorted(self.make_dataset()) 39 | self.transform_A = get_transform(self.opt, convert=False) 40 | self.transform_R = get_transform(self.opt, convert=False, must_resize=True) 41 | assert(opt.input_nc == 1 and opt.output_nc == 2) 42 | 43 | def __getitem__(self, index): 44 | path_A, path_R = self.AB_paths[index] 45 | im_A_l, im_A_ab = self.process_img(path_A, self.transform_A) 46 | im_R_l, im_R_ab = self.process_img(path_R, self.transform_R) 47 | hist_ab = im_R_ab 48 | label = torch.Tensor([0]).long() 49 | 50 | im_dict = { 51 | 'A_l': im_A_l, 52 | 'A_ab': im_A_ab, 53 | 'R_l': im_R_l, 54 | 'R_ab': im_R_ab, 55 | 'hist_ab': hist_ab, 56 | 'labels': label, 57 | 'A_paths': path_A 58 | } 59 | return im_dict 60 | 61 | def make_dataset(self, max_dataset_size=float("inf")): 62 | images = [] 63 | assert os.path.isdir(self.dir), '%s is not a valid directory' % self.dir 64 | 65 | with open(os.path.join(self.dir, self.opt.paired_file)) as f: 66 | for line in f: 67 | line = line.strip().split('\t') 68 | line = [os.path.join(self.dir, i) for i in line] 69 | images.append(tuple(line)) 70 | return images[:min(max_dataset_size, len(images))] 71 | 72 | def process_img(self, im_path, transform): 73 | im = Image.open(im_path).convert('RGB') 74 | im = transform(im) 75 | im = np.array(im) 76 | ims = [im] 77 | for i in [0.5, 0.25]: 78 | ims = [cv2.resize(im, None, fx=i, fy=i, interpolation=cv2.INTER_AREA)] + ims 79 | l_ts, ab_ts = [], [] 80 | for im in ims: 81 | lab = color.rgb2lab(im).astype(np.float32) 82 | lab_t = transforms.ToTensor()(lab) 83 | l_ts.append(lab_t[[0], ...] / 50.0 - 1.0) 84 | ab_ts.append(lab_t[[1, 2], ...] / 110.0) 85 | return l_ts, ab_ts 86 | 87 | def __len__(self): 88 | """Return the total number of images in the dataset.""" 89 | return len(self.AB_paths) 90 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name('colorization') 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def add_more_images(self, ims, links, width=400): 69 | """add images to the HTML file 70 | 71 | Parameters: 72 | ims (str list) -- a list of image paths 73 | txts (str list) -- a list of image names shown on the website 74 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 75 | """ 76 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 77 | self.doc.add(self.t) 78 | with self.t: 79 | for i in range(2): 80 | with tr(): 81 | for im, link in zip(ims[i], links[i]): 82 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 83 | with p(): 84 | with a(href=os.path.join('images', link)): 85 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 86 | br() 87 | 88 | def save(self): 89 | """save the current content to the HMTL file""" 90 | html_file = '%s/index.html' % self.web_dir 91 | f = open(html_file, 'wt') 92 | f.write(self.doc.render()) 93 | f.close() 94 | 95 | 96 | if __name__ == '__main__': # we show an example usage here. 97 | html = HTML('web/', 'test_html') 98 | html.add_header('hello world') 99 | 100 | ims, txts, links = [], [], [] 101 | for n in range(4): 102 | ims.append('image_%d.png' % n) 103 | txts.append('text_%d' % n) 104 | links.append('image_%d.png' % n) 105 | html.add_images(ims, txts, links) 106 | html.save() 107 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from scipy import linalg 6 | from PIL import Image 7 | from skimage import color 8 | import os 9 | 10 | 11 | def tensor2im(input_image, imtype=np.uint8): 12 | """"Converts a Tensor array into a numpy image array. 13 | 14 | Parameters: 15 | input_image (tensor) -- the input image tensor array 16 | imtype (type) -- the desired type of the converted numpy array 17 | """ 18 | if not isinstance(input_image, np.ndarray): 19 | if isinstance(input_image, torch.Tensor): # get the data from a variable 20 | image_tensor = input_image.data 21 | else: 22 | return input_image 23 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 24 | if image_numpy.shape[0] == 1: # grayscale to RGB 25 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 26 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 27 | else: # if it is a numpy array, do nothing 28 | image_numpy = input_image 29 | return image_numpy.astype(imtype) 30 | 31 | 32 | def save_image(image_numpy, image_path): 33 | """Save a numpy image to the disk 34 | 35 | Parameters: 36 | image_numpy (numpy array) -- input numpy array 37 | image_path (str) -- the path of the image 38 | """ 39 | image_pil = Image.fromarray(image_numpy) 40 | image_pil.save(image_path) 41 | 42 | 43 | def mkdirs(paths): 44 | """create empty directories if they don't exist 45 | 46 | Parameters: 47 | paths (str list) -- a list of directory paths 48 | """ 49 | if isinstance(paths, list) and not isinstance(paths, str): 50 | for path in paths: 51 | mkdir(path) 52 | else: 53 | mkdir(paths) 54 | 55 | 56 | def mkdir(path): 57 | """create a single empty directory if it didn't exist 58 | 59 | Parameters: 60 | path (str) -- a single directory path 61 | """ 62 | if not os.path.exists(path): 63 | os.makedirs(path) 64 | 65 | 66 | def calc_hist(data_ab, device): 67 | N, C, H, W = data_ab.shape 68 | grid_a = torch.linspace(-1, 1, 21).view(1, 21, 1, 1, 1).expand(N, 21, 21, H, W).to(device) 69 | grid_b = torch.linspace(-1, 1, 21).view(1, 1, 21, 1, 1).expand(N, 21, 21, H, W).to(device) 70 | hist_a = torch.max(0.1 - torch.abs(grid_a - data_ab[:, 0, :, :].view(N, 1, 1, H, W)), torch.Tensor([0]).to(device)) * 10 71 | hist_b = torch.max(0.1 - torch.abs(grid_b - data_ab[:, 1, :, :].view(N, 1, 1, H, W)), torch.Tensor([0]).to(device)) * 10 72 | hist = (hist_a * hist_b).mean(dim=(3, 4)).view(N, -1) 73 | return hist 74 | 75 | 76 | class Convert(object): 77 | def __init__(self, device): 78 | xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], 79 | [0.212671, 0.715160, 0.072169], 80 | [0.019334, 0.119193, 0.950227]]) 81 | rgb_from_xyz = linalg.inv(xyz_from_rgb) 82 | self.rgb_from_xyz = torch.Tensor(rgb_from_xyz).to(device) 83 | self.channel_mask = torch.Tensor([1, 0, 0]).to(device) 84 | self.xyz_weight = torch.Tensor([0.95047, 1., 1.08883]).to(device) 85 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).to(device) 86 | self.std = torch.Tensor([0.229, 0.224, 0.225]).to(device) 87 | self.zero = torch.Tensor([0]).to(device) 88 | self.one = torch.Tensor([1]).to(device) 89 | 90 | def lab2rgb(self, img): 91 | img = img.permute(0, 2, 3, 1) 92 | img1 = (img + 1.0) * 50.0 * self.channel_mask 93 | img2 = img * 110.0 * (1 - self.channel_mask) 94 | img = img1 + img2 95 | return self.xyz2rgb(self.lab2xyz(img)) 96 | 97 | def lab2xyz(self, img): 98 | L, a, b = img[:, :, :, 0], img[:, :, :, 1], img[:, :, :, 2] 99 | y = (L + 16.) / 116. 100 | x = (a / 500.) + y 101 | z = y - (b / 200.) 102 | z = torch.max(z, self.zero) 103 | out = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1)], dim=-1) 104 | mask = (out > 0.2068966).float() 105 | out1 = torch.pow(out, 3) * mask 106 | out2 = (out - 16.0 / 116.) / 7.787 * (1 - mask) 107 | out = out1 + out2 108 | out *= self.xyz_weight 109 | return out 110 | 111 | def xyz2rgb(self, img): 112 | arr = img.matmul(self.rgb_from_xyz.t()) 113 | mask = (arr > 0.0031308).float() 114 | arr1 = (1.055 * torch.pow(torch.max(arr, self.zero), 1 / 2.4) - 0.055) * mask 115 | arr2 = arr * 12.92 * (1 - mask) 116 | arr = arr1 + arr2 117 | arr = torch.min(torch.max(arr, self.zero), self.one) 118 | return arr 119 | 120 | def rgb_norm(self, img): 121 | img = (img - self.mean) / self.std 122 | return img.permute(0, 3, 1, 2) 123 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True, must_resize=False): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess or must_resize: 86 | osize = [opt.crop_size, opt.crop_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.CenterCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if 'none' in opt.preprocess: 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=8, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor()] 108 | if grayscale: 109 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 110 | else: 111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 112 | return transforms.Compose(transform_list) 113 | 114 | 115 | def __make_power_2(img, base, method=Image.BICUBIC): 116 | ow, oh = img.size 117 | h = int(round(oh / base) * base) 118 | w = int(round(ow / base) * base) 119 | if (h == oh) and (w == ow): 120 | return img 121 | 122 | __print_size_warning(ow, oh, w, h) 123 | return img.resize((w, h), method) 124 | 125 | 126 | def __scale_width(img, target_width, method=Image.BICUBIC): 127 | ow, oh = img.size 128 | if ow <= oh: 129 | if (ow == target_width): 130 | return img 131 | w = target_width 132 | h = int(target_width * oh / ow) 133 | else: 134 | if (oh == target_width): 135 | return img 136 | h = target_width 137 | w = int(target_width * ow / oh) 138 | return img.resize((w, h), method) 139 | 140 | 141 | def __crop(img, pos, size): 142 | ow, oh = img.size 143 | x1, y1 = pos 144 | tw = th = size 145 | if (ow > tw or oh > th): 146 | return img.crop((x1, y1, x1 + tw, y1 + th)) 147 | return img 148 | 149 | 150 | def __flip(img, flip): 151 | if flip: 152 | return img.transpose(Image.FLIP_LEFT_RIGHT) 153 | return img 154 | 155 | 156 | def __print_size_warning(ow, oh, w, h): 157 | """Print warning information about image size(only print once)""" 158 | if not hasattr(__print_size_warning, 'has_printed'): 159 | print("The image size needs to be a multiple of 4. " 160 | "The loaded image size was (%d, %d), so it was adjusted to " 161 | "(%d, %d). This adjustment will be done to all images " 162 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 163 | __print_size_warning.has_printed = True 164 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | parser.add_argument('--dataroot', required=True, help='path to images') 24 | parser.add_argument('--name', type=str, default='imagenet', help='name of the experiment. It decides where to store samples and models') 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 27 | # model parameters 28 | parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale') 29 | parser.add_argument('--bias_input_nc', type=int, default=441, help='# of reference image histogram bins') 30 | parser.add_argument('--output_nc', type=int, default=2, help='# of output image channels: 3 for RGB and 1 for grayscale') 31 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 32 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 33 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 34 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 35 | # dataset parameters 36 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 37 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 38 | parser.add_argument('--batch_size', type=int, default=16, help='input batch size') 39 | parser.add_argument('--load_size', type=int, default=144, help='scale images to this size') 40 | parser.add_argument('--crop_size', type=int, default=128, help='then crop to this size') 41 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 42 | parser.add_argument('--preprocess', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 43 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 44 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 45 | parser.add_argument('--paired_file', type=str, default='shuffle_same.txt') 46 | # additional parameters 47 | parser.add_argument('--use_D', action='store_true', help='whether to use discriminator or not') 48 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 49 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 50 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 51 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 52 | self.initialized = True 53 | return parser 54 | 55 | def gather_options(self): 56 | """Initialize our parser with basic options(only once). 57 | Add additional model-specific and dataset-specific options. 58 | These options are defined in the function 59 | in model and dataset classes. 60 | """ 61 | if not self.initialized: # check if it has been initialized 62 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 63 | parser = self.initialize(parser) 64 | 65 | # get the basic options 66 | opt, _ = parser.parse_known_args() 67 | 68 | # modify model-related parser options 69 | model_name = 'colorization' 70 | model_option_setter = models.get_option_setter(model_name) 71 | parser = model_option_setter(parser, self.isTrain) 72 | opt, _ = parser.parse_known_args() # parse again with new defaults 73 | 74 | # modify dataset-related parser options 75 | dataset_name = 'colorization' 76 | dataset_option_setter = data.get_option_setter(dataset_name) 77 | parser = dataset_option_setter(parser, self.isTrain) 78 | 79 | # save and return the parser 80 | self.parser = parser 81 | return parser.parse_args() 82 | 83 | def print_options(self, opt): 84 | """Print and save options 85 | 86 | It will print both current options and default values(if different). 87 | It will save options into a text file / [checkpoints_dir] / opt.txt 88 | """ 89 | message = '' 90 | message += '----------------- Options ---------------\n' 91 | for k, v in sorted(vars(opt).items()): 92 | comment = '' 93 | default = self.parser.get_default(k) 94 | if v != default: 95 | comment = '\t[default: %s]' % str(default) 96 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 97 | message += '----------------- End -------------------' 98 | print(message) 99 | 100 | # save to the disk 101 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 102 | util.mkdirs(expr_dir) 103 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 104 | with open(file_name, 'wt') as opt_file: 105 | opt_file.write(message) 106 | opt_file.write('\n') 107 | 108 | def parse(self): 109 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 110 | opt = self.gather_options() 111 | opt.isTrain = self.isTrain # train or test 112 | 113 | # process opt.suffix 114 | if opt.suffix: 115 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 116 | opt.name = opt.name + suffix 117 | 118 | self.print_options(opt) 119 | 120 | # set gpu ids 121 | str_ids = opt.gpu_ids.split(',') 122 | opt.gpu_ids = [] 123 | for str_id in str_ids: 124 | id = int(str_id) 125 | if id >= 0: 126 | opt.gpu_ids.append(id) 127 | if len(opt.gpu_ids) > 0: 128 | torch.cuda.set_device(opt.gpu_ids[0]) 129 | 130 | self.opt = opt 131 | return self.opt 132 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | from scipy.misc import imresize 9 | 10 | if sys.version_info[0] == 2: 11 | VisdomExceptionBase = Exception 12 | else: 13 | VisdomExceptionBase = ConnectionError 14 | 15 | 16 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 17 | """Save images to the disk. 18 | 19 | Parameters: 20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 21 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 22 | image_path (str) -- the string is used to create image paths 23 | aspect_ratio (float) -- the aspect ratio of saved images 24 | width (int) -- the images will be resized to width x width 25 | 26 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 27 | """ 28 | image_dir = webpage.get_image_dir() 29 | # short_path = ntpath.basename(image_path[0]) 30 | short_path = image_path[0].split('/') 31 | short_path = short_path[-2] + '_' + short_path[-1] 32 | name = os.path.splitext(short_path)[0] 33 | 34 | webpage.add_header(name) 35 | ims, txts, links = [], [], [] 36 | 37 | for label, im_data in visuals.items(): 38 | im = util.tensor2im(im_data) 39 | image_name = '%s_%s.png' % (name, label) 40 | save_path = os.path.join(image_dir, image_name) 41 | h, w, _ = im.shape 42 | if aspect_ratio > 1.0: 43 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 44 | if aspect_ratio < 1.0: 45 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 46 | util.save_image(im, save_path) 47 | 48 | ims.append(image_name) 49 | txts.append(label) 50 | links.append(image_name) 51 | webpage.add_images(ims, txts, links, width=width) 52 | 53 | 54 | def save_more_images(webpage, name, sources, targets, aspect_ratio=1.0, width=256): 55 | image_dir = webpage.get_image_dir() 56 | 57 | webpage.add_header('%04d' % name) 58 | ims, links = [], [] 59 | 60 | for dnum, d in enumerate([sources, targets]): 61 | ims.append([]) 62 | links.append([]) 63 | for idx, im_data in enumerate(d): 64 | im = util.tensor2im(im_data) 65 | image_name = '%04d_%02d_%s.png' % (name, idx, ['S', 'T'][dnum]) 66 | save_path = os.path.join(image_dir, image_name) 67 | h, w, _ = im.shape 68 | if aspect_ratio > 1.0: 69 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 70 | if aspect_ratio < 1.0: 71 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 72 | util.save_image(im, save_path) 73 | ims[dnum].append(image_name) 74 | links[dnum].append(image_name) 75 | webpage.add_more_images(ims, links, width=width) 76 | 77 | 78 | class Visualizer(): 79 | """This class includes several functions that can display/save images and print/save logging information. 80 | 81 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 82 | """ 83 | 84 | def __init__(self, opt): 85 | """Initialize the Visualizer class 86 | 87 | Parameters: 88 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 89 | Step 1: Cache the training/test options 90 | Step 2: connect to a visdom server 91 | Step 3: create an HTML object for saveing HTML filters 92 | Step 4: create a logging file to store training losses 93 | """ 94 | self.opt = opt # cache the option 95 | self.display_id = opt.display_id 96 | self.use_html = opt.isTrain and not opt.no_html 97 | self.win_size = opt.display_winsize 98 | self.name = opt.name 99 | self.port = opt.display_port 100 | self.saved = False 101 | if self.display_id > 0: # connect to a visdom server given and 102 | import visdom 103 | self.ncols = opt.display_ncols 104 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 105 | if not self.vis.check_connection(): 106 | self.create_visdom_connections() 107 | 108 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 109 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 110 | self.img_dir = os.path.join(self.web_dir, 'images') 111 | print('create web directory %s...' % self.web_dir) 112 | util.mkdirs([self.web_dir, self.img_dir]) 113 | # create a logging file to store training losses 114 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 115 | with open(self.log_name, "a") as log_file: 116 | now = time.strftime("%c") 117 | log_file.write('================ Training Loss (%s) ================\n' % now) 118 | 119 | def reset(self): 120 | """Reset the self.saved status""" 121 | self.saved = False 122 | 123 | def create_visdom_connections(self): 124 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 125 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 126 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 127 | print('Command: %s' % cmd) 128 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 129 | 130 | def display_current_results(self, visuals, epoch, save_result): 131 | """Display current results on visdom; save current results to an HTML file. 132 | 133 | Parameters: 134 | visuals (OrderedDict) - - dictionary of images to display or save 135 | epoch (int) - - the current epoch 136 | save_result (bool) - - if save the current results to an HTML file 137 | """ 138 | if self.display_id > 0: # show images in the browser using visdom 139 | ncols = self.ncols 140 | if ncols > 0: # show all the images in one visdom panel 141 | ncols = min(ncols, len(visuals)) 142 | h, w = next(iter(visuals.values())).shape[:2] 143 | table_css = """""" % (w, h) # create a table css 147 | # create a table of images. 148 | title = self.name 149 | label_html = '' 150 | label_html_row = '' 151 | images = [] 152 | idx = 0 153 | for label, image in visuals.items(): 154 | image_numpy = util.tensor2im(image) 155 | label_html_row += '%s' % label 156 | images.append(image_numpy.transpose([2, 0, 1])) 157 | idx += 1 158 | if idx % ncols == 0: 159 | label_html += '%s' % label_html_row 160 | label_html_row = '' 161 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 162 | while idx % ncols != 0: 163 | images.append(white_image) 164 | label_html_row += '' 165 | idx += 1 166 | if label_html_row != '': 167 | label_html += '%s' % label_html_row 168 | try: 169 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 170 | padding=2, opts=dict(title=title + ' images')) 171 | label_html = '%s
' % label_html 172 | self.vis.text(table_css + label_html, win=self.display_id + 2, 173 | opts=dict(title=title + ' labels')) 174 | except VisdomExceptionBase: 175 | self.create_visdom_connections() 176 | 177 | else: # show each image in a separate visdom panel; 178 | idx = 1 179 | try: 180 | for label, image in visuals.items(): 181 | image_numpy = util.tensor2im(image) 182 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 183 | win=self.display_id + idx) 184 | idx += 1 185 | except VisdomExceptionBase: 186 | self.create_visdom_connections() 187 | 188 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 189 | self.saved = True 190 | # save images to the disk 191 | for label, image in visuals.items(): 192 | image_numpy = util.tensor2im(image) 193 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 194 | util.save_image(image_numpy, img_path) 195 | 196 | # update website 197 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 198 | for n in range(epoch, 0, -1): 199 | webpage.add_header('epoch [%d]' % n) 200 | ims, txts, links = [], [], [] 201 | 202 | for label, image_numpy in visuals.items(): 203 | image_numpy = util.tensor2im(image) 204 | img_path = 'epoch%.3d_%s.png' % (n, label) 205 | ims.append(img_path) 206 | txts.append(label) 207 | links.append(img_path) 208 | webpage.add_images(ims, txts, links, width=self.win_size) 209 | webpage.save() 210 | 211 | def plot_current_losses(self, epoch, counter_ratio, losses): 212 | """display the current losses on visdom display: dictionary of error labels and values 213 | 214 | Parameters: 215 | epoch (int) -- current epoch 216 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 217 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 218 | """ 219 | if not hasattr(self, 'plot_data'): 220 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 221 | self.plot_data['X'].append(epoch + counter_ratio) 222 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 223 | try: 224 | self.vis.line( 225 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 226 | Y=np.array(self.plot_data['Y']), 227 | opts={ 228 | 'title': self.name + ' loss over time', 229 | 'legend': self.plot_data['legend'], 230 | 'xlabel': 'epoch', 231 | 'ylabel': 'loss'}, 232 | win=self.display_id) 233 | except VisdomExceptionBase: 234 | self.create_visdom_connections() 235 | 236 | # losses: same format as |losses| of plot_current_losses 237 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 238 | """print current losses on console; also save the losses to the disk 239 | 240 | Parameters: 241 | epoch (int) -- current epoch 242 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 243 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 244 | t_comp (float) -- computational time per data point (normalized by batch_size) 245 | t_data (float) -- data loading time per data point (normalized by batch_size) 246 | """ 247 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 248 | for k, v in losses.items(): 249 | message += '%s: %.3f ' % (k, v) 250 | 251 | print(message) # print the message 252 | with open(self.log_name, "a") as log_file: 253 | log_file.write('%s\n' % message) # save the message 254 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this fucntion, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): specify the images that you want to display and save. 29 | -- self.visual_names (str list): define networks used in our training. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.log_names = [] 41 | self.model_names = [] 42 | self.visual_names = [] 43 | self.optimizers = [] 44 | self.image_paths = [] 45 | self.metric = 0 # used for learning rate policy 'plateau' 46 | 47 | @staticmethod 48 | def modify_commandline_options(parser, is_train): 49 | """Add new model-specific options, and rewrite default values for existing options. 50 | 51 | Parameters: 52 | parser -- original option parser 53 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 54 | 55 | Returns: 56 | the modified parser. 57 | """ 58 | return parser 59 | 60 | @abstractmethod 61 | def set_input(self, input): 62 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 63 | 64 | Parameters: 65 | input (dict): includes the data itself and its metadata information. 66 | """ 67 | pass 68 | 69 | @abstractmethod 70 | def forward(self): 71 | """Run forward pass; called by both functions and .""" 72 | pass 73 | 74 | def setup(self, opt): 75 | """Load and print networks; create schedulers 76 | 77 | Parameters: 78 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 79 | """ 80 | if self.isTrain: 81 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 82 | if not self.isTrain or opt.continue_train: 83 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 84 | self.load_networks(load_suffix) 85 | self.print_networks(opt.verbose) 86 | 87 | def eval(self): 88 | """Make models eval mode during test time""" 89 | for name in self.model_names: 90 | if isinstance(name, str): 91 | net = getattr(self, 'net' + name) 92 | net.eval() 93 | 94 | def test(self): 95 | """Forward function used in test time. 96 | 97 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 98 | It also calls to produce additional visualization results 99 | """ 100 | with torch.no_grad(): 101 | self.forward() 102 | self.compute_visuals() 103 | 104 | def compute_visuals(self): 105 | """Calculate additional output images for visdom and HTML visualization""" 106 | pass 107 | 108 | def get_image_paths(self): 109 | """ Return image paths that are used to load current data""" 110 | return self.image_paths 111 | 112 | def update_learning_rate(self): 113 | """Update learning rates for all the networks; called at the end of every epoch""" 114 | for scheduler in self.schedulers: 115 | if self.opt.lr_policy == 'plateau': 116 | scheduler.step(self.metric) 117 | else: 118 | scheduler.step() 119 | 120 | lr = self.optimizers[0].param_groups[0]['lr'] 121 | print('learning rate = %.7f' % lr) 122 | 123 | def get_current_visuals(self): 124 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 125 | visual_ret = OrderedDict() 126 | for name in self.visual_names: 127 | if isinstance(name, str): 128 | vis = getattr(self, name) 129 | if isinstance(vis, list): 130 | for i in range(len(vis)): 131 | visual_ret[name + '_' + str(i+1)] = vis[i] 132 | else: 133 | visual_ret[name] = vis 134 | return visual_ret 135 | 136 | def get_current_losses(self): 137 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 138 | errors_ret = OrderedDict() 139 | for name in self.loss_names: 140 | if isinstance(name, str): 141 | loss = getattr(self, 'loss_' + name) 142 | if isinstance(loss, list): 143 | for i in range(len(loss)): 144 | errors_ret[name + '_' + str(i+1)] = float(loss[i]) 145 | else: 146 | errors_ret[name] = float(loss) # float(...) works for both scalar tensor and float number 147 | return errors_ret 148 | 149 | def get_current_log(self): 150 | ret = self.get_current_losses() 151 | for name in self.log_names: 152 | if isinstance(name, str): 153 | ret[name] = float(getattr(self, name)) # float(...) works for both scalar tensor and float number 154 | return ret 155 | 156 | def save_networks(self, epoch): 157 | """Save all the networks to the disk. 158 | 159 | Parameters: 160 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 161 | """ 162 | for name in self.model_names: 163 | if isinstance(name, str): 164 | net = getattr(self, 'net' + name) 165 | if isinstance(net, list): 166 | for i in range(len(net)): 167 | save_filename = '%s_net_%s_%d.pth' % (epoch, name, i+1) 168 | save_path = os.path.join(self.save_dir, save_filename) 169 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 170 | torch.save(net[i].module.cpu().state_dict(), save_path) 171 | net[i].cuda(self.gpu_ids[0]) 172 | else: 173 | torch.save(net[i].cpu().state_dict(), save_path) 174 | else: 175 | save_filename = '%s_net_%s.pth' % (epoch, name) 176 | save_path = os.path.join(self.save_dir, save_filename) 177 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 178 | torch.save(net.module.cpu().state_dict(), save_path) 179 | net.cuda(self.gpu_ids[0]) 180 | else: 181 | torch.save(net.cpu().state_dict(), save_path) 182 | 183 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 184 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 185 | key = keys[i] 186 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 187 | if module.__class__.__name__.startswith('InstanceNorm') and \ 188 | (key == 'running_mean' or key == 'running_var'): 189 | if getattr(module, key) is None: 190 | state_dict.pop('.'.join(keys)) 191 | if module.__class__.__name__.startswith('InstanceNorm') and \ 192 | (key == 'num_batches_tracked'): 193 | state_dict.pop('.'.join(keys)) 194 | else: 195 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 196 | 197 | def load_networks(self, epoch): 198 | """Load all the networks from the disk. 199 | 200 | Parameters: 201 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 202 | """ 203 | for name in self.model_names: 204 | if isinstance(name, str): 205 | net = getattr(self, 'net' + name) 206 | if isinstance(net, list): 207 | for i in range(len(net)): 208 | load_filename = '%s_net_%s_%d.pth' % (epoch, name, i+1) 209 | load_path = os.path.join(self.save_dir, load_filename) 210 | net_i = net[i] 211 | if isinstance(net_i, torch.nn.DataParallel): 212 | net_i = net_i.module 213 | print('loading the model from %s' % load_path) 214 | # if you are using PyTorch newer than 0.4 (e.g., built from 215 | # GitHub source), you can remove str() on self.device 216 | state_dict = torch.load(load_path, map_location=str(self.device)) 217 | if hasattr(state_dict, '_metadata'): 218 | del state_dict._metadata 219 | 220 | # patch InstanceNorm checkpoints prior to 0.4 221 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 222 | self.__patch_instance_norm_state_dict(state_dict, net_i, key.split('.')) 223 | net_i.load_state_dict(state_dict) 224 | else: 225 | load_filename = '%s_net_%s.pth' % (epoch, name) 226 | load_path = os.path.join(self.save_dir, load_filename) 227 | if isinstance(net, torch.nn.DataParallel): 228 | net = net.module 229 | print('loading the model from %s' % load_path) 230 | # if you are using PyTorch newer than 0.4 (e.g., built from 231 | # GitHub source), you can remove str() on self.device 232 | state_dict = torch.load(load_path, map_location=str(self.device)) 233 | if hasattr(state_dict, '_metadata'): 234 | del state_dict._metadata 235 | 236 | # patch InstanceNorm checkpoints prior to 0.4 237 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 238 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 239 | net.load_state_dict(state_dict) 240 | 241 | def print_networks(self, verbose): 242 | """Print the total number of parameters in the network and (if verbose) network architecture 243 | 244 | Parameters: 245 | verbose (bool) -- if verbose: print the network architecture 246 | """ 247 | print('---------- Networks initialized -------------') 248 | for name in self.model_names: 249 | if isinstance(name, str): 250 | net = getattr(self, 'net' + name) 251 | if isinstance(net, list): 252 | for i in range(len(net)): 253 | num_params = 0 254 | for param in net[i].parameters(): 255 | num_params += param.numel() 256 | if verbose: 257 | print(net[i]) 258 | print('[Network %s_%d] Total number of parameters : %.3f M' % (name, i+1, num_params / 1e6)) 259 | else: 260 | num_params = 0 261 | for param in net.parameters(): 262 | num_params += param.numel() 263 | if verbose: 264 | print(net) 265 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 266 | print('-----------------------------------------------') 267 | 268 | def set_requires_grad(self, nets, requires_grad=False): 269 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 270 | Parameters: 271 | nets (network list) -- a list of networks 272 | requires_grad (bool) -- whether the networks require gradients or not 273 | """ 274 | if not isinstance(nets, list): 275 | nets = [nets] 276 | for net in nets: 277 | if net is not None: 278 | for param in net.parameters(): 279 | param.requires_grad = requires_grad 280 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | 8 | 9 | class Identity(nn.Module): 10 | def forward(self, x): 11 | return x 12 | 13 | 14 | def get_norm_layer(norm_type='instance'): 15 | if norm_type == 'batch': 16 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 17 | elif norm_type == 'instance': 18 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 19 | elif norm_type == 'none': 20 | norm_layer = lambda x: Identity() 21 | else: 22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 23 | return norm_layer 24 | 25 | 26 | def get_scheduler(optimizer, opt): 27 | if opt.lr_policy == 'linear': 28 | def lambda_rule(epoch): 29 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 30 | return lr_l 31 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 32 | elif opt.lr_policy == 'step': 33 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 34 | elif opt.lr_policy == 'plateau': 35 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 36 | elif opt.lr_policy == 'cosine': 37 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 38 | else: 39 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 40 | return scheduler 41 | 42 | 43 | def init_weights(net, init_type='normal', init_gain=0.02): 44 | def init_func(m): # define the initialization function 45 | classname = m.__class__.__name__ 46 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 47 | if init_type == 'normal': 48 | init.normal_(m.weight.data, 0.0, init_gain) 49 | elif init_type == 'xavier': 50 | init.xavier_normal_(m.weight.data, gain=init_gain) 51 | elif init_type == 'kaiming': 52 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 53 | elif init_type == 'orthogonal': 54 | init.orthogonal_(m.weight.data, gain=init_gain) 55 | else: 56 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 57 | if hasattr(m, 'bias') and m.bias is not None: 58 | init.constant_(m.bias.data, 0.0) 59 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 60 | init.normal_(m.weight.data, 1.0, init_gain) 61 | init.constant_(m.bias.data, 0.0) 62 | 63 | print('initialize network with %s' % init_type) 64 | net.apply(init_func) # apply the initialization function 65 | 66 | 67 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 68 | if len(gpu_ids) > 0: 69 | assert(torch.cuda.is_available()) 70 | net.to(gpu_ids[0]) 71 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 72 | init_weights(net, init_type, init_gain=init_gain) 73 | return net 74 | 75 | 76 | def define_G(input_nc, bias_input_nc, output_nc, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 77 | norm_layer = get_norm_layer(norm_type=norm) 78 | net = Gray2ColorNet(input_nc, bias_input_nc, output_nc, norm_layer=norm_layer) 79 | 80 | return init_net(net, init_type, init_gain, gpu_ids) 81 | 82 | 83 | class ResBlock(nn.Module): 84 | def __init__(self, dim, norm_layer, use_dropout, use_bias): 85 | super(ResBlock, self).__init__() 86 | conv_block = [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 87 | if use_dropout: 88 | conv_block += [nn.Dropout(0.5)] 89 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias), norm_layer(dim)] 90 | self.conv_block = nn.Sequential(*conv_block) 91 | 92 | def forward(self, x): 93 | out = x + self.conv_block(x) # add skip connections 94 | return out 95 | 96 | 97 | class global_network(nn.Module): 98 | def __init__(self, in_dim): 99 | super(global_network, self).__init__() 100 | model = [nn.Conv2d(in_dim, 512, kernel_size=1, padding=0), nn.ReLU(True)] 101 | model += [nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.ReLU(True)] 102 | model += [nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.ReLU(True)] 103 | self.model = nn.Sequential(*model) 104 | 105 | self.model_1 = nn.Sequential(*[nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.ReLU(True)]) 106 | self.model_2 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=1, padding=0), nn.ReLU(True)]) 107 | self.model_3 = nn.Sequential(*[nn.Conv2d(512, 128, kernel_size=1, padding=0), nn.ReLU(True)]) 108 | 109 | def forward(self, x): 110 | x = self.model(x) 111 | x1 = self.model_1(x) 112 | x2 = self.model_2(x) 113 | x3 = self.model_3(x) 114 | 115 | return x1, x2, x3 116 | 117 | 118 | class ref_network(nn.Module): 119 | def __init__(self, norm_layer): 120 | super(ref_network, self).__init__() 121 | model1 = [nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 122 | model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(64)] 123 | self.model1 = nn.Sequential(*model1) 124 | model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 125 | model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)] 126 | self.model2 = nn.Sequential(*model2) 127 | model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 128 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)] 129 | self.model3 = nn.Sequential(*model3) 130 | model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 131 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(256)] 132 | self.model4 = nn.Sequential(*model4) 133 | 134 | def forward(self, color, corr, H, W): 135 | conv1 = self.model1(color) 136 | conv2 = self.model2(conv1[:,:,::2,::2]) 137 | conv2_flatten = conv2.view(conv2.shape[0], conv2.shape[1], -1) 138 | 139 | align = torch.bmm(conv2_flatten, corr) 140 | align_1 = align.view(align.shape[0], align.shape[1], H, W) 141 | align_2 = self.model3(align_1[:,:,::2,::2]) 142 | align_3 = self.model4(align_2[:,:,::2,::2]) 143 | 144 | return align_1, align_2, align_3 145 | 146 | 147 | class conf_feature(nn.Module): 148 | def __init__(self): 149 | super(conf_feature, self).__init__() 150 | self.fc1 = nn.Sequential(*[nn.Conv1d(4096, 1024, kernel_size=1, stride=1, padding=0, bias=True), nn.ReLU(True)]) 151 | self.fc2 = nn.Sequential(*[nn.Conv1d(1024, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid()]) 152 | 153 | def forward(self, x): 154 | x = self.fc1(x) 155 | x = self.fc2(x) 156 | return x 157 | 158 | 159 | class classify_network(nn.Module): 160 | def __init__(self): 161 | super(classify_network, self).__init__() 162 | self.maxpool = nn.AdaptiveMaxPool2d((1, 1)) 163 | self.fc = nn.Linear(512, 1000) 164 | 165 | def forward(self, x): 166 | x = self.maxpool(x) 167 | x = x.squeeze(-1).squeeze(-1) 168 | x = self.fc(x) 169 | return x 170 | 171 | 172 | class Gray2ColorNet(nn.Module): 173 | def __init__(self, input_nc, bias_input_nc, output_nc, norm_layer=nn.BatchNorm2d): 174 | super(Gray2ColorNet, self).__init__() 175 | self.input_nc = input_nc 176 | self.output_nc = output_nc 177 | use_bias = True 178 | 179 | downmodel1=[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True)] 180 | downmodel1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(64)] 181 | 182 | downmodel2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True)] 183 | downmodel2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(128)] 184 | 185 | downmodel3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True)] 186 | downmodel3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True)] 187 | downmodel3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(256)] 188 | 189 | downmodel4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True)] 190 | downmodel4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True)] 191 | downmodel4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(512)] 192 | 193 | downmodel5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), nn.ReLU(True)] 194 | downmodel5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), nn.ReLU(True)] 195 | downmodel5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), nn.ReLU(True), norm_layer(512)] 196 | 197 | downmodel6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), nn.ReLU(True)] 198 | downmodel6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), nn.ReLU(True)] 199 | downmodel6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), nn.ReLU(True), norm_layer(512)] 200 | 201 | resblock0_1 = [nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=use_bias), norm_layer(512), nn.ReLU(True)] 202 | self.resblock0_2 = ResBlock(512, norm_layer, False, use_bias) 203 | self.resblock0_3 = ResBlock(512, norm_layer, False, use_bias) 204 | 205 | upmodel1up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)] 206 | upmodel1short=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),] 207 | upmodel1=[nn.ReLU(True), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True)] 208 | upmodel1+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(256)] 209 | 210 | resblock1_1 = [nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=use_bias), norm_layer(256), nn.ReLU(True)] 211 | self.resblock1_2 = ResBlock(256, norm_layer, False, use_bias) 212 | self.resblock1_3 = ResBlock(256, norm_layer, False, use_bias) 213 | 214 | upmodel2up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),] 215 | upmodel2short=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),] 216 | upmodel2=[nn.ReLU(True), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(128)] 217 | 218 | resblock2_1 = [nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=use_bias), norm_layer(128), nn.ReLU(True)] 219 | self.resblock2_2 = ResBlock(128, norm_layer, False, use_bias) 220 | self.resblock2_3 = ResBlock(128, norm_layer, False, use_bias) 221 | 222 | upmodel3up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),] 223 | upmodel3short=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),] 224 | upmodel3=[nn.ReLU(True), nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2)] 225 | 226 | self.global_network = global_network(bias_input_nc) 227 | self.ref_network = ref_network(norm_layer) 228 | self.conf_feature = conf_feature() 229 | self.classify_network = classify_network() 230 | 231 | model_out1 = [nn.Conv2d(256, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()] 232 | model_out2 = [nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()] 233 | model_out3 = [nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()] 234 | 235 | self.model1 = nn.Sequential(*downmodel1) 236 | self.model2 = nn.Sequential(*downmodel2) 237 | self.model3 = nn.Sequential(*downmodel3) 238 | self.model4 = nn.Sequential(*downmodel4) 239 | self.model5 = nn.Sequential(*downmodel5) 240 | self.model6 = nn.Sequential(*downmodel6) 241 | self.model8up = nn.Sequential(*upmodel1up) 242 | self.model8 = nn.Sequential(*upmodel1) 243 | self.model9up = nn.Sequential(*upmodel2up) 244 | self.model9 = nn.Sequential(*upmodel2) 245 | self.model10up = nn.Sequential(*upmodel3up) 246 | self.model10 = nn.Sequential(*upmodel3) 247 | self.model3short8 = nn.Sequential(*upmodel1short) 248 | self.model2short9 = nn.Sequential(*upmodel2short) 249 | self.model1short10 = nn.Sequential(*upmodel3short) 250 | self.resblock0_1 = nn.Sequential(*resblock0_1) 251 | self.resblock1_1 = nn.Sequential(*resblock1_1) 252 | self.resblock2_1 = nn.Sequential(*resblock2_1) 253 | self.model_out1 = nn.Sequential(*model_out1) 254 | self.model_out2 = nn.Sequential(*model_out2) 255 | self.model_out3 = nn.Sequential(*model_out3) 256 | 257 | def forward(self, input, bias_input, ref_input, ref_color): 258 | bias_input = bias_input.view(input.shape[0], -1, 1, 1) 259 | in_1 = self.model1(input) 260 | in_2 = self.model2(in_1[:,:,::2,::2]) 261 | in_3 = self.model3(in_2[:,:,::2,::2]) 262 | in_4 = self.model4(in_3[:,:,::2,::2]) 263 | in_5 = self.model5(in_4) 264 | in_6 = self.model6(in_5) 265 | 266 | ref_1 = self.model1(ref_input) 267 | ref_2 = self.model2(ref_1[:,:,::2,::2]) 268 | ref_3 = self.model3(ref_2[:,:,::2,::2]) 269 | ref_4 = self.model4(ref_3[:,:,::2,::2]) 270 | ref_5 = self.model5(ref_4) 271 | ref_6 = self.model6(ref_5) 272 | 273 | t1 = F.interpolate(in_1, scale_factor=0.5, mode='bilinear') 274 | t2 = in_2 275 | t3 = F.interpolate(in_3, scale_factor=2, mode='bilinear') 276 | t4 = F.interpolate(in_4, scale_factor=4, mode='bilinear') 277 | t5 = F.interpolate(in_5, scale_factor=4, mode='bilinear') 278 | t6 = F.interpolate(in_6, scale_factor=4, mode='bilinear') 279 | t = torch.cat((t1, t2, t3, t4, t5, t6), dim=1) 280 | 281 | r1 = F.interpolate(ref_1, scale_factor=0.5, mode='bilinear') 282 | r2 = ref_2 283 | r3 = F.interpolate(ref_3, scale_factor=2, mode='bilinear') 284 | r4 = F.interpolate(ref_4, scale_factor=4, mode='bilinear') 285 | r5 = F.interpolate(ref_5, scale_factor=4, mode='bilinear') 286 | r6 = F.interpolate(ref_6, scale_factor=4, mode='bilinear') 287 | r = torch.cat((r1, r2, r3, r4, r5, r6), dim=1) 288 | 289 | input_T_flatten = t.view(t.shape[0], t.shape[1], -1).permute(0, 2, 1) 290 | input_R_flatten = r.view(r.shape[0], r.shape[1], -1).permute(0, 2, 1) 291 | input_T_flatten = input_T_flatten / torch.norm(input_T_flatten, p=2, dim=-1, keepdim=True) 292 | input_R_flatten = input_R_flatten / torch.norm(input_R_flatten, p=2, dim=-1, keepdim=True) 293 | corr = torch.bmm(input_R_flatten, input_T_flatten.permute(0, 2, 1)) 294 | 295 | conf = self.conf_feature(corr) 296 | conf = conf.view(conf.shape[0], 1, t2.shape[2], t2.shape[3]) 297 | conf_1 = conf 298 | conf_2 = conf_1[:,:,::2,::2] 299 | conf_3 = conf_2[:,:,::2,::2] 300 | 301 | corr = F.softmax(corr/0.01, dim=1) 302 | align_1, align_2, align_3 = self.ref_network(ref_color, corr, t2.shape[2], t2.shape[3]) 303 | conv_global1, conv_global2, conv_global3 = self.global_network(bias_input) 304 | 305 | conv1_2 = self.model1(input) 306 | conv2_2 = self.model2(conv1_2[:,:,::2,::2]) 307 | conv3_3 = self.model3(conv2_2[:,:,::2,::2]) 308 | conv4_3 = self.model4(conv3_3[:,:,::2,::2]) 309 | conv5_3 = self.model5(conv4_3) 310 | conv6_3 = self.model6(conv5_3) 311 | 312 | class_output = self.classify_network(conv6_3) 313 | 314 | conv_global1_repeat = conv_global1.expand_as(conv6_3) 315 | conv6_3_global = conv6_3 + align_3 * conf_3 + conv_global1_repeat * (1 - conf_3) 316 | conv7_resblock1 = self.resblock0_1(conv6_3_global) 317 | conv7_resblock2 = self.resblock0_2(conv7_resblock1) 318 | conv7_resblock3 = self.resblock0_3(conv7_resblock2) 319 | conv8_up = self.model8up(conv7_resblock3) + self.model3short8(conv3_3) 320 | conv8_3 = self.model8(conv8_up) 321 | fake_img1 = self.model_out1(conv8_3) 322 | 323 | conv_global2_repeat = conv_global2.expand_as(conv8_3) 324 | conv8_3_global = conv8_3 + align_2 * conf_2 + conv_global2_repeat * (1 - conf_2) 325 | conv8_resblock1 = self.resblock1_1(conv8_3_global) 326 | conv8_resblock2 = self.resblock1_2(conv8_resblock1) 327 | conv8_resblock3 = self.resblock1_3(conv8_resblock2) 328 | conv9_up = self.model9up(conv8_resblock3) + self.model2short9(conv2_2) 329 | conv9_3 = self.model9(conv9_up) 330 | fake_img2 = self.model_out2(conv9_3) 331 | 332 | conv_global3_repeat = conv_global3.expand_as(conv9_3) 333 | conv9_3_global = conv9_3 + align_1 * conf_1 + conv_global3_repeat * (1 - conf_1) 334 | conv9_resblock1 = self.resblock2_1(conv9_3_global) 335 | conv9_resblock2 = self.resblock2_2(conv9_resblock1) 336 | conv9_resblock3 = self.resblock2_3(conv9_resblock2) 337 | conv10_up = self.model10up(conv9_resblock3) + self.model1short10(conv1_2) 338 | conv10_2 = self.model10(conv10_up) 339 | fake_img3 = self.model_out3(conv10_2) 340 | 341 | return [fake_img1, fake_img2, fake_img3], class_output, conf 342 | --------------------------------------------------------------------------------