├── .gitignore
├── README.md
├── denoising
├── data
│ ├── __init__.py
│ └── datasets.py
├── main.py
├── models
│ ├── __init__.py
│ ├── dncnn.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── losses.py
│ │ └── misc.py
│ ├── vgg.py
│ ├── xdense.py
│ └── xdncnn.py
├── trainer.py
└── utils
│ ├── __init__.py
│ ├── misc.py
│ ├── optim.py
│ └── recorderx.py
├── figures
├── activations.png
├── figure1.png
└── tradeoff.png
├── requirements.txt
└── super-resolution
├── data
├── __init__.py
└── datasets.py
├── main.py
├── models
├── __init__.py
├── modules
│ ├── __init__.py
│ ├── activations.py
│ ├── losses.py
│ └── misc.py
├── srgan.py
├── vanilla.py
├── vgg.py
├── xdense.py
├── xsrgan.py
└── xvanilla.py
├── scripts
├── matlab
│ └── bicubic_subsample.m
└── python
│ ├── extract_images.py
│ └── remove_images.py
├── trainer.py
└── utils
├── __init__.py
├── core.py
├── misc.py
├── optim.py
├── recorderx.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | results/
2 | tmp/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *,cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # IPython Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # dotenv
82 | .env
83 |
84 | # virtualenv
85 | venv/
86 | ENV/
87 |
88 | # Spyder project settings
89 | .spyderproject
90 |
91 | # Rope project settings
92 | .ropeproject
93 |
94 | # Pycharm
95 | .idea/
96 |
97 | # Vscode
98 | .vscode/
99 |
100 | # jobs
101 | *.job
102 |
103 | # tar
104 | *.tar.gz
105 | *.pb
106 |
107 | *.pth
108 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # xUnit
2 | Learning a Spatial Activation Function for Efficient Image Restoration.
3 |
4 |
5 |
6 |
7 |
8 | Please refer our [paper](https://arxiv.org/abs/1711.06445) for more details.
9 |
10 |
11 | ## Citation
12 | If you use this code for your research, please cite our papers:
13 |
14 | ```
15 | @inproceedings{kligvasser2018xunit,
16 | title={xunit: Learning a spatial activation function for efficient image restoration},
17 | author={Kligvasser, Idan and Rott Shaham, Tamar and Michaeli, Tomer},
18 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
19 | pages={2433--2442},
20 | year={2018}
21 | }
22 | ```
23 |
24 | ```
25 | @article{kligvasser2018dense,
26 | title={Dense xUnit Networks},
27 | author={Kligvasser, Idan and Michaeli, Tomer},
28 | journal={arXiv preprint arXiv:1811.11051},
29 | year={2018}
30 | }
31 | ```
32 |
33 | ## Code
34 |
35 | ### Clone repository
36 |
37 | Clone this repository into any place you want.
38 |
39 | ```
40 | git clone https://github.com/kligvasser/xUnit
41 | cd xUnit
42 | ```
43 |
44 | ### Install dependencies
45 |
46 | ```
47 | python -m pip install -r requirements.txt
48 | ```
49 |
50 | This code requires PyTorch 1.0+ and python 3+.
51 |
52 | ### Super-resoltution
53 | Pretrained models are avaible at: [LINK](https://www.dropbox.com/s/hq1n5yrl5hjsh34/sr_pretrained.zip?dl=0).
54 |
55 |
56 |
57 |
58 |
59 | #### Dataset preparation
60 | For the super-resolution task, the dataset should contains a low and high resolution pairs, in folder structure of:
61 |
62 | ```txt
63 | train
64 | ├── img
65 | ├── img_x2
66 | ├── img_x4
67 | val
68 | ├── img
69 | ├── img_x2
70 | ├── img_x4
71 | ```
72 |
73 | You may prepare your own data by using the matlab script:
74 |
75 | ```
76 | ./super-resolution/scripts/matlab/bicubic_subsample.m
77 | ```
78 |
79 | Or download a prepared dataset based on the [BSD](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/) and [VOC](http://host.robots.ox.ac.uk/pascal/VOC/) datasets from [LINK](https://www.dropbox.com/s/o1nzpr9q7vup8b7/bsdvoc.zip?dl=0).
80 |
81 | #### Train xSRGAN x4 PSNR model
82 | ```
83 | python3 main.py --root --g-model g_xsrgan --d-model d_xsrgan --model-config "{'scale':4, 'gen_blocks':10, 'dis_blocks':5}" --scale 4 --reconstruction-weight 1.0 --perceptual-weight 0 --adversarial-weight 0 --crop-size 40
84 | ```
85 |
86 | #### Train xSRGAN x4 WGAN-GP model
87 | ```
88 | python3 main.py --root --g-model g_xsrgan --d-model d_xsrgan_ad --model-config "{'scale':4, 'gen_blocks':10, 'dis_blocks':5}" --scale 4 --reconstruction-weight 1.0 --perceptual-weight 1.0 --adversarial-weight 0.005 --crop-size 64 --epochs 1200 --step-size 900 --gen-to-load --wgan --penalty-weight 10
89 | ```
90 |
91 |
92 | #### Train xSRGAN x4 with SN-discriminator model
93 | ```
94 | python3 main.py --root --g-model g_xsrgan --d-model d_xsrgan --model-config "{'scale':4, 'gen_blocks':10, 'dis_blocks':5, 'spectral':True}" --scale 4 --reconstruction-weight 1.0 --perceptual-weight 1.0 --adversarial-weight 0.01 --crop-size 40 --epochs 2000 --step-size 800 --gen-to-load --dis-betas 0 0.9
95 | ```
96 |
97 | #### Eval xSRGAN x4 model
98 | ```
99 | python3 main.py --root --g-model g_xsrgan --d-model d_xsrgan --model-config "{'scale':4, 'gen_blocks':10, 'dis_blocks':5}" --scale 4 --evaluation --gen-to-load
100 | ```
101 |
102 | ### Gaussian denoising
103 | Pretrained models are avaible at: [LINK](https://www.dropbox.com/s/zychmfzx52y8tvq/denoising_pretrained.zip?dl=0).
104 |
105 | #### Dataset preparation
106 | For the denoising task, the dataset should contains only clean images, in folder structure of:
107 |
108 | ```txt
109 | train
110 | ├── img
111 | val
112 | ├── img
113 | ```
114 |
115 | #### Train xDNCNN Grayscale 50 sigma PSNR model
116 | ```
117 | python3 main.py --root --g-model g_xdncnn --d-model d_xdncnn --model-config "{'gen_blocks':10, 'dis_blocks':4, 'in_channels':1}" --reconstruction-weight 1.0 --perceptual-weight 0 --adversarial-weight 0 --crop-size 50 --gray-scale --noise-sigma 50 --epochs 500 --step-size 150
118 | ```
119 |
120 | #### Train xDNCNN 75 sigma PSNR model
121 | ```
122 | python3 main.py --root --g-model g_xdncnn --d-model d_xdncnn --model-config "{'gen_blocks':10, 'dis_blocks':4, 'in_channels':3}" --reconstruction-weight 1.0 --perceptual-weight 0 --adversarial-weight 0 --crop-size 64 --noise-sigma 75 --epochs 1000 --step-size 300
123 | ```
124 |
125 | #### Train xDNCNN 75 sigma WGAN-GP model
126 | ```
127 | python3 main.py --root --g-model g_xdncnn --d-model d_xdncnn --model-config "{'gen_blocks':10, 'dis_blocks':4, 'in_channels':3}" --reconstruction-weight 1.0 --perceptual-weight 1.0 --adversarial-weight 0.01 --crop-size 72 --noise-sigma 75 --epochs 1000 --step-size 300 --gen-to-load --wgan --penalty-weight 10
128 | ```
129 |
130 | #### Train xDNCNN Grayscale blind PSNR model
131 | ```
132 | python3 main.py --root --g-model g_xdncnn --d-model d_xdncnn --model-config "{'gen_blocks':10, 'dis_blocks':5, 'in_channels':1}" --reconstruction-weight 1.0 --perceptual-weight 0 --adversarial-weight 0 --crop-size 50 --gray-scale --noise-sigma 50 --blind --epochs 500 --step-size 150
133 |
134 | ```
--------------------------------------------------------------------------------
/denoising/data/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .datasets import DatasetNoise
3 | from torch.utils.data import DataLoader
4 |
5 | def get_loaders(args):
6 | # datasets
7 | dataset_train = DatasetNoise(root=os.path.join(args.root, 'train'), noise_sigma=args.noise_sigma, training=True, crop_size=args.crop_size, blind_denoising=args.blind, gray_scale=args.gray_scale)
8 | dataset_val = DatasetNoise(root=os.path.join(args.root, 'val'), noise_sigma=args.noise_sigma, training=False, crop_size=args.crop_size, blind_denoising=args.blind, gray_scale=args.gray_scale, max_size=args.max_size)
9 |
10 | # loaders
11 | loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
12 | loader_eval = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1)
13 | loaders = {'train': loader_train, 'eval': loader_eval}
14 |
15 | return loaders
16 |
--------------------------------------------------------------------------------
/denoising/data/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import glob
3 | import os
4 | import random
5 | from torchvision import transforms
6 | from PIL import Image
7 |
8 | class DatasetNoise(torch.utils.data.dataset.Dataset):
9 | def __init__(self, root='', noise_sigma=50., training=True, crop_size=60, blind_denoising=False, gray_scale=False, max_size=None):
10 | self.root = root
11 | self.noise_sigma = noise_sigma
12 | self.training = training
13 | self.crop_size = crop_size
14 | self.blind_denoising = blind_denoising
15 | self.gray_scale = gray_scale
16 | self.max_size = max_size
17 |
18 | self._init()
19 |
20 | def _init(self):
21 | # data paths
22 | targets = glob.glob(os.path.join(self.root, 'img', '*.*'))[:self.max_size]
23 | self.paths = {'target' : targets}
24 |
25 | # transforms
26 | t_list = [transforms.ToTensor()]
27 | self.image_transform = transforms.Compose(t_list)
28 |
29 | def _get_augment_params(self, size):
30 | random.seed(random.randint(0, 12345))
31 |
32 | # position
33 | w_size, h_size = size
34 | x = random.randint(0, max(0, w_size - self.crop_size))
35 | y = random.randint(0, max(0, h_size - self.crop_size))
36 |
37 | # flip
38 | flip = random.random() > 0.5
39 | return {'crop_pos': (x, y), 'flip': flip}
40 |
41 | def _augment(self, image, aug_params):
42 | x, y = aug_params['crop_pos']
43 | image = image.crop((x, y, x + self.crop_size, y + self.crop_size))
44 | if aug_params['flip']:
45 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
46 | return image
47 |
48 | def __getitem__(self, index):
49 | # target image
50 | if self.gray_scale:
51 | target = Image.open(self.paths['target'][index]).convert('L')
52 | else:
53 | target = Image.open(self.paths['target'][index]).convert('RGB')
54 |
55 | # transform
56 | if self.training:
57 | aug_params = self._get_augment_params(target.size)
58 | target = self._augment(target, aug_params)
59 | target = self.image_transform(target)
60 |
61 | # add noise
62 | if self.blind_denoising:
63 | noise_sigma = random.randint(0, self.noise_sigma)
64 | else:
65 | noise_sigma = self.noise_sigma
66 | input = target + (noise_sigma / 255.) * torch.randn_like(target)
67 |
68 | return {'input': input, 'target': target, 'path': self.paths['target'][index]}
69 |
70 | def __len__(self):
71 | return len(self.paths['target'])
--------------------------------------------------------------------------------
/denoising/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import logging
4 | import signal
5 | import sys
6 | import torch.backends.cudnn as cudnn
7 | from trainer import Trainer
8 | from datetime import datetime
9 | from os import path
10 | from utils import misc
11 |
12 | # torch.autograd.set_detect_anomaly(True)
13 |
14 | def get_arguments():
15 | parser = argparse.ArgumentParser(description='super-resolution')
16 | parser.add_argument('--device', default='cuda', help='device assignment ("cpu" or "cuda")')
17 | parser.add_argument('--device-ids', default=[0], type=int, nargs='+', help='device ids assignment (e.g 0 1 2 3')
18 | parser.add_argument('--d-model', default='d_dncnn', help='discriminator architecture (default: dncnn)')
19 | parser.add_argument('--g-model', default='g_dncnn', help='generator architecture (default: dncnn)')
20 | parser.add_argument('--model-config', default='', help='additional architecture configuration')
21 | parser.add_argument('--dis-to-load', default='', help='resume training from file (default: None)')
22 | parser.add_argument('--gen-to-load', default='', help='resume training from file (default: None)')
23 | parser.add_argument('--root', default='', help='root dataset folder')
24 | parser.add_argument('--noise-sigma', default=50, type=int, help='noise-sigma (default: 50)')
25 | parser.add_argument('--crop-size', default=64, type=int, help='low resolution cropping size (default: 64)')
26 | parser.add_argument('--gray-scale', default=False, action='store_true', help='gray-scale images (default: false)')
27 | parser.add_argument('--blind', default=False, action='store_true', help='blind-denoising images (default: false)')
28 | parser.add_argument('--max-size', default=None, type=int, help='validation set max-size (default: None)')
29 | parser.add_argument('--num-workers', default=2, type=int, help='number of workers (default: 2)')
30 | parser.add_argument('--batch-size', default=16, type=int, help='batch-size (default: 16)')
31 | parser.add_argument('--epochs', default=1000, type=int, help='epochs (default: 1000)')
32 | parser.add_argument('--lr', default=2e-4, type=float, help='lr (default: 2e-4)')
33 | parser.add_argument('--gen-betas', default=[0.9, 0.99], nargs=2, type=float, help='scheduler gamma (default: 0.9, 0.99)')
34 | parser.add_argument('--dis-betas', default=[0.9, 0.99], nargs=2, type=float, help='scheduler gamma (default: 0.9, 0.99)')
35 | parser.add_argument('--num-critic', default=1, type=int, help='critic iterations (default: 1)')
36 | parser.add_argument('--wgan', default=False, action='store_true', help='critic wgan loss (default: false)')
37 | parser.add_argument('--relativistic', default=False, action='store_true', help='relativistic wgan loss (default: false)')
38 | parser.add_argument('--step-size', default=300, type=int, help='scheduler step size (default: 300)')
39 | parser.add_argument('--gamma', default=0.5, type=float, help='scheduler gamma (default: 0.5)')
40 | parser.add_argument('--penalty-weight', default=0, type=float, help='gradient penalty weight (default: 0)')
41 | parser.add_argument('--range-weight', default=0, type=float, help='pixel-weight (default: 0)')
42 | parser.add_argument('--reconstruction-weight', default=1.0, type=float, help='reconstruction-weight (default: 1.0)')
43 | parser.add_argument('--perceptual-weight', default=0, type=float, help='perceptual-weight (default: 0)')
44 | parser.add_argument('--adversarial-weight', default=0.01, type=float, help='adversarial-weight (default: 0.01)')
45 | parser.add_argument('--style-weight', default=0, type=float, help='style-weight (default: 0)')
46 | parser.add_argument('--print-every', default=20, type=int, help='print-every (default: 20)')
47 | parser.add_argument('--eval-every', default=50, type=int, help='eval-every (default: 50)')
48 | parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results', help='results dir')
49 | parser.add_argument('--save', metavar='SAVE', default='', help='saved folder')
50 | parser.add_argument('--evaluation', default=False, action='store_true', help='evaluate a model (default: false)')
51 | parser.add_argument('--use-tb', default=False, action='store_true', help='use tensorboardx (default: false)')
52 | args = parser.parse_args()
53 |
54 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
55 | if args.save == '':
56 | args.save = time_stamp
57 | args.save_path = path.join(args.results_dir, args.save)
58 | return args
59 |
60 | def main():
61 | # arguments
62 | args = get_arguments()
63 |
64 | # cuda
65 | if 'cuda' in args.device and torch.cuda.is_available():
66 | torch.cuda.set_device(args.device_ids[0])
67 | cudnn.benchmark = True
68 | else:
69 | args.device_ids = None
70 |
71 | # set logs
72 | misc.mkdir(args.save_path)
73 | misc.mkdir(path.join(args.save_path, 'images'))
74 | misc.setup_logging(path.join(args.save_path, 'log.txt'))
75 |
76 | # print logs
77 | logging.info(args)
78 |
79 | # trainer
80 | trainer = Trainer(args)
81 |
82 | if args.evaluation:
83 | trainer.eval()
84 | else:
85 | trainer.train()
86 |
87 | if __name__ == '__main__':
88 | # enables a ctrl-c without triggering errors
89 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))
90 | main()
91 |
--------------------------------------------------------------------------------
/denoising/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .dncnn import *
2 | from .xdncnn import *
3 | from .xdense import *
--------------------------------------------------------------------------------
/denoising/models/dncnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
5 |
6 | __all__ = ['g_dncnn', 'd_dncnn']
7 |
8 | def initialize_weights(net, scale=1.):
9 | if not isinstance(net, list):
10 | net = [net]
11 | for layer in net:
12 | for m in layer.modules():
13 | if isinstance(m, nn.Conv2d):
14 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
15 | m.weight.data *= scale # for residual block
16 | if m.bias is not None:
17 | m.bias.data.zero_()
18 | elif isinstance(m, nn.Linear):
19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20 | m.weight.data *= scale
21 | if m.bias is not None:
22 | m.bias.data.zero_()
23 | elif isinstance(m, nn.BatchNorm2d):
24 | nn.init.constant_(m.weight, 1)
25 | nn.init.constant_(m.bias.data, 0.0)
26 |
27 | class GenBlock(nn.Module):
28 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, bias=True):
29 | super(GenBlock, self).__init__()
30 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias)
31 | self.bn = nn.BatchNorm2d(num_features=out_channels)
32 | self.relu = nn.ReLU(inplace=True)
33 |
34 | initialize_weights([self.conv, self.bn], 0.02)
35 |
36 | def forward(self, x):
37 | x = self.relu(self.bn(self.conv(x)))
38 | return x
39 |
40 | class DisBlock(nn.Module):
41 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False):
42 | super(DisBlock, self).__init__()
43 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
44 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias)
45 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias)
46 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
47 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True)
48 |
49 | initialize_weights([self.conv1, self.conv2], 0.1)
50 |
51 | if normalization:
52 | self.conv1 = SpectralNorm(self.conv1)
53 | self.conv2 = SpectralNorm(self.conv2)
54 |
55 | def forward(self, x):
56 | x = self.lrelu(self.bn1(self.conv1(x)))
57 | x = self.lrelu(self.bn2(self.conv2(x)))
58 | return x
59 |
60 | class Generator(nn.Module):
61 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks):
62 | super(Generator, self).__init__()
63 |
64 | # image to features
65 | self.image_to_features = GenBlock(in_channels=in_channels, out_channels=num_features)
66 |
67 | # features
68 | blocks = []
69 | for _ in range(gen_blocks):
70 | blocks.append(GenBlock(in_channels=num_features, out_channels=num_features, bias=False))
71 | self.features = nn.Sequential(*blocks)
72 |
73 | # features to image
74 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2)
75 | initialize_weights([self.features_to_image], 0.02)
76 |
77 | def forward(self, x):
78 | r = x
79 | x = self.image_to_features(x)
80 | x = self.features(x)
81 | x = self.features_to_image(x)
82 | x += r
83 | return x
84 |
85 | class Discriminator(nn.Module):
86 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks):
87 | super(Discriminator, self).__init__()
88 |
89 | # image to features
90 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
91 |
92 | # features
93 | blocks = []
94 | for i in range(0, dis_blocks - 1):
95 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
96 | self.features = nn.Sequential(*blocks)
97 |
98 | # classifier
99 | self.classifier = nn.Conv2d(in_channels=num_features * min(pow(2, dis_blocks - 1), 8), out_channels=1, kernel_size=4, padding=0)
100 |
101 | def forward(self, x):
102 | x = self.image_to_features(x)
103 | x = self.features(x)
104 | x = self.classifier(x)
105 | x = x.flatten(start_dim=1).mean(dim=-1)
106 | return x
107 |
108 | class SNDiscriminator(nn.Module):
109 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks):
110 | super(SNDiscriminator, self).__init__()
111 |
112 | # image to features
113 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
114 |
115 | # features
116 | blocks = []
117 | for i in range(0, dis_blocks - 1):
118 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
119 | self.features = nn.Sequential(*blocks)
120 |
121 | # classifier
122 | self.classifier = SpectralNorm(nn.Conv2d(in_channels=num_features * min(pow(2, dis_blocks - 1), 8), out_channels=1, kernel_size=4, padding=0))
123 |
124 | def forward(self, x):
125 | x = self.image_to_features(x)
126 | x = self.features(x)
127 | x = self.classifier(x)
128 | x = x.flatten(start_dim=1).mean(dim=-1)
129 | return x
130 |
131 | def g_dncnn(**config):
132 | config.setdefault('in_channels', 3)
133 | config.setdefault('num_features', 64)
134 | config.setdefault('gen_blocks', 8)
135 | config.setdefault('dis_blocks', 5)
136 |
137 | return Generator(**config)
138 |
139 | def d_dncnn(**config):
140 | config.setdefault('in_channels', 3)
141 | config.setdefault('num_features', 64)
142 | config.setdefault('gen_blocks', 8)
143 | config.setdefault('dis_blocks', 5)
144 |
145 | return Discriminator(**config)
--------------------------------------------------------------------------------
/denoising/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/denoising/models/modules/__init__.py
--------------------------------------------------------------------------------
/denoising/models/modules/activations.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class xUnit(nn.Module):
4 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False):
5 | super(xUnit, self).__init__()
6 | # xUnit
7 | self.features = nn.Sequential(
8 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity,
9 | nn.ReLU(),
10 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
11 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity,
12 | nn.Sigmoid()
13 | )
14 |
15 | def forward(self, x):
16 | a = self.features(x)
17 | r = x * a
18 | return r
19 |
20 | class xUnitS(nn.Module):
21 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False):
22 | super(xUnitS, self).__init__()
23 | # slim xUnit
24 | self.features = nn.Sequential(
25 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
26 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
27 | nn.Sigmoid()
28 | )
29 |
30 | def forward(self, x):
31 | a = self.features(x)
32 | r = x * a
33 | return r
34 |
35 | class xUnitD(nn.Module):
36 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False):
37 | super(xUnitD, self).__init__()
38 | # dense xUnit
39 | self.features = nn.Sequential(
40 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=1, padding=0),
41 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
42 | nn.ReLU(),
43 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
44 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
45 | nn.Sigmoid()
46 | )
47 |
48 | def forward(self, x):
49 | a = self.features(x)
50 | r = x * a
51 | return r
52 |
53 | class Identity(nn.Module):
54 | def __init__(self,):
55 | super(Identity, self).__init__()
56 |
57 | def forward(self, x):
58 | return x
--------------------------------------------------------------------------------
/denoising/models/modules/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .misc import shave_edge
5 | from ..vgg import MultiVGGFeaturesExtractor
6 |
7 | class RangeLoss(nn.Module):
8 | def __init__(self, min_value=0., max_value=1., invalidity_margins=None):
9 | super(RangeLoss, self).__init__()
10 | self.min_value = min_value
11 | self.max_value = max_value
12 | self.invalidity_margins = invalidity_margins
13 |
14 | def forward(self, inputs):
15 | if self.invalidity_margins:
16 | inputs = shave_edge(inputs, self.invalidity_margins, self.invalidity_margins)
17 | loss = (F.relu(self.min_value - inputs) + F.relu(inputs - self.max_value)).mean()
18 | return loss
19 |
20 | class PerceptualLoss(nn.Module):
21 | def __init__(self, features_to_compute, criterion=torch.nn.L1Loss(), shave_edge=None):
22 | super(PerceptualLoss, self).__init__()
23 | self.criterion = criterion
24 | self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute, shave_edge=shave_edge).eval()
25 |
26 | def forward(self, inputs, targets):
27 | inputs_fea = self.features_extractor(inputs)
28 | with torch.no_grad():
29 | targets_fea = self.features_extractor(targets)
30 |
31 | loss = 0
32 | for key in inputs_fea.keys():
33 | loss += self.criterion(inputs_fea[key], targets_fea[key].detach())
34 |
35 | return loss
36 |
37 | class StyleLoss(nn.Module):
38 | def __init__(self, features_to_compute, criterion=torch.nn.L1Loss(), shave_edge=None):
39 | super(StyleLoss, self).__init__()
40 | self.criterion = criterion
41 | self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute, use_input_norm=True, shave_edge=shave_edge).eval()
42 |
43 | def forward(self, inputs, targets):
44 | inputs_fea = self.features_extractor(inputs)
45 | with torch.no_grad():
46 | targets_fea = self.features_extractor(targets)
47 |
48 | loss = 0
49 | for key in inputs_fea.keys():
50 | inputs_gram = self._gram_matrix(inputs_fea[key])
51 | with torch.no_grad():
52 | targets_gram = self._gram_matrix(targets_fea[key]).detach()
53 |
54 | loss += self.criterion(inputs_gram, targets_gram)
55 |
56 | return loss
57 |
58 | def _gram_matrix(self, x):
59 | a, b, c, d = x.size()
60 | features = x.view(a, b, c * d)
61 | gram = features.bmm(features.transpose(1, 2))
62 | return gram.div(b * c * d)
--------------------------------------------------------------------------------
/denoising/models/modules/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class UpsampleX2(nn.Module):
6 | def __init__(self, in_channels, out_channels, kernel_size=3):
7 | super(UpsampleX2, self).__init__()
8 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=(out_channels * 4), kernel_size=kernel_size, padding=(kernel_size // 2))
9 | self.shuffler = nn.PixelShuffle(2)
10 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
11 |
12 | def forward(self, x):
13 | return self.lrelu(self.shuffler(self.conv(x)))
14 |
15 | def center_crop(x, height, width):
16 | crop_h = torch.FloatTensor([x.size()[2]]).sub(height).div(-2)
17 | crop_w = torch.FloatTensor([x.size()[3]]).sub(width).div(-2)
18 |
19 | return F.pad(x, [
20 | crop_w.ceil().int()[0], crop_w.floor().int()[0],
21 | crop_h.ceil().int()[0], crop_h.floor().int()[0],
22 | ])
23 |
24 | def shave_edge(x, shave_h, shave_w):
25 | return F.pad(x, [-shave_w, -shave_w, -shave_h, -shave_h])
26 |
27 | def shave_modulo(x, factor):
28 | shave_w = x.size(-1) % factor
29 | shave_h = x.size(-2) % factor
30 | return F.pad(x, [0, -shave_w, 0, -shave_h])
31 |
32 | if __name__ == "__main__":
33 | x = torch.randn(1, 2, 4, 6)
34 | y = shave_edge(x, 1, 2)
35 | print(x)
36 | print(y)
--------------------------------------------------------------------------------
/denoising/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 | from models.modules.misc import shave_edge
5 | from collections import OrderedDict
6 |
7 | names = {'vgg19': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
8 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
9 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
10 | 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
11 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2',
12 | 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
13 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14 | 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'],
15 |
16 | 'vgg19_bn': ['conv1_1', 'bn1_1', 'relu1_1', 'conv1_2', 'bn1_2', 'relu1_2', 'pool1',
17 | 'conv2_1', 'bn2_1', 'relu2_1', 'conv2_2', 'bn2_2', 'relu2_2', 'pool2',
18 | 'conv3_1', 'bn3_1', 'relu3_1', 'conv3_2', 'bn3_2', 'relu3_2',
19 | 'conv3_3', 'bn3_3', 'relu3_3', 'conv3_4', 'bn3_4', 'relu3_4', 'pool3',
20 | 'conv4_1', 'bn4_1', 'relu4_1', 'conv4_2', 'bn4_2', 'relu4_2',
21 | 'conv4_3', 'bn4_3', 'relu4_3', 'conv4_4', 'bn4_4', 'relu4_4', 'pool4',
22 | 'conv5_1', 'bn5_1', 'relu5_1', 'conv5_2', 'bn5_2', 'relu5_2',
23 | 'conv5_3', 'bn5_3', 'relu5_3', 'conv5_4', 'bn5_4', 'relu5_4', 'pool5']
24 | }
25 |
26 |
27 | class VGGFeaturesExtractor(nn.Module):
28 | def __init__(self, feature_layer='conv5_4', use_bn=False, use_input_norm=True, requires_grad=False):
29 | super(VGGFeaturesExtractor, self).__init__()
30 | self.use_input_norm = use_input_norm
31 |
32 | if use_bn:
33 | model = torchvision.models.vgg19_bn(pretrained=True)
34 | else:
35 | model = torchvision.models.vgg19(pretrained=True)
36 |
37 | if self.use_input_norm:
38 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
39 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
40 | self.register_buffer('mean', mean)
41 | self.register_buffer('std', std)
42 |
43 | layer_index = names['vgg19_bn'].index(feature_layer) if use_bn else names['vgg19'].index(feature_layer)
44 | self.features = nn.Sequential(*list(model.features.children())[:(layer_index + 1)])
45 |
46 | if not requires_grad:
47 | for k, v in self.features.named_parameters():
48 | v.requires_grad = False
49 | self.features.eval()
50 |
51 | def forward(self, x):
52 | # Assume input range is [0, 1]
53 | if self.use_input_norm:
54 | x = (x - self.mean) / self.std
55 | output = self.features(x)
56 | return output
57 |
58 | class MultiVGGFeaturesExtractor(nn.Module):
59 | def __init__(self, target_features=('relu1_1', 'relu2_1', 'relu3_1'), use_bn=False, use_input_norm=True, requires_grad=False, shave_edge=None):
60 | super(MultiVGGFeaturesExtractor, self).__init__()
61 | self.use_input_norm = use_input_norm
62 | self.target_features = target_features
63 | self.shave_edge = shave_edge
64 |
65 | if use_bn:
66 | model = torchvision.models.vgg19_bn(pretrained=True)
67 | names_key = 'vgg19_bn'
68 | else:
69 | model = torchvision.models.vgg19(pretrained=True)
70 | names_key = 'vgg19'
71 |
72 | if self.use_input_norm:
73 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
75 | self.register_buffer('mean', mean)
76 | self.register_buffer('std', std)
77 |
78 | self.target_indexes = [names[names_key].index(k) for k in self.target_features]
79 | self.features = nn.Sequential(*list(model.features.children())[:(max(self.target_indexes) + 1)])
80 |
81 | if not requires_grad:
82 | for k, v in self.features.named_parameters():
83 | v.requires_grad = False
84 | self.features.eval()
85 |
86 | def forward(self, x):
87 | if self.shave_edge:
88 | x = shave_edge(x, self.shave_edge, self.shave_edge)
89 |
90 | # assume input range is [0, 1]
91 | if self.use_input_norm:
92 | x = (x - self.mean) / self.std
93 |
94 | output = OrderedDict()
95 | for key, layer in self.features._modules.items():
96 | x = layer(x)
97 | if int(key) in self.target_indexes:
98 | output.update({self.target_features[self.target_indexes.index(int(key))]: x})
99 | return output
--------------------------------------------------------------------------------
/denoising/models/xdense.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
5 | from .modules.activations import xUnitD
6 | from .modules.misc import center_crop
7 |
8 | __all__ = ['g_xdense', 'd_xdense']
9 |
10 | def initialize_weights(net, scale=1.):
11 | if not isinstance(net, list):
12 | net = [net]
13 | for layer in net:
14 | for m in layer.modules():
15 | if isinstance(m, nn.Conv2d):
16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
17 | m.weight.data *= scale # for residual block
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.Linear):
21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
22 | m.weight.data *= scale
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm2d):
26 | nn.init.constant_(m.weight, 1)
27 | nn.init.constant_(m.bias.data, 0.0)
28 |
29 | class xModule(nn.Module):
30 | def __init__(self, in_channels, out_channels):
31 | super(xModule, self).__init__()
32 | # features
33 | self.features = nn.Sequential(
34 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
35 | xUnitD(num_features=out_channels, batch_norm=True),
36 | nn.BatchNorm2d(num_features=out_channels)
37 | )
38 |
39 | def forward(self, x):
40 | x = self.features(x)
41 | return x
42 |
43 | class xDenseLayer(nn.Module):
44 | def __init__(self, in_channels, growth_rate, bn_size):
45 | super(xDenseLayer, self).__init__()
46 | # features
47 | self.features = nn.Sequential(
48 | nn.BatchNorm2d(num_features=in_channels),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(in_channels=in_channels,out_channels=bn_size * growth_rate, kernel_size=1, stride=1, bias=False),
51 | nn.BatchNorm2d(num_features=bn_size * growth_rate),
52 | nn.ReLU(inplace=True),
53 | xModule(in_channels=bn_size * growth_rate, out_channels=growth_rate)
54 | )
55 |
56 | def forward(self, x):
57 | f = self.features(x)
58 | return torch.cat([x, f], dim=1)
59 |
60 | class xDenseBlock(nn.Module):
61 | def __init__(self, in_channels, num_layers, growth_rate, bn_size):
62 | super(xDenseBlock, self).__init__()
63 | # features
64 | blocks = []
65 | for i in range(num_layers):
66 | blocks.append(xDenseLayer(in_channels=in_channels + growth_rate*i, growth_rate=growth_rate, bn_size=bn_size))
67 | self.features = nn.Sequential(*blocks)
68 |
69 | def forward(self, x):
70 | x = self.features(x)
71 | return x
72 |
73 | class Transition(nn.Module):
74 | def __init__(self, in_channels, out_channels):
75 | super(Transition, self).__init__()
76 | # features
77 | self.features = nn.Sequential(
78 | nn.BatchNorm2d(num_features=in_channels),
79 | nn.ReLU(inplace=True),
80 | nn.Conv2d(in_channels=in_channels,out_channels=out_channels, kernel_size=1, stride=1, bias=False),
81 | )
82 |
83 | def forward(self, x):
84 | x = self.features(x)
85 | return x
86 |
87 | class DisBlock(nn.Module):
88 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False):
89 | super(DisBlock, self).__init__()
90 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
91 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias)
92 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias)
93 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
94 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True)
95 |
96 | initialize_weights([self.conv1, self.conv2], 0.1)
97 |
98 | if normalization:
99 | self.conv1 = SpectralNorm(self.conv1)
100 | self.conv2 = SpectralNorm(self.conv2)
101 |
102 | def forward(self, x):
103 | x = self.lrelu(self.bn1(self.conv1(x)))
104 | x = self.lrelu(self.bn2(self.conv2(x)))
105 | return x
106 |
107 | class Generator(nn.Module):
108 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size):
109 | super(Generator, self).__init__()
110 |
111 | # image to features
112 | self.image_to_features = xModule(in_channels=in_channels, out_channels=num_features)
113 |
114 | # features
115 | blocks = []
116 | self.num_features = num_features
117 | for i, num_layers in enumerate(gen_blocks):
118 | blocks.append(xDenseBlock(in_channels=self.num_features, num_layers=num_layers, growth_rate=growth_rate, bn_size=bn_size))
119 | self.num_features += num_layers * growth_rate
120 |
121 | if i != len(gen_blocks) - 1:
122 | blocks.append(Transition(in_channels=self.num_features, out_channels=self.num_features // 2))
123 | self.num_features = self.num_features // 2
124 |
125 | self.features = nn.Sequential(*blocks)
126 |
127 | # features to image
128 | self.features_to_image = nn.Conv2d(in_channels=self.num_features, out_channels=in_channels, kernel_size=5, padding=2)
129 |
130 | def forward(self, x):
131 | r = x
132 | x = self.image_to_features(x)
133 | x = self.features(x)
134 | x = self.features_to_image(x)
135 | x += r
136 | return x
137 |
138 | class Discriminator(nn.Module):
139 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size):
140 | super(Discriminator, self).__init__()
141 | self.crop_size = 4 * pow(2, dis_blocks)
142 |
143 | # image to features
144 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
145 |
146 | # features
147 | blocks = []
148 | for i in range(0, dis_blocks - 1):
149 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
150 | self.features = nn.Sequential(*blocks)
151 |
152 | # classifier
153 | self.classifier = nn.Sequential(
154 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100),
155 | nn.LeakyReLU(negative_slope=0.1),
156 | nn.Linear(100, 1)
157 | )
158 |
159 | def forward(self, x):
160 | x = center_crop(x, self.crop_size, self.crop_size)
161 | x = self.image_to_features(x)
162 | x = self.features(x)
163 | x = x.flatten(start_dim=1)
164 | x = self.classifier(x)
165 | return x
166 |
167 | class SNDiscriminator(nn.Module):
168 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size):
169 | super(SNDiscriminator, self).__init__()
170 | self.crop_size = 4 * pow(2, dis_blocks)
171 |
172 | # image to features
173 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
174 |
175 | # features
176 | blocks = []
177 | for i in range(0, dis_blocks - 1):
178 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
179 | self.features = nn.Sequential(*blocks)
180 |
181 | # classifier
182 | self.classifier = nn.Sequential(
183 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)),
184 | nn.LeakyReLU(negative_slope=0.1),
185 | SpectralNorm(nn.Linear(100, 1))
186 | )
187 |
188 | def forward(self, x):
189 | x = center_crop(x, self.crop_size, self.crop_size)
190 | x = self.image_to_features(x)
191 | x = self.features(x)
192 | x = x.flatten(start_dim=1)
193 | x = self.classifier(x)
194 | return x
195 |
196 | def g_xdense(**config):
197 | config.setdefault('in_channels', 3)
198 | config.setdefault('num_features', 64)
199 | config.setdefault('gen_blocks', [4, 6, 8])
200 | config.setdefault('dis_blocks', 4)
201 | config.setdefault('growth_rate', 16)
202 | config.setdefault('bn_size', 2)
203 |
204 | _ = config.pop('spectral', False)
205 |
206 | return Generator(**config)
207 |
208 | def d_xdense(**config):
209 | config.setdefault('in_channels', 3)
210 | config.setdefault('num_features', 64)
211 | config.setdefault('gen_blocks', [4, 6, 8])
212 | config.setdefault('dis_blocks', 4)
213 | config.setdefault('growth_rate', 16)
214 | config.setdefault('bn_size', 2)
215 |
216 | sn = config.pop('spectral', False)
217 |
218 | if sn:
219 | return SNDiscriminator(**config)
220 | else:
221 | return Discriminator(**config)
222 |
--------------------------------------------------------------------------------
/denoising/models/xdncnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
5 | from .modules.activations import xUnitS
6 | from .modules.misc import center_crop
7 |
8 | __all__ = ['g_xdncnn', 'd_xdncnn']
9 |
10 | def initialize_weights(net, scale=1.):
11 | if not isinstance(net, list):
12 | net = [net]
13 | for layer in net:
14 | for m in layer.modules():
15 | if isinstance(m, nn.Conv2d):
16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
17 | m.weight.data *= scale # for residual block
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.Linear):
21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
22 | m.weight.data *= scale
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm2d):
26 | nn.init.constant_(m.weight, 1)
27 | nn.init.constant_(m.bias.data, 0.0)
28 |
29 | class GenBlock(nn.Module):
30 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, bias=True):
31 | super(GenBlock, self).__init__()
32 | self.bn = nn.BatchNorm2d(num_features=in_channels)
33 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias)
34 | self.xunit = xUnitS(num_features=out_channels, batch_norm=True)
35 |
36 | initialize_weights([self.conv, self.bn], 0.02)
37 |
38 | def forward(self, x):
39 | x = self.xunit(self.conv(self.bn(x)))
40 | return x
41 |
42 | class DisBlock(nn.Module):
43 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False):
44 | super(DisBlock, self).__init__()
45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
46 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias)
47 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias)
48 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
49 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True)
50 |
51 | initialize_weights([self.conv1, self.conv2], 0.1)
52 |
53 | if normalization:
54 | self.conv1 = SpectralNorm(self.conv1)
55 | self.conv2 = SpectralNorm(self.conv2)
56 |
57 | def forward(self, x):
58 | x = self.lrelu(self.bn1(self.conv1(x)))
59 | x = self.lrelu(self.bn2(self.conv2(x)))
60 | return x
61 |
62 | class Generator(nn.Module):
63 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks):
64 | super(Generator, self).__init__()
65 |
66 | # image to features
67 | self.image_to_features = GenBlock(in_channels=in_channels, out_channels=num_features)
68 |
69 | # features
70 | blocks = []
71 | for _ in range(gen_blocks):
72 | blocks.append(GenBlock(in_channels=num_features, out_channels=num_features))
73 | self.features = nn.Sequential(*blocks)
74 |
75 | # features to image
76 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2)
77 | initialize_weights([self.features_to_image], 0.02)
78 |
79 | def forward(self, x):
80 | r = x
81 | x = self.image_to_features(x)
82 | x = self.features(x)
83 | x = self.features_to_image(x)
84 | x += r
85 | return x
86 |
87 | class Discriminator(nn.Module):
88 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks):
89 | super(Discriminator, self).__init__()
90 | self.crop_size = 4 * pow(2, dis_blocks)
91 |
92 | # image to features
93 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
94 |
95 | # features
96 | blocks = []
97 | for i in range(0, dis_blocks - 1):
98 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
99 | self.features = nn.Sequential(*blocks)
100 |
101 | # classifier
102 | self.classifier = nn.Sequential(
103 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100),
104 | nn.LeakyReLU(negative_slope=0.1),
105 | nn.Linear(100, 1)
106 | )
107 |
108 | def forward(self, x):
109 | x = center_crop(x, self.crop_size, self.crop_size)
110 | x = self.image_to_features(x)
111 | x = self.features(x)
112 | x = x.flatten(start_dim=1)
113 | x = self.classifier(x)
114 | return x
115 |
116 | class SNDiscriminator(nn.Module):
117 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks):
118 | super(SNDiscriminator, self).__init__()
119 | self.crop_size = 4 * pow(2, dis_blocks)
120 |
121 | # image to features
122 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
123 |
124 | # features
125 | blocks = []
126 | for i in range(0, dis_blocks - 1):
127 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
128 | self.features = nn.Sequential(*blocks)
129 |
130 | # classifier
131 | self.classifier = nn.Sequential(
132 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)),
133 | nn.LeakyReLU(negative_slope=0.1),
134 | SpectralNorm(nn.Linear(100, 1))
135 | )
136 |
137 | def forward(self, x):
138 | x = center_crop(x, self.crop_size, self.crop_size)
139 | x = self.image_to_features(x)
140 | x = self.features(x)
141 | x = x.flatten(start_dim=1)
142 | x = self.classifier(x)
143 | return x
144 |
145 | def g_xdncnn(**config):
146 | config.setdefault('in_channels', 3)
147 | config.setdefault('num_features', 64)
148 | config.setdefault('gen_blocks', 8)
149 | config.setdefault('dis_blocks', 4)
150 |
151 | _ = config.pop('spectral', False)
152 |
153 | return Generator(**config)
154 |
155 | def d_xdncnn(**config):
156 | config.setdefault('in_channels', 3)
157 | config.setdefault('num_features', 64)
158 | config.setdefault('gen_blocks', 8)
159 | config.setdefault('dis_blocks', 4)
160 |
161 | sn = config.pop('spectral', False)
162 |
163 | if sn:
164 | return SNDiscriminator(**config)
165 | else:
166 | return Discriminator(**config)
167 |
--------------------------------------------------------------------------------
/denoising/trainer.py:
--------------------------------------------------------------------------------
1 | import random
2 | import logging
3 | import models
4 | import os
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.optim.lr_scheduler import StepLR
8 | from torch.autograd import grad as torch_grad, Variable
9 | from data import get_loaders
10 | from ast import literal_eval
11 | from utils.recorderx import RecoderX
12 | from utils.misc import save_image, average, mkdir, compute_psnr
13 | from models.modules.losses import RangeLoss, PerceptualLoss, StyleLoss
14 |
15 | class Trainer():
16 | def __init__(self, args):
17 | # parameters
18 | self.args = args
19 | self.print_model = True
20 | self.invalidity_margins = None
21 |
22 | if self.args.use_tb:
23 | self.tb = RecoderX(log_dir=args.save_path)
24 |
25 | # initialize
26 | self._init()
27 |
28 | def _init_model(self):
29 | # initialize model
30 | if self.args.model_config != '':
31 | model_config = dict({}, **literal_eval(self.args.model_config))
32 | else:
33 | model_config = {}
34 |
35 | g_model = models.__dict__[self.args.g_model]
36 | d_model = models.__dict__[self.args.d_model]
37 | self.g_model = g_model(**model_config)
38 | self.d_model = d_model(**model_config)
39 |
40 | # loading weights
41 | if self.args.gen_to_load != '':
42 | logging.info('\nLoading g-model...')
43 | self.g_model.load_state_dict(torch.load(self.args.gen_to_load, map_location='cpu'))
44 | if self.args.dis_to_load != '':
45 | logging.info('\nLoading d-model...')
46 | self.d_model.load_state_dict(torch.load(self.args.dis_to_load, map_location='cpu'))
47 |
48 | # to cuda
49 | self.g_model = self.g_model.to(self.args.device)
50 | self.d_model = self.d_model.to(self.args.device)
51 |
52 | # parallel
53 | if self.args.device_ids and len(self.args.device_ids) > 1:
54 | self.g_model = torch.nn.DataParallel(self.g_model, self.args.device_ids)
55 | self.d_model = torch.nn.DataParallel(self.d_model, self.args.device_ids)
56 |
57 | # print model
58 | if self.print_model:
59 | logging.info(self.g_model)
60 | logging.info('Number of parameters in generator: {}\n'.format(sum([l.nelement() for l in self.g_model.parameters()])))
61 | logging.info(self.d_model)
62 | logging.info('Number of parameters in discriminator: {}\n'.format(sum([l.nelement() for l in self.d_model.parameters()])))
63 | self.print_model = False
64 |
65 | def _init_optim(self):
66 | # initialize optimizer
67 | self.g_optimizer = torch.optim.Adam(self.g_model.parameters(), lr=self.args.lr, betas=self.args.gen_betas)
68 | self.d_optimizer = torch.optim.Adam(self.d_model.parameters(), lr=self.args.lr, betas=self.args.dis_betas)
69 |
70 | # initialize scheduler
71 | self.g_scheduler = StepLR(self.g_optimizer, step_size=self.args.step_size, gamma=self.args.gamma)
72 | self.d_scheduler = StepLR(self.d_optimizer, step_size=self.args.step_size, gamma=self.args.gamma)
73 |
74 | # initialize criterion
75 | if self.args.reconstruction_weight:
76 | self.reconstruction = torch.nn.L1Loss().to(self.args.device)
77 | if self.args.perceptual_weight > 0.:
78 | self.perceptual = PerceptualLoss(features_to_compute=['conv5_4'], criterion=torch.nn.L1Loss(), shave_edge=self.invalidity_margins).to(self.args.device)
79 | if self.args.style_weight > 0.:
80 | self.style = StyleLoss(features_to_compute=['relu3_1', 'relu2_1'], shave_edge=self.invalidity_margins).to(self.args.device)
81 | if self.args.range_weight > 0.:
82 | self.range = RangeLoss(invalidity_margins=self.invalidity_margins).to(self.args.device)
83 |
84 | def _init(self):
85 | # init parameters
86 | self.step = 0
87 | self.losses = {'D': [], 'D_r': [], 'D_gp': [], 'D_f': [], 'G': [], 'G_recon': [], 'G_rng': [], 'G_perc': [], 'G_sty': [], 'G_adv': [], 'psnr': []}
88 |
89 | # initialize model
90 | self._init_model()
91 |
92 | # initialize optimizer
93 | self._init_optim()
94 |
95 | def _save_model(self, epoch):
96 | # save models
97 | torch.save(self.g_model.state_dict(), os.path.join(self.args.save_path, '{}_e{}.pt'.format(self.args.g_model, epoch + 1)))
98 | torch.save(self.d_model.state_dict(), os.path.join(self.args.save_path, '{}_e{}.pt'.format(self.args.d_model, epoch + 1)))
99 |
100 | def _set_require_grads(self, model, require_grad):
101 | for p in model.parameters():
102 | p.requires_grad_(require_grad)
103 |
104 | def _critic_hinge_iteration(self, inputs, targets):
105 | # require grads
106 | self._set_require_grads(self.d_model, True)
107 |
108 | # get generated data
109 | generated_data = self.g_model(inputs)
110 |
111 | # zero grads
112 | self.d_optimizer.zero_grad()
113 |
114 | # calculate probabilities on real and generated data
115 | d_real = self.d_model(targets)
116 | d_generated = self.d_model(generated_data.detach())
117 |
118 | # create total loss and optimize
119 | loss_r = F.relu(1.0 - d_real).mean()
120 | loss_f = F.relu(1.0 + d_generated).mean()
121 | loss = loss_r + loss_f
122 |
123 | # get gradient penalty
124 | if self.args.penalty_weight > 0.:
125 | gradient_penalty = self._gradient_penalty(targets, generated_data)
126 | loss += gradient_penalty
127 |
128 | loss.backward()
129 |
130 | self.d_optimizer.step()
131 |
132 | # record loss
133 | self.losses['D'].append(loss.data.item())
134 | self.losses['D_r'].append(loss_r.data.item())
135 | self.losses['D_f'].append(loss_f.data.item())
136 | if self.args.penalty_weight > 0.:
137 | self.losses['D_gp'].append(gradient_penalty.data.item())
138 |
139 | # require grads
140 | self._set_require_grads(self.d_model, False)
141 |
142 | def _critic_wgan_iteration(self, inputs, targets):
143 | # require grads
144 | self._set_require_grads(self.d_model, True)
145 |
146 | # get generated data
147 | generated_data = self.g_model(inputs)
148 |
149 | # zero grads
150 | self.d_optimizer.zero_grad()
151 |
152 | # calculate probabilities on real and generated data
153 | d_real = self.d_model(targets)
154 | d_generated = self.d_model(generated_data.detach())
155 |
156 | # create total loss and optimize
157 | if self.args.relativistic:
158 | loss_r = -(d_real - d_generated.mean()).mean()
159 | loss_f = (d_generated - d_real.mean()).mean()
160 | else:
161 | loss_r = -d_real.mean()
162 | loss_f = d_generated.mean()
163 | loss = loss_f + loss_r
164 |
165 | # get gradient penalty
166 | if self.args.penalty_weight > 0.:
167 | gradient_penalty = self._gradient_penalty(targets, generated_data)
168 | loss += gradient_penalty
169 |
170 | loss.backward()
171 |
172 | self.d_optimizer.step()
173 |
174 | # record loss
175 | self.losses['D'].append(loss.data.item())
176 | self.losses['D_r'].append(loss_r.data.item())
177 | self.losses['D_f'].append(loss_f.data.item())
178 | if self.args.penalty_weight > 0.:
179 | self.losses['D_gp'].append(gradient_penalty.data.item())
180 |
181 | # require grads
182 | self._set_require_grads(self.d_model, False)
183 |
184 | def _gradient_penalty(self, real_data, generated_data):
185 | batch_size = real_data.size()[0]
186 |
187 | # calculate interpolation
188 | alpha = torch.rand(batch_size, 1, 1, 1)
189 | alpha = alpha.expand_as(real_data)
190 | alpha = alpha.to(self.args.device)
191 | interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
192 | interpolated = Variable(interpolated, requires_grad=True)
193 | interpolated = interpolated.to(self.args.device)
194 |
195 | # calculate probability of interpolated examples
196 | prob_interpolated = self.d_model(interpolated)
197 |
198 | # calculate gradients of probabilities with respect to examples
199 | gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
200 | grad_outputs=torch.ones(prob_interpolated.size()).to(self.args.device),
201 | create_graph=True, retain_graph=True)[0]
202 |
203 | # gradients have shape (batch_size, num_channels, img_width, img_height),
204 | # so flatten to easily take norm per example in batch
205 | gradients = gradients.view(batch_size, -1)
206 |
207 | # derivatives of the gradient close to 0 can cause problems because of
208 | # the square root, so manually calculate norm and add epsilon
209 | gradients_norm = gradients.norm(p=2, dim=1)
210 |
211 | # return gradient penalty
212 | return ((gradients_norm - 1) ** 2).mean()
213 |
214 | def _generator_iteration(self, inputs, targets):
215 | # zero grads
216 | self.g_optimizer.zero_grad()
217 |
218 | # get generated data
219 | generated_data = self.g_model(inputs)
220 | loss = 0.
221 |
222 | # reconstruction loss
223 | if self.args.reconstruction_weight > 0.:
224 | loss_recon = self.reconstruction(generated_data, targets)
225 | loss += loss_recon * self.args.reconstruction_weight
226 | self.losses['G_recon'].append(loss_recon.data.item())
227 |
228 | # range loss
229 | if self.args.range_weight > 0.:
230 | loss_rng = self.range(generated_data)
231 | loss += loss_rng * self.args.range_weight
232 | self.losses['G_rng'].append(loss_rng.data.item())
233 |
234 | # adversarial loss
235 | if self.args.adversarial_weight > 0.:
236 | d_generated = self.d_model(generated_data)
237 | if self.args.relativistic:
238 | d_real = self.d_model(targets)
239 | loss_adv = (d_real - d_generated.mean()).mean() - (d_generated - d_real.mean()).mean()
240 | else:
241 | loss_adv = -d_generated.mean()
242 | loss += loss_adv * self.args.adversarial_weight
243 | self.losses['G_adv'].append(loss_adv.data.item())
244 |
245 | # perceptual loss
246 | if self.args.perceptual_weight > 0.:
247 | loss_perc = self.perceptual(generated_data, targets)
248 | loss += loss_perc * self.args.perceptual_weight
249 | self.losses['G_perc'].append(loss_perc.data.item())
250 |
251 | # style loss
252 | if self.args.style_weight > 0.:
253 | loss_sty = self.style(generated_data, targets)
254 | loss += loss_sty * self.args.style_weight
255 | self.losses['G_sty'].append(loss_sty.data.item())
256 |
257 | # backward loss
258 | loss.backward()
259 | self.g_optimizer.step()
260 |
261 | # record loss
262 | self.losses['G'].append(loss.data.item())
263 |
264 | def _train_iteration(self, data):
265 | # set inputs
266 | inputs = data['input'].to(self.args.device)
267 | targets = data['target'].to(self.args.device)
268 |
269 | # critic iteration
270 | if self.args.adversarial_weight > 0.:
271 | if self.args.wgan:
272 | self._critic_wgan_iteration(inputs, targets)
273 | else:
274 | self._critic_hinge_iteration(inputs, targets)
275 |
276 | # only update generator every |critic_iterations| iterations
277 | if self.step % self.args.num_critic == 0:
278 | self._generator_iteration(inputs, targets)
279 |
280 | # logging
281 | if self.step % self.args.print_every == 0:
282 | line2print = 'Iteration {}'.format(self.step)
283 | if self.args.adversarial_weight > 0.:
284 | line2print += ', D: {:.6f}, D_r: {:.6f}, D_f: {:.6f}'.format(self.losses['D'][-1], self.losses['D_r'][-1], self.losses['D_f'][-1])
285 | if self.args.penalty_weight > 0.:
286 | line2print += ', D_gp: {:.6f}'.format(self.losses['D_gp'][-1])
287 | if self.step > self.args.num_critic:
288 | line2print += ', G: {:.5f}'.format(self.losses['G'][-1])
289 | if self.args.reconstruction_weight:
290 | line2print += ', G_recon: {:.6f}'.format(self.losses['G_recon'][-1])
291 | if self.args.range_weight:
292 | line2print += ', G_rng: {:.6f}'.format(self.losses['G_rng'][-1])
293 | if self.args.perceptual_weight:
294 | line2print += ', G_perc: {:.6f}'.format(self.losses['G_perc'][-1])
295 | if self.args.style_weight:
296 | line2print += ', G_sty: {:.8f}'.format(self.losses['G_sty'][-1])
297 | if self.args.adversarial_weight:
298 | line2print += ', G_adv: {:.6f},'.format(self.losses['G_adv'][-1])
299 | logging.info(line2print)
300 |
301 | # plots for tensorboard
302 | if self.args.use_tb:
303 | if self.args.adversarial_weight > 0.:
304 | self.tb.add_scalar('data/loss_d', self.losses['D'][-1], self.step)
305 | if self.step > self.args.num_critic:
306 | self.tb.add_scalar('data/loss_g', self.losses['G'][-1], self.step)
307 |
308 | def _eval_iteration(self, data, epoch):
309 | # set inputs
310 | inputs = data['input'].to(self.args.device)
311 | targets = data['target']
312 | paths = data['path']
313 |
314 | # evaluation
315 | with torch.no_grad():
316 | outputs = self.g_model(inputs)
317 |
318 | # save image and compute psnr
319 | self._save_image(outputs, paths[0], epoch + 1)
320 | psnr = compute_psnr(outputs, targets)
321 |
322 | return psnr
323 |
324 | def _train_epoch(self, loader):
325 | self.g_model.train()
326 | self.d_model.train()
327 |
328 | # train over epochs
329 | for _, data in enumerate(loader):
330 | self._train_iteration(data)
331 | self.step += 1
332 |
333 | def _eval_epoch(self, loader, epoch):
334 | self.g_model.eval()
335 | psnrs = []
336 |
337 | # eval over epoch
338 | for _, data in enumerate(loader):
339 | psnr = self._eval_iteration(data, epoch)
340 | psnrs.append(psnr)
341 |
342 | # record psnr
343 | self.losses['psnr'].append(average(psnrs))
344 | logging.info('Evaluation: {:.3f}'.format(self.losses['psnr'][-1]))
345 | if self.args.use_tb:
346 | self.tb.add_scalar('data/psnr', self.losses['psnr'][-1], epoch)
347 |
348 | def _save_image(self, image, path, epoch):
349 | directory = os.path.join(self.args.save_path, 'images', 'epoch_{}'.format(epoch))
350 | save_path = os.path.join(directory, os.path.basename(path))
351 | mkdir(directory)
352 | save_image(image.data.cpu(), save_path)
353 |
354 | def _train(self, loaders):
355 | # run epoch iterations
356 | for epoch in range(self.args.epochs):
357 | # random seed
358 | torch.manual_seed(random.randint(1, 123456789))
359 |
360 | logging.info('\nEpoch {}'.format(epoch + 1))
361 |
362 | # train
363 | self._train_epoch(loaders['train'])
364 |
365 | # scheduler
366 | self.g_scheduler.step(epoch=epoch)
367 | self.d_scheduler.step(epoch=epoch)
368 |
369 | # evaluation
370 | if ((epoch + 1) % self.args.eval_every == 0) or ((epoch + 1) == self.args.epochs):
371 | self._eval_epoch(loaders['eval'], epoch)
372 | self._save_model(epoch)
373 |
374 | # best score
375 | logging.info('Best PSNR Score: {:.2f}\n'.format(max(self.losses['psnr'])))
376 |
377 | def train(self):
378 | # get loader
379 | loaders = get_loaders(self.args)
380 |
381 | # run training
382 | self._train(loaders)
383 |
384 | # close tensorboard
385 | if self.args.use_tb:
386 | self.tb.close()
387 |
388 | def eval(self):
389 | # get loader
390 | loaders = get_loaders(self.args)
391 |
392 | # evaluation
393 | logging.info('\nEvaluating...')
394 | self._eval_epoch(loaders['eval'], 0)
395 |
396 | # close tensorboard
397 | if self.args.use_tb:
398 | self.tb.close()
399 |
--------------------------------------------------------------------------------
/denoising/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/denoising/utils/__init__.py
--------------------------------------------------------------------------------
/denoising/utils/misc.py:
--------------------------------------------------------------------------------
1 | import logging.config
2 | import os
3 | import matplotlib.pyplot as plt
4 | import torch
5 | import math
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision.utils import make_grid
9 | from os import path, makedirs
10 |
11 | def mkdir(save_path):
12 | if not path.exists(save_path):
13 | makedirs(save_path)
14 |
15 | def make_image_grid(x, nrow, padding=0, pad_value=0):
16 | x = x.clone().cpu().data
17 | grid = make_grid(x, nrow=nrow, padding=padding, normalize=True, scale_each=False, pad_value=pad_value)
18 | return grid
19 |
20 | def tensor_to_image(x):
21 | ndarr = x.squeeze(dim=0).mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).squeeze().to('cpu', torch.uint8).numpy()
22 | image = Image.fromarray(ndarr)
23 | return image
24 |
25 | def plot_image_grid(x, nrow, padding=0):
26 | grid = make_image_grid(x=x, nrow=nrow, padding=padding).permute(1, 2, 0).numpy()
27 | plt.imshow(grid)
28 | plt.show()
29 |
30 | def save_image(x, path, size=None):
31 | image = tensor_to_image(x)
32 | if size:
33 | image = image.resize((size, size), Image.NEAREST)
34 | image.save(path)
35 |
36 | def save_image_grid(x, path, nrow=8, size=None):
37 | grid = make_image_grid(x, nrow)
38 | save_image(grid, path, size=size)
39 |
40 | def setup_logging(log_file='log.txt', resume=False, dummy=False):
41 | if dummy:
42 | logging.getLogger('dummy')
43 | else:
44 | if os.path.isfile(log_file) and resume:
45 | file_mode = 'a'
46 | else:
47 | file_mode = 'w'
48 |
49 | root_logger = logging.getLogger()
50 | if root_logger.handlers:
51 | root_logger.removeHandler(root_logger.handlers[0])
52 | logging.basicConfig(level=logging.INFO,
53 | format="%(asctime)s - %(levelname)s - %(message)s",
54 | datefmt="%Y-%m-%d %H:%M:%S",
55 | filename=log_file,
56 | filemode=file_mode)
57 | console = logging.StreamHandler()
58 | console.setLevel(logging.INFO)
59 | formatter = logging.Formatter('%(message)s')
60 | console.setFormatter(formatter)
61 | logging.getLogger('').addHandler(console)
62 |
63 | def average(lst):
64 | return sum(lst) / len(lst)
65 |
66 | def rgb2yc(img):
67 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
68 | rlt = rlt.round()
69 | return rlt
70 |
71 | def compute_psnr(x, y):
72 | x = np.array(tensor_to_image(x)).astype(np.float64)
73 | y = np.array(tensor_to_image(y)).astype(np.float64)
74 |
75 | mse = np.mean((x - y) ** 2)
76 | return 20 * math.log10(255.0 / math.sqrt(mse))
77 |
78 | if __name__ == "__main__":
79 | print('None')
80 |
81 |
--------------------------------------------------------------------------------
/denoising/utils/optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | def get_exp_scheduler_with_warmup(optimizer, rampup_steps=5, sustain_steps=5):
5 | def lr_lambda(step):
6 | if step < rampup_steps:
7 | return min(1., 1.8 ** ((step - rampup_steps)))
8 | elif step < rampup_steps + sustain_steps:
9 | return 1.
10 | else:
11 | return max(0.1, 0.85 ** (step - rampup_steps - sustain_steps))
12 |
13 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
14 |
15 | def get_cosine_scheduler_with_warmup(optimizer, rampup_steps=3, sustain_steps=2, priod=5):
16 | def lr_lambda(step):
17 | if step < rampup_steps:
18 | return min(1., 1.5 ** ((step - rampup_steps)))
19 | elif step < rampup_steps + sustain_steps:
20 | return 1.
21 | else:
22 | return max(0.1, (1 + math.cos(math.pi * (step - rampup_steps - sustain_steps) / priod)) / 2)
23 |
24 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
25 |
26 | if __name__ == "__main__":
27 | import matplotlib.pyplot as plt
28 |
29 | model = torch.nn.Linear(2, 1)
30 | optimizer = torch.optim.SGD(model.parameters(), lr=2.5e-4)
31 | lr_scheduler = get_exp_scheduler_with_warmup(optimizer)
32 |
33 | lrs = []
34 | for i in range(25):
35 | lr_scheduler.step()
36 | lrs.append(optimizer.param_groups[0]["lr"])
37 |
38 | plt.plot(lrs)
39 | plt.show()
40 | print(lrs)
41 |
--------------------------------------------------------------------------------
/denoising/utils/recorderx.py:
--------------------------------------------------------------------------------
1 | import utils.misc as misc
2 | from tensorboardX import SummaryWriter
3 |
4 | class RecoderX():
5 | def __init__(self, log_dir):
6 | self.log_dir = log_dir
7 | self.writer = SummaryWriter(logdir=log_dir)
8 | self.log = ''
9 |
10 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
11 | self.writer.add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step, walltime=walltime)
12 |
13 | def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
14 | self.writer.add_scalars(main_tag=main_tag, tag_scalar_dict=tag_scalar_dict, global_step=global_step, walltime=walltime)
15 |
16 | def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
17 | self.writer.add_image(tag=tag, img_tensor=img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)
18 |
19 | def add_image_grid(self, tag, img_tensor, nrow=8, global_step=None, padding=0, pad_value=0,walltime=None, dataformats='CHW'):
20 | grid = misc.make_image_grid(img_tensor, nrow, padding=padding, pad_value=pad_value)
21 | self.writer.add_image(tag=tag, img_tensor=grid, global_step=global_step, walltime=walltime, dataformats=dataformats)
22 |
23 | def add_graph(self, graph_profile, walltime=None):
24 | self.writer.add_graph(graph_profile, walltime=walltime)
25 |
26 | def add_histogram(self, tag, values, global_step=None):
27 | self.writer.add_histogram(tag, values, global_step)
28 |
29 | def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
30 | self.writer.add_figure(tag, figure, global_step=global_step, close=close, walltime=walltime)
31 |
32 | def export_json(self, out_file):
33 | self.writer.export_scalars_to_json(out_file)
34 |
35 | def close(self):
36 | self.writer.close()
37 |
38 | if __name__ == "__main__":
39 | print('None')
40 |
--------------------------------------------------------------------------------
/figures/activations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/figures/activations.png
--------------------------------------------------------------------------------
/figures/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/figures/figure1.png
--------------------------------------------------------------------------------
/figures/tradeoff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/figures/tradeoff.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | scikit-image
3 | scikit-learn
4 | scipy
5 | numpy
6 | torch
7 | torchvision
8 | tensorboardx
9 |
--------------------------------------------------------------------------------
/super-resolution/data/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .datasets import DatasetSR
3 | from torch.utils.data import DataLoader
4 |
5 | def get_loaders(args):
6 | # datasets
7 | dataset_train = DatasetSR(root=os.path.join(args.root, 'train'), scale=args.scale, training=True, crop_size=args.crop_size)
8 | dataset_val = DatasetSR(root=os.path.join(args.root, 'val'), scale=args.scale, training=False, max_size=args.max_size)
9 |
10 | # loaders
11 | loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
12 | loader_eval = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1)
13 | loaders = {'train': loader_train, 'eval': loader_eval}
14 |
15 | return loaders
16 |
--------------------------------------------------------------------------------
/super-resolution/data/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import glob
3 | import os
4 | import random
5 | from torchvision import transforms
6 | from PIL import Image
7 |
8 | class DatasetSR(torch.utils.data.dataset.Dataset):
9 | def __init__(self, root='', scale=4, training=True, crop_size=60, max_size=None):
10 | self.root = root
11 | self.scale = scale if (scale % 1) else int(scale)
12 | self.training = training
13 | self.crop_size = crop_size
14 | self.max_size = max_size
15 |
16 | self._init()
17 |
18 | def _init(self):
19 | # data paths
20 | inputs = glob.glob(os.path.join(self.root, 'img_x{}'.format(self.scale), '*.*'))[:self.max_size]
21 | targets = [x.replace('img_x{}'.format(self.scale), 'img') for x in inputs]
22 | self.paths = {'input' : inputs, 'target' : targets}
23 |
24 | # transforms
25 | t_list = [transforms.ToTensor()]
26 | if self.training:
27 | t_list.append(lambda x: ((255. * x) + torch.zeros_like(x).uniform_(0., 1.)) / 256.,)
28 | self.image_transform = transforms.Compose(t_list)
29 |
30 | def _get_augment_params(self, size):
31 | random.seed(random.randint(0, 12345))
32 |
33 | # position
34 | w_size, h_size = size
35 | x = random.randint(0, max(0, w_size - self.crop_size))
36 | y = random.randint(0, max(0, h_size - self.crop_size))
37 |
38 | # flip
39 | flip = random.random() > 0.5
40 | return {'crop_pos': (x, y), 'flip': flip}
41 |
42 | def _augment(self, image, aug_params, scale=1):
43 | x, y = aug_params['crop_pos']
44 | image = image.crop((x * scale, y * scale, x * scale + self.crop_size * scale, y * scale + self.crop_size * scale))
45 | if aug_params['flip']:
46 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
47 | return image
48 |
49 | def __getitem__(self, index):
50 | # input image
51 | input = Image.open(self.paths['input'][index]).convert('RGB')
52 |
53 | # target image
54 | target = Image.open(self.paths['target'][index]).convert('RGB')
55 |
56 | if self.training:
57 | aug_params = self._get_augment_params(input.size)
58 | input = self._augment(input, aug_params)
59 | target = self._augment(target, aug_params, self.scale)
60 |
61 | input = self.image_transform(input)
62 | target = self.image_transform(target)
63 |
64 | return {'input': input, 'target': target, 'path': self.paths['target'][index]}
65 |
66 | def __len__(self):
67 | return len(self.paths['input'])
--------------------------------------------------------------------------------
/super-resolution/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import logging
4 | import signal
5 | import sys
6 | import torch.backends.cudnn as cudnn
7 | from trainer import Trainer
8 | from datetime import datetime
9 | from os import path
10 | from utils import misc
11 | from random import randint
12 |
13 | # torch.autograd.set_detect_anomaly(True)
14 |
15 | def get_arguments():
16 | parser = argparse.ArgumentParser(description='super-resolution')
17 | parser.add_argument('--device', default='cuda', help='device assignment ("cpu" or "cuda")')
18 | parser.add_argument('--device-ids', default=[0], type=int, nargs='+', help='device ids assignment (e.g 0 1 2 3')
19 | parser.add_argument('--d-model', default='d_srgan', help='discriminator architecture (default: srgan)')
20 | parser.add_argument('--g-model', default='g_srgan', help='generator architecture (default: srgan)')
21 | parser.add_argument('--model-config', default='', help='additional architecture configuration')
22 | parser.add_argument('--dis-to-load', default='', help='resume training from file (default: None)')
23 | parser.add_argument('--gen-to-load', default='', help='resume training from file (default: None)')
24 | parser.add_argument('--root', default='', help='root dataset folder')
25 | parser.add_argument('--scale', default=4, type=float, help='super-resolution scale (default: 4)')
26 | parser.add_argument('--crop-size', default=40, type=int, help='low resolution cropping size (default: 40)')
27 | parser.add_argument('--max-size', default=None, type=int, help='validation set max-size (default: None)')
28 | parser.add_argument('--num-workers', default=2, type=int, help='number of workers (default: 2)')
29 | parser.add_argument('--batch-size', default=16, type=int, help='batch-size (default: 16)')
30 | parser.add_argument('--epochs', default=1000, type=int, help='epochs (default: 1000)')
31 | parser.add_argument('--lr', default=2e-4, type=float, help='lr (default: 2e-4)')
32 | parser.add_argument('--gen-betas', default=[0.9, 0.99], nargs=2, type=float, help='scheduler gamma (default: 0.9, 0.99)')
33 | parser.add_argument('--dis-betas', default=[0.9, 0.99], nargs=2, type=float, help='scheduler gamma (default: 0.9, 0.99)')
34 | parser.add_argument('--num-critic', default=1, type=int, help='critic iterations (default: 1)')
35 | parser.add_argument('--wgan', default=False, action='store_true', help='critic wgan loss (default: false)')
36 | parser.add_argument('--relativistic', default=False, action='store_true', help='relativistic wgan loss (default: false)')
37 | parser.add_argument('--step-size', default=300, type=int, help='scheduler step size (default: 300)')
38 | parser.add_argument('--gamma', default=0.5, type=float, help='scheduler gamma (default: 0.5)')
39 | parser.add_argument('--penalty-weight', default=0, type=float, help='gradient penalty weight (default: 0)')
40 | parser.add_argument('--range-weight', default=0, type=float, help='pixel-weight (default: 0)')
41 | parser.add_argument('--reconstruction-weight', default=1.0, type=float, help='reconstruction-weight (default: 1.0)')
42 | parser.add_argument('--perceptual-weight', default=0, type=float, help='perceptual-weight (default: 0)')
43 | parser.add_argument('--adversarial-weight', default=0.01, type=float, help='adversarial-weight (default: 0.01)')
44 | parser.add_argument('--style-weight', default=0, type=float, help='style-weight (default: 0)')
45 | parser.add_argument('--seed', default=-1, type=int, help='random seed (default: random)')
46 | parser.add_argument('--print-every', default=20, type=int, help='print-every (default: 20)')
47 | parser.add_argument('--eval-every', default=50, type=int, help='eval-every (default: 50)')
48 | parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results', help='results dir')
49 | parser.add_argument('--save', metavar='SAVE', default='', help='saved folder')
50 | parser.add_argument('--evaluation', default=False, action='store_true', help='evaluate a model (default: false)')
51 | parser.add_argument('--use-tb', default=False, action='store_true', help='use tensorboardx (default: false)')
52 | args = parser.parse_args()
53 |
54 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
55 | if args.save == '':
56 | args.save = time_stamp
57 | args.save_path = path.join(args.results_dir, args.save)
58 | if args.seed == -1:
59 | args.seed = randint(0, 12345)
60 | return args
61 |
62 | def main():
63 | # arguments
64 | args = get_arguments()
65 |
66 | torch.manual_seed(args.seed)
67 |
68 | # cuda
69 | if 'cuda' in args.device and torch.cuda.is_available():
70 | torch.cuda.manual_seed_all(args.seed)
71 | torch.cuda.set_device(args.device_ids[0])
72 | cudnn.benchmark = True
73 | else:
74 | args.device_ids = None
75 |
76 | # set logs
77 | misc.mkdir(args.save_path)
78 | misc.mkdir(path.join(args.save_path, 'images'))
79 | misc.setup_logging(path.join(args.save_path, 'log.txt'))
80 |
81 | # print logs
82 | logging.info(args)
83 |
84 | # trainer
85 | trainer = Trainer(args)
86 |
87 | if args.evaluation:
88 | trainer.eval()
89 | else:
90 | trainer.train()
91 |
92 | if __name__ == '__main__':
93 | # enables a ctrl-c without triggering errors
94 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))
95 | main()
--------------------------------------------------------------------------------
/super-resolution/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .srgan import *
2 | from .xsrgan import *
3 | from .vanilla import *
4 | from .xvanilla import *
5 | from .xdense import *
--------------------------------------------------------------------------------
/super-resolution/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/super-resolution/models/modules/__init__.py
--------------------------------------------------------------------------------
/super-resolution/models/modules/activations.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class xUnit(nn.Module):
4 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False):
5 | super(xUnit, self).__init__()
6 | # xUnit
7 | self.features = nn.Sequential(
8 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
9 | nn.ReLU(),
10 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
11 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
12 | nn.Sigmoid()
13 | )
14 |
15 | def forward(self, x):
16 | a = self.features(x)
17 | r = x * a
18 | return r
19 |
20 | class xUnitS(nn.Module):
21 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False):
22 | super(xUnitS, self).__init__()
23 | # slim xUnit
24 | self.features = nn.Sequential(
25 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
26 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
27 | nn.Sigmoid()
28 | )
29 |
30 | def forward(self, x):
31 | a = self.features(x)
32 | r = x * a
33 | return r
34 |
35 | class xUnitD(nn.Module):
36 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False):
37 | super(xUnitD, self).__init__()
38 | # dense xUnit
39 | self.features = nn.Sequential(
40 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=1, padding=0),
41 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
42 | nn.ReLU(),
43 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
44 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
45 | nn.Sigmoid()
46 | )
47 |
48 | def forward(self, x):
49 | a = self.features(x)
50 | r = x * a
51 | return r
52 |
53 | class Identity(nn.Module):
54 | def __init__(self,):
55 | super(Identity, self).__init__()
56 |
57 | def forward(self, x):
58 | return x
59 |
--------------------------------------------------------------------------------
/super-resolution/models/modules/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .misc import shave_edge
5 | from ..vgg import MultiVGGFeaturesExtractor
6 |
7 | class RangeLoss(nn.Module):
8 | def __init__(self, min_value=0., max_value=1., invalidity_margins=None):
9 | super(RangeLoss, self).__init__()
10 | self.min_value = min_value
11 | self.max_value = max_value
12 | self.invalidity_margins = invalidity_margins
13 |
14 | def forward(self, inputs):
15 | if self.invalidity_margins:
16 | inputs = shave_edge(inputs, self.invalidity_margins, self.invalidity_margins)
17 | loss = (F.relu(self.min_value - inputs) + F.relu(inputs - self.max_value)).mean()
18 | return loss
19 |
20 | class PerceptualLoss(nn.Module):
21 | def __init__(self, features_to_compute, criterion=torch.nn.L1Loss(), shave_edge=None):
22 | super(PerceptualLoss, self).__init__()
23 | self.criterion = criterion
24 | self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute, shave_edge=shave_edge).eval()
25 |
26 | def forward(self, inputs, targets):
27 | inputs_fea = self.features_extractor(inputs)
28 | with torch.no_grad():
29 | targets_fea = self.features_extractor(targets)
30 |
31 | loss = 0
32 | for key in inputs_fea.keys():
33 | loss += self.criterion(inputs_fea[key], targets_fea[key].detach())
34 |
35 | return loss
36 |
37 | class StyleLoss(nn.Module):
38 | def __init__(self, features_to_compute, criterion=torch.nn.L1Loss(), shave_edge=None):
39 | super(StyleLoss, self).__init__()
40 | self.criterion = criterion
41 | self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute, use_input_norm=True, shave_edge=shave_edge).eval()
42 |
43 | def forward(self, inputs, targets):
44 | inputs_fea = self.features_extractor(inputs)
45 | with torch.no_grad():
46 | targets_fea = self.features_extractor(targets)
47 |
48 | loss = 0
49 | for key in inputs_fea.keys():
50 | inputs_gram = self._gram_matrix(inputs_fea[key])
51 | with torch.no_grad():
52 | targets_gram = self._gram_matrix(targets_fea[key]).detach()
53 |
54 | loss += self.criterion(inputs_gram, targets_gram)
55 |
56 | return loss
57 |
58 | def _gram_matrix(self, x):
59 | a, b, c, d = x.size()
60 | features = x.view(a, b, c * d)
61 | gram = features.bmm(features.transpose(1, 2))
62 | return gram.div(b * c * d)
--------------------------------------------------------------------------------
/super-resolution/models/modules/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class UpsampleX2(nn.Module):
6 | def __init__(self, in_channels, out_channels, kernel_size=3):
7 | super(UpsampleX2, self).__init__()
8 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=(out_channels * 4), kernel_size=kernel_size, padding=(kernel_size // 2))
9 | self.shuffler = nn.PixelShuffle(2)
10 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
11 |
12 | def forward(self, x):
13 | return self.lrelu(self.shuffler(self.conv(x)))
14 |
15 | def center_crop(x, height, width):
16 | crop_h = torch.FloatTensor([x.size()[2]]).sub(height).div(-2)
17 | crop_w = torch.FloatTensor([x.size()[3]]).sub(width).div(-2)
18 |
19 | return F.pad(x, [
20 | crop_w.ceil().int()[0], crop_w.floor().int()[0],
21 | crop_h.ceil().int()[0], crop_h.floor().int()[0],
22 | ])
23 |
24 | def shave_edge(x, shave_h, shave_w):
25 | return F.pad(x, [-shave_w, -shave_w, -shave_h, -shave_h])
26 |
27 | def shave_modulo(x, factor):
28 | shave_w = x.size(-1) % factor
29 | shave_h = x.size(-2) % factor
30 | return F.pad(x, [0, -shave_w, 0, -shave_h])
31 |
32 | if __name__ == "__main__":
33 | x = torch.randn(1, 2, 4, 6)
34 | y = shave_edge(x, 1, 2)
35 | print(x)
36 | print(y)
--------------------------------------------------------------------------------
/super-resolution/models/srgan.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from math import log
3 | from torch.nn import functional as F
4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
5 | from .modules.misc import UpsampleX2, center_crop
6 |
7 | __all__ = ['g_srgan', 'd_srgan']
8 |
9 | def initialize_weights(net, scale=1.):
10 | if not isinstance(net, list):
11 | net = [net]
12 | for layer in net:
13 | for m in layer.modules():
14 | if isinstance(m, nn.Conv2d):
15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
16 | m.weight.data *= scale # for residual block
17 | if m.bias is not None:
18 | m.bias.data.zero_()
19 | elif isinstance(m, nn.Linear):
20 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
21 | m.weight.data *= scale
22 | if m.bias is not None:
23 | m.bias.data.zero_()
24 | elif isinstance(m, nn.BatchNorm2d):
25 | nn.init.constant_(m.weight, 1)
26 | nn.init.constant_(m.bias.data, 0.0)
27 |
28 | class GenBlock(nn.Module):
29 | def __init__(self, num_features=64, kernel_size=3, bias=False):
30 | super(GenBlock, self).__init__()
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv1 = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias)
33 | self.conv2 = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias)
34 |
35 | initialize_weights([self.conv1, self.conv2], 0.1)
36 |
37 | def forward(self, x):
38 | residual = x
39 | x = self.conv2(self.relu(self.conv1(x)))
40 | x = residual + x
41 | return x
42 |
43 | class DisBlock(nn.Module):
44 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False):
45 | super(DisBlock, self).__init__()
46 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
47 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias)
48 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias)
49 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
50 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True)
51 |
52 | initialize_weights([self.conv1, self.conv2], 0.1)
53 |
54 | if normalization:
55 | self.conv1 = SpectralNorm(self.conv1)
56 | self.conv2 = SpectralNorm(self.conv2)
57 |
58 | def forward(self, x):
59 | x = self.lrelu(self.bn1(self.conv1(x)))
60 | x = self.lrelu(self.bn2(self.conv2(x)))
61 | return x
62 |
63 | class Generator(nn.Module):
64 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
65 | super(Generator, self).__init__()
66 | self.scale = scale
67 | # image to features
68 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
69 | self.image_to_features = nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=5, padding=2)
70 |
71 | # features
72 | blocks = []
73 | for _ in range(gen_blocks):
74 | blocks.append(GenBlock(num_features=num_features))
75 | self.features = nn.Sequential(*blocks)
76 |
77 | # upsampling
78 | blocks = []
79 | for _ in range(int(log(scale, 2))):
80 | block = UpsampleX2(in_channels=num_features, out_channels=num_features)
81 | blocks.append(block)
82 | self.usample = nn.Sequential(*blocks)
83 |
84 | # features to image
85 | self.hrconv = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, padding=1)
86 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2)
87 |
88 | # init weights
89 | initialize_weights([self.image_to_features, self.hrconv, self.features_to_image], 0.1)
90 |
91 | def forward(self, x):
92 | r = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
93 | x = self.lrelu(self.image_to_features(x))
94 | x = self.features(x)
95 | x = self.usample(x)
96 | x = self.features_to_image(self.lrelu(self.hrconv(x)))
97 | x = x + r
98 | return x
99 |
100 | class Discriminator(nn.Module):
101 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
102 | super(Discriminator, self).__init__()
103 | self.crop_size = 4 * pow(2, dis_blocks)
104 |
105 | # image to features
106 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
107 |
108 | # features
109 | blocks = []
110 | for i in range(0, dis_blocks - 1):
111 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
112 | self.features = nn.Sequential(*blocks)
113 |
114 | # classifier
115 | self.classifier = nn.Sequential(
116 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100),
117 | nn.LeakyReLU(negative_slope=0.1),
118 | nn.Linear(100, 1)
119 | )
120 |
121 | def forward(self, x):
122 | x = center_crop(x, self.crop_size, self.crop_size)
123 | x = self.image_to_features(x)
124 | x = self.features(x)
125 | x = x.flatten(start_dim=1)
126 | x = self.classifier(x)
127 | return x
128 |
129 | class SNDiscriminator(nn.Module):
130 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
131 | super(SNDiscriminator, self).__init__()
132 | self.crop_size = 4 * pow(2, dis_blocks)
133 |
134 | # image to features
135 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
136 |
137 | # features
138 | blocks = []
139 | for i in range(0, dis_blocks - 1):
140 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
141 | self.features = nn.Sequential(*blocks)
142 |
143 | # classifier
144 | self.classifier = nn.Sequential(
145 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)),
146 | nn.LeakyReLU(negative_slope=0.1),
147 | SpectralNorm(nn.Linear(100, 1))
148 | )
149 |
150 | def forward(self, x):
151 | x = center_crop(x, self.crop_size, self.crop_size)
152 | x = self.image_to_features(x)
153 | x = self.features(x)
154 | x = x.flatten(start_dim=1)
155 | x = self.classifier(x)
156 | return x
157 |
158 | def g_srgan(**config):
159 | config.setdefault('in_channels', 3)
160 | config.setdefault('num_features', 64)
161 | config.setdefault('gen_blocks', 16)
162 | config.setdefault('dis_blocks', 5)
163 | config.setdefault('scale', 4)
164 |
165 | _ = config.pop('spectral', False)
166 |
167 | return Generator(**config)
168 |
169 | def d_srgan(**config):
170 | config.setdefault('in_channels', 3)
171 | config.setdefault('num_features', 64)
172 | config.setdefault('gen_blocks', 16)
173 | config.setdefault('dis_blocks', 5)
174 | config.setdefault('scale', 4)
175 |
176 | sn = config.pop('spectral', False)
177 |
178 | if sn:
179 | return SNDiscriminator(**config)
180 | else:
181 | return Discriminator(**config)
182 |
--------------------------------------------------------------------------------
/super-resolution/models/vanilla.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from math import log
5 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
6 | from .modules.misc import center_crop
7 |
8 | __all__ = ['g_vanilla', 'd_vanilla']
9 |
10 | def initialize_weights(net, scale=1.):
11 | if not isinstance(net, list):
12 | net = [net]
13 | for layer in net:
14 | for m in layer.modules():
15 | if isinstance(m, nn.Conv2d):
16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
17 | m.weight.data *= scale # for residual block
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.Linear):
21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
22 | m.weight.data *= scale
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm2d):
26 | nn.init.constant_(m.weight, 1)
27 | nn.init.constant_(m.bias.data, 0.0)
28 |
29 | class GenBlock(nn.Module):
30 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, bias=True):
31 | super(GenBlock, self).__init__()
32 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias)
33 | self.bn = nn.BatchNorm2d(num_features=out_channels)
34 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
35 |
36 | initialize_weights([self.conv, self.norm], 0.02)
37 |
38 | def forward(self, x):
39 | x = self.lrelu(self.bn(self.conv(x)))
40 | return x
41 |
42 | class DisBlock(nn.Module):
43 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False):
44 | super(DisBlock, self).__init__()
45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
46 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias)
47 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias)
48 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
49 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True)
50 |
51 | initialize_weights([self.conv1, self.conv2], 0.1)
52 |
53 | if normalization:
54 | self.conv1 = SpectralNorm(self.conv1)
55 | self.conv2 = SpectralNorm(self.conv2)
56 |
57 | def forward(self, x):
58 | x = self.lrelu(self.bn1(self.conv1(x)))
59 | x = self.lrelu(self.bn2(self.conv2(x)))
60 | return x
61 |
62 | class Generator(nn.Module):
63 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
64 | super(Generator, self).__init__()
65 | self.scale_factor = scale
66 |
67 | # image to features
68 | self.image_to_features = GenBlock(in_channels=in_channels, out_channels=num_features)
69 |
70 | # features
71 | blocks = []
72 | for _ in range(gen_blocks):
73 | blocks.append(GenBlock(in_channels=num_features, out_channels=num_features))
74 | self.features = nn.Sequential(*blocks)
75 |
76 | # features to image
77 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2)
78 | initialize_weights([self.features_to_image], 0.02)
79 |
80 | def forward(self, x):
81 | x = F.interpolate(x, size=(self.scale_factor * x.size(-2), self.scale_factor * x.size(-1)), mode='bicubic', align_corners=True)
82 | r = x
83 | x = self.image_to_features(x)
84 | x = self.features(x)
85 | x = self.features_to_image(x)
86 | x += r
87 | return x
88 |
89 | class Discriminator(nn.Module):
90 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
91 | super(Discriminator, self).__init__()
92 | self.crop_size = 4 * pow(2, dis_blocks)
93 |
94 | # image to features
95 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
96 |
97 | # features
98 | blocks = []
99 | for i in range(0, dis_blocks - 1):
100 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
101 | self.features = nn.Sequential(*blocks)
102 |
103 | # classifier
104 | self.classifier = nn.Sequential(
105 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100),
106 | nn.LeakyReLU(negative_slope=0.1),
107 | nn.Linear(100, 1)
108 | )
109 |
110 | def forward(self, x):
111 | x = center_crop(x, self.crop_size, self.crop_size)
112 | x = self.image_to_features(x)
113 | x = self.features(x)
114 | x = x.flatten(start_dim=1)
115 | x = self.classifier(x)
116 | return x
117 |
118 | class SNDiscriminator(nn.Module):
119 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
120 | super(SNDiscriminator, self).__init__()
121 | self.crop_size = 4 * pow(2, dis_blocks)
122 |
123 | # image to features
124 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
125 |
126 | # features
127 | blocks = []
128 | for i in range(0, dis_blocks - 1):
129 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
130 | self.features = nn.Sequential(*blocks)
131 |
132 | # classifier
133 | self.classifier = nn.Sequential(
134 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)),
135 | nn.LeakyReLU(negative_slope=0.1),
136 | SpectralNorm(nn.Linear(100, 1))
137 | )
138 |
139 | def forward(self, x):
140 | x = center_crop(x, self.crop_size, self.crop_size)
141 | x = self.image_to_features(x)
142 | x = self.features(x)
143 | x = x.flatten(start_dim=1)
144 | x = self.classifier(x)
145 | return x
146 |
147 | def g_vanilla(**config):
148 | config.setdefault('in_channels', 3)
149 | config.setdefault('num_features', 64)
150 | config.setdefault('gen_blocks', 8)
151 | config.setdefault('dis_blocks', 5)
152 | config.setdefault('scale', 4)
153 |
154 | return Generator(**config)
155 |
156 | def d_vanilla(**config):
157 | config.setdefault('in_channels', 3)
158 | config.setdefault('num_features', 64)
159 | config.setdefault('gen_blocks', 8)
160 | config.setdefault('dis_blocks', 5)
161 | config.setdefault('scale', 4)
162 |
163 | return Discriminator(**config)
--------------------------------------------------------------------------------
/super-resolution/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 | from models.modules.misc import shave_edge
5 | from collections import OrderedDict
6 |
7 | names = {'vgg19': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
8 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
9 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
10 | 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
11 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2',
12 | 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
13 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14 | 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'],
15 |
16 | 'vgg19_bn': ['conv1_1', 'bn1_1', 'relu1_1', 'conv1_2', 'bn1_2', 'relu1_2', 'pool1',
17 | 'conv2_1', 'bn2_1', 'relu2_1', 'conv2_2', 'bn2_2', 'relu2_2', 'pool2',
18 | 'conv3_1', 'bn3_1', 'relu3_1', 'conv3_2', 'bn3_2', 'relu3_2',
19 | 'conv3_3', 'bn3_3', 'relu3_3', 'conv3_4', 'bn3_4', 'relu3_4', 'pool3',
20 | 'conv4_1', 'bn4_1', 'relu4_1', 'conv4_2', 'bn4_2', 'relu4_2',
21 | 'conv4_3', 'bn4_3', 'relu4_3', 'conv4_4', 'bn4_4', 'relu4_4', 'pool4',
22 | 'conv5_1', 'bn5_1', 'relu5_1', 'conv5_2', 'bn5_2', 'relu5_2',
23 | 'conv5_3', 'bn5_3', 'relu5_3', 'conv5_4', 'bn5_4', 'relu5_4', 'pool5']
24 | }
25 |
26 |
27 | class VGGFeaturesExtractor(nn.Module):
28 | def __init__(self, feature_layer='conv5_4', use_bn=False, use_input_norm=True, requires_grad=False):
29 | super(VGGFeaturesExtractor, self).__init__()
30 | self.use_input_norm = use_input_norm
31 |
32 | if use_bn:
33 | model = torchvision.models.vgg19_bn(pretrained=True)
34 | else:
35 | model = torchvision.models.vgg19(pretrained=True)
36 |
37 | if self.use_input_norm:
38 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
39 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
40 | self.register_buffer('mean', mean)
41 | self.register_buffer('std', std)
42 |
43 | layer_index = names['vgg19_bn'].index(feature_layer) if use_bn else names['vgg19'].index(feature_layer)
44 | self.features = nn.Sequential(*list(model.features.children())[:(layer_index + 1)])
45 |
46 | if not requires_grad:
47 | for k, v in self.features.named_parameters():
48 | v.requires_grad = False
49 | self.features.eval()
50 |
51 | def forward(self, x):
52 | # Assume input range is [0, 1]
53 | if self.use_input_norm:
54 | x = (x - self.mean) / self.std
55 | output = self.features(x)
56 | return output
57 |
58 | class MultiVGGFeaturesExtractor(nn.Module):
59 | def __init__(self, target_features=('relu1_1', 'relu2_1', 'relu3_1'), use_bn=False, use_input_norm=True, requires_grad=False, shave_edge=None):
60 | super(MultiVGGFeaturesExtractor, self).__init__()
61 | self.use_input_norm = use_input_norm
62 | self.target_features = target_features
63 | self.shave_edge = shave_edge
64 |
65 | if use_bn:
66 | model = torchvision.models.vgg19_bn(pretrained=True)
67 | names_key = 'vgg19_bn'
68 | else:
69 | model = torchvision.models.vgg19(pretrained=True)
70 | names_key = 'vgg19'
71 |
72 | if self.use_input_norm:
73 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
75 | self.register_buffer('mean', mean)
76 | self.register_buffer('std', std)
77 |
78 | self.target_indexes = [names[names_key].index(k) for k in self.target_features]
79 | self.features = nn.Sequential(*list(model.features.children())[:(max(self.target_indexes) + 1)])
80 |
81 | if not requires_grad:
82 | for k, v in self.features.named_parameters():
83 | v.requires_grad = False
84 | self.features.eval()
85 |
86 | def forward(self, x):
87 | if self.shave_edge:
88 | x = shave_edge(x, self.shave_edge, self.shave_edge)
89 |
90 | # assume input range is [0, 1]
91 | if self.use_input_norm:
92 | x = (x - self.mean) / self.std
93 |
94 | output = OrderedDict()
95 | for key, layer in self.features._modules.items():
96 | x = layer(x)
97 | if int(key) in self.target_indexes:
98 | output.update({self.target_features[self.target_indexes.index(int(key))]: x})
99 | return output
--------------------------------------------------------------------------------
/super-resolution/models/xdense.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from math import log
4 | from torch.nn import functional as F
5 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
6 | from .modules.activations import xUnitD
7 | from .modules.misc import UpsampleX2, center_crop
8 |
9 | __all__ = ['g_xdense', 'd_xdense']
10 |
11 | def initialize_weights(net, scale=1.):
12 | if not isinstance(net, list):
13 | net = [net]
14 | for layer in net:
15 | for m in layer.modules():
16 | if isinstance(m, nn.Conv2d):
17 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
18 | m.weight.data *= scale # for residual block
19 | if m.bias is not None:
20 | m.bias.data.zero_()
21 | elif isinstance(m, nn.Linear):
22 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
23 | m.weight.data *= scale
24 | if m.bias is not None:
25 | m.bias.data.zero_()
26 | elif isinstance(m, nn.BatchNorm2d):
27 | nn.init.constant_(m.weight, 1)
28 | nn.init.constant_(m.bias.data, 0.0)
29 |
30 | class xModule(nn.Module):
31 | def __init__(self, in_channels, out_channels):
32 | super(xModule, self).__init__()
33 | # features
34 | self.features = nn.Sequential(
35 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
36 | xUnitD(num_features=out_channels, batch_norm=False),
37 | )
38 |
39 | def forward(self, x):
40 | x = self.features(x)
41 | return x
42 |
43 | class xDenseLayer(nn.Module):
44 | def __init__(self, in_channels, growth_rate, bn_size):
45 | super(xDenseLayer, self).__init__()
46 | # features
47 | self.features = nn.Sequential(
48 | nn.PReLU(),
49 | nn.Conv2d(in_channels=in_channels,out_channels=bn_size * growth_rate, kernel_size=1, stride=1, bias=False),
50 | nn.PReLU(),
51 | xModule(in_channels=bn_size * growth_rate, out_channels=growth_rate)
52 | )
53 |
54 | def forward(self, x):
55 | f = self.features(x)
56 | return torch.cat([x, f], dim=1)
57 |
58 | class xDenseBlock(nn.Module):
59 | def __init__(self, in_channels, num_layers, growth_rate, bn_size):
60 | super(xDenseBlock, self).__init__()
61 | # features
62 | blocks = []
63 | for i in range(num_layers):
64 | blocks.append(xDenseLayer(in_channels=in_channels + growth_rate*i, growth_rate=growth_rate, bn_size=bn_size))
65 | self.features = nn.Sequential(*blocks)
66 |
67 | def forward(self, x):
68 | x = self.features(x)
69 | return x
70 |
71 | class Transition(nn.Module):
72 | def __init__(self, in_channels, out_channels):
73 | super(Transition, self).__init__()
74 | # features
75 | self.features = nn.Sequential(
76 | nn.PReLU(),
77 | nn.Conv2d(in_channels=in_channels,out_channels=out_channels, kernel_size=1, stride=1, bias=False),
78 | )
79 |
80 | def forward(self, x):
81 | x = self.features(x)
82 | return x
83 |
84 | class DisBlock(nn.Module):
85 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False):
86 | super(DisBlock, self).__init__()
87 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
88 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias)
89 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias)
90 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
91 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True)
92 |
93 | initialize_weights([self.conv1, self.conv2], 0.1)
94 |
95 | if normalization:
96 | self.conv1 = SpectralNorm(self.conv1)
97 | self.conv2 = SpectralNorm(self.conv2)
98 |
99 | def forward(self, x):
100 | x = self.lrelu(self.bn1(self.conv1(x)))
101 | x = self.lrelu(self.bn2(self.conv2(x)))
102 | return x
103 |
104 | class Generator(nn.Module):
105 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size, scale):
106 | super(Generator, self).__init__()
107 | self.scale = scale
108 | # image to features
109 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
110 | self.image_to_features = xModule(in_channels=in_channels, out_channels=num_features)
111 |
112 | # features
113 | blocks = []
114 | self.num_features = num_features
115 | for i, num_layers in enumerate(gen_blocks):
116 | blocks.append(xDenseBlock(in_channels=self.num_features, num_layers=num_layers, growth_rate=growth_rate, bn_size=bn_size))
117 | self.num_features += num_layers * growth_rate
118 |
119 | if i != len(gen_blocks) - 1:
120 | blocks.append(Transition(in_channels=self.num_features, out_channels=self.num_features // 2))
121 | self.num_features = self.num_features // 2
122 |
123 | self.features = nn.Sequential(*blocks)
124 |
125 | # upsampling
126 | blocks = []
127 | for _ in range(int(log(scale, 2))):
128 | block = UpsampleX2(in_channels=self.num_features, out_channels=self.num_features)
129 | blocks.append(block)
130 | self.usample = nn.Sequential(*blocks)
131 |
132 | # features to image
133 | self.hrconv = nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features, kernel_size=3, padding=1)
134 | self.features_to_image = nn.Conv2d(in_channels=self.num_features, out_channels=in_channels, kernel_size=5, padding=2)
135 |
136 | # init weights
137 | initialize_weights([self.image_to_features, self.hrconv, self.features_to_image], 0.1)
138 |
139 | def forward(self, x):
140 | r = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
141 | x = self.image_to_features(x)
142 | x = self.features(x)
143 | x = self.usample(x)
144 | x = self.features_to_image(self.lrelu(self.hrconv(x)))
145 | x = x + r
146 | return x
147 |
148 | class Discriminator(nn.Module):
149 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size, scale):
150 | super(Discriminator, self).__init__()
151 | self.crop_size = 4 * pow(2, dis_blocks)
152 |
153 | # image to features
154 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
155 |
156 | # features
157 | blocks = []
158 | for i in range(0, dis_blocks - 1):
159 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
160 | self.features = nn.Sequential(*blocks)
161 |
162 | # classifier
163 | self.classifier = nn.Sequential(
164 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100),
165 | nn.LeakyReLU(negative_slope=0.1),
166 | nn.Linear(100, 1)
167 | )
168 |
169 | def forward(self, x):
170 | x = center_crop(x, self.crop_size, self.crop_size)
171 | x = self.image_to_features(x)
172 | x = self.features(x)
173 | x = x.flatten(start_dim=1)
174 | x = self.classifier(x)
175 | return x
176 |
177 | class SNDiscriminator(nn.Module):
178 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size, scale):
179 | super(SNDiscriminator, self).__init__()
180 | self.crop_size = 4 * pow(2, dis_blocks)
181 |
182 | # image to features
183 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
184 |
185 | # features
186 | blocks = []
187 | for i in range(0, dis_blocks - 1):
188 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
189 | self.features = nn.Sequential(*blocks)
190 |
191 | # classifier
192 | self.classifier = nn.Sequential(
193 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)),
194 | nn.LeakyReLU(negative_slope=0.1),
195 | SpectralNorm(nn.Linear(100, 1))
196 | )
197 |
198 | def forward(self, x):
199 | x = center_crop(x, self.crop_size, self.crop_size)
200 | x = self.image_to_features(x)
201 | x = self.features(x)
202 | x = x.flatten(start_dim=1)
203 | x = self.classifier(x)
204 | return x
205 |
206 | def g_xdense(**config):
207 | config.setdefault('in_channels', 3)
208 | config.setdefault('num_features', 64)
209 | config.setdefault('gen_blocks', [4, 6, 8])
210 | config.setdefault('dis_blocks', 5)
211 | config.setdefault('growth_rate', 8)
212 | config.setdefault('bn_size', 2)
213 | config.setdefault('scale', 4)
214 |
215 | _ = config.pop('spectral', False)
216 |
217 | return Generator(**config)
218 |
219 | def d_xdense(**config):
220 | config.setdefault('in_channels', 3)
221 | config.setdefault('num_features', 64)
222 | config.setdefault('gen_blocks', [4, 6, 8])
223 | config.setdefault('dis_blocks', 5)
224 | config.setdefault('growth_rate', 8)
225 | config.setdefault('bn_size', 2)
226 | config.setdefault('scale', 4)
227 |
228 | sn = config.pop('spectral', False)
229 |
230 | if sn:
231 | return SNDiscriminator(**config)
232 | else:
233 | return Discriminator(**config)
234 |
--------------------------------------------------------------------------------
/super-resolution/models/xsrgan.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from math import log
3 | from torch.nn import functional as F
4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
5 | from .modules.activations import xUnitS
6 | from .modules.misc import UpsampleX2, center_crop
7 |
8 | __all__ = ['g_xsrgan', 'd_xsrgan', 'd_xsrgan_ad']
9 |
10 | def initialize_weights(net, scale=1.):
11 | if not isinstance(net, list):
12 | net = [net]
13 | for layer in net:
14 | for m in layer.modules():
15 | if isinstance(m, nn.Conv2d):
16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
17 | m.weight.data *= scale # for residual block
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.Linear):
21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
22 | m.weight.data *= scale
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm2d):
26 | nn.init.constant_(m.weight, 1)
27 | nn.init.constant_(m.bias.data, 0.0)
28 |
29 | class GenBlock(nn.Module):
30 | def __init__(self, num_features=64, kernel_size=3, bias=False):
31 | super(GenBlock, self).__init__()
32 | self.xunit = xUnitS(num_features=num_features)
33 | self.conv1 = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias)
34 | self.conv2 = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias)
35 |
36 | initialize_weights([self.conv1, self.conv2], 0.1)
37 |
38 | def forward(self, x):
39 | residual = x
40 | x = self.conv2(self.xunit(self.conv1(x)))
41 | x = residual + x
42 | return x
43 |
44 | class DisBlock(nn.Module):
45 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False):
46 | super(DisBlock, self).__init__()
47 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
48 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias)
49 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias)
50 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
51 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True)
52 |
53 | initialize_weights([self.conv1, self.conv2], 0.1)
54 |
55 | if normalization:
56 | self.conv1 = SpectralNorm(self.conv1)
57 | self.conv2 = SpectralNorm(self.conv2)
58 |
59 | def forward(self, x):
60 | x = self.lrelu(self.bn1(self.conv1(x)))
61 | x = self.lrelu(self.bn2(self.conv2(x)))
62 | return x
63 |
64 | class Generator(nn.Module):
65 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
66 | super(Generator, self).__init__()
67 | self.scale = scale
68 | # image to features
69 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
70 | self.image_to_features = nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=5, padding=2)
71 |
72 | # features
73 | blocks = []
74 | for _ in range(gen_blocks):
75 | blocks.append(GenBlock(num_features=num_features))
76 | self.features = nn.Sequential(*blocks)
77 |
78 | # upsampling
79 | blocks = []
80 | for _ in range(int(log(scale, 2))):
81 | block = UpsampleX2(in_channels=num_features, out_channels=num_features)
82 | blocks.append(block)
83 | self.usample = nn.Sequential(*blocks)
84 |
85 | # features to image
86 | self.hrconv = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, padding=1)
87 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2)
88 |
89 | # init weights
90 | initialize_weights([self.image_to_features, self.hrconv, self.features_to_image], 0.1)
91 |
92 | def forward(self, x):
93 | r = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
94 | x = self.lrelu(self.image_to_features(x))
95 | x = self.features(x)
96 | x = self.usample(x)
97 | x = self.features_to_image(self.lrelu(self.hrconv(x)))
98 | x = x + r
99 | return x
100 |
101 | class Discriminator(nn.Module):
102 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
103 | super(Discriminator, self).__init__()
104 | self.crop_size = 4 * pow(2, dis_blocks)
105 |
106 | # image to features
107 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
108 |
109 | # features
110 | blocks = []
111 | for i in range(0, dis_blocks - 1):
112 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
113 | self.features = nn.Sequential(*blocks)
114 |
115 | # classifier
116 | self.classifier = nn.Sequential(
117 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100),
118 | nn.LeakyReLU(negative_slope=0.1),
119 | nn.Linear(100, 1)
120 | )
121 |
122 | def forward(self, x):
123 | x = center_crop(x, self.crop_size, self.crop_size)
124 | x = self.image_to_features(x)
125 | x = self.features(x)
126 | x = x.flatten(start_dim=1)
127 | x = self.classifier(x)
128 | return x
129 |
130 | class DiscriminatorAdaptive(nn.Module):
131 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
132 | super(DiscriminatorAdaptive, self).__init__()
133 | self.crop_size = 4 * pow(2, dis_blocks)
134 |
135 | # image to features
136 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
137 |
138 | # features
139 | blocks = []
140 | for i in range(0, dis_blocks - 1):
141 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
142 | self.features = nn.Sequential(*blocks)
143 |
144 | # pooling
145 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
146 |
147 | # classifier
148 | self.classifier = nn.Sequential(
149 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8), num_features),
150 | nn.LeakyReLU(negative_slope=0.1),
151 | nn.Linear(num_features, 1)
152 | )
153 |
154 | def forward(self, x):
155 | x = center_crop(x, self.crop_size, self.crop_size)
156 | x = self.image_to_features(x)
157 | x = self.features(x)
158 | x = self.avg_pool(x)
159 | x = x.flatten(start_dim=1)
160 | x = self.classifier(x)
161 | return x
162 |
163 | class SNDiscriminator(nn.Module):
164 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
165 | super(SNDiscriminator, self).__init__()
166 | self.crop_size = 4 * pow(2, dis_blocks)
167 |
168 | # image to features
169 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
170 |
171 | # features
172 | blocks = []
173 | for i in range(0, dis_blocks - 1):
174 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
175 | self.features = nn.Sequential(*blocks)
176 |
177 | # pooling
178 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
179 |
180 | # classifier
181 | self.classifier = nn.Sequential(
182 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)),
183 | nn.LeakyReLU(negative_slope=0.1),
184 | SpectralNorm(nn.Linear(100, 1))
185 | )
186 |
187 | def forward(self, x):
188 | x = center_crop(x, self.crop_size, self.crop_size)
189 | x = self.image_to_features(x)
190 | x = self.features(x)
191 | x = x.flatten(start_dim=1)
192 | x = self.classifier(x)
193 | return x
194 |
195 | class SNDiscriminatorAdaptive(nn.Module):
196 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
197 | super(SNDiscriminatorAdaptive, self).__init__()
198 | self.crop_size = 4 * pow(2, dis_blocks)
199 |
200 | # image to features
201 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
202 |
203 | # features
204 | blocks = []
205 | for i in range(0, dis_blocks - 1):
206 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
207 | self.features = nn.Sequential(*blocks)
208 |
209 | # classifier
210 | self.classifier = nn.Sequential(
211 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8), num_features)),
212 | nn.LeakyReLU(negative_slope=0.1),
213 | SpectralNorm(nn.Linear(num_features, 1))
214 | )
215 |
216 | def forward(self, x):
217 | x = center_crop(x, self.crop_size, self.crop_size)
218 | x = self.image_to_features(x)
219 | x = self.features(x)
220 | x = self.avg_pool(x)
221 | x = x.flatten(start_dim=1)
222 | x = self.classifier(x)
223 | return x
224 |
225 | def g_xsrgan(**config):
226 | config.setdefault('in_channels', 3)
227 | config.setdefault('num_features', 64)
228 | config.setdefault('gen_blocks', 10)
229 | config.setdefault('dis_blocks', 5)
230 | config.setdefault('scale', 4)
231 |
232 | _ = config.pop('spectral', False)
233 |
234 | return Generator(**config)
235 |
236 | def d_xsrgan(**config):
237 | config.setdefault('in_channels', 3)
238 | config.setdefault('num_features', 64)
239 | config.setdefault('gen_blocks', 10)
240 | config.setdefault('dis_blocks', 5)
241 | config.setdefault('scale', 4)
242 |
243 | sn = config.pop('spectral', False)
244 |
245 | if sn:
246 | return SNDiscriminator(**config)
247 | else:
248 | return Discriminator(**config)
249 |
250 | def d_xsrgan_ad(**config):
251 | config.setdefault('in_channels', 3)
252 | config.setdefault('num_features', 64)
253 | config.setdefault('gen_blocks', 10)
254 | config.setdefault('dis_blocks', 5)
255 | config.setdefault('scale', 4)
256 |
257 | sn = config.pop('spectral', False)
258 |
259 | if sn:
260 | return SNDiscriminatorAdaptive(**config)
261 | else:
262 | return DiscriminatorAdaptive(**config)
263 |
--------------------------------------------------------------------------------
/super-resolution/models/xvanilla.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from math import log
5 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
6 | from .modules.activations import xUnitS
7 | from .modules.misc import center_crop
8 |
9 | __all__ = ['g_xvanilla', 'd_xvanilla']
10 |
11 | def initialize_weights(net, scale=1.):
12 | if not isinstance(net, list):
13 | net = [net]
14 | for layer in net:
15 | for m in layer.modules():
16 | if isinstance(m, nn.Conv2d):
17 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
18 | m.weight.data *= scale # for residual block
19 | if m.bias is not None:
20 | m.bias.data.zero_()
21 | elif isinstance(m, nn.Linear):
22 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
23 | m.weight.data *= scale
24 | if m.bias is not None:
25 | m.bias.data.zero_()
26 | elif isinstance(m, nn.BatchNorm2d):
27 | nn.init.constant_(m.weight, 1)
28 | nn.init.constant_(m.bias.data, 0.0)
29 |
30 | class GenBlock(nn.Module):
31 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, bias=True):
32 | super(GenBlock, self).__init__()
33 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias)
34 | self.bn = nn.BatchNorm2d(num_features=out_channels)
35 | self.xunit = xUnitS(num_features=out_channels, batch_norm=True)
36 |
37 | initialize_weights([self.conv, self.norm], 0.02)
38 |
39 | def forward(self, x):
40 | x = self.xunit(self.bn(self.conv(x)))
41 | return x
42 |
43 | class DisBlock(nn.Module):
44 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False):
45 | super(DisBlock, self).__init__()
46 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
47 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias)
48 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias)
49 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True)
50 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True)
51 |
52 | initialize_weights([self.conv1, self.conv2], 0.1)
53 |
54 | if normalization:
55 | self.conv1 = SpectralNorm(self.conv1)
56 | self.conv2 = SpectralNorm(self.conv2)
57 |
58 | def forward(self, x):
59 | x = self.lrelu(self.bn1(self.conv1(x)))
60 | x = self.lrelu(self.bn2(self.conv2(x)))
61 | return x
62 |
63 | class Generator(nn.Module):
64 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
65 | super(Generator, self).__init__()
66 | self.scale_factor = scale
67 |
68 | # image to features
69 | self.image_to_features = GenBlock(in_channels=in_channels, out_channels=num_features)
70 |
71 | # features
72 | blocks = []
73 | for _ in range(gen_blocks):
74 | blocks.append(GenBlock(in_channels=num_features, out_channels=num_features))
75 | self.features = nn.Sequential(*blocks)
76 |
77 | # features to image
78 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2)
79 | initialize_weights([self.features_to_image], 0.02)
80 |
81 | def forward(self, x):
82 | x = F.interpolate(x, size=(self.scale_factor * x.size(-2), self.scale_factor * x.size(-1)), mode='bicubic', align_corners=True)
83 | r = x
84 | x = self.image_to_features(x)
85 | x = self.features(x)
86 | x = self.features_to_image(x)
87 | x += r
88 | return x
89 |
90 | class Discriminator(nn.Module):
91 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
92 | super(Discriminator, self).__init__()
93 | self.crop_size = 4 * pow(2, dis_blocks)
94 |
95 | # image to features
96 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False)
97 |
98 | # features
99 | blocks = []
100 | for i in range(0, dis_blocks - 1):
101 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False))
102 | self.features = nn.Sequential(*blocks)
103 |
104 | # classifier
105 | self.classifier = nn.Sequential(
106 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100),
107 | nn.LeakyReLU(negative_slope=0.1),
108 | nn.Linear(100, 1)
109 | )
110 |
111 | def forward(self, x):
112 | x = center_crop(x, self.crop_size, self.crop_size)
113 | x = self.image_to_features(x)
114 | x = self.features(x)
115 | x = x.flatten(start_dim=1)
116 | x = self.classifier(x)
117 | return x
118 |
119 | class SNDiscriminator(nn.Module):
120 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale):
121 | super(SNDiscriminator, self).__init__()
122 | self.crop_size = 4 * pow(2, dis_blocks)
123 |
124 | # image to features
125 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True)
126 |
127 | # features
128 | blocks = []
129 | for i in range(0, dis_blocks - 1):
130 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True))
131 | self.features = nn.Sequential(*blocks)
132 |
133 | # classifier
134 | self.classifier = nn.Sequential(
135 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)),
136 | nn.LeakyReLU(negative_slope=0.1),
137 | SpectralNorm(nn.Linear(100, 1))
138 | )
139 |
140 | def forward(self, x):
141 | x = center_crop(x, self.crop_size, self.crop_size)
142 | x = self.image_to_features(x)
143 | x = self.features(x)
144 | x = x.flatten(start_dim=1)
145 | x = self.classifier(x)
146 | return x
147 |
148 | def g_xvanilla(**config):
149 | config.setdefault('in_channels', 3)
150 | config.setdefault('num_features', 64)
151 | config.setdefault('gen_blocks', 8)
152 | config.setdefault('dis_blocks', 5)
153 | config.setdefault('scale', 4)
154 |
155 | return Generator(**config)
156 |
157 | def d_xvanilla(**config):
158 | config.setdefault('in_channels', 3)
159 | config.setdefault('num_features', 64)
160 | config.setdefault('gen_blocks', 8)
161 | config.setdefault('dis_blocks', 5)
162 | config.setdefault('scale', 4)
163 |
164 | return Discriminator(**config)
--------------------------------------------------------------------------------
/super-resolution/scripts/matlab/bicubic_subsample.m:
--------------------------------------------------------------------------------
1 | function bicubic_subsample()
2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
3 |
4 | %% set parameters
5 | % comment the unnecessary line
6 | input_folder = '';
7 | save_mod_folder = '';
8 | save_LR_folder = '';
9 | % save_bic_folder = '';
10 |
11 | up_scale = 3;
12 | mod_scale = up_scale;
13 |
14 | if exist('save_mod_folder', 'var')
15 | if exist(save_mod_folder, 'dir')
16 | disp(['It will cover ', save_mod_folder]);
17 | else
18 | mkdir(save_mod_folder);
19 | end
20 | end
21 | if exist('save_LR_folder', 'var')
22 | if exist(save_LR_folder, 'dir')
23 | disp(['It will cover ', save_LR_folder]);
24 | else
25 | mkdir(save_LR_folder);
26 | end
27 | end
28 | if exist('save_bic_folder', 'var')
29 | if exist(save_bic_folder, 'dir')
30 | disp(['It will cover ', save_bic_folder]);
31 | else
32 | mkdir(save_bic_folder);
33 | end
34 | end
35 |
36 | idx = 0;
37 | filepaths = dir(fullfile(input_folder,'*.*'));
38 | for i = 1 : length(filepaths)
39 | [paths,imname,ext] = fileparts(filepaths(i).name);
40 | if isempty(imname)
41 | disp('Ignore . folder.');
42 | elseif strcmp(imname, '.')
43 | disp('Ignore .. folder.');
44 | else
45 | idx = idx + 1;
46 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
47 | fprintf(str_rlt);
48 | % read image
49 | img = imread(fullfile(input_folder, [imname, ext]));
50 | img = im2double(img);
51 | % modcrop
52 | img = modcrop(img, mod_scale);
53 | if exist('save_mod_folder', 'var')
54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
55 | end
56 | % LR
57 | im_LR = imresize(img, 1/up_scale, 'bicubic');
58 | if exist('save_LR_folder', 'var')
59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
60 | end
61 | % Bicubic
62 | if exist('save_bic_folder', 'var')
63 | im_B = imresize(im_LR, up_scale, 'bicubic');
64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png']));
65 | end
66 | end
67 | end
68 | end
69 |
70 | %% modcrop
71 | function img = modcrop(img, modulo)
72 | if size(img,3) == 1
73 | sz = size(img);
74 | sz = sz - mod(sz, modulo);
75 | if mod(modulo, 1) == 0
76 | img = img(1:sz(1), 1:sz(2));
77 | end
78 | else
79 | tmpsz = size(img);
80 | sz = tmpsz(1:2);
81 | sz = sz - mod(sz, modulo);
82 | if mod(modulo, 1) == 0
83 | img = img(1:sz(1), 1:sz(2),:);
84 | end
85 | end
86 |
87 | end
88 |
--------------------------------------------------------------------------------
/super-resolution/scripts/python/extract_images.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import numpy as np
5 | from PIL import Image
6 | from multiprocessing import Pool
7 |
8 | def get_arguments():
9 | parser = argparse.ArgumentParser(description='Extract sub images for faster data loading.')
10 | parser.add_argument('--root', default='', required=True, help='source folder')
11 | parser.add_argument('--crop-size', default=256, type=int, help='crop size (default: 256)')
12 | parser.add_argument('--step-size', default=128, type=int, help='step size (default: 128)')
13 | parser.add_argument('--scale', default=4, type=int, help='super-resolution scale (default: 4)')
14 | parser.add_argument('--num-threads', default=8, type=int, help='num of threads (default: 8)')
15 | parser.add_argument('--only-once', default=False, action='store_true', help='center sub image')
16 | args = parser.parse_args()
17 | return args
18 |
19 | def mkdir(save_path):
20 | if not os.path.exists(save_path):
21 | os.makedirs(save_path)
22 |
23 | def mkdirs(args):
24 | mkdir(os.path.join(args.root,'img_sub'))
25 | mkdir(os.path.join(args.root,'img_sub_x{}'.format(args.scale)))
26 |
27 | def get_image_paths(args):
28 | paths = glob.glob(os.path.join(args.root, 'img_x{}'.format(args.scale), '*.*'))
29 | return paths
30 |
31 | def save_images(args, i, path, input, target):
32 | save_path = path.replace('img_x', 'img_sub_x').replace('.png', '_{}.png'.format(i))
33 | input.save(save_path)
34 | save_path = path.replace('img_x{}'.format(args.scale), 'img_sub').replace('.png', '_{}.png'.format(i))
35 | target.save(save_path)
36 |
37 | def load_images(path, args):
38 | input = Image.open(path)
39 | target = Image.open(path.replace('img_x{}'.format(args.scale), 'img'))
40 | return input, target
41 |
42 | def worker(path, args):
43 | input, target = load_images(path, args)
44 |
45 | h, w = input.size
46 | hs = np.arange(0, h - args.crop_size + 1, args.step_size)
47 | ws = np.arange(0, w - args.crop_size + 1, args.step_size)
48 |
49 | hs = np.append(hs, h - (args.crop_size + 1))
50 | ws = np.append(ws, w - (args.crop_size + 1))
51 |
52 | counts = 0
53 |
54 | if args.only_once:
55 | hs = [hs[len(hs) // 2]]
56 | ws = [ws[len(ws) // 2]]
57 |
58 | for x in hs:
59 | for y in ws:
60 | cropped_input = input.crop((x, y, (x + args.crop_size), (y + args.crop_size)))
61 | cropped_target = target.crop((x * args.scale, y * args.scale, (x + args.crop_size) * args.scale, (y + args.crop_size) * args.scale))
62 | save_images(args, counts, path, cropped_input, cropped_target)
63 | counts += 1
64 |
65 | print('Processed {:s}'.format(os.path.basename(path)))
66 |
67 | def main():
68 | args = get_arguments()
69 |
70 | mkdirs(args)
71 | paths = get_image_paths(args)
72 | pool = Pool(args.num_threads)
73 |
74 | for path in paths:
75 | pool.apply_async(worker, args=(path, args))
76 |
77 | pool.close()
78 | pool.join()
79 | print('All subprocesses done.')
80 |
81 | if __name__ == "__main__":
82 | main()
83 |
--------------------------------------------------------------------------------
/super-resolution/scripts/python/remove_images.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | from PIL import Image, ImageStat
5 | from multiprocessing import Pool
6 |
7 | def get_arguments():
8 | parser = argparse.ArgumentParser(description='Remove images in dataset')
9 | parser.add_argument('--root', default='', required=True, help='source folder')
10 | parser.add_argument('--scale', default=4, type=int, help='super-resolution scale (default: 4)')
11 | parser.add_argument('--limit-size', default=50, type=int, help='limit size for input (default: 50)')
12 | parser.add_argument('--num-threads', default=8, type=int, help='num of threads (default: 8)')
13 | args = parser.parse_args()
14 | return args
15 |
16 | def is_grayscale(image):
17 | stat = ImageStat.Stat(image)
18 |
19 | if sum(stat.sum) / 3 == stat.sum[0]:
20 | return True
21 | else:
22 | return False
23 |
24 | def is_small(image, args):
25 | h, w = image.size
26 |
27 | if h < args.limit_size or w < args.limit_size:
28 | return True
29 | else:
30 | return False
31 |
32 | def remove(path):
33 | if os.path.isfile(path):
34 | os.remove(path)
35 |
36 | def get_image_paths(args):
37 | paths = glob.glob(os.path.join(args.root, 'img_x{}'.format(args.scale), '*.*'))
38 | return paths
39 |
40 | def get_paths(path, args):
41 | input = path
42 | target = path.replace('img_x{}'.format(args.scale), 'img')
43 | return input, target
44 |
45 | def worker(path, args):
46 | input, target = get_paths(path, args)
47 |
48 | image = Image.open(input).convert("RGB")
49 |
50 | if is_grayscale(image) or is_small(image, args):
51 | remove(input)
52 | remove(target)
53 |
54 | print('Removed {:s}'.format(os.path.basename(path)))
55 |
56 | def main():
57 | args = get_arguments()
58 |
59 | paths = get_image_paths(args)
60 | pool = Pool(args.num_threads)
61 |
62 | for path in paths:
63 | pool.apply_async(worker, args=(path, args))
64 |
65 | pool.close()
66 | pool.join()
67 | print('All subprocesses done.')
68 |
69 | if __name__ == "__main__":
70 | main()
--------------------------------------------------------------------------------
/super-resolution/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import models
3 | import os
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.optim.lr_scheduler import StepLR
7 | from torch.autograd import grad as torch_grad, Variable
8 | from data import get_loaders
9 | from ast import literal_eval
10 | from utils.recorderx import RecoderX
11 | from utils.misc import save_image, average, mkdir, compute_psnr
12 | from models.modules.losses import RangeLoss, PerceptualLoss, StyleLoss
13 |
14 | class Trainer():
15 | def __init__(self, args):
16 | # parameters
17 | self.args = args
18 | self.print_model = True
19 | self.invalidity_margins = None
20 |
21 | if self.args.use_tb:
22 | self.tb = RecoderX(log_dir=args.save_path)
23 |
24 | # initialize
25 | self._init()
26 |
27 | def _init_model(self):
28 | # initialize model
29 | if self.args.model_config != '':
30 | model_config = dict({}, **literal_eval(self.args.model_config))
31 | else:
32 | model_config = {}
33 |
34 | g_model = models.__dict__[self.args.g_model]
35 | d_model = models.__dict__[self.args.d_model]
36 | self.g_model = g_model(**model_config)
37 | self.d_model = d_model(**model_config)
38 |
39 | # loading weights
40 | if self.args.gen_to_load != '':
41 | logging.info('\nLoading g-model...')
42 | self.g_model.load_state_dict(torch.load(self.args.gen_to_load, map_location='cpu'))
43 | if self.args.dis_to_load != '':
44 | logging.info('\nLoading d-model...')
45 | self.d_model.load_state_dict(torch.load(self.args.dis_to_load, map_location='cpu'))
46 |
47 | # to cuda
48 | self.g_model = self.g_model.to(self.args.device)
49 | self.d_model = self.d_model.to(self.args.device)
50 |
51 | # parallel
52 | if self.args.device_ids and len(self.args.device_ids) > 1:
53 | self.g_model = torch.nn.DataParallel(self.g_model, self.args.device_ids)
54 | self.d_model = torch.nn.DataParallel(self.d_model, self.args.device_ids)
55 |
56 | # print model
57 | if self.print_model:
58 | logging.info(self.g_model)
59 | logging.info('Number of parameters in generator: {}\n'.format(sum([l.nelement() for l in self.g_model.parameters()])))
60 | logging.info(self.d_model)
61 | logging.info('Number of parameters in discriminator: {}\n'.format(sum([l.nelement() for l in self.d_model.parameters()])))
62 | self.print_model = False
63 |
64 | def _init_optim(self):
65 | # initialize optimizer
66 | self.g_optimizer = torch.optim.Adam(self.g_model.parameters(), lr=self.args.lr, betas=self.args.gen_betas)
67 | self.d_optimizer = torch.optim.Adam(self.d_model.parameters(), lr=self.args.lr, betas=self.args.dis_betas)
68 |
69 | # initialize scheduler
70 | self.g_scheduler = StepLR(self.g_optimizer, step_size=self.args.step_size, gamma=self.args.gamma)
71 | self.d_scheduler = StepLR(self.d_optimizer, step_size=self.args.step_size, gamma=self.args.gamma)
72 |
73 | # initialize criterion
74 | if self.args.reconstruction_weight:
75 | self.reconstruction = torch.nn.L1Loss().to(self.args.device)
76 | if self.args.perceptual_weight > 0.:
77 | self.perceptual = PerceptualLoss(features_to_compute=['conv5_4'], criterion=torch.nn.L1Loss(), shave_edge=self.invalidity_margins).to(self.args.device)
78 | if self.args.style_weight > 0.:
79 | self.style = StyleLoss(features_to_compute=['relu3_1', 'relu2_1']).to(self.args.device)
80 | if self.args.range_weight > 0.:
81 | self.range = RangeLoss(invalidity_margins=self.invalidity_margins).to(self.args.device)
82 |
83 | def _init(self):
84 | # init parameters
85 | self.step = 0
86 | self.losses = {'D': [], 'D_r': [], 'D_gp': [], 'D_f': [], 'G': [], 'G_recon': [], 'G_rng': [], 'G_perc': [], 'G_sty': [], 'G_adv': [], 'psnr': []}
87 |
88 | # initialize model
89 | self._init_model()
90 |
91 | # initialize optimizer
92 | self._init_optim()
93 |
94 | def _save_model(self, epoch):
95 | # save models
96 | torch.save(self.g_model.state_dict(), os.path.join(self.args.save_path, '{}_e{}.pt'.format(self.args.g_model, epoch + 1)))
97 | torch.save(self.d_model.state_dict(), os.path.join(self.args.save_path, '{}_e{}.pt'.format(self.args.d_model, epoch + 1)))
98 |
99 | def _set_require_grads(self, model, require_grad):
100 | for p in model.parameters():
101 | p.requires_grad_(require_grad)
102 |
103 | def _critic_hinge_iteration(self, inputs, targets):
104 | # require grads
105 | self._set_require_grads(self.d_model, True)
106 |
107 | # get generated data
108 | generated_data = self.g_model(inputs)
109 |
110 | # zero grads
111 | self.d_optimizer.zero_grad()
112 |
113 | # calculate probabilities on real and generated data
114 | d_real = self.d_model(targets)
115 | d_generated = self.d_model(generated_data.detach())
116 |
117 | # create total loss and optimize
118 | loss_r = F.relu(1.0 - d_real).mean()
119 | loss_f = F.relu(1.0 + d_generated).mean()
120 | loss = loss_r + loss_f
121 |
122 | # get gradient penalty
123 | if self.args.penalty_weight > 0.:
124 | gradient_penalty = self._gradient_penalty(targets, generated_data)
125 | loss += gradient_penalty
126 |
127 | loss.backward()
128 |
129 | self.d_optimizer.step()
130 |
131 | # record loss
132 | self.losses['D'].append(loss.data.item())
133 | self.losses['D_r'].append(loss_r.data.item())
134 | self.losses['D_f'].append(loss_f.data.item())
135 | if self.args.penalty_weight > 0.:
136 | self.losses['D_gp'].append(gradient_penalty.data.item())
137 |
138 | # require grads
139 | self._set_require_grads(self.d_model, False)
140 |
141 | def _critic_wgan_iteration(self, inputs, targets):
142 | # require grads
143 | self._set_require_grads(self.d_model, True)
144 |
145 | # get generated data
146 | generated_data = self.g_model(inputs)
147 |
148 | # zero grads
149 | self.d_optimizer.zero_grad()
150 |
151 | # calculate probabilities on real and generated data
152 | d_real = self.d_model(targets)
153 | d_generated = self.d_model(generated_data.detach())
154 |
155 | # create total loss and optimize
156 | if self.args.relativistic:
157 | loss_r = -(d_real - d_generated.mean()).mean()
158 | loss_f = (d_generated - d_real.mean()).mean()
159 | else:
160 | loss_r = -d_real.mean()
161 | loss_f = d_generated.mean()
162 | loss = loss_f + loss_r
163 |
164 | # get gradient penalty
165 | if self.args.penalty_weight > 0.:
166 | gradient_penalty = self._gradient_penalty(targets, generated_data)
167 | loss += gradient_penalty
168 |
169 | loss.backward()
170 |
171 | self.d_optimizer.step()
172 |
173 | # record loss
174 | self.losses['D'].append(loss.data.item())
175 | self.losses['D_r'].append(loss_r.data.item())
176 | self.losses['D_f'].append(loss_f.data.item())
177 | if self.args.penalty_weight > 0.:
178 | self.losses['D_gp'].append(gradient_penalty.data.item())
179 |
180 | # require grads
181 | self._set_require_grads(self.d_model, False)
182 |
183 | def _gradient_penalty(self, real_data, generated_data):
184 | batch_size = real_data.size(0)
185 |
186 | # calculate interpolation
187 | alpha = torch.rand(batch_size, 1, 1, 1)
188 | alpha = alpha.expand_as(real_data)
189 | alpha = alpha.to(self.args.device)
190 | interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
191 | interpolated = Variable(interpolated, requires_grad=True)
192 | interpolated = interpolated.to(self.args.device)
193 |
194 | # calculate probability of interpolated examples
195 | prob_interpolated = self.d_model(interpolated)
196 |
197 | # calculate gradients of probabilities with respect to examples
198 | gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
199 | grad_outputs=torch.ones(prob_interpolated.size()).to(self.args.device),
200 | create_graph=True, retain_graph=True)[0]
201 |
202 | # gradients have shape (batch_size, num_channels, img_width, img_height),
203 | # so flatten to easily take norm per example in batch
204 | gradients = gradients.view(batch_size, -1)
205 |
206 | # derivatives of the gradient close to 0 can cause problems because of
207 | # the square root, so manually calculate norm and add epsilon
208 | gradients_norm = gradients.norm(p=2, dim=1)
209 |
210 | # return gradient penalty
211 | return ((gradients_norm - 1) ** 2).mean()
212 |
213 | def _generator_iteration(self, inputs, targets):
214 | # zero grads
215 | self.g_optimizer.zero_grad()
216 |
217 | # get generated data
218 | generated_data = self.g_model(inputs)
219 | loss = 0.
220 |
221 | # reconstruction loss
222 | if self.args.reconstruction_weight > 0.:
223 | loss_recon = self.reconstruction(generated_data, targets)
224 | loss += loss_recon * self.args.reconstruction_weight
225 | self.losses['G_recon'].append(loss_recon.data.item())
226 |
227 | # range loss
228 | if self.args.range_weight > 0.:
229 | loss_rng = self.range(generated_data)
230 | loss += loss_rng * self.args.range_weight
231 | self.losses['G_rng'].append(loss_rng.data.item())
232 |
233 | # adversarial loss
234 | if self.args.adversarial_weight > 0.:
235 | d_generated = self.d_model(generated_data)
236 | if self.args.relativistic:
237 | d_real = self.d_model(targets)
238 | loss_adv = (d_real - d_generated.mean()).mean() - (d_generated - d_real.mean()).mean()
239 | else:
240 | loss_adv = -d_generated.mean()
241 | loss += loss_adv * self.args.adversarial_weight
242 | self.losses['G_adv'].append(loss_adv.data.item())
243 |
244 | # perceptual loss
245 | if self.args.perceptual_weight > 0.:
246 | loss_perc = self.perceptual(generated_data, targets)
247 | loss += loss_perc * self.args.perceptual_weight
248 | self.losses['G_perc'].append(loss_perc.data.item())
249 |
250 | # style loss
251 | if self.args.style_weight > 0.:
252 | loss_sty = self.style(generated_data, targets)
253 | loss += loss_sty * self.args.style_weight
254 | self.losses['G_sty'].append(loss_sty.data.item())
255 |
256 | # backward loss
257 | loss.backward()
258 | self.g_optimizer.step()
259 |
260 | # record loss
261 | self.losses['G'].append(loss.data.item())
262 |
263 | def _train_iteration(self, data):
264 | # set inputs
265 | inputs = data['input'].to(self.args.device)
266 | targets = data['target'].to(self.args.device)
267 |
268 | # critic iteration
269 | if self.args.adversarial_weight > 0.:
270 | if self.args.wgan or self.args.relativistic:
271 | self._critic_wgan_iteration(inputs, targets)
272 | else:
273 | self._critic_hinge_iteration(inputs, targets)
274 |
275 | # only update generator every |critic_iterations| iterations
276 | if self.step % self.args.num_critic == 0:
277 | self._generator_iteration(inputs, targets)
278 |
279 | # logging
280 | if self.step % self.args.print_every == 0:
281 | line2print = 'Iteration {}'.format(self.step)
282 | if self.args.adversarial_weight > 0.:
283 | line2print += ', D: {:.6f}, D_r: {:.6f}, D_f: {:.6f}'.format(self.losses['D'][-1], self.losses['D_r'][-1], self.losses['D_f'][-1])
284 | if self.args.penalty_weight > 0.:
285 | line2print += ', D_gp: {:.6f}'.format(self.losses['D_gp'][-1])
286 | if self.step > self.args.num_critic:
287 | line2print += ', G: {:.5f}'.format(self.losses['G'][-1])
288 | if self.args.reconstruction_weight:
289 | line2print += ', G_recon: {:.6f}'.format(self.losses['G_recon'][-1])
290 | if self.args.range_weight:
291 | line2print += ', G_rng: {:.6f}'.format(self.losses['G_rng'][-1])
292 | if self.args.perceptual_weight:
293 | line2print += ', G_perc: {:.6f}'.format(self.losses['G_perc'][-1])
294 | if self.args.style_weight:
295 | line2print += ', G_sty: {:.8f}'.format(self.losses['G_sty'][-1])
296 | if self.args.adversarial_weight:
297 | line2print += ', G_adv: {:.6f},'.format(self.losses['G_adv'][-1])
298 | logging.info(line2print)
299 |
300 | # plots for tensorboard
301 | if self.args.use_tb:
302 | if self.args.adversarial_weight > 0.:
303 | self.tb.add_scalar('data/loss_d', self.losses['D'][-1], self.step)
304 | if self.step > self.args.num_critic:
305 | self.tb.add_scalar('data/loss_g', self.losses['G'][-1], self.step)
306 |
307 | def _eval_iteration(self, data, epoch):
308 | # set inputs
309 | inputs = data['input'].to(self.args.device)
310 | targets = data['target']
311 | paths = data['path']
312 |
313 | # evaluation
314 | with torch.no_grad():
315 | outputs = self.g_model(inputs)
316 |
317 | # save image and compute psnr
318 | self._save_image(outputs, paths[0], epoch + 1)
319 | psnr = compute_psnr(outputs, targets, self.args.scale)
320 |
321 | return psnr
322 |
323 | def _train_epoch(self, loader):
324 | self.g_model.train()
325 | self.d_model.train()
326 |
327 | # train over epochs
328 | for _, data in enumerate(loader):
329 | self._train_iteration(data)
330 | self.step += 1
331 |
332 | def _eval_epoch(self, loader, epoch):
333 | self.g_model.eval()
334 | psnrs = []
335 |
336 | # eval over epoch
337 | for _, data in enumerate(loader):
338 | psnr = self._eval_iteration(data, epoch)
339 | psnrs.append(psnr)
340 |
341 | # record psnr
342 | self.losses['psnr'].append(average(psnrs))
343 | logging.info('Evaluation: {:.3f}'.format(self.losses['psnr'][-1]))
344 | if self.args.use_tb:
345 | self.tb.add_scalar('data/psnr', self.losses['psnr'][-1], epoch)
346 |
347 | def _save_image(self, image, path, epoch):
348 | directory = os.path.join(self.args.save_path, 'images', 'epoch_{}'.format(epoch))
349 | save_path = os.path.join(directory, os.path.basename(path))
350 | mkdir(directory)
351 | save_image(image.data.cpu(), save_path)
352 |
353 | def _train(self, loaders):
354 | # run epoch iterations
355 | for epoch in range(self.args.epochs):
356 | logging.info('\nEpoch {}'.format(epoch + 1))
357 |
358 | # train
359 | self._train_epoch(loaders['train'])
360 |
361 | # scheduler
362 | self.g_scheduler.step(epoch=epoch)
363 | self.d_scheduler.step(epoch=epoch)
364 |
365 | # evaluation
366 | if ((epoch + 1) % self.args.eval_every == 0) or ((epoch + 1) == self.args.epochs):
367 | self._eval_epoch(loaders['eval'], epoch)
368 | self._save_model(epoch)
369 |
370 | # best score
371 | logging.info('Best PSNR Score: {:.2f}\n'.format(max(self.losses['psnr'])))
372 |
373 | def train(self):
374 | # get loader
375 | loaders = get_loaders(self.args)
376 |
377 | # run training
378 | self._train(loaders)
379 |
380 | # close tensorboard
381 | if self.args.use_tb:
382 | self.tb.close()
383 |
384 | def eval(self):
385 | # get loader
386 | loaders = get_loaders(self.args)
387 |
388 | # evaluation
389 | logging.info('\nEvaluating...')
390 | self._eval_epoch(loaders['eval'], 0)
391 |
392 | # close tensorboard
393 | if self.args.use_tb:
394 | self.tb.close()
--------------------------------------------------------------------------------
/super-resolution/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/super-resolution/utils/__init__.py
--------------------------------------------------------------------------------
/super-resolution/utils/core.py:
--------------------------------------------------------------------------------
1 | '''
2 | https://github.com/thstkdgus35/bicubic_pytorch/blob/master/core.py
3 |
4 | A standalone PyTorch implementation for fast and efficient bicubic resampling.
5 | The resulting values are the same to MATLAB function imresize('bicubic').
6 | ## Author: Sanghyun Son
7 | ## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary)
8 | ## Version: 1.1.0
9 | ## Last update: July 9th, 2020 (KST)
10 | Depencency: torch
11 | Example::
12 | >>> import torch
13 | >>> import core
14 | >>> x = torch.arange(16).float().view(1, 1, 4, 4)
15 | >>> y = core.imresize(x, sides=(3, 3))
16 | >>> print(y)
17 | tensor([[[[ 0.7506, 2.1004, 3.4503],
18 | [ 6.1505, 7.5000, 8.8499],
19 | [11.5497, 12.8996, 14.2494]]]])
20 | '''
21 |
22 | import math
23 | import typing
24 |
25 | import torch
26 | from torch.nn import functional as F
27 |
28 | __all__ = ['imresize']
29 |
30 | K = typing.TypeVar('K', str, torch.Tensor)
31 |
32 | def cubic_contribution(x: torch.Tensor, a: float=-0.5) -> torch.Tensor:
33 | ax = x.abs()
34 | ax2 = ax * ax
35 | ax3 = ax * ax2
36 |
37 | range_01 = (ax <= 1)
38 | range_12 = (ax > 1) * (ax <= 2)
39 |
40 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
41 | cont_01 = cont_01 * range_01.to(dtype=x.dtype)
42 |
43 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
44 | cont_12 = cont_12 * range_12.to(dtype=x.dtype)
45 |
46 | cont = cont_01 + cont_12
47 | cont = cont / cont.sum()
48 | return cont
49 |
50 | def gaussian_contribution(x: torch.Tensor, sigma: float=2.0) -> torch.Tensor:
51 | range_3sigma = (x.abs() <= 3 * sigma + 1)
52 | # Normalization will be done after
53 | cont = torch.exp(-x.pow(2) / (2 * sigma**2))
54 | cont = cont * range_3sigma.to(dtype=x.dtype)
55 | return cont
56 |
57 | def discrete_kernel(
58 | kernel: str, scale: float, antialiasing: bool=True) -> torch.Tensor:
59 |
60 | '''
61 | For downsampling with integer scale only.
62 | '''
63 | downsampling_factor = int(1 / scale)
64 | if kernel == 'cubic':
65 | kernel_size_orig = 4
66 | else:
67 | raise ValueError('Pass!')
68 |
69 | if antialiasing:
70 | kernel_size = kernel_size_orig * downsampling_factor
71 | else:
72 | kernel_size = kernel_size_orig
73 |
74 | if downsampling_factor % 2 == 0:
75 | a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
76 | else:
77 | kernel_size -= 1
78 | a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
79 |
80 | with torch.no_grad():
81 | r = torch.linspace(-a, a, steps=kernel_size)
82 | k = cubic_contribution(r).view(-1, 1)
83 | k = torch.matmul(k, k.t())
84 | k /= k.sum()
85 |
86 | return k
87 |
88 | def reflect_padding(
89 | x: torch.Tensor,
90 | dim: int,
91 | pad_pre: int,
92 | pad_post: int) -> torch.Tensor:
93 |
94 | '''
95 | Apply reflect padding to the given Tensor.
96 | Note that it is slightly different from the PyTorch functional.pad,
97 | where boundary elements are used only once.
98 | Instead, we follow the MATLAB implementation
99 | which uses boundary elements twice.
100 | For example,
101 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
102 | while our implementation yields [a, a, b, c, d, d].
103 | '''
104 | b, c, h, w = x.size()
105 | if dim == 2 or dim == -2:
106 | padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
107 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
108 | for p in range(pad_pre):
109 | padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
110 | for p in range(pad_post):
111 | padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
112 | else:
113 | padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
114 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
115 | for p in range(pad_pre):
116 | padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
117 | for p in range(pad_post):
118 | padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
119 |
120 | return padding_buffer
121 |
122 | def padding(
123 | x: torch.Tensor,
124 | dim: int,
125 | pad_pre: int,
126 | pad_post: int,
127 | padding_type: str='reflect') -> torch.Tensor:
128 |
129 | if padding_type == 'reflect':
130 | x_pad = reflect_padding(x, dim, pad_pre, pad_post)
131 | else:
132 | raise ValueError('{} padding is not supported!'.format(padding_type))
133 |
134 | return x_pad
135 |
136 | def get_padding(
137 | base: torch.Tensor,
138 | kernel_size: int,
139 | x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
140 |
141 | base = base.long()
142 | r_min = base.min()
143 | r_max = base.max() + kernel_size - 1
144 |
145 | if r_min <= 0:
146 | pad_pre = -r_min
147 | pad_pre = pad_pre.item()
148 | base += pad_pre
149 | else:
150 | pad_pre = 0
151 |
152 | if r_max >= x_size:
153 | pad_post = r_max - x_size + 1
154 | pad_post = pad_post.item()
155 | else:
156 | pad_post = 0
157 |
158 | return pad_pre, pad_post, base
159 |
160 | def get_weight(
161 | dist: torch.Tensor,
162 | kernel_size: int,
163 | kernel: str='cubic',
164 | sigma: float=2.0,
165 | antialiasing_factor: float=1) -> torch.Tensor:
166 |
167 | buffer_pos = dist.new_zeros(kernel_size, len(dist))
168 | for idx, buffer_sub in enumerate(buffer_pos):
169 | buffer_sub.copy_(dist - idx)
170 |
171 | # Expand (downsampling) / Shrink (upsampling) the receptive field.
172 | buffer_pos *= antialiasing_factor
173 | if kernel == 'cubic':
174 | weight = cubic_contribution(buffer_pos)
175 | elif kernel == 'gaussian':
176 | weight = gaussian_contribution(buffer_pos, sigma=sigma)
177 | else:
178 | raise ValueError('{} kernel is not supported!'.format(kernel))
179 |
180 | weight /= weight.sum(dim=0, keepdim=True)
181 | return weight
182 |
183 | def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
184 | # Resize height
185 | if dim == 2 or dim == -2:
186 | k = (kernel_size, 1)
187 | h_out = x.size(-2) - kernel_size + 1
188 | w_out = x.size(-1)
189 | # Resize width
190 | else:
191 | k = (1, kernel_size)
192 | h_out = x.size(-2)
193 | w_out = x.size(-1) - kernel_size + 1
194 |
195 | unfold = F.unfold(x, k)
196 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
197 | return unfold
198 |
199 | def resize_1d(
200 | x: torch.Tensor,
201 | dim: int,
202 | side: int=None,
203 | kernel: str='cubic',
204 | sigma: float=2.0,
205 | padding_type: str='reflect',
206 | antialiasing: bool=True) -> torch.Tensor:
207 |
208 | '''
209 | Args:
210 | x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
211 | dim (int):
212 | scale (float):
213 | side (int):
214 | Return:
215 | '''
216 | scale = side / x.size(dim)
217 | # Identity case
218 | if scale == 1:
219 | return x
220 |
221 | # Default bicubic kernel with antialiasing (only when downsampling)
222 | if kernel == 'cubic':
223 | kernel_size = 4
224 | else:
225 | kernel_size = math.floor(6 * sigma)
226 |
227 | if antialiasing and (scale < 1):
228 | antialiasing_factor = scale
229 | kernel_size = math.ceil(kernel_size / antialiasing_factor)
230 | else:
231 | antialiasing_factor = 1
232 |
233 | # We allow margin to both sides
234 | kernel_size += 2
235 |
236 | # Weights only depend on the shape of input and output,
237 | # so we do not calculate gradients here.
238 | with torch.no_grad():
239 | d = 1 / (2 * side)
240 | pos = torch.linspace(
241 | start=d,
242 | end=(1 - d),
243 | steps=side,
244 | dtype=x.dtype,
245 | device=x.device,
246 | )
247 | pos = x.size(dim) * pos - 0.5
248 | base = pos.floor() - (kernel_size // 2) + 1
249 | dist = pos - base
250 | weight = get_weight(
251 | dist,
252 | kernel_size,
253 | kernel=kernel,
254 | sigma=sigma,
255 | antialiasing_factor=antialiasing_factor,
256 | )
257 | pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
258 |
259 | # To backpropagate through x
260 | x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
261 | unfold = reshape_tensor(x_pad, dim, kernel_size)
262 | # Subsampling first
263 | if dim == 2 or dim == -2:
264 | sample = unfold[..., base, :]
265 | weight = weight.view(1, kernel_size, sample.size(2), 1)
266 | else:
267 | sample = unfold[..., base]
268 | weight = weight.view(1, kernel_size, 1, sample.size(3))
269 |
270 | # Apply the kernel
271 | down = sample * weight
272 | down = down.sum(dim=1, keepdim=True)
273 | return down
274 |
275 | def downsampling_2d(
276 | x: torch.Tensor,
277 | k: torch.Tensor,
278 | scale: int,
279 | padding_type: str='reflect') -> torch.Tensor:
280 |
281 | c = x.size(1)
282 | k_h = k.size(-2)
283 | k_w = k.size(-1)
284 |
285 | k = k.to(dtype=x.dtype, device=x.device)
286 | k = k.view(1, 1, k_h, k_w)
287 | k = k.repeat(c, c, 1, 1)
288 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
289 | e = e.view(c, c, 1, 1)
290 | k = k * e
291 |
292 | pad_h = (k_h - scale) // 2
293 | pad_w = (k_w - scale) // 2
294 | x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
295 | x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
296 | y = F.conv2d(x, k, padding=0, stride=scale)
297 | return y
298 |
299 | def imresize(
300 | x: torch.Tensor,
301 | scale: float=None,
302 | sides: typing.Tuple[int, int]=None,
303 | kernel: K='cubic',
304 | sigma: float=2,
305 | rotation_degree: float=0,
306 | padding_type: str='reflect',
307 | antialiasing: bool=True) -> torch.Tensor:
308 |
309 | '''
310 | Args:
311 | x (torch.Tensor):
312 | scale (float):
313 | sides (tuple(int, int)):
314 | kernel (str, default='cubic'):
315 | sigma (float, default=2):
316 | rotation_degree (float, default=0):
317 | padding_type (str, default='reflect'):
318 | antialiasing (bool, default=True):
319 | Return:
320 | torch.Tensor:
321 | '''
322 |
323 | if scale is None and sides is None:
324 | raise ValueError('One of scale or sides must be specified!')
325 | if scale is not None and sides is not None:
326 | raise ValueError('Please specify scale or sides to avoid conflict!')
327 |
328 | if x.dim() == 4:
329 | b, c, h, w = x.size()
330 | elif x.dim() == 3:
331 | c, h, w = x.size()
332 | b = None
333 | elif x.dim() == 2:
334 | h, w = x.size()
335 | b = c = None
336 | else:
337 | raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
338 |
339 | x = x.view(-1, 1, h, w)
340 |
341 | if sides is None:
342 | # Determine output size
343 | sides = (math.ceil(h * scale), math.ceil(w * scale))
344 | scale_inv = 1 / scale
345 | if isinstance(kernel, str) and scale_inv.is_integer():
346 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
347 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
348 | raise ValueError(
349 | 'An integer downsampling factor '
350 | 'should be used with a predefined kernel!'
351 | )
352 |
353 | if x.dtype != torch.float32 or x.dtype != torch.float64:
354 | dtype = x.dtype
355 | x = x.float()
356 | else:
357 | dtype = None
358 |
359 | if isinstance(kernel, str):
360 | # Shared keyword arguments across dimensions
361 | kwargs = {
362 | 'kernel': kernel,
363 | 'sigma': sigma,
364 | 'padding_type': padding_type,
365 | 'antialiasing': antialiasing,
366 | }
367 | # Core resizing module
368 | x = resize_1d(x, -2, side=sides[0], **kwargs)
369 | x = resize_1d(x, -1, side=sides[1], **kwargs)
370 | elif isinstance(kernel, torch.Tensor):
371 | x = downsampling_2d(x, kernel, scale=int(1 / scale))
372 |
373 | rh = x.size(-2)
374 | rw = x.size(-1)
375 | # Back to the original dimension
376 | if b is not None:
377 | x = x.view(b, c, rh, rw) # 4-dim
378 | else:
379 | if c is not None:
380 | x = x.view(c, rh, rw) # 3-dim
381 | else:
382 | x = x.view(rh, rw) # 2-dim
383 |
384 | if dtype is not None:
385 | if not dtype.is_floating_point:
386 | x = x.round()
387 | # To prevent over/underflow when converting types
388 | if dtype is torch.uint8:
389 | x = x.clamp(0, 255)
390 |
391 | x = x.to(dtype=dtype)
392 |
393 | return x
394 |
395 | if __name__ == '__main__':
396 | # Just for debugging
397 | torch.set_printoptions(precision=4, sci_mode=False, edgeitems=16, linewidth=200)
398 | a = torch.arange(64).float().view(1, 1, 8, 8)
399 | z = imresize(a, 0.5)
400 | print(z)
401 | #a = torch.arange(16).float().view(1, 1, 4, 4)
402 | '''
403 | a = torch.zeros(1, 1, 4, 4)
404 | a[..., 0, 0] = 100
405 | a[..., 1, 0] = 10
406 | a[..., 0, 1] = 1
407 | a[..., 0, -1] = 100
408 | a = torch.zeros(1, 1, 4, 4)
409 | a[..., -1, -1] = 100
410 | a[..., -2, -1] = 10
411 | a[..., -1, -2] = 1
412 | a[..., -1, 0] = 100
413 | '''
414 | #b = imresize(a, sides=(3, 8), antialiasing=False)
415 | #c = imresize(a, sides=(11, 13), antialiasing=True)
416 | #c = imresize(a, sides=(4, 4), antialiasing=False, kernel='gaussian', sigma=1)
417 | #print(a)
418 | #print(b)
419 | #print(c)
420 |
421 | #r = discrete_kernel('cubic', 1 / 3)
422 | #print(r)
423 | '''
424 | a = torch.arange(225).float().view(1, 1, 15, 15)
425 | imresize(a, sides=[5, 5])
426 | '''
--------------------------------------------------------------------------------
/super-resolution/utils/misc.py:
--------------------------------------------------------------------------------
1 | import logging.config
2 | import os
3 | import matplotlib.pyplot as plt
4 | import torch
5 | import math
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision.utils import make_grid
9 | from os import path, makedirs
10 |
11 | def mkdir(save_path):
12 | if not path.exists(save_path):
13 | makedirs(save_path)
14 |
15 | def make_image_grid(x, nrow, padding=0, pad_value=0):
16 | x = x.clone().cpu().data
17 | grid = make_grid(x, nrow=nrow, padding=padding, normalize=True, scale_each=False, pad_value=pad_value)
18 | return grid
19 |
20 | def tensor_to_image(x):
21 | ndarr = x.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
22 | image = Image.fromarray(ndarr)
23 | return image
24 |
25 | def plot_image_grid(x, nrow, padding=0):
26 | grid = make_image_grid(x=x, nrow=nrow, padding=padding).permute(1, 2, 0).numpy()
27 | plt.imshow(grid)
28 | plt.show()
29 |
30 | def save_image(x, path, size=None):
31 | image = tensor_to_image(x)
32 | if size:
33 | image = image.resize((size, size), Image.NEAREST)
34 | image.save(path)
35 |
36 | def save_image_grid(x, path, nrow=8, size=None):
37 | grid = make_image_grid(x, nrow)
38 | save_image(grid, path, size=size)
39 |
40 | def setup_logging(log_file='log.txt', resume=False, dummy=False):
41 | if dummy:
42 | logging.getLogger('dummy')
43 | else:
44 | if os.path.isfile(log_file) and resume:
45 | file_mode = 'a'
46 | else:
47 | file_mode = 'w'
48 |
49 | root_logger = logging.getLogger()
50 | if root_logger.handlers:
51 | root_logger.removeHandler(root_logger.handlers[0])
52 | logging.basicConfig(level=logging.INFO,
53 | format="%(asctime)s - %(levelname)s - %(message)s",
54 | datefmt="%Y-%m-%d %H:%M:%S",
55 | filename=log_file,
56 | filemode=file_mode)
57 | console = logging.StreamHandler()
58 | console.setLevel(logging.INFO)
59 | formatter = logging.Formatter('%(message)s')
60 | console.setFormatter(formatter)
61 | logging.getLogger('').addHandler(console)
62 |
63 | def average(lst):
64 | return sum(lst) / len(lst)
65 |
66 | def rgb2yc(img):
67 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
68 | rlt = rlt.round()
69 | return rlt
70 |
71 | def compute_psnr(x, y, scale):
72 | x = rgb2yc(tensor_to_image(x)).astype(np.float64)
73 | y = rgb2yc(tensor_to_image(y)).astype(np.float64)
74 |
75 | x = x[round(scale):-round(scale), round(scale):-round(scale)]
76 | y = y[round(scale):-round(scale), round(scale):-round(scale)]
77 |
78 | mse = np.mean((x - y) ** 2)
79 | return 20 * math.log10(255.0 / math.sqrt(mse))
80 |
81 | if __name__ == "__main__":
82 | print('None')
83 |
84 |
--------------------------------------------------------------------------------
/super-resolution/utils/optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | def get_exp_scheduler_with_warmup(optimizer, rampup_steps=5, sustain_steps=5):
5 | def lr_lambda(step):
6 | if step < rampup_steps:
7 | return min(1., 1.8 ** ((step - rampup_steps)))
8 | elif step < rampup_steps + sustain_steps:
9 | return 1.
10 | else:
11 | return max(0.1, 0.85 ** (step - rampup_steps - sustain_steps))
12 |
13 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
14 |
15 | def get_cosine_scheduler_with_warmup(optimizer, rampup_steps=3, sustain_steps=2, priod=5):
16 | def lr_lambda(step):
17 | if step < rampup_steps:
18 | return min(1., 1.5 ** ((step - rampup_steps)))
19 | elif step < rampup_steps + sustain_steps:
20 | return 1.
21 | else:
22 | return max(0.1, (1 + math.cos(math.pi * (step - rampup_steps - sustain_steps) / priod)) / 2)
23 |
24 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
25 |
26 | if __name__ == "__main__":
27 | import matplotlib.pyplot as plt
28 |
29 | model = torch.nn.Linear(2, 1)
30 | optimizer = torch.optim.SGD(model.parameters(), lr=2.5e-4)
31 | lr_scheduler = get_exp_scheduler_with_warmup(optimizer)
32 |
33 | lrs = []
34 | for i in range(25):
35 | lr_scheduler.step()
36 | lrs.append(optimizer.param_groups[0]["lr"])
37 |
38 | plt.plot(lrs)
39 | plt.show()
40 | print(lrs)
41 |
--------------------------------------------------------------------------------
/super-resolution/utils/recorderx.py:
--------------------------------------------------------------------------------
1 | import utils.misc as misc
2 | from tensorboardX import SummaryWriter
3 |
4 | class RecoderX():
5 | def __init__(self, log_dir):
6 | self.log_dir = log_dir
7 | self.writer = SummaryWriter(logdir=log_dir)
8 | self.log = ''
9 |
10 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
11 | self.writer.add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step, walltime=walltime)
12 |
13 | def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
14 | self.writer.add_scalars(main_tag=main_tag, tag_scalar_dict=tag_scalar_dict, global_step=global_step, walltime=walltime)
15 |
16 | def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
17 | self.writer.add_image(tag=tag, img_tensor=img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)
18 |
19 | def add_image_grid(self, tag, img_tensor, nrow=8, global_step=None, padding=0, pad_value=0,walltime=None, dataformats='CHW'):
20 | grid = misc.make_image_grid(img_tensor, nrow, padding=padding, pad_value=pad_value)
21 | self.writer.add_image(tag=tag, img_tensor=grid, global_step=global_step, walltime=walltime, dataformats=dataformats)
22 |
23 | def add_graph(self, graph_profile, walltime=None):
24 | self.writer.add_graph(graph_profile, walltime=walltime)
25 |
26 | def add_histogram(self, tag, values, global_step=None):
27 | self.writer.add_histogram(tag, values, global_step)
28 |
29 | def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
30 | self.writer.add_figure(tag, figure, global_step=global_step, close=close, walltime=walltime)
31 |
32 | def export_json(self, out_file):
33 | self.writer.export_scalars_to_json(out_file)
34 |
35 | def close(self):
36 | self.writer.close()
37 |
38 | if __name__ == "__main__":
39 | print('None')
40 |
--------------------------------------------------------------------------------
/super-resolution/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | global visualization, hooks
3 |
4 | visualization = {}
5 | hooks = []
6 |
7 | def hook_act(m, i, o):
8 | min, _ = o.flatten(start_dim=2).min(dim=-1)
9 | max, _ = o.flatten(start_dim=2).max(dim=-1)
10 | visualization[m] = (o - min.view(o.size(0), -1, 1, 1)) / (max.view(o.size(0), -1, 1, 1) - min.view(o.size(0), -1, 1, 1) + 1e-8)
11 |
12 | def hook_forward_output(m, i, o):
13 | visualization[m] = o
14 |
15 | def hook_forward_norm(m, i, o):
16 | visualization[m] = o.flatten(start_dim=1).norm(p=2, dim=1).mean()
17 |
18 | def hook_backward_norm(m, i, o):
19 | visualization[m] = i[1].flatten(start_dim=1).norm(p=2, dim=1).mean() # i[1]: weights' grad
20 |
21 | def backward_norms_hook(model, instance=nn.Conv2d):
22 | for name, layer in model._modules.items():
23 | if isinstance(layer, nn.Sequential):
24 | backward_norms_hook(layer, instance)
25 | elif isinstance(layer, instance):
26 | hooks.append(layer.register_backward_hook(hook_backward_norm))
27 |
28 | def forward_output_hook(model, instance=(nn.ReLU, nn.LeakyReLU, nn.Linear)):
29 | for name, layer in model._modules.items():
30 | if isinstance(layer, (nn.Sequential)):
31 | forward_output_hook(layer, instance)
32 | elif isinstance(layer, instance):
33 | layer.register_forward_hook(hook_forward_output)
34 |
35 | def forward_activations_hook(model, instance=(nn.ReLU, nn.LeakyReLU)):
36 | for name, layer in model._modules.items():
37 | if isinstance(layer, (nn.Sequential)):
38 | forward_activations_hook(layer, instance)
39 | elif isinstance(layer, instance):
40 | layer.register_forward_hook(hook_act)
41 |
42 | def get_visualization():
43 | return visualization
44 |
45 | def remove_hooks():
46 | for i, hook in enumerate(hooks):
47 | hook.remove()
48 | # del hooks[i]
49 |
--------------------------------------------------------------------------------