├── .DS_Store
├── .gitignore
├── .gitmodules
├── .idea
├── Face-SPARNet.iml
├── misc.xml
├── modules.xml
└── workspace.xml
├── Network.png
├── README.md
├── data
├── __init__.py
├── base_dataset.py
├── celeba_dataset.py
├── ffhq_dataset.py
├── image_folder.py
└── single_dataset.py
├── img
├── Snipaste_2023-05-09_22-20-55.png
└── compare_CelebA_0213.png
├── list.txt
├── log_test
├── .DS_Store
└── log.txt
├── models
├── __init__.py
├── base_model.py
├── blocks.py
├── common.py
├── common_ESTR.py
├── ctcnet.py
├── ctcnet_model.py
├── loss.py
├── networks.py
├── rlutrans.py
└── utils
│ ├── logger.py
│ ├── rlutrans.py
│ ├── timer.py
│ ├── tools.py
│ └── utils.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── psnr_ssim.py
├── psnr_ssim_log.py
├── read list.py
├── requirement.txt
├── test.py
├── test.sh
├── train.py
├── train.sh
├── util
├── .DS_Store
├── __init__.py
├── rlutrans.py
└── tools.py
└── utils
├── __init__.py
├── logger.py
├── timer.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | check_points/
2 | pretrain_models/
3 | results*
4 | test_dirs/
5 | test_hzgg_results/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | env/
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # dotenv
89 | .env
90 |
91 | # virtualenv
92 | .venv
93 | venv/
94 | ENV/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "metrics/face-alignment"]
2 | path = metrics/face-alignment
3 | url = https://github.com/1adrianb/face-alignment.git
4 |
--------------------------------------------------------------------------------
/.idea/Face-SPARNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
85 |
86 |
87 |
88 |
89 | true
90 | DEFINITION_ORDER
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 | 1612314895572
187 |
188 |
189 | 1612314895572
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
--------------------------------------------------------------------------------
/Network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/Network.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CTCNet: A CNN-Transformer Cooperation Network for Face Image Super-Resolution in PyTorch
2 |
3 | [A CNN-Transformer Cooperation Network for Face Image Super-Resolution](https://arxiv.org/abs/2204.08696v2)
4 | [**Guangwei Gao**](https://guangweigao.github.io/), [Zixiang Xu](https://github.com/wsxuzixiang)
5 |
6 | 
7 |
8 | ## Comparisons for ✖️8 SR on the CelebA test set.
9 | 
10 |
11 |
12 | 
13 |
14 |
15 | ## Installation and Requirements
16 |
17 | Clone this repository
18 | ```
19 | git clone https://github.com/IVIPLab/CTCNet
20 | cd CTCNet
21 | ```
22 |
23 | I have tested the codes on
24 | -install required packages by `pip install -r requirements.txt`
25 |
26 |
27 | ### Test with Pretrained Models
28 |
29 | We provide example test commands in script `test.sh` for both CTCNet. Two models with difference configurations are provided for each of them, refer to [section below](#differences-with-the-paper) to see the differences. Here are some test tips:
30 |
31 | - CTCNet upsample a 16x16 bicubic downsampled face image to 128x128, and there is **no need to align the LR face**.
32 | - Please specify test input directory with `--dataroot` option.
33 | - Please specify save path with `--save_as_dir`, otherwise the results will be saved to predefined directory `results/exp_name/test_latest`.
34 |
35 | ### Train the Model
36 |
37 | The commands used to train the released models are provided in script `train.sh`. Here are some train tips:
38 |
39 | - You should download [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) to train CTCNet and CTCGAN respectively. Please change the `--dataroot` to the path where your training images are stored.
40 | - To train CTCNet, we simply crop out faces from CelebA without pre-alignment, because for ultra low resolution face SR, it is difficult to pre-align the LR images.
41 | - Please change the `--name` option for different experiments. Tensorboard records with the same name will be moved to `check_points/log_archive`, and the weight directory will only store weight history of latest experiment with the same name.
42 | - `--gpus` specify number of GPUs used to train. The script will use GPUs with more available memory first. To specify the GPU index, uncomment the `export CUDA_VISIBLE_DEVICES=`
43 |
44 | ### Pretrained models
45 |
46 | The **pretrained models** and **test results** can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1sJs2JYqddSk1o4hksOrO2Fk2ciRelXUQ) .
47 |
48 | ## Citation
49 | ```
50 | @article{gao2023ctcnet,
51 | title={Ctcnet: a cnn-transformer cooperation network for face image super-resolution},
52 | author={Gao, Guangwei and Xu, Zixiang and Li, Juncheng and Yang, Jian and Zeng, Tieyong and Qi, Guo-Jun},
53 | journal={IEEE Transactions on Image Processing},
54 | year={2023},
55 | publisher={IEEE}
56 | }
57 | ```
58 |
59 | ## License
60 |
61 | 
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
62 |
63 | ## Acknowledgement
64 |
65 | The codes are based on [SPARNet](https://github.com/chaofengc/Face-SPARNet).
66 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from data.base_dataset import BaseDataset
16 |
17 |
18 | def find_dataset_using_name(dataset_name):
19 | """Import the module "data/[dataset_name]_dataset.py".
20 |
21 | In the file, the class called DatasetNameDataset() will
22 | be instantiated. It has to be a subclass of BaseDataset,
23 | and it is case-insensitive.
24 | """
25 | dataset_filename = "data." + dataset_name + "_dataset"
26 | datasetlib = importlib.import_module(dataset_filename)
27 |
28 | dataset = None
29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30 | for name, cls in datasetlib.__dict__.items():
31 | if name.lower() == target_dataset_name.lower() \
32 | and issubclass(cls, BaseDataset):
33 | dataset = cls
34 |
35 | if dataset is None:
36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37 |
38 | return dataset
39 |
40 |
41 | def get_option_setter(dataset_name):
42 | """Return the static method of the dataset class."""
43 | dataset_class = find_dataset_using_name(dataset_name)
44 | return dataset_class.modify_commandline_options
45 |
46 |
47 | def create_dataset(opt):
48 | """Create a dataset given the option.
49 |
50 | This function wraps the class CustomDatasetDataLoader.
51 | This is the main interface between this package and 'train.py'/'test.py'
52 |
53 | Example:
54 | >>> from data import create_dataset
55 | >>> dataset = create_dataset(opt)
56 | """
57 | data_loader = CustomDatasetDataLoader(opt)
58 | dataset = data_loader.load_data()
59 | return dataset
60 |
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_name)
73 | self.dataset = dataset_class(opt)
74 | print("dataset [%s] was created" % type(self.dataset).__name__)
75 | drop_last = True if opt.isTrain else False
76 | self.dataloader = torch.utils.data.DataLoader(
77 | self.dataset,
78 | batch_size=opt.batch_size,
79 | shuffle=not opt.serial_batches,
80 | num_workers=int(opt.num_threads), drop_last=drop_last)
81 |
82 | def load_data(self):
83 | return self
84 |
85 | def __len__(self):
86 | """Return the number of data in the dataset"""
87 | return min(len(self.dataset), self.opt.max_dataset_size)
88 |
89 | def __iter__(self):
90 | """Return a batch of data"""
91 | for i, data in enumerate(self.dataloader):
92 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
93 | break
94 | yield data
95 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 | import imgaug as ia
13 | import imgaug.augmenters as iaa
14 |
15 | class BaseDataset(data.Dataset, ABC):
16 | """This class is an abstract base class (ABC) for datasets.
17 |
18 | To create a subclass, you need to implement the following four functions:
19 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
20 | -- <__len__>: return the size of dataset.
21 | -- <__getitem__>: get a data point.
22 | -- : (optionally) add dataset-specific options and set default options.
23 | """
24 |
25 | def __init__(self, opt):
26 | """Initialize the class; save the options in the class
27 |
28 | Parameters:
29 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
30 | """
31 | self.opt = opt
32 | self.root = opt.dataroot
33 |
34 | @staticmethod
35 | def modify_commandline_options(parser, is_train):
36 | """Add new dataset-specific options, and rewrite default values for existing options.
37 |
38 | Parameters:
39 | parser -- original option parser
40 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
41 |
42 | Returns:
43 | the modified parser.
44 | """
45 | return parser
46 |
47 | @abstractmethod
48 | def __len__(self):
49 | """Return the total number of images in the dataset."""
50 | return 0
51 |
52 | @abstractmethod
53 | def __getitem__(self, index):
54 | """Return a data point and its metadata information.
55 |
56 | Parameters:
57 | index - - a random integer for data indexing
58 |
59 | Returns:
60 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
61 | """
62 | pass
63 |
64 |
65 | def get_params(opt, size):
66 | w, h = size
67 | new_h = h
68 | new_w = w
69 | if opt.preprocess == 'resize_and_crop':
70 | new_h = new_w = opt.load_size
71 | elif opt.preprocess == 'scale_width_and_crop':
72 | new_w = opt.load_size
73 | new_h = opt.load_size * h // w
74 |
75 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
76 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
77 |
78 | flip = random.random() > 0.5
79 |
80 | return {'crop_pos': (x, y), 'flip': flip}
81 |
82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83 | transform_list = []
84 | if grayscale:
85 | # transform_list.append(transforms.Grayscale(1))
86 | from util import util
87 | transform_list.append(util.RGBtoY)
88 | if 'resize' in opt.preprocess:
89 | osize = [opt.load_size, opt.load_size]
90 | transform_list.append(transforms.Resize(osize, method))
91 | elif 'scale_width' in opt.preprocess:
92 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
93 |
94 | if 'crop' in opt.preprocess:
95 | if params is None:
96 | transform_list.append(transforms.RandomCrop(opt.crop_size))
97 | else:
98 | if 'crop_size' in params:
99 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], params['crop_size'])))
100 | else:
101 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
102 |
103 | if opt.preprocess == 'none':
104 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
105 |
106 | if not opt.no_flip:
107 | if params is None:
108 | transform_list.append(transforms.RandomHorizontalFlip())
109 | elif params['flip']:
110 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
111 |
112 | if convert:
113 | transform_list += [transforms.ToTensor()]
114 | if grayscale:
115 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
116 | else:
117 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
118 | return transforms.Compose(transform_list)
119 |
120 |
121 | def __make_power_2(img, base, method=Image.BICUBIC):
122 | ow, oh = img.size
123 | h = int(round(oh / base) * base)
124 | w = int(round(ow / base) * base)
125 | if (h == oh) and (w == ow):
126 | return img
127 |
128 | __print_size_warning(ow, oh, w, h)
129 | return img.resize((w, h), method)
130 |
131 |
132 | def __scale_width(img, target_width, method=Image.BICUBIC):
133 | ow, oh = img.size
134 | if (ow == target_width):
135 | return img
136 | w = target_width
137 | h = int(target_width * oh / ow)
138 | return img.resize((w, h), method)
139 |
140 |
141 | def __crop(img, pos, size):
142 | ow, oh = img.size
143 | x1, y1 = pos
144 | tw = th = size
145 | if (ow > tw or oh > th):
146 | return img.crop((x1, y1, x1 + tw, y1 + th))
147 | return img
148 |
149 |
150 | def __flip(img, flip):
151 | if flip:
152 | return img.transpose(Image.FLIP_LEFT_RIGHT)
153 | return img
154 |
155 |
156 | def __print_size_warning(ow, oh, w, h):
157 | """Print warning information about image size(only print once)"""
158 | if not hasattr(__print_size_warning, 'has_printed'):
159 | print("The image size needs to be a multiple of 4. "
160 | "The loaded image size was (%d, %d), so it was adjusted to "
161 | "(%d, %d). This adjustment will be done to all images "
162 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
163 | __print_size_warning.has_printed = True
164 |
165 |
--------------------------------------------------------------------------------
/data/celeba_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from PIL import Image
5 | import imgaug as ia
6 | import imgaug.augmenters as iaa
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import transforms
11 | import torchvision.transforms.functional as tf
12 |
13 | from data.base_dataset import BaseDataset
14 |
15 |
16 | class CelebADataset(BaseDataset):
17 | def __init__(self, opt):
18 | BaseDataset.__init__(self, opt)
19 |
20 | self.shuffle = True if opt.isTrain else False
21 | self.lr_size = opt.load_size // opt.scale_factor
22 | self.hr_size = opt.load_size
23 |
24 | self.img_dir = opt.dataroot
25 | self.img_names = self.get_img_names()
26 |
27 | self.aug = transforms.Compose([
28 | transforms.RandomHorizontalFlip(),
29 | Scale((1.0, 1.3), opt.load_size)
30 | ])
31 |
32 | self.to_tensor = transforms.Compose([
33 | transforms.ToTensor(),
34 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
35 | ])
36 |
37 | def get_img_names(self,):
38 | img_names = [x for x in os.listdir(self.img_dir)]
39 | if self.shuffle:
40 | random.shuffle(img_names)
41 | return img_names
42 |
43 | def __len__(self,):
44 | return len(self.img_names)
45 |
46 | def __getitem__(self, idx):
47 | img_path = os.path.join(self.img_dir, self.img_names[idx])
48 |
49 | hr_img = Image.open(img_path).convert('RGB')
50 | hr_img = self.aug(hr_img)
51 |
52 | # downsample and upsample to get the LR image
53 | lr_img = hr_img.resize((self.lr_size, self.lr_size), Image.BICUBIC)
54 | lr_img_up = lr_img.resize((self.hr_size, self.hr_size), Image.BICUBIC)
55 |
56 | hr_tensor = self.to_tensor( hr_img)
57 | lr_tensor = self.to_tensor(lr_img_up)
58 |
59 | return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path}
60 |
61 |
62 | class Scale():
63 | """
64 | Random scale the image and pad to the same size if needed.
65 | ---------------
66 | # Args:
67 | factor: tuple input, max and min scale factor.
68 | """
69 | def __init__(self, factor, size):
70 | self.factor = factor
71 | rc_scale = (2 - factor[1], 1)
72 | self.size = (size, size)
73 | self.rc_scale = rc_scale
74 | self.ratio = (3. / 4., 4. / 3.)
75 | self.resize_crop = transforms.RandomResizedCrop(size, rc_scale)
76 |
77 | def __call__(self, img):
78 | scale_factor = random.random() * (self.factor[1] - self.factor[0]) + self.factor[0]
79 | w, h = img.size
80 | sw, sh = int(w*scale_factor), int(h*scale_factor)
81 | scaled_img = tf.resize(img, (sh, sw))
82 | if sw > w:
83 | i, j, h, w = self.resize_crop.get_params(img, self.rc_scale, self.ratio)
84 | scaled_img = tf.resized_crop(img, i, j, h, w, self.size, Image.BICUBIC)
85 | elif sw < w:
86 | lp = (w - sw) // 2
87 | tp = (h - sh) // 2
88 | padding = (lp, tp, w - sw - lp, h - sh - tp)
89 | scaled_img = tf.pad(scaled_img, padding)
90 | return scaled_img
91 |
92 |
--------------------------------------------------------------------------------
/data/ffhq_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from PIL import Image
5 | import imgaug as ia
6 | import imgaug.augmenters as iaa
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import transforms
11 |
12 | from data.base_dataset import BaseDataset
13 |
14 |
15 | class FFHQDataset(BaseDataset):
16 | def __init__(self, opt):
17 | BaseDataset.__init__(self, opt)
18 | self.img_size = opt.load_size
19 | self.shuffle = True if opt.isTrain else False
20 |
21 | self.img_dir = opt.dataroot
22 | self.img_names = self.get_img_names()
23 |
24 | self.to_tensor = transforms.Compose([
25 | transforms.ToTensor(),
26 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
27 | ])
28 |
29 | def get_img_names(self,):
30 | img_names = [x for x in os.listdir(self.img_dir)]
31 | if self.shuffle:
32 | random.shuffle(img_names)
33 | return img_names
34 |
35 | def __len__(self,):
36 | return len(self.img_names)
37 |
38 | def __getitem__(self, idx):
39 | sample = {}
40 | img_path = os.path.join(self.img_dir, self.img_names[idx])
41 |
42 | hr_img = Image.open(img_path).convert('RGB')
43 | hr_img = hr_img.resize((self.img_size, self.img_size))
44 | hr_img = random_gray(hr_img, p=0.3)
45 | scale_size = np.random.randint(32, 128)
46 | lr_img = complex_imgaug(hr_img, self.img_size, scale_size)
47 |
48 | hr_tensor = self.to_tensor(hr_img)
49 | lr_tensor = self.to_tensor(lr_img)
50 |
51 | return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path}
52 |
53 |
54 | def complex_imgaug(x, org_size, scale_size):
55 | """input single RGB PIL Image instance"""
56 | x = np.array(x)
57 | x = x[np.newaxis, :, :, :]
58 | aug_seq = iaa.Sequential([
59 | iaa.Sometimes(0.5, iaa.OneOf([
60 | iaa.GaussianBlur((3, 15)),
61 | iaa.AverageBlur(k=(3, 15)),
62 | iaa.MedianBlur(k=(3, 15)),
63 | iaa.MotionBlur((5, 25))
64 | ])),
65 | iaa.Resize(scale_size, interpolation=ia.ALL),
66 | iaa.Sometimes(0.2, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.1*255), per_channel=0.5)),
67 | iaa.Sometimes(0.7, iaa.JpegCompression(compression=(10, 65))),
68 | iaa.Resize(org_size),
69 | ])
70 |
71 | aug_img = aug_seq(images=x)
72 | return aug_img[0]
73 |
74 |
75 | def random_gray(x, p=0.5):
76 | """input single RGB PIL Image instance"""
77 | x = np.array(x)
78 | x = x[np.newaxis, :, :, :]
79 | aug = iaa.Sometimes(p, iaa.Grayscale(alpha=1.0))
80 | aug_img = aug(images=x)
81 | return aug_img[0]
82 |
83 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | '.tif', '.TIF', '.tiff', '.TIFF',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir, max_dataset_size=float("inf")):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 | return images[:min(max_dataset_size, len(images))]
34 |
35 |
36 | def default_loader(path):
37 | return Image.open(path).convert('RGB')
38 |
39 |
40 | class ImageFolder(data.Dataset):
41 |
42 | def __init__(self, root, transform=None, return_paths=False,
43 | loader=default_loader):
44 | imgs = make_dataset(root)
45 | if len(imgs) == 0:
46 | raise(RuntimeError("Found 0 images in: " + root + "\n"
47 | "Supported image extensions are: " +
48 | ",".join(IMG_EXTENSIONS)))
49 |
50 | self.root = root
51 | self.imgs = imgs
52 | self.transform = transform
53 | self.return_paths = return_paths
54 | self.loader = loader
55 |
56 | def __getitem__(self, index):
57 | path = self.imgs[index]
58 | img = self.loader(path)
59 | if self.transform is not None:
60 | img = self.transform(img)
61 | if self.return_paths:
62 | return img, path
63 | else:
64 | return img
65 |
66 | def __len__(self):
67 | return len(self.imgs)
68 |
--------------------------------------------------------------------------------
/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | from data.base_dataset import BaseDataset, get_transform
2 | from data.image_folder import make_dataset
3 | from PIL import Image
4 |
5 | class SingleDataset(BaseDataset):
6 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
7 | """
8 |
9 | def __init__(self, opt):
10 | """Initialize this dataset class.
11 |
12 | Parameters:
13 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
14 | """
15 | BaseDataset.__init__(self, opt)
16 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
17 | input_nc = self.opt.output_nc
18 | self.transform = get_transform(opt, grayscale=(input_nc == 1))
19 |
20 | def __getitem__(self, index):
21 | """Return a data point and its metadata information.
22 |
23 | Parameters:
24 | index - - a random integer for data indexing
25 |
26 | Returns a dictionary that contains A and A_paths
27 | A(tensor) - - an image in one domain
28 | A_paths(str) - - the path of the image
29 | """
30 | A_path = self.A_paths[index]
31 | A_img = Image.open(A_path).convert('RGB')
32 | A_img = A_img.resize((self.opt.load_size//8, self.opt.load_size//8), Image.BICUBIC)
33 | A_img = A_img.resize((self.opt.load_size , self.opt.load_size ), Image.BICUBIC)
34 | A = self.transform(A_img)
35 | return {'LR': A, 'LR_paths': A_path}
36 |
37 | def __len__(self):
38 | """Return the total number of images in the dataset."""
39 | return len(self.A_paths)
40 |
--------------------------------------------------------------------------------
/img/Snipaste_2023-05-09_22-20-55.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/img/Snipaste_2023-05-09_22-20-55.png
--------------------------------------------------------------------------------
/img/compare_CelebA_0213.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/img/compare_CelebA_0213.png
--------------------------------------------------------------------------------
/list.txt:
--------------------------------------------------------------------------------
1 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/30/
2 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/120300/
3 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/120600/
4 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/120900/
5 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/121200/
6 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/121500/
7 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/121800/
8 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/122100/
9 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/122400/
10 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/122700/
11 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/123000/
12 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/123300/
13 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/123600/
14 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/123900/
15 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/124200/
16 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/124500/
17 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/124800/
18 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/125100/
19 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/125400/
20 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/125700/
21 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/126000/
22 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/126300/
23 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/126600/
24 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/126900/
25 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/127200/
26 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/127500/
27 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/127800/
28 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/128100/
29 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/128400/
30 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/128700/
31 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/129000/
32 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/129300/
33 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/129600/
34 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/129900/
35 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/130200/
36 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/130500/
37 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/130800/
38 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/131100/
39 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/131400/
40 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/131700/
41 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/132000/
42 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/132300/
43 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/132600/
44 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/132900/
45 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/133200/
46 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/133500/
47 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/133800/
48 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/134100/
49 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/134400/
50 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/134700/
51 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/135000/
52 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/135300/
53 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/135600/
54 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/135900/
55 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/136200/
56 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/136500/
57 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/136800/
58 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/137100/
59 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/137400/
60 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/137700/
61 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/138000/
62 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/138300/
63 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/138600/
64 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/138900/
65 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/139200/
66 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/139500/
67 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/139800/
68 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/140100/
69 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/140400/
70 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/140700/
71 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/141000/
72 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/141300/
73 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/141600/
74 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/141900/
75 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/142200/
76 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/142500/
77 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/142800/
78 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/143100/
79 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/143400/
80 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/143700/
81 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/144000/
82 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/144300/
83 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/144600/
84 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/144900/
85 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/145200/
86 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/145500/
87 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/145800/
88 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/146100/
89 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/146400/
90 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/146700/
91 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/147000/
92 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/147300/
93 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/147600/
94 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/147900/
95 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/148200/
96 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/148500/
97 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/148800/
98 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/149100/
99 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/149400/
100 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/149700/
101 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/150000/
102 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/150300/
103 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/150600/
104 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/150900/
105 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/151200/
106 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/151500/
107 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/151800/
108 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/152100/
109 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/152400/
110 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/152700/
111 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/153000/
112 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/153300/
113 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/153600/
114 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/153900/
115 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/154200/
116 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/154500/
117 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/154800/
118 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/155100/
119 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/155400/
120 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/155700/
121 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/156000/
122 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/156300/
123 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/156600/
124 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/156900/
125 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/157200/
126 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/157500/
127 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/157800/
128 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/158100/
129 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/158400/
130 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/158700/
131 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/159000/
132 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/159300/
133 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/159600/
134 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/159900/
135 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/160200/
136 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/160500/
137 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/160800/
138 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/161100/
139 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/161400/
140 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/161700/
141 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/162000/
142 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/162300/
143 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/162600/
144 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/162900/
145 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/163200/
146 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/163500/
147 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/163800/
148 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/164100/
149 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/164400/
150 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/164700/
151 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/165000/
152 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/165300/
153 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/165600/
154 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/165900/
155 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/166200/
156 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/166500/
157 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/166800/
158 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/167100/
159 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/167400/
160 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/167700/
161 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/168000/
162 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/168300/
163 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/168600/
164 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/168900/
165 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/169200/
166 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/169500/
167 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/169800/
168 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/170100/
169 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/170400/
170 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/170700/
171 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/171000/
172 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/171300/
173 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/171600/
174 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/171900/
175 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/172200/
176 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/172500/
177 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/172800/
178 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/173100/
179 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/173400/
180 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/173700/
181 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/174000/
182 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/174300/
183 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/174600/
184 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/174900/
185 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/175200/
186 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/175500/
187 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/175800/
188 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/176100/
189 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/176400/
190 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/176700/
191 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/177000/
192 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/177300/
193 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/177600/
194 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/177900/
195 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/178200/
196 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/178500/
197 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/178800/
198 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/179100/
199 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/179400/
200 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/179700/
201 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/180000/
202 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/180300/
203 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/180600/
204 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/180900/
205 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/181200/
206 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/181500/
207 | /home2/ZiXiangXu/Last_ding/res4_jiu/result/181800/
208 |
--------------------------------------------------------------------------------
/log_test/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/log_test/.DS_Store
--------------------------------------------------------------------------------
/log_test/log.txt:
--------------------------------------------------------------------------------
1 | 12.0783
2 | 28.3220
3 | 28.3708
4 | 28.3204
5 | 28.3169
6 | 28.3368
7 | 28.3303
8 | 28.3376
9 | 28.3043
10 | 28.3455
11 | 28.3427
12 | 28.3043
13 | 28.3420
14 | 28.3322
15 | 28.3534
16 | 28.3355
17 | 28.3511
18 | 28.3273
19 | 28.3317
20 | 28.3479
21 | 28.3109
22 | 28.3295
23 | 28.3343
24 | 28.3320
25 | 28.3320
26 | 28.3173
27 | 28.3268
28 | 28.3168
29 | 28.3313
30 | 28.3388
31 | 28.3352
32 | 28.3314
33 | 28.2833
34 | 28.3357
35 | 28.3168
36 | 28.3110
37 | 28.3277
38 | 28.2818
39 | 28.3360
40 | 28.3478
41 | 28.3134
42 | 28.3129
43 | 28.3344
44 | 28.3179
45 | 28.3556
46 | 28.3111
47 | 28.3330
48 | 28.3280
49 | 28.3254
50 | 28.3229
51 | 28.3637
52 | 28.3306
53 | 28.3196
54 | 28.3532
55 | 28.3433
56 | 28.3104
57 | 28.3350
58 | 28.3331
59 | 28.3263
60 | 28.3453
61 | 28.3390
62 | 28.3335
63 | 28.3625
64 | 28.3183
65 | 28.3324
66 | 28.3158
67 | 28.3430
68 | 28.3575
69 | 28.3496
70 | 28.3030
71 | 28.3165
72 | 28.3404
73 | 28.3460
74 | 28.3220
75 | 28.3133
76 | 28.3043
77 | 28.3441
78 | 28.3283
79 | 28.3305
80 | 28.3587
81 | 28.3322
82 | 28.3368
83 | 28.3249
84 | 28.3503
85 | 28.3056
86 | 28.3166
87 | 28.3644
88 | 28.3331
89 | 28.3308
90 | 28.3278
91 | 28.3275
92 | 28.3343
93 | 28.2769
94 | 28.2736
95 | 28.3392
96 | 28.2909
97 | 28.3435
98 | 28.3099
99 | 28.3412
100 | 28.3531
101 | 28.3014
102 | 28.3428
103 | 28.3202
104 | 28.3460
105 | 28.3500
106 | 28.3497
107 | 28.3392
108 | 28.3316
109 | 28.3482
110 | 28.3585
111 | 28.3302
112 | 28.3519
113 | 28.3341
114 | 28.3564
115 | 28.3363
116 | 28.3425
117 | 28.3388
118 | 28.3116
119 | 28.3341
120 | 28.3073
121 | 28.3334
122 | 28.3283
123 | 28.2865
124 | 28.3707
125 | 28.3330
126 | 28.3388
127 | 28.3180
128 | 28.3550
129 | 28.3425
130 | 28.3092
131 | 28.3180
132 | 28.3432
133 | 28.3047
134 | 28.3071
135 | 28.3340
136 | 28.3279
137 | 28.3185
138 | 28.3296
139 | 28.3274
140 | 28.3160
141 | 28.3480
142 | 28.3357
143 | 28.3403
144 | 28.3289
145 | 28.3628
146 | 28.3637
147 | 28.3303
148 | 28.3436
149 | 28.3389
150 | 28.3285
151 | 28.3487
152 | 28.3357
153 | 28.2994
154 | 28.3394
155 | 28.3265
156 | 28.3439
157 | 28.3053
158 | 28.3387
159 | 28.3095
160 | 28.3175
161 | 28.3418
162 | 28.3320
163 | 28.3314
164 | 28.3302
165 | 28.3545
166 | 28.3434
167 | 28.3423
168 | 28.3424
169 | 28.3365
170 | 28.3671
171 | 28.2795
172 | 28.3650
173 | 28.3424
174 | 28.3152
175 | 28.3450
176 | 28.3658
177 | 28.3301
178 | 28.3419
179 | 28.3350
180 | 28.3159
181 | 28.3191
182 | 28.3227
183 | 28.3352
184 | 28.3256
185 | 28.3616
186 | 28.3221
187 | 28.3580
188 | 28.3277
189 | 28.3325
190 | 28.3275
191 | 28.3465
192 | 28.3416
193 | 28.3519
194 | 28.3454
195 | 28.3446
196 | 28.3256
197 | 28.3471
198 | 28.3252
199 | 28.3131
200 | 28.3357
201 | 28.3301
202 | 28.3096
203 | 28.2849
204 | 28.3513
205 | 28.2897
206 | 28.3352
207 | 28.3462
208 | 12.0783
209 | 28.3220
210 | 12.0783
211 | 28.3220
212 | 28.3708
213 | 28.3204
214 | 28.3169
215 | 28.3368
216 | 28.3303
217 | 28.3376
218 | 28.3043
219 | 28.3455
220 | 28.3427
221 | 28.3043
222 | 28.3420
223 | 28.3322
224 | 28.3534
225 | 28.3355
226 | 28.3511
227 | 28.3273
228 | 28.3317
229 | 28.3479
230 | 28.3109
231 | 28.3295
232 | 28.3343
233 | 28.3320
234 | 28.3320
235 | 28.3173
236 | 28.3268
237 | 28.3168
238 | 28.3313
239 | 28.3388
240 | 28.3352
241 | 28.3314
242 | 28.2833
243 | 28.3357
244 | 28.3168
245 | 28.3110
246 | 28.3277
247 | 28.2818
248 | 28.3360
249 | 28.3478
250 | 28.3134
251 | 28.3129
252 | 28.3344
253 | 28.3179
254 | 28.3556
255 | 28.3111
256 | 28.3330
257 | 28.3280
258 | 28.3254
259 | 28.3229
260 | 28.3637
261 | 28.3306
262 | 28.3196
263 | 28.3532
264 | 28.3433
265 | 28.3104
266 | 28.3350
267 | 28.3331
268 | 28.3263
269 | 28.3453
270 | 28.3390
271 | 28.3335
272 | 28.3625
273 | 28.3183
274 | 28.3324
275 | 28.3158
276 | 28.3430
277 | 28.3575
278 | 28.3496
279 | 28.3030
280 | 28.3165
281 | 28.3404
282 | 28.3460
283 | 28.3220
284 | 28.3133
285 | 28.3043
286 | 28.3441
287 | 28.3283
288 | 28.3305
289 | 28.3587
290 | 28.3322
291 | 28.3368
292 | 28.3249
293 | 28.3503
294 | 28.3056
295 | 28.3166
296 | 28.3644
297 | 28.3331
298 | 28.3308
299 | 28.3278
300 | 28.3275
301 | 28.3343
302 | 28.2769
303 | 28.2736
304 | 28.3392
305 | 28.2909
306 | 28.3435
307 | 28.3099
308 | 28.3412
309 | 28.3531
310 | 28.3014
311 | 28.3428
312 | 28.3202
313 | 28.3460
314 | 28.3500
315 | 28.3497
316 | 28.3392
317 | 28.3316
318 | 28.3482
319 | 28.3585
320 | 28.3302
321 | 28.3519
322 | 28.3341
323 | 28.3564
324 | 28.3363
325 | 28.3425
326 | 28.3388
327 | 28.3116
328 | 28.3341
329 | 28.3073
330 | 28.3334
331 | 28.3283
332 | 28.2865
333 | 28.3707
334 | 28.3330
335 | 28.3388
336 | 28.3180
337 | 28.3550
338 | 28.3425
339 | 28.3092
340 | 28.3180
341 | 28.3432
342 | 28.3047
343 | 28.3071
344 | 28.3340
345 | 28.3279
346 | 28.3185
347 | 28.3296
348 | 28.3274
349 | 28.3160
350 | 28.3480
351 | 28.3357
352 | 28.3403
353 | 28.3289
354 | 28.3628
355 | 28.3637
356 | 28.3303
357 | 28.3436
358 | 28.3389
359 | 28.3285
360 | 28.3487
361 | 28.3357
362 | 28.2994
363 | 28.3394
364 | 28.3265
365 | 28.3439
366 | 28.3053
367 | 28.3387
368 | 28.3095
369 | 28.3175
370 | 28.3418
371 | 28.3320
372 | 28.3314
373 | 28.3302
374 | 28.3545
375 | 28.3434
376 | 28.3423
377 | 28.3424
378 | 28.3365
379 | 28.3671
380 | 28.2795
381 | 28.3650
382 | 28.3424
383 | 28.3152
384 | 28.3450
385 | 28.3658
386 | 28.3301
387 | 28.3419
388 | 28.3350
389 | 28.3159
390 | 28.3191
391 | 28.3227
392 | 28.3352
393 | 28.3256
394 | 28.3616
395 | 28.3221
396 | 28.3580
397 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 | from . import networks
6 |
7 | class BaseModel(ABC):
8 | """This class is an abstract base class (ABC) for models.
9 | To create a subclass, you need to implement the following five functions:
10 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
11 | -- : unpack data from dataset and apply preprocessing.
12 | -- : produce intermediate results.
13 | -- : calculate losses, gradients, and update network weights.
14 | -- : (optionally) add model-specific options and set default options.
15 | """
16 |
17 | def __init__(self, opt):
18 | """Initialize the BaseModel class.
19 |
20 | Parameters:
21 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
22 |
23 | When creating your custom class, you need to implement your own initialization.
24 | In this fucntion, you should first call
25 | Then, you need to define four lists:
26 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
27 | -- self.model_names (str list): specify the images that you want to display and save.
28 | -- self.visual_names (str list): define networks used in our training.
29 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
30 | """
31 | self.opt = opt
32 | self.gpu_ids = opt.gpu_ids
33 | self.isTrain = opt.isTrain
34 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36 |
37 | self.loss_names = []
38 | self.model_names = []
39 | self.visual_names = []
40 | self.optimizers = []
41 | self.image_paths = []
42 | self.metric = 0 # used for learning rate policy 'plateau'
43 |
44 | @staticmethod
45 | def modify_commandline_options(parser, is_train):
46 | """Add new model-specific options, and rewrite default values for existing options.
47 |
48 | Parameters:
49 | parser -- original option parser
50 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
51 |
52 | Returns:
53 | the modified parser.
54 | """
55 | return parser
56 |
57 | @abstractmethod
58 | def set_input(self, input):
59 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
60 |
61 | Parameters:
62 | input (dict): includes the data itself and its metadata information.
63 | """
64 | pass
65 |
66 | @abstractmethod
67 | def forward(self):
68 | """Run forward pass; called by both functions and ."""
69 | pass
70 |
71 | @abstractmethod
72 | def optimize_parameters(self):
73 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
74 | pass
75 |
76 | def setup(self, opt):
77 | """Load and print networks; create schedulers
78 |
79 | Parameters:
80 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
81 | """
82 | if self.isTrain:
83 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
84 | if not self.isTrain or opt.continue_train:
85 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
86 | self.load_networks(load_suffix)
87 | self.print_networks(opt.verbose)
88 |
89 | def eval(self):
90 | """Make models eval mode during test time"""
91 | for name in self.model_names:
92 | if isinstance(name, str):
93 | net = getattr(self, 'net' + name)
94 | net.eval()
95 |
96 | def accumulate(self, model1, model2, decay=0.999):
97 | par1 = model1.state_dict()
98 | par2 = model2.state_dict()
99 |
100 | for k in par1.keys():
101 | par1[k].data = par1[k].data * decay + (1 - decay) * par2[k].data
102 | model1.load_state_dict(par1)
103 |
104 | def test(self):
105 | """Forward function used in test time.
106 |
107 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
108 | It also calls to produce additional visualization results
109 | """
110 | with torch.no_grad():
111 | self.forward()
112 | self.compute_visuals()
113 |
114 | def compute_visuals(self):
115 | """Calculate additional output images for visdom and HTML visualization"""
116 | pass
117 |
118 | def get_image_paths(self):
119 | """ Return image paths that are used to load current data"""
120 | return self.image_paths
121 |
122 | def update_learning_rate(self):
123 | """Update learning rates for all the networks; called at the end of every epoch"""
124 | for scheduler in self.schedulers:
125 | if self.opt.lr_policy == 'plateau':
126 | scheduler.step(self.metric)
127 | else:
128 | scheduler.step()
129 |
130 | lr = self.optimizers[0].param_groups[0]['lr']
131 | print('learning rate = %.7f' % lr)
132 |
133 | def get_lr(self,):
134 | lrs = {}
135 | for idx, p in enumerate(self.optimizers):
136 | lrs['LR{}'.format(idx)] = p.param_groups[0]['lr']
137 | return lrs
138 |
139 | def get_current_visuals(self):
140 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
141 | visual_ret = OrderedDict()
142 | for name in self.visual_names:
143 | if isinstance(name, str):
144 | visual_ret[name] = getattr(self, name)
145 | return visual_ret
146 |
147 | def get_current_losses(self):
148 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
149 | errors_ret = OrderedDict()
150 | for name in self.loss_names:
151 | if isinstance(name, str):
152 | errors_ret['Loss_' + name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
153 | return errors_ret
154 |
155 | def save_networks(self, epoch, info=None):
156 | """Save all the networks to the disk.
157 |
158 | Parameters:
159 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
160 | """
161 | for name in self.model_names:
162 | if isinstance(name, str):
163 | save_filename = '%s_net_%s.pth' % (epoch, name)
164 | save_path = os.path.join(self.save_dir, save_filename)
165 | net = getattr(self, 'net' + name)
166 |
167 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
168 | torch.save(net.module.cpu().state_dict(), save_path)
169 | net.cuda(self.gpu_ids[0])
170 | else:
171 | torch.save(net.cpu().state_dict(), save_path)
172 | opts = []
173 | for opt in self.optimizers:
174 | opts.append(opt.state_dict())
175 | torch.save(opts, os.path.join(self.save_dir, '%s_opts.pth' % epoch))
176 |
177 | if info is not None:
178 | torch.save(info, os.path.join(self.save_dir, '%s.info' % epoch))
179 |
180 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
181 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
182 | key = keys[i]
183 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
184 | if module.__class__.__name__.startswith('InstanceNorm') and \
185 | (key == 'running_mean' or key == 'running_var'):
186 | if getattr(module, key) is None:
187 | state_dict.pop('.'.join(keys))
188 | if module.__class__.__name__.startswith('InstanceNorm') and \
189 | (key == 'num_batches_tracked'):
190 | state_dict.pop('.'.join(keys))
191 | else:
192 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
193 |
194 | def load_networks(self, epoch):
195 | """Load all the networks from the disk.
196 |
197 | Parameters:
198 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
199 | """
200 | # Load optimizers
201 | saved_opts = torch.load(os.path.join(self.save_dir, '%s_opts.pth' % epoch))
202 | for sopt, opt in zip(saved_opts, self.optimizers):
203 | opt.load_state_dict(sopt)
204 |
205 | # Load model weights
206 | for name in self.load_model_names:
207 | if isinstance(name, str):
208 | load_filename = '%s_net_%s.pth' % (epoch, name)
209 | load_path = os.path.join(self.save_dir, load_filename)
210 | net = getattr(self, 'net' + name)
211 | if isinstance(net, torch.nn.DataParallel):
212 | net = net.module
213 | print('loading the model from %s' % load_path)
214 | # if you are using PyTorch newer than 0.4 (e.g., built from
215 | # GitHub source), you can remove str() on self.device
216 | state_dict = torch.load(load_path, map_location=str(self.device))
217 | # if hasattr(state_dict, '_metadata'):
218 | # del state_dict._metadata
219 |
220 | # patch InstanceNorm checkpoints prior to 0.4
221 | # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
222 | # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
223 | # net.load_state_dict(state_dict)
224 | # Load partial weights
225 | model_dict = net.state_dict()
226 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
227 | model_dict.update(pretrained_dict)
228 | net.load_state_dict(model_dict, strict=False)
229 |
230 | info_path = os.path.join(self.save_dir, '%s.info' % epoch)
231 | if os.path.exists(info_path):
232 | info_dict = torch.load(info_path)
233 | for k, v in info_dict.items():
234 | setattr(self.opt, k, v)
235 |
236 | def print_networks(self, verbose):
237 | """Print the total number of parameters in the network and (if verbose) network architecture
238 |
239 | Parameters:
240 | verbose (bool) -- if verbose: print the network architecture
241 | """
242 | print('---------- Networks initialized -------------')
243 | for name in self.model_names:
244 | if isinstance(name, str):
245 | net = getattr(self, 'net' + name)
246 | num_params = 0
247 | for param in net.parameters():
248 | num_params += param.numel()
249 | if verbose:
250 | print(net)
251 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
252 | print('-----------------------------------------------')
253 |
254 | def set_requires_grad(self, nets, requires_grad=False):
255 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
256 | Parameters:
257 | nets (network list) -- a list of networks
258 | requires_grad (bool) -- whether the networks require gradients or not
259 | """
260 | if not isinstance(nets, list):
261 | nets = [nets]
262 | for net in nets:
263 | if net is not None:
264 | for param in net.parameters():
265 | param.requires_grad = requires_grad
266 |
--------------------------------------------------------------------------------
/models/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 | from torch.nn import functional as F
5 | import numpy as np
6 | from IPython import embed
7 |
8 | class NormLayer(nn.Module):
9 | """Normalization Layers.
10 | ------------
11 | # Arguments
12 | - channels: input channels, for batch norm and instance norm.
13 | - input_size: input shape without batch size, for layer norm.
14 | """
15 | def __init__(self, channels, normalize_shape=None, norm_type='bn'):
16 | super(NormLayer, self).__init__()
17 | norm_type = norm_type.lower()
18 | if norm_type == 'bn':
19 | self.norm = nn.BatchNorm2d(channels)
20 | elif norm_type == 'in':
21 | self.norm = nn.InstanceNorm2d(channels, affine=True)
22 | elif norm_type == 'gn':
23 | self.norm = nn.GroupNorm(32, channels, affine=True)
24 | elif norm_type == 'pixel':
25 | self.norm = lambda x: F.normalize(x, p=2, dim=1)
26 | elif norm_type == 'layer':
27 | self.norm = nn.LayerNorm(normalize_shape)
28 | elif norm_type == 'none':
29 | self.norm = lambda x: x
30 | else:
31 | assert 1==0, 'Norm type {} not support.'.format(norm_type)
32 |
33 | def forward(self, x):
34 | return self.norm(x)
35 |
36 |
37 | class ReluLayer(nn.Module):
38 | """Relu Layer.
39 | ------------
40 | # Arguments
41 | - relu type: type of relu layer, candidates are
42 | - ReLU
43 | - LeakyReLU: default relu slope 0.2
44 | - PRelu
45 | - SELU
46 | - none: direct pass
47 | """
48 | def __init__(self, channels, relu_type='relu'):
49 | super(ReluLayer, self).__init__()
50 | relu_type = relu_type.lower()
51 | if relu_type == 'relu':
52 | self.func = nn.ReLU(True)
53 | elif relu_type == 'leakyrelu':
54 | self.func = nn.LeakyReLU(0.2, inplace=True)
55 | elif relu_type == 'prelu':
56 | self.func = nn.PReLU(channels)
57 | elif relu_type == 'selu':
58 | self.func = nn.SELU(True)
59 | elif relu_type == 'none':
60 | self.func = lambda x: x
61 | else:
62 | assert 1==0, 'Relu type {} not support.'.format(relu_type)
63 |
64 | def forward(self, x):
65 | return self.func(x)
66 |
67 |
68 | class ConvLayer(nn.Module):
69 | def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True):
70 | super(ConvLayer, self).__init__()
71 | self.use_pad = use_pad
72 |
73 | bias = True if norm_type in ['pixel', 'none'] else False
74 | stride = 2 if scale == 'down' else 1
75 |
76 | self.scale_func = lambda x: x
77 | if scale == 'up':
78 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
79 |
80 | self.reflection_pad = nn.ReflectionPad2d(kernel_size // 2)
81 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
82 |
83 | self.relu = ReluLayer(out_channels, relu_type)
84 | self.norm = NormLayer(out_channels, norm_type=norm_type)
85 |
86 | def forward(self, x):
87 | out = self.scale_func(x)
88 | if self.use_pad:
89 | out = self.reflection_pad(out)
90 | out = self.conv2d(out)
91 | out = self.norm(out)
92 | out = self.relu(out)
93 | return out
94 |
95 |
96 | class ResidualBlock(nn.Module):
97 | """
98 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
99 | ------------------
100 | # Args
101 | - hg_depth: depth of HourGlassBlock. 0: don't use attention map.
102 | - use_pmask: whether use previous mask as HourGlassBlock input.
103 | """
104 | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none', hg_depth=2, att_name='spar'):
105 | super(ResidualBlock, self).__init__()
106 | self.c_in = c_in
107 | self.c_out = c_out
108 | self.norm_type = norm_type
109 | self.relu_type = relu_type
110 | self.hg_depth = hg_depth
111 |
112 | kwargs = {'norm_type': norm_type, 'relu_type': relu_type}
113 |
114 | if scale == 'none' and c_in == c_out:
115 | self.shortcut_func = lambda x: x
116 | else:
117 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
118 |
119 | self.preact_func = nn.Sequential(
120 | NormLayer(c_in, norm_type=self.norm_type),
121 | ReluLayer(c_in, self.relu_type),
122 | )
123 |
124 | if scale == 'down':
125 | scales = ['none', 'down']
126 | elif scale == 'up':
127 | scales = ['up', 'none']
128 | elif scale == 'none':
129 | scales = ['none', 'none']
130 |
131 | self.conv1 = ConvLayer(c_in, c_out, 3, scales[0], **kwargs)
132 | self.conv2 = ConvLayer(c_out, c_out, 3, scales[1], norm_type=norm_type, relu_type='none')
133 |
134 | if att_name.lower() == 'spar':
135 | c_attn = 1
136 | elif att_name.lower() == 'spar3d':
137 | c_attn = c_out
138 | else:
139 | raise Exception("Attention type {} not implemented".format(att_name))
140 |
141 | self.att_func = HourGlassBlock(self.hg_depth, c_out, c_attn, **kwargs)
142 |
143 | def forward(self, x):
144 | identity = self.shortcut_func(x)
145 | out = self.preact_func(x)
146 | out = self.conv1(out)
147 | out = self.conv2(out)
148 | out = identity + self.att_func(out)
149 | return out
150 |
151 |
152 | class HourGlassBlock(nn.Module):
153 | """Simplified HourGlass block.
154 | Reference: https://github.com/1adrianb/face-alignment
155 | --------------------------
156 | """
157 | def __init__(self, depth, c_in, c_out,
158 | c_mid=64,
159 | norm_type='bn',
160 | relu_type='prelu',
161 | ):
162 | super(HourGlassBlock, self).__init__()
163 | self.depth = depth
164 | self.c_in = c_in
165 | self.c_mid = c_mid
166 | self.c_out = c_out
167 | self.kwargs = {'norm_type': norm_type, 'relu_type': relu_type}
168 |
169 | if self.depth:
170 | self._generate_network(self.depth)
171 | self.out_block = nn.Sequential(
172 | ConvLayer(self.c_mid, self.c_out, norm_type='none', relu_type='none'),
173 | nn.Sigmoid()
174 | )
175 |
176 | def _generate_network(self, level):
177 | if level == self.depth:
178 | c1, c2 = self.c_in, self.c_mid
179 | else:
180 | c1, c2 = self.c_mid, self.c_mid
181 |
182 | self.add_module('b1_' + str(level), ConvLayer(c1, c2, **self.kwargs))
183 | self.add_module('b2_' + str(level), ConvLayer(c1, c2, scale='down', **self.kwargs))
184 | if level > 1:
185 | self._generate_network(level - 1)
186 | else:
187 | self.add_module('b2_plus_' + str(level), ConvLayer(self.c_mid, self.c_mid, **self.kwargs))
188 |
189 | self.add_module('b3_' + str(level), ConvLayer(self.c_mid, self.c_mid, scale='up', **self.kwargs))
190 |
191 | def _forward(self, level, in_x):
192 | up1 = self._modules['b1_' + str(level)](in_x)
193 | low1 = self._modules['b2_' + str(level)](in_x)
194 | if level > 1:
195 | low2 = self._forward(level - 1, low1)
196 | else:
197 | low2 = self._modules['b2_plus_' + str(level)](low1)
198 |
199 | up2 = self._modules['b3_' + str(level)](low2)
200 | if up1.shape[2:] != up2.shape[2:]:
201 | up2 = nn.functional.interpolate(up2, up1.shape[2:])
202 |
203 | return up1 + up2
204 |
205 | def forward(self, x, pmask=None):
206 | if self.depth == 0: return x
207 | input_x = x # (1,64,64,64)
208 | x = self._forward(self.depth, x) # (1,64,64,64)
209 | #embed()
210 | self.att_map = self.out_block(x) #(1,1,64,64)
211 | #embed()
212 | x = input_x * self.att_map #(1,64,64,64)
213 | return x
214 |
215 |
216 |
217 |
--------------------------------------------------------------------------------
/models/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
8 | return nn.Conv2d(
9 | in_channels, out_channels, kernel_size,
10 | padding=(kernel_size//2), bias=bias)
11 |
12 |
13 | class MeanShift(nn.Conv2d):
14 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
15 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
16 | std = torch.Tensor(rgb_std)
17 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
18 | self.weight.data.div_(std.view(3, 1, 1, 1))
19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
20 | self.bias.data.div_(std)
21 | self.requires_grad = False
22 |
23 |
24 | class Upsampler(nn.Sequential):
25 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
26 | m = []
27 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
28 | for _ in range(int(math.log(scale, 2))):
29 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
30 | m.append(nn.PixelShuffle(2))
31 | if bn: m.append(nn.BatchNorm2d(n_feats))
32 |
33 | if act == 'relu':
34 | m.append(nn.ReLU(True))
35 | elif act == 'prelu':
36 | m.append(nn.PReLU(n_feats))
37 |
38 | elif scale == 3:
39 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
40 | m.append(nn.PixelShuffle(3))
41 | if bn: m.append(nn.BatchNorm2d(n_feats))
42 |
43 | if act == 'relu':
44 | m.append(nn.ReLU(True))
45 | elif act == 'prelu':
46 | m.append(nn.PReLU(n_feats))
47 | else:
48 | raise NotImplementedError
49 |
50 | super(Upsampler, self).__init__(*m)
51 |
52 |
53 | class DownBlock(nn.Module):
54 | def __init__(self, scale, nFeat=None, in_channels=None, out_channels=None):
55 | super(DownBlock, self).__init__()
56 |
57 | if nFeat is None:
58 | nFeat = 20
59 |
60 | if in_channels is None:
61 | in_channels = 3
62 |
63 | if out_channels is None:
64 | out_channels = 3
65 |
66 |
67 | dual_block = [
68 | nn.Sequential(
69 | nn.Conv2d(in_channels, nFeat, kernel_size=3, stride=2, padding=1, bias=False),
70 | nn.ReLU(inplace=True)
71 | )
72 | ]
73 |
74 | #该模块在本代码中不起作用,scale默认均为2
75 | for _ in range(1, int(np.log2(scale))): #当scale=2时,此循环不起作用,dual_block仅包括conv-LeakyReLU-conv;当scale=4时,此循环起作用,dual_block包括conv-LeakyReLU-conv-LeakyReLU-conv;
76 | dual_block.append(
77 | nn.Sequential(
78 | nn.Conv2d(nFeat, nFeat, kernel_size=3, stride=2, padding=1, bias=False),
79 | nn.ReLU(inplace=True)
80 | )
81 | )
82 |
83 | dual_block.append(nn.Conv2d(nFeat, out_channels, kernel_size=3, stride=1, padding=1, bias=False))
84 |
85 | self.dual_module = nn.Sequential(*dual_block)
86 |
87 | def forward(self, x):
88 | x = self.dual_module(x)
89 | return x
90 |
91 |
92 |
93 | class Scale(nn.Module):
94 |
95 | def __init__(self, init_value=1e-3):
96 | super().__init__()
97 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
98 |
99 | def forward(self, input):
100 | return input * self.scale
101 |
102 | ## Channel Attention (CA) Layer
103 | class CALayer(nn.Module):
104 | def __init__(self, channel, reduction=16):
105 | super(CALayer, self).__init__()
106 | # global average pooling: feature --> point
107 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
108 | # feature channel downscale and upscale --> channel weight
109 | self.conv_du = nn.Sequential(
110 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
111 | nn.ReLU(inplace=True),
112 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
113 | nn.Sigmoid()
114 | )
115 |
116 | def forward(self, x):
117 | y = self.avg_pool(x)
118 | y = self.conv_du(y)
119 | return x * y
120 |
121 |
122 | ## Residual Channel Attention Block (RCAB)
123 | class RCAB(nn.Module):
124 | def __init__(self, conv, n_feat, kernel_size, reduction=16, bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
125 | super(RCAB, self).__init__()
126 | modules_body = []
127 | for i in range(2):
128 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
129 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
130 | if i == 0: modules_body.append(act)
131 | modules_body.append(CALayer(n_feat, reduction))
132 | self.body = nn.Sequential(*modules_body)
133 | self.res_scale = res_scale
134 |
135 | def forward(self, x):
136 | res = self.body(x)
137 | res += x
138 | return res
139 |
140 | class EcaLayer(nn.Module):
141 |
142 | def __init__(self, channels, gamma=2, b=1):
143 | super(EcaLayer, self).__init__()
144 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
145 |
146 | t = int(abs((math.log(channels, 2) + b) / gamma))
147 | k_size = t if t % 2 else t + 1
148 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
149 |
150 | self.sigmoid = nn.Sigmoid()
151 |
152 | def forward(self, x):
153 | # feature descriptor on the global spatial information
154 | y = self.avg_pool(x)
155 |
156 | # Two different branches of ECA module
157 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
158 |
159 | # Multi-scale information fusion
160 | y = self.sigmoid(y)
161 |
162 | return x * y.expand_as(x)
163 |
164 | class MSRB(nn.Module):
165 | def __init__(self, conv, n_feat):
166 | super(MSRB, self).__init__()
167 |
168 | kernel_size_1 = 3
169 | kernel_size_2 = 5
170 | #self.ca1 = EcaLayer(n_feat)
171 | #self.ca2 = EcaLayer(n_feat*2)
172 | #self.ca3 = EcaLayer(n_feat * 4)
173 |
174 | self.ca1 = CALayer(channel=n_feat)
175 | self.ca2 = CALayer(channel=n_feat*2)
176 | self.ca3 = CALayer(channel=n_feat * 4)
177 |
178 | self.conv_3_1 = conv(n_feat, n_feat, kernel_size_1)
179 | self.conv_3_2 = conv(n_feat * 2, n_feat * 2, kernel_size_1)
180 | self.conv_5_1 = conv(n_feat, n_feat, kernel_size_2)
181 | self.conv_5_2 = conv(n_feat * 2, n_feat * 2, kernel_size_2)
182 | self.confusion = nn.Conv2d(n_feat * 4, n_feat, 1, padding=0, stride=1)
183 | self.confusion1 = nn.Conv2d(n_feat * 4, n_feat * 4, 1, padding=0, stride=1)
184 | self.relu = nn.ReLU(inplace=True)
185 |
186 | def forward(self, x):
187 | input_1 = x
188 | output_3_1 = self.conv_3_1(self.relu(self.conv_3_1(input_1)))
189 | output_5_1 = self.conv_5_1(self.relu(self.conv_5_1(input_1)))
190 |
191 | output_ca_31 = self.ca1(output_3_1)
192 | output_ca_51 = self.ca1(output_5_1)
193 |
194 | output_ca_31_1 = input_1+output_ca_31
195 | output_ca_51_1 = input_1+output_ca_51
196 |
197 | input_2 = torch.cat([output_ca_31_1, output_ca_51_1], 1)
198 |
199 | output_3_2 = self.conv_3_2(self.relu(self.conv_3_2(input_2)))
200 | output_5_2 = self.conv_5_2(self.relu(self.conv_5_2(input_2)))
201 |
202 | output_ca_32 = self.ca2(output_3_2)
203 | output_ca_52 = self.ca2(output_5_2)
204 |
205 | output_ca_32_2 = input_2 + output_ca_32
206 | output_ca_52_2 = input_2 + output_ca_52
207 |
208 |
209 | input_3 = torch.cat([output_ca_32_2, output_ca_52_2], 1)
210 |
211 | output = self.confusion1(self.relu(self.confusion1(input_3)))
212 | output_4 = self.ca3(output)
213 | output5 = input_3 + output_4
214 |
215 | output6 = self.confusion(output5)
216 | output6 += x
217 | return output6
218 |
219 | class RCAB_ECA(nn.Module):
220 | def __init__(self, conv, n_feat, kernel_size, reduction=16, bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
221 | super(RCAB_ECA, self).__init__()
222 | modules_body = []
223 | for i in range(2):
224 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
225 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
226 | if i == 0: modules_body.append(act)
227 | modules_body.append(EcaLayer(channels=n_feat))
228 | self.body = nn.Sequential(*modules_body)
229 | self.res_scale = res_scale
230 |
231 | def forward(self, x):
232 | res = self.body(x)
233 | res += x
234 | return res
--------------------------------------------------------------------------------
/models/common_ESTR.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from torch.autograd import Variable
8 |
9 | def default_conv(in_channels, out_channels, kernel_size, bias=True, groups = 1):
10 | wn = lambda x:torch.nn.utils.weight_norm(x)
11 | return nn.Conv2d(
12 | in_channels, out_channels, kernel_size,
13 | padding=(kernel_size//2), bias=bias, groups = groups)
14 |
15 | class Scale(nn.Module):
16 |
17 | def __init__(self, init_value=1e-3):
18 | super().__init__()
19 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
20 |
21 | def forward(self, input):
22 | return input * self.scale
23 |
24 | class MeanShift(nn.Conv2d):
25 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
26 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
27 | std = torch.Tensor(rgb_std)
28 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
29 | self.weight.data.div_(std.view(3, 1, 1, 1))
30 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
31 | self.bias.data.div_(std)
32 | self.requires_grad = False
33 |
34 | class BasicBlock(nn.Sequential):
35 | def __init__(
36 | self, in_channels, out_channels, kernel_size, stride=1, bias=False,
37 | bn=True, act=nn.ReLU(True)):
38 |
39 | m = [nn.Conv2d(
40 | in_channels, out_channels, kernel_size,
41 | padding=(kernel_size//2), stride=stride, bias=bias)
42 | ]
43 | if bn: m.append(nn.BatchNorm2d(out_channels))
44 | if act is not None: m.append(act)
45 | super(BasicBlock, self).__init__(*m)
46 |
47 | class ResBlock(nn.Module):
48 | def __init__(
49 | self, conv, n_feats, kernel_size,
50 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
51 |
52 | super(ResBlock, self).__init__()
53 | m = []
54 | for i in range(2):
55 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
56 | if bn: m.append(nn.BatchNorm2d(n_feats))
57 | if i == 0: m.append(act)
58 |
59 | self.body = nn.Sequential(*m)
60 | self.res_scale = res_scale
61 |
62 | def forward(self, x):
63 | res = self.body(x).mul(self.res_scale)
64 | res += x
65 |
66 | return res
67 |
68 | class LuConv(nn.Module):
69 | def __init__(
70 | self, conv, n_feats, kernel_size,
71 | bias=True, bn=False, act=nn.LeakyReLU(0.05), res_scale=1):
72 | super(LuConv, self).__init__()
73 | #self.scale1 = Scale(1)
74 | #self.scale2 = Scale(1)
75 | m = []
76 | for i in range(2):
77 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
78 | if bn: m.append(nn.BatchNorm2d(n_feats))
79 | if i == 0: m.append(act)
80 |
81 | self.body = nn.Sequential(*m)
82 | self.res_scale = res_scale
83 |
84 | def forward(self, x):
85 | res = self.body(x)
86 | return res
87 |
88 | class Upsampler(nn.Sequential):
89 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
90 |
91 | m = []
92 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
93 | for _ in range(int(math.log(scale, 2))):
94 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
95 | m.append(nn.PixelShuffle(2))
96 | if bn: m.append(nn.BatchNorm2d(n_feats))
97 |
98 | if act == 'relu':
99 | m.append(nn.ReLU(True))
100 | elif act == 'prelu':
101 | m.append(nn.PReLU(n_feats))
102 |
103 | elif scale == 3:
104 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
105 | m.append(nn.PixelShuffle(3))
106 | if bn: m.append(nn.BatchNorm2d(n_feats))
107 |
108 | if act == 'relu':
109 | m.append(nn.ReLU(True))
110 | elif act == 'prelu':
111 | m.append(nn.PReLU(n_feats))
112 | else:
113 | raise NotImplementedError
114 |
115 | super(Upsampler, self).__init__(*m)
116 |
--------------------------------------------------------------------------------
/models/ctcnet.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/models/ctcnet.py
--------------------------------------------------------------------------------
/models/ctcnet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from models import loss
5 | from models import networks
6 | from .base_model import BaseModel
7 | from utils import utils
8 | from models.ctcnet import CTCNet
9 |
10 |
11 | class CTCNetModel(BaseModel):
12 |
13 | def modify_commandline_options(parser, is_train):
14 | parser.add_argument('--scale_factor', type=int, default=8, help='upscale factor for CTCNet')
15 | parser.add_argument('--lambda_pix', type=float, default=1.0, help='weight for pixel loss')
16 | return parser
17 |
18 | def __init__(self, opt):
19 | BaseModel.__init__(self, opt)
20 |
21 | self.netG = CTCNet()
22 | self.netG = networks.define_network(opt, self.netG)
23 |
24 | self.model_names = ['G']
25 | self.load_model_names = ['G']
26 | self.loss_names = ['Pix']
27 | self.visual_names = ['img_LR', 'img_SR', 'img_HR']
28 |
29 | if self.isTrain:
30 | self.criterionL1 = nn.L1Loss()
31 |
32 | self.optimizer_G = optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.99))
33 | self.optimizers = [self.optimizer_G]
34 |
35 | def load_pretrain_model(self,):
36 | print('Loading pretrained model', self.opt.pretrain_model_path)
37 | weight = torch.load(self.opt.pretrain_model_path)
38 | self.netG.module.load_state_dict(weight)
39 |
40 | def set_input(self, input, cur_iters=None):
41 | self.cur_iters = cur_iters
42 | self.img_LR = input['LR'].to(self.opt.data_device)
43 | self.img_HR = input['HR'].to(self.opt.data_device)
44 |
45 | def forward(self):
46 | self.img_SR = self.netG(self.img_LR)
47 |
48 | def backward_G(self):
49 | # Pix loss
50 | #self.loss_Pix1 = self.criterionL1(self.out1, self.img_HR) * self.opt.lambda_pix
51 | self.loss_Pix = self.criterionL1(self.img_SR, self.img_HR) * self.opt.lambda_pix
52 | #self.loss_Pix = self.loss_Pix + self.loss_Pix1
53 | self.loss_Pix.backward()
54 |
55 | def optimize_parameters(self, ):
56 | # ---- Update G ------------
57 | self.optimizer_G.zero_grad()
58 | self.backward_G()
59 | self.optimizer_G.step()
60 |
61 | def get_current_visuals(self, size=128):
62 | out = []
63 | out.append(utils.tensor_to_numpy(self.img_LR))
64 | out.append(utils.tensor_to_numpy(self.img_SR))
65 | out.append(utils.tensor_to_numpy(self.img_HR))
66 | visual_imgs = [utils.batch_numpy_to_image(x, size) for x in out]
67 |
68 | return visual_imgs
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import models
3 | from utils import utils
4 | from torch import nn, autograd
5 | from torch.nn import functional as F
6 |
7 | class PCPFeat(torch.nn.Module):
8 | """
9 | Features used to calculate Perceptual Loss based on ResNet50 features.
10 | Input: (B, C, H, W), RGB, [0, 1]
11 | """
12 | def __init__(self, weight_path, model='vgg'):
13 | super(PCPFeat, self).__init__()
14 | if model == 'vgg':
15 | self.model = models.vgg19(pretrained=False)
16 | self.build_vgg_layers()
17 | elif model == 'resnet':
18 | self.model = models.resnet50(pretrained=False)
19 | self.build_resnet_layers()
20 |
21 | self.model.load_state_dict(torch.load(weight_path))
22 | self.model.eval()
23 | for param in self.model.parameters():
24 | param.requires_grad = False
25 |
26 | def build_resnet_layers(self):
27 | self.layer1 = torch.nn.Sequential(
28 | self.model.conv1,
29 | self.model.bn1,
30 | self.model.relu,
31 | self.model.maxpool,
32 | self.model.layer1
33 | )
34 | self.layer2 = self.model.layer2
35 | self.layer3 = self.model.layer3
36 | self.layer4 = self.model.layer4
37 | self.features = torch.nn.ModuleList(
38 | [self.layer1, self.layer2, self.layer3, self.layer4]
39 | )
40 |
41 | def build_vgg_layers(self):
42 | vgg_pretrained_features = self.model.features
43 | self.features = []
44 | feature_layers = [0, 3, 8, 17, 26, 35]
45 | for i in range(len(feature_layers)-1):
46 | module_layers = torch.nn.Sequential()
47 | for j in range(feature_layers[i], feature_layers[i+1]):
48 | module_layers.add_module(str(j), vgg_pretrained_features[j])
49 | self.features.append(module_layers)
50 | self.features = torch.nn.ModuleList(self.features)
51 |
52 | def preprocess(self, x):
53 | x = (x + 1) / 2
54 | mean = torch.Tensor([0.485, 0.456, 0.406]).to(x)
55 | std = torch.Tensor([0.229, 0.224, 0.225]).to(x)
56 | mean = mean.view(1, 3, 1, 1)
57 | std = std.view(1, 3, 1, 1)
58 | x = (x - mean) / std
59 | if x.shape[3] < 224:
60 | x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
61 | return x
62 |
63 | def forward(self, x):
64 | x = self.preprocess(x)
65 |
66 | features = []
67 | for m in self.features:
68 | x = m(x)
69 | features.append(x)
70 | return features
71 |
72 |
73 | class PCPLoss(torch.nn.Module):
74 | """Perceptual Loss.
75 | """
76 | def __init__(self,
77 | opt,
78 | layer=5,
79 | model='vgg',
80 | ):
81 | super(PCPLoss, self).__init__()
82 |
83 | self.crit = torch.nn.L1Loss()
84 | # self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
85 | self.weights = [1, 1, 1, 1, 1]
86 |
87 | def forward(self, x_feats, y_feats):
88 | loss = 0
89 | for xf, yf, w in zip(x_feats, y_feats, self.weights):
90 | loss = loss + self.crit(xf, yf.detach()) * w
91 | return loss
92 |
93 |
94 | class FMLoss(nn.Module):
95 | def __init__(self):
96 | super().__init__()
97 | self.crit = torch.nn.L1Loss()
98 |
99 | def forward(self, x_feats, y_feats):
100 | loss = 0
101 | for xf, yf in zip(x_feats, y_feats):
102 | loss = loss + self.crit(xf, yf.detach())
103 | return loss
104 |
105 |
106 | class GANLoss(nn.Module):
107 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
108 | """ Initialize the GANLoss class.
109 | Parameters:
110 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
111 | target_real_label (bool) - - label for a real image
112 | target_fake_label (bool) - - label of a fake image
113 | Note: Do not use sigmoid as the last layer of Discriminator.
114 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
115 | """
116 | super(GANLoss, self).__init__()
117 | self.register_buffer('real_label', torch.tensor(target_real_label))
118 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
119 | self.gan_mode = gan_mode
120 | if gan_mode == 'lsgan':
121 | self.loss = nn.MSELoss()
122 | elif gan_mode == 'vanilla':
123 | self.loss = nn.BCEWithLogitsLoss()
124 | elif gan_mode == 'hinge':
125 | pass
126 | elif gan_mode in ['wgangp']:
127 | self.loss = None
128 | elif gan_mode in ['softwgan']:
129 | self.loss = None
130 | else:
131 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
132 |
133 | def get_target_tensor(self, prediction, target_is_real):
134 | if target_is_real:
135 | target_tensor = self.real_label
136 | else:
137 | target_tensor = self.fake_label
138 | return target_tensor.expand_as(prediction)
139 |
140 | def __call__(self, prediction, target_is_real, for_discriminator=True):
141 | """Calculate loss given Discriminator's output and grount truth labels.
142 | Parameters:
143 | prediction (tensor) - - tpyically the prediction output from a discriminator
144 | target_is_real (bool) - - if the ground truth label is for real images or fake images
145 | Returns:
146 | the calculated loss.
147 | """
148 | if self.gan_mode in ['lsgan', 'vanilla']:
149 | target_tensor = self.get_target_tensor(prediction, target_is_real)
150 | loss = self.loss(prediction, target_tensor)
151 | elif self.gan_mode == 'hinge':
152 | if for_discriminator:
153 | if target_is_real:
154 | loss = nn.ReLU()(1 - prediction).mean()
155 | else:
156 | loss = nn.ReLU()(1 + prediction).mean()
157 | else:
158 | assert target_is_real, "The generator's hinge loss must be aiming for real"
159 | loss = - prediction.mean()
160 | return loss
161 |
162 | elif self.gan_mode == 'wgangp':
163 | if target_is_real:
164 | loss = -prediction.mean()
165 | else:
166 | loss = prediction.mean()
167 | elif self.gan_mode == 'softwgan':
168 | if target_is_real:
169 | loss = F.softplus(-prediction).mean()
170 | else:
171 | loss = F.softplus(prediction).mean()
172 | return loss
173 |
174 |
175 |
176 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | from models.blocks import *
2 | import torch
3 | from torch import nn
4 | from torch.nn import init
5 | from torch.optim import lr_scheduler
6 | import torch.nn.utils as tutils
7 |
8 |
9 | def apply_norm(net, weight_norm_type):
10 | for m in net.modules():
11 | if isinstance(m, nn.Conv2d):
12 | if weight_norm_type.lower() == 'spectral_norm':
13 | tutils.spectral_norm(m)
14 | elif weight_norm_type.lower() == 'weight_norm':
15 | tutils.weight_norm(m)
16 | else:
17 | pass
18 |
19 |
20 | def init_weights(net, init_type='normal', init_gain=0.02):
21 | """Initialize network weights.
22 | Parameters:
23 | net (network) -- network to be initialized
24 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
25 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
26 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
27 | work better for some applications. Feel free to try yourself.
28 | """
29 | def init_func(m): # define the initialization function
30 | classname = m.__class__.__name__
31 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
32 | if init_type == 'normal':
33 | init.normal_(m.weight.data, 0.0, init_gain)
34 | elif init_type == 'xavier':
35 | init.xavier_normal_(m.weight.data, gain=init_gain)
36 | elif init_type == 'kaiming':
37 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
38 | elif init_type == 'orthogonal':
39 | init.orthogonal_(m.weight.data, gain=init_gain)
40 | else:
41 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
42 | if hasattr(m, 'bias') and m.bias is not None:
43 | init.constant_(m.bias.data, 0.0)
44 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
45 | init.normal_(m.weight.data, 1.0, init_gain)
46 | init.constant_(m.bias.data, 0.0)
47 | if isinstance(net, nn.DataParallel):
48 | network_name = net.module.__class__.__name__
49 | else:
50 | network_name = net.__class__.__name__
51 |
52 | print('initialize network %s with %s' % (network_name, init_type))
53 | net.apply(init_func) # apply the initialization function
54 |
55 |
56 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
57 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
58 | Parameters:
59 | net (network) -- the network to be initialized
60 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
61 | gain (float) -- scaling factor for normal, xavier and orthogonal.
62 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
63 | Return an initialized network.
64 | """
65 | if len(gpu_ids) > 0:
66 | assert(torch.cuda.is_available())
67 | net.to(gpu_ids[0])
68 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
69 | init_weights(net, init_type, init_gain=init_gain)
70 | return net
71 |
72 |
73 | def get_scheduler(optimizer, opt):
74 | """Return a learning rate scheduler
75 | Parameters:
76 | optimizer -- the optimizer of the network
77 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
78 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
79 | For 'linear', we keep the same learning rate for the first epochs
80 | and linearly decay the rate to zero over the next epochs.
81 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
82 | See https://pytorch.org/docs/stable/optim.html for more details.
83 | """
84 | if opt.lr_policy == 'linear':
85 | def lambda_rule(epoch):
86 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
87 | return lr_l
88 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
89 | elif opt.lr_policy == 'step':
90 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
91 | elif opt.lr_policy == 'plateau':
92 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
93 | elif opt.lr_policy == 'cosine':
94 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
95 | else:
96 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
97 | return scheduler
98 |
99 |
100 | def define_network(opt, net, isTrain=True, use_norm='none', init_network=True):
101 | apply_norm(net, use_norm)
102 | if not isTrain:
103 | net.eval()
104 | if len(opt.gpu_ids) > 0:
105 | assert(torch.cuda.is_available())
106 | net.to(opt.device)
107 | net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.data_device)
108 | if init_network:
109 | init_weights(net, init_type='normal', init_gain=0.02)
110 | return net
111 |
112 |
113 | class MultiScaleDiscriminator(nn.Module):
114 | def __init__(self, input_ch, base_ch=64, n_layers=4, norm_type='none', relu_type='LeakyReLU', num_D=4):
115 | super().__init__()
116 |
117 | self.D_pool = nn.ModuleList()
118 | for i in range(num_D):
119 | netD = NLayerDiscriminator(input_ch, base_ch, depth=n_layers, norm_type=norm_type, relu_type=relu_type)
120 | self.D_pool.append(netD)
121 |
122 | def downsample(x):
123 | return nn.functional.interpolate(x, scale_factor=0.5, mode='bicubic', align_corners=False)
124 | self.downsample = downsample
125 |
126 | def forward(self, input, return_feat=False):
127 | results = []
128 | for netd in self.D_pool:
129 | output = netd(input, return_feat)
130 | results.append(output)
131 | input = self.downsample(input)
132 | return results
133 |
134 |
135 | class NLayerDiscriminator(nn.Module):
136 | def __init__(self,
137 | input_ch = 3,
138 | base_ch = 64,
139 | max_ch = 512,
140 | depth = 4,
141 | norm_type = 'none',
142 | relu_type = 'LeakyReLU',
143 | ):
144 | super().__init__()
145 |
146 | nargs = {'norm_type': norm_type, 'relu_type': relu_type}
147 |
148 | self.model = []
149 | self.model.append(ConvLayer(input_ch, base_ch, norm_type='none', relu_type=relu_type))
150 | for i in range(depth):
151 | cin = min(base_ch * 2**(i), max_ch)
152 | cout = min(base_ch * 2**(i+1), max_ch)
153 | self.model.append(ConvLayer(cin, cout, scale='down', **nargs))
154 | self.model = nn.Sequential(*self.model)
155 | self.score_out = ConvLayer(cout, 1, use_pad=False)
156 |
157 | def forward(self, x, return_feat=False):
158 | ret_feats = []
159 | for idx, m in enumerate(self.model):
160 | x = m(x)
161 | ret_feats.append(x)
162 | x = self.score_out(x)
163 | if return_feat:
164 | return x, ret_feats
165 | else:
166 | return x
167 |
168 |
--------------------------------------------------------------------------------
/models/rlutrans.py:
--------------------------------------------------------------------------------
1 |
2 | import common_ESTR as common
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from utils.tools import extract_image_patches, reduce_mean, reduce_sum, same_padding, reverse_patches
7 | import pdb
8 | import math
9 |
10 | def drop_path(x, drop_prob: float = 0., training: bool = False):
11 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
16 | 'survival rate' as the argument.
17 | """
18 | if drop_prob == 0. or not training:
19 | return x
20 | keep_prob = 1 - drop_prob
21 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
22 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
23 | random_tensor.floor_() # binarize
24 | output = x.div(keep_prob) * random_tensor
25 | return output
26 |
27 | # MLP in the paper
28 | class Mlp(nn.Module):
29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
30 | super().__init__()
31 | out_features = out_features or in_features
32 | hidden_features = hidden_features or in_features//4
33 | self.fc1 = nn.Linear(in_features, hidden_features)
34 | self.act = act_layer()
35 | self.fc2 = nn.Linear(hidden_features, out_features)
36 | self.drop = nn.Dropout(drop)
37 |
38 | def forward(self, x):
39 | x = self.fc1(x)
40 | x = self.act(x)
41 | x = self.drop(x)
42 | x = self.fc2(x)
43 | x = self.drop(x)
44 | return x
45 |
46 |
47 | # Efficient Multi-Head Attention in the paper
48 | class EffAttention(nn.Module):
49 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
50 | super().__init__()
51 | self.num_heads = num_heads
52 | head_dim = dim // num_heads
53 | self.scale = qk_scale or head_dim ** -0.5
54 |
55 | self.reduce = nn.Linear(dim, dim//2, bias=qkv_bias)
56 | self.qkv = nn.Linear(dim//2, dim//2 * 3, bias=qkv_bias)
57 | self.proj = nn.Linear(dim//2, dim)
58 | self.attn_drop = nn.Dropout(attn_drop)
59 |
60 | def forward(self, x):
61 | x = self.reduce(x)
62 | B, N, C = x.shape
63 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
64 | q, k, v = qkv[0], qkv[1], qkv[2]
65 |
66 | q_all = torch.split(q, math.ceil(N//4), dim=-2)
67 | k_all = torch.split(k, math.ceil(N//4), dim=-2)
68 | v_all = torch.split(v, math.ceil(N//4), dim=-2)
69 |
70 | output = []
71 | for q,k,v in zip(q_all, k_all, v_all):
72 | attn = (q @ k.transpose(-2, -1)) * self.scale #16*8*37*37
73 | attn = attn.softmax(dim=-1)
74 | attn = self.attn_drop(attn)
75 | trans_x = (attn @ v).transpose(1, 2) #.reshape(B, N, C)
76 | output.append(trans_x)
77 | x = torch.cat(output,dim=1)
78 | x = x.reshape(B,N,C)
79 | x = self.proj(x)
80 | return x
81 |
82 |
83 | ## Key Module: Efficient Transformer (ET) in the paper
84 | class TransBlock(nn.Module):
85 | def __init__(
86 | self, n_feat = 64,dim=768, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
87 | drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
88 | super(TransBlock, self).__init__()
89 | self.dim = dim
90 | self.atten = EffAttention(self.dim, num_heads=8, qkv_bias=False, qk_scale=None, \
91 | attn_drop=0., proj_drop=0.)
92 | self.norm1 = nn.LayerNorm(self.dim)
93 | self.mlp = Mlp(in_features=dim, hidden_features=dim//4, act_layer=act_layer, drop=drop)
94 | self.norm2 = nn.LayerNorm(self.dim)
95 |
96 | def forward(self, x):
97 | B = x.shape[0]
98 | x = extract_image_patches(x, ksizes=[3, 3],
99 | strides=[1,1],
100 | rates=[1, 1],
101 | padding='same') # 16*2304*576
102 | x = x.permute(0,2,1)
103 |
104 | x = x + self.atten(self.norm1(x))
105 | x = x + self.mlp(self.norm2(x))
106 | return x
107 |
--------------------------------------------------------------------------------
/models/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import numpy as np
4 | from .utils import mkdirs
5 | from tensorboardX import SummaryWriter
6 | from datetime import datetime
7 | import socket
8 | import shutil
9 |
10 | class Logger():
11 | def __init__(self, opts):
12 | time_stamp = '_{}'.format(datetime.now().strftime('%Y-%m-%d_%H:%M'))
13 | self.opts = opts
14 | self.log_dir = os.path.join(opts.log_dir, opts.name+time_stamp)
15 | self.phase_keys = ['train', 'val', 'test']
16 | self.iter_log = []
17 | self.epoch_log = OrderedDict()
18 | self.set_mode(opts.phase)
19 |
20 | # check if exist previous log belong to the same experiment name
21 | exist_log = None
22 | for log_name in os.listdir(opts.log_dir):
23 | if opts.name in log_name:
24 | exist_log = log_name
25 | if exist_log is not None:
26 | old_dir = os.path.join(opts.log_dir, exist_log)
27 | archive_dir = os.path.join(opts.log_archive, exist_log)
28 | shutil.move(old_dir, archive_dir)
29 |
30 | self.mk_log_file()
31 |
32 | self.writer = SummaryWriter(self.log_dir)
33 |
34 | def mk_log_file(self):
35 | mkdirs(self.log_dir)
36 | self.txt_files = OrderedDict()
37 | for i in self.phase_keys:
38 | self.txt_files[i] = os.path.join(self.log_dir, 'log_{}'.format(i))
39 |
40 | def set_mode(self, mode):
41 | self.mode = mode
42 | self.epoch_log[mode] = []
43 |
44 | def set_current_iter(self, cur_iter):
45 | self.cur_iter = cur_iter
46 |
47 | def record_losses(self, items):
48 | """
49 | iteration log: [iter][{key: value}]
50 | """
51 | self.iter_log.append(items)
52 | for k, v in items.items():
53 | if 'loss' in k.lower():
54 | self.writer.add_scalar('loss/{}'.format(k), v, self.cur_iter)
55 |
56 | def record_scalar(self, items):
57 | """
58 | Add scalar records. item, {key: value}
59 | """
60 | for i in items.keys():
61 | self.writer.add_scalar('{}'.format(i), items[i], self.cur_iter)
62 |
63 | def record_image(self, visual_img, tag='ckpt_image'):
64 | self.writer.add_image(tag, visual_img, self.cur_iter, dataformats='HWC')
65 |
66 | def record_images(self, visuals, nrow=6, tag='ckpt_image'):
67 | imgs = []
68 | nrow = min(nrow, visuals[0].shape[0])
69 | for i in range(nrow):
70 | tmp_imgs = [x[i] for x in visuals]
71 | imgs.append(np.hstack(tmp_imgs))
72 | imgs = np.vstack(imgs).astype(np.uint8)
73 | self.writer.add_image(tag, imgs, self.cur_iter, dataformats='HWC')
74 |
75 | def record_text(self, tag, text):
76 | self.writer.add_text(tag, text)
77 |
78 | def printIterSummary(self, epoch, cur_iters, total_it, timer):
79 | msg = '{}\nIter: [{}]{:03d}/{:03d}\t\t'.format(
80 | timer.to_string(total_it - cur_iters), epoch, cur_iters, total_it)
81 | for k, v in self.iter_log[-1].items():
82 | msg += '{}: {:.6f}\t'.format(k, v)
83 | print(msg + '\n')
84 | with open(self.txt_files[self.mode], 'a+') as f:
85 | f.write(msg + '\n')
86 |
87 | def close(self):
88 | self.writer.export_scalars_to_json(os.path.join(self.log_dir, 'all_scalars.json'))
89 | self.writer.close()
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/models/utils/rlutrans.py:
--------------------------------------------------------------------------------
1 |
2 | import common_ESTR as common
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from utils.tools import extract_image_patches, reduce_mean, reduce_sum, same_padding, reverse_patches
7 | import pdb
8 | import math
9 |
10 | def drop_path(x, drop_prob: float = 0., training: bool = False):
11 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
16 | 'survival rate' as the argument.
17 | """
18 | if drop_prob == 0. or not training:
19 | return x
20 | keep_prob = 1 - drop_prob
21 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
22 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
23 | random_tensor.floor_() # binarize
24 | output = x.div(keep_prob) * random_tensor
25 | return output
26 |
27 | # MLP in the paper
28 | class Mlp(nn.Module):
29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
30 | super().__init__()
31 | out_features = out_features or in_features
32 | hidden_features = hidden_features or in_features//4
33 | self.fc1 = nn.Linear(in_features, hidden_features)
34 | self.act = act_layer()
35 | self.fc2 = nn.Linear(hidden_features, out_features)
36 | self.drop = nn.Dropout(drop)
37 |
38 | def forward(self, x):
39 | x = self.fc1(x)
40 | x = self.act(x)
41 | x = self.drop(x)
42 | x = self.fc2(x)
43 | x = self.drop(x)
44 | return x
45 |
46 |
47 | # Efficient Multi-Head Attention in the paper
48 | class EffAttention(nn.Module):
49 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
50 | super().__init__()
51 | self.num_heads = num_heads
52 | head_dim = dim // num_heads
53 | self.scale = qk_scale or head_dim ** -0.5
54 |
55 | self.reduce = nn.Linear(dim, dim//2, bias=qkv_bias)
56 | self.qkv = nn.Linear(dim//2, dim//2 * 3, bias=qkv_bias)
57 | self.proj = nn.Linear(dim//2, dim)
58 | self.attn_drop = nn.Dropout(attn_drop)
59 |
60 | def forward(self, x):
61 | x = self.reduce(x)
62 | B, N, C = x.shape
63 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
64 | q, k, v = qkv[0], qkv[1], qkv[2]
65 |
66 | q_all = torch.split(q, math.ceil(N//4), dim=-2)
67 | k_all = torch.split(k, math.ceil(N//4), dim=-2)
68 | v_all = torch.split(v, math.ceil(N//4), dim=-2)
69 |
70 | output = []
71 | for q,k,v in zip(q_all, k_all, v_all):
72 | attn = (q @ k.transpose(-2, -1)) * self.scale #16*8*37*37
73 | attn = attn.softmax(dim=-1)
74 | attn = self.attn_drop(attn)
75 | trans_x = (attn @ v).transpose(1, 2) #.reshape(B, N, C)
76 | output.append(trans_x)
77 | x = torch.cat(output,dim=1)
78 | x = x.reshape(B,N,C)
79 | x = self.proj(x)
80 | return x
81 |
82 |
83 | ## Key Module: Efficient Transformer (ET) in the paper
84 | class TransBlock(nn.Module):
85 | def __init__(
86 | self, n_feat = 64,dim=768, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
87 | drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
88 | super(TransBlock, self).__init__()
89 | self.dim = dim
90 | self.atten = EffAttention(self.dim, num_heads=8, qkv_bias=False, qk_scale=None, \
91 | attn_drop=0., proj_drop=0.)
92 | self.norm1 = nn.LayerNorm(self.dim)
93 | self.mlp = Mlp(in_features=dim, hidden_features=dim//4, act_layer=act_layer, drop=drop)
94 | self.norm2 = nn.LayerNorm(self.dim)
95 |
96 | def forward(self, x):
97 | B = x.shape[0]
98 | x = extract_image_patches(x, ksizes=[3, 3],
99 | strides=[1,1],
100 | rates=[1, 1],
101 | padding='same') # 16*2304*576
102 | x = x.permute(0,2,1)
103 |
104 | x = x + self.atten(self.norm1(x))
105 | x = x + self.mlp(self.norm2(x))
106 | return x
107 |
--------------------------------------------------------------------------------
/models/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import datetime
3 | from collections import OrderedDict
4 |
5 | class Timer():
6 | def __init__(self):
7 | self.reset_timer()
8 | self.start = time.time()
9 |
10 | def reset_timer(self):
11 | self.before = time.time()
12 | self.timer = OrderedDict()
13 |
14 | def update_time(self, key):
15 | self.timer[key] = time.time() - self.before
16 | self.before = time.time()
17 |
18 | def to_string(self, iters_left, short=False):
19 | iter_total = sum(self.timer.values())
20 | msg = "{:%Y-%m-%d %H:%M:%S}\tElapse: {}\tTimeLeft: {}\t".format(
21 | datetime.datetime.now(),
22 | datetime.timedelta(seconds=round(time.time() - self.start)),
23 | datetime.timedelta(seconds=round(iter_total*iters_left))
24 | )
25 | if short:
26 | msg += '{}: {:.2f}s'.format('|'.join(self.timer.keys()), iter_total)
27 | else:
28 | msg += '\tIterTotal: {:.2f}s\t{}: {} '.format(iter_total,
29 | '|'.join(self.timer.keys()), ' '.join('{:.2f}s'.format(x) for x in self.timer.values()))
30 | return msg
31 |
32 |
--------------------------------------------------------------------------------
/models/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch.nn.functional as F
7 |
8 | def normalize(x):
9 | return x.mul_(2).add_(-1)
10 |
11 | def same_padding(images, ksizes, strides, rates):
12 | assert len(images.size()) == 4
13 | batch_size, channel, rows, cols = images.size()
14 | out_rows = (rows + strides[0] - 1) // strides[0]
15 | out_cols = (cols + strides[1] - 1) // strides[1]
16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1
17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1
18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
20 | # Pad the input
21 | padding_top = int(padding_rows / 2.)
22 | padding_left = int(padding_cols / 2.)
23 | padding_bottom = padding_rows - padding_top
24 | padding_right = padding_cols - padding_left
25 | paddings = (padding_left, padding_right, padding_top, padding_bottom)
26 | images = torch.nn.ZeroPad2d(paddings)(images)
27 | return images
28 |
29 |
30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'):
31 | """
32 | Extract patches from images and put them in the C output dimension.
33 | :param padding:
34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
36 | each dimension of images
37 | :param strides: [stride_rows, stride_cols]
38 | :param rates: [dilation_rows, dilation_cols]
39 | :return: A Tensor
40 | """
41 | assert len(images.size()) == 4
42 | assert padding in ['same', 'valid']
43 | batch_size, channel, height, width = images.size()
44 |
45 | if padding == 'same':
46 | images = same_padding(images, ksizes, strides, rates)
47 | elif padding == 'valid':
48 | pass
49 | else:
50 | raise NotImplementedError('Unsupported padding type: {}.\
51 | Only "same" or "valid" are supported.'.format(padding))
52 |
53 | unfold = torch.nn.Unfold(kernel_size=ksizes,
54 | dilation=rates,
55 | padding=0,
56 | stride=strides)
57 | patches = unfold(images)
58 | return patches # [N, C*k*k, L], L is the total number of such blocks
59 |
60 | def reverse_patches(images, out_size, ksizes, strides, padding):
61 | """
62 | Extract patches from images and put them in the C output dimension.
63 | :param padding:
64 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
65 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
66 | each dimension of images
67 | :param strides: [stride_rows, stride_cols]
68 | :param rates: [dilation_rows, dilation_cols]
69 | :return: A Tensor
70 | """
71 | unfold = torch.nn.Fold(output_size = out_size,
72 | kernel_size=ksizes,
73 | dilation=1,
74 | padding=padding,
75 | stride=strides)
76 | patches = unfold(images)
77 | return patches # [N, C*k*k, L], L is the total number of such blocks
78 | def reduce_mean(x, axis=None, keepdim=False):
79 | if not axis:
80 | axis = range(len(x.shape))
81 | for i in sorted(axis, reverse=True):
82 | x = torch.mean(x, dim=i, keepdim=keepdim)
83 | return x
84 |
85 |
86 | def reduce_std(x, axis=None, keepdim=False):
87 | if not axis:
88 | axis = range(len(x.shape))
89 | for i in sorted(axis, reverse=True):
90 | x = torch.std(x, dim=i, keepdim=keepdim)
91 | return x
92 |
93 |
94 | def reduce_sum(x, axis=None, keepdim=False):
95 | if not axis:
96 | axis = range(len(x.shape))
97 | for i in sorted(axis, reverse=True):
98 | x = torch.sum(x, dim=i, keepdim=keepdim)
99 | return x
--------------------------------------------------------------------------------
/models/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2 as cv
4 | from skimage import io
5 | from PIL import Image
6 | import os
7 | import subprocess
8 |
9 |
10 | def img_to_tensor(img_path, device, size=None, mode='rgb'):
11 | """
12 | Read image from img_path, and convert to (C, H, W) tensor in range [-1, 1]
13 | """
14 | img = Image.open(img_path).convert('RGB')
15 | img = np.array(img)
16 | if mode=='bgr':
17 | img = img[..., ::-1]
18 | if size:
19 | img = cv.resize(img, size)
20 | img = img / 255 * 2 - 1
21 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
22 | return img_tensor.float()
23 |
24 |
25 | def tensor_to_img(tensor, save_path=None, size=None, mode='rgb', normal=False):
26 | img_array = tensor.squeeze().data.cpu().numpy()
27 | img_array = img_array.transpose(1, 2, 0)
28 | if size is not None:
29 | img_array = cv.resize(img_array, size, interpolation=cv.INTER_LINEAR)
30 | if normal:
31 | # img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min())
32 | img_array = (img_array + 1.) / 2. * 255.
33 | img_array = img_array.clip(0, 255)
34 | if save_path:
35 | if img_array.max() <= 1:
36 | img_array = (img_array * 255).astype(np.uint8)
37 | io.imsave(save_path, img_array)
38 |
39 | return img_array.astype(np.uint8)
40 |
41 |
42 | def tensor_to_numpy(tensor):
43 | return tensor.data.cpu().numpy()
44 |
45 |
46 | def batch_numpy_to_image(array, size=None):
47 | """
48 | Input: numpy array (B, C, H, W) in [-1, 1]
49 | """
50 | if isinstance(size, int):
51 | size = (size, size)
52 |
53 | out_imgs = []
54 | array = np.clip((array + 1)/2 * 255, 0, 255)
55 | array = np.transpose(array, (0, 2, 3, 1))
56 | for i in range(array.shape[0]):
57 | if size is not None:
58 | tmp_array = cv.resize(array[i], size)
59 | else:
60 | tmp_array = array[i]
61 | out_imgs.append(tmp_array)
62 | return np.array(out_imgs)
63 |
64 |
65 | def batch_tensor_to_img(tensor, size=None):
66 | """
67 | Input: (B, C, H, W)
68 | Return: RGB image, [0, 255]
69 | """
70 | arrays = tensor_to_numpy(tensor)
71 | out_imgs = batch_numpy_to_image(arrays, size)
72 | return out_imgs
73 |
74 |
75 | def mkdirs(paths):
76 | if isinstance(paths, list) and not isinstance(paths, str):
77 | for path in paths:
78 | if not os.path.exists(path):
79 | os.makedirs(path)
80 | else:
81 | if not os.path.exists(paths):
82 | os.makedirs(paths)
83 |
84 |
85 | def get_gpu_memory_map():
86 | """Get the current gpu usage within visible cuda devices.
87 |
88 | Returns
89 | -------
90 | Memory Map: dict
91 | Keys are device ids as integers.
92 | Values are memory usage as integers in MB.
93 | Device Ids: gpu ids sorted in descending order according to the available memory.
94 | """
95 | result = subprocess.check_output(
96 | [
97 | 'nvidia-smi', '--query-gpu=memory.used',
98 | '--format=csv,nounits,noheader'
99 | ]).decode('utf-8')
100 | # Convert lines into a dictionary
101 | gpu_memory = np.array([int(x) for x in result.strip().split('\n')])
102 | if 'CUDA_VISIBLE_DEVICES' in os.environ:
103 | visible_devices = sorted([int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')])
104 | else:
105 | visible_devices = range(len(gpu_memory))
106 | gpu_memory_map = dict(zip(range(len(visible_devices)), gpu_memory[visible_devices]))
107 | return gpu_memory_map, sorted(gpu_memory_map, key=gpu_memory_map.get)
108 |
109 |
110 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import random
5 | from utils import utils
6 | import torch
7 | import models
8 | import data
9 | from utils import utils
10 |
11 |
12 | class BaseOptions():
13 | """This class defines options used during both training and test time.
14 |
15 | It also implements several helper functions such as parsing, printing, and saving the options.
16 | It also gathers additional options defined in functions in both dataset class and model class.
17 | """
18 |
19 | def __init__(self):
20 | """Reset the class; indicates the class hasn't been initailized"""
21 | self.initialized = False
22 |
23 | def initialize(self, parser):
24 | """Define the common options that are used in both training and test."""
25 | # basic parameters
26 | parser.add_argument('--dataroot', required=False, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
27 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
28 | parser.add_argument('--gpus', type=int, default=1, help='how many gpus to use')
29 | parser.add_argument('--seed', type=int, default=123, help='Random seed for training')
30 | parser.add_argument('--checkpoints_dir', type=str, default='./epoch150_18000', help='models are saved here')
31 | parser.add_argument('--debug', action='store_true', help='if specified, set to debug mode')
32 | # model parameters
33 | parser.add_argument('--model', type=str, default='drn', help='chooses which model to train [parse|enhance]')
34 | parser.add_argument('--att_name', type=str, default='drn', help='attention type [drn|spar3d]')
35 | parser.add_argument('--res_depth', type=int, default=10, help='depth of residual layers')
36 | parser.add_argument('--bottleneck_size', type=int, default=4, help='bottleneck feature size in hourglass block')
37 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
38 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
39 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
40 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
41 | parser.add_argument('--n_layers_D', type=int, default=4, help='downsampling layers in discriminator')
42 | parser.add_argument('--num_D', type=int, default=3, help='numbers of discriminators')
43 | parser.add_argument('--Gnorm', type=str, default='bn', help='generator norm [in | bn | none]')
44 | parser.add_argument('--Dnorm', type=str, default='none', help='discriminator norm [in | bn | none]')
45 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
46 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
47 | # dataset parameters
48 | parser.add_argument('--dataset_name', type=str, default='celeba', help='dataset name')
49 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
50 | parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data')
51 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
52 | parser.add_argument('--load_size', type=int, default=512, help='scale images to this size')
53 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
54 | parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
55 | # additional parameters
56 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
57 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
58 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
59 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
60 | self.initialized = True
61 | return parser
62 |
63 | def gather_options(self):
64 | """Initialize our parser with basic options(only once).
65 | Add additional model-specific and dataset-specific options.
66 | These options are defined in the function
67 | in model and dataset classes.
68 | """
69 | if not self.initialized: # check if it has been initialized
70 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
71 | parser = self.initialize(parser)
72 |
73 | # get the basic options
74 | opt, _ = parser.parse_known_args()
75 |
76 | # modify model-related parser options
77 | model_name = opt.model
78 | model_option_setter = models.get_option_setter(model_name)
79 | parser = model_option_setter(parser, self.isTrain)
80 | opt, _ = parser.parse_known_args() # parse again with new defaults
81 |
82 | # modify dataset-related parser options
83 | dataset_name = opt.dataset_name
84 | dataset_option_setter = data.get_option_setter(dataset_name)
85 | parser = dataset_option_setter(parser, self.isTrain)
86 |
87 | # save and return the parser
88 | self.parser = parser
89 | return parser.parse_args()
90 |
91 | def print_options(self, opt):
92 | """Print and save options
93 |
94 | It will print both current options and default values(if different).
95 | It will save options into a text file / [checkpoints_dir] / opt.txt
96 | """
97 | message = ''
98 | message += '----------------- Options ---------------\n'
99 | for k, v in sorted(vars(opt).items()):
100 | comment = ''
101 | default = self.parser.get_default(k)
102 | if v != default:
103 | comment = '\t[default: %s]' % str(default)
104 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
105 | message += '----------------- End -------------------'
106 | print(message)
107 |
108 | # save to the disk
109 | opt.expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
110 | utils.mkdirs(opt.expr_dir)
111 | file_name = os.path.join(opt.expr_dir, '{}_opt.txt'.format(opt.phase))
112 | with open(file_name, 'wt') as opt_file:
113 | opt_file.write(message)
114 | opt_file.write('\n')
115 |
116 | opt.log_dir = os.path.join(opt.checkpoints_dir, 'log_dir')
117 | utils.mkdirs(opt.log_dir)
118 | opt.log_archive = os.path.join(opt.checkpoints_dir, 'log_archive')
119 | utils.mkdirs(opt.log_archive)
120 |
121 | def parse(self):
122 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
123 | opt = self.gather_options()
124 | opt.isTrain = self.isTrain # train or test
125 |
126 | # Find avaliable GPUs automatically
127 | if opt.gpus > 0:
128 | opt.gpu_ids = utils.get_gpu_memory_map()[1][:opt.gpus]
129 | if not isinstance(opt.gpu_ids, list):
130 | opt.gpu_ids = [opt.gpu_ids]
131 | torch.cuda.set_device(opt.gpu_ids[0])
132 | opt.device = torch.device('cuda:{}'.format(opt.gpu_ids[0 % opt.gpus]))
133 | opt.data_device = torch.device('cuda:{}'.format(opt.gpu_ids[2 % opt.gpus]))
134 | else:
135 | opt.gpu_ids = []
136 | opt.device = torch.device('cpu')
137 |
138 | # set random seed for reproducibility
139 | np.random.seed(opt.seed)
140 | random.seed(opt.seed)
141 | torch.manual_seed(opt.seed)
142 | torch.cuda.manual_seed_all(opt.seed)
143 |
144 | # process opt.suffix
145 | if opt.suffix:
146 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
147 | opt.name = opt.name + suffix
148 |
149 | self.print_options(opt)
150 |
151 | self.opt = opt
152 | return self.opt
153 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
13 | parser.add_argument('--test_img_path', type=str, default='', help='path of single test image.')
14 | parser.add_argument('--test_upscale', type=int, default=1, help='upscale single test image.')
15 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
16 | parser.add_argument('--save_as_dir', type=str, default='', help='save results in different dir.')
17 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
18 | parser.add_argument('--pretrain_model_path', type=str, default='', help='load pretrain model path if specified')
19 | # Dropout and Batchnorm has different behavioir during training and test.
20 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
21 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
22 | # rewrite devalue values
23 | parser.set_defaults(model='test')
24 | # To avoid cropping, the load_size should be the same as crop_size
25 | parser.set_defaults(load_size=parser.get_default('crop_size'))
26 | self.isTrain = False
27 | return parser
28 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TrainOptions(BaseOptions):
4 | """This class includes training options.
5 |
6 | It also includes shared options defined in BaseOptions.
7 | """
8 |
9 | def initialize(self, parser):
10 | parser = BaseOptions.initialize(self, parser)
11 | # visdom and HTML visualization parameters
12 | parser.add_argument('--visual_freq', type=int, default=400, help='frequency of show training images in tensorboard')
13 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
14 | # network saving and loading parameters
15 | parser.add_argument('--save_iter_freq', type=int, default=1000, help='frequency of saving the models')
16 | parser.add_argument('--save_latest_freq', type=int, default=1000, help='save latest freq')
17 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
18 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
19 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
20 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
21 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
22 | # training parameters
23 | parser.add_argument('--resume_epoch', type=int, default=0, help='training resume epoch')
24 | parser.add_argument('--resume_iter', type=int, default=0, help='training resume iter')
25 | parser.add_argument('--n_epochs', type=int, default=90, help='number of epochs with the initial learning rate')
26 | parser.add_argument('--n_epochs_decay', type=int, default=10, help='number of epochs to linearly decay learning rate to zero')
27 | parser.add_argument('--total_epochs', type=int, default=100, help='# of epochs to train')
28 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
29 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
30 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
31 | parser.add_argument('--g_lr', type=float, default=0.0001, help='initial learning rate for adam')
32 | parser.add_argument('--d_lr', type=float, default=0.0004, help='initial learning rate for adam')
33 | parser.add_argument('--gan_mode', type=str, default='hinge', help='the type of GAN objective. [vanilla| lsgan | hinge]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
34 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
35 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
36 |
37 | self.isTrain = True
38 |
39 | return parser
40 |
41 |
--------------------------------------------------------------------------------
/psnr_ssim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.metrics import structural_similarity as compare_ssim
3 | from PIL import Image
4 | import os
5 | from IPython import embed
6 |
7 | def rgb2y_matlab(x):
8 | """Convert RGB image to illumination Y in Ycbcr space in matlab way.
9 | -------------
10 | # Args
11 | - Input: x, byte RGB image, value range [0, 255]
12 | - Ouput: byte gray image, value range [16, 235]
13 |
14 | # Shape
15 | - Input: (H, W, C)
16 | - Output: (H, W)
17 | """
18 | K = np.array([65.481, 128.553, 24.966]) / 255.0
19 | Y = 16 + np.matmul(x, K)
20 | return Y.astype(np.uint8)
21 |
22 |
23 | def PSNR(im1, im2, use_y_channel=True):
24 | """Calculate PSNR score between im1 and im2
25 | --------------
26 | # Args
27 | - im1, im2: input byte RGB image, value range [0, 255]
28 | - use_y_channel: if convert im1 and im2 to illumination channel first
29 | """
30 | if use_y_channel:
31 | im1 = rgb2y_matlab(im1)
32 | im2 = rgb2y_matlab(im2)
33 | im1 = im1.astype(np.float)
34 | im2 = im2.astype(np.float)
35 | mse = np.mean(np.square(im1 - im2))
36 | return 10 * np.log10(255**2 / mse)
37 |
38 |
39 | def SSIM(gt_img, noise_img):
40 | """Calculate SSIM score between im1 and im2 in Y space
41 | -------------
42 | # Args
43 | - gt_img: ground truth image, byte RGB image
44 | - noise_img: image with noise, byte RGB image
45 | """
46 | gt_img = rgb2y_matlab(gt_img)
47 | noise_img = rgb2y_matlab(noise_img)
48 |
49 | ssim_score = compare_ssim(gt_img, noise_img, gaussian_weights=True,
50 | sigma=1.5, use_sample_covariance=False)
51 | return ssim_score
52 |
53 | def psnr_ssim_dir(gt_dir, test_dir):
54 | gt_img_list = sorted([x for x in sorted(os.listdir(gt_dir))])
55 | test_img_list = sorted([x for x in sorted(os.listdir(test_dir))])
56 | # assert gt_img_list == test_img_list, 'Test image names are different from gt images.'
57 |
58 | psnr_score = 0
59 | ssim_score = 0
60 | count = 0
61 | print(gt_img_list)
62 | for gt_name, test_name in zip(gt_img_list, test_img_list):
63 | count += 1
64 | gt_img = Image.open(os.path.join(gt_dir, gt_name))
65 | test_img = Image.open(os.path.join(test_dir, test_name))
66 | gt_img = np.array(gt_img)
67 | test_img = np.array(test_img)
68 | psnr_score += PSNR(gt_img, test_img)
69 | ssim_score += SSIM(gt_img, test_img)
70 | if (PSNR(gt_img, test_img) > 30):
71 | print(PSNR(gt_img, test_img))
72 | print(gt_img_list[count])
73 | return psnr_score / len(gt_img_list), ssim_score / len(gt_img_list)
74 |
75 | if __name__ == '__main__':
76 | '''
77 | gt_dir = "/home2/ZiXiangXu/Last_ding/spation_test_1000/"
78 | test_dirs = [
79 | "/home2/ZiXiangXu/Last_ding/res4/epoch_CelebA/"
80 | ]
81 | '''
82 |
83 | gt_dir = "/home2/ZiXiangXu/DataSets/spation_test_1000/"
84 | test_dirs = [
85 | "/home2/ZiXiangXu/Last_ding/res4/epoch_CelebA/"
86 | ]
87 |
88 | for td in test_dirs:
89 | result = psnr_ssim_dir(td, gt_dir)
90 | print(td, result)
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/psnr_ssim_log.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.metrics import structural_similarity as compare_ssim
3 | from PIL import Image
4 | import os
5 | from IPython import embed
6 |
7 | def rgb2y_matlab(x):
8 | """Convert RGB image to illumination Y in Ycbcr space in matlab way.
9 | -------------
10 | # Args
11 | - Input: x, byte RGB image, value range [0, 255]
12 | - Ouput: byte gray image, value range [16, 235]
13 |
14 | # Shape
15 | - Input: (H, W, C)
16 | - Output: (H, W)
17 | """
18 | K = np.array([65.481, 128.553, 24.966]) / 255.0
19 | Y = 16 + np.matmul(x, K)
20 | return Y.astype(np.uint8)
21 |
22 |
23 | def PSNR(im1, im2, use_y_channel=True):
24 | """Calculate PSNR score between im1 and im2
25 | --------------
26 | # Args
27 | - im1, im2: input byte RGB image, value range [0, 255]
28 | - use_y_channel: if convert im1 and im2 to illumination channel first
29 | """
30 | if use_y_channel:
31 | im1 = rgb2y_matlab(im1)
32 | im2 = rgb2y_matlab(im2)
33 | im1 = im1.astype(np.float)
34 | im2 = im2.astype(np.float)
35 | mse = np.mean(np.square(im1 - im2))
36 | return 10 * np.log10(255**2 / mse)
37 |
38 |
39 | def write_log(log_file, log_str):
40 | f = open(log_file, 'a')
41 | f.write(log_str + '\n')
42 | f.close()
43 |
44 | def SSIM(gt_img, noise_img):
45 | """Calculate SSIM score between im1 and im2 in Y space
46 | -------------
47 | # Args
48 | - gt_img: ground truth image, byte RGB image
49 | - noise_img: image with noise, byte RGB image
50 | """
51 | gt_img = rgb2y_matlab(gt_img)
52 | noise_img = rgb2y_matlab(noise_img)
53 |
54 | ssim_score = compare_ssim(gt_img, noise_img, gaussian_weights=True,
55 | sigma=1.5, use_sample_covariance=False)
56 | return ssim_score
57 |
58 | def psnr_ssim_dir(gt_dir, test_dir):
59 | gt_img_list = sorted([x for x in sorted(os.listdir(gt_dir))])
60 | test_img_list = sorted([x for x in sorted(os.listdir(test_dir))])
61 | # assert gt_img_list == test_img_list, 'Test image names are different from gt images.'
62 |
63 | psnr_score = 0
64 | ssim_score = 0
65 | for gt_name, test_name in zip(gt_img_list, test_img_list):
66 | gt_img = Image.open(os.path.join(gt_dir, gt_name))
67 | test_img = Image.open(os.path.join(test_dir, test_name))
68 | gt_img = np.array(gt_img)
69 | test_img = np.array(test_img)
70 | psnr_score += PSNR(gt_img, test_img)
71 | ssim_score += SSIM(gt_img, test_img)
72 | return psnr_score / len(gt_img_list), ssim_score / len(gt_img_list)
73 |
74 | if __name__ == '__main__':
75 |
76 | gt_dir = "/home2/ZiXiangXu/DataSets/spation_test_1000/"
77 |
78 | test_dirs = []
79 | input_txt = "/home2/ZiXiangXu/Last_ding/res4_jiu/list.txt"
80 | f = open(input_txt)
81 | for line in f:
82 | line = line.strip('\n')
83 | line = line.split(' ')
84 | test_dirs.append(line[0])
85 | f.close
86 |
87 | output_dir = "/home2/ZiXiangXu/Last_ding/res4_jiu/log_test/"
88 | logtxt_dir = os.path.join(output_dir, 'log.txt')
89 |
90 | for td in test_dirs:
91 | result = psnr_ssim_dir(td, gt_dir)
92 |
93 | log_str = '%.4f' % (result[0])
94 | write_log(logtxt_dir, log_str)
95 |
96 | print(td, result)
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/read list.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/read list.py
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | addict==2.4.0
3 | alabaster==0.7.12
4 | anaconda-client==1.9.0
5 | anaconda-navigator==2.1.1
6 | anaconda-project==0.10.1
7 | anyio==2.2.0
8 | appdirs==1.4.4
9 | argh==0.26.2
10 | argon2-cffi==20.1.0
11 | arrow==0.13.1
12 | asn1crypto==1.4.0
13 | astroid==2.6.6
14 | astropy==4.3.1
15 | async-generator==1.10
16 | atomicwrites==1.4.0
17 | attrs==21.2.0
18 | autopep8==1.5.7
19 | Babel==2.9.1
20 | backcall==0.2.0
21 | backports.functools-lru-cache==1.6.4
22 | backports.shutil-get-terminal-size==1.0.0
23 | backports.tempfile==1.0
24 | backports.weakref==1.0.post1
25 | basicsr==1.4.2
26 | beautifulsoup4==4.10.0
27 | binaryornot==0.4.4
28 | bitarray==2.3.0
29 | bkcharts==0.2
30 | black==19.10b0
31 | bleach==4.0.0
32 | bokeh==2.4.1
33 | boto==2.49.0
34 | Bottleneck==1.3.2
35 | brotlipy==0.7.0
36 | cached-property==1.5.2
37 | cachetools==5.3.0
38 | certifi==2021.10.8
39 | cffi==1.14.6
40 | chardet==4.0.0
41 | charset-normalizer==2.0.4
42 | click==8.0.3
43 | cloudpickle==2.0.0
44 | clyent==1.2.2
45 | colorama==0.4.4
46 | conda==4.10.3
47 | conda-build==3.21.5
48 | conda-content-trust==0+unknown
49 | conda-pack==0.6.0
50 | conda-package-handling==1.7.3
51 | conda-repo-cli==1.0.4
52 | conda-token==0.3.0
53 | conda-verify==3.4.2
54 | contextlib2==0.6.0.post1
55 | cookiecutter==1.7.2
56 | cryptography==3.4.8
57 | cycler==0.10.0
58 | Cython==0.29.24
59 | cytoolz==0.11.0
60 | daal4py==2021.3.0
61 | dask==2021.10.0
62 | debugpy==1.4.1
63 | decorator==5.1.0
64 | defusedxml==0.7.1
65 | diff-match-patch==20200713
66 | dill==0.3.6
67 | distributed==2021.10.0
68 | dlib==19.24.0
69 | docutils==0.17.1
70 | dominate==2.7.0
71 | einops==0.3.2
72 | entrypoints==0.3
73 | et-xmlfile==1.1.0
74 | fastcache==1.1.0
75 | filelock==3.3.1
76 | flake8==3.9.2
77 | Flask==1.1.2
78 | fonttools==4.25.0
79 | fsspec==2021.8.1
80 | future==0.18.2
81 | gdown==4.6.4
82 | gevent==21.8.0
83 | glob2==0.7
84 | gmpy2==2.0.8
85 | google-auth==2.16.2
86 | google-auth-oauthlib==1.0.0
87 | greenlet==1.1.1
88 | grpcio==1.51.3
89 | h5py==3.3.0
90 | HeapDict==1.0.1
91 | html5lib==1.1
92 | idna==3.2
93 | imagecodecs==2021.8.26
94 | imageio==2.10.1
95 | imagesize==1.2.0
96 | imgaug==0.4.0
97 | importlib-metadata==4.8.1
98 | inflection==0.5.1
99 | iniconfig==1.1.1
100 | intervaltree==3.1.0
101 | ipykernel==6.4.1
102 | ipython==7.29.0
103 | ipython-genutils==0.2.0
104 | ipywidgets==7.6.5
105 | isort==5.9.3
106 | itsdangerous==2.0.1
107 | jdcal==1.4.1
108 | jedi==0.18.0
109 | jeepney==0.7.1
110 | Jinja2==2.11.3
111 | jinja2-time==0.2.0
112 | joblib==1.1.0
113 | json5==0.9.6
114 | jsonschema==3.2.0
115 | jupyter==1.0.0
116 | jupyter-client==6.1.12
117 | jupyter-console==6.4.0
118 | jupyter-core==4.8.1
119 | jupyter-server==1.4.1
120 | jupyterlab==3.2.1
121 | jupyterlab-pygments==0.1.2
122 | jupyterlab-server==2.8.2
123 | jupyterlab-widgets==1.0.0
124 | keyring==23.1.0
125 | kiwisolver==1.3.1
126 | lazy-object-proxy==1.6.0
127 | libarchive-c==2.9
128 | llvmlite==0.37.0
129 | lmdb==1.4.0
130 | locket==0.2.1
131 | lpips==0.1.3
132 | lxml==4.6.3
133 | Markdown==3.4.1
134 | MarkupSafe==1.1.1
135 | matplotlib==3.4.3
136 | matplotlib-inline==0.1.2
137 | mccabe==0.6.1
138 | mistune==0.8.4
139 | mkl-fft==1.3.1
140 | mkl-random==1.2.2
141 | mkl-service==2.4.0
142 | mock==4.0.3
143 | more-itertools==8.10.0
144 | mpmath==1.2.1
145 | msgpack==1.0.2
146 | multipledispatch==0.6.0
147 | munkres==1.1.4
148 | mypy-extensions==0.4.3
149 | navigator-updater==0.2.1
150 | nbclassic==0.2.6
151 | nbclient==0.5.3
152 | nbconvert==6.1.0
153 | nbformat==5.1.3
154 | nest-asyncio==1.5.1
155 | networkx==2.6.3
156 | nltk==3.6.5
157 | nose==1.3.7
158 | notebook==6.4.5
159 | numba==0.54.1
160 | numexpr==2.7.3
161 | numpy==1.20.3
162 | numpydoc==1.1.0
163 | oauthlib==3.2.2
164 | olefile==0.46
165 | opencv-contrib-python==4.6.0.66
166 | opencv-python==4.6.0.66
167 | openpyxl==3.0.9
168 | packaging==21.0
169 | pandas==1.3.4
170 | pandocfilters==1.4.3
171 | parso==0.8.2
172 | partd==1.2.0
173 | path==16.0.0
174 | pathlib==1.0.1
175 | pathlib2==2.3.6
176 | pathspec==0.7.0
177 | patsy==0.5.2
178 | pep8==1.7.1
179 | pexpect==4.8.0
180 | pickleshare==0.7.5
181 | Pillow==8.4.0
182 | pip==21.2.4
183 | pkginfo==1.7.1
184 | pluggy==0.13.1
185 | ply==3.11
186 | poyo==0.5.0
187 | prometheus-client==0.11.0
188 | prompt-toolkit==3.0.20
189 | protobuf==4.22.1
190 | protocol==0.1.0
191 | psutil==5.8.0
192 | ptyprocess==0.7.0
193 | py==1.10.0
194 | pyasn1==0.4.8
195 | pyasn1-modules==0.2.8
196 | pycodestyle==2.7.0
197 | pycosat==0.6.3
198 | pycparser==2.20
199 | pycurl==7.44.1
200 | pydocstyle==6.1.1
201 | pyerfa==2.0.0
202 | pyflakes==2.3.1
203 | Pygments==2.10.0
204 | PyJWT==2.1.0
205 | pylint==2.9.6
206 | pyls-spyder==0.4.0
207 | pyodbc==4.0.0-unsupported
208 | pyOpenSSL==21.0.0
209 | pyparsing==3.0.4
210 | pyrsistent==0.18.0
211 | pyrtools==1.0.1
212 | PySocks==1.7.1
213 | pytest==6.2.4
214 | python-dateutil==2.8.2
215 | python-lsp-black==1.0.0
216 | python-lsp-jsonrpc==1.0.0
217 | python-lsp-server==1.2.4
218 | python-slugify==5.0.2
219 | pytorch-fid==0.3.0
220 | pytz==2021.3
221 | PyWavelets==1.1.1
222 | pyxdg==0.27
223 | PyYAML==6.0
224 | pyzmq==22.2.1
225 | QDarkStyle==3.0.2
226 | qstylizer==0.1.10
227 | QtAwesome==1.0.2
228 | qtconsole==5.1.1
229 | QtPy==1.10.0
230 | regex==2021.8.3
231 | requests==2.26.0
232 | requests-oauthlib==1.3.1
233 | rope==0.19.0
234 | rsa==4.9
235 | Rtree==0.9.7
236 | ruamel-yaml-conda==0.15.100
237 | scikit-image==0.18.3
238 | scikit-learn==0.24.2
239 | scikit-learn-intelex==2021.20210714.170444
240 | scipy==1.6.1
241 | seaborn==0.11.2
242 | SecretStorage==3.3.1
243 | Send2Trash==1.8.0
244 | setuptools==58.0.4
245 | shapely==2.0.1
246 | simplegeneric==0.8.1
247 | singledispatch==3.7.0
248 | sip==4.19.13
249 | six==1.16.0
250 | sniffio==1.2.0
251 | snowballstemmer==2.1.0
252 | sortedcollections==2.1.0
253 | sortedcontainers==2.4.0
254 | soupsieve==2.2.1
255 | Sphinx==4.2.0
256 | sphinxcontrib-applehelp==1.0.2
257 | sphinxcontrib-devhelp==1.0.2
258 | sphinxcontrib-htmlhelp==2.0.0
259 | sphinxcontrib-jsmath==1.0.1
260 | sphinxcontrib-qthelp==1.0.3
261 | sphinxcontrib-serializinghtml==1.1.5
262 | sphinxcontrib-websupport==1.2.4
263 | spyder==5.1.5
264 | spyder-kernels==2.1.3
265 | SQLAlchemy==1.4.22
266 | statsmodels==0.12.2
267 | sympy==1.9
268 | tables==3.6.1
269 | tb-nightly==2.13.0a20230330
270 | TBB==0.2
271 | tblib==1.7.0
272 | tensorboard==2.7.0
273 | tensorboard-data-server==0.7.0
274 | tensorboard-plugin-wit==1.8.1
275 | tensorboardX==2.4
276 | terminado==0.9.4
277 | testpath==0.5.0
278 | text-unidecode==1.3
279 | textdistance==4.2.1
280 | thop==0.1.1.post2209072238
281 | threadpoolctl==2.2.0
282 | three-merge==0.1.1
283 | tifffile==2021.7.2
284 | timm==0.4.12
285 | tinycss==0.4
286 | toml==0.10.2
287 | toolz==0.12.0
288 | torch==1.7.1+cu110
289 | torchaudio==0.7.2
290 | torchsummary==1.5.1
291 | torchsummaryX==1.3.0
292 | torchvision==0.8.2+cu110
293 | tornado==6.1
294 | tqdm==4.62.3
295 | traitlets==5.1.0
296 | typed-ast==1.4.3
297 | typing-extensions==3.10.0.2
298 | ujson==4.0.2
299 | unicodecsv==0.14.1
300 | Unidecode==1.2.0
301 | urllib3==1.26.7
302 | watchdog==2.1.3
303 | wcwidth==0.2.5
304 | webencodings==0.5.1
305 | Werkzeug==2.0.2
306 | wheel==0.37.0
307 | whichcraft==0.6.1
308 | widgetsnbextension==3.5.1
309 | wrapt==1.12.1
310 | wurlitzer==2.1.1
311 | xlrd==2.0.1
312 | XlsxWriter==3.0.1
313 | xlwt==1.3.0
314 | xmltodict==0.12.0
315 | yacs==0.1.8
316 | yapf==0.32.0
317 | zict==2.0.0
318 | zipp==3.6.0
319 | zope.event==4.5.0
320 | zope.interface==5.4.0
321 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from options.test_options import TestOptions
3 | from data import create_dataset
4 | from models import create_model
5 | from utils import utils
6 | from PIL import Image
7 | from tqdm import tqdm
8 | import torch
9 | import torch.nn.functional as F
10 | import time
11 | from IPython import embed
12 | if __name__ == '__main__':
13 | opt = TestOptions().parse() # get test options
14 | opt.num_threads = 0 # test code only supports num_threads = 1
15 | opt.batch_size = 1 # test code only supports batch_size = 1
16 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
17 | opt.no_flip = True
18 |
19 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
20 | model = create_model(opt) # create a model given opt.model and other options
21 | if len(opt.pretrain_model_path):
22 | model.load_pretrain_model()
23 | else:
24 | model.setup(opt) # regular setup: load and print networks; create schedulers
25 |
26 | if len(opt.save_as_dir):
27 | save_dir = opt.save_as_dir
28 | else:
29 | save_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch))
30 | if opt.load_iter > 0: # load_iter is 0 by default
31 | save_dir = '{:s}_iter{:d}'.format(save_dir, opt.load_iter)
32 | os.makedirs(save_dir, exist_ok=True)
33 |
34 | print('creating result directory', save_dir)
35 |
36 | network = model.netG
37 | network.eval()
38 |
39 | for i, data in tqdm(enumerate(dataset), total=len(dataset)):
40 | #embed()
41 | inp = data['LR']
42 |
43 | print(inp.shape)
44 | # inp = F.interpolate(inp, size=[16, 16], mode="bicubic")
45 |
46 | # print(inp.shape)
47 | with torch.no_grad():
48 | start_time = time.time()
49 | output_SR = network(inp)
50 | elapsed_time += time.time() - start_time
51 | # print(output_SR.shape)
52 | img_path = data['LR_paths'] # get image paths
53 | output_sr_img = utils.tensor_to_img(output_SR, normal=True)
54 |
55 | save_path = os.path.join(save_dir, img_path[0].split('/')[-1])
56 | save_img = Image.fromarray(output_sr_img)
57 | save_img.save(save_path)
58 | print(elapsed_time/1000)
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test.py --gpus 1 --model drn --name SPARNet_S16_V4_Attn2D \
2 | --load_size 128 --dataset_name single --dataroot "/home2/ZiXiangXu/test_HR" \
3 | --pretrain_model_path "/home2/ZiXiangXu/best_pth" \
4 | --save_as_dir "/home2/ZiXiangXu/test_results"
5 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import os
4 | import numpy as np
5 | from math import sqrt
6 | import math
7 | from math import log10
8 | import torch
9 | import torch.nn as nn
10 | from PIL import Image
11 | import torch.optim as optim
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | import torchvision.transforms as transforms
15 | import matplotlib.pyplot as pyplot
16 | from datetime import datetime
17 |
18 | import torch.backends.cudnn as cudnn
19 | from IPython import embed
20 | from utils.timer import Timer
21 | from utils.logger import Logger
22 | from utils import utils
23 | from IPython import embed
24 | from options.train_options import TrainOptions
25 | from data import create_dataset
26 | from models import create_model
27 | import os
28 | import torchvision.transforms as transforms
29 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
30 | if __name__ == '__main__':
31 |
32 | def is_image_file(filename):
33 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
34 |
35 | def load_img(filepath):
36 | image = Image.open(filepath).convert('RGB')
37 | return image
38 |
39 | def rgb2y_matlab(x):
40 | K = np.array([65.481, 128.553, 24.966]) / 255.0
41 | Y = 16 + np.matmul(x, K)
42 | return Y.astype(np.uint8)
43 |
44 |
45 | opt = TrainOptions().parse()
46 |
47 | dataset = create_dataset(opt)
48 | dataset_size = len(dataset)
49 | print('The number of training images = %d' % dataset_size)
50 |
51 | model = create_model(opt)
52 | model.setup(opt)
53 |
54 | logger = Logger(opt)
55 | timer = Timer()
56 | to_tensor = transforms.Compose([
57 | transforms.ToTensor(),
58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
59 | ])
60 |
61 | single_epoch_iters = (dataset_size // opt.batch_size)
62 | total_iters = opt.total_epochs * single_epoch_iters
63 | cur_iters = opt.resume_iter + opt.resume_epoch * single_epoch_iters
64 | start_iter = opt.resume_iter
65 | print('Start training from epoch: {:05d}; iter: {:07d}'.format(opt.resume_epoch, opt.resume_iter))
66 | for epoch in range(opt.resume_epoch, opt.total_epochs + 1):
67 | for i, data in enumerate(dataset, start=start_iter):
68 | cur_iters += 1
69 | logger.set_current_iter(cur_iters)
70 | # =================== load data ===============
71 | model.set_input(data, cur_iters)
72 | timer.update_time('DataTime')
73 |
74 | # =================== model train ===============
75 | model.forward(), timer.update_time('Forward')
76 | model.optimize_parameters(), timer.update_time('Backward')
77 | loss = model.get_current_losses()
78 | loss.update(model.get_lr())
79 | logger.record_losses(loss)
80 |
81 | # =================== save model and visualize ===============
82 | if cur_iters % opt.print_freq == 0:
83 | print('Model log directory: {}'.format(opt.expr_dir))
84 | epoch_progress = '{:03d}|{:05d}/{:05d}'.format(epoch, i, single_epoch_iters)
85 | logger.printIterSummary(epoch_progress, cur_iters, total_iters, timer)
86 |
87 | if cur_iters % opt.visual_freq == 0:
88 | visual_imgs = model.get_current_visuals()
89 | logger.record_images(visual_imgs)
90 |
91 | info = {'resume_epoch': epoch, 'resume_iter': i+1}
92 | if cur_iters % opt.save_iter_freq == 0 and cur_iters>120000:
93 | #if cur_iters % opt.save_iter_freq == 0:
94 | print('saving current model (epoch %d, iters %d)' % (epoch, cur_iters))
95 | save_suffix = 'iter_%d' % cur_iters
96 | model.save_networks(save_suffix, info)
97 | avg_psnr = 0
98 | image_ldir = "/home2/ZiXiangXu/Last_ding/spation_test_1000/"
99 | image_hdir = "/home2/ZiXiangXu/Last_ding/spation_test_1000/"
100 | image_filenames = [x for x in os.listdir(image_hdir) if is_image_file(x)]
101 | transform_list = [transforms.ToTensor()]
102 | transform = transforms.Compose(transform_list)
103 | for image_name in image_filenames:
104 | imgg = load_img(image_ldir + image_name)
105 | img_h = load_img(image_hdir + image_name)
106 | imgg = imgg.resize((16,16),Image.BICUBIC)
107 | img = imgg.resize((128, 128), Image.BICUBIC)
108 |
109 | input = to_tensor(img)
110 | input = Variable(input, volatile=True).view(1, -1, 128, 128)
111 | #embed()
112 | network = model.netG
113 | network.eval()
114 | out = network(input)
115 |
116 | output_sr_img = utils.tensor_to_img(out, normal=True)
117 | save_img = Image.fromarray(output_sr_img)
118 |
119 | if not os.path.exists(os.path.join("result", format(cur_iters))):
120 | os.mkdir(os.path.join("result", format(cur_iters)))
121 | save_img.save("result/{}/{}".format(cur_iters, image_name))
122 | print("Image saved as {}".format("result/{}/{}".format(cur_iters, image_name)))
123 |
124 | if cur_iters % opt.save_latest_freq == 0:
125 | print('saving the latest model (epoch %d, iters %d)' % (epoch, cur_iters))
126 | model.save_networks('latest', info)
127 |
128 | if opt.debug: break
129 | if opt.debug and epoch > 5: exit()
130 |
131 |
132 |
133 |
134 | logger.close()
135 |
136 |
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | # export CUDA_VISIBLE_DEVICES=$1
2 | # =================================================================================
3 | # Train CTCNet
4 | # =================================================================================
5 |
6 | python train.py --gpus 1 --name CTCNet_S16_V4_Attn2D --model ctcnet \
7 | --Gnorm "bn" --lr 0.0002 --beta1 0.9 --scale_factor 8 --load_size 128 \
8 | --dataroot "/home/ggw/ZiXiangXu/DataSets/spation_18000/" --dataset_name celeba --batch_size 10 --total_epochs 100 \
9 | --visual_freq 100 --print_freq 50 --save_latest_freq 1000 #--continue_train
10 |
--------------------------------------------------------------------------------
/util/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/util/.DS_Store
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/util/__init__.py
--------------------------------------------------------------------------------
/util/rlutrans.py:
--------------------------------------------------------------------------------
1 | from models import common_ESTR as common
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | # from thop import profile
6 | from util.tools import extract_image_patches, reduce_mean, reduce_sum, same_padding, reverse_patches
7 | import pdb
8 | import math
9 | from IPython import embed
10 |
11 | def drop_path(x, drop_prob: float = 0., training: bool = False):
12 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
13 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
14 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
15 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
16 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
17 | 'survival rate' as the argument.
18 | """
19 | if drop_prob == 0. or not training:
20 | return x
21 | keep_prob = 1 - drop_prob
22 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
23 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
24 | random_tensor.floor_() # binarize
25 | output = x.div(keep_prob) * random_tensor
26 | return output
27 |
28 | # MLP in the paper
29 | class Mlp(nn.Module):
30 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
31 | super().__init__()
32 | out_features = out_features or in_features
33 | hidden_features = hidden_features or in_features//4
34 | self.fc1 = nn.Linear(in_features, hidden_features)
35 | self.act = act_layer()
36 | self.fc2 = nn.Linear(hidden_features, out_features)
37 | self.drop = nn.Dropout(drop)
38 |
39 | def forward(self, x):
40 | x = self.fc1(x)
41 | x = self.act(x)
42 | x = self.drop(x)
43 | x = self.fc2(x)
44 | x = self.drop(x)
45 | return x
46 |
47 |
48 | # Efficient Multi-Head Attention in the paper
49 | class EffAttention(nn.Module):
50 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
51 | super().__init__()
52 | self.num_heads = num_heads
53 | head_dim = dim // num_heads
54 | self.scale = qk_scale or 1//4##head_dim ** -0.5
55 |
56 | self.reduce = nn.Linear(dim, dim, bias=qkv_bias) #nn.Linear全连接层其输入输出为二维张量,需要将4维张量转换为二维张量之后才能作为输入
57 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
58 | self.proj = nn.Linear(dim, dim)
59 | self.attn_drop = nn.Dropout(attn_drop)
60 | self.proj_drop = nn.Dropout(proj_drop)
61 |
62 | def forward(self, x):
63 | x = self.reduce(x)
64 | B, N, C = x.shape
65 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
66 | q, k, v = qkv[0], qkv[1], qkv[2]
67 | ''''
68 | q_all = torch.split(q, math.ceil(N//4), dim=-2)
69 | k_all = torch.split(k, math.ceil(N//4), dim=-2)
70 | v_all = torch.split(v, math.ceil(N//4), dim=-2)
71 |
72 | output = []
73 | for q,k,v in zip(q_all, k_all, v_all):
74 | attn = (q @ k.transpose(-2, -1)) * self.scale #16*8*37*37
75 | attn = attn.softmax(dim=-1)
76 | attn = self.attn_drop(attn)
77 | trans_x = (attn @ v).transpose(1, 2) #.reshape(B, N, C)
78 | output.append(trans_x)
79 | x = torch.cat(output,dim=1)
80 | '''
81 | attn = (q @ k.transpose(-2, -1)) * self.scale
82 | attn = attn.softmax(dim=-1)
83 | attn = self.attn_drop(attn)
84 |
85 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
86 | x = self.proj(x)
87 | x = self.proj_drop(x)
88 | return x
89 | '''
90 | class Attention(nn.Module):
91 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
92 | super().__init__()
93 | self.num_heads = num_heads
94 | head_dim = dim // num_heads
95 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
96 | self.scale = qk_scale or head_dim ** -0.5
97 |
98 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
99 | self.attn_drop = nn.Dropout(attn_drop)
100 | self.proj = nn.Linear(dim, dim)
101 | self.proj_drop = nn.Dropout(proj_drop)
102 |
103 | def forward(self, x):
104 | B, N, C = x.shape
105 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
106 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
107 |
108 | attn = (q @ k.transpose(-2, -1)) * self.scale
109 | attn = attn.softmax(dim=-1)
110 | attn = self.attn_drop(attn)
111 |
112 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
113 | x = self.proj(x)
114 | x = self.proj_drop(x)
115 | return x
116 | '''
117 | ## Key Module: Efficient Transformer (ET) in the paper
118 | class TransBlock(nn.Module):
119 | def __init__(
120 | self, n_feat = 64,dim=64, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
121 | drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
122 | super(TransBlock, self).__init__()
123 | self.dim = dim
124 | self.atten = EffAttention(self.dim, num_heads=8, qkv_bias=False, qk_scale=None, \
125 | attn_drop=0., proj_drop=0.)
126 | self.norm1 = nn.LayerNorm(self.dim)
127 | self.mlp = Mlp(in_features=dim, hidden_features=dim//4, act_layer=act_layer, drop=drop)
128 | self.norm2 = nn.LayerNorm(self.dim)
129 |
130 | def forward(self, x):
131 | #B = x.shape[0]
132 | #x = extract_image_patches(x, ksizes=[3, 3],
133 | # strides=[1,1],
134 | # rates=[1, 1],
135 | # padding='same') # 16*2304*576
136 | #x = x.permute(0,2,1)
137 |
138 | x = x + self.atten(self.norm1(x))
139 | x = x + self.mlp(self.norm2(x))
140 | return x
141 |
--------------------------------------------------------------------------------
/util/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch.nn.functional as F
7 |
8 | def normalize(x):
9 | return x.mul_(2).add_(-1)
10 |
11 | def same_padding(images, ksizes, strides, rates):
12 | assert len(images.size()) == 4
13 | batch_size, channel, rows, cols = images.size()
14 | out_rows = (rows + strides[0] - 1) // strides[0]
15 | out_cols = (cols + strides[1] - 1) // strides[1]
16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1
17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1
18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
20 | # Pad the input
21 | padding_top = int(padding_rows / 2.)
22 | padding_left = int(padding_cols / 2.)
23 | padding_bottom = padding_rows - padding_top
24 | padding_right = padding_cols - padding_left
25 | paddings = (padding_left, padding_right, padding_top, padding_bottom)
26 | images = torch.nn.ZeroPad2d(paddings)(images)#对Tensor使用0进行边界填充,我们可以指定tensor的四个方向上的填充数,比如左边添加1dim、右边添加2dim、上边添加3dim、下边添加4dim,即指定paddin参数为(1,2,3,4)
27 | return images
28 |
29 |
30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'):
31 | """
32 | Extract patches from images and put them in the C output dimension.
33 | :param padding:
34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
36 | each dimension of images
37 | :param strides: [stride_rows, stride_cols]
38 | :param rates: [dilation_rows, dilation_cols]
39 | :return: A Tensor
40 | """
41 | assert len(images.size()) == 4
42 | assert padding in ['same', 'valid']
43 | batch_size, channel, height, width = images.size()
44 |
45 | if padding == 'same':
46 | images = same_padding(images, ksizes, strides, rates)
47 | elif padding == 'valid':
48 | pass
49 | else:
50 | raise NotImplementedError('Unsupported padding type: {}.\
51 | Only "same" or "valid" are supported.'.format(padding))
52 |
53 | unfold = torch.nn.Unfold(kernel_size=ksizes,
54 | dilation=rates,
55 | padding=0,
56 | stride=strides)
57 | patches = unfold(images)
58 | return patches # [N, C*k*k, L], L is the total number of such blocks
59 |
60 | def reverse_patches(images, out_size, ksizes, strides, padding):
61 | """
62 | Extract patches from images and put them in the C output dimension.
63 | :param padding:
64 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
65 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
66 | each dimension of images
67 | :param strides: [stride_rows, stride_cols]
68 | :param rates: [dilation_rows, dilation_cols]
69 | :return: A Tensor
70 | """
71 | unfold = torch.nn.Fold(output_size = out_size,
72 | kernel_size=ksizes,
73 | dilation=1,
74 | padding=padding,
75 | stride=strides)
76 | patches = unfold(images)
77 | return patches # [N, C*k*k, L], L is the total number of such blocks
78 | def reduce_mean(x, axis=None, keepdim=False):
79 | if not axis:
80 | axis = range(len(x.shape))
81 | for i in sorted(axis, reverse=True):
82 | x = torch.mean(x, dim=i, keepdim=keepdim)
83 | return x
84 |
85 |
86 | def reduce_std(x, axis=None, keepdim=False):
87 | if not axis:
88 | axis = range(len(x.shape))
89 | for i in sorted(axis, reverse=True):
90 | x = torch.std(x, dim=i, keepdim=keepdim)
91 | return x
92 |
93 |
94 | def reduce_sum(x, axis=None, keepdim=False):
95 | if not axis:
96 | axis = range(len(x.shape))
97 | for i in sorted(axis, reverse=True):
98 | x = torch.sum(x, dim=i, keepdim=keepdim)
99 | return x
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CTCNet/4406888c1f8d01a612b993334bce835899483a97/utils/__init__.py
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import numpy as np
4 | from .utils import mkdirs
5 | from tensorboardX import SummaryWriter
6 | from datetime import datetime
7 | import socket
8 | import shutil
9 |
10 | class Logger():
11 | def __init__(self, opts):
12 | time_stamp = '_{}'.format(datetime.now().strftime('%Y-%m-%d_%H:%M'))
13 | self.opts = opts
14 | self.log_dir = os.path.join(opts.log_dir, opts.name+time_stamp)
15 | self.phase_keys = ['train', 'val', 'test']
16 | self.iter_log = []
17 | self.epoch_log = OrderedDict()
18 | self.set_mode(opts.phase)
19 |
20 | # check if exist previous log belong to the same experiment name
21 | exist_log = None
22 | for log_name in os.listdir(opts.log_dir):
23 | if opts.name in log_name:
24 | exist_log = log_name
25 | if exist_log is not None:
26 | old_dir = os.path.join(opts.log_dir, exist_log)
27 | archive_dir = os.path.join(opts.log_archive, exist_log)
28 | shutil.move(old_dir, archive_dir)
29 |
30 | self.mk_log_file()
31 |
32 | self.writer = SummaryWriter(self.log_dir)
33 |
34 | def mk_log_file(self):
35 | mkdirs(self.log_dir)
36 | self.txt_files = OrderedDict()
37 | for i in self.phase_keys:
38 | self.txt_files[i] = os.path.join(self.log_dir, 'log_{}'.format(i))
39 |
40 | def set_mode(self, mode):
41 | self.mode = mode
42 | self.epoch_log[mode] = []
43 |
44 | def set_current_iter(self, cur_iter):
45 | self.cur_iter = cur_iter
46 |
47 | def record_losses(self, items):
48 | """
49 | iteration log: [iter][{key: value}]
50 | """
51 | self.iter_log.append(items)
52 | for k, v in items.items():
53 | if 'loss' in k.lower():
54 | self.writer.add_scalar('loss/{}'.format(k), v, self.cur_iter)
55 |
56 | def record_scalar(self, items):
57 | """
58 | Add scalar records. item, {key: value}
59 | """
60 | for i in items.keys():
61 | self.writer.add_scalar('{}'.format(i), items[i], self.cur_iter)
62 |
63 | def record_image(self, visual_img, tag='ckpt_image'):
64 | self.writer.add_image(tag, visual_img, self.cur_iter, dataformats='HWC')
65 |
66 | def record_images(self, visuals, nrow=6, tag='ckpt_image'):
67 | imgs = []
68 | nrow = min(nrow, visuals[0].shape[0])
69 | for i in range(nrow):
70 | tmp_imgs = [x[i] for x in visuals]
71 | imgs.append(np.hstack(tmp_imgs))
72 | imgs = np.vstack(imgs).astype(np.uint8)
73 | self.writer.add_image(tag, imgs, self.cur_iter, dataformats='HWC')
74 |
75 | def record_text(self, tag, text):
76 | self.writer.add_text(tag, text)
77 |
78 | def printIterSummary(self, epoch, cur_iters, total_it, timer):
79 | msg = '{}\nIter: [{}]{:03d}/{:03d}\t\t'.format(
80 | timer.to_string(total_it - cur_iters), epoch, cur_iters, total_it)
81 | for k, v in self.iter_log[-1].items():
82 | msg += '{}: {:.6f}\t'.format(k, v)
83 | print(msg + '\n')
84 | with open(self.txt_files[self.mode], 'a+') as f:
85 | f.write(msg + '\n')
86 |
87 | def close(self):
88 | self.writer.export_scalars_to_json(os.path.join(self.log_dir, 'all_scalars.json'))
89 | self.writer.close()
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import datetime
3 | from collections import OrderedDict
4 |
5 | class Timer():
6 | def __init__(self):
7 | self.reset_timer()
8 | self.start = time.time()
9 |
10 | def reset_timer(self):
11 | self.before = time.time()
12 | self.timer = OrderedDict()
13 |
14 | def update_time(self, key):
15 | self.timer[key] = time.time() - self.before
16 | self.before = time.time()
17 |
18 | def to_string(self, iters_left, short=False):
19 | iter_total = sum(self.timer.values())
20 | msg = "{:%Y-%m-%d %H:%M:%S}\tElapse: {}\tTimeLeft: {}\t".format(
21 | datetime.datetime.now(),
22 | datetime.timedelta(seconds=round(time.time() - self.start)),
23 | datetime.timedelta(seconds=round(iter_total*iters_left))
24 | )
25 | if short:
26 | msg += '{}: {:.2f}s'.format('|'.join(self.timer.keys()), iter_total)
27 | else:
28 | msg += '\tIterTotal: {:.2f}s\t{}: {} '.format(iter_total,
29 | '|'.join(self.timer.keys()), ' '.join('{:.2f}s'.format(x) for x in self.timer.values()))
30 | return msg
31 |
32 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2 as cv
4 | from skimage import io
5 | from PIL import Image
6 | import os
7 | import subprocess
8 |
9 |
10 | def img_to_tensor(img_path, device, size=None, mode='rgb'):
11 | """
12 | Read image from img_path, and convert to (C, H, W) tensor in range [-1, 1]
13 | """
14 | img = Image.open(img_path).convert('RGB')
15 | img = np.array(img)
16 | if mode=='bgr':
17 | img = img[..., ::-1]
18 | if size:
19 | img = cv.resize(img, size)
20 | img = img / 255 * 2 - 1
21 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
22 | return img_tensor.float()
23 |
24 |
25 | def tensor_to_img(tensor, save_path=None, size=None, mode='rgb', normal=False):
26 | img_array = tensor.squeeze().data.cpu().numpy()
27 | img_array = img_array.transpose(1, 2, 0)
28 | if size is not None:
29 | img_array = cv.resize(img_array, size, interpolation=cv.INTER_LINEAR)
30 | if normal:
31 | # img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min())
32 | img_array = (img_array + 1.) / 2. * 255.
33 | img_array = img_array.clip(0, 255)
34 | if save_path:
35 | if img_array.max() <= 1:
36 | img_array = (img_array * 255).astype(np.uint8)
37 | io.imsave(save_path, img_array)
38 |
39 | return img_array.astype(np.uint8)
40 |
41 |
42 | def tensor_to_numpy(tensor):
43 | return tensor.data.cpu().numpy()
44 |
45 |
46 | def batch_numpy_to_image(array, size=None):
47 | """
48 | Input: numpy array (B, C, H, W) in [-1, 1]
49 | """
50 | if isinstance(size, int):
51 | size = (size, size)
52 |
53 | out_imgs = []
54 | array = np.clip((array + 1)/2 * 255, 0, 255)
55 | array = np.transpose(array, (0, 2, 3, 1))
56 | for i in range(array.shape[0]):
57 | if size is not None:
58 | tmp_array = cv.resize(array[i], size)
59 | else:
60 | tmp_array = array[i]
61 | out_imgs.append(tmp_array)
62 | return np.array(out_imgs)
63 |
64 |
65 | def batch_tensor_to_img(tensor, size=None):
66 | """
67 | Input: (B, C, H, W)
68 | Return: RGB image, [0, 255]
69 | """
70 | arrays = tensor_to_numpy(tensor)
71 | out_imgs = batch_numpy_to_image(arrays, size)
72 | return out_imgs
73 |
74 |
75 | def mkdirs(paths):
76 | if isinstance(paths, list) and not isinstance(paths, str):
77 | for path in paths:
78 | if not os.path.exists(path):
79 | os.makedirs(path)
80 | else:
81 | if not os.path.exists(paths):
82 | os.makedirs(paths)
83 |
84 |
85 | def get_gpu_memory_map():
86 | """Get the current gpu usage within visible cuda devices.
87 |
88 | Returns
89 | -------
90 | Memory Map: dict
91 | Keys are device ids as integers.
92 | Values are memory usage as integers in MB.
93 | Device Ids: gpu ids sorted in descending order according to the available memory.
94 | """
95 | result = subprocess.check_output(
96 | [
97 | 'nvidia-smi', '--query-gpu=memory.used',
98 | '--format=csv,nounits,noheader'
99 | ]).decode('utf-8')
100 | # Convert lines into a dictionary
101 | gpu_memory = np.array([int(x) for x in result.strip().split('\n')])
102 | if 'CUDA_VISIBLE_DEVICES' in os.environ:
103 | visible_devices = sorted([int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')])
104 | else:
105 | visible_devices = range(len(gpu_memory))
106 | gpu_memory_map = dict(zip(range(len(visible_devices)), gpu_memory[visible_devices]))
107 | return gpu_memory_map, sorted(gpu_memory_map, key=gpu_memory_map.get)
108 |
109 |
110 |
--------------------------------------------------------------------------------