├── README.md
├── data
├── __init__.py
├── base_dataset.py
└── colorization_dataset.py
├── doc
├── ab_constant_filter.npy
└── weight_index.npy
├── imgs
├── reference.JPEG
├── target.JPEG
└── visual.jpg
├── models
├── __init__.py
├── base_model.py
├── colorization_model.py
├── main_model.py
└── networks.py
├── options
├── __init__.py
├── base_options.py
└── test_options.py
├── requirements.txt
├── test.py
├── test.sh
└── util
├── __init__.py
├── html.py
├── util.py
└── visualizer.py
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer for Image Colorization (Pytorch Implementation)
2 |
3 |
4 |
5 | ### [Paper](https://dl.acm.org/doi/10.1145/3474085.3475385) | [Pretrained Model](https://drive.google.com/file/d/11FM-2v4iVH8Dvowo-7bQG56Z_ey8kjOa/view?usp=sharing)
6 |
7 | **Yes, "Attention Is All You Need", for Exemplar based Colorization, ACMMM2021**
8 |
9 | Wang Yin1,
10 | Peng Lu1,
11 | Zhaoran Zhao1,
12 | Xujun Peng2
13 | 1Beijing University of Posts and Telecommunications,2USC
14 | ## Table of Contents
15 |
16 | - [Prerequisites](#Prerequisites)
17 | - [Getting Started](#Getting-Started)
18 | - [Citation](#Citation)
19 |
20 | ## Prerequisites
21 | - Ubuntu 16.04
22 | - Python 3.6.10
23 | - Pytorch 1.5.1
24 | - CPU or NVIDIA GPU + CUDA 10.2 CuDNN
25 |
26 | ## Getting Started
27 |
28 | ### Installation
29 | - Clone this repo:
30 | ```bash
31 | git clone https://github.com/wangyins/transformer-for-image-colorization
32 | cd transformer-for-image-colorization
33 | pip install requriments.txt
34 | ```
35 | - Download model weights from this link to get "checkpoints_acmmm2021.zip"
36 | ```bash
37 | mkdir -p checkpoints/imagenet/
38 | cd checkpoints/imagenet/
39 | unzip checkpoints_acmmm2021.zip
40 | ```
41 | ### Testing
42 | ```bash
43 | sh test.sh
44 | ```
45 | ## Citation
46 | If you use this code for your research, please cite our paper.
47 | ```
48 | @inproceedings{yin_mm2021,
49 | title={Yes, "Attention Is All You Need", for Exemplar based Colorization},
50 | author={Yin, Wang and Lu, Peng and Zhao, ZhaoRan and Peng, XuJun},
51 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
52 | year={2021}
53 | }
54 | ```
55 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from data.base_dataset import BaseDataset
16 |
17 |
18 | def find_dataset_using_name(dataset_name):
19 | """Import the module "data/[dataset_name]_dataset.py".
20 |
21 | In the file, the class called DatasetNameDataset() will
22 | be instantiated. It has to be a subclass of BaseDataset,
23 | and it is case-insensitive.
24 | """
25 | dataset_filename = "data." + dataset_name + "_dataset"
26 | datasetlib = importlib.import_module(dataset_filename)
27 |
28 | dataset = None
29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30 | for name, cls in datasetlib.__dict__.items():
31 | if name.lower() == target_dataset_name.lower() \
32 | and issubclass(cls, BaseDataset):
33 | dataset = cls
34 |
35 | if dataset is None:
36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37 |
38 | return dataset
39 |
40 |
41 | def get_option_setter(dataset_name):
42 | """Return the static method of the dataset class."""
43 | dataset_class = find_dataset_using_name(dataset_name)
44 | return dataset_class.modify_commandline_options
45 |
46 |
47 | def create_dataset(opt):
48 | """Create a dataset given the option.
49 |
50 | This function wraps the class CustomDatasetDataLoader.
51 | This is the main interface between this package and 'train.py'/'test.py'
52 |
53 | Example:
54 | >>> from data import create_dataset
55 | >>> dataset = create_dataset(opt)
56 | """
57 | data_loader = CustomDatasetDataLoader(opt)
58 | dataset = data_loader.load_data()
59 | return dataset
60 |
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | print("dataset [%s] was created" % type(self.dataset).__name__)
75 | self.dataloader = torch.utils.data.DataLoader(
76 | self.dataset,
77 | batch_size=opt.batch_size,
78 | shuffle=not opt.serial_batches,
79 | num_workers=int(opt.num_threads))
80 |
81 | def load_data(self):
82 | return self
83 |
84 | def __len__(self):
85 | """Return the number of data in the dataset"""
86 | return min(len(self.dataset), self.opt.max_dataset_size)
87 |
88 | def __iter__(self):
89 | """Return a batch of data"""
90 | for i, data in enumerate(self.dataloader):
91 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
92 | break
93 | yield data
94 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | self.root = opt.dataroot
31 |
32 | @staticmethod
33 | def modify_commandline_options(parser, is_train):
34 | """Add new dataset-specific options, and rewrite default values for existing options.
35 |
36 | Parameters:
37 | parser -- original option parser
38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39 |
40 | Returns:
41 | the modified parser.
42 | """
43 | return parser
44 |
45 | @abstractmethod
46 | def __len__(self):
47 | """Return the total number of images in the dataset."""
48 | return 0
49 |
50 | @abstractmethod
51 | def __getitem__(self, index):
52 | """Return a data point and its metadata information.
53 |
54 | Parameters:
55 | index - - a random integer for data indexing
56 |
57 | Returns:
58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59 | """
60 | pass
61 |
62 |
63 | def get_params(opt, size):
64 | w, h = size
65 | new_h = h
66 | new_w = w
67 | if opt.preprocess == 'resize_and_crop':
68 | new_h = new_w = opt.load_size
69 | elif opt.preprocess == 'scale_width_and_crop':
70 | new_w = opt.load_size
71 | new_h = opt.load_size * h // w
72 |
73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
75 |
76 | flip = random.random() > 0.5
77 |
78 | return {'crop_pos': (x, y), 'flip': flip}
79 |
80 |
81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True, must_crop=False):
82 | transform_list = []
83 | if grayscale:
84 | transform_list.append(transforms.Grayscale(1))
85 | if 'resize' in opt.preprocess or must_crop:
86 | osize = [opt.crop_size, opt.crop_size]
87 | transform_list.append(transforms.Resize(osize, method))
88 | elif 'scale_width' in opt.preprocess:
89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
90 |
91 | if 'crop' in opt.preprocess:
92 | if params is None:
93 | transform_list.append(transforms.CenterCrop(opt.crop_size))
94 | else:
95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
96 |
97 | if 'none' in opt.preprocess:
98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=8, method=method)))
99 |
100 | if not opt.no_flip:
101 | if params is None:
102 | transform_list.append(transforms.RandomHorizontalFlip())
103 | elif params['flip']:
104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
105 |
106 | if convert:
107 | transform_list += [transforms.ToTensor()]
108 | if grayscale:
109 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
110 | else:
111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
112 | return transforms.Compose(transform_list)
113 |
114 |
115 | def __make_power_2(img, base, method=Image.BICUBIC):
116 | ow, oh = img.size
117 | h = int(round(oh / base) * base)
118 | w = int(round(ow / base) * base)
119 | if (h == oh) and (w == ow):
120 | return img
121 |
122 | __print_size_warning(ow, oh, w, h)
123 | return img.resize((w, h), method)
124 |
125 |
126 | def __scale_width(img, target_width, method=Image.BICUBIC):
127 | ow, oh = img.size
128 | if ow <= oh:
129 | if (ow == target_width):
130 | return img
131 | w = target_width
132 | h = int(target_width * oh / ow)
133 | else:
134 | if (oh == target_width):
135 | return img
136 | h = target_width
137 | w = int(target_width * ow / oh)
138 | return img.resize((w, h), method)
139 |
140 |
141 | def __crop(img, pos, size):
142 | ow, oh = img.size
143 | x1, y1 = pos
144 | tw = th = size
145 | if (ow > tw or oh > th):
146 | return img.crop((x1, y1, x1 + tw, y1 + th))
147 | return img
148 |
149 |
150 | def __flip(img, flip):
151 | if flip:
152 | return img.transpose(Image.FLIP_LEFT_RIGHT)
153 | return img
154 |
155 |
156 | def __print_size_warning(ow, oh, w, h):
157 | """Print warning information about image size(only print once)"""
158 | if not hasattr(__print_size_warning, 'has_printed'):
159 | print("The image size needs to be a multiple of 4. "
160 | "The loaded image size was (%d, %d), so it was adjusted to "
161 | "(%d, %d). This adjustment will be done to all images "
162 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
163 | __print_size_warning.has_printed = True
164 |
--------------------------------------------------------------------------------
/data/colorization_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_transform
3 | from skimage import color # require skimage
4 | from PIL import Image
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | import cv2
8 | from collections import Counter
9 | from tqdm import tqdm
10 |
11 |
12 | class ColorizationDataset(BaseDataset):
13 | """This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space."""
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train):
16 | """Add new dataset-specific options, and rewrite default values for existing options.
17 |
18 | Parameters:
19 | parser -- original option parser
20 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
21 |
22 | Returns:
23 | the modified parser.
24 |
25 | By default, the number of channels for input image is 1 (L) and
26 | the nubmer of channels for output image is 2 (ab). The direction is from A to B
27 | """
28 | parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB')
29 | return parser
30 |
31 | def __init__(self, opt):
32 | """Initialize this dataset class.
33 |
34 | Parameters:
35 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
36 | """
37 | BaseDataset.__init__(self, opt)
38 | self.dir = os.path.join(opt.dataroot, opt.phase)
39 | self.AB_paths = [[self.opt.targetImage_path, self.opt.referenceImage_path]]
40 | self.ab_constant = np.load('./doc/ab_constant_filter.npy')
41 | self.transform_A = get_transform(self.opt, convert=False)
42 | self.transform_R = get_transform(self.opt, convert=False, must_crop=True)
43 | assert(opt.input_nc == 1 and opt.output_nc == 2)
44 |
45 | def __getitem__(self, index):
46 | path_A, path_R = self.AB_paths[index]
47 | im_A_l, im_A_ab, _ = self.process_img(path_A, self.transform_A)
48 | im_R_l, im_R_ab, hist = self.process_img(path_R, self.transform_R)
49 |
50 | im_dict = {
51 | 'A_l': im_A_l,
52 | 'A_ab': im_A_ab,
53 | 'R_l': im_R_l,
54 | 'R_ab': im_R_ab,
55 | 'ab': self.ab_constant,
56 | 'hist': hist,
57 | 'A_paths': path_A
58 | }
59 | return im_dict
60 |
61 |
62 | def process_img(self, im_path, transform):
63 |
64 | weights_index = np.load('./doc/weight_index.npy')
65 |
66 | im = Image.open(im_path).convert('RGB')
67 | im = transform(im)
68 | im = self.__scale_width(im, 256)
69 | im = np.array(im)
70 | im = im[:16 * int(im.shape[0] / 16.0), :16 * int(im.shape[1] / 16.0), :]
71 | l_ts, ab_ts, gt_keys = [], [], []
72 | hist_total_new = np.zeros((441,), dtype=np.float32)
73 | for ratio in [0.25, 0.5, 1]:
74 | if ratio == 1:
75 | im_ratio = im
76 | else:
77 | im_ratio = cv2.resize(im, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_AREA)
78 | lab = color.rgb2lab(im_ratio).astype(np.float32)
79 |
80 | if ratio == 1:
81 | ab_index_1 = np.round(lab[:, :, 1:] / 110.0 / 0.1) + 10.0
82 | keys_t = ab_index_1[:,:,0] * 21+ ab_index_1[:,:,1]
83 | keys_t_flatten = keys_t.flatten().astype(np.int32)
84 | dict_counter = dict(Counter(keys_t_flatten))
85 | for k, v in dict_counter.items():
86 | hist_total_new[k] += v
87 |
88 | hist = hist_total_new[weights_index]
89 | hist = hist / np.sum(hist)
90 |
91 | lab_t = transforms.ToTensor()(lab)
92 | l_t = lab_t[[0], ...] / 50.0 - 1.0
93 | ab_t = lab_t[[1, 2], ...] / 110.0
94 | l_ts.append(l_t)
95 | ab_ts.append(ab_t)
96 |
97 | return l_ts, ab_ts, hist
98 |
99 |
100 | def __scale_width(self, img, target_width, method=Image.BICUBIC):
101 | ow, oh = img.size
102 | if ow <= oh:
103 | if (ow == target_width):
104 | return img
105 | w = target_width
106 | h = int(target_width * oh / ow)
107 | else:
108 | if (oh == target_width):
109 | return img
110 | h = target_width
111 | w = int(target_width * ow / oh)
112 | return img.resize((w, h), method)
113 |
114 | def __len__(self):
115 | """Return the total number of images in the dataset."""
116 | return len(self.AB_paths)
117 |
--------------------------------------------------------------------------------
/doc/ab_constant_filter.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/doc/ab_constant_filter.npy
--------------------------------------------------------------------------------
/doc/weight_index.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/doc/weight_index.npy
--------------------------------------------------------------------------------
/imgs/reference.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/imgs/reference.JPEG
--------------------------------------------------------------------------------
/imgs/target.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/imgs/target.JPEG
--------------------------------------------------------------------------------
/imgs/visual.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangyins/transformer-for-image-colorization/38772b248e56e0449fc367addc5cd82d927cf160/imgs/visual.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 | from . import networks
6 |
7 |
8 | class BaseModel(ABC):
9 | """This class is an abstract base class (ABC) for models.
10 | To create a subclass, you need to implement the following five functions:
11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12 | -- : unpack data from dataset and apply preprocessing.
13 | -- : produce intermediate results.
14 | -- : calculate losses, gradients, and update network weights.
15 | -- : (optionally) add model-specific options and set default options.
16 | """
17 |
18 | def __init__(self, opt):
19 | """Initialize the BaseModel class.
20 |
21 | Parameters:
22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23 |
24 | When creating your custom class, you need to implement your own initialization.
25 | In this fucntion, you should first call
26 | Then, you need to define four lists:
27 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
28 | -- self.model_names (str list): specify the images that you want to display and save.
29 | -- self.visual_names (str list): define networks used in our training.
30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31 | """
32 | self.opt = opt
33 | self.gpu_ids = opt.gpu_ids
34 | self.isTrain = opt.isTrain
35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38 | torch.backends.cudnn.benchmark = True
39 | self.loss_names = []
40 | self.log_names = []
41 | self.model_names = []
42 | self.visual_names = []
43 | self.optimizers = []
44 | self.image_paths = []
45 | self.metric = 0 # used for learning rate policy 'plateau'
46 |
47 | @staticmethod
48 | def modify_commandline_options(parser, is_train):
49 | """Add new model-specific options, and rewrite default values for existing options.
50 |
51 | Parameters:
52 | parser -- original option parser
53 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
54 |
55 | Returns:
56 | the modified parser.
57 | """
58 | return parser
59 |
60 | @abstractmethod
61 | def set_input(self, input):
62 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
63 |
64 | Parameters:
65 | input (dict): includes the data itself and its metadata information.
66 | """
67 | pass
68 |
69 | @abstractmethod
70 | def forward(self):
71 | """Run forward pass; called by both functions and ."""
72 | pass
73 |
74 | def setup(self, opt):
75 | """Load and print networks; create schedulers
76 |
77 | Parameters:
78 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
79 | """
80 | if self.isTrain:
81 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
82 | if not self.isTrain or opt.continue_train:
83 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
84 | self.load_networks(load_suffix)
85 | self.print_networks(opt.verbose)
86 |
87 | def eval(self):
88 | """Make models eval mode during test time"""
89 | for name in self.model_names:
90 | if isinstance(name, str):
91 | net = getattr(self, 'net' + name)
92 | net.eval()
93 |
94 | def test(self):
95 | """Forward function used in test time.
96 |
97 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
98 | It also calls to produce additional visualization results
99 | """
100 | with torch.no_grad():
101 | self.forward()
102 | self.compute_visuals()
103 |
104 | def compute_visuals(self):
105 | """Calculate additional output images for visdom and HTML visualization"""
106 | pass
107 |
108 | def get_image_paths(self):
109 | """ Return image paths that are used to load current data"""
110 | return self.image_paths
111 |
112 | def update_learning_rate(self):
113 | """Update learning rates for all the networks; called at the end of every epoch"""
114 | for scheduler in self.schedulers:
115 | if self.opt.lr_policy == 'plateau':
116 | scheduler.step(self.metric)
117 | else:
118 | scheduler.step()
119 |
120 | lr = self.optimizers[0].param_groups[0]['lr']
121 | print('learning rate = %.7f' % lr)
122 |
123 | def get_current_visuals(self):
124 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
125 | visual_ret = OrderedDict()
126 | for name in self.visual_names:
127 | if isinstance(name, str):
128 | vis = getattr(self, name)
129 | if isinstance(vis, list):
130 | for i in range(len(vis)):
131 | visual_ret[name + '_' + str(i+1)] = vis[i]
132 | else:
133 | visual_ret[name] = vis
134 | return visual_ret
135 |
136 | def get_current_losses(self):
137 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
138 | errors_ret = OrderedDict()
139 | for name in self.loss_names:
140 | if isinstance(name, str):
141 | loss = getattr(self, 'loss_' + name)
142 | if isinstance(loss, list):
143 | for i in range(len(loss)):
144 | errors_ret[name + '_' + str(i+1)] = float(loss[i])
145 | else:
146 | errors_ret[name] = float(loss) # float(...) works for both scalar tensor and float number
147 | return errors_ret
148 |
149 | def get_current_log(self):
150 | ret = self.get_current_losses()
151 | for name in self.log_names:
152 | if isinstance(name, str):
153 | ret[name] = float(getattr(self, name)) # float(...) works for both scalar tensor and float number
154 | return ret
155 |
156 | def save_networks(self, epoch):
157 | """Save all the networks to the disk.
158 |
159 | Parameters:
160 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
161 | """
162 | for name in self.model_names:
163 | if isinstance(name, str):
164 | net = getattr(self, 'net' + name)
165 | if isinstance(net, list):
166 | for i in range(len(net)):
167 | save_filename = '%s_net_%s_%d.pth' % (epoch, name, i+1)
168 | save_path = os.path.join(self.save_dir, save_filename)
169 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
170 | torch.save(net[i].module.cpu().state_dict(), save_path)
171 | net[i].cuda(self.gpu_ids[0])
172 | else:
173 | torch.save(net[i].cpu().state_dict(), save_path)
174 | else:
175 | save_filename = '%s_net_%s.pth' % (epoch, name)
176 | save_path = os.path.join(self.save_dir, save_filename)
177 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
178 | torch.save(net.module.cpu().state_dict(), save_path)
179 | net.cuda(self.gpu_ids[0])
180 | else:
181 | torch.save(net.cpu().state_dict(), save_path)
182 |
183 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
184 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
185 | key = keys[i]
186 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
187 | if module.__class__.__name__.startswith('InstanceNorm') and \
188 | (key == 'running_mean' or key == 'running_var'):
189 | if getattr(module, key) is None:
190 | state_dict.pop('.'.join(keys))
191 | if module.__class__.__name__.startswith('InstanceNorm') and \
192 | (key == 'num_batches_tracked'):
193 | state_dict.pop('.'.join(keys))
194 | else:
195 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
196 |
197 | def load_networks(self, epoch):
198 | """Load all the networks from the disk.
199 |
200 | Parameters:
201 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
202 | """
203 | for name in self.model_names:
204 | if isinstance(name, str):
205 | net = getattr(self, 'net' + name)
206 | if isinstance(net, list):
207 | for i in range(len(net)):
208 | load_filename = '%s_net_%s_%d.pth' % (epoch, name, i+1)
209 | load_path = os.path.join(self.save_dir, load_filename)
210 | net_i = net[i]
211 | if isinstance(net_i, torch.nn.DataParallel):
212 | net_i = net_i.module
213 | print('loading the model from %s' % load_path)
214 | # if you are using PyTorch newer than 0.4 (e.g., built from
215 | # GitHub source), you can remove str() on self.device
216 | state_dict = torch.load(load_path, map_location=str(self.device))
217 | if hasattr(state_dict, '_metadata'):
218 | del state_dict._metadata
219 |
220 | # patch InstanceNorm checkpoints prior to 0.4
221 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
222 | self.__patch_instance_norm_state_dict(state_dict, net_i, key.split('.'))
223 | net_i.load_state_dict(state_dict)
224 | else:
225 | load_filename = '%s_net_%s.pth' % (epoch, name)
226 | load_path = os.path.join(self.save_dir, load_filename)
227 | if isinstance(net, torch.nn.DataParallel):
228 | net = net.module
229 | print('loading the model from %s' % load_path)
230 | # if you are using PyTorch newer than 0.4 (e.g., built from
231 | # GitHub source), you can remove str() on self.device
232 | state_dict = torch.load(load_path, map_location=str(self.device))
233 | if hasattr(state_dict, '_metadata'):
234 | del state_dict._metadata
235 |
236 | # patch InstanceNorm checkpoints prior to 0.4
237 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
238 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
239 | net.load_state_dict(state_dict)
240 |
241 | def print_networks(self, verbose):
242 | """Print the total number of parameters in the network and (if verbose) network architecture
243 |
244 | Parameters:
245 | verbose (bool) -- if verbose: print the network architecture
246 | """
247 | print('---------- Networks initialized -------------')
248 | for name in self.model_names:
249 | if isinstance(name, str):
250 | net = getattr(self, 'net' + name)
251 | if isinstance(net, list):
252 | for i in range(len(net)):
253 | num_params = 0
254 | for param in net[i].parameters():
255 | num_params += param.numel()
256 | if verbose:
257 | print(net[i])
258 | print('[Network %s_%d] Total number of parameters : %.3f M' % (name, i+1, num_params / 1e6))
259 | else:
260 | num_params = 0
261 | for param in net.parameters():
262 | num_params += param.numel()
263 | if verbose:
264 | print(net)
265 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
266 | print('-----------------------------------------------')
267 |
268 | def set_requires_grad(self, nets, requires_grad=False):
269 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
270 | Parameters:
271 | nets (network list) -- a list of networks
272 | requires_grad (bool) -- whether the networks require gradients or not
273 | """
274 | if not isinstance(nets, list):
275 | nets = [nets]
276 | for net in nets:
277 | if net is not None:
278 | for param in net.parameters():
279 | param.requires_grad = requires_grad
--------------------------------------------------------------------------------
/models/colorization_model.py:
--------------------------------------------------------------------------------
1 | from .main_model import MainModel
2 | import torch
3 | from skimage import color
4 | import numpy as np
5 | import cv2
6 |
7 |
8 | class ColorizationModel(MainModel):
9 |
10 | @staticmethod
11 | def modify_commandline_options(parser, is_train=True):
12 | MainModel.modify_commandline_options(parser, is_train)
13 | parser.set_defaults(dataset_mode='colorization')
14 | return parser
15 |
16 | def __init__(self, opt):
17 | MainModel.__init__(self, opt)
18 | self.visual_names = ['real_A_l_0', 'real_A_rgb', 'real_R_rgb', 'fake_R_rgb']
19 |
20 | def lab2rgb(self, L, AB):
21 | AB2 = AB * 110.0
22 | L2 = (L + 1.0) * 50.0
23 | Lab = torch.cat([L2, AB2], dim=1)
24 | Lab = Lab[0].data.cpu().float().numpy()
25 | Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
26 | rgb = color.lab2rgb(Lab) * 255
27 | return rgb
28 |
29 | def tensor2gray(self, im):
30 | im = im[0].data.cpu().float().numpy()
31 | im = np.transpose(im.astype(np.float64), (1, 2, 0))
32 | im = np.repeat(im, 3, axis=-1) * 255
33 | return im
34 |
35 | def compute_visuals(self):
36 | self.real_A_l_0 = self.real_A_l[-1]
37 | self.real_A_rgb = self.lab2rgb(self.real_A_l[-1], self.real_A_ab[-1])
38 | self.real_R_rgb = self.lab2rgb(self.real_R_l[-1], self.real_R_ab[-1])
39 | self.real_R_rgb = cv2.resize(self.real_R_rgb, (self.real_A_rgb.shape[1], self.real_A_rgb.shape[0]))
40 | self.fake_R_rgb = []
41 | for i in range(3):
42 | self.fake_R_rgb += [self.lab2rgb(self.real_A_l[i], self.fake_imgs[i])]
43 | if i != 2:
44 | self.fake_R_rgb[i] = cv2.resize(self.fake_R_rgb[i], (self.real_A_rgb.shape[1], self.real_A_rgb.shape[0]))
45 |
46 | def compute_scores(self):
47 | metrics = []
48 | hr = self.real_R_histogram[-1].data.cpu().float().numpy().flatten()
49 | hg = self.fake_R_histogram[-1].data.cpu().float().numpy().flatten()
50 | intersect = cv2.compareHist(hr, hg, cv2.HISTCMP_INTERSECT)
51 | metrics.append(intersect)
52 |
53 | return metrics
54 |
--------------------------------------------------------------------------------
/models/main_model.py:
--------------------------------------------------------------------------------
1 | from .base_model import BaseModel
2 | from . import networks
3 | from util import util
4 |
5 |
6 | class MainModel(BaseModel):
7 | @staticmethod
8 | def modify_commandline_options(parser, is_train=True):
9 | parser.set_defaults(norm='instance', dataset_mode='aligned')
10 | return parser
11 |
12 | def __init__(self, opt):
13 |
14 | BaseModel.__init__(self, opt)
15 | self.visual_names = ['real_A', 'fake_B', 'real_B']
16 | self.model_names = ['G']
17 | self.netG = networks.define_G(opt.input_nc, opt.bias_input_nc, opt.output_nc, opt.norm, opt.init_type,
18 | opt.init_gain, self.gpu_ids)
19 | self.convert = util.Convert(self.device)
20 |
21 | def set_input(self, input):
22 | self.image_paths = input['A_paths']
23 | self.ab_constant = input['ab'].to(self.device)
24 | self.hist = input['hist'].to(self.device)
25 |
26 | self.real_A_l, self.real_A_ab, self.real_R_l, self.real_R_ab, self.real_R_histogram = [], [], [], [], []
27 | for i in range(3):
28 | self.real_A_l += input['A_l'][i].to(self.device).unsqueeze(0)
29 | self.real_A_ab += input['A_ab'][i].to(self.device).unsqueeze(0)
30 | self.real_R_l += input['R_l'][i].to(self.device).unsqueeze(0)
31 | self.real_R_ab += input['R_ab'][i].to(self.device).unsqueeze(0)
32 | self.real_R_histogram += [util.calc_hist(input['A_ab'][i].to(self.device), self.device)]
33 |
34 | def forward(self):
35 | self.fake_imgs = self.netG(self.real_A_l[-1], self.real_R_l[-1], self.real_R_ab[0], self.hist,
36 | self.ab_constant, self.device)
37 | self.fake_R_histogram = []
38 | for i in range(3):
39 | self.fake_R_histogram += [util.calc_hist(self.fake_imgs[i], self.device)]
40 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 | import torch.nn.functional as F
7 |
8 |
9 | class Identity(nn.Module):
10 | def forward(self, x):
11 | return x
12 |
13 |
14 | def get_norm_layer(norm_type='instance'):
15 | if norm_type == 'batch':
16 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
17 | elif norm_type == 'instance':
18 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
19 | elif norm_type == 'none':
20 | norm_layer = lambda x: Identity()
21 | else:
22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
23 | return norm_layer
24 |
25 |
26 | def get_scheduler(optimizer, opt):
27 | if opt.lr_policy == 'linear':
28 | def lambda_rule(epoch):
29 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
30 | return lr_l
31 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
32 | elif opt.lr_policy == 'step':
33 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
34 | elif opt.lr_policy == 'plateau':
35 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
36 | elif opt.lr_policy == 'cosine':
37 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
38 | else:
39 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
40 | return scheduler
41 |
42 |
43 | def init_weights(net, init_type='normal', init_gain=0.02):
44 | def init_func(m): # define the initialization function
45 | classname = m.__class__.__name__
46 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
47 | if init_type == 'normal':
48 | init.normal_(m.weight.data, 0.0, init_gain)
49 | elif init_type == 'xavier':
50 | init.xavier_normal_(m.weight.data, gain=init_gain)
51 | elif init_type == 'kaiming':
52 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
53 | elif init_type == 'orthogonal':
54 | init.orthogonal_(m.weight.data, gain=init_gain)
55 | else:
56 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
57 | if hasattr(m, 'bias') and m.bias is not None:
58 | init.constant_(m.bias.data, 0.0)
59 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
60 | init.normal_(m.weight.data, 1.0, init_gain)
61 | init.constant_(m.bias.data, 0.0)
62 |
63 | print('initialize network with %s' % init_type)
64 | net.apply(init_func) # apply the initialization function
65 |
66 |
67 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
68 | if len(gpu_ids) > 0:
69 | assert(torch.cuda.is_available())
70 | net.to(gpu_ids[0])
71 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
72 | init_weights(net, init_type, init_gain=init_gain)
73 | return net
74 |
75 |
76 | def define_G(input_nc, bias_input_nc, output_nc, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
77 | norm_layer = get_norm_layer(norm_type=norm)
78 | net = ColorNet(input_nc, bias_input_nc, output_nc, norm_layer=norm_layer)
79 |
80 | return init_net(net, init_type, init_gain, gpu_ids)
81 |
82 |
83 | class ResBlock(nn.Module):
84 | def __init__(self, dim, norm_layer, use_dropout, use_bias):
85 | super(ResBlock, self).__init__()
86 | conv_block = [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
87 | if use_dropout:
88 | conv_block += [nn.Dropout(0.5)]
89 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias), norm_layer(dim)]
90 | self.conv_block = nn.Sequential(*conv_block)
91 |
92 | def forward(self, x):
93 | out = x + self.conv_block(x) # add skip connections
94 | return out
95 |
96 |
97 | class global_network(nn.Module):
98 | def __init__(self, in_dim):
99 | super(global_network, self).__init__()
100 | model = [nn.Conv2d(in_dim, 512, kernel_size=1, padding=0), nn.ReLU(True)]
101 | model += [nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.ReLU(True)]
102 | model += [nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.ReLU(True)]
103 | self.model = nn.Sequential(*model)
104 |
105 | self.model_1 = nn.Sequential(*[nn.Conv2d(512, 512, kernel_size=1, padding=0), nn.Sigmoid()])
106 |
107 | def forward(self, x):
108 | x = self.model(x)
109 | x1 = self.model_1(x)
110 |
111 | return x1
112 |
113 |
114 | class ref_network_align(nn.Module):
115 | def __init__(self, norm_layer):
116 | super(ref_network_align, self).__init__()
117 | model1 = [nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)]
118 | model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(64)]
119 | self.model1 = nn.Sequential(*model1)
120 | model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)]
121 | model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)]
122 | self.model2 = nn.Sequential(*model2)
123 | model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)]
124 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)]
125 | self.model3 = nn.Sequential(*model3)
126 | model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)]
127 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(256)]
128 | self.model4 = nn.Sequential(*model4)
129 |
130 | def forward(self, color, corr, H, W):
131 |
132 | color_flatten = color.view(color.shape[0], color.shape[1], -1)
133 | align_color = torch.bmm(color_flatten, corr)
134 | align_color_output = align_color.view(align_color.shape[0], align_color.shape[1], H, W)
135 |
136 | conv1 = self.model1(align_color_output)
137 | align_color1 = self.model2(conv1)
138 | align_color2 = self.model3(align_color1[:,:,::2,::2])
139 | align_color3 = self.model4(align_color2[:,:,::2,::2])
140 |
141 | return align_color1, align_color2, align_color3
142 |
143 |
144 | class ref_network_hist(nn.Module):
145 | def __init__(self, norm_layer):
146 | super(ref_network_hist, self).__init__()
147 | model1 = [nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)]
148 | model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(64)]
149 | self.model1 = nn.Sequential(*model1)
150 | model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)]
151 | model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)]
152 | self.model2 = nn.Sequential(*model2)
153 | model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)]
154 | model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(128)]
155 | self.model3 = nn.Sequential(*model3)
156 | model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True)]
157 | model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(True), norm_layer(256)]
158 | self.model4 = nn.Sequential(*model4)
159 |
160 | def forward(self, color):
161 |
162 | conv1 = self.model1(color)
163 | align_color1 = self.model2(conv1)
164 | align_color2 = self.model3(align_color1[:,:,::2,::2])
165 | align_color3 = self.model4(align_color2[:,:,::2,::2])
166 |
167 | return align_color1, align_color2, align_color3
168 |
169 |
170 | class conf_feature_align(nn.Module):
171 | def __init__(self):
172 | super(conf_feature_align, self).__init__()
173 | self.fc1 = nn.Sequential(*[nn.Conv1d(4096, 1024, kernel_size=1, stride=1, padding=0, bias=True), nn.ReLU(True)])
174 | self.fc2 = nn.Sequential(*[nn.Conv1d(1024, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid()])
175 | self.dropout1 = nn.Dropout(0.1)
176 |
177 | def forward(self, x):
178 | x1 = self.fc1(x)
179 | x2 = self.dropout1(x1)
180 | x3 = self.fc2(x2)
181 |
182 | return x3
183 |
184 |
185 | class conf_feature_hist(nn.Module):
186 | def __init__(self):
187 | super(conf_feature_hist, self).__init__()
188 | self.fc1 = nn.Sequential(*[nn.Conv1d(4096, 1024, kernel_size=1, stride=1, padding=0, bias=True), nn.ReLU(True)])
189 | self.fc2 = nn.Sequential(*[nn.Conv1d(1024, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid()])
190 | self.dropout1 = nn.Dropout(0.1)
191 |
192 | def forward(self, x):
193 | x1 = self.fc1(x)
194 | x2 = self.dropout1(x1)
195 | x3 = self.fc2(x2)
196 |
197 | return x3
198 |
199 |
200 | class classify_network(nn.Module):
201 | def __init__(self):
202 | super(classify_network, self).__init__()
203 | self.maxpool = nn.AdaptiveMaxPool2d((1, 1))
204 | self.fc = nn.Linear(512, 1000)
205 |
206 | def forward(self, x):
207 | x = self.maxpool(x)
208 | x = x.squeeze(-1).squeeze(-1)
209 | x = self.fc(x)
210 | return x
211 |
212 |
213 | class ColorNet(nn.Module):
214 | def __init__(self, input_nc, bias_input_nc, output_nc, norm_layer=nn.BatchNorm2d):
215 | super(ColorNet, self).__init__()
216 | self.input_nc = input_nc
217 | self.output_nc = output_nc
218 | use_bias = True
219 |
220 | model_head = [nn.Conv2d(input_nc, 32, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True),
221 | norm_layer(32)]
222 |
223 | # Conv1
224 | model1=[nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),]
225 | model1+=[nn.ReLU(True),]
226 | model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),]
227 | model1+=[nn.ReLU(True),]
228 | model1+=[norm_layer(64),]
229 |
230 | # Conv2
231 | model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]
232 | model2+=[nn.ReLU(True),]
233 | model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]
234 | model2+=[nn.ReLU(True),]
235 | model2+=[norm_layer(128),]
236 |
237 | # Conv3
238 | model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
239 | model3+=[nn.ReLU(True),]
240 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
241 | model3+=[nn.ReLU(True),]
242 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
243 | model3+=[nn.ReLU(True),]
244 | model3+=[norm_layer(256),]
245 |
246 | # Conv4
247 | model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),]
248 | model4+=[nn.ReLU(True),]
249 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),]
250 | model4+=[nn.ReLU(True),]
251 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),]
252 | model4+=[nn.ReLU(True),]
253 | model4+=[norm_layer(512),]
254 |
255 | # Conv5
256 | model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
257 | model5+=[nn.ReLU(True),]
258 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
259 | model5+=[nn.ReLU(True),]
260 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
261 | model5+=[nn.ReLU(True),]
262 | model5+=[norm_layer(512),]
263 |
264 | # Conv6
265 | model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
266 | model6+=[nn.ReLU(True),]
267 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
268 | model6+=[nn.ReLU(True),]
269 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
270 | model6+=[nn.ReLU(True),]
271 | model6+=[norm_layer(512),]
272 |
273 | # Conv7
274 | model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
275 | model7+=[nn.ReLU(True),]
276 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
277 | model7+=[nn.ReLU(True),]
278 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
279 | model7+=[nn.ReLU(True),]
280 | model7+=[norm_layer(512),]
281 |
282 | model_hist=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
283 | model_hist+=[nn.ReLU(True),]
284 | model_hist+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
285 | model_hist+=[nn.ReLU(True),]
286 | model_hist+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
287 | model_hist+=[nn.ReLU(True),]
288 |
289 | model_hist+=[nn.Conv2d(256, 198, kernel_size=1, stride=1, padding=0, bias=True),]
290 |
291 | # ResBlock0
292 | resblock0_1 = [nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=use_bias), norm_layer(512), nn.ReLU(True)]
293 | self.resblock0_2 = ResBlock(512, norm_layer, False, use_bias)
294 | self.resblock0_3 = ResBlock(512, norm_layer, False, use_bias)
295 |
296 | # Conv8
297 | model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)]
298 |
299 | model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
300 |
301 | model8=[nn.ReLU(True),]
302 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
303 | model8+=[nn.ReLU(True),]
304 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
305 | model8+=[nn.ReLU(True),]
306 | model8+=[norm_layer(256),]
307 |
308 | # ResBlock1
309 | resblock1_1 = [nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=use_bias), norm_layer(256), nn.ReLU(True)]
310 | self.resblock1_2 = ResBlock(256, norm_layer, False, use_bias)
311 | self.resblock1_3 = ResBlock(256, norm_layer, False, use_bias)
312 |
313 | # Conv9
314 | model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),]
315 |
316 | model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]
317 |
318 | model9=[nn.ReLU(True),]
319 | model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]
320 | model9+=[nn.ReLU(True),]
321 | model9+=[norm_layer(128),]
322 |
323 | # ResBlock2
324 | resblock2_1 = [nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=use_bias), norm_layer(128), nn.ReLU(True)]
325 | self.resblock2_2 = ResBlock(128, norm_layer, False, use_bias)
326 | self.resblock2_3 = ResBlock(128, norm_layer, False, use_bias)
327 |
328 | # Conv10
329 | model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),]
330 |
331 | model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]
332 |
333 | model10=[nn.ReLU(True),]
334 | model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=use_bias),]
335 | model10+=[nn.LeakyReLU(negative_slope=.2),]
336 |
337 | # Conv Global
338 | self.global_network = global_network(bias_input_nc)
339 |
340 | # conf feature
341 | self.conf_feature_align = conf_feature_align()
342 | self.conf_feature_hist = conf_feature_hist()
343 |
344 | # Conv Ref
345 | self.ref_network_align = ref_network_align(norm_layer)
346 | self.ref_network_hist = ref_network_hist(norm_layer)
347 |
348 | # classification
349 | self.classify_network = classify_network()
350 |
351 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
352 | self.softmax_gate = nn.Softmax(dim=1)
353 | self.softmax = nn.Softmax(dim=-1)
354 | self.key_dataset = torch.eye(bias_input_nc)
355 |
356 | model_tail_1 = [nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2)]
357 | model_tail_2 = [nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2)]
358 | model_tail_3 = [nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2)]
359 |
360 | model_out1 = [nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()]
361 | model_out2 = [nn.Conv2d(64, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()]
362 | model_out3 = [nn.Conv2d(64, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), nn.Tanh()]
363 |
364 | self.model1 = nn.Sequential(*model1)
365 | self.model2 = nn.Sequential(*model2)
366 | self.model3 = nn.Sequential(*model3)
367 | self.model4 = nn.Sequential(*model4)
368 | self.model5 = nn.Sequential(*model5)
369 | self.model6 = nn.Sequential(*model6)
370 | self.model7 = nn.Sequential(*model7)
371 | self.model_hist = nn.Sequential(*model_hist)
372 | self.model8up = nn.Sequential(*model8up)
373 | self.model8 = nn.Sequential(*model8)
374 | self.model9up = nn.Sequential(*model9up)
375 | self.model9 = nn.Sequential(*model9)
376 | self.model10up = nn.Sequential(*model10up)
377 | self.model10 = nn.Sequential(*model10)
378 | self.model3short8 = nn.Sequential(*model3short8)
379 | self.model2short9 = nn.Sequential(*model2short9)
380 | self.model1short10 = nn.Sequential(*model1short10)
381 | self.resblock0_1 = nn.Sequential(*resblock0_1)
382 | self.resblock1_1 = nn.Sequential(*resblock1_1)
383 | self.resblock2_1 = nn.Sequential(*resblock2_1)
384 | self.model_out1 = nn.Sequential(*model_out1)
385 | self.model_out2 = nn.Sequential(*model_out2)
386 | self.model_out3 = nn.Sequential(*model_out3)
387 | self.model_head = nn.Sequential(*model_head)
388 | self.model_tail_1 = nn.Sequential(*model_tail_1)
389 | self.model_tail_2 = nn.Sequential(*model_tail_2)
390 | self.model_tail_3 = nn.Sequential(*model_tail_3)
391 |
392 |
393 | def forward(self, input, ref_input, ref_color, bias_input, ab_constant, device):
394 |
395 | # align branch
396 | in_conv = self.model_head(input)
397 |
398 | in_1 = self.model1(in_conv[:, :, ::2, ::2])
399 | in_2 = self.model2(in_1[:, :, ::2, ::2])
400 | in_3 = self.model3(in_2[:, :, ::2, ::2])
401 | in_4 = self.model4(in_3[:, :, ::2, ::2])
402 | in_5 = self.model5(in_4)
403 | in_6 = self.model6(in_5)
404 |
405 | ref_conv_head = self.model_head(ref_input)
406 | ref_1 = self.model1(ref_conv_head[:,:,::2,::2])
407 | ref_2 = self.model2(ref_1[:, :, ::2, ::2])
408 | ref_3 = self.model3(ref_2[:, :, ::2, ::2])
409 | ref_4 = self.model4(ref_3[:, :, ::2, ::2])
410 | ref_5 = self.model5(ref_4)
411 | ref_6 = self.model6(ref_5)
412 |
413 | t1 = F.interpolate(in_1, scale_factor=0.5, mode='bilinear')
414 | t2 = in_2
415 | t3 = F.interpolate(in_3, scale_factor=2, mode='bilinear')
416 | t4 = F.interpolate(in_4, scale_factor=4, mode='bilinear')
417 | t5 = F.interpolate(in_5, scale_factor=4, mode='bilinear')
418 | t6 = F.interpolate(in_6, scale_factor=4, mode='bilinear')
419 | t = torch.cat((t1, t2, t3, t4, t5, t6), dim=1)
420 |
421 | r1 = F.interpolate(ref_1, scale_factor=0.5, mode='bilinear')
422 | r2 = ref_2
423 | r3 = F.interpolate(ref_3, scale_factor=2, mode='bilinear')
424 | r4 = F.interpolate(ref_4, scale_factor=4, mode='bilinear')
425 | r5 = F.interpolate(ref_5, scale_factor=4, mode='bilinear')
426 | r6 = F.interpolate(ref_6, scale_factor=4, mode='bilinear')
427 | r = torch.cat((r1, r2, r3, r4, r5, r6), dim=1)
428 |
429 | input_T_flatten = t.view(t.shape[0], t.shape[1], -1).permute(0, 2, 1)
430 | input_R_flatten = r.view(r.shape[0], r.shape[1], -1).permute(0, 2, 1)
431 | input_T_flatten = input_T_flatten / torch.norm(input_T_flatten, p=2, dim=-1, keepdim=True)
432 | input_R_flatten = input_R_flatten / torch.norm(input_R_flatten, p=2, dim=-1, keepdim=True)
433 | corr = torch.bmm(input_R_flatten, input_T_flatten.permute(0, 2, 1))
434 |
435 | corr = F.softmax(corr / 0.01, dim=1)
436 |
437 | # Align branch confidence map learning
438 | align_1, align_2, align_3 = self.ref_network_align(ref_color, corr, t2.shape[2], t2.shape[3])
439 | conf_align = self.conf_feature_align(corr)
440 | conf_align = conf_align.view(conf_align.shape[0], 1, t2.shape[2], t2.shape[3])
441 | conf_aligns = 5.0 * conf_align
442 |
443 | # Histogram branch confidence map learning
444 | conf_hist = self.conf_feature_hist(corr)
445 | conf_hist = conf_hist.view(conf_hist.shape[0], 1, t2.shape[2], t2.shape[3])
446 | conf_hists = 5.0 * conf_hist
447 |
448 | # Gate softmax operation on confidence map
449 | conf_total = torch.cat((conf_aligns, conf_hists), dim=1)
450 | conf_softmax = self.softmax_gate(conf_total)
451 |
452 | conf_1_align = conf_softmax[:, :1, :, :]
453 | conf_1_hist = conf_softmax[:, 1:, :, :]
454 | conf_2_align = conf_1_align[:,:,::2,::2]
455 | conf_3_align = conf_2_align[:,:,::2,::2]
456 | conf_2_hist = conf_1_hist[:,:,::2,::2]
457 | conf_3_hist = conf_2_hist[:,:,::2,::2]
458 |
459 | # hist branch
460 | bias_input = bias_input.view(input.shape[0], -1, 1, 1)
461 |
462 | conv_head = self.model_head(input)
463 | conv1_2 = self.model1(conv_head[:, :, ::2, ::2])
464 | conv2_2 = self.model2(conv1_2[:,:,::2,::2])
465 | conv3_3 = self.model3(conv2_2[:,:,::2,::2])
466 | conv4_3 = self.model4(conv3_3[:,:,::2,::2])
467 | conv5_3 = self.model5(conv4_3)
468 | conv6_3 = self.model6(conv5_3)
469 |
470 | class_output = self.classify_network(conv6_3)
471 |
472 | # hist align
473 | conv_global1 = self.global_network(bias_input)
474 | conv_global1_repeat = conv_global1.expand_as(conv6_3)
475 | conv_global1_add = conv6_3 * conv_global1_repeat
476 | conv7_3 = self.model7(conv_global1_add)
477 | color_reg = self.model_hist(conv7_3)
478 |
479 | # calculate attention matrix for histogram branch
480 | key_datasets = self.key_dataset.unsqueeze(0).to(device)
481 | attn_weights = torch.bmm(color_reg.flatten(2).permute(0, 2, 1), key_datasets)
482 | value = ab_constant.type_as(color_reg)
483 | attn_weights_softmax = self.softmax(attn_weights * 100.0)
484 | conv_total_out = torch.bmm(attn_weights_softmax, value).permute(0, 2, 1)
485 | conv_total_out_re = conv_total_out.view(color_reg.shape[0], -1, color_reg.shape[2], color_reg.shape[3])
486 | conv_total_out_up = self.upsample(conv_total_out_re)
487 |
488 | hist_1, hist_2, hist_3 = self.ref_network_hist(conv_total_out_up)
489 |
490 | # encoder1
491 | conv6_3_global = conv6_3 + align_3 * conf_3_align + hist_3 * conf_3_hist
492 | conv7_resblock1 = self.resblock0_1(conv6_3_global)
493 | conv7_resblock2 = self.resblock0_2(conv7_resblock1)
494 | conv7_resblock3 = self.resblock0_3(conv7_resblock2)
495 | conv8_up = self.model8up(conv7_resblock3) + self.model3short8(conv3_3)
496 | conv8_3 = self.model8(conv8_up)
497 | conv_tail_1 = self.model_tail_1(conv8_3)
498 | fake_img1 = self.model_out1(conv_tail_1)
499 |
500 | # encoder2
501 | conv8_3_global = conv8_3 + align_2 * conf_2_align + hist_2 * conf_2_hist
502 | conv8_resblock1 = self.resblock1_1(conv8_3_global)
503 | conv8_resblock2 = self.resblock1_2(conv8_resblock1)
504 | conv8_resblock3 = self.resblock1_3(conv8_resblock2)
505 | conv9_up = self.model9up(conv8_resblock3) + self.model2short9(conv2_2)
506 | conv9_3 = self.model9(conv9_up)
507 | conv_tail_2 = self.model_tail_2(conv9_3)
508 | fake_img2 = self.model_out2(conv_tail_2)
509 |
510 | # encoder3
511 | conv9_3_global = conv9_3 + align_1 * conf_1_align + hist_1 * conf_1_hist
512 | conv9_resblock1 = self.resblock2_1(conv9_3_global)
513 | conv9_resblock2 = self.resblock2_2(conv9_resblock1)
514 | conv9_resblock3 = self.resblock2_3(conv9_resblock2)
515 | conv10_up = self.model10up(conv9_resblock3) + self.model1short10(conv1_2)
516 | conv10_2 = self.model10(conv10_up)
517 | conv_tail_3 = self.model_tail_3(conv10_2)
518 | fake_img3 = self.model_out3(conv_tail_3)
519 |
520 | return [fake_img1, fake_img2, fake_img3]
521 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: test options, and basic options (used in test)."""
2 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import models
6 | import data
7 |
8 |
9 | class BaseOptions():
10 | """This class defines options used during both training and test time.
11 |
12 | It also implements several helper functions such as parsing, printing, and saving the options.
13 | It also gathers additional options defined in functions in both dataset class and model class.
14 | """
15 |
16 | def __init__(self):
17 | """Reset the class; indicates the class hasn't been initailized"""
18 | self.initialized = False
19 |
20 | def initialize(self, parser):
21 | """Define the common options that are used in both training and test."""
22 | # basic parameters
23 | parser.add_argument('--dataroot', type=str, default='./dataset/',
24 | help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
25 | parser.add_argument('--name', type=str, default='imagenet', help='name of the experiment. It decides where to store samples and models')
26 | parser.add_argument('--gpu_ids', type=str, default='1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/', help='models are saved here')
28 | # model parameters
29 | parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale')
30 | parser.add_argument('--bias_input_nc', type=int, default=198, help='# of reference image histogram bins')
31 | parser.add_argument('--output_nc', type=int, default=2, help='# of output image channels: 3 for RGB and 1 for grayscale')
32 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
33 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
34 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
35 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
36 | # dataset parameters
37 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
38 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
39 | parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
40 | parser.add_argument('--load_size', type=int, default=288, help='scale images to this size')
41 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
42 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
43 | parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
44 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
45 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
46 | parser.add_argument('--targetImage_path', type=str, default='./imgs/target.JPEG')
47 | parser.add_argument('--referenceImage_path', type=str, default='./imgs/reference.JPEG')
48 | # additional parameters
49 | parser.add_argument('--use_D', action='store_true', help='whether to use discriminator or not')
50 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
51 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
52 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
53 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
54 | self.initialized = True
55 | return parser
56 |
57 | def gather_options(self):
58 | """Initialize our parser with basic options(only once).
59 | Add additional model-specific and dataset-specific options.
60 | These options are defined in the function
61 | in model and dataset classes.
62 | """
63 | if not self.initialized: # check if it has been initialized
64 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
65 | parser = self.initialize(parser)
66 |
67 | # get the basic options
68 | opt, _ = parser.parse_known_args()
69 |
70 | # modify model-related parser options
71 | model_name = opt.model
72 | model_option_setter = models.get_option_setter(model_name)
73 | parser = model_option_setter(parser, self.isTrain)
74 | opt, _ = parser.parse_known_args() # parse again with new defaults
75 |
76 | # modify dataset-related parser options
77 | dataset_name = opt.dataset_mode
78 | dataset_option_setter = data.get_option_setter(dataset_name)
79 | parser = dataset_option_setter(parser, self.isTrain)
80 |
81 | # save and return the parser
82 | self.parser = parser
83 | return parser.parse_args()
84 |
85 | def print_options(self, opt):
86 | """Print and save options
87 |
88 | It will print both current options and default values(if different).
89 | It will save options into a text file / [checkpoints_dir] / opt.txt
90 | """
91 | message = ''
92 | message += '----------------- Options ---------------\n'
93 | for k, v in sorted(vars(opt).items()):
94 | comment = ''
95 | default = self.parser.get_default(k)
96 | if v != default:
97 | comment = '\t[default: %s]' % str(default)
98 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
99 | message += '----------------- End -------------------'
100 | print(message)
101 |
102 | # save to the disk
103 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
104 | util.mkdirs(expr_dir)
105 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
106 | with open(file_name, 'wt') as opt_file:
107 | opt_file.write(message)
108 | opt_file.write('\n')
109 |
110 | def parse(self):
111 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
112 | opt = self.gather_options()
113 | opt.isTrain = self.isTrain # train or test
114 |
115 | # process opt.suffix
116 | if opt.suffix:
117 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
118 | opt.name = opt.name + suffix
119 |
120 | self.print_options(opt)
121 |
122 | # set gpu ids
123 | str_ids = opt.gpu_ids.split(',')
124 | opt.gpu_ids = []
125 | for str_id in str_ids:
126 | id = int(str_id)
127 | if id >= 0:
128 | opt.gpu_ids.append(id)
129 | if len(opt.gpu_ids) > 0:
130 | torch.cuda.set_device(opt.gpu_ids[0])
131 | opt.A = 2 * 110.0 / 10.0 + 1
132 |
133 | self.opt = opt
134 | return self.opt
135 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | # Dropout and Batchnorm has different behavioir during training and test.
17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
19 | # rewrite devalue values
20 | parser.set_defaults(model='colorization')
21 | # To avoid cropping, the load_size should be the same as crop_size
22 | # parser.set_defaults(load_size=parser.get_default('crop_size'))
23 | self.isTrain = False
24 | return parser
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.10.0
2 | astor==0.8.1
3 | axial-positional-embedding==0.2.1
4 | certifi==2020.6.20
5 | chardet==3.0.4
6 | cycler==0.10.0
7 | decorator==4.4.2
8 | dominate==2.5.1
9 | future==0.18.2
10 | gast==0.4.0
11 | google-pasta==0.2.0
12 | graphviz==0.14.2
13 | grpcio==1.33.1
14 | h5py==2.10.0
15 | idna==2.10
16 | imageio==2.9.0
17 | importlib-metadata==2.0.0
18 | jsonpatch==1.26
19 | jsonpointer==2.0
20 | Keras==2.2.4
21 | Keras-Applications==1.0.8
22 | Keras-Preprocessing==1.1.2
23 | kiwisolver==1.2.0
24 | local-attention==1.0.2
25 | Markdown==3.3.2
26 | matplotlib==3.3.0
27 | mkl-fft==1.1.0
28 | mkl-random==1.1.1
29 | mkl-service==2.3.0
30 | networkx==2.4
31 | numpy==1.18.5
32 | olefile==0.46
33 | opencv-contrib-python==3.4.2.16
34 | opencv-python==3.4.2.16
35 | Pillow==7.2.0
36 | product-key-memory==0.1.10
37 | protobuf==3.13.0
38 | pyheatmap==0.1.12
39 | pyparsing==2.4.7
40 | python-dateutil==2.8.1
41 | PyWavelets==1.1.1
42 | PyYAML==5.3.1
43 | pyzmq==19.0.1
44 | reformer-pytorch==1.1.3
45 | requests==2.24.0
46 | scikit-image==0.17.2
47 | scipy==1.2.1
48 | six==1.15.0
49 | spatial-correlation-sampler==0.3.0
50 | tensorboard==1.14.0
51 | tensorflow-estimator==1.14.0
52 | termcolor==1.1.0
53 | tifffile==2020.7.24
54 | torch==1.5.1
55 | torchfile==0.1.0
56 | torchvision==0.6.0a0+35d732a
57 | torchviz @ git+https://github.com/szagoruyko/pytorchviz@46add7f2c071b6d29fc3d56e9d2d21e1c0a3af1d
58 | tornado==6.0.4
59 | tqdm==4.48.0
60 | urllib3==1.25.10
61 | visdom==0.1.8.9
62 | websocket-client==0.57.0
63 | Werkzeug==1.0.1
64 | wrapt==1.12.1
65 | zipp==3.3.1
66 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from options.test_options import TestOptions
3 | from data import create_dataset
4 | from models import create_model
5 | from util.visualizer import save_images
6 | from util import html
7 | import numpy as np
8 |
9 |
10 | if __name__ == '__main__':
11 | opt = TestOptions().parse()
12 | opt.num_threads = 0
13 | opt.batch_size = 1
14 | opt.serial_batches = True
15 | opt.no_flip = True
16 | opt.display_id = -1
17 |
18 | dataset = create_dataset(opt)
19 | model = create_model(opt)
20 | model.setup(opt)
21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch))
22 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
23 | scores = []
24 | if opt.eval:
25 | model.eval()
26 | for i, data in enumerate(dataset):
27 | model.set_input(data)
28 | model.test()
29 | visuals = model.get_current_visuals()
30 | img_path = model.get_image_paths()
31 | metrics = model.compute_scores()
32 | scores.extend(metrics)
33 | print('processing (%04d)-th image... %s' % (i, img_path))
34 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
35 | webpage.save()
36 | print('Histogram Intersection: %.4f' % np.mean(scores))
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python3 -u test.py --targetImage_path ./imgs/target.JPEG --referenceImage_path ./imgs/reference.JPEG --gpu_id 1
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def add_more_images(self, ims, links, width=400):
69 | """add images to the HTML file
70 |
71 | Parameters:
72 | ims (str list) -- a list of image paths
73 | txts (str list) -- a list of image names shown on the website
74 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
75 | """
76 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
77 | self.doc.add(self.t)
78 | with self.t:
79 | for i in range(2):
80 | with tr():
81 | for im, link in zip(ims[i], links[i]):
82 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
83 | with p():
84 | with a(href=os.path.join('images', link)):
85 | img(style="width:%dpx" % width, src=os.path.join('images', im))
86 | br()
87 |
88 | def save(self):
89 | """save the current content to the HMTL file"""
90 | html_file = '%s/index.html' % self.web_dir
91 | f = open(html_file, 'wt')
92 | f.write(self.doc.render())
93 | f.close()
94 |
95 |
96 | if __name__ == '__main__': # we show an example usage here.
97 | html = HTML('web/', 'test_html')
98 | html.add_header('hello world')
99 |
100 | ims, txts, links = [], [], []
101 | for n in range(4):
102 | ims.append('image_%d.png' % n)
103 | txts.append('text_%d' % n)
104 | links.append('image_%d.png' % n)
105 | html.add_images(ims, txts, links)
106 | html.save()
107 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 | from __future__ import print_function
3 | import torch
4 | import numpy as np
5 | from scipy import linalg
6 | from PIL import Image
7 | from skimage import color
8 | import os
9 |
10 |
11 | def tensor2im(input_image, imtype=np.uint8):
12 | """"Converts a Tensor array into a numpy image array.
13 |
14 | Parameters:
15 | input_image (tensor) -- the input image tensor array
16 | imtype (type) -- the desired type of the converted numpy array
17 | """
18 | if not isinstance(input_image, np.ndarray):
19 | if isinstance(input_image, torch.Tensor): # get the data from a variable
20 | image_tensor = input_image.data
21 | else:
22 | return input_image
23 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
24 | if image_numpy.shape[0] == 1: # grayscale to RGB
25 | image_numpy = np.tile(image_numpy, (3, 1, 1))
26 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
27 | else: # if it is a numpy array, do nothing
28 | image_numpy = input_image
29 | return image_numpy.astype(imtype)
30 |
31 |
32 | def save_image(image_numpy, image_path):
33 | """Save a numpy image to the disk
34 |
35 | Parameters:
36 | image_numpy (numpy array) -- input numpy array
37 | image_path (str) -- the path of the image
38 | """
39 | image_pil = Image.fromarray(image_numpy)
40 | image_pil.save(image_path)
41 |
42 |
43 | def mkdirs(paths):
44 | """create empty directories if they don't exist
45 |
46 | Parameters:
47 | paths (str list) -- a list of directory paths
48 | """
49 | if isinstance(paths, list) and not isinstance(paths, str):
50 | for path in paths:
51 | mkdir(path)
52 | else:
53 | mkdir(paths)
54 |
55 |
56 | def mkdir(path):
57 | """create a single empty directory if it didn't exist
58 |
59 | Parameters:
60 | path (str) -- a single directory path
61 | """
62 | if not os.path.exists(path):
63 | os.makedirs(path)
64 |
65 |
66 | def calc_hist(data_ab, device):
67 | N, C, H, W = data_ab.shape
68 | grid_a = torch.linspace(-1, 1, 21).view(1, 21, 1, 1, 1).expand(N, 21, 21, H, W).to(device)
69 | grid_b = torch.linspace(-1, 1, 21).view(1, 1, 21, 1, 1).expand(N, 21, 21, H, W).to(device)
70 | hist_a = torch.max(0.1 - torch.abs(grid_a - data_ab[:, 0, :, :].view(N, 1, 1, H, W)), torch.Tensor([0]).to(device)) * 10
71 | hist_b = torch.max(0.1 - torch.abs(grid_b - data_ab[:, 1, :, :].view(N, 1, 1, H, W)), torch.Tensor([0]).to(device)) * 10
72 | hist = (hist_a * hist_b).mean(dim=(3, 4)).view(N, -1)
73 | return hist
74 |
75 |
76 | class Convert(object):
77 | def __init__(self, device):
78 | xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423],
79 | [0.212671, 0.715160, 0.072169],
80 | [0.019334, 0.119193, 0.950227]])
81 | rgb_from_xyz = linalg.inv(xyz_from_rgb)
82 | self.rgb_from_xyz = torch.Tensor(rgb_from_xyz).to(device)
83 | self.channel_mask = torch.Tensor([1, 0, 0]).to(device)
84 | self.xyz_weight = torch.Tensor([0.95047, 1., 1.08883]).to(device)
85 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).to(device)
86 | self.std = torch.Tensor([0.229, 0.224, 0.225]).to(device)
87 | self.zero = torch.Tensor([0]).to(device)
88 | self.one = torch.Tensor([1]).to(device)
89 |
90 | def lab2rgb(self, img):
91 | img = img.permute(0, 2, 3, 1)
92 | img1 = (img + 1.0) * 50.0 * self.channel_mask
93 | img2 = img * 110.0 * (1 - self.channel_mask)
94 | img = img1 + img2
95 | return self.xyz2rgb(self.lab2xyz(img))
96 |
97 | def lab2xyz(self, img):
98 | L, a, b = img[:, :, :, 0], img[:, :, :, 1], img[:, :, :, 2]
99 | y = (L + 16.) / 116.
100 | x = (a / 500.) + y
101 | z = y - (b / 200.)
102 | z = torch.max(z, self.zero)
103 | out = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1)], dim=-1)
104 | mask = (out > 0.2068966).float()
105 | out1 = torch.pow(out, 3) * mask
106 | out2 = (out - 16.0 / 116.) / 7.787 * (1 - mask)
107 | out = out1 + out2
108 | out *= self.xyz_weight
109 | return out
110 |
111 | def xyz2rgb(self, img):
112 | arr = img.matmul(self.rgb_from_xyz.t())
113 | mask = (arr > 0.0031308).float()
114 | arr1 = (1.055 * torch.pow(torch.max(arr, self.zero), 1 / 2.4) - 0.055) * mask
115 | arr2 = arr * 12.92 * (1 - mask)
116 | arr = arr1 + arr2
117 | arr = torch.min(torch.max(arr, self.zero), self.one)
118 | return arr
119 |
120 | def rgb_norm(self, img):
121 | img = (img - self.mean) / self.std
122 | return img.permute(0, 3, 1, 2)
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import ntpath
5 | import time
6 | from . import util, html
7 | from subprocess import Popen, PIPE
8 | from scipy.misc import imresize
9 |
10 | if sys.version_info[0] == 2:
11 | VisdomExceptionBase = Exception
12 | else:
13 | VisdomExceptionBase = ConnectionError
14 |
15 |
16 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
17 | """Save images to the disk.
18 |
19 | Parameters:
20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
21 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
22 | image_path (str) -- the string is used to create image paths
23 | aspect_ratio (float) -- the aspect ratio of saved images
24 | width (int) -- the images will be resized to width x width
25 |
26 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
27 | """
28 | image_dir = webpage.get_image_dir()
29 | # short_path = ntpath.basename(image_path[0])
30 | short_path = image_path[0].split('/')
31 | short_path = short_path[-2] + '_' + short_path[-1]
32 | name = os.path.splitext(short_path)[0]
33 |
34 | webpage.add_header(name)
35 | ims, txts, links = [], [], []
36 |
37 | for label, im_data in visuals.items():
38 | im = util.tensor2im(im_data)
39 | image_name = '%s_%s.png' % (name, label)
40 | save_path = os.path.join(image_dir, image_name)
41 | h, w, _ = im.shape
42 | if aspect_ratio > 1.0:
43 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
44 | if aspect_ratio < 1.0:
45 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
46 | util.save_image(im, save_path)
47 |
48 | ims.append(image_name)
49 | txts.append(label)
50 | links.append(image_name)
51 | webpage.add_images(ims, txts, links, width=width)
52 |
53 |
54 | def save_more_images(webpage, name, sources, targets, aspect_ratio=1.0, width=256):
55 | image_dir = webpage.get_image_dir()
56 |
57 | webpage.add_header('%04d' % name)
58 | ims, links = [], []
59 |
60 | for dnum, d in enumerate([sources, targets]):
61 | ims.append([])
62 | links.append([])
63 | for idx, im_data in enumerate(d):
64 | im = util.tensor2im(im_data)
65 | image_name = '%04d_%02d_%s.png' % (name, idx, ['S', 'T'][dnum])
66 | save_path = os.path.join(image_dir, image_name)
67 | h, w, _ = im.shape
68 | if aspect_ratio > 1.0:
69 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
70 | if aspect_ratio < 1.0:
71 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
72 | util.save_image(im, save_path)
73 | ims[dnum].append(image_name)
74 | links[dnum].append(image_name)
75 | webpage.add_more_images(ims, links, width=width)
76 |
77 |
78 | class Visualizer():
79 | """This class includes several functions that can display/save images and print/save logging information.
80 |
81 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
82 | """
83 |
84 | def __init__(self, opt):
85 | """Initialize the Visualizer class
86 |
87 | Parameters:
88 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
89 | Step 1: Cache the training/test options
90 | Step 2: connect to a visdom server
91 | Step 3: create an HTML object for saveing HTML filters
92 | Step 4: create a logging file to store training losses
93 | """
94 | self.opt = opt # cache the option
95 | self.display_id = opt.display_id
96 | self.use_html = opt.isTrain and not opt.no_html
97 | self.win_size = opt.display_winsize
98 | self.name = opt.name
99 | self.port = opt.display_port
100 | self.saved = False
101 | if self.display_id > 0: # connect to a visdom server given and
102 | import visdom
103 | self.ncols = opt.display_ncols
104 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
105 | if not self.vis.check_connection():
106 | self.create_visdom_connections()
107 |
108 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
109 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
110 | self.img_dir = os.path.join(self.web_dir, 'images')
111 | print('create web directory %s...' % self.web_dir)
112 | util.mkdirs([self.web_dir, self.img_dir])
113 | # create a logging file to store training losses
114 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
115 | with open(self.log_name, "a") as log_file:
116 | now = time.strftime("%c")
117 | log_file.write('================ Training Loss (%s) ================\n' % now)
118 |
119 | def reset(self):
120 | """Reset the self.saved status"""
121 | self.saved = False
122 |
123 | def create_visdom_connections(self):
124 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
125 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
126 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
127 | print('Command: %s' % cmd)
128 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
129 |
130 | def display_current_results(self, visuals, epoch, save_result):
131 | """Display current results on visdom; save current results to an HTML file.
132 |
133 | Parameters:
134 | visuals (OrderedDict) - - dictionary of images to display or save
135 | epoch (int) - - the current epoch
136 | save_result (bool) - - if save the current results to an HTML file
137 | """
138 | if self.display_id > 0: # show images in the browser using visdom
139 | ncols = self.ncols
140 | if ncols > 0: # show all the images in one visdom panel
141 | ncols = min(ncols, len(visuals))
142 | h, w = next(iter(visuals.values())).shape[:2]
143 | table_css = """""" % (w, h) # create a table css
147 | # create a table of images.
148 | title = self.name
149 | label_html = ''
150 | label_html_row = ''
151 | images = []
152 | idx = 0
153 | for label, image in visuals.items():
154 | image_numpy = util.tensor2im(image)
155 | label_html_row += '%s | ' % label
156 | images.append(image_numpy.transpose([2, 0, 1]))
157 | idx += 1
158 | if idx % ncols == 0:
159 | label_html += '%s
' % label_html_row
160 | label_html_row = ''
161 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
162 | while idx % ncols != 0:
163 | images.append(white_image)
164 | label_html_row += ' | '
165 | idx += 1
166 | if label_html_row != '':
167 | label_html += '%s
' % label_html_row
168 | try:
169 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
170 | padding=2, opts=dict(title=title + ' images'))
171 | label_html = '' % label_html
172 | self.vis.text(table_css + label_html, win=self.display_id + 2,
173 | opts=dict(title=title + ' labels'))
174 | except VisdomExceptionBase:
175 | self.create_visdom_connections()
176 |
177 | else: # show each image in a separate visdom panel;
178 | idx = 1
179 | try:
180 | for label, image in visuals.items():
181 | image_numpy = util.tensor2im(image)
182 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
183 | win=self.display_id + idx)
184 | idx += 1
185 | except VisdomExceptionBase:
186 | self.create_visdom_connections()
187 |
188 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
189 | self.saved = True
190 | # save images to the disk
191 | for label, image in visuals.items():
192 | image_numpy = util.tensor2im(image)
193 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
194 | util.save_image(image_numpy, img_path)
195 |
196 | # update website
197 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
198 | for n in range(epoch, 0, -1):
199 | webpage.add_header('epoch [%d]' % n)
200 | ims, txts, links = [], [], []
201 |
202 | for label, image_numpy in visuals.items():
203 | image_numpy = util.tensor2im(image)
204 | img_path = 'epoch%.3d_%s.png' % (n, label)
205 | ims.append(img_path)
206 | txts.append(label)
207 | links.append(img_path)
208 | webpage.add_images(ims, txts, links, width=self.win_size)
209 | webpage.save()
210 |
211 | def plot_current_losses(self, epoch, counter_ratio, losses):
212 | """display the current losses on visdom display: dictionary of error labels and values
213 |
214 | Parameters:
215 | epoch (int) -- current epoch
216 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
217 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
218 | """
219 | if not hasattr(self, 'plot_data'):
220 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
221 | self.plot_data['X'].append(epoch + counter_ratio)
222 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
223 | try:
224 | self.vis.line(
225 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
226 | Y=np.array(self.plot_data['Y']),
227 | opts={
228 | 'title': self.name + ' loss over time',
229 | 'legend': self.plot_data['legend'],
230 | 'xlabel': 'epoch',
231 | 'ylabel': 'loss'},
232 | win=self.display_id)
233 | except VisdomExceptionBase:
234 | self.create_visdom_connections()
235 |
236 | # losses: same format as |losses| of plot_current_losses
237 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
238 | """print current losses on console; also save the losses to the disk
239 |
240 | Parameters:
241 | epoch (int) -- current epoch
242 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
243 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
244 | t_comp (float) -- computational time per data point (normalized by batch_size)
245 | t_data (float) -- data loading time per data point (normalized by batch_size)
246 | """
247 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
248 | for k, v in losses.items():
249 | message += '%s: %.3f ' % (k, v)
250 |
251 | print(message) # print the message
252 | with open(self.log_name, "a") as log_file:
253 | log_file.write('%s\n' % message) # save the message
254 |
--------------------------------------------------------------------------------