├── LICENSE ├── README.md ├── codes ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── paired_folder_dataset.py │ ├── paired_lmdb_dataset.py │ ├── unpaired_folder_dataset.py │ └── unpaired_lmdb_dataset.py ├── main.py ├── metrics │ ├── LPIPS │ │ ├── LICENSE │ │ ├── __init__.py │ │ └── models │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── dist_model.py │ │ │ ├── networks_basic.py │ │ │ ├── pretrained_networks.py │ │ │ └── weights │ │ │ └── v0.1 │ │ │ ├── alex.pth │ │ │ ├── squeeze.pth │ │ │ └── vgg.pth │ ├── __init__.py │ ├── metric_calculator.py │ └── model_summary.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── networks │ │ ├── __init__.py │ │ ├── base_nets.py │ │ ├── tecogan_nets.py │ │ └── vgg_nets.py │ ├── optim │ │ ├── __init__.py │ │ ├── losses.py │ │ └── lr_schedules.py │ ├── vsr_model.py │ └── vsrgan_model.py ├── official_metrics │ ├── LPIPSmodels │ │ ├── LPIPSsource.txt │ │ ├── base_model.py │ │ ├── dist_model.py │ │ ├── networks_basic.py │ │ ├── pretrained_networks.py │ │ ├── util.py │ │ └── v0.1 │ │ │ └── alex.pth │ ├── evaluate.py │ └── metrics.py └── utils │ ├── __init__.py │ ├── base_utils.py │ ├── data_utils.py │ ├── dist_utils.py │ └── net_utils.py ├── data ├── meta │ └── REDS │ │ ├── test_list.txt │ │ └── train_list.txt └── put_data_here ├── experiments_BD ├── FRVSR │ ├── FRVSR_REDS_2xSR_2GPU │ │ ├── test.yml │ │ └── train.yml │ ├── FRVSR_REDS_4xSR_2GPU │ │ ├── test.yml │ │ └── train.yml │ └── FRVSR_VimeoTecoGAN_4xSR_2GPU │ │ ├── test.yml │ │ └── train.yml └── TecoGAN │ ├── TecoGAN_REDS_2xSR_2GPU │ ├── test.yml │ └── train.yml │ ├── TecoGAN_REDS_4xSR_2GPU │ ├── test.yml │ └── train.yml │ └── TecoGAN_VimeoTecoGAN_4xSR_2GPU │ ├── test.yml │ └── train.yml ├── experiments_BI ├── FRVSR │ └── FRVSR_VimeoTecoGAN_4xSR_2GPU │ │ ├── test.yml │ │ └── train.yml └── TecoGAN │ └── TecoGAN_VimeoTecoGAN_4xSR_2GPU │ ├── test.yml │ └── train.yml ├── pretrained_models └── put_the_pretrained_models_here ├── profile.sh ├── resources ├── benchmark.png ├── bridge.gif ├── fire.gif ├── foliage.gif ├── losses.png ├── metrics.png └── pond.gif ├── scripts ├── create_lmdb.py ├── download │ ├── download_datasets.sh │ └── download_models.sh ├── generate_lr_bi.m ├── monitor_training.py └── resize_bd.py ├── test.sh └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | # TecoGAN-PyTorch 2 | 3 | ### Introduction 4 | This is a PyTorch reimplementation of **TecoGAN**: **Te**mporally **Co**herent **GAN** for Video Super-Resolution (VSR). Please refer to the official TensorFlow implementation [TecoGAN-TensorFlow](https://github.com/thunil/TecoGAN) for more information. 5 | 6 |

7 | 8 | 9 |

10 | 11 |

12 | 13 | 14 |

15 | 16 | 17 | 18 | ### Updates 19 | - 11/2021: Supported 2x SR. 20 | - 10/2021: Supported model training/testing on the [REDS](https://seungjunnah.github.io/Datasets/reds.html) dataset. 21 | - 07/2021: Upgraded codebase to support multi-GPU training & testing. 22 | 23 | 24 | 25 | ### Features 26 | - **Better Performance**: This repo provides model with smaller size yet better performance than the official repo. See our [Benchmark](https://github.com/skycrapers/TecoGAN-PyTorch#benchmark). 27 | - **Multiple Degradations**: This repo supports two types of degradation, BI (Matlab's imresize with the option bicubic) & BD (Gaussian Blurring + Down-sampling). 28 | - **Unified Framework**: This repo provides a unified framework for distortion-based and perception-based VSR methods. 29 | 30 | 31 | 32 | ### Contents 33 | 1. [Dependencies](#dependencies) 34 | 1. [Testing](#testing) 35 | 1. [Training](#training) 36 | 1. [Benchmark](#benchmark) 37 | 1. [License & Citation](#license--citation) 38 | 1. [Acknowledgements](#acknowledgements) 39 | 40 | 41 | 42 | ## Dependencies 43 | - Ubuntu >= 16.04 44 | - NVIDIA GPU + CUDA 45 | - Python >= 3.7 46 | - PyTorch >= 1.4.0 47 | - Python packages: numpy, matplotlib, opencv-python, pyyaml, lmdb 48 | - (Optional) Matlab >= R2016b 49 | 50 | 51 | 52 | ## Testing 53 | 54 | **Note:** We apply different models according to the degradation type. The following steps are for `4xSR` under `BD` degradation. You can switch to `2xSR` or `BI` degradation by replacing all `4x` to `2x` and `BD` to `BI` below. 55 | 56 | 1. Download the official Vid4 and ToS3 datasets. In `BD` mode, only ground-truth data is needed. 57 | ```bash 58 | bash ./scripts/download/download_datasets.sh BD 59 | ``` 60 | > You can manually download these datasets from Google Drive, and unzip them under `./data`. 61 | > * Vid4 Dataset [[Ground-Truth](https://drive.google.com/file/d/1T8TuyyOxEUfXzCanH5kvNH2iA8nI06Wj/view?usp=sharing)] [[Low Resolution (BD)](https://drive.google.com/file/d/1-5NFW6fEPUczmRqKHtBVyhn2Wge6j3ma/view?usp=sharing)] [[Low Resolution (BI)](https://drive.google.com/file/d/1Kg0VBgk1r9I1c4f5ZVZ4sbfqtVRYub91/view?usp=sharing)] 62 | > * ToS3 Dataset [[Ground-Truth](https://drive.google.com/file/d/1XoR_NVBR-LbZOA8fXh7d4oPV0M8fRi8a/view?usp=sharing)] [[Low Resolution (BD)](https://drive.google.com/file/d/1rDCe61kR-OykLyCo2Ornd2YgPnul2ffM/view?usp=sharing)] [[Low Resolution (BI)](https://drive.google.com/file/d/1FNuC0jajEjH9ycqDkH4cZQ3_eUqjxzzf/view?usp=sharing)] 63 | 64 | The dataset structure is shown as below. 65 | ```tex 66 | data 67 | ├─ Vid4 68 | ├─ GT # Ground-Truth (GT) sequences 69 | └─ calendar 70 | └─ ***.png 71 | ├─ Gaussian4xLR # Low Resolution (LR) sequences in BD degradation 72 | └─ calendar 73 | └─ ***.png 74 | └─ Bicubic4xLR # Low Resolution (LR) sequences in BI degradation 75 | └─ calendar 76 | └─ ***.png 77 | └─ ToS3 78 | ├─ GT 79 | ├─ Gaussian4xLR 80 | └─ Bicubic4xLR 81 | ``` 82 | 83 | 2. Download our pre-trained TecoGAN model. 84 | ```bash 85 | bash ./scripts/download/download_models.sh BD TecoGAN 86 | ``` 87 | > You can download the model from [[BD-4x-Vimeo](https://drive.google.com/file/d/13FPxKE6q7tuRrfhTE7GB040jBeURBj58/view?usp=sharing)][[BI-4x-Vimeo](https://drive.google.com/file/d/1ie1F7wJcO4mhNWK8nPX7F0LgOoPzCwEu/view?usp=sharing)][[BD-4x-REDS](https://drive.google.com/file/d/1vMvMbv_BvC2G-qCcaOBkNnkMh_gLNe6q/view?usp=sharing)][[BD-2x-REDS](https://drive.google.com/file/d/1XN5D4hjNvitO9Kb3OrYiKGjwNU0b43ZI/view?usp=sharing)], and put it under `./pretrained_models`. 88 | 89 | 3. Run TecoGAN for 4x SR. The results will be saved in `./results`. You can specify which model and how many gpus to be used in `test.sh`. 90 | ```bash 91 | bash ./test.sh BD TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU 92 | ``` 93 | 94 | 4. Evaluate the upsampled results using the official metrics. These codes are borrowed from [TecoGAN-TensorFlow](https://github.com/thunil/TecoGAN), with minor modifications to adapt to the BI degradation. 95 | ```bash 96 | python ./codes/official_metrics/evaluate.py -m TecoGAN_4x_BD_Vimeo_iter500K 97 | ``` 98 | 99 | 5. Profile model (FLOPs, parameters and speed). You can modify the last argument to specify the size of the LR video. 100 | ```bash 101 | bash ./profile.sh BD TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU 3x134x320 102 | ``` 103 | 104 | ## Training 105 | **Note:** Due to the inaccessibility of the VimeoTecoGAN dataset, we recommend using other public datasets, e.g., REDS, for model training. To use REDS as the training dataset, just download it from [here](https://seungjunnah.github.io/Datasets/reds.html) and replace the following `VimeoTecoGAN` to `REDS`. 106 | 107 | 1. Download the official training dataset according to the instructions in [TecoGAN-TensorFlow](https://github.com/thunil/TecoGAN), rename to `VimeoTecoGAN/Raw`, and place under `./data`. 108 | 109 | 2. Generate LMDB for GT data to accelerate IO. The LR counterpart will then be generated on the fly during training. 110 | ```bash 111 | python ./scripts/create_lmdb.py --dataset VimeoTecoGAN --raw_dir ./data/VimeoTecoGAN/Raw --lmdb_dir ./data/VimeoTecoGAN/GT.lmdb 112 | ``` 113 | 114 | The following shows the dataset structure after finishing the above two steps. 115 | ```tex 116 | data 117 | ├─ VimeoTecoGAN 118 | ├─ Raw # Raw dataset 119 | ├─ scene_2000 120 | └─ ***.png 121 | ├─ scene_2001 122 | └─ ***.png 123 | └─ ... 124 | └─ GT.lmdb # LMDB dataset 125 | ├─ data.mdb 126 | ├─ lock.mdb 127 | └─ meta_info.pkl # each key has format: [vid]_[total_frame]x[h]x[w]_[i-th_frame] 128 | ``` 129 | 130 | 3. **(Optional, this step is only required for BI degradation)** Manually generate the LR sequences with the Matlab's imresize function, and then create LMDB for them. 131 | ```bash 132 | # Generate the raw LR video sequences. Results will be saved at ./data/VimeoTecoGAN/Bicubic4xLR 133 | matlab -nodesktop -nosplash -r "cd ./scripts; generate_lr_bi" 134 | 135 | # Create LMDB for the LR video sequences 136 | python ./scripts/create_lmdb.py --dataset VimeoTecoGAN --raw_dir ./data/VimeoTecoGAN/Bicubic4xLR --lmdb_dir ./data/VimeoTecoGAN/Bicubic4xLR.lmdb 137 | ``` 138 | 139 | 4. Train a FRVSR model first, which can provide a better initialization for the subsequent TecoGAN training. FRVSR has the same generator as TecoGAN, but without perceptual training (GAN and perceptual losses). 140 | ```bash 141 | bash ./train.sh BD FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU 142 | ``` 143 | > You can download and use our pre-trained FRVSR models instead of training from scratch. [[BD-4x-Vimeo](https://drive.google.com/file/d/11kPVS04a3B3k0SD-mKEpY_Q8WL7KrTIA/view?usp=sharing)] [[BI-4x-Vimeo](https://drive.google.com/file/d/1wejMAFwIBde_7sz-H7zwlOCbCvjt3G9L/view?usp=sharing)] [[BD-4x-REDS](https://drive.google.com/file/d/1YyTwBFF6P9xy6b9UBILF4ornCdmWbDLY/view?usp=sharing)][[BD-2x-REDS](https://drive.google.com/file/d/1ibsr3td1rYeKsDc2d-J9-8jFURBFc_ST/view?usp=sharing)] 144 | 145 | When the training is complete, set the generator's `load_path` in `experiments_BD/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/train.yml` to the latest checkpoint weight of the FRVSR model. 146 | 147 | 5. Train a TecoGAN model. You can specify which gpu to be used in `train.sh`. By default, the training is conducted in the background and the output info will be logged in `./experiments_BD/TecoGAN/TecoGAN_VimeoTecoGAN/train/train.log`. 148 | ```bash 149 | bash ./train.sh BD TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU 150 | ``` 151 | 152 | 6. Run the following script to monitor the training process and visualize the validation performance. 153 | ```bash 154 | python ./scripts/monitor_training.py -dg BD -m TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU -ds Vid4 155 | ``` 156 | > Note that the validation results are NOT exactly the same as the testing results mentioned above due to different implementation of the metrics. The differences are caused by croping policy, LPIPS version and some other issues. 157 | 158 |

159 | 160 | 161 |

162 | 163 | 164 | 165 | ## Benchmark 166 | 167 |

168 | 169 |

