├── README.md ├── data ├── __init__.py ├── base_dataset.py └── colorization_dataset.py ├── doc ├── ab_constant_filter.npy └── weight_index.npy ├── imgs ├── reference.JPEG ├── target.JPEG └── visual.jpg ├── models ├── __init__.py ├── base_model.py ├── colorization_model.py ├── main_model.py └── networks.py ├── options ├── __init__.py ├── base_options.py └── test_options.py ├── requirements.txt ├── test.py ├── test.sh └── util ├── __init__.py ├── html.py ├── util.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # Transformer for Image Colorization (Pytorch Implementation) 2 | 3 | 4 | 5 | ### [Paper](https://dl.acm.org/doi/10.1145/3474085.3475385) | [Pretrained Model](https://drive.google.com/file/d/11FM-2v4iVH8Dvowo-7bQG56Z_ey8kjOa/view?usp=sharing) 6 | 7 | **Yes, "Attention Is All You Need", for Exemplar based Colorization, ACMMM2021** 8 | 9 | Wang Yin1, 10 | Peng Lu1, 11 | Zhaoran Zhao1, 12 | Xujun Peng2
13 | 1Beijing University of Posts and Telecommunications,2USC 14 | ## Table of Contents 15 | 16 | - [Prerequisites](#Prerequisites) 17 | - [Getting Started](#Getting-Started) 18 | - [Citation](#Citation) 19 | 20 | ## Prerequisites 21 | - Ubuntu 16.04 22 | - Python 3.6.10 23 | - Pytorch 1.5.1 24 | - CPU or NVIDIA GPU + CUDA 10.2 CuDNN 25 | 26 | ## Getting Started 27 | 28 | ### Installation 29 | - Clone this repo: 30 | ```bash 31 | git clone https://github.com/wangyins/transformer-for-image-colorization 32 | cd transformer-for-image-colorization 33 | pip install requriments.txt 34 | ``` 35 | - Download model weights from this link to get "checkpoints_acmmm2021.zip" 36 | ```bash 37 | mkdir -p checkpoints/imagenet/ 38 | cd checkpoints/imagenet/ 39 | unzip checkpoints_acmmm2021.zip 40 | ``` 41 | ### Testing 42 | ```bash 43 | sh test.sh 44 | ``` 45 | ## Citation 46 | If you use this code for your research, please cite our paper. 47 | ``` 48 | @inproceedings{yin_mm2021, 49 | title={Yes, "Attention Is All You Need", for Exemplar based Colorization}, 50 | author={Yin, Wang and Lu, Peng and Zhao, ZhaoRan and Peng, XuJun}, 51 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia}, 52 | year={2021} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True, must_crop=False): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess or must_crop: 86 | osize = [opt.crop_size, opt.crop_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.CenterCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if 'none' in opt.preprocess: 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=8, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor()] 108 | if grayscale: 109 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 110 | else: 111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 112 | return transforms.Compose(transform_list) 113 | 114 | 115 | def __make_power_2(img, base, method=Image.BICUBIC): 116 | ow, oh = img.size 117 | h = int(round(oh / base) * base) 118 | w = int(round(ow / base) * base) 119 | if (h == oh) and (w == ow): 120 | return img 121 | 122 | __print_size_warning(ow, oh, w, h) 123 | return img.resize((w, h), method) 124 | 125 | 126 | def __scale_width(img, target_width, method=Image.BICUBIC): 127 | ow, oh = img.size 128 | if ow <= oh: 129 | if (ow == target_width): 130 | return img 131 | w = target_width 132 | h = int(target_width * oh / ow) 133 | else: 134 | if (oh == target_width): 135 | return img 136 | h = target_width 137 | w = int(target_width * ow / oh) 138 | return img.resize((w, h), method) 139 | 140 | 141 | def __crop(img, pos, size): 142 | ow, oh = img.size 143 | x1, y1 = pos 144 | tw = th = size 145 | if (ow > tw or oh > th): 146 | return img.crop((x1, y1, x1 + tw, y1 + th)) 147 | return img 148 | 149 | 150 | def __flip(img, flip): 151 | if flip: 152 | return img.transpose(Image.FLIP_LEFT_RIGHT) 153 | return img 154 | 155 | 156 | def __print_size_warning(ow, oh, w, h): 157 | """Print warning information about image size(only print once)""" 158 | if not hasattr(__print_size_warning, 'has_printed'): 159 | print("The image size needs to be a multiple of 4. " 160 | "The loaded image size was (%d, %d), so it was adjusted to " 161 | "(%d, %d). This adjustment will be done to all images " 162 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 163 | __print_size_warning.has_printed = True 164 | -------------------------------------------------------------------------------- /data/colorization_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from skimage import color # require skimage 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import cv2 8 | from collections import Counter 9 | from tqdm import tqdm 10 | 11 | 12 | class ColorizationDataset(BaseDataset): 13 | """This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space.""" 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | """Add new dataset-specific options, and rewrite default values for existing options. 17 | 18 | Parameters: 19 | parser -- original option parser 20 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 21 | 22 | Returns: 23 | the modified parser. 24 | 25 | By default, the number of channels for input image is 1 (L) and 26 | the nubmer of channels for output image is 2 (ab). The direction is from A to B 27 | """ 28 | parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB') 29 | return parser 30 | 31 | def __init__(self, opt): 32 | """Initialize this dataset class. 33 | 34 | Parameters: 35 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 36 | """ 37 | BaseDataset.__init__(self, opt) 38 | self.dir = os.path.join(opt.dataroot, opt.phase) 39 | self.AB_paths = [[self.opt.targetImage_path, self.opt.referenceImage_path]] 40 | self.ab_constant = np.load('./doc/ab_constant_filter.npy') 41 | self.transform_A = get_transform(self.opt, convert=False) 42 | self.transform_R = get_transform(self.opt, convert=False, must_crop=True) 43 | assert(opt.input_nc == 1 and opt.output_nc == 2) 44 | 45 | def __getitem__(self, index): 46 | path_A, path_R = self.AB_paths[index] 47 | im_A_l, im_A_ab, _ = self.process_img(path_A, self.transform_A) 48 | im_R_l, im_R_ab, hist = self.process_img(path_R, self.transform_R) 49 | 50 | im_dict = { 51 | 'A_l': im_A_l, 52 | 'A_ab': im_A_ab, 53 | 'R_l': im_R_l, 54 | 'R_ab': im_R_ab, 55 | 'ab': self.ab_constant, 56 | 'hist': hist, 57 | 'A_paths': path_A 58 | } 59 | return im_dict 60 | 61 | 62 | def process_img(self, im_path, transform): 63 | 64 | weights_index = np.load('./doc/weight_index.npy') 65 | 66 | im = Image.open(im_path).convert('RGB') 67 | im = transform(im) 68 | im = self.__scale_width(im, 256) 69 | im = np.array(im) 70 | im = im[:16 * int(im.shape[0] / 16.0), :16 * int(im.shape[1] / 16.0), :] 71 | l_ts, ab_ts, gt_keys = [], [], [] 72 | hist_total_new = np.zeros((441,), dtype=np.float32) 73 | for ratio in [0.25, 0.5, 1]: 74 | if ratio == 1: 75 | im_ratio = im 76 | else: 77 | im_ratio = cv2.resize(im, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_AREA) 78 | lab = color.rgb2lab(im_ratio).astype(np.float32) 79 | 80 | if ratio == 1: 81 | ab_index_1 = np.round(lab[:, :, 1:] / 110.0 / 0.1) + 10.0 82 | keys_t = ab_index_1[:,:,0] * 21+ ab_index_1[:,:,1] 83 | keys_t_flatten = keys_t.flatten().astype(np.int32) 84 | dict_counter = dict(Counter(keys_t_flatten)) 85 | for k, v in dict_counter.items(): 86 | hist_total_new[k] += v 87 | 88 | hist = hist_total_new[weights_index] 89 | hist = hist / np.sum(hist) 90 | 91 | lab_t = transforms.ToTensor()(lab) 92 | l_t = lab_t[[0], ...] / 50.0 - 1.0 93 | ab_t = lab_t[[1, 2], ...] / 110.0 94 | l_ts.append(l_t) 95 | ab_ts.append(ab_t) 96 | 97 | return l_ts, ab_ts, hist 98 | 99 | 100 | def __scale_width(self, img, target_width, method=Image.BICUBIC): 101 | ow, oh = img.size 102 | if ow <= oh: 103 | if (ow == target_width): 104 | return img 105 | w = target_width 106 | h = int(target_width * oh / ow) 107 | else: 108 | if (oh == target_width): 109 | return img 110 | h = target_width 111 | w = int(target_width * ow / oh) 112 | return img.resize((w, h), method) 113 | 114 | def __len__(self): 115 | """Return the total number of images in the dataset.""" 116 | return len(self.AB_paths) 117 | -------------------------------------------------------------------------------- /doc/ab_constant_filter.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/doc/ab_constant_filter.npy -------------------------------------------------------------------------------- /doc/weight_index.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/doc/weight_index.npy -------------------------------------------------------------------------------- /imgs/reference.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/imgs/reference.JPEG -------------------------------------------------------------------------------- /imgs/target.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/imgs/target.JPEG -------------------------------------------------------------------------------- /imgs/visual.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/imgs/visual.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this fucntion, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): specify the images that you want to display and save. 29 | -- self.visual_names (str list): define networks used in our training. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.log_names = [] 41 | self.model_names = [] 42 | self.visual_names = [] 43 | self.optimizers = [] 44 | self.image_paths = [] 45 | self.metric = 0 # used for learning rate policy 'plateau' 46 | 47 | @staticmethod 48 | def modify_commandline_options(parser, is_train): 49 | """Add new model-specific options, and rewrite default values for existing options. 50 | 51 | Parameters: 52 | parser -- original option parser 53 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 54 | 55 | Returns: 56 | the modified parser. 57 | """ 58 | return parser 59 | 60 | @abstractmethod 61 | def set_input(self, input): 62 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 63 | 64 | Parameters: 65 | input (dict): includes the data itself and its metadata information. 66 | """ 67 | pass 68 | 69 | @abstractmethod 70 | def forward(self): 71 | """Run forward pass; called by both functions and .""" 72 | pass 73 | 74 | def setup(self, opt): 75 | """Load and print networks; create schedulers 76 | 77 | Parameters: 78 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 79 | """ 80 | if self.isTrain: 81 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 82 | if not self.isTrain or opt.continue_train: 83 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 84 | self.load_networks(load_suffix) 85 | self.print_networks(opt.verbose) 86 | 87 | def eval(self): 88 | """Make models eval mode during test time""" 89 | for name in self.model_names: 90 | if isinstance(name, str): 91 | net = getattr(self, 'net' + name) 92 | net.eval() 93 | 94 | def test(self): 95 | """Forward function used in test time. 96 | 97 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 98 | It also calls to produce additional visualization results 99 | """ 100 | with torch.no_grad(): 101 | self.forward() 102 | self.compute_visuals() 103 | 104 | def compute_visuals(self): 105 | """Calculate additional output images for visdom and HTML visualization""" 106 | pass 107 | 108 | def get_image_paths(self): 109 | """ Return image paths that are used to load current data""" 110 | return self.image_paths 111 | 112 | def update_learning_rate(self): 113 | """Update learning rates for all the networks; called at the end of every epoch""" 114 | for scheduler in self.schedulers: 115 | if self.opt.lr_policy == 'plateau': 116 | scheduler.step(self.metric) 117 | else: 118 | scheduler.step() 119 | 120 | lr = self.optimizers[0].param_groups[0]['lr'] 121 | print('learning rate = %.7f' % lr) 122 | 123 | def get_current_visuals(self): 124 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 125 | visual_ret = OrderedDict() 126 | for name in self.visual_names: 127 | if isinstance(name, str): 128 | vis = getattr(self, name) 129 | if isinstance(vis, list): 130 | for i in range(len(vis)): 131 | visual_ret[name + '_' + str(i+1)] = vis[i] 132 | else: 133 | visual_ret[name] = vis 134 | return visual_ret 135 | 136 | def get_current_losses(self): 137 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 138 | errors_ret = OrderedDict() 139 | for name in self.loss_names: 140 | if isinstance(name, str): 141 | loss = getattr(self, 'loss_' + name) 142 | if isinstance(loss, list): 143 | for i in range(len(loss)): 144 | errors_ret[name + '_' + str(i+1)] = float(loss[i]) 145 | else: 146 | errors_ret[name] = float(loss) # float(...) works for both scalar tensor and float number 147 | return errors_ret 148 | 149 | def get_current_log(self): 150 | ret = self.get_current_losses() 151 | for name in self.log_names: 152 | if isinstance(name, str): 153 | ret[name] = float(getattr(self, name)) # float(...) works for both scalar tensor and float number 154 | return ret 155 | 156 | def save_networks(self, epoch): 157 | """Save all the networks to the disk. 158 | 159 | Parameters: 160 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 161 | """ 162 | for name in self.model_names: 163 | if isinstance(name, str): 164 | net = getattr(self, 'net' + name) 165 | if isinstance(net, list): 166 | for i in range(len(net)): 167 | save_filename = '%s_net_%s_%d.pth' % (epoch, name, i+1) 168 | save_path = os.path.join(self.save_dir, save_filename) 169 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 170 | torch.save(net[i].module.cpu().state_dict(), save_path) 171 | net[i].cuda(self.gpu_ids[0]) 172 | else: 173 | torch.save(net[i].cpu().state_dict(), save_path) 174 | else: 175 | save_filename = '%s_net_%s.pth' % (epoch, name) 176 | save_path = os.path.join(self.save_dir, save_filename) 177 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 178 | torch.save(net.module.cpu().state_dict(), save_path) 179 | net.cuda(self.gpu_ids[0]) 180 | else: 181 | torch.save(net.cpu().state_dict(), save_path) 182 | 183 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 184 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 185 | key = keys[i] 186 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 187 | if module.__class__.__name__.startswith('InstanceNorm') and \ 188 | (key == 'running_mean' or key == 'running_var'): 189 | if getattr(module, key) is None: 190 | state_dict.pop('.'.join(keys)) 191 | if module.__class__.__name__.startswith('InstanceNorm') and \ 192 | (key == 'num_batches_tracked'): 193 | state_dict.pop('.'.join(keys)) 194 | else: 195 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 196 | 197 | def load_networks(self, epoch): 198 | """Load all the networks from the disk. 199 | 200 | Parameters: 201 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 202 | """ 203 | for name in self.model_names: 204 | if isinstance(name, str): 205 | net = getattr(self, 'net' + name) 206 | if isinstance(net, list): 207 | for i in range(len(net)): 208 | load_filename = '%s_net_%s_%d.pth' % (epoch, name, i+1) 209 | load_path = os.path.join(self.save_dir, load_filename) 210 | net_i = net[i] 211 | if isinstance(net_i, torch.nn.DataParallel): 212 | net_i = net_i.module 213 | print('loading the model from %s' % load_path) 214 | # if you are using PyTorch newer than 0.4 (e.g., built from 215 | # GitHub source), you can remove str() on self.device 216 | state_dict = torch.load(load_path, map_location=str(self.device)) 217 | if hasattr(state_dict, '_metadata'): 218 | del state_dict._metadata 219 | 220 | # patch InstanceNorm checkpoints prior to 0.4 221 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 222 | self.__patch_instance_norm_state_dict(state_dict, net_i, key.split('.')) 223 | net_i.load_state_dict(state_dict) 224 | else: 225 | load_filename = '%s_net_%s.pth' % (epoch, name) 226 | load_path = os.path.join(self.save_dir, load_filename) 227 | if isinstance(net, torch.nn.DataParallel): 228 | net = net.module 229 | print('loading the model from %s' % load_path) 230 | # if you are using PyTorch newer than 0.4 (e.g., built from 231 | # GitHub source), you can remove str() on self.device 232 | state_dict = torch.load(load_path, map_location=str(self.device)) 233 | if hasattr(state_dict, '_metadata'): 234 | del state_dict._metadata 235 | 236 | # patch InstanceNorm checkpoints prior to 0.4 237 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 238 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 239 | net.load_state_dict(state_dict) 240 | 241 | def print_networks(self, verbose): 242 | """Print the total number of parameters in the network and (if verbose) network architecture 243 | 244 | Parameters: 245 | verbose (bool) -- if verbose: print the network architecture 246 | """ 247 | print('---------- Networks initialized -------------') 248 | for name in self.model_names: 249 | if isinstance(name, str): 250 | net = getattr(self, 'net' + name) 251 | if isinstance(net, list): 252 | for i in range(len(net)): 253 | num_params = 0 254 | for param in net[i].parameters(): 255 | num_params += param.numel() 256 | if verbose: 257 | print(net[i]) 258 | print('[Network %s_%d] Total number of parameters : %.3f M' % (name, i+1, num_params / 1e6)) 259 | else: 260 | num_params = 0 261 | for param in net.parameters(): 262 | num_params += param.numel() 263 | if verbose: 264 | print(net) 265 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 266 | print('-----------------------------------------------') 267 | 268 | def set_requires_grad(self, nets, requires_grad=False): 269 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 270 | Parameters: 271 | nets (network list) -- a list of networks 272 | requires_grad (bool) -- whether the networks require gradients or not 273 | """ 274 | if not isinstance(nets, list): 275 | nets = [nets] 276 | for net in nets: 277 | if net is not None: 278 | for param in net.parameters(): 279 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /models/colorization_model.py: -------------------------------------------------------------------------------- 1 | from .main_model import MainModel 2 | import torch 3 | from skimage import color 4 | import numpy as np 5 | import cv2 6 | 7 | 8 | class ColorizationModel(MainModel): 9 | 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train=True): 12 | MainModel.modify_commandline_options(parser, is_train) 13 | parser.set_defaults(dataset_mode='colorization') 14 | return parser 15 | 16 | def __init__(self, opt): 17 | MainModel.__init__(self, opt) 18 | self.visual_names = ['real_A_l_0', 'real_A_rgb', 'real_R_rgb', 'fake_R_rgb'] 19 | 20 | def lab2rgb(self, L, AB): 21 | AB2 = AB * 110.0 22 | L2 = (L + 1.0) * 50.0 23 | Lab = torch.cat([L2, AB2], dim=1) 24 | Lab = Lab[0].data.cpu().float().numpy() 25 | Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) 26 | rgb = color.lab2rgb(Lab) * 255 27 | return rgb 28 | 29 | def tensor2gray(self, im): 30 | im = im[0].data.cpu().float().numpy() 31 | im = np.transpose(im.astype(np.float64), (1, 2, 0)) 32 | im = np.repeat(im, 3, axis=-1) * 255 33 | return im 34 | 35 | def compute_visuals(self): 36 | self.real_A_l_0 = self.real_A_l[-1] 37 | self.real_A_rgb = self.lab2rgb(self.real_A_l[-1], self.real_A_ab[-1]) 38 | self.real_R_rgb = self.lab2rgb(self.real_R_l[-1], self.real_R_ab[-1]) 39 | self.real_R_rgb = cv2.resize(self.real_R_rgb, (self.real_A_rgb.shape[1], self.real_A_rgb.shape[0])) 40 | self.fake_R_rgb = [] 41 | for i in range(3): 42 | self.fake_R_rgb += [self.lab2rgb(self.real_A_l[i], self.fake_imgs[i])] 43 | if i != 2: 44 | self.fake_R_rgb[i] = cv2.resize(self.fake_R_rgb[i], (self.real_A_rgb.shape[1], self.real_A_rgb.shape[0])) 45 | 46 | def compute_scores(self): 47 | metrics = [] 48 | hr = self.real_R_histogram[-1].data.cpu().float().numpy().flatten() 49 | hg = self.fake_R_histogram[-1].data.cpu().float().numpy().flatten() 50 | intersect = cv2.compareHist(hr, hg, cv2.HISTCMP_INTERSECT) 51 | metrics.append(intersect) 52 | 53 | return metrics 54 | -------------------------------------------------------------------------------- /models/main_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | from util import util 4 | 5 | 6 | class MainModel(BaseModel): 7 | @staticmethod 8 | def modify_commandline_options(parser, is_train=True): 9 | parser.set_defaults(norm='instance', dataset_mode='aligned') 10 | return parser 11 | 12 | def __init__(self, opt): 13 | 14 | BaseModel.__init__(self, opt) 15 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 16 | self.model_names = ['G'] 17 | self.netG = networks.define_G(opt.input_nc, opt.bias_input_nc, opt.output_nc, opt.norm, opt.init_type, 18 | opt.init_gain, self.gpu_ids) 19 | self.convert = util.Convert(self.device) 20 | 21 | def set_input(self, input): 22 | self.image_paths = input['A_paths'] 23 | self.ab_constant = input['ab'].to(self.device) 24 | self.hist = input['hist'].to(self.device) 25 | 26 | self.real_A_l, self.real_A_ab, self.real_R_l, self.real_R_ab, self.real_R_histogram = [], [], [], [], [] 27 | for i in range(3): 28 | self.real_A_l += input['A_l'][i].to(self.device).unsqueeze(0) 29 | self.real_A_ab += input['A_ab'][i].to(self.device).unsqueeze(0) 30 | self.real_R_l += input['R_l'][i].to(self.device).unsqueeze(0) 31 | self.real_R_ab += input['R_ab'][i].to(self.device).unsqueeze(0) 32 | self.real_R_histogram += [util.calc_hist(input['A_ab'][i].to(self.device), self.device)] 33 | 34 | def forward(self): 35 | self.fake_imgs = self.netG(self.real_A_l[-1], self.real_R_l[-1], self.real_R_ab[0], self.hist, 36 | self.ab_constant, self.device) 37 | self.fake_R_histogram = [] 38 | for i in range(3): 39 | self.fake_R_histogram += [util.calc_hist(self.fake_imgs[i], self.device)] 40 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | 8 | 9 | class Identity(nn.Module): 10 | def forward(self, x): 11 | return x 12 | 13 | 14 | def get_norm_layer(norm_type='instance'): 15 | if norm_type == 'batch': 16 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 17 | elif norm_type == 'instance': 18 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 19 | elif norm_type == 'none': 20 | norm_layer = lambda x: Identity() 21 | else: 22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 23 | return norm_layer 24 | 25 | 26 | def get_scheduler(optimizer, opt): 27 | if opt.lr_policy == 'linear': 28 | def lambda_rule(epoch): 29 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 30 | return lr_l 31 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 32 | elif opt.lr_policy == 'step': 33 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 34 | elif opt.lr_policy == 'plateau': 35 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 36 | elif opt.lr_policy == 'cosine': 37 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 38 | else: 39 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 40 | return scheduler 41 | 42 | 43 | def init_weights(net, init_type='normal', init_gain=0.02): 44 | def init_func(m): # define the initialization function 45 | classname = m.__class__.__name__ 46 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 47 | if init_type == 'normal': 48 | init.normal_(m.weight.data, 0.0, init_gain) 49 | elif init_type == 'xavier': 50 | init.xavier_normal_(m.weight.data, gain=init_gain) 51 | elif init_type == 'kaiming': 52 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 53 | elif init_type == 'orthogonal': 54 | init.orthogonal_(m.weight.data, gain=init_gain) 55 | else: 56 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 57 | if hasattr(m, 'bias') and m.bias is not None: 58 | init.constant_(m.bias.data, 0.0) 59 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 60 | init.normal_(m.weight.data, 1.0, init_gain) 61 | init.constant_(m.bias.data, 0.0) 62 | 63 | print('initialize network with %s' % init_type) 64 | net.apply(init_func) # apply the initialization function 65 | 66 | 67 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 68 | if len(gpu_ids) > 0: 69 | assert(torch.cuda.is_available()) 70 | net.to(gpu_ids[0]) 71 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 72 | init_weights(net, init_type, init_gain=init_gain) 73 | return net 74 | 75 | 76 | def define_G(input_nc, bias_input_nc, output_nc, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 77 | norm_layer = get_norm_layer(norm_type=norm) 78 | net = ColorNet(input_nc, bias_input_nc, output_nc, norm_layer=norm_layer) 79 | 80 | return init_net(net, init_type, init_gain, gpu_ids) 81 | 82 | 83 | class ResBlock(nn.Module): 84 | def __init__(self, dim, norm_layer, use_dropout, use_bias): 85 | super(ResBlock, self).__init__() 86 | conv_block = [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 87 | if use_dropout: 88 | conv_block += [nn.Dropout(0.5)] 89 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias), norm_layer(dim)] 90 | self.conv_block = nn.Sequential(*conv_block) 91 | 92 | def forward(self, x): 93 | out = x + self.conv_block(x) # add skip connections 94 | return out 95 | 96 | 97 | class global_network(nn.Module): 98 | def __init__(self, in_dim): 99 | super(global_network, self).__init__() 100 | model = [nn.Conv2d(in_dim, 512, kernel_size=1, padding=0), nn.ReLU(True)] 101 | model += [nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.ReLU(True)] 102 | model += [nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.ReLU(True)] 103 | self.model = nn.Sequential(*model) 104 | 105 | self.model_1 = nn.Sequential(*[nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.Sigmoid()]) 106 | 107 | def forward(self, x): 108 | x = self.model(x) 109 | x1 = self.model_1(x) 110 | 111 | return x1 112 | 113 | 114 | class ref_network_align(nn.Module): 115 | def __init__(self, norm_layer): 116 | super(ref_network_align, self).__init__() 117 | model1 = [nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 118 | model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(64)] 119 | self.model1 = nn.Sequential(*model1) 120 | model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 121 | model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)] 122 | self.model2 = nn.Sequential(*model2) 123 | model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 124 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)] 125 | self.model3 = nn.Sequential(*model3) 126 | model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 127 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(256)] 128 | self.model4 = nn.Sequential(*model4) 129 | 130 | def forward(self, color, corr, H, W): 131 | 132 | color_flatten = color.view(color.shape[0], color.shape[1], -1) 133 | align_color = torch.bmm(color_flatten, corr) 134 | align_color_output = align_color.view(align_color.shape[0], align_color.shape[1], H, W) 135 | 136 | conv1 = self.model1(align_color_output) 137 | align_color1 = self.model2(conv1) 138 | align_color2 = self.model3(align_color1[:,:,::2,::2]) 139 | align_color3 = self.model4(align_color2[:,:,::2,::2]) 140 | 141 | return align_color1, align_color2, align_color3 142 | 143 | 144 | class ref_network_hist(nn.Module): 145 | def __init__(self, norm_layer): 146 | super(ref_network_hist, self).__init__() 147 | model1 = [nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 148 | model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(64)] 149 | self.model1 = nn.Sequential(*model1) 150 | model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 151 | model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)] 152 | self.model2 = nn.Sequential(*model2) 153 | model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 154 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)] 155 | self.model3 = nn.Sequential(*model3) 156 | model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)] 157 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(256)] 158 | self.model4 = nn.Sequential(*model4) 159 | 160 | def forward(self, color): 161 | 162 | conv1 = self.model1(color) 163 | align_color1 = self.model2(conv1) 164 | align_color2 = self.model3(align_color1[:,:,::2,::2]) 165 | align_color3 = self.model4(align_color2[:,:,::2,::2]) 166 | 167 | return align_color1, align_color2, align_color3 168 | 169 | 170 | class conf_feature_align(nn.Module): 171 | def __init__(self): 172 | super(conf_feature_align, self).__init__() 173 | self.fc1 = nn.Sequential(*[nn.Conv1d(4096, 1024, kernel_size=1, stride=1, padding=0, bias=True), nn.ReLU(True)]) 174 | self.fc2 = nn.Sequential(*[nn.Conv1d(1024, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid()]) 175 | self.dropout1 = nn.Dropout(0.1) 176 | 177 | def forward(self, x): 178 | x1 = self.fc1(x) 179 | x2 = self.dropout1(x1) 180 | x3 = self.fc2(x2) 181 | 182 | return x3 183 | 184 | 185 | class conf_feature_hist(nn.Module): 186 | def __init__(self): 187 | super(conf_feature_hist, self).__init__() 188 | self.fc1 = nn.Sequential(*[nn.Conv1d(4096, 1024, kernel_size=1, stride=1, padding=0, bias=True), nn.ReLU(True)]) 189 | self.fc2 = nn.Sequential(*[nn.Conv1d(1024, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid()]) 190 | self.dropout1 = nn.Dropout(0.1) 191 | 192 | def forward(self, x): 193 | x1 = self.fc1(x) 194 | x2 = self.dropout1(x1) 195 | x3 = self.fc2(x2) 196 | 197 | return x3 198 | 199 | 200 | class classify_network(nn.Module): 201 | def __init__(self): 202 | super(classify_network, self).__init__() 203 | self.maxpool = nn.AdaptiveMaxPool2d((1, 1)) 204 | self.fc = nn.Linear(512, 1000) 205 | 206 | def forward(self, x): 207 | x = self.maxpool(x) 208 | x = x.squeeze(-1).squeeze(-1) 209 | x = self.fc(x) 210 | return x 211 | 212 | 213 | class ColorNet(nn.Module): 214 | def __init__(self, input_nc, bias_input_nc, output_nc, norm_layer=nn.BatchNorm2d): 215 | super(ColorNet, self).__init__() 216 | self.input_nc = input_nc 217 | self.output_nc = output_nc 218 | use_bias = True 219 | 220 | model_head = [nn.Conv2d(input_nc, 32, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), 221 | norm_layer(32)] 222 | 223 | # Conv1 224 | model1=[nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),] 225 | model1+=[nn.ReLU(True),] 226 | model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),] 227 | model1+=[nn.ReLU(True),] 228 | model1+=[norm_layer(64),] 229 | 230 | # Conv2 231 | model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),] 232 | model2+=[nn.ReLU(True),] 233 | model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),] 234 | model2+=[nn.ReLU(True),] 235 | model2+=[norm_layer(128),] 236 | 237 | # Conv3 238 | model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),] 239 | model3+=[nn.ReLU(True),] 240 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),] 241 | model3+=[nn.ReLU(True),] 242 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),] 243 | model3+=[nn.ReLU(True),] 244 | model3+=[norm_layer(256),] 245 | 246 | # Conv4 247 | model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),] 248 | model4+=[nn.ReLU(True),] 249 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),] 250 | model4+=[nn.ReLU(True),] 251 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),] 252 | model4+=[nn.ReLU(True),] 253 | model4+=[norm_layer(512),] 254 | 255 | # Conv5 256 | model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),] 257 | model5+=[nn.ReLU(True),] 258 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),] 259 | model5+=[nn.ReLU(True),] 260 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),] 261 | model5+=[nn.ReLU(True),] 262 | model5+=[norm_layer(512),] 263 | 264 | # Conv6 265 | model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),] 266 | model6+=[nn.ReLU(True),] 267 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),] 268 | model6+=[nn.ReLU(True),] 269 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),] 270 | model6+=[nn.ReLU(True),] 271 | model6+=[norm_layer(512),] 272 | 273 | # Conv7 274 | model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 275 | model7+=[nn.ReLU(True),] 276 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 277 | model7+=[nn.ReLU(True),] 278 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 279 | model7+=[nn.ReLU(True),] 280 | model7+=[norm_layer(512),] 281 | 282 | model_hist=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),] 283 | model_hist+=[nn.ReLU(True),] 284 | model_hist+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 285 | model_hist+=[nn.ReLU(True),] 286 | model_hist+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 287 | model_hist+=[nn.ReLU(True),] 288 | 289 | model_hist+=[nn.Conv2d(256, 198, kernel_size=1, stride=1, padding=0, bias=True),] 290 | 291 | # ResBlock0 292 | resblock0_1 = [nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=use_bias), norm_layer(512), nn.ReLU(True)] 293 | self.resblock0_2 = ResBlock(512, norm_layer, False, use_bias) 294 | self.resblock0_3 = ResBlock(512, norm_layer, False, use_bias) 295 | 296 | # Conv8 297 | model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)] 298 | 299 | model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),] 300 | 301 | model8=[nn.ReLU(True),] 302 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),] 303 | model8+=[nn.ReLU(True),] 304 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),] 305 | model8+=[nn.ReLU(True),] 306 | model8+=[norm_layer(256),] 307 | 308 | # ResBlock1 309 | resblock1_1 = [nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=use_bias), norm_layer(256), nn.ReLU(True)] 310 | self.resblock1_2 = ResBlock(256, norm_layer, False, use_bias) 311 | self.resblock1_3 = ResBlock(256, norm_layer, False, use_bias) 312 | 313 | # Conv9 314 | model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),] 315 | 316 | model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),] 317 | 318 | model9=[nn.ReLU(True),] 319 | model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),] 320 | model9+=[nn.ReLU(True),] 321 | model9+=[norm_layer(128),] 322 | 323 | # ResBlock2 324 | resblock2_1 = [nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=use_bias), norm_layer(128), nn.ReLU(True)] 325 | self.resblock2_2 = ResBlock(128, norm_layer, False, use_bias) 326 | self.resblock2_3 = ResBlock(128, norm_layer, False, use_bias) 327 | 328 | # Conv10 329 | model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),] 330 | 331 | model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),] 332 | 333 | model10=[nn.ReLU(True),] 334 | model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=use_bias),] 335 | model10+=[nn.LeakyReLU(negative_slope=.2),] 336 | 337 | # Conv Global 338 | self.global_network = global_network(bias_input_nc) 339 | 340 | # conf feature 341 | self.conf_feature_align = conf_feature_align() 342 | self.conf_feature_hist = conf_feature_hist() 343 | 344 | # Conv Ref 345 | self.ref_network_align = ref_network_align(norm_layer) 346 | self.ref_network_hist = ref_network_hist(norm_layer) 347 | 348 | # classification 349 | self.classify_network = classify_network() 350 | 351 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 352 | self.softmax_gate = nn.Softmax(dim=1) 353 | self.softmax = nn.Softmax(dim=-1) 354 | self.key_dataset = torch.eye(bias_input_nc) 355 | 356 | model_tail_1 = [nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2)] 357 | model_tail_2 = [nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2)] 358 | model_tail_3 = [nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2)] 359 | 360 | model_out1 = [nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()] 361 | model_out2 = [nn.Conv2d(64, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()] 362 | model_out3 = [nn.Conv2d(64, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()] 363 | 364 | self.model1 = nn.Sequential(*model1) 365 | self.model2 = nn.Sequential(*model2) 366 | self.model3 = nn.Sequential(*model3) 367 | self.model4 = nn.Sequential(*model4) 368 | self.model5 = nn.Sequential(*model5) 369 | self.model6 = nn.Sequential(*model6) 370 | self.model7 = nn.Sequential(*model7) 371 | self.model_hist = nn.Sequential(*model_hist) 372 | self.model8up = nn.Sequential(*model8up) 373 | self.model8 = nn.Sequential(*model8) 374 | self.model9up = nn.Sequential(*model9up) 375 | self.model9 = nn.Sequential(*model9) 376 | self.model10up = nn.Sequential(*model10up) 377 | self.model10 = nn.Sequential(*model10) 378 | self.model3short8 = nn.Sequential(*model3short8) 379 | self.model2short9 = nn.Sequential(*model2short9) 380 | self.model1short10 = nn.Sequential(*model1short10) 381 | self.resblock0_1 = nn.Sequential(*resblock0_1) 382 | self.resblock1_1 = nn.Sequential(*resblock1_1) 383 | self.resblock2_1 = nn.Sequential(*resblock2_1) 384 | self.model_out1 = nn.Sequential(*model_out1) 385 | self.model_out2 = nn.Sequential(*model_out2) 386 | self.model_out3 = nn.Sequential(*model_out3) 387 | self.model_head = nn.Sequential(*model_head) 388 | self.model_tail_1 = nn.Sequential(*model_tail_1) 389 | self.model_tail_2 = nn.Sequential(*model_tail_2) 390 | self.model_tail_3 = nn.Sequential(*model_tail_3) 391 | 392 | 393 | def forward(self, input, ref_input, ref_color, bias_input, ab_constant, device): 394 | 395 | # align branch 396 | in_conv = self.model_head(input) 397 | 398 | in_1 = self.model1(in_conv[:, :, ::2, ::2]) 399 | in_2 = self.model2(in_1[:, :, ::2, ::2]) 400 | in_3 = self.model3(in_2[:, :, ::2, ::2]) 401 | in_4 = self.model4(in_3[:, :, ::2, ::2]) 402 | in_5 = self.model5(in_4) 403 | in_6 = self.model6(in_5) 404 | 405 | ref_conv_head = self.model_head(ref_input) 406 | ref_1 = self.model1(ref_conv_head[:,:,::2,::2]) 407 | ref_2 = self.model2(ref_1[:, :, ::2, ::2]) 408 | ref_3 = self.model3(ref_2[:, :, ::2, ::2]) 409 | ref_4 = self.model4(ref_3[:, :, ::2, ::2]) 410 | ref_5 = self.model5(ref_4) 411 | ref_6 = self.model6(ref_5) 412 | 413 | t1 = F.interpolate(in_1, scale_factor=0.5, mode='bilinear') 414 | t2 = in_2 415 | t3 = F.interpolate(in_3, scale_factor=2, mode='bilinear') 416 | t4 = F.interpolate(in_4, scale_factor=4, mode='bilinear') 417 | t5 = F.interpolate(in_5, scale_factor=4, mode='bilinear') 418 | t6 = F.interpolate(in_6, scale_factor=4, mode='bilinear') 419 | t = torch.cat((t1, t2, t3, t4, t5, t6), dim=1) 420 | 421 | r1 = F.interpolate(ref_1, scale_factor=0.5, mode='bilinear') 422 | r2 = ref_2 423 | r3 = F.interpolate(ref_3, scale_factor=2, mode='bilinear') 424 | r4 = F.interpolate(ref_4, scale_factor=4, mode='bilinear') 425 | r5 = F.interpolate(ref_5, scale_factor=4, mode='bilinear') 426 | r6 = F.interpolate(ref_6, scale_factor=4, mode='bilinear') 427 | r = torch.cat((r1, r2, r3, r4, r5, r6), dim=1) 428 | 429 | input_T_flatten = t.view(t.shape[0], t.shape[1], -1).permute(0, 2, 1) 430 | input_R_flatten = r.view(r.shape[0], r.shape[1], -1).permute(0, 2, 1) 431 | input_T_flatten = input_T_flatten / torch.norm(input_T_flatten, p=2, dim=-1, keepdim=True) 432 | input_R_flatten = input_R_flatten / torch.norm(input_R_flatten, p=2, dim=-1, keepdim=True) 433 | corr = torch.bmm(input_R_flatten, input_T_flatten.permute(0, 2, 1)) 434 | 435 | corr = F.softmax(corr / 0.01, dim=1) 436 | 437 | # Align branch confidence map learning 438 | align_1, align_2, align_3 = self.ref_network_align(ref_color, corr, t2.shape[2], t2.shape[3]) 439 | conf_align = self.conf_feature_align(corr) 440 | conf_align = conf_align.view(conf_align.shape[0], 1, t2.shape[2], t2.shape[3]) 441 | conf_aligns = 5.0 * conf_align 442 | 443 | # Histogram branch confidence map learning 444 | conf_hist = self.conf_feature_hist(corr) 445 | conf_hist = conf_hist.view(conf_hist.shape[0], 1, t2.shape[2], t2.shape[3]) 446 | conf_hists = 5.0 * conf_hist 447 | 448 | # Gate softmax operation on confidence map 449 | conf_total = torch.cat((conf_aligns, conf_hists), dim=1) 450 | conf_softmax = self.softmax_gate(conf_total) 451 | 452 | conf_1_align = conf_softmax[:, :1, :, :] 453 | conf_1_hist = conf_softmax[:, 1:, :, :] 454 | conf_2_align = conf_1_align[:,:,::2,::2] 455 | conf_3_align = conf_2_align[:,:,::2,::2] 456 | conf_2_hist = conf_1_hist[:,:,::2,::2] 457 | conf_3_hist = conf_2_hist[:,:,::2,::2] 458 | 459 | # hist branch 460 | bias_input = bias_input.view(input.shape[0], -1, 1, 1) 461 | 462 | conv_head = self.model_head(input) 463 | conv1_2 = self.model1(conv_head[:, :, ::2, ::2]) 464 | conv2_2 = self.model2(conv1_2[:,:,::2,::2]) 465 | conv3_3 = self.model3(conv2_2[:,:,::2,::2]) 466 | conv4_3 = self.model4(conv3_3[:,:,::2,::2]) 467 | conv5_3 = self.model5(conv4_3) 468 | conv6_3 = self.model6(conv5_3) 469 | 470 | class_output = self.classify_network(conv6_3) 471 | 472 | # hist align 473 | conv_global1 = self.global_network(bias_input) 474 | conv_global1_repeat = conv_global1.expand_as(conv6_3) 475 | conv_global1_add = conv6_3 * conv_global1_repeat 476 | conv7_3 = self.model7(conv_global1_add) 477 | color_reg = self.model_hist(conv7_3) 478 | 479 | # calculate attention matrix for histogram branch 480 | key_datasets = self.key_dataset.unsqueeze(0).to(device) 481 | attn_weights = torch.bmm(color_reg.flatten(2).permute(0, 2, 1), key_datasets) 482 | value = ab_constant.type_as(color_reg) 483 | attn_weights_softmax = self.softmax(attn_weights * 100.0) 484 | conv_total_out = torch.bmm(attn_weights_softmax, value).permute(0, 2, 1) 485 | conv_total_out_re = conv_total_out.view(color_reg.shape[0], -1, color_reg.shape[2], color_reg.shape[3]) 486 | conv_total_out_up = self.upsample(conv_total_out_re) 487 | 488 | hist_1, hist_2, hist_3 = self.ref_network_hist(conv_total_out_up) 489 | 490 | # encoder1 491 | conv6_3_global = conv6_3 + align_3 * conf_3_align + hist_3 * conf_3_hist 492 | conv7_resblock1 = self.resblock0_1(conv6_3_global) 493 | conv7_resblock2 = self.resblock0_2(conv7_resblock1) 494 | conv7_resblock3 = self.resblock0_3(conv7_resblock2) 495 | conv8_up = self.model8up(conv7_resblock3) + self.model3short8(conv3_3) 496 | conv8_3 = self.model8(conv8_up) 497 | conv_tail_1 = self.model_tail_1(conv8_3) 498 | fake_img1 = self.model_out1(conv_tail_1) 499 | 500 | # encoder2 501 | conv8_3_global = conv8_3 + align_2 * conf_2_align + hist_2 * conf_2_hist 502 | conv8_resblock1 = self.resblock1_1(conv8_3_global) 503 | conv8_resblock2 = self.resblock1_2(conv8_resblock1) 504 | conv8_resblock3 = self.resblock1_3(conv8_resblock2) 505 | conv9_up = self.model9up(conv8_resblock3) + self.model2short9(conv2_2) 506 | conv9_3 = self.model9(conv9_up) 507 | conv_tail_2 = self.model_tail_2(conv9_3) 508 | fake_img2 = self.model_out2(conv_tail_2) 509 | 510 | # encoder3 511 | conv9_3_global = conv9_3 + align_1 * conf_1_align + hist_1 * conf_1_hist 512 | conv9_resblock1 = self.resblock2_1(conv9_3_global) 513 | conv9_resblock2 = self.resblock2_2(conv9_resblock1) 514 | conv9_resblock3 = self.resblock2_3(conv9_resblock2) 515 | conv10_up = self.model10up(conv9_resblock3) + self.model1short10(conv1_2) 516 | conv10_2 = self.model10(conv10_up) 517 | conv_tail_3 = self.model_tail_3(conv10_2) 518 | fake_img3 = self.model_out3(conv_tail_3) 519 | 520 | return [fake_img1, fake_img2, fake_img3] 521 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: test options, and basic options (used in test).""" 2 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | parser.add_argument('--dataroot', type=str, default='./dataset/', 24 | help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 25 | parser.add_argument('--name', type=str, default='imagenet', help='name of the experiment. It decides where to store samples and models') 26 | parser.add_argument('--gpu_ids', type=str, default='1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/', help='models are saved here') 28 | # model parameters 29 | parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale') 30 | parser.add_argument('--bias_input_nc', type=int, default=198, help='# of reference image histogram bins') 31 | parser.add_argument('--output_nc', type=int, default=2, help='# of output image channels: 3 for RGB and 1 for grayscale') 32 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 33 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 34 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 35 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 36 | # dataset parameters 37 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 38 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 39 | parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 40 | parser.add_argument('--load_size', type=int, default=288, help='scale images to this size') 41 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 42 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 43 | parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 44 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 45 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 46 | parser.add_argument('--targetImage_path', type=str, default='./imgs/target.JPEG') 47 | parser.add_argument('--referenceImage_path', type=str, default='./imgs/reference.JPEG') 48 | # additional parameters 49 | parser.add_argument('--use_D', action='store_true', help='whether to use discriminator or not') 50 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 51 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 52 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 53 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 54 | self.initialized = True 55 | return parser 56 | 57 | def gather_options(self): 58 | """Initialize our parser with basic options(only once). 59 | Add additional model-specific and dataset-specific options. 60 | These options are defined in the function 61 | in model and dataset classes. 62 | """ 63 | if not self.initialized: # check if it has been initialized 64 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 65 | parser = self.initialize(parser) 66 | 67 | # get the basic options 68 | opt, _ = parser.parse_known_args() 69 | 70 | # modify model-related parser options 71 | model_name = opt.model 72 | model_option_setter = models.get_option_setter(model_name) 73 | parser = model_option_setter(parser, self.isTrain) 74 | opt, _ = parser.parse_known_args() # parse again with new defaults 75 | 76 | # modify dataset-related parser options 77 | dataset_name = opt.dataset_mode 78 | dataset_option_setter = data.get_option_setter(dataset_name) 79 | parser = dataset_option_setter(parser, self.isTrain) 80 | 81 | # save and return the parser 82 | self.parser = parser 83 | return parser.parse_args() 84 | 85 | def print_options(self, opt): 86 | """Print and save options 87 | 88 | It will print both current options and default values(if different). 89 | It will save options into a text file / [checkpoints_dir] / opt.txt 90 | """ 91 | message = '' 92 | message += '----------------- Options ---------------\n' 93 | for k, v in sorted(vars(opt).items()): 94 | comment = '' 95 | default = self.parser.get_default(k) 96 | if v != default: 97 | comment = '\t[default: %s]' % str(default) 98 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 99 | message += '----------------- End -------------------' 100 | print(message) 101 | 102 | # save to the disk 103 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 104 | util.mkdirs(expr_dir) 105 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 106 | with open(file_name, 'wt') as opt_file: 107 | opt_file.write(message) 108 | opt_file.write('\n') 109 | 110 | def parse(self): 111 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 112 | opt = self.gather_options() 113 | opt.isTrain = self.isTrain # train or test 114 | 115 | # process opt.suffix 116 | if opt.suffix: 117 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 118 | opt.name = opt.name + suffix 119 | 120 | self.print_options(opt) 121 | 122 | # set gpu ids 123 | str_ids = opt.gpu_ids.split(',') 124 | opt.gpu_ids = [] 125 | for str_id in str_ids: 126 | id = int(str_id) 127 | if id >= 0: 128 | opt.gpu_ids.append(id) 129 | if len(opt.gpu_ids) > 0: 130 | torch.cuda.set_device(opt.gpu_ids[0]) 131 | opt.A = 2 * 110.0 / 10.0 + 1 132 | 133 | self.opt = opt 134 | return self.opt 135 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | # rewrite devalue values 20 | parser.set_defaults(model='colorization') 21 | # To avoid cropping, the load_size should be the same as crop_size 22 | # parser.set_defaults(load_size=parser.get_default('crop_size')) 23 | self.isTrain = False 24 | return parser -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | astor==0.8.1 3 | axial-positional-embedding==0.2.1 4 | certifi==2020.6.20 5 | chardet==3.0.4 6 | cycler==0.10.0 7 | decorator==4.4.2 8 | dominate==2.5.1 9 | future==0.18.2 10 | gast==0.4.0 11 | google-pasta==0.2.0 12 | graphviz==0.14.2 13 | grpcio==1.33.1 14 | h5py==2.10.0 15 | idna==2.10 16 | imageio==2.9.0 17 | importlib-metadata==2.0.0 18 | jsonpatch==1.26 19 | jsonpointer==2.0 20 | Keras==2.2.4 21 | Keras-Applications==1.0.8 22 | Keras-Preprocessing==1.1.2 23 | kiwisolver==1.2.0 24 | local-attention==1.0.2 25 | Markdown==3.3.2 26 | matplotlib==3.3.0 27 | mkl-fft==1.1.0 28 | mkl-random==1.1.1 29 | mkl-service==2.3.0 30 | networkx==2.4 31 | numpy==1.18.5 32 | olefile==0.46 33 | opencv-contrib-python==3.4.2.16 34 | opencv-python==3.4.2.16 35 | Pillow==7.2.0 36 | product-key-memory==0.1.10 37 | protobuf==3.13.0 38 | pyheatmap==0.1.12 39 | pyparsing==2.4.7 40 | python-dateutil==2.8.1 41 | PyWavelets==1.1.1 42 | PyYAML==5.3.1 43 | pyzmq==19.0.1 44 | reformer-pytorch==1.1.3 45 | requests==2.24.0 46 | scikit-image==0.17.2 47 | scipy==1.2.1 48 | six==1.15.0 49 | spatial-correlation-sampler==0.3.0 50 | tensorboard==1.14.0 51 | tensorflow-estimator==1.14.0 52 | termcolor==1.1.0 53 | tifffile==2020.7.24 54 | torch==1.5.1 55 | torchfile==0.1.0 56 | torchvision==0.6.0a0+35d732a 57 | torchviz @ git+https://github.com/szagoruyko/pytorchviz@46add7f2c071b6d29fc3d56e9d2d21e1c0a3af1d 58 | tornado==6.0.4 59 | tqdm==4.48.0 60 | urllib3==1.25.10 61 | visdom==0.1.8.9 62 | websocket-client==0.57.0 63 | Werkzeug==1.0.1 64 | wrapt==1.12.1 65 | zipp==3.3.1 66 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import create_dataset 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from util import html 7 | import numpy as np 8 | 9 | 10 | if __name__ == '__main__': 11 | opt = TestOptions().parse() 12 | opt.num_threads = 0 13 | opt.batch_size = 1 14 | opt.serial_batches = True 15 | opt.no_flip = True 16 | opt.display_id = -1 17 | 18 | dataset = create_dataset(opt) 19 | model = create_model(opt) 20 | model.setup(opt) 21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) 22 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 23 | scores = [] 24 | if opt.eval: 25 | model.eval() 26 | for i, data in enumerate(dataset): 27 | model.set_input(data) 28 | model.test() 29 | visuals = model.get_current_visuals() 30 | img_path = model.get_image_paths() 31 | metrics = model.compute_scores() 32 | scores.extend(metrics) 33 | print('processing (%04d)-th image... %s' % (i, img_path)) 34 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 35 | webpage.save() 36 | print('Histogram Intersection: %.4f' % np.mean(scores)) -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python3 -u test.py --targetImage_path ./imgs/target.JPEG --referenceImage_path ./imgs/reference.JPEG --gpu_id 1 -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def add_more_images(self, ims, links, width=400): 69 | """add images to the HTML file 70 | 71 | Parameters: 72 | ims (str list) -- a list of image paths 73 | txts (str list) -- a list of image names shown on the website 74 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 75 | """ 76 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 77 | self.doc.add(self.t) 78 | with self.t: 79 | for i in range(2): 80 | with tr(): 81 | for im, link in zip(ims[i], links[i]): 82 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 83 | with p(): 84 | with a(href=os.path.join('images', link)): 85 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 86 | br() 87 | 88 | def save(self): 89 | """save the current content to the HMTL file""" 90 | html_file = '%s/index.html' % self.web_dir 91 | f = open(html_file, 'wt') 92 | f.write(self.doc.render()) 93 | f.close() 94 | 95 | 96 | if __name__ == '__main__': # we show an example usage here. 97 | html = HTML('web/', 'test_html') 98 | html.add_header('hello world') 99 | 100 | ims, txts, links = [], [], [] 101 | for n in range(4): 102 | ims.append('image_%d.png' % n) 103 | txts.append('text_%d' % n) 104 | links.append('image_%d.png' % n) 105 | html.add_images(ims, txts, links) 106 | html.save() 107 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from scipy import linalg 6 | from PIL import Image 7 | from skimage import color 8 | import os 9 | 10 | 11 | def tensor2im(input_image, imtype=np.uint8): 12 | """"Converts a Tensor array into a numpy image array. 13 | 14 | Parameters: 15 | input_image (tensor) -- the input image tensor array 16 | imtype (type) -- the desired type of the converted numpy array 17 | """ 18 | if not isinstance(input_image, np.ndarray): 19 | if isinstance(input_image, torch.Tensor): # get the data from a variable 20 | image_tensor = input_image.data 21 | else: 22 | return input_image 23 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 24 | if image_numpy.shape[0] == 1: # grayscale to RGB 25 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 26 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 27 | else: # if it is a numpy array, do nothing 28 | image_numpy = input_image 29 | return image_numpy.astype(imtype) 30 | 31 | 32 | def save_image(image_numpy, image_path): 33 | """Save a numpy image to the disk 34 | 35 | Parameters: 36 | image_numpy (numpy array) -- input numpy array 37 | image_path (str) -- the path of the image 38 | """ 39 | image_pil = Image.fromarray(image_numpy) 40 | image_pil.save(image_path) 41 | 42 | 43 | def mkdirs(paths): 44 | """create empty directories if they don't exist 45 | 46 | Parameters: 47 | paths (str list) -- a list of directory paths 48 | """ 49 | if isinstance(paths, list) and not isinstance(paths, str): 50 | for path in paths: 51 | mkdir(path) 52 | else: 53 | mkdir(paths) 54 | 55 | 56 | def mkdir(path): 57 | """create a single empty directory if it didn't exist 58 | 59 | Parameters: 60 | path (str) -- a single directory path 61 | """ 62 | if not os.path.exists(path): 63 | os.makedirs(path) 64 | 65 | 66 | def calc_hist(data_ab, device): 67 | N, C, H, W = data_ab.shape 68 | grid_a = torch.linspace(-1, 1, 21).view(1, 21, 1, 1, 1).expand(N, 21, 21, H, W).to(device) 69 | grid_b = torch.linspace(-1, 1, 21).view(1, 1, 21, 1, 1).expand(N, 21, 21, H, W).to(device) 70 | hist_a = torch.max(0.1 - torch.abs(grid_a - data_ab[:, 0, :, :].view(N, 1, 1, H, W)), torch.Tensor([0]).to(device)) * 10 71 | hist_b = torch.max(0.1 - torch.abs(grid_b - data_ab[:, 1, :, :].view(N, 1, 1, H, W)), torch.Tensor([0]).to(device)) * 10 72 | hist = (hist_a * hist_b).mean(dim=(3, 4)).view(N, -1) 73 | return hist 74 | 75 | 76 | class Convert(object): 77 | def __init__(self, device): 78 | xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], 79 | [0.212671, 0.715160, 0.072169], 80 | [0.019334, 0.119193, 0.950227]]) 81 | rgb_from_xyz = linalg.inv(xyz_from_rgb) 82 | self.rgb_from_xyz = torch.Tensor(rgb_from_xyz).to(device) 83 | self.channel_mask = torch.Tensor([1, 0, 0]).to(device) 84 | self.xyz_weight = torch.Tensor([0.95047, 1., 1.08883]).to(device) 85 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).to(device) 86 | self.std = torch.Tensor([0.229, 0.224, 0.225]).to(device) 87 | self.zero = torch.Tensor([0]).to(device) 88 | self.one = torch.Tensor([1]).to(device) 89 | 90 | def lab2rgb(self, img): 91 | img = img.permute(0, 2, 3, 1) 92 | img1 = (img + 1.0) * 50.0 * self.channel_mask 93 | img2 = img * 110.0 * (1 - self.channel_mask) 94 | img = img1 + img2 95 | return self.xyz2rgb(self.lab2xyz(img)) 96 | 97 | def lab2xyz(self, img): 98 | L, a, b = img[:, :, :, 0], img[:, :, :, 1], img[:, :, :, 2] 99 | y = (L + 16.) / 116. 100 | x = (a / 500.) + y 101 | z = y - (b / 200.) 102 | z = torch.max(z, self.zero) 103 | out = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1)], dim=-1) 104 | mask = (out > 0.2068966).float() 105 | out1 = torch.pow(out, 3) * mask 106 | out2 = (out - 16.0 / 116.) / 7.787 * (1 - mask) 107 | out = out1 + out2 108 | out *= self.xyz_weight 109 | return out 110 | 111 | def xyz2rgb(self, img): 112 | arr = img.matmul(self.rgb_from_xyz.t()) 113 | mask = (arr > 0.0031308).float() 114 | arr1 = (1.055 * torch.pow(torch.max(arr, self.zero), 1 / 2.4) - 0.055) * mask 115 | arr2 = arr * 12.92 * (1 - mask) 116 | arr = arr1 + arr2 117 | arr = torch.min(torch.max(arr, self.zero), self.one) 118 | return arr 119 | 120 | def rgb_norm(self, img): 121 | img = (img - self.mean) / self.std 122 | return img.permute(0, 3, 1, 2) -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | from scipy.misc import imresize 9 | 10 | if sys.version_info[0] == 2: 11 | VisdomExceptionBase = Exception 12 | else: 13 | VisdomExceptionBase = ConnectionError 14 | 15 | 16 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 17 | """Save images to the disk. 18 | 19 | Parameters: 20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 21 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 22 | image_path (str) -- the string is used to create image paths 23 | aspect_ratio (float) -- the aspect ratio of saved images 24 | width (int) -- the images will be resized to width x width 25 | 26 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 27 | """ 28 | image_dir = webpage.get_image_dir() 29 | # short_path = ntpath.basename(image_path[0]) 30 | short_path = image_path[0].split('/') 31 | short_path = short_path[-2] + '_' + short_path[-1] 32 | name = os.path.splitext(short_path)[0] 33 | 34 | webpage.add_header(name) 35 | ims, txts, links = [], [], [] 36 | 37 | for label, im_data in visuals.items(): 38 | im = util.tensor2im(im_data) 39 | image_name = '%s_%s.png' % (name, label) 40 | save_path = os.path.join(image_dir, image_name) 41 | h, w, _ = im.shape 42 | if aspect_ratio > 1.0: 43 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 44 | if aspect_ratio < 1.0: 45 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 46 | util.save_image(im, save_path) 47 | 48 | ims.append(image_name) 49 | txts.append(label) 50 | links.append(image_name) 51 | webpage.add_images(ims, txts, links, width=width) 52 | 53 | 54 | def save_more_images(webpage, name, sources, targets, aspect_ratio=1.0, width=256): 55 | image_dir = webpage.get_image_dir() 56 | 57 | webpage.add_header('%04d' % name) 58 | ims, links = [], [] 59 | 60 | for dnum, d in enumerate([sources, targets]): 61 | ims.append([]) 62 | links.append([]) 63 | for idx, im_data in enumerate(d): 64 | im = util.tensor2im(im_data) 65 | image_name = '%04d_%02d_%s.png' % (name, idx, ['S', 'T'][dnum]) 66 | save_path = os.path.join(image_dir, image_name) 67 | h, w, _ = im.shape 68 | if aspect_ratio > 1.0: 69 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 70 | if aspect_ratio < 1.0: 71 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 72 | util.save_image(im, save_path) 73 | ims[dnum].append(image_name) 74 | links[dnum].append(image_name) 75 | webpage.add_more_images(ims, links, width=width) 76 | 77 | 78 | class Visualizer(): 79 | """This class includes several functions that can display/save images and print/save logging information. 80 | 81 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 82 | """ 83 | 84 | def __init__(self, opt): 85 | """Initialize the Visualizer class 86 | 87 | Parameters: 88 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 89 | Step 1: Cache the training/test options 90 | Step 2: connect to a visdom server 91 | Step 3: create an HTML object for saveing HTML filters 92 | Step 4: create a logging file to store training losses 93 | """ 94 | self.opt = opt # cache the option 95 | self.display_id = opt.display_id 96 | self.use_html = opt.isTrain and not opt.no_html 97 | self.win_size = opt.display_winsize 98 | self.name = opt.name 99 | self.port = opt.display_port 100 | self.saved = False 101 | if self.display_id > 0: # connect to a visdom server given and 102 | import visdom 103 | self.ncols = opt.display_ncols 104 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 105 | if not self.vis.check_connection(): 106 | self.create_visdom_connections() 107 | 108 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 109 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 110 | self.img_dir = os.path.join(self.web_dir, 'images') 111 | print('create web directory %s...' % self.web_dir) 112 | util.mkdirs([self.web_dir, self.img_dir]) 113 | # create a logging file to store training losses 114 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 115 | with open(self.log_name, "a") as log_file: 116 | now = time.strftime("%c") 117 | log_file.write('================ Training Loss (%s) ================\n' % now) 118 | 119 | def reset(self): 120 | """Reset the self.saved status""" 121 | self.saved = False 122 | 123 | def create_visdom_connections(self): 124 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 125 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 126 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 127 | print('Command: %s' % cmd) 128 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 129 | 130 | def display_current_results(self, visuals, epoch, save_result): 131 | """Display current results on visdom; save current results to an HTML file. 132 | 133 | Parameters: 134 | visuals (OrderedDict) - - dictionary of images to display or save 135 | epoch (int) - - the current epoch 136 | save_result (bool) - - if save the current results to an HTML file 137 | """ 138 | if self.display_id > 0: # show images in the browser using visdom 139 | ncols = self.ncols 140 | if ncols > 0: # show all the images in one visdom panel 141 | ncols = min(ncols, len(visuals)) 142 | h, w = next(iter(visuals.values())).shape[:2] 143 | table_css = """""" % (w, h) # create a table css 147 | # create a table of images. 148 | title = self.name 149 | label_html = '' 150 | label_html_row = '' 151 | images = [] 152 | idx = 0 153 | for label, image in visuals.items(): 154 | image_numpy = util.tensor2im(image) 155 | label_html_row += '%s' % label 156 | images.append(image_numpy.transpose([2, 0, 1])) 157 | idx += 1 158 | if idx % ncols == 0: 159 | label_html += '%s' % label_html_row 160 | label_html_row = '' 161 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 162 | while idx % ncols != 0: 163 | images.append(white_image) 164 | label_html_row += '' 165 | idx += 1 166 | if label_html_row != '': 167 | label_html += '%s' % label_html_row 168 | try: 169 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 170 | padding=2, opts=dict(title=title + ' images')) 171 | label_html = '%s
' % label_html 172 | self.vis.text(table_css + label_html, win=self.display_id + 2, 173 | opts=dict(title=title + ' labels')) 174 | except VisdomExceptionBase: 175 | self.create_visdom_connections() 176 | 177 | else: # show each image in a separate visdom panel; 178 | idx = 1 179 | try: 180 | for label, image in visuals.items(): 181 | image_numpy = util.tensor2im(image) 182 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 183 | win=self.display_id + idx) 184 | idx += 1 185 | except VisdomExceptionBase: 186 | self.create_visdom_connections() 187 | 188 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 189 | self.saved = True 190 | # save images to the disk 191 | for label, image in visuals.items(): 192 | image_numpy = util.tensor2im(image) 193 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 194 | util.save_image(image_numpy, img_path) 195 | 196 | # update website 197 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 198 | for n in range(epoch, 0, -1): 199 | webpage.add_header('epoch [%d]' % n) 200 | ims, txts, links = [], [], [] 201 | 202 | for label, image_numpy in visuals.items(): 203 | image_numpy = util.tensor2im(image) 204 | img_path = 'epoch%.3d_%s.png' % (n, label) 205 | ims.append(img_path) 206 | txts.append(label) 207 | links.append(img_path) 208 | webpage.add_images(ims, txts, links, width=self.win_size) 209 | webpage.save() 210 | 211 | def plot_current_losses(self, epoch, counter_ratio, losses): 212 | """display the current losses on visdom display: dictionary of error labels and values 213 | 214 | Parameters: 215 | epoch (int) -- current epoch 216 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 217 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 218 | """ 219 | if not hasattr(self, 'plot_data'): 220 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 221 | self.plot_data['X'].append(epoch + counter_ratio) 222 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 223 | try: 224 | self.vis.line( 225 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 226 | Y=np.array(self.plot_data['Y']), 227 | opts={ 228 | 'title': self.name + ' loss over time', 229 | 'legend': self.plot_data['legend'], 230 | 'xlabel': 'epoch', 231 | 'ylabel': 'loss'}, 232 | win=self.display_id) 233 | except VisdomExceptionBase: 234 | self.create_visdom_connections() 235 | 236 | # losses: same format as |losses| of plot_current_losses 237 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 238 | """print current losses on console; also save the losses to the disk 239 | 240 | Parameters: 241 | epoch (int) -- current epoch 242 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 243 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 244 | t_comp (float) -- computational time per data point (normalized by batch_size) 245 | t_data (float) -- data loading time per data point (normalized by batch_size) 246 | """ 247 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 248 | for k, v in losses.items(): 249 | message += '%s: %.3f ' % (k, v) 250 | 251 | print(message) # print the message 252 | with open(self.log_name, "a") as log_file: 253 | log_file.write('%s\n' % message) # save the message 254 | --------------------------------------------------------------------------------