├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── config.py ├── data ├── __init__.py ├── snufilm.py ├── video.py └── vimeo90k.py ├── eval.sh ├── figures ├── CAIN_AAAI20_poster.pdf ├── CAIN_paper_thumb.jpg ├── CAIN_spotlight_thumb.jpg ├── overall_architecture.png └── qualitative_vimeo.png ├── generate.py ├── loss.py ├── main.py ├── model ├── __init__.py ├── cain.py ├── cain_encdec.py ├── cain_noca.py └── common.py ├── pytorch_msssim └── __init__.py ├── run.sh ├── run_noca.sh ├── test_custom.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore Git here 2 | .git 3 | 4 | # But not these files... 5 | # !.gitignore 6 | 7 | checkpoint/* 8 | logs/* 9 | data/vimeo_triplet 10 | 11 | 12 | # Created by .ignore support plugin (hsz.mobi) 13 | ### Python template 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *,cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # IPython Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | ### VirtualEnv template 104 | # Virtualenv 105 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 106 | .Python 107 | [Bb]in 108 | [Ii]nclude 109 | [Ll]ib 110 | [Ll]ib64 111 | [Ll]ocal 112 | [Ss]cripts 113 | pyvenv.cfg 114 | .venv 115 | pip-selfcheck.json 116 | ### JetBrains template 117 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 118 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 119 | 120 | # User-specific stuff: 121 | .idea/workspace.xml 122 | .idea/tasks.xml 123 | .idea/dictionaries 124 | .idea/vcs.xml 125 | .idea/jsLibraryMappings.xml 126 | 127 | # Sensitive or high-churn files: 128 | .idea/dataSources.ids 129 | .idea/dataSources.xml 130 | .idea/dataSources.local.xml 131 | .idea/sqlDataSources.xml 132 | .idea/dynamic.xml 133 | .idea/uiDesigner.xml 134 | 135 | # Gradle: 136 | .idea/gradle.xml 137 | .idea/libraries 138 | 139 | # Mongo Explorer plugin: 140 | .idea/mongoSettings.xml 141 | 142 | .idea/ 143 | 144 | ## File-based project format: 145 | *.iws 146 | 147 | ## Plugin-specific files: 148 | 149 | # IntelliJ 150 | /out/ 151 | 152 | # mpeltonen/sbt-idea plugin 153 | .idea_modules/ 154 | 155 | # JIRA plugin 156 | atlassian-ide-plugin.xml 157 | 158 | # Crashlytics plugin (for Android Studio and IntelliJ) 159 | com_crashlytics_export_strings.xml 160 | crashlytics.properties 161 | crashlytics-build.properties 162 | fabric.properties 163 | 164 | *.swp 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Myungsub Choi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Channel Attention Is All You Need for Video Frame Interpolation 2 | 3 | #### Myungsub Choi, Heewon Kim, Bohyung Han, Ning Xu, Kyoung Mu Lee 4 | 5 | #### 2nd place in [[AIM 2019 ICCV Workshop](http://www.vision.ee.ethz.ch/aim19/)] - Video Temporal Super-Resolution Challenge 6 | 7 | [Project](https://myungsub.github.io/CAIN) | [Paper-AAAI](https://aaai.org/ojs/index.php/AAAI/article/view/6693/6547) (Download the paper [[here](https://www.dropbox.com/s/b62wnroqdd5lhfc/AAAI-ChoiM.4773.pdf?dl=0)] in case the AAAI link is broken) | [Poster](https://www.dropbox.com/s/7lxwka16qkuacvh/AAAI-ChoiM.4773.pdf) 8 | 9 | Paper 10 | 11 | 12 | 13 | ## Directory Structure 14 | 15 | ``` text 16 | project 17 | │ README.md 18 | | run.sh - main script to train CAIN model 19 | | run_noca.sh - script to train CAIN_NoCA model 20 | | test_custom.sh - script to run interpolation on custom dataset 21 | | eval.sh - script to evaluate on SNU-FILM benchmark 22 | | main.py - main file to run train/val 23 | | config.py - check & change training/testing configurations here 24 | | loss.py - defines different loss functions 25 | | utils.py - misc. 26 | └───model 27 | │ │ common.py 28 | │ │ cain.py - main model 29 | | | cain_noca.py - model without channel attention 30 | | | cain_encdec.py - model with additional encoder-decoder 31 | └───data - implements dataloaders for each dataset 32 | │ | vimeo90k.py - main training / testing dataset 33 | | | video.py - custom data for testing 34 | │ └───symbolic links to each dataset 35 | | | ... 36 | ``` 37 | 38 | ## Dependencies 39 | 40 | Current version is tested on: 41 | 42 | - Ubuntu 18.04 43 | - Python==3.7.5 44 | - numpy==1.17 45 | - [PyTorch](http://pytorch.org/)==1.3.1, torchvision==0.4.2, cudatoolkit==10.1 46 | - tensorboard==2.0.0 (If you want training logs) 47 | - opencv==3.4.2 48 | - tqdm==4.39.0 49 | 50 | ``` text 51 | # Easy installation (using Anaconda environment) 52 | conda create -n cain 53 | conda activate cain 54 | conda install python=3.7 55 | conda install pip numpy 56 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 57 | conda install tqdm opencv tensorboard 58 | ``` 59 | 60 | ## Model 61 | 62 |
63 | 64 | ## Dataset Preparation 65 | 66 | - We use **[Vimeo90K Triplet dataset](http://toflow.csail.mit.edu/)** for training + testing. 67 | - After downloading the full dataset, make symbolic links in `data/` folder : 68 | - `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet` 69 | - Then you're done! 70 | - For more thorough evaluation, we built **[SNU-FILM (SNU Frame Interpolation with Large Motion)](https://myungsub.github.io/CAIN)** benchmark. 71 | - Download links can be found in the [project page](https://myungsub.github.io/CAIN). 72 | - Also make symbolic links after download : 73 | - `ln -s /path/to/SNU-FILM_data/ ./data/SNU-FILM` 74 | - Done! 75 | 76 | ## Usage 77 | 78 | #### Training / Testing with Vimeo90K dataset 79 | - First make symbolic links in `data/` folder : `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet` 80 | - [Vimeo90K dataset](http://toflow.csail.mit.edu/) 81 | - For training: `CUDA_VISIBLE_DEVICES=0 python main.py --exp_name EXPNAME --batch_size 16 --test_batch_size 16 --dataset vimeo90k --model cain --loss 1*L1 --max_epoch 200 --lr 0.0002` 82 | - Or, just run `./run.sh` 83 | - For testing performance on Vimeo90K dataset, just add `--mode test` option 84 | - For testing on SNU-FILM dataset, run `./eval.sh` 85 | - Testing mode (choose from ['easy', 'medium', 'hard', 'extreme']) can be modified by changing `--test_mode` option in `eval.sh`. 86 | 87 | #### Interpolating with custom video 88 | - Download pretrained models from [[Here](https://www.dropbox.com/s/y1xf46m2cbwk7yf/pretrained_cain.pth?dl=0)] 89 | - Prepare frame sequences in `data/frame_seq` 90 | - run `test_custom.sh` 91 | 92 | ## Results 93 | 94 |
95 | 96 | ### Video 97 | 98 | Video 99 | 100 | ## Citation 101 | 102 | If you find this code useful for your research, please consider citing the following paper: 103 | 104 | ``` text 105 | @inproceedings{choi2020cain, 106 | author = {Choi, Myungsub and Kim, Heewon and Han, Bohyung and Xu, Ning and Lee, Kyoung Mu}, 107 | title = {Channel Attention Is All You Need for Video Frame Interpolation}, 108 | booktitle = {AAAI}, 109 | year = {2020} 110 | } 111 | ``` 112 | 113 | ## Acknowledgement 114 | 115 | Many parts of this code is adapted from: 116 | 117 | - [EDSR-Pytorch](https://github.com/thstkdgus35/EDSR-PyTorch) 118 | - [RCAN](https://github.com/yulunzhang/RCAN) 119 | 120 | We thank the authors for sharing codes for their great works. 121 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | arg_lists = [] 4 | parser = argparse.ArgumentParser() 5 | 6 | def str2bool(v): 7 | return v.lower() in ('true') 8 | 9 | def add_argument_group(name): 10 | arg = parser.add_argument_group(name) 11 | arg_lists.append(arg) 12 | return arg 13 | 14 | # Dataset 15 | data_arg = add_argument_group('Dataset') 16 | data_arg.add_argument('--dataset', type=str, default='vimeo90k') 17 | data_arg.add_argument('--num_frames', type=int, default=3) 18 | data_arg.add_argument('--data_root', type=str, default='data/vimeo_triplet') 19 | data_arg.add_argument('--img_fmt', type=str, default='png') 20 | 21 | # Model 22 | model_arg = add_argument_group('Model') 23 | model_arg.add_argument('--model', type=str, default='CAIN') 24 | model_arg.add_argument('--depth', type=int, default=3, help='# of pooling') 25 | model_arg.add_argument('--n_resblocks', type=int, default=12) 26 | model_arg.add_argument('--up_mode', type=str, default='shuffle') 27 | 28 | # Training / test parameters 29 | learn_arg = add_argument_group('Learning') 30 | learn_arg.add_argument('--mode', type=str, default='train', 31 | choices=['train', 'test', 'test-multi', 'gen-multi']) 32 | learn_arg.add_argument('--loss', type=str, default='1*L1') 33 | learn_arg.add_argument('--lr', type=float, default=1e-4) 34 | learn_arg.add_argument('--beta1', type=float, default=0.9) 35 | learn_arg.add_argument('--beta2', type=float, default=0.99) 36 | learn_arg.add_argument('--batch_size', type=int, default=16) 37 | learn_arg.add_argument('--val_batch_size', type=int, default=4) 38 | learn_arg.add_argument('--test_batch_size', type=int, default=1) 39 | learn_arg.add_argument('--test_mode', type=str, default='hard', help='Test mode to evaluate on SNU-FILM dataset') 40 | learn_arg.add_argument('--start_epoch', type=int, default=0) 41 | learn_arg.add_argument('--max_epoch', type=int, default=200) 42 | learn_arg.add_argument('--resume', action='store_true') 43 | learn_arg.add_argument('--resume_exp', type=str, default=None) 44 | learn_arg.add_argument('--fix_loaded', action='store_true', help='whether to fix updating all loaded parts of the model') 45 | 46 | # Misc 47 | misc_arg = add_argument_group('Misc') 48 | misc_arg.add_argument('--exp_name', type=str, default='exp') 49 | misc_arg.add_argument('--log_iter', type=int, default=20) 50 | misc_arg.add_argument('--log_dir', type=str, default='logs') 51 | misc_arg.add_argument('--data_dir', type=str, default='data') 52 | misc_arg.add_argument('--num_gpu', type=int, default=1) 53 | misc_arg.add_argument('--random_seed', type=int, default=12345) 54 | misc_arg.add_argument('--num_workers', type=int, default=5) 55 | misc_arg.add_argument('--use_tensorboard', action='store_true') 56 | misc_arg.add_argument('--viz', action='store_true', help='whether to save images') 57 | misc_arg.add_argument('--lpips', action='store_true', help='evaluates LPIPS if set true') 58 | 59 | def get_args(): 60 | """Parses all of the arguments above 61 | """ 62 | args, unparsed = parser.parse_known_args() 63 | if args.num_gpu > 0: 64 | setattr(args, 'cuda', True) 65 | else: 66 | setattr(args, 'cuda', False) 67 | if len(unparsed) > 1: 68 | print("Unparsed args: {}".format(unparsed)) 69 | return args, unparsed 70 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/data/__init__.py -------------------------------------------------------------------------------- /data/snufilm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms 6 | from PIL import Image 7 | 8 | class SNUFILM(Dataset): 9 | def __init__(self, data_root, mode='hard'): 10 | ''' 11 | :param data_root: ./data/SNU-FILM 12 | :param mode: ['easy', 'medium', 'hard', 'extreme'] 13 | ''' 14 | test_root = os.path.join(data_root, 'test') 15 | test_fn = os.path.join(data_root, 'test-%s.txt' % mode) 16 | with open(test_fn, 'r') as f: 17 | self.frame_list = f.read().splitlines() 18 | self.frame_list = [v.split(' ') for v in self.frame_list] 19 | 20 | self.transforms = transforms.Compose([ 21 | transforms.ToTensor() 22 | ]) 23 | 24 | print("[%s] Test dataset has %d triplets" % (mode, len(self.frame_list))) 25 | 26 | 27 | def __getitem__(self, index): 28 | 29 | # Use self.test_all_images: 30 | imgpaths = self.frame_list[index] 31 | 32 | img1 = Image.open(imgpaths[0]) 33 | img2 = Image.open(imgpaths[1]) 34 | img3 = Image.open(imgpaths[2]) 35 | 36 | img1 = self.transforms(img1) 37 | img2 = self.transforms(img2) 38 | img3 = self.transforms(img3) 39 | 40 | imgs = [img1, img2, img3] 41 | 42 | return imgs, imgpaths 43 | 44 | def __len__(self): 45 | return len(self.frame_list) 46 | 47 | 48 | def check_already_extracted(vid): 49 | return bool(os.path.exists(vid + '/0001.png')) 50 | 51 | 52 | def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode='hard'): 53 | # data_root = 'data/SNUFILM' 54 | dataset = SNUFILM(data_root, mode=test_mode) 55 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 56 | -------------------------------------------------------------------------------- /data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision import transforms 7 | from PIL import Image 8 | 9 | class Video(Dataset): 10 | def __init__(self, data_root, fmt='png'): 11 | images = sorted(glob.glob(os.path.join(data_root, '*.%s' % fmt))) 12 | for im in images: 13 | try: 14 | float_ind = float(im.split('_')[-1][:-4]) 15 | except ValueError: 16 | os.rename(im, '%s_%.06f.%s' % (im[:-4], 0.0, fmt)) 17 | # re 18 | images = sorted(glob.glob(os.path.join(data_root, '*.%s' % fmt))) 19 | self.imglist = [[images[i], images[i+1]] for i in range(len(images)-1)] 20 | print('[%d] images ready to be loaded' % len(self.imglist)) 21 | 22 | 23 | def __getitem__(self, index): 24 | imgpaths = self.imglist[index] 25 | 26 | # Load images 27 | img1 = Image.open(imgpaths[0]) 28 | img2 = Image.open(imgpaths[1]) 29 | 30 | T = transforms.ToTensor() 31 | img1 = T(img1) 32 | img2 = T(img2) 33 | 34 | imgs = [img1, img2] 35 | meta = {'imgpath': imgpaths} 36 | return imgs, meta 37 | 38 | def __len__(self): 39 | return len(self.imglist) 40 | 41 | 42 | def get_loader(mode, data_root, batch_size, img_fmt='png', shuffle=False, num_workers=0, n_frames=1): 43 | if mode == 'train': 44 | is_training = True 45 | else: 46 | is_training = False 47 | dataset = Video(data_root, fmt=img_fmt) 48 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 49 | -------------------------------------------------------------------------------- /data/vimeo90k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms 6 | from PIL import Image 7 | import random 8 | 9 | class VimeoTriplet(Dataset): 10 | def __init__(self, data_root, is_training): 11 | self.data_root = data_root 12 | self.image_root = os.path.join(self.data_root, 'sequences') 13 | self.training = is_training 14 | 15 | train_fn = os.path.join(self.data_root, 'tri_trainlist.txt') 16 | test_fn = os.path.join(self.data_root, 'tri_testlist.txt') 17 | with open(train_fn, 'r') as f: 18 | self.trainlist = f.read().splitlines() 19 | with open(test_fn, 'r') as f: 20 | self.testlist = f.read().splitlines() 21 | 22 | self.transforms = transforms.Compose([ 23 | transforms.RandomCrop(256), 24 | transforms.RandomHorizontalFlip(0.5), 25 | transforms.RandomVerticalFlip(0.5), 26 | transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), 27 | transforms.ToTensor() 28 | ]) 29 | 30 | 31 | def __getitem__(self, index): 32 | if self.training: 33 | imgpath = os.path.join(self.image_root, self.trainlist[index]) 34 | else: 35 | imgpath = os.path.join(self.image_root, self.testlist[index]) 36 | imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png'] 37 | 38 | # Load images 39 | img1 = Image.open(imgpaths[0]) 40 | img2 = Image.open(imgpaths[1]) 41 | img3 = Image.open(imgpaths[2]) 42 | 43 | # Data augmentation 44 | if self.training: 45 | seed = random.randint(0, 2**32) 46 | random.seed(seed) 47 | img1 = self.transforms(img1) 48 | random.seed(seed) 49 | img2 = self.transforms(img2) 50 | random.seed(seed) 51 | img3 = self.transforms(img3) 52 | # Random Temporal Flip 53 | if random.random() >= 0.5: 54 | img1, img3 = img3, img1 55 | imgpaths[0], imgpaths[2] = imgpaths[2], imgpaths[0] 56 | else: 57 | T = transforms.ToTensor() 58 | img1 = T(img1) 59 | img2 = T(img2) 60 | img3 = T(img3) 61 | 62 | imgs = [img1, img2, img3] 63 | 64 | return imgs, imgpaths 65 | 66 | def __len__(self): 67 | if self.training: 68 | return len(self.trainlist) 69 | else: 70 | return len(self.testlist) 71 | return 0 72 | 73 | 74 | def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode=None): 75 | if mode == 'train': 76 | is_training = True 77 | else: 78 | is_training = False 79 | dataset = VimeoTriplet(data_root, is_training=is_training) 80 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 81 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 python main.py \ 4 | --exp_name CAIN_eval \ 5 | --dataset snufilm \ 6 | --data_root data/SNU-FILM \ 7 | --test_batch_size 1 \ 8 | --model cain \ 9 | --depth 3 \ 10 | --mode test \ 11 | --resume \ 12 | --resume_exp CAIN_train \ 13 | --test_mode hard -------------------------------------------------------------------------------- /figures/CAIN_AAAI20_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/CAIN_AAAI20_poster.pdf -------------------------------------------------------------------------------- /figures/CAIN_paper_thumb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/CAIN_paper_thumb.jpg -------------------------------------------------------------------------------- /figures/CAIN_spotlight_thumb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/CAIN_spotlight_thumb.jpg -------------------------------------------------------------------------------- /figures/overall_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/overall_architecture.png -------------------------------------------------------------------------------- /figures/qualitative_vimeo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/qualitative_vimeo.png -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import copy 5 | import shutil 6 | import random 7 | 8 | import torch 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | import config 13 | import utils 14 | 15 | 16 | ##### Parse CmdLine Arguments ##### 17 | args, unparsed = config.get_args() 18 | cwd = os.getcwd() 19 | print(args) 20 | 21 | 22 | device = torch.device('cuda' if args.cuda else 'cpu') 23 | torch.backends.cudnn.enabled = True 24 | torch.backends.cudnn.benchmark = True 25 | 26 | torch.manual_seed(args.random_seed) 27 | if args.cuda: 28 | torch.cuda.manual_seed(args.random_seed) 29 | 30 | 31 | 32 | 33 | ##### Build Model ##### 34 | if args.model.lower() == 'cain_encdec': 35 | from model.cain_encdec import CAIN_EncDec 36 | print('Building model: CAIN_EncDec') 37 | model = CAIN_EncDec(depth=args.depth, start_filts=32) 38 | elif args.model.lower() == 'cain': 39 | from model.cain import CAIN 40 | print("Building model: CAIN") 41 | model = CAIN(depth=args.depth) 42 | elif args.model.lower() == 'cain_noca': 43 | from model.cain_noca import CAIN_NoCA 44 | print("Building model: CAIN_NoCA") 45 | model = CAIN_NoCA(depth=args.depth) 46 | else: 47 | raise NotImplementedError("Unknown model!") 48 | # Just make every model to DataParallel 49 | model = torch.nn.DataParallel(model).to(device) 50 | #print(model) 51 | 52 | print('# of parameters: %d' % sum(p.numel() for p in model.parameters())) 53 | 54 | 55 | # If resume, load checkpoint: model 56 | if args.resume: 57 | #utils.load_checkpoint(args, model, optimizer=None) 58 | checkpoint = torch.load('pretrained_cain.pth') 59 | args.start_epoch = checkpoint['epoch'] + 1 60 | model.load_state_dict(checkpoint['state_dict']) 61 | del checkpoint 62 | 63 | 64 | 65 | def test(args, epoch): 66 | print('Evaluating for epoch = %d' % epoch) 67 | ##### Load Dataset ##### 68 | test_loader = utils.load_dataset( 69 | args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, img_fmt=args.img_fmt) 70 | model.eval() 71 | 72 | t = time.time() 73 | with torch.no_grad(): 74 | for i, (images, meta) in enumerate(tqdm(test_loader)): 75 | 76 | # Build input batch 77 | im1, im2 = images[0].to(device), images[1].to(device) 78 | 79 | # Forward 80 | out, _ = model(im1, im2) 81 | 82 | # Save result images 83 | if args.mode == 'test': 84 | for b in range(images[0].size(0)): 85 | paths = meta['imgpath'][0][b].split('/') 86 | fp = args.data_root 87 | fp = os.path.join(fp, paths[-1][:-4]) # remove '.png' extension 88 | 89 | # Decide float index 90 | i1_str = paths[-1][:-4] 91 | i2_str = meta['imgpath'][1][b].split('/')[-1][:-4] 92 | try: 93 | i1 = float(i1_str.split('_')[-1]) 94 | except ValueError: 95 | i1 = 0.0 96 | try: 97 | i2 = float(i2_str.split('_')[-1]) 98 | if i2 == 0.0: 99 | i2 = 1.0 100 | except ValueError: 101 | i2 = 1.0 102 | fpos = max(0, fp.rfind('_')) 103 | fInd = (i1 + i2) / 2 104 | savepath = "%s_%06f.%s" % (fp[:fpos], fInd, args.img_fmt) 105 | utils.save_image(out[b], savepath) 106 | 107 | # Print progress 108 | print('im_processed: {:d}/{:d} {:.3f}s \r'.format(i + 1, len(test_loader), time.time() - t)) 109 | 110 | return 111 | 112 | 113 | """ Entry Point """ 114 | def main(args): 115 | 116 | num_iter = 2 # x2**num_iter interpolation 117 | for _ in range(num_iter): 118 | 119 | # run test 120 | test(args, args.start_epoch) 121 | 122 | 123 | if __name__ == "__main__": 124 | main(args) 125 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | import pytorch_msssim 6 | from model.common import sub_mean, InOutPaddings, meanShift, PixelShuffle, ResidualGroup, conv 7 | 8 | class MeanShift(nn.Conv2d): 9 | def __init__(self, rgb_mean, rgb_std, sign=-1): 10 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 11 | std = torch.Tensor(rgb_std) 12 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 13 | self.weight.data.div_(std.view(3, 1, 1, 1)) 14 | self.bias.data = sign * torch.Tensor(rgb_mean) 15 | self.bias.data.div_(std) 16 | self.requires_grad = False 17 | 18 | 19 | class VGG(nn.Module): 20 | def __init__(self, loss_type): 21 | super(VGG, self).__init__() 22 | vgg_features = models.vgg19(pretrained=True).features 23 | modules = [m for m in vgg_features] 24 | conv_index = loss_type[-2:] 25 | if conv_index == '22': 26 | self.vgg = nn.Sequential(*modules[:8]) 27 | elif conv_index == '33': 28 | self.vgg = nn.Sequential(*modules[:16]) 29 | elif conv_index == '44': 30 | self.vgg = nn.Sequential(*modules[:26]) 31 | elif conv_index == '54': 32 | self.vgg = nn.Sequential(*modules[:35]) 33 | elif conv_index == 'P': 34 | self.vgg = nn.ModuleList([ 35 | nn.Sequential(*modules[:8]), 36 | nn.Sequential(*modules[8:16]), 37 | nn.Sequential(*modules[16:26]), 38 | nn.Sequential(*modules[26:35]) 39 | ]) 40 | self.vgg = nn.DataParallel(self.vgg).cuda() 41 | 42 | vgg_mean = (0.485, 0.456, 0.406) 43 | vgg_std = (0.229, 0.224, 0.225) 44 | self.sub_mean = MeanShift(vgg_mean, vgg_std) 45 | self.vgg.requires_grad = False 46 | # self.criterion = nn.L1Loss() 47 | self.conv_index = conv_index 48 | 49 | def forward(self, sr, hr): 50 | def _forward(x): 51 | x = self.sub_mean(x) 52 | x = self.vgg(x) 53 | return x 54 | def _forward_all(x): 55 | feats = [] 56 | x = self.sub_mean(x) 57 | for module in self.vgg.module: 58 | x = module(x) 59 | feats.append(x) 60 | return feats 61 | 62 | if self.conv_index == 'P': 63 | vgg_sr_feats = _forward_all(sr) 64 | with torch.no_grad(): 65 | vgg_hr_feats = _forward_all(hr.detach()) 66 | loss = 0 67 | for i in range(len(vgg_sr_feats)): 68 | loss_f = F.mse_loss(vgg_sr_feats[i], vgg_hr_feats[i]) 69 | #print(loss_f) 70 | loss += loss_f 71 | #print() 72 | else: 73 | vgg_sr = _forward(sr) 74 | with torch.no_grad(): 75 | vgg_hr = _forward(hr.detach()) 76 | loss = F.mse_loss(vgg_sr, vgg_hr) 77 | 78 | return loss 79 | 80 | 81 | # For Adversarial loss 82 | class BasicBlock(nn.Sequential): 83 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): 84 | m = [nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), stride=stride, bias=bias)] 85 | if bn: m.append(nn.BatchNorm2d(out_channels)) 86 | if act is not None: m.append(act) 87 | super(BasicBlock, self).__init__(*m) 88 | 89 | class Discriminator(nn.Module): 90 | def __init__(self, args, gan_type='GAN'): 91 | super(Discriminator, self).__init__() 92 | 93 | in_channels = 3 94 | out_channels = 64 95 | depth = 7 96 | #bn = not gan_type == 'WGAN_GP' 97 | bn = True 98 | act = nn.LeakyReLU(negative_slope=0.2, inplace=True) 99 | 100 | m_features = [ 101 | BasicBlock(in_channels, out_channels, 3, bn=bn, act=act) 102 | ] 103 | for i in range(depth): 104 | in_channels = out_channels 105 | if i % 2 == 1: 106 | stride = 1 107 | out_channels *= 2 108 | else: 109 | stride = 2 110 | m_features.append(BasicBlock( 111 | in_channels, out_channels, 3, stride=stride, bn=bn, act=act 112 | )) 113 | 114 | self.features = nn.Sequential(*m_features) 115 | 116 | self.patch_size = args.patch_size 117 | feature_patch_size = self.patch_size // (2**((depth + 1) // 2)) 118 | #patch_size = 256 // (2**((depth + 1) // 2)) 119 | m_classifier = [ 120 | nn.Linear(out_channels * feature_patch_size**2, 1024), 121 | act, 122 | nn.Linear(1024, 1) 123 | ] 124 | self.classifier = nn.Sequential(*m_classifier) 125 | 126 | def forward(self, x): 127 | if x.size(2) != self.patch_size or x.size(3) != self.patch_size: 128 | midH, midW = x.size(2) // 2, x.size(3) // 2 129 | p = self.patch_size // 2 130 | x = x[:, :, (midH - p):(midH - p + self.patch_size), (midW - p):(midW - p + self.patch_size)] 131 | features = self.features(x) 132 | output = self.classifier(features.view(features.size(0), -1)) 133 | 134 | return output 135 | 136 | 137 | import torch.optim as optim 138 | class Adversarial(nn.Module): 139 | def __init__(self, args, gan_type): 140 | super(Adversarial, self).__init__() 141 | self.gan_type = gan_type 142 | self.gan_k = 1 #args.gan_k 143 | self.discriminator = torch.nn.DataParallel(Discriminator(args, gan_type)) 144 | if gan_type != 'WGAN_GP': 145 | self.optimizer = optim.Adam( 146 | self.discriminator.parameters(), 147 | betas=(0.9, 0.99), eps=1e-8, lr=1e-4 148 | ) 149 | else: 150 | self.optimizer = optim.Adam( 151 | self.discriminator.parameters(), 152 | betas=(0, 0.9), eps=1e-8, lr=1e-5 153 | ) 154 | # self.scheduler = utility.make_scheduler(args, self.optimizer) 155 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( 156 | self.optimizer, mode='min', factor=0.5, patience=3, verbose=True) 157 | 158 | def forward(self, fake, real, fake_input0=None, fake_input1=None, fake_input_mean=None): 159 | # def forward(self, fake, real): 160 | fake_detach = fake.detach() 161 | if fake_input0 is not None: 162 | fake0, fake1 = fake_input0.detach(), fake_input1.detach() 163 | if fake_input_mean is not None: 164 | fake_m = fake_input_mean.detach() 165 | # print(fake.size(), fake_input0.size(), fake_input1.size(), fake_input_mean.size()) 166 | 167 | self.loss = 0 168 | for _ in range(self.gan_k): 169 | self.optimizer.zero_grad() 170 | d_fake = self.discriminator(fake_detach) 171 | 172 | if fake_input0 is not None and fake_input1 is not None: 173 | d_fake0 = self.discriminator(fake0) 174 | d_fake1 = self.discriminator(fake1) 175 | if fake_input_mean is not None: 176 | d_fake_m = self.discriminator(fake_m) 177 | 178 | # print(d_fake.size(), d_fake0.size(), d_fake1.size(), d_fake_m.size()) 179 | 180 | d_real = self.discriminator(real) 181 | if self.gan_type == 'GAN': 182 | label_fake = torch.zeros_like(d_fake) 183 | label_real = torch.ones_like(d_real) 184 | loss_d \ 185 | = F.binary_cross_entropy_with_logits(d_fake, label_fake) \ 186 | + F.binary_cross_entropy_with_logits(d_real, label_real) 187 | if fake_input0 is not None and fake_input1 is not None: 188 | loss_d += F.binary_cross_entropy_with_logits(d_fake0, label_fake) \ 189 | + F.binary_cross_entropy_with_logits(d_fake1, label_fake) 190 | if fake_input_mean is not None: 191 | loss_d += F.binary_cross_entropy_with_logits(d_fake_m, label_fake) 192 | 193 | elif self.gan_type.find('WGAN') >= 0: 194 | loss_d = (d_fake - d_real).mean() 195 | if self.gan_type.find('GP') >= 0: 196 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 197 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 198 | hat.requires_grad = True 199 | d_hat = self.discriminator(hat) 200 | gradients = torch.autograd.grad( 201 | outputs=d_hat.sum(), inputs=hat, 202 | retain_graph=True, create_graph=True, only_inputs=True 203 | )[0] 204 | gradients = gradients.view(gradients.size(0), -1) 205 | gradient_norm = gradients.norm(2, dim=1) 206 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 207 | loss_d += gradient_penalty 208 | 209 | # Discriminator update 210 | self.loss += loss_d.item() 211 | if self.training: 212 | loss_d.backward() 213 | self.optimizer.step() 214 | 215 | if self.gan_type == 'WGAN': 216 | for p in self.discriminator.parameters(): 217 | p.data.clamp_(-1, 1) 218 | 219 | self.loss /= self.gan_k 220 | 221 | d_fake_for_g = self.discriminator(fake) 222 | if self.gan_type == 'GAN': 223 | loss_g = F.binary_cross_entropy_with_logits( 224 | d_fake_for_g, label_real 225 | ) 226 | elif self.gan_type.find('WGAN') >= 0: 227 | loss_g = -d_fake_for_g.mean() 228 | 229 | # Generator loss 230 | return loss_g 231 | 232 | def state_dict(self, *args, **kwargs): 233 | state_discriminator = self.discriminator.state_dict(*args, **kwargs) 234 | state_optimizer = self.optimizer.state_dict() 235 | 236 | return dict(**state_discriminator, **state_optimizer) 237 | 238 | 239 | # Some references 240 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 241 | # OR 242 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 243 | 244 | 245 | # Wrapper of loss functions 246 | class Loss(nn.modules.loss._Loss): 247 | def __init__(self, args): 248 | super(Loss, self).__init__() 249 | print('Preparing loss function:') 250 | 251 | self.loss = [] 252 | self.loss_module = nn.ModuleList() 253 | for loss in args.loss.split('+'): 254 | weight, loss_type = loss.split('*') 255 | if loss_type == 'MSE': 256 | loss_function = nn.MSELoss() 257 | elif loss_type == 'L1': 258 | loss_function = nn.L1Loss() 259 | elif loss_type.find('VGG') >= 0: 260 | loss_function = VGG(loss_type[3:]) 261 | elif loss_type == 'SSIM': 262 | loss_function = pytorch_msssim.SSIM(val_range=1.) 263 | elif loss_type.find('GAN') >= 0: 264 | loss_function = Adversarial(args, loss_type) 265 | 266 | self.loss.append({ 267 | 'type': loss_type, 268 | 'weight': float(weight), 269 | 'function': loss_function} 270 | ) 271 | if loss_type.find('GAN') >= 0 >= 0: 272 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 273 | 274 | if len(self.loss) > 1: 275 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 276 | 277 | for l in self.loss: 278 | if l['function'] is not None: 279 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 280 | self.loss_module.append(l['function']) 281 | 282 | device = torch.device('cuda' if args.cuda else 'cpu') 283 | self.loss_module.to(device) 284 | #if args.precision == 'half': self.loss_module.half() 285 | if args.cuda:# and args.n_GPUs > 1: 286 | self.loss_module = nn.DataParallel(self.loss_module) 287 | 288 | 289 | def forward(self, sr, hr, model_enc=None, feats=None, fake_imgs=None): 290 | loss = 0 291 | losses = {} 292 | for i, l in enumerate(self.loss): 293 | if l['function'] is not None: 294 | if l['type'] == 'GAN': 295 | if fake_imgs is None: 296 | fake_imgs = [None, None, None] 297 | _loss = l['function'](sr, hr, fake_imgs[0], fake_imgs[1], fake_imgs[2]) 298 | else: 299 | _loss = l['function'](sr, hr) 300 | effective_loss = l['weight'] * _loss 301 | losses[l['type']] = effective_loss 302 | loss += effective_loss 303 | elif l['type'] == 'DIS': 304 | losses[l['type']] = self.loss[i - 1]['function'].loss 305 | 306 | #loss_sum = sum(losses) 307 | #if len(self.loss) > 1: 308 | # self.log[-1, -1] += loss_sum.item() 309 | 310 | return loss, losses 311 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import copy 5 | import shutil 6 | import random 7 | 8 | import torch 9 | import numpy as np 10 | from tqdm import tqdm 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | import config 14 | import utils 15 | from loss import Loss 16 | 17 | 18 | ##### Parse CmdLine Arguments ##### 19 | args, unparsed = config.get_args() 20 | cwd = os.getcwd() 21 | print(args) 22 | 23 | 24 | ##### TensorBoard & Misc Setup ##### 25 | if args.mode != 'test': 26 | writer = SummaryWriter('logs/%s' % args.exp_name) 27 | 28 | device = torch.device('cuda' if args.cuda else 'cpu') 29 | torch.backends.cudnn.enabled = True 30 | torch.backends.cudnn.benchmark = True 31 | 32 | torch.manual_seed(args.random_seed) 33 | if args.cuda: 34 | torch.cuda.manual_seed(args.random_seed) 35 | 36 | 37 | ##### Load Dataset ##### 38 | train_loader, test_loader = utils.load_dataset( 39 | args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, args.test_mode) 40 | 41 | 42 | ##### Build Model ##### 43 | if args.model.lower() == 'cain_encdec': 44 | from model.cain_encdec import CAIN_EncDec 45 | print('Building model: CAIN_EncDec') 46 | model = CAIN_EncDec(depth=args.depth, start_filts=32) 47 | elif args.model.lower() == 'cain': 48 | from model.cain import CAIN 49 | print("Building model: CAIN") 50 | model = CAIN(depth=args.depth) 51 | elif args.model.lower() == 'cain_noca': 52 | from model.cain_noca import CAIN_NoCA 53 | print("Building model: CAIN_NoCA") 54 | model = CAIN_NoCA(depth=args.depth) 55 | else: 56 | raise NotImplementedError("Unknown model!") 57 | # Just make every model to DataParallel 58 | model = torch.nn.DataParallel(model).to(device) 59 | #print(model) 60 | 61 | ##### Define Loss & Optimizer ##### 62 | criterion = Loss(args) 63 | 64 | args.radam = False 65 | if args.radam: 66 | from radam import RAdam 67 | optimizer = RAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 68 | else: 69 | from torch.optim import Adam 70 | optimizer = Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 71 | print('# of parameters: %d' % sum(p.numel() for p in model.parameters())) 72 | 73 | 74 | # If resume, load checkpoint: model + optimizer 75 | if args.resume: 76 | utils.load_checkpoint(args, model, optimizer) 77 | 78 | # Learning Rate Scheduler 79 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 80 | optimizer, mode='min', factor=0.5, patience=5, verbose=True) 81 | 82 | 83 | # Initialize LPIPS model if used for evaluation 84 | # lpips_model = utils.init_lpips_eval() if args.lpips else None 85 | lpips_model = None 86 | 87 | LOSS_0 = 0 88 | 89 | 90 | def train(args, epoch): 91 | global LOSS_0 92 | losses, psnrs, ssims, lpips = utils.init_meters(args.loss) 93 | model.train() 94 | criterion.train() 95 | 96 | t = time.time() 97 | for i, (images, imgpaths) in enumerate(train_loader): 98 | 99 | # Build input batch 100 | im1, im2, gt = utils.build_input(images, imgpaths) 101 | 102 | # Forward 103 | optimizer.zero_grad() 104 | out, feats = model(im1, im2) 105 | loss, loss_specific = criterion(out, gt, None, feats) 106 | 107 | # Save loss values 108 | for k, v in losses.items(): 109 | if k != 'total': 110 | v.update(loss_specific[k].item()) 111 | if LOSS_0 == 0: 112 | LOSS_0 = loss.data.item() 113 | losses['total'].update(loss.item()) 114 | 115 | # Backward (+ grad clip) - if loss explodes, skip current iteration 116 | loss.backward() 117 | if loss.data.item() > 10.0 * LOSS_0: 118 | print(max(p.grad.data.abs().max() for p in model.parameters())) 119 | continue 120 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) 121 | optimizer.step() 122 | 123 | # Calc metrics & print logs 124 | if i % args.log_iter == 0: 125 | utils.eval_metrics(out, gt, psnrs, ssims, lpips, lpips_model) 126 | 127 | print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}\tPSNR: {:.4f}\tTime({:.2f})'.format( 128 | epoch, i, len(train_loader), losses['total'].avg, psnrs.avg, time.time() - t)) 129 | 130 | # Log to TensorBoard 131 | utils.log_tensorboard(writer, losses, psnrs.avg, ssims.avg, lpips.avg, 132 | optimizer.param_groups[-1]['lr'], epoch * len(train_loader) + i) 133 | 134 | # Reset metrics 135 | losses, psnrs, ssims, lpips = utils.init_meters(args.loss) 136 | t = time.time() 137 | 138 | 139 | def test(args, epoch, eval_alpha=0.5): 140 | print('Evaluating for epoch = %d' % epoch) 141 | losses, psnrs, ssims, lpips = utils.init_meters(args.loss) 142 | model.eval() 143 | criterion.eval() 144 | 145 | save_folder = 'test%03d' % epoch 146 | if args.dataset == 'snufilm': 147 | save_folder = os.path.join(save_folder, args.dataset, args.test_mode) 148 | else: 149 | save_folder = os.path.join(save_folder, args.dataset) 150 | save_dir = os.path.join('checkpoint', args.exp_name, save_folder) 151 | utils.makedirs(save_dir) 152 | save_fn = os.path.join(save_dir, 'results.txt') 153 | if not os.path.exists(save_fn): 154 | with open(save_fn, 'w') as f: 155 | f.write('For epoch=%d\n' % epoch) 156 | 157 | t = time.time() 158 | with torch.no_grad(): 159 | for i, (images, imgpaths) in enumerate(tqdm(test_loader)): 160 | 161 | # Build input batch 162 | im1, im2, gt = utils.build_input(images, imgpaths, is_training=False) 163 | 164 | # Forward 165 | out, feats = model(im1, im2) 166 | 167 | # Save loss values 168 | loss, loss_specific = criterion(out, gt, None, feats) 169 | for k, v in losses.items(): 170 | if k != 'total': 171 | v.update(loss_specific[k].item()) 172 | losses['total'].update(loss.item()) 173 | 174 | # Evaluate metrics 175 | utils.eval_metrics(out, gt, psnrs, ssims, lpips) 176 | 177 | # Log examples that have bad performance 178 | if (ssims.val < 0.9 or psnrs.val < 25) and epoch > 50: 179 | print(imgpaths) 180 | print("\nLoss: %f, PSNR: %f, SSIM: %f, LPIPS: %f" % 181 | (losses['total'].val, psnrs.val, ssims.val, lpips.val)) 182 | print(imgpaths[1][-1]) 183 | 184 | # Save result images 185 | if ((epoch + 1) % 1 == 0 and i < 20) or args.mode == 'test': 186 | savepath = os.path.join('checkpoint', args.exp_name, save_folder) 187 | 188 | for b in range(images[0].size(0)): 189 | paths = imgpaths[1][b].split('/') 190 | fp = os.path.join(savepath, paths[-3], paths[-2]) 191 | if not os.path.exists(fp): 192 | os.makedirs(fp) 193 | # remove '.png' extension 194 | fp = os.path.join(fp, paths[-1][:-4]) 195 | utils.save_image(out[b], "%s.png" % fp) 196 | 197 | # Print progress 198 | print('im_processed: {:d}/{:d} {:.3f}s \r'.format(i + 1, len(test_loader), time.time() - t)) 199 | print("Loss: %f, PSNR: %f, SSIM: %f, LPIPS: %f\n" % 200 | (losses['total'].avg, psnrs.avg, ssims.avg, lpips.avg)) 201 | 202 | # Save psnr & ssim 203 | save_fn = os.path.join('checkpoint', args.exp_name, save_folder, 'results.txt') 204 | with open(save_fn, 'a') as f: 205 | f.write("PSNR: %f, SSIM: %f, LPIPS: %f\n" % 206 | (psnrs.avg, ssims.avg, lpips.avg)) 207 | 208 | # Log to TensorBoard 209 | if args.mode != 'test': 210 | utils.log_tensorboard(writer, losses, psnrs.avg, ssims.avg, lpips.avg, 211 | optimizer.param_groups[-1]['lr'], epoch * len(train_loader) + i, mode='test') 212 | 213 | return losses['total'].avg, psnrs.avg, ssims.avg, lpips.avg 214 | 215 | 216 | """ Entry Point """ 217 | def main(args): 218 | if args.mode == 'test': 219 | _, _, _, _ = test(args, args.start_epoch) 220 | return 221 | 222 | best_psnr = 0 223 | for epoch in range(args.start_epoch, args.max_epoch): 224 | 225 | # run training 226 | train(args, epoch) 227 | 228 | # run test 229 | test_loss, psnr, _, _ = test(args, epoch) 230 | 231 | # save checkpoint 232 | is_best = psnr > best_psnr 233 | best_psnr = max(psnr, best_psnr) 234 | utils.save_checkpoint({ 235 | 'epoch': epoch, 236 | 'state_dict': model.state_dict(), 237 | 'optimizer': optimizer.state_dict(), 238 | 'best_psnr': best_psnr 239 | }, is_best, args.exp_name) 240 | 241 | # update optimizer policy 242 | scheduler.step(test_loss) 243 | 244 | if __name__ == "__main__": 245 | main(args) 246 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/model/__init__.py -------------------------------------------------------------------------------- /model/cain.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .common import * 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, in_channels=3, depth=3): 12 | super(Encoder, self).__init__() 13 | 14 | # Shuffle pixels to expand in channel dimension 15 | # shuffler_list = [PixelShuffle(0.5) for i in range(depth)] 16 | # self.shuffler = nn.Sequential(*shuffler_list) 17 | self.shuffler = PixelShuffle(1 / 2**depth) 18 | 19 | relu = nn.LeakyReLU(0.2, True) 20 | 21 | # FF_RCAN or FF_Resblocks 22 | self.interpolate = Interpolation(5, 12, in_channels * (4**depth), act=relu) 23 | 24 | def forward(self, x1, x2): 25 | """ 26 | Encoder: Shuffle-spread --> Feature Fusion --> Return fused features 27 | """ 28 | feats1 = self.shuffler(x1) 29 | feats2 = self.shuffler(x2) 30 | 31 | feats = self.interpolate(feats1, feats2) 32 | 33 | return feats 34 | 35 | 36 | class Decoder(nn.Module): 37 | def __init__(self, depth=3): 38 | super(Decoder, self).__init__() 39 | 40 | # shuffler_list = [PixelShuffle(2) for i in range(depth)] 41 | # self.shuffler = nn.Sequential(*shuffler_list) 42 | self.shuffler = PixelShuffle(2**depth) 43 | 44 | def forward(self, feats): 45 | out = self.shuffler(feats) 46 | return out 47 | 48 | 49 | class CAIN(nn.Module): 50 | def __init__(self, depth=3): 51 | super(CAIN, self).__init__() 52 | 53 | self.encoder = Encoder(in_channels=3, depth=depth) 54 | self.decoder = Decoder(depth=depth) 55 | 56 | def forward(self, x1, x2): 57 | x1, m1 = sub_mean(x1) 58 | x2, m2 = sub_mean(x2) 59 | 60 | if not self.training: 61 | paddingInput, paddingOutput = InOutPaddings(x1) 62 | x1 = paddingInput(x1) 63 | x2 = paddingInput(x2) 64 | 65 | feats = self.encoder(x1, x2) 66 | out = self.decoder(feats) 67 | 68 | if not self.training: 69 | out = paddingOutput(out) 70 | 71 | mi = (m1 + m2) / 2 72 | out += mi 73 | 74 | return out, feats 75 | -------------------------------------------------------------------------------- /model/cain_encdec.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .common import * 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, in_channels=3, depth=3, nf_start=32, norm=False): 12 | super(Encoder, self).__init__() 13 | self.device = torch.device('cuda') 14 | 15 | nf = nf_start 16 | relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 17 | 18 | self.body = nn.Sequential( 19 | ConvNorm(in_channels, nf * 1, 7, stride=1, norm=norm), 20 | relu, 21 | ConvNorm(nf * 1, nf * 2, 5, stride=2, norm=norm), 22 | relu, 23 | ConvNorm(nf * 2, nf * 4, 5, stride=2, norm=norm), 24 | relu, 25 | ConvNorm(nf * 4, nf * 6, 5, stride=2, norm=norm) 26 | ) 27 | 28 | self.interpolate = Interpolation(5, 12, nf * 6, reduction=16, act=relu) 29 | 30 | def forward(self, x1, x2): 31 | """ 32 | Encoder: Feature Extraction --> Feature Fusion --> Return 33 | """ 34 | feats1 = self.body(x1) 35 | feats2 = self.body(x2) 36 | 37 | feats = self.interpolate(feats1, feats2) 38 | 39 | return feats 40 | 41 | 42 | class Decoder(nn.Module): 43 | def __init__(self, in_channels=192, out_channels=3, depth=3, norm=False, up_mode='shuffle'): 44 | super(Decoder, self).__init__() 45 | self.device = torch.device('cuda') 46 | 47 | relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 48 | 49 | nf = [in_channels, (in_channels*2)//3, in_channels//3, in_channels//6] 50 | #nf = [192, 128, 64, 32] 51 | #nf = [186, 124, 62, 31] 52 | self.body = nn.Sequential( 53 | UpConvNorm(nf[0], nf[1], mode=up_mode, norm=norm), 54 | ResBlock(nf[1], nf[1], norm=norm, act=relu), 55 | UpConvNorm(nf[1], nf[2], mode=up_mode, norm=norm), 56 | ResBlock(nf[2], nf[2], norm=norm, act=relu), 57 | UpConvNorm(nf[2], nf[3], mode=up_mode, norm=norm), 58 | ResBlock(nf[3], nf[3], norm=norm, act=relu), 59 | conv7x7(nf[3], out_channels) 60 | ) 61 | 62 | def forward(self, feats): 63 | out = self.body(feats) 64 | #out = self.conv_final(out) 65 | 66 | return out 67 | 68 | 69 | class CAIN_EncDec(nn.Module): 70 | def __init__(self, depth=3, n_resblocks=3, start_filts=32, up_mode='shuffle'): 71 | super(CAIN_EncDec, self).__init__() 72 | self.depth = depth 73 | 74 | self.encoder = Encoder(in_channels=3, depth=depth, norm=False) 75 | self.decoder = Decoder(in_channels=start_filts*6, depth=depth, norm=False, up_mode=up_mode) 76 | 77 | def forward(self, x1, x2): 78 | x1, m1 = sub_mean(x1) 79 | x2, m2 = sub_mean(x2) 80 | 81 | if not self.training: 82 | paddingInput, paddingOutput = InOutPaddings(x1) 83 | x1 = paddingInput(x1) 84 | x2 = paddingInput(x2) 85 | 86 | feats = self.encoder(x1, x2) 87 | out = self.decoder(feats) 88 | 89 | if not self.training: 90 | out = paddingOutput(out) 91 | 92 | mi = (m1 + m2)/2 93 | out += mi 94 | 95 | return out, feats 96 | 97 | 98 | -------------------------------------------------------------------------------- /model/cain_noca.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .common import * 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, in_channels=3, depth=3): 11 | super(Encoder, self).__init__() 12 | self.device = torch.device('cuda') 13 | 14 | self.shuffler = PixelShuffle(1/2**depth) 15 | # self.shuffler = nn.Sequential( 16 | # PixelShuffle(1/2), 17 | # PixelShuffle(1/2), 18 | # PixelShuffle(1/2)) 19 | self.interpolate = Interpolation_res(5, 12, in_channels * (4**depth)) 20 | 21 | def forward(self, x1, x2): 22 | feats1 = self.shuffler(x1) 23 | feats2 = self.shuffler(x2) 24 | 25 | feats = self.interpolate(feats1, feats2) 26 | 27 | return feats 28 | 29 | 30 | class Decoder(nn.Module): 31 | def __init__(self, depth=3): 32 | super(Decoder, self).__init__() 33 | self.device = torch.device('cuda') 34 | 35 | self.shuffler = PixelShuffle(2**depth) 36 | # self.shuffler = nn.Sequential( 37 | # PixelShuffle(2), 38 | # PixelShuffle(2), 39 | # PixelShuffle(2)) 40 | 41 | def forward(self, feats): 42 | out = self.shuffler(feats) 43 | return out 44 | 45 | 46 | class CAIN_NoCA(nn.Module): 47 | def __init__(self, depth=3): 48 | super(CAIN_NoCA, self).__init__() 49 | self.depth = depth 50 | 51 | self.encoder = Encoder(in_channels=3, depth=depth) 52 | self.decoder = Decoder(depth=depth) 53 | 54 | def forward(self, x1, x2): 55 | x1, m1 = sub_mean(x1) 56 | x2, m2 = sub_mean(x2) 57 | 58 | if not self.training: 59 | paddingInput, paddingOutput = InOutPaddings(x1) 60 | x1 = paddingInput(x1) 61 | x2 = paddingInput(x2) 62 | 63 | feats = self.encoder(x1, x2) 64 | out = self.decoder(feats) 65 | 66 | if not self.training: 67 | out = paddingOutput(out) 68 | 69 | mi = (m1 + m2) / 2 70 | out += mi 71 | 72 | return out, feats 73 | -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def sub_mean(x): 8 | mean = x.mean(2, keepdim=True).mean(3, keepdim=True) 9 | x -= mean 10 | return x, mean 11 | 12 | def InOutPaddings(x): 13 | w, h = x.size(3), x.size(2) 14 | padding_width, padding_height = 0, 0 15 | if w != ((w >> 7) << 7): 16 | padding_width = (((w >> 7) + 1) << 7) - w 17 | if h != ((h >> 7) << 7): 18 | padding_height = (((h >> 7) + 1) << 7) - h 19 | paddingInput = nn.ReflectionPad2d(padding=[padding_width // 2, padding_width - padding_width // 2, 20 | padding_height // 2, padding_height - padding_height // 2]) 21 | paddingOutput = nn.ReflectionPad2d(padding=[0 - padding_width // 2, padding_width // 2 - padding_width, 22 | 0 - padding_height // 2, padding_height // 2 - padding_height]) 23 | return paddingInput, paddingOutput 24 | 25 | 26 | class ConvNorm(nn.Module): 27 | def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False): 28 | super(ConvNorm, self).__init__() 29 | 30 | reflection_padding = kernel_size // 2 31 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 32 | self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True) 33 | 34 | self.norm = norm 35 | if norm == 'IN': 36 | self.norm = nn.InstanceNorm2d(out_feat, track_running_stats=True) 37 | elif norm == 'BN': 38 | self.norm = nn.BatchNorm2d(out_feat) 39 | 40 | def forward(self, x): 41 | out = self.reflection_pad(x) 42 | out = self.conv(out) 43 | if self.norm: 44 | out = self.norm(out) 45 | return out 46 | 47 | 48 | class UpConvNorm(nn.Module): 49 | def __init__(self, in_channels, out_channels, mode='transpose', norm=False): 50 | super(UpConvNorm, self).__init__() 51 | 52 | if mode == 'transpose': 53 | self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) 54 | elif mode == 'shuffle': 55 | self.upconv = nn.Sequential( 56 | ConvNorm(in_channels, 4*out_channels, kernel_size=3, stride=1, norm=norm), 57 | PixelShuffle(2)) 58 | else: 59 | # out_channels is always going to be the same as in_channels 60 | self.upconv = nn.Sequential( 61 | nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), 62 | ConvNorm(in_channels, out_channels, kernel_size=1, stride=1, norm=norm)) 63 | 64 | def forward(self, x): 65 | out = self.upconv(x) 66 | return out 67 | 68 | 69 | 70 | class meanShift(nn.Module): 71 | def __init__(self, rgbRange, rgbMean, sign, nChannel=3): 72 | super(meanShift, self).__init__() 73 | if nChannel == 1: 74 | l = rgbMean[0] * rgbRange * float(sign) 75 | 76 | self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0) 77 | self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1) 78 | self.shifter.bias.data = torch.Tensor([l]) 79 | elif nChannel == 3: 80 | r = rgbMean[0] * rgbRange * float(sign) 81 | g = rgbMean[1] * rgbRange * float(sign) 82 | b = rgbMean[2] * rgbRange * float(sign) 83 | 84 | self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) 85 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) 86 | self.shifter.bias.data = torch.Tensor([r, g, b]) 87 | else: 88 | r = rgbMean[0] * rgbRange * float(sign) 89 | g = rgbMean[1] * rgbRange * float(sign) 90 | b = rgbMean[2] * rgbRange * float(sign) 91 | self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0) 92 | self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1) 93 | self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b]) 94 | 95 | # Freeze the meanShift layer 96 | for params in self.shifter.parameters(): 97 | params.requires_grad = False 98 | 99 | def forward(self, x): 100 | x = self.shifter(x) 101 | 102 | return x 103 | 104 | 105 | """ CONV - (BN) - RELU - CONV - (BN) """ 106 | class ResBlock(nn.Module): 107 | def __init__(self, in_feat, out_feat, kernel_size=3, reduction=False, bias=True, # 'reduction' is just for placeholder 108 | norm=False, act=nn.ReLU(True), downscale=False): 109 | super(ResBlock, self).__init__() 110 | 111 | self.body = nn.Sequential( 112 | ConvNorm(in_feat, out_feat, kernel_size=kernel_size, stride=2 if downscale else 1), 113 | act, 114 | ConvNorm(out_feat, out_feat, kernel_size=kernel_size, stride=1) 115 | ) 116 | 117 | self.downscale = None 118 | if downscale: 119 | self.downscale = nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=2) 120 | 121 | def forward(self, x): 122 | res = x 123 | out = self.body(x) 124 | if self.downscale is not None: 125 | res = self.downscale(res) 126 | out += res 127 | 128 | return out 129 | 130 | 131 | ## Channel Attention (CA) Layer 132 | class CALayer(nn.Module): 133 | def __init__(self, channel, reduction=16): 134 | super(CALayer, self).__init__() 135 | # global average pooling: feature --> point 136 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 137 | # feature channel downscale and upscale --> channel weight 138 | self.conv_du = nn.Sequential( 139 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 140 | nn.ReLU(inplace=True), 141 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 142 | nn.Sigmoid() 143 | ) 144 | 145 | def forward(self, x): 146 | y = self.avg_pool(x) 147 | y = self.conv_du(y) 148 | return x * y, y 149 | 150 | 151 | ## Residual Channel Attention Block (RCAB) 152 | class RCAB(nn.Module): 153 | def __init__(self, in_feat, out_feat, kernel_size, reduction, bias=True, 154 | norm=False, act=nn.ReLU(True), downscale=False, return_ca=False): 155 | super(RCAB, self).__init__() 156 | 157 | self.body = nn.Sequential( 158 | ConvNorm(in_feat, out_feat, kernel_size, stride=2 if downscale else 1, norm=norm), 159 | act, 160 | ConvNorm(out_feat, out_feat, kernel_size, stride=1, norm=norm), 161 | CALayer(out_feat, reduction) 162 | ) 163 | self.downscale = downscale 164 | if downscale: 165 | self.downConv = nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=2, padding=1) 166 | self.return_ca = return_ca 167 | 168 | def forward(self, x): 169 | res = x 170 | out, ca = self.body(x) 171 | if self.downscale: 172 | res = self.downConv(res) 173 | out += res 174 | 175 | if self.return_ca: 176 | return out, ca 177 | else: 178 | return out 179 | 180 | 181 | ## Residual Group (RG) 182 | class ResidualGroup(nn.Module): 183 | def __init__(self, Block, n_resblocks, n_feat, kernel_size, reduction, act, norm=False): 184 | super(ResidualGroup, self).__init__() 185 | 186 | modules_body = [Block(n_feat, n_feat, kernel_size, reduction, bias=True, norm=norm, act=act) 187 | for _ in range(n_resblocks)] 188 | modules_body.append(ConvNorm(n_feat, n_feat, kernel_size, stride=1, norm=norm)) 189 | self.body = nn.Sequential(*modules_body) 190 | 191 | def forward(self, x): 192 | res = self.body(x) 193 | res += x 194 | return res 195 | 196 | 197 | def pixel_shuffle(input, scale_factor): 198 | batch_size, channels, in_height, in_width = input.size() 199 | 200 | out_channels = int(int(channels / scale_factor) / scale_factor) 201 | out_height = int(in_height * scale_factor) 202 | out_width = int(in_width * scale_factor) 203 | 204 | if scale_factor >= 1: 205 | input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width) 206 | shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() 207 | else: 208 | block_size = int(1 / scale_factor) 209 | input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size) 210 | shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 211 | 212 | return shuffle_out.view(batch_size, out_channels, out_height, out_width) 213 | 214 | 215 | class PixelShuffle(nn.Module): 216 | def __init__(self, scale_factor): 217 | super(PixelShuffle, self).__init__() 218 | self.scale_factor = scale_factor 219 | 220 | def forward(self, x): 221 | return pixel_shuffle(x, self.scale_factor) 222 | def extra_repr(self): 223 | return 'scale_factor={}'.format(self.scale_factor) 224 | 225 | 226 | def conv(in_channels, out_channels, kernel_size, 227 | stride=1, bias=True, groups=1): 228 | return nn.Conv2d( 229 | in_channels, 230 | out_channels, 231 | kernel_size=kernel_size, 232 | padding=kernel_size//2, 233 | stride=1, 234 | bias=bias, 235 | groups=groups) 236 | 237 | 238 | def conv1x1(in_channels, out_channels, stride=1, bias=True, groups=1): 239 | return nn.Conv2d( 240 | in_channels, 241 | out_channels, 242 | kernel_size=1, 243 | stride=stride, 244 | bias=bias, 245 | groups=groups) 246 | 247 | def conv3x3(in_channels, out_channels, stride=1, 248 | padding=1, bias=True, groups=1): 249 | return nn.Conv2d( 250 | in_channels, 251 | out_channels, 252 | kernel_size=3, 253 | stride=stride, 254 | padding=padding, 255 | bias=bias, 256 | groups=groups) 257 | 258 | def conv5x5(in_channels, out_channels, stride=1, 259 | padding=2, bias=True, groups=1): 260 | return nn.Conv2d( 261 | in_channels, 262 | out_channels, 263 | kernel_size=5, 264 | stride=stride, 265 | padding=padding, 266 | bias=bias, 267 | groups=groups) 268 | 269 | def conv7x7(in_channels, out_channels, stride=1, 270 | padding=3, bias=True, groups=1): 271 | return nn.Conv2d( 272 | in_channels, 273 | out_channels, 274 | kernel_size=7, 275 | stride=stride, 276 | padding=padding, 277 | bias=bias, 278 | groups=groups) 279 | 280 | def upconv2x2(in_channels, out_channels, mode='shuffle'): 281 | if mode == 'transpose': 282 | return nn.ConvTranspose2d( 283 | in_channels, 284 | out_channels, 285 | kernel_size=4, 286 | stride=2, 287 | padding=1) 288 | elif mode == 'shuffle': 289 | return nn.Sequential( 290 | conv3x3(in_channels, 4*out_channels), 291 | PixelShuffle(2)) 292 | else: 293 | # out_channels is always going to be the same as in_channels 294 | return nn.Sequential( 295 | nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), 296 | conv1x1(in_channels, out_channels)) 297 | 298 | 299 | 300 | class Interpolation(nn.Module): 301 | def __init__(self, n_resgroups, n_resblocks, n_feats, 302 | reduction=16, act=nn.LeakyReLU(0.2, True), norm=False): 303 | super(Interpolation, self).__init__() 304 | 305 | # define modules: head, body, tail 306 | self.headConv = conv3x3(n_feats * 2, n_feats) 307 | 308 | modules_body = [ 309 | ResidualGroup( 310 | RCAB, 311 | n_resblocks=n_resblocks, 312 | n_feat=n_feats, 313 | kernel_size=3, 314 | reduction=reduction, 315 | act=act, 316 | norm=norm) 317 | for _ in range(n_resgroups)] 318 | self.body = nn.Sequential(*modules_body) 319 | 320 | self.tailConv = conv3x3(n_feats, n_feats) 321 | 322 | def forward(self, x0, x1): 323 | # Build input tensor 324 | x = torch.cat([x0, x1], dim=1) 325 | x = self.headConv(x) 326 | 327 | res = self.body(x) 328 | res += x 329 | 330 | out = self.tailConv(res) 331 | return out 332 | 333 | 334 | class Interpolation_res(nn.Module): 335 | def __init__(self, n_resgroups, n_resblocks, n_feats, 336 | act=nn.LeakyReLU(0.2, True), norm=False): 337 | super(Interpolation_res, self).__init__() 338 | 339 | # define modules: head, body, tail (reduces concatenated inputs to n_feat) 340 | self.headConv = conv3x3(n_feats * 2, n_feats) 341 | 342 | modules_body = [ResidualGroup(ResBlock, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3, 343 | reduction=0, act=act, norm=norm) 344 | for _ in range(n_resgroups)] 345 | self.body = nn.Sequential(*modules_body) 346 | 347 | self.tailConv = conv3x3(n_feats, n_feats) 348 | 349 | def forward(self, x0, x1): 350 | # Build input tensor 351 | x = torch.cat([x0, x1], dim=1) 352 | x = self.headConv(x) 353 | 354 | res = x 355 | for m in self.body: 356 | res = m(res) 357 | res += x 358 | 359 | x = self.tailConv(res) 360 | 361 | return x 362 | -------------------------------------------------------------------------------- /pytorch_msssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | 12 | def create_window(window_size, channel=1): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda() 15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 16 | return window 17 | 18 | def create_window_3d(window_size, channel=1): 19 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 20 | _2D_window = _1D_window.mm(_1D_window.t()) 21 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) 22 | window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda() 23 | return window 24 | 25 | 26 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 27 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 28 | if val_range is None: 29 | if torch.max(img1) > 128: 30 | max_val = 255 31 | else: 32 | max_val = 1 33 | 34 | if torch.min(img1) < -0.5: 35 | min_val = -1 36 | else: 37 | min_val = 0 38 | L = max_val - min_val 39 | else: 40 | L = val_range 41 | 42 | padd = 0 43 | (_, channel, height, width) = img1.size() 44 | if window is None: 45 | real_size = min(window_size, height, width) 46 | window = create_window(real_size, channel=channel).to(img1.device) 47 | 48 | # mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 49 | # mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 50 | mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) 51 | mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) 52 | 53 | mu1_sq = mu1.pow(2) 54 | mu2_sq = mu2.pow(2) 55 | mu1_mu2 = mu1 * mu2 56 | 57 | # sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 58 | # sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 59 | # sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 60 | 61 | sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq 62 | sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq 63 | sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2 64 | 65 | C1 = (0.01 * L) ** 2 66 | C2 = (0.03 * L) ** 2 67 | 68 | v1 = 2.0 * sigma12 + C2 69 | v2 = sigma1_sq + sigma2_sq + C2 70 | cs = torch.mean(v1 / v2) # contrast sensitivity 71 | 72 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 73 | 74 | if size_average: 75 | ret = ssim_map.mean() 76 | else: 77 | ret = ssim_map.mean(1).mean(1).mean(1) 78 | 79 | if full: 80 | return ret, cs 81 | return ret 82 | 83 | 84 | def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 85 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 86 | if val_range is None: 87 | if torch.max(img1) > 128: 88 | max_val = 255 89 | else: 90 | max_val = 1 91 | 92 | if torch.min(img1) < -0.5: 93 | min_val = -1 94 | else: 95 | min_val = 0 96 | L = max_val - min_val 97 | else: 98 | L = val_range 99 | 100 | padd = 0 101 | (_, _, height, width) = img1.size() 102 | if window is None: 103 | real_size = min(window_size, height, width) 104 | window = create_window_3d(real_size, channel=1).to(img1.device) 105 | # Channel is set to 1 since we consider color images as volumetric images 106 | 107 | img1 = img1.unsqueeze(1) 108 | img2 = img2.unsqueeze(1) 109 | 110 | mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 111 | mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 112 | 113 | mu1_sq = mu1.pow(2) 114 | mu2_sq = mu2.pow(2) 115 | mu1_mu2 = mu1 * mu2 116 | 117 | sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq 118 | sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq 119 | sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 120 | 121 | C1 = (0.01 * L) ** 2 122 | C2 = (0.03 * L) ** 2 123 | 124 | v1 = 2.0 * sigma12 + C2 125 | v2 = sigma1_sq + sigma2_sq + C2 126 | cs = torch.mean(v1 / v2) # contrast sensitivity 127 | 128 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 129 | 130 | if size_average: 131 | ret = ssim_map.mean() 132 | else: 133 | ret = ssim_map.mean(1).mean(1).mean(1) 134 | 135 | if full: 136 | return ret, cs 137 | return ret 138 | 139 | 140 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 141 | device = img1.device 142 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 143 | levels = weights.size()[0] 144 | mssim = [] 145 | mcs = [] 146 | for _ in range(levels): 147 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 148 | mssim.append(sim) 149 | mcs.append(cs) 150 | 151 | img1 = F.avg_pool2d(img1, (2, 2)) 152 | img2 = F.avg_pool2d(img2, (2, 2)) 153 | 154 | mssim = torch.stack(mssim) 155 | mcs = torch.stack(mcs) 156 | 157 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 158 | if normalize: 159 | mssim = (mssim + 1) / 2 160 | mcs = (mcs + 1) / 2 161 | 162 | pow1 = mcs ** weights 163 | pow2 = mssim ** weights 164 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 165 | output = torch.prod(pow1[:-1] * pow2[-1]) 166 | return output 167 | 168 | 169 | # Classes to re-use window 170 | class SSIM(torch.nn.Module): 171 | def __init__(self, window_size=11, size_average=True, val_range=None): 172 | super(SSIM, self).__init__() 173 | self.window_size = window_size 174 | self.size_average = size_average 175 | self.val_range = val_range 176 | 177 | # Assume 3 channel for SSIM 178 | self.channel = 3 179 | self.window = create_window(window_size, channel=self.channel) 180 | 181 | def forward(self, img1, img2): 182 | (_, channel, _, _) = img1.size() 183 | 184 | if channel == self.channel and self.window.dtype == img1.dtype: 185 | window = self.window 186 | else: 187 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 188 | self.window = window 189 | self.channel = channel 190 | 191 | _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 192 | dssim = (1 - _ssim) / 2 193 | return dssim 194 | 195 | class MSSSIM(torch.nn.Module): 196 | def __init__(self, window_size=11, size_average=True, channel=3): 197 | super(MSSSIM, self).__init__() 198 | self.window_size = window_size 199 | self.size_average = size_average 200 | self.channel = channel 201 | 202 | def forward(self, img1, img2): 203 | # TODO: store window between calls if possible 204 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 205 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py \ 4 | --exp_name CAIN_train \ 5 | --dataset vimeo90k \ 6 | --batch_size 16 \ 7 | --test_batch_size 16 \ 8 | --model cain \ 9 | --depth 3 \ 10 | --loss 1*L1 \ 11 | --max_epoch 200 \ 12 | --lr 0.0002 \ 13 | --log_iter 100 \ 14 | # --mode test -------------------------------------------------------------------------------- /run_noca.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 python main.py \ 4 | --exp_name CAIN_test_noca \ 5 | --dataset vimeo90k \ 6 | --batch_size 16 \ 7 | --test_batch_size 16 \ 8 | --model cain_noca \ 9 | --depth 3 \ 10 | --loss 1*L1 \ 11 | --max_epoch 200 \ 12 | --lr 0.0002 \ 13 | --log_iter 100 \ 14 | # --mode test 15 | # --resume True \ 16 | # --resume_exp SH_5_12 17 | # --fix_encoder -------------------------------------------------------------------------------- /test_custom.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python generate.py \ 4 | --exp_name CAIN_fin \ 5 | --dataset custom \ 6 | --data_root data/frame_seq \ 7 | --img_fmt png \ 8 | --batch_size 32 \ 9 | --test_batch_size 16 \ 10 | --model cain \ 11 | --depth 3 \ 12 | --loss 1*L1 \ 13 | --resume \ 14 | --mode test -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from datetime import datetime 3 | import os 4 | import sys 5 | import math 6 | import random 7 | import json 8 | import glob 9 | import logging 10 | import shutil 11 | 12 | import numpy as np 13 | import torch 14 | from torchvision import transforms 15 | 16 | from PIL import Image, ImageFont, ImageDraw 17 | #from skimage.measure import compare_psnr, compare_ssim 18 | 19 | try: 20 | from StringIO import StringIO # Python 2.7 21 | except ImportError: 22 | from io import BytesIO # Python 3.x 23 | 24 | import cv2 25 | 26 | from pytorch_msssim import ssim_matlab as ssim_pth 27 | # from pytorch_msssim import ssim as ssim_pth 28 | 29 | ########################## 30 | # Training Helper Functions for making main.py clean 31 | ########################## 32 | 33 | def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers, test_mode='medium', img_fmt='png'): 34 | 35 | if dataset_str == 'snufilm': 36 | from data.snufilm import get_loader 37 | test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, test_mode=test_mode) 38 | return None, test_loader 39 | elif dataset_str == 'vimeo90k': 40 | from data.vimeo90k import get_loader 41 | elif dataset_str == 'aim': 42 | from data.aim import get_loader 43 | elif dataset_str == 'custom': 44 | from data.video import get_loader 45 | test_loader = get_loader('test', data_root, test_batch_size, img_fmt=img_fmt, shuffle=False, num_workers=num_workers, n_frames=1) 46 | return test_loader 47 | else: 48 | raise NotImplementedError('Training / Testing for this dataset is not implemented.') 49 | 50 | train_loader = get_loader('train', data_root, batch_size, shuffle=True, num_workers=num_workers) 51 | if dataset_str == 'aim': 52 | test_loader = get_loader('val', data_root, test_batch_size, shuffle=False, num_workers=num_workers) 53 | else: 54 | test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers) 55 | 56 | return train_loader, test_loader 57 | 58 | 59 | def build_input(images, imgpaths, is_training=True, include_edge=False, device=torch.device('cuda')): 60 | if isinstance(images[0], list): 61 | images_gathered = [None, None, None] 62 | for j in range(len(images[0])): # 3 63 | _images = [images[k][j] for k in range(len(images))] 64 | images_gathered[j] = torch.cat(_images, 0) 65 | imgpaths = [p for _ in images for p in imgpaths] 66 | images = images_gathered 67 | 68 | im1, im2 = images[0].to(device), images[2].to(device) 69 | gt = images[1].to(device) 70 | 71 | return im1, im2, gt 72 | 73 | 74 | def load_checkpoint(args, model, optimizer, fix_loaded=False): 75 | if args.resume_exp is None: 76 | args.resume_exp = args.exp_name 77 | if args.mode == 'test': 78 | load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth') 79 | else: 80 | #load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth') 81 | load_name = os.path.join('checkpoint', args.resume_exp, 'checkpoint.pth') 82 | print("loading checkpoint %s" % load_name) 83 | checkpoint = torch.load(load_name) 84 | args.start_epoch = checkpoint['epoch'] + 1 85 | if args.resume_exp != args.exp_name: 86 | args.start_epoch = 0 87 | 88 | # filter out different keys or those with size mismatch 89 | model_dict = model.state_dict() 90 | ckpt_dict = {} 91 | mismatch = False 92 | for k, v in checkpoint['state_dict'].items(): 93 | if k in model_dict: 94 | if model_dict[k].size() == v.size(): 95 | ckpt_dict[k] = v 96 | else: 97 | print('Size mismatch while loading! %s != %s Skipping %s...' 98 | % (str(model_dict[k].size()), str(v.size()), k)) 99 | mismatch = True 100 | else: 101 | mismatch = True 102 | if len(model.state_dict().keys()) > len(ckpt_dict.keys()): 103 | mismatch = True 104 | # Overwrite parameters to model_dict 105 | model_dict.update(ckpt_dict) 106 | # Load to model 107 | model.load_state_dict(model_dict) 108 | # if size mismatch, give up on loading optimizer; if resuming from other experiment, also don't load optimizer 109 | if (not mismatch) and (optimizer is not None) and (args.resume_exp is not None): 110 | optimizer.load_state_dict(checkpoint['optimizer']) 111 | update_lr(optimizer, args.lr) 112 | if fix_loaded: 113 | for k, param in model.named_parameters(): 114 | if k in ckpt_dict.keys(): 115 | print(k) 116 | param.requires_grad = False 117 | print("loaded checkpoint %s" % load_name) 118 | del checkpoint, ckpt_dict, model_dict 119 | 120 | 121 | def save_checkpoint(state, is_best, exp_name, filename='checkpoint.pth'): 122 | """Saves checkpoint to disk""" 123 | directory = "checkpoint/%s/" % (exp_name) 124 | if not os.path.exists(directory): 125 | os.makedirs(directory) 126 | filename = directory + filename 127 | torch.save(state, filename) 128 | if is_best: 129 | shutil.copyfile(filename, 'checkpoint/%s/' % (exp_name) + 'model_best.pth') 130 | 131 | 132 | def init_lpips_eval(): 133 | LPIPS_dir = "../PerceptualSimilarity" 134 | LPIPS_net = "squeeze" 135 | sys.path.append(LPIPS_dir) 136 | from models import dist_model as dm 137 | print("Initialize Distance model from %s" % LPIPS_net) 138 | lpips_model = dm.DistModel() 139 | lpips_model.initialize(model='net-lin',net='squeeze', use_gpu=True, 140 | model_path=os.path.join(LPIPS_dir, 'weights/v0.1/%s.pth' % LPIPS_net)) 141 | return lpips_model 142 | 143 | 144 | ########################## 145 | # Evaluations 146 | ########################## 147 | 148 | class AverageMeter(object): 149 | """Computes and stores the average and current value""" 150 | def __init__(self): 151 | self.reset() 152 | 153 | def reset(self): 154 | self.val = 0 155 | self.avg = 0 156 | self.sum = 0 157 | self.count = 0 158 | 159 | def update(self, val, n=1): 160 | self.val = val 161 | self.sum += val * n 162 | self.count += n 163 | self.avg = self.sum / self.count 164 | 165 | 166 | def init_losses(loss_str): 167 | loss_specifics = {} 168 | loss_list = loss_str.split('+') 169 | for l in loss_list: 170 | _, loss_type = l.split('*') 171 | loss_specifics[loss_type] = AverageMeter() 172 | loss_specifics['total'] = AverageMeter() 173 | return loss_specifics 174 | 175 | 176 | def init_meters(loss_str): 177 | losses = init_losses(loss_str) 178 | psnrs = AverageMeter() 179 | ssims = AverageMeter() 180 | lpips = AverageMeter() 181 | return losses, psnrs, ssims, lpips 182 | 183 | 184 | def quantize(img, rgb_range=255): 185 | return img.mul(255 / rgb_range).clamp(0, 255).round() 186 | 187 | 188 | def calc_psnr(pred, gt, mask=None): 189 | ''' 190 | Here we assume quantized(0-255) arguments. 191 | ''' 192 | diff = (pred - gt).div(255) 193 | 194 | if mask is not None: 195 | mse = diff.pow(2).sum() / (3 * mask.sum()) 196 | else: 197 | mse = diff.pow(2).mean() + 1e-8 # mse can (surprisingly!) reach 0, which results in math domain error 198 | 199 | return -10 * math.log10(mse) 200 | 201 | 202 | def calc_ssim(img1, img2, datarange=255.): 203 | im1 = img1.numpy().transpose(1, 2, 0).astype(np.uint8) 204 | im2 = img2.numpy().transpose(1, 2, 0).astype(np.uint8) 205 | return compare_ssim(im1, im2, datarange=datarange, multichannel=True, gaussian_weights=True) 206 | 207 | 208 | def calc_metrics(im_pred, im_gt, mask=None): 209 | q_im_pred = quantize(im_pred.data, rgb_range=1.) 210 | q_im_gt = quantize(im_gt.data, rgb_range=1.) 211 | if mask is not None: 212 | q_im_pred = q_im_pred * mask 213 | q_im_gt = q_im_gt * mask 214 | psnr = calc_psnr(q_im_pred, q_im_gt, mask=mask) 215 | # ssim = calc_ssim(q_im_pred.cpu(), q_im_gt.cpu()) 216 | ssim = ssim_pth(q_im_pred.unsqueeze(0), q_im_gt.unsqueeze(0), val_range=255) 217 | return psnr, ssim 218 | 219 | 220 | def eval_LPIPS(model, im_pred, im_gt): 221 | im_pred = 2.0 * im_pred - 1 222 | im_gt = 2.0 * im_gt - 1 223 | dist = model.forward(im_pred, im_gt)[0] 224 | return dist 225 | 226 | 227 | def eval_metrics(output, gt, psnrs, ssims, lpips, lpips_model=None, mask=None, psnrs_masked=None, ssims_masked=None): 228 | # PSNR should be calculated for each image 229 | for b in range(gt.size(0)): 230 | psnr, ssim = calc_metrics(output[b], gt[b], None) 231 | psnrs.update(psnr) 232 | ssims.update(ssim) 233 | if mask is not None: 234 | psnr_masked, ssim_masked = calc_metrics(output[b], gt[b], mask[b]) 235 | psnrs_masked.update(psnr_masked) 236 | ssims_masked.update(ssim_masked) 237 | if lpips_model is not None: 238 | _lpips = eval_LPIPS(lpips_model, output[b].unsqueeze(0), gt[b].unsqueeze(0)) 239 | lpips.update(_lpips) 240 | 241 | 242 | ########################## 243 | # ETC 244 | ########################## 245 | 246 | def get_time(): 247 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 248 | 249 | def makedirs(path): 250 | if not os.path.exists(path): 251 | print("[*] Make directories : {}".format(path)) 252 | os.makedirs(path) 253 | 254 | def remove_file(path): 255 | if os.path.exists(path): 256 | print("[*] Removed: {}".format(path)) 257 | os.remove(path) 258 | 259 | def backup_file(path): 260 | root, ext = os.path.splitext(path) 261 | new_path = "{}.backup_{}{}".format(root, get_time(), ext) 262 | 263 | os.rename(path, new_path) 264 | print("[*] {} has backup: {}".format(path, new_path)) 265 | 266 | def update_lr(optimizer, lr): 267 | for param_group in optimizer.param_groups: 268 | param_group['lr'] = lr 269 | 270 | 271 | # TensorBoard 272 | def log_tensorboard(writer, losses, psnr, ssim, lpips, lr, timestep, mode='train'): 273 | for k, v in losses.items(): 274 | writer.add_scalar('Loss/%s/%s' % (mode, k), v.avg, timestep) 275 | writer.add_scalar('PSNR/%s' % mode, psnr, timestep) 276 | writer.add_scalar('SSIM/%s' % mode, ssim, timestep) 277 | if lpips is not None: 278 | writer.add_scalar('LPIPS/%s' % mode, lpips, timestep) 279 | if mode == 'train': 280 | writer.add_scalar('lr', lr, timestep) 281 | 282 | 283 | ########################### 284 | ###### VISUALIZATIONS ##### 285 | ########################### 286 | 287 | def save_image(img, path): 288 | # img : torch Tensor of size (C, H, W) 289 | q_im = quantize(img.data.mul(255)) 290 | if len(img.size()) == 2: # grayscale image 291 | im = Image.fromarray(q_im.cpu().numpy().astype(np.uint8), 'L') 292 | elif len(img.size()) == 3: 293 | im = Image.fromarray(q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB') 294 | else: 295 | pass 296 | im.save(path) 297 | 298 | def save_batch_images(output, imgpath, save_dir, alpha=0.5): 299 | GEN = save_dir.find('-gen') >= 0 or save_dir.find('stereo') >= 0 300 | q_im_output = [quantize(o.data, rgb_range=1.) for o in output] 301 | for b in range(output[0].size(0)): 302 | paths = imgpath[0][b].split('/') 303 | if GEN: 304 | save_path = save_dir 305 | else: 306 | save_path = os.path.join(save_dir, paths[-3], paths[-2]) 307 | makedirs(save_path) 308 | for o in range(len(output)): 309 | if o % 2 == 1 or len(output) == 1: 310 | output_img = Image.fromarray(q_im_output[o][b].permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB') 311 | if GEN: 312 | _imgname = imgpath[o//2][b].split('/')[-1] 313 | imgname = "%s-%.04f.png" % (_imgname, alpha) 314 | else: 315 | imgname = imgpath[o//2][b].split('/')[-1] 316 | 317 | if save_dir.find('voxelflow') >= 0: 318 | #imgname = imgname.replace('gt', 'ours') 319 | imgname = 'frame_01_ours.png' 320 | elif save_dir.find('middlebury') >= 0: 321 | imgname = 'frame10i11.png' 322 | 323 | output_img.save(os.path.join(save_path, imgname)) 324 | 325 | 326 | def save_batch_images_test(output, imgpath, save_dir, alpha=0.5): 327 | GEN = save_dir.find('-gen') >= 0 or save_dir.find('stereo') >= 0 328 | q_im_output = [quantize(o.data, rgb_range=1.) for o in output] 329 | for b in range(output[0].size(0)): 330 | paths = imgpath[0][b].split('/') 331 | if GEN: 332 | save_path = save_dir 333 | else: 334 | save_path = os.path.join(save_dir, paths[-3], paths[-2]) 335 | makedirs(save_path) 336 | for o in range(len(output)): 337 | # if o % 2 == 1 or len(output) == 1: 338 | # print(" ", o, b, imgpath[o][b]) 339 | output_img = Image.fromarray(q_im_output[o][b].permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB') 340 | if GEN: 341 | _imgname = imgpath[o][b].split('/')[-1] 342 | imgname = "%s-%.04f.png" % (_imgname, alpha) 343 | else: 344 | imgname = imgpath[o][b].split('/')[-1] 345 | 346 | if save_dir.find('voxelflow') >= 0: 347 | #imgname = imgname.replace('gt', 'ours') 348 | imgname = 'frame_01_ours.png' 349 | elif save_dir.find('middlebury') >= 0: 350 | imgname = 'frame10i11.png' 351 | 352 | output_img.save(os.path.join(save_path, imgname)) 353 | 354 | 355 | def save_images_test(output, imgpath, save_dir, alpha=0.5): 356 | q_im_output = [quantize(o.data, rgb_range=1.) for o in output] 357 | for b in range(output[0].size(0)): 358 | paths = imgpath[1][b].split('/') 359 | save_path = os.path.join(save_dir, paths[-3], paths[-2]) 360 | makedirs(save_path) 361 | # Output length is one 362 | output_img = Image.fromarray(q_im_output[0][b].permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB') 363 | imgname = imgpath[1][b].split('/')[-1] 364 | 365 | # if save_dir.find('voxelflow') >= 0: 366 | # imgname = 'frame_01_ours.png' 367 | # elif save_dir.find('middlebury') >= 0: 368 | # imgname = 'frame10i11.png' 369 | 370 | output_img.save(os.path.join(save_path, imgname)) 371 | 372 | 373 | def save_images_multi(output, imgpath, save_dir, idx=1): 374 | q_im_output = [quantize(o.data, rgb_range=1.) for o in output] 375 | for b in range(output[0].size(0)): 376 | paths = imgpath[0][b].split('/') 377 | # save_path = os.path.join(save_dir, paths[-3], paths[-2]) 378 | # makedirs(save_path) 379 | # Output length is one 380 | output_img = Image.fromarray(q_im_output[0][b].permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB') 381 | # imgname = imgpath[idx][b].split('/')[-1] 382 | imgname = '%s_%03d.png' % (paths[-1], idx) 383 | 384 | output_img.save(os.path.join(save_dir, imgname)) 385 | 386 | 387 | def make_video(out_dir, gt_dir, gt_first=False): 388 | gt_ext = '/*.png' 389 | frames_all = sorted(glob.glob(out_dir + '/*.png') + glob.glob(gt_dir + gt_ext), \ 390 | key=lambda frame: frame.split('/')[-1]) 391 | print("# of total frames : %d" % len(frames_all)) 392 | if gt_first: 393 | print("Appending GT in front..") 394 | frames_all = sorted(glob.glob(gt_dir + gt_ext)) + frames_all 395 | print("# of total frames : %d" % len(frames_all)) 396 | 397 | # Read the first image to determine height and width 398 | frame = cv2.imread(frames_all[0]) 399 | h, w, _ = frame.shape 400 | 401 | # Write video 402 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 403 | out = cv2.VideoWriter(out_dir + '/slomo.mp4', fourcc, 30, (w, h)) 404 | for p in frames_all: 405 | #print(p) 406 | # TODO: add captions (e.g. 'GT', 'slow motion x4') 407 | frame = cv2.imread(p) 408 | fh, fw = frame.shape[:2] 409 | #print(fh, fw, h, w) 410 | if fh != h or fw != w: 411 | frame = cv2.resize(frame, (w, h), interpolation=cv2.INTER_LINEAR) 412 | out.write(frame) 413 | 414 | def check_already_extracted(vid): 415 | return bool(os.path.exists(vid + '/00001.png')) 416 | --------------------------------------------------------------------------------