170 | 171 | > [1] FLOPs & speed are computed on RGB sequence with resolution 134\*320 on a single NVIDIA 1080Ti GPU. \ 172 | > [2] Both FRVSR & TecoGAN use 10 residual blocks, while TecoGAN+ has 16 residual blocks. 173 | 174 | 175 | 176 | ## License & Citation 177 | If you use this code for your research, please cite the following paper and our project. 178 | ```tex 179 | @article{tecogan2020, 180 | title={Learning temporal coherence via self-supervision for GAN-based video generation}, 181 | author={Chu, Mengyu and Xie, You and Mayer, Jonas and Leal-Taix{\'e}, Laura and Thuerey, Nils}, 182 | journal={ACM Transactions on Graphics (TOG)}, 183 | volume={39}, 184 | number={4}, 185 | pages={75--1}, 186 | year={2020}, 187 | publisher={ACM New York, NY, USA} 188 | } 189 | ``` 190 | ```tex 191 | @misc{tecogan_pytorch, 192 | author={Deng, Jianing and Zhuo, Cheng}, 193 | title={PyTorch Implementation of Temporally Coherent GAN (TecoGAN) for Video Super-Resolution}, 194 | howpublished="\url{https://github.com/skycrapers/TecoGAN-PyTorch}", 195 | year={2020}, 196 | } 197 | ``` 198 | 199 | 200 | 201 | ## Acknowledgements 202 | This code is built on [TecoGAN-TensorFlow](https://github.com/thunil/TecoGAN), [BasicSR](https://github.com/xinntao/BasicSR) and [LPIPS](https://github.com/richzhang/PerceptualSimilarity). We thank the authors for sharing their codes. 203 | 204 | If you have any questions, feel free to email me `jn.deng@foxmail.com` 205 | -------------------------------------------------------------------------------- /codes/data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.distributed import DistributedSampler 4 | 5 | from .paired_lmdb_dataset import PairedLMDBDataset 6 | from .unpaired_lmdb_dataset import UnpairedLMDBDataset 7 | from .paired_folder_dataset import PairedFolderDataset 8 | from .unpaired_folder_dataset import UnpairedFolderDataset 9 | 10 | 11 | def create_dataloader(opt, phase, idx): 12 | # set params 13 | data_opt = opt['dataset'].get(idx) 14 | degradation_type = opt['dataset']['degradation']['type'] 15 | 16 | # === create loader for training === # 17 | if phase == 'train': 18 | # check dataset 19 | assert data_opt['name'] in ('VimeoTecoGAN', 'REDS'), \ 20 | f'Unknown Dataset: {data_opt["name"]}' 21 | 22 | if degradation_type == 'BI': 23 | # create dataset 24 | dataset = PairedLMDBDataset( 25 | data_opt, 26 | scale=opt['scale'], 27 | tempo_extent=opt['train']['tempo_extent'], 28 | moving_first_frame=opt['train'].get('moving_first_frame', False), 29 | moving_factor=opt['train'].get('moving_factor', 1.0)) 30 | 31 | elif degradation_type == 'BD': 32 | # enlarge the crop size to incorporate border 33 | sigma = opt['dataset']['degradation']['sigma'] 34 | enlarged_crop_size = data_opt['crop_size'] + 2 * int(sigma * 3.0) 35 | 36 | # create dataset 37 | dataset = UnpairedLMDBDataset( 38 | data_opt, 39 | crop_size=enlarged_crop_size, # override 40 | tempo_extent=opt['train']['tempo_extent'], 41 | moving_first_frame=opt['train'].get('moving_first_frame', False), 42 | moving_factor=opt['train'].get('moving_factor', 1.0)) 43 | 44 | else: 45 | raise ValueError(f'Unrecognized degradation type: {degradation_type}') 46 | 47 | # create data loader 48 | if opt['dist']: 49 | batch_size = data_opt['batch_size_per_gpu'] 50 | shuffle = False 51 | sampler = DistributedSampler(dataset) 52 | else: 53 | batch_size = data_opt['batch_size_per_gpu'] 54 | shuffle = True 55 | sampler = None 56 | 57 | loader = DataLoader( 58 | dataset=dataset, 59 | batch_size=batch_size, 60 | shuffle=shuffle, 61 | drop_last=True, 62 | sampler=sampler, 63 | num_workers=data_opt['num_worker_per_gpu'], 64 | pin_memory=data_opt['pin_memory']) 65 | 66 | # === create loader for testing === # 67 | elif phase == 'test': 68 | # create data loader 69 | if 'lr_seq_dir' in data_opt and data_opt['lr_seq_dir']: 70 | loader = DataLoader( 71 | dataset=PairedFolderDataset(data_opt), 72 | batch_size=1, 73 | shuffle=False, 74 | num_workers=data_opt['num_worker_per_gpu'], 75 | pin_memory=data_opt['pin_memory']) 76 | 77 | else: 78 | assert degradation_type == 'BD', \ 79 | '"lr_seq_dir" is required for BI mode' 80 | 81 | sigma = opt['dataset']['degradation']['sigma'] 82 | ksize = 2 * int(sigma * 3.0) + 1 83 | 84 | loader = DataLoader( 85 | dataset=UnpairedFolderDataset( 86 | data_opt, scale=opt['scale'], sigma=sigma, ksize=ksize), 87 | batch_size=1, 88 | shuffle=False, 89 | num_workers=data_opt['num_worker_per_gpu'], 90 | pin_memory=data_opt['pin_memory']) 91 | 92 | else: 93 | raise ValueError(f'Unrecognized phase: {phase}') 94 | 95 | return loader 96 | -------------------------------------------------------------------------------- /codes/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class BaseDataset(Dataset): 8 | def __init__(self, data_opt, **kwargs): 9 | # dict to attr 10 | for kw, args in data_opt.items(): 11 | setattr(self, kw, args) 12 | 13 | # can be used to override options defined in data_opt 14 | for kw, args in kwargs.items(): 15 | setattr(self, kw, args) 16 | 17 | def __len__(self): 18 | pass 19 | 20 | def __getitem__(self, item): 21 | pass 22 | 23 | def check_info(self, gt_keys, lr_keys): 24 | if len(gt_keys) != len(lr_keys): 25 | raise ValueError( 26 | f'GT & LR contain different numbers of images ({len(gt_keys)} vs. {len(lr_keys)})') 27 | 28 | for i, (gt_key, lr_key) in enumerate(zip(gt_keys, lr_keys)): 29 | gt_info = self.parse_lmdb_key(gt_key) 30 | lr_info = self.parse_lmdb_key(lr_key) 31 | 32 | if gt_info[0] != lr_info[0]: 33 | raise ValueError( 34 | f'video index mismatch ({gt_info[0]} vs. {lr_info[0]} for the {i} key)') 35 | 36 | gt_num, gt_h, gt_w = gt_info[1] 37 | lr_num, lr_h, lr_w = lr_info[1] 38 | s = self.scale 39 | if (gt_num != lr_num) or (gt_h != lr_h * s) or (gt_w != lr_w * s): 40 | raise ValueError( 41 | f'video size mismatch ({gt_info[1]} vs. {lr_info[1]} for the {i} key)') 42 | 43 | if gt_info[2] != lr_info[2]: 44 | raise ValueError( 45 | f'frame mismatch ({gt_info[2]} vs. {lr_info[2]} for the {i} key)') 46 | 47 | @staticmethod 48 | def init_lmdb(seq_dir): 49 | env = lmdb.open( 50 | seq_dir, readonly=True, lock=False, readahead=False, meminit=False) 51 | return env 52 | 53 | @staticmethod 54 | def parse_lmdb_key(key): 55 | key_lst = key.split('_') 56 | idx, size, frm = key_lst[:-2], key_lst[-2], int(key_lst[-1]) 57 | idx = '_'.join(idx) 58 | size = tuple(map(int, size.split('x'))) # n_frm, h, w 59 | return idx, size, frm 60 | 61 | @staticmethod 62 | def read_lmdb_frame(env, key, size): 63 | with env.begin(write=False) as txn: 64 | buf = txn.get(key.encode('ascii')) 65 | frm = np.frombuffer(buf, dtype=np.uint8).reshape(*size) 66 | return frm 67 | 68 | def crop_sequence(self, **kwargs): 69 | pass 70 | 71 | @staticmethod 72 | def augment_sequence(**kwargs): 73 | pass 74 | -------------------------------------------------------------------------------- /codes/data/paired_folder_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from .base_dataset import BaseDataset 9 | from utils.base_utils import retrieve_files 10 | 11 | 12 | class PairedFolderDataset(BaseDataset): 13 | """ Folder dataset for paired data. It supports both BI & BD degradation. 14 | """ 15 | 16 | def __init__(self, data_opt, **kwargs): 17 | super(PairedFolderDataset, self).__init__(data_opt, **kwargs) 18 | 19 | # get keys 20 | gt_keys = sorted(os.listdir(self.gt_seq_dir)) 21 | lr_keys = sorted(os.listdir(self.lr_seq_dir)) 22 | self.keys = sorted(list(set(gt_keys) & set(lr_keys))) 23 | 24 | # filter keys 25 | sel_keys = set(self.keys) 26 | if hasattr(self, 'filter_file') and self.filter_file is not None: 27 | with open(self.filter_file, 'r') as f: 28 | sel_keys = {line.strip() for line in f} 29 | elif hasattr(self, 'filter_list') and self.filter_list is not None: 30 | sel_keys = set(self.filter_list) 31 | self.keys = sorted(list(sel_keys & set(self.keys))) 32 | 33 | def __len__(self): 34 | return len(self.keys) 35 | 36 | def __getitem__(self, item): 37 | key = self.keys[item] 38 | 39 | # load gt frames 40 | gt_seq = [] 41 | for frm_path in retrieve_files(osp.join(self.gt_seq_dir, key)): 42 | frm = cv2.imread(frm_path)[..., ::-1] 43 | gt_seq.append(frm) 44 | gt_seq = np.stack(gt_seq) # thwc|rgb|uint8 45 | 46 | # load lr frames 47 | lr_seq = [] 48 | for frm_path in retrieve_files(osp.join(self.lr_seq_dir, key)): 49 | frm = cv2.imread(frm_path)[..., ::-1].astype(np.float32) / 255.0 50 | lr_seq.append(frm) 51 | lr_seq = np.stack(lr_seq) # thwc|rgb|float32 52 | 53 | # convert to tensor 54 | gt_tsr = torch.from_numpy(np.ascontiguousarray(gt_seq)) # uint8 55 | lr_tsr = torch.from_numpy(np.ascontiguousarray(lr_seq)) # float32 56 | 57 | # gt: thwc|rgb|uint8 | lr: thwc|rgb|float32 58 | return { 59 | 'gt': gt_tsr, 60 | 'lr': lr_tsr, 61 | 'seq_idx': key, 62 | 'frm_idx': sorted(os.listdir(osp.join(self.gt_seq_dir, key))) 63 | } 64 | -------------------------------------------------------------------------------- /codes/data/paired_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | 9 | from .base_dataset import BaseDataset 10 | 11 | 12 | class PairedLMDBDataset(BaseDataset): 13 | """ LMDB dataset for paired data (for BI degradation) 14 | """ 15 | 16 | def __init__(self, data_opt, **kwargs): 17 | super(PairedLMDBDataset, self).__init__(data_opt, **kwargs) 18 | 19 | # load meta info 20 | gt_meta = pickle.load( 21 | open(osp.join(self.gt_seq_dir, 'meta_info.pkl'), 'rb')) 22 | lr_meta = pickle.load( 23 | open(osp.join(self.lr_seq_dir, 'meta_info.pkl'), 'rb')) 24 | gt_keys = sorted(gt_meta['keys']) 25 | lr_keys = sorted(lr_meta['keys']) 26 | 27 | self.check_info(gt_keys, lr_keys) 28 | self.gt_lr_keys = list(zip(gt_keys, lr_keys)) 29 | 30 | # use partial videos 31 | if hasattr(self, 'filter_file') and self.filter_file is not None: 32 | with open(self.filter_file, 'r') as f: 33 | sel_seqs = { line.strip() for line in f } 34 | self.gt_lr_keys = list(filter( 35 | lambda x: self.parse_lmdb_key(x[0])[0] in sel_seqs, self.gt_lr_keys)) 36 | 37 | # register parameters 38 | self.gt_env = None 39 | self.lr_env = None 40 | 41 | def __len__(self): 42 | return len(self.gt_lr_keys) 43 | 44 | def __getitem__(self, item): 45 | if self.gt_env is None: 46 | self.gt_env = self.init_lmdb(self.gt_seq_dir) 47 | if self.lr_env is None: 48 | self.lr_env = self.init_lmdb(self.lr_seq_dir) 49 | 50 | # parse info 51 | gt_key, lr_key = self.gt_lr_keys[item] 52 | idx, (tot_frm, gt_h, gt_w), cur_frm = self.parse_lmdb_key(gt_key) 53 | _, (_, lr_h, lr_w), _ = self.parse_lmdb_key(lr_key) 54 | 55 | c = 3 if self.data_type.lower() == 'rgb' else 1 56 | s = self.scale 57 | assert (gt_h == lr_h * s) and (gt_w == lr_w * s) 58 | 59 | # get frames 60 | gt_frms, lr_frms = [], [] 61 | if self.moving_first_frame and (random.uniform(0, 1) > self.moving_factor): 62 | # load the first gt&lr frame 63 | gt_frm = self.read_lmdb_frame( 64 | self.gt_env, gt_key, size=(gt_h, gt_w, c)) 65 | gt_frm = gt_frm.transpose(2, 0, 1) # chw|rgb|uint8 66 | 67 | lr_frm = self.read_lmdb_frame( 68 | self.lr_env, lr_key, size=(lr_h, lr_w, c)) 69 | lr_frm = lr_frm.transpose(2, 0, 1) # chw|rgb|uint8 70 | 71 | # generate random moving parameters 72 | offsets = np.floor( 73 | np.random.uniform(-1.5, 1.5, size=(self.tempo_extent, 2))) 74 | offsets = offsets.astype(np.int32) 75 | pos = np.cumsum(offsets, axis=0) 76 | min_pos = np.min(pos, axis=0) 77 | topleft_pos = pos - min_pos 78 | range_pos = np.max(pos, axis=0) - min_pos 79 | c_h, c_w = lr_h - range_pos[0], lr_w - range_pos[1] 80 | 81 | # generate frames 82 | for i in range(self.tempo_extent): 83 | lr_top, lr_left = topleft_pos[i] 84 | lr_frms.append(lr_frm[ 85 | :, lr_top: lr_top + c_h, lr_left: lr_left + c_w].copy()) 86 | 87 | gt_top, gt_left = lr_top * s, lr_left * s 88 | gt_frms.append(gt_frm[ 89 | :, gt_top: gt_top + c_h * s, gt_left: gt_left + c_w * s].copy()) 90 | else: 91 | # read frames 92 | for i in range(cur_frm, cur_frm + self.tempo_extent): 93 | if i >= tot_frm: 94 | # reflect temporal paddding, e.g., (0,1,2) -> (0,1,2,1,0) 95 | gt_key = '{}_{}x{}x{}_{:04d}'.format( 96 | idx, tot_frm, gt_h, gt_w, 2 * tot_frm - i - 2) 97 | lr_key = '{}_{}x{}x{}_{:04d}'.format( 98 | idx, tot_frm, lr_h, lr_w, 2 * tot_frm - i - 2) 99 | else: 100 | gt_key = '{}_{}x{}x{}_{:04d}'.format( 101 | idx, tot_frm, gt_h, gt_w, i) 102 | lr_key = '{}_{}x{}x{}_{:04d}'.format( 103 | idx, tot_frm, lr_h, lr_w, i) 104 | 105 | gt_frm = self.read_lmdb_frame( 106 | self.gt_env, gt_key, size=(gt_h, gt_w, c)) 107 | gt_frm = gt_frm.transpose(2, 0, 1) # chw|rgb|uint8 108 | gt_frms.append(gt_frm) 109 | 110 | lr_frm = self.read_lmdb_frame( 111 | self.lr_env, lr_key, size=(lr_h, lr_w, c)) 112 | lr_frm = lr_frm.transpose(2, 0, 1) 113 | lr_frms.append(lr_frm) 114 | 115 | gt_frms = np.stack(gt_frms) # tchw|rgb|uint8 116 | lr_frms = np.stack(lr_frms) 117 | 118 | # crop randomly 119 | gt_pats, lr_pats = self.crop_sequence(gt_frms, lr_frms) 120 | 121 | # augment patches 122 | gt_pats, lr_pats = self.augment_sequence(gt_pats, lr_pats) 123 | 124 | # convert to tensor and normalize to range [0, 1] 125 | gt_tsr = torch.FloatTensor(np.ascontiguousarray(gt_pats)) / 255.0 126 | lr_tsr = torch.FloatTensor(np.ascontiguousarray(lr_pats)) / 255.0 127 | 128 | # tchw|rgb|float32 129 | return {'gt': gt_tsr, 'lr': lr_tsr} 130 | 131 | def crop_sequence(self, gt_frms, lr_frms): 132 | gt_csz = self.gt_crop_size 133 | lr_csz = self.gt_crop_size // self.scale 134 | 135 | lr_h, lr_w = lr_frms.shape[-2:] 136 | assert (lr_csz <= lr_h) and (lr_csz <= lr_w), \ 137 | 'the crop size is larger than the image size' 138 | 139 | # crop lr 140 | lr_top = random.randint(0, lr_h - lr_csz) 141 | lr_left = random.randint(0, lr_w - lr_csz) 142 | lr_pats = lr_frms[ 143 | ..., lr_top: lr_top + lr_csz, lr_left: lr_left + lr_csz] 144 | 145 | # crop gt 146 | gt_top = lr_top * self.scale 147 | gt_left = lr_left * self.scale 148 | gt_pats = gt_frms[ 149 | ..., gt_top: gt_top + gt_csz, gt_left: gt_left + gt_csz] 150 | 151 | return gt_pats, lr_pats 152 | 153 | @staticmethod 154 | def augment_sequence(gt_pats, lr_pats): 155 | # flip 156 | axis = random.randint(1, 3) 157 | if axis > 1: 158 | gt_pats = np.flip(gt_pats, axis) 159 | lr_pats = np.flip(lr_pats, axis) 160 | 161 | # rotate 90 degree 162 | k = random.randint(0, 3) 163 | gt_pats = np.rot90(gt_pats, k, (2, 3)) 164 | lr_pats = np.rot90(lr_pats, k, (2, 3)) 165 | 166 | return gt_pats, lr_pats 167 | -------------------------------------------------------------------------------- /codes/data/unpaired_folder_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from .base_dataset import BaseDataset 9 | from utils.base_utils import retrieve_files 10 | 11 | 12 | class UnpairedFolderDataset(BaseDataset): 13 | """ Folder dataset for unpaired data (for BD degradation) 14 | """ 15 | 16 | def __init__(self, data_opt, **kwargs): 17 | super(UnpairedFolderDataset, self).__init__(data_opt, **kwargs) 18 | 19 | # get keys 20 | self.keys = sorted(os.listdir(self.gt_seq_dir)) 21 | 22 | # filter keys 23 | sel_keys = set(self.keys) 24 | if hasattr(self, 'filter_file') and self.filter_file is not None: 25 | with open(self.filter_file, 'r') as f: 26 | sel_keys = {line.strip() for line in f} 27 | elif hasattr(self, 'filter_list') and self.filter_list is not None: 28 | sel_keys = set(self.filter_list) 29 | self.keys = sorted(list(sel_keys & set(self.keys))) 30 | 31 | def __len__(self): 32 | return len(self.keys) 33 | 34 | def __getitem__(self, item): 35 | key = self.keys[item] 36 | 37 | # load gt frames 38 | gt_seq = [] 39 | for frm_path in retrieve_files(osp.join(self.gt_seq_dir, key)): 40 | gt_frm = cv2.imread(frm_path)[..., ::-1] 41 | gt_seq.append(gt_frm) 42 | gt_seq = np.stack(gt_seq) # thwc|rgb|uint8 43 | 44 | # convert to tensor 45 | gt_tsr = torch.from_numpy(np.ascontiguousarray(gt_seq)) # uint8 46 | 47 | # gt: thwc|rgb|uint8 48 | return { 49 | 'gt': gt_tsr, 50 | 'seq_idx': key, 51 | 'frm_idx': sorted(os.listdir(osp.join(self.gt_seq_dir, key))) 52 | } 53 | -------------------------------------------------------------------------------- /codes/data/unpaired_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from .base_dataset import BaseDataset 9 | 10 | 11 | class UnpairedLMDBDataset(BaseDataset): 12 | """ LMDB dataset for unpaired data (for BD degradation) 13 | """ 14 | 15 | def __init__(self, data_opt, **kwargs): 16 | super(UnpairedLMDBDataset, self).__init__(data_opt, **kwargs) 17 | 18 | # load meta info 19 | meta = pickle.load( 20 | open(osp.join(self.seq_dir, 'meta_info.pkl'), 'rb')) 21 | self.keys = sorted(meta['keys']) 22 | 23 | # use partial videos 24 | if hasattr(self, 'filter_file') and self.filter_file is not None: 25 | with open(self.filter_file, 'r') as f: 26 | sel_seqs = { line.strip() for line in f } 27 | self.keys = list(filter( 28 | lambda x: self.parse_lmdb_key(x)[0] in sel_seqs, self.keys)) 29 | 30 | # register parameters 31 | self.env = None 32 | 33 | def __len__(self): 34 | return len(self.keys) 35 | 36 | def __getitem__(self, item): 37 | if self.env is None: 38 | self.env = self.init_lmdb(self.seq_dir) 39 | 40 | # parse info 41 | key = self.keys[item] 42 | idx, (tot_frm, h, w), cur_frm = self.parse_lmdb_key(key) 43 | c = 3 if self.data_type.lower() == 'rgb' else 1 44 | 45 | # get frames 46 | frms = [] 47 | if self.moving_first_frame and (random.uniform(0, 1) > self.moving_factor): 48 | # load data 49 | frm = self.read_lmdb_frame(self.env, key, size=(h, w, c)) 50 | frm = frm.transpose(2, 0, 1) # chw|rgb|uint8 51 | 52 | # generate random moving parameters 53 | offsets = np.floor( 54 | np.random.uniform(-3.5, 4.5, size=(self.tempo_extent, 2))) 55 | offsets = offsets.astype(np.int32) 56 | pos = np.cumsum(offsets, axis=0) 57 | min_pos = np.min(pos, axis=0) 58 | topleft_pos = pos - min_pos 59 | range_pos = np.max(pos, axis=0) - min_pos 60 | c_h, c_w = h - range_pos[0], w - range_pos[1] 61 | 62 | # generate frames 63 | for i in range(self.tempo_extent): 64 | top, left = topleft_pos[i] 65 | frms.append(frm[:, top: top + c_h, left: left + c_w].copy()) 66 | else: 67 | # read frames 68 | for i in range(cur_frm, cur_frm + self.tempo_extent): 69 | if i >= tot_frm: 70 | # reflect temporal paddding, e.g., (0,1,2) -> (0,1,2,1,0) 71 | key = '{}_{}x{}x{}_{:04d}'.format( 72 | idx, tot_frm, h, w, 2 * tot_frm - i - 2) 73 | else: 74 | key = '{}_{}x{}x{}_{:04d}'.format( 75 | idx, tot_frm, h, w, i) 76 | 77 | frm = self.read_lmdb_frame(self.env, key, size=(h, w, c)) 78 | frm = frm.transpose(2, 0, 1) # chw|rgb|uint8 79 | frms.append(frm) 80 | 81 | frms = np.stack(frms) # tchw|rgb|uint8 82 | 83 | # crop randomly 84 | pats = self.crop_sequence(frms) 85 | 86 | # augment patches 87 | pats = self.augment_sequence(pats) 88 | 89 | # convert to tensor and normalize to range [0, 1] 90 | tsr = torch.FloatTensor(np.ascontiguousarray(pats)) / 255.0 91 | 92 | # tchw|rgb|float32 93 | return {'gt': tsr} 94 | 95 | def crop_sequence(self, frms): 96 | csz = self.crop_size 97 | 98 | h, w = frms.shape[-2:] 99 | assert (csz <= h) and (csz <= w), \ 100 | f'The crop size is larger than the image size ({csz} vs. h{h}/w{w})' 101 | 102 | # crop 103 | top = random.randint(0, h - csz) 104 | left = random.randint(0, w - csz) 105 | pats = frms[..., top: top + csz, left: left + csz] 106 | 107 | return pats 108 | 109 | @staticmethod 110 | def augment_sequence(pats): 111 | # flip spatially 112 | axis = random.randint(1, 3) 113 | if axis > 1: 114 | pats = np.flip(pats, axis) 115 | 116 | # flip temporally 117 | axis = random.randint(0, 1) 118 | if axis < 1: 119 | pats = np.flip(pats, axis) 120 | 121 | # rotate 122 | k = random.randint(0, 3) 123 | pats = np.rot90(pats, k, (2, 3)) 124 | 125 | return pats 126 | -------------------------------------------------------------------------------- /codes/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import math 4 | import time 5 | 6 | import torch 7 | 8 | from data import create_dataloader 9 | from models import define_model 10 | from metrics import create_metric_calculator 11 | from utils import dist_utils, base_utils, data_utils 12 | 13 | 14 | def train(opt): 15 | # print configurations 16 | base_utils.log_info(f'{20*"-"} Configurations {20*"-"}') 17 | base_utils.print_options(opt) 18 | 19 | # create data loader 20 | train_loader = create_dataloader(opt, phase='train', idx='train') 21 | 22 | # build model 23 | model = define_model(opt) 24 | 25 | # set training params 26 | total_sample, iter_per_epoch = len(train_loader.dataset), len(train_loader) 27 | total_iter = opt['train']['total_iter'] 28 | total_epoch = int(math.ceil(total_iter / iter_per_epoch)) 29 | start_iter, iter = opt['train']['start_iter'], 0 30 | test_freq = opt['test']['test_freq'] 31 | log_freq = opt['logger']['log_freq'] 32 | ckpt_freq = opt['logger']['ckpt_freq'] 33 | 34 | base_utils.log_info(f'Number of the training samples: {total_sample}') 35 | base_utils.log_info(f'{total_epoch} epochs needed for {total_iter} iterations') 36 | 37 | # train 38 | for epoch in range(total_epoch): 39 | if opt['dist']: 40 | train_loader.sampler.set_epoch(epoch) 41 | 42 | for data in train_loader: 43 | # update iter 44 | iter += 1 45 | curr_iter = start_iter + iter 46 | if iter > total_iter: break 47 | 48 | # prepare data 49 | model.prepare_training_data(data) 50 | 51 | # train a mini-batch 52 | model.train() 53 | 54 | # update running log 55 | model.update_running_log() 56 | 57 | # update learning rate 58 | model.update_learning_rate() 59 | 60 | # print messages 61 | if log_freq > 0 and curr_iter % log_freq == 0: 62 | msg = model.get_format_msg(epoch, curr_iter) 63 | base_utils.log_info(msg) 64 | 65 | # save model 66 | if ckpt_freq > 0 and curr_iter % ckpt_freq == 0: 67 | model.save(curr_iter) 68 | 69 | # evaluate model 70 | if test_freq > 0 and curr_iter % test_freq == 0: 71 | # set model index 72 | model_idx = f'G_iter{curr_iter}' 73 | 74 | # for each testset 75 | for dataset_idx in sorted(opt['dataset'].keys()): 76 | # select test dataset 77 | if 'test' not in dataset_idx: continue 78 | 79 | ds_name = opt['dataset'][dataset_idx]['name'] 80 | base_utils.log_info(f'Testing on {ds_name} dataset') 81 | 82 | # create data loader 83 | test_loader = create_dataloader( 84 | opt, phase='test', idx=dataset_idx) 85 | test_dataset = test_loader.dataset 86 | num_seq = len(test_dataset) 87 | 88 | # create metric calculator 89 | metric_calculator = create_metric_calculator(opt) 90 | 91 | # infer a sequence 92 | rank, world_size = dist_utils.get_dist_info() 93 | for idx in range(rank, num_seq, world_size): 94 | # fetch data 95 | data = test_dataset[idx] 96 | 97 | # prepare data 98 | model.prepare_inference_data(data) 99 | 100 | # infer 101 | hr_seq = model.infer() 102 | 103 | # save hr results 104 | if opt['test']['save_res']: 105 | res_dir = osp.join( 106 | opt['test']['res_dir'], ds_name, model_idx) 107 | res_seq_dir = osp.join(res_dir, data['seq_idx']) 108 | data_utils.save_sequence( 109 | res_seq_dir, hr_seq, data['frm_idx'], to_bgr=True) 110 | 111 | # compute metrics for the current sequence 112 | if metric_calculator is not None: 113 | gt_seq = data['gt'].numpy() 114 | metric_calculator.compute_sequence_metrics( 115 | data['seq_idx'], gt_seq, hr_seq) 116 | 117 | # save/print results 118 | if metric_calculator is not None: 119 | seq_idx_lst = [data['seq_idx'] for data in test_dataset] 120 | metric_calculator.gather(seq_idx_lst) 121 | 122 | if opt['test'].get('save_json'): 123 | # write results to a json file 124 | json_path = osp.join( 125 | opt['test']['json_dir'], f'{ds_name}_avg.json') 126 | metric_calculator.save(model_idx, json_path, override=True) 127 | else: 128 | # print directly 129 | metric_calculator.display() 130 | 131 | 132 | def test(opt): 133 | # logging 134 | base_utils.print_options(opt) 135 | 136 | # infer and evaluate performance for each model 137 | for load_path in opt['model']['generator']['load_path_lst']: 138 | # set model index 139 | model_idx = osp.splitext(osp.split(load_path)[-1])[0] 140 | 141 | # log 142 | base_utils.log_info(f'{"=" * 40}') 143 | base_utils.log_info(f'Testing model: {model_idx}') 144 | base_utils.log_info(f'{"=" * 40}') 145 | 146 | # create model 147 | opt['model']['generator']['load_path'] = load_path 148 | model = define_model(opt) 149 | 150 | # for each test dataset 151 | for dataset_idx in sorted(opt['dataset'].keys()): 152 | # select testing dataset 153 | if 'test' not in dataset_idx: 154 | continue 155 | 156 | ds_name = opt['dataset'][dataset_idx]['name'] 157 | base_utils.log_info(f'Testing on {ds_name} dataset') 158 | 159 | # create data loader 160 | test_loader = create_dataloader(opt, phase='test', idx=dataset_idx) 161 | test_dataset = test_loader.dataset 162 | num_seq = len(test_dataset) 163 | 164 | # create metric calculator 165 | metric_calculator = create_metric_calculator(opt) 166 | 167 | # infer a sequence 168 | rank, world_size = dist_utils.get_dist_info() 169 | for idx in range(rank, num_seq, world_size): 170 | # fetch data 171 | data = test_dataset[idx] 172 | 173 | # prepare data 174 | model.prepare_inference_data(data) 175 | 176 | # infer 177 | hr_seq = model.infer() 178 | 179 | # save hr results 180 | if opt['test']['save_res']: 181 | res_dir = osp.join( 182 | opt['test']['res_dir'], ds_name, model_idx) 183 | res_seq_dir = osp.join(res_dir, data['seq_idx']) 184 | data_utils.save_sequence( 185 | res_seq_dir, hr_seq, data['frm_idx'], to_bgr=True) 186 | 187 | # compute metrics for the current sequence 188 | if metric_calculator is not None: 189 | gt_seq = data['gt'].numpy() 190 | metric_calculator.compute_sequence_metrics( 191 | data['seq_idx'], gt_seq, hr_seq) 192 | 193 | # save/print results 194 | if metric_calculator is not None: 195 | seq_idx_lst = [data['seq_idx'] for data in test_dataset] 196 | metric_calculator.gather(seq_idx_lst) 197 | 198 | if opt['test'].get('save_json'): 199 | # write results to a json file 200 | json_path = osp.join( 201 | opt['test']['json_dir'], f'{ds_name}_avg.json') 202 | metric_calculator.save(model_idx, json_path, override=True) 203 | else: 204 | # print directly 205 | metric_calculator.display() 206 | 207 | base_utils.log_info('-' * 40) 208 | 209 | 210 | def profile(opt, lr_size, test_speed=False): 211 | # basic configs 212 | scale = opt['scale'] 213 | device = torch.device(opt['device']) 214 | msg = '\n' 215 | 216 | torch.backends.cudnn.benchmark = True 217 | # torch.backends.cudnn.deterministic = False 218 | 219 | # logging 220 | base_utils.print_options(opt['model']['generator']) 221 | 222 | lr_size_lst = tuple(map(int, lr_size.split('x'))) 223 | hr_size = f'{lr_size_lst[0]}x{lr_size_lst[1]*scale}x{lr_size_lst[2]*scale}' 224 | msg += f'{"*"*40}\nResolution: {lr_size} -> {hr_size} ({scale}x SR)' 225 | 226 | # create model 227 | from models.networks import define_generator 228 | net_G = define_generator(opt).to(device) 229 | # base_utils.log_info(f'\n{net_G.__str__()}') 230 | 231 | # profile 232 | lr_size = tuple(map(int, lr_size.split('x'))) 233 | gflops_dict, params_dict = net_G.profile(lr_size, device) 234 | 235 | gflops_all, params_all = 0, 0 236 | for module_name in gflops_dict.keys(): 237 | gflops, params = gflops_dict[module_name], params_dict[module_name] 238 | msg += f'\n{"-"*40}\nModule: [{module_name}]' 239 | msg += f'\n FLOPs (10^9): {gflops:.3f}' 240 | msg += f'\n Parameters (10^6): {params/1e6:.3f}' 241 | gflops_all += gflops 242 | params_all += params 243 | msg += f'\n{"-"*40}\nOverall' 244 | msg += f'\n FLOPs (10^9): {gflops_all:.3f}' 245 | msg += f'\n Parameters (10^6): {params_all/1e6:.3f}\n{"*"*40}' 246 | 247 | # test running speed 248 | if test_speed: 249 | n_test, tot_time = 30, 0 250 | for i in range(n_test): 251 | dummy_input_list = net_G.generate_dummy_data(lr_size, device) 252 | 253 | start_time = time.time() 254 | # --- 255 | net_G.eval() 256 | with torch.no_grad(): 257 | _ = net_G.step(*dummy_input_list) 258 | torch.cuda.synchronize() 259 | # --- 260 | end_time = time.time() 261 | tot_time += end_time - start_time 262 | msg += f'\nSpeed: {n_test/tot_time:.3f} FPS (averaged over {n_test} runs)\n{"*"*40}' 263 | 264 | base_utils.log_info(msg) 265 | 266 | 267 | if __name__ == '__main__': 268 | # === parse arguments === # 269 | args = base_utils.parse_agrs() 270 | 271 | # === generic settings === # 272 | # parse configs, set device, set ramdom seed 273 | opt = base_utils.parse_configs(args) 274 | # set logger 275 | base_utils.setup_logger('base') 276 | # set paths 277 | base_utils.setup_paths(opt, args.mode) 278 | 279 | # === train === # 280 | if args.mode == 'train': 281 | train(opt) 282 | 283 | # === test === # 284 | elif args.mode == 'test': 285 | test(opt) 286 | 287 | # === profile === # 288 | elif args.mode == 'profile': 289 | profile(opt, args.lr_size, args.test_speed) 290 | 291 | else: 292 | raise ValueError(f'Unrecognized mode: {args.mode} (train|test|profile)') 293 | -------------------------------------------------------------------------------- /codes/metrics/LPIPS/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | -------------------------------------------------------------------------------- /codes/metrics/LPIPS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/metrics/LPIPS/__init__.py -------------------------------------------------------------------------------- /codes/metrics/LPIPS/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from metrics.LPIPS.models import dist_model 12 | 13 | 14 | class PerceptualLoss(torch.nn.Module): 15 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 16 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 17 | super(PerceptualLoss, self).__init__() 18 | print('Setting up Perceptual loss...') 19 | self.use_gpu = use_gpu 20 | self.spatial = spatial 21 | self.gpu_ids = gpu_ids 22 | self.model = dist_model.DistModel() 23 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 24 | print('...[%s] initialized'%self.model.name()) 25 | print('...Done') 26 | 27 | def forward(self, pred, target, normalize=False): 28 | """ 29 | Pred and target are Variables. 30 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 31 | If normalize is False, assumes the images are already between [-1,+1] 32 | 33 | Inputs pred and target are Nx3xHxW 34 | Output pytorch Variable N long 35 | """ 36 | 37 | if normalize: 38 | target = 2 * target - 1 39 | pred = 2 * pred - 1 40 | 41 | return self.model.forward(target, pred) 42 | 43 | def normalize_tensor(in_feat,eps=1e-10): 44 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 45 | return in_feat/(norm_factor+eps) 46 | 47 | def l2(p0, p1, range=255.): 48 | return .5*np.mean((p0 / range - p1 / range)**2) 49 | 50 | def psnr(p0, p1, peak=255.): 51 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 52 | 53 | def dssim(p0, p1, range=255.): 54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 55 | 56 | def rgb2lab(in_img,mean_cent=False): 57 | from skimage import color 58 | img_lab = color.rgb2lab(in_img) 59 | if(mean_cent): 60 | img_lab[:,:,0] = img_lab[:,:,0]-50 61 | return img_lab 62 | 63 | def tensor2np(tensor_obj): 64 | # change dimension of a tensor object into a numpy array 65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 66 | 67 | def np2tensor(np_obj): 68 | # change dimenion of np array into tensor array 69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 70 | 71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 72 | # image tensor to lab tensor 73 | from skimage import color 74 | 75 | img = tensor2im(image_tensor) 76 | img_lab = color.rgb2lab(img) 77 | if(mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | if(to_norm and not mc_only): 80 | img_lab[:,:,0] = img_lab[:,:,0]-50 81 | img_lab = img_lab/100. 82 | 83 | return np2tensor(img_lab) 84 | 85 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 86 | from skimage import color 87 | import warnings 88 | warnings.filterwarnings("ignore") 89 | 90 | lab = tensor2np(lab_tensor)*100. 91 | lab[:,:,0] = lab[:,:,0]+50 92 | 93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 94 | if(return_inbnd): 95 | # convert back to lab, see if we match 96 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 97 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 99 | return (im2tensor(rgb_back),mask) 100 | else: 101 | return im2tensor(rgb_back) 102 | 103 | def rgb2lab(input): 104 | from skimage import color 105 | return color.rgb2lab(input / 255.) 106 | 107 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 108 | image_numpy = image_tensor[0].cpu().float().numpy() 109 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 110 | return image_numpy.astype(imtype) 111 | 112 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 113 | return torch.Tensor((image / factor - cent) 114 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 115 | 116 | def tensor2vec(vector_tensor): 117 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 118 | 119 | def voc_ap(rec, prec, use_07_metric=False): 120 | """ ap = voc_ap(rec, prec, [use_07_metric]) 121 | Compute VOC AP given precision and recall. 122 | If use_07_metric is true, uses the 123 | VOC 07 11 point method (default:False). 124 | """ 125 | if use_07_metric: 126 | # 11 point metric 127 | ap = 0. 128 | for t in np.arange(0., 1.1, 0.1): 129 | if np.sum(rec >= t) == 0: 130 | p = 0 131 | else: 132 | p = np.max(prec[rec >= t]) 133 | ap = ap + p / 11. 134 | else: 135 | # correct AP calculation 136 | # first append sentinel values at the end 137 | mrec = np.concatenate(([0.], rec, [1.])) 138 | mpre = np.concatenate(([0.], prec, [0.])) 139 | 140 | # compute the precision envelope 141 | for i in range(mpre.size - 1, 0, -1): 142 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 143 | 144 | # to calculate area under PR curve, look for points 145 | # where X axis (recall) changes value 146 | i = np.where(mrec[1:] != mrec[:-1])[0] 147 | 148 | # and sum (\Delta recall) * prec 149 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 150 | return ap 151 | 152 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 153 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 154 | image_numpy = image_tensor[0].cpu().float().numpy() 155 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 156 | return image_numpy.astype(imtype) 157 | 158 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 159 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 160 | return torch.Tensor((image / factor - cent) 161 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 162 | -------------------------------------------------------------------------------- /codes/metrics/LPIPS/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /codes/metrics/LPIPS/models/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from pdb import set_trace as st 11 | from skimage import color 12 | from IPython import embed 13 | from . import pretrained_networks as pn 14 | 15 | import metrics.LPIPS.models as util 16 | 17 | def spatial_average(in_tens, keepdim=True): 18 | return in_tens.mean([2,3],keepdim=keepdim) 19 | 20 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W 21 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 22 | scale_factor_H, scale_factor_W = 1.*out_HW[0]/in_H, 1.*out_HW[1]/in_W 23 | 24 | return nn.Upsample(scale_factor=(scale_factor_H, scale_factor_W), mode='bilinear', align_corners=False)(in_tens) 25 | 26 | # Learned perceptual metric 27 | class PNetLin(nn.Module): 28 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 29 | super(PNetLin, self).__init__() 30 | 31 | self.pnet_type = pnet_type 32 | self.pnet_tune = pnet_tune 33 | self.pnet_rand = pnet_rand 34 | self.spatial = spatial 35 | self.lpips = lpips 36 | self.version = version 37 | self.scaling_layer = ScalingLayer() 38 | 39 | if(self.pnet_type in ['vgg','vgg16']): 40 | net_type = pn.vgg16 41 | self.chns = [64,128,256,512,512] 42 | elif(self.pnet_type=='alex'): 43 | net_type = pn.alexnet 44 | self.chns = [64,192,384,256,256] 45 | elif(self.pnet_type=='squeeze'): 46 | net_type = pn.squeezenet 47 | self.chns = [64,128,256,384,384,512,512] 48 | self.L = len(self.chns) 49 | 50 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 51 | 52 | if(lpips): 53 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 54 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 55 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 56 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 57 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 58 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 59 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 60 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 61 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 62 | self.lins+=[self.lin5,self.lin6] 63 | 64 | def forward(self, in0, in1, retPerLayer=False): 65 | # v0.0 - original release had a bug, where input was not scaled 66 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 67 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 68 | feats0, feats1, diffs = {}, {}, {} 69 | 70 | for kk in range(self.L): 71 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 72 | diffs[kk] = (feats0[kk]-feats1[kk])**2 73 | 74 | if(self.lpips): 75 | if(self.spatial): 76 | res = [upsample(self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 77 | else: 78 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 79 | else: 80 | if(self.spatial): 81 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 82 | else: 83 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 84 | 85 | val = res[0] 86 | for l in range(1,self.L): 87 | val += res[l] 88 | 89 | if(retPerLayer): 90 | return (val, res) 91 | else: 92 | return val 93 | 94 | class ScalingLayer(nn.Module): 95 | def __init__(self): 96 | super(ScalingLayer, self).__init__() 97 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 98 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 99 | 100 | def forward(self, inp): 101 | return (inp - self.shift) / self.scale 102 | 103 | 104 | class NetLinLayer(nn.Module): 105 | ''' A single linear layer which does a 1x1 conv ''' 106 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 107 | super(NetLinLayer, self).__init__() 108 | 109 | layers = [nn.Dropout(),] if(use_dropout) else [] 110 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 111 | self.model = nn.Sequential(*layers) 112 | 113 | 114 | class Dist2LogitLayer(nn.Module): 115 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 116 | def __init__(self, chn_mid=32, use_sigmoid=True): 117 | super(Dist2LogitLayer, self).__init__() 118 | 119 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 120 | layers += [nn.LeakyReLU(0.2,True),] 121 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 122 | layers += [nn.LeakyReLU(0.2,True),] 123 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 124 | if(use_sigmoid): 125 | layers += [nn.Sigmoid(),] 126 | self.model = nn.Sequential(*layers) 127 | 128 | def forward(self,d0,d1,eps=0.1): 129 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 130 | 131 | class BCERankingLoss(nn.Module): 132 | def __init__(self, chn_mid=32): 133 | super(BCERankingLoss, self).__init__() 134 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 135 | # self.parameters = list(self.net.parameters()) 136 | self.loss = torch.nn.BCELoss() 137 | 138 | def forward(self, d0, d1, judge): 139 | per = (judge+1.)/2. 140 | self.logit = self.net.forward(d0,d1) 141 | return self.loss(self.logit, per) 142 | 143 | # L2, DSSIM evaluation 144 | class FakeNet(nn.Module): 145 | def __init__(self, use_gpu=True, colorspace='Lab'): 146 | super(FakeNet, self).__init__() 147 | self.use_gpu = use_gpu 148 | self.colorspace=colorspace 149 | 150 | class L2(FakeNet): 151 | 152 | def forward(self, in0, in1, retPerLayer=None): 153 | assert(in0.size()[0]==1) # currently only supports batchSize 1 154 | 155 | if(self.colorspace=='RGB'): 156 | (N,C,X,Y) = in0.size() 157 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 158 | return value 159 | elif(self.colorspace=='Lab'): 160 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 161 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 162 | ret_var = Variable( torch.Tensor((value,) ) ) 163 | if(self.use_gpu): 164 | ret_var = ret_var.cuda() 165 | return ret_var 166 | 167 | class DSSIM(FakeNet): 168 | 169 | def forward(self, in0, in1, retPerLayer=None): 170 | assert(in0.size()[0]==1) # currently only supports batchSize 1 171 | 172 | if(self.colorspace=='RGB'): 173 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 174 | elif(self.colorspace=='Lab'): 175 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 176 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 177 | ret_var = Variable( torch.Tensor((value,) ) ) 178 | if(self.use_gpu): 179 | ret_var = ret_var.cuda() 180 | return ret_var 181 | 182 | def print_network(net): 183 | num_params = 0 184 | for param in net.parameters(): 185 | num_params += param.numel() 186 | print('Network',net) 187 | print('Total number of parameters: %d' % num_params) 188 | -------------------------------------------------------------------------------- /codes/metrics/LPIPS/models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /codes/metrics/LPIPS/models/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/metrics/LPIPS/models/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /codes/metrics/LPIPS/models/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/metrics/LPIPS/models/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /codes/metrics/LPIPS/models/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/metrics/LPIPS/models/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /codes/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric_calculator import MetricCalculator 2 | 3 | 4 | def create_metric_calculator(opt): 5 | if 'metric' in opt and opt['metric']: 6 | return MetricCalculator(opt) 7 | else: 8 | return None 9 | -------------------------------------------------------------------------------- /codes/metrics/metric_calculator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import cv2 8 | import torch 9 | import torch.distributed as dist 10 | 11 | from utils import base_utils, data_utils, net_utils 12 | from utils.dist_utils import master_only 13 | from .LPIPS.models.dist_model import DistModel 14 | 15 | 16 | class MetricCalculator(): 17 | """ Metric calculator for model evaluation 18 | 19 | Currently supported metrics: 20 | * PSNR (RGB and Y) 21 | * LPIPS 22 | * tOF as described in TecoGAN paper 23 | 24 | TODO: 25 | * save/print metrics in a fixed order 26 | """ 27 | 28 | def __init__(self, opt): 29 | # initialize 30 | self.metric_opt = opt['metric'] 31 | self.device = torch.device(opt['device']) 32 | self.dist = opt['dist'] 33 | self.rank = opt['rank'] 34 | 35 | self.psnr_colorspace = '' 36 | self.dm = None 37 | 38 | self.reset() 39 | 40 | # update configs for each metric 41 | for metric_type, cfg in self.metric_opt.items(): 42 | if metric_type.lower() == 'psnr': 43 | self.psnr_colorspace = cfg['colorspace'] 44 | 45 | if metric_type.lower() == 'lpips': 46 | self.dm = DistModel() 47 | self.dm.initialize( 48 | model=cfg['model'], 49 | net=cfg['net'], 50 | colorspace=cfg['colorspace'], 51 | spatial=cfg['spatial'], 52 | use_gpu=(opt['device'] == 'cuda'), 53 | gpu_ids=[0 if not self.dist else opt['local_rank']], 54 | version=cfg['version']) 55 | 56 | def reset(self): 57 | self.reset_per_sequence() 58 | self.metric_dict = OrderedDict() 59 | self.avg_metric_dict = OrderedDict() 60 | 61 | def reset_per_sequence(self): 62 | self.seq_idx_curr = '' 63 | self.true_img_cur = None 64 | self.pred_img_cur = None 65 | self.true_img_pre = None 66 | self.pred_img_pre = None 67 | 68 | def gather(self, seq_idx_lst): 69 | """ Gather results from all devices. 70 | Results will be updated into self.metric_dict on device 0 71 | """ 72 | 73 | # mdict 74 | # { 75 | # 'seq_idx': { 76 | # 'metric1': [frm1, frm2, ...], 77 | # 'metric2': [frm1, frm2, ...] 78 | # } 79 | # } 80 | mdict = self.metric_dict 81 | 82 | mtype_lst = self.metric_opt.keys() 83 | n_metric = len(mtype_lst) 84 | 85 | # avg_mdict 86 | # { 87 | # 'seq_idx': torch.tensor([metric1_avg, metric2_avg, ...]) 88 | # } 89 | avg_mdict = { 90 | seq_idx: torch.zeros(n_metric, dtype=torch.float32, device=self.device) 91 | for seq_idx in seq_idx_lst 92 | } 93 | 94 | # average metric results for each sequence 95 | for i, mtype in enumerate(mtype_lst): 96 | for seq_idx, mdict_per_seq in mdict.items(): # ordered 97 | avg_mdict[seq_idx][i] += np.mean(mdict_per_seq[mtype]) 98 | 99 | if self.dist: 100 | for seq_idx, tensor in avg_mdict.items(): 101 | dist.reduce(tensor, dst=0) 102 | dist.barrier() 103 | 104 | if self.rank == 0: 105 | # avg_metric_dict 106 | # { 107 | # 'seq_idx': { 108 | # 'metric1': avg, 109 | # 'metric2': avg 110 | # } 111 | # } 112 | 113 | for seq_idx in seq_idx_lst: 114 | self.avg_metric_dict[seq_idx] = OrderedDict([ 115 | (mtype, avg_mdict[seq_idx][i].item()) 116 | for i, mtype in enumerate(mtype_lst) 117 | ]) 118 | 119 | def average(self): 120 | """ Return a dict including metric results averaged over all sequence 121 | """ 122 | 123 | metric_avg_dict = OrderedDict() 124 | for mtype in self.metric_opt.keys(): 125 | metric_all_seq = [] 126 | for sqe_idx, mdict_per_seq in self.avg_metric_dict.items(): 127 | metric_all_seq.append(mdict_per_seq[mtype]) 128 | 129 | metric_avg_dict[mtype] = np.mean(metric_all_seq) 130 | 131 | return metric_avg_dict 132 | 133 | @master_only 134 | def display(self): 135 | # per sequence results 136 | for seq_idx, mdict_per_seq in self.avg_metric_dict.items(): 137 | base_utils.log_info(f'Sequence: {seq_idx}') 138 | for mtype in self.metric_opt.keys(): 139 | base_utils.log_info(f'\t{mtype}: {mdict_per_seq[mtype]:.6f}') 140 | 141 | # average results 142 | base_utils.log_info('Average') 143 | metric_avg_dict = self.average() 144 | for mtype, value in metric_avg_dict.items(): 145 | base_utils.log_info(f'\t{mtype}: {value:.6f}') 146 | 147 | @master_only 148 | def save(self, model_idx, save_path, average=True, override=False): 149 | # load previous results if existed 150 | if osp.exists(save_path): 151 | with open(save_path, 'r') as f: 152 | json_dict = json.load(f) 153 | else: 154 | json_dict = dict() 155 | 156 | # update 157 | if model_idx not in json_dict: 158 | json_dict[model_idx] = OrderedDict() 159 | 160 | if average: 161 | metric_avg_dict = self.average() 162 | for mtype, value in metric_avg_dict.items(): 163 | # override or skip 164 | if mtype in json_dict[model_idx] and not override: 165 | continue 166 | json_dict[model_idx][mtype] = f'{value:.6f}' 167 | else: 168 | # TODO: save results of each sequence 169 | raise NotImplementedError() 170 | 171 | # sort 172 | json_dict = OrderedDict(sorted( 173 | json_dict.items(), key=lambda x: int(x[0].replace('G_iter', '')))) 174 | 175 | # save results 176 | with open(save_path, 'w') as f: 177 | json.dump(json_dict, f, sort_keys=False, indent=4) 178 | 179 | def compute_sequence_metrics(self, seq_idx, true_seq, pred_seq): 180 | # clear 181 | self.reset_per_sequence() 182 | 183 | # initialize metric_dict for the current sequence 184 | self.seq_idx_curr = seq_idx 185 | self.metric_dict[self.seq_idx_curr] = OrderedDict({ 186 | metric: [] for metric in self.metric_opt.keys()}) 187 | 188 | # compute metrics for each frame 189 | tot_frm = true_seq.shape[0] 190 | for i in range(tot_frm): 191 | self.true_img_cur = true_seq[i] # hwc|rgb/y|uint8 192 | self.pred_img_cur = pred_seq[i] 193 | 194 | # pred_img and true_img may have different sizes 195 | # crop the larger one to match the smaller one 196 | true_h, true_w = self.true_img_cur.shape[:-1] 197 | pred_h, pred_w = self.pred_img_cur.shape[:-1] 198 | min_h, min_w = min(true_h, pred_h), min(true_w, pred_w) 199 | self.true_img_cur = self.true_img_cur[:min_h, :min_w, :] 200 | self.pred_img_cur = self.pred_img_cur[:min_h, :min_w, :] 201 | 202 | # compute metrics for the current frame 203 | self.compute_frame_metrics() 204 | 205 | # update 206 | self.true_img_pre = self.true_img_cur 207 | self.pred_img_pre = self.pred_img_cur 208 | 209 | def compute_frame_metrics(self): 210 | metric_dict = self.metric_dict[self.seq_idx_curr] 211 | 212 | # compute evaluation 213 | for metric_type, opt in self.metric_opt.items(): 214 | if metric_type == 'PSNR': 215 | PSNR = self.compute_PSNR() 216 | metric_dict['PSNR'].append(PSNR) 217 | 218 | elif metric_type == 'LPIPS': 219 | LPIPS = self.compute_LPIPS()[0, 0, 0, 0].item() 220 | metric_dict['LPIPS'].append(LPIPS) 221 | 222 | elif metric_type == 'tOF': 223 | # skip the first frame 224 | if self.pred_img_pre is not None: 225 | tOF = self.compute_tOF() 226 | metric_dict['tOF'].append(tOF) 227 | 228 | def compute_PSNR(self): 229 | if self.psnr_colorspace == 'rgb': 230 | true_img = self.true_img_cur 231 | pred_img = self.pred_img_cur 232 | else: 233 | # convert to ycbcr, and keep the y channel 234 | true_img = data_utils.rgb_to_ycbcr(self.true_img_cur)[..., 0] 235 | pred_img = data_utils.rgb_to_ycbcr(self.pred_img_cur)[..., 0] 236 | 237 | diff = true_img.astype(np.float64) - pred_img.astype(np.float64) 238 | RMSE = np.sqrt(np.mean(np.power(diff, 2))) 239 | 240 | if RMSE == 0: 241 | return np.inf 242 | 243 | PSNR = 20 * np.log10(255.0 / RMSE) 244 | return PSNR 245 | 246 | def compute_LPIPS(self): 247 | true_img = np.ascontiguousarray(self.true_img_cur) 248 | pred_img = np.ascontiguousarray(self.pred_img_cur) 249 | 250 | # to tensor 251 | true_img = torch.FloatTensor(true_img).unsqueeze(0).permute(0, 3, 1, 2) 252 | pred_img = torch.FloatTensor(pred_img).unsqueeze(0).permute(0, 3, 1, 2) 253 | 254 | # normalize to [-1, 1] 255 | true_img = true_img.to(self.device) * 2.0 / 255.0 - 1.0 256 | pred_img = pred_img.to(self.device) * 2.0 / 255.0 - 1.0 257 | 258 | with torch.no_grad(): 259 | LPIPS = self.dm.forward(true_img, pred_img) 260 | 261 | return LPIPS 262 | 263 | def compute_tOF(self): 264 | true_img_cur = cv2.cvtColor(self.true_img_cur, cv2.COLOR_RGB2GRAY) 265 | pred_img_cur = cv2.cvtColor(self.pred_img_cur, cv2.COLOR_RGB2GRAY) 266 | true_img_pre = cv2.cvtColor(self.true_img_pre, cv2.COLOR_RGB2GRAY) 267 | pred_img_pre = cv2.cvtColor(self.pred_img_pre, cv2.COLOR_RGB2GRAY) 268 | 269 | # forward flow 270 | true_OF = cv2.calcOpticalFlowFarneback( 271 | true_img_pre, true_img_cur, None, 0.5, 3, 15, 3, 5, 1.2, 0) 272 | pred_OF = cv2.calcOpticalFlowFarneback( 273 | pred_img_pre, pred_img_cur, None, 0.5, 3, 15, 3, 5, 1.2, 0) 274 | 275 | # EPE 276 | diff_OF = true_OF - pred_OF 277 | tOF = np.mean(np.sqrt(np.sum(diff_OF**2, axis=-1))) 278 | 279 | return tOF 280 | -------------------------------------------------------------------------------- /codes/metrics/model_summary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # define which modules to be incorporated 6 | registered_module = [ 7 | nn.Conv2d, 8 | nn.ConvTranspose2d, 9 | nn.Conv3d 10 | ] 11 | 12 | # initialize 13 | registered_hooks, model_info_lst = [], [] 14 | 15 | 16 | def calc_2d_gflops_per_batch(module, out_h, out_w): 17 | """ Calculate flops of conv weights (support groups_conv & dilated_conv) 18 | """ 19 | gflops = 0 20 | if hasattr(module, 'weight'): 21 | # Note: in_c is already divided by groups while out_c is not 22 | bias = 0 if hasattr(module, 'bias') else -1 23 | out_c, in_c, k_h, k_w = module.weight.shape 24 | 25 | gflops += (2*in_c*k_h*k_w + bias)*out_c*out_h*out_w/1e9 26 | return gflops 27 | 28 | 29 | def calc_3d_gflops_per_batch(module, out_d, out_h, out_w): 30 | """ Calculate flops of conv weights (support groups_conv & dilated_conv) 31 | """ 32 | gflops = 0 33 | if hasattr(module, 'weight'): 34 | # Note: in_c is already divided by groups while out_c is not 35 | bias = 0 if hasattr(module, 'bias') else -1 36 | out_c, in_c, k_d, k_h, k_w = module.weight.shape 37 | 38 | gflops += (2*in_c*k_d*k_h*k_w + bias)*out_c*out_d*out_h*out_w/1e9 39 | return gflops 40 | 41 | 42 | def hook_fn_forward(module, input, output): 43 | if isinstance(module, nn.Conv3d): 44 | batch_size, _, out_d, out_h, out_w = output.size() 45 | gflops = batch_size*calc_3d_gflops_per_batch(module, out_d, out_h, out_w) 46 | else: 47 | if isinstance(module, nn.ConvTranspose2d): 48 | batch_size, _, out_h, out_w = input[0].size() 49 | else: 50 | batch_size, _, out_h, out_w = output.size() 51 | gflops = batch_size*calc_2d_gflops_per_batch(module, out_h, out_w) 52 | 53 | model_info_lst.append({'gflops': gflops}) 54 | 55 | 56 | def register_hook(module): 57 | if isinstance(module, tuple(registered_module)): 58 | registered_hooks.append(module.register_forward_hook(hook_fn_forward)) 59 | 60 | 61 | def register(model, dummy_input_list): 62 | # reset params 63 | global registered_hooks, model_info_lst 64 | registered_hooks, model_info_lst = [], [] 65 | 66 | # register hook 67 | model.apply(register_hook) 68 | 69 | # forward 70 | with torch.no_grad(): 71 | model.eval() 72 | out = model(*dummy_input_list) 73 | 74 | # remove hooks 75 | for hook in registered_hooks: 76 | hook.remove() 77 | 78 | return out 79 | 80 | 81 | def parse_model_info(model): 82 | tot_gflops = 0 83 | for module_info in model_info_lst: 84 | if module_info['gflops']: 85 | tot_gflops += module_info['gflops'] 86 | 87 | tot_params = 0 88 | for param in model.parameters(): 89 | tot_params += torch.prod(torch.tensor(param.size())).item() 90 | 91 | return tot_gflops, tot_params 92 | -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vsr_model import VSRModel 2 | from .vsrgan_model import VSRGANModel 3 | 4 | 5 | # register vsr model 6 | vsr_model_lst = [ 7 | 'frvsr', 8 | ] 9 | 10 | # register vsrgan model 11 | vsrgan_model_lst = [ 12 | 'tecogan', 13 | ] 14 | 15 | 16 | def define_model(opt): 17 | if opt['model']['name'].lower() in vsr_model_lst: 18 | model = VSRModel(opt) 19 | 20 | elif opt['model']['name'].lower() in vsrgan_model_lst: 21 | model = VSRGANModel(opt) 22 | 23 | else: 24 | raise ValueError(f'Unrecognized model: {opt["model"]["name"]}') 25 | 26 | return model 27 | -------------------------------------------------------------------------------- /codes/models/base_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os.path as osp 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel 9 | 10 | from utils.data_utils import create_kernel, downsample_bd 11 | from utils.dist_utils import master_only 12 | 13 | 14 | class BaseModel(): 15 | def __init__(self, opt): 16 | self.opt = opt 17 | self.scale = opt['scale'] 18 | self.device = torch.device(opt['device']) 19 | self.blur_kernel = None 20 | self.dist = opt['dist'] 21 | self.is_train = opt['is_train'] 22 | 23 | if self.is_train: 24 | self.lr_data, self.gt_data = None, None 25 | self.ckpt_dir = opt['train']['ckpt_dir'] 26 | self.log_decay = opt['logger'].get('decay', 0.99) 27 | self.log_dict = OrderedDict() 28 | self.running_log_dict = OrderedDict() 29 | 30 | def set_networks(self): 31 | pass 32 | 33 | def set_criterions(self): 34 | pass 35 | 36 | def set_optimizers(self): 37 | pass 38 | 39 | def set_lr_schedules(self): 40 | pass 41 | 42 | def prepare_training_data(self, data): 43 | """ prepare gt, lr data for training 44 | 45 | for BD degradation, generate lr data and remove the border of gt data 46 | for BI degradation, use input data directly 47 | """ 48 | 49 | degradation_type = self.opt['dataset']['degradation']['type'] 50 | 51 | if degradation_type == 'BI': 52 | self.gt_data = data['gt'].to(self.device) 53 | self.lr_data = data['lr'].to(self.device) 54 | 55 | elif degradation_type == 'BD': 56 | # generate lr data on the fly (on gpu) 57 | 58 | # set params 59 | scale = self.opt['scale'] 60 | sigma = self.opt['dataset']['degradation'].get('sigma', 1.5) 61 | border_size = int(sigma * 3.0) 62 | 63 | gt_data = data['gt'].to(self.device) # with border 64 | n, t, c, gt_h, gt_w = gt_data.size() 65 | lr_h = (gt_h - 2*border_size)//scale 66 | lr_w = (gt_w - 2*border_size)//scale 67 | 68 | # create blurring kernel 69 | if self.blur_kernel is None: 70 | self.blur_kernel = create_kernel(sigma).to(self.device) 71 | blur_kernel = self.blur_kernel 72 | 73 | # generate lr data 74 | gt_data = gt_data.view(n*t, c, gt_h, gt_w) 75 | lr_data = downsample_bd(gt_data, blur_kernel, scale, pad_data=False) 76 | lr_data = lr_data.view(n, t, c, lr_h, lr_w) 77 | 78 | # remove gt border 79 | gt_data = gt_data[ 80 | ..., 81 | border_size: border_size + scale*lr_h, 82 | border_size: border_size + scale*lr_w] 83 | gt_data = gt_data.view(n, t, c, scale*lr_h, scale*lr_w) 84 | 85 | self.gt_data, self.lr_data = gt_data, lr_data # tchw|float32 86 | 87 | def prepare_inference_data(self, data): 88 | """ Prepare lr data for training (w/o loading on device) 89 | """ 90 | 91 | degradation_type = self.opt['dataset']['degradation']['type'] 92 | 93 | if degradation_type == 'BI': 94 | self.lr_data = data['lr'] 95 | 96 | elif degradation_type == 'BD': 97 | if 'lr' in data: 98 | self.lr_data = data['lr'] 99 | else: 100 | # generate lr data on the fly (on cpu) 101 | # TODO: do frame-wise downsampling on gpu for acceleration? 102 | gt_data = data['gt'] # thwc|uint8 103 | 104 | # set params 105 | scale = self.opt['scale'] 106 | sigma = self.opt['dataset']['degradation'].get('sigma', 1.5) 107 | 108 | # create blurring kernel 109 | if self.blur_kernel is None: 110 | self.blur_kernel = create_kernel(sigma) 111 | blur_kernel = self.blur_kernel.cpu() 112 | 113 | # generate lr data 114 | gt_data = gt_data.permute(0, 3, 1, 2).float() / 255.0 # tchw|float32 115 | lr_data = downsample_bd( 116 | gt_data, blur_kernel, scale, pad_data=True) 117 | lr_data = lr_data.permute(0, 2, 3, 1) # thwc|float32 118 | 119 | self.lr_data = lr_data 120 | 121 | # thwc to tchw 122 | self.lr_data = self.lr_data.permute(0, 3, 1, 2) # tchw|float32 123 | 124 | def train(self): 125 | pass 126 | 127 | def infer(self): 128 | pass 129 | 130 | def model_to_device(self, net): 131 | net = net.to(self.device) 132 | if self.dist: 133 | net = nn.SyncBatchNorm.convert_sync_batchnorm(net) 134 | net = DistributedDataParallel( 135 | net, device_ids=[torch.cuda.current_device()]) 136 | return net 137 | 138 | def update_learning_rate(self): 139 | if hasattr(self, 'sched_G') and self.sched_G is not None: 140 | self.sched_G.step() 141 | 142 | if hasattr(self, 'sched_D') and self.sched_D is not None: 143 | self.sched_D.step() 144 | 145 | def get_learning_rate(self): 146 | lr_dict = OrderedDict() 147 | 148 | if hasattr(self, 'optim_G'): 149 | lr_dict['lr_G'] = self.optim_G.param_groups[0]['lr'] 150 | 151 | if hasattr(self, 'optim_D'): 152 | lr_dict['lr_D'] = self.optim_D.param_groups[0]['lr'] 153 | 154 | return lr_dict 155 | 156 | def reduce_log(self): 157 | if self.dist: 158 | rank, world_size = self.opt['rank'], self.opt['world_size'] 159 | with torch.no_grad(): 160 | keys, vals = [], [] 161 | for key, val in self.log_dict.items(): 162 | keys.append(key) 163 | vals.append(val) 164 | vals = torch.FloatTensor(vals).to(self.device) 165 | dist.reduce(vals, dst=0) 166 | if rank == 0: # average 167 | vals /= world_size 168 | self.log_dict = {key: val.item() for key, val in zip(keys, vals)} 169 | 170 | def update_running_log(self): 171 | self.reduce_log() # for distributed training 172 | 173 | d = self.log_decay 174 | for k in self.log_dict.keys(): 175 | current_val = self.log_dict[k] 176 | running_val = self.running_log_dict.get(k) 177 | 178 | if running_val is None: 179 | running_val = current_val 180 | else: 181 | running_val = d * running_val + (1.0 - d) * current_val 182 | 183 | self.running_log_dict[k] = running_val 184 | 185 | def get_current_log(self): 186 | return self.log_dict 187 | 188 | def get_running_log(self): 189 | return self.running_log_dict 190 | 191 | def get_format_msg(self, epoch, iter): 192 | # generic info 193 | msg = f'[epoch: {epoch} | iter: {iter}' 194 | for lr_type, lr in self.get_learning_rate().items(): 195 | msg += f' | {lr_type}: {lr:.2e}' 196 | msg += '] ' 197 | 198 | # loss info 199 | log_dict = self.get_running_log() 200 | msg += ', '.join([f'{k}: {v:.3e}' for k, v in log_dict.items()]) 201 | 202 | return msg 203 | 204 | def save(self, current_iter): 205 | pass 206 | 207 | @staticmethod 208 | def get_bare_model(net): 209 | if isinstance(net, DistributedDataParallel): 210 | net = net.module 211 | return net 212 | 213 | @master_only 214 | def save_network(self, net, net_label, current_iter): 215 | filename = f'{net_label}_iter{current_iter}.pth' 216 | save_path = osp.join(self.ckpt_dir, filename) 217 | net = self.get_bare_model(net) 218 | torch.save(net.state_dict(), save_path) 219 | 220 | def save_training_state(self, current_epoch, current_iter): 221 | # TODO 222 | pass 223 | 224 | def load_network(self, net, load_path): 225 | state_dict = torch.load( 226 | load_path, map_location=lambda storage, loc: storage) 227 | net = self.get_bare_model(net) 228 | net.load_state_dict(state_dict) 229 | 230 | def pad_sequence(self, lr_data): 231 | """ 232 | Parameters: 233 | :param lr_data: tensor in shape tchw 234 | """ 235 | padding_mode = self.opt['test'].get('padding_mode', 'reflect') 236 | n_pad_front = self.opt['test'].get('num_pad_front', 0) 237 | assert n_pad_front < lr_data.size(0) 238 | 239 | # pad 240 | if padding_mode == 'reflect': 241 | lr_data = torch.cat( 242 | [lr_data[1: 1 + n_pad_front, ...].flip(0), lr_data], dim=0) 243 | 244 | elif padding_mode == 'replicate': 245 | lr_data = torch.cat( 246 | [lr_data[:1, ...].expand(n_pad_front, -1, -1, -1), lr_data], dim=0) 247 | 248 | else: 249 | raise ValueError(f'Unrecognized padding mode: {padding_mode}') 250 | 251 | return lr_data, n_pad_front 252 | -------------------------------------------------------------------------------- /codes/models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .tecogan_nets import FRNet, SpatioTemporalDiscriminator, SpatialDiscriminator 2 | 3 | 4 | def define_generator(opt): 5 | net_G_opt = opt['model']['generator'] 6 | 7 | if net_G_opt['name'].lower() == 'frnet': # frame-recurrent generator 8 | net_G = FRNet( 9 | in_nc=net_G_opt['in_nc'], 10 | out_nc=net_G_opt['out_nc'], 11 | nf=net_G_opt['nf'], 12 | nb=net_G_opt['nb'], 13 | degradation=opt['dataset']['degradation']['type'], 14 | scale=opt['scale']) 15 | 16 | else: 17 | raise ValueError(f'Unrecognized generator: {net_G_opt["name"]}') 18 | 19 | return net_G 20 | 21 | 22 | def define_discriminator(opt): 23 | net_D_opt = opt['model']['discriminator'] 24 | 25 | if opt['dataset']['degradation']['type'] == 'BD': 26 | spatial_size = opt['dataset']['train']['crop_size'] 27 | else: # BI 28 | spatial_size = opt['dataset']['train']['gt_crop_size'] 29 | 30 | if net_D_opt['name'].lower() == 'stnet': # spatio-temporal discriminator 31 | net_D = SpatioTemporalDiscriminator( 32 | in_nc=net_D_opt['in_nc'], 33 | spatial_size=spatial_size, 34 | tempo_range=net_D_opt['tempo_range'], 35 | degradation=opt['dataset']['degradation']['type'], 36 | scale=opt['scale']) 37 | 38 | elif net_D_opt['name'].lower() == 'snet': # spatial discriminator 39 | net_D = SpatialDiscriminator( 40 | in_nc=net_D_opt['in_nc'], 41 | spatial_size=spatial_size, 42 | use_cond=net_D_opt['use_cond']) 43 | 44 | else: 45 | raise ValueError(f'Unrecognized discriminator: {net_D_opt["name"]}') 46 | 47 | return net_D 48 | -------------------------------------------------------------------------------- /codes/models/networks/base_nets.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BaseSequenceGenerator(nn.Module): 5 | def __init__(self): 6 | super(BaseSequenceGenerator, self).__init__() 7 | 8 | def generate_dummy_data(self, lr_size): 9 | """ Generate random input tensors for function `step` 10 | """ 11 | return None 12 | 13 | def profile(self, *args, **kwargs): 14 | pass 15 | 16 | def forward(self, *args, **kwargs): 17 | """ Interface (support DDP) 18 | """ 19 | pass 20 | 21 | def forward_sequence(self, lr_data): 22 | """ Forward a whole sequence (for training) 23 | """ 24 | pass 25 | 26 | def step(self, *args, **kwargs): 27 | """ Forward a single frame 28 | """ 29 | pass 30 | 31 | def infer_sequence(self, lr_data, device): 32 | """ Infer a whole sequence (for inference) 33 | """ 34 | pass 35 | 36 | 37 | class BaseSequenceDiscriminator(nn.Module): 38 | def __init__(self): 39 | super(BaseSequenceDiscriminator, self).__init__() 40 | 41 | def forward(self, *args, **kwargs): 42 | """ Interface (support DDP) 43 | """ 44 | pass 45 | 46 | def step(self, *args, **kwargs): 47 | """ Forward a singe frame 48 | """ 49 | pass 50 | 51 | def forward_sequence(self, data, args_dict): 52 | """ Forward a whole sequence (for training) 53 | """ 54 | pass 55 | -------------------------------------------------------------------------------- /codes/models/networks/vgg_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class VGGFeatureExtractor(nn.Module): 7 | def __init__(self, feature_indexs=(8, 17, 26, 35)): 8 | super(VGGFeatureExtractor, self).__init__() 9 | 10 | # init feature layers 11 | self.features = torchvision.models.vgg19(pretrained=True).features 12 | for param in self.features.parameters(): 13 | param.requires_grad = False 14 | 15 | # Notes: 16 | # 1. default feature layers are 8(conv2_2), 17(conv3_4), 26(conv4_4), 17 | # 35(conv5_4) 18 | # 2. features are extracted after ReLU activation 19 | self.feature_indexs = sorted(feature_indexs) 20 | 21 | # register normalization params 22 | mean = torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) # RGB 23 | std = torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 24 | self.register_buffer('mean', mean) 25 | self.register_buffer('std', std) 26 | 27 | def forward(self, x): 28 | # assume input ranges in [0, 1] 29 | out = (x - self.mean) / self.std 30 | 31 | feature_list = [] 32 | for i in range(len(self.features)): 33 | out = self.features[i](out) 34 | if i in self.feature_indexs: 35 | # clone to prevent overlapping by inplaced ReLU 36 | feature_list.append(out.clone()) 37 | 38 | return feature_list 39 | -------------------------------------------------------------------------------- /codes/models/optim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.optim as optim 3 | 4 | 5 | def define_criterion(criterion_opt): 6 | if criterion_opt is None: 7 | return None 8 | 9 | # parse 10 | if criterion_opt['type'] == 'MSE': 11 | criterion = nn.MSELoss(reduction=criterion_opt['reduction']) 12 | 13 | elif criterion_opt['type'] == 'L1': 14 | criterion = nn.L1Loss(reduction=criterion_opt['reduction']) 15 | 16 | elif criterion_opt['type'] == 'CB': 17 | from .losses import CharbonnierLoss 18 | criterion = CharbonnierLoss(reduction=criterion_opt['reduction']) 19 | 20 | elif criterion_opt['type'] == 'CosineSimilarity': 21 | from .losses import CosineSimilarityLoss 22 | criterion = CosineSimilarityLoss() 23 | 24 | elif criterion_opt['type'] == 'GAN': 25 | from .losses import VanillaGANLoss 26 | criterion = VanillaGANLoss(reduction=criterion_opt['reduction']) 27 | 28 | elif criterion_opt['type'] == 'LSGAN': 29 | from .losses import LSGANLoss 30 | criterion = LSGANLoss(reduction=criterion_opt['reduction']) 31 | 32 | else: 33 | raise ValueError(f'Unrecognized criterion: {criterion_opt["type"]}') 34 | 35 | return criterion 36 | 37 | 38 | def define_lr_schedule(schedule_opt, optimizer): 39 | if schedule_opt is None: 40 | return None 41 | 42 | # parse 43 | if schedule_opt['type'] == 'FixedLR': 44 | schedule = None 45 | 46 | elif schedule_opt['type'] == 'MultiStepLR': 47 | schedule = optim.lr_scheduler.MultiStepLR( 48 | optimizer, 49 | milestones=schedule_opt['milestones'], 50 | gamma=schedule_opt['gamma']) 51 | 52 | elif schedule_opt['type'] == 'CosineAnnealingRestartLR': 53 | from .lr_schedules import CosineAnnealingRestartLR 54 | schedule = CosineAnnealingRestartLR( 55 | optimizer, 56 | periods=schedule_opt['periods'], 57 | restart_weights=schedule_opt['restart_weights'], 58 | eta_min=schedule_opt['eta_min']) 59 | 60 | else: 61 | raise ValueError(f'Unrecognized lr schedule: {schedule_opt["type"]}') 62 | 63 | return schedule -------------------------------------------------------------------------------- /codes/models/optim/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VanillaGANLoss(nn.Module): 7 | def __init__(self, reduction='mean'): 8 | super(VanillaGANLoss, self).__init__() 9 | self.crit = nn.BCEWithLogitsLoss(reduction=reduction) 10 | 11 | def forward(self, input, status): 12 | target = torch.empty_like(input).fill_(int(status)) 13 | loss = self.crit(input, target) 14 | return loss 15 | 16 | 17 | class LSGANLoss(nn.Module): 18 | def __init__(self, reduction='mean'): 19 | super(LSGANLoss, self).__init__() 20 | self.crit = nn.MSELoss(reduction=reduction) 21 | 22 | def forward(self, input, status): 23 | """ 24 | :param status: boolean, True/False 25 | """ 26 | target = torch.empty_like(input).fill_(int(status)) 27 | loss = self.crit(input, target) 28 | return loss 29 | 30 | 31 | class CharbonnierLoss(nn.Module): 32 | """ Charbonnier Loss (robust L1) 33 | """ 34 | 35 | def __init__(self, eps=1e-6, reduction='sum'): 36 | super(CharbonnierLoss, self).__init__() 37 | self.eps = eps 38 | self.reduction = reduction 39 | 40 | def forward(self, x, y): 41 | diff = x - y 42 | loss = torch.sqrt(diff * diff + self.eps) 43 | 44 | if self.reduction == 'sum': 45 | loss = torch.sum(loss) 46 | elif self.reduction == 'mean': 47 | loss = torch.mean(loss) 48 | else: 49 | raise NotImplementedError 50 | return loss 51 | 52 | 53 | class CosineSimilarityLoss(nn.Module): 54 | def __init__(self, eps=1e-8): 55 | super(CosineSimilarityLoss, self).__init__() 56 | self.eps = eps 57 | 58 | def forward(self, input, target): 59 | diff = F.cosine_similarity(input, target, dim=1, eps=self.eps) 60 | loss = 1.0 - diff.mean() 61 | 62 | return loss 63 | -------------------------------------------------------------------------------- /codes/models/optim/lr_schedules.py: -------------------------------------------------------------------------------- 1 | """ Code adopted from BasicSR: https://github.com/xinntao/BasicSR/blob/master/basicsr/models/lr_scheduler.py 2 | """ 3 | 4 | import math 5 | 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | def get_position_from_periods(iteration, cumulative_period): 10 | """Get the position from a period list. 11 | 12 | It will return the index of the right-closest number in the period list. 13 | For example, the cumulative_period = [100, 200, 300, 400], 14 | if iteration == 50, return 0; 15 | if iteration == 210, return 2; 16 | if iteration == 300, return 2. 17 | 18 | Args: 19 | iteration (int): Current iteration. 20 | cumulative_period (list[int]): Cumulative period list. 21 | 22 | Returns: 23 | int: The position of the right-closest number in the period list. 24 | """ 25 | for i, period in enumerate(cumulative_period): 26 | if iteration <= period: 27 | return i 28 | 29 | 30 | class CosineAnnealingRestartLR(_LRScheduler): 31 | """ Cosine annealing with restarts learning rate scheme. 32 | 33 | An example of config: 34 | periods = [10, 10, 10, 10] 35 | restart_weights = [1, 0.5, 0.5, 0.5] 36 | eta_min=1e-7 37 | 38 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 39 | scheduler will restart with the weights in restart_weights. 40 | 41 | Args: 42 | optimizer (torch.nn.optimizer): Torch optimizer. 43 | periods (list): Period for each cosine anneling cycle. 44 | restart_weights (list): Restart weights at each restart iteration. 45 | Default: [1]. 46 | eta_min (float): The mimimum lr. Default: 0. 47 | last_epoch (int): Used in _LRScheduler. Default: -1. 48 | """ 49 | 50 | def __init__(self, 51 | optimizer, 52 | periods, 53 | restart_weights=(1, ), 54 | eta_min=0, 55 | last_epoch=-1): 56 | self.periods = periods 57 | self.restart_weights = restart_weights 58 | self.eta_min = eta_min 59 | assert (len(self.periods) == len(self.restart_weights) 60 | ), 'periods and restart_weights should have the same length.' 61 | self.cumulative_period = [ 62 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 63 | ] 64 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | idx = get_position_from_periods(self.last_epoch, 68 | self.cumulative_period) 69 | current_weight = self.restart_weights[idx] 70 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 71 | current_period = self.periods[idx] 72 | 73 | return [ 74 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 75 | (1 + math.cos(math.pi * ( 76 | (self.last_epoch - nearest_restart) / current_period))) 77 | for base_lr in self.base_lrs 78 | ] 79 | -------------------------------------------------------------------------------- /codes/models/vsr_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | from .base_model import BaseModel 7 | from .networks import define_generator 8 | from .optim import define_criterion, define_lr_schedule 9 | from utils import base_utils, net_utils, data_utils 10 | 11 | 12 | class VSRModel(BaseModel): 13 | """ A model wrapper for objective video super-resolution 14 | """ 15 | 16 | def __init__(self, opt): 17 | super(VSRModel, self).__init__(opt) 18 | 19 | # define network 20 | self.set_networks() 21 | 22 | # config training 23 | if self.is_train: 24 | self.set_criterions() 25 | self.set_optimizers() 26 | self.set_lr_schedules() 27 | 28 | def set_networks(self): 29 | # define generator 30 | self.net_G = define_generator(self.opt) 31 | self.net_G = self.model_to_device(self.net_G) 32 | base_utils.log_info('Generator: {}\n{}'.format( 33 | self.opt['model']['generator']['name'], self.net_G.__str__())) 34 | 35 | # load generator 36 | load_path_G = self.opt['model']['generator'].get('load_path') 37 | if load_path_G is not None: 38 | self.load_network(self.net_G, load_path_G) 39 | base_utils.log_info(f'Load generator from: {load_path_G}') 40 | 41 | def set_criterions(self): 42 | # pixel criterion 43 | self.pix_crit = define_criterion( 44 | self.opt['train'].get('pixel_crit')) 45 | 46 | # warping criterion 47 | self.warp_crit = define_criterion( 48 | self.opt['train'].get('warping_crit')) 49 | 50 | def set_optimizers(self): 51 | self.optim_G = optim.Adam( 52 | self.net_G.parameters(), 53 | lr=self.opt['train']['generator']['lr'], 54 | weight_decay=self.opt['train']['generator'].get('weight_decay', 0), 55 | betas=self.opt['train']['generator'].get('betas', (0.9, 0.999))) 56 | 57 | def set_lr_schedules(self): 58 | self.sched_G = define_lr_schedule( 59 | self.opt['train']['generator'].get('lr_schedule'), self.optim_G) 60 | 61 | def train(self): 62 | # === initialize === # 63 | self.net_G.train() 64 | self.optim_G.zero_grad() 65 | 66 | # === forward net_G === # 67 | net_G_output_dict = self.net_G(self.lr_data) 68 | self.hr_data = net_G_output_dict['hr_data'] 69 | 70 | # === optimize net_G === # 71 | loss_G = 0 72 | self.log_dict = OrderedDict() 73 | 74 | # pixel loss 75 | pix_w = self.opt['train']['pixel_crit'].get('weight', 1.0) 76 | loss_pix_G = pix_w * self.pix_crit(self.hr_data, self.gt_data) 77 | loss_G += loss_pix_G 78 | self.log_dict['l_pix_G'] = loss_pix_G.item() 79 | 80 | # warping loss 81 | if self.warp_crit is not None: 82 | # warp lr_prev according to lr_flow 83 | lr_curr = net_G_output_dict['lr_curr'] 84 | lr_prev = net_G_output_dict['lr_prev'] 85 | lr_flow = net_G_output_dict['lr_flow'] 86 | lr_warp = net_utils.backward_warp(lr_prev, lr_flow) 87 | 88 | warp_w = self.opt['train']['warping_crit'].get('weight', 1.0) 89 | loss_warp_G = warp_w * self.warp_crit(lr_warp, lr_curr) 90 | loss_G += loss_warp_G 91 | self.log_dict['l_warp_G'] = loss_warp_G.item() 92 | 93 | # optimize 94 | loss_G.backward() 95 | self.optim_G.step() 96 | 97 | def infer(self): 98 | """ Infer the `lr_data` sequence 99 | 100 | :return: np.ndarray sequence in type [uint8] and shape [thwc] 101 | """ 102 | 103 | lr_data = self.lr_data 104 | 105 | # temporal padding 106 | lr_data, n_pad_front = self.pad_sequence(lr_data) 107 | 108 | # infer 109 | self.net_G.eval() 110 | hr_seq = self.net_G(lr_data, self.device) 111 | hr_seq = hr_seq[n_pad_front:] 112 | 113 | return hr_seq 114 | 115 | def save(self, current_iter): 116 | self.save_network(self.net_G, 'G', current_iter) 117 | -------------------------------------------------------------------------------- /codes/official_metrics/LPIPSmodels/LPIPSsource.txt: -------------------------------------------------------------------------------- 1 | originally from https://github.com/richzhang/PerceptualSimilarity, 2 | @inproceedings{zhang2018perceptual, 3 | title={The Unreasonable Effectiveness of Deep Features as a Perceptual Metric}, 4 | author={Zhang, Richard and Isola, Phillip and Efros, Alexei A and Shechtman, Eli and Wang, Oliver}, 5 | booktitle={CVPR}, 6 | year={2018} 7 | } -------------------------------------------------------------------------------- /codes/official_metrics/LPIPSmodels/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import LPIPSmodels.util as util 5 | from torch.autograd import Variable 6 | from pdb import set_trace as st 7 | from IPython import embed 8 | 9 | class BaseModel(): 10 | def __init__(self): 11 | pass; 12 | 13 | def name(self): 14 | return 'BaseModel' 15 | 16 | def initialize(self, use_gpu=True): 17 | self.use_gpu = use_gpu 18 | self.Tensor = torch.cuda.FloatTensor if self.use_gpu else torch.Tensor 19 | # self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 20 | 21 | def forward(self): 22 | pass 23 | 24 | def get_image_paths(self): 25 | pass 26 | 27 | def optimize_parameters(self): 28 | pass 29 | 30 | def get_current_visuals(self): 31 | return self.input 32 | 33 | def get_current_errors(self): 34 | return {} 35 | 36 | def save(self, label): 37 | pass 38 | 39 | # helper saving function that can be used by subclasses 40 | def save_network(self, network, path, network_label, epoch_label): 41 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 42 | save_path = os.path.join(path, save_filename) 43 | torch.save(network.state_dict(), save_path) 44 | 45 | # helper loading function that can be used by subclasses 46 | def load_network(self, network, network_label, epoch_label): 47 | # embed() 48 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 49 | save_path = os.path.join(self.save_dir, save_filename) 50 | print('Loading network from %s'%save_path) 51 | network.load_state_dict(torch.load(save_path)) 52 | 53 | def update_learning_rate(): 54 | pass 55 | 56 | def get_image_paths(self): 57 | return self.image_paths 58 | 59 | def save_done(self, flag=False): 60 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 61 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 62 | 63 | -------------------------------------------------------------------------------- /codes/official_metrics/LPIPSmodels/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | sys.path.append('..') 6 | sys.path.append('.') 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | from torch.autograd import Variable 11 | import numpy as np 12 | from pdb import set_trace as st 13 | from skimage import color 14 | from IPython import embed 15 | from . import pretrained_networks as pn 16 | 17 | from . import util 18 | 19 | # Off-the-shelf deep network 20 | class PNet(nn.Module): 21 | '''Pre-trained network with all channels equally weighted by default''' 22 | def __init__(self, pnet_type='vgg', pnet_rand=False, use_gpu=True): 23 | super(PNet, self).__init__() 24 | 25 | self.use_gpu = use_gpu 26 | 27 | self.pnet_type = pnet_type 28 | self.pnet_rand = pnet_rand 29 | 30 | self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1)) 31 | self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1)) 32 | 33 | if(self.pnet_type in ['vgg','vgg16']): 34 | self.net = pn.vgg16(pretrained=not self.pnet_rand,requires_grad=False) 35 | elif(self.pnet_type=='alex'): 36 | self.net = pn.alexnet(pretrained=not self.pnet_rand,requires_grad=False) 37 | elif(self.pnet_type[:-2]=='resnet'): 38 | self.net = pn.resnet(pretrained=not self.pnet_rand,requires_grad=False, num=int(self.pnet_type[-2:])) 39 | elif(self.pnet_type=='squeeze'): 40 | self.net = pn.squeezenet(pretrained=not self.pnet_rand,requires_grad=False) 41 | 42 | self.L = self.net.N_slices 43 | 44 | if(use_gpu): 45 | self.net.cuda() 46 | self.shift = self.shift.cuda() 47 | self.scale = self.scale.cuda() 48 | 49 | def forward(self, in0, in1, retPerLayer=False): 50 | in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0) 51 | in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0) 52 | 53 | outs0 = self.net.forward(in0_sc) 54 | outs1 = self.net.forward(in1_sc) 55 | 56 | if(retPerLayer): 57 | all_scores = [] 58 | for (kk,out0) in enumerate(outs0): 59 | cur_score = (1.-util.cos_sim(outs0[kk],outs1[kk])) 60 | if(kk==0): 61 | val = 1.*cur_score 62 | else: 63 | # val = val + self.lambda_feat_layers[kk]*cur_score 64 | val = val + cur_score 65 | if(retPerLayer): 66 | all_scores+=[cur_score] 67 | 68 | if(retPerLayer): 69 | return (val, all_scores) 70 | else: 71 | return val 72 | 73 | # Learned perceptual metric 74 | class PNetLin(nn.Module): 75 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, use_gpu=True, spatial=False, version='0.1'): 76 | super(PNetLin, self).__init__() 77 | 78 | self.use_gpu = use_gpu 79 | self.pnet_type = pnet_type 80 | self.pnet_tune = pnet_tune 81 | self.pnet_rand = pnet_rand 82 | self.spatial = spatial 83 | self.version = version 84 | 85 | if(self.pnet_type in ['vgg','vgg16']): 86 | net_type = pn.vgg16 87 | self.chns = [64,128,256,512,512] 88 | elif(self.pnet_type=='alex'): 89 | net_type = pn.alexnet 90 | self.chns = [64,192,384,256,256] 91 | elif(self.pnet_type=='squeeze'): 92 | net_type = pn.squeezenet 93 | self.chns = [64,128,256,384,384,512,512] 94 | 95 | if(self.pnet_tune): 96 | self.net = net_type(pretrained=not self.pnet_rand,requires_grad=True) 97 | else: 98 | self.net = [net_type(pretrained=not self.pnet_rand,requires_grad=False),] 99 | 100 | self.lin0 = NetLinLayer(self.chns[0],use_dropout=use_dropout) 101 | self.lin1 = NetLinLayer(self.chns[1],use_dropout=use_dropout) 102 | self.lin2 = NetLinLayer(self.chns[2],use_dropout=use_dropout) 103 | self.lin3 = NetLinLayer(self.chns[3],use_dropout=use_dropout) 104 | self.lin4 = NetLinLayer(self.chns[4],use_dropout=use_dropout) 105 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 106 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 107 | self.lin5 = NetLinLayer(self.chns[5],use_dropout=use_dropout) 108 | self.lin6 = NetLinLayer(self.chns[6],use_dropout=use_dropout) 109 | self.lins+=[self.lin5,self.lin6] 110 | 111 | self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1)) 112 | self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1)) 113 | 114 | if(use_gpu): 115 | if(self.pnet_tune): 116 | self.net.cuda() 117 | else: 118 | self.net[0].cuda() 119 | self.shift = self.shift.cuda() 120 | self.scale = self.scale.cuda() 121 | self.lin0.cuda() 122 | self.lin1.cuda() 123 | self.lin2.cuda() 124 | self.lin3.cuda() 125 | self.lin4.cuda() 126 | if(self.pnet_type=='squeeze'): 127 | self.lin5.cuda() 128 | self.lin6.cuda() 129 | 130 | def forward(self, in0, in1): 131 | in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0) 132 | in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0) 133 | 134 | if(self.version=='0.0'): 135 | # v0.0 - original release had a bug, where input was not scaled 136 | in0_input = in0 137 | in1_input = in1 138 | else: 139 | # v0.1 140 | in0_input = in0_sc 141 | in1_input = in1_sc 142 | 143 | if(self.pnet_tune): 144 | outs0 = self.net.forward(in0_input) 145 | outs1 = self.net.forward(in1_input) 146 | else: 147 | outs0 = self.net[0].forward(in0_input) 148 | outs1 = self.net[0].forward(in1_input) 149 | 150 | feats0 = {} 151 | feats1 = {} 152 | diffs = [0]*len(outs0) 153 | 154 | for (kk,out0) in enumerate(outs0): 155 | feats0[kk] = util.normalize_tensor(outs0[kk]) 156 | feats1[kk] = util.normalize_tensor(outs1[kk]) 157 | diffs[kk] = (feats0[kk]-feats1[kk])**2 158 | 159 | if self.spatial: 160 | lin_models = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 161 | if(self.pnet_type=='squeeze'): 162 | lin_models.extend([self.lin5, self.lin6]) 163 | res = [lin_models[kk].model(diffs[kk]) for kk in range(len(diffs))] 164 | return res 165 | 166 | val = torch.mean(torch.mean(self.lin0.model(diffs[0]),dim=3),dim=2) 167 | val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]),dim=3),dim=2) 168 | val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]),dim=3),dim=2) 169 | val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]),dim=3),dim=2) 170 | val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]),dim=3),dim=2) 171 | if(self.pnet_type=='squeeze'): 172 | val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]),dim=3),dim=2) 173 | val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]),dim=3),dim=2) 174 | 175 | val = val.view(val.size()[0],val.size()[1],1,1) 176 | 177 | return val 178 | 179 | class Dist2LogitLayer(nn.Module): 180 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 181 | def __init__(self, chn_mid=32,use_sigmoid=True): 182 | super(Dist2LogitLayer, self).__init__() 183 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 184 | layers += [nn.LeakyReLU(0.2,True),] 185 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 186 | layers += [nn.LeakyReLU(0.2,True),] 187 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 188 | if(use_sigmoid): 189 | layers += [nn.Sigmoid(),] 190 | self.model = nn.Sequential(*layers) 191 | 192 | def forward(self,d0,d1,eps=0.1): 193 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 194 | 195 | class BCERankingLoss(nn.Module): 196 | def __init__(self, use_gpu=True, chn_mid=32): 197 | super(BCERankingLoss, self).__init__() 198 | self.use_gpu = use_gpu 199 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 200 | self.parameters = list(self.net.parameters()) 201 | self.loss = torch.nn.BCELoss() 202 | self.model = nn.Sequential(*[self.net]) 203 | 204 | if(self.use_gpu): 205 | self.net.cuda() 206 | 207 | def forward(self, d0, d1, judge): 208 | per = (judge+1.)/2. 209 | if(self.use_gpu): 210 | per = per.cuda() 211 | self.logit = self.net.forward(d0,d1) 212 | return self.loss(self.logit, per) 213 | 214 | class NetLinLayer(nn.Module): 215 | ''' A single linear layer which does a 1x1 conv ''' 216 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 217 | super(NetLinLayer, self).__init__() 218 | 219 | layers = [nn.Dropout(),] if(use_dropout) else [] 220 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 221 | self.model = nn.Sequential(*layers) 222 | 223 | 224 | # L2, DSSIM metrics 225 | class FakeNet(nn.Module): 226 | def __init__(self, use_gpu=True, colorspace='Lab'): 227 | super(FakeNet, self).__init__() 228 | self.use_gpu = use_gpu 229 | self.colorspace=colorspace 230 | 231 | class L2(FakeNet): 232 | 233 | def forward(self, in0, in1): 234 | assert(in0.size()[0]==1) # currently only supports batchSize 1 235 | 236 | if(self.colorspace=='RGB'): 237 | (N,C,X,Y) = in0.size() 238 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 239 | return value 240 | elif(self.colorspace=='Lab'): 241 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 242 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 243 | ret_var = Variable( torch.Tensor((value,) ) ) 244 | if(self.use_gpu): 245 | ret_var = ret_var.cuda() 246 | return ret_var 247 | 248 | class DSSIM(FakeNet): 249 | 250 | def forward(self, in0, in1): 251 | assert(in0.size()[0]==1) # currently only supports batchSize 1 252 | 253 | if(self.colorspace=='RGB'): 254 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 255 | elif(self.colorspace=='Lab'): 256 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 257 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 258 | ret_var = Variable( torch.Tensor((value,) ) ) 259 | if(self.use_gpu): 260 | ret_var = ret_var.cuda() 261 | return ret_var 262 | 263 | def print_network(net): 264 | num_params = 0 265 | for param in net.parameters(): 266 | num_params += param.numel() 267 | print('Network',net) 268 | print('Total number of parameters: %d' % num_params) 269 | -------------------------------------------------------------------------------- /codes/official_metrics/LPIPSmodels/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models 4 | from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = models.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = models.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = models.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = models.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = models.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = models.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /codes/official_metrics/LPIPSmodels/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/official_metrics/LPIPSmodels/v0.1/alex.pth -------------------------------------------------------------------------------- /codes/official_metrics/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | 5 | 6 | if __name__ == '__main__': 7 | # get agrs 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--model', '-m', type=str, required=True) 10 | args = parser.parse_args() 11 | 12 | keys = args.model.split('_') 13 | assert keys[0] in ('TecoGAN', 'FRVSR') 14 | assert keys[1] in ('BD', 'BI') 15 | 16 | # set dirs 17 | Vid4_GT_dir = 'data/Vid4/GT' 18 | Vid4_SR_dir = 'results/Vid4/{}'.format(args.model) 19 | Vid4_vids = ['calendar', 'city', 'foliage', 'walk'] 20 | 21 | ToS3_GT_dir = 'data/ToS3/GT' 22 | ToS3_SR_dir = 'results/ToS3/{}'.format(args.model) 23 | ToS3_vids = ['bridge', 'face', 'room'] 24 | 25 | # evaluate Vid4 26 | if osp.exists(Vid4_SR_dir): 27 | Vid4_GT_lst = [ 28 | osp.join(Vid4_GT_dir, vid) for vid in Vid4_vids] 29 | Vid4_SR_lst = [ 30 | osp.join(Vid4_SR_dir, vid) for vid in Vid4_vids] 31 | os.system('python codes/official_metrics/metrics.py --output {} --results {} --targets {}'.format( 32 | osp.join(Vid4_SR_dir, 'metric_log'), 33 | ','.join(Vid4_SR_lst), 34 | ','.join(Vid4_GT_lst))) 35 | 36 | # evaluate ToS3 37 | if osp.exists(ToS3_SR_dir): 38 | ToS3_GT_lst = [ 39 | osp.join(ToS3_GT_dir, vid) for vid in ToS3_vids] 40 | ToS3_SR_lst = [ 41 | osp.join(ToS3_SR_dir, vid) for vid in ToS3_vids] 42 | os.system('python codes/official_metrics/metrics.py --output {} --results {} --targets {}'.format( 43 | osp.join(ToS3_SR_dir, 'metric_log'), 44 | ','.join(ToS3_SR_lst), 45 | ','.join(ToS3_GT_lst))) 46 | 47 | -------------------------------------------------------------------------------- /codes/official_metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os, sys 4 | import pandas as pd 5 | from LPIPSmodels import util 6 | import LPIPSmodels.dist_model as dm 7 | from skimage.measure import compare_ssim 8 | 9 | from absl import flags 10 | flags.DEFINE_string('output', None, 'the path of output directory') 11 | flags.DEFINE_string('results', None, 'the list of paths of result directory') 12 | flags.DEFINE_string('targets', None, 'the list of paths of target directory') 13 | 14 | FLAGS = flags.FLAGS 15 | FLAGS(sys.argv) 16 | 17 | if(not os.path.exists(FLAGS.output)): 18 | os.mkdir(FLAGS.output) 19 | 20 | # The operation used to print out the configuration 21 | def print_configuration_op(FLAGS): 22 | print('[Configurations]:') 23 | for name, value in FLAGS.flag_values_dict().items(): 24 | print('\t%s: %s'%(name, str(value))) 25 | print('End of configuration') 26 | # custom Logger to write Log to file 27 | 28 | def listPNGinDir(dirpath): 29 | filelist = os.listdir(dirpath) 30 | filelist = [_ for _ in filelist if _.endswith(".png")] 31 | filelist = [_ for _ in filelist if not _.startswith("IB")] 32 | filelist = sorted(filelist) 33 | filelist.sort(key=lambda f: int(''.join(list(filter(str.isdigit, f))) or -1)) 34 | result = [os.path.join(dirpath,_) for _ in filelist if _.endswith(".png")] 35 | return result 36 | 37 | def _rgb2ycbcr(img, maxVal=255): 38 | ##### color space transform, originally from https://github.com/yhjo09/VSR-DUF ##### 39 | O = np.array([[16], 40 | [128], 41 | [128]]) 42 | T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941], 43 | [-0.148223529411765, -0.290992156862745, 0.439215686274510], 44 | [0.439215686274510, -0.367788235294118, -0.071427450980392]]) 45 | 46 | if maxVal == 1: 47 | O = O / 255.0 48 | 49 | t = np.reshape(img, (img.shape[0]*img.shape[1], img.shape[2])) 50 | t = np.dot(t, np.transpose(T)) 51 | t[:, 0] += O[0] 52 | t[:, 1] += O[1] 53 | t[:, 2] += O[2] 54 | ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]]) 55 | 56 | return ycbcr 57 | 58 | def to_uint8(x, vmin, vmax): 59 | ##### color space transform, originally from https://github.com/yhjo09/VSR-DUF ##### 60 | x = x.astype('float32') 61 | x = (x-vmin)/(vmax-vmin)*255 # 0~255 62 | return np.clip(np.round(x), 0, 255) 63 | 64 | def psnr(img_true, img_pred): 65 | ##### PSNR with color space transform, originally from https://github.com/yhjo09/VSR-DUF ##### 66 | Y_true = _rgb2ycbcr(to_uint8(img_true, 0, 255), 255)[:,:,0] 67 | Y_pred = _rgb2ycbcr(to_uint8(img_pred, 0, 255), 255)[:,:,0] 68 | diff = Y_true - Y_pred 69 | rmse = np.sqrt(np.mean(np.power(diff,2))) 70 | return 20*np.log10(255./rmse) 71 | 72 | def ssim(img_true, img_pred): ##### SSIM ##### 73 | Y_true = _rgb2ycbcr(to_uint8(img_true, 0, 255), 255)[:,:,0] 74 | Y_pred = _rgb2ycbcr(to_uint8(img_pred, 0, 255), 255)[:,:,0] 75 | return compare_ssim(Y_true, Y_pred, data_range=Y_pred.max() - Y_pred.min()) 76 | 77 | def crop_8x8( img ): 78 | ori_h = img.shape[0] 79 | ori_w = img.shape[1] 80 | 81 | h = (ori_h//32) * 32 82 | w = (ori_w//32) * 32 83 | 84 | while(h > ori_h - 16): 85 | h = h - 32 86 | while(w > ori_w - 16): 87 | w = w - 32 88 | 89 | y = (ori_h - h) // 2 90 | x = (ori_w - w) // 2 91 | crop_img = img[y:y+h, x:x+w] 92 | return crop_img, y, x 93 | 94 | class Logger(object): 95 | def __init__(self): 96 | self.terminal = sys.stdout 97 | filename = "metricsfile.txt" 98 | self.log = open(os.path.join(FLAGS.output, filename), "a") 99 | def write(self, message): 100 | self.terminal.write(message) 101 | self.log.write(message) 102 | def flush(self): 103 | self.log.flush() 104 | 105 | sys.stdout = Logger() 106 | 107 | print_configuration_op(FLAGS) 108 | 109 | result_list = FLAGS.results.split(',') 110 | target_list = FLAGS.targets.split(',') 111 | folder_n = len(result_list) 112 | 113 | 114 | model = dm.DistModel() 115 | model.initialize(model='net-lin',net='alex',use_gpu=True) 116 | 117 | cutfr = 2 118 | # maxV = 0.4, for line 154-166 119 | 120 | keys = ["PSNR", "SSIM", "LPIPS", "tOF", "tLP100"] # keys = ["LPIPS"] 121 | sum_dict = dict.fromkeys(["FrameAvg_"+_ for _ in keys], 0) 122 | len_dict = dict.fromkeys(keys, 0) 123 | avg_dict = dict.fromkeys(["Avg_"+_ for _ in keys], 0) 124 | folder_dict = dict.fromkeys(["FolderAvg_"+_ for _ in keys], 0) 125 | 126 | for folder_i in range(folder_n): 127 | result = listPNGinDir(result_list[folder_i]) 128 | target = listPNGinDir(target_list[folder_i]) 129 | image_no = len(target) 130 | 131 | list_dict = {} 132 | for key_i in keys: 133 | list_dict[key_i] = [] 134 | 135 | for i in range(cutfr, image_no-cutfr): 136 | output_img = cv2.imread(result[i])[:,:,::-1] 137 | target_img = cv2.imread(target[i])[:,:,::-1] 138 | msg = "frame %d, tar %s, out %s, "%(i, str(target_img.shape), str(output_img.shape)) 139 | if( target_img.shape[0] < output_img.shape[0]) or ( target_img.shape[1] < output_img.shape[1]): # target is not dividable by 4 140 | output_img = output_img[:target_img.shape[0],:target_img.shape[1]] 141 | if( target_img.shape[0] > output_img.shape[0]) or ( target_img.shape[1] > output_img.shape[1]): # target is not dividable by 4 142 | target_img = target_img[:output_img.shape[0],:output_img.shape[1]] 143 | #print(result[i]) 144 | 145 | if "tOF" in keys:# tOF 146 | output_grey = cv2.cvtColor(output_img, cv2.COLOR_RGB2GRAY) 147 | target_grey = cv2.cvtColor(target_img, cv2.COLOR_RGB2GRAY) 148 | if (i > cutfr): # temporal metrics 149 | target_OF=cv2.calcOpticalFlowFarneback(pre_tar_grey, target_grey, None, 0.5, 3, 15, 3, 5, 1.2, 0) 150 | output_OF=cv2.calcOpticalFlowFarneback(pre_out_grey, output_grey, None, 0.5, 3, 15, 3, 5, 1.2, 0) 151 | target_OF, ofy, ofx = crop_8x8(target_OF) 152 | output_OF, ofy, ofx = crop_8x8(output_OF) 153 | OF_diff = np.absolute(target_OF - output_OF) 154 | if False: # for motion visualization 155 | tOFpath = os.path.join(FLAGS.output,"%03d_tOF"%folder_i) 156 | if(not os.path.exists(tOFpath)): os.mkdir(tOFpath) 157 | hsv = np.zeros_like(output_img) 158 | hsv[...,1] = 255 159 | out_path = os.path.join(tOFpath, "flow_%04d.jpg" %i) 160 | mag, ang = cv2.cartToPolar(OF_diff[...,0], OF_diff[...,1]) 161 | # print("tar max %02.6f, min %02.6f, avg %02.6f" % (mag.max(), mag.min(), mag.mean())) 162 | mag = np.clip(mag, 0.0, maxV)/maxV 163 | hsv[...,0] = ang*180/np.pi/2 164 | hsv[...,2] = mag * 255.0 # 165 | bgr = cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR) 166 | cv2.imwrite(out_path, bgr) 167 | 168 | OF_diff = np.sqrt(np.sum(OF_diff * OF_diff, axis = -1)) # l1 vector norm 169 | # OF_diff, ofy, ofx = crop_8x8(OF_diff) 170 | list_dict["tOF"].append( OF_diff.mean() ) 171 | msg += "tOF %02.2f, " %(list_dict["tOF"][-1]) 172 | 173 | pre_out_grey = output_grey 174 | pre_tar_grey = target_grey 175 | 176 | target_img, ofy, ofx = crop_8x8(target_img) 177 | output_img, ofy, ofx = crop_8x8(output_img) 178 | 179 | if "PSNR" in keys:# psnr 180 | list_dict["PSNR"].append( psnr(target_img, output_img) ) 181 | msg +="psnr %02.2f" %(list_dict["PSNR"][-1]) 182 | 183 | if "SSIM" in keys:# ssim 184 | list_dict["SSIM"].append( ssim(target_img, output_img) ) 185 | msg +=", ssim %02.2f" %(list_dict["SSIM"][-1]) 186 | 187 | if "LPIPS" in keys or "tLP100" in keys: 188 | img0 = util.im2tensor(target_img) # RGB image from [-1,1] 189 | img1 = util.im2tensor(output_img) 190 | 191 | if "LPIPS" in keys: # LPIPS 192 | dist01 = model.forward(img0,img1) 193 | list_dict["LPIPS"].append( dist01[0] ) 194 | msg +=", lpips %02.2f" %(dist01[0]) 195 | 196 | if "tLP100" in keys and (i > cutfr):# tLP, temporal metrics 197 | dist0t = model.forward(pre_img0, img0) 198 | dist1t = model.forward(pre_img1, img1) 199 | # print ("tardis %f, outdis %f" %(dist0t, dist1t)) 200 | dist01t = np.absolute(dist0t - dist1t) * 100.0 ##########!!!!! 201 | list_dict["tLP100"].append( dist01t[0] ) 202 | msg += ", tLPx100 %02.2f" %(dist01t[0]) 203 | pre_img0 = img0 204 | pre_img1 = img1 205 | 206 | msg +=", crop (%d, %d)" %(ofy, ofx) 207 | #print(msg) 208 | mode = 'w' if folder_i==0 else 'a' 209 | 210 | pd_dict = {} 211 | for cur_num_data in keys: 212 | num_data = cur_num_data+"_%02d" % folder_i 213 | cur_list = np.float32(list_dict[cur_num_data]) 214 | pd_dict[num_data] = pd.Series(cur_list) 215 | 216 | num_data_sum = cur_list.sum() 217 | num_data_len = cur_list.shape[0] 218 | num_data_mean = num_data_sum / num_data_len 219 | #print("%s, max %02.4f, min %02.4f, avg %02.4f" % 220 | # (num_data, cur_list.max(), cur_list.min(), num_data_mean)) 221 | 222 | if folder_i == 0: 223 | avg_dict["Avg_"+cur_num_data] = [num_data_mean] 224 | else: 225 | avg_dict["Avg_"+cur_num_data] += [num_data_mean] 226 | 227 | sum_dict["FrameAvg_"+cur_num_data] += num_data_sum 228 | len_dict[cur_num_data] += num_data_len 229 | folder_dict["FolderAvg_"+cur_num_data] += num_data_mean 230 | 231 | pd.DataFrame(pd_dict).to_csv(os.path.join(FLAGS.output,"metrics.csv"), mode=mode) 232 | 233 | for num_data in keys: 234 | sum_dict["FrameAvg_"+num_data] = pd.Series([sum_dict["FrameAvg_"+num_data] / len_dict[num_data]]) 235 | folder_dict["FolderAvg_"+num_data] = pd.Series([folder_dict["FolderAvg_"+num_data] / folder_n]) 236 | avg_dict["Avg_"+num_data] = pd.Series(np.float32(avg_dict["Avg_"+num_data])) 237 | print("%s, total frame %d, total avg %02.4f, folder avg %02.4f" % 238 | (num_data, len_dict[num_data], sum_dict["FrameAvg_"+num_data][0], folder_dict["FolderAvg_"+num_data][0])) 239 | pd.DataFrame(avg_dict).to_csv(os.path.join(FLAGS.output,"metrics.csv"), mode='a') 240 | pd.DataFrame(folder_dict).to_csv(os.path.join(FLAGS.output,"metrics.csv"), mode='a') 241 | pd.DataFrame(sum_dict).to_csv(os.path.join(FLAGS.output,"metrics.csv"), mode='a') 242 | print("Finished.") 243 | -------------------------------------------------------------------------------- /codes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/codes/utils/__init__.py -------------------------------------------------------------------------------- /codes/utils/base_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import random 4 | import logging 5 | import argparse 6 | import yaml 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from .dist_utils import init_dist, master_only 12 | 13 | 14 | def parse_agrs(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--exp_dir', type=str, required=True, 17 | help='directory of the current experiment') 18 | parser.add_argument('--mode', type=str, required=True, 19 | help='which mode to use (train|test|profile)') 20 | parser.add_argument('--opt', type=str, required=True, 21 | help='path to the config yaml file') 22 | parser.add_argument('--gpu_ids', type=str, default='-1', 23 | help='GPU index (set -1 to use CPU)') 24 | parser.add_argument('--lr_size', type=str, default='3x256x256', 25 | help='size of the input frame') 26 | parser.add_argument('--test_speed', action='store_true', 27 | help='whether to test speed') 28 | parser.add_argument('--local_rank', default=-1, type=int, 29 | help='local gpu index') 30 | return parser.parse_args() 31 | 32 | 33 | def parse_configs(args): 34 | # load option file 35 | with open(osp.join(args.exp_dir, args.opt), 'r') as f: 36 | opt = yaml.load(f.read(), Loader=yaml.FullLoader) 37 | 38 | opt['exp_dir'] = args.exp_dir 39 | opt['gpu_ids'] = args.gpu_ids 40 | opt['is_train'] = (args.mode == 'train') 41 | 42 | # setup device 43 | setup_device(opt, args.gpu_ids, args.local_rank) 44 | 45 | # setup random seed 46 | setup_random_seed(opt.get('manual_seed', 2021) + opt['rank']) 47 | 48 | return opt 49 | 50 | 51 | def setup_device(opt, gpu_ids, local_rank): 52 | gpu_ids = tuple(map(int, gpu_ids.split(','))) 53 | if gpu_ids[0] < 0 or not torch.cuda.is_available(): 54 | # cpu settings 55 | opt.update({ 56 | 'dist': False, 57 | 'device': 'cpu', 58 | 'rank': 0 59 | }) 60 | else: 61 | # gpu settings 62 | if len(gpu_ids) == 1: 63 | # single gpu 64 | torch.cuda.set_device(0) 65 | opt.update({ 66 | 'dist': False, 67 | 'device': 'cuda', 68 | 'rank': 0 69 | }) 70 | else: 71 | # multiple gpus 72 | init_dist(opt, local_rank) 73 | 74 | torch.backends.cudnn.benchmark = True 75 | # torch.backends.cudnn.deterministic = True 76 | 77 | 78 | def setup_random_seed(seed): 79 | random.seed(seed) 80 | np.random.seed(seed) 81 | torch.manual_seed(seed) 82 | torch.cuda.manual_seed(seed) 83 | torch.cuda.manual_seed_all(seed) 84 | 85 | 86 | def setup_logger(name): 87 | # create a logger 88 | base_logger = logging.getLogger(name=name) 89 | base_logger.setLevel(logging.INFO) 90 | # create a logging format 91 | formatter = logging.Formatter(fmt='%(asctime)s [%(levelname)s]: %(message)s') 92 | # create a stream handler & set format 93 | sh = logging.StreamHandler() 94 | sh.setFormatter(formatter) 95 | # add handlers 96 | base_logger.addHandler(sh) 97 | 98 | 99 | @master_only 100 | def log_info(msg, logger_name='base'): 101 | logger = logging.getLogger(logger_name) 102 | logger.info(msg) 103 | 104 | 105 | def print_options(opt, logger_name='base', tab=''): 106 | for key, val in opt.items(): 107 | if isinstance(val, dict): 108 | log_info('{}{}:'.format(tab, key), logger_name) 109 | print_options(val, logger_name, tab + ' ') 110 | else: 111 | log_info('{}{}: {}'.format(tab, key, val), logger_name) 112 | 113 | 114 | def retrieve_files(dir, suffix='png|jpg'): 115 | """ retrive files with specific suffix under dir and sub-dirs recursively 116 | """ 117 | 118 | def retrieve_files_recursively(dir, file_lst): 119 | for d in sorted(os.listdir(dir)): 120 | dd = osp.join(dir, d) 121 | 122 | if osp.isdir(dd): 123 | retrieve_files_recursively(dd, file_lst) 124 | else: 125 | if osp.splitext(d)[-1].lower() in ['.' + s for s in suffix]: 126 | file_lst.append(dd) 127 | 128 | if not dir: 129 | return [] 130 | 131 | if isinstance(suffix, str): 132 | suffix = suffix.split('|') 133 | 134 | file_lst = [] 135 | retrieve_files_recursively(dir, file_lst) 136 | file_lst.sort() 137 | 138 | return file_lst 139 | 140 | 141 | def setup_paths(opt, mode): 142 | 143 | def setup_ckpt_dir(): 144 | ckpt_dir = opt['train'].get('ckpt_dir', '') 145 | if not ckpt_dir: # default dir 146 | ckpt_dir = osp.join(opt['exp_dir'], 'train', 'ckpt') 147 | opt['train']['ckpt_dir'] = ckpt_dir 148 | os.makedirs(ckpt_dir, exist_ok=True) 149 | 150 | def setup_res_dir(): 151 | res_dir = opt['test'].get('res_dir', '') 152 | if not res_dir: # default dir 153 | res_dir = osp.join(opt['exp_dir'], 'test', 'results') 154 | opt['test']['res_dir'] = res_dir 155 | os.makedirs(res_dir, exist_ok=True) 156 | 157 | def setup_json_path(): 158 | json_dir = opt['test'].get('json_dir', '') 159 | if not json_dir: # default dir 160 | json_dir = osp.join(opt['exp_dir'], 'test', 'metrics') 161 | opt['test']['json_dir'] = json_dir 162 | os.makedirs(json_dir, exist_ok=True) 163 | 164 | def setup_model_path(): 165 | load_path = opt['model']['generator'].get('load_path', '') 166 | if not load_path: 167 | raise ValueError( 168 | 'Pretrained generator model is needed for testing') 169 | 170 | # parse models 171 | ckpt_dir, model_idx = osp.split(load_path) 172 | model_idx = osp.splitext(model_idx)[0] 173 | if model_idx == '*': 174 | # test a serial of models TODO: check validity 175 | start_iter = opt['test']['start_iter'] 176 | end_iter = opt['test']['end_iter'] 177 | freq = opt['test']['test_freq'] 178 | opt['model']['generator']['load_path_lst'] = [ 179 | osp.join(ckpt_dir, f'G_iter{iter}.pth') 180 | for iter in range(start_iter, end_iter + 1, freq)] 181 | else: 182 | # test a single model 183 | opt['model']['generator']['load_path_lst'] = [ 184 | osp.join(ckpt_dir, f'{model_idx}.pth')] 185 | 186 | if mode == 'train': 187 | setup_ckpt_dir() 188 | 189 | # for validation purpose 190 | for dataset_idx in opt['dataset'].keys(): 191 | if 'test' not in dataset_idx: 192 | continue 193 | 194 | if opt['test'].get('save_res', False): 195 | setup_res_dir() 196 | 197 | if opt['test'].get('save_json', False): 198 | setup_json_path() 199 | 200 | elif mode == 'test': 201 | setup_model_path() 202 | 203 | for dataset_idx in opt['dataset'].keys(): 204 | if 'test' not in dataset_idx: 205 | continue 206 | 207 | if opt['test'].get('save_res', False): 208 | setup_res_dir() 209 | 210 | if opt['test'].get('save_json', False): 211 | setup_json_path() 212 | -------------------------------------------------------------------------------- /codes/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | from scipy import signal 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | def create_kernel(sigma, ksize=None): 12 | if ksize is None: 13 | ksize = 1 + 2 * int(sigma * 3.0) 14 | 15 | gkern1d = signal.gaussian(ksize, std=sigma).reshape(ksize, 1) 16 | gkern2d = np.outer(gkern1d, gkern1d) 17 | gaussian_kernel = gkern2d / gkern2d.sum() 18 | zero_kernel = np.zeros_like(gaussian_kernel) 19 | 20 | kernel = np.float32([ 21 | [gaussian_kernel, zero_kernel, zero_kernel], 22 | [zero_kernel, gaussian_kernel, zero_kernel], 23 | [zero_kernel, zero_kernel, gaussian_kernel]]) 24 | 25 | kernel = torch.from_numpy(kernel) 26 | 27 | return kernel 28 | 29 | 30 | def downsample_bd(data, kernel, scale, pad_data): 31 | """ 32 | Note: 33 | 1. `data` should be torch.FloatTensor (data range 0~1) in shape [nchw] 34 | 2. `pad_data` should be enabled in model testing 35 | 3. This function is device agnostic, i.e., data/kernel could be on cpu or gpu 36 | """ 37 | 38 | if pad_data: 39 | # compute padding params 40 | kernel_h, kernel_w = kernel.shape[-2:] 41 | pad_h, pad_w = kernel_h - 1, kernel_w - 1 42 | pad_t = pad_h // 2 43 | pad_b = pad_h - pad_t 44 | pad_l = pad_w // 2 45 | pad_r = pad_w - pad_l 46 | 47 | # pad data 48 | data = F.pad(data, (pad_l, pad_r, pad_t, pad_b), 'reflect') 49 | 50 | # blur + down sample 51 | data = F.conv2d(data, kernel, stride=scale, bias=None, padding=0) 52 | 53 | return data 54 | 55 | 56 | def rgb_to_ycbcr(img): 57 | """ Coefficients are taken from the official codes of DUF-VSR 58 | This conversion is also the same as that in BasicSR 59 | 60 | Parameters: 61 | :param img: rgb image in type np.uint8 62 | :return: ycbcr image in type np.uint8 63 | """ 64 | 65 | T = np.array([ 66 | [0.256788235294118, -0.148223529411765, 0.439215686274510], 67 | [0.504129411764706, -0.290992156862745, -0.367788235294118], 68 | [0.097905882352941, 0.439215686274510, -0.071427450980392], 69 | ], dtype=np.float64) 70 | 71 | O = np.array([16, 128, 128], dtype=np.float64) 72 | 73 | img = img.astype(np.float64) 74 | res = np.matmul(img, T) + O 75 | res = res.clip(0, 255).round().astype(np.uint8) 76 | 77 | return res 78 | 79 | 80 | def float32_to_uint8(inputs): 81 | """ Convert np.float32 array to np.uint8 82 | 83 | Parameters: 84 | :param input: np.float32, (NT)CHW, [0, 1] 85 | :return: np.uint8, (NT)CHW, [0, 255] 86 | """ 87 | return np.uint8(np.clip(np.round(inputs * 255), 0, 255)) 88 | 89 | 90 | def save_sequence(seq_dir, seq_data, frm_idx_lst=None, to_bgr=False): 91 | """ Save each frame of a sequence to .png image in seq_dir 92 | 93 | Parameters: 94 | :param seq_dir: dir to save results 95 | :param seq_data: sequence with shape thwc|uint8 96 | :param frm_idx_lst: specify filename for each frame to be saved 97 | :param to_bgr: whether to flip color channels 98 | """ 99 | 100 | if to_bgr: 101 | seq_data = seq_data[..., ::-1] # rgb2bgr 102 | 103 | # use default frm_idx_lst is not specified 104 | tot_frm = len(seq_data) 105 | if frm_idx_lst is None: 106 | frm_idx_lst = ['{:04d}.png'.format(i) for i in range(tot_frm)] 107 | 108 | # save for each frame 109 | os.makedirs(seq_dir, exist_ok=True) 110 | for i in range(tot_frm): 111 | cv2.imwrite(osp.join(seq_dir, frm_idx_lst[i]), seq_data[i]) 112 | -------------------------------------------------------------------------------- /codes/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.multiprocessing as mp 6 | 7 | 8 | def init_dist(opt, local_rank): 9 | """ Adopted from BasicSR 10 | """ 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | torch.cuda.set_device(local_rank) 14 | dist.init_process_group(backend='nccl') 15 | 16 | rank, world_size = get_dist_info() 17 | 18 | opt.update({ 19 | 'dist': True, 20 | 'device': 'cuda', 21 | 'local_rank': local_rank, 22 | 'world_size': world_size, 23 | 'rank': rank 24 | }) 25 | 26 | 27 | def get_dist_info(): 28 | """ Adopted from BasicSR 29 | """ 30 | if dist.is_available(): 31 | initialized = dist.is_initialized() 32 | else: 33 | initialized = False 34 | 35 | if initialized: 36 | rank = dist.get_rank() 37 | world_size = dist.get_world_size() 38 | else: 39 | rank = 0 40 | world_size = 1 41 | 42 | return rank, world_size 43 | 44 | 45 | def master_only(func): 46 | """ Adopted from BasicSR 47 | """ 48 | @functools.wraps(func) 49 | def wrapper(*args, **kwargs): 50 | rank, _ = get_dist_info() 51 | if rank == 0: 52 | return func(*args, **kwargs) 53 | 54 | return wrapper 55 | -------------------------------------------------------------------------------- /codes/utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # ===----------------- utility functions -------------------- # 9 | def initialize_weights(net_l, init_type='kaiming', scale=1): 10 | """ Modify from BasicSR/MMSR 11 | """ 12 | 13 | if not isinstance(net_l, list): 14 | net_l = [net_l] 15 | 16 | for net in net_l: 17 | for m in net.modules(): 18 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 19 | if init_type == 'xavier': 20 | nn.init.xavier_uniform_(m.weight) 21 | elif init_type == 'kaiming': 22 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 23 | else: 24 | raise NotImplementedError(init_type) 25 | 26 | m.weight.data *= scale # to stabilize training 27 | 28 | if m.bias is not None: 29 | nn.init.constant_(m.bias.data, 0) 30 | 31 | elif isinstance(m, nn.BatchNorm2d): 32 | nn.init.constant_(m.weight.data, 1) 33 | nn.init.constant_(m.bias.data, 0) 34 | 35 | 36 | def space_to_depth(x, scale): 37 | """ Equivalent to tf.space_to_depth() 38 | """ 39 | 40 | n, c, in_h, in_w = x.size() 41 | out_h, out_w = in_h // scale, in_w // scale 42 | 43 | x_reshaped = x.reshape(n, c, out_h, scale, out_w, scale) 44 | x_reshaped = x_reshaped.permute(0, 3, 5, 1, 2, 4) 45 | output = x_reshaped.reshape(n, scale * scale * c, out_h, out_w) 46 | 47 | return output 48 | 49 | 50 | def backward_warp(x, flow, mode='bilinear', padding_mode='border'): 51 | """ Backward warp `x` according to `flow` 52 | 53 | Both x and flow are pytorch tensor in shape `nchw` and `n2hw` 54 | 55 | Reference: 56 | https://github.com/sniklaus/pytorch-spynet/blob/master/run.py#L41 57 | """ 58 | 59 | n, c, h, w = x.size() 60 | 61 | # create mesh grid 62 | iu = torch.linspace(-1.0, 1.0, w).view(1, 1, 1, w).expand(n, -1, h, -1) 63 | iv = torch.linspace(-1.0, 1.0, h).view(1, 1, h, 1).expand(n, -1, -1, w) 64 | grid = torch.cat([iu, iv], 1).to(flow.device) 65 | 66 | # normalize flow to [-1, 1] 67 | flow = torch.cat([ 68 | flow[:, 0:1, ...] / ((w - 1.0) / 2.0), 69 | flow[:, 1:2, ...] / ((h - 1.0) / 2.0)], dim=1) 70 | 71 | # add flow to grid and reshape to nhw2 72 | grid = (grid + flow).permute(0, 2, 3, 1) 73 | 74 | # bilinear sampling 75 | # Note: `align_corners` is set to `True` by default for PyTorch version < 1.4.0 76 | if int(''.join(torch.__version__.split('.')[:2])) >= 14: 77 | output = F.grid_sample( 78 | x, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 79 | else: 80 | output = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode) 81 | 82 | return output 83 | 84 | 85 | def get_upsampling_func(scale=4, degradation='BI'): 86 | if degradation == 'BI': 87 | upsample_func = functools.partial( 88 | F.interpolate, scale_factor=scale, mode='bilinear', 89 | align_corners=False) 90 | 91 | elif degradation == 'BD': 92 | upsample_func = BicubicUpsampler(scale_factor=scale) 93 | 94 | else: 95 | raise ValueError(f'Unrecognized degradation type: {degradation}') 96 | 97 | return upsample_func 98 | 99 | 100 | # --------------------- utility classes --------------------- # 101 | class BicubicUpsampler(nn.Module): 102 | """ Bicubic upsampling function with similar behavior to that in TecoGAN-Tensorflow 103 | 104 | Note: 105 | This function is different from torch.nn.functional.interpolate and matlab's imresize 106 | in terms of the bicubic kernel and the sampling strategy 107 | 108 | References: 109 | http://verona.fi-p.unam.mx/boris/practicas/CubConvInterp.pdf 110 | https://stackoverflow.com/questions/26823140/imresize-trying-to-understand-the-bicubic-interpolation 111 | """ 112 | 113 | def __init__(self, scale_factor, a=-0.75): 114 | super(BicubicUpsampler, self).__init__() 115 | 116 | # calculate weights (according to Eq.(6) in the reference paper) 117 | cubic = torch.FloatTensor([ 118 | [0, a, -2*a, a], 119 | [1, 0, -(a + 3), a + 2], 120 | [0, -a, (2*a + 3), -(a + 2)], 121 | [0, 0, a, -a] 122 | ]) 123 | 124 | kernels = [ 125 | torch.matmul(cubic, torch.FloatTensor([1, s, s**2, s**3])) 126 | for s in [1.0*d/scale_factor for d in range(scale_factor)] 127 | ] # s = x - floor(x) 128 | 129 | # register parameters 130 | self.scale_factor = scale_factor 131 | self.register_buffer('kernels', torch.stack(kernels)) # size: (f, 4) 132 | 133 | def forward(self, input): 134 | n, c, h, w = input.size() 135 | f = self.scale_factor 136 | 137 | # merge n&c 138 | input = input.reshape(n*c, 1, h, w) 139 | 140 | # pad input (left, right, top, bottom) 141 | input = F.pad(input, (1, 2, 1, 2), mode='replicate') 142 | 143 | # calculate output (vertical expansion) 144 | kernel_h = self.kernels.view(f, 1, 4, 1) 145 | output = F.conv2d(input, kernel_h, stride=1, padding=0) 146 | output = output.permute(0, 2, 1, 3).reshape(n*c, 1, f*h, w + 3) 147 | 148 | # calculate output (horizontal expansion) 149 | kernel_w = self.kernels.view(f, 1, 1, 4) 150 | output = F.conv2d(output, kernel_w, stride=1, padding=0) 151 | output = output.permute(0, 2, 3, 1).reshape(n*c, 1, f*h, f*w) 152 | 153 | # split n&c 154 | output = output.reshape(n, c, f*h, f*w) 155 | 156 | return output 157 | -------------------------------------------------------------------------------- /data/meta/REDS/test_list.txt: -------------------------------------------------------------------------------- 1 | 000 2 | 011 3 | 015 4 | 020 5 | -------------------------------------------------------------------------------- /data/meta/REDS/train_list.txt: -------------------------------------------------------------------------------- 1 | 001 2 | 002 3 | 003 4 | 004 5 | 005 6 | 006 7 | 007 8 | 008 9 | 009 10 | 010 11 | 012 12 | 013 13 | 014 14 | 016 15 | 017 16 | 018 17 | 019 18 | 021 19 | 022 20 | 023 21 | 024 22 | 025 23 | 026 24 | 027 25 | 028 26 | 029 27 | 030 28 | 031 29 | 032 30 | 033 31 | 034 32 | 035 33 | 036 34 | 037 35 | 038 36 | 039 37 | 040 38 | 041 39 | 042 40 | 043 41 | 044 42 | 045 43 | 046 44 | 047 45 | 048 46 | 049 47 | 050 48 | 051 49 | 052 50 | 053 51 | 054 52 | 055 53 | 056 54 | 057 55 | 058 56 | 059 57 | 060 58 | 061 59 | 062 60 | 063 61 | 064 62 | 065 63 | 066 64 | 067 65 | 068 66 | 069 67 | 070 68 | 071 69 | 072 70 | 073 71 | 074 72 | 075 73 | 076 74 | 077 75 | 078 76 | 079 77 | 080 78 | 081 79 | 082 80 | 083 81 | 084 82 | 085 83 | 086 84 | 087 85 | 088 86 | 089 87 | 090 88 | 091 89 | 092 90 | 093 91 | 094 92 | 095 93 | 096 94 | 097 95 | 098 96 | 099 97 | 100 98 | 101 99 | 102 100 | 103 101 | 104 102 | 105 103 | 106 104 | 107 105 | 108 106 | 109 107 | 110 108 | 111 109 | 112 110 | 113 111 | 114 112 | 115 113 | 116 114 | 117 115 | 118 116 | 119 117 | 120 118 | 121 119 | 122 120 | 123 121 | 124 122 | 125 123 | 126 124 | 127 125 | 128 126 | 129 127 | 130 128 | 131 129 | 132 130 | 133 131 | 134 132 | 135 133 | 136 134 | 137 135 | 138 136 | 139 137 | 140 138 | 141 139 | 142 140 | 143 141 | 144 142 | 145 143 | 146 144 | 147 145 | 148 146 | 149 147 | 150 148 | 151 149 | 152 150 | 153 151 | 154 152 | 155 153 | 156 154 | 157 155 | 158 156 | 159 157 | 160 158 | 161 159 | 162 160 | 163 161 | 164 162 | 165 163 | 166 164 | 167 165 | 168 166 | 169 167 | 170 168 | 171 169 | 172 170 | 173 171 | 174 172 | 175 173 | 176 174 | 177 175 | 178 176 | 179 177 | 180 178 | 181 179 | 182 180 | 183 181 | 184 182 | 185 183 | 186 184 | 187 185 | 188 186 | 189 187 | 190 188 | 191 189 | 192 190 | 193 191 | 194 192 | 195 193 | 196 194 | 197 195 | 198 196 | 199 197 | 200 198 | 201 199 | 202 200 | 203 201 | 204 202 | 205 203 | 206 204 | 207 205 | 208 206 | 209 207 | 210 208 | 211 209 | 212 210 | 213 211 | 214 212 | 215 213 | 216 214 | 217 215 | 218 216 | 219 217 | 220 218 | 221 219 | 222 220 | 223 221 | 224 222 | 225 223 | 226 224 | 227 225 | 228 226 | 229 227 | 230 228 | 231 229 | 232 230 | 233 231 | 234 232 | 235 233 | 236 234 | 237 235 | 238 236 | 239 237 | 240 238 | 241 239 | 242 240 | 243 241 | 244 242 | 245 243 | 246 244 | 247 245 | 248 246 | 249 247 | 250 248 | 251 249 | 252 250 | 253 251 | 254 252 | 255 253 | 256 254 | 257 255 | 258 256 | 259 257 | 260 258 | 261 259 | 262 260 | 263 261 | 264 262 | 265 263 | 266 264 | 267 265 | 268 266 | 269 267 | -------------------------------------------------------------------------------- /data/put_data_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/data/put_data_here -------------------------------------------------------------------------------- /experiments_BD/FRVSR/FRVSR_REDS_2xSR_2GPU/test.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 2 3 | manual_seed: 0 4 | verbose: false 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | test: 14 | name: REDS 15 | gt_seq_dir: data/REDS/Raw 16 | lr_seq_dir: ~ 17 | filter_list: ['000', '011', '015', '020'] 18 | num_worker_per_gpu: 4 19 | pin_memory: true 20 | 21 | 22 | # model configs 23 | model: 24 | name: FRVSR 25 | 26 | generator: 27 | name: FRNet # frame-recurrent network 28 | in_nc: 3 29 | out_nc: 3 30 | nf: 64 31 | nb: 10 32 | 33 | load_path: pretrained_models/FRVSR_2x_BD_REDS_iter400K.pth 34 | 35 | 36 | # validation configs 37 | test: 38 | # whether to save the generated SR results 39 | save_res: false 40 | res_dir: ~ # use default dir 41 | 42 | # whether to save the test results in a json file 43 | save_json: false 44 | json_dir: ~ # use default dir 45 | 46 | padding_mode: reflect 47 | num_pad_front: 5 48 | 49 | 50 | # metric configs 51 | metric: 52 | PSNR: 53 | colorspace: y 54 | -------------------------------------------------------------------------------- /experiments_BD/FRVSR/FRVSR_REDS_2xSR_2GPU/train.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 2 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | train: 14 | name: REDS 15 | seq_dir: data/REDS/GT.lmdb 16 | filter_file: data/meta/REDS/train_list.txt 17 | data_type: rgb 18 | crop_size: 128 19 | batch_size_per_gpu: 2 20 | num_worker_per_gpu: 3 21 | pin_memory: true 22 | 23 | test: 24 | name: REDS 25 | gt_seq_dir: data/REDS/Raw 26 | lr_seq_dir: ~ 27 | filter_list: ['000', '011', '015', '020'] 28 | num_worker_per_gpu: 4 29 | pin_memory: true 30 | 31 | 32 | # model configs 33 | model: 34 | name: FRVSR 35 | 36 | generator: 37 | name: FRNet # frame-recurrent network 38 | in_nc: 3 39 | out_nc: 3 40 | nf: 64 41 | nb: 10 42 | 43 | load_path: ~ 44 | 45 | 46 | # training settings 47 | train: 48 | tempo_extent: 10 49 | 50 | start_iter: 0 51 | total_iter: 400000 52 | 53 | # configs for generator 54 | generator: 55 | lr: !!float 1e-4 56 | lr_schedule: 57 | type: MultiStepLR 58 | milestones: [150000, 300000] 59 | gamma: 0.5 60 | betas: [0.9, 0.999] 61 | 62 | # other settings 63 | moving_first_frame: true 64 | moving_factor: 0.7 65 | 66 | # criterions 67 | pixel_crit: 68 | type: CB 69 | weight: 1 70 | reduction: mean 71 | 72 | warping_crit: 73 | type: CB 74 | weight: 1 75 | reduction: mean 76 | 77 | 78 | # validation configs 79 | test: 80 | test_freq: 10000 81 | 82 | # whether to save the generated SR results 83 | save_res: false 84 | res_dir: ~ # use default dir 85 | 86 | # whether to save the test results in a json file 87 | save_json: true 88 | json_dir: ~ # use default dir 89 | 90 | padding_mode: reflect 91 | num_pad_front: 5 92 | 93 | 94 | # metric configs 95 | metric: 96 | PSNR: 97 | colorspace: y 98 | 99 | 100 | # logger configs 101 | logger: 102 | log_freq: 100 103 | decay: 0.99 104 | ckpt_freq: 10000 105 | -------------------------------------------------------------------------------- /experiments_BD/FRVSR/FRVSR_REDS_4xSR_2GPU/test.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: false 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | test: 14 | name: REDS 15 | gt_seq_dir: data/REDS/Raw 16 | lr_seq_dir: ~ 17 | filter_list: ['000', '011', '015', '020'] 18 | num_worker_per_gpu: 4 19 | pin_memory: true 20 | 21 | 22 | # model configs 23 | model: 24 | name: FRVSR 25 | 26 | generator: 27 | name: FRNet # frame-recurrent network 28 | in_nc: 3 29 | out_nc: 3 30 | nf: 64 31 | nb: 10 32 | 33 | load_path: pretrained_models/FRVSR_4x_BD_REDS_iter400K.pth 34 | 35 | 36 | # validation configs 37 | test: 38 | # whether to save the generated SR results 39 | save_res: true 40 | res_dir: ~ # use default dir 41 | 42 | padding_mode: reflect 43 | num_pad_front: 5 44 | -------------------------------------------------------------------------------- /experiments_BD/FRVSR/FRVSR_REDS_4xSR_2GPU/train.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | train: 14 | name: REDS 15 | seq_dir: data/REDS/GT.lmdb 16 | filter_file: data/meta/REDS/train_list.txt 17 | data_type: rgb 18 | crop_size: 128 19 | batch_size_per_gpu: 2 20 | num_worker_per_gpu: 3 21 | pin_memory: true 22 | 23 | test: 24 | name: REDS 25 | gt_seq_dir: data/REDS/Raw 26 | lr_seq_dir: ~ 27 | filter_list: ['000', '011', '015', '020'] 28 | num_worker_per_gpu: 4 29 | pin_memory: true 30 | 31 | 32 | # model configs 33 | model: 34 | name: FRVSR 35 | 36 | generator: 37 | name: FRNet # frame-recurrent network 38 | in_nc: 3 39 | out_nc: 3 40 | nf: 64 41 | nb: 10 42 | 43 | load_path: ~ 44 | 45 | 46 | # training settings 47 | train: 48 | tempo_extent: 10 49 | 50 | start_iter: 0 51 | total_iter: 400000 52 | 53 | # configs for generator 54 | generator: 55 | lr: !!float 1e-4 56 | lr_schedule: 57 | type: MultiStepLR 58 | milestones: [150000, 300000] 59 | gamma: 0.5 60 | betas: [0.9, 0.999] 61 | 62 | # other settings 63 | moving_first_frame: true 64 | moving_factor: 0.7 65 | 66 | # criterions 67 | pixel_crit: 68 | type: CB 69 | weight: 1 70 | reduction: mean 71 | 72 | warping_crit: 73 | type: CB 74 | weight: 1 75 | reduction: mean 76 | 77 | 78 | # validation configs 79 | test: 80 | test_freq: 10000 81 | 82 | # whether to save the generated SR results 83 | save_res: false 84 | res_dir: ~ # use default dir 85 | 86 | # whether to save the test results in a json file 87 | save_json: true 88 | json_dir: ~ # use default dir 89 | 90 | padding_mode: reflect 91 | num_pad_front: 5 92 | 93 | 94 | # metric configs 95 | metric: 96 | PSNR: 97 | colorspace: y 98 | 99 | 100 | # logger configs 101 | logger: 102 | log_freq: 100 103 | decay: 0.99 104 | ckpt_freq: 20000 105 | -------------------------------------------------------------------------------- /experiments_BD/FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU/test.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: false 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | test1: 14 | name: Vid4 15 | gt_seq_dir: data/Vid4/GT 16 | lr_seq_dir: data/Vid4/Gaussian4xLR 17 | num_worker_per_gpu: 3 18 | pin_memory: true 19 | 20 | test2: 21 | name: ToS3 22 | gt_seq_dir: data/ToS3/GT 23 | lr_seq_dir: data/ToS3/Gaussian4xLR 24 | num_worker_per_gpu: 3 25 | pin_memory: true 26 | 27 | 28 | # model configs 29 | model: 30 | name: FRVSR 31 | 32 | generator: 33 | name: FRNet # frame-recurrent network 34 | in_nc: 3 35 | out_nc: 3 36 | nf: 64 37 | nb: 10 38 | load_path: pretrained_models/FRVSR_4x_BD_Vimeo_iter400K.pth 39 | 40 | 41 | # test configs 42 | test: 43 | # whether to save the SR results 44 | save_res: true 45 | res_dir: results 46 | 47 | # temporal padding 48 | padding_mode: reflect 49 | num_pad_front: 5 50 | -------------------------------------------------------------------------------- /experiments_BD/FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU/train.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | train: 14 | name: VimeoTecoGAN 15 | seq_dir: data/VimeoTecoGAN/GT.lmdb 16 | filter_file: ~ 17 | data_type: rgb 18 | crop_size: 128 19 | batch_size_per_gpu: 2 20 | num_worker_per_gpu: 3 21 | pin_memory: true 22 | 23 | test: 24 | name: Vid4 25 | gt_seq_dir: data/Vid4/GT 26 | lr_seq_dir: data/Vid4/Gaussian4xLR 27 | filter_file: ~ 28 | num_worker_per_gpu: 4 29 | pin_memory: true 30 | 31 | 32 | # model configs 33 | model: 34 | name: FRVSR 35 | 36 | generator: 37 | name: FRNet # frame-recurrent network 38 | in_nc: 3 39 | out_nc: 3 40 | nf: 64 41 | nb: 10 42 | load_path: ~ 43 | 44 | 45 | # training settings 46 | train: 47 | tempo_extent: 10 48 | 49 | start_iter: 0 50 | total_iter: 400000 51 | 52 | # configs for generator 53 | generator: 54 | lr: !!float 1e-4 55 | lr_schedule: 56 | type: MultiStepLR 57 | milestones: [150000, 300000] 58 | gamma: 0.5 59 | betas: [0.9, 0.999] 60 | 61 | # other settings 62 | moving_first_frame: true 63 | moving_factor: 0.7 64 | 65 | # criterions 66 | pixel_crit: 67 | type: CB 68 | weight: 1 69 | reduction: mean 70 | 71 | warping_crit: 72 | type: CB 73 | weight: 1 74 | reduction: mean 75 | 76 | 77 | # validation configs 78 | test: 79 | test_freq: 10000 80 | 81 | # whether to save the generated SR results 82 | save_res: false 83 | res_dir: ~ # use default dir 84 | 85 | # whether to save the test results in a json file 86 | save_json: true 87 | json_dir: ~ # use default dir 88 | 89 | padding_mode: reflect 90 | num_pad_front: 5 91 | 92 | 93 | # metric configs 94 | metric: 95 | PSNR: 96 | colorspace: y 97 | 98 | 99 | # logger configs 100 | logger: 101 | log_freq: 100 102 | decay: 0.99 103 | ckpt_freq: 20000 104 | -------------------------------------------------------------------------------- /experiments_BD/TecoGAN/TecoGAN_REDS_2xSR_2GPU/test.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 2 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | test: 14 | name: REDS 15 | gt_seq_dir: data/REDS/Raw 16 | lr_seq_dir: ~ 17 | filter_list: ['000', '011', '015', '020'] 18 | num_worker_per_gpu: 4 19 | pin_memory: true 20 | 21 | 22 | # model configs 23 | model: 24 | name: TecoGAN 25 | 26 | generator: 27 | name: FRNet # frame-recurrent network 28 | in_nc: 3 29 | out_nc: 3 30 | nf: 64 31 | nb: 10 32 | 33 | load_path: pretrained_models/TecoGAN_2x_BD_REDS_iter500K.pth 34 | 35 | 36 | # validation configs 37 | test: 38 | test_freq: 10000 39 | 40 | # whether to save the generated SR results 41 | save_res: true 42 | res_dir: ~ # use default dir 43 | 44 | # whether to save the test results in a json file 45 | save_json: false 46 | json_dir: ~ # use default dir 47 | 48 | padding_mode: reflect 49 | num_pad_front: 5 50 | 51 | 52 | # metric configs 53 | metric: 54 | PSNR: 55 | colorspace: y 56 | 57 | LPIPS: 58 | model: net-lin 59 | net: alex 60 | colorspace: rgb 61 | spatial: false 62 | version: 0.1 63 | 64 | tOF: 65 | colorspace: y 66 | -------------------------------------------------------------------------------- /experiments_BD/TecoGAN/TecoGAN_REDS_2xSR_2GPU/train.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 2 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | train: 14 | name: REDS 15 | seq_dir: data/REDS/GT.lmdb 16 | filter_file: data/meta/REDS/train_list.txt 17 | data_type: rgb 18 | crop_size: 128 19 | batch_size_per_gpu: 2 20 | num_worker_per_gpu: 3 21 | pin_memory: true 22 | 23 | test: 24 | name: REDS 25 | gt_seq_dir: data/REDS/Raw 26 | lr_seq_dir: ~ 27 | filter_list: ['000', '011', '015', '020'] 28 | num_worker_per_gpu: 4 29 | pin_memory: true 30 | 31 | 32 | # model configs 33 | model: 34 | name: TecoGAN 35 | 36 | generator: 37 | name: FRNet # frame-recurrent network 38 | in_nc: 3 39 | out_nc: 3 40 | nf: 64 41 | nb: 10 42 | 43 | load_path: experiments_BD/FRVSR/FRVSR_REDS_2xSR_2GPU/train/ckpt/G_iter400000.pth 44 | 45 | discriminator: 46 | name: STNet # spatio-temporal network 47 | in_nc: 3 48 | tempo_range: 3 49 | 50 | load_path: ~ 51 | 52 | 53 | # training configs 54 | train: 55 | tempo_extent: 10 56 | 57 | start_iter: 0 58 | total_iter: 500000 59 | 60 | # configs for generator 61 | generator: 62 | lr: !!float 5e-5 63 | lr_schedule: 64 | type: FixedLR 65 | betas: [0.9, 0.999] 66 | 67 | # configs for discriminator 68 | discriminator: 69 | update_policy: adaptive 70 | update_threshold: 0.4 71 | crop_border_ratio: 0.75 72 | lr: !!float 5e-5 73 | lr_schedule: 74 | type: FixedLR 75 | betas: [0.9, 0.999] 76 | 77 | # other configs 78 | moving_first_frame: true 79 | moving_factor: 0.7 80 | 81 | # criterions 82 | pixel_crit: 83 | type: CB 84 | weight: 1 85 | reduction: mean 86 | 87 | warping_crit: 88 | type: CB 89 | weight: 1 90 | reduction: mean 91 | 92 | feature_crit: 93 | type: CosineSimilarity 94 | weight: 0.2 95 | reduction: mean 96 | feature_layers: [8, 17, 26, 35] 97 | 98 | pingpong_crit: 99 | type: CB 100 | weight: 0.5 101 | reduction: mean 102 | 103 | gan_crit: 104 | type: GAN 105 | weight: 0.01 106 | reduction: mean 107 | 108 | 109 | # validation configs 110 | test: 111 | test_freq: 10000 112 | 113 | # whether to save the generated SR results 114 | save_res: false 115 | res_dir: ~ # use default dir 116 | 117 | # whether to save the test results in a json file 118 | save_json: true 119 | json_dir: ~ # use default dir 120 | 121 | padding_mode: reflect 122 | num_pad_front: 5 123 | 124 | 125 | # metric configs 126 | metric: 127 | PSNR: 128 | colorspace: y 129 | 130 | LPIPS: 131 | model: net-lin 132 | net: alex 133 | colorspace: rgb 134 | spatial: false 135 | version: 0.1 136 | 137 | tOF: 138 | colorspace: y 139 | 140 | 141 | # logger configs 142 | logger: 143 | log_freq: 100 144 | decay: 0.99 145 | ckpt_freq: 20000 146 | -------------------------------------------------------------------------------- /experiments_BD/TecoGAN/TecoGAN_REDS_4xSR_2GPU/test.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: false 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | test: 14 | name: REDS 15 | gt_seq_dir: data/REDS/Raw 16 | lr_seq_dir: ~ 17 | filter_list: ['000', '011', '015', '020'] 18 | num_worker_per_gpu: 4 19 | pin_memory: true 20 | 21 | 22 | # model configs 23 | model: 24 | name: TecoGAN 25 | 26 | generator: 27 | name: FRNet # frame-recurrent network 28 | in_nc: 3 29 | out_nc: 3 30 | nf: 64 31 | nb: 10 32 | 33 | load_path: pretrained_models/TecoGAN_4x_BD_REDS_iter500K.pth 34 | 35 | discriminator: 36 | name: STNet # spatio-temporal network 37 | in_nc: 3 38 | tempo_range: 3 39 | 40 | load_path: ~ 41 | 42 | 43 | # validation configs 44 | test: 45 | # whether to save the generated SR results 46 | save_res: true 47 | res_dir: ~ # use default dir 48 | 49 | padding_mode: reflect 50 | num_pad_front: 5 51 | -------------------------------------------------------------------------------- /experiments_BD/TecoGAN/TecoGAN_REDS_4xSR_2GPU/train.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | train: 14 | name: REDS 15 | seq_dir: data/REDS/GT.lmdb 16 | filter_file: data/meta/REDS/train_list.txt 17 | data_type: rgb 18 | crop_size: 128 19 | batch_size_per_gpu: 2 20 | num_worker_per_gpu: 3 21 | pin_memory: true 22 | 23 | test: 24 | name: REDS 25 | gt_seq_dir: data/REDS/Raw 26 | lr_seq_dir: ~ 27 | filter_list: ['000', '011', '015', '020'] 28 | num_worker_per_gpu: 4 29 | pin_memory: true 30 | 31 | 32 | # model configs 33 | model: 34 | name: TecoGAN 35 | 36 | generator: 37 | name: FRNet # frame-recurrent network 38 | in_nc: 3 39 | out_nc: 3 40 | nf: 64 41 | nb: 10 42 | 43 | load_path: experiments_BD/FRVSR/FRVSR_REDS_4xSR_2GPU/train/ckpt/G_iter400000.pth 44 | 45 | discriminator: 46 | name: STNet # spatio-temporal network 47 | in_nc: 3 48 | tempo_range: 3 49 | 50 | load_path: ~ 51 | 52 | 53 | # training configs 54 | train: 55 | tempo_extent: 10 56 | 57 | start_iter: 0 58 | total_iter: 500000 59 | 60 | # configs for generator 61 | generator: 62 | lr: !!float 5e-5 63 | lr_schedule: 64 | type: FixedLR 65 | betas: [0.9, 0.999] 66 | 67 | # configs for discriminator 68 | discriminator: 69 | update_policy: adaptive 70 | update_threshold: 0.4 71 | crop_border_ratio: 0.75 72 | lr: !!float 5e-5 73 | lr_schedule: 74 | type: FixedLR 75 | betas: [0.9, 0.999] 76 | 77 | # other configs 78 | moving_first_frame: true 79 | moving_factor: 0.7 80 | 81 | # criterions 82 | pixel_crit: 83 | type: CB 84 | weight: 1 85 | reduction: mean 86 | 87 | warping_crit: 88 | type: CB 89 | weight: 1 90 | reduction: mean 91 | 92 | feature_crit: 93 | type: CosineSimilarity 94 | weight: 0.2 95 | reduction: mean 96 | feature_layers: [8, 17, 26, 35] 97 | 98 | pingpong_crit: 99 | type: CB 100 | weight: 0.5 101 | reduction: mean 102 | 103 | gan_crit: 104 | type: GAN 105 | weight: 0.01 106 | reduction: mean 107 | 108 | 109 | # validation configs 110 | test: 111 | test_freq: 10000 112 | 113 | # whether to save the generated SR results 114 | save_res: false 115 | res_dir: ~ # use default dir 116 | 117 | # whether to save the test results in a json file 118 | save_json: true 119 | json_dir: ~ # use default dir 120 | 121 | padding_mode: reflect 122 | num_pad_front: 5 123 | 124 | 125 | # metric configs 126 | metric: 127 | PSNR: 128 | colorspace: y 129 | 130 | LPIPS: 131 | model: net-lin 132 | net: alex 133 | colorspace: rgb 134 | spatial: false 135 | version: 0.1 136 | 137 | tOF: 138 | colorspace: y 139 | 140 | 141 | # logger configs 142 | logger: 143 | log_freq: 100 144 | decay: 0.99 145 | ckpt_freq: 20000 146 | -------------------------------------------------------------------------------- /experiments_BD/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/test.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: false 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | test1: 14 | name: Vid4 15 | gt_seq_dir: data/Vid4/GT 16 | lr_seq_dir: data/Vid4/Gaussian4xLR 17 | num_worker_per_gpu: 3 18 | pin_memory: true 19 | 20 | test2: 21 | name: ToS3 22 | gt_seq_dir: data/ToS3/GT 23 | lr_seq_dir: data/ToS3/Gaussian4xLR 24 | num_worker_per_gpu: 3 25 | pin_memory: true 26 | 27 | 28 | # model configs 29 | model: 30 | name: TecoGAN 31 | 32 | generator: 33 | name: FRNet # frame-recurrent network 34 | in_nc: 3 35 | out_nc: 3 36 | nf: 64 37 | nb: 10 38 | load_path: pretrained_models/TecoGAN_4x_BD_Vimeo_iter500K.pth 39 | 40 | 41 | # test configs 42 | test: 43 | # whether to save the SR results 44 | save_res: true 45 | res_dir: results 46 | 47 | # temporal padding 48 | padding_mode: reflect 49 | num_pad_front: 5 50 | -------------------------------------------------------------------------------- /experiments_BD/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/train.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BD 11 | sigma: 1.5 12 | 13 | train: 14 | name: VimeoTecoGAN 15 | seq_dir: data/VimeoTecoGAN/GT.lmdb 16 | filter_file: ~ 17 | data_type: rgb 18 | crop_size: 128 19 | batch_size_per_gpu: 2 20 | num_worker_per_gpu: 3 21 | pin_memory: true 22 | 23 | test: 24 | name: Vid4 25 | gt_seq_dir: data/Vid4/GT 26 | lr_seq_dir: data/Vid4/Gaussian4xLR 27 | filter_file: ~ 28 | num_worker_per_gpu: 3 29 | pin_memory: true 30 | 31 | 32 | # model configs 33 | model: 34 | name: TecoGAN 35 | 36 | generator: 37 | name: FRNet # frame-recurrent network 38 | in_nc: 3 39 | out_nc: 3 40 | nf: 64 41 | nb: 10 42 | load_path: pretrained_models/FRVSR_BD_iter400000.pth 43 | 44 | discriminator: 45 | name: STNet # spatio-temporal network 46 | in_nc: 3 47 | tempo_range: 3 48 | load_path: ~ 49 | 50 | 51 | # training configs 52 | train: 53 | tempo_extent: 10 54 | 55 | start_iter: 0 56 | total_iter: 500000 57 | 58 | # configs for generator 59 | generator: 60 | lr: !!float 5e-5 61 | lr_schedule: 62 | type: FixedLR 63 | betas: [0.9, 0.999] 64 | 65 | # configs for discriminator 66 | discriminator: 67 | update_policy: adaptive 68 | update_threshold: 0.4 69 | crop_border_ratio: 0.75 70 | lr: !!float 5e-5 71 | lr_schedule: 72 | type: FixedLR 73 | betas: [0.9, 0.999] 74 | 75 | # other configs 76 | moving_first_frame: true 77 | moving_factor: 0.7 78 | 79 | # criterions 80 | pixel_crit: 81 | type: CB 82 | weight: 1 83 | reduction: mean 84 | 85 | warping_crit: 86 | type: CB 87 | weight: 1 88 | reduction: mean 89 | 90 | feature_crit: 91 | type: CosineSimilarity 92 | weight: 0.2 93 | reduction: mean 94 | feature_layers: [8, 17, 26, 35] 95 | 96 | pingpong_crit: 97 | type: CB 98 | weight: 0.5 99 | reduction: mean 100 | 101 | gan_crit: 102 | type: GAN 103 | weight: 0.01 104 | reduction: mean 105 | 106 | 107 | # validation configs 108 | test: 109 | test_freq: 10000 110 | 111 | # whether to save the generated SR results 112 | save_res: false 113 | res_dir: ~ # use default dir 114 | 115 | # whether to save the test results in a json file 116 | save_json: true 117 | json_dir: ~ # use default dir 118 | 119 | padding_mode: reflect 120 | num_pad_front: 5 121 | 122 | 123 | # metric configs 124 | metric: 125 | PSNR: 126 | colorspace: y 127 | 128 | LPIPS: 129 | model: net-lin 130 | net: alex 131 | colorspace: rgb 132 | spatial: false 133 | version: 0.1 134 | 135 | tOF: 136 | colorspace: y 137 | 138 | 139 | # logger configs 140 | logger: 141 | log_freq: 100 142 | decay: 0.99 143 | ckpt_freq: 20000 144 | -------------------------------------------------------------------------------- /experiments_BI/FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU/test.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: false 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BI 11 | 12 | test1: 13 | name: Vid4 14 | gt_seq_dir: data/Vid4/GT 15 | lr_seq_dir: data/Vid4/Bicubic4xLR 16 | num_worker_per_gpu: 3 17 | pin_memory: true 18 | 19 | test2: 20 | name: ToS3 21 | gt_seq_dir: data/ToS3/GT 22 | lr_seq_dir: data/ToS3/Bicubic4xLR 23 | num_worker_per_gpu: 3 24 | pin_memory: true 25 | 26 | 27 | # model configs 28 | model: 29 | name: FRVSR 30 | 31 | generator: 32 | name: FRNet # frame-recurrent network 33 | in_nc: 3 34 | out_nc: 3 35 | nf: 64 36 | nb: 10 37 | load_path: pretrained_models/FRVSR_4x_BI_Vimeo_iter400K.pth 38 | 39 | 40 | # test configs 41 | test: 42 | # whether to save the SR results 43 | save_res: true 44 | res_dir: results 45 | 46 | # temporal padding 47 | padding_mode: reflect 48 | num_pad_front: 5 49 | -------------------------------------------------------------------------------- /experiments_BI/FRVSR/FRVSR_VimeoTecoGAN_4xSR_2GPU/train.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BI 11 | 12 | train: 13 | name: VimeoTecoGAN 14 | gt_seq_dir: data/VimeoTecoGAN/GT.lmdb 15 | lr_seq_dir: data/VimeoTecoGAN/Bicubic4xLR.lmdb 16 | filter_file: ~ 17 | data_type: rgb 18 | gt_crop_size: 128 19 | batch_size_per_gpu: 2 20 | num_worker_per_gpu: 3 21 | pin_memory: true 22 | 23 | test: 24 | name: Vid4 25 | gt_seq_dir: data/Vid4/GT 26 | lr_seq_dir: data/Vid4/Bicubic4xLR 27 | filter_file: ~ 28 | num_worker_per_gpu: 3 29 | pin_memory: true 30 | 31 | 32 | # model configs 33 | model: 34 | name: FRVSR 35 | 36 | generator: 37 | name: FRNet # frame-recurrent network 38 | in_nc: 3 39 | out_nc: 3 40 | nf: 64 41 | nb: 10 42 | load_path: ~ 43 | 44 | 45 | # training settings 46 | train: 47 | tempo_extent: 10 48 | 49 | start_iter: 0 50 | total_iter: 400000 51 | 52 | # configs for generator 53 | generator: 54 | lr: !!float 1e-4 55 | lr_schedule: 56 | type: MultiStepLR 57 | milestones: [150000, 300000] 58 | gamma: 0.5 59 | betas: [0.9, 0.999] 60 | 61 | # other settings 62 | moving_first_frame: true 63 | moving_factor: 0.7 64 | 65 | # criterions 66 | pixel_crit: 67 | type: CB 68 | weight: 1 69 | reduction: mean 70 | 71 | warping_crit: 72 | type: CB 73 | weight: 1 74 | reduction: mean 75 | 76 | 77 | # validation configs 78 | test: 79 | test_freq: 10000 80 | 81 | # whether to save the generated SR results 82 | save_res: false 83 | res_dir: ~ # use default dir 84 | 85 | # whether to save the test results in a json file 86 | save_json: true 87 | json_dir: ~ # use default dir 88 | 89 | padding_mode: reflect 90 | num_pad_front: 5 91 | 92 | 93 | # metric configs 94 | metric: 95 | PSNR: 96 | colorspace: y 97 | 98 | 99 | # logger configs 100 | logger: 101 | log_freq: 100 102 | decay: 0.99 103 | ckpt_freq: 20000 104 | -------------------------------------------------------------------------------- /experiments_BI/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/test.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: false 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BI 11 | 12 | test1: 13 | name: Vid4 14 | gt_seq_dir: data/Vid4/GT 15 | lr_seq_dir: data/Vid4/Bicubic4xLR 16 | num_worker_per_gpu: 3 17 | pin_memory: true 18 | 19 | test2: 20 | name: ToS3 21 | gt_seq_dir: data/ToS3/GT 22 | lr_seq_dir: data/ToS3/Bicubic4xLR 23 | num_worker_per_gpu: 3 24 | pin_memory: true 25 | 26 | 27 | # model configs 28 | model: 29 | name: TecoGAN 30 | 31 | generator: 32 | name: FRNet # frame-recurrent network 33 | in_nc: 3 34 | out_nc: 3 35 | nf: 64 36 | nb: 10 37 | load_path: pretrained_models/TecoGAN_4x_BI_Vimeo_iter500K.pth 38 | 39 | 40 | # test configs 41 | test: 42 | # whether to save the SR results 43 | save_res: true 44 | res_dir: results 45 | 46 | # temporal padding 47 | padding_mode: reflect 48 | num_pad_front: 5 49 | -------------------------------------------------------------------------------- /experiments_BI/TecoGAN/TecoGAN_VimeoTecoGAN_4xSR_2GPU/train.yml: -------------------------------------------------------------------------------- 1 | # basic configs 2 | scale: 4 3 | manual_seed: 0 4 | verbose: true 5 | 6 | 7 | # dataset configs 8 | dataset: 9 | degradation: 10 | type: BI 11 | 12 | train: 13 | name: VimeoTecoGAN 14 | gt_seq_dir: data/VimeoTecoGAN/GT.lmdb 15 | lr_seq_dir: data/VimeoTecoGAN/Bicubic4xLR.lmdb 16 | filter_file: ~ 17 | data_type: rgb 18 | gt_crop_size: 128 19 | batch_size_per_gpu: 2 20 | num_worker_per_gpu: 3 21 | pin_memory: true 22 | 23 | test: 24 | name: Vid4 25 | gt_seq_dir: data/Vid4/GT 26 | lr_seq_dir: data/Vid4/Bicubic4xLR 27 | filter_file: ~ 28 | num_worker_per_gpu: 4 29 | pin_memory: true 30 | 31 | 32 | # model configs 33 | model: 34 | name: TecoGAN 35 | 36 | generator: 37 | name: FRNet # frame-recurrent network 38 | in_nc: 3 39 | out_nc: 3 40 | nf: 64 41 | nb: 10 42 | load_path: pretrained_models/FRVSR_BI_iter400000.pth 43 | 44 | discriminator: 45 | name: STNet # spatio-temporal network 46 | in_nc: 3 47 | tempo_range: 3 48 | load_path: ~ 49 | 50 | 51 | # training configs 52 | train: 53 | tempo_extent: 10 54 | 55 | start_iter: 0 56 | total_iter: 500000 57 | 58 | # configs for generator 59 | generator: 60 | lr: !!float 5e-5 61 | lr_schedule: 62 | type: FixedLR 63 | betas: [0.9, 0.999] 64 | 65 | # configs for discriminator 66 | discriminator: 67 | update_policy: adaptive 68 | update_threshold: 0.4 69 | crop_border_ratio: 0.75 70 | lr: !!float 5e-5 71 | lr_schedule: 72 | type: FixedLR 73 | betas: [0.9, 0.999] 74 | 75 | # other configs 76 | moving_first_frame: true 77 | moving_factor: 0.7 78 | 79 | # criterions 80 | pixel_crit: 81 | type: CB 82 | weight: 1 83 | reduction: mean 84 | 85 | warping_crit: 86 | type: CB 87 | weight: 1 88 | reduction: mean 89 | 90 | feature_crit: 91 | type: CosineSimilarity 92 | weight: 0.2 93 | reduction: mean 94 | feature_layers: [8, 17, 26, 35] 95 | 96 | pingpong_crit: 97 | type: CB 98 | weight: 0.5 99 | reduction: mean 100 | 101 | gan_crit: 102 | type: GAN 103 | weight: 0.01 104 | reduction: mean 105 | 106 | 107 | # validation configs 108 | test: 109 | test_freq: 10000 110 | 111 | # whether to save the generated SR results 112 | save_res: false 113 | res_dir: ~ # use default dir 114 | 115 | # whether to save the test results in a json file 116 | save_json: true 117 | json_dir: ~ # use default dir 118 | 119 | padding_mode: reflect 120 | num_pad_front: 5 121 | 122 | 123 | # metric configs 124 | metric: 125 | PSNR: 126 | colorspace: y 127 | 128 | LPIPS: 129 | model: net-lin 130 | net: alex 131 | colorspace: rgb 132 | spatial: false 133 | version: 0.1 134 | 135 | tOF: 136 | colorspace: y 137 | 138 | 139 | # logger configs 140 | logger: 141 | log_freq: 100 142 | decay: 0.99 143 | ckpt_freq: 20000 144 | 145 | -------------------------------------------------------------------------------- /pretrained_models/put_the_pretrained_models_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/pretrained_models/put_the_pretrained_models_here -------------------------------------------------------------------------------- /profile.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is used to profile a model. 4 | 5 | # basic settings 6 | root_dir=. 7 | degradation=$1 8 | model=$2 9 | gpu_ids=0 10 | # specify the size of input data in the format of [color]x[height]x[weight] 11 | lr_size=$3 12 | 13 | 14 | # run 15 | python ${root_dir}/codes/main.py \ 16 | --exp_dir ${root_dir}/experiments_${degradation}/${model} \ 17 | --mode profile \ 18 | --opt test.yml \ 19 | --gpu_ids ${gpu_ids} \ 20 | --lr_size ${lr_size} \ 21 | --test_speed 22 | 23 | -------------------------------------------------------------------------------- /resources/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/benchmark.png -------------------------------------------------------------------------------- /resources/bridge.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/bridge.gif -------------------------------------------------------------------------------- /resources/fire.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/fire.gif -------------------------------------------------------------------------------- /resources/foliage.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/foliage.gif -------------------------------------------------------------------------------- /resources/losses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/losses.png -------------------------------------------------------------------------------- /resources/metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/metrics.png -------------------------------------------------------------------------------- /resources/pond.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skycrapers/TecoGAN-PyTorch/903b070bd7dda27fb29111e39af837d589506f95/resources/pond.gif -------------------------------------------------------------------------------- /scripts/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | import glob 5 | import lmdb 6 | import pickle 7 | import random 8 | 9 | import numpy as np 10 | import cv2 11 | 12 | 13 | def create_lmdb(dataset, raw_dir, lmdb_dir, filter_file=''): 14 | assert dataset in ('VimeoTecoGAN', 'REDS'), f'Unknown Dataset: {dataset}' 15 | print(f'>> Start to create lmdb for {dataset}') 16 | 17 | # scan dir 18 | if filter_file: # use sequences specified by the filter_file 19 | with open(filter_file, 'r') as f: 20 | seq_idx_lst = sorted([line.strip() for line in f]) 21 | else: # use all found sequences 22 | seq_idx_lst = sorted(os.listdir(raw_dir)) 23 | 24 | num_seq = len(seq_idx_lst) 25 | print(f'>> Number of sequences: {num_seq}') 26 | 27 | # compute space to be allocated 28 | nbytes = 0 29 | for seq_idx in seq_idx_lst: 30 | frm_path_lst = sorted(glob.glob(osp.join(raw_dir, seq_idx, '*.png'))) 31 | nbytes_per_frm = cv2.imread(frm_path_lst[0], cv2.IMREAD_UNCHANGED).nbytes 32 | nbytes += len(frm_path_lst) * nbytes_per_frm 33 | alloc_size = round(2 * nbytes) 34 | print(f'>> Space required for lmdb generation: {alloc_size / (1 << 30):.2f} GB') 35 | 36 | # create lmdb environment 37 | env = lmdb.open(lmdb_dir, map_size=alloc_size) 38 | 39 | # write data to lmdb 40 | commit_freq = 5 41 | keys = [] 42 | txn = env.begin(write=True) 43 | for b, seq_idx in enumerate(seq_idx_lst): 44 | # log 45 | print(f' Processing sequence: {seq_idx} ({b + 1}/{num_seq})\r', end='') 46 | 47 | # get info 48 | frm_path_lst = sorted(glob.glob(osp.join(raw_dir, seq_idx, '*.png'))) 49 | n_frm = len(frm_path_lst) 50 | 51 | # read frames 52 | for i in range(n_frm): 53 | frm = cv2.imread(frm_path_lst[i], cv2.IMREAD_UNCHANGED) 54 | frm = np.ascontiguousarray(frm[..., ::-1]) # hwc|rgb|uint8 55 | 56 | h, w, c = frm.shape 57 | key = f'{seq_idx}_{n_frm}x{h}x{w}_{i:04d}' 58 | 59 | txn.put(key.encode('ascii'), frm) 60 | keys.append(key) 61 | 62 | # commit 63 | if b % commit_freq == 0: 64 | txn.commit() 65 | txn = env.begin(write=True) 66 | 67 | txn.commit() 68 | env.close() 69 | 70 | # create meta information 71 | meta_info = { 72 | 'name': dataset, 73 | 'color': 'RGB', 74 | 'keys': keys 75 | } 76 | pickle.dump(meta_info, open(osp.join(lmdb_dir, 'meta_info.pkl'), 'wb')) 77 | 78 | print(f'>> Finished lmdb generation for {dataset}') 79 | 80 | 81 | def check_lmdb(dataset, lmdb_dir): 82 | 83 | def visualize(win, img): 84 | cv2.namedWindow(win, 0) 85 | cv2.resizeWindow(win, img.shape[-2], img.shape[-3]) 86 | cv2.imshow(win, img[..., ::-1]) 87 | cv2.waitKey(0) 88 | cv2.destroyAllWindows() 89 | 90 | assert dataset in ('VimeoTecoGAN', 'REDS'), f'Unknown Dataset: {dataset}' 91 | print(f'>> Start to check lmdb dataset: {dataset}.lmdb') 92 | 93 | # load keys 94 | meta_info = pickle.load(open(osp.join(lmdb_dir, 'meta_info.pkl'), 'rb')) 95 | keys = meta_info['keys'] 96 | print(f'>> Number of keys: {len(keys)}') 97 | 98 | # randomly select frames for visualization 99 | with lmdb.open(lmdb_dir) as env: 100 | for i in range(3): # can be replaced to any number 101 | idx = random.randint(0, len(keys) - 1) 102 | key = keys[idx] 103 | 104 | # parse key 105 | key_lst = key.split('_') 106 | vid, sz, frm = '_'.join(key_lst[:-2]), key_lst[-2], key_lst[-1] 107 | sz = tuple(map(int, sz.split('x'))) 108 | sz = (*sz[1:], 3) 109 | print(f' Visualizing frame: #{frm} from sequence: {vid} (size: {sz})') 110 | 111 | with env.begin() as txn: 112 | buf = txn.get(key.encode('ascii')) 113 | val = np.frombuffer(buf, dtype=np.uint8).reshape(*sz) # hwc 114 | 115 | visualize(key, val) 116 | 117 | print(f'>> Finished lmdb checking for {dataset}') 118 | 119 | 120 | if __name__ == '__main__': 121 | # parse args 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--dataset', type=str, required=True, 124 | help='VimeoTecoGAN | REDS') 125 | parser.add_argument('--raw_dir', type=str, required=True, 126 | help='Dir to the raw data') 127 | parser.add_argument('--lmdb_dir', type=str, required=True, 128 | help='Dir to the lmdb data') 129 | parser.add_argument('--filter_file', type=str, default='', 130 | help='File used to select sequences') 131 | args = parser.parse_args() 132 | 133 | # run 134 | if osp.exists(args.lmdb_dir): 135 | print(f'>> Dataset [{args.dataset}] already exists.') 136 | check_lmdb(args.dataset, args.lmdb_dir) 137 | else: 138 | create_lmdb(args.dataset, args.raw_dir, args.lmdb_dir, args.filter_file) 139 | check_lmdb(args.dataset, args.lmdb_dir) 140 | -------------------------------------------------------------------------------- /scripts/download/download_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is used to download datasets for model evaluation. 4 | # Usage: 5 | # cd 6 | # bash ./scripts/download/download_datasets 7 | 8 | 9 | 10 | function download_large_file() { 11 | local DATA_DIR=$1 12 | local FID=$2 13 | local FNAME=$3 14 | 15 | wget --load-cookies ${DATA_DIR}/cookies.txt -O ${DATA_DIR}/${FNAME}.zip "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ${DATA_DIR}/cookies.txt --keep-session-cookies --no-check-certificate "https://drive.google.com/uc?export=download&id=${FID}" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FID}" && rm -rf ${DATA_DIR}/cookies.txt 16 | } 17 | 18 | function download_small_file() { 19 | local DATA_DIR=$1 20 | local FID=$2 21 | local FNAME=$3 22 | 23 | wget --no-check-certificate -O ${DATA_DIR}/${FNAME}.zip "https://drive.google.com/uc?export=download&id=${FID}" 24 | } 25 | 26 | function check_md5() { 27 | local FPATH=$1 28 | local MD5=$2 29 | 30 | if [ ${MD5} != $(md5sum ${FPATH} | cut -d " " -f1) ]; then 31 | echo "!!! Fail to match MD5 sum for: ${FPATH}" 32 | echo "!!! Please try downloading it again" 33 | exit 1 34 | fi 35 | } 36 | 37 | 38 | # Vid4 GT 39 | if [ ! -d "./data/Vid4/GT" ]; then 40 | DATA_DIR="./data/Vid4" 41 | FID="1T8TuyyOxEUfXzCanH5kvNH2iA8nI06Wj" 42 | MD5="d2850eccf30092418f15afe4a7ea27e5" 43 | 44 | echo ">>> Start to download [Vid4 GT] dataset" 45 | 46 | mkdir -p ${DATA_DIR} 47 | download_large_file ${DATA_DIR} ${FID} GT 48 | check_md5 ${DATA_DIR}/GT.zip ${MD5} 49 | unzip ${DATA_DIR}/GT.zip -d ${DATA_DIR} && rm ${DATA_DIR}/GT.zip 50 | fi 51 | 52 | sleep 1s 53 | 54 | # ToS3 GT 55 | if [ ! -d "./data/ToS3/GT" ]; then 56 | DATA_DIR="./data/ToS3" 57 | FID="1XoR_NVBR-LbZOA8fXh7d4oPV0M8fRi8a" 58 | MD5="56eb9e8298a4e955d618c1658dfc89c9" 59 | 60 | echo ">>> Start to download [ToS3 GT] dataset" 61 | 62 | mkdir -p ${DATA_DIR} 63 | download_large_file ${DATA_DIR} ${FID} GT 64 | check_md5 ${DATA_DIR}/GT.zip ${MD5} 65 | unzip ${DATA_DIR}/GT.zip -d ${DATA_DIR} && rm ${DATA_DIR}/GT.zip 66 | fi 67 | 68 | 69 | if [ $1 == BD ]; then 70 | # Vid4 LR BD 71 | if [ ! -d "./data/Vid4/Gaussian4xLR" ]; then 72 | DATA_DIR="./data/Vid4" 73 | FID="1-5NFW6fEPUczmRqKHtBVyhn2Wge6j3ma" 74 | MD5="3b525cb0f10286743c76950d9949a255" 75 | 76 | echo ">>> Start to download [Vid4 LR] dataset (BD degradation)" 77 | 78 | download_small_file ${DATA_DIR} ${FID} Gaussian4xLR 79 | check_md5 ${DATA_DIR}/Gaussian4xLR.zip ${MD5} 80 | unzip ${DATA_DIR}/Gaussian4xLR.zip -d ${DATA_DIR} && rm ${DATA_DIR}/Gaussian4xLR.zip 81 | fi 82 | 83 | sleep 1s 84 | 85 | # ToS3 LR BD 86 | if [ ! -d "./data/ToS3/Gaussian4xLR" ]; then 87 | DATA_DIR="./data/ToS3" 88 | FID="1rDCe61kR-OykLyCo2Ornd2YgPnul2ffM" 89 | MD5="803609a12453a267eb9c78b68e073e81" 90 | 91 | echo ">>> Start to download [ToS3 LR] dataset (BD degradation)" 92 | 93 | download_large_file ${DATA_DIR} ${FID} Gaussian4xLR 94 | check_md5 ${DATA_DIR}/Gaussian4xLR.zip ${MD5} 95 | unzip ${DATA_DIR}/Gaussian4xLR.zip -d ${DATA_DIR} && rm ${DATA_DIR}/Gaussian4xLR.zip 96 | fi 97 | 98 | elif [ $1 == BI ]; then 99 | # Vid4 LR BI 100 | if [ ! -d "./data/Vid4/Bicubic4xLR" ]; then 101 | DATA_DIR="./data/Vid4" 102 | FID="1Kg0VBgk1r9I1c4f5ZVZ4sbfqtVRYub91" 103 | MD5="35666bd16ce582ae74fa935b3732ae1a" 104 | 105 | echo ">>> Start to download [Vid4 LR] dataset (BI degradation)" 106 | 107 | download_small_file ${DATA_DIR} ${FID} Bicubic4xLR 108 | check_md5 ${DATA_DIR}/Bicubic4xLR.zip ${MD5} 109 | unzip ${DATA_DIR}/Bicubic4xLR.zip -d ${DATA_DIR} && rm ${DATA_DIR}/Bicubic4xLR.zip 110 | fi 111 | 112 | sleep 1s 113 | 114 | # ToS3 LR BI 115 | if [ ! -d "./data/ToS3/Bicubic4xLR" ]; then 116 | DATA_DIR="./data/ToS3" 117 | FID="1FNuC0jajEjH9ycqDkH4cZQ3_eUqjxzzf" 118 | MD5="3b165ffc8819d695500cf565bf3a9ca2" 119 | 120 | echo ">>> Start to download [ToS3 LR] dataset (BI degradation)" 121 | 122 | download_large_file ${DATA_DIR} ${FID} Bicubic4xLR 123 | check_md5 ${DATA_DIR}/Bicubic4xLR.zip ${MD5} 124 | unzip ${DATA_DIR}/Bicubic4xLR.zip -d ${DATA_DIR} && rm ${DATA_DIR}/Bicubic4xLR.zip 125 | fi 126 | 127 | else 128 | echo Unknown Degradation Type: $1 \(Currently supported: \"BD\" or \"BI\"\) 129 | fi 130 | -------------------------------------------------------------------------------- /scripts/download/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is used to download pretrained models 4 | # Usage: 5 | # cd 6 | # bash ./scripts/download/download_models 7 | 8 | 9 | function download_small_file() { 10 | local FPATH=$1 11 | local FID=$2 12 | local MD5=$3 13 | 14 | wget --no-check-certificate -O ${FPATH} "https://drive.google.com/uc?export=download&id=${FID}" 15 | 16 | if [ ${MD5} != $(md5sum ${FPATH} | cut -d " " -f1) ]; then 17 | echo "!!! Fail to match MD5 sum for: ${FPATH}" 18 | echo "!!! Please try downloading it again" 19 | exit 1 20 | fi 21 | } 22 | 23 | 24 | 25 | if [ $1 == BD -a $2 == TecoGAN ]; then 26 | FPATH="./pretrained_models/TecoGAN_BD_iter500000.pth" 27 | if [ ! -f $FPATH ]; then 28 | FID="13FPxKE6q7tuRrfhTE7GB040jBeURBj58" 29 | MD5="13d826c9f066538aea9340e8d3387289" 30 | 31 | echo "Start to download model [TecoGAN BD]" 32 | download_small_file ${FPATH} ${FID} ${MD5} 33 | fi 34 | 35 | elif [ $1 == BD -a $2 == FRVSR ]; then 36 | FPATH="./pretrained_models/FRVSR_BD_iter400000.pth" 37 | if [ ! -f $FPATH ]; then 38 | FID="11kPVS04a3B3k0SD-mKEpY_Q8WL7KrTIA" 39 | MD5="77d33c58b5cbf1fc68a1887be80ed18f" 40 | 41 | echo "Start to download model [FRVSR BD]" 42 | download_small_file ${FPATH} ${FID} ${MD5} 43 | fi 44 | 45 | elif [ $1 == BI -a $2 == TecoGAN ]; then 46 | FPATH="./pretrained_models/TecoGAN_BI_iter500000.pth" 47 | if [ ! -f $FPATH ]; then 48 | FID="1ie1F7wJcO4mhNWK8nPX7F0LgOoPzCwEu" 49 | MD5="4955b65b80f88456e94443d9d042d1e6" 50 | 51 | echo "Start to download model [TecoGAN BI]" 52 | download_small_file ${FPATH} ${FID} ${MD5} 53 | fi 54 | 55 | elif [ $1 == BI -a $2 == FRVSR ]; then 56 | FPATH="./pretrained_models/FRVSR_BI_iter400000.pth" 57 | if [ ! -f $FPATH ]; then 58 | FID="1wejMAFwIBde_7sz-H7zwlOCbCvjt3G9L" 59 | MD5="ad6337d934ec7ca72441082acd80c4ae" 60 | 61 | echo "Start to download model [FRVSR BI]" 62 | download_small_file ${FPATH} ${FID} ${MD5} 63 | fi 64 | 65 | else 66 | echo Unknown combination: $1, $2 67 | fi 68 | -------------------------------------------------------------------------------- /scripts/generate_lr_bi.m: -------------------------------------------------------------------------------- 1 | function generate_lr_bi() 2 | 3 | up_scale = 4; 4 | mod_scale = 4; 5 | idx = 0; 6 | filepaths = dir('./data/VimeoTecoGAN/Raw/*/*.png'); 7 | 8 | for i = 1 : length(filepaths) 9 | [~,imname,ext] = fileparts(filepaths(i).name); 10 | folder_path = filepaths(i).folder; 11 | save_lr_folder = strrep(folder_path, 'Raw', 'Bicubic4xLR') 12 | save_bi_folder = strrep(folder_path, 'Raw', 'Bicubic4xBI') 13 | if ~exist(save_lr_folder, 'dir') 14 | mkdir(save_lr_folder); 15 | end 16 | if ~exist(save_bi_folder, 'dir') 17 | mkdir(save_bi_folder); 18 | end 19 | if isempty(imname) 20 | disp('Ignore . folder.'); 21 | elseif strcmp(imname, '.') 22 | disp('Ignore .. folder.'); 23 | else 24 | idx = idx + 1; 25 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 26 | fprintf(str_rlt); 27 | % read image 28 | img = imread(fullfile(folder_path, [imname, ext])); 29 | img = im2double(img); 30 | % modcrop 31 | img = modcrop(img, mod_scale); 32 | % LR 33 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 34 | im_BI = imresize(im_LR, up_scale, 'bicubic'); 35 | if exist('save_lr_folder', 'var') 36 | imwrite(im_LR, fullfile(save_lr_folder, [imname, '.png'])); 37 | end 38 | if exist('save_bi_folder', 'var') 39 | imwrite(im_BI, fullfile(save_bi_folder, [imname, '.png'])); 40 | end 41 | end 42 | end 43 | end 44 | 45 | %% modcrop 46 | function img = modcrop(img, modulo) 47 | if size(img,3) == 1 48 | sz = size(img); 49 | sz = sz - mod(sz, modulo); 50 | img = img(1:sz(1), 1:sz(2)); 51 | else 52 | tmpsz = size(img); 53 | sz = tmpsz(1:2); 54 | sz = sz - mod(sz, modulo); 55 | img = img(1:sz(1), 1:sz(2),:); 56 | end 57 | end 58 | -------------------------------------------------------------------------------- /scripts/monitor_training.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import argparse 3 | import json 4 | import math 5 | import re 6 | import bisect 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | # -------------------- utility functions -------------------- # 12 | def append(loss_dict, loss_name, loss_value): 13 | if loss_name not in loss_dict: 14 | loss_dict[loss_name] = [loss_value] 15 | else: 16 | loss_dict[loss_name].append(loss_value) 17 | 18 | 19 | def split(pattern, string): 20 | return re.split(r'\s*{}\s*'.format(pattern), string) 21 | 22 | 23 | def parse_log(log_file): 24 | # define loss patterns 25 | loss_pattern = r'.*\[epoch:.*\| iter: (\d+).*\] (.*)' 26 | 27 | # load log file 28 | with open(log_file, 'r') as f: 29 | lines = [line.strip() for line in f] 30 | 31 | # parse log file 32 | loss_dict = {} # {'iter': [], 'loss1': [], 'loss2':[], ...} 33 | for line in lines: 34 | loss_match = re.match(loss_pattern, line) 35 | if loss_match: 36 | iter = int(loss_match.group(1)) 37 | append(loss_dict, 'iter', iter) 38 | for s in split(',', loss_match.group(2)): 39 | if s: 40 | k, v = split(':', s) 41 | append(loss_dict, k, float(v)) 42 | 43 | return loss_dict 44 | 45 | 46 | def parse_json(json_file): 47 | with open(json_file, 'r') as f: 48 | json_dict = json.load(f) 49 | 50 | metric_dict = {} 51 | for model_idx, metrics in json_dict.items(): 52 | append(metric_dict, 'iter', int(model_idx.replace('G_iter', ''))) 53 | for metric, val in metrics.items(): 54 | append(metric_dict, metric, float(val)) 55 | 56 | return metric_dict 57 | 58 | 59 | def plot_curve(ax, iter, value, style='-', alpha=1.0, label='', 60 | start_iter=0, end_iter=-1, smooth=0, linewidth=1.0): 61 | 62 | assert len(iter) == len(value), \ 63 | 'mismatch in and ({} vs {})'.format( 64 | len(iter), len(value)) 65 | l = len(iter) 66 | 67 | if smooth: 68 | for i in range(1, l): 69 | value[i] = smooth * value[i - 1] + (1 - smooth) * value[i] 70 | 71 | start_index = bisect.bisect_left(iter, start_iter) 72 | end_index = l if end_iter < 0 else bisect.bisect_right(iter, end_iter) 73 | ax.plot( 74 | iter[start_index:end_index], value[start_index:end_index], 75 | style, alpha=alpha, label=label, linewidth=linewidth) 76 | 77 | 78 | def plot_loss_curves(loss_dict, ax, loss_type, start_iter=0, end_iter=-1, 79 | smooth=0): 80 | 81 | for model_idx, model_loss_dict in loss_dict.items(): 82 | if loss_type in model_loss_dict: 83 | plot_curve( 84 | ax, model_loss_dict['iter'], model_loss_dict[loss_type], 85 | alpha=1.0, label=model_idx, start_iter=start_iter, 86 | end_iter=end_iter, smooth=smooth) 87 | ax.legend(loc='best', fontsize='small') 88 | ax.set_ylabel(loss_type) 89 | ax.set_xlabel('iteration') 90 | plt.grid(True) 91 | 92 | 93 | def plot_metric_curves(metric_dict, ax, metric_type, start_iter=0, end_iter=-1): 94 | """ currently can only plot average results 95 | """ 96 | 97 | for model_idx, model_metric_dict in metric_dict.items(): 98 | if metric_type in model_metric_dict: 99 | plot_curve( 100 | ax, model_metric_dict['iter'], model_metric_dict[metric_type], 101 | alpha=1.0, label=model_idx, start_iter=start_iter, 102 | end_iter=end_iter) 103 | ax.legend(loc='best', fontsize='small') 104 | ax.set_ylabel(metric_type) 105 | ax.set_xlabel('iteration') 106 | plt.grid(True) 107 | 108 | 109 | # -------------------- monitor -------------------- # 110 | def monitor(root_dir, testset, exp_id_lst, loss_lst, metric_lst): 111 | # ================ basic settings ================# 112 | start_iter = 0 113 | loss_smooth = 0 114 | 115 | # ================ parse logs ================# 116 | loss_dict = {} # {'model1': {'loss1': x1, ...}, ...} 117 | metric_dict = {} # {'model1': {'metric1': x1, ...}, ...} 118 | for exp_id in exp_id_lst: 119 | # parse log 120 | log_file = osp.join(root_dir, exp_id, 'train', 'train.log') 121 | if osp.exists(log_file): 122 | loss_dict[exp_id] = parse_log(log_file) 123 | 124 | # parse json 125 | json_file = osp.join( 126 | root_dir, exp_id, 'test', 'metrics', f'{testset}_avg.json') 127 | if osp.exists(json_file): 128 | metric_dict[exp_id] = parse_json(json_file) 129 | 130 | # ================ plot loss curves ================# 131 | n_loss = len(loss_lst) 132 | base_figsize = (12, 2 * math.ceil(n_loss / 2)) 133 | fig = plt.figure(1, figsize=base_figsize) 134 | for i in range(n_loss): 135 | ax = fig.add_subplot('{}{}{}'.format(math.ceil(n_loss / 2), 2, i + 1)) 136 | plot_loss_curves( 137 | loss_dict, ax, loss_lst[i], start_iter=start_iter, smooth=loss_smooth) 138 | 139 | # ================ plot metric curves ================# 140 | n_metric = len(metric_lst) 141 | base_figsize = (12, 2 * math.ceil(n_metric / 2)) 142 | fig = plt.figure(2, figsize=base_figsize) 143 | for i in range(n_metric): 144 | ax = fig.add_subplot('{}{}{}'.format(math.ceil(n_metric / 2), 2, i + 1)) 145 | plot_metric_curves( 146 | metric_dict, ax, metric_lst[i], start_iter=start_iter) 147 | 148 | plt.show() 149 | 150 | 151 | if __name__ == '__main__': 152 | # parse args 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument('--degradation', '-dg', type=str, required=True) 155 | parser.add_argument('--model', '-m', type=str, required=True) 156 | parser.add_argument('--dataset', '-ds', type=str, required=True) 157 | args = parser.parse_args() 158 | 159 | # select model 160 | root_dir = '.' 161 | if 'FRVSR' in args.model: 162 | # select experiments 163 | exp_id_lst = [ 164 | # experiment dirs 165 | f'experiments_{args.degradation}/{args.model}', 166 | ] 167 | 168 | # select losses 169 | loss_lst = [ 170 | 'l_pix_G', # pixel loss 171 | 'l_warp_G', # warping loss 172 | ] 173 | 174 | # select metrics 175 | metric_lst = [ 176 | 'PSNR', 177 | ] 178 | 179 | elif 'TecoGAN' in args.model: 180 | # select experiments 181 | exp_id_lst = [ 182 | # experiment dirs 183 | f'experiments_{args.degradation}/{args.model}', 184 | ] 185 | 186 | # select losses 187 | loss_lst = [ 188 | 'l_pix_G', # pixel loss 189 | 'l_warp_G', # warping loss 190 | 'l_feat_G', # perceptual loss 191 | 'l_gan_G', # generator loss 192 | 'l_gan_D', # discriminator loss 193 | 'p_real_D', 194 | 'p_fake_D', 195 | ] 196 | 197 | # select metrics 198 | metric_lst = [ 199 | 'PSNR', 200 | 'LPIPS', 201 | 'tOF' 202 | ] 203 | 204 | else: 205 | raise ValueError(f'Unrecoginzed model: {args.model}') 206 | 207 | monitor(root_dir, args.dataset, exp_id_lst, loss_lst, metric_lst) 208 | -------------------------------------------------------------------------------- /scripts/resize_bd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | from multiprocessing import Pool 5 | 6 | import numpy as np 7 | import cv2 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from codes.utils.data_utils import float32_to_uint8 12 | 13 | 14 | """ Note: 15 | There are two implementations (opencv's GaussianBlur and PyTorch's conv2d) 16 | for generating BD degraded LR data. Although there's only a slight numerical 17 | difference between these two methods, it's recommended to use the later one 18 | since it is adopted in model training. 19 | """ 20 | 21 | # setup params 22 | # default settings 23 | scale = 4 # downsampling scale 24 | sigma = 1.5 25 | ksize = 1 + 2 * int(sigma * 3.0) 26 | # to be modified 27 | n_process = 16 # the number of process to be used for downsampling 28 | filepaths = glob.glob('../data/Vid4/GT/*/*.png') 29 | gt_dir_idx, lr_dir_idx = 'GT', 'Gaussian4xLR' 30 | 31 | 32 | def down_opencv(img, sigma, ksize, scale): 33 | blur_img = cv2.GaussianBlur(img, (ksize, ksize), sigmaX=sigma) # hwc|uint8 34 | lr_img = blur_img[::scale, ::scale].astype(np.float32) / 255.0 35 | return lr_img # hwc|float32 36 | 37 | 38 | def down_pytorch(img, sigma, ksize, scale): 39 | img = np.ascontiguousarray(img) 40 | img = torch.FloatTensor(img).unsqueeze(0).permute(0, 3, 1, 2) / 255.0 # nchw 41 | 42 | gaussian_filters = create_kernel(sigma, ksize) 43 | 44 | filters_h, filters_w = gaussian_filters.shape[-2:] 45 | pad_h, pad_w = filters_h - 1, filters_w - 1 46 | 47 | pad_t = pad_h // 2 48 | pad_b = pad_h - pad_t 49 | pad_l = pad_w // 2 50 | pad_r = pad_w - pad_l 51 | 52 | img = F.pad(img, (pad_l, pad_r, pad_t, pad_b), 'reflect') 53 | 54 | lr_img = F.conv2d( 55 | img, gaussian_filters, stride=scale, bias=None, padding=0) 56 | 57 | return lr_img[0].permute(1, 2, 0).numpy() # hwc|float32 58 | 59 | 60 | def downsample_worker(filepath): 61 | # log 62 | print('Processing {}'.format(filepath)) 63 | 64 | # setup dirs 65 | gt_folder, img_idx = osp.split(filepath) 66 | lr_folder = gt_folder.replace(gt_dir_idx, lr_dir_idx) 67 | 68 | # read image 69 | img = cv2.imread(filepath) # hwc|bgr|uint8 70 | 71 | # dowmsample 72 | # img_lr = down_opencv(img, sigma, ksize, scale) 73 | img_lr = down_pytorch(img, sigma, ksize, scale) 74 | img_lr = float32_to_uint8(img_lr) 75 | 76 | # save image 77 | cv2.imwrite(osp.join(lr_folder, img_idx), img_lr) 78 | 79 | 80 | if __name__ == '__main__': 81 | # setup dirs 82 | print('# of images: {}'.format(len(filepaths))) 83 | for filepath in filepaths: 84 | gt_folder, _ = osp.split(filepath) 85 | lr_folder = gt_folder.replace(gt_dir_idx, lr_dir_idx) 86 | if not osp.exists(lr_folder): 87 | os.makedirs(lr_folder) 88 | 89 | # for each image 90 | pool = Pool(n_process) 91 | for filepath in sorted(filepaths): 92 | pool.apply_async(downsample_worker, args=(filepath,)) 93 | pool.close() 94 | pool.join() 95 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is used to evaluate a pretrained model. 4 | 5 | # basic settings 6 | root_dir=. 7 | degradation=$1 8 | model=$2 9 | gpu_ids=2,3 10 | master_port=4322 11 | 12 | 13 | # run 14 | num_gpus=`echo ${gpu_ids} | awk -F\, '{print NF}'` 15 | if [[ ${num_gpus} > 1 ]]; then 16 | dist_args="-m torch.distributed.launch --nproc_per_node ${num_gpus} --master_port ${master_port}" 17 | fi 18 | 19 | CUDA_VISIBLE_DEVICES=${gpu_ids} \ 20 | python ${dist_args} ${root_dir}/codes/main.py \ 21 | --exp_dir ${root_dir}/experiments_${degradation}/${model} \ 22 | --mode test \ 23 | --opt test.yml \ 24 | --gpu_ids ${gpu_ids} 25 | 26 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is used to train a model. 4 | 5 | # basic settings 6 | root_dir=. 7 | degradation=$1 8 | model=$2 9 | gpu_ids=0,1 # set to -1 to use cpu 10 | master_port=4321 11 | 12 | debug=0 13 | 14 | 15 | # retain training or train from scratch 16 | start_iter=0 17 | if [[ ${start_iter} > 0 ]]; then 18 | suffix=_iter${start_iter} 19 | else 20 | suffix='' 21 | fi 22 | 23 | 24 | exp_dir=${root_dir}/experiments_${degradation}/${model} 25 | # check 26 | if [ -d "$exp_dir/train" ]; then 27 | echo ">> Experiment dir already exists: $exp_dir/train" 28 | echo ">> Please delete it for retraining" 29 | exit 1 30 | fi 31 | # make dir 32 | mkdir -p ${exp_dir}/train 33 | 34 | 35 | # backup codes 36 | if [[ ${debug} > 0 ]]; then 37 | cp -r ${root_dir}/codes ${exp_dir}/train/codes_backup${suffix} 38 | fi 39 | 40 | 41 | # run 42 | num_gpus=`echo ${gpu_ids} | awk -F\, '{print NF}'` 43 | if [[ ${num_gpus} > 1 ]]; then 44 | dist_args="-m torch.distributed.launch --nproc_per_node ${num_gpus} --master_port ${master_port}" 45 | fi 46 | 47 | CUDA_VISIBLE_DEVICES=${gpu_ids} \ 48 | python ${dist_args} ${root_dir}/codes/main.py \ 49 | --exp_dir ${exp_dir} \ 50 | --mode train \ 51 | --opt train${suffix}.yml \ 52 | --gpu_ids ${gpu_ids} \ 53 | > ${exp_dir}/train/train${suffix}.log 2>&1 & 54 | 55 | --------------------------------------------------------------------------------