├── .gitignore
├── LICENSE
├── README.md
├── _config.yml
├── config.py
├── data
├── __init__.py
├── snufilm.py
├── video.py
└── vimeo90k.py
├── eval.sh
├── figures
├── CAIN_AAAI20_poster.pdf
├── CAIN_paper_thumb.jpg
├── CAIN_spotlight_thumb.jpg
├── overall_architecture.png
└── qualitative_vimeo.png
├── generate.py
├── loss.py
├── main.py
├── model
├── __init__.py
├── cain.py
├── cain_encdec.py
├── cain_noca.py
└── common.py
├── pytorch_msssim
└── __init__.py
├── run.sh
├── run_noca.sh
├── test_custom.sh
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore Git here
2 | .git
3 |
4 | # But not these files...
5 | # !.gitignore
6 |
7 | checkpoint/*
8 | logs/*
9 | data/vimeo_triplet
10 |
11 |
12 | # Created by .ignore support plugin (hsz.mobi)
13 | ### Python template
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | env/
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *,cover
59 | .hypothesis/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # IPython Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # dotenv
92 | .env
93 |
94 | # virtualenv
95 | venv/
96 | ENV/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 | ### VirtualEnv template
104 | # Virtualenv
105 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
106 | .Python
107 | [Bb]in
108 | [Ii]nclude
109 | [Ll]ib
110 | [Ll]ib64
111 | [Ll]ocal
112 | [Ss]cripts
113 | pyvenv.cfg
114 | .venv
115 | pip-selfcheck.json
116 | ### JetBrains template
117 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
118 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
119 |
120 | # User-specific stuff:
121 | .idea/workspace.xml
122 | .idea/tasks.xml
123 | .idea/dictionaries
124 | .idea/vcs.xml
125 | .idea/jsLibraryMappings.xml
126 |
127 | # Sensitive or high-churn files:
128 | .idea/dataSources.ids
129 | .idea/dataSources.xml
130 | .idea/dataSources.local.xml
131 | .idea/sqlDataSources.xml
132 | .idea/dynamic.xml
133 | .idea/uiDesigner.xml
134 |
135 | # Gradle:
136 | .idea/gradle.xml
137 | .idea/libraries
138 |
139 | # Mongo Explorer plugin:
140 | .idea/mongoSettings.xml
141 |
142 | .idea/
143 |
144 | ## File-based project format:
145 | *.iws
146 |
147 | ## Plugin-specific files:
148 |
149 | # IntelliJ
150 | /out/
151 |
152 | # mpeltonen/sbt-idea plugin
153 | .idea_modules/
154 |
155 | # JIRA plugin
156 | atlassian-ide-plugin.xml
157 |
158 | # Crashlytics plugin (for Android Studio and IntelliJ)
159 | com_crashlytics_export_strings.xml
160 | crashlytics.properties
161 | crashlytics-build.properties
162 | fabric.properties
163 |
164 | *.swp
165 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Myungsub Choi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Channel Attention Is All You Need for Video Frame Interpolation
2 |
3 | #### Myungsub Choi, Heewon Kim, Bohyung Han, Ning Xu, Kyoung Mu Lee
4 |
5 | #### 2nd place in [[AIM 2019 ICCV Workshop](http://www.vision.ee.ethz.ch/aim19/)] - Video Temporal Super-Resolution Challenge
6 |
7 | [Project](https://myungsub.github.io/CAIN) | [Paper-AAAI](https://aaai.org/ojs/index.php/AAAI/article/view/6693/6547) (Download the paper [[here](https://www.dropbox.com/s/b62wnroqdd5lhfc/AAAI-ChoiM.4773.pdf?dl=0)] in case the AAAI link is broken) | [Poster](https://www.dropbox.com/s/7lxwka16qkuacvh/AAAI-ChoiM.4773.pdf)
8 |
9 |
10 |
11 |
12 |
13 | ## Directory Structure
14 |
15 | ``` text
16 | project
17 | │ README.md
18 | | run.sh - main script to train CAIN model
19 | | run_noca.sh - script to train CAIN_NoCA model
20 | | test_custom.sh - script to run interpolation on custom dataset
21 | | eval.sh - script to evaluate on SNU-FILM benchmark
22 | | main.py - main file to run train/val
23 | | config.py - check & change training/testing configurations here
24 | | loss.py - defines different loss functions
25 | | utils.py - misc.
26 | └───model
27 | │ │ common.py
28 | │ │ cain.py - main model
29 | | | cain_noca.py - model without channel attention
30 | | | cain_encdec.py - model with additional encoder-decoder
31 | └───data - implements dataloaders for each dataset
32 | │ | vimeo90k.py - main training / testing dataset
33 | | | video.py - custom data for testing
34 | │ └───symbolic links to each dataset
35 | | | ...
36 | ```
37 |
38 | ## Dependencies
39 |
40 | Current version is tested on:
41 |
42 | - Ubuntu 18.04
43 | - Python==3.7.5
44 | - numpy==1.17
45 | - [PyTorch](http://pytorch.org/)==1.3.1, torchvision==0.4.2, cudatoolkit==10.1
46 | - tensorboard==2.0.0 (If you want training logs)
47 | - opencv==3.4.2
48 | - tqdm==4.39.0
49 |
50 | ``` text
51 | # Easy installation (using Anaconda environment)
52 | conda create -n cain
53 | conda activate cain
54 | conda install python=3.7
55 | conda install pip numpy
56 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
57 | conda install tqdm opencv tensorboard
58 | ```
59 |
60 | ## Model
61 |
62 |
63 |
64 | ## Dataset Preparation
65 |
66 | - We use **[Vimeo90K Triplet dataset](http://toflow.csail.mit.edu/)** for training + testing.
67 | - After downloading the full dataset, make symbolic links in `data/` folder :
68 | - `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet`
69 | - Then you're done!
70 | - For more thorough evaluation, we built **[SNU-FILM (SNU Frame Interpolation with Large Motion)](https://myungsub.github.io/CAIN)** benchmark.
71 | - Download links can be found in the [project page](https://myungsub.github.io/CAIN).
72 | - Also make symbolic links after download :
73 | - `ln -s /path/to/SNU-FILM_data/ ./data/SNU-FILM`
74 | - Done!
75 |
76 | ## Usage
77 |
78 | #### Training / Testing with Vimeo90K dataset
79 | - First make symbolic links in `data/` folder : `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet`
80 | - [Vimeo90K dataset](http://toflow.csail.mit.edu/)
81 | - For training: `CUDA_VISIBLE_DEVICES=0 python main.py --exp_name EXPNAME --batch_size 16 --test_batch_size 16 --dataset vimeo90k --model cain --loss 1*L1 --max_epoch 200 --lr 0.0002`
82 | - Or, just run `./run.sh`
83 | - For testing performance on Vimeo90K dataset, just add `--mode test` option
84 | - For testing on SNU-FILM dataset, run `./eval.sh`
85 | - Testing mode (choose from ['easy', 'medium', 'hard', 'extreme']) can be modified by changing `--test_mode` option in `eval.sh`.
86 |
87 | #### Interpolating with custom video
88 | - Download pretrained models from [[Here](https://www.dropbox.com/s/y1xf46m2cbwk7yf/pretrained_cain.pth?dl=0)]
89 | - Prepare frame sequences in `data/frame_seq`
90 | - run `test_custom.sh`
91 |
92 | ## Results
93 |
94 |
95 |
96 | ### Video
97 |
98 |
99 |
100 | ## Citation
101 |
102 | If you find this code useful for your research, please consider citing the following paper:
103 |
104 | ``` text
105 | @inproceedings{choi2020cain,
106 | author = {Choi, Myungsub and Kim, Heewon and Han, Bohyung and Xu, Ning and Lee, Kyoung Mu},
107 | title = {Channel Attention Is All You Need for Video Frame Interpolation},
108 | booktitle = {AAAI},
109 | year = {2020}
110 | }
111 | ```
112 |
113 | ## Acknowledgement
114 |
115 | Many parts of this code is adapted from:
116 |
117 | - [EDSR-Pytorch](https://github.com/thstkdgus35/EDSR-PyTorch)
118 | - [RCAN](https://github.com/yulunzhang/RCAN)
119 |
120 | We thank the authors for sharing codes for their great works.
121 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | arg_lists = []
4 | parser = argparse.ArgumentParser()
5 |
6 | def str2bool(v):
7 | return v.lower() in ('true')
8 |
9 | def add_argument_group(name):
10 | arg = parser.add_argument_group(name)
11 | arg_lists.append(arg)
12 | return arg
13 |
14 | # Dataset
15 | data_arg = add_argument_group('Dataset')
16 | data_arg.add_argument('--dataset', type=str, default='vimeo90k')
17 | data_arg.add_argument('--num_frames', type=int, default=3)
18 | data_arg.add_argument('--data_root', type=str, default='data/vimeo_triplet')
19 | data_arg.add_argument('--img_fmt', type=str, default='png')
20 |
21 | # Model
22 | model_arg = add_argument_group('Model')
23 | model_arg.add_argument('--model', type=str, default='CAIN')
24 | model_arg.add_argument('--depth', type=int, default=3, help='# of pooling')
25 | model_arg.add_argument('--n_resblocks', type=int, default=12)
26 | model_arg.add_argument('--up_mode', type=str, default='shuffle')
27 |
28 | # Training / test parameters
29 | learn_arg = add_argument_group('Learning')
30 | learn_arg.add_argument('--mode', type=str, default='train',
31 | choices=['train', 'test', 'test-multi', 'gen-multi'])
32 | learn_arg.add_argument('--loss', type=str, default='1*L1')
33 | learn_arg.add_argument('--lr', type=float, default=1e-4)
34 | learn_arg.add_argument('--beta1', type=float, default=0.9)
35 | learn_arg.add_argument('--beta2', type=float, default=0.99)
36 | learn_arg.add_argument('--batch_size', type=int, default=16)
37 | learn_arg.add_argument('--val_batch_size', type=int, default=4)
38 | learn_arg.add_argument('--test_batch_size', type=int, default=1)
39 | learn_arg.add_argument('--test_mode', type=str, default='hard', help='Test mode to evaluate on SNU-FILM dataset')
40 | learn_arg.add_argument('--start_epoch', type=int, default=0)
41 | learn_arg.add_argument('--max_epoch', type=int, default=200)
42 | learn_arg.add_argument('--resume', action='store_true')
43 | learn_arg.add_argument('--resume_exp', type=str, default=None)
44 | learn_arg.add_argument('--fix_loaded', action='store_true', help='whether to fix updating all loaded parts of the model')
45 |
46 | # Misc
47 | misc_arg = add_argument_group('Misc')
48 | misc_arg.add_argument('--exp_name', type=str, default='exp')
49 | misc_arg.add_argument('--log_iter', type=int, default=20)
50 | misc_arg.add_argument('--log_dir', type=str, default='logs')
51 | misc_arg.add_argument('--data_dir', type=str, default='data')
52 | misc_arg.add_argument('--num_gpu', type=int, default=1)
53 | misc_arg.add_argument('--random_seed', type=int, default=12345)
54 | misc_arg.add_argument('--num_workers', type=int, default=5)
55 | misc_arg.add_argument('--use_tensorboard', action='store_true')
56 | misc_arg.add_argument('--viz', action='store_true', help='whether to save images')
57 | misc_arg.add_argument('--lpips', action='store_true', help='evaluates LPIPS if set true')
58 |
59 | def get_args():
60 | """Parses all of the arguments above
61 | """
62 | args, unparsed = parser.parse_known_args()
63 | if args.num_gpu > 0:
64 | setattr(args, 'cuda', True)
65 | else:
66 | setattr(args, 'cuda', False)
67 | if len(unparsed) > 1:
68 | print("Unparsed args: {}".format(unparsed))
69 | return args, unparsed
70 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/data/__init__.py
--------------------------------------------------------------------------------
/data/snufilm.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | from torchvision import transforms
6 | from PIL import Image
7 |
8 | class SNUFILM(Dataset):
9 | def __init__(self, data_root, mode='hard'):
10 | '''
11 | :param data_root: ./data/SNU-FILM
12 | :param mode: ['easy', 'medium', 'hard', 'extreme']
13 | '''
14 | test_root = os.path.join(data_root, 'test')
15 | test_fn = os.path.join(data_root, 'test-%s.txt' % mode)
16 | with open(test_fn, 'r') as f:
17 | self.frame_list = f.read().splitlines()
18 | self.frame_list = [v.split(' ') for v in self.frame_list]
19 |
20 | self.transforms = transforms.Compose([
21 | transforms.ToTensor()
22 | ])
23 |
24 | print("[%s] Test dataset has %d triplets" % (mode, len(self.frame_list)))
25 |
26 |
27 | def __getitem__(self, index):
28 |
29 | # Use self.test_all_images:
30 | imgpaths = self.frame_list[index]
31 |
32 | img1 = Image.open(imgpaths[0])
33 | img2 = Image.open(imgpaths[1])
34 | img3 = Image.open(imgpaths[2])
35 |
36 | img1 = self.transforms(img1)
37 | img2 = self.transforms(img2)
38 | img3 = self.transforms(img3)
39 |
40 | imgs = [img1, img2, img3]
41 |
42 | return imgs, imgpaths
43 |
44 | def __len__(self):
45 | return len(self.frame_list)
46 |
47 |
48 | def check_already_extracted(vid):
49 | return bool(os.path.exists(vid + '/0001.png'))
50 |
51 |
52 | def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode='hard'):
53 | # data_root = 'data/SNUFILM'
54 | dataset = SNUFILM(data_root, mode=test_mode)
55 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
56 |
--------------------------------------------------------------------------------
/data/video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Dataset, DataLoader
6 | from torchvision import transforms
7 | from PIL import Image
8 |
9 | class Video(Dataset):
10 | def __init__(self, data_root, fmt='png'):
11 | images = sorted(glob.glob(os.path.join(data_root, '*.%s' % fmt)))
12 | for im in images:
13 | try:
14 | float_ind = float(im.split('_')[-1][:-4])
15 | except ValueError:
16 | os.rename(im, '%s_%.06f.%s' % (im[:-4], 0.0, fmt))
17 | # re
18 | images = sorted(glob.glob(os.path.join(data_root, '*.%s' % fmt)))
19 | self.imglist = [[images[i], images[i+1]] for i in range(len(images)-1)]
20 | print('[%d] images ready to be loaded' % len(self.imglist))
21 |
22 |
23 | def __getitem__(self, index):
24 | imgpaths = self.imglist[index]
25 |
26 | # Load images
27 | img1 = Image.open(imgpaths[0])
28 | img2 = Image.open(imgpaths[1])
29 |
30 | T = transforms.ToTensor()
31 | img1 = T(img1)
32 | img2 = T(img2)
33 |
34 | imgs = [img1, img2]
35 | meta = {'imgpath': imgpaths}
36 | return imgs, meta
37 |
38 | def __len__(self):
39 | return len(self.imglist)
40 |
41 |
42 | def get_loader(mode, data_root, batch_size, img_fmt='png', shuffle=False, num_workers=0, n_frames=1):
43 | if mode == 'train':
44 | is_training = True
45 | else:
46 | is_training = False
47 | dataset = Video(data_root, fmt=img_fmt)
48 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
49 |
--------------------------------------------------------------------------------
/data/vimeo90k.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | from torchvision import transforms
6 | from PIL import Image
7 | import random
8 |
9 | class VimeoTriplet(Dataset):
10 | def __init__(self, data_root, is_training):
11 | self.data_root = data_root
12 | self.image_root = os.path.join(self.data_root, 'sequences')
13 | self.training = is_training
14 |
15 | train_fn = os.path.join(self.data_root, 'tri_trainlist.txt')
16 | test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
17 | with open(train_fn, 'r') as f:
18 | self.trainlist = f.read().splitlines()
19 | with open(test_fn, 'r') as f:
20 | self.testlist = f.read().splitlines()
21 |
22 | self.transforms = transforms.Compose([
23 | transforms.RandomCrop(256),
24 | transforms.RandomHorizontalFlip(0.5),
25 | transforms.RandomVerticalFlip(0.5),
26 | transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
27 | transforms.ToTensor()
28 | ])
29 |
30 |
31 | def __getitem__(self, index):
32 | if self.training:
33 | imgpath = os.path.join(self.image_root, self.trainlist[index])
34 | else:
35 | imgpath = os.path.join(self.image_root, self.testlist[index])
36 | imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png']
37 |
38 | # Load images
39 | img1 = Image.open(imgpaths[0])
40 | img2 = Image.open(imgpaths[1])
41 | img3 = Image.open(imgpaths[2])
42 |
43 | # Data augmentation
44 | if self.training:
45 | seed = random.randint(0, 2**32)
46 | random.seed(seed)
47 | img1 = self.transforms(img1)
48 | random.seed(seed)
49 | img2 = self.transforms(img2)
50 | random.seed(seed)
51 | img3 = self.transforms(img3)
52 | # Random Temporal Flip
53 | if random.random() >= 0.5:
54 | img1, img3 = img3, img1
55 | imgpaths[0], imgpaths[2] = imgpaths[2], imgpaths[0]
56 | else:
57 | T = transforms.ToTensor()
58 | img1 = T(img1)
59 | img2 = T(img2)
60 | img3 = T(img3)
61 |
62 | imgs = [img1, img2, img3]
63 |
64 | return imgs, imgpaths
65 |
66 | def __len__(self):
67 | if self.training:
68 | return len(self.trainlist)
69 | else:
70 | return len(self.testlist)
71 | return 0
72 |
73 |
74 | def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode=None):
75 | if mode == 'train':
76 | is_training = True
77 | else:
78 | is_training = False
79 | dataset = VimeoTriplet(data_root, is_training=is_training)
80 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
81 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=1 python main.py \
4 | --exp_name CAIN_eval \
5 | --dataset snufilm \
6 | --data_root data/SNU-FILM \
7 | --test_batch_size 1 \
8 | --model cain \
9 | --depth 3 \
10 | --mode test \
11 | --resume \
12 | --resume_exp CAIN_train \
13 | --test_mode hard
--------------------------------------------------------------------------------
/figures/CAIN_AAAI20_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/CAIN_AAAI20_poster.pdf
--------------------------------------------------------------------------------
/figures/CAIN_paper_thumb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/CAIN_paper_thumb.jpg
--------------------------------------------------------------------------------
/figures/CAIN_spotlight_thumb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/CAIN_spotlight_thumb.jpg
--------------------------------------------------------------------------------
/figures/overall_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/overall_architecture.png
--------------------------------------------------------------------------------
/figures/qualitative_vimeo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/figures/qualitative_vimeo.png
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import copy
5 | import shutil
6 | import random
7 |
8 | import torch
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | import config
13 | import utils
14 |
15 |
16 | ##### Parse CmdLine Arguments #####
17 | args, unparsed = config.get_args()
18 | cwd = os.getcwd()
19 | print(args)
20 |
21 |
22 | device = torch.device('cuda' if args.cuda else 'cpu')
23 | torch.backends.cudnn.enabled = True
24 | torch.backends.cudnn.benchmark = True
25 |
26 | torch.manual_seed(args.random_seed)
27 | if args.cuda:
28 | torch.cuda.manual_seed(args.random_seed)
29 |
30 |
31 |
32 |
33 | ##### Build Model #####
34 | if args.model.lower() == 'cain_encdec':
35 | from model.cain_encdec import CAIN_EncDec
36 | print('Building model: CAIN_EncDec')
37 | model = CAIN_EncDec(depth=args.depth, start_filts=32)
38 | elif args.model.lower() == 'cain':
39 | from model.cain import CAIN
40 | print("Building model: CAIN")
41 | model = CAIN(depth=args.depth)
42 | elif args.model.lower() == 'cain_noca':
43 | from model.cain_noca import CAIN_NoCA
44 | print("Building model: CAIN_NoCA")
45 | model = CAIN_NoCA(depth=args.depth)
46 | else:
47 | raise NotImplementedError("Unknown model!")
48 | # Just make every model to DataParallel
49 | model = torch.nn.DataParallel(model).to(device)
50 | #print(model)
51 |
52 | print('# of parameters: %d' % sum(p.numel() for p in model.parameters()))
53 |
54 |
55 | # If resume, load checkpoint: model
56 | if args.resume:
57 | #utils.load_checkpoint(args, model, optimizer=None)
58 | checkpoint = torch.load('pretrained_cain.pth')
59 | args.start_epoch = checkpoint['epoch'] + 1
60 | model.load_state_dict(checkpoint['state_dict'])
61 | del checkpoint
62 |
63 |
64 |
65 | def test(args, epoch):
66 | print('Evaluating for epoch = %d' % epoch)
67 | ##### Load Dataset #####
68 | test_loader = utils.load_dataset(
69 | args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, img_fmt=args.img_fmt)
70 | model.eval()
71 |
72 | t = time.time()
73 | with torch.no_grad():
74 | for i, (images, meta) in enumerate(tqdm(test_loader)):
75 |
76 | # Build input batch
77 | im1, im2 = images[0].to(device), images[1].to(device)
78 |
79 | # Forward
80 | out, _ = model(im1, im2)
81 |
82 | # Save result images
83 | if args.mode == 'test':
84 | for b in range(images[0].size(0)):
85 | paths = meta['imgpath'][0][b].split('/')
86 | fp = args.data_root
87 | fp = os.path.join(fp, paths[-1][:-4]) # remove '.png' extension
88 |
89 | # Decide float index
90 | i1_str = paths[-1][:-4]
91 | i2_str = meta['imgpath'][1][b].split('/')[-1][:-4]
92 | try:
93 | i1 = float(i1_str.split('_')[-1])
94 | except ValueError:
95 | i1 = 0.0
96 | try:
97 | i2 = float(i2_str.split('_')[-1])
98 | if i2 == 0.0:
99 | i2 = 1.0
100 | except ValueError:
101 | i2 = 1.0
102 | fpos = max(0, fp.rfind('_'))
103 | fInd = (i1 + i2) / 2
104 | savepath = "%s_%06f.%s" % (fp[:fpos], fInd, args.img_fmt)
105 | utils.save_image(out[b], savepath)
106 |
107 | # Print progress
108 | print('im_processed: {:d}/{:d} {:.3f}s \r'.format(i + 1, len(test_loader), time.time() - t))
109 |
110 | return
111 |
112 |
113 | """ Entry Point """
114 | def main(args):
115 |
116 | num_iter = 2 # x2**num_iter interpolation
117 | for _ in range(num_iter):
118 |
119 | # run test
120 | test(args, args.start_epoch)
121 |
122 |
123 | if __name__ == "__main__":
124 | main(args)
125 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 | import pytorch_msssim
6 | from model.common import sub_mean, InOutPaddings, meanShift, PixelShuffle, ResidualGroup, conv
7 |
8 | class MeanShift(nn.Conv2d):
9 | def __init__(self, rgb_mean, rgb_std, sign=-1):
10 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
11 | std = torch.Tensor(rgb_std)
12 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
13 | self.weight.data.div_(std.view(3, 1, 1, 1))
14 | self.bias.data = sign * torch.Tensor(rgb_mean)
15 | self.bias.data.div_(std)
16 | self.requires_grad = False
17 |
18 |
19 | class VGG(nn.Module):
20 | def __init__(self, loss_type):
21 | super(VGG, self).__init__()
22 | vgg_features = models.vgg19(pretrained=True).features
23 | modules = [m for m in vgg_features]
24 | conv_index = loss_type[-2:]
25 | if conv_index == '22':
26 | self.vgg = nn.Sequential(*modules[:8])
27 | elif conv_index == '33':
28 | self.vgg = nn.Sequential(*modules[:16])
29 | elif conv_index == '44':
30 | self.vgg = nn.Sequential(*modules[:26])
31 | elif conv_index == '54':
32 | self.vgg = nn.Sequential(*modules[:35])
33 | elif conv_index == 'P':
34 | self.vgg = nn.ModuleList([
35 | nn.Sequential(*modules[:8]),
36 | nn.Sequential(*modules[8:16]),
37 | nn.Sequential(*modules[16:26]),
38 | nn.Sequential(*modules[26:35])
39 | ])
40 | self.vgg = nn.DataParallel(self.vgg).cuda()
41 |
42 | vgg_mean = (0.485, 0.456, 0.406)
43 | vgg_std = (0.229, 0.224, 0.225)
44 | self.sub_mean = MeanShift(vgg_mean, vgg_std)
45 | self.vgg.requires_grad = False
46 | # self.criterion = nn.L1Loss()
47 | self.conv_index = conv_index
48 |
49 | def forward(self, sr, hr):
50 | def _forward(x):
51 | x = self.sub_mean(x)
52 | x = self.vgg(x)
53 | return x
54 | def _forward_all(x):
55 | feats = []
56 | x = self.sub_mean(x)
57 | for module in self.vgg.module:
58 | x = module(x)
59 | feats.append(x)
60 | return feats
61 |
62 | if self.conv_index == 'P':
63 | vgg_sr_feats = _forward_all(sr)
64 | with torch.no_grad():
65 | vgg_hr_feats = _forward_all(hr.detach())
66 | loss = 0
67 | for i in range(len(vgg_sr_feats)):
68 | loss_f = F.mse_loss(vgg_sr_feats[i], vgg_hr_feats[i])
69 | #print(loss_f)
70 | loss += loss_f
71 | #print()
72 | else:
73 | vgg_sr = _forward(sr)
74 | with torch.no_grad():
75 | vgg_hr = _forward(hr.detach())
76 | loss = F.mse_loss(vgg_sr, vgg_hr)
77 |
78 | return loss
79 |
80 |
81 | # For Adversarial loss
82 | class BasicBlock(nn.Sequential):
83 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)):
84 | m = [nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), stride=stride, bias=bias)]
85 | if bn: m.append(nn.BatchNorm2d(out_channels))
86 | if act is not None: m.append(act)
87 | super(BasicBlock, self).__init__(*m)
88 |
89 | class Discriminator(nn.Module):
90 | def __init__(self, args, gan_type='GAN'):
91 | super(Discriminator, self).__init__()
92 |
93 | in_channels = 3
94 | out_channels = 64
95 | depth = 7
96 | #bn = not gan_type == 'WGAN_GP'
97 | bn = True
98 | act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
99 |
100 | m_features = [
101 | BasicBlock(in_channels, out_channels, 3, bn=bn, act=act)
102 | ]
103 | for i in range(depth):
104 | in_channels = out_channels
105 | if i % 2 == 1:
106 | stride = 1
107 | out_channels *= 2
108 | else:
109 | stride = 2
110 | m_features.append(BasicBlock(
111 | in_channels, out_channels, 3, stride=stride, bn=bn, act=act
112 | ))
113 |
114 | self.features = nn.Sequential(*m_features)
115 |
116 | self.patch_size = args.patch_size
117 | feature_patch_size = self.patch_size // (2**((depth + 1) // 2))
118 | #patch_size = 256 // (2**((depth + 1) // 2))
119 | m_classifier = [
120 | nn.Linear(out_channels * feature_patch_size**2, 1024),
121 | act,
122 | nn.Linear(1024, 1)
123 | ]
124 | self.classifier = nn.Sequential(*m_classifier)
125 |
126 | def forward(self, x):
127 | if x.size(2) != self.patch_size or x.size(3) != self.patch_size:
128 | midH, midW = x.size(2) // 2, x.size(3) // 2
129 | p = self.patch_size // 2
130 | x = x[:, :, (midH - p):(midH - p + self.patch_size), (midW - p):(midW - p + self.patch_size)]
131 | features = self.features(x)
132 | output = self.classifier(features.view(features.size(0), -1))
133 |
134 | return output
135 |
136 |
137 | import torch.optim as optim
138 | class Adversarial(nn.Module):
139 | def __init__(self, args, gan_type):
140 | super(Adversarial, self).__init__()
141 | self.gan_type = gan_type
142 | self.gan_k = 1 #args.gan_k
143 | self.discriminator = torch.nn.DataParallel(Discriminator(args, gan_type))
144 | if gan_type != 'WGAN_GP':
145 | self.optimizer = optim.Adam(
146 | self.discriminator.parameters(),
147 | betas=(0.9, 0.99), eps=1e-8, lr=1e-4
148 | )
149 | else:
150 | self.optimizer = optim.Adam(
151 | self.discriminator.parameters(),
152 | betas=(0, 0.9), eps=1e-8, lr=1e-5
153 | )
154 | # self.scheduler = utility.make_scheduler(args, self.optimizer)
155 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
156 | self.optimizer, mode='min', factor=0.5, patience=3, verbose=True)
157 |
158 | def forward(self, fake, real, fake_input0=None, fake_input1=None, fake_input_mean=None):
159 | # def forward(self, fake, real):
160 | fake_detach = fake.detach()
161 | if fake_input0 is not None:
162 | fake0, fake1 = fake_input0.detach(), fake_input1.detach()
163 | if fake_input_mean is not None:
164 | fake_m = fake_input_mean.detach()
165 | # print(fake.size(), fake_input0.size(), fake_input1.size(), fake_input_mean.size())
166 |
167 | self.loss = 0
168 | for _ in range(self.gan_k):
169 | self.optimizer.zero_grad()
170 | d_fake = self.discriminator(fake_detach)
171 |
172 | if fake_input0 is not None and fake_input1 is not None:
173 | d_fake0 = self.discriminator(fake0)
174 | d_fake1 = self.discriminator(fake1)
175 | if fake_input_mean is not None:
176 | d_fake_m = self.discriminator(fake_m)
177 |
178 | # print(d_fake.size(), d_fake0.size(), d_fake1.size(), d_fake_m.size())
179 |
180 | d_real = self.discriminator(real)
181 | if self.gan_type == 'GAN':
182 | label_fake = torch.zeros_like(d_fake)
183 | label_real = torch.ones_like(d_real)
184 | loss_d \
185 | = F.binary_cross_entropy_with_logits(d_fake, label_fake) \
186 | + F.binary_cross_entropy_with_logits(d_real, label_real)
187 | if fake_input0 is not None and fake_input1 is not None:
188 | loss_d += F.binary_cross_entropy_with_logits(d_fake0, label_fake) \
189 | + F.binary_cross_entropy_with_logits(d_fake1, label_fake)
190 | if fake_input_mean is not None:
191 | loss_d += F.binary_cross_entropy_with_logits(d_fake_m, label_fake)
192 |
193 | elif self.gan_type.find('WGAN') >= 0:
194 | loss_d = (d_fake - d_real).mean()
195 | if self.gan_type.find('GP') >= 0:
196 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
197 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
198 | hat.requires_grad = True
199 | d_hat = self.discriminator(hat)
200 | gradients = torch.autograd.grad(
201 | outputs=d_hat.sum(), inputs=hat,
202 | retain_graph=True, create_graph=True, only_inputs=True
203 | )[0]
204 | gradients = gradients.view(gradients.size(0), -1)
205 | gradient_norm = gradients.norm(2, dim=1)
206 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
207 | loss_d += gradient_penalty
208 |
209 | # Discriminator update
210 | self.loss += loss_d.item()
211 | if self.training:
212 | loss_d.backward()
213 | self.optimizer.step()
214 |
215 | if self.gan_type == 'WGAN':
216 | for p in self.discriminator.parameters():
217 | p.data.clamp_(-1, 1)
218 |
219 | self.loss /= self.gan_k
220 |
221 | d_fake_for_g = self.discriminator(fake)
222 | if self.gan_type == 'GAN':
223 | loss_g = F.binary_cross_entropy_with_logits(
224 | d_fake_for_g, label_real
225 | )
226 | elif self.gan_type.find('WGAN') >= 0:
227 | loss_g = -d_fake_for_g.mean()
228 |
229 | # Generator loss
230 | return loss_g
231 |
232 | def state_dict(self, *args, **kwargs):
233 | state_discriminator = self.discriminator.state_dict(*args, **kwargs)
234 | state_optimizer = self.optimizer.state_dict()
235 |
236 | return dict(**state_discriminator, **state_optimizer)
237 |
238 |
239 | # Some references
240 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
241 | # OR
242 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
243 |
244 |
245 | # Wrapper of loss functions
246 | class Loss(nn.modules.loss._Loss):
247 | def __init__(self, args):
248 | super(Loss, self).__init__()
249 | print('Preparing loss function:')
250 |
251 | self.loss = []
252 | self.loss_module = nn.ModuleList()
253 | for loss in args.loss.split('+'):
254 | weight, loss_type = loss.split('*')
255 | if loss_type == 'MSE':
256 | loss_function = nn.MSELoss()
257 | elif loss_type == 'L1':
258 | loss_function = nn.L1Loss()
259 | elif loss_type.find('VGG') >= 0:
260 | loss_function = VGG(loss_type[3:])
261 | elif loss_type == 'SSIM':
262 | loss_function = pytorch_msssim.SSIM(val_range=1.)
263 | elif loss_type.find('GAN') >= 0:
264 | loss_function = Adversarial(args, loss_type)
265 |
266 | self.loss.append({
267 | 'type': loss_type,
268 | 'weight': float(weight),
269 | 'function': loss_function}
270 | )
271 | if loss_type.find('GAN') >= 0 >= 0:
272 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
273 |
274 | if len(self.loss) > 1:
275 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
276 |
277 | for l in self.loss:
278 | if l['function'] is not None:
279 | print('{:.3f} * {}'.format(l['weight'], l['type']))
280 | self.loss_module.append(l['function'])
281 |
282 | device = torch.device('cuda' if args.cuda else 'cpu')
283 | self.loss_module.to(device)
284 | #if args.precision == 'half': self.loss_module.half()
285 | if args.cuda:# and args.n_GPUs > 1:
286 | self.loss_module = nn.DataParallel(self.loss_module)
287 |
288 |
289 | def forward(self, sr, hr, model_enc=None, feats=None, fake_imgs=None):
290 | loss = 0
291 | losses = {}
292 | for i, l in enumerate(self.loss):
293 | if l['function'] is not None:
294 | if l['type'] == 'GAN':
295 | if fake_imgs is None:
296 | fake_imgs = [None, None, None]
297 | _loss = l['function'](sr, hr, fake_imgs[0], fake_imgs[1], fake_imgs[2])
298 | else:
299 | _loss = l['function'](sr, hr)
300 | effective_loss = l['weight'] * _loss
301 | losses[l['type']] = effective_loss
302 | loss += effective_loss
303 | elif l['type'] == 'DIS':
304 | losses[l['type']] = self.loss[i - 1]['function'].loss
305 |
306 | #loss_sum = sum(losses)
307 | #if len(self.loss) > 1:
308 | # self.log[-1, -1] += loss_sum.item()
309 |
310 | return loss, losses
311 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import copy
5 | import shutil
6 | import random
7 |
8 | import torch
9 | import numpy as np
10 | from tqdm import tqdm
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | import config
14 | import utils
15 | from loss import Loss
16 |
17 |
18 | ##### Parse CmdLine Arguments #####
19 | args, unparsed = config.get_args()
20 | cwd = os.getcwd()
21 | print(args)
22 |
23 |
24 | ##### TensorBoard & Misc Setup #####
25 | if args.mode != 'test':
26 | writer = SummaryWriter('logs/%s' % args.exp_name)
27 |
28 | device = torch.device('cuda' if args.cuda else 'cpu')
29 | torch.backends.cudnn.enabled = True
30 | torch.backends.cudnn.benchmark = True
31 |
32 | torch.manual_seed(args.random_seed)
33 | if args.cuda:
34 | torch.cuda.manual_seed(args.random_seed)
35 |
36 |
37 | ##### Load Dataset #####
38 | train_loader, test_loader = utils.load_dataset(
39 | args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, args.test_mode)
40 |
41 |
42 | ##### Build Model #####
43 | if args.model.lower() == 'cain_encdec':
44 | from model.cain_encdec import CAIN_EncDec
45 | print('Building model: CAIN_EncDec')
46 | model = CAIN_EncDec(depth=args.depth, start_filts=32)
47 | elif args.model.lower() == 'cain':
48 | from model.cain import CAIN
49 | print("Building model: CAIN")
50 | model = CAIN(depth=args.depth)
51 | elif args.model.lower() == 'cain_noca':
52 | from model.cain_noca import CAIN_NoCA
53 | print("Building model: CAIN_NoCA")
54 | model = CAIN_NoCA(depth=args.depth)
55 | else:
56 | raise NotImplementedError("Unknown model!")
57 | # Just make every model to DataParallel
58 | model = torch.nn.DataParallel(model).to(device)
59 | #print(model)
60 |
61 | ##### Define Loss & Optimizer #####
62 | criterion = Loss(args)
63 |
64 | args.radam = False
65 | if args.radam:
66 | from radam import RAdam
67 | optimizer = RAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
68 | else:
69 | from torch.optim import Adam
70 | optimizer = Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
71 | print('# of parameters: %d' % sum(p.numel() for p in model.parameters()))
72 |
73 |
74 | # If resume, load checkpoint: model + optimizer
75 | if args.resume:
76 | utils.load_checkpoint(args, model, optimizer)
77 |
78 | # Learning Rate Scheduler
79 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
80 | optimizer, mode='min', factor=0.5, patience=5, verbose=True)
81 |
82 |
83 | # Initialize LPIPS model if used for evaluation
84 | # lpips_model = utils.init_lpips_eval() if args.lpips else None
85 | lpips_model = None
86 |
87 | LOSS_0 = 0
88 |
89 |
90 | def train(args, epoch):
91 | global LOSS_0
92 | losses, psnrs, ssims, lpips = utils.init_meters(args.loss)
93 | model.train()
94 | criterion.train()
95 |
96 | t = time.time()
97 | for i, (images, imgpaths) in enumerate(train_loader):
98 |
99 | # Build input batch
100 | im1, im2, gt = utils.build_input(images, imgpaths)
101 |
102 | # Forward
103 | optimizer.zero_grad()
104 | out, feats = model(im1, im2)
105 | loss, loss_specific = criterion(out, gt, None, feats)
106 |
107 | # Save loss values
108 | for k, v in losses.items():
109 | if k != 'total':
110 | v.update(loss_specific[k].item())
111 | if LOSS_0 == 0:
112 | LOSS_0 = loss.data.item()
113 | losses['total'].update(loss.item())
114 |
115 | # Backward (+ grad clip) - if loss explodes, skip current iteration
116 | loss.backward()
117 | if loss.data.item() > 10.0 * LOSS_0:
118 | print(max(p.grad.data.abs().max() for p in model.parameters()))
119 | continue
120 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
121 | optimizer.step()
122 |
123 | # Calc metrics & print logs
124 | if i % args.log_iter == 0:
125 | utils.eval_metrics(out, gt, psnrs, ssims, lpips, lpips_model)
126 |
127 | print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}\tPSNR: {:.4f}\tTime({:.2f})'.format(
128 | epoch, i, len(train_loader), losses['total'].avg, psnrs.avg, time.time() - t))
129 |
130 | # Log to TensorBoard
131 | utils.log_tensorboard(writer, losses, psnrs.avg, ssims.avg, lpips.avg,
132 | optimizer.param_groups[-1]['lr'], epoch * len(train_loader) + i)
133 |
134 | # Reset metrics
135 | losses, psnrs, ssims, lpips = utils.init_meters(args.loss)
136 | t = time.time()
137 |
138 |
139 | def test(args, epoch, eval_alpha=0.5):
140 | print('Evaluating for epoch = %d' % epoch)
141 | losses, psnrs, ssims, lpips = utils.init_meters(args.loss)
142 | model.eval()
143 | criterion.eval()
144 |
145 | save_folder = 'test%03d' % epoch
146 | if args.dataset == 'snufilm':
147 | save_folder = os.path.join(save_folder, args.dataset, args.test_mode)
148 | else:
149 | save_folder = os.path.join(save_folder, args.dataset)
150 | save_dir = os.path.join('checkpoint', args.exp_name, save_folder)
151 | utils.makedirs(save_dir)
152 | save_fn = os.path.join(save_dir, 'results.txt')
153 | if not os.path.exists(save_fn):
154 | with open(save_fn, 'w') as f:
155 | f.write('For epoch=%d\n' % epoch)
156 |
157 | t = time.time()
158 | with torch.no_grad():
159 | for i, (images, imgpaths) in enumerate(tqdm(test_loader)):
160 |
161 | # Build input batch
162 | im1, im2, gt = utils.build_input(images, imgpaths, is_training=False)
163 |
164 | # Forward
165 | out, feats = model(im1, im2)
166 |
167 | # Save loss values
168 | loss, loss_specific = criterion(out, gt, None, feats)
169 | for k, v in losses.items():
170 | if k != 'total':
171 | v.update(loss_specific[k].item())
172 | losses['total'].update(loss.item())
173 |
174 | # Evaluate metrics
175 | utils.eval_metrics(out, gt, psnrs, ssims, lpips)
176 |
177 | # Log examples that have bad performance
178 | if (ssims.val < 0.9 or psnrs.val < 25) and epoch > 50:
179 | print(imgpaths)
180 | print("\nLoss: %f, PSNR: %f, SSIM: %f, LPIPS: %f" %
181 | (losses['total'].val, psnrs.val, ssims.val, lpips.val))
182 | print(imgpaths[1][-1])
183 |
184 | # Save result images
185 | if ((epoch + 1) % 1 == 0 and i < 20) or args.mode == 'test':
186 | savepath = os.path.join('checkpoint', args.exp_name, save_folder)
187 |
188 | for b in range(images[0].size(0)):
189 | paths = imgpaths[1][b].split('/')
190 | fp = os.path.join(savepath, paths[-3], paths[-2])
191 | if not os.path.exists(fp):
192 | os.makedirs(fp)
193 | # remove '.png' extension
194 | fp = os.path.join(fp, paths[-1][:-4])
195 | utils.save_image(out[b], "%s.png" % fp)
196 |
197 | # Print progress
198 | print('im_processed: {:d}/{:d} {:.3f}s \r'.format(i + 1, len(test_loader), time.time() - t))
199 | print("Loss: %f, PSNR: %f, SSIM: %f, LPIPS: %f\n" %
200 | (losses['total'].avg, psnrs.avg, ssims.avg, lpips.avg))
201 |
202 | # Save psnr & ssim
203 | save_fn = os.path.join('checkpoint', args.exp_name, save_folder, 'results.txt')
204 | with open(save_fn, 'a') as f:
205 | f.write("PSNR: %f, SSIM: %f, LPIPS: %f\n" %
206 | (psnrs.avg, ssims.avg, lpips.avg))
207 |
208 | # Log to TensorBoard
209 | if args.mode != 'test':
210 | utils.log_tensorboard(writer, losses, psnrs.avg, ssims.avg, lpips.avg,
211 | optimizer.param_groups[-1]['lr'], epoch * len(train_loader) + i, mode='test')
212 |
213 | return losses['total'].avg, psnrs.avg, ssims.avg, lpips.avg
214 |
215 |
216 | """ Entry Point """
217 | def main(args):
218 | if args.mode == 'test':
219 | _, _, _, _ = test(args, args.start_epoch)
220 | return
221 |
222 | best_psnr = 0
223 | for epoch in range(args.start_epoch, args.max_epoch):
224 |
225 | # run training
226 | train(args, epoch)
227 |
228 | # run test
229 | test_loss, psnr, _, _ = test(args, epoch)
230 |
231 | # save checkpoint
232 | is_best = psnr > best_psnr
233 | best_psnr = max(psnr, best_psnr)
234 | utils.save_checkpoint({
235 | 'epoch': epoch,
236 | 'state_dict': model.state_dict(),
237 | 'optimizer': optimizer.state_dict(),
238 | 'best_psnr': best_psnr
239 | }, is_best, args.exp_name)
240 |
241 | # update optimizer policy
242 | scheduler.step(test_loss)
243 |
244 | if __name__ == "__main__":
245 | main(args)
246 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myungsub/CAIN/2e727d2a07d3f1061f17e2edaa47a7fb3f7e62c5/model/__init__.py
--------------------------------------------------------------------------------
/model/cain.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .common import *
8 |
9 |
10 | class Encoder(nn.Module):
11 | def __init__(self, in_channels=3, depth=3):
12 | super(Encoder, self).__init__()
13 |
14 | # Shuffle pixels to expand in channel dimension
15 | # shuffler_list = [PixelShuffle(0.5) for i in range(depth)]
16 | # self.shuffler = nn.Sequential(*shuffler_list)
17 | self.shuffler = PixelShuffle(1 / 2**depth)
18 |
19 | relu = nn.LeakyReLU(0.2, True)
20 |
21 | # FF_RCAN or FF_Resblocks
22 | self.interpolate = Interpolation(5, 12, in_channels * (4**depth), act=relu)
23 |
24 | def forward(self, x1, x2):
25 | """
26 | Encoder: Shuffle-spread --> Feature Fusion --> Return fused features
27 | """
28 | feats1 = self.shuffler(x1)
29 | feats2 = self.shuffler(x2)
30 |
31 | feats = self.interpolate(feats1, feats2)
32 |
33 | return feats
34 |
35 |
36 | class Decoder(nn.Module):
37 | def __init__(self, depth=3):
38 | super(Decoder, self).__init__()
39 |
40 | # shuffler_list = [PixelShuffle(2) for i in range(depth)]
41 | # self.shuffler = nn.Sequential(*shuffler_list)
42 | self.shuffler = PixelShuffle(2**depth)
43 |
44 | def forward(self, feats):
45 | out = self.shuffler(feats)
46 | return out
47 |
48 |
49 | class CAIN(nn.Module):
50 | def __init__(self, depth=3):
51 | super(CAIN, self).__init__()
52 |
53 | self.encoder = Encoder(in_channels=3, depth=depth)
54 | self.decoder = Decoder(depth=depth)
55 |
56 | def forward(self, x1, x2):
57 | x1, m1 = sub_mean(x1)
58 | x2, m2 = sub_mean(x2)
59 |
60 | if not self.training:
61 | paddingInput, paddingOutput = InOutPaddings(x1)
62 | x1 = paddingInput(x1)
63 | x2 = paddingInput(x2)
64 |
65 | feats = self.encoder(x1, x2)
66 | out = self.decoder(feats)
67 |
68 | if not self.training:
69 | out = paddingOutput(out)
70 |
71 | mi = (m1 + m2) / 2
72 | out += mi
73 |
74 | return out, feats
75 |
--------------------------------------------------------------------------------
/model/cain_encdec.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .common import *
8 |
9 |
10 | class Encoder(nn.Module):
11 | def __init__(self, in_channels=3, depth=3, nf_start=32, norm=False):
12 | super(Encoder, self).__init__()
13 | self.device = torch.device('cuda')
14 |
15 | nf = nf_start
16 | relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
17 |
18 | self.body = nn.Sequential(
19 | ConvNorm(in_channels, nf * 1, 7, stride=1, norm=norm),
20 | relu,
21 | ConvNorm(nf * 1, nf * 2, 5, stride=2, norm=norm),
22 | relu,
23 | ConvNorm(nf * 2, nf * 4, 5, stride=2, norm=norm),
24 | relu,
25 | ConvNorm(nf * 4, nf * 6, 5, stride=2, norm=norm)
26 | )
27 |
28 | self.interpolate = Interpolation(5, 12, nf * 6, reduction=16, act=relu)
29 |
30 | def forward(self, x1, x2):
31 | """
32 | Encoder: Feature Extraction --> Feature Fusion --> Return
33 | """
34 | feats1 = self.body(x1)
35 | feats2 = self.body(x2)
36 |
37 | feats = self.interpolate(feats1, feats2)
38 |
39 | return feats
40 |
41 |
42 | class Decoder(nn.Module):
43 | def __init__(self, in_channels=192, out_channels=3, depth=3, norm=False, up_mode='shuffle'):
44 | super(Decoder, self).__init__()
45 | self.device = torch.device('cuda')
46 |
47 | relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
48 |
49 | nf = [in_channels, (in_channels*2)//3, in_channels//3, in_channels//6]
50 | #nf = [192, 128, 64, 32]
51 | #nf = [186, 124, 62, 31]
52 | self.body = nn.Sequential(
53 | UpConvNorm(nf[0], nf[1], mode=up_mode, norm=norm),
54 | ResBlock(nf[1], nf[1], norm=norm, act=relu),
55 | UpConvNorm(nf[1], nf[2], mode=up_mode, norm=norm),
56 | ResBlock(nf[2], nf[2], norm=norm, act=relu),
57 | UpConvNorm(nf[2], nf[3], mode=up_mode, norm=norm),
58 | ResBlock(nf[3], nf[3], norm=norm, act=relu),
59 | conv7x7(nf[3], out_channels)
60 | )
61 |
62 | def forward(self, feats):
63 | out = self.body(feats)
64 | #out = self.conv_final(out)
65 |
66 | return out
67 |
68 |
69 | class CAIN_EncDec(nn.Module):
70 | def __init__(self, depth=3, n_resblocks=3, start_filts=32, up_mode='shuffle'):
71 | super(CAIN_EncDec, self).__init__()
72 | self.depth = depth
73 |
74 | self.encoder = Encoder(in_channels=3, depth=depth, norm=False)
75 | self.decoder = Decoder(in_channels=start_filts*6, depth=depth, norm=False, up_mode=up_mode)
76 |
77 | def forward(self, x1, x2):
78 | x1, m1 = sub_mean(x1)
79 | x2, m2 = sub_mean(x2)
80 |
81 | if not self.training:
82 | paddingInput, paddingOutput = InOutPaddings(x1)
83 | x1 = paddingInput(x1)
84 | x2 = paddingInput(x2)
85 |
86 | feats = self.encoder(x1, x2)
87 | out = self.decoder(feats)
88 |
89 | if not self.training:
90 | out = paddingOutput(out)
91 |
92 | mi = (m1 + m2)/2
93 | out += mi
94 |
95 | return out, feats
96 |
97 |
98 |
--------------------------------------------------------------------------------
/model/cain_noca.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .common import *
8 |
9 | class Encoder(nn.Module):
10 | def __init__(self, in_channels=3, depth=3):
11 | super(Encoder, self).__init__()
12 | self.device = torch.device('cuda')
13 |
14 | self.shuffler = PixelShuffle(1/2**depth)
15 | # self.shuffler = nn.Sequential(
16 | # PixelShuffle(1/2),
17 | # PixelShuffle(1/2),
18 | # PixelShuffle(1/2))
19 | self.interpolate = Interpolation_res(5, 12, in_channels * (4**depth))
20 |
21 | def forward(self, x1, x2):
22 | feats1 = self.shuffler(x1)
23 | feats2 = self.shuffler(x2)
24 |
25 | feats = self.interpolate(feats1, feats2)
26 |
27 | return feats
28 |
29 |
30 | class Decoder(nn.Module):
31 | def __init__(self, depth=3):
32 | super(Decoder, self).__init__()
33 | self.device = torch.device('cuda')
34 |
35 | self.shuffler = PixelShuffle(2**depth)
36 | # self.shuffler = nn.Sequential(
37 | # PixelShuffle(2),
38 | # PixelShuffle(2),
39 | # PixelShuffle(2))
40 |
41 | def forward(self, feats):
42 | out = self.shuffler(feats)
43 | return out
44 |
45 |
46 | class CAIN_NoCA(nn.Module):
47 | def __init__(self, depth=3):
48 | super(CAIN_NoCA, self).__init__()
49 | self.depth = depth
50 |
51 | self.encoder = Encoder(in_channels=3, depth=depth)
52 | self.decoder = Decoder(depth=depth)
53 |
54 | def forward(self, x1, x2):
55 | x1, m1 = sub_mean(x1)
56 | x2, m2 = sub_mean(x2)
57 |
58 | if not self.training:
59 | paddingInput, paddingOutput = InOutPaddings(x1)
60 | x1 = paddingInput(x1)
61 | x2 = paddingInput(x2)
62 |
63 | feats = self.encoder(x1, x2)
64 | out = self.decoder(feats)
65 |
66 | if not self.training:
67 | out = paddingOutput(out)
68 |
69 | mi = (m1 + m2) / 2
70 | out += mi
71 |
72 | return out, feats
73 |
--------------------------------------------------------------------------------
/model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | def sub_mean(x):
8 | mean = x.mean(2, keepdim=True).mean(3, keepdim=True)
9 | x -= mean
10 | return x, mean
11 |
12 | def InOutPaddings(x):
13 | w, h = x.size(3), x.size(2)
14 | padding_width, padding_height = 0, 0
15 | if w != ((w >> 7) << 7):
16 | padding_width = (((w >> 7) + 1) << 7) - w
17 | if h != ((h >> 7) << 7):
18 | padding_height = (((h >> 7) + 1) << 7) - h
19 | paddingInput = nn.ReflectionPad2d(padding=[padding_width // 2, padding_width - padding_width // 2,
20 | padding_height // 2, padding_height - padding_height // 2])
21 | paddingOutput = nn.ReflectionPad2d(padding=[0 - padding_width // 2, padding_width // 2 - padding_width,
22 | 0 - padding_height // 2, padding_height // 2 - padding_height])
23 | return paddingInput, paddingOutput
24 |
25 |
26 | class ConvNorm(nn.Module):
27 | def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False):
28 | super(ConvNorm, self).__init__()
29 |
30 | reflection_padding = kernel_size // 2
31 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
32 | self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True)
33 |
34 | self.norm = norm
35 | if norm == 'IN':
36 | self.norm = nn.InstanceNorm2d(out_feat, track_running_stats=True)
37 | elif norm == 'BN':
38 | self.norm = nn.BatchNorm2d(out_feat)
39 |
40 | def forward(self, x):
41 | out = self.reflection_pad(x)
42 | out = self.conv(out)
43 | if self.norm:
44 | out = self.norm(out)
45 | return out
46 |
47 |
48 | class UpConvNorm(nn.Module):
49 | def __init__(self, in_channels, out_channels, mode='transpose', norm=False):
50 | super(UpConvNorm, self).__init__()
51 |
52 | if mode == 'transpose':
53 | self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
54 | elif mode == 'shuffle':
55 | self.upconv = nn.Sequential(
56 | ConvNorm(in_channels, 4*out_channels, kernel_size=3, stride=1, norm=norm),
57 | PixelShuffle(2))
58 | else:
59 | # out_channels is always going to be the same as in_channels
60 | self.upconv = nn.Sequential(
61 | nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
62 | ConvNorm(in_channels, out_channels, kernel_size=1, stride=1, norm=norm))
63 |
64 | def forward(self, x):
65 | out = self.upconv(x)
66 | return out
67 |
68 |
69 |
70 | class meanShift(nn.Module):
71 | def __init__(self, rgbRange, rgbMean, sign, nChannel=3):
72 | super(meanShift, self).__init__()
73 | if nChannel == 1:
74 | l = rgbMean[0] * rgbRange * float(sign)
75 |
76 | self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0)
77 | self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1)
78 | self.shifter.bias.data = torch.Tensor([l])
79 | elif nChannel == 3:
80 | r = rgbMean[0] * rgbRange * float(sign)
81 | g = rgbMean[1] * rgbRange * float(sign)
82 | b = rgbMean[2] * rgbRange * float(sign)
83 |
84 | self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
85 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
86 | self.shifter.bias.data = torch.Tensor([r, g, b])
87 | else:
88 | r = rgbMean[0] * rgbRange * float(sign)
89 | g = rgbMean[1] * rgbRange * float(sign)
90 | b = rgbMean[2] * rgbRange * float(sign)
91 | self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0)
92 | self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1)
93 | self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b])
94 |
95 | # Freeze the meanShift layer
96 | for params in self.shifter.parameters():
97 | params.requires_grad = False
98 |
99 | def forward(self, x):
100 | x = self.shifter(x)
101 |
102 | return x
103 |
104 |
105 | """ CONV - (BN) - RELU - CONV - (BN) """
106 | class ResBlock(nn.Module):
107 | def __init__(self, in_feat, out_feat, kernel_size=3, reduction=False, bias=True, # 'reduction' is just for placeholder
108 | norm=False, act=nn.ReLU(True), downscale=False):
109 | super(ResBlock, self).__init__()
110 |
111 | self.body = nn.Sequential(
112 | ConvNorm(in_feat, out_feat, kernel_size=kernel_size, stride=2 if downscale else 1),
113 | act,
114 | ConvNorm(out_feat, out_feat, kernel_size=kernel_size, stride=1)
115 | )
116 |
117 | self.downscale = None
118 | if downscale:
119 | self.downscale = nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=2)
120 |
121 | def forward(self, x):
122 | res = x
123 | out = self.body(x)
124 | if self.downscale is not None:
125 | res = self.downscale(res)
126 | out += res
127 |
128 | return out
129 |
130 |
131 | ## Channel Attention (CA) Layer
132 | class CALayer(nn.Module):
133 | def __init__(self, channel, reduction=16):
134 | super(CALayer, self).__init__()
135 | # global average pooling: feature --> point
136 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
137 | # feature channel downscale and upscale --> channel weight
138 | self.conv_du = nn.Sequential(
139 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
140 | nn.ReLU(inplace=True),
141 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
142 | nn.Sigmoid()
143 | )
144 |
145 | def forward(self, x):
146 | y = self.avg_pool(x)
147 | y = self.conv_du(y)
148 | return x * y, y
149 |
150 |
151 | ## Residual Channel Attention Block (RCAB)
152 | class RCAB(nn.Module):
153 | def __init__(self, in_feat, out_feat, kernel_size, reduction, bias=True,
154 | norm=False, act=nn.ReLU(True), downscale=False, return_ca=False):
155 | super(RCAB, self).__init__()
156 |
157 | self.body = nn.Sequential(
158 | ConvNorm(in_feat, out_feat, kernel_size, stride=2 if downscale else 1, norm=norm),
159 | act,
160 | ConvNorm(out_feat, out_feat, kernel_size, stride=1, norm=norm),
161 | CALayer(out_feat, reduction)
162 | )
163 | self.downscale = downscale
164 | if downscale:
165 | self.downConv = nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=2, padding=1)
166 | self.return_ca = return_ca
167 |
168 | def forward(self, x):
169 | res = x
170 | out, ca = self.body(x)
171 | if self.downscale:
172 | res = self.downConv(res)
173 | out += res
174 |
175 | if self.return_ca:
176 | return out, ca
177 | else:
178 | return out
179 |
180 |
181 | ## Residual Group (RG)
182 | class ResidualGroup(nn.Module):
183 | def __init__(self, Block, n_resblocks, n_feat, kernel_size, reduction, act, norm=False):
184 | super(ResidualGroup, self).__init__()
185 |
186 | modules_body = [Block(n_feat, n_feat, kernel_size, reduction, bias=True, norm=norm, act=act)
187 | for _ in range(n_resblocks)]
188 | modules_body.append(ConvNorm(n_feat, n_feat, kernel_size, stride=1, norm=norm))
189 | self.body = nn.Sequential(*modules_body)
190 |
191 | def forward(self, x):
192 | res = self.body(x)
193 | res += x
194 | return res
195 |
196 |
197 | def pixel_shuffle(input, scale_factor):
198 | batch_size, channels, in_height, in_width = input.size()
199 |
200 | out_channels = int(int(channels / scale_factor) / scale_factor)
201 | out_height = int(in_height * scale_factor)
202 | out_width = int(in_width * scale_factor)
203 |
204 | if scale_factor >= 1:
205 | input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width)
206 | shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
207 | else:
208 | block_size = int(1 / scale_factor)
209 | input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size)
210 | shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
211 |
212 | return shuffle_out.view(batch_size, out_channels, out_height, out_width)
213 |
214 |
215 | class PixelShuffle(nn.Module):
216 | def __init__(self, scale_factor):
217 | super(PixelShuffle, self).__init__()
218 | self.scale_factor = scale_factor
219 |
220 | def forward(self, x):
221 | return pixel_shuffle(x, self.scale_factor)
222 | def extra_repr(self):
223 | return 'scale_factor={}'.format(self.scale_factor)
224 |
225 |
226 | def conv(in_channels, out_channels, kernel_size,
227 | stride=1, bias=True, groups=1):
228 | return nn.Conv2d(
229 | in_channels,
230 | out_channels,
231 | kernel_size=kernel_size,
232 | padding=kernel_size//2,
233 | stride=1,
234 | bias=bias,
235 | groups=groups)
236 |
237 |
238 | def conv1x1(in_channels, out_channels, stride=1, bias=True, groups=1):
239 | return nn.Conv2d(
240 | in_channels,
241 | out_channels,
242 | kernel_size=1,
243 | stride=stride,
244 | bias=bias,
245 | groups=groups)
246 |
247 | def conv3x3(in_channels, out_channels, stride=1,
248 | padding=1, bias=True, groups=1):
249 | return nn.Conv2d(
250 | in_channels,
251 | out_channels,
252 | kernel_size=3,
253 | stride=stride,
254 | padding=padding,
255 | bias=bias,
256 | groups=groups)
257 |
258 | def conv5x5(in_channels, out_channels, stride=1,
259 | padding=2, bias=True, groups=1):
260 | return nn.Conv2d(
261 | in_channels,
262 | out_channels,
263 | kernel_size=5,
264 | stride=stride,
265 | padding=padding,
266 | bias=bias,
267 | groups=groups)
268 |
269 | def conv7x7(in_channels, out_channels, stride=1,
270 | padding=3, bias=True, groups=1):
271 | return nn.Conv2d(
272 | in_channels,
273 | out_channels,
274 | kernel_size=7,
275 | stride=stride,
276 | padding=padding,
277 | bias=bias,
278 | groups=groups)
279 |
280 | def upconv2x2(in_channels, out_channels, mode='shuffle'):
281 | if mode == 'transpose':
282 | return nn.ConvTranspose2d(
283 | in_channels,
284 | out_channels,
285 | kernel_size=4,
286 | stride=2,
287 | padding=1)
288 | elif mode == 'shuffle':
289 | return nn.Sequential(
290 | conv3x3(in_channels, 4*out_channels),
291 | PixelShuffle(2))
292 | else:
293 | # out_channels is always going to be the same as in_channels
294 | return nn.Sequential(
295 | nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
296 | conv1x1(in_channels, out_channels))
297 |
298 |
299 |
300 | class Interpolation(nn.Module):
301 | def __init__(self, n_resgroups, n_resblocks, n_feats,
302 | reduction=16, act=nn.LeakyReLU(0.2, True), norm=False):
303 | super(Interpolation, self).__init__()
304 |
305 | # define modules: head, body, tail
306 | self.headConv = conv3x3(n_feats * 2, n_feats)
307 |
308 | modules_body = [
309 | ResidualGroup(
310 | RCAB,
311 | n_resblocks=n_resblocks,
312 | n_feat=n_feats,
313 | kernel_size=3,
314 | reduction=reduction,
315 | act=act,
316 | norm=norm)
317 | for _ in range(n_resgroups)]
318 | self.body = nn.Sequential(*modules_body)
319 |
320 | self.tailConv = conv3x3(n_feats, n_feats)
321 |
322 | def forward(self, x0, x1):
323 | # Build input tensor
324 | x = torch.cat([x0, x1], dim=1)
325 | x = self.headConv(x)
326 |
327 | res = self.body(x)
328 | res += x
329 |
330 | out = self.tailConv(res)
331 | return out
332 |
333 |
334 | class Interpolation_res(nn.Module):
335 | def __init__(self, n_resgroups, n_resblocks, n_feats,
336 | act=nn.LeakyReLU(0.2, True), norm=False):
337 | super(Interpolation_res, self).__init__()
338 |
339 | # define modules: head, body, tail (reduces concatenated inputs to n_feat)
340 | self.headConv = conv3x3(n_feats * 2, n_feats)
341 |
342 | modules_body = [ResidualGroup(ResBlock, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3,
343 | reduction=0, act=act, norm=norm)
344 | for _ in range(n_resgroups)]
345 | self.body = nn.Sequential(*modules_body)
346 |
347 | self.tailConv = conv3x3(n_feats, n_feats)
348 |
349 | def forward(self, x0, x1):
350 | # Build input tensor
351 | x = torch.cat([x0, x1], dim=1)
352 | x = self.headConv(x)
353 |
354 | res = x
355 | for m in self.body:
356 | res = m(res)
357 | res += x
358 |
359 | x = self.tailConv(res)
360 |
361 | return x
362 |
--------------------------------------------------------------------------------
/pytorch_msssim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from math import exp
4 | import numpy as np
5 |
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 |
12 | def create_window(window_size, channel=1):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda()
15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
16 | return window
17 |
18 | def create_window_3d(window_size, channel=1):
19 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
20 | _2D_window = _1D_window.mm(_1D_window.t())
21 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
22 | window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda()
23 | return window
24 |
25 |
26 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
27 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
28 | if val_range is None:
29 | if torch.max(img1) > 128:
30 | max_val = 255
31 | else:
32 | max_val = 1
33 |
34 | if torch.min(img1) < -0.5:
35 | min_val = -1
36 | else:
37 | min_val = 0
38 | L = max_val - min_val
39 | else:
40 | L = val_range
41 |
42 | padd = 0
43 | (_, channel, height, width) = img1.size()
44 | if window is None:
45 | real_size = min(window_size, height, width)
46 | window = create_window(real_size, channel=channel).to(img1.device)
47 |
48 | # mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
49 | # mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
50 | mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
51 | mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
52 |
53 | mu1_sq = mu1.pow(2)
54 | mu2_sq = mu2.pow(2)
55 | mu1_mu2 = mu1 * mu2
56 |
57 | # sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
58 | # sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
59 | # sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
60 |
61 | sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
62 | sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
63 | sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
64 |
65 | C1 = (0.01 * L) ** 2
66 | C2 = (0.03 * L) ** 2
67 |
68 | v1 = 2.0 * sigma12 + C2
69 | v2 = sigma1_sq + sigma2_sq + C2
70 | cs = torch.mean(v1 / v2) # contrast sensitivity
71 |
72 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
73 |
74 | if size_average:
75 | ret = ssim_map.mean()
76 | else:
77 | ret = ssim_map.mean(1).mean(1).mean(1)
78 |
79 | if full:
80 | return ret, cs
81 | return ret
82 |
83 |
84 | def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
85 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
86 | if val_range is None:
87 | if torch.max(img1) > 128:
88 | max_val = 255
89 | else:
90 | max_val = 1
91 |
92 | if torch.min(img1) < -0.5:
93 | min_val = -1
94 | else:
95 | min_val = 0
96 | L = max_val - min_val
97 | else:
98 | L = val_range
99 |
100 | padd = 0
101 | (_, _, height, width) = img1.size()
102 | if window is None:
103 | real_size = min(window_size, height, width)
104 | window = create_window_3d(real_size, channel=1).to(img1.device)
105 | # Channel is set to 1 since we consider color images as volumetric images
106 |
107 | img1 = img1.unsqueeze(1)
108 | img2 = img2.unsqueeze(1)
109 |
110 | mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
111 | mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
112 |
113 | mu1_sq = mu1.pow(2)
114 | mu2_sq = mu2.pow(2)
115 | mu1_mu2 = mu1 * mu2
116 |
117 | sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
118 | sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
119 | sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
120 |
121 | C1 = (0.01 * L) ** 2
122 | C2 = (0.03 * L) ** 2
123 |
124 | v1 = 2.0 * sigma12 + C2
125 | v2 = sigma1_sq + sigma2_sq + C2
126 | cs = torch.mean(v1 / v2) # contrast sensitivity
127 |
128 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
129 |
130 | if size_average:
131 | ret = ssim_map.mean()
132 | else:
133 | ret = ssim_map.mean(1).mean(1).mean(1)
134 |
135 | if full:
136 | return ret, cs
137 | return ret
138 |
139 |
140 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
141 | device = img1.device
142 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
143 | levels = weights.size()[0]
144 | mssim = []
145 | mcs = []
146 | for _ in range(levels):
147 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
148 | mssim.append(sim)
149 | mcs.append(cs)
150 |
151 | img1 = F.avg_pool2d(img1, (2, 2))
152 | img2 = F.avg_pool2d(img2, (2, 2))
153 |
154 | mssim = torch.stack(mssim)
155 | mcs = torch.stack(mcs)
156 |
157 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
158 | if normalize:
159 | mssim = (mssim + 1) / 2
160 | mcs = (mcs + 1) / 2
161 |
162 | pow1 = mcs ** weights
163 | pow2 = mssim ** weights
164 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
165 | output = torch.prod(pow1[:-1] * pow2[-1])
166 | return output
167 |
168 |
169 | # Classes to re-use window
170 | class SSIM(torch.nn.Module):
171 | def __init__(self, window_size=11, size_average=True, val_range=None):
172 | super(SSIM, self).__init__()
173 | self.window_size = window_size
174 | self.size_average = size_average
175 | self.val_range = val_range
176 |
177 | # Assume 3 channel for SSIM
178 | self.channel = 3
179 | self.window = create_window(window_size, channel=self.channel)
180 |
181 | def forward(self, img1, img2):
182 | (_, channel, _, _) = img1.size()
183 |
184 | if channel == self.channel and self.window.dtype == img1.dtype:
185 | window = self.window
186 | else:
187 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
188 | self.window = window
189 | self.channel = channel
190 |
191 | _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
192 | dssim = (1 - _ssim) / 2
193 | return dssim
194 |
195 | class MSSSIM(torch.nn.Module):
196 | def __init__(self, window_size=11, size_average=True, channel=3):
197 | super(MSSSIM, self).__init__()
198 | self.window_size = window_size
199 | self.size_average = size_average
200 | self.channel = channel
201 |
202 | def forward(self, img1, img2):
203 | # TODO: store window between calls if possible
204 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
205 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py \
4 | --exp_name CAIN_train \
5 | --dataset vimeo90k \
6 | --batch_size 16 \
7 | --test_batch_size 16 \
8 | --model cain \
9 | --depth 3 \
10 | --loss 1*L1 \
11 | --max_epoch 200 \
12 | --lr 0.0002 \
13 | --log_iter 100 \
14 | # --mode test
--------------------------------------------------------------------------------
/run_noca.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=1 python main.py \
4 | --exp_name CAIN_test_noca \
5 | --dataset vimeo90k \
6 | --batch_size 16 \
7 | --test_batch_size 16 \
8 | --model cain_noca \
9 | --depth 3 \
10 | --loss 1*L1 \
11 | --max_epoch 200 \
12 | --lr 0.0002 \
13 | --log_iter 100 \
14 | # --mode test
15 | # --resume True \
16 | # --resume_exp SH_5_12
17 | # --fix_encoder
--------------------------------------------------------------------------------
/test_custom.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=0 python generate.py \
4 | --exp_name CAIN_fin \
5 | --dataset custom \
6 | --data_root data/frame_seq \
7 | --img_fmt png \
8 | --batch_size 32 \
9 | --test_batch_size 16 \
10 | --model cain \
11 | --depth 3 \
12 | --loss 1*L1 \
13 | --resume \
14 | --mode test
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from datetime import datetime
3 | import os
4 | import sys
5 | import math
6 | import random
7 | import json
8 | import glob
9 | import logging
10 | import shutil
11 |
12 | import numpy as np
13 | import torch
14 | from torchvision import transforms
15 |
16 | from PIL import Image, ImageFont, ImageDraw
17 | #from skimage.measure import compare_psnr, compare_ssim
18 |
19 | try:
20 | from StringIO import StringIO # Python 2.7
21 | except ImportError:
22 | from io import BytesIO # Python 3.x
23 |
24 | import cv2
25 |
26 | from pytorch_msssim import ssim_matlab as ssim_pth
27 | # from pytorch_msssim import ssim as ssim_pth
28 |
29 | ##########################
30 | # Training Helper Functions for making main.py clean
31 | ##########################
32 |
33 | def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers, test_mode='medium', img_fmt='png'):
34 |
35 | if dataset_str == 'snufilm':
36 | from data.snufilm import get_loader
37 | test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, test_mode=test_mode)
38 | return None, test_loader
39 | elif dataset_str == 'vimeo90k':
40 | from data.vimeo90k import get_loader
41 | elif dataset_str == 'aim':
42 | from data.aim import get_loader
43 | elif dataset_str == 'custom':
44 | from data.video import get_loader
45 | test_loader = get_loader('test', data_root, test_batch_size, img_fmt=img_fmt, shuffle=False, num_workers=num_workers, n_frames=1)
46 | return test_loader
47 | else:
48 | raise NotImplementedError('Training / Testing for this dataset is not implemented.')
49 |
50 | train_loader = get_loader('train', data_root, batch_size, shuffle=True, num_workers=num_workers)
51 | if dataset_str == 'aim':
52 | test_loader = get_loader('val', data_root, test_batch_size, shuffle=False, num_workers=num_workers)
53 | else:
54 | test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers)
55 |
56 | return train_loader, test_loader
57 |
58 |
59 | def build_input(images, imgpaths, is_training=True, include_edge=False, device=torch.device('cuda')):
60 | if isinstance(images[0], list):
61 | images_gathered = [None, None, None]
62 | for j in range(len(images[0])): # 3
63 | _images = [images[k][j] for k in range(len(images))]
64 | images_gathered[j] = torch.cat(_images, 0)
65 | imgpaths = [p for _ in images for p in imgpaths]
66 | images = images_gathered
67 |
68 | im1, im2 = images[0].to(device), images[2].to(device)
69 | gt = images[1].to(device)
70 |
71 | return im1, im2, gt
72 |
73 |
74 | def load_checkpoint(args, model, optimizer, fix_loaded=False):
75 | if args.resume_exp is None:
76 | args.resume_exp = args.exp_name
77 | if args.mode == 'test':
78 | load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth')
79 | else:
80 | #load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth')
81 | load_name = os.path.join('checkpoint', args.resume_exp, 'checkpoint.pth')
82 | print("loading checkpoint %s" % load_name)
83 | checkpoint = torch.load(load_name)
84 | args.start_epoch = checkpoint['epoch'] + 1
85 | if args.resume_exp != args.exp_name:
86 | args.start_epoch = 0
87 |
88 | # filter out different keys or those with size mismatch
89 | model_dict = model.state_dict()
90 | ckpt_dict = {}
91 | mismatch = False
92 | for k, v in checkpoint['state_dict'].items():
93 | if k in model_dict:
94 | if model_dict[k].size() == v.size():
95 | ckpt_dict[k] = v
96 | else:
97 | print('Size mismatch while loading! %s != %s Skipping %s...'
98 | % (str(model_dict[k].size()), str(v.size()), k))
99 | mismatch = True
100 | else:
101 | mismatch = True
102 | if len(model.state_dict().keys()) > len(ckpt_dict.keys()):
103 | mismatch = True
104 | # Overwrite parameters to model_dict
105 | model_dict.update(ckpt_dict)
106 | # Load to model
107 | model.load_state_dict(model_dict)
108 | # if size mismatch, give up on loading optimizer; if resuming from other experiment, also don't load optimizer
109 | if (not mismatch) and (optimizer is not None) and (args.resume_exp is not None):
110 | optimizer.load_state_dict(checkpoint['optimizer'])
111 | update_lr(optimizer, args.lr)
112 | if fix_loaded:
113 | for k, param in model.named_parameters():
114 | if k in ckpt_dict.keys():
115 | print(k)
116 | param.requires_grad = False
117 | print("loaded checkpoint %s" % load_name)
118 | del checkpoint, ckpt_dict, model_dict
119 |
120 |
121 | def save_checkpoint(state, is_best, exp_name, filename='checkpoint.pth'):
122 | """Saves checkpoint to disk"""
123 | directory = "checkpoint/%s/" % (exp_name)
124 | if not os.path.exists(directory):
125 | os.makedirs(directory)
126 | filename = directory + filename
127 | torch.save(state, filename)
128 | if is_best:
129 | shutil.copyfile(filename, 'checkpoint/%s/' % (exp_name) + 'model_best.pth')
130 |
131 |
132 | def init_lpips_eval():
133 | LPIPS_dir = "../PerceptualSimilarity"
134 | LPIPS_net = "squeeze"
135 | sys.path.append(LPIPS_dir)
136 | from models import dist_model as dm
137 | print("Initialize Distance model from %s" % LPIPS_net)
138 | lpips_model = dm.DistModel()
139 | lpips_model.initialize(model='net-lin',net='squeeze', use_gpu=True,
140 | model_path=os.path.join(LPIPS_dir, 'weights/v0.1/%s.pth' % LPIPS_net))
141 | return lpips_model
142 |
143 |
144 | ##########################
145 | # Evaluations
146 | ##########################
147 |
148 | class AverageMeter(object):
149 | """Computes and stores the average and current value"""
150 | def __init__(self):
151 | self.reset()
152 |
153 | def reset(self):
154 | self.val = 0
155 | self.avg = 0
156 | self.sum = 0
157 | self.count = 0
158 |
159 | def update(self, val, n=1):
160 | self.val = val
161 | self.sum += val * n
162 | self.count += n
163 | self.avg = self.sum / self.count
164 |
165 |
166 | def init_losses(loss_str):
167 | loss_specifics = {}
168 | loss_list = loss_str.split('+')
169 | for l in loss_list:
170 | _, loss_type = l.split('*')
171 | loss_specifics[loss_type] = AverageMeter()
172 | loss_specifics['total'] = AverageMeter()
173 | return loss_specifics
174 |
175 |
176 | def init_meters(loss_str):
177 | losses = init_losses(loss_str)
178 | psnrs = AverageMeter()
179 | ssims = AverageMeter()
180 | lpips = AverageMeter()
181 | return losses, psnrs, ssims, lpips
182 |
183 |
184 | def quantize(img, rgb_range=255):
185 | return img.mul(255 / rgb_range).clamp(0, 255).round()
186 |
187 |
188 | def calc_psnr(pred, gt, mask=None):
189 | '''
190 | Here we assume quantized(0-255) arguments.
191 | '''
192 | diff = (pred - gt).div(255)
193 |
194 | if mask is not None:
195 | mse = diff.pow(2).sum() / (3 * mask.sum())
196 | else:
197 | mse = diff.pow(2).mean() + 1e-8 # mse can (surprisingly!) reach 0, which results in math domain error
198 |
199 | return -10 * math.log10(mse)
200 |
201 |
202 | def calc_ssim(img1, img2, datarange=255.):
203 | im1 = img1.numpy().transpose(1, 2, 0).astype(np.uint8)
204 | im2 = img2.numpy().transpose(1, 2, 0).astype(np.uint8)
205 | return compare_ssim(im1, im2, datarange=datarange, multichannel=True, gaussian_weights=True)
206 |
207 |
208 | def calc_metrics(im_pred, im_gt, mask=None):
209 | q_im_pred = quantize(im_pred.data, rgb_range=1.)
210 | q_im_gt = quantize(im_gt.data, rgb_range=1.)
211 | if mask is not None:
212 | q_im_pred = q_im_pred * mask
213 | q_im_gt = q_im_gt * mask
214 | psnr = calc_psnr(q_im_pred, q_im_gt, mask=mask)
215 | # ssim = calc_ssim(q_im_pred.cpu(), q_im_gt.cpu())
216 | ssim = ssim_pth(q_im_pred.unsqueeze(0), q_im_gt.unsqueeze(0), val_range=255)
217 | return psnr, ssim
218 |
219 |
220 | def eval_LPIPS(model, im_pred, im_gt):
221 | im_pred = 2.0 * im_pred - 1
222 | im_gt = 2.0 * im_gt - 1
223 | dist = model.forward(im_pred, im_gt)[0]
224 | return dist
225 |
226 |
227 | def eval_metrics(output, gt, psnrs, ssims, lpips, lpips_model=None, mask=None, psnrs_masked=None, ssims_masked=None):
228 | # PSNR should be calculated for each image
229 | for b in range(gt.size(0)):
230 | psnr, ssim = calc_metrics(output[b], gt[b], None)
231 | psnrs.update(psnr)
232 | ssims.update(ssim)
233 | if mask is not None:
234 | psnr_masked, ssim_masked = calc_metrics(output[b], gt[b], mask[b])
235 | psnrs_masked.update(psnr_masked)
236 | ssims_masked.update(ssim_masked)
237 | if lpips_model is not None:
238 | _lpips = eval_LPIPS(lpips_model, output[b].unsqueeze(0), gt[b].unsqueeze(0))
239 | lpips.update(_lpips)
240 |
241 |
242 | ##########################
243 | # ETC
244 | ##########################
245 |
246 | def get_time():
247 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
248 |
249 | def makedirs(path):
250 | if not os.path.exists(path):
251 | print("[*] Make directories : {}".format(path))
252 | os.makedirs(path)
253 |
254 | def remove_file(path):
255 | if os.path.exists(path):
256 | print("[*] Removed: {}".format(path))
257 | os.remove(path)
258 |
259 | def backup_file(path):
260 | root, ext = os.path.splitext(path)
261 | new_path = "{}.backup_{}{}".format(root, get_time(), ext)
262 |
263 | os.rename(path, new_path)
264 | print("[*] {} has backup: {}".format(path, new_path))
265 |
266 | def update_lr(optimizer, lr):
267 | for param_group in optimizer.param_groups:
268 | param_group['lr'] = lr
269 |
270 |
271 | # TensorBoard
272 | def log_tensorboard(writer, losses, psnr, ssim, lpips, lr, timestep, mode='train'):
273 | for k, v in losses.items():
274 | writer.add_scalar('Loss/%s/%s' % (mode, k), v.avg, timestep)
275 | writer.add_scalar('PSNR/%s' % mode, psnr, timestep)
276 | writer.add_scalar('SSIM/%s' % mode, ssim, timestep)
277 | if lpips is not None:
278 | writer.add_scalar('LPIPS/%s' % mode, lpips, timestep)
279 | if mode == 'train':
280 | writer.add_scalar('lr', lr, timestep)
281 |
282 |
283 | ###########################
284 | ###### VISUALIZATIONS #####
285 | ###########################
286 |
287 | def save_image(img, path):
288 | # img : torch Tensor of size (C, H, W)
289 | q_im = quantize(img.data.mul(255))
290 | if len(img.size()) == 2: # grayscale image
291 | im = Image.fromarray(q_im.cpu().numpy().astype(np.uint8), 'L')
292 | elif len(img.size()) == 3:
293 | im = Image.fromarray(q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB')
294 | else:
295 | pass
296 | im.save(path)
297 |
298 | def save_batch_images(output, imgpath, save_dir, alpha=0.5):
299 | GEN = save_dir.find('-gen') >= 0 or save_dir.find('stereo') >= 0
300 | q_im_output = [quantize(o.data, rgb_range=1.) for o in output]
301 | for b in range(output[0].size(0)):
302 | paths = imgpath[0][b].split('/')
303 | if GEN:
304 | save_path = save_dir
305 | else:
306 | save_path = os.path.join(save_dir, paths[-3], paths[-2])
307 | makedirs(save_path)
308 | for o in range(len(output)):
309 | if o % 2 == 1 or len(output) == 1:
310 | output_img = Image.fromarray(q_im_output[o][b].permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB')
311 | if GEN:
312 | _imgname = imgpath[o//2][b].split('/')[-1]
313 | imgname = "%s-%.04f.png" % (_imgname, alpha)
314 | else:
315 | imgname = imgpath[o//2][b].split('/')[-1]
316 |
317 | if save_dir.find('voxelflow') >= 0:
318 | #imgname = imgname.replace('gt', 'ours')
319 | imgname = 'frame_01_ours.png'
320 | elif save_dir.find('middlebury') >= 0:
321 | imgname = 'frame10i11.png'
322 |
323 | output_img.save(os.path.join(save_path, imgname))
324 |
325 |
326 | def save_batch_images_test(output, imgpath, save_dir, alpha=0.5):
327 | GEN = save_dir.find('-gen') >= 0 or save_dir.find('stereo') >= 0
328 | q_im_output = [quantize(o.data, rgb_range=1.) for o in output]
329 | for b in range(output[0].size(0)):
330 | paths = imgpath[0][b].split('/')
331 | if GEN:
332 | save_path = save_dir
333 | else:
334 | save_path = os.path.join(save_dir, paths[-3], paths[-2])
335 | makedirs(save_path)
336 | for o in range(len(output)):
337 | # if o % 2 == 1 or len(output) == 1:
338 | # print(" ", o, b, imgpath[o][b])
339 | output_img = Image.fromarray(q_im_output[o][b].permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB')
340 | if GEN:
341 | _imgname = imgpath[o][b].split('/')[-1]
342 | imgname = "%s-%.04f.png" % (_imgname, alpha)
343 | else:
344 | imgname = imgpath[o][b].split('/')[-1]
345 |
346 | if save_dir.find('voxelflow') >= 0:
347 | #imgname = imgname.replace('gt', 'ours')
348 | imgname = 'frame_01_ours.png'
349 | elif save_dir.find('middlebury') >= 0:
350 | imgname = 'frame10i11.png'
351 |
352 | output_img.save(os.path.join(save_path, imgname))
353 |
354 |
355 | def save_images_test(output, imgpath, save_dir, alpha=0.5):
356 | q_im_output = [quantize(o.data, rgb_range=1.) for o in output]
357 | for b in range(output[0].size(0)):
358 | paths = imgpath[1][b].split('/')
359 | save_path = os.path.join(save_dir, paths[-3], paths[-2])
360 | makedirs(save_path)
361 | # Output length is one
362 | output_img = Image.fromarray(q_im_output[0][b].permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB')
363 | imgname = imgpath[1][b].split('/')[-1]
364 |
365 | # if save_dir.find('voxelflow') >= 0:
366 | # imgname = 'frame_01_ours.png'
367 | # elif save_dir.find('middlebury') >= 0:
368 | # imgname = 'frame10i11.png'
369 |
370 | output_img.save(os.path.join(save_path, imgname))
371 |
372 |
373 | def save_images_multi(output, imgpath, save_dir, idx=1):
374 | q_im_output = [quantize(o.data, rgb_range=1.) for o in output]
375 | for b in range(output[0].size(0)):
376 | paths = imgpath[0][b].split('/')
377 | # save_path = os.path.join(save_dir, paths[-3], paths[-2])
378 | # makedirs(save_path)
379 | # Output length is one
380 | output_img = Image.fromarray(q_im_output[0][b].permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB')
381 | # imgname = imgpath[idx][b].split('/')[-1]
382 | imgname = '%s_%03d.png' % (paths[-1], idx)
383 |
384 | output_img.save(os.path.join(save_dir, imgname))
385 |
386 |
387 | def make_video(out_dir, gt_dir, gt_first=False):
388 | gt_ext = '/*.png'
389 | frames_all = sorted(glob.glob(out_dir + '/*.png') + glob.glob(gt_dir + gt_ext), \
390 | key=lambda frame: frame.split('/')[-1])
391 | print("# of total frames : %d" % len(frames_all))
392 | if gt_first:
393 | print("Appending GT in front..")
394 | frames_all = sorted(glob.glob(gt_dir + gt_ext)) + frames_all
395 | print("# of total frames : %d" % len(frames_all))
396 |
397 | # Read the first image to determine height and width
398 | frame = cv2.imread(frames_all[0])
399 | h, w, _ = frame.shape
400 |
401 | # Write video
402 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
403 | out = cv2.VideoWriter(out_dir + '/slomo.mp4', fourcc, 30, (w, h))
404 | for p in frames_all:
405 | #print(p)
406 | # TODO: add captions (e.g. 'GT', 'slow motion x4')
407 | frame = cv2.imread(p)
408 | fh, fw = frame.shape[:2]
409 | #print(fh, fw, h, w)
410 | if fh != h or fw != w:
411 | frame = cv2.resize(frame, (w, h), interpolation=cv2.INTER_LINEAR)
412 | out.write(frame)
413 |
414 | def check_already_extracted(vid):
415 | return bool(os.path.exists(vid + '/00001.png'))
416 |
--------------------------------------------------------------------------------