├── LICENSE
├── README.md
├── codes
├── data
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── paired_folder_dataset.py
│ ├── paired_lmdb_dataset.py
│ ├── unpaired_folder_dataset.py
│ └── unpaired_lmdb_dataset.py
├── main.py
├── metrics
│ ├── LPIPS
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ └── models
│ │ │ ├── __init__.py
│ │ │ ├── base_model.py
│ │ │ ├── dist_model.py
│ │ │ ├── networks_basic.py
│ │ │ ├── pretrained_networks.py
│ │ │ └── weights
│ │ │ └── v0.1
│ │ │ ├── alex.pth
│ │ │ ├── squeeze.pth
│ │ │ └── vgg.pth
│ ├── __init__.py
│ ├── metric_calculator.py
│ └── model_summary.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── networks
│ │ ├── __init__.py
│ │ ├── base_nets.py
│ │ ├── tecogan_nets.py
│ │ └── vgg_nets.py
│ ├── optim
│ │ ├── __init__.py
│ │ ├── losses.py
│ │ └── lr_schedules.py
│ ├── vsr_model.py
│ └── vsrgan_model.py
├── official_metrics
│ ├── LPIPSmodels
│ │ ├── LPIPSsource.txt
│ │ ├── base_model.py
│ │ ├── dist_model.py
│ │ ├── networks_basic.py
│ │ ├── pretrained_networks.py
│ │ ├── util.py
│ │ └── v0.1
│ │ │ └── alex.pth
│ ├── evaluate.py
│ └── metrics.py
└── utils
│ ├── __init__.py
│ ├── base_utils.py
│ ├── data_utils.py
│ ├── dist_utils.py
│ └── net_utils.py
├── data
├── meta
│ └── REDS
│ │ ├── test_list.txt
│ │ └── train_list.txt
└── put_data_here
├── experiments_BD
├── FRVSR
│ ├── FRVSR_REDS_2xSR_2GPU
│ │ ├── test.yml
│ │ └── train.yml
│ ├── FRVSR_REDS_4xSR_2GPU
│ │ ├── test.yml
│ │ └── train.yml
│ └── FRVSR_VimeoTecoGAN_4xSR_2GPU
│ │ ├── test.yml
│ │ └── train.yml
└── TecoGAN
│ ├── TecoGAN_REDS_2xSR_2GPU
│ ├── test.yml
│ └── train.yml
│ ├── TecoGAN_REDS_4xSR_2GPU
│ ├── test.yml
│ └── train.yml
│ └── TecoGAN_VimeoTecoGAN_4xSR_2GPU
│ ├── test.yml
│ └── train.yml
├── experiments_BI
├── FRVSR
│ └── FRVSR_VimeoTecoGAN_4xSR_2GPU
│ │ ├── test.yml
│ │ └── train.yml
└── TecoGAN
│ └── TecoGAN_VimeoTecoGAN_4xSR_2GPU
│ ├── test.yml
│ └── train.yml
├── pretrained_models
└── put_the_pretrained_models_here
├── profile.sh
├── resources
├── benchmark.png
├── bridge.gif
├── fire.gif
├── foliage.gif
├── losses.png
├── metrics.png
└── pond.gif
├── scripts
├── create_lmdb.py
├── download
│ ├── download_datasets.sh
│ └── download_models.sh
├── generate_lr_bi.m
├── monitor_training.py
└── resize_bd.py
├── test.sh
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 | # TecoGAN-PyTorch
2 |
3 | ### Introduction
4 | This is a PyTorch reimplementation of **TecoGAN**: **Te**mporally **Co**herent **GAN** for Video Super-Resolution (VSR). Please refer to the official TensorFlow implementation [TecoGAN-TensorFlow](https://github.com/thunil/TecoGAN) for more information.
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | ### Updates
19 | - 11/2021: Supported 2x SR.
20 | - 10/2021: Supported model training/testing on the [REDS](https://seungjunnah.github.io/Datasets/reds.html) dataset.
21 | - 07/2021: Upgraded codebase to support multi-GPU training & testing.
22 |
23 |
24 |
25 | ### Features
26 | - **Better Performance**: This repo provides model with smaller size yet better performance than the official repo. See our [Benchmark](https://github.com/skycrapers/TecoGAN-PyTorch#benchmark).
27 | - **Multiple Degradations**: This repo supports two types of degradation, BI (Matlab's imresize with the option bicubic) & BD (Gaussian Blurring + Down-sampling).
28 | - **Unified Framework**: This repo provides a unified framework for distortion-based and perception-based VSR methods.
29 |
30 |
31 |
32 | ### Contents
33 | 1. [Dependencies](#dependencies)
34 | 1. [Testing](#testing)
35 | 1. [Training](#training)
36 | 1. [Benchmark](#benchmark)
37 | 1. [License & Citation](#license--citation)
38 | 1. [Acknowledgements](#acknowledgements)
39 |
40 |
41 |
42 | ## Dependencies
43 | - Ubuntu >= 16.04
44 | - NVIDIA GPU + CUDA
45 | - Python >= 3.7
46 | - PyTorch >= 1.4.0
47 | - Python packages: numpy, matplotlib, opencv-python, pyyaml, lmdb
48 | - (Optional) Matlab >= R2016b
49 |
50 |
51 |
52 | ## Testing
53 |
54 | **Note:** We apply different models according to the degradation type. The following steps are for `4xSR` under `BD` degradation. You can switch to `2xSR` or `BI` degradation by replacing all `4x` to `2x` and `BD` to `BI` below.
55 |
56 | 1. Download the official Vid4 and ToS3 datasets. In `BD` mode, only ground-truth data is needed.
57 | ```bash
58 | bash ./scripts/download/download_datasets.sh BD
59 | ```
60 | > You can manually download these datasets from Google Drive, and unzip them under `./data`.
61 | > * Vid4 Dataset [[Ground-Truth](https://drive.google.com/file/d/1T8TuyyOxEUfXzCanH5kvNH2iA8nI06Wj/view?usp=sharing)] [[Low Resolution (BD)](https://drive.google.com/file/d/1-5NFW6fEPUczmRqKHtBVyhn2Wge6j3ma/view?usp=sharing)] [[Low Resolution (BI)](https://drive.google.com/file/d/1Kg0VBgk1r9I1c4f5ZVZ4sbfqtVRYub91/view?usp=sharing)]
62 | > * ToS3 Dataset [[Ground-Truth](https://drive.google.com/file/d/1XoR_NVBR-LbZOA8fXh7d4oPV0M8fRi8a/view?usp=sharing)] [[Low Resolution (BD)](https://drive.google.com/file/d/1rDCe61kR-OykLyCo2Ornd2YgPnul2ffM/view?usp=sharing)] [[Low Resolution (BI)](https://drive.google.com/file/d/1FNuC0jajEjH9ycqDkH4cZQ3_eUqjxzzf/view?usp=sharing)]
63 |
64 | The dataset structure is shown as below.
65 | ```tex
66 | data
67 | ├─ Vid4
68 | ├─ GT # Ground-Truth (GT) sequences
69 | └─ calendar
70 | └─ ***.png
71 | ├─ Gaussian4xLR # Low Resolution (LR) sequences in BD degradation
72 | └─ calendar
73 | └─ ***.png
74 | └─ Bicubic4xLR # Low Resolution (LR) sequences in BI degradation
75 | └─ calendar
76 | └─ ***.png
77 | └─ ToS3
78 | ├─ GT
79 | ├─ Gaussian4xLR
80 | └─ Bicubic4xLR
81 | ```
82 |
83 | 2. Download our pre-trained TecoGAN model.
84 | ```bash
85 | bash ./scripts/download/download_models.sh BD TecoGAN
86 | ```
87 | > You can download the model from [[BD-4x-Vimeo](https://drive.google.com/file/d/13FPxKE6q7tuRrfhTE7GB040jBeURBj58/view?usp=sharing)][[BI-4x-Vimeo](https://drive.google.com/file/d/1ie1F7wJcO4mhNWK8nPX7F0LgOoPzCwEu/view?usp=sharing)][[BD-4x-REDS](https://drive.google.com/file/d/1vMvMbv_BvC2G-qCcaOBkNnkMh_gLNe6q/view?usp=sharing)][[BD-2x-REDS](https://drive.google.com/file/d/1XN5D4hjNvitO9Kb3OrYiKGjwNU0b43ZI/view?usp=sharing)], and put it under `./pretrained_models`.
88 |
89 | 3. Run TecoGAN for 4x SR. The results will be saved in `./results`. You can specify which model and how many gpus to be used in `test.sh`.
90 | ```bash
91 | bash ./test.sh BD TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU
92 | ```
93 |
94 | 4. Evaluate the upsampled results using the official metrics. These codes are borrowed from [TecoGAN-TensorFlow](https://github.com/thunil/TecoGAN), with minor modifications to adapt to the BI degradation.
95 | ```bash
96 | python ./codes/official_metrics/evaluate.py -m TecoGAN_4x_BD_Vimeo_iter500K
97 | ```
98 |
99 | 5. Profile model (FLOPs, parameters and speed). You can modify the last argument to specify the size of the LR video.
100 | ```bash
101 | bash ./profile.sh BD TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU 3x134x320
102 | ```
103 |
104 | ## Training
105 | **Note:** Due to the inaccessibility of the VimeoTecoGAN dataset, we recommend using other public datasets, e.g., REDS, for model training. To use REDS as the training dataset, just download it from [here](https://seungjunnah.github.io/Datasets/reds.html) and replace the following `VimeoTecoGAN` to `REDS`.
106 |
107 | 1. Download the official training dataset according to the instructions in [TecoGAN-TensorFlow](https://github.com/thunil/TecoGAN), rename to `VimeoTecoGAN/Raw`, and place under `./data`.
108 |
109 | 2. Generate LMDB for GT data to accelerate IO. The LR counterpart will then be generated on the fly during training.
110 | ```bash
111 | python ./scripts/create_lmdb.py --dataset VimeoTecoGAN --raw_dir ./data/VimeoTecoGAN/Raw --lmdb_dir ./data/VimeoTecoGAN/GT.lmdb
112 | ```
113 |
114 | The following shows the dataset structure after finishing the above two steps.
115 | ```tex
116 | data
117 | ├─ VimeoTecoGAN
118 | ├─ Raw # Raw dataset
119 | ├─ scene_2000
120 | └─ ***.png
121 | ├─ scene_2001
122 | └─ ***.png
123 | └─ ...
124 | └─ GT.lmdb # LMDB dataset
125 | ├─ data.mdb
126 | ├─ lock.mdb
127 | └─ meta_info.pkl # each key has format: [vid]_[total_frame]x[h]x[w]_[i-th_frame]
128 | ```
129 |
130 | 3. **(Optional, this step is only required for BI degradation)** Manually generate the LR sequences with the Matlab's imresize function, and then create LMDB for them.
131 | ```bash
132 | # Generate the raw LR video sequences. Results will be saved at ./data/VimeoTecoGAN/Bicubic4xLR
133 | matlab -nodesktop -nosplash -r "cd ./scripts; generate_lr_bi"
134 |
135 | # Create LMDB for the LR video sequences
136 | python ./scripts/create_lmdb.py --dataset VimeoTecoGAN --raw_dir ./data/VimeoTecoGAN/Bicubic4xLR --lmdb_dir ./data/VimeoTecoGAN/Bicubic4xLR.lmdb
137 | ```
138 |
139 | 4. Train a FRVSR model first, which can provide a better initialization for the subsequent TecoGAN training. FRVSR has the same generator as TecoGAN, but without perceptual training (GAN and perceptual losses).
140 | ```bash
141 | bash ./train.sh BD FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU
142 | ```
143 | > You can download and use our pre-trained FRVSR models instead of training from scratch. [[BD-4x-Vimeo](https://drive.google.com/file/d/11kPVS04a3B3k0SD-mKEpY_Q8WL7KrTIA/view?usp=sharing)] [[BI-4x-Vimeo](https://drive.google.com/file/d/1wejMAFwIBde_7sz-H7zwlOCbCvjt3G9L/view?usp=sharing)] [[BD-4x-REDS](https://drive.google.com/file/d/1YyTwBFF6P9xy6b9UBILF4ornCdmWbDLY/view?usp=sharing)][[BD-2x-REDS](https://drive.google.com/file/d/1ibsr3td1rYeKsDc2d-J9-8jFURBFc_ST/view?usp=sharing)]
144 |
145 | When the training is complete, set the generator's `load_path` in `experiments_BD/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/train.yml` to the latest checkpoint weight of the FRVSR model.
146 |
147 | 5. Train a TecoGAN model. You can specify which gpu to be used in `train.sh`. By default, the training is conducted in the background and the output info will be logged in `./experiments_BD/TecoGAN/TecoGAN_VimeoTecoGAN/train/train.log`.
148 | ```bash
149 | bash ./train.sh BD TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU
150 | ```
151 |
152 | 6. Run the following script to monitor the training process and visualize the validation performance.
153 | ```bash
154 | python ./scripts/monitor_training.py -dg BD -m TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU -ds Vid4
155 | ```
156 | > Note that the validation results are NOT exactly the same as the testing results mentioned above due to different implementation of the metrics. The differences are caused by croping policy, LPIPS version and some other issues.
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 | ## Benchmark
166 |
167 |
168 |
169 |
170 |
171 | > [1] FLOPs & speed are computed on RGB sequence with resolution 134\*320 on a single NVIDIA 1080Ti GPU. \
172 | > [2] Both FRVSR & TecoGAN use 10 residual blocks, while TecoGAN+ has 16 residual blocks.
173 |
174 |
175 |
176 | ## License & Citation
177 | If you use this code for your research, please cite the following paper and our project.
178 | ```tex
179 | @article{tecogan2020,
180 | title={Learning temporal coherence via self-supervision for GAN-based video generation},
181 | author={Chu, Mengyu and Xie, You and Mayer, Jonas and Leal-Taix{\'e}, Laura and Thuerey, Nils},
182 | journal={ACM Transactions on Graphics (TOG)},
183 | volume={39},
184 | number={4},
185 | pages={75--1},
186 | year={2020},
187 | publisher={ACM New York, NY, USA}
188 | }
189 | ```
190 | ```tex
191 | @misc{tecogan_pytorch,
192 | author={Deng, Jianing and Zhuo, Cheng},
193 | title={PyTorch Implementation of Temporally Coherent GAN (TecoGAN) for Video Super-Resolution},
194 | howpublished="\url{https://github.com/skycrapers/TecoGAN-PyTorch}",
195 | year={2020},
196 | }
197 | ```
198 |
199 |
200 |
201 | ## Acknowledgements
202 | This code is built on [TecoGAN-TensorFlow](https://github.com/thunil/TecoGAN), [BasicSR](https://github.com/xinntao/BasicSR) and [LPIPS](https://github.com/richzhang/PerceptualSimilarity). We thank the authors for sharing their codes.
203 |
204 | If you have any questions, feel free to email me `jn.deng@foxmail.com`
205 |
--------------------------------------------------------------------------------
/codes/data/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.distributed import DistributedSampler
4 |
5 | from .paired_lmdb_dataset import PairedLMDBDataset
6 | from .unpaired_lmdb_dataset import UnpairedLMDBDataset
7 | from .paired_folder_dataset import PairedFolderDataset
8 | from .unpaired_folder_dataset import UnpairedFolderDataset
9 |
10 |
11 | def create_dataloader(opt, phase, idx):
12 | # set params
13 | data_opt = opt['dataset'].get(idx)
14 | degradation_type = opt['dataset']['degradation']['type']
15 |
16 | # === create loader for training === #
17 | if phase == 'train':
18 | # check dataset
19 | assert data_opt['name'] in ('VimeoTecoGAN', 'REDS'), \
20 | f'Unknown Dataset: {data_opt["name"]}'
21 |
22 | if degradation_type == 'BI':
23 | # create dataset
24 | dataset = PairedLMDBDataset(
25 | data_opt,
26 | scale=opt['scale'],
27 | tempo_extent=opt['train']['tempo_extent'],
28 | moving_first_frame=opt['train'].get('moving_first_frame', False),
29 | moving_factor=opt['train'].get('moving_factor', 1.0))
30 |
31 | elif degradation_type == 'BD':
32 | # enlarge the crop size to incorporate border
33 | sigma = opt['dataset']['degradation']['sigma']
34 | enlarged_crop_size = data_opt['crop_size'] + 2 * int(sigma * 3.0)
35 |
36 | # create dataset
37 | dataset = UnpairedLMDBDataset(
38 | data_opt,
39 | crop_size=enlarged_crop_size, # override
40 | tempo_extent=opt['train']['tempo_extent'],
41 | moving_first_frame=opt['train'].get('moving_first_frame', False),
42 | moving_factor=opt['train'].get('moving_factor', 1.0))
43 |
44 | else:
45 | raise ValueError(f'Unrecognized degradation type: {degradation_type}')
46 |
47 | # create data loader
48 | if opt['dist']:
49 | batch_size = data_opt['batch_size_per_gpu']
50 | shuffle = False
51 | sampler = DistributedSampler(dataset)
52 | else:
53 | batch_size = data_opt['batch_size_per_gpu']
54 | shuffle = True
55 | sampler = None
56 |
57 | loader = DataLoader(
58 | dataset=dataset,
59 | batch_size=batch_size,
60 | shuffle=shuffle,
61 | drop_last=True,
62 | sampler=sampler,
63 | num_workers=data_opt['num_worker_per_gpu'],
64 | pin_memory=data_opt['pin_memory'])
65 |
66 | # === create loader for testing === #
67 | elif phase == 'test':
68 | # create data loader
69 | if 'lr_seq_dir' in data_opt and data_opt['lr_seq_dir']:
70 | loader = DataLoader(
71 | dataset=PairedFolderDataset(data_opt),
72 | batch_size=1,
73 | shuffle=False,
74 | num_workers=data_opt['num_worker_per_gpu'],
75 | pin_memory=data_opt['pin_memory'])
76 |
77 | else:
78 | assert degradation_type == 'BD', \
79 | '"lr_seq_dir" is required for BI mode'
80 |
81 | sigma = opt['dataset']['degradation']['sigma']
82 | ksize = 2 * int(sigma * 3.0) + 1
83 |
84 | loader = DataLoader(
85 | dataset=UnpairedFolderDataset(
86 | data_opt, scale=opt['scale'], sigma=sigma, ksize=ksize),
87 | batch_size=1,
88 | shuffle=False,
89 | num_workers=data_opt['num_worker_per_gpu'],
90 | pin_memory=data_opt['pin_memory'])
91 |
92 | else:
93 | raise ValueError(f'Unrecognized phase: {phase}')
94 |
95 | return loader
96 |
--------------------------------------------------------------------------------
/codes/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import lmdb
2 |
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class BaseDataset(Dataset):
8 | def __init__(self, data_opt, **kwargs):
9 | # dict to attr
10 | for kw, args in data_opt.items():
11 | setattr(self, kw, args)
12 |
13 | # can be used to override options defined in data_opt
14 | for kw, args in kwargs.items():
15 | setattr(self, kw, args)
16 |
17 | def __len__(self):
18 | pass
19 |
20 | def __getitem__(self, item):
21 | pass
22 |
23 | def check_info(self, gt_keys, lr_keys):
24 | if len(gt_keys) != len(lr_keys):
25 | raise ValueError(
26 | f'GT & LR contain different numbers of images ({len(gt_keys)} vs. {len(lr_keys)})')
27 |
28 | for i, (gt_key, lr_key) in enumerate(zip(gt_keys, lr_keys)):
29 | gt_info = self.parse_lmdb_key(gt_key)
30 | lr_info = self.parse_lmdb_key(lr_key)
31 |
32 | if gt_info[0] != lr_info[0]:
33 | raise ValueError(
34 | f'video index mismatch ({gt_info[0]} vs. {lr_info[0]} for the {i} key)')
35 |
36 | gt_num, gt_h, gt_w = gt_info[1]
37 | lr_num, lr_h, lr_w = lr_info[1]
38 | s = self.scale
39 | if (gt_num != lr_num) or (gt_h != lr_h * s) or (gt_w != lr_w * s):
40 | raise ValueError(
41 | f'video size mismatch ({gt_info[1]} vs. {lr_info[1]} for the {i} key)')
42 |
43 | if gt_info[2] != lr_info[2]:
44 | raise ValueError(
45 | f'frame mismatch ({gt_info[2]} vs. {lr_info[2]} for the {i} key)')
46 |
47 | @staticmethod
48 | def init_lmdb(seq_dir):
49 | env = lmdb.open(
50 | seq_dir, readonly=True, lock=False, readahead=False, meminit=False)
51 | return env
52 |
53 | @staticmethod
54 | def parse_lmdb_key(key):
55 | key_lst = key.split('_')
56 | idx, size, frm = key_lst[:-2], key_lst[-2], int(key_lst[-1])
57 | idx = '_'.join(idx)
58 | size = tuple(map(int, size.split('x'))) # n_frm, h, w
59 | return idx, size, frm
60 |
61 | @staticmethod
62 | def read_lmdb_frame(env, key, size):
63 | with env.begin(write=False) as txn:
64 | buf = txn.get(key.encode('ascii'))
65 | frm = np.frombuffer(buf, dtype=np.uint8).reshape(*size)
66 | return frm
67 |
68 | def crop_sequence(self, **kwargs):
69 | pass
70 |
71 | @staticmethod
72 | def augment_sequence(**kwargs):
73 | pass
74 |
--------------------------------------------------------------------------------
/codes/data/paired_folder_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 |
8 | from .base_dataset import BaseDataset
9 | from utils.base_utils import retrieve_files
10 |
11 |
12 | class PairedFolderDataset(BaseDataset):
13 | """ Folder dataset for paired data. It supports both BI & BD degradation.
14 | """
15 |
16 | def __init__(self, data_opt, **kwargs):
17 | super(PairedFolderDataset, self).__init__(data_opt, **kwargs)
18 |
19 | # get keys
20 | gt_keys = sorted(os.listdir(self.gt_seq_dir))
21 | lr_keys = sorted(os.listdir(self.lr_seq_dir))
22 | self.keys = sorted(list(set(gt_keys) & set(lr_keys)))
23 |
24 | # filter keys
25 | sel_keys = set(self.keys)
26 | if hasattr(self, 'filter_file') and self.filter_file is not None:
27 | with open(self.filter_file, 'r') as f:
28 | sel_keys = {line.strip() for line in f}
29 | elif hasattr(self, 'filter_list') and self.filter_list is not None:
30 | sel_keys = set(self.filter_list)
31 | self.keys = sorted(list(sel_keys & set(self.keys)))
32 |
33 | def __len__(self):
34 | return len(self.keys)
35 |
36 | def __getitem__(self, item):
37 | key = self.keys[item]
38 |
39 | # load gt frames
40 | gt_seq = []
41 | for frm_path in retrieve_files(osp.join(self.gt_seq_dir, key)):
42 | frm = cv2.imread(frm_path)[..., ::-1]
43 | gt_seq.append(frm)
44 | gt_seq = np.stack(gt_seq) # thwc|rgb|uint8
45 |
46 | # load lr frames
47 | lr_seq = []
48 | for frm_path in retrieve_files(osp.join(self.lr_seq_dir, key)):
49 | frm = cv2.imread(frm_path)[..., ::-1].astype(np.float32) / 255.0
50 | lr_seq.append(frm)
51 | lr_seq = np.stack(lr_seq) # thwc|rgb|float32
52 |
53 | # convert to tensor
54 | gt_tsr = torch.from_numpy(np.ascontiguousarray(gt_seq)) # uint8
55 | lr_tsr = torch.from_numpy(np.ascontiguousarray(lr_seq)) # float32
56 |
57 | # gt: thwc|rgb|uint8 | lr: thwc|rgb|float32
58 | return {
59 | 'gt': gt_tsr,
60 | 'lr': lr_tsr,
61 | 'seq_idx': key,
62 | 'frm_idx': sorted(os.listdir(osp.join(self.gt_seq_dir, key)))
63 | }
64 |
--------------------------------------------------------------------------------
/codes/data/paired_lmdb_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pickle
3 | import random
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 |
9 | from .base_dataset import BaseDataset
10 |
11 |
12 | class PairedLMDBDataset(BaseDataset):
13 | """ LMDB dataset for paired data (for BI degradation)
14 | """
15 |
16 | def __init__(self, data_opt, **kwargs):
17 | super(PairedLMDBDataset, self).__init__(data_opt, **kwargs)
18 |
19 | # load meta info
20 | gt_meta = pickle.load(
21 | open(osp.join(self.gt_seq_dir, 'meta_info.pkl'), 'rb'))
22 | lr_meta = pickle.load(
23 | open(osp.join(self.lr_seq_dir, 'meta_info.pkl'), 'rb'))
24 | gt_keys = sorted(gt_meta['keys'])
25 | lr_keys = sorted(lr_meta['keys'])
26 |
27 | self.check_info(gt_keys, lr_keys)
28 | self.gt_lr_keys = list(zip(gt_keys, lr_keys))
29 |
30 | # use partial videos
31 | if hasattr(self, 'filter_file') and self.filter_file is not None:
32 | with open(self.filter_file, 'r') as f:
33 | sel_seqs = { line.strip() for line in f }
34 | self.gt_lr_keys = list(filter(
35 | lambda x: self.parse_lmdb_key(x[0])[0] in sel_seqs, self.gt_lr_keys))
36 |
37 | # register parameters
38 | self.gt_env = None
39 | self.lr_env = None
40 |
41 | def __len__(self):
42 | return len(self.gt_lr_keys)
43 |
44 | def __getitem__(self, item):
45 | if self.gt_env is None:
46 | self.gt_env = self.init_lmdb(self.gt_seq_dir)
47 | if self.lr_env is None:
48 | self.lr_env = self.init_lmdb(self.lr_seq_dir)
49 |
50 | # parse info
51 | gt_key, lr_key = self.gt_lr_keys[item]
52 | idx, (tot_frm, gt_h, gt_w), cur_frm = self.parse_lmdb_key(gt_key)
53 | _, (_, lr_h, lr_w), _ = self.parse_lmdb_key(lr_key)
54 |
55 | c = 3 if self.data_type.lower() == 'rgb' else 1
56 | s = self.scale
57 | assert (gt_h == lr_h * s) and (gt_w == lr_w * s)
58 |
59 | # get frames
60 | gt_frms, lr_frms = [], []
61 | if self.moving_first_frame and (random.uniform(0, 1) > self.moving_factor):
62 | # load the first gt&lr frame
63 | gt_frm = self.read_lmdb_frame(
64 | self.gt_env, gt_key, size=(gt_h, gt_w, c))
65 | gt_frm = gt_frm.transpose(2, 0, 1) # chw|rgb|uint8
66 |
67 | lr_frm = self.read_lmdb_frame(
68 | self.lr_env, lr_key, size=(lr_h, lr_w, c))
69 | lr_frm = lr_frm.transpose(2, 0, 1) # chw|rgb|uint8
70 |
71 | # generate random moving parameters
72 | offsets = np.floor(
73 | np.random.uniform(-1.5, 1.5, size=(self.tempo_extent, 2)))
74 | offsets = offsets.astype(np.int32)
75 | pos = np.cumsum(offsets, axis=0)
76 | min_pos = np.min(pos, axis=0)
77 | topleft_pos = pos - min_pos
78 | range_pos = np.max(pos, axis=0) - min_pos
79 | c_h, c_w = lr_h - range_pos[0], lr_w - range_pos[1]
80 |
81 | # generate frames
82 | for i in range(self.tempo_extent):
83 | lr_top, lr_left = topleft_pos[i]
84 | lr_frms.append(lr_frm[
85 | :, lr_top: lr_top + c_h, lr_left: lr_left + c_w].copy())
86 |
87 | gt_top, gt_left = lr_top * s, lr_left * s
88 | gt_frms.append(gt_frm[
89 | :, gt_top: gt_top + c_h * s, gt_left: gt_left + c_w * s].copy())
90 | else:
91 | # read frames
92 | for i in range(cur_frm, cur_frm + self.tempo_extent):
93 | if i >= tot_frm:
94 | # reflect temporal paddding, e.g., (0,1,2) -> (0,1,2,1,0)
95 | gt_key = '{}_{}x{}x{}_{:04d}'.format(
96 | idx, tot_frm, gt_h, gt_w, 2 * tot_frm - i - 2)
97 | lr_key = '{}_{}x{}x{}_{:04d}'.format(
98 | idx, tot_frm, lr_h, lr_w, 2 * tot_frm - i - 2)
99 | else:
100 | gt_key = '{}_{}x{}x{}_{:04d}'.format(
101 | idx, tot_frm, gt_h, gt_w, i)
102 | lr_key = '{}_{}x{}x{}_{:04d}'.format(
103 | idx, tot_frm, lr_h, lr_w, i)
104 |
105 | gt_frm = self.read_lmdb_frame(
106 | self.gt_env, gt_key, size=(gt_h, gt_w, c))
107 | gt_frm = gt_frm.transpose(2, 0, 1) # chw|rgb|uint8
108 | gt_frms.append(gt_frm)
109 |
110 | lr_frm = self.read_lmdb_frame(
111 | self.lr_env, lr_key, size=(lr_h, lr_w, c))
112 | lr_frm = lr_frm.transpose(2, 0, 1)
113 | lr_frms.append(lr_frm)
114 |
115 | gt_frms = np.stack(gt_frms) # tchw|rgb|uint8
116 | lr_frms = np.stack(lr_frms)
117 |
118 | # crop randomly
119 | gt_pats, lr_pats = self.crop_sequence(gt_frms, lr_frms)
120 |
121 | # augment patches
122 | gt_pats, lr_pats = self.augment_sequence(gt_pats, lr_pats)
123 |
124 | # convert to tensor and normalize to range [0, 1]
125 | gt_tsr = torch.FloatTensor(np.ascontiguousarray(gt_pats)) / 255.0
126 | lr_tsr = torch.FloatTensor(np.ascontiguousarray(lr_pats)) / 255.0
127 |
128 | # tchw|rgb|float32
129 | return {'gt': gt_tsr, 'lr': lr_tsr}
130 |
131 | def crop_sequence(self, gt_frms, lr_frms):
132 | gt_csz = self.gt_crop_size
133 | lr_csz = self.gt_crop_size // self.scale
134 |
135 | lr_h, lr_w = lr_frms.shape[-2:]
136 | assert (lr_csz <= lr_h) and (lr_csz <= lr_w), \
137 | 'the crop size is larger than the image size'
138 |
139 | # crop lr
140 | lr_top = random.randint(0, lr_h - lr_csz)
141 | lr_left = random.randint(0, lr_w - lr_csz)
142 | lr_pats = lr_frms[
143 | ..., lr_top: lr_top + lr_csz, lr_left: lr_left + lr_csz]
144 |
145 | # crop gt
146 | gt_top = lr_top * self.scale
147 | gt_left = lr_left * self.scale
148 | gt_pats = gt_frms[
149 | ..., gt_top: gt_top + gt_csz, gt_left: gt_left + gt_csz]
150 |
151 | return gt_pats, lr_pats
152 |
153 | @staticmethod
154 | def augment_sequence(gt_pats, lr_pats):
155 | # flip
156 | axis = random.randint(1, 3)
157 | if axis > 1:
158 | gt_pats = np.flip(gt_pats, axis)
159 | lr_pats = np.flip(lr_pats, axis)
160 |
161 | # rotate 90 degree
162 | k = random.randint(0, 3)
163 | gt_pats = np.rot90(gt_pats, k, (2, 3))
164 | lr_pats = np.rot90(lr_pats, k, (2, 3))
165 |
166 | return gt_pats, lr_pats
167 |
--------------------------------------------------------------------------------
/codes/data/unpaired_folder_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 |
8 | from .base_dataset import BaseDataset
9 | from utils.base_utils import retrieve_files
10 |
11 |
12 | class UnpairedFolderDataset(BaseDataset):
13 | """ Folder dataset for unpaired data (for BD degradation)
14 | """
15 |
16 | def __init__(self, data_opt, **kwargs):
17 | super(UnpairedFolderDataset, self).__init__(data_opt, **kwargs)
18 |
19 | # get keys
20 | self.keys = sorted(os.listdir(self.gt_seq_dir))
21 |
22 | # filter keys
23 | sel_keys = set(self.keys)
24 | if hasattr(self, 'filter_file') and self.filter_file is not None:
25 | with open(self.filter_file, 'r') as f:
26 | sel_keys = {line.strip() for line in f}
27 | elif hasattr(self, 'filter_list') and self.filter_list is not None:
28 | sel_keys = set(self.filter_list)
29 | self.keys = sorted(list(sel_keys & set(self.keys)))
30 |
31 | def __len__(self):
32 | return len(self.keys)
33 |
34 | def __getitem__(self, item):
35 | key = self.keys[item]
36 |
37 | # load gt frames
38 | gt_seq = []
39 | for frm_path in retrieve_files(osp.join(self.gt_seq_dir, key)):
40 | gt_frm = cv2.imread(frm_path)[..., ::-1]
41 | gt_seq.append(gt_frm)
42 | gt_seq = np.stack(gt_seq) # thwc|rgb|uint8
43 |
44 | # convert to tensor
45 | gt_tsr = torch.from_numpy(np.ascontiguousarray(gt_seq)) # uint8
46 |
47 | # gt: thwc|rgb|uint8
48 | return {
49 | 'gt': gt_tsr,
50 | 'seq_idx': key,
51 | 'frm_idx': sorted(os.listdir(osp.join(self.gt_seq_dir, key)))
52 | }
53 |
--------------------------------------------------------------------------------
/codes/data/unpaired_lmdb_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pickle
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from .base_dataset import BaseDataset
9 |
10 |
11 | class UnpairedLMDBDataset(BaseDataset):
12 | """ LMDB dataset for unpaired data (for BD degradation)
13 | """
14 |
15 | def __init__(self, data_opt, **kwargs):
16 | super(UnpairedLMDBDataset, self).__init__(data_opt, **kwargs)
17 |
18 | # load meta info
19 | meta = pickle.load(
20 | open(osp.join(self.seq_dir, 'meta_info.pkl'), 'rb'))
21 | self.keys = sorted(meta['keys'])
22 |
23 | # use partial videos
24 | if hasattr(self, 'filter_file') and self.filter_file is not None:
25 | with open(self.filter_file, 'r') as f:
26 | sel_seqs = { line.strip() for line in f }
27 | self.keys = list(filter(
28 | lambda x: self.parse_lmdb_key(x)[0] in sel_seqs, self.keys))
29 |
30 | # register parameters
31 | self.env = None
32 |
33 | def __len__(self):
34 | return len(self.keys)
35 |
36 | def __getitem__(self, item):
37 | if self.env is None:
38 | self.env = self.init_lmdb(self.seq_dir)
39 |
40 | # parse info
41 | key = self.keys[item]
42 | idx, (tot_frm, h, w), cur_frm = self.parse_lmdb_key(key)
43 | c = 3 if self.data_type.lower() == 'rgb' else 1
44 |
45 | # get frames
46 | frms = []
47 | if self.moving_first_frame and (random.uniform(0, 1) > self.moving_factor):
48 | # load data
49 | frm = self.read_lmdb_frame(self.env, key, size=(h, w, c))
50 | frm = frm.transpose(2, 0, 1) # chw|rgb|uint8
51 |
52 | # generate random moving parameters
53 | offsets = np.floor(
54 | np.random.uniform(-3.5, 4.5, size=(self.tempo_extent, 2)))
55 | offsets = offsets.astype(np.int32)
56 | pos = np.cumsum(offsets, axis=0)
57 | min_pos = np.min(pos, axis=0)
58 | topleft_pos = pos - min_pos
59 | range_pos = np.max(pos, axis=0) - min_pos
60 | c_h, c_w = h - range_pos[0], w - range_pos[1]
61 |
62 | # generate frames
63 | for i in range(self.tempo_extent):
64 | top, left = topleft_pos[i]
65 | frms.append(frm[:, top: top + c_h, left: left + c_w].copy())
66 | else:
67 | # read frames
68 | for i in range(cur_frm, cur_frm + self.tempo_extent):
69 | if i >= tot_frm:
70 | # reflect temporal paddding, e.g., (0,1,2) -> (0,1,2,1,0)
71 | key = '{}_{}x{}x{}_{:04d}'.format(
72 | idx, tot_frm, h, w, 2 * tot_frm - i - 2)
73 | else:
74 | key = '{}_{}x{}x{}_{:04d}'.format(
75 | idx, tot_frm, h, w, i)
76 |
77 | frm = self.read_lmdb_frame(self.env, key, size=(h, w, c))
78 | frm = frm.transpose(2, 0, 1) # chw|rgb|uint8
79 | frms.append(frm)
80 |
81 | frms = np.stack(frms) # tchw|rgb|uint8
82 |
83 | # crop randomly
84 | pats = self.crop_sequence(frms)
85 |
86 | # augment patches
87 | pats = self.augment_sequence(pats)
88 |
89 | # convert to tensor and normalize to range [0, 1]
90 | tsr = torch.FloatTensor(np.ascontiguousarray(pats)) / 255.0
91 |
92 | # tchw|rgb|float32
93 | return {'gt': tsr}
94 |
95 | def crop_sequence(self, frms):
96 | csz = self.crop_size
97 |
98 | h, w = frms.shape[-2:]
99 | assert (csz <= h) and (csz <= w), \
100 | f'The crop size is larger than the image size ({csz} vs. h{h}/w{w})'
101 |
102 | # crop
103 | top = random.randint(0, h - csz)
104 | left = random.randint(0, w - csz)
105 | pats = frms[..., top: top + csz, left: left + csz]
106 |
107 | return pats
108 |
109 | @staticmethod
110 | def augment_sequence(pats):
111 | # flip spatially
112 | axis = random.randint(1, 3)
113 | if axis > 1:
114 | pats = np.flip(pats, axis)
115 |
116 | # flip temporally
117 | axis = random.randint(0, 1)
118 | if axis < 1:
119 | pats = np.flip(pats, axis)
120 |
121 | # rotate
122 | k = random.randint(0, 3)
123 | pats = np.rot90(pats, k, (2, 3))
124 |
125 | return pats
126 |
--------------------------------------------------------------------------------
/codes/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import math
4 | import time
5 |
6 | import torch
7 |
8 | from data import create_dataloader
9 | from models import define_model
10 | from metrics import create_metric_calculator
11 | from utils import dist_utils, base_utils, data_utils
12 |
13 |
14 | def train(opt):
15 | # print configurations
16 | base_utils.log_info(f'{20*"-"} Configurations {20*"-"}')
17 | base_utils.print_options(opt)
18 |
19 | # create data loader
20 | train_loader = create_dataloader(opt, phase='train', idx='train')
21 |
22 | # build model
23 | model = define_model(opt)
24 |
25 | # set training params
26 | total_sample, iter_per_epoch = len(train_loader.dataset), len(train_loader)
27 | total_iter = opt['train']['total_iter']
28 | total_epoch = int(math.ceil(total_iter / iter_per_epoch))
29 | start_iter, iter = opt['train']['start_iter'], 0
30 | test_freq = opt['test']['test_freq']
31 | log_freq = opt['logger']['log_freq']
32 | ckpt_freq = opt['logger']['ckpt_freq']
33 |
34 | base_utils.log_info(f'Number of the training samples: {total_sample}')
35 | base_utils.log_info(f'{total_epoch} epochs needed for {total_iter} iterations')
36 |
37 | # train
38 | for epoch in range(total_epoch):
39 | if opt['dist']:
40 | train_loader.sampler.set_epoch(epoch)
41 |
42 | for data in train_loader:
43 | # update iter
44 | iter += 1
45 | curr_iter = start_iter + iter
46 | if iter > total_iter: break
47 |
48 | # prepare data
49 | model.prepare_training_data(data)
50 |
51 | # train a mini-batch
52 | model.train()
53 |
54 | # update running log
55 | model.update_running_log()
56 |
57 | # update learning rate
58 | model.update_learning_rate()
59 |
60 | # print messages
61 | if log_freq > 0 and curr_iter % log_freq == 0:
62 | msg = model.get_format_msg(epoch, curr_iter)
63 | base_utils.log_info(msg)
64 |
65 | # save model
66 | if ckpt_freq > 0 and curr_iter % ckpt_freq == 0:
67 | model.save(curr_iter)
68 |
69 | # evaluate model
70 | if test_freq > 0 and curr_iter % test_freq == 0:
71 | # set model index
72 | model_idx = f'G_iter{curr_iter}'
73 |
74 | # for each testset
75 | for dataset_idx in sorted(opt['dataset'].keys()):
76 | # select test dataset
77 | if 'test' not in dataset_idx: continue
78 |
79 | ds_name = opt['dataset'][dataset_idx]['name']
80 | base_utils.log_info(f'Testing on {ds_name} dataset')
81 |
82 | # create data loader
83 | test_loader = create_dataloader(
84 | opt, phase='test', idx=dataset_idx)
85 | test_dataset = test_loader.dataset
86 | num_seq = len(test_dataset)
87 |
88 | # create metric calculator
89 | metric_calculator = create_metric_calculator(opt)
90 |
91 | # infer a sequence
92 | rank, world_size = dist_utils.get_dist_info()
93 | for idx in range(rank, num_seq, world_size):
94 | # fetch data
95 | data = test_dataset[idx]
96 |
97 | # prepare data
98 | model.prepare_inference_data(data)
99 |
100 | # infer
101 | hr_seq = model.infer()
102 |
103 | # save hr results
104 | if opt['test']['save_res']:
105 | res_dir = osp.join(
106 | opt['test']['res_dir'], ds_name, model_idx)
107 | res_seq_dir = osp.join(res_dir, data['seq_idx'])
108 | data_utils.save_sequence(
109 | res_seq_dir, hr_seq, data['frm_idx'], to_bgr=True)
110 |
111 | # compute metrics for the current sequence
112 | if metric_calculator is not None:
113 | gt_seq = data['gt'].numpy()
114 | metric_calculator.compute_sequence_metrics(
115 | data['seq_idx'], gt_seq, hr_seq)
116 |
117 | # save/print results
118 | if metric_calculator is not None:
119 | seq_idx_lst = [data['seq_idx'] for data in test_dataset]
120 | metric_calculator.gather(seq_idx_lst)
121 |
122 | if opt['test'].get('save_json'):
123 | # write results to a json file
124 | json_path = osp.join(
125 | opt['test']['json_dir'], f'{ds_name}_avg.json')
126 | metric_calculator.save(model_idx, json_path, override=True)
127 | else:
128 | # print directly
129 | metric_calculator.display()
130 |
131 |
132 | def test(opt):
133 | # logging
134 | base_utils.print_options(opt)
135 |
136 | # infer and evaluate performance for each model
137 | for load_path in opt['model']['generator']['load_path_lst']:
138 | # set model index
139 | model_idx = osp.splitext(osp.split(load_path)[-1])[0]
140 |
141 | # log
142 | base_utils.log_info(f'{"=" * 40}')
143 | base_utils.log_info(f'Testing model: {model_idx}')
144 | base_utils.log_info(f'{"=" * 40}')
145 |
146 | # create model
147 | opt['model']['generator']['load_path'] = load_path
148 | model = define_model(opt)
149 |
150 | # for each test dataset
151 | for dataset_idx in sorted(opt['dataset'].keys()):
152 | # select testing dataset
153 | if 'test' not in dataset_idx:
154 | continue
155 |
156 | ds_name = opt['dataset'][dataset_idx]['name']
157 | base_utils.log_info(f'Testing on {ds_name} dataset')
158 |
159 | # create data loader
160 | test_loader = create_dataloader(opt, phase='test', idx=dataset_idx)
161 | test_dataset = test_loader.dataset
162 | num_seq = len(test_dataset)
163 |
164 | # create metric calculator
165 | metric_calculator = create_metric_calculator(opt)
166 |
167 | # infer a sequence
168 | rank, world_size = dist_utils.get_dist_info()
169 | for idx in range(rank, num_seq, world_size):
170 | # fetch data
171 | data = test_dataset[idx]
172 |
173 | # prepare data
174 | model.prepare_inference_data(data)
175 |
176 | # infer
177 | hr_seq = model.infer()
178 |
179 | # save hr results
180 | if opt['test']['save_res']:
181 | res_dir = osp.join(
182 | opt['test']['res_dir'], ds_name, model_idx)
183 | res_seq_dir = osp.join(res_dir, data['seq_idx'])
184 | data_utils.save_sequence(
185 | res_seq_dir, hr_seq, data['frm_idx'], to_bgr=True)
186 |
187 | # compute metrics for the current sequence
188 | if metric_calculator is not None:
189 | gt_seq = data['gt'].numpy()
190 | metric_calculator.compute_sequence_metrics(
191 | data['seq_idx'], gt_seq, hr_seq)
192 |
193 | # save/print results
194 | if metric_calculator is not None:
195 | seq_idx_lst = [data['seq_idx'] for data in test_dataset]
196 | metric_calculator.gather(seq_idx_lst)
197 |
198 | if opt['test'].get('save_json'):
199 | # write results to a json file
200 | json_path = osp.join(
201 | opt['test']['json_dir'], f'{ds_name}_avg.json')
202 | metric_calculator.save(model_idx, json_path, override=True)
203 | else:
204 | # print directly
205 | metric_calculator.display()
206 |
207 | base_utils.log_info('-' * 40)
208 |
209 |
210 | def profile(opt, lr_size, test_speed=False):
211 | # basic configs
212 | scale = opt['scale']
213 | device = torch.device(opt['device'])
214 | msg = '\n'
215 |
216 | torch.backends.cudnn.benchmark = True
217 | # torch.backends.cudnn.deterministic = False
218 |
219 | # logging
220 | base_utils.print_options(opt['model']['generator'])
221 |
222 | lr_size_lst = tuple(map(int, lr_size.split('x')))
223 | hr_size = f'{lr_size_lst[0]}x{lr_size_lst[1]*scale}x{lr_size_lst[2]*scale}'
224 | msg += f'{"*"*40}\nResolution: {lr_size} -> {hr_size} ({scale}x SR)'
225 |
226 | # create model
227 | from models.networks import define_generator
228 | net_G = define_generator(opt).to(device)
229 | # base_utils.log_info(f'\n{net_G.__str__()}')
230 |
231 | # profile
232 | lr_size = tuple(map(int, lr_size.split('x')))
233 | gflops_dict, params_dict = net_G.profile(lr_size, device)
234 |
235 | gflops_all, params_all = 0, 0
236 | for module_name in gflops_dict.keys():
237 | gflops, params = gflops_dict[module_name], params_dict[module_name]
238 | msg += f'\n{"-"*40}\nModule: [{module_name}]'
239 | msg += f'\n FLOPs (10^9): {gflops:.3f}'
240 | msg += f'\n Parameters (10^6): {params/1e6:.3f}'
241 | gflops_all += gflops
242 | params_all += params
243 | msg += f'\n{"-"*40}\nOverall'
244 | msg += f'\n FLOPs (10^9): {gflops_all:.3f}'
245 | msg += f'\n Parameters (10^6): {params_all/1e6:.3f}\n{"*"*40}'
246 |
247 | # test running speed
248 | if test_speed:
249 | n_test, tot_time = 30, 0
250 | for i in range(n_test):
251 | dummy_input_list = net_G.generate_dummy_data(lr_size, device)
252 |
253 | start_time = time.time()
254 | # ---
255 | net_G.eval()
256 | with torch.no_grad():
257 | _ = net_G.step(*dummy_input_list)
258 | torch.cuda.synchronize()
259 | # ---
260 | end_time = time.time()
261 | tot_time += end_time - start_time
262 | msg += f'\nSpeed: {n_test/tot_time:.3f} FPS (averaged over {n_test} runs)\n{"*"*40}'
263 |
264 | base_utils.log_info(msg)
265 |
266 |
267 | if __name__ == '__main__':
268 | # === parse arguments === #
269 | args = base_utils.parse_agrs()
270 |
271 | # === generic settings === #
272 | # parse configs, set device, set ramdom seed
273 | opt = base_utils.parse_configs(args)
274 | # set logger
275 | base_utils.setup_logger('base')
276 | # set paths
277 | base_utils.setup_paths(opt, args.mode)
278 |
279 | # === train === #
280 | if args.mode == 'train':
281 | train(opt)
282 |
283 | # === test === #
284 | elif args.mode == 'test':
285 | test(opt)
286 |
287 | # === profile === #
288 | elif args.mode == 'profile':
289 | profile(opt, args.lr_size, args.test_speed)
290 |
291 | else:
292 | raise ValueError(f'Unrecognized mode: {args.mode} (train|test|profile)')
293 |
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/metrics/LPIPS/__init__.py
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | from skimage.measure import compare_ssim
8 | import torch
9 | from torch.autograd import Variable
10 |
11 | from metrics.LPIPS.models import dist_model
12 |
13 |
14 | class PerceptualLoss(torch.nn.Module):
15 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric)
16 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
17 | super(PerceptualLoss, self).__init__()
18 | print('Setting up Perceptual loss...')
19 | self.use_gpu = use_gpu
20 | self.spatial = spatial
21 | self.gpu_ids = gpu_ids
22 | self.model = dist_model.DistModel()
23 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, version=version)
24 | print('...[%s] initialized'%self.model.name())
25 | print('...Done')
26 |
27 | def forward(self, pred, target, normalize=False):
28 | """
29 | Pred and target are Variables.
30 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
31 | If normalize is False, assumes the images are already between [-1,+1]
32 |
33 | Inputs pred and target are Nx3xHxW
34 | Output pytorch Variable N long
35 | """
36 |
37 | if normalize:
38 | target = 2 * target - 1
39 | pred = 2 * pred - 1
40 |
41 | return self.model.forward(target, pred)
42 |
43 | def normalize_tensor(in_feat,eps=1e-10):
44 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
45 | return in_feat/(norm_factor+eps)
46 |
47 | def l2(p0, p1, range=255.):
48 | return .5*np.mean((p0 / range - p1 / range)**2)
49 |
50 | def psnr(p0, p1, peak=255.):
51 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
52 |
53 | def dssim(p0, p1, range=255.):
54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
55 |
56 | def rgb2lab(in_img,mean_cent=False):
57 | from skimage import color
58 | img_lab = color.rgb2lab(in_img)
59 | if(mean_cent):
60 | img_lab[:,:,0] = img_lab[:,:,0]-50
61 | return img_lab
62 |
63 | def tensor2np(tensor_obj):
64 | # change dimension of a tensor object into a numpy array
65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
66 |
67 | def np2tensor(np_obj):
68 | # change dimenion of np array into tensor array
69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
70 |
71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
72 | # image tensor to lab tensor
73 | from skimage import color
74 |
75 | img = tensor2im(image_tensor)
76 | img_lab = color.rgb2lab(img)
77 | if(mc_only):
78 | img_lab[:,:,0] = img_lab[:,:,0]-50
79 | if(to_norm and not mc_only):
80 | img_lab[:,:,0] = img_lab[:,:,0]-50
81 | img_lab = img_lab/100.
82 |
83 | return np2tensor(img_lab)
84 |
85 | def tensorlab2tensor(lab_tensor,return_inbnd=False):
86 | from skimage import color
87 | import warnings
88 | warnings.filterwarnings("ignore")
89 |
90 | lab = tensor2np(lab_tensor)*100.
91 | lab[:,:,0] = lab[:,:,0]+50
92 |
93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
94 | if(return_inbnd):
95 | # convert back to lab, see if we match
96 | lab_back = color.rgb2lab(rgb_back.astype('uint8'))
97 | mask = 1.*np.isclose(lab_back,lab,atol=2.)
98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
99 | return (im2tensor(rgb_back),mask)
100 | else:
101 | return im2tensor(rgb_back)
102 |
103 | def rgb2lab(input):
104 | from skimage import color
105 | return color.rgb2lab(input / 255.)
106 |
107 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
108 | image_numpy = image_tensor[0].cpu().float().numpy()
109 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
110 | return image_numpy.astype(imtype)
111 |
112 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
113 | return torch.Tensor((image / factor - cent)
114 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
115 |
116 | def tensor2vec(vector_tensor):
117 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
118 |
119 | def voc_ap(rec, prec, use_07_metric=False):
120 | """ ap = voc_ap(rec, prec, [use_07_metric])
121 | Compute VOC AP given precision and recall.
122 | If use_07_metric is true, uses the
123 | VOC 07 11 point method (default:False).
124 | """
125 | if use_07_metric:
126 | # 11 point metric
127 | ap = 0.
128 | for t in np.arange(0., 1.1, 0.1):
129 | if np.sum(rec >= t) == 0:
130 | p = 0
131 | else:
132 | p = np.max(prec[rec >= t])
133 | ap = ap + p / 11.
134 | else:
135 | # correct AP calculation
136 | # first append sentinel values at the end
137 | mrec = np.concatenate(([0.], rec, [1.]))
138 | mpre = np.concatenate(([0.], prec, [0.]))
139 |
140 | # compute the precision envelope
141 | for i in range(mpre.size - 1, 0, -1):
142 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
143 |
144 | # to calculate area under PR curve, look for points
145 | # where X axis (recall) changes value
146 | i = np.where(mrec[1:] != mrec[:-1])[0]
147 |
148 | # and sum (\Delta recall) * prec
149 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
150 | return ap
151 |
152 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
153 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
154 | image_numpy = image_tensor[0].cpu().float().numpy()
155 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
156 | return image_numpy.astype(imtype)
157 |
158 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
159 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
160 | return torch.Tensor((image / factor - cent)
161 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
162 |
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.autograd import Variable
4 | from pdb import set_trace as st
5 |
6 |
7 | class BaseModel():
8 | def __init__(self):
9 | pass;
10 |
11 | def name(self):
12 | return 'BaseModel'
13 |
14 | def initialize(self, use_gpu=True, gpu_ids=[0]):
15 | self.use_gpu = use_gpu
16 | self.gpu_ids = gpu_ids
17 |
18 | def forward(self):
19 | pass
20 |
21 | def get_image_paths(self):
22 | pass
23 |
24 | def optimize_parameters(self):
25 | pass
26 |
27 | def get_current_visuals(self):
28 | return self.input
29 |
30 | def get_current_errors(self):
31 | return {}
32 |
33 | def save(self, label):
34 | pass
35 |
36 | # helper saving function that can be used by subclasses
37 | def save_network(self, network, path, network_label, epoch_label):
38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
39 | save_path = os.path.join(path, save_filename)
40 | torch.save(network.state_dict(), save_path)
41 |
42 | # helper loading function that can be used by subclasses
43 | def load_network(self, network, network_label, epoch_label):
44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 | save_path = os.path.join(self.save_dir, save_filename)
46 | print('Loading network from %s'%save_path)
47 | network.load_state_dict(torch.load(save_path))
48 |
49 | def update_learning_rate():
50 | pass
51 |
52 | def get_image_paths(self):
53 | return self.image_paths
54 |
55 | def save_done(self, flag=False):
56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag)
57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
58 |
59 |
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/models/networks_basic.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import sys
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 | from torch.autograd import Variable
9 | import numpy as np
10 | from pdb import set_trace as st
11 | from skimage import color
12 | from IPython import embed
13 | from . import pretrained_networks as pn
14 |
15 | import metrics.LPIPS.models as util
16 |
17 | def spatial_average(in_tens, keepdim=True):
18 | return in_tens.mean([2,3],keepdim=keepdim)
19 |
20 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
21 | in_H, in_W = in_tens.shape[2], in_tens.shape[3]
22 | scale_factor_H, scale_factor_W = 1.*out_HW[0]/in_H, 1.*out_HW[1]/in_W
23 |
24 | return nn.Upsample(scale_factor=(scale_factor_H, scale_factor_W), mode='bilinear', align_corners=False)(in_tens)
25 |
26 | # Learned perceptual metric
27 | class PNetLin(nn.Module):
28 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
29 | super(PNetLin, self).__init__()
30 |
31 | self.pnet_type = pnet_type
32 | self.pnet_tune = pnet_tune
33 | self.pnet_rand = pnet_rand
34 | self.spatial = spatial
35 | self.lpips = lpips
36 | self.version = version
37 | self.scaling_layer = ScalingLayer()
38 |
39 | if(self.pnet_type in ['vgg','vgg16']):
40 | net_type = pn.vgg16
41 | self.chns = [64,128,256,512,512]
42 | elif(self.pnet_type=='alex'):
43 | net_type = pn.alexnet
44 | self.chns = [64,192,384,256,256]
45 | elif(self.pnet_type=='squeeze'):
46 | net_type = pn.squeezenet
47 | self.chns = [64,128,256,384,384,512,512]
48 | self.L = len(self.chns)
49 |
50 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
51 |
52 | if(lpips):
53 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
54 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
55 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
56 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
57 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
58 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
59 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
60 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
61 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
62 | self.lins+=[self.lin5,self.lin6]
63 |
64 | def forward(self, in0, in1, retPerLayer=False):
65 | # v0.0 - original release had a bug, where input was not scaled
66 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
67 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
68 | feats0, feats1, diffs = {}, {}, {}
69 |
70 | for kk in range(self.L):
71 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
72 | diffs[kk] = (feats0[kk]-feats1[kk])**2
73 |
74 | if(self.lpips):
75 | if(self.spatial):
76 | res = [upsample(self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
77 | else:
78 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
79 | else:
80 | if(self.spatial):
81 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
82 | else:
83 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
84 |
85 | val = res[0]
86 | for l in range(1,self.L):
87 | val += res[l]
88 |
89 | if(retPerLayer):
90 | return (val, res)
91 | else:
92 | return val
93 |
94 | class ScalingLayer(nn.Module):
95 | def __init__(self):
96 | super(ScalingLayer, self).__init__()
97 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
98 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
99 |
100 | def forward(self, inp):
101 | return (inp - self.shift) / self.scale
102 |
103 |
104 | class NetLinLayer(nn.Module):
105 | ''' A single linear layer which does a 1x1 conv '''
106 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
107 | super(NetLinLayer, self).__init__()
108 |
109 | layers = [nn.Dropout(),] if(use_dropout) else []
110 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
111 | self.model = nn.Sequential(*layers)
112 |
113 |
114 | class Dist2LogitLayer(nn.Module):
115 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
116 | def __init__(self, chn_mid=32, use_sigmoid=True):
117 | super(Dist2LogitLayer, self).__init__()
118 |
119 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
120 | layers += [nn.LeakyReLU(0.2,True),]
121 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
122 | layers += [nn.LeakyReLU(0.2,True),]
123 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
124 | if(use_sigmoid):
125 | layers += [nn.Sigmoid(),]
126 | self.model = nn.Sequential(*layers)
127 |
128 | def forward(self,d0,d1,eps=0.1):
129 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
130 |
131 | class BCERankingLoss(nn.Module):
132 | def __init__(self, chn_mid=32):
133 | super(BCERankingLoss, self).__init__()
134 | self.net = Dist2LogitLayer(chn_mid=chn_mid)
135 | # self.parameters = list(self.net.parameters())
136 | self.loss = torch.nn.BCELoss()
137 |
138 | def forward(self, d0, d1, judge):
139 | per = (judge+1.)/2.
140 | self.logit = self.net.forward(d0,d1)
141 | return self.loss(self.logit, per)
142 |
143 | # L2, DSSIM evaluation
144 | class FakeNet(nn.Module):
145 | def __init__(self, use_gpu=True, colorspace='Lab'):
146 | super(FakeNet, self).__init__()
147 | self.use_gpu = use_gpu
148 | self.colorspace=colorspace
149 |
150 | class L2(FakeNet):
151 |
152 | def forward(self, in0, in1, retPerLayer=None):
153 | assert(in0.size()[0]==1) # currently only supports batchSize 1
154 |
155 | if(self.colorspace=='RGB'):
156 | (N,C,X,Y) = in0.size()
157 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
158 | return value
159 | elif(self.colorspace=='Lab'):
160 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
161 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
162 | ret_var = Variable( torch.Tensor((value,) ) )
163 | if(self.use_gpu):
164 | ret_var = ret_var.cuda()
165 | return ret_var
166 |
167 | class DSSIM(FakeNet):
168 |
169 | def forward(self, in0, in1, retPerLayer=None):
170 | assert(in0.size()[0]==1) # currently only supports batchSize 1
171 |
172 | if(self.colorspace=='RGB'):
173 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
174 | elif(self.colorspace=='Lab'):
175 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
176 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
177 | ret_var = Variable( torch.Tensor((value,) ) )
178 | if(self.use_gpu):
179 | ret_var = ret_var.cuda()
180 | return ret_var
181 |
182 | def print_network(net):
183 | num_params = 0
184 | for param in net.parameters():
185 | num_params += param.numel()
186 | print('Network',net)
187 | print('Total number of parameters: %d' % num_params)
188 |
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/models/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torchvision import models as tv
4 |
5 |
6 | class squeezenet(torch.nn.Module):
7 | def __init__(self, requires_grad=False, pretrained=True):
8 | super(squeezenet, self).__init__()
9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10 | self.slice1 = torch.nn.Sequential()
11 | self.slice2 = torch.nn.Sequential()
12 | self.slice3 = torch.nn.Sequential()
13 | self.slice4 = torch.nn.Sequential()
14 | self.slice5 = torch.nn.Sequential()
15 | self.slice6 = torch.nn.Sequential()
16 | self.slice7 = torch.nn.Sequential()
17 | self.N_slices = 7
18 | for x in range(2):
19 | self.slice1.add_module(str(x), pretrained_features[x])
20 | for x in range(2,5):
21 | self.slice2.add_module(str(x), pretrained_features[x])
22 | for x in range(5, 8):
23 | self.slice3.add_module(str(x), pretrained_features[x])
24 | for x in range(8, 10):
25 | self.slice4.add_module(str(x), pretrained_features[x])
26 | for x in range(10, 11):
27 | self.slice5.add_module(str(x), pretrained_features[x])
28 | for x in range(11, 12):
29 | self.slice6.add_module(str(x), pretrained_features[x])
30 | for x in range(12, 13):
31 | self.slice7.add_module(str(x), pretrained_features[x])
32 | if not requires_grad:
33 | for param in self.parameters():
34 | param.requires_grad = False
35 |
36 | def forward(self, X):
37 | h = self.slice1(X)
38 | h_relu1 = h
39 | h = self.slice2(h)
40 | h_relu2 = h
41 | h = self.slice3(h)
42 | h_relu3 = h
43 | h = self.slice4(h)
44 | h_relu4 = h
45 | h = self.slice5(h)
46 | h_relu5 = h
47 | h = self.slice6(h)
48 | h_relu6 = h
49 | h = self.slice7(h)
50 | h_relu7 = h
51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53 |
54 | return out
55 |
56 |
57 | class alexnet(torch.nn.Module):
58 | def __init__(self, requires_grad=False, pretrained=True):
59 | super(alexnet, self).__init__()
60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61 | self.slice1 = torch.nn.Sequential()
62 | self.slice2 = torch.nn.Sequential()
63 | self.slice3 = torch.nn.Sequential()
64 | self.slice4 = torch.nn.Sequential()
65 | self.slice5 = torch.nn.Sequential()
66 | self.N_slices = 5
67 | for x in range(2):
68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69 | for x in range(2, 5):
70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71 | for x in range(5, 8):
72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73 | for x in range(8, 10):
74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75 | for x in range(10, 12):
76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77 | if not requires_grad:
78 | for param in self.parameters():
79 | param.requires_grad = False
80 |
81 | def forward(self, X):
82 | h = self.slice1(X)
83 | h_relu1 = h
84 | h = self.slice2(h)
85 | h_relu2 = h
86 | h = self.slice3(h)
87 | h_relu3 = h
88 | h = self.slice4(h)
89 | h_relu4 = h
90 | h = self.slice5(h)
91 | h_relu5 = h
92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94 |
95 | return out
96 |
97 | class vgg16(torch.nn.Module):
98 | def __init__(self, requires_grad=False, pretrained=True):
99 | super(vgg16, self).__init__()
100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101 | self.slice1 = torch.nn.Sequential()
102 | self.slice2 = torch.nn.Sequential()
103 | self.slice3 = torch.nn.Sequential()
104 | self.slice4 = torch.nn.Sequential()
105 | self.slice5 = torch.nn.Sequential()
106 | self.N_slices = 5
107 | for x in range(4):
108 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
109 | for x in range(4, 9):
110 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(9, 16):
112 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(16, 23):
114 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(23, 30):
116 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
117 | if not requires_grad:
118 | for param in self.parameters():
119 | param.requires_grad = False
120 |
121 | def forward(self, X):
122 | h = self.slice1(X)
123 | h_relu1_2 = h
124 | h = self.slice2(h)
125 | h_relu2_2 = h
126 | h = self.slice3(h)
127 | h_relu3_3 = h
128 | h = self.slice4(h)
129 | h_relu4_3 = h
130 | h = self.slice5(h)
131 | h_relu5_3 = h
132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134 |
135 | return out
136 |
137 |
138 |
139 | class resnet(torch.nn.Module):
140 | def __init__(self, requires_grad=False, pretrained=True, num=18):
141 | super(resnet, self).__init__()
142 | if(num==18):
143 | self.net = tv.resnet18(pretrained=pretrained)
144 | elif(num==34):
145 | self.net = tv.resnet34(pretrained=pretrained)
146 | elif(num==50):
147 | self.net = tv.resnet50(pretrained=pretrained)
148 | elif(num==101):
149 | self.net = tv.resnet101(pretrained=pretrained)
150 | elif(num==152):
151 | self.net = tv.resnet152(pretrained=pretrained)
152 | self.N_slices = 5
153 |
154 | self.conv1 = self.net.conv1
155 | self.bn1 = self.net.bn1
156 | self.relu = self.net.relu
157 | self.maxpool = self.net.maxpool
158 | self.layer1 = self.net.layer1
159 | self.layer2 = self.net.layer2
160 | self.layer3 = self.net.layer3
161 | self.layer4 = self.net.layer4
162 |
163 | def forward(self, X):
164 | h = self.conv1(X)
165 | h = self.bn1(h)
166 | h = self.relu(h)
167 | h_relu1 = h
168 | h = self.maxpool(h)
169 | h = self.layer1(h)
170 | h_conv2 = h
171 | h = self.layer2(h)
172 | h_conv3 = h
173 | h = self.layer3(h)
174 | h_conv4 = h
175 | h = self.layer4(h)
176 | h_conv5 = h
177 |
178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180 |
181 | return out
182 |
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/models/weights/v0.1/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/metrics/LPIPS/models/weights/v0.1/alex.pth
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/models/weights/v0.1/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/metrics/LPIPS/models/weights/v0.1/squeeze.pth
--------------------------------------------------------------------------------
/codes/metrics/LPIPS/models/weights/v0.1/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/metrics/LPIPS/models/weights/v0.1/vgg.pth
--------------------------------------------------------------------------------
/codes/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .metric_calculator import MetricCalculator
2 |
3 |
4 | def create_metric_calculator(opt):
5 | if 'metric' in opt and opt['metric']:
6 | return MetricCalculator(opt)
7 | else:
8 | return None
9 |
--------------------------------------------------------------------------------
/codes/metrics/metric_calculator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import json
4 | from collections import OrderedDict
5 |
6 | import numpy as np
7 | import cv2
8 | import torch
9 | import torch.distributed as dist
10 |
11 | from utils import base_utils, data_utils, net_utils
12 | from utils.dist_utils import master_only
13 | from .LPIPS.models.dist_model import DistModel
14 |
15 |
16 | class MetricCalculator():
17 | """ Metric calculator for model evaluation
18 |
19 | Currently supported metrics:
20 | * PSNR (RGB and Y)
21 | * LPIPS
22 | * tOF as described in TecoGAN paper
23 |
24 | TODO:
25 | * save/print metrics in a fixed order
26 | """
27 |
28 | def __init__(self, opt):
29 | # initialize
30 | self.metric_opt = opt['metric']
31 | self.device = torch.device(opt['device'])
32 | self.dist = opt['dist']
33 | self.rank = opt['rank']
34 |
35 | self.psnr_colorspace = ''
36 | self.dm = None
37 |
38 | self.reset()
39 |
40 | # update configs for each metric
41 | for metric_type, cfg in self.metric_opt.items():
42 | if metric_type.lower() == 'psnr':
43 | self.psnr_colorspace = cfg['colorspace']
44 |
45 | if metric_type.lower() == 'lpips':
46 | self.dm = DistModel()
47 | self.dm.initialize(
48 | model=cfg['model'],
49 | net=cfg['net'],
50 | colorspace=cfg['colorspace'],
51 | spatial=cfg['spatial'],
52 | use_gpu=(opt['device'] == 'cuda'),
53 | gpu_ids=[0 if not self.dist else opt['local_rank']],
54 | version=cfg['version'])
55 |
56 | def reset(self):
57 | self.reset_per_sequence()
58 | self.metric_dict = OrderedDict()
59 | self.avg_metric_dict = OrderedDict()
60 |
61 | def reset_per_sequence(self):
62 | self.seq_idx_curr = ''
63 | self.true_img_cur = None
64 | self.pred_img_cur = None
65 | self.true_img_pre = None
66 | self.pred_img_pre = None
67 |
68 | def gather(self, seq_idx_lst):
69 | """ Gather results from all devices.
70 | Results will be updated into self.metric_dict on device 0
71 | """
72 |
73 | # mdict
74 | # {
75 | # 'seq_idx': {
76 | # 'metric1': [frm1, frm2, ...],
77 | # 'metric2': [frm1, frm2, ...]
78 | # }
79 | # }
80 | mdict = self.metric_dict
81 |
82 | mtype_lst = self.metric_opt.keys()
83 | n_metric = len(mtype_lst)
84 |
85 | # avg_mdict
86 | # {
87 | # 'seq_idx': torch.tensor([metric1_avg, metric2_avg, ...])
88 | # }
89 | avg_mdict = {
90 | seq_idx: torch.zeros(n_metric, dtype=torch.float32, device=self.device)
91 | for seq_idx in seq_idx_lst
92 | }
93 |
94 | # average metric results for each sequence
95 | for i, mtype in enumerate(mtype_lst):
96 | for seq_idx, mdict_per_seq in mdict.items(): # ordered
97 | avg_mdict[seq_idx][i] += np.mean(mdict_per_seq[mtype])
98 |
99 | if self.dist:
100 | for seq_idx, tensor in avg_mdict.items():
101 | dist.reduce(tensor, dst=0)
102 | dist.barrier()
103 |
104 | if self.rank == 0:
105 | # avg_metric_dict
106 | # {
107 | # 'seq_idx': {
108 | # 'metric1': avg,
109 | # 'metric2': avg
110 | # }
111 | # }
112 |
113 | for seq_idx in seq_idx_lst:
114 | self.avg_metric_dict[seq_idx] = OrderedDict([
115 | (mtype, avg_mdict[seq_idx][i].item())
116 | for i, mtype in enumerate(mtype_lst)
117 | ])
118 |
119 | def average(self):
120 | """ Return a dict including metric results averaged over all sequence
121 | """
122 |
123 | metric_avg_dict = OrderedDict()
124 | for mtype in self.metric_opt.keys():
125 | metric_all_seq = []
126 | for sqe_idx, mdict_per_seq in self.avg_metric_dict.items():
127 | metric_all_seq.append(mdict_per_seq[mtype])
128 |
129 | metric_avg_dict[mtype] = np.mean(metric_all_seq)
130 |
131 | return metric_avg_dict
132 |
133 | @master_only
134 | def display(self):
135 | # per sequence results
136 | for seq_idx, mdict_per_seq in self.avg_metric_dict.items():
137 | base_utils.log_info(f'Sequence: {seq_idx}')
138 | for mtype in self.metric_opt.keys():
139 | base_utils.log_info(f'\t{mtype}: {mdict_per_seq[mtype]:.6f}')
140 |
141 | # average results
142 | base_utils.log_info('Average')
143 | metric_avg_dict = self.average()
144 | for mtype, value in metric_avg_dict.items():
145 | base_utils.log_info(f'\t{mtype}: {value:.6f}')
146 |
147 | @master_only
148 | def save(self, model_idx, save_path, average=True, override=False):
149 | # load previous results if existed
150 | if osp.exists(save_path):
151 | with open(save_path, 'r') as f:
152 | json_dict = json.load(f)
153 | else:
154 | json_dict = dict()
155 |
156 | # update
157 | if model_idx not in json_dict:
158 | json_dict[model_idx] = OrderedDict()
159 |
160 | if average:
161 | metric_avg_dict = self.average()
162 | for mtype, value in metric_avg_dict.items():
163 | # override or skip
164 | if mtype in json_dict[model_idx] and not override:
165 | continue
166 | json_dict[model_idx][mtype] = f'{value:.6f}'
167 | else:
168 | # TODO: save results of each sequence
169 | raise NotImplementedError()
170 |
171 | # sort
172 | json_dict = OrderedDict(sorted(
173 | json_dict.items(), key=lambda x: int(x[0].replace('G_iter', ''))))
174 |
175 | # save results
176 | with open(save_path, 'w') as f:
177 | json.dump(json_dict, f, sort_keys=False, indent=4)
178 |
179 | def compute_sequence_metrics(self, seq_idx, true_seq, pred_seq):
180 | # clear
181 | self.reset_per_sequence()
182 |
183 | # initialize metric_dict for the current sequence
184 | self.seq_idx_curr = seq_idx
185 | self.metric_dict[self.seq_idx_curr] = OrderedDict({
186 | metric: [] for metric in self.metric_opt.keys()})
187 |
188 | # compute metrics for each frame
189 | tot_frm = true_seq.shape[0]
190 | for i in range(tot_frm):
191 | self.true_img_cur = true_seq[i] # hwc|rgb/y|uint8
192 | self.pred_img_cur = pred_seq[i]
193 |
194 | # pred_img and true_img may have different sizes
195 | # crop the larger one to match the smaller one
196 | true_h, true_w = self.true_img_cur.shape[:-1]
197 | pred_h, pred_w = self.pred_img_cur.shape[:-1]
198 | min_h, min_w = min(true_h, pred_h), min(true_w, pred_w)
199 | self.true_img_cur = self.true_img_cur[:min_h, :min_w, :]
200 | self.pred_img_cur = self.pred_img_cur[:min_h, :min_w, :]
201 |
202 | # compute metrics for the current frame
203 | self.compute_frame_metrics()
204 |
205 | # update
206 | self.true_img_pre = self.true_img_cur
207 | self.pred_img_pre = self.pred_img_cur
208 |
209 | def compute_frame_metrics(self):
210 | metric_dict = self.metric_dict[self.seq_idx_curr]
211 |
212 | # compute evaluation
213 | for metric_type, opt in self.metric_opt.items():
214 | if metric_type == 'PSNR':
215 | PSNR = self.compute_PSNR()
216 | metric_dict['PSNR'].append(PSNR)
217 |
218 | elif metric_type == 'LPIPS':
219 | LPIPS = self.compute_LPIPS()[0, 0, 0, 0].item()
220 | metric_dict['LPIPS'].append(LPIPS)
221 |
222 | elif metric_type == 'tOF':
223 | # skip the first frame
224 | if self.pred_img_pre is not None:
225 | tOF = self.compute_tOF()
226 | metric_dict['tOF'].append(tOF)
227 |
228 | def compute_PSNR(self):
229 | if self.psnr_colorspace == 'rgb':
230 | true_img = self.true_img_cur
231 | pred_img = self.pred_img_cur
232 | else:
233 | # convert to ycbcr, and keep the y channel
234 | true_img = data_utils.rgb_to_ycbcr(self.true_img_cur)[..., 0]
235 | pred_img = data_utils.rgb_to_ycbcr(self.pred_img_cur)[..., 0]
236 |
237 | diff = true_img.astype(np.float64) - pred_img.astype(np.float64)
238 | RMSE = np.sqrt(np.mean(np.power(diff, 2)))
239 |
240 | if RMSE == 0:
241 | return np.inf
242 |
243 | PSNR = 20 * np.log10(255.0 / RMSE)
244 | return PSNR
245 |
246 | def compute_LPIPS(self):
247 | true_img = np.ascontiguousarray(self.true_img_cur)
248 | pred_img = np.ascontiguousarray(self.pred_img_cur)
249 |
250 | # to tensor
251 | true_img = torch.FloatTensor(true_img).unsqueeze(0).permute(0, 3, 1, 2)
252 | pred_img = torch.FloatTensor(pred_img).unsqueeze(0).permute(0, 3, 1, 2)
253 |
254 | # normalize to [-1, 1]
255 | true_img = true_img.to(self.device) * 2.0 / 255.0 - 1.0
256 | pred_img = pred_img.to(self.device) * 2.0 / 255.0 - 1.0
257 |
258 | with torch.no_grad():
259 | LPIPS = self.dm.forward(true_img, pred_img)
260 |
261 | return LPIPS
262 |
263 | def compute_tOF(self):
264 | true_img_cur = cv2.cvtColor(self.true_img_cur, cv2.COLOR_RGB2GRAY)
265 | pred_img_cur = cv2.cvtColor(self.pred_img_cur, cv2.COLOR_RGB2GRAY)
266 | true_img_pre = cv2.cvtColor(self.true_img_pre, cv2.COLOR_RGB2GRAY)
267 | pred_img_pre = cv2.cvtColor(self.pred_img_pre, cv2.COLOR_RGB2GRAY)
268 |
269 | # forward flow
270 | true_OF = cv2.calcOpticalFlowFarneback(
271 | true_img_pre, true_img_cur, None, 0.5, 3, 15, 3, 5, 1.2, 0)
272 | pred_OF = cv2.calcOpticalFlowFarneback(
273 | pred_img_pre, pred_img_cur, None, 0.5, 3, 15, 3, 5, 1.2, 0)
274 |
275 | # EPE
276 | diff_OF = true_OF - pred_OF
277 | tOF = np.mean(np.sqrt(np.sum(diff_OF**2, axis=-1)))
278 |
279 | return tOF
280 |
--------------------------------------------------------------------------------
/codes/metrics/model_summary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # define which modules to be incorporated
6 | registered_module = [
7 | nn.Conv2d,
8 | nn.ConvTranspose2d,
9 | nn.Conv3d
10 | ]
11 |
12 | # initialize
13 | registered_hooks, model_info_lst = [], []
14 |
15 |
16 | def calc_2d_gflops_per_batch(module, out_h, out_w):
17 | """ Calculate flops of conv weights (support groups_conv & dilated_conv)
18 | """
19 | gflops = 0
20 | if hasattr(module, 'weight'):
21 | # Note: in_c is already divided by groups while out_c is not
22 | bias = 0 if hasattr(module, 'bias') else -1
23 | out_c, in_c, k_h, k_w = module.weight.shape
24 |
25 | gflops += (2*in_c*k_h*k_w + bias)*out_c*out_h*out_w/1e9
26 | return gflops
27 |
28 |
29 | def calc_3d_gflops_per_batch(module, out_d, out_h, out_w):
30 | """ Calculate flops of conv weights (support groups_conv & dilated_conv)
31 | """
32 | gflops = 0
33 | if hasattr(module, 'weight'):
34 | # Note: in_c is already divided by groups while out_c is not
35 | bias = 0 if hasattr(module, 'bias') else -1
36 | out_c, in_c, k_d, k_h, k_w = module.weight.shape
37 |
38 | gflops += (2*in_c*k_d*k_h*k_w + bias)*out_c*out_d*out_h*out_w/1e9
39 | return gflops
40 |
41 |
42 | def hook_fn_forward(module, input, output):
43 | if isinstance(module, nn.Conv3d):
44 | batch_size, _, out_d, out_h, out_w = output.size()
45 | gflops = batch_size*calc_3d_gflops_per_batch(module, out_d, out_h, out_w)
46 | else:
47 | if isinstance(module, nn.ConvTranspose2d):
48 | batch_size, _, out_h, out_w = input[0].size()
49 | else:
50 | batch_size, _, out_h, out_w = output.size()
51 | gflops = batch_size*calc_2d_gflops_per_batch(module, out_h, out_w)
52 |
53 | model_info_lst.append({'gflops': gflops})
54 |
55 |
56 | def register_hook(module):
57 | if isinstance(module, tuple(registered_module)):
58 | registered_hooks.append(module.register_forward_hook(hook_fn_forward))
59 |
60 |
61 | def register(model, dummy_input_list):
62 | # reset params
63 | global registered_hooks, model_info_lst
64 | registered_hooks, model_info_lst = [], []
65 |
66 | # register hook
67 | model.apply(register_hook)
68 |
69 | # forward
70 | with torch.no_grad():
71 | model.eval()
72 | out = model(*dummy_input_list)
73 |
74 | # remove hooks
75 | for hook in registered_hooks:
76 | hook.remove()
77 |
78 | return out
79 |
80 |
81 | def parse_model_info(model):
82 | tot_gflops = 0
83 | for module_info in model_info_lst:
84 | if module_info['gflops']:
85 | tot_gflops += module_info['gflops']
86 |
87 | tot_params = 0
88 | for param in model.parameters():
89 | tot_params += torch.prod(torch.tensor(param.size())).item()
90 |
91 | return tot_gflops, tot_params
92 |
--------------------------------------------------------------------------------
/codes/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vsr_model import VSRModel
2 | from .vsrgan_model import VSRGANModel
3 |
4 |
5 | # register vsr model
6 | vsr_model_lst = [
7 | 'frvsr',
8 | ]
9 |
10 | # register vsrgan model
11 | vsrgan_model_lst = [
12 | 'tecogan',
13 | ]
14 |
15 |
16 | def define_model(opt):
17 | if opt['model']['name'].lower() in vsr_model_lst:
18 | model = VSRModel(opt)
19 |
20 | elif opt['model']['name'].lower() in vsrgan_model_lst:
21 | model = VSRGANModel(opt)
22 |
23 | else:
24 | raise ValueError(f'Unrecognized model: {opt["model"]["name"]}')
25 |
26 | return model
27 |
--------------------------------------------------------------------------------
/codes/models/base_model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import os.path as osp
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.distributed as dist
8 | from torch.nn.parallel import DistributedDataParallel
9 |
10 | from utils.data_utils import create_kernel, downsample_bd
11 | from utils.dist_utils import master_only
12 |
13 |
14 | class BaseModel():
15 | def __init__(self, opt):
16 | self.opt = opt
17 | self.scale = opt['scale']
18 | self.device = torch.device(opt['device'])
19 | self.blur_kernel = None
20 | self.dist = opt['dist']
21 | self.is_train = opt['is_train']
22 |
23 | if self.is_train:
24 | self.lr_data, self.gt_data = None, None
25 | self.ckpt_dir = opt['train']['ckpt_dir']
26 | self.log_decay = opt['logger'].get('decay', 0.99)
27 | self.log_dict = OrderedDict()
28 | self.running_log_dict = OrderedDict()
29 |
30 | def set_networks(self):
31 | pass
32 |
33 | def set_criterions(self):
34 | pass
35 |
36 | def set_optimizers(self):
37 | pass
38 |
39 | def set_lr_schedules(self):
40 | pass
41 |
42 | def prepare_training_data(self, data):
43 | """ prepare gt, lr data for training
44 |
45 | for BD degradation, generate lr data and remove the border of gt data
46 | for BI degradation, use input data directly
47 | """
48 |
49 | degradation_type = self.opt['dataset']['degradation']['type']
50 |
51 | if degradation_type == 'BI':
52 | self.gt_data = data['gt'].to(self.device)
53 | self.lr_data = data['lr'].to(self.device)
54 |
55 | elif degradation_type == 'BD':
56 | # generate lr data on the fly (on gpu)
57 |
58 | # set params
59 | scale = self.opt['scale']
60 | sigma = self.opt['dataset']['degradation'].get('sigma', 1.5)
61 | border_size = int(sigma * 3.0)
62 |
63 | gt_data = data['gt'].to(self.device) # with border
64 | n, t, c, gt_h, gt_w = gt_data.size()
65 | lr_h = (gt_h - 2*border_size)//scale
66 | lr_w = (gt_w - 2*border_size)//scale
67 |
68 | # create blurring kernel
69 | if self.blur_kernel is None:
70 | self.blur_kernel = create_kernel(sigma).to(self.device)
71 | blur_kernel = self.blur_kernel
72 |
73 | # generate lr data
74 | gt_data = gt_data.view(n*t, c, gt_h, gt_w)
75 | lr_data = downsample_bd(gt_data, blur_kernel, scale, pad_data=False)
76 | lr_data = lr_data.view(n, t, c, lr_h, lr_w)
77 |
78 | # remove gt border
79 | gt_data = gt_data[
80 | ...,
81 | border_size: border_size + scale*lr_h,
82 | border_size: border_size + scale*lr_w]
83 | gt_data = gt_data.view(n, t, c, scale*lr_h, scale*lr_w)
84 |
85 | self.gt_data, self.lr_data = gt_data, lr_data # tchw|float32
86 |
87 | def prepare_inference_data(self, data):
88 | """ Prepare lr data for training (w/o loading on device)
89 | """
90 |
91 | degradation_type = self.opt['dataset']['degradation']['type']
92 |
93 | if degradation_type == 'BI':
94 | self.lr_data = data['lr']
95 |
96 | elif degradation_type == 'BD':
97 | if 'lr' in data:
98 | self.lr_data = data['lr']
99 | else:
100 | # generate lr data on the fly (on cpu)
101 | # TODO: do frame-wise downsampling on gpu for acceleration?
102 | gt_data = data['gt'] # thwc|uint8
103 |
104 | # set params
105 | scale = self.opt['scale']
106 | sigma = self.opt['dataset']['degradation'].get('sigma', 1.5)
107 |
108 | # create blurring kernel
109 | if self.blur_kernel is None:
110 | self.blur_kernel = create_kernel(sigma)
111 | blur_kernel = self.blur_kernel.cpu()
112 |
113 | # generate lr data
114 | gt_data = gt_data.permute(0, 3, 1, 2).float() / 255.0 # tchw|float32
115 | lr_data = downsample_bd(
116 | gt_data, blur_kernel, scale, pad_data=True)
117 | lr_data = lr_data.permute(0, 2, 3, 1) # thwc|float32
118 |
119 | self.lr_data = lr_data
120 |
121 | # thwc to tchw
122 | self.lr_data = self.lr_data.permute(0, 3, 1, 2) # tchw|float32
123 |
124 | def train(self):
125 | pass
126 |
127 | def infer(self):
128 | pass
129 |
130 | def model_to_device(self, net):
131 | net = net.to(self.device)
132 | if self.dist:
133 | net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
134 | net = DistributedDataParallel(
135 | net, device_ids=[torch.cuda.current_device()])
136 | return net
137 |
138 | def update_learning_rate(self):
139 | if hasattr(self, 'sched_G') and self.sched_G is not None:
140 | self.sched_G.step()
141 |
142 | if hasattr(self, 'sched_D') and self.sched_D is not None:
143 | self.sched_D.step()
144 |
145 | def get_learning_rate(self):
146 | lr_dict = OrderedDict()
147 |
148 | if hasattr(self, 'optim_G'):
149 | lr_dict['lr_G'] = self.optim_G.param_groups[0]['lr']
150 |
151 | if hasattr(self, 'optim_D'):
152 | lr_dict['lr_D'] = self.optim_D.param_groups[0]['lr']
153 |
154 | return lr_dict
155 |
156 | def reduce_log(self):
157 | if self.dist:
158 | rank, world_size = self.opt['rank'], self.opt['world_size']
159 | with torch.no_grad():
160 | keys, vals = [], []
161 | for key, val in self.log_dict.items():
162 | keys.append(key)
163 | vals.append(val)
164 | vals = torch.FloatTensor(vals).to(self.device)
165 | dist.reduce(vals, dst=0)
166 | if rank == 0: # average
167 | vals /= world_size
168 | self.log_dict = {key: val.item() for key, val in zip(keys, vals)}
169 |
170 | def update_running_log(self):
171 | self.reduce_log() # for distributed training
172 |
173 | d = self.log_decay
174 | for k in self.log_dict.keys():
175 | current_val = self.log_dict[k]
176 | running_val = self.running_log_dict.get(k)
177 |
178 | if running_val is None:
179 | running_val = current_val
180 | else:
181 | running_val = d * running_val + (1.0 - d) * current_val
182 |
183 | self.running_log_dict[k] = running_val
184 |
185 | def get_current_log(self):
186 | return self.log_dict
187 |
188 | def get_running_log(self):
189 | return self.running_log_dict
190 |
191 | def get_format_msg(self, epoch, iter):
192 | # generic info
193 | msg = f'[epoch: {epoch} | iter: {iter}'
194 | for lr_type, lr in self.get_learning_rate().items():
195 | msg += f' | {lr_type}: {lr:.2e}'
196 | msg += '] '
197 |
198 | # loss info
199 | log_dict = self.get_running_log()
200 | msg += ', '.join([f'{k}: {v:.3e}' for k, v in log_dict.items()])
201 |
202 | return msg
203 |
204 | def save(self, current_iter):
205 | pass
206 |
207 | @staticmethod
208 | def get_bare_model(net):
209 | if isinstance(net, DistributedDataParallel):
210 | net = net.module
211 | return net
212 |
213 | @master_only
214 | def save_network(self, net, net_label, current_iter):
215 | filename = f'{net_label}_iter{current_iter}.pth'
216 | save_path = osp.join(self.ckpt_dir, filename)
217 | net = self.get_bare_model(net)
218 | torch.save(net.state_dict(), save_path)
219 |
220 | def save_training_state(self, current_epoch, current_iter):
221 | # TODO
222 | pass
223 |
224 | def load_network(self, net, load_path):
225 | state_dict = torch.load(
226 | load_path, map_location=lambda storage, loc: storage)
227 | net = self.get_bare_model(net)
228 | net.load_state_dict(state_dict)
229 |
230 | def pad_sequence(self, lr_data):
231 | """
232 | Parameters:
233 | :param lr_data: tensor in shape tchw
234 | """
235 | padding_mode = self.opt['test'].get('padding_mode', 'reflect')
236 | n_pad_front = self.opt['test'].get('num_pad_front', 0)
237 | assert n_pad_front < lr_data.size(0)
238 |
239 | # pad
240 | if padding_mode == 'reflect':
241 | lr_data = torch.cat(
242 | [lr_data[1: 1 + n_pad_front, ...].flip(0), lr_data], dim=0)
243 |
244 | elif padding_mode == 'replicate':
245 | lr_data = torch.cat(
246 | [lr_data[:1, ...].expand(n_pad_front, -1, -1, -1), lr_data], dim=0)
247 |
248 | else:
249 | raise ValueError(f'Unrecognized padding mode: {padding_mode}')
250 |
251 | return lr_data, n_pad_front
252 |
--------------------------------------------------------------------------------
/codes/models/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .tecogan_nets import FRNet, SpatioTemporalDiscriminator, SpatialDiscriminator
2 |
3 |
4 | def define_generator(opt):
5 | net_G_opt = opt['model']['generator']
6 |
7 | if net_G_opt['name'].lower() == 'frnet': # frame-recurrent generator
8 | net_G = FRNet(
9 | in_nc=net_G_opt['in_nc'],
10 | out_nc=net_G_opt['out_nc'],
11 | nf=net_G_opt['nf'],
12 | nb=net_G_opt['nb'],
13 | degradation=opt['dataset']['degradation']['type'],
14 | scale=opt['scale'])
15 |
16 | else:
17 | raise ValueError(f'Unrecognized generator: {net_G_opt["name"]}')
18 |
19 | return net_G
20 |
21 |
22 | def define_discriminator(opt):
23 | net_D_opt = opt['model']['discriminator']
24 |
25 | if opt['dataset']['degradation']['type'] == 'BD':
26 | spatial_size = opt['dataset']['train']['crop_size']
27 | else: # BI
28 | spatial_size = opt['dataset']['train']['gt_crop_size']
29 |
30 | if net_D_opt['name'].lower() == 'stnet': # spatio-temporal discriminator
31 | net_D = SpatioTemporalDiscriminator(
32 | in_nc=net_D_opt['in_nc'],
33 | spatial_size=spatial_size,
34 | tempo_range=net_D_opt['tempo_range'],
35 | degradation=opt['dataset']['degradation']['type'],
36 | scale=opt['scale'])
37 |
38 | elif net_D_opt['name'].lower() == 'snet': # spatial discriminator
39 | net_D = SpatialDiscriminator(
40 | in_nc=net_D_opt['in_nc'],
41 | spatial_size=spatial_size,
42 | use_cond=net_D_opt['use_cond'])
43 |
44 | else:
45 | raise ValueError(f'Unrecognized discriminator: {net_D_opt["name"]}')
46 |
47 | return net_D
48 |
--------------------------------------------------------------------------------
/codes/models/networks/base_nets.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BaseSequenceGenerator(nn.Module):
5 | def __init__(self):
6 | super(BaseSequenceGenerator, self).__init__()
7 |
8 | def generate_dummy_data(self, lr_size):
9 | """ Generate random input tensors for function `step`
10 | """
11 | return None
12 |
13 | def profile(self, *args, **kwargs):
14 | pass
15 |
16 | def forward(self, *args, **kwargs):
17 | """ Interface (support DDP)
18 | """
19 | pass
20 |
21 | def forward_sequence(self, lr_data):
22 | """ Forward a whole sequence (for training)
23 | """
24 | pass
25 |
26 | def step(self, *args, **kwargs):
27 | """ Forward a single frame
28 | """
29 | pass
30 |
31 | def infer_sequence(self, lr_data, device):
32 | """ Infer a whole sequence (for inference)
33 | """
34 | pass
35 |
36 |
37 | class BaseSequenceDiscriminator(nn.Module):
38 | def __init__(self):
39 | super(BaseSequenceDiscriminator, self).__init__()
40 |
41 | def forward(self, *args, **kwargs):
42 | """ Interface (support DDP)
43 | """
44 | pass
45 |
46 | def step(self, *args, **kwargs):
47 | """ Forward a singe frame
48 | """
49 | pass
50 |
51 | def forward_sequence(self, data, args_dict):
52 | """ Forward a whole sequence (for training)
53 | """
54 | pass
55 |
--------------------------------------------------------------------------------
/codes/models/networks/vgg_nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | class VGGFeatureExtractor(nn.Module):
7 | def __init__(self, feature_indexs=(8, 17, 26, 35)):
8 | super(VGGFeatureExtractor, self).__init__()
9 |
10 | # init feature layers
11 | self.features = torchvision.models.vgg19(pretrained=True).features
12 | for param in self.features.parameters():
13 | param.requires_grad = False
14 |
15 | # Notes:
16 | # 1. default feature layers are 8(conv2_2), 17(conv3_4), 26(conv4_4),
17 | # 35(conv5_4)
18 | # 2. features are extracted after ReLU activation
19 | self.feature_indexs = sorted(feature_indexs)
20 |
21 | # register normalization params
22 | mean = torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) # RGB
23 | std = torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
24 | self.register_buffer('mean', mean)
25 | self.register_buffer('std', std)
26 |
27 | def forward(self, x):
28 | # assume input ranges in [0, 1]
29 | out = (x - self.mean) / self.std
30 |
31 | feature_list = []
32 | for i in range(len(self.features)):
33 | out = self.features[i](out)
34 | if i in self.feature_indexs:
35 | # clone to prevent overlapping by inplaced ReLU
36 | feature_list.append(out.clone())
37 |
38 | return feature_list
39 |
--------------------------------------------------------------------------------
/codes/models/optim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.optim as optim
3 |
4 |
5 | def define_criterion(criterion_opt):
6 | if criterion_opt is None:
7 | return None
8 |
9 | # parse
10 | if criterion_opt['type'] == 'MSE':
11 | criterion = nn.MSELoss(reduction=criterion_opt['reduction'])
12 |
13 | elif criterion_opt['type'] == 'L1':
14 | criterion = nn.L1Loss(reduction=criterion_opt['reduction'])
15 |
16 | elif criterion_opt['type'] == 'CB':
17 | from .losses import CharbonnierLoss
18 | criterion = CharbonnierLoss(reduction=criterion_opt['reduction'])
19 |
20 | elif criterion_opt['type'] == 'CosineSimilarity':
21 | from .losses import CosineSimilarityLoss
22 | criterion = CosineSimilarityLoss()
23 |
24 | elif criterion_opt['type'] == 'GAN':
25 | from .losses import VanillaGANLoss
26 | criterion = VanillaGANLoss(reduction=criterion_opt['reduction'])
27 |
28 | elif criterion_opt['type'] == 'LSGAN':
29 | from .losses import LSGANLoss
30 | criterion = LSGANLoss(reduction=criterion_opt['reduction'])
31 |
32 | else:
33 | raise ValueError(f'Unrecognized criterion: {criterion_opt["type"]}')
34 |
35 | return criterion
36 |
37 |
38 | def define_lr_schedule(schedule_opt, optimizer):
39 | if schedule_opt is None:
40 | return None
41 |
42 | # parse
43 | if schedule_opt['type'] == 'FixedLR':
44 | schedule = None
45 |
46 | elif schedule_opt['type'] == 'MultiStepLR':
47 | schedule = optim.lr_scheduler.MultiStepLR(
48 | optimizer,
49 | milestones=schedule_opt['milestones'],
50 | gamma=schedule_opt['gamma'])
51 |
52 | elif schedule_opt['type'] == 'CosineAnnealingRestartLR':
53 | from .lr_schedules import CosineAnnealingRestartLR
54 | schedule = CosineAnnealingRestartLR(
55 | optimizer,
56 | periods=schedule_opt['periods'],
57 | restart_weights=schedule_opt['restart_weights'],
58 | eta_min=schedule_opt['eta_min'])
59 |
60 | else:
61 | raise ValueError(f'Unrecognized lr schedule: {schedule_opt["type"]}')
62 |
63 | return schedule
--------------------------------------------------------------------------------
/codes/models/optim/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class VanillaGANLoss(nn.Module):
7 | def __init__(self, reduction='mean'):
8 | super(VanillaGANLoss, self).__init__()
9 | self.crit = nn.BCEWithLogitsLoss(reduction=reduction)
10 |
11 | def forward(self, input, status):
12 | target = torch.empty_like(input).fill_(int(status))
13 | loss = self.crit(input, target)
14 | return loss
15 |
16 |
17 | class LSGANLoss(nn.Module):
18 | def __init__(self, reduction='mean'):
19 | super(LSGANLoss, self).__init__()
20 | self.crit = nn.MSELoss(reduction=reduction)
21 |
22 | def forward(self, input, status):
23 | """
24 | :param status: boolean, True/False
25 | """
26 | target = torch.empty_like(input).fill_(int(status))
27 | loss = self.crit(input, target)
28 | return loss
29 |
30 |
31 | class CharbonnierLoss(nn.Module):
32 | """ Charbonnier Loss (robust L1)
33 | """
34 |
35 | def __init__(self, eps=1e-6, reduction='sum'):
36 | super(CharbonnierLoss, self).__init__()
37 | self.eps = eps
38 | self.reduction = reduction
39 |
40 | def forward(self, x, y):
41 | diff = x - y
42 | loss = torch.sqrt(diff * diff + self.eps)
43 |
44 | if self.reduction == 'sum':
45 | loss = torch.sum(loss)
46 | elif self.reduction == 'mean':
47 | loss = torch.mean(loss)
48 | else:
49 | raise NotImplementedError
50 | return loss
51 |
52 |
53 | class CosineSimilarityLoss(nn.Module):
54 | def __init__(self, eps=1e-8):
55 | super(CosineSimilarityLoss, self).__init__()
56 | self.eps = eps
57 |
58 | def forward(self, input, target):
59 | diff = F.cosine_similarity(input, target, dim=1, eps=self.eps)
60 | loss = 1.0 - diff.mean()
61 |
62 | return loss
63 |
--------------------------------------------------------------------------------
/codes/models/optim/lr_schedules.py:
--------------------------------------------------------------------------------
1 | """ Code adopted from BasicSR: https://github.com/xinntao/BasicSR/blob/master/basicsr/models/lr_scheduler.py
2 | """
3 |
4 | import math
5 |
6 | from torch.optim.lr_scheduler import _LRScheduler
7 |
8 |
9 | def get_position_from_periods(iteration, cumulative_period):
10 | """Get the position from a period list.
11 |
12 | It will return the index of the right-closest number in the period list.
13 | For example, the cumulative_period = [100, 200, 300, 400],
14 | if iteration == 50, return 0;
15 | if iteration == 210, return 2;
16 | if iteration == 300, return 2.
17 |
18 | Args:
19 | iteration (int): Current iteration.
20 | cumulative_period (list[int]): Cumulative period list.
21 |
22 | Returns:
23 | int: The position of the right-closest number in the period list.
24 | """
25 | for i, period in enumerate(cumulative_period):
26 | if iteration <= period:
27 | return i
28 |
29 |
30 | class CosineAnnealingRestartLR(_LRScheduler):
31 | """ Cosine annealing with restarts learning rate scheme.
32 |
33 | An example of config:
34 | periods = [10, 10, 10, 10]
35 | restart_weights = [1, 0.5, 0.5, 0.5]
36 | eta_min=1e-7
37 |
38 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
39 | scheduler will restart with the weights in restart_weights.
40 |
41 | Args:
42 | optimizer (torch.nn.optimizer): Torch optimizer.
43 | periods (list): Period for each cosine anneling cycle.
44 | restart_weights (list): Restart weights at each restart iteration.
45 | Default: [1].
46 | eta_min (float): The mimimum lr. Default: 0.
47 | last_epoch (int): Used in _LRScheduler. Default: -1.
48 | """
49 |
50 | def __init__(self,
51 | optimizer,
52 | periods,
53 | restart_weights=(1, ),
54 | eta_min=0,
55 | last_epoch=-1):
56 | self.periods = periods
57 | self.restart_weights = restart_weights
58 | self.eta_min = eta_min
59 | assert (len(self.periods) == len(self.restart_weights)
60 | ), 'periods and restart_weights should have the same length.'
61 | self.cumulative_period = [
62 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
63 | ]
64 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
65 |
66 | def get_lr(self):
67 | idx = get_position_from_periods(self.last_epoch,
68 | self.cumulative_period)
69 | current_weight = self.restart_weights[idx]
70 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
71 | current_period = self.periods[idx]
72 |
73 | return [
74 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
75 | (1 + math.cos(math.pi * (
76 | (self.last_epoch - nearest_restart) / current_period)))
77 | for base_lr in self.base_lrs
78 | ]
79 |
--------------------------------------------------------------------------------
/codes/models/vsr_model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.optim as optim
5 |
6 | from .base_model import BaseModel
7 | from .networks import define_generator
8 | from .optim import define_criterion, define_lr_schedule
9 | from utils import base_utils, net_utils, data_utils
10 |
11 |
12 | class VSRModel(BaseModel):
13 | """ A model wrapper for objective video super-resolution
14 | """
15 |
16 | def __init__(self, opt):
17 | super(VSRModel, self).__init__(opt)
18 |
19 | # define network
20 | self.set_networks()
21 |
22 | # config training
23 | if self.is_train:
24 | self.set_criterions()
25 | self.set_optimizers()
26 | self.set_lr_schedules()
27 |
28 | def set_networks(self):
29 | # define generator
30 | self.net_G = define_generator(self.opt)
31 | self.net_G = self.model_to_device(self.net_G)
32 | base_utils.log_info('Generator: {}\n{}'.format(
33 | self.opt['model']['generator']['name'], self.net_G.__str__()))
34 |
35 | # load generator
36 | load_path_G = self.opt['model']['generator'].get('load_path')
37 | if load_path_G is not None:
38 | self.load_network(self.net_G, load_path_G)
39 | base_utils.log_info(f'Load generator from: {load_path_G}')
40 |
41 | def set_criterions(self):
42 | # pixel criterion
43 | self.pix_crit = define_criterion(
44 | self.opt['train'].get('pixel_crit'))
45 |
46 | # warping criterion
47 | self.warp_crit = define_criterion(
48 | self.opt['train'].get('warping_crit'))
49 |
50 | def set_optimizers(self):
51 | self.optim_G = optim.Adam(
52 | self.net_G.parameters(),
53 | lr=self.opt['train']['generator']['lr'],
54 | weight_decay=self.opt['train']['generator'].get('weight_decay', 0),
55 | betas=self.opt['train']['generator'].get('betas', (0.9, 0.999)))
56 |
57 | def set_lr_schedules(self):
58 | self.sched_G = define_lr_schedule(
59 | self.opt['train']['generator'].get('lr_schedule'), self.optim_G)
60 |
61 | def train(self):
62 | # === initialize === #
63 | self.net_G.train()
64 | self.optim_G.zero_grad()
65 |
66 | # === forward net_G === #
67 | net_G_output_dict = self.net_G(self.lr_data)
68 | self.hr_data = net_G_output_dict['hr_data']
69 |
70 | # === optimize net_G === #
71 | loss_G = 0
72 | self.log_dict = OrderedDict()
73 |
74 | # pixel loss
75 | pix_w = self.opt['train']['pixel_crit'].get('weight', 1.0)
76 | loss_pix_G = pix_w * self.pix_crit(self.hr_data, self.gt_data)
77 | loss_G += loss_pix_G
78 | self.log_dict['l_pix_G'] = loss_pix_G.item()
79 |
80 | # warping loss
81 | if self.warp_crit is not None:
82 | # warp lr_prev according to lr_flow
83 | lr_curr = net_G_output_dict['lr_curr']
84 | lr_prev = net_G_output_dict['lr_prev']
85 | lr_flow = net_G_output_dict['lr_flow']
86 | lr_warp = net_utils.backward_warp(lr_prev, lr_flow)
87 |
88 | warp_w = self.opt['train']['warping_crit'].get('weight', 1.0)
89 | loss_warp_G = warp_w * self.warp_crit(lr_warp, lr_curr)
90 | loss_G += loss_warp_G
91 | self.log_dict['l_warp_G'] = loss_warp_G.item()
92 |
93 | # optimize
94 | loss_G.backward()
95 | self.optim_G.step()
96 |
97 | def infer(self):
98 | """ Infer the `lr_data` sequence
99 |
100 | :return: np.ndarray sequence in type [uint8] and shape [thwc]
101 | """
102 |
103 | lr_data = self.lr_data
104 |
105 | # temporal padding
106 | lr_data, n_pad_front = self.pad_sequence(lr_data)
107 |
108 | # infer
109 | self.net_G.eval()
110 | hr_seq = self.net_G(lr_data, self.device)
111 | hr_seq = hr_seq[n_pad_front:]
112 |
113 | return hr_seq
114 |
115 | def save(self, current_iter):
116 | self.save_network(self.net_G, 'G', current_iter)
117 |
--------------------------------------------------------------------------------
/codes/official_metrics/LPIPSmodels/LPIPSsource.txt:
--------------------------------------------------------------------------------
1 | originally from https://github.com/richzhang/PerceptualSimilarity,
2 | @inproceedings{zhang2018perceptual,
3 | title={The Unreasonable Effectiveness of Deep Features as a Perceptual Metric},
4 | author={Zhang, Richard and Isola, Phillip and Efros, Alexei A and Shechtman, Eli and Wang, Oliver},
5 | booktitle={CVPR},
6 | year={2018}
7 | }
--------------------------------------------------------------------------------
/codes/official_metrics/LPIPSmodels/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import LPIPSmodels.util as util
5 | from torch.autograd import Variable
6 | from pdb import set_trace as st
7 | from IPython import embed
8 |
9 | class BaseModel():
10 | def __init__(self):
11 | pass;
12 |
13 | def name(self):
14 | return 'BaseModel'
15 |
16 | def initialize(self, use_gpu=True):
17 | self.use_gpu = use_gpu
18 | self.Tensor = torch.cuda.FloatTensor if self.use_gpu else torch.Tensor
19 | # self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
20 |
21 | def forward(self):
22 | pass
23 |
24 | def get_image_paths(self):
25 | pass
26 |
27 | def optimize_parameters(self):
28 | pass
29 |
30 | def get_current_visuals(self):
31 | return self.input
32 |
33 | def get_current_errors(self):
34 | return {}
35 |
36 | def save(self, label):
37 | pass
38 |
39 | # helper saving function that can be used by subclasses
40 | def save_network(self, network, path, network_label, epoch_label):
41 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
42 | save_path = os.path.join(path, save_filename)
43 | torch.save(network.state_dict(), save_path)
44 |
45 | # helper loading function that can be used by subclasses
46 | def load_network(self, network, network_label, epoch_label):
47 | # embed()
48 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
49 | save_path = os.path.join(self.save_dir, save_filename)
50 | print('Loading network from %s'%save_path)
51 | network.load_state_dict(torch.load(save_path))
52 |
53 | def update_learning_rate():
54 | pass
55 |
56 | def get_image_paths(self):
57 | return self.image_paths
58 |
59 | def save_done(self, flag=False):
60 | np.save(os.path.join(self.save_dir, 'done_flag'),flag)
61 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
62 |
63 |
--------------------------------------------------------------------------------
/codes/official_metrics/LPIPSmodels/networks_basic.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import sys
5 | sys.path.append('..')
6 | sys.path.append('.')
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.init as init
10 | from torch.autograd import Variable
11 | import numpy as np
12 | from pdb import set_trace as st
13 | from skimage import color
14 | from IPython import embed
15 | from . import pretrained_networks as pn
16 |
17 | from . import util
18 |
19 | # Off-the-shelf deep network
20 | class PNet(nn.Module):
21 | '''Pre-trained network with all channels equally weighted by default'''
22 | def __init__(self, pnet_type='vgg', pnet_rand=False, use_gpu=True):
23 | super(PNet, self).__init__()
24 |
25 | self.use_gpu = use_gpu
26 |
27 | self.pnet_type = pnet_type
28 | self.pnet_rand = pnet_rand
29 |
30 | self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1))
31 | self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1))
32 |
33 | if(self.pnet_type in ['vgg','vgg16']):
34 | self.net = pn.vgg16(pretrained=not self.pnet_rand,requires_grad=False)
35 | elif(self.pnet_type=='alex'):
36 | self.net = pn.alexnet(pretrained=not self.pnet_rand,requires_grad=False)
37 | elif(self.pnet_type[:-2]=='resnet'):
38 | self.net = pn.resnet(pretrained=not self.pnet_rand,requires_grad=False, num=int(self.pnet_type[-2:]))
39 | elif(self.pnet_type=='squeeze'):
40 | self.net = pn.squeezenet(pretrained=not self.pnet_rand,requires_grad=False)
41 |
42 | self.L = self.net.N_slices
43 |
44 | if(use_gpu):
45 | self.net.cuda()
46 | self.shift = self.shift.cuda()
47 | self.scale = self.scale.cuda()
48 |
49 | def forward(self, in0, in1, retPerLayer=False):
50 | in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0)
51 | in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0)
52 |
53 | outs0 = self.net.forward(in0_sc)
54 | outs1 = self.net.forward(in1_sc)
55 |
56 | if(retPerLayer):
57 | all_scores = []
58 | for (kk,out0) in enumerate(outs0):
59 | cur_score = (1.-util.cos_sim(outs0[kk],outs1[kk]))
60 | if(kk==0):
61 | val = 1.*cur_score
62 | else:
63 | # val = val + self.lambda_feat_layers[kk]*cur_score
64 | val = val + cur_score
65 | if(retPerLayer):
66 | all_scores+=[cur_score]
67 |
68 | if(retPerLayer):
69 | return (val, all_scores)
70 | else:
71 | return val
72 |
73 | # Learned perceptual metric
74 | class PNetLin(nn.Module):
75 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, use_gpu=True, spatial=False, version='0.1'):
76 | super(PNetLin, self).__init__()
77 |
78 | self.use_gpu = use_gpu
79 | self.pnet_type = pnet_type
80 | self.pnet_tune = pnet_tune
81 | self.pnet_rand = pnet_rand
82 | self.spatial = spatial
83 | self.version = version
84 |
85 | if(self.pnet_type in ['vgg','vgg16']):
86 | net_type = pn.vgg16
87 | self.chns = [64,128,256,512,512]
88 | elif(self.pnet_type=='alex'):
89 | net_type = pn.alexnet
90 | self.chns = [64,192,384,256,256]
91 | elif(self.pnet_type=='squeeze'):
92 | net_type = pn.squeezenet
93 | self.chns = [64,128,256,384,384,512,512]
94 |
95 | if(self.pnet_tune):
96 | self.net = net_type(pretrained=not self.pnet_rand,requires_grad=True)
97 | else:
98 | self.net = [net_type(pretrained=not self.pnet_rand,requires_grad=False),]
99 |
100 | self.lin0 = NetLinLayer(self.chns[0],use_dropout=use_dropout)
101 | self.lin1 = NetLinLayer(self.chns[1],use_dropout=use_dropout)
102 | self.lin2 = NetLinLayer(self.chns[2],use_dropout=use_dropout)
103 | self.lin3 = NetLinLayer(self.chns[3],use_dropout=use_dropout)
104 | self.lin4 = NetLinLayer(self.chns[4],use_dropout=use_dropout)
105 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
106 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
107 | self.lin5 = NetLinLayer(self.chns[5],use_dropout=use_dropout)
108 | self.lin6 = NetLinLayer(self.chns[6],use_dropout=use_dropout)
109 | self.lins+=[self.lin5,self.lin6]
110 |
111 | self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1))
112 | self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1))
113 |
114 | if(use_gpu):
115 | if(self.pnet_tune):
116 | self.net.cuda()
117 | else:
118 | self.net[0].cuda()
119 | self.shift = self.shift.cuda()
120 | self.scale = self.scale.cuda()
121 | self.lin0.cuda()
122 | self.lin1.cuda()
123 | self.lin2.cuda()
124 | self.lin3.cuda()
125 | self.lin4.cuda()
126 | if(self.pnet_type=='squeeze'):
127 | self.lin5.cuda()
128 | self.lin6.cuda()
129 |
130 | def forward(self, in0, in1):
131 | in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0)
132 | in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0)
133 |
134 | if(self.version=='0.0'):
135 | # v0.0 - original release had a bug, where input was not scaled
136 | in0_input = in0
137 | in1_input = in1
138 | else:
139 | # v0.1
140 | in0_input = in0_sc
141 | in1_input = in1_sc
142 |
143 | if(self.pnet_tune):
144 | outs0 = self.net.forward(in0_input)
145 | outs1 = self.net.forward(in1_input)
146 | else:
147 | outs0 = self.net[0].forward(in0_input)
148 | outs1 = self.net[0].forward(in1_input)
149 |
150 | feats0 = {}
151 | feats1 = {}
152 | diffs = [0]*len(outs0)
153 |
154 | for (kk,out0) in enumerate(outs0):
155 | feats0[kk] = util.normalize_tensor(outs0[kk])
156 | feats1[kk] = util.normalize_tensor(outs1[kk])
157 | diffs[kk] = (feats0[kk]-feats1[kk])**2
158 |
159 | if self.spatial:
160 | lin_models = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
161 | if(self.pnet_type=='squeeze'):
162 | lin_models.extend([self.lin5, self.lin6])
163 | res = [lin_models[kk].model(diffs[kk]) for kk in range(len(diffs))]
164 | return res
165 |
166 | val = torch.mean(torch.mean(self.lin0.model(diffs[0]),dim=3),dim=2)
167 | val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]),dim=3),dim=2)
168 | val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]),dim=3),dim=2)
169 | val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]),dim=3),dim=2)
170 | val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]),dim=3),dim=2)
171 | if(self.pnet_type=='squeeze'):
172 | val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]),dim=3),dim=2)
173 | val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]),dim=3),dim=2)
174 |
175 | val = val.view(val.size()[0],val.size()[1],1,1)
176 |
177 | return val
178 |
179 | class Dist2LogitLayer(nn.Module):
180 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
181 | def __init__(self, chn_mid=32,use_sigmoid=True):
182 | super(Dist2LogitLayer, self).__init__()
183 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
184 | layers += [nn.LeakyReLU(0.2,True),]
185 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
186 | layers += [nn.LeakyReLU(0.2,True),]
187 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
188 | if(use_sigmoid):
189 | layers += [nn.Sigmoid(),]
190 | self.model = nn.Sequential(*layers)
191 |
192 | def forward(self,d0,d1,eps=0.1):
193 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
194 |
195 | class BCERankingLoss(nn.Module):
196 | def __init__(self, use_gpu=True, chn_mid=32):
197 | super(BCERankingLoss, self).__init__()
198 | self.use_gpu = use_gpu
199 | self.net = Dist2LogitLayer(chn_mid=chn_mid)
200 | self.parameters = list(self.net.parameters())
201 | self.loss = torch.nn.BCELoss()
202 | self.model = nn.Sequential(*[self.net])
203 |
204 | if(self.use_gpu):
205 | self.net.cuda()
206 |
207 | def forward(self, d0, d1, judge):
208 | per = (judge+1.)/2.
209 | if(self.use_gpu):
210 | per = per.cuda()
211 | self.logit = self.net.forward(d0,d1)
212 | return self.loss(self.logit, per)
213 |
214 | class NetLinLayer(nn.Module):
215 | ''' A single linear layer which does a 1x1 conv '''
216 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
217 | super(NetLinLayer, self).__init__()
218 |
219 | layers = [nn.Dropout(),] if(use_dropout) else []
220 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
221 | self.model = nn.Sequential(*layers)
222 |
223 |
224 | # L2, DSSIM metrics
225 | class FakeNet(nn.Module):
226 | def __init__(self, use_gpu=True, colorspace='Lab'):
227 | super(FakeNet, self).__init__()
228 | self.use_gpu = use_gpu
229 | self.colorspace=colorspace
230 |
231 | class L2(FakeNet):
232 |
233 | def forward(self, in0, in1):
234 | assert(in0.size()[0]==1) # currently only supports batchSize 1
235 |
236 | if(self.colorspace=='RGB'):
237 | (N,C,X,Y) = in0.size()
238 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
239 | return value
240 | elif(self.colorspace=='Lab'):
241 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
242 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
243 | ret_var = Variable( torch.Tensor((value,) ) )
244 | if(self.use_gpu):
245 | ret_var = ret_var.cuda()
246 | return ret_var
247 |
248 | class DSSIM(FakeNet):
249 |
250 | def forward(self, in0, in1):
251 | assert(in0.size()[0]==1) # currently only supports batchSize 1
252 |
253 | if(self.colorspace=='RGB'):
254 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
255 | elif(self.colorspace=='Lab'):
256 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
257 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
258 | ret_var = Variable( torch.Tensor((value,) ) )
259 | if(self.use_gpu):
260 | ret_var = ret_var.cuda()
261 | return ret_var
262 |
263 | def print_network(net):
264 | num_params = 0
265 | for param in net.parameters():
266 | num_params += param.numel()
267 | print('Network',net)
268 | print('Total number of parameters: %d' % num_params)
269 |
--------------------------------------------------------------------------------
/codes/official_metrics/LPIPSmodels/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torchvision import models
4 | from IPython import embed
5 |
6 | class squeezenet(torch.nn.Module):
7 | def __init__(self, requires_grad=False, pretrained=True):
8 | super(squeezenet, self).__init__()
9 | pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
10 | self.slice1 = torch.nn.Sequential()
11 | self.slice2 = torch.nn.Sequential()
12 | self.slice3 = torch.nn.Sequential()
13 | self.slice4 = torch.nn.Sequential()
14 | self.slice5 = torch.nn.Sequential()
15 | self.slice6 = torch.nn.Sequential()
16 | self.slice7 = torch.nn.Sequential()
17 | self.N_slices = 7
18 | for x in range(2):
19 | self.slice1.add_module(str(x), pretrained_features[x])
20 | for x in range(2,5):
21 | self.slice2.add_module(str(x), pretrained_features[x])
22 | for x in range(5, 8):
23 | self.slice3.add_module(str(x), pretrained_features[x])
24 | for x in range(8, 10):
25 | self.slice4.add_module(str(x), pretrained_features[x])
26 | for x in range(10, 11):
27 | self.slice5.add_module(str(x), pretrained_features[x])
28 | for x in range(11, 12):
29 | self.slice6.add_module(str(x), pretrained_features[x])
30 | for x in range(12, 13):
31 | self.slice7.add_module(str(x), pretrained_features[x])
32 | if not requires_grad:
33 | for param in self.parameters():
34 | param.requires_grad = False
35 |
36 | def forward(self, X):
37 | h = self.slice1(X)
38 | h_relu1 = h
39 | h = self.slice2(h)
40 | h_relu2 = h
41 | h = self.slice3(h)
42 | h_relu3 = h
43 | h = self.slice4(h)
44 | h_relu4 = h
45 | h = self.slice5(h)
46 | h_relu5 = h
47 | h = self.slice6(h)
48 | h_relu6 = h
49 | h = self.slice7(h)
50 | h_relu7 = h
51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53 |
54 | return out
55 |
56 |
57 | class alexnet(torch.nn.Module):
58 | def __init__(self, requires_grad=False, pretrained=True):
59 | super(alexnet, self).__init__()
60 | alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
61 | self.slice1 = torch.nn.Sequential()
62 | self.slice2 = torch.nn.Sequential()
63 | self.slice3 = torch.nn.Sequential()
64 | self.slice4 = torch.nn.Sequential()
65 | self.slice5 = torch.nn.Sequential()
66 | self.N_slices = 5
67 | for x in range(2):
68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69 | for x in range(2, 5):
70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71 | for x in range(5, 8):
72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73 | for x in range(8, 10):
74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75 | for x in range(10, 12):
76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77 | if not requires_grad:
78 | for param in self.parameters():
79 | param.requires_grad = False
80 |
81 | def forward(self, X):
82 | h = self.slice1(X)
83 | h_relu1 = h
84 | h = self.slice2(h)
85 | h_relu2 = h
86 | h = self.slice3(h)
87 | h_relu3 = h
88 | h = self.slice4(h)
89 | h_relu4 = h
90 | h = self.slice5(h)
91 | h_relu5 = h
92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94 |
95 | return out
96 |
97 | class vgg16(torch.nn.Module):
98 | def __init__(self, requires_grad=False, pretrained=True):
99 | super(vgg16, self).__init__()
100 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
101 | self.slice1 = torch.nn.Sequential()
102 | self.slice2 = torch.nn.Sequential()
103 | self.slice3 = torch.nn.Sequential()
104 | self.slice4 = torch.nn.Sequential()
105 | self.slice5 = torch.nn.Sequential()
106 | self.N_slices = 5
107 | for x in range(4):
108 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
109 | for x in range(4, 9):
110 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(9, 16):
112 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(16, 23):
114 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(23, 30):
116 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
117 | if not requires_grad:
118 | for param in self.parameters():
119 | param.requires_grad = False
120 |
121 | def forward(self, X):
122 | h = self.slice1(X)
123 | h_relu1_2 = h
124 | h = self.slice2(h)
125 | h_relu2_2 = h
126 | h = self.slice3(h)
127 | h_relu3_3 = h
128 | h = self.slice4(h)
129 | h_relu4_3 = h
130 | h = self.slice5(h)
131 | h_relu5_3 = h
132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134 |
135 | return out
136 |
137 |
138 |
139 | class resnet(torch.nn.Module):
140 | def __init__(self, requires_grad=False, pretrained=True, num=18):
141 | super(resnet, self).__init__()
142 | if(num==18):
143 | self.net = models.resnet18(pretrained=pretrained)
144 | elif(num==34):
145 | self.net = models.resnet34(pretrained=pretrained)
146 | elif(num==50):
147 | self.net = models.resnet50(pretrained=pretrained)
148 | elif(num==101):
149 | self.net = models.resnet101(pretrained=pretrained)
150 | elif(num==152):
151 | self.net = models.resnet152(pretrained=pretrained)
152 | self.N_slices = 5
153 |
154 | self.conv1 = self.net.conv1
155 | self.bn1 = self.net.bn1
156 | self.relu = self.net.relu
157 | self.maxpool = self.net.maxpool
158 | self.layer1 = self.net.layer1
159 | self.layer2 = self.net.layer2
160 | self.layer3 = self.net.layer3
161 | self.layer4 = self.net.layer4
162 |
163 | def forward(self, X):
164 | h = self.conv1(X)
165 | h = self.bn1(h)
166 | h = self.relu(h)
167 | h_relu1 = h
168 | h = self.maxpool(h)
169 | h = self.layer1(h)
170 | h_conv2 = h
171 | h = self.layer2(h)
172 | h_conv3 = h
173 | h = self.layer3(h)
174 | h_conv4 = h
175 | h = self.layer4(h)
176 | h_conv5 = h
177 |
178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180 |
181 | return out
182 |
--------------------------------------------------------------------------------
/codes/official_metrics/LPIPSmodels/v0.1/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/official_metrics/LPIPSmodels/v0.1/alex.pth
--------------------------------------------------------------------------------
/codes/official_metrics/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import argparse
4 |
5 |
6 | if __name__ == '__main__':
7 | # get agrs
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--model', '-m', type=str, required=True)
10 | args = parser.parse_args()
11 |
12 | keys = args.model.split('_')
13 | assert keys[0] in ('TecoGAN', 'FRVSR')
14 | assert keys[1] in ('BD', 'BI')
15 |
16 | # set dirs
17 | Vid4_GT_dir = 'data/Vid4/GT'
18 | Vid4_SR_dir = 'results/Vid4/{}'.format(args.model)
19 | Vid4_vids = ['calendar', 'city', 'foliage', 'walk']
20 |
21 | ToS3_GT_dir = 'data/ToS3/GT'
22 | ToS3_SR_dir = 'results/ToS3/{}'.format(args.model)
23 | ToS3_vids = ['bridge', 'face', 'room']
24 |
25 | # evaluate Vid4
26 | if osp.exists(Vid4_SR_dir):
27 | Vid4_GT_lst = [
28 | osp.join(Vid4_GT_dir, vid) for vid in Vid4_vids]
29 | Vid4_SR_lst = [
30 | osp.join(Vid4_SR_dir, vid) for vid in Vid4_vids]
31 | os.system('python codes/official_metrics/metrics.py --output {} --results {} --targets {}'.format(
32 | osp.join(Vid4_SR_dir, 'metric_log'),
33 | ','.join(Vid4_SR_lst),
34 | ','.join(Vid4_GT_lst)))
35 |
36 | # evaluate ToS3
37 | if osp.exists(ToS3_SR_dir):
38 | ToS3_GT_lst = [
39 | osp.join(ToS3_GT_dir, vid) for vid in ToS3_vids]
40 | ToS3_SR_lst = [
41 | osp.join(ToS3_SR_dir, vid) for vid in ToS3_vids]
42 | os.system('python codes/official_metrics/metrics.py --output {} --results {} --targets {}'.format(
43 | osp.join(ToS3_SR_dir, 'metric_log'),
44 | ','.join(ToS3_SR_lst),
45 | ','.join(ToS3_GT_lst)))
46 |
47 |
--------------------------------------------------------------------------------
/codes/official_metrics/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os, sys
4 | import pandas as pd
5 | from LPIPSmodels import util
6 | import LPIPSmodels.dist_model as dm
7 | from skimage.measure import compare_ssim
8 |
9 | from absl import flags
10 | flags.DEFINE_string('output', None, 'the path of output directory')
11 | flags.DEFINE_string('results', None, 'the list of paths of result directory')
12 | flags.DEFINE_string('targets', None, 'the list of paths of target directory')
13 |
14 | FLAGS = flags.FLAGS
15 | FLAGS(sys.argv)
16 |
17 | if(not os.path.exists(FLAGS.output)):
18 | os.mkdir(FLAGS.output)
19 |
20 | # The operation used to print out the configuration
21 | def print_configuration_op(FLAGS):
22 | print('[Configurations]:')
23 | for name, value in FLAGS.flag_values_dict().items():
24 | print('\t%s: %s'%(name, str(value)))
25 | print('End of configuration')
26 | # custom Logger to write Log to file
27 |
28 | def listPNGinDir(dirpath):
29 | filelist = os.listdir(dirpath)
30 | filelist = [_ for _ in filelist if _.endswith(".png")]
31 | filelist = [_ for _ in filelist if not _.startswith("IB")]
32 | filelist = sorted(filelist)
33 | filelist.sort(key=lambda f: int(''.join(list(filter(str.isdigit, f))) or -1))
34 | result = [os.path.join(dirpath,_) for _ in filelist if _.endswith(".png")]
35 | return result
36 |
37 | def _rgb2ycbcr(img, maxVal=255):
38 | ##### color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
39 | O = np.array([[16],
40 | [128],
41 | [128]])
42 | T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
43 | [-0.148223529411765, -0.290992156862745, 0.439215686274510],
44 | [0.439215686274510, -0.367788235294118, -0.071427450980392]])
45 |
46 | if maxVal == 1:
47 | O = O / 255.0
48 |
49 | t = np.reshape(img, (img.shape[0]*img.shape[1], img.shape[2]))
50 | t = np.dot(t, np.transpose(T))
51 | t[:, 0] += O[0]
52 | t[:, 1] += O[1]
53 | t[:, 2] += O[2]
54 | ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
55 |
56 | return ycbcr
57 |
58 | def to_uint8(x, vmin, vmax):
59 | ##### color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
60 | x = x.astype('float32')
61 | x = (x-vmin)/(vmax-vmin)*255 # 0~255
62 | return np.clip(np.round(x), 0, 255)
63 |
64 | def psnr(img_true, img_pred):
65 | ##### PSNR with color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
66 | Y_true = _rgb2ycbcr(to_uint8(img_true, 0, 255), 255)[:,:,0]
67 | Y_pred = _rgb2ycbcr(to_uint8(img_pred, 0, 255), 255)[:,:,0]
68 | diff = Y_true - Y_pred
69 | rmse = np.sqrt(np.mean(np.power(diff,2)))
70 | return 20*np.log10(255./rmse)
71 |
72 | def ssim(img_true, img_pred): ##### SSIM #####
73 | Y_true = _rgb2ycbcr(to_uint8(img_true, 0, 255), 255)[:,:,0]
74 | Y_pred = _rgb2ycbcr(to_uint8(img_pred, 0, 255), 255)[:,:,0]
75 | return compare_ssim(Y_true, Y_pred, data_range=Y_pred.max() - Y_pred.min())
76 |
77 | def crop_8x8( img ):
78 | ori_h = img.shape[0]
79 | ori_w = img.shape[1]
80 |
81 | h = (ori_h//32) * 32
82 | w = (ori_w//32) * 32
83 |
84 | while(h > ori_h - 16):
85 | h = h - 32
86 | while(w > ori_w - 16):
87 | w = w - 32
88 |
89 | y = (ori_h - h) // 2
90 | x = (ori_w - w) // 2
91 | crop_img = img[y:y+h, x:x+w]
92 | return crop_img, y, x
93 |
94 | class Logger(object):
95 | def __init__(self):
96 | self.terminal = sys.stdout
97 | filename = "metricsfile.txt"
98 | self.log = open(os.path.join(FLAGS.output, filename), "a")
99 | def write(self, message):
100 | self.terminal.write(message)
101 | self.log.write(message)
102 | def flush(self):
103 | self.log.flush()
104 |
105 | sys.stdout = Logger()
106 |
107 | print_configuration_op(FLAGS)
108 |
109 | result_list = FLAGS.results.split(',')
110 | target_list = FLAGS.targets.split(',')
111 | folder_n = len(result_list)
112 |
113 |
114 | model = dm.DistModel()
115 | model.initialize(model='net-lin',net='alex',use_gpu=True)
116 |
117 | cutfr = 2
118 | # maxV = 0.4, for line 154-166
119 |
120 | keys = ["PSNR", "SSIM", "LPIPS", "tOF", "tLP100"] # keys = ["LPIPS"]
121 | sum_dict = dict.fromkeys(["FrameAvg_"+_ for _ in keys], 0)
122 | len_dict = dict.fromkeys(keys, 0)
123 | avg_dict = dict.fromkeys(["Avg_"+_ for _ in keys], 0)
124 | folder_dict = dict.fromkeys(["FolderAvg_"+_ for _ in keys], 0)
125 |
126 | for folder_i in range(folder_n):
127 | result = listPNGinDir(result_list[folder_i])
128 | target = listPNGinDir(target_list[folder_i])
129 | image_no = len(target)
130 |
131 | list_dict = {}
132 | for key_i in keys:
133 | list_dict[key_i] = []
134 |
135 | for i in range(cutfr, image_no-cutfr):
136 | output_img = cv2.imread(result[i])[:,:,::-1]
137 | target_img = cv2.imread(target[i])[:,:,::-1]
138 | msg = "frame %d, tar %s, out %s, "%(i, str(target_img.shape), str(output_img.shape))
139 | if( target_img.shape[0] < output_img.shape[0]) or ( target_img.shape[1] < output_img.shape[1]): # target is not dividable by 4
140 | output_img = output_img[:target_img.shape[0],:target_img.shape[1]]
141 | if( target_img.shape[0] > output_img.shape[0]) or ( target_img.shape[1] > output_img.shape[1]): # target is not dividable by 4
142 | target_img = target_img[:output_img.shape[0],:output_img.shape[1]]
143 | #print(result[i])
144 |
145 | if "tOF" in keys:# tOF
146 | output_grey = cv2.cvtColor(output_img, cv2.COLOR_RGB2GRAY)
147 | target_grey = cv2.cvtColor(target_img, cv2.COLOR_RGB2GRAY)
148 | if (i > cutfr): # temporal metrics
149 | target_OF=cv2.calcOpticalFlowFarneback(pre_tar_grey, target_grey, None, 0.5, 3, 15, 3, 5, 1.2, 0)
150 | output_OF=cv2.calcOpticalFlowFarneback(pre_out_grey, output_grey, None, 0.5, 3, 15, 3, 5, 1.2, 0)
151 | target_OF, ofy, ofx = crop_8x8(target_OF)
152 | output_OF, ofy, ofx = crop_8x8(output_OF)
153 | OF_diff = np.absolute(target_OF - output_OF)
154 | if False: # for motion visualization
155 | tOFpath = os.path.join(FLAGS.output,"%03d_tOF"%folder_i)
156 | if(not os.path.exists(tOFpath)): os.mkdir(tOFpath)
157 | hsv = np.zeros_like(output_img)
158 | hsv[...,1] = 255
159 | out_path = os.path.join(tOFpath, "flow_%04d.jpg" %i)
160 | mag, ang = cv2.cartToPolar(OF_diff[...,0], OF_diff[...,1])
161 | # print("tar max %02.6f, min %02.6f, avg %02.6f" % (mag.max(), mag.min(), mag.mean()))
162 | mag = np.clip(mag, 0.0, maxV)/maxV
163 | hsv[...,0] = ang*180/np.pi/2
164 | hsv[...,2] = mag * 255.0 #
165 | bgr = cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR)
166 | cv2.imwrite(out_path, bgr)
167 |
168 | OF_diff = np.sqrt(np.sum(OF_diff * OF_diff, axis = -1)) # l1 vector norm
169 | # OF_diff, ofy, ofx = crop_8x8(OF_diff)
170 | list_dict["tOF"].append( OF_diff.mean() )
171 | msg += "tOF %02.2f, " %(list_dict["tOF"][-1])
172 |
173 | pre_out_grey = output_grey
174 | pre_tar_grey = target_grey
175 |
176 | target_img, ofy, ofx = crop_8x8(target_img)
177 | output_img, ofy, ofx = crop_8x8(output_img)
178 |
179 | if "PSNR" in keys:# psnr
180 | list_dict["PSNR"].append( psnr(target_img, output_img) )
181 | msg +="psnr %02.2f" %(list_dict["PSNR"][-1])
182 |
183 | if "SSIM" in keys:# ssim
184 | list_dict["SSIM"].append( ssim(target_img, output_img) )
185 | msg +=", ssim %02.2f" %(list_dict["SSIM"][-1])
186 |
187 | if "LPIPS" in keys or "tLP100" in keys:
188 | img0 = util.im2tensor(target_img) # RGB image from [-1,1]
189 | img1 = util.im2tensor(output_img)
190 |
191 | if "LPIPS" in keys: # LPIPS
192 | dist01 = model.forward(img0,img1)
193 | list_dict["LPIPS"].append( dist01[0] )
194 | msg +=", lpips %02.2f" %(dist01[0])
195 |
196 | if "tLP100" in keys and (i > cutfr):# tLP, temporal metrics
197 | dist0t = model.forward(pre_img0, img0)
198 | dist1t = model.forward(pre_img1, img1)
199 | # print ("tardis %f, outdis %f" %(dist0t, dist1t))
200 | dist01t = np.absolute(dist0t - dist1t) * 100.0 ##########!!!!!
201 | list_dict["tLP100"].append( dist01t[0] )
202 | msg += ", tLPx100 %02.2f" %(dist01t[0])
203 | pre_img0 = img0
204 | pre_img1 = img1
205 |
206 | msg +=", crop (%d, %d)" %(ofy, ofx)
207 | #print(msg)
208 | mode = 'w' if folder_i==0 else 'a'
209 |
210 | pd_dict = {}
211 | for cur_num_data in keys:
212 | num_data = cur_num_data+"_%02d" % folder_i
213 | cur_list = np.float32(list_dict[cur_num_data])
214 | pd_dict[num_data] = pd.Series(cur_list)
215 |
216 | num_data_sum = cur_list.sum()
217 | num_data_len = cur_list.shape[0]
218 | num_data_mean = num_data_sum / num_data_len
219 | #print("%s, max %02.4f, min %02.4f, avg %02.4f" %
220 | # (num_data, cur_list.max(), cur_list.min(), num_data_mean))
221 |
222 | if folder_i == 0:
223 | avg_dict["Avg_"+cur_num_data] = [num_data_mean]
224 | else:
225 | avg_dict["Avg_"+cur_num_data] += [num_data_mean]
226 |
227 | sum_dict["FrameAvg_"+cur_num_data] += num_data_sum
228 | len_dict[cur_num_data] += num_data_len
229 | folder_dict["FolderAvg_"+cur_num_data] += num_data_mean
230 |
231 | pd.DataFrame(pd_dict).to_csv(os.path.join(FLAGS.output,"metrics.csv"), mode=mode)
232 |
233 | for num_data in keys:
234 | sum_dict["FrameAvg_"+num_data] = pd.Series([sum_dict["FrameAvg_"+num_data] / len_dict[num_data]])
235 | folder_dict["FolderAvg_"+num_data] = pd.Series([folder_dict["FolderAvg_"+num_data] / folder_n])
236 | avg_dict["Avg_"+num_data] = pd.Series(np.float32(avg_dict["Avg_"+num_data]))
237 | print("%s, total frame %d, total avg %02.4f, folder avg %02.4f" %
238 | (num_data, len_dict[num_data], sum_dict["FrameAvg_"+num_data][0], folder_dict["FolderAvg_"+num_data][0]))
239 | pd.DataFrame(avg_dict).to_csv(os.path.join(FLAGS.output,"metrics.csv"), mode='a')
240 | pd.DataFrame(folder_dict).to_csv(os.path.join(FLAGS.output,"metrics.csv"), mode='a')
241 | pd.DataFrame(sum_dict).to_csv(os.path.join(FLAGS.output,"metrics.csv"), mode='a')
242 | print("Finished.")
243 |
--------------------------------------------------------------------------------
/codes/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/utils/__init__.py
--------------------------------------------------------------------------------
/codes/utils/base_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import random
4 | import logging
5 | import argparse
6 | import yaml
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from .dist_utils import init_dist, master_only
12 |
13 |
14 | def parse_agrs():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--exp_dir', type=str, required=True,
17 | help='directory of the current experiment')
18 | parser.add_argument('--mode', type=str, required=True,
19 | help='which mode to use (train|test|profile)')
20 | parser.add_argument('--opt', type=str, required=True,
21 | help='path to the config yaml file')
22 | parser.add_argument('--gpu_ids', type=str, default='-1',
23 | help='GPU index (set -1 to use CPU)')
24 | parser.add_argument('--lr_size', type=str, default='3x256x256',
25 | help='size of the input frame')
26 | parser.add_argument('--test_speed', action='store_true',
27 | help='whether to test speed')
28 | parser.add_argument('--local_rank', default=-1, type=int,
29 | help='local gpu index')
30 | return parser.parse_args()
31 |
32 |
33 | def parse_configs(args):
34 | # load option file
35 | with open(osp.join(args.exp_dir, args.opt), 'r') as f:
36 | opt = yaml.load(f.read(), Loader=yaml.FullLoader)
37 |
38 | opt['exp_dir'] = args.exp_dir
39 | opt['gpu_ids'] = args.gpu_ids
40 | opt['is_train'] = (args.mode == 'train')
41 |
42 | # setup device
43 | setup_device(opt, args.gpu_ids, args.local_rank)
44 |
45 | # setup random seed
46 | setup_random_seed(opt.get('manual_seed', 2021) + opt['rank'])
47 |
48 | return opt
49 |
50 |
51 | def setup_device(opt, gpu_ids, local_rank):
52 | gpu_ids = tuple(map(int, gpu_ids.split(',')))
53 | if gpu_ids[0] < 0 or not torch.cuda.is_available():
54 | # cpu settings
55 | opt.update({
56 | 'dist': False,
57 | 'device': 'cpu',
58 | 'rank': 0
59 | })
60 | else:
61 | # gpu settings
62 | if len(gpu_ids) == 1:
63 | # single gpu
64 | torch.cuda.set_device(0)
65 | opt.update({
66 | 'dist': False,
67 | 'device': 'cuda',
68 | 'rank': 0
69 | })
70 | else:
71 | # multiple gpus
72 | init_dist(opt, local_rank)
73 |
74 | torch.backends.cudnn.benchmark = True
75 | # torch.backends.cudnn.deterministic = True
76 |
77 |
78 | def setup_random_seed(seed):
79 | random.seed(seed)
80 | np.random.seed(seed)
81 | torch.manual_seed(seed)
82 | torch.cuda.manual_seed(seed)
83 | torch.cuda.manual_seed_all(seed)
84 |
85 |
86 | def setup_logger(name):
87 | # create a logger
88 | base_logger = logging.getLogger(name=name)
89 | base_logger.setLevel(logging.INFO)
90 | # create a logging format
91 | formatter = logging.Formatter(fmt='%(asctime)s [%(levelname)s]: %(message)s')
92 | # create a stream handler & set format
93 | sh = logging.StreamHandler()
94 | sh.setFormatter(formatter)
95 | # add handlers
96 | base_logger.addHandler(sh)
97 |
98 |
99 | @master_only
100 | def log_info(msg, logger_name='base'):
101 | logger = logging.getLogger(logger_name)
102 | logger.info(msg)
103 |
104 |
105 | def print_options(opt, logger_name='base', tab=''):
106 | for key, val in opt.items():
107 | if isinstance(val, dict):
108 | log_info('{}{}:'.format(tab, key), logger_name)
109 | print_options(val, logger_name, tab + ' ')
110 | else:
111 | log_info('{}{}: {}'.format(tab, key, val), logger_name)
112 |
113 |
114 | def retrieve_files(dir, suffix='png|jpg'):
115 | """ retrive files with specific suffix under dir and sub-dirs recursively
116 | """
117 |
118 | def retrieve_files_recursively(dir, file_lst):
119 | for d in sorted(os.listdir(dir)):
120 | dd = osp.join(dir, d)
121 |
122 | if osp.isdir(dd):
123 | retrieve_files_recursively(dd, file_lst)
124 | else:
125 | if osp.splitext(d)[-1].lower() in ['.' + s for s in suffix]:
126 | file_lst.append(dd)
127 |
128 | if not dir:
129 | return []
130 |
131 | if isinstance(suffix, str):
132 | suffix = suffix.split('|')
133 |
134 | file_lst = []
135 | retrieve_files_recursively(dir, file_lst)
136 | file_lst.sort()
137 |
138 | return file_lst
139 |
140 |
141 | def setup_paths(opt, mode):
142 |
143 | def setup_ckpt_dir():
144 | ckpt_dir = opt['train'].get('ckpt_dir', '')
145 | if not ckpt_dir: # default dir
146 | ckpt_dir = osp.join(opt['exp_dir'], 'train', 'ckpt')
147 | opt['train']['ckpt_dir'] = ckpt_dir
148 | os.makedirs(ckpt_dir, exist_ok=True)
149 |
150 | def setup_res_dir():
151 | res_dir = opt['test'].get('res_dir', '')
152 | if not res_dir: # default dir
153 | res_dir = osp.join(opt['exp_dir'], 'test', 'results')
154 | opt['test']['res_dir'] = res_dir
155 | os.makedirs(res_dir, exist_ok=True)
156 |
157 | def setup_json_path():
158 | json_dir = opt['test'].get('json_dir', '')
159 | if not json_dir: # default dir
160 | json_dir = osp.join(opt['exp_dir'], 'test', 'metrics')
161 | opt['test']['json_dir'] = json_dir
162 | os.makedirs(json_dir, exist_ok=True)
163 |
164 | def setup_model_path():
165 | load_path = opt['model']['generator'].get('load_path', '')
166 | if not load_path:
167 | raise ValueError(
168 | 'Pretrained generator model is needed for testing')
169 |
170 | # parse models
171 | ckpt_dir, model_idx = osp.split(load_path)
172 | model_idx = osp.splitext(model_idx)[0]
173 | if model_idx == '*':
174 | # test a serial of models TODO: check validity
175 | start_iter = opt['test']['start_iter']
176 | end_iter = opt['test']['end_iter']
177 | freq = opt['test']['test_freq']
178 | opt['model']['generator']['load_path_lst'] = [
179 | osp.join(ckpt_dir, f'G_iter{iter}.pth')
180 | for iter in range(start_iter, end_iter + 1, freq)]
181 | else:
182 | # test a single model
183 | opt['model']['generator']['load_path_lst'] = [
184 | osp.join(ckpt_dir, f'{model_idx}.pth')]
185 |
186 | if mode == 'train':
187 | setup_ckpt_dir()
188 |
189 | # for validation purpose
190 | for dataset_idx in opt['dataset'].keys():
191 | if 'test' not in dataset_idx:
192 | continue
193 |
194 | if opt['test'].get('save_res', False):
195 | setup_res_dir()
196 |
197 | if opt['test'].get('save_json', False):
198 | setup_json_path()
199 |
200 | elif mode == 'test':
201 | setup_model_path()
202 |
203 | for dataset_idx in opt['dataset'].keys():
204 | if 'test' not in dataset_idx:
205 | continue
206 |
207 | if opt['test'].get('save_res', False):
208 | setup_res_dir()
209 |
210 | if opt['test'].get('save_json', False):
211 | setup_json_path()
212 |
--------------------------------------------------------------------------------
/codes/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | from scipy import signal
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 |
11 | def create_kernel(sigma, ksize=None):
12 | if ksize is None:
13 | ksize = 1 + 2 * int(sigma * 3.0)
14 |
15 | gkern1d = signal.gaussian(ksize, std=sigma).reshape(ksize, 1)
16 | gkern2d = np.outer(gkern1d, gkern1d)
17 | gaussian_kernel = gkern2d / gkern2d.sum()
18 | zero_kernel = np.zeros_like(gaussian_kernel)
19 |
20 | kernel = np.float32([
21 | [gaussian_kernel, zero_kernel, zero_kernel],
22 | [zero_kernel, gaussian_kernel, zero_kernel],
23 | [zero_kernel, zero_kernel, gaussian_kernel]])
24 |
25 | kernel = torch.from_numpy(kernel)
26 |
27 | return kernel
28 |
29 |
30 | def downsample_bd(data, kernel, scale, pad_data):
31 | """
32 | Note:
33 | 1. `data` should be torch.FloatTensor (data range 0~1) in shape [nchw]
34 | 2. `pad_data` should be enabled in model testing
35 | 3. This function is device agnostic, i.e., data/kernel could be on cpu or gpu
36 | """
37 |
38 | if pad_data:
39 | # compute padding params
40 | kernel_h, kernel_w = kernel.shape[-2:]
41 | pad_h, pad_w = kernel_h - 1, kernel_w - 1
42 | pad_t = pad_h // 2
43 | pad_b = pad_h - pad_t
44 | pad_l = pad_w // 2
45 | pad_r = pad_w - pad_l
46 |
47 | # pad data
48 | data = F.pad(data, (pad_l, pad_r, pad_t, pad_b), 'reflect')
49 |
50 | # blur + down sample
51 | data = F.conv2d(data, kernel, stride=scale, bias=None, padding=0)
52 |
53 | return data
54 |
55 |
56 | def rgb_to_ycbcr(img):
57 | """ Coefficients are taken from the official codes of DUF-VSR
58 | This conversion is also the same as that in BasicSR
59 |
60 | Parameters:
61 | :param img: rgb image in type np.uint8
62 | :return: ycbcr image in type np.uint8
63 | """
64 |
65 | T = np.array([
66 | [0.256788235294118, -0.148223529411765, 0.439215686274510],
67 | [0.504129411764706, -0.290992156862745, -0.367788235294118],
68 | [0.097905882352941, 0.439215686274510, -0.071427450980392],
69 | ], dtype=np.float64)
70 |
71 | O = np.array([16, 128, 128], dtype=np.float64)
72 |
73 | img = img.astype(np.float64)
74 | res = np.matmul(img, T) + O
75 | res = res.clip(0, 255).round().astype(np.uint8)
76 |
77 | return res
78 |
79 |
80 | def float32_to_uint8(inputs):
81 | """ Convert np.float32 array to np.uint8
82 |
83 | Parameters:
84 | :param input: np.float32, (NT)CHW, [0, 1]
85 | :return: np.uint8, (NT)CHW, [0, 255]
86 | """
87 | return np.uint8(np.clip(np.round(inputs * 255), 0, 255))
88 |
89 |
90 | def save_sequence(seq_dir, seq_data, frm_idx_lst=None, to_bgr=False):
91 | """ Save each frame of a sequence to .png image in seq_dir
92 |
93 | Parameters:
94 | :param seq_dir: dir to save results
95 | :param seq_data: sequence with shape thwc|uint8
96 | :param frm_idx_lst: specify filename for each frame to be saved
97 | :param to_bgr: whether to flip color channels
98 | """
99 |
100 | if to_bgr:
101 | seq_data = seq_data[..., ::-1] # rgb2bgr
102 |
103 | # use default frm_idx_lst is not specified
104 | tot_frm = len(seq_data)
105 | if frm_idx_lst is None:
106 | frm_idx_lst = ['{:04d}.png'.format(i) for i in range(tot_frm)]
107 |
108 | # save for each frame
109 | os.makedirs(seq_dir, exist_ok=True)
110 | for i in range(tot_frm):
111 | cv2.imwrite(osp.join(seq_dir, frm_idx_lst[i]), seq_data[i])
112 |
--------------------------------------------------------------------------------
/codes/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.distributed as dist
5 | import torch.multiprocessing as mp
6 |
7 |
8 | def init_dist(opt, local_rank):
9 | """ Adopted from BasicSR
10 | """
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | torch.cuda.set_device(local_rank)
14 | dist.init_process_group(backend='nccl')
15 |
16 | rank, world_size = get_dist_info()
17 |
18 | opt.update({
19 | 'dist': True,
20 | 'device': 'cuda',
21 | 'local_rank': local_rank,
22 | 'world_size': world_size,
23 | 'rank': rank
24 | })
25 |
26 |
27 | def get_dist_info():
28 | """ Adopted from BasicSR
29 | """
30 | if dist.is_available():
31 | initialized = dist.is_initialized()
32 | else:
33 | initialized = False
34 |
35 | if initialized:
36 | rank = dist.get_rank()
37 | world_size = dist.get_world_size()
38 | else:
39 | rank = 0
40 | world_size = 1
41 |
42 | return rank, world_size
43 |
44 |
45 | def master_only(func):
46 | """ Adopted from BasicSR
47 | """
48 | @functools.wraps(func)
49 | def wrapper(*args, **kwargs):
50 | rank, _ = get_dist_info()
51 | if rank == 0:
52 | return func(*args, **kwargs)
53 |
54 | return wrapper
55 |
--------------------------------------------------------------------------------
/codes/utils/net_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | # ===----------------- utility functions -------------------- #
9 | def initialize_weights(net_l, init_type='kaiming', scale=1):
10 | """ Modify from BasicSR/MMSR
11 | """
12 |
13 | if not isinstance(net_l, list):
14 | net_l = [net_l]
15 |
16 | for net in net_l:
17 | for m in net.modules():
18 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
19 | if init_type == 'xavier':
20 | nn.init.xavier_uniform_(m.weight)
21 | elif init_type == 'kaiming':
22 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
23 | else:
24 | raise NotImplementedError(init_type)
25 |
26 | m.weight.data *= scale # to stabilize training
27 |
28 | if m.bias is not None:
29 | nn.init.constant_(m.bias.data, 0)
30 |
31 | elif isinstance(m, nn.BatchNorm2d):
32 | nn.init.constant_(m.weight.data, 1)
33 | nn.init.constant_(m.bias.data, 0)
34 |
35 |
36 | def space_to_depth(x, scale):
37 | """ Equivalent to tf.space_to_depth()
38 | """
39 |
40 | n, c, in_h, in_w = x.size()
41 | out_h, out_w = in_h // scale, in_w // scale
42 |
43 | x_reshaped = x.reshape(n, c, out_h, scale, out_w, scale)
44 | x_reshaped = x_reshaped.permute(0, 3, 5, 1, 2, 4)
45 | output = x_reshaped.reshape(n, scale * scale * c, out_h, out_w)
46 |
47 | return output
48 |
49 |
50 | def backward_warp(x, flow, mode='bilinear', padding_mode='border'):
51 | """ Backward warp `x` according to `flow`
52 |
53 | Both x and flow are pytorch tensor in shape `nchw` and `n2hw`
54 |
55 | Reference:
56 | https://github.com/sniklaus/pytorch-spynet/blob/master/run.py#L41
57 | """
58 |
59 | n, c, h, w = x.size()
60 |
61 | # create mesh grid
62 | iu = torch.linspace(-1.0, 1.0, w).view(1, 1, 1, w).expand(n, -1, h, -1)
63 | iv = torch.linspace(-1.0, 1.0, h).view(1, 1, h, 1).expand(n, -1, -1, w)
64 | grid = torch.cat([iu, iv], 1).to(flow.device)
65 |
66 | # normalize flow to [-1, 1]
67 | flow = torch.cat([
68 | flow[:, 0:1, ...] / ((w - 1.0) / 2.0),
69 | flow[:, 1:2, ...] / ((h - 1.0) / 2.0)], dim=1)
70 |
71 | # add flow to grid and reshape to nhw2
72 | grid = (grid + flow).permute(0, 2, 3, 1)
73 |
74 | # bilinear sampling
75 | # Note: `align_corners` is set to `True` by default for PyTorch version < 1.4.0
76 | if int(''.join(torch.__version__.split('.')[:2])) >= 14:
77 | output = F.grid_sample(
78 | x, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
79 | else:
80 | output = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode)
81 |
82 | return output
83 |
84 |
85 | def get_upsampling_func(scale=4, degradation='BI'):
86 | if degradation == 'BI':
87 | upsample_func = functools.partial(
88 | F.interpolate, scale_factor=scale, mode='bilinear',
89 | align_corners=False)
90 |
91 | elif degradation == 'BD':
92 | upsample_func = BicubicUpsampler(scale_factor=scale)
93 |
94 | else:
95 | raise ValueError(f'Unrecognized degradation type: {degradation}')
96 |
97 | return upsample_func
98 |
99 |
100 | # --------------------- utility classes --------------------- #
101 | class BicubicUpsampler(nn.Module):
102 | """ Bicubic upsampling function with similar behavior to that in TecoGAN-Tensorflow
103 |
104 | Note:
105 | This function is different from torch.nn.functional.interpolate and matlab's imresize
106 | in terms of the bicubic kernel and the sampling strategy
107 |
108 | References:
109 | http://verona.fi-p.unam.mx/boris/practicas/CubConvInterp.pdf
110 | https://stackoverflow.com/questions/26823140/imresize-trying-to-understand-the-bicubic-interpolation
111 | """
112 |
113 | def __init__(self, scale_factor, a=-0.75):
114 | super(BicubicUpsampler, self).__init__()
115 |
116 | # calculate weights (according to Eq.(6) in the reference paper)
117 | cubic = torch.FloatTensor([
118 | [0, a, -2*a, a],
119 | [1, 0, -(a + 3), a + 2],
120 | [0, -a, (2*a + 3), -(a + 2)],
121 | [0, 0, a, -a]
122 | ])
123 |
124 | kernels = [
125 | torch.matmul(cubic, torch.FloatTensor([1, s, s**2, s**3]))
126 | for s in [1.0*d/scale_factor for d in range(scale_factor)]
127 | ] # s = x - floor(x)
128 |
129 | # register parameters
130 | self.scale_factor = scale_factor
131 | self.register_buffer('kernels', torch.stack(kernels)) # size: (f, 4)
132 |
133 | def forward(self, input):
134 | n, c, h, w = input.size()
135 | f = self.scale_factor
136 |
137 | # merge n&c
138 | input = input.reshape(n*c, 1, h, w)
139 |
140 | # pad input (left, right, top, bottom)
141 | input = F.pad(input, (1, 2, 1, 2), mode='replicate')
142 |
143 | # calculate output (vertical expansion)
144 | kernel_h = self.kernels.view(f, 1, 4, 1)
145 | output = F.conv2d(input, kernel_h, stride=1, padding=0)
146 | output = output.permute(0, 2, 1, 3).reshape(n*c, 1, f*h, w + 3)
147 |
148 | # calculate output (horizontal expansion)
149 | kernel_w = self.kernels.view(f, 1, 1, 4)
150 | output = F.conv2d(output, kernel_w, stride=1, padding=0)
151 | output = output.permute(0, 2, 3, 1).reshape(n*c, 1, f*h, f*w)
152 |
153 | # split n&c
154 | output = output.reshape(n, c, f*h, f*w)
155 |
156 | return output
157 |
--------------------------------------------------------------------------------
/data/meta/REDS/test_list.txt:
--------------------------------------------------------------------------------
1 | 000
2 | 011
3 | 015
4 | 020
5 |
--------------------------------------------------------------------------------
/data/meta/REDS/train_list.txt:
--------------------------------------------------------------------------------
1 | 001
2 | 002
3 | 003
4 | 004
5 | 005
6 | 006
7 | 007
8 | 008
9 | 009
10 | 010
11 | 012
12 | 013
13 | 014
14 | 016
15 | 017
16 | 018
17 | 019
18 | 021
19 | 022
20 | 023
21 | 024
22 | 025
23 | 026
24 | 027
25 | 028
26 | 029
27 | 030
28 | 031
29 | 032
30 | 033
31 | 034
32 | 035
33 | 036
34 | 037
35 | 038
36 | 039
37 | 040
38 | 041
39 | 042
40 | 043
41 | 044
42 | 045
43 | 046
44 | 047
45 | 048
46 | 049
47 | 050
48 | 051
49 | 052
50 | 053
51 | 054
52 | 055
53 | 056
54 | 057
55 | 058
56 | 059
57 | 060
58 | 061
59 | 062
60 | 063
61 | 064
62 | 065
63 | 066
64 | 067
65 | 068
66 | 069
67 | 070
68 | 071
69 | 072
70 | 073
71 | 074
72 | 075
73 | 076
74 | 077
75 | 078
76 | 079
77 | 080
78 | 081
79 | 082
80 | 083
81 | 084
82 | 085
83 | 086
84 | 087
85 | 088
86 | 089
87 | 090
88 | 091
89 | 092
90 | 093
91 | 094
92 | 095
93 | 096
94 | 097
95 | 098
96 | 099
97 | 100
98 | 101
99 | 102
100 | 103
101 | 104
102 | 105
103 | 106
104 | 107
105 | 108
106 | 109
107 | 110
108 | 111
109 | 112
110 | 113
111 | 114
112 | 115
113 | 116
114 | 117
115 | 118
116 | 119
117 | 120
118 | 121
119 | 122
120 | 123
121 | 124
122 | 125
123 | 126
124 | 127
125 | 128
126 | 129
127 | 130
128 | 131
129 | 132
130 | 133
131 | 134
132 | 135
133 | 136
134 | 137
135 | 138
136 | 139
137 | 140
138 | 141
139 | 142
140 | 143
141 | 144
142 | 145
143 | 146
144 | 147
145 | 148
146 | 149
147 | 150
148 | 151
149 | 152
150 | 153
151 | 154
152 | 155
153 | 156
154 | 157
155 | 158
156 | 159
157 | 160
158 | 161
159 | 162
160 | 163
161 | 164
162 | 165
163 | 166
164 | 167
165 | 168
166 | 169
167 | 170
168 | 171
169 | 172
170 | 173
171 | 174
172 | 175
173 | 176
174 | 177
175 | 178
176 | 179
177 | 180
178 | 181
179 | 182
180 | 183
181 | 184
182 | 185
183 | 186
184 | 187
185 | 188
186 | 189
187 | 190
188 | 191
189 | 192
190 | 193
191 | 194
192 | 195
193 | 196
194 | 197
195 | 198
196 | 199
197 | 200
198 | 201
199 | 202
200 | 203
201 | 204
202 | 205
203 | 206
204 | 207
205 | 208
206 | 209
207 | 210
208 | 211
209 | 212
210 | 213
211 | 214
212 | 215
213 | 216
214 | 217
215 | 218
216 | 219
217 | 220
218 | 221
219 | 222
220 | 223
221 | 224
222 | 225
223 | 226
224 | 227
225 | 228
226 | 229
227 | 230
228 | 231
229 | 232
230 | 233
231 | 234
232 | 235
233 | 236
234 | 237
235 | 238
236 | 239
237 | 240
238 | 241
239 | 242
240 | 243
241 | 244
242 | 245
243 | 246
244 | 247
245 | 248
246 | 249
247 | 250
248 | 251
249 | 252
250 | 253
251 | 254
252 | 255
253 | 256
254 | 257
255 | 258
256 | 259
257 | 260
258 | 261
259 | 262
260 | 263
261 | 264
262 | 265
263 | 266
264 | 267
265 | 268
266 | 269
267 |
--------------------------------------------------------------------------------
/data/put_data_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/data/put_data_here
--------------------------------------------------------------------------------
/experiments_BD/FRVSR/FRVSR_REDS_2xSR_2GPU/test.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 2
3 | manual_seed: 0
4 | verbose: false
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | test:
14 | name: REDS
15 | gt_seq_dir: data/REDS/Raw
16 | lr_seq_dir: ~
17 | filter_list: ['000', '011', '015', '020']
18 | num_worker_per_gpu: 4
19 | pin_memory: true
20 |
21 |
22 | # model configs
23 | model:
24 | name: FRVSR
25 |
26 | generator:
27 | name: FRNet # frame-recurrent network
28 | in_nc: 3
29 | out_nc: 3
30 | nf: 64
31 | nb: 10
32 |
33 | load_path: pretrained_models/FRVSR_2x_BD_REDS_iter400K.pth
34 |
35 |
36 | # validation configs
37 | test:
38 | # whether to save the generated SR results
39 | save_res: false
40 | res_dir: ~ # use default dir
41 |
42 | # whether to save the test results in a json file
43 | save_json: false
44 | json_dir: ~ # use default dir
45 |
46 | padding_mode: reflect
47 | num_pad_front: 5
48 |
49 |
50 | # metric configs
51 | metric:
52 | PSNR:
53 | colorspace: y
54 |
--------------------------------------------------------------------------------
/experiments_BD/FRVSR/FRVSR_REDS_2xSR_2GPU/train.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 2
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | train:
14 | name: REDS
15 | seq_dir: data/REDS/GT.lmdb
16 | filter_file: data/meta/REDS/train_list.txt
17 | data_type: rgb
18 | crop_size: 128
19 | batch_size_per_gpu: 2
20 | num_worker_per_gpu: 3
21 | pin_memory: true
22 |
23 | test:
24 | name: REDS
25 | gt_seq_dir: data/REDS/Raw
26 | lr_seq_dir: ~
27 | filter_list: ['000', '011', '015', '020']
28 | num_worker_per_gpu: 4
29 | pin_memory: true
30 |
31 |
32 | # model configs
33 | model:
34 | name: FRVSR
35 |
36 | generator:
37 | name: FRNet # frame-recurrent network
38 | in_nc: 3
39 | out_nc: 3
40 | nf: 64
41 | nb: 10
42 |
43 | load_path: ~
44 |
45 |
46 | # training settings
47 | train:
48 | tempo_extent: 10
49 |
50 | start_iter: 0
51 | total_iter: 400000
52 |
53 | # configs for generator
54 | generator:
55 | lr: !!float 1e-4
56 | lr_schedule:
57 | type: MultiStepLR
58 | milestones: [150000, 300000]
59 | gamma: 0.5
60 | betas: [0.9, 0.999]
61 |
62 | # other settings
63 | moving_first_frame: true
64 | moving_factor: 0.7
65 |
66 | # criterions
67 | pixel_crit:
68 | type: CB
69 | weight: 1
70 | reduction: mean
71 |
72 | warping_crit:
73 | type: CB
74 | weight: 1
75 | reduction: mean
76 |
77 |
78 | # validation configs
79 | test:
80 | test_freq: 10000
81 |
82 | # whether to save the generated SR results
83 | save_res: false
84 | res_dir: ~ # use default dir
85 |
86 | # whether to save the test results in a json file
87 | save_json: true
88 | json_dir: ~ # use default dir
89 |
90 | padding_mode: reflect
91 | num_pad_front: 5
92 |
93 |
94 | # metric configs
95 | metric:
96 | PSNR:
97 | colorspace: y
98 |
99 |
100 | # logger configs
101 | logger:
102 | log_freq: 100
103 | decay: 0.99
104 | ckpt_freq: 10000
105 |
--------------------------------------------------------------------------------
/experiments_BD/FRVSR/FRVSR_REDS_4xSR_2GPU/test.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: false
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | test:
14 | name: REDS
15 | gt_seq_dir: data/REDS/Raw
16 | lr_seq_dir: ~
17 | filter_list: ['000', '011', '015', '020']
18 | num_worker_per_gpu: 4
19 | pin_memory: true
20 |
21 |
22 | # model configs
23 | model:
24 | name: FRVSR
25 |
26 | generator:
27 | name: FRNet # frame-recurrent network
28 | in_nc: 3
29 | out_nc: 3
30 | nf: 64
31 | nb: 10
32 |
33 | load_path: pretrained_models/FRVSR_4x_BD_REDS_iter400K.pth
34 |
35 |
36 | # validation configs
37 | test:
38 | # whether to save the generated SR results
39 | save_res: true
40 | res_dir: ~ # use default dir
41 |
42 | padding_mode: reflect
43 | num_pad_front: 5
44 |
--------------------------------------------------------------------------------
/experiments_BD/FRVSR/FRVSR_REDS_4xSR_2GPU/train.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | train:
14 | name: REDS
15 | seq_dir: data/REDS/GT.lmdb
16 | filter_file: data/meta/REDS/train_list.txt
17 | data_type: rgb
18 | crop_size: 128
19 | batch_size_per_gpu: 2
20 | num_worker_per_gpu: 3
21 | pin_memory: true
22 |
23 | test:
24 | name: REDS
25 | gt_seq_dir: data/REDS/Raw
26 | lr_seq_dir: ~
27 | filter_list: ['000', '011', '015', '020']
28 | num_worker_per_gpu: 4
29 | pin_memory: true
30 |
31 |
32 | # model configs
33 | model:
34 | name: FRVSR
35 |
36 | generator:
37 | name: FRNet # frame-recurrent network
38 | in_nc: 3
39 | out_nc: 3
40 | nf: 64
41 | nb: 10
42 |
43 | load_path: ~
44 |
45 |
46 | # training settings
47 | train:
48 | tempo_extent: 10
49 |
50 | start_iter: 0
51 | total_iter: 400000
52 |
53 | # configs for generator
54 | generator:
55 | lr: !!float 1e-4
56 | lr_schedule:
57 | type: MultiStepLR
58 | milestones: [150000, 300000]
59 | gamma: 0.5
60 | betas: [0.9, 0.999]
61 |
62 | # other settings
63 | moving_first_frame: true
64 | moving_factor: 0.7
65 |
66 | # criterions
67 | pixel_crit:
68 | type: CB
69 | weight: 1
70 | reduction: mean
71 |
72 | warping_crit:
73 | type: CB
74 | weight: 1
75 | reduction: mean
76 |
77 |
78 | # validation configs
79 | test:
80 | test_freq: 10000
81 |
82 | # whether to save the generated SR results
83 | save_res: false
84 | res_dir: ~ # use default dir
85 |
86 | # whether to save the test results in a json file
87 | save_json: true
88 | json_dir: ~ # use default dir
89 |
90 | padding_mode: reflect
91 | num_pad_front: 5
92 |
93 |
94 | # metric configs
95 | metric:
96 | PSNR:
97 | colorspace: y
98 |
99 |
100 | # logger configs
101 | logger:
102 | log_freq: 100
103 | decay: 0.99
104 | ckpt_freq: 20000
105 |
--------------------------------------------------------------------------------
/experiments_BD/FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU/test.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: false
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | test1:
14 | name: Vid4
15 | gt_seq_dir: data/Vid4/GT
16 | lr_seq_dir: data/Vid4/Gaussian4xLR
17 | num_worker_per_gpu: 3
18 | pin_memory: true
19 |
20 | test2:
21 | name: ToS3
22 | gt_seq_dir: data/ToS3/GT
23 | lr_seq_dir: data/ToS3/Gaussian4xLR
24 | num_worker_per_gpu: 3
25 | pin_memory: true
26 |
27 |
28 | # model configs
29 | model:
30 | name: FRVSR
31 |
32 | generator:
33 | name: FRNet # frame-recurrent network
34 | in_nc: 3
35 | out_nc: 3
36 | nf: 64
37 | nb: 10
38 | load_path: pretrained_models/FRVSR_4x_BD_Vimeo_iter400K.pth
39 |
40 |
41 | # test configs
42 | test:
43 | # whether to save the SR results
44 | save_res: true
45 | res_dir: results
46 |
47 | # temporal padding
48 | padding_mode: reflect
49 | num_pad_front: 5
50 |
--------------------------------------------------------------------------------
/experiments_BD/FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU/train.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | train:
14 | name: VimeoTecoGAN
15 | seq_dir: data/VimeoTecoGAN/GT.lmdb
16 | filter_file: ~
17 | data_type: rgb
18 | crop_size: 128
19 | batch_size_per_gpu: 2
20 | num_worker_per_gpu: 3
21 | pin_memory: true
22 |
23 | test:
24 | name: Vid4
25 | gt_seq_dir: data/Vid4/GT
26 | lr_seq_dir: data/Vid4/Gaussian4xLR
27 | filter_file: ~
28 | num_worker_per_gpu: 4
29 | pin_memory: true
30 |
31 |
32 | # model configs
33 | model:
34 | name: FRVSR
35 |
36 | generator:
37 | name: FRNet # frame-recurrent network
38 | in_nc: 3
39 | out_nc: 3
40 | nf: 64
41 | nb: 10
42 | load_path: ~
43 |
44 |
45 | # training settings
46 | train:
47 | tempo_extent: 10
48 |
49 | start_iter: 0
50 | total_iter: 400000
51 |
52 | # configs for generator
53 | generator:
54 | lr: !!float 1e-4
55 | lr_schedule:
56 | type: MultiStepLR
57 | milestones: [150000, 300000]
58 | gamma: 0.5
59 | betas: [0.9, 0.999]
60 |
61 | # other settings
62 | moving_first_frame: true
63 | moving_factor: 0.7
64 |
65 | # criterions
66 | pixel_crit:
67 | type: CB
68 | weight: 1
69 | reduction: mean
70 |
71 | warping_crit:
72 | type: CB
73 | weight: 1
74 | reduction: mean
75 |
76 |
77 | # validation configs
78 | test:
79 | test_freq: 10000
80 |
81 | # whether to save the generated SR results
82 | save_res: false
83 | res_dir: ~ # use default dir
84 |
85 | # whether to save the test results in a json file
86 | save_json: true
87 | json_dir: ~ # use default dir
88 |
89 | padding_mode: reflect
90 | num_pad_front: 5
91 |
92 |
93 | # metric configs
94 | metric:
95 | PSNR:
96 | colorspace: y
97 |
98 |
99 | # logger configs
100 | logger:
101 | log_freq: 100
102 | decay: 0.99
103 | ckpt_freq: 20000
104 |
--------------------------------------------------------------------------------
/experiments_BD/TecoGAN/TecoGAN_REDS_2xSR_2GPU/test.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 2
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | test:
14 | name: REDS
15 | gt_seq_dir: data/REDS/Raw
16 | lr_seq_dir: ~
17 | filter_list: ['000', '011', '015', '020']
18 | num_worker_per_gpu: 4
19 | pin_memory: true
20 |
21 |
22 | # model configs
23 | model:
24 | name: TecoGAN
25 |
26 | generator:
27 | name: FRNet # frame-recurrent network
28 | in_nc: 3
29 | out_nc: 3
30 | nf: 64
31 | nb: 10
32 |
33 | load_path: pretrained_models/TecoGAN_2x_BD_REDS_iter500K.pth
34 |
35 |
36 | # validation configs
37 | test:
38 | test_freq: 10000
39 |
40 | # whether to save the generated SR results
41 | save_res: true
42 | res_dir: ~ # use default dir
43 |
44 | # whether to save the test results in a json file
45 | save_json: false
46 | json_dir: ~ # use default dir
47 |
48 | padding_mode: reflect
49 | num_pad_front: 5
50 |
51 |
52 | # metric configs
53 | metric:
54 | PSNR:
55 | colorspace: y
56 |
57 | LPIPS:
58 | model: net-lin
59 | net: alex
60 | colorspace: rgb
61 | spatial: false
62 | version: 0.1
63 |
64 | tOF:
65 | colorspace: y
66 |
--------------------------------------------------------------------------------
/experiments_BD/TecoGAN/TecoGAN_REDS_2xSR_2GPU/train.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 2
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | train:
14 | name: REDS
15 | seq_dir: data/REDS/GT.lmdb
16 | filter_file: data/meta/REDS/train_list.txt
17 | data_type: rgb
18 | crop_size: 128
19 | batch_size_per_gpu: 2
20 | num_worker_per_gpu: 3
21 | pin_memory: true
22 |
23 | test:
24 | name: REDS
25 | gt_seq_dir: data/REDS/Raw
26 | lr_seq_dir: ~
27 | filter_list: ['000', '011', '015', '020']
28 | num_worker_per_gpu: 4
29 | pin_memory: true
30 |
31 |
32 | # model configs
33 | model:
34 | name: TecoGAN
35 |
36 | generator:
37 | name: FRNet # frame-recurrent network
38 | in_nc: 3
39 | out_nc: 3
40 | nf: 64
41 | nb: 10
42 |
43 | load_path: experiments_BD/FRVSR/FRVSR_REDS_2xSR_2GPU/train/ckpt/G_iter400000.pth
44 |
45 | discriminator:
46 | name: STNet # spatio-temporal network
47 | in_nc: 3
48 | tempo_range: 3
49 |
50 | load_path: ~
51 |
52 |
53 | # training configs
54 | train:
55 | tempo_extent: 10
56 |
57 | start_iter: 0
58 | total_iter: 500000
59 |
60 | # configs for generator
61 | generator:
62 | lr: !!float 5e-5
63 | lr_schedule:
64 | type: FixedLR
65 | betas: [0.9, 0.999]
66 |
67 | # configs for discriminator
68 | discriminator:
69 | update_policy: adaptive
70 | update_threshold: 0.4
71 | crop_border_ratio: 0.75
72 | lr: !!float 5e-5
73 | lr_schedule:
74 | type: FixedLR
75 | betas: [0.9, 0.999]
76 |
77 | # other configs
78 | moving_first_frame: true
79 | moving_factor: 0.7
80 |
81 | # criterions
82 | pixel_crit:
83 | type: CB
84 | weight: 1
85 | reduction: mean
86 |
87 | warping_crit:
88 | type: CB
89 | weight: 1
90 | reduction: mean
91 |
92 | feature_crit:
93 | type: CosineSimilarity
94 | weight: 0.2
95 | reduction: mean
96 | feature_layers: [8, 17, 26, 35]
97 |
98 | pingpong_crit:
99 | type: CB
100 | weight: 0.5
101 | reduction: mean
102 |
103 | gan_crit:
104 | type: GAN
105 | weight: 0.01
106 | reduction: mean
107 |
108 |
109 | # validation configs
110 | test:
111 | test_freq: 10000
112 |
113 | # whether to save the generated SR results
114 | save_res: false
115 | res_dir: ~ # use default dir
116 |
117 | # whether to save the test results in a json file
118 | save_json: true
119 | json_dir: ~ # use default dir
120 |
121 | padding_mode: reflect
122 | num_pad_front: 5
123 |
124 |
125 | # metric configs
126 | metric:
127 | PSNR:
128 | colorspace: y
129 |
130 | LPIPS:
131 | model: net-lin
132 | net: alex
133 | colorspace: rgb
134 | spatial: false
135 | version: 0.1
136 |
137 | tOF:
138 | colorspace: y
139 |
140 |
141 | # logger configs
142 | logger:
143 | log_freq: 100
144 | decay: 0.99
145 | ckpt_freq: 20000
146 |
--------------------------------------------------------------------------------
/experiments_BD/TecoGAN/TecoGAN_REDS_4xSR_2GPU/test.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: false
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | test:
14 | name: REDS
15 | gt_seq_dir: data/REDS/Raw
16 | lr_seq_dir: ~
17 | filter_list: ['000', '011', '015', '020']
18 | num_worker_per_gpu: 4
19 | pin_memory: true
20 |
21 |
22 | # model configs
23 | model:
24 | name: TecoGAN
25 |
26 | generator:
27 | name: FRNet # frame-recurrent network
28 | in_nc: 3
29 | out_nc: 3
30 | nf: 64
31 | nb: 10
32 |
33 | load_path: pretrained_models/TecoGAN_4x_BD_REDS_iter500K.pth
34 |
35 | discriminator:
36 | name: STNet # spatio-temporal network
37 | in_nc: 3
38 | tempo_range: 3
39 |
40 | load_path: ~
41 |
42 |
43 | # validation configs
44 | test:
45 | # whether to save the generated SR results
46 | save_res: true
47 | res_dir: ~ # use default dir
48 |
49 | padding_mode: reflect
50 | num_pad_front: 5
51 |
--------------------------------------------------------------------------------
/experiments_BD/TecoGAN/TecoGAN_REDS_4xSR_2GPU/train.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | train:
14 | name: REDS
15 | seq_dir: data/REDS/GT.lmdb
16 | filter_file: data/meta/REDS/train_list.txt
17 | data_type: rgb
18 | crop_size: 128
19 | batch_size_per_gpu: 2
20 | num_worker_per_gpu: 3
21 | pin_memory: true
22 |
23 | test:
24 | name: REDS
25 | gt_seq_dir: data/REDS/Raw
26 | lr_seq_dir: ~
27 | filter_list: ['000', '011', '015', '020']
28 | num_worker_per_gpu: 4
29 | pin_memory: true
30 |
31 |
32 | # model configs
33 | model:
34 | name: TecoGAN
35 |
36 | generator:
37 | name: FRNet # frame-recurrent network
38 | in_nc: 3
39 | out_nc: 3
40 | nf: 64
41 | nb: 10
42 |
43 | load_path: experiments_BD/FRVSR/FRVSR_REDS_4xSR_2GPU/train/ckpt/G_iter400000.pth
44 |
45 | discriminator:
46 | name: STNet # spatio-temporal network
47 | in_nc: 3
48 | tempo_range: 3
49 |
50 | load_path: ~
51 |
52 |
53 | # training configs
54 | train:
55 | tempo_extent: 10
56 |
57 | start_iter: 0
58 | total_iter: 500000
59 |
60 | # configs for generator
61 | generator:
62 | lr: !!float 5e-5
63 | lr_schedule:
64 | type: FixedLR
65 | betas: [0.9, 0.999]
66 |
67 | # configs for discriminator
68 | discriminator:
69 | update_policy: adaptive
70 | update_threshold: 0.4
71 | crop_border_ratio: 0.75
72 | lr: !!float 5e-5
73 | lr_schedule:
74 | type: FixedLR
75 | betas: [0.9, 0.999]
76 |
77 | # other configs
78 | moving_first_frame: true
79 | moving_factor: 0.7
80 |
81 | # criterions
82 | pixel_crit:
83 | type: CB
84 | weight: 1
85 | reduction: mean
86 |
87 | warping_crit:
88 | type: CB
89 | weight: 1
90 | reduction: mean
91 |
92 | feature_crit:
93 | type: CosineSimilarity
94 | weight: 0.2
95 | reduction: mean
96 | feature_layers: [8, 17, 26, 35]
97 |
98 | pingpong_crit:
99 | type: CB
100 | weight: 0.5
101 | reduction: mean
102 |
103 | gan_crit:
104 | type: GAN
105 | weight: 0.01
106 | reduction: mean
107 |
108 |
109 | # validation configs
110 | test:
111 | test_freq: 10000
112 |
113 | # whether to save the generated SR results
114 | save_res: false
115 | res_dir: ~ # use default dir
116 |
117 | # whether to save the test results in a json file
118 | save_json: true
119 | json_dir: ~ # use default dir
120 |
121 | padding_mode: reflect
122 | num_pad_front: 5
123 |
124 |
125 | # metric configs
126 | metric:
127 | PSNR:
128 | colorspace: y
129 |
130 | LPIPS:
131 | model: net-lin
132 | net: alex
133 | colorspace: rgb
134 | spatial: false
135 | version: 0.1
136 |
137 | tOF:
138 | colorspace: y
139 |
140 |
141 | # logger configs
142 | logger:
143 | log_freq: 100
144 | decay: 0.99
145 | ckpt_freq: 20000
146 |
--------------------------------------------------------------------------------
/experiments_BD/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/test.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: false
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | test1:
14 | name: Vid4
15 | gt_seq_dir: data/Vid4/GT
16 | lr_seq_dir: data/Vid4/Gaussian4xLR
17 | num_worker_per_gpu: 3
18 | pin_memory: true
19 |
20 | test2:
21 | name: ToS3
22 | gt_seq_dir: data/ToS3/GT
23 | lr_seq_dir: data/ToS3/Gaussian4xLR
24 | num_worker_per_gpu: 3
25 | pin_memory: true
26 |
27 |
28 | # model configs
29 | model:
30 | name: TecoGAN
31 |
32 | generator:
33 | name: FRNet # frame-recurrent network
34 | in_nc: 3
35 | out_nc: 3
36 | nf: 64
37 | nb: 10
38 | load_path: pretrained_models/TecoGAN_4x_BD_Vimeo_iter500K.pth
39 |
40 |
41 | # test configs
42 | test:
43 | # whether to save the SR results
44 | save_res: true
45 | res_dir: results
46 |
47 | # temporal padding
48 | padding_mode: reflect
49 | num_pad_front: 5
50 |
--------------------------------------------------------------------------------
/experiments_BD/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/train.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BD
11 | sigma: 1.5
12 |
13 | train:
14 | name: VimeoTecoGAN
15 | seq_dir: data/VimeoTecoGAN/GT.lmdb
16 | filter_file: ~
17 | data_type: rgb
18 | crop_size: 128
19 | batch_size_per_gpu: 2
20 | num_worker_per_gpu: 3
21 | pin_memory: true
22 |
23 | test:
24 | name: Vid4
25 | gt_seq_dir: data/Vid4/GT
26 | lr_seq_dir: data/Vid4/Gaussian4xLR
27 | filter_file: ~
28 | num_worker_per_gpu: 3
29 | pin_memory: true
30 |
31 |
32 | # model configs
33 | model:
34 | name: TecoGAN
35 |
36 | generator:
37 | name: FRNet # frame-recurrent network
38 | in_nc: 3
39 | out_nc: 3
40 | nf: 64
41 | nb: 10
42 | load_path: pretrained_models/FRVSR_BD_iter400000.pth
43 |
44 | discriminator:
45 | name: STNet # spatio-temporal network
46 | in_nc: 3
47 | tempo_range: 3
48 | load_path: ~
49 |
50 |
51 | # training configs
52 | train:
53 | tempo_extent: 10
54 |
55 | start_iter: 0
56 | total_iter: 500000
57 |
58 | # configs for generator
59 | generator:
60 | lr: !!float 5e-5
61 | lr_schedule:
62 | type: FixedLR
63 | betas: [0.9, 0.999]
64 |
65 | # configs for discriminator
66 | discriminator:
67 | update_policy: adaptive
68 | update_threshold: 0.4
69 | crop_border_ratio: 0.75
70 | lr: !!float 5e-5
71 | lr_schedule:
72 | type: FixedLR
73 | betas: [0.9, 0.999]
74 |
75 | # other configs
76 | moving_first_frame: true
77 | moving_factor: 0.7
78 |
79 | # criterions
80 | pixel_crit:
81 | type: CB
82 | weight: 1
83 | reduction: mean
84 |
85 | warping_crit:
86 | type: CB
87 | weight: 1
88 | reduction: mean
89 |
90 | feature_crit:
91 | type: CosineSimilarity
92 | weight: 0.2
93 | reduction: mean
94 | feature_layers: [8, 17, 26, 35]
95 |
96 | pingpong_crit:
97 | type: CB
98 | weight: 0.5
99 | reduction: mean
100 |
101 | gan_crit:
102 | type: GAN
103 | weight: 0.01
104 | reduction: mean
105 |
106 |
107 | # validation configs
108 | test:
109 | test_freq: 10000
110 |
111 | # whether to save the generated SR results
112 | save_res: false
113 | res_dir: ~ # use default dir
114 |
115 | # whether to save the test results in a json file
116 | save_json: true
117 | json_dir: ~ # use default dir
118 |
119 | padding_mode: reflect
120 | num_pad_front: 5
121 |
122 |
123 | # metric configs
124 | metric:
125 | PSNR:
126 | colorspace: y
127 |
128 | LPIPS:
129 | model: net-lin
130 | net: alex
131 | colorspace: rgb
132 | spatial: false
133 | version: 0.1
134 |
135 | tOF:
136 | colorspace: y
137 |
138 |
139 | # logger configs
140 | logger:
141 | log_freq: 100
142 | decay: 0.99
143 | ckpt_freq: 20000
144 |
--------------------------------------------------------------------------------
/experiments_BI/FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU/test.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: false
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BI
11 |
12 | test1:
13 | name: Vid4
14 | gt_seq_dir: data/Vid4/GT
15 | lr_seq_dir: data/Vid4/Bicubic4xLR
16 | num_worker_per_gpu: 3
17 | pin_memory: true
18 |
19 | test2:
20 | name: ToS3
21 | gt_seq_dir: data/ToS3/GT
22 | lr_seq_dir: data/ToS3/Bicubic4xLR
23 | num_worker_per_gpu: 3
24 | pin_memory: true
25 |
26 |
27 | # model configs
28 | model:
29 | name: FRVSR
30 |
31 | generator:
32 | name: FRNet # frame-recurrent network
33 | in_nc: 3
34 | out_nc: 3
35 | nf: 64
36 | nb: 10
37 | load_path: pretrained_models/FRVSR_4x_BI_Vimeo_iter400K.pth
38 |
39 |
40 | # test configs
41 | test:
42 | # whether to save the SR results
43 | save_res: true
44 | res_dir: results
45 |
46 | # temporal padding
47 | padding_mode: reflect
48 | num_pad_front: 5
49 |
--------------------------------------------------------------------------------
/experiments_BI/FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU/train.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BI
11 |
12 | train:
13 | name: VimeoTecoGAN
14 | gt_seq_dir: data/VimeoTecoGAN/GT.lmdb
15 | lr_seq_dir: data/VimeoTecoGAN/Bicubic4xLR.lmdb
16 | filter_file: ~
17 | data_type: rgb
18 | gt_crop_size: 128
19 | batch_size_per_gpu: 2
20 | num_worker_per_gpu: 3
21 | pin_memory: true
22 |
23 | test:
24 | name: Vid4
25 | gt_seq_dir: data/Vid4/GT
26 | lr_seq_dir: data/Vid4/Bicubic4xLR
27 | filter_file: ~
28 | num_worker_per_gpu: 3
29 | pin_memory: true
30 |
31 |
32 | # model configs
33 | model:
34 | name: FRVSR
35 |
36 | generator:
37 | name: FRNet # frame-recurrent network
38 | in_nc: 3
39 | out_nc: 3
40 | nf: 64
41 | nb: 10
42 | load_path: ~
43 |
44 |
45 | # training settings
46 | train:
47 | tempo_extent: 10
48 |
49 | start_iter: 0
50 | total_iter: 400000
51 |
52 | # configs for generator
53 | generator:
54 | lr: !!float 1e-4
55 | lr_schedule:
56 | type: MultiStepLR
57 | milestones: [150000, 300000]
58 | gamma: 0.5
59 | betas: [0.9, 0.999]
60 |
61 | # other settings
62 | moving_first_frame: true
63 | moving_factor: 0.7
64 |
65 | # criterions
66 | pixel_crit:
67 | type: CB
68 | weight: 1
69 | reduction: mean
70 |
71 | warping_crit:
72 | type: CB
73 | weight: 1
74 | reduction: mean
75 |
76 |
77 | # validation configs
78 | test:
79 | test_freq: 10000
80 |
81 | # whether to save the generated SR results
82 | save_res: false
83 | res_dir: ~ # use default dir
84 |
85 | # whether to save the test results in a json file
86 | save_json: true
87 | json_dir: ~ # use default dir
88 |
89 | padding_mode: reflect
90 | num_pad_front: 5
91 |
92 |
93 | # metric configs
94 | metric:
95 | PSNR:
96 | colorspace: y
97 |
98 |
99 | # logger configs
100 | logger:
101 | log_freq: 100
102 | decay: 0.99
103 | ckpt_freq: 20000
104 |
--------------------------------------------------------------------------------
/experiments_BI/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/test.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: false
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BI
11 |
12 | test1:
13 | name: Vid4
14 | gt_seq_dir: data/Vid4/GT
15 | lr_seq_dir: data/Vid4/Bicubic4xLR
16 | num_worker_per_gpu: 3
17 | pin_memory: true
18 |
19 | test2:
20 | name: ToS3
21 | gt_seq_dir: data/ToS3/GT
22 | lr_seq_dir: data/ToS3/Bicubic4xLR
23 | num_worker_per_gpu: 3
24 | pin_memory: true
25 |
26 |
27 | # model configs
28 | model:
29 | name: TecoGAN
30 |
31 | generator:
32 | name: FRNet # frame-recurrent network
33 | in_nc: 3
34 | out_nc: 3
35 | nf: 64
36 | nb: 10
37 | load_path: pretrained_models/TecoGAN_4x_BI_Vimeo_iter500K.pth
38 |
39 |
40 | # test configs
41 | test:
42 | # whether to save the SR results
43 | save_res: true
44 | res_dir: results
45 |
46 | # temporal padding
47 | padding_mode: reflect
48 | num_pad_front: 5
49 |
--------------------------------------------------------------------------------
/experiments_BI/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/train.yml:
--------------------------------------------------------------------------------
1 | # basic configs
2 | scale: 4
3 | manual_seed: 0
4 | verbose: true
5 |
6 |
7 | # dataset configs
8 | dataset:
9 | degradation:
10 | type: BI
11 |
12 | train:
13 | name: VimeoTecoGAN
14 | gt_seq_dir: data/VimeoTecoGAN/GT.lmdb
15 | lr_seq_dir: data/VimeoTecoGAN/Bicubic4xLR.lmdb
16 | filter_file: ~
17 | data_type: rgb
18 | gt_crop_size: 128
19 | batch_size_per_gpu: 2
20 | num_worker_per_gpu: 3
21 | pin_memory: true
22 |
23 | test:
24 | name: Vid4
25 | gt_seq_dir: data/Vid4/GT
26 | lr_seq_dir: data/Vid4/Bicubic4xLR
27 | filter_file: ~
28 | num_worker_per_gpu: 4
29 | pin_memory: true
30 |
31 |
32 | # model configs
33 | model:
34 | name: TecoGAN
35 |
36 | generator:
37 | name: FRNet # frame-recurrent network
38 | in_nc: 3
39 | out_nc: 3
40 | nf: 64
41 | nb: 10
42 | load_path: pretrained_models/FRVSR_BI_iter400000.pth
43 |
44 | discriminator:
45 | name: STNet # spatio-temporal network
46 | in_nc: 3
47 | tempo_range: 3
48 | load_path: ~
49 |
50 |
51 | # training configs
52 | train:
53 | tempo_extent: 10
54 |
55 | start_iter: 0
56 | total_iter: 500000
57 |
58 | # configs for generator
59 | generator:
60 | lr: !!float 5e-5
61 | lr_schedule:
62 | type: FixedLR
63 | betas: [0.9, 0.999]
64 |
65 | # configs for discriminator
66 | discriminator:
67 | update_policy: adaptive
68 | update_threshold: 0.4
69 | crop_border_ratio: 0.75
70 | lr: !!float 5e-5
71 | lr_schedule:
72 | type: FixedLR
73 | betas: [0.9, 0.999]
74 |
75 | # other configs
76 | moving_first_frame: true
77 | moving_factor: 0.7
78 |
79 | # criterions
80 | pixel_crit:
81 | type: CB
82 | weight: 1
83 | reduction: mean
84 |
85 | warping_crit:
86 | type: CB
87 | weight: 1
88 | reduction: mean
89 |
90 | feature_crit:
91 | type: CosineSimilarity
92 | weight: 0.2
93 | reduction: mean
94 | feature_layers: [8, 17, 26, 35]
95 |
96 | pingpong_crit:
97 | type: CB
98 | weight: 0.5
99 | reduction: mean
100 |
101 | gan_crit:
102 | type: GAN
103 | weight: 0.01
104 | reduction: mean
105 |
106 |
107 | # validation configs
108 | test:
109 | test_freq: 10000
110 |
111 | # whether to save the generated SR results
112 | save_res: false
113 | res_dir: ~ # use default dir
114 |
115 | # whether to save the test results in a json file
116 | save_json: true
117 | json_dir: ~ # use default dir
118 |
119 | padding_mode: reflect
120 | num_pad_front: 5
121 |
122 |
123 | # metric configs
124 | metric:
125 | PSNR:
126 | colorspace: y
127 |
128 | LPIPS:
129 | model: net-lin
130 | net: alex
131 | colorspace: rgb
132 | spatial: false
133 | version: 0.1
134 |
135 | tOF:
136 | colorspace: y
137 |
138 |
139 | # logger configs
140 | logger:
141 | log_freq: 100
142 | decay: 0.99
143 | ckpt_freq: 20000
144 |
145 |
--------------------------------------------------------------------------------
/pretrained_models/put_the_pretrained_models_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/pretrained_models/put_the_pretrained_models_here
--------------------------------------------------------------------------------
/profile.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This script is used to profile a model.
4 |
5 | # basic settings
6 | root_dir=.
7 | degradation=$1
8 | model=$2
9 | gpu_ids=0
10 | # specify the size of input data in the format of [color]x[height]x[weight]
11 | lr_size=$3
12 |
13 |
14 | # run
15 | python ${root_dir}/codes/main.py \
16 | --exp_dir ${root_dir}/experiments_${degradation}/${model} \
17 | --mode profile \
18 | --opt test.yml \
19 | --gpu_ids ${gpu_ids} \
20 | --lr_size ${lr_size} \
21 | --test_speed
22 |
23 |
--------------------------------------------------------------------------------
/resources/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/benchmark.png
--------------------------------------------------------------------------------
/resources/bridge.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/bridge.gif
--------------------------------------------------------------------------------
/resources/fire.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/fire.gif
--------------------------------------------------------------------------------
/resources/foliage.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/foliage.gif
--------------------------------------------------------------------------------
/resources/losses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/losses.png
--------------------------------------------------------------------------------
/resources/metrics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/metrics.png
--------------------------------------------------------------------------------
/resources/pond.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/pond.gif
--------------------------------------------------------------------------------
/scripts/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import argparse
4 | import glob
5 | import lmdb
6 | import pickle
7 | import random
8 |
9 | import numpy as np
10 | import cv2
11 |
12 |
13 | def create_lmdb(dataset, raw_dir, lmdb_dir, filter_file=''):
14 | assert dataset in ('VimeoTecoGAN', 'REDS'), f'Unknown Dataset: {dataset}'
15 | print(f'>> Start to create lmdb for {dataset}')
16 |
17 | # scan dir
18 | if filter_file: # use sequences specified by the filter_file
19 | with open(filter_file, 'r') as f:
20 | seq_idx_lst = sorted([line.strip() for line in f])
21 | else: # use all found sequences
22 | seq_idx_lst = sorted(os.listdir(raw_dir))
23 |
24 | num_seq = len(seq_idx_lst)
25 | print(f'>> Number of sequences: {num_seq}')
26 |
27 | # compute space to be allocated
28 | nbytes = 0
29 | for seq_idx in seq_idx_lst:
30 | frm_path_lst = sorted(glob.glob(osp.join(raw_dir, seq_idx, '*.png')))
31 | nbytes_per_frm = cv2.imread(frm_path_lst[0], cv2.IMREAD_UNCHANGED).nbytes
32 | nbytes += len(frm_path_lst) * nbytes_per_frm
33 | alloc_size = round(2 * nbytes)
34 | print(f'>> Space required for lmdb generation: {alloc_size / (1 << 30):.2f} GB')
35 |
36 | # create lmdb environment
37 | env = lmdb.open(lmdb_dir, map_size=alloc_size)
38 |
39 | # write data to lmdb
40 | commit_freq = 5
41 | keys = []
42 | txn = env.begin(write=True)
43 | for b, seq_idx in enumerate(seq_idx_lst):
44 | # log
45 | print(f' Processing sequence: {seq_idx} ({b + 1}/{num_seq})\r', end='')
46 |
47 | # get info
48 | frm_path_lst = sorted(glob.glob(osp.join(raw_dir, seq_idx, '*.png')))
49 | n_frm = len(frm_path_lst)
50 |
51 | # read frames
52 | for i in range(n_frm):
53 | frm = cv2.imread(frm_path_lst[i], cv2.IMREAD_UNCHANGED)
54 | frm = np.ascontiguousarray(frm[..., ::-1]) # hwc|rgb|uint8
55 |
56 | h, w, c = frm.shape
57 | key = f'{seq_idx}_{n_frm}x{h}x{w}_{i:04d}'
58 |
59 | txn.put(key.encode('ascii'), frm)
60 | keys.append(key)
61 |
62 | # commit
63 | if b % commit_freq == 0:
64 | txn.commit()
65 | txn = env.begin(write=True)
66 |
67 | txn.commit()
68 | env.close()
69 |
70 | # create meta information
71 | meta_info = {
72 | 'name': dataset,
73 | 'color': 'RGB',
74 | 'keys': keys
75 | }
76 | pickle.dump(meta_info, open(osp.join(lmdb_dir, 'meta_info.pkl'), 'wb'))
77 |
78 | print(f'>> Finished lmdb generation for {dataset}')
79 |
80 |
81 | def check_lmdb(dataset, lmdb_dir):
82 |
83 | def visualize(win, img):
84 | cv2.namedWindow(win, 0)
85 | cv2.resizeWindow(win, img.shape[-2], img.shape[-3])
86 | cv2.imshow(win, img[..., ::-1])
87 | cv2.waitKey(0)
88 | cv2.destroyAllWindows()
89 |
90 | assert dataset in ('VimeoTecoGAN', 'REDS'), f'Unknown Dataset: {dataset}'
91 | print(f'>> Start to check lmdb dataset: {dataset}.lmdb')
92 |
93 | # load keys
94 | meta_info = pickle.load(open(osp.join(lmdb_dir, 'meta_info.pkl'), 'rb'))
95 | keys = meta_info['keys']
96 | print(f'>> Number of keys: {len(keys)}')
97 |
98 | # randomly select frames for visualization
99 | with lmdb.open(lmdb_dir) as env:
100 | for i in range(3): # can be replaced to any number
101 | idx = random.randint(0, len(keys) - 1)
102 | key = keys[idx]
103 |
104 | # parse key
105 | key_lst = key.split('_')
106 | vid, sz, frm = '_'.join(key_lst[:-2]), key_lst[-2], key_lst[-1]
107 | sz = tuple(map(int, sz.split('x')))
108 | sz = (*sz[1:], 3)
109 | print(f' Visualizing frame: #{frm} from sequence: {vid} (size: {sz})')
110 |
111 | with env.begin() as txn:
112 | buf = txn.get(key.encode('ascii'))
113 | val = np.frombuffer(buf, dtype=np.uint8).reshape(*sz) # hwc
114 |
115 | visualize(key, val)
116 |
117 | print(f'>> Finished lmdb checking for {dataset}')
118 |
119 |
120 | if __name__ == '__main__':
121 | # parse args
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('--dataset', type=str, required=True,
124 | help='VimeoTecoGAN | REDS')
125 | parser.add_argument('--raw_dir', type=str, required=True,
126 | help='Dir to the raw data')
127 | parser.add_argument('--lmdb_dir', type=str, required=True,
128 | help='Dir to the lmdb data')
129 | parser.add_argument('--filter_file', type=str, default='',
130 | help='File used to select sequences')
131 | args = parser.parse_args()
132 |
133 | # run
134 | if osp.exists(args.lmdb_dir):
135 | print(f'>> Dataset [{args.dataset}] already exists.')
136 | check_lmdb(args.dataset, args.lmdb_dir)
137 | else:
138 | create_lmdb(args.dataset, args.raw_dir, args.lmdb_dir, args.filter_file)
139 | check_lmdb(args.dataset, args.lmdb_dir)
140 |
--------------------------------------------------------------------------------
/scripts/download/download_datasets.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This script is used to download datasets for model evaluation.
4 | # Usage:
5 | # cd
6 | # bash ./scripts/download/download_datasets
7 |
8 |
9 |
10 | function download_large_file() {
11 | local DATA_DIR=$1
12 | local FID=$2
13 | local FNAME=$3
14 |
15 | wget --load-cookies ${DATA_DIR}/cookies.txt -O ${DATA_DIR}/${FNAME}.zip "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ${DATA_DIR}/cookies.txt --keep-session-cookies --no-check-certificate "https://drive.google.com/uc?export=download&id=${FID}" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FID}" && rm -rf ${DATA_DIR}/cookies.txt
16 | }
17 |
18 | function download_small_file() {
19 | local DATA_DIR=$1
20 | local FID=$2
21 | local FNAME=$3
22 |
23 | wget --no-check-certificate -O ${DATA_DIR}/${FNAME}.zip "https://drive.google.com/uc?export=download&id=${FID}"
24 | }
25 |
26 | function check_md5() {
27 | local FPATH=$1
28 | local MD5=$2
29 |
30 | if [ ${MD5} != $(md5sum ${FPATH} | cut -d " " -f1) ]; then
31 | echo "!!! Fail to match MD5 sum for: ${FPATH}"
32 | echo "!!! Please try downloading it again"
33 | exit 1
34 | fi
35 | }
36 |
37 |
38 | # Vid4 GT
39 | if [ ! -d "./data/Vid4/GT" ]; then
40 | DATA_DIR="./data/Vid4"
41 | FID="1T8TuyyOxEUfXzCanH5kvNH2iA8nI06Wj"
42 | MD5="d2850eccf30092418f15afe4a7ea27e5"
43 |
44 | echo ">>> Start to download [Vid4 GT] dataset"
45 |
46 | mkdir -p ${DATA_DIR}
47 | download_large_file ${DATA_DIR} ${FID} GT
48 | check_md5 ${DATA_DIR}/GT.zip ${MD5}
49 | unzip ${DATA_DIR}/GT.zip -d ${DATA_DIR} && rm ${DATA_DIR}/GT.zip
50 | fi
51 |
52 | sleep 1s
53 |
54 | # ToS3 GT
55 | if [ ! -d "./data/ToS3/GT" ]; then
56 | DATA_DIR="./data/ToS3"
57 | FID="1XoR_NVBR-LbZOA8fXh7d4oPV0M8fRi8a"
58 | MD5="56eb9e8298a4e955d618c1658dfc89c9"
59 |
60 | echo ">>> Start to download [ToS3 GT] dataset"
61 |
62 | mkdir -p ${DATA_DIR}
63 | download_large_file ${DATA_DIR} ${FID} GT
64 | check_md5 ${DATA_DIR}/GT.zip ${MD5}
65 | unzip ${DATA_DIR}/GT.zip -d ${DATA_DIR} && rm ${DATA_DIR}/GT.zip
66 | fi
67 |
68 |
69 | if [ $1 == BD ]; then
70 | # Vid4 LR BD
71 | if [ ! -d "./data/Vid4/Gaussian4xLR" ]; then
72 | DATA_DIR="./data/Vid4"
73 | FID="1-5NFW6fEPUczmRqKHtBVyhn2Wge6j3ma"
74 | MD5="3b525cb0f10286743c76950d9949a255"
75 |
76 | echo ">>> Start to download [Vid4 LR] dataset (BD degradation)"
77 |
78 | download_small_file ${DATA_DIR} ${FID} Gaussian4xLR
79 | check_md5 ${DATA_DIR}/Gaussian4xLR.zip ${MD5}
80 | unzip ${DATA_DIR}/Gaussian4xLR.zip -d ${DATA_DIR} && rm ${DATA_DIR}/Gaussian4xLR.zip
81 | fi
82 |
83 | sleep 1s
84 |
85 | # ToS3 LR BD
86 | if [ ! -d "./data/ToS3/Gaussian4xLR" ]; then
87 | DATA_DIR="./data/ToS3"
88 | FID="1rDCe61kR-OykLyCo2Ornd2YgPnul2ffM"
89 | MD5="803609a12453a267eb9c78b68e073e81"
90 |
91 | echo ">>> Start to download [ToS3 LR] dataset (BD degradation)"
92 |
93 | download_large_file ${DATA_DIR} ${FID} Gaussian4xLR
94 | check_md5 ${DATA_DIR}/Gaussian4xLR.zip ${MD5}
95 | unzip ${DATA_DIR}/Gaussian4xLR.zip -d ${DATA_DIR} && rm ${DATA_DIR}/Gaussian4xLR.zip
96 | fi
97 |
98 | elif [ $1 == BI ]; then
99 | # Vid4 LR BI
100 | if [ ! -d "./data/Vid4/Bicubic4xLR" ]; then
101 | DATA_DIR="./data/Vid4"
102 | FID="1Kg0VBgk1r9I1c4f5ZVZ4sbfqtVRYub91"
103 | MD5="35666bd16ce582ae74fa935b3732ae1a"
104 |
105 | echo ">>> Start to download [Vid4 LR] dataset (BI degradation)"
106 |
107 | download_small_file ${DATA_DIR} ${FID} Bicubic4xLR
108 | check_md5 ${DATA_DIR}/Bicubic4xLR.zip ${MD5}
109 | unzip ${DATA_DIR}/Bicubic4xLR.zip -d ${DATA_DIR} && rm ${DATA_DIR}/Bicubic4xLR.zip
110 | fi
111 |
112 | sleep 1s
113 |
114 | # ToS3 LR BI
115 | if [ ! -d "./data/ToS3/Bicubic4xLR" ]; then
116 | DATA_DIR="./data/ToS3"
117 | FID="1FNuC0jajEjH9ycqDkH4cZQ3_eUqjxzzf"
118 | MD5="3b165ffc8819d695500cf565bf3a9ca2"
119 |
120 | echo ">>> Start to download [ToS3 LR] dataset (BI degradation)"
121 |
122 | download_large_file ${DATA_DIR} ${FID} Bicubic4xLR
123 | check_md5 ${DATA_DIR}/Bicubic4xLR.zip ${MD5}
124 | unzip ${DATA_DIR}/Bicubic4xLR.zip -d ${DATA_DIR} && rm ${DATA_DIR}/Bicubic4xLR.zip
125 | fi
126 |
127 | else
128 | echo Unknown Degradation Type: $1 \(Currently supported: \"BD\" or \"BI\"\)
129 | fi
130 |
--------------------------------------------------------------------------------
/scripts/download/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This script is used to download pretrained models
4 | # Usage:
5 | # cd
6 | # bash ./scripts/download/download_models
7 |
8 |
9 | function download_small_file() {
10 | local FPATH=$1
11 | local FID=$2
12 | local MD5=$3
13 |
14 | wget --no-check-certificate -O ${FPATH} "https://drive.google.com/uc?export=download&id=${FID}"
15 |
16 | if [ ${MD5} != $(md5sum ${FPATH} | cut -d " " -f1) ]; then
17 | echo "!!! Fail to match MD5 sum for: ${FPATH}"
18 | echo "!!! Please try downloading it again"
19 | exit 1
20 | fi
21 | }
22 |
23 |
24 |
25 | if [ $1 == BD -a $2 == TecoGAN ]; then
26 | FPATH="./pretrained_models/TecoGAN_BD_iter500000.pth"
27 | if [ ! -f $FPATH ]; then
28 | FID="13FPxKE6q7tuRrfhTE7GB040jBeURBj58"
29 | MD5="13d826c9f066538aea9340e8d3387289"
30 |
31 | echo "Start to download model [TecoGAN BD]"
32 | download_small_file ${FPATH} ${FID} ${MD5}
33 | fi
34 |
35 | elif [ $1 == BD -a $2 == FRVSR ]; then
36 | FPATH="./pretrained_models/FRVSR_BD_iter400000.pth"
37 | if [ ! -f $FPATH ]; then
38 | FID="11kPVS04a3B3k0SD-mKEpY_Q8WL7KrTIA"
39 | MD5="77d33c58b5cbf1fc68a1887be80ed18f"
40 |
41 | echo "Start to download model [FRVSR BD]"
42 | download_small_file ${FPATH} ${FID} ${MD5}
43 | fi
44 |
45 | elif [ $1 == BI -a $2 == TecoGAN ]; then
46 | FPATH="./pretrained_models/TecoGAN_BI_iter500000.pth"
47 | if [ ! -f $FPATH ]; then
48 | FID="1ie1F7wJcO4mhNWK8nPX7F0LgOoPzCwEu"
49 | MD5="4955b65b80f88456e94443d9d042d1e6"
50 |
51 | echo "Start to download model [TecoGAN BI]"
52 | download_small_file ${FPATH} ${FID} ${MD5}
53 | fi
54 |
55 | elif [ $1 == BI -a $2 == FRVSR ]; then
56 | FPATH="./pretrained_models/FRVSR_BI_iter400000.pth"
57 | if [ ! -f $FPATH ]; then
58 | FID="1wejMAFwIBde_7sz-H7zwlOCbCvjt3G9L"
59 | MD5="ad6337d934ec7ca72441082acd80c4ae"
60 |
61 | echo "Start to download model [FRVSR BI]"
62 | download_small_file ${FPATH} ${FID} ${MD5}
63 | fi
64 |
65 | else
66 | echo Unknown combination: $1, $2
67 | fi
68 |
--------------------------------------------------------------------------------
/scripts/generate_lr_bi.m:
--------------------------------------------------------------------------------
1 | function generate_lr_bi()
2 |
3 | up_scale = 4;
4 | mod_scale = 4;
5 | idx = 0;
6 | filepaths = dir('./data/VimeoTecoGAN/Raw/*/*.png');
7 |
8 | for i = 1 : length(filepaths)
9 | [~,imname,ext] = fileparts(filepaths(i).name);
10 | folder_path = filepaths(i).folder;
11 | save_lr_folder = strrep(folder_path, 'Raw', 'Bicubic4xLR')
12 | save_bi_folder = strrep(folder_path, 'Raw', 'Bicubic4xBI')
13 | if ~exist(save_lr_folder, 'dir')
14 | mkdir(save_lr_folder);
15 | end
16 | if ~exist(save_bi_folder, 'dir')
17 | mkdir(save_bi_folder);
18 | end
19 | if isempty(imname)
20 | disp('Ignore . folder.');
21 | elseif strcmp(imname, '.')
22 | disp('Ignore .. folder.');
23 | else
24 | idx = idx + 1;
25 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
26 | fprintf(str_rlt);
27 | % read image
28 | img = imread(fullfile(folder_path, [imname, ext]));
29 | img = im2double(img);
30 | % modcrop
31 | img = modcrop(img, mod_scale);
32 | % LR
33 | im_LR = imresize(img, 1/up_scale, 'bicubic');
34 | im_BI = imresize(im_LR, up_scale, 'bicubic');
35 | if exist('save_lr_folder', 'var')
36 | imwrite(im_LR, fullfile(save_lr_folder, [imname, '.png']));
37 | end
38 | if exist('save_bi_folder', 'var')
39 | imwrite(im_BI, fullfile(save_bi_folder, [imname, '.png']));
40 | end
41 | end
42 | end
43 | end
44 |
45 | %% modcrop
46 | function img = modcrop(img, modulo)
47 | if size(img,3) == 1
48 | sz = size(img);
49 | sz = sz - mod(sz, modulo);
50 | img = img(1:sz(1), 1:sz(2));
51 | else
52 | tmpsz = size(img);
53 | sz = tmpsz(1:2);
54 | sz = sz - mod(sz, modulo);
55 | img = img(1:sz(1), 1:sz(2),:);
56 | end
57 | end
58 |
--------------------------------------------------------------------------------
/scripts/monitor_training.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import argparse
3 | import json
4 | import math
5 | import re
6 | import bisect
7 |
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | # -------------------- utility functions -------------------- #
12 | def append(loss_dict, loss_name, loss_value):
13 | if loss_name not in loss_dict:
14 | loss_dict[loss_name] = [loss_value]
15 | else:
16 | loss_dict[loss_name].append(loss_value)
17 |
18 |
19 | def split(pattern, string):
20 | return re.split(r'\s*{}\s*'.format(pattern), string)
21 |
22 |
23 | def parse_log(log_file):
24 | # define loss patterns
25 | loss_pattern = r'.*\[epoch:.*\| iter: (\d+).*\] (.*)'
26 |
27 | # load log file
28 | with open(log_file, 'r') as f:
29 | lines = [line.strip() for line in f]
30 |
31 | # parse log file
32 | loss_dict = {} # {'iter': [], 'loss1': [], 'loss2':[], ...}
33 | for line in lines:
34 | loss_match = re.match(loss_pattern, line)
35 | if loss_match:
36 | iter = int(loss_match.group(1))
37 | append(loss_dict, 'iter', iter)
38 | for s in split(',', loss_match.group(2)):
39 | if s:
40 | k, v = split(':', s)
41 | append(loss_dict, k, float(v))
42 |
43 | return loss_dict
44 |
45 |
46 | def parse_json(json_file):
47 | with open(json_file, 'r') as f:
48 | json_dict = json.load(f)
49 |
50 | metric_dict = {}
51 | for model_idx, metrics in json_dict.items():
52 | append(metric_dict, 'iter', int(model_idx.replace('G_iter', '')))
53 | for metric, val in metrics.items():
54 | append(metric_dict, metric, float(val))
55 |
56 | return metric_dict
57 |
58 |
59 | def plot_curve(ax, iter, value, style='-', alpha=1.0, label='',
60 | start_iter=0, end_iter=-1, smooth=0, linewidth=1.0):
61 |
62 | assert len(iter) == len(value), \
63 | 'mismatch in and ({} vs {})'.format(
64 | len(iter), len(value))
65 | l = len(iter)
66 |
67 | if smooth:
68 | for i in range(1, l):
69 | value[i] = smooth * value[i - 1] + (1 - smooth) * value[i]
70 |
71 | start_index = bisect.bisect_left(iter, start_iter)
72 | end_index = l if end_iter < 0 else bisect.bisect_right(iter, end_iter)
73 | ax.plot(
74 | iter[start_index:end_index], value[start_index:end_index],
75 | style, alpha=alpha, label=label, linewidth=linewidth)
76 |
77 |
78 | def plot_loss_curves(loss_dict, ax, loss_type, start_iter=0, end_iter=-1,
79 | smooth=0):
80 |
81 | for model_idx, model_loss_dict in loss_dict.items():
82 | if loss_type in model_loss_dict:
83 | plot_curve(
84 | ax, model_loss_dict['iter'], model_loss_dict[loss_type],
85 | alpha=1.0, label=model_idx, start_iter=start_iter,
86 | end_iter=end_iter, smooth=smooth)
87 | ax.legend(loc='best', fontsize='small')
88 | ax.set_ylabel(loss_type)
89 | ax.set_xlabel('iteration')
90 | plt.grid(True)
91 |
92 |
93 | def plot_metric_curves(metric_dict, ax, metric_type, start_iter=0, end_iter=-1):
94 | """ currently can only plot average results
95 | """
96 |
97 | for model_idx, model_metric_dict in metric_dict.items():
98 | if metric_type in model_metric_dict:
99 | plot_curve(
100 | ax, model_metric_dict['iter'], model_metric_dict[metric_type],
101 | alpha=1.0, label=model_idx, start_iter=start_iter,
102 | end_iter=end_iter)
103 | ax.legend(loc='best', fontsize='small')
104 | ax.set_ylabel(metric_type)
105 | ax.set_xlabel('iteration')
106 | plt.grid(True)
107 |
108 |
109 | # -------------------- monitor -------------------- #
110 | def monitor(root_dir, testset, exp_id_lst, loss_lst, metric_lst):
111 | # ================ basic settings ================#
112 | start_iter = 0
113 | loss_smooth = 0
114 |
115 | # ================ parse logs ================#
116 | loss_dict = {} # {'model1': {'loss1': x1, ...}, ...}
117 | metric_dict = {} # {'model1': {'metric1': x1, ...}, ...}
118 | for exp_id in exp_id_lst:
119 | # parse log
120 | log_file = osp.join(root_dir, exp_id, 'train', 'train.log')
121 | if osp.exists(log_file):
122 | loss_dict[exp_id] = parse_log(log_file)
123 |
124 | # parse json
125 | json_file = osp.join(
126 | root_dir, exp_id, 'test', 'metrics', f'{testset}_avg.json')
127 | if osp.exists(json_file):
128 | metric_dict[exp_id] = parse_json(json_file)
129 |
130 | # ================ plot loss curves ================#
131 | n_loss = len(loss_lst)
132 | base_figsize = (12, 2 * math.ceil(n_loss / 2))
133 | fig = plt.figure(1, figsize=base_figsize)
134 | for i in range(n_loss):
135 | ax = fig.add_subplot('{}{}{}'.format(math.ceil(n_loss / 2), 2, i + 1))
136 | plot_loss_curves(
137 | loss_dict, ax, loss_lst[i], start_iter=start_iter, smooth=loss_smooth)
138 |
139 | # ================ plot metric curves ================#
140 | n_metric = len(metric_lst)
141 | base_figsize = (12, 2 * math.ceil(n_metric / 2))
142 | fig = plt.figure(2, figsize=base_figsize)
143 | for i in range(n_metric):
144 | ax = fig.add_subplot('{}{}{}'.format(math.ceil(n_metric / 2), 2, i + 1))
145 | plot_metric_curves(
146 | metric_dict, ax, metric_lst[i], start_iter=start_iter)
147 |
148 | plt.show()
149 |
150 |
151 | if __name__ == '__main__':
152 | # parse args
153 | parser = argparse.ArgumentParser()
154 | parser.add_argument('--degradation', '-dg', type=str, required=True)
155 | parser.add_argument('--model', '-m', type=str, required=True)
156 | parser.add_argument('--dataset', '-ds', type=str, required=True)
157 | args = parser.parse_args()
158 |
159 | # select model
160 | root_dir = '.'
161 | if 'FRVSR' in args.model:
162 | # select experiments
163 | exp_id_lst = [
164 | # experiment dirs
165 | f'experiments_{args.degradation}/{args.model}',
166 | ]
167 |
168 | # select losses
169 | loss_lst = [
170 | 'l_pix_G', # pixel loss
171 | 'l_warp_G', # warping loss
172 | ]
173 |
174 | # select metrics
175 | metric_lst = [
176 | 'PSNR',
177 | ]
178 |
179 | elif 'TecoGAN' in args.model:
180 | # select experiments
181 | exp_id_lst = [
182 | # experiment dirs
183 | f'experiments_{args.degradation}/{args.model}',
184 | ]
185 |
186 | # select losses
187 | loss_lst = [
188 | 'l_pix_G', # pixel loss
189 | 'l_warp_G', # warping loss
190 | 'l_feat_G', # perceptual loss
191 | 'l_gan_G', # generator loss
192 | 'l_gan_D', # discriminator loss
193 | 'p_real_D',
194 | 'p_fake_D',
195 | ]
196 |
197 | # select metrics
198 | metric_lst = [
199 | 'PSNR',
200 | 'LPIPS',
201 | 'tOF'
202 | ]
203 |
204 | else:
205 | raise ValueError(f'Unrecoginzed model: {args.model}')
206 |
207 | monitor(root_dir, args.dataset, exp_id_lst, loss_lst, metric_lst)
208 |
--------------------------------------------------------------------------------
/scripts/resize_bd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import glob
4 | from multiprocessing import Pool
5 |
6 | import numpy as np
7 | import cv2
8 | import torch
9 | import torch.nn.functional as F
10 |
11 | from codes.utils.data_utils import float32_to_uint8
12 |
13 |
14 | """ Note:
15 | There are two implementations (opencv's GaussianBlur and PyTorch's conv2d)
16 | for generating BD degraded LR data. Although there's only a slight numerical
17 | difference between these two methods, it's recommended to use the later one
18 | since it is adopted in model training.
19 | """
20 |
21 | # setup params
22 | # default settings
23 | scale = 4 # downsampling scale
24 | sigma = 1.5
25 | ksize = 1 + 2 * int(sigma * 3.0)
26 | # to be modified
27 | n_process = 16 # the number of process to be used for downsampling
28 | filepaths = glob.glob('../data/Vid4/GT/*/*.png')
29 | gt_dir_idx, lr_dir_idx = 'GT', 'Gaussian4xLR'
30 |
31 |
32 | def down_opencv(img, sigma, ksize, scale):
33 | blur_img = cv2.GaussianBlur(img, (ksize, ksize), sigmaX=sigma) # hwc|uint8
34 | lr_img = blur_img[::scale, ::scale].astype(np.float32) / 255.0
35 | return lr_img # hwc|float32
36 |
37 |
38 | def down_pytorch(img, sigma, ksize, scale):
39 | img = np.ascontiguousarray(img)
40 | img = torch.FloatTensor(img).unsqueeze(0).permute(0, 3, 1, 2) / 255.0 # nchw
41 |
42 | gaussian_filters = create_kernel(sigma, ksize)
43 |
44 | filters_h, filters_w = gaussian_filters.shape[-2:]
45 | pad_h, pad_w = filters_h - 1, filters_w - 1
46 |
47 | pad_t = pad_h // 2
48 | pad_b = pad_h - pad_t
49 | pad_l = pad_w // 2
50 | pad_r = pad_w - pad_l
51 |
52 | img = F.pad(img, (pad_l, pad_r, pad_t, pad_b), 'reflect')
53 |
54 | lr_img = F.conv2d(
55 | img, gaussian_filters, stride=scale, bias=None, padding=0)
56 |
57 | return lr_img[0].permute(1, 2, 0).numpy() # hwc|float32
58 |
59 |
60 | def downsample_worker(filepath):
61 | # log
62 | print('Processing {}'.format(filepath))
63 |
64 | # setup dirs
65 | gt_folder, img_idx = osp.split(filepath)
66 | lr_folder = gt_folder.replace(gt_dir_idx, lr_dir_idx)
67 |
68 | # read image
69 | img = cv2.imread(filepath) # hwc|bgr|uint8
70 |
71 | # dowmsample
72 | # img_lr = down_opencv(img, sigma, ksize, scale)
73 | img_lr = down_pytorch(img, sigma, ksize, scale)
74 | img_lr = float32_to_uint8(img_lr)
75 |
76 | # save image
77 | cv2.imwrite(osp.join(lr_folder, img_idx), img_lr)
78 |
79 |
80 | if __name__ == '__main__':
81 | # setup dirs
82 | print('# of images: {}'.format(len(filepaths)))
83 | for filepath in filepaths:
84 | gt_folder, _ = osp.split(filepath)
85 | lr_folder = gt_folder.replace(gt_dir_idx, lr_dir_idx)
86 | if not osp.exists(lr_folder):
87 | os.makedirs(lr_folder)
88 |
89 | # for each image
90 | pool = Pool(n_process)
91 | for filepath in sorted(filepaths):
92 | pool.apply_async(downsample_worker, args=(filepath,))
93 | pool.close()
94 | pool.join()
95 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This script is used to evaluate a pretrained model.
4 |
5 | # basic settings
6 | root_dir=.
7 | degradation=$1
8 | model=$2
9 | gpu_ids=2,3
10 | master_port=4322
11 |
12 |
13 | # run
14 | num_gpus=`echo ${gpu_ids} | awk -F\, '{print NF}'`
15 | if [[ ${num_gpus} > 1 ]]; then
16 | dist_args="-m torch.distributed.launch --nproc_per_node ${num_gpus} --master_port ${master_port}"
17 | fi
18 |
19 | CUDA_VISIBLE_DEVICES=${gpu_ids} \
20 | python ${dist_args} ${root_dir}/codes/main.py \
21 | --exp_dir ${root_dir}/experiments_${degradation}/${model} \
22 | --mode test \
23 | --opt test.yml \
24 | --gpu_ids ${gpu_ids}
25 |
26 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This script is used to train a model.
4 |
5 | # basic settings
6 | root_dir=.
7 | degradation=$1
8 | model=$2
9 | gpu_ids=0,1 # set to -1 to use cpu
10 | master_port=4321
11 |
12 | debug=0
13 |
14 |
15 | # retain training or train from scratch
16 | start_iter=0
17 | if [[ ${start_iter} > 0 ]]; then
18 | suffix=_iter${start_iter}
19 | else
20 | suffix=''
21 | fi
22 |
23 |
24 | exp_dir=${root_dir}/experiments_${degradation}/${model}
25 | # check
26 | if [ -d "$exp_dir/train" ]; then
27 | echo ">> Experiment dir already exists: $exp_dir/train"
28 | echo ">> Please delete it for retraining"
29 | exit 1
30 | fi
31 | # make dir
32 | mkdir -p ${exp_dir}/train
33 |
34 |
35 | # backup codes
36 | if [[ ${debug} > 0 ]]; then
37 | cp -r ${root_dir}/codes ${exp_dir}/train/codes_backup${suffix}
38 | fi
39 |
40 |
41 | # run
42 | num_gpus=`echo ${gpu_ids} | awk -F\, '{print NF}'`
43 | if [[ ${num_gpus} > 1 ]]; then
44 | dist_args="-m torch.distributed.launch --nproc_per_node ${num_gpus} --master_port ${master_port}"
45 | fi
46 |
47 | CUDA_VISIBLE_DEVICES=${gpu_ids} \
48 | python ${dist_args} ${root_dir}/codes/main.py \
49 | --exp_dir ${exp_dir} \
50 | --mode train \
51 | --opt train${suffix}.yml \
52 | --gpu_ids ${gpu_ids} \
53 | > ${exp_dir}/train/train${suffix}.log 2>&1 &
54 |
55 |
--------------------------------------------------------------------------------