├── .gitignore ├── README.md ├── denoising ├── data │ ├── __init__.py │ └── datasets.py ├── main.py ├── models │ ├── __init__.py │ ├── dncnn.py │ ├── modules │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── losses.py │ │ └── misc.py │ ├── vgg.py │ ├── xdense.py │ └── xdncnn.py ├── trainer.py └── utils │ ├── __init__.py │ ├── misc.py │ ├── optim.py │ └── recorderx.py ├── figures ├── activations.png ├── figure1.png └── tradeoff.png ├── requirements.txt └── super-resolution ├── data ├── __init__.py └── datasets.py ├── main.py ├── models ├── __init__.py ├── modules │ ├── __init__.py │ ├── activations.py │ ├── losses.py │ └── misc.py ├── srgan.py ├── vanilla.py ├── vgg.py ├── xdense.py ├── xsrgan.py └── xvanilla.py ├── scripts ├── matlab │ └── bicubic_subsample.m └── python │ ├── extract_images.py │ └── remove_images.py ├── trainer.py └── utils ├── __init__.py ├── core.py ├── misc.py ├── optim.py ├── recorderx.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | tmp/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # IPython Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # dotenv 82 | .env 83 | 84 | # virtualenv 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | 91 | # Rope project settings 92 | .ropeproject 93 | 94 | # Pycharm 95 | .idea/ 96 | 97 | # Vscode 98 | .vscode/ 99 | 100 | # jobs 101 | *.job 102 | 103 | # tar 104 | *.tar.gz 105 | *.pb 106 | 107 | *.pth 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xUnit 2 | Learning a Spatial Activation Function for Efficient Image Restoration. 3 | 4 |

5 | 6 |

7 | 8 | Please refer our [paper](https://arxiv.org/abs/1711.06445) for more details. 9 | 10 | 11 | ## Citation 12 | If you use this code for your research, please cite our papers: 13 | 14 | ``` 15 | @inproceedings{kligvasser2018xunit, 16 | title={xunit: Learning a spatial activation function for efficient image restoration}, 17 | author={Kligvasser, Idan and Rott Shaham, Tamar and Michaeli, Tomer}, 18 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 19 | pages={2433--2442}, 20 | year={2018} 21 | } 22 | ``` 23 | 24 | ``` 25 | @article{kligvasser2018dense, 26 | title={Dense xUnit Networks}, 27 | author={Kligvasser, Idan and Michaeli, Tomer}, 28 | journal={arXiv preprint arXiv:1811.11051}, 29 | year={2018} 30 | } 31 | ``` 32 | 33 | ## Code 34 | 35 | ### Clone repository 36 | 37 | Clone this repository into any place you want. 38 | 39 | ``` 40 | git clone https://github.com/kligvasser/xUnit 41 | cd xUnit 42 | ``` 43 | 44 | ### Install dependencies 45 | 46 | ``` 47 | python -m pip install -r requirements.txt 48 | ``` 49 | 50 | This code requires PyTorch 1.0+ and python 3+. 51 | 52 | ### Super-resoltution 53 | Pretrained models are avaible at: [LINK](https://www.dropbox.com/s/hq1n5yrl5hjsh34/sr_pretrained.zip?dl=0). 54 | 55 |

56 | 57 |

58 | 59 | #### Dataset preparation 60 | For the super-resolution task, the dataset should contains a low and high resolution pairs, in folder structure of: 61 | 62 | ```txt 63 | train 64 | ├── img 65 | ├── img_x2 66 | ├── img_x4 67 | val 68 | ├── img 69 | ├── img_x2 70 | ├── img_x4 71 | ``` 72 | 73 | You may prepare your own data by using the matlab script: 74 | 75 | ``` 76 | ./super-resolution/scripts/matlab/bicubic_subsample.m 77 | ``` 78 | 79 | Or download a prepared dataset based on the [BSD](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/) and [VOC](http://host.robots.ox.ac.uk/pascal/VOC/) datasets from [LINK](https://www.dropbox.com/s/o1nzpr9q7vup8b7/bsdvoc.zip?dl=0). 80 | 81 | #### Train xSRGAN x4 PSNR model 82 | ``` 83 | python3 main.py --root --g-model g_xsrgan --d-model d_xsrgan --model-config "{'scale':4, 'gen_blocks':10, 'dis_blocks':5}" --scale 4 --reconstruction-weight 1.0 --perceptual-weight 0 --adversarial-weight 0 --crop-size 40 84 | ``` 85 | 86 | #### Train xSRGAN x4 WGAN-GP model 87 | ``` 88 | python3 main.py --root --g-model g_xsrgan --d-model d_xsrgan_ad --model-config "{'scale':4, 'gen_blocks':10, 'dis_blocks':5}" --scale 4 --reconstruction-weight 1.0 --perceptual-weight 1.0 --adversarial-weight 0.005 --crop-size 64 --epochs 1200 --step-size 900 --gen-to-load --wgan --penalty-weight 10 89 | ``` 90 | 91 | 92 | #### Train xSRGAN x4 with SN-discriminator model 93 | ``` 94 | python3 main.py --root --g-model g_xsrgan --d-model d_xsrgan --model-config "{'scale':4, 'gen_blocks':10, 'dis_blocks':5, 'spectral':True}" --scale 4 --reconstruction-weight 1.0 --perceptual-weight 1.0 --adversarial-weight 0.01 --crop-size 40 --epochs 2000 --step-size 800 --gen-to-load --dis-betas 0 0.9 95 | ``` 96 | 97 | #### Eval xSRGAN x4 model 98 | ``` 99 | python3 main.py --root --g-model g_xsrgan --d-model d_xsrgan --model-config "{'scale':4, 'gen_blocks':10, 'dis_blocks':5}" --scale 4 --evaluation --gen-to-load 100 | ``` 101 | 102 | ### Gaussian denoising 103 | Pretrained models are avaible at: [LINK](https://www.dropbox.com/s/zychmfzx52y8tvq/denoising_pretrained.zip?dl=0). 104 | 105 | #### Dataset preparation 106 | For the denoising task, the dataset should contains only clean images, in folder structure of: 107 | 108 | ```txt 109 | train 110 | ├── img 111 | val 112 | ├── img 113 | ``` 114 | 115 | #### Train xDNCNN Grayscale 50 sigma PSNR model 116 | ``` 117 | python3 main.py --root --g-model g_xdncnn --d-model d_xdncnn --model-config "{'gen_blocks':10, 'dis_blocks':4, 'in_channels':1}" --reconstruction-weight 1.0 --perceptual-weight 0 --adversarial-weight 0 --crop-size 50 --gray-scale --noise-sigma 50 --epochs 500 --step-size 150 118 | ``` 119 | 120 | #### Train xDNCNN 75 sigma PSNR model 121 | ``` 122 | python3 main.py --root --g-model g_xdncnn --d-model d_xdncnn --model-config "{'gen_blocks':10, 'dis_blocks':4, 'in_channels':3}" --reconstruction-weight 1.0 --perceptual-weight 0 --adversarial-weight 0 --crop-size 64 --noise-sigma 75 --epochs 1000 --step-size 300 123 | ``` 124 | 125 | #### Train xDNCNN 75 sigma WGAN-GP model 126 | ``` 127 | python3 main.py --root --g-model g_xdncnn --d-model d_xdncnn --model-config "{'gen_blocks':10, 'dis_blocks':4, 'in_channels':3}" --reconstruction-weight 1.0 --perceptual-weight 1.0 --adversarial-weight 0.01 --crop-size 72 --noise-sigma 75 --epochs 1000 --step-size 300 --gen-to-load --wgan --penalty-weight 10 128 | ``` 129 | 130 | #### Train xDNCNN Grayscale blind PSNR model 131 | ``` 132 | python3 main.py --root --g-model g_xdncnn --d-model d_xdncnn --model-config "{'gen_blocks':10, 'dis_blocks':5, 'in_channels':1}" --reconstruction-weight 1.0 --perceptual-weight 0 --adversarial-weight 0 --crop-size 50 --gray-scale --noise-sigma 50 --blind --epochs 500 --step-size 150 133 | 134 | ``` -------------------------------------------------------------------------------- /denoising/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .datasets import DatasetNoise 3 | from torch.utils.data import DataLoader 4 | 5 | def get_loaders(args): 6 | # datasets 7 | dataset_train = DatasetNoise(root=os.path.join(args.root, 'train'), noise_sigma=args.noise_sigma, training=True, crop_size=args.crop_size, blind_denoising=args.blind, gray_scale=args.gray_scale) 8 | dataset_val = DatasetNoise(root=os.path.join(args.root, 'val'), noise_sigma=args.noise_sigma, training=False, crop_size=args.crop_size, blind_denoising=args.blind, gray_scale=args.gray_scale, max_size=args.max_size) 9 | 10 | # loaders 11 | loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 12 | loader_eval = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1) 13 | loaders = {'train': loader_train, 'eval': loader_eval} 14 | 15 | return loaders 16 | -------------------------------------------------------------------------------- /denoising/data/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import os 4 | import random 5 | from torchvision import transforms 6 | from PIL import Image 7 | 8 | class DatasetNoise(torch.utils.data.dataset.Dataset): 9 | def __init__(self, root='', noise_sigma=50., training=True, crop_size=60, blind_denoising=False, gray_scale=False, max_size=None): 10 | self.root = root 11 | self.noise_sigma = noise_sigma 12 | self.training = training 13 | self.crop_size = crop_size 14 | self.blind_denoising = blind_denoising 15 | self.gray_scale = gray_scale 16 | self.max_size = max_size 17 | 18 | self._init() 19 | 20 | def _init(self): 21 | # data paths 22 | targets = glob.glob(os.path.join(self.root, 'img', '*.*'))[:self.max_size] 23 | self.paths = {'target' : targets} 24 | 25 | # transforms 26 | t_list = [transforms.ToTensor()] 27 | self.image_transform = transforms.Compose(t_list) 28 | 29 | def _get_augment_params(self, size): 30 | random.seed(random.randint(0, 12345)) 31 | 32 | # position 33 | w_size, h_size = size 34 | x = random.randint(0, max(0, w_size - self.crop_size)) 35 | y = random.randint(0, max(0, h_size - self.crop_size)) 36 | 37 | # flip 38 | flip = random.random() > 0.5 39 | return {'crop_pos': (x, y), 'flip': flip} 40 | 41 | def _augment(self, image, aug_params): 42 | x, y = aug_params['crop_pos'] 43 | image = image.crop((x, y, x + self.crop_size, y + self.crop_size)) 44 | if aug_params['flip']: 45 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 46 | return image 47 | 48 | def __getitem__(self, index): 49 | # target image 50 | if self.gray_scale: 51 | target = Image.open(self.paths['target'][index]).convert('L') 52 | else: 53 | target = Image.open(self.paths['target'][index]).convert('RGB') 54 | 55 | # transform 56 | if self.training: 57 | aug_params = self._get_augment_params(target.size) 58 | target = self._augment(target, aug_params) 59 | target = self.image_transform(target) 60 | 61 | # add noise 62 | if self.blind_denoising: 63 | noise_sigma = random.randint(0, self.noise_sigma) 64 | else: 65 | noise_sigma = self.noise_sigma 66 | input = target + (noise_sigma / 255.) * torch.randn_like(target) 67 | 68 | return {'input': input, 'target': target, 'path': self.paths['target'][index]} 69 | 70 | def __len__(self): 71 | return len(self.paths['target']) -------------------------------------------------------------------------------- /denoising/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import logging 4 | import signal 5 | import sys 6 | import torch.backends.cudnn as cudnn 7 | from trainer import Trainer 8 | from datetime import datetime 9 | from os import path 10 | from utils import misc 11 | 12 | # torch.autograd.set_detect_anomaly(True) 13 | 14 | def get_arguments(): 15 | parser = argparse.ArgumentParser(description='super-resolution') 16 | parser.add_argument('--device', default='cuda', help='device assignment ("cpu" or "cuda")') 17 | parser.add_argument('--device-ids', default=[0], type=int, nargs='+', help='device ids assignment (e.g 0 1 2 3') 18 | parser.add_argument('--d-model', default='d_dncnn', help='discriminator architecture (default: dncnn)') 19 | parser.add_argument('--g-model', default='g_dncnn', help='generator architecture (default: dncnn)') 20 | parser.add_argument('--model-config', default='', help='additional architecture configuration') 21 | parser.add_argument('--dis-to-load', default='', help='resume training from file (default: None)') 22 | parser.add_argument('--gen-to-load', default='', help='resume training from file (default: None)') 23 | parser.add_argument('--root', default='', help='root dataset folder') 24 | parser.add_argument('--noise-sigma', default=50, type=int, help='noise-sigma (default: 50)') 25 | parser.add_argument('--crop-size', default=64, type=int, help='low resolution cropping size (default: 64)') 26 | parser.add_argument('--gray-scale', default=False, action='store_true', help='gray-scale images (default: false)') 27 | parser.add_argument('--blind', default=False, action='store_true', help='blind-denoising images (default: false)') 28 | parser.add_argument('--max-size', default=None, type=int, help='validation set max-size (default: None)') 29 | parser.add_argument('--num-workers', default=2, type=int, help='number of workers (default: 2)') 30 | parser.add_argument('--batch-size', default=16, type=int, help='batch-size (default: 16)') 31 | parser.add_argument('--epochs', default=1000, type=int, help='epochs (default: 1000)') 32 | parser.add_argument('--lr', default=2e-4, type=float, help='lr (default: 2e-4)') 33 | parser.add_argument('--gen-betas', default=[0.9, 0.99], nargs=2, type=float, help='scheduler gamma (default: 0.9, 0.99)') 34 | parser.add_argument('--dis-betas', default=[0.9, 0.99], nargs=2, type=float, help='scheduler gamma (default: 0.9, 0.99)') 35 | parser.add_argument('--num-critic', default=1, type=int, help='critic iterations (default: 1)') 36 | parser.add_argument('--wgan', default=False, action='store_true', help='critic wgan loss (default: false)') 37 | parser.add_argument('--relativistic', default=False, action='store_true', help='relativistic wgan loss (default: false)') 38 | parser.add_argument('--step-size', default=300, type=int, help='scheduler step size (default: 300)') 39 | parser.add_argument('--gamma', default=0.5, type=float, help='scheduler gamma (default: 0.5)') 40 | parser.add_argument('--penalty-weight', default=0, type=float, help='gradient penalty weight (default: 0)') 41 | parser.add_argument('--range-weight', default=0, type=float, help='pixel-weight (default: 0)') 42 | parser.add_argument('--reconstruction-weight', default=1.0, type=float, help='reconstruction-weight (default: 1.0)') 43 | parser.add_argument('--perceptual-weight', default=0, type=float, help='perceptual-weight (default: 0)') 44 | parser.add_argument('--adversarial-weight', default=0.01, type=float, help='adversarial-weight (default: 0.01)') 45 | parser.add_argument('--style-weight', default=0, type=float, help='style-weight (default: 0)') 46 | parser.add_argument('--print-every', default=20, type=int, help='print-every (default: 20)') 47 | parser.add_argument('--eval-every', default=50, type=int, help='eval-every (default: 50)') 48 | parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results', help='results dir') 49 | parser.add_argument('--save', metavar='SAVE', default='', help='saved folder') 50 | parser.add_argument('--evaluation', default=False, action='store_true', help='evaluate a model (default: false)') 51 | parser.add_argument('--use-tb', default=False, action='store_true', help='use tensorboardx (default: false)') 52 | args = parser.parse_args() 53 | 54 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 55 | if args.save == '': 56 | args.save = time_stamp 57 | args.save_path = path.join(args.results_dir, args.save) 58 | return args 59 | 60 | def main(): 61 | # arguments 62 | args = get_arguments() 63 | 64 | # cuda 65 | if 'cuda' in args.device and torch.cuda.is_available(): 66 | torch.cuda.set_device(args.device_ids[0]) 67 | cudnn.benchmark = True 68 | else: 69 | args.device_ids = None 70 | 71 | # set logs 72 | misc.mkdir(args.save_path) 73 | misc.mkdir(path.join(args.save_path, 'images')) 74 | misc.setup_logging(path.join(args.save_path, 'log.txt')) 75 | 76 | # print logs 77 | logging.info(args) 78 | 79 | # trainer 80 | trainer = Trainer(args) 81 | 82 | if args.evaluation: 83 | trainer.eval() 84 | else: 85 | trainer.train() 86 | 87 | if __name__ == '__main__': 88 | # enables a ctrl-c without triggering errors 89 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) 90 | main() 91 | -------------------------------------------------------------------------------- /denoising/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dncnn import * 2 | from .xdncnn import * 3 | from .xdense import * -------------------------------------------------------------------------------- /denoising/models/dncnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 5 | 6 | __all__ = ['g_dncnn', 'd_dncnn'] 7 | 8 | def initialize_weights(net, scale=1.): 9 | if not isinstance(net, list): 10 | net = [net] 11 | for layer in net: 12 | for m in layer.modules(): 13 | if isinstance(m, nn.Conv2d): 14 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 15 | m.weight.data *= scale # for residual block 16 | if m.bias is not None: 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 20 | m.weight.data *= scale 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif isinstance(m, nn.BatchNorm2d): 24 | nn.init.constant_(m.weight, 1) 25 | nn.init.constant_(m.bias.data, 0.0) 26 | 27 | class GenBlock(nn.Module): 28 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, bias=True): 29 | super(GenBlock, self).__init__() 30 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias) 31 | self.bn = nn.BatchNorm2d(num_features=out_channels) 32 | self.relu = nn.ReLU(inplace=True) 33 | 34 | initialize_weights([self.conv, self.bn], 0.02) 35 | 36 | def forward(self, x): 37 | x = self.relu(self.bn(self.conv(x))) 38 | return x 39 | 40 | class DisBlock(nn.Module): 41 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False): 42 | super(DisBlock, self).__init__() 43 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 44 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias) 45 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias) 46 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True) 47 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True) 48 | 49 | initialize_weights([self.conv1, self.conv2], 0.1) 50 | 51 | if normalization: 52 | self.conv1 = SpectralNorm(self.conv1) 53 | self.conv2 = SpectralNorm(self.conv2) 54 | 55 | def forward(self, x): 56 | x = self.lrelu(self.bn1(self.conv1(x))) 57 | x = self.lrelu(self.bn2(self.conv2(x))) 58 | return x 59 | 60 | class Generator(nn.Module): 61 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks): 62 | super(Generator, self).__init__() 63 | 64 | # image to features 65 | self.image_to_features = GenBlock(in_channels=in_channels, out_channels=num_features) 66 | 67 | # features 68 | blocks = [] 69 | for _ in range(gen_blocks): 70 | blocks.append(GenBlock(in_channels=num_features, out_channels=num_features, bias=False)) 71 | self.features = nn.Sequential(*blocks) 72 | 73 | # features to image 74 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2) 75 | initialize_weights([self.features_to_image], 0.02) 76 | 77 | def forward(self, x): 78 | r = x 79 | x = self.image_to_features(x) 80 | x = self.features(x) 81 | x = self.features_to_image(x) 82 | x += r 83 | return x 84 | 85 | class Discriminator(nn.Module): 86 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks): 87 | super(Discriminator, self).__init__() 88 | 89 | # image to features 90 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 91 | 92 | # features 93 | blocks = [] 94 | for i in range(0, dis_blocks - 1): 95 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 96 | self.features = nn.Sequential(*blocks) 97 | 98 | # classifier 99 | self.classifier = nn.Conv2d(in_channels=num_features * min(pow(2, dis_blocks - 1), 8), out_channels=1, kernel_size=4, padding=0) 100 | 101 | def forward(self, x): 102 | x = self.image_to_features(x) 103 | x = self.features(x) 104 | x = self.classifier(x) 105 | x = x.flatten(start_dim=1).mean(dim=-1) 106 | return x 107 | 108 | class SNDiscriminator(nn.Module): 109 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks): 110 | super(SNDiscriminator, self).__init__() 111 | 112 | # image to features 113 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 114 | 115 | # features 116 | blocks = [] 117 | for i in range(0, dis_blocks - 1): 118 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 119 | self.features = nn.Sequential(*blocks) 120 | 121 | # classifier 122 | self.classifier = SpectralNorm(nn.Conv2d(in_channels=num_features * min(pow(2, dis_blocks - 1), 8), out_channels=1, kernel_size=4, padding=0)) 123 | 124 | def forward(self, x): 125 | x = self.image_to_features(x) 126 | x = self.features(x) 127 | x = self.classifier(x) 128 | x = x.flatten(start_dim=1).mean(dim=-1) 129 | return x 130 | 131 | def g_dncnn(**config): 132 | config.setdefault('in_channels', 3) 133 | config.setdefault('num_features', 64) 134 | config.setdefault('gen_blocks', 8) 135 | config.setdefault('dis_blocks', 5) 136 | 137 | return Generator(**config) 138 | 139 | def d_dncnn(**config): 140 | config.setdefault('in_channels', 3) 141 | config.setdefault('num_features', 64) 142 | config.setdefault('gen_blocks', 8) 143 | config.setdefault('dis_blocks', 5) 144 | 145 | return Discriminator(**config) -------------------------------------------------------------------------------- /denoising/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/denoising/models/modules/__init__.py -------------------------------------------------------------------------------- /denoising/models/modules/activations.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class xUnit(nn.Module): 4 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False): 5 | super(xUnit, self).__init__() 6 | # xUnit 7 | self.features = nn.Sequential( 8 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity, 9 | nn.ReLU(), 10 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features), 11 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity, 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | a = self.features(x) 17 | r = x * a 18 | return r 19 | 20 | class xUnitS(nn.Module): 21 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False): 22 | super(xUnitS, self).__init__() 23 | # slim xUnit 24 | self.features = nn.Sequential( 25 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features), 26 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(), 27 | nn.Sigmoid() 28 | ) 29 | 30 | def forward(self, x): 31 | a = self.features(x) 32 | r = x * a 33 | return r 34 | 35 | class xUnitD(nn.Module): 36 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False): 37 | super(xUnitD, self).__init__() 38 | # dense xUnit 39 | self.features = nn.Sequential( 40 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=1, padding=0), 41 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(), 42 | nn.ReLU(), 43 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features), 44 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(), 45 | nn.Sigmoid() 46 | ) 47 | 48 | def forward(self, x): 49 | a = self.features(x) 50 | r = x * a 51 | return r 52 | 53 | class Identity(nn.Module): 54 | def __init__(self,): 55 | super(Identity, self).__init__() 56 | 57 | def forward(self, x): 58 | return x -------------------------------------------------------------------------------- /denoising/models/modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .misc import shave_edge 5 | from ..vgg import MultiVGGFeaturesExtractor 6 | 7 | class RangeLoss(nn.Module): 8 | def __init__(self, min_value=0., max_value=1., invalidity_margins=None): 9 | super(RangeLoss, self).__init__() 10 | self.min_value = min_value 11 | self.max_value = max_value 12 | self.invalidity_margins = invalidity_margins 13 | 14 | def forward(self, inputs): 15 | if self.invalidity_margins: 16 | inputs = shave_edge(inputs, self.invalidity_margins, self.invalidity_margins) 17 | loss = (F.relu(self.min_value - inputs) + F.relu(inputs - self.max_value)).mean() 18 | return loss 19 | 20 | class PerceptualLoss(nn.Module): 21 | def __init__(self, features_to_compute, criterion=torch.nn.L1Loss(), shave_edge=None): 22 | super(PerceptualLoss, self).__init__() 23 | self.criterion = criterion 24 | self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute, shave_edge=shave_edge).eval() 25 | 26 | def forward(self, inputs, targets): 27 | inputs_fea = self.features_extractor(inputs) 28 | with torch.no_grad(): 29 | targets_fea = self.features_extractor(targets) 30 | 31 | loss = 0 32 | for key in inputs_fea.keys(): 33 | loss += self.criterion(inputs_fea[key], targets_fea[key].detach()) 34 | 35 | return loss 36 | 37 | class StyleLoss(nn.Module): 38 | def __init__(self, features_to_compute, criterion=torch.nn.L1Loss(), shave_edge=None): 39 | super(StyleLoss, self).__init__() 40 | self.criterion = criterion 41 | self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute, use_input_norm=True, shave_edge=shave_edge).eval() 42 | 43 | def forward(self, inputs, targets): 44 | inputs_fea = self.features_extractor(inputs) 45 | with torch.no_grad(): 46 | targets_fea = self.features_extractor(targets) 47 | 48 | loss = 0 49 | for key in inputs_fea.keys(): 50 | inputs_gram = self._gram_matrix(inputs_fea[key]) 51 | with torch.no_grad(): 52 | targets_gram = self._gram_matrix(targets_fea[key]).detach() 53 | 54 | loss += self.criterion(inputs_gram, targets_gram) 55 | 56 | return loss 57 | 58 | def _gram_matrix(self, x): 59 | a, b, c, d = x.size() 60 | features = x.view(a, b, c * d) 61 | gram = features.bmm(features.transpose(1, 2)) 62 | return gram.div(b * c * d) -------------------------------------------------------------------------------- /denoising/models/modules/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class UpsampleX2(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size=3): 7 | super(UpsampleX2, self).__init__() 8 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=(out_channels * 4), kernel_size=kernel_size, padding=(kernel_size // 2)) 9 | self.shuffler = nn.PixelShuffle(2) 10 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 11 | 12 | def forward(self, x): 13 | return self.lrelu(self.shuffler(self.conv(x))) 14 | 15 | def center_crop(x, height, width): 16 | crop_h = torch.FloatTensor([x.size()[2]]).sub(height).div(-2) 17 | crop_w = torch.FloatTensor([x.size()[3]]).sub(width).div(-2) 18 | 19 | return F.pad(x, [ 20 | crop_w.ceil().int()[0], crop_w.floor().int()[0], 21 | crop_h.ceil().int()[0], crop_h.floor().int()[0], 22 | ]) 23 | 24 | def shave_edge(x, shave_h, shave_w): 25 | return F.pad(x, [-shave_w, -shave_w, -shave_h, -shave_h]) 26 | 27 | def shave_modulo(x, factor): 28 | shave_w = x.size(-1) % factor 29 | shave_h = x.size(-2) % factor 30 | return F.pad(x, [0, -shave_w, 0, -shave_h]) 31 | 32 | if __name__ == "__main__": 33 | x = torch.randn(1, 2, 4, 6) 34 | y = shave_edge(x, 1, 2) 35 | print(x) 36 | print(y) -------------------------------------------------------------------------------- /denoising/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | from models.modules.misc import shave_edge 5 | from collections import OrderedDict 6 | 7 | names = {'vgg19': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 8 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 9 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 10 | 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 11 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 12 | 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 13 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'], 15 | 16 | 'vgg19_bn': ['conv1_1', 'bn1_1', 'relu1_1', 'conv1_2', 'bn1_2', 'relu1_2', 'pool1', 17 | 'conv2_1', 'bn2_1', 'relu2_1', 'conv2_2', 'bn2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'bn3_1', 'relu3_1', 'conv3_2', 'bn3_2', 'relu3_2', 19 | 'conv3_3', 'bn3_3', 'relu3_3', 'conv3_4', 'bn3_4', 'relu3_4', 'pool3', 20 | 'conv4_1', 'bn4_1', 'relu4_1', 'conv4_2', 'bn4_2', 'relu4_2', 21 | 'conv4_3', 'bn4_3', 'relu4_3', 'conv4_4', 'bn4_4', 'relu4_4', 'pool4', 22 | 'conv5_1', 'bn5_1', 'relu5_1', 'conv5_2', 'bn5_2', 'relu5_2', 23 | 'conv5_3', 'bn5_3', 'relu5_3', 'conv5_4', 'bn5_4', 'relu5_4', 'pool5'] 24 | } 25 | 26 | 27 | class VGGFeaturesExtractor(nn.Module): 28 | def __init__(self, feature_layer='conv5_4', use_bn=False, use_input_norm=True, requires_grad=False): 29 | super(VGGFeaturesExtractor, self).__init__() 30 | self.use_input_norm = use_input_norm 31 | 32 | if use_bn: 33 | model = torchvision.models.vgg19_bn(pretrained=True) 34 | else: 35 | model = torchvision.models.vgg19(pretrained=True) 36 | 37 | if self.use_input_norm: 38 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 39 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 40 | self.register_buffer('mean', mean) 41 | self.register_buffer('std', std) 42 | 43 | layer_index = names['vgg19_bn'].index(feature_layer) if use_bn else names['vgg19'].index(feature_layer) 44 | self.features = nn.Sequential(*list(model.features.children())[:(layer_index + 1)]) 45 | 46 | if not requires_grad: 47 | for k, v in self.features.named_parameters(): 48 | v.requires_grad = False 49 | self.features.eval() 50 | 51 | def forward(self, x): 52 | # Assume input range is [0, 1] 53 | if self.use_input_norm: 54 | x = (x - self.mean) / self.std 55 | output = self.features(x) 56 | return output 57 | 58 | class MultiVGGFeaturesExtractor(nn.Module): 59 | def __init__(self, target_features=('relu1_1', 'relu2_1', 'relu3_1'), use_bn=False, use_input_norm=True, requires_grad=False, shave_edge=None): 60 | super(MultiVGGFeaturesExtractor, self).__init__() 61 | self.use_input_norm = use_input_norm 62 | self.target_features = target_features 63 | self.shave_edge = shave_edge 64 | 65 | if use_bn: 66 | model = torchvision.models.vgg19_bn(pretrained=True) 67 | names_key = 'vgg19_bn' 68 | else: 69 | model = torchvision.models.vgg19(pretrained=True) 70 | names_key = 'vgg19' 71 | 72 | if self.use_input_norm: 73 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 75 | self.register_buffer('mean', mean) 76 | self.register_buffer('std', std) 77 | 78 | self.target_indexes = [names[names_key].index(k) for k in self.target_features] 79 | self.features = nn.Sequential(*list(model.features.children())[:(max(self.target_indexes) + 1)]) 80 | 81 | if not requires_grad: 82 | for k, v in self.features.named_parameters(): 83 | v.requires_grad = False 84 | self.features.eval() 85 | 86 | def forward(self, x): 87 | if self.shave_edge: 88 | x = shave_edge(x, self.shave_edge, self.shave_edge) 89 | 90 | # assume input range is [0, 1] 91 | if self.use_input_norm: 92 | x = (x - self.mean) / self.std 93 | 94 | output = OrderedDict() 95 | for key, layer in self.features._modules.items(): 96 | x = layer(x) 97 | if int(key) in self.target_indexes: 98 | output.update({self.target_features[self.target_indexes.index(int(key))]: x}) 99 | return output -------------------------------------------------------------------------------- /denoising/models/xdense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 5 | from .modules.activations import xUnitD 6 | from .modules.misc import center_crop 7 | 8 | __all__ = ['g_xdense', 'd_xdense'] 9 | 10 | def initialize_weights(net, scale=1.): 11 | if not isinstance(net, list): 12 | net = [net] 13 | for layer in net: 14 | for m in layer.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 17 | m.weight.data *= scale # for residual block 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.Linear): 21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 22 | m.weight.data *= scale 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm2d): 26 | nn.init.constant_(m.weight, 1) 27 | nn.init.constant_(m.bias.data, 0.0) 28 | 29 | class xModule(nn.Module): 30 | def __init__(self, in_channels, out_channels): 31 | super(xModule, self).__init__() 32 | # features 33 | self.features = nn.Sequential( 34 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1), 35 | xUnitD(num_features=out_channels, batch_norm=True), 36 | nn.BatchNorm2d(num_features=out_channels) 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | return x 42 | 43 | class xDenseLayer(nn.Module): 44 | def __init__(self, in_channels, growth_rate, bn_size): 45 | super(xDenseLayer, self).__init__() 46 | # features 47 | self.features = nn.Sequential( 48 | nn.BatchNorm2d(num_features=in_channels), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(in_channels=in_channels,out_channels=bn_size * growth_rate, kernel_size=1, stride=1, bias=False), 51 | nn.BatchNorm2d(num_features=bn_size * growth_rate), 52 | nn.ReLU(inplace=True), 53 | xModule(in_channels=bn_size * growth_rate, out_channels=growth_rate) 54 | ) 55 | 56 | def forward(self, x): 57 | f = self.features(x) 58 | return torch.cat([x, f], dim=1) 59 | 60 | class xDenseBlock(nn.Module): 61 | def __init__(self, in_channels, num_layers, growth_rate, bn_size): 62 | super(xDenseBlock, self).__init__() 63 | # features 64 | blocks = [] 65 | for i in range(num_layers): 66 | blocks.append(xDenseLayer(in_channels=in_channels + growth_rate*i, growth_rate=growth_rate, bn_size=bn_size)) 67 | self.features = nn.Sequential(*blocks) 68 | 69 | def forward(self, x): 70 | x = self.features(x) 71 | return x 72 | 73 | class Transition(nn.Module): 74 | def __init__(self, in_channels, out_channels): 75 | super(Transition, self).__init__() 76 | # features 77 | self.features = nn.Sequential( 78 | nn.BatchNorm2d(num_features=in_channels), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(in_channels=in_channels,out_channels=out_channels, kernel_size=1, stride=1, bias=False), 81 | ) 82 | 83 | def forward(self, x): 84 | x = self.features(x) 85 | return x 86 | 87 | class DisBlock(nn.Module): 88 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False): 89 | super(DisBlock, self).__init__() 90 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 91 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias) 92 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias) 93 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True) 94 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True) 95 | 96 | initialize_weights([self.conv1, self.conv2], 0.1) 97 | 98 | if normalization: 99 | self.conv1 = SpectralNorm(self.conv1) 100 | self.conv2 = SpectralNorm(self.conv2) 101 | 102 | def forward(self, x): 103 | x = self.lrelu(self.bn1(self.conv1(x))) 104 | x = self.lrelu(self.bn2(self.conv2(x))) 105 | return x 106 | 107 | class Generator(nn.Module): 108 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size): 109 | super(Generator, self).__init__() 110 | 111 | # image to features 112 | self.image_to_features = xModule(in_channels=in_channels, out_channels=num_features) 113 | 114 | # features 115 | blocks = [] 116 | self.num_features = num_features 117 | for i, num_layers in enumerate(gen_blocks): 118 | blocks.append(xDenseBlock(in_channels=self.num_features, num_layers=num_layers, growth_rate=growth_rate, bn_size=bn_size)) 119 | self.num_features += num_layers * growth_rate 120 | 121 | if i != len(gen_blocks) - 1: 122 | blocks.append(Transition(in_channels=self.num_features, out_channels=self.num_features // 2)) 123 | self.num_features = self.num_features // 2 124 | 125 | self.features = nn.Sequential(*blocks) 126 | 127 | # features to image 128 | self.features_to_image = nn.Conv2d(in_channels=self.num_features, out_channels=in_channels, kernel_size=5, padding=2) 129 | 130 | def forward(self, x): 131 | r = x 132 | x = self.image_to_features(x) 133 | x = self.features(x) 134 | x = self.features_to_image(x) 135 | x += r 136 | return x 137 | 138 | class Discriminator(nn.Module): 139 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size): 140 | super(Discriminator, self).__init__() 141 | self.crop_size = 4 * pow(2, dis_blocks) 142 | 143 | # image to features 144 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 145 | 146 | # features 147 | blocks = [] 148 | for i in range(0, dis_blocks - 1): 149 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 150 | self.features = nn.Sequential(*blocks) 151 | 152 | # classifier 153 | self.classifier = nn.Sequential( 154 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100), 155 | nn.LeakyReLU(negative_slope=0.1), 156 | nn.Linear(100, 1) 157 | ) 158 | 159 | def forward(self, x): 160 | x = center_crop(x, self.crop_size, self.crop_size) 161 | x = self.image_to_features(x) 162 | x = self.features(x) 163 | x = x.flatten(start_dim=1) 164 | x = self.classifier(x) 165 | return x 166 | 167 | class SNDiscriminator(nn.Module): 168 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size): 169 | super(SNDiscriminator, self).__init__() 170 | self.crop_size = 4 * pow(2, dis_blocks) 171 | 172 | # image to features 173 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 174 | 175 | # features 176 | blocks = [] 177 | for i in range(0, dis_blocks - 1): 178 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 179 | self.features = nn.Sequential(*blocks) 180 | 181 | # classifier 182 | self.classifier = nn.Sequential( 183 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)), 184 | nn.LeakyReLU(negative_slope=0.1), 185 | SpectralNorm(nn.Linear(100, 1)) 186 | ) 187 | 188 | def forward(self, x): 189 | x = center_crop(x, self.crop_size, self.crop_size) 190 | x = self.image_to_features(x) 191 | x = self.features(x) 192 | x = x.flatten(start_dim=1) 193 | x = self.classifier(x) 194 | return x 195 | 196 | def g_xdense(**config): 197 | config.setdefault('in_channels', 3) 198 | config.setdefault('num_features', 64) 199 | config.setdefault('gen_blocks', [4, 6, 8]) 200 | config.setdefault('dis_blocks', 4) 201 | config.setdefault('growth_rate', 16) 202 | config.setdefault('bn_size', 2) 203 | 204 | _ = config.pop('spectral', False) 205 | 206 | return Generator(**config) 207 | 208 | def d_xdense(**config): 209 | config.setdefault('in_channels', 3) 210 | config.setdefault('num_features', 64) 211 | config.setdefault('gen_blocks', [4, 6, 8]) 212 | config.setdefault('dis_blocks', 4) 213 | config.setdefault('growth_rate', 16) 214 | config.setdefault('bn_size', 2) 215 | 216 | sn = config.pop('spectral', False) 217 | 218 | if sn: 219 | return SNDiscriminator(**config) 220 | else: 221 | return Discriminator(**config) 222 | -------------------------------------------------------------------------------- /denoising/models/xdncnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 5 | from .modules.activations import xUnitS 6 | from .modules.misc import center_crop 7 | 8 | __all__ = ['g_xdncnn', 'd_xdncnn'] 9 | 10 | def initialize_weights(net, scale=1.): 11 | if not isinstance(net, list): 12 | net = [net] 13 | for layer in net: 14 | for m in layer.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 17 | m.weight.data *= scale # for residual block 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.Linear): 21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 22 | m.weight.data *= scale 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm2d): 26 | nn.init.constant_(m.weight, 1) 27 | nn.init.constant_(m.bias.data, 0.0) 28 | 29 | class GenBlock(nn.Module): 30 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, bias=True): 31 | super(GenBlock, self).__init__() 32 | self.bn = nn.BatchNorm2d(num_features=in_channels) 33 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias) 34 | self.xunit = xUnitS(num_features=out_channels, batch_norm=True) 35 | 36 | initialize_weights([self.conv, self.bn], 0.02) 37 | 38 | def forward(self, x): 39 | x = self.xunit(self.conv(self.bn(x))) 40 | return x 41 | 42 | class DisBlock(nn.Module): 43 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False): 44 | super(DisBlock, self).__init__() 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 46 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias) 47 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias) 48 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True) 49 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True) 50 | 51 | initialize_weights([self.conv1, self.conv2], 0.1) 52 | 53 | if normalization: 54 | self.conv1 = SpectralNorm(self.conv1) 55 | self.conv2 = SpectralNorm(self.conv2) 56 | 57 | def forward(self, x): 58 | x = self.lrelu(self.bn1(self.conv1(x))) 59 | x = self.lrelu(self.bn2(self.conv2(x))) 60 | return x 61 | 62 | class Generator(nn.Module): 63 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks): 64 | super(Generator, self).__init__() 65 | 66 | # image to features 67 | self.image_to_features = GenBlock(in_channels=in_channels, out_channels=num_features) 68 | 69 | # features 70 | blocks = [] 71 | for _ in range(gen_blocks): 72 | blocks.append(GenBlock(in_channels=num_features, out_channels=num_features)) 73 | self.features = nn.Sequential(*blocks) 74 | 75 | # features to image 76 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2) 77 | initialize_weights([self.features_to_image], 0.02) 78 | 79 | def forward(self, x): 80 | r = x 81 | x = self.image_to_features(x) 82 | x = self.features(x) 83 | x = self.features_to_image(x) 84 | x += r 85 | return x 86 | 87 | class Discriminator(nn.Module): 88 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks): 89 | super(Discriminator, self).__init__() 90 | self.crop_size = 4 * pow(2, dis_blocks) 91 | 92 | # image to features 93 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 94 | 95 | # features 96 | blocks = [] 97 | for i in range(0, dis_blocks - 1): 98 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 99 | self.features = nn.Sequential(*blocks) 100 | 101 | # classifier 102 | self.classifier = nn.Sequential( 103 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100), 104 | nn.LeakyReLU(negative_slope=0.1), 105 | nn.Linear(100, 1) 106 | ) 107 | 108 | def forward(self, x): 109 | x = center_crop(x, self.crop_size, self.crop_size) 110 | x = self.image_to_features(x) 111 | x = self.features(x) 112 | x = x.flatten(start_dim=1) 113 | x = self.classifier(x) 114 | return x 115 | 116 | class SNDiscriminator(nn.Module): 117 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks): 118 | super(SNDiscriminator, self).__init__() 119 | self.crop_size = 4 * pow(2, dis_blocks) 120 | 121 | # image to features 122 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 123 | 124 | # features 125 | blocks = [] 126 | for i in range(0, dis_blocks - 1): 127 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 128 | self.features = nn.Sequential(*blocks) 129 | 130 | # classifier 131 | self.classifier = nn.Sequential( 132 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)), 133 | nn.LeakyReLU(negative_slope=0.1), 134 | SpectralNorm(nn.Linear(100, 1)) 135 | ) 136 | 137 | def forward(self, x): 138 | x = center_crop(x, self.crop_size, self.crop_size) 139 | x = self.image_to_features(x) 140 | x = self.features(x) 141 | x = x.flatten(start_dim=1) 142 | x = self.classifier(x) 143 | return x 144 | 145 | def g_xdncnn(**config): 146 | config.setdefault('in_channels', 3) 147 | config.setdefault('num_features', 64) 148 | config.setdefault('gen_blocks', 8) 149 | config.setdefault('dis_blocks', 4) 150 | 151 | _ = config.pop('spectral', False) 152 | 153 | return Generator(**config) 154 | 155 | def d_xdncnn(**config): 156 | config.setdefault('in_channels', 3) 157 | config.setdefault('num_features', 64) 158 | config.setdefault('gen_blocks', 8) 159 | config.setdefault('dis_blocks', 4) 160 | 161 | sn = config.pop('spectral', False) 162 | 163 | if sn: 164 | return SNDiscriminator(**config) 165 | else: 166 | return Discriminator(**config) 167 | -------------------------------------------------------------------------------- /denoising/trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | import models 4 | import os 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.optim.lr_scheduler import StepLR 8 | from torch.autograd import grad as torch_grad, Variable 9 | from data import get_loaders 10 | from ast import literal_eval 11 | from utils.recorderx import RecoderX 12 | from utils.misc import save_image, average, mkdir, compute_psnr 13 | from models.modules.losses import RangeLoss, PerceptualLoss, StyleLoss 14 | 15 | class Trainer(): 16 | def __init__(self, args): 17 | # parameters 18 | self.args = args 19 | self.print_model = True 20 | self.invalidity_margins = None 21 | 22 | if self.args.use_tb: 23 | self.tb = RecoderX(log_dir=args.save_path) 24 | 25 | # initialize 26 | self._init() 27 | 28 | def _init_model(self): 29 | # initialize model 30 | if self.args.model_config != '': 31 | model_config = dict({}, **literal_eval(self.args.model_config)) 32 | else: 33 | model_config = {} 34 | 35 | g_model = models.__dict__[self.args.g_model] 36 | d_model = models.__dict__[self.args.d_model] 37 | self.g_model = g_model(**model_config) 38 | self.d_model = d_model(**model_config) 39 | 40 | # loading weights 41 | if self.args.gen_to_load != '': 42 | logging.info('\nLoading g-model...') 43 | self.g_model.load_state_dict(torch.load(self.args.gen_to_load, map_location='cpu')) 44 | if self.args.dis_to_load != '': 45 | logging.info('\nLoading d-model...') 46 | self.d_model.load_state_dict(torch.load(self.args.dis_to_load, map_location='cpu')) 47 | 48 | # to cuda 49 | self.g_model = self.g_model.to(self.args.device) 50 | self.d_model = self.d_model.to(self.args.device) 51 | 52 | # parallel 53 | if self.args.device_ids and len(self.args.device_ids) > 1: 54 | self.g_model = torch.nn.DataParallel(self.g_model, self.args.device_ids) 55 | self.d_model = torch.nn.DataParallel(self.d_model, self.args.device_ids) 56 | 57 | # print model 58 | if self.print_model: 59 | logging.info(self.g_model) 60 | logging.info('Number of parameters in generator: {}\n'.format(sum([l.nelement() for l in self.g_model.parameters()]))) 61 | logging.info(self.d_model) 62 | logging.info('Number of parameters in discriminator: {}\n'.format(sum([l.nelement() for l in self.d_model.parameters()]))) 63 | self.print_model = False 64 | 65 | def _init_optim(self): 66 | # initialize optimizer 67 | self.g_optimizer = torch.optim.Adam(self.g_model.parameters(), lr=self.args.lr, betas=self.args.gen_betas) 68 | self.d_optimizer = torch.optim.Adam(self.d_model.parameters(), lr=self.args.lr, betas=self.args.dis_betas) 69 | 70 | # initialize scheduler 71 | self.g_scheduler = StepLR(self.g_optimizer, step_size=self.args.step_size, gamma=self.args.gamma) 72 | self.d_scheduler = StepLR(self.d_optimizer, step_size=self.args.step_size, gamma=self.args.gamma) 73 | 74 | # initialize criterion 75 | if self.args.reconstruction_weight: 76 | self.reconstruction = torch.nn.L1Loss().to(self.args.device) 77 | if self.args.perceptual_weight > 0.: 78 | self.perceptual = PerceptualLoss(features_to_compute=['conv5_4'], criterion=torch.nn.L1Loss(), shave_edge=self.invalidity_margins).to(self.args.device) 79 | if self.args.style_weight > 0.: 80 | self.style = StyleLoss(features_to_compute=['relu3_1', 'relu2_1'], shave_edge=self.invalidity_margins).to(self.args.device) 81 | if self.args.range_weight > 0.: 82 | self.range = RangeLoss(invalidity_margins=self.invalidity_margins).to(self.args.device) 83 | 84 | def _init(self): 85 | # init parameters 86 | self.step = 0 87 | self.losses = {'D': [], 'D_r': [], 'D_gp': [], 'D_f': [], 'G': [], 'G_recon': [], 'G_rng': [], 'G_perc': [], 'G_sty': [], 'G_adv': [], 'psnr': []} 88 | 89 | # initialize model 90 | self._init_model() 91 | 92 | # initialize optimizer 93 | self._init_optim() 94 | 95 | def _save_model(self, epoch): 96 | # save models 97 | torch.save(self.g_model.state_dict(), os.path.join(self.args.save_path, '{}_e{}.pt'.format(self.args.g_model, epoch + 1))) 98 | torch.save(self.d_model.state_dict(), os.path.join(self.args.save_path, '{}_e{}.pt'.format(self.args.d_model, epoch + 1))) 99 | 100 | def _set_require_grads(self, model, require_grad): 101 | for p in model.parameters(): 102 | p.requires_grad_(require_grad) 103 | 104 | def _critic_hinge_iteration(self, inputs, targets): 105 | # require grads 106 | self._set_require_grads(self.d_model, True) 107 | 108 | # get generated data 109 | generated_data = self.g_model(inputs) 110 | 111 | # zero grads 112 | self.d_optimizer.zero_grad() 113 | 114 | # calculate probabilities on real and generated data 115 | d_real = self.d_model(targets) 116 | d_generated = self.d_model(generated_data.detach()) 117 | 118 | # create total loss and optimize 119 | loss_r = F.relu(1.0 - d_real).mean() 120 | loss_f = F.relu(1.0 + d_generated).mean() 121 | loss = loss_r + loss_f 122 | 123 | # get gradient penalty 124 | if self.args.penalty_weight > 0.: 125 | gradient_penalty = self._gradient_penalty(targets, generated_data) 126 | loss += gradient_penalty 127 | 128 | loss.backward() 129 | 130 | self.d_optimizer.step() 131 | 132 | # record loss 133 | self.losses['D'].append(loss.data.item()) 134 | self.losses['D_r'].append(loss_r.data.item()) 135 | self.losses['D_f'].append(loss_f.data.item()) 136 | if self.args.penalty_weight > 0.: 137 | self.losses['D_gp'].append(gradient_penalty.data.item()) 138 | 139 | # require grads 140 | self._set_require_grads(self.d_model, False) 141 | 142 | def _critic_wgan_iteration(self, inputs, targets): 143 | # require grads 144 | self._set_require_grads(self.d_model, True) 145 | 146 | # get generated data 147 | generated_data = self.g_model(inputs) 148 | 149 | # zero grads 150 | self.d_optimizer.zero_grad() 151 | 152 | # calculate probabilities on real and generated data 153 | d_real = self.d_model(targets) 154 | d_generated = self.d_model(generated_data.detach()) 155 | 156 | # create total loss and optimize 157 | if self.args.relativistic: 158 | loss_r = -(d_real - d_generated.mean()).mean() 159 | loss_f = (d_generated - d_real.mean()).mean() 160 | else: 161 | loss_r = -d_real.mean() 162 | loss_f = d_generated.mean() 163 | loss = loss_f + loss_r 164 | 165 | # get gradient penalty 166 | if self.args.penalty_weight > 0.: 167 | gradient_penalty = self._gradient_penalty(targets, generated_data) 168 | loss += gradient_penalty 169 | 170 | loss.backward() 171 | 172 | self.d_optimizer.step() 173 | 174 | # record loss 175 | self.losses['D'].append(loss.data.item()) 176 | self.losses['D_r'].append(loss_r.data.item()) 177 | self.losses['D_f'].append(loss_f.data.item()) 178 | if self.args.penalty_weight > 0.: 179 | self.losses['D_gp'].append(gradient_penalty.data.item()) 180 | 181 | # require grads 182 | self._set_require_grads(self.d_model, False) 183 | 184 | def _gradient_penalty(self, real_data, generated_data): 185 | batch_size = real_data.size()[0] 186 | 187 | # calculate interpolation 188 | alpha = torch.rand(batch_size, 1, 1, 1) 189 | alpha = alpha.expand_as(real_data) 190 | alpha = alpha.to(self.args.device) 191 | interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data 192 | interpolated = Variable(interpolated, requires_grad=True) 193 | interpolated = interpolated.to(self.args.device) 194 | 195 | # calculate probability of interpolated examples 196 | prob_interpolated = self.d_model(interpolated) 197 | 198 | # calculate gradients of probabilities with respect to examples 199 | gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, 200 | grad_outputs=torch.ones(prob_interpolated.size()).to(self.args.device), 201 | create_graph=True, retain_graph=True)[0] 202 | 203 | # gradients have shape (batch_size, num_channels, img_width, img_height), 204 | # so flatten to easily take norm per example in batch 205 | gradients = gradients.view(batch_size, -1) 206 | 207 | # derivatives of the gradient close to 0 can cause problems because of 208 | # the square root, so manually calculate norm and add epsilon 209 | gradients_norm = gradients.norm(p=2, dim=1) 210 | 211 | # return gradient penalty 212 | return ((gradients_norm - 1) ** 2).mean() 213 | 214 | def _generator_iteration(self, inputs, targets): 215 | # zero grads 216 | self.g_optimizer.zero_grad() 217 | 218 | # get generated data 219 | generated_data = self.g_model(inputs) 220 | loss = 0. 221 | 222 | # reconstruction loss 223 | if self.args.reconstruction_weight > 0.: 224 | loss_recon = self.reconstruction(generated_data, targets) 225 | loss += loss_recon * self.args.reconstruction_weight 226 | self.losses['G_recon'].append(loss_recon.data.item()) 227 | 228 | # range loss 229 | if self.args.range_weight > 0.: 230 | loss_rng = self.range(generated_data) 231 | loss += loss_rng * self.args.range_weight 232 | self.losses['G_rng'].append(loss_rng.data.item()) 233 | 234 | # adversarial loss 235 | if self.args.adversarial_weight > 0.: 236 | d_generated = self.d_model(generated_data) 237 | if self.args.relativistic: 238 | d_real = self.d_model(targets) 239 | loss_adv = (d_real - d_generated.mean()).mean() - (d_generated - d_real.mean()).mean() 240 | else: 241 | loss_adv = -d_generated.mean() 242 | loss += loss_adv * self.args.adversarial_weight 243 | self.losses['G_adv'].append(loss_adv.data.item()) 244 | 245 | # perceptual loss 246 | if self.args.perceptual_weight > 0.: 247 | loss_perc = self.perceptual(generated_data, targets) 248 | loss += loss_perc * self.args.perceptual_weight 249 | self.losses['G_perc'].append(loss_perc.data.item()) 250 | 251 | # style loss 252 | if self.args.style_weight > 0.: 253 | loss_sty = self.style(generated_data, targets) 254 | loss += loss_sty * self.args.style_weight 255 | self.losses['G_sty'].append(loss_sty.data.item()) 256 | 257 | # backward loss 258 | loss.backward() 259 | self.g_optimizer.step() 260 | 261 | # record loss 262 | self.losses['G'].append(loss.data.item()) 263 | 264 | def _train_iteration(self, data): 265 | # set inputs 266 | inputs = data['input'].to(self.args.device) 267 | targets = data['target'].to(self.args.device) 268 | 269 | # critic iteration 270 | if self.args.adversarial_weight > 0.: 271 | if self.args.wgan: 272 | self._critic_wgan_iteration(inputs, targets) 273 | else: 274 | self._critic_hinge_iteration(inputs, targets) 275 | 276 | # only update generator every |critic_iterations| iterations 277 | if self.step % self.args.num_critic == 0: 278 | self._generator_iteration(inputs, targets) 279 | 280 | # logging 281 | if self.step % self.args.print_every == 0: 282 | line2print = 'Iteration {}'.format(self.step) 283 | if self.args.adversarial_weight > 0.: 284 | line2print += ', D: {:.6f}, D_r: {:.6f}, D_f: {:.6f}'.format(self.losses['D'][-1], self.losses['D_r'][-1], self.losses['D_f'][-1]) 285 | if self.args.penalty_weight > 0.: 286 | line2print += ', D_gp: {:.6f}'.format(self.losses['D_gp'][-1]) 287 | if self.step > self.args.num_critic: 288 | line2print += ', G: {:.5f}'.format(self.losses['G'][-1]) 289 | if self.args.reconstruction_weight: 290 | line2print += ', G_recon: {:.6f}'.format(self.losses['G_recon'][-1]) 291 | if self.args.range_weight: 292 | line2print += ', G_rng: {:.6f}'.format(self.losses['G_rng'][-1]) 293 | if self.args.perceptual_weight: 294 | line2print += ', G_perc: {:.6f}'.format(self.losses['G_perc'][-1]) 295 | if self.args.style_weight: 296 | line2print += ', G_sty: {:.8f}'.format(self.losses['G_sty'][-1]) 297 | if self.args.adversarial_weight: 298 | line2print += ', G_adv: {:.6f},'.format(self.losses['G_adv'][-1]) 299 | logging.info(line2print) 300 | 301 | # plots for tensorboard 302 | if self.args.use_tb: 303 | if self.args.adversarial_weight > 0.: 304 | self.tb.add_scalar('data/loss_d', self.losses['D'][-1], self.step) 305 | if self.step > self.args.num_critic: 306 | self.tb.add_scalar('data/loss_g', self.losses['G'][-1], self.step) 307 | 308 | def _eval_iteration(self, data, epoch): 309 | # set inputs 310 | inputs = data['input'].to(self.args.device) 311 | targets = data['target'] 312 | paths = data['path'] 313 | 314 | # evaluation 315 | with torch.no_grad(): 316 | outputs = self.g_model(inputs) 317 | 318 | # save image and compute psnr 319 | self._save_image(outputs, paths[0], epoch + 1) 320 | psnr = compute_psnr(outputs, targets) 321 | 322 | return psnr 323 | 324 | def _train_epoch(self, loader): 325 | self.g_model.train() 326 | self.d_model.train() 327 | 328 | # train over epochs 329 | for _, data in enumerate(loader): 330 | self._train_iteration(data) 331 | self.step += 1 332 | 333 | def _eval_epoch(self, loader, epoch): 334 | self.g_model.eval() 335 | psnrs = [] 336 | 337 | # eval over epoch 338 | for _, data in enumerate(loader): 339 | psnr = self._eval_iteration(data, epoch) 340 | psnrs.append(psnr) 341 | 342 | # record psnr 343 | self.losses['psnr'].append(average(psnrs)) 344 | logging.info('Evaluation: {:.3f}'.format(self.losses['psnr'][-1])) 345 | if self.args.use_tb: 346 | self.tb.add_scalar('data/psnr', self.losses['psnr'][-1], epoch) 347 | 348 | def _save_image(self, image, path, epoch): 349 | directory = os.path.join(self.args.save_path, 'images', 'epoch_{}'.format(epoch)) 350 | save_path = os.path.join(directory, os.path.basename(path)) 351 | mkdir(directory) 352 | save_image(image.data.cpu(), save_path) 353 | 354 | def _train(self, loaders): 355 | # run epoch iterations 356 | for epoch in range(self.args.epochs): 357 | # random seed 358 | torch.manual_seed(random.randint(1, 123456789)) 359 | 360 | logging.info('\nEpoch {}'.format(epoch + 1)) 361 | 362 | # train 363 | self._train_epoch(loaders['train']) 364 | 365 | # scheduler 366 | self.g_scheduler.step(epoch=epoch) 367 | self.d_scheduler.step(epoch=epoch) 368 | 369 | # evaluation 370 | if ((epoch + 1) % self.args.eval_every == 0) or ((epoch + 1) == self.args.epochs): 371 | self._eval_epoch(loaders['eval'], epoch) 372 | self._save_model(epoch) 373 | 374 | # best score 375 | logging.info('Best PSNR Score: {:.2f}\n'.format(max(self.losses['psnr']))) 376 | 377 | def train(self): 378 | # get loader 379 | loaders = get_loaders(self.args) 380 | 381 | # run training 382 | self._train(loaders) 383 | 384 | # close tensorboard 385 | if self.args.use_tb: 386 | self.tb.close() 387 | 388 | def eval(self): 389 | # get loader 390 | loaders = get_loaders(self.args) 391 | 392 | # evaluation 393 | logging.info('\nEvaluating...') 394 | self._eval_epoch(loaders['eval'], 0) 395 | 396 | # close tensorboard 397 | if self.args.use_tb: 398 | self.tb.close() 399 | -------------------------------------------------------------------------------- /denoising/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/denoising/utils/__init__.py -------------------------------------------------------------------------------- /denoising/utils/misc.py: -------------------------------------------------------------------------------- 1 | import logging.config 2 | import os 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import math 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision.utils import make_grid 9 | from os import path, makedirs 10 | 11 | def mkdir(save_path): 12 | if not path.exists(save_path): 13 | makedirs(save_path) 14 | 15 | def make_image_grid(x, nrow, padding=0, pad_value=0): 16 | x = x.clone().cpu().data 17 | grid = make_grid(x, nrow=nrow, padding=padding, normalize=True, scale_each=False, pad_value=pad_value) 18 | return grid 19 | 20 | def tensor_to_image(x): 21 | ndarr = x.squeeze(dim=0).mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).squeeze().to('cpu', torch.uint8).numpy() 22 | image = Image.fromarray(ndarr) 23 | return image 24 | 25 | def plot_image_grid(x, nrow, padding=0): 26 | grid = make_image_grid(x=x, nrow=nrow, padding=padding).permute(1, 2, 0).numpy() 27 | plt.imshow(grid) 28 | plt.show() 29 | 30 | def save_image(x, path, size=None): 31 | image = tensor_to_image(x) 32 | if size: 33 | image = image.resize((size, size), Image.NEAREST) 34 | image.save(path) 35 | 36 | def save_image_grid(x, path, nrow=8, size=None): 37 | grid = make_image_grid(x, nrow) 38 | save_image(grid, path, size=size) 39 | 40 | def setup_logging(log_file='log.txt', resume=False, dummy=False): 41 | if dummy: 42 | logging.getLogger('dummy') 43 | else: 44 | if os.path.isfile(log_file) and resume: 45 | file_mode = 'a' 46 | else: 47 | file_mode = 'w' 48 | 49 | root_logger = logging.getLogger() 50 | if root_logger.handlers: 51 | root_logger.removeHandler(root_logger.handlers[0]) 52 | logging.basicConfig(level=logging.INFO, 53 | format="%(asctime)s - %(levelname)s - %(message)s", 54 | datefmt="%Y-%m-%d %H:%M:%S", 55 | filename=log_file, 56 | filemode=file_mode) 57 | console = logging.StreamHandler() 58 | console.setLevel(logging.INFO) 59 | formatter = logging.Formatter('%(message)s') 60 | console.setFormatter(formatter) 61 | logging.getLogger('').addHandler(console) 62 | 63 | def average(lst): 64 | return sum(lst) / len(lst) 65 | 66 | def rgb2yc(img): 67 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 68 | rlt = rlt.round() 69 | return rlt 70 | 71 | def compute_psnr(x, y): 72 | x = np.array(tensor_to_image(x)).astype(np.float64) 73 | y = np.array(tensor_to_image(y)).astype(np.float64) 74 | 75 | mse = np.mean((x - y) ** 2) 76 | return 20 * math.log10(255.0 / math.sqrt(mse)) 77 | 78 | if __name__ == "__main__": 79 | print('None') 80 | 81 | -------------------------------------------------------------------------------- /denoising/utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def get_exp_scheduler_with_warmup(optimizer, rampup_steps=5, sustain_steps=5): 5 | def lr_lambda(step): 6 | if step < rampup_steps: 7 | return min(1., 1.8 ** ((step - rampup_steps))) 8 | elif step < rampup_steps + sustain_steps: 9 | return 1. 10 | else: 11 | return max(0.1, 0.85 ** (step - rampup_steps - sustain_steps)) 12 | 13 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 14 | 15 | def get_cosine_scheduler_with_warmup(optimizer, rampup_steps=3, sustain_steps=2, priod=5): 16 | def lr_lambda(step): 17 | if step < rampup_steps: 18 | return min(1., 1.5 ** ((step - rampup_steps))) 19 | elif step < rampup_steps + sustain_steps: 20 | return 1. 21 | else: 22 | return max(0.1, (1 + math.cos(math.pi * (step - rampup_steps - sustain_steps) / priod)) / 2) 23 | 24 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 25 | 26 | if __name__ == "__main__": 27 | import matplotlib.pyplot as plt 28 | 29 | model = torch.nn.Linear(2, 1) 30 | optimizer = torch.optim.SGD(model.parameters(), lr=2.5e-4) 31 | lr_scheduler = get_exp_scheduler_with_warmup(optimizer) 32 | 33 | lrs = [] 34 | for i in range(25): 35 | lr_scheduler.step() 36 | lrs.append(optimizer.param_groups[0]["lr"]) 37 | 38 | plt.plot(lrs) 39 | plt.show() 40 | print(lrs) 41 | -------------------------------------------------------------------------------- /denoising/utils/recorderx.py: -------------------------------------------------------------------------------- 1 | import utils.misc as misc 2 | from tensorboardX import SummaryWriter 3 | 4 | class RecoderX(): 5 | def __init__(self, log_dir): 6 | self.log_dir = log_dir 7 | self.writer = SummaryWriter(logdir=log_dir) 8 | self.log = '' 9 | 10 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 11 | self.writer.add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step, walltime=walltime) 12 | 13 | def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): 14 | self.writer.add_scalars(main_tag=main_tag, tag_scalar_dict=tag_scalar_dict, global_step=global_step, walltime=walltime) 15 | 16 | def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'): 17 | self.writer.add_image(tag=tag, img_tensor=img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats) 18 | 19 | def add_image_grid(self, tag, img_tensor, nrow=8, global_step=None, padding=0, pad_value=0,walltime=None, dataformats='CHW'): 20 | grid = misc.make_image_grid(img_tensor, nrow, padding=padding, pad_value=pad_value) 21 | self.writer.add_image(tag=tag, img_tensor=grid, global_step=global_step, walltime=walltime, dataformats=dataformats) 22 | 23 | def add_graph(self, graph_profile, walltime=None): 24 | self.writer.add_graph(graph_profile, walltime=walltime) 25 | 26 | def add_histogram(self, tag, values, global_step=None): 27 | self.writer.add_histogram(tag, values, global_step) 28 | 29 | def add_figure(self, tag, figure, global_step=None, close=True, walltime=None): 30 | self.writer.add_figure(tag, figure, global_step=global_step, close=close, walltime=walltime) 31 | 32 | def export_json(self, out_file): 33 | self.writer.export_scalars_to_json(out_file) 34 | 35 | def close(self): 36 | self.writer.close() 37 | 38 | if __name__ == "__main__": 39 | print('None') 40 | -------------------------------------------------------------------------------- /figures/activations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/figures/activations.png -------------------------------------------------------------------------------- /figures/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/figures/figure1.png -------------------------------------------------------------------------------- /figures/tradeoff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/figures/tradeoff.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | scikit-image 3 | scikit-learn 4 | scipy 5 | numpy 6 | torch 7 | torchvision 8 | tensorboardx 9 | -------------------------------------------------------------------------------- /super-resolution/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .datasets import DatasetSR 3 | from torch.utils.data import DataLoader 4 | 5 | def get_loaders(args): 6 | # datasets 7 | dataset_train = DatasetSR(root=os.path.join(args.root, 'train'), scale=args.scale, training=True, crop_size=args.crop_size) 8 | dataset_val = DatasetSR(root=os.path.join(args.root, 'val'), scale=args.scale, training=False, max_size=args.max_size) 9 | 10 | # loaders 11 | loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 12 | loader_eval = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1) 13 | loaders = {'train': loader_train, 'eval': loader_eval} 14 | 15 | return loaders 16 | -------------------------------------------------------------------------------- /super-resolution/data/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import os 4 | import random 5 | from torchvision import transforms 6 | from PIL import Image 7 | 8 | class DatasetSR(torch.utils.data.dataset.Dataset): 9 | def __init__(self, root='', scale=4, training=True, crop_size=60, max_size=None): 10 | self.root = root 11 | self.scale = scale if (scale % 1) else int(scale) 12 | self.training = training 13 | self.crop_size = crop_size 14 | self.max_size = max_size 15 | 16 | self._init() 17 | 18 | def _init(self): 19 | # data paths 20 | inputs = glob.glob(os.path.join(self.root, 'img_x{}'.format(self.scale), '*.*'))[:self.max_size] 21 | targets = [x.replace('img_x{}'.format(self.scale), 'img') for x in inputs] 22 | self.paths = {'input' : inputs, 'target' : targets} 23 | 24 | # transforms 25 | t_list = [transforms.ToTensor()] 26 | if self.training: 27 | t_list.append(lambda x: ((255. * x) + torch.zeros_like(x).uniform_(0., 1.)) / 256.,) 28 | self.image_transform = transforms.Compose(t_list) 29 | 30 | def _get_augment_params(self, size): 31 | random.seed(random.randint(0, 12345)) 32 | 33 | # position 34 | w_size, h_size = size 35 | x = random.randint(0, max(0, w_size - self.crop_size)) 36 | y = random.randint(0, max(0, h_size - self.crop_size)) 37 | 38 | # flip 39 | flip = random.random() > 0.5 40 | return {'crop_pos': (x, y), 'flip': flip} 41 | 42 | def _augment(self, image, aug_params, scale=1): 43 | x, y = aug_params['crop_pos'] 44 | image = image.crop((x * scale, y * scale, x * scale + self.crop_size * scale, y * scale + self.crop_size * scale)) 45 | if aug_params['flip']: 46 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 47 | return image 48 | 49 | def __getitem__(self, index): 50 | # input image 51 | input = Image.open(self.paths['input'][index]).convert('RGB') 52 | 53 | # target image 54 | target = Image.open(self.paths['target'][index]).convert('RGB') 55 | 56 | if self.training: 57 | aug_params = self._get_augment_params(input.size) 58 | input = self._augment(input, aug_params) 59 | target = self._augment(target, aug_params, self.scale) 60 | 61 | input = self.image_transform(input) 62 | target = self.image_transform(target) 63 | 64 | return {'input': input, 'target': target, 'path': self.paths['target'][index]} 65 | 66 | def __len__(self): 67 | return len(self.paths['input']) -------------------------------------------------------------------------------- /super-resolution/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import logging 4 | import signal 5 | import sys 6 | import torch.backends.cudnn as cudnn 7 | from trainer import Trainer 8 | from datetime import datetime 9 | from os import path 10 | from utils import misc 11 | from random import randint 12 | 13 | # torch.autograd.set_detect_anomaly(True) 14 | 15 | def get_arguments(): 16 | parser = argparse.ArgumentParser(description='super-resolution') 17 | parser.add_argument('--device', default='cuda', help='device assignment ("cpu" or "cuda")') 18 | parser.add_argument('--device-ids', default=[0], type=int, nargs='+', help='device ids assignment (e.g 0 1 2 3') 19 | parser.add_argument('--d-model', default='d_srgan', help='discriminator architecture (default: srgan)') 20 | parser.add_argument('--g-model', default='g_srgan', help='generator architecture (default: srgan)') 21 | parser.add_argument('--model-config', default='', help='additional architecture configuration') 22 | parser.add_argument('--dis-to-load', default='', help='resume training from file (default: None)') 23 | parser.add_argument('--gen-to-load', default='', help='resume training from file (default: None)') 24 | parser.add_argument('--root', default='', help='root dataset folder') 25 | parser.add_argument('--scale', default=4, type=float, help='super-resolution scale (default: 4)') 26 | parser.add_argument('--crop-size', default=40, type=int, help='low resolution cropping size (default: 40)') 27 | parser.add_argument('--max-size', default=None, type=int, help='validation set max-size (default: None)') 28 | parser.add_argument('--num-workers', default=2, type=int, help='number of workers (default: 2)') 29 | parser.add_argument('--batch-size', default=16, type=int, help='batch-size (default: 16)') 30 | parser.add_argument('--epochs', default=1000, type=int, help='epochs (default: 1000)') 31 | parser.add_argument('--lr', default=2e-4, type=float, help='lr (default: 2e-4)') 32 | parser.add_argument('--gen-betas', default=[0.9, 0.99], nargs=2, type=float, help='scheduler gamma (default: 0.9, 0.99)') 33 | parser.add_argument('--dis-betas', default=[0.9, 0.99], nargs=2, type=float, help='scheduler gamma (default: 0.9, 0.99)') 34 | parser.add_argument('--num-critic', default=1, type=int, help='critic iterations (default: 1)') 35 | parser.add_argument('--wgan', default=False, action='store_true', help='critic wgan loss (default: false)') 36 | parser.add_argument('--relativistic', default=False, action='store_true', help='relativistic wgan loss (default: false)') 37 | parser.add_argument('--step-size', default=300, type=int, help='scheduler step size (default: 300)') 38 | parser.add_argument('--gamma', default=0.5, type=float, help='scheduler gamma (default: 0.5)') 39 | parser.add_argument('--penalty-weight', default=0, type=float, help='gradient penalty weight (default: 0)') 40 | parser.add_argument('--range-weight', default=0, type=float, help='pixel-weight (default: 0)') 41 | parser.add_argument('--reconstruction-weight', default=1.0, type=float, help='reconstruction-weight (default: 1.0)') 42 | parser.add_argument('--perceptual-weight', default=0, type=float, help='perceptual-weight (default: 0)') 43 | parser.add_argument('--adversarial-weight', default=0.01, type=float, help='adversarial-weight (default: 0.01)') 44 | parser.add_argument('--style-weight', default=0, type=float, help='style-weight (default: 0)') 45 | parser.add_argument('--seed', default=-1, type=int, help='random seed (default: random)') 46 | parser.add_argument('--print-every', default=20, type=int, help='print-every (default: 20)') 47 | parser.add_argument('--eval-every', default=50, type=int, help='eval-every (default: 50)') 48 | parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results', help='results dir') 49 | parser.add_argument('--save', metavar='SAVE', default='', help='saved folder') 50 | parser.add_argument('--evaluation', default=False, action='store_true', help='evaluate a model (default: false)') 51 | parser.add_argument('--use-tb', default=False, action='store_true', help='use tensorboardx (default: false)') 52 | args = parser.parse_args() 53 | 54 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 55 | if args.save == '': 56 | args.save = time_stamp 57 | args.save_path = path.join(args.results_dir, args.save) 58 | if args.seed == -1: 59 | args.seed = randint(0, 12345) 60 | return args 61 | 62 | def main(): 63 | # arguments 64 | args = get_arguments() 65 | 66 | torch.manual_seed(args.seed) 67 | 68 | # cuda 69 | if 'cuda' in args.device and torch.cuda.is_available(): 70 | torch.cuda.manual_seed_all(args.seed) 71 | torch.cuda.set_device(args.device_ids[0]) 72 | cudnn.benchmark = True 73 | else: 74 | args.device_ids = None 75 | 76 | # set logs 77 | misc.mkdir(args.save_path) 78 | misc.mkdir(path.join(args.save_path, 'images')) 79 | misc.setup_logging(path.join(args.save_path, 'log.txt')) 80 | 81 | # print logs 82 | logging.info(args) 83 | 84 | # trainer 85 | trainer = Trainer(args) 86 | 87 | if args.evaluation: 88 | trainer.eval() 89 | else: 90 | trainer.train() 91 | 92 | if __name__ == '__main__': 93 | # enables a ctrl-c without triggering errors 94 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) 95 | main() -------------------------------------------------------------------------------- /super-resolution/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .srgan import * 2 | from .xsrgan import * 3 | from .vanilla import * 4 | from .xvanilla import * 5 | from .xdense import * -------------------------------------------------------------------------------- /super-resolution/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/super-resolution/models/modules/__init__.py -------------------------------------------------------------------------------- /super-resolution/models/modules/activations.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class xUnit(nn.Module): 4 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False): 5 | super(xUnit, self).__init__() 6 | # xUnit 7 | self.features = nn.Sequential( 8 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(), 9 | nn.ReLU(), 10 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features), 11 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | a = self.features(x) 17 | r = x * a 18 | return r 19 | 20 | class xUnitS(nn.Module): 21 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False): 22 | super(xUnitS, self).__init__() 23 | # slim xUnit 24 | self.features = nn.Sequential( 25 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features), 26 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(), 27 | nn.Sigmoid() 28 | ) 29 | 30 | def forward(self, x): 31 | a = self.features(x) 32 | r = x * a 33 | return r 34 | 35 | class xUnitD(nn.Module): 36 | def __init__(self, num_features=64, kernel_size=7, batch_norm=False): 37 | super(xUnitD, self).__init__() 38 | # dense xUnit 39 | self.features = nn.Sequential( 40 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=1, padding=0), 41 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(), 42 | nn.ReLU(), 43 | nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features), 44 | nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(), 45 | nn.Sigmoid() 46 | ) 47 | 48 | def forward(self, x): 49 | a = self.features(x) 50 | r = x * a 51 | return r 52 | 53 | class Identity(nn.Module): 54 | def __init__(self,): 55 | super(Identity, self).__init__() 56 | 57 | def forward(self, x): 58 | return x 59 | -------------------------------------------------------------------------------- /super-resolution/models/modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .misc import shave_edge 5 | from ..vgg import MultiVGGFeaturesExtractor 6 | 7 | class RangeLoss(nn.Module): 8 | def __init__(self, min_value=0., max_value=1., invalidity_margins=None): 9 | super(RangeLoss, self).__init__() 10 | self.min_value = min_value 11 | self.max_value = max_value 12 | self.invalidity_margins = invalidity_margins 13 | 14 | def forward(self, inputs): 15 | if self.invalidity_margins: 16 | inputs = shave_edge(inputs, self.invalidity_margins, self.invalidity_margins) 17 | loss = (F.relu(self.min_value - inputs) + F.relu(inputs - self.max_value)).mean() 18 | return loss 19 | 20 | class PerceptualLoss(nn.Module): 21 | def __init__(self, features_to_compute, criterion=torch.nn.L1Loss(), shave_edge=None): 22 | super(PerceptualLoss, self).__init__() 23 | self.criterion = criterion 24 | self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute, shave_edge=shave_edge).eval() 25 | 26 | def forward(self, inputs, targets): 27 | inputs_fea = self.features_extractor(inputs) 28 | with torch.no_grad(): 29 | targets_fea = self.features_extractor(targets) 30 | 31 | loss = 0 32 | for key in inputs_fea.keys(): 33 | loss += self.criterion(inputs_fea[key], targets_fea[key].detach()) 34 | 35 | return loss 36 | 37 | class StyleLoss(nn.Module): 38 | def __init__(self, features_to_compute, criterion=torch.nn.L1Loss(), shave_edge=None): 39 | super(StyleLoss, self).__init__() 40 | self.criterion = criterion 41 | self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute, use_input_norm=True, shave_edge=shave_edge).eval() 42 | 43 | def forward(self, inputs, targets): 44 | inputs_fea = self.features_extractor(inputs) 45 | with torch.no_grad(): 46 | targets_fea = self.features_extractor(targets) 47 | 48 | loss = 0 49 | for key in inputs_fea.keys(): 50 | inputs_gram = self._gram_matrix(inputs_fea[key]) 51 | with torch.no_grad(): 52 | targets_gram = self._gram_matrix(targets_fea[key]).detach() 53 | 54 | loss += self.criterion(inputs_gram, targets_gram) 55 | 56 | return loss 57 | 58 | def _gram_matrix(self, x): 59 | a, b, c, d = x.size() 60 | features = x.view(a, b, c * d) 61 | gram = features.bmm(features.transpose(1, 2)) 62 | return gram.div(b * c * d) -------------------------------------------------------------------------------- /super-resolution/models/modules/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class UpsampleX2(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size=3): 7 | super(UpsampleX2, self).__init__() 8 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=(out_channels * 4), kernel_size=kernel_size, padding=(kernel_size // 2)) 9 | self.shuffler = nn.PixelShuffle(2) 10 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 11 | 12 | def forward(self, x): 13 | return self.lrelu(self.shuffler(self.conv(x))) 14 | 15 | def center_crop(x, height, width): 16 | crop_h = torch.FloatTensor([x.size()[2]]).sub(height).div(-2) 17 | crop_w = torch.FloatTensor([x.size()[3]]).sub(width).div(-2) 18 | 19 | return F.pad(x, [ 20 | crop_w.ceil().int()[0], crop_w.floor().int()[0], 21 | crop_h.ceil().int()[0], crop_h.floor().int()[0], 22 | ]) 23 | 24 | def shave_edge(x, shave_h, shave_w): 25 | return F.pad(x, [-shave_w, -shave_w, -shave_h, -shave_h]) 26 | 27 | def shave_modulo(x, factor): 28 | shave_w = x.size(-1) % factor 29 | shave_h = x.size(-2) % factor 30 | return F.pad(x, [0, -shave_w, 0, -shave_h]) 31 | 32 | if __name__ == "__main__": 33 | x = torch.randn(1, 2, 4, 6) 34 | y = shave_edge(x, 1, 2) 35 | print(x) 36 | print(y) -------------------------------------------------------------------------------- /super-resolution/models/srgan.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from math import log 3 | from torch.nn import functional as F 4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 5 | from .modules.misc import UpsampleX2, center_crop 6 | 7 | __all__ = ['g_srgan', 'd_srgan'] 8 | 9 | def initialize_weights(net, scale=1.): 10 | if not isinstance(net, list): 11 | net = [net] 12 | for layer in net: 13 | for m in layer.modules(): 14 | if isinstance(m, nn.Conv2d): 15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 16 | m.weight.data *= scale # for residual block 17 | if m.bias is not None: 18 | m.bias.data.zero_() 19 | elif isinstance(m, nn.Linear): 20 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 21 | m.weight.data *= scale 22 | if m.bias is not None: 23 | m.bias.data.zero_() 24 | elif isinstance(m, nn.BatchNorm2d): 25 | nn.init.constant_(m.weight, 1) 26 | nn.init.constant_(m.bias.data, 0.0) 27 | 28 | class GenBlock(nn.Module): 29 | def __init__(self, num_features=64, kernel_size=3, bias=False): 30 | super(GenBlock, self).__init__() 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv1 = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias) 33 | self.conv2 = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias) 34 | 35 | initialize_weights([self.conv1, self.conv2], 0.1) 36 | 37 | def forward(self, x): 38 | residual = x 39 | x = self.conv2(self.relu(self.conv1(x))) 40 | x = residual + x 41 | return x 42 | 43 | class DisBlock(nn.Module): 44 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False): 45 | super(DisBlock, self).__init__() 46 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 47 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias) 48 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias) 49 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True) 50 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True) 51 | 52 | initialize_weights([self.conv1, self.conv2], 0.1) 53 | 54 | if normalization: 55 | self.conv1 = SpectralNorm(self.conv1) 56 | self.conv2 = SpectralNorm(self.conv2) 57 | 58 | def forward(self, x): 59 | x = self.lrelu(self.bn1(self.conv1(x))) 60 | x = self.lrelu(self.bn2(self.conv2(x))) 61 | return x 62 | 63 | class Generator(nn.Module): 64 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 65 | super(Generator, self).__init__() 66 | self.scale = scale 67 | # image to features 68 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 69 | self.image_to_features = nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=5, padding=2) 70 | 71 | # features 72 | blocks = [] 73 | for _ in range(gen_blocks): 74 | blocks.append(GenBlock(num_features=num_features)) 75 | self.features = nn.Sequential(*blocks) 76 | 77 | # upsampling 78 | blocks = [] 79 | for _ in range(int(log(scale, 2))): 80 | block = UpsampleX2(in_channels=num_features, out_channels=num_features) 81 | blocks.append(block) 82 | self.usample = nn.Sequential(*blocks) 83 | 84 | # features to image 85 | self.hrconv = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, padding=1) 86 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2) 87 | 88 | # init weights 89 | initialize_weights([self.image_to_features, self.hrconv, self.features_to_image], 0.1) 90 | 91 | def forward(self, x): 92 | r = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) 93 | x = self.lrelu(self.image_to_features(x)) 94 | x = self.features(x) 95 | x = self.usample(x) 96 | x = self.features_to_image(self.lrelu(self.hrconv(x))) 97 | x = x + r 98 | return x 99 | 100 | class Discriminator(nn.Module): 101 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 102 | super(Discriminator, self).__init__() 103 | self.crop_size = 4 * pow(2, dis_blocks) 104 | 105 | # image to features 106 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 107 | 108 | # features 109 | blocks = [] 110 | for i in range(0, dis_blocks - 1): 111 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 112 | self.features = nn.Sequential(*blocks) 113 | 114 | # classifier 115 | self.classifier = nn.Sequential( 116 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100), 117 | nn.LeakyReLU(negative_slope=0.1), 118 | nn.Linear(100, 1) 119 | ) 120 | 121 | def forward(self, x): 122 | x = center_crop(x, self.crop_size, self.crop_size) 123 | x = self.image_to_features(x) 124 | x = self.features(x) 125 | x = x.flatten(start_dim=1) 126 | x = self.classifier(x) 127 | return x 128 | 129 | class SNDiscriminator(nn.Module): 130 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 131 | super(SNDiscriminator, self).__init__() 132 | self.crop_size = 4 * pow(2, dis_blocks) 133 | 134 | # image to features 135 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 136 | 137 | # features 138 | blocks = [] 139 | for i in range(0, dis_blocks - 1): 140 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 141 | self.features = nn.Sequential(*blocks) 142 | 143 | # classifier 144 | self.classifier = nn.Sequential( 145 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)), 146 | nn.LeakyReLU(negative_slope=0.1), 147 | SpectralNorm(nn.Linear(100, 1)) 148 | ) 149 | 150 | def forward(self, x): 151 | x = center_crop(x, self.crop_size, self.crop_size) 152 | x = self.image_to_features(x) 153 | x = self.features(x) 154 | x = x.flatten(start_dim=1) 155 | x = self.classifier(x) 156 | return x 157 | 158 | def g_srgan(**config): 159 | config.setdefault('in_channels', 3) 160 | config.setdefault('num_features', 64) 161 | config.setdefault('gen_blocks', 16) 162 | config.setdefault('dis_blocks', 5) 163 | config.setdefault('scale', 4) 164 | 165 | _ = config.pop('spectral', False) 166 | 167 | return Generator(**config) 168 | 169 | def d_srgan(**config): 170 | config.setdefault('in_channels', 3) 171 | config.setdefault('num_features', 64) 172 | config.setdefault('gen_blocks', 16) 173 | config.setdefault('dis_blocks', 5) 174 | config.setdefault('scale', 4) 175 | 176 | sn = config.pop('spectral', False) 177 | 178 | if sn: 179 | return SNDiscriminator(**config) 180 | else: 181 | return Discriminator(**config) 182 | -------------------------------------------------------------------------------- /super-resolution/models/vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import log 5 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 6 | from .modules.misc import center_crop 7 | 8 | __all__ = ['g_vanilla', 'd_vanilla'] 9 | 10 | def initialize_weights(net, scale=1.): 11 | if not isinstance(net, list): 12 | net = [net] 13 | for layer in net: 14 | for m in layer.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 17 | m.weight.data *= scale # for residual block 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.Linear): 21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 22 | m.weight.data *= scale 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm2d): 26 | nn.init.constant_(m.weight, 1) 27 | nn.init.constant_(m.bias.data, 0.0) 28 | 29 | class GenBlock(nn.Module): 30 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, bias=True): 31 | super(GenBlock, self).__init__() 32 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias) 33 | self.bn = nn.BatchNorm2d(num_features=out_channels) 34 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 35 | 36 | initialize_weights([self.conv, self.norm], 0.02) 37 | 38 | def forward(self, x): 39 | x = self.lrelu(self.bn(self.conv(x))) 40 | return x 41 | 42 | class DisBlock(nn.Module): 43 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False): 44 | super(DisBlock, self).__init__() 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 46 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias) 47 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias) 48 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True) 49 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True) 50 | 51 | initialize_weights([self.conv1, self.conv2], 0.1) 52 | 53 | if normalization: 54 | self.conv1 = SpectralNorm(self.conv1) 55 | self.conv2 = SpectralNorm(self.conv2) 56 | 57 | def forward(self, x): 58 | x = self.lrelu(self.bn1(self.conv1(x))) 59 | x = self.lrelu(self.bn2(self.conv2(x))) 60 | return x 61 | 62 | class Generator(nn.Module): 63 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 64 | super(Generator, self).__init__() 65 | self.scale_factor = scale 66 | 67 | # image to features 68 | self.image_to_features = GenBlock(in_channels=in_channels, out_channels=num_features) 69 | 70 | # features 71 | blocks = [] 72 | for _ in range(gen_blocks): 73 | blocks.append(GenBlock(in_channels=num_features, out_channels=num_features)) 74 | self.features = nn.Sequential(*blocks) 75 | 76 | # features to image 77 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2) 78 | initialize_weights([self.features_to_image], 0.02) 79 | 80 | def forward(self, x): 81 | x = F.interpolate(x, size=(self.scale_factor * x.size(-2), self.scale_factor * x.size(-1)), mode='bicubic', align_corners=True) 82 | r = x 83 | x = self.image_to_features(x) 84 | x = self.features(x) 85 | x = self.features_to_image(x) 86 | x += r 87 | return x 88 | 89 | class Discriminator(nn.Module): 90 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 91 | super(Discriminator, self).__init__() 92 | self.crop_size = 4 * pow(2, dis_blocks) 93 | 94 | # image to features 95 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 96 | 97 | # features 98 | blocks = [] 99 | for i in range(0, dis_blocks - 1): 100 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 101 | self.features = nn.Sequential(*blocks) 102 | 103 | # classifier 104 | self.classifier = nn.Sequential( 105 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100), 106 | nn.LeakyReLU(negative_slope=0.1), 107 | nn.Linear(100, 1) 108 | ) 109 | 110 | def forward(self, x): 111 | x = center_crop(x, self.crop_size, self.crop_size) 112 | x = self.image_to_features(x) 113 | x = self.features(x) 114 | x = x.flatten(start_dim=1) 115 | x = self.classifier(x) 116 | return x 117 | 118 | class SNDiscriminator(nn.Module): 119 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 120 | super(SNDiscriminator, self).__init__() 121 | self.crop_size = 4 * pow(2, dis_blocks) 122 | 123 | # image to features 124 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 125 | 126 | # features 127 | blocks = [] 128 | for i in range(0, dis_blocks - 1): 129 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 130 | self.features = nn.Sequential(*blocks) 131 | 132 | # classifier 133 | self.classifier = nn.Sequential( 134 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)), 135 | nn.LeakyReLU(negative_slope=0.1), 136 | SpectralNorm(nn.Linear(100, 1)) 137 | ) 138 | 139 | def forward(self, x): 140 | x = center_crop(x, self.crop_size, self.crop_size) 141 | x = self.image_to_features(x) 142 | x = self.features(x) 143 | x = x.flatten(start_dim=1) 144 | x = self.classifier(x) 145 | return x 146 | 147 | def g_vanilla(**config): 148 | config.setdefault('in_channels', 3) 149 | config.setdefault('num_features', 64) 150 | config.setdefault('gen_blocks', 8) 151 | config.setdefault('dis_blocks', 5) 152 | config.setdefault('scale', 4) 153 | 154 | return Generator(**config) 155 | 156 | def d_vanilla(**config): 157 | config.setdefault('in_channels', 3) 158 | config.setdefault('num_features', 64) 159 | config.setdefault('gen_blocks', 8) 160 | config.setdefault('dis_blocks', 5) 161 | config.setdefault('scale', 4) 162 | 163 | return Discriminator(**config) -------------------------------------------------------------------------------- /super-resolution/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | from models.modules.misc import shave_edge 5 | from collections import OrderedDict 6 | 7 | names = {'vgg19': ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 8 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 9 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 10 | 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 11 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 12 | 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 13 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'], 15 | 16 | 'vgg19_bn': ['conv1_1', 'bn1_1', 'relu1_1', 'conv1_2', 'bn1_2', 'relu1_2', 'pool1', 17 | 'conv2_1', 'bn2_1', 'relu2_1', 'conv2_2', 'bn2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'bn3_1', 'relu3_1', 'conv3_2', 'bn3_2', 'relu3_2', 19 | 'conv3_3', 'bn3_3', 'relu3_3', 'conv3_4', 'bn3_4', 'relu3_4', 'pool3', 20 | 'conv4_1', 'bn4_1', 'relu4_1', 'conv4_2', 'bn4_2', 'relu4_2', 21 | 'conv4_3', 'bn4_3', 'relu4_3', 'conv4_4', 'bn4_4', 'relu4_4', 'pool4', 22 | 'conv5_1', 'bn5_1', 'relu5_1', 'conv5_2', 'bn5_2', 'relu5_2', 23 | 'conv5_3', 'bn5_3', 'relu5_3', 'conv5_4', 'bn5_4', 'relu5_4', 'pool5'] 24 | } 25 | 26 | 27 | class VGGFeaturesExtractor(nn.Module): 28 | def __init__(self, feature_layer='conv5_4', use_bn=False, use_input_norm=True, requires_grad=False): 29 | super(VGGFeaturesExtractor, self).__init__() 30 | self.use_input_norm = use_input_norm 31 | 32 | if use_bn: 33 | model = torchvision.models.vgg19_bn(pretrained=True) 34 | else: 35 | model = torchvision.models.vgg19(pretrained=True) 36 | 37 | if self.use_input_norm: 38 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 39 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 40 | self.register_buffer('mean', mean) 41 | self.register_buffer('std', std) 42 | 43 | layer_index = names['vgg19_bn'].index(feature_layer) if use_bn else names['vgg19'].index(feature_layer) 44 | self.features = nn.Sequential(*list(model.features.children())[:(layer_index + 1)]) 45 | 46 | if not requires_grad: 47 | for k, v in self.features.named_parameters(): 48 | v.requires_grad = False 49 | self.features.eval() 50 | 51 | def forward(self, x): 52 | # Assume input range is [0, 1] 53 | if self.use_input_norm: 54 | x = (x - self.mean) / self.std 55 | output = self.features(x) 56 | return output 57 | 58 | class MultiVGGFeaturesExtractor(nn.Module): 59 | def __init__(self, target_features=('relu1_1', 'relu2_1', 'relu3_1'), use_bn=False, use_input_norm=True, requires_grad=False, shave_edge=None): 60 | super(MultiVGGFeaturesExtractor, self).__init__() 61 | self.use_input_norm = use_input_norm 62 | self.target_features = target_features 63 | self.shave_edge = shave_edge 64 | 65 | if use_bn: 66 | model = torchvision.models.vgg19_bn(pretrained=True) 67 | names_key = 'vgg19_bn' 68 | else: 69 | model = torchvision.models.vgg19(pretrained=True) 70 | names_key = 'vgg19' 71 | 72 | if self.use_input_norm: 73 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 75 | self.register_buffer('mean', mean) 76 | self.register_buffer('std', std) 77 | 78 | self.target_indexes = [names[names_key].index(k) for k in self.target_features] 79 | self.features = nn.Sequential(*list(model.features.children())[:(max(self.target_indexes) + 1)]) 80 | 81 | if not requires_grad: 82 | for k, v in self.features.named_parameters(): 83 | v.requires_grad = False 84 | self.features.eval() 85 | 86 | def forward(self, x): 87 | if self.shave_edge: 88 | x = shave_edge(x, self.shave_edge, self.shave_edge) 89 | 90 | # assume input range is [0, 1] 91 | if self.use_input_norm: 92 | x = (x - self.mean) / self.std 93 | 94 | output = OrderedDict() 95 | for key, layer in self.features._modules.items(): 96 | x = layer(x) 97 | if int(key) in self.target_indexes: 98 | output.update({self.target_features[self.target_indexes.index(int(key))]: x}) 99 | return output -------------------------------------------------------------------------------- /super-resolution/models/xdense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import log 4 | from torch.nn import functional as F 5 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 6 | from .modules.activations import xUnitD 7 | from .modules.misc import UpsampleX2, center_crop 8 | 9 | __all__ = ['g_xdense', 'd_xdense'] 10 | 11 | def initialize_weights(net, scale=1.): 12 | if not isinstance(net, list): 13 | net = [net] 14 | for layer in net: 15 | for m in layer.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 18 | m.weight.data *= scale # for residual block 19 | if m.bias is not None: 20 | m.bias.data.zero_() 21 | elif isinstance(m, nn.Linear): 22 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 23 | m.weight.data *= scale 24 | if m.bias is not None: 25 | m.bias.data.zero_() 26 | elif isinstance(m, nn.BatchNorm2d): 27 | nn.init.constant_(m.weight, 1) 28 | nn.init.constant_(m.bias.data, 0.0) 29 | 30 | class xModule(nn.Module): 31 | def __init__(self, in_channels, out_channels): 32 | super(xModule, self).__init__() 33 | # features 34 | self.features = nn.Sequential( 35 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1), 36 | xUnitD(num_features=out_channels, batch_norm=False), 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | return x 42 | 43 | class xDenseLayer(nn.Module): 44 | def __init__(self, in_channels, growth_rate, bn_size): 45 | super(xDenseLayer, self).__init__() 46 | # features 47 | self.features = nn.Sequential( 48 | nn.PReLU(), 49 | nn.Conv2d(in_channels=in_channels,out_channels=bn_size * growth_rate, kernel_size=1, stride=1, bias=False), 50 | nn.PReLU(), 51 | xModule(in_channels=bn_size * growth_rate, out_channels=growth_rate) 52 | ) 53 | 54 | def forward(self, x): 55 | f = self.features(x) 56 | return torch.cat([x, f], dim=1) 57 | 58 | class xDenseBlock(nn.Module): 59 | def __init__(self, in_channels, num_layers, growth_rate, bn_size): 60 | super(xDenseBlock, self).__init__() 61 | # features 62 | blocks = [] 63 | for i in range(num_layers): 64 | blocks.append(xDenseLayer(in_channels=in_channels + growth_rate*i, growth_rate=growth_rate, bn_size=bn_size)) 65 | self.features = nn.Sequential(*blocks) 66 | 67 | def forward(self, x): 68 | x = self.features(x) 69 | return x 70 | 71 | class Transition(nn.Module): 72 | def __init__(self, in_channels, out_channels): 73 | super(Transition, self).__init__() 74 | # features 75 | self.features = nn.Sequential( 76 | nn.PReLU(), 77 | nn.Conv2d(in_channels=in_channels,out_channels=out_channels, kernel_size=1, stride=1, bias=False), 78 | ) 79 | 80 | def forward(self, x): 81 | x = self.features(x) 82 | return x 83 | 84 | class DisBlock(nn.Module): 85 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False): 86 | super(DisBlock, self).__init__() 87 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 88 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias) 89 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias) 90 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True) 91 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True) 92 | 93 | initialize_weights([self.conv1, self.conv2], 0.1) 94 | 95 | if normalization: 96 | self.conv1 = SpectralNorm(self.conv1) 97 | self.conv2 = SpectralNorm(self.conv2) 98 | 99 | def forward(self, x): 100 | x = self.lrelu(self.bn1(self.conv1(x))) 101 | x = self.lrelu(self.bn2(self.conv2(x))) 102 | return x 103 | 104 | class Generator(nn.Module): 105 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size, scale): 106 | super(Generator, self).__init__() 107 | self.scale = scale 108 | # image to features 109 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 110 | self.image_to_features = xModule(in_channels=in_channels, out_channels=num_features) 111 | 112 | # features 113 | blocks = [] 114 | self.num_features = num_features 115 | for i, num_layers in enumerate(gen_blocks): 116 | blocks.append(xDenseBlock(in_channels=self.num_features, num_layers=num_layers, growth_rate=growth_rate, bn_size=bn_size)) 117 | self.num_features += num_layers * growth_rate 118 | 119 | if i != len(gen_blocks) - 1: 120 | blocks.append(Transition(in_channels=self.num_features, out_channels=self.num_features // 2)) 121 | self.num_features = self.num_features // 2 122 | 123 | self.features = nn.Sequential(*blocks) 124 | 125 | # upsampling 126 | blocks = [] 127 | for _ in range(int(log(scale, 2))): 128 | block = UpsampleX2(in_channels=self.num_features, out_channels=self.num_features) 129 | blocks.append(block) 130 | self.usample = nn.Sequential(*blocks) 131 | 132 | # features to image 133 | self.hrconv = nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features, kernel_size=3, padding=1) 134 | self.features_to_image = nn.Conv2d(in_channels=self.num_features, out_channels=in_channels, kernel_size=5, padding=2) 135 | 136 | # init weights 137 | initialize_weights([self.image_to_features, self.hrconv, self.features_to_image], 0.1) 138 | 139 | def forward(self, x): 140 | r = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) 141 | x = self.image_to_features(x) 142 | x = self.features(x) 143 | x = self.usample(x) 144 | x = self.features_to_image(self.lrelu(self.hrconv(x))) 145 | x = x + r 146 | return x 147 | 148 | class Discriminator(nn.Module): 149 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size, scale): 150 | super(Discriminator, self).__init__() 151 | self.crop_size = 4 * pow(2, dis_blocks) 152 | 153 | # image to features 154 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 155 | 156 | # features 157 | blocks = [] 158 | for i in range(0, dis_blocks - 1): 159 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 160 | self.features = nn.Sequential(*blocks) 161 | 162 | # classifier 163 | self.classifier = nn.Sequential( 164 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100), 165 | nn.LeakyReLU(negative_slope=0.1), 166 | nn.Linear(100, 1) 167 | ) 168 | 169 | def forward(self, x): 170 | x = center_crop(x, self.crop_size, self.crop_size) 171 | x = self.image_to_features(x) 172 | x = self.features(x) 173 | x = x.flatten(start_dim=1) 174 | x = self.classifier(x) 175 | return x 176 | 177 | class SNDiscriminator(nn.Module): 178 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, growth_rate, bn_size, scale): 179 | super(SNDiscriminator, self).__init__() 180 | self.crop_size = 4 * pow(2, dis_blocks) 181 | 182 | # image to features 183 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 184 | 185 | # features 186 | blocks = [] 187 | for i in range(0, dis_blocks - 1): 188 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 189 | self.features = nn.Sequential(*blocks) 190 | 191 | # classifier 192 | self.classifier = nn.Sequential( 193 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)), 194 | nn.LeakyReLU(negative_slope=0.1), 195 | SpectralNorm(nn.Linear(100, 1)) 196 | ) 197 | 198 | def forward(self, x): 199 | x = center_crop(x, self.crop_size, self.crop_size) 200 | x = self.image_to_features(x) 201 | x = self.features(x) 202 | x = x.flatten(start_dim=1) 203 | x = self.classifier(x) 204 | return x 205 | 206 | def g_xdense(**config): 207 | config.setdefault('in_channels', 3) 208 | config.setdefault('num_features', 64) 209 | config.setdefault('gen_blocks', [4, 6, 8]) 210 | config.setdefault('dis_blocks', 5) 211 | config.setdefault('growth_rate', 8) 212 | config.setdefault('bn_size', 2) 213 | config.setdefault('scale', 4) 214 | 215 | _ = config.pop('spectral', False) 216 | 217 | return Generator(**config) 218 | 219 | def d_xdense(**config): 220 | config.setdefault('in_channels', 3) 221 | config.setdefault('num_features', 64) 222 | config.setdefault('gen_blocks', [4, 6, 8]) 223 | config.setdefault('dis_blocks', 5) 224 | config.setdefault('growth_rate', 8) 225 | config.setdefault('bn_size', 2) 226 | config.setdefault('scale', 4) 227 | 228 | sn = config.pop('spectral', False) 229 | 230 | if sn: 231 | return SNDiscriminator(**config) 232 | else: 233 | return Discriminator(**config) 234 | -------------------------------------------------------------------------------- /super-resolution/models/xsrgan.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from math import log 3 | from torch.nn import functional as F 4 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 5 | from .modules.activations import xUnitS 6 | from .modules.misc import UpsampleX2, center_crop 7 | 8 | __all__ = ['g_xsrgan', 'd_xsrgan', 'd_xsrgan_ad'] 9 | 10 | def initialize_weights(net, scale=1.): 11 | if not isinstance(net, list): 12 | net = [net] 13 | for layer in net: 14 | for m in layer.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 17 | m.weight.data *= scale # for residual block 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.Linear): 21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 22 | m.weight.data *= scale 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm2d): 26 | nn.init.constant_(m.weight, 1) 27 | nn.init.constant_(m.bias.data, 0.0) 28 | 29 | class GenBlock(nn.Module): 30 | def __init__(self, num_features=64, kernel_size=3, bias=False): 31 | super(GenBlock, self).__init__() 32 | self.xunit = xUnitS(num_features=num_features) 33 | self.conv1 = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias) 34 | self.conv2 = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias) 35 | 36 | initialize_weights([self.conv1, self.conv2], 0.1) 37 | 38 | def forward(self, x): 39 | residual = x 40 | x = self.conv2(self.xunit(self.conv1(x))) 41 | x = residual + x 42 | return x 43 | 44 | class DisBlock(nn.Module): 45 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False): 46 | super(DisBlock, self).__init__() 47 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 48 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias) 49 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias) 50 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True) 51 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True) 52 | 53 | initialize_weights([self.conv1, self.conv2], 0.1) 54 | 55 | if normalization: 56 | self.conv1 = SpectralNorm(self.conv1) 57 | self.conv2 = SpectralNorm(self.conv2) 58 | 59 | def forward(self, x): 60 | x = self.lrelu(self.bn1(self.conv1(x))) 61 | x = self.lrelu(self.bn2(self.conv2(x))) 62 | return x 63 | 64 | class Generator(nn.Module): 65 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 66 | super(Generator, self).__init__() 67 | self.scale = scale 68 | # image to features 69 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 70 | self.image_to_features = nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=5, padding=2) 71 | 72 | # features 73 | blocks = [] 74 | for _ in range(gen_blocks): 75 | blocks.append(GenBlock(num_features=num_features)) 76 | self.features = nn.Sequential(*blocks) 77 | 78 | # upsampling 79 | blocks = [] 80 | for _ in range(int(log(scale, 2))): 81 | block = UpsampleX2(in_channels=num_features, out_channels=num_features) 82 | blocks.append(block) 83 | self.usample = nn.Sequential(*blocks) 84 | 85 | # features to image 86 | self.hrconv = nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, padding=1) 87 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2) 88 | 89 | # init weights 90 | initialize_weights([self.image_to_features, self.hrconv, self.features_to_image], 0.1) 91 | 92 | def forward(self, x): 93 | r = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) 94 | x = self.lrelu(self.image_to_features(x)) 95 | x = self.features(x) 96 | x = self.usample(x) 97 | x = self.features_to_image(self.lrelu(self.hrconv(x))) 98 | x = x + r 99 | return x 100 | 101 | class Discriminator(nn.Module): 102 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 103 | super(Discriminator, self).__init__() 104 | self.crop_size = 4 * pow(2, dis_blocks) 105 | 106 | # image to features 107 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 108 | 109 | # features 110 | blocks = [] 111 | for i in range(0, dis_blocks - 1): 112 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 113 | self.features = nn.Sequential(*blocks) 114 | 115 | # classifier 116 | self.classifier = nn.Sequential( 117 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100), 118 | nn.LeakyReLU(negative_slope=0.1), 119 | nn.Linear(100, 1) 120 | ) 121 | 122 | def forward(self, x): 123 | x = center_crop(x, self.crop_size, self.crop_size) 124 | x = self.image_to_features(x) 125 | x = self.features(x) 126 | x = x.flatten(start_dim=1) 127 | x = self.classifier(x) 128 | return x 129 | 130 | class DiscriminatorAdaptive(nn.Module): 131 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 132 | super(DiscriminatorAdaptive, self).__init__() 133 | self.crop_size = 4 * pow(2, dis_blocks) 134 | 135 | # image to features 136 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 137 | 138 | # features 139 | blocks = [] 140 | for i in range(0, dis_blocks - 1): 141 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 142 | self.features = nn.Sequential(*blocks) 143 | 144 | # pooling 145 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1) 146 | 147 | # classifier 148 | self.classifier = nn.Sequential( 149 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8), num_features), 150 | nn.LeakyReLU(negative_slope=0.1), 151 | nn.Linear(num_features, 1) 152 | ) 153 | 154 | def forward(self, x): 155 | x = center_crop(x, self.crop_size, self.crop_size) 156 | x = self.image_to_features(x) 157 | x = self.features(x) 158 | x = self.avg_pool(x) 159 | x = x.flatten(start_dim=1) 160 | x = self.classifier(x) 161 | return x 162 | 163 | class SNDiscriminator(nn.Module): 164 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 165 | super(SNDiscriminator, self).__init__() 166 | self.crop_size = 4 * pow(2, dis_blocks) 167 | 168 | # image to features 169 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 170 | 171 | # features 172 | blocks = [] 173 | for i in range(0, dis_blocks - 1): 174 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 175 | self.features = nn.Sequential(*blocks) 176 | 177 | # pooling 178 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1) 179 | 180 | # classifier 181 | self.classifier = nn.Sequential( 182 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)), 183 | nn.LeakyReLU(negative_slope=0.1), 184 | SpectralNorm(nn.Linear(100, 1)) 185 | ) 186 | 187 | def forward(self, x): 188 | x = center_crop(x, self.crop_size, self.crop_size) 189 | x = self.image_to_features(x) 190 | x = self.features(x) 191 | x = x.flatten(start_dim=1) 192 | x = self.classifier(x) 193 | return x 194 | 195 | class SNDiscriminatorAdaptive(nn.Module): 196 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 197 | super(SNDiscriminatorAdaptive, self).__init__() 198 | self.crop_size = 4 * pow(2, dis_blocks) 199 | 200 | # image to features 201 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 202 | 203 | # features 204 | blocks = [] 205 | for i in range(0, dis_blocks - 1): 206 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 207 | self.features = nn.Sequential(*blocks) 208 | 209 | # classifier 210 | self.classifier = nn.Sequential( 211 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8), num_features)), 212 | nn.LeakyReLU(negative_slope=0.1), 213 | SpectralNorm(nn.Linear(num_features, 1)) 214 | ) 215 | 216 | def forward(self, x): 217 | x = center_crop(x, self.crop_size, self.crop_size) 218 | x = self.image_to_features(x) 219 | x = self.features(x) 220 | x = self.avg_pool(x) 221 | x = x.flatten(start_dim=1) 222 | x = self.classifier(x) 223 | return x 224 | 225 | def g_xsrgan(**config): 226 | config.setdefault('in_channels', 3) 227 | config.setdefault('num_features', 64) 228 | config.setdefault('gen_blocks', 10) 229 | config.setdefault('dis_blocks', 5) 230 | config.setdefault('scale', 4) 231 | 232 | _ = config.pop('spectral', False) 233 | 234 | return Generator(**config) 235 | 236 | def d_xsrgan(**config): 237 | config.setdefault('in_channels', 3) 238 | config.setdefault('num_features', 64) 239 | config.setdefault('gen_blocks', 10) 240 | config.setdefault('dis_blocks', 5) 241 | config.setdefault('scale', 4) 242 | 243 | sn = config.pop('spectral', False) 244 | 245 | if sn: 246 | return SNDiscriminator(**config) 247 | else: 248 | return Discriminator(**config) 249 | 250 | def d_xsrgan_ad(**config): 251 | config.setdefault('in_channels', 3) 252 | config.setdefault('num_features', 64) 253 | config.setdefault('gen_blocks', 10) 254 | config.setdefault('dis_blocks', 5) 255 | config.setdefault('scale', 4) 256 | 257 | sn = config.pop('spectral', False) 258 | 259 | if sn: 260 | return SNDiscriminatorAdaptive(**config) 261 | else: 262 | return DiscriminatorAdaptive(**config) 263 | -------------------------------------------------------------------------------- /super-resolution/models/xvanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import log 5 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 6 | from .modules.activations import xUnitS 7 | from .modules.misc import center_crop 8 | 9 | __all__ = ['g_xvanilla', 'd_xvanilla'] 10 | 11 | def initialize_weights(net, scale=1.): 12 | if not isinstance(net, list): 13 | net = [net] 14 | for layer in net: 15 | for m in layer.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 18 | m.weight.data *= scale # for residual block 19 | if m.bias is not None: 20 | m.bias.data.zero_() 21 | elif isinstance(m, nn.Linear): 22 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 23 | m.weight.data *= scale 24 | if m.bias is not None: 25 | m.bias.data.zero_() 26 | elif isinstance(m, nn.BatchNorm2d): 27 | nn.init.constant_(m.weight, 1) 28 | nn.init.constant_(m.bias.data, 0.0) 29 | 30 | class GenBlock(nn.Module): 31 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, bias=True): 32 | super(GenBlock, self).__init__() 33 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size // 2), bias=bias) 34 | self.bn = nn.BatchNorm2d(num_features=out_channels) 35 | self.xunit = xUnitS(num_features=out_channels, batch_norm=True) 36 | 37 | initialize_weights([self.conv, self.norm], 0.02) 38 | 39 | def forward(self, x): 40 | x = self.xunit(self.bn(self.conv(x))) 41 | return x 42 | 43 | class DisBlock(nn.Module): 44 | def __init__(self, in_channels=64, out_channels=64, bias=True, normalization=False): 45 | super(DisBlock, self).__init__() 46 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 47 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=bias) 48 | self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=bias) 49 | self.bn1 = nn.BatchNorm2d(out_channels, affine=True) 50 | self.bn2 = nn.BatchNorm2d(out_channels, affine=True) 51 | 52 | initialize_weights([self.conv1, self.conv2], 0.1) 53 | 54 | if normalization: 55 | self.conv1 = SpectralNorm(self.conv1) 56 | self.conv2 = SpectralNorm(self.conv2) 57 | 58 | def forward(self, x): 59 | x = self.lrelu(self.bn1(self.conv1(x))) 60 | x = self.lrelu(self.bn2(self.conv2(x))) 61 | return x 62 | 63 | class Generator(nn.Module): 64 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 65 | super(Generator, self).__init__() 66 | self.scale_factor = scale 67 | 68 | # image to features 69 | self.image_to_features = GenBlock(in_channels=in_channels, out_channels=num_features) 70 | 71 | # features 72 | blocks = [] 73 | for _ in range(gen_blocks): 74 | blocks.append(GenBlock(in_channels=num_features, out_channels=num_features)) 75 | self.features = nn.Sequential(*blocks) 76 | 77 | # features to image 78 | self.features_to_image = nn.Conv2d(in_channels=num_features, out_channels=in_channels, kernel_size=5, padding=2) 79 | initialize_weights([self.features_to_image], 0.02) 80 | 81 | def forward(self, x): 82 | x = F.interpolate(x, size=(self.scale_factor * x.size(-2), self.scale_factor * x.size(-1)), mode='bicubic', align_corners=True) 83 | r = x 84 | x = self.image_to_features(x) 85 | x = self.features(x) 86 | x = self.features_to_image(x) 87 | x += r 88 | return x 89 | 90 | class Discriminator(nn.Module): 91 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 92 | super(Discriminator, self).__init__() 93 | self.crop_size = 4 * pow(2, dis_blocks) 94 | 95 | # image to features 96 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=False) 97 | 98 | # features 99 | blocks = [] 100 | for i in range(0, dis_blocks - 1): 101 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=False)) 102 | self.features = nn.Sequential(*blocks) 103 | 104 | # classifier 105 | self.classifier = nn.Sequential( 106 | nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100), 107 | nn.LeakyReLU(negative_slope=0.1), 108 | nn.Linear(100, 1) 109 | ) 110 | 111 | def forward(self, x): 112 | x = center_crop(x, self.crop_size, self.crop_size) 113 | x = self.image_to_features(x) 114 | x = self.features(x) 115 | x = x.flatten(start_dim=1) 116 | x = self.classifier(x) 117 | return x 118 | 119 | class SNDiscriminator(nn.Module): 120 | def __init__(self, in_channels, num_features, gen_blocks, dis_blocks, scale): 121 | super(SNDiscriminator, self).__init__() 122 | self.crop_size = 4 * pow(2, dis_blocks) 123 | 124 | # image to features 125 | self.image_to_features = DisBlock(in_channels=in_channels, out_channels=num_features, bias=True, normalization=True) 126 | 127 | # features 128 | blocks = [] 129 | for i in range(0, dis_blocks - 1): 130 | blocks.append(DisBlock(in_channels=num_features * min(pow(2, i), 8), out_channels=num_features * min(pow(2, i + 1), 8), bias=False, normalization=True)) 131 | self.features = nn.Sequential(*blocks) 132 | 133 | # classifier 134 | self.classifier = nn.Sequential( 135 | SpectralNorm(nn.Linear(num_features * min(pow(2, dis_blocks - 1), 8) * 4 * 4, 100)), 136 | nn.LeakyReLU(negative_slope=0.1), 137 | SpectralNorm(nn.Linear(100, 1)) 138 | ) 139 | 140 | def forward(self, x): 141 | x = center_crop(x, self.crop_size, self.crop_size) 142 | x = self.image_to_features(x) 143 | x = self.features(x) 144 | x = x.flatten(start_dim=1) 145 | x = self.classifier(x) 146 | return x 147 | 148 | def g_xvanilla(**config): 149 | config.setdefault('in_channels', 3) 150 | config.setdefault('num_features', 64) 151 | config.setdefault('gen_blocks', 8) 152 | config.setdefault('dis_blocks', 5) 153 | config.setdefault('scale', 4) 154 | 155 | return Generator(**config) 156 | 157 | def d_xvanilla(**config): 158 | config.setdefault('in_channels', 3) 159 | config.setdefault('num_features', 64) 160 | config.setdefault('gen_blocks', 8) 161 | config.setdefault('dis_blocks', 5) 162 | config.setdefault('scale', 4) 163 | 164 | return Discriminator(**config) -------------------------------------------------------------------------------- /super-resolution/scripts/matlab/bicubic_subsample.m: -------------------------------------------------------------------------------- 1 | function bicubic_subsample() 2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. 3 | 4 | %% set parameters 5 | % comment the unnecessary line 6 | input_folder = ''; 7 | save_mod_folder = ''; 8 | save_LR_folder = ''; 9 | % save_bic_folder = ''; 10 | 11 | up_scale = 3; 12 | mod_scale = up_scale; 13 | 14 | if exist('save_mod_folder', 'var') 15 | if exist(save_mod_folder, 'dir') 16 | disp(['It will cover ', save_mod_folder]); 17 | else 18 | mkdir(save_mod_folder); 19 | end 20 | end 21 | if exist('save_LR_folder', 'var') 22 | if exist(save_LR_folder, 'dir') 23 | disp(['It will cover ', save_LR_folder]); 24 | else 25 | mkdir(save_LR_folder); 26 | end 27 | end 28 | if exist('save_bic_folder', 'var') 29 | if exist(save_bic_folder, 'dir') 30 | disp(['It will cover ', save_bic_folder]); 31 | else 32 | mkdir(save_bic_folder); 33 | end 34 | end 35 | 36 | idx = 0; 37 | filepaths = dir(fullfile(input_folder,'*.*')); 38 | for i = 1 : length(filepaths) 39 | [paths,imname,ext] = fileparts(filepaths(i).name); 40 | if isempty(imname) 41 | disp('Ignore . folder.'); 42 | elseif strcmp(imname, '.') 43 | disp('Ignore .. folder.'); 44 | else 45 | idx = idx + 1; 46 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 47 | fprintf(str_rlt); 48 | % read image 49 | img = imread(fullfile(input_folder, [imname, ext])); 50 | img = im2double(img); 51 | % modcrop 52 | img = modcrop(img, mod_scale); 53 | if exist('save_mod_folder', 'var') 54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); 55 | end 56 | % LR 57 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 58 | if exist('save_LR_folder', 'var') 59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 60 | end 61 | % Bicubic 62 | if exist('save_bic_folder', 'var') 63 | im_B = imresize(im_LR, up_scale, 'bicubic'); 64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png'])); 65 | end 66 | end 67 | end 68 | end 69 | 70 | %% modcrop 71 | function img = modcrop(img, modulo) 72 | if size(img,3) == 1 73 | sz = size(img); 74 | sz = sz - mod(sz, modulo); 75 | if mod(modulo, 1) == 0 76 | img = img(1:sz(1), 1:sz(2)); 77 | end 78 | else 79 | tmpsz = size(img); 80 | sz = tmpsz(1:2); 81 | sz = sz - mod(sz, modulo); 82 | if mod(modulo, 1) == 0 83 | img = img(1:sz(1), 1:sz(2),:); 84 | end 85 | end 86 | 87 | end 88 | -------------------------------------------------------------------------------- /super-resolution/scripts/python/extract_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | from multiprocessing import Pool 7 | 8 | def get_arguments(): 9 | parser = argparse.ArgumentParser(description='Extract sub images for faster data loading.') 10 | parser.add_argument('--root', default='', required=True, help='source folder') 11 | parser.add_argument('--crop-size', default=256, type=int, help='crop size (default: 256)') 12 | parser.add_argument('--step-size', default=128, type=int, help='step size (default: 128)') 13 | parser.add_argument('--scale', default=4, type=int, help='super-resolution scale (default: 4)') 14 | parser.add_argument('--num-threads', default=8, type=int, help='num of threads (default: 8)') 15 | parser.add_argument('--only-once', default=False, action='store_true', help='center sub image') 16 | args = parser.parse_args() 17 | return args 18 | 19 | def mkdir(save_path): 20 | if not os.path.exists(save_path): 21 | os.makedirs(save_path) 22 | 23 | def mkdirs(args): 24 | mkdir(os.path.join(args.root,'img_sub')) 25 | mkdir(os.path.join(args.root,'img_sub_x{}'.format(args.scale))) 26 | 27 | def get_image_paths(args): 28 | paths = glob.glob(os.path.join(args.root, 'img_x{}'.format(args.scale), '*.*')) 29 | return paths 30 | 31 | def save_images(args, i, path, input, target): 32 | save_path = path.replace('img_x', 'img_sub_x').replace('.png', '_{}.png'.format(i)) 33 | input.save(save_path) 34 | save_path = path.replace('img_x{}'.format(args.scale), 'img_sub').replace('.png', '_{}.png'.format(i)) 35 | target.save(save_path) 36 | 37 | def load_images(path, args): 38 | input = Image.open(path) 39 | target = Image.open(path.replace('img_x{}'.format(args.scale), 'img')) 40 | return input, target 41 | 42 | def worker(path, args): 43 | input, target = load_images(path, args) 44 | 45 | h, w = input.size 46 | hs = np.arange(0, h - args.crop_size + 1, args.step_size) 47 | ws = np.arange(0, w - args.crop_size + 1, args.step_size) 48 | 49 | hs = np.append(hs, h - (args.crop_size + 1)) 50 | ws = np.append(ws, w - (args.crop_size + 1)) 51 | 52 | counts = 0 53 | 54 | if args.only_once: 55 | hs = [hs[len(hs) // 2]] 56 | ws = [ws[len(ws) // 2]] 57 | 58 | for x in hs: 59 | for y in ws: 60 | cropped_input = input.crop((x, y, (x + args.crop_size), (y + args.crop_size))) 61 | cropped_target = target.crop((x * args.scale, y * args.scale, (x + args.crop_size) * args.scale, (y + args.crop_size) * args.scale)) 62 | save_images(args, counts, path, cropped_input, cropped_target) 63 | counts += 1 64 | 65 | print('Processed {:s}'.format(os.path.basename(path))) 66 | 67 | def main(): 68 | args = get_arguments() 69 | 70 | mkdirs(args) 71 | paths = get_image_paths(args) 72 | pool = Pool(args.num_threads) 73 | 74 | for path in paths: 75 | pool.apply_async(worker, args=(path, args)) 76 | 77 | pool.close() 78 | pool.join() 79 | print('All subprocesses done.') 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /super-resolution/scripts/python/remove_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | from PIL import Image, ImageStat 5 | from multiprocessing import Pool 6 | 7 | def get_arguments(): 8 | parser = argparse.ArgumentParser(description='Remove images in dataset') 9 | parser.add_argument('--root', default='', required=True, help='source folder') 10 | parser.add_argument('--scale', default=4, type=int, help='super-resolution scale (default: 4)') 11 | parser.add_argument('--limit-size', default=50, type=int, help='limit size for input (default: 50)') 12 | parser.add_argument('--num-threads', default=8, type=int, help='num of threads (default: 8)') 13 | args = parser.parse_args() 14 | return args 15 | 16 | def is_grayscale(image): 17 | stat = ImageStat.Stat(image) 18 | 19 | if sum(stat.sum) / 3 == stat.sum[0]: 20 | return True 21 | else: 22 | return False 23 | 24 | def is_small(image, args): 25 | h, w = image.size 26 | 27 | if h < args.limit_size or w < args.limit_size: 28 | return True 29 | else: 30 | return False 31 | 32 | def remove(path): 33 | if os.path.isfile(path): 34 | os.remove(path) 35 | 36 | def get_image_paths(args): 37 | paths = glob.glob(os.path.join(args.root, 'img_x{}'.format(args.scale), '*.*')) 38 | return paths 39 | 40 | def get_paths(path, args): 41 | input = path 42 | target = path.replace('img_x{}'.format(args.scale), 'img') 43 | return input, target 44 | 45 | def worker(path, args): 46 | input, target = get_paths(path, args) 47 | 48 | image = Image.open(input).convert("RGB") 49 | 50 | if is_grayscale(image) or is_small(image, args): 51 | remove(input) 52 | remove(target) 53 | 54 | print('Removed {:s}'.format(os.path.basename(path))) 55 | 56 | def main(): 57 | args = get_arguments() 58 | 59 | paths = get_image_paths(args) 60 | pool = Pool(args.num_threads) 61 | 62 | for path in paths: 63 | pool.apply_async(worker, args=(path, args)) 64 | 65 | pool.close() 66 | pool.join() 67 | print('All subprocesses done.') 68 | 69 | if __name__ == "__main__": 70 | main() -------------------------------------------------------------------------------- /super-resolution/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import models 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.optim.lr_scheduler import StepLR 7 | from torch.autograd import grad as torch_grad, Variable 8 | from data import get_loaders 9 | from ast import literal_eval 10 | from utils.recorderx import RecoderX 11 | from utils.misc import save_image, average, mkdir, compute_psnr 12 | from models.modules.losses import RangeLoss, PerceptualLoss, StyleLoss 13 | 14 | class Trainer(): 15 | def __init__(self, args): 16 | # parameters 17 | self.args = args 18 | self.print_model = True 19 | self.invalidity_margins = None 20 | 21 | if self.args.use_tb: 22 | self.tb = RecoderX(log_dir=args.save_path) 23 | 24 | # initialize 25 | self._init() 26 | 27 | def _init_model(self): 28 | # initialize model 29 | if self.args.model_config != '': 30 | model_config = dict({}, **literal_eval(self.args.model_config)) 31 | else: 32 | model_config = {} 33 | 34 | g_model = models.__dict__[self.args.g_model] 35 | d_model = models.__dict__[self.args.d_model] 36 | self.g_model = g_model(**model_config) 37 | self.d_model = d_model(**model_config) 38 | 39 | # loading weights 40 | if self.args.gen_to_load != '': 41 | logging.info('\nLoading g-model...') 42 | self.g_model.load_state_dict(torch.load(self.args.gen_to_load, map_location='cpu')) 43 | if self.args.dis_to_load != '': 44 | logging.info('\nLoading d-model...') 45 | self.d_model.load_state_dict(torch.load(self.args.dis_to_load, map_location='cpu')) 46 | 47 | # to cuda 48 | self.g_model = self.g_model.to(self.args.device) 49 | self.d_model = self.d_model.to(self.args.device) 50 | 51 | # parallel 52 | if self.args.device_ids and len(self.args.device_ids) > 1: 53 | self.g_model = torch.nn.DataParallel(self.g_model, self.args.device_ids) 54 | self.d_model = torch.nn.DataParallel(self.d_model, self.args.device_ids) 55 | 56 | # print model 57 | if self.print_model: 58 | logging.info(self.g_model) 59 | logging.info('Number of parameters in generator: {}\n'.format(sum([l.nelement() for l in self.g_model.parameters()]))) 60 | logging.info(self.d_model) 61 | logging.info('Number of parameters in discriminator: {}\n'.format(sum([l.nelement() for l in self.d_model.parameters()]))) 62 | self.print_model = False 63 | 64 | def _init_optim(self): 65 | # initialize optimizer 66 | self.g_optimizer = torch.optim.Adam(self.g_model.parameters(), lr=self.args.lr, betas=self.args.gen_betas) 67 | self.d_optimizer = torch.optim.Adam(self.d_model.parameters(), lr=self.args.lr, betas=self.args.dis_betas) 68 | 69 | # initialize scheduler 70 | self.g_scheduler = StepLR(self.g_optimizer, step_size=self.args.step_size, gamma=self.args.gamma) 71 | self.d_scheduler = StepLR(self.d_optimizer, step_size=self.args.step_size, gamma=self.args.gamma) 72 | 73 | # initialize criterion 74 | if self.args.reconstruction_weight: 75 | self.reconstruction = torch.nn.L1Loss().to(self.args.device) 76 | if self.args.perceptual_weight > 0.: 77 | self.perceptual = PerceptualLoss(features_to_compute=['conv5_4'], criterion=torch.nn.L1Loss(), shave_edge=self.invalidity_margins).to(self.args.device) 78 | if self.args.style_weight > 0.: 79 | self.style = StyleLoss(features_to_compute=['relu3_1', 'relu2_1']).to(self.args.device) 80 | if self.args.range_weight > 0.: 81 | self.range = RangeLoss(invalidity_margins=self.invalidity_margins).to(self.args.device) 82 | 83 | def _init(self): 84 | # init parameters 85 | self.step = 0 86 | self.losses = {'D': [], 'D_r': [], 'D_gp': [], 'D_f': [], 'G': [], 'G_recon': [], 'G_rng': [], 'G_perc': [], 'G_sty': [], 'G_adv': [], 'psnr': []} 87 | 88 | # initialize model 89 | self._init_model() 90 | 91 | # initialize optimizer 92 | self._init_optim() 93 | 94 | def _save_model(self, epoch): 95 | # save models 96 | torch.save(self.g_model.state_dict(), os.path.join(self.args.save_path, '{}_e{}.pt'.format(self.args.g_model, epoch + 1))) 97 | torch.save(self.d_model.state_dict(), os.path.join(self.args.save_path, '{}_e{}.pt'.format(self.args.d_model, epoch + 1))) 98 | 99 | def _set_require_grads(self, model, require_grad): 100 | for p in model.parameters(): 101 | p.requires_grad_(require_grad) 102 | 103 | def _critic_hinge_iteration(self, inputs, targets): 104 | # require grads 105 | self._set_require_grads(self.d_model, True) 106 | 107 | # get generated data 108 | generated_data = self.g_model(inputs) 109 | 110 | # zero grads 111 | self.d_optimizer.zero_grad() 112 | 113 | # calculate probabilities on real and generated data 114 | d_real = self.d_model(targets) 115 | d_generated = self.d_model(generated_data.detach()) 116 | 117 | # create total loss and optimize 118 | loss_r = F.relu(1.0 - d_real).mean() 119 | loss_f = F.relu(1.0 + d_generated).mean() 120 | loss = loss_r + loss_f 121 | 122 | # get gradient penalty 123 | if self.args.penalty_weight > 0.: 124 | gradient_penalty = self._gradient_penalty(targets, generated_data) 125 | loss += gradient_penalty 126 | 127 | loss.backward() 128 | 129 | self.d_optimizer.step() 130 | 131 | # record loss 132 | self.losses['D'].append(loss.data.item()) 133 | self.losses['D_r'].append(loss_r.data.item()) 134 | self.losses['D_f'].append(loss_f.data.item()) 135 | if self.args.penalty_weight > 0.: 136 | self.losses['D_gp'].append(gradient_penalty.data.item()) 137 | 138 | # require grads 139 | self._set_require_grads(self.d_model, False) 140 | 141 | def _critic_wgan_iteration(self, inputs, targets): 142 | # require grads 143 | self._set_require_grads(self.d_model, True) 144 | 145 | # get generated data 146 | generated_data = self.g_model(inputs) 147 | 148 | # zero grads 149 | self.d_optimizer.zero_grad() 150 | 151 | # calculate probabilities on real and generated data 152 | d_real = self.d_model(targets) 153 | d_generated = self.d_model(generated_data.detach()) 154 | 155 | # create total loss and optimize 156 | if self.args.relativistic: 157 | loss_r = -(d_real - d_generated.mean()).mean() 158 | loss_f = (d_generated - d_real.mean()).mean() 159 | else: 160 | loss_r = -d_real.mean() 161 | loss_f = d_generated.mean() 162 | loss = loss_f + loss_r 163 | 164 | # get gradient penalty 165 | if self.args.penalty_weight > 0.: 166 | gradient_penalty = self._gradient_penalty(targets, generated_data) 167 | loss += gradient_penalty 168 | 169 | loss.backward() 170 | 171 | self.d_optimizer.step() 172 | 173 | # record loss 174 | self.losses['D'].append(loss.data.item()) 175 | self.losses['D_r'].append(loss_r.data.item()) 176 | self.losses['D_f'].append(loss_f.data.item()) 177 | if self.args.penalty_weight > 0.: 178 | self.losses['D_gp'].append(gradient_penalty.data.item()) 179 | 180 | # require grads 181 | self._set_require_grads(self.d_model, False) 182 | 183 | def _gradient_penalty(self, real_data, generated_data): 184 | batch_size = real_data.size(0) 185 | 186 | # calculate interpolation 187 | alpha = torch.rand(batch_size, 1, 1, 1) 188 | alpha = alpha.expand_as(real_data) 189 | alpha = alpha.to(self.args.device) 190 | interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data 191 | interpolated = Variable(interpolated, requires_grad=True) 192 | interpolated = interpolated.to(self.args.device) 193 | 194 | # calculate probability of interpolated examples 195 | prob_interpolated = self.d_model(interpolated) 196 | 197 | # calculate gradients of probabilities with respect to examples 198 | gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, 199 | grad_outputs=torch.ones(prob_interpolated.size()).to(self.args.device), 200 | create_graph=True, retain_graph=True)[0] 201 | 202 | # gradients have shape (batch_size, num_channels, img_width, img_height), 203 | # so flatten to easily take norm per example in batch 204 | gradients = gradients.view(batch_size, -1) 205 | 206 | # derivatives of the gradient close to 0 can cause problems because of 207 | # the square root, so manually calculate norm and add epsilon 208 | gradients_norm = gradients.norm(p=2, dim=1) 209 | 210 | # return gradient penalty 211 | return ((gradients_norm - 1) ** 2).mean() 212 | 213 | def _generator_iteration(self, inputs, targets): 214 | # zero grads 215 | self.g_optimizer.zero_grad() 216 | 217 | # get generated data 218 | generated_data = self.g_model(inputs) 219 | loss = 0. 220 | 221 | # reconstruction loss 222 | if self.args.reconstruction_weight > 0.: 223 | loss_recon = self.reconstruction(generated_data, targets) 224 | loss += loss_recon * self.args.reconstruction_weight 225 | self.losses['G_recon'].append(loss_recon.data.item()) 226 | 227 | # range loss 228 | if self.args.range_weight > 0.: 229 | loss_rng = self.range(generated_data) 230 | loss += loss_rng * self.args.range_weight 231 | self.losses['G_rng'].append(loss_rng.data.item()) 232 | 233 | # adversarial loss 234 | if self.args.adversarial_weight > 0.: 235 | d_generated = self.d_model(generated_data) 236 | if self.args.relativistic: 237 | d_real = self.d_model(targets) 238 | loss_adv = (d_real - d_generated.mean()).mean() - (d_generated - d_real.mean()).mean() 239 | else: 240 | loss_adv = -d_generated.mean() 241 | loss += loss_adv * self.args.adversarial_weight 242 | self.losses['G_adv'].append(loss_adv.data.item()) 243 | 244 | # perceptual loss 245 | if self.args.perceptual_weight > 0.: 246 | loss_perc = self.perceptual(generated_data, targets) 247 | loss += loss_perc * self.args.perceptual_weight 248 | self.losses['G_perc'].append(loss_perc.data.item()) 249 | 250 | # style loss 251 | if self.args.style_weight > 0.: 252 | loss_sty = self.style(generated_data, targets) 253 | loss += loss_sty * self.args.style_weight 254 | self.losses['G_sty'].append(loss_sty.data.item()) 255 | 256 | # backward loss 257 | loss.backward() 258 | self.g_optimizer.step() 259 | 260 | # record loss 261 | self.losses['G'].append(loss.data.item()) 262 | 263 | def _train_iteration(self, data): 264 | # set inputs 265 | inputs = data['input'].to(self.args.device) 266 | targets = data['target'].to(self.args.device) 267 | 268 | # critic iteration 269 | if self.args.adversarial_weight > 0.: 270 | if self.args.wgan or self.args.relativistic: 271 | self._critic_wgan_iteration(inputs, targets) 272 | else: 273 | self._critic_hinge_iteration(inputs, targets) 274 | 275 | # only update generator every |critic_iterations| iterations 276 | if self.step % self.args.num_critic == 0: 277 | self._generator_iteration(inputs, targets) 278 | 279 | # logging 280 | if self.step % self.args.print_every == 0: 281 | line2print = 'Iteration {}'.format(self.step) 282 | if self.args.adversarial_weight > 0.: 283 | line2print += ', D: {:.6f}, D_r: {:.6f}, D_f: {:.6f}'.format(self.losses['D'][-1], self.losses['D_r'][-1], self.losses['D_f'][-1]) 284 | if self.args.penalty_weight > 0.: 285 | line2print += ', D_gp: {:.6f}'.format(self.losses['D_gp'][-1]) 286 | if self.step > self.args.num_critic: 287 | line2print += ', G: {:.5f}'.format(self.losses['G'][-1]) 288 | if self.args.reconstruction_weight: 289 | line2print += ', G_recon: {:.6f}'.format(self.losses['G_recon'][-1]) 290 | if self.args.range_weight: 291 | line2print += ', G_rng: {:.6f}'.format(self.losses['G_rng'][-1]) 292 | if self.args.perceptual_weight: 293 | line2print += ', G_perc: {:.6f}'.format(self.losses['G_perc'][-1]) 294 | if self.args.style_weight: 295 | line2print += ', G_sty: {:.8f}'.format(self.losses['G_sty'][-1]) 296 | if self.args.adversarial_weight: 297 | line2print += ', G_adv: {:.6f},'.format(self.losses['G_adv'][-1]) 298 | logging.info(line2print) 299 | 300 | # plots for tensorboard 301 | if self.args.use_tb: 302 | if self.args.adversarial_weight > 0.: 303 | self.tb.add_scalar('data/loss_d', self.losses['D'][-1], self.step) 304 | if self.step > self.args.num_critic: 305 | self.tb.add_scalar('data/loss_g', self.losses['G'][-1], self.step) 306 | 307 | def _eval_iteration(self, data, epoch): 308 | # set inputs 309 | inputs = data['input'].to(self.args.device) 310 | targets = data['target'] 311 | paths = data['path'] 312 | 313 | # evaluation 314 | with torch.no_grad(): 315 | outputs = self.g_model(inputs) 316 | 317 | # save image and compute psnr 318 | self._save_image(outputs, paths[0], epoch + 1) 319 | psnr = compute_psnr(outputs, targets, self.args.scale) 320 | 321 | return psnr 322 | 323 | def _train_epoch(self, loader): 324 | self.g_model.train() 325 | self.d_model.train() 326 | 327 | # train over epochs 328 | for _, data in enumerate(loader): 329 | self._train_iteration(data) 330 | self.step += 1 331 | 332 | def _eval_epoch(self, loader, epoch): 333 | self.g_model.eval() 334 | psnrs = [] 335 | 336 | # eval over epoch 337 | for _, data in enumerate(loader): 338 | psnr = self._eval_iteration(data, epoch) 339 | psnrs.append(psnr) 340 | 341 | # record psnr 342 | self.losses['psnr'].append(average(psnrs)) 343 | logging.info('Evaluation: {:.3f}'.format(self.losses['psnr'][-1])) 344 | if self.args.use_tb: 345 | self.tb.add_scalar('data/psnr', self.losses['psnr'][-1], epoch) 346 | 347 | def _save_image(self, image, path, epoch): 348 | directory = os.path.join(self.args.save_path, 'images', 'epoch_{}'.format(epoch)) 349 | save_path = os.path.join(directory, os.path.basename(path)) 350 | mkdir(directory) 351 | save_image(image.data.cpu(), save_path) 352 | 353 | def _train(self, loaders): 354 | # run epoch iterations 355 | for epoch in range(self.args.epochs): 356 | logging.info('\nEpoch {}'.format(epoch + 1)) 357 | 358 | # train 359 | self._train_epoch(loaders['train']) 360 | 361 | # scheduler 362 | self.g_scheduler.step(epoch=epoch) 363 | self.d_scheduler.step(epoch=epoch) 364 | 365 | # evaluation 366 | if ((epoch + 1) % self.args.eval_every == 0) or ((epoch + 1) == self.args.epochs): 367 | self._eval_epoch(loaders['eval'], epoch) 368 | self._save_model(epoch) 369 | 370 | # best score 371 | logging.info('Best PSNR Score: {:.2f}\n'.format(max(self.losses['psnr']))) 372 | 373 | def train(self): 374 | # get loader 375 | loaders = get_loaders(self.args) 376 | 377 | # run training 378 | self._train(loaders) 379 | 380 | # close tensorboard 381 | if self.args.use_tb: 382 | self.tb.close() 383 | 384 | def eval(self): 385 | # get loader 386 | loaders = get_loaders(self.args) 387 | 388 | # evaluation 389 | logging.info('\nEvaluating...') 390 | self._eval_epoch(loaders['eval'], 0) 391 | 392 | # close tensorboard 393 | if self.args.use_tb: 394 | self.tb.close() -------------------------------------------------------------------------------- /super-resolution/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kligvasser/xUnit/f8773f9f73a8990b03a09b8590d9d195c2104d53/super-resolution/utils/__init__.py -------------------------------------------------------------------------------- /super-resolution/utils/core.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/thstkdgus35/bicubic_pytorch/blob/master/core.py 3 | 4 | A standalone PyTorch implementation for fast and efficient bicubic resampling. 5 | The resulting values are the same to MATLAB function imresize('bicubic'). 6 | ## Author: Sanghyun Son 7 | ## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary) 8 | ## Version: 1.1.0 9 | ## Last update: July 9th, 2020 (KST) 10 | Depencency: torch 11 | Example:: 12 | >>> import torch 13 | >>> import core 14 | >>> x = torch.arange(16).float().view(1, 1, 4, 4) 15 | >>> y = core.imresize(x, sides=(3, 3)) 16 | >>> print(y) 17 | tensor([[[[ 0.7506, 2.1004, 3.4503], 18 | [ 6.1505, 7.5000, 8.8499], 19 | [11.5497, 12.8996, 14.2494]]]]) 20 | ''' 21 | 22 | import math 23 | import typing 24 | 25 | import torch 26 | from torch.nn import functional as F 27 | 28 | __all__ = ['imresize'] 29 | 30 | K = typing.TypeVar('K', str, torch.Tensor) 31 | 32 | def cubic_contribution(x: torch.Tensor, a: float=-0.5) -> torch.Tensor: 33 | ax = x.abs() 34 | ax2 = ax * ax 35 | ax3 = ax * ax2 36 | 37 | range_01 = (ax <= 1) 38 | range_12 = (ax > 1) * (ax <= 2) 39 | 40 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 41 | cont_01 = cont_01 * range_01.to(dtype=x.dtype) 42 | 43 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) 44 | cont_12 = cont_12 * range_12.to(dtype=x.dtype) 45 | 46 | cont = cont_01 + cont_12 47 | cont = cont / cont.sum() 48 | return cont 49 | 50 | def gaussian_contribution(x: torch.Tensor, sigma: float=2.0) -> torch.Tensor: 51 | range_3sigma = (x.abs() <= 3 * sigma + 1) 52 | # Normalization will be done after 53 | cont = torch.exp(-x.pow(2) / (2 * sigma**2)) 54 | cont = cont * range_3sigma.to(dtype=x.dtype) 55 | return cont 56 | 57 | def discrete_kernel( 58 | kernel: str, scale: float, antialiasing: bool=True) -> torch.Tensor: 59 | 60 | ''' 61 | For downsampling with integer scale only. 62 | ''' 63 | downsampling_factor = int(1 / scale) 64 | if kernel == 'cubic': 65 | kernel_size_orig = 4 66 | else: 67 | raise ValueError('Pass!') 68 | 69 | if antialiasing: 70 | kernel_size = kernel_size_orig * downsampling_factor 71 | else: 72 | kernel_size = kernel_size_orig 73 | 74 | if downsampling_factor % 2 == 0: 75 | a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size)) 76 | else: 77 | kernel_size -= 1 78 | a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1)) 79 | 80 | with torch.no_grad(): 81 | r = torch.linspace(-a, a, steps=kernel_size) 82 | k = cubic_contribution(r).view(-1, 1) 83 | k = torch.matmul(k, k.t()) 84 | k /= k.sum() 85 | 86 | return k 87 | 88 | def reflect_padding( 89 | x: torch.Tensor, 90 | dim: int, 91 | pad_pre: int, 92 | pad_post: int) -> torch.Tensor: 93 | 94 | ''' 95 | Apply reflect padding to the given Tensor. 96 | Note that it is slightly different from the PyTorch functional.pad, 97 | where boundary elements are used only once. 98 | Instead, we follow the MATLAB implementation 99 | which uses boundary elements twice. 100 | For example, 101 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, 102 | while our implementation yields [a, a, b, c, d, d]. 103 | ''' 104 | b, c, h, w = x.size() 105 | if dim == 2 or dim == -2: 106 | padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w) 107 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x) 108 | for p in range(pad_pre): 109 | padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :]) 110 | for p in range(pad_post): 111 | padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :]) 112 | else: 113 | padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post) 114 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x) 115 | for p in range(pad_pre): 116 | padding_buffer[..., pad_pre - p - 1].copy_(x[..., p]) 117 | for p in range(pad_post): 118 | padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)]) 119 | 120 | return padding_buffer 121 | 122 | def padding( 123 | x: torch.Tensor, 124 | dim: int, 125 | pad_pre: int, 126 | pad_post: int, 127 | padding_type: str='reflect') -> torch.Tensor: 128 | 129 | if padding_type == 'reflect': 130 | x_pad = reflect_padding(x, dim, pad_pre, pad_post) 131 | else: 132 | raise ValueError('{} padding is not supported!'.format(padding_type)) 133 | 134 | return x_pad 135 | 136 | def get_padding( 137 | base: torch.Tensor, 138 | kernel_size: int, 139 | x_size: int) -> typing.Tuple[int, int, torch.Tensor]: 140 | 141 | base = base.long() 142 | r_min = base.min() 143 | r_max = base.max() + kernel_size - 1 144 | 145 | if r_min <= 0: 146 | pad_pre = -r_min 147 | pad_pre = pad_pre.item() 148 | base += pad_pre 149 | else: 150 | pad_pre = 0 151 | 152 | if r_max >= x_size: 153 | pad_post = r_max - x_size + 1 154 | pad_post = pad_post.item() 155 | else: 156 | pad_post = 0 157 | 158 | return pad_pre, pad_post, base 159 | 160 | def get_weight( 161 | dist: torch.Tensor, 162 | kernel_size: int, 163 | kernel: str='cubic', 164 | sigma: float=2.0, 165 | antialiasing_factor: float=1) -> torch.Tensor: 166 | 167 | buffer_pos = dist.new_zeros(kernel_size, len(dist)) 168 | for idx, buffer_sub in enumerate(buffer_pos): 169 | buffer_sub.copy_(dist - idx) 170 | 171 | # Expand (downsampling) / Shrink (upsampling) the receptive field. 172 | buffer_pos *= antialiasing_factor 173 | if kernel == 'cubic': 174 | weight = cubic_contribution(buffer_pos) 175 | elif kernel == 'gaussian': 176 | weight = gaussian_contribution(buffer_pos, sigma=sigma) 177 | else: 178 | raise ValueError('{} kernel is not supported!'.format(kernel)) 179 | 180 | weight /= weight.sum(dim=0, keepdim=True) 181 | return weight 182 | 183 | def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: 184 | # Resize height 185 | if dim == 2 or dim == -2: 186 | k = (kernel_size, 1) 187 | h_out = x.size(-2) - kernel_size + 1 188 | w_out = x.size(-1) 189 | # Resize width 190 | else: 191 | k = (1, kernel_size) 192 | h_out = x.size(-2) 193 | w_out = x.size(-1) - kernel_size + 1 194 | 195 | unfold = F.unfold(x, k) 196 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out) 197 | return unfold 198 | 199 | def resize_1d( 200 | x: torch.Tensor, 201 | dim: int, 202 | side: int=None, 203 | kernel: str='cubic', 204 | sigma: float=2.0, 205 | padding_type: str='reflect', 206 | antialiasing: bool=True) -> torch.Tensor: 207 | 208 | ''' 209 | Args: 210 | x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). 211 | dim (int): 212 | scale (float): 213 | side (int): 214 | Return: 215 | ''' 216 | scale = side / x.size(dim) 217 | # Identity case 218 | if scale == 1: 219 | return x 220 | 221 | # Default bicubic kernel with antialiasing (only when downsampling) 222 | if kernel == 'cubic': 223 | kernel_size = 4 224 | else: 225 | kernel_size = math.floor(6 * sigma) 226 | 227 | if antialiasing and (scale < 1): 228 | antialiasing_factor = scale 229 | kernel_size = math.ceil(kernel_size / antialiasing_factor) 230 | else: 231 | antialiasing_factor = 1 232 | 233 | # We allow margin to both sides 234 | kernel_size += 2 235 | 236 | # Weights only depend on the shape of input and output, 237 | # so we do not calculate gradients here. 238 | with torch.no_grad(): 239 | d = 1 / (2 * side) 240 | pos = torch.linspace( 241 | start=d, 242 | end=(1 - d), 243 | steps=side, 244 | dtype=x.dtype, 245 | device=x.device, 246 | ) 247 | pos = x.size(dim) * pos - 0.5 248 | base = pos.floor() - (kernel_size // 2) + 1 249 | dist = pos - base 250 | weight = get_weight( 251 | dist, 252 | kernel_size, 253 | kernel=kernel, 254 | sigma=sigma, 255 | antialiasing_factor=antialiasing_factor, 256 | ) 257 | pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim)) 258 | 259 | # To backpropagate through x 260 | x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type) 261 | unfold = reshape_tensor(x_pad, dim, kernel_size) 262 | # Subsampling first 263 | if dim == 2 or dim == -2: 264 | sample = unfold[..., base, :] 265 | weight = weight.view(1, kernel_size, sample.size(2), 1) 266 | else: 267 | sample = unfold[..., base] 268 | weight = weight.view(1, kernel_size, 1, sample.size(3)) 269 | 270 | # Apply the kernel 271 | down = sample * weight 272 | down = down.sum(dim=1, keepdim=True) 273 | return down 274 | 275 | def downsampling_2d( 276 | x: torch.Tensor, 277 | k: torch.Tensor, 278 | scale: int, 279 | padding_type: str='reflect') -> torch.Tensor: 280 | 281 | c = x.size(1) 282 | k_h = k.size(-2) 283 | k_w = k.size(-1) 284 | 285 | k = k.to(dtype=x.dtype, device=x.device) 286 | k = k.view(1, 1, k_h, k_w) 287 | k = k.repeat(c, c, 1, 1) 288 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) 289 | e = e.view(c, c, 1, 1) 290 | k = k * e 291 | 292 | pad_h = (k_h - scale) // 2 293 | pad_w = (k_w - scale) // 2 294 | x = padding(x, -2, pad_h, pad_h, padding_type=padding_type) 295 | x = padding(x, -1, pad_w, pad_w, padding_type=padding_type) 296 | y = F.conv2d(x, k, padding=0, stride=scale) 297 | return y 298 | 299 | def imresize( 300 | x: torch.Tensor, 301 | scale: float=None, 302 | sides: typing.Tuple[int, int]=None, 303 | kernel: K='cubic', 304 | sigma: float=2, 305 | rotation_degree: float=0, 306 | padding_type: str='reflect', 307 | antialiasing: bool=True) -> torch.Tensor: 308 | 309 | ''' 310 | Args: 311 | x (torch.Tensor): 312 | scale (float): 313 | sides (tuple(int, int)): 314 | kernel (str, default='cubic'): 315 | sigma (float, default=2): 316 | rotation_degree (float, default=0): 317 | padding_type (str, default='reflect'): 318 | antialiasing (bool, default=True): 319 | Return: 320 | torch.Tensor: 321 | ''' 322 | 323 | if scale is None and sides is None: 324 | raise ValueError('One of scale or sides must be specified!') 325 | if scale is not None and sides is not None: 326 | raise ValueError('Please specify scale or sides to avoid conflict!') 327 | 328 | if x.dim() == 4: 329 | b, c, h, w = x.size() 330 | elif x.dim() == 3: 331 | c, h, w = x.size() 332 | b = None 333 | elif x.dim() == 2: 334 | h, w = x.size() 335 | b = c = None 336 | else: 337 | raise ValueError('{}-dim Tensor is not supported!'.format(x.dim())) 338 | 339 | x = x.view(-1, 1, h, w) 340 | 341 | if sides is None: 342 | # Determine output size 343 | sides = (math.ceil(h * scale), math.ceil(w * scale)) 344 | scale_inv = 1 / scale 345 | if isinstance(kernel, str) and scale_inv.is_integer(): 346 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) 347 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): 348 | raise ValueError( 349 | 'An integer downsampling factor ' 350 | 'should be used with a predefined kernel!' 351 | ) 352 | 353 | if x.dtype != torch.float32 or x.dtype != torch.float64: 354 | dtype = x.dtype 355 | x = x.float() 356 | else: 357 | dtype = None 358 | 359 | if isinstance(kernel, str): 360 | # Shared keyword arguments across dimensions 361 | kwargs = { 362 | 'kernel': kernel, 363 | 'sigma': sigma, 364 | 'padding_type': padding_type, 365 | 'antialiasing': antialiasing, 366 | } 367 | # Core resizing module 368 | x = resize_1d(x, -2, side=sides[0], **kwargs) 369 | x = resize_1d(x, -1, side=sides[1], **kwargs) 370 | elif isinstance(kernel, torch.Tensor): 371 | x = downsampling_2d(x, kernel, scale=int(1 / scale)) 372 | 373 | rh = x.size(-2) 374 | rw = x.size(-1) 375 | # Back to the original dimension 376 | if b is not None: 377 | x = x.view(b, c, rh, rw) # 4-dim 378 | else: 379 | if c is not None: 380 | x = x.view(c, rh, rw) # 3-dim 381 | else: 382 | x = x.view(rh, rw) # 2-dim 383 | 384 | if dtype is not None: 385 | if not dtype.is_floating_point: 386 | x = x.round() 387 | # To prevent over/underflow when converting types 388 | if dtype is torch.uint8: 389 | x = x.clamp(0, 255) 390 | 391 | x = x.to(dtype=dtype) 392 | 393 | return x 394 | 395 | if __name__ == '__main__': 396 | # Just for debugging 397 | torch.set_printoptions(precision=4, sci_mode=False, edgeitems=16, linewidth=200) 398 | a = torch.arange(64).float().view(1, 1, 8, 8) 399 | z = imresize(a, 0.5) 400 | print(z) 401 | #a = torch.arange(16).float().view(1, 1, 4, 4) 402 | ''' 403 | a = torch.zeros(1, 1, 4, 4) 404 | a[..., 0, 0] = 100 405 | a[..., 1, 0] = 10 406 | a[..., 0, 1] = 1 407 | a[..., 0, -1] = 100 408 | a = torch.zeros(1, 1, 4, 4) 409 | a[..., -1, -1] = 100 410 | a[..., -2, -1] = 10 411 | a[..., -1, -2] = 1 412 | a[..., -1, 0] = 100 413 | ''' 414 | #b = imresize(a, sides=(3, 8), antialiasing=False) 415 | #c = imresize(a, sides=(11, 13), antialiasing=True) 416 | #c = imresize(a, sides=(4, 4), antialiasing=False, kernel='gaussian', sigma=1) 417 | #print(a) 418 | #print(b) 419 | #print(c) 420 | 421 | #r = discrete_kernel('cubic', 1 / 3) 422 | #print(r) 423 | ''' 424 | a = torch.arange(225).float().view(1, 1, 15, 15) 425 | imresize(a, sides=[5, 5]) 426 | ''' -------------------------------------------------------------------------------- /super-resolution/utils/misc.py: -------------------------------------------------------------------------------- 1 | import logging.config 2 | import os 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import math 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision.utils import make_grid 9 | from os import path, makedirs 10 | 11 | def mkdir(save_path): 12 | if not path.exists(save_path): 13 | makedirs(save_path) 14 | 15 | def make_image_grid(x, nrow, padding=0, pad_value=0): 16 | x = x.clone().cpu().data 17 | grid = make_grid(x, nrow=nrow, padding=padding, normalize=True, scale_each=False, pad_value=pad_value) 18 | return grid 19 | 20 | def tensor_to_image(x): 21 | ndarr = x.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 22 | image = Image.fromarray(ndarr) 23 | return image 24 | 25 | def plot_image_grid(x, nrow, padding=0): 26 | grid = make_image_grid(x=x, nrow=nrow, padding=padding).permute(1, 2, 0).numpy() 27 | plt.imshow(grid) 28 | plt.show() 29 | 30 | def save_image(x, path, size=None): 31 | image = tensor_to_image(x) 32 | if size: 33 | image = image.resize((size, size), Image.NEAREST) 34 | image.save(path) 35 | 36 | def save_image_grid(x, path, nrow=8, size=None): 37 | grid = make_image_grid(x, nrow) 38 | save_image(grid, path, size=size) 39 | 40 | def setup_logging(log_file='log.txt', resume=False, dummy=False): 41 | if dummy: 42 | logging.getLogger('dummy') 43 | else: 44 | if os.path.isfile(log_file) and resume: 45 | file_mode = 'a' 46 | else: 47 | file_mode = 'w' 48 | 49 | root_logger = logging.getLogger() 50 | if root_logger.handlers: 51 | root_logger.removeHandler(root_logger.handlers[0]) 52 | logging.basicConfig(level=logging.INFO, 53 | format="%(asctime)s - %(levelname)s - %(message)s", 54 | datefmt="%Y-%m-%d %H:%M:%S", 55 | filename=log_file, 56 | filemode=file_mode) 57 | console = logging.StreamHandler() 58 | console.setLevel(logging.INFO) 59 | formatter = logging.Formatter('%(message)s') 60 | console.setFormatter(formatter) 61 | logging.getLogger('').addHandler(console) 62 | 63 | def average(lst): 64 | return sum(lst) / len(lst) 65 | 66 | def rgb2yc(img): 67 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 68 | rlt = rlt.round() 69 | return rlt 70 | 71 | def compute_psnr(x, y, scale): 72 | x = rgb2yc(tensor_to_image(x)).astype(np.float64) 73 | y = rgb2yc(tensor_to_image(y)).astype(np.float64) 74 | 75 | x = x[round(scale):-round(scale), round(scale):-round(scale)] 76 | y = y[round(scale):-round(scale), round(scale):-round(scale)] 77 | 78 | mse = np.mean((x - y) ** 2) 79 | return 20 * math.log10(255.0 / math.sqrt(mse)) 80 | 81 | if __name__ == "__main__": 82 | print('None') 83 | 84 | -------------------------------------------------------------------------------- /super-resolution/utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def get_exp_scheduler_with_warmup(optimizer, rampup_steps=5, sustain_steps=5): 5 | def lr_lambda(step): 6 | if step < rampup_steps: 7 | return min(1., 1.8 ** ((step - rampup_steps))) 8 | elif step < rampup_steps + sustain_steps: 9 | return 1. 10 | else: 11 | return max(0.1, 0.85 ** (step - rampup_steps - sustain_steps)) 12 | 13 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 14 | 15 | def get_cosine_scheduler_with_warmup(optimizer, rampup_steps=3, sustain_steps=2, priod=5): 16 | def lr_lambda(step): 17 | if step < rampup_steps: 18 | return min(1., 1.5 ** ((step - rampup_steps))) 19 | elif step < rampup_steps + sustain_steps: 20 | return 1. 21 | else: 22 | return max(0.1, (1 + math.cos(math.pi * (step - rampup_steps - sustain_steps) / priod)) / 2) 23 | 24 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 25 | 26 | if __name__ == "__main__": 27 | import matplotlib.pyplot as plt 28 | 29 | model = torch.nn.Linear(2, 1) 30 | optimizer = torch.optim.SGD(model.parameters(), lr=2.5e-4) 31 | lr_scheduler = get_exp_scheduler_with_warmup(optimizer) 32 | 33 | lrs = [] 34 | for i in range(25): 35 | lr_scheduler.step() 36 | lrs.append(optimizer.param_groups[0]["lr"]) 37 | 38 | plt.plot(lrs) 39 | plt.show() 40 | print(lrs) 41 | -------------------------------------------------------------------------------- /super-resolution/utils/recorderx.py: -------------------------------------------------------------------------------- 1 | import utils.misc as misc 2 | from tensorboardX import SummaryWriter 3 | 4 | class RecoderX(): 5 | def __init__(self, log_dir): 6 | self.log_dir = log_dir 7 | self.writer = SummaryWriter(logdir=log_dir) 8 | self.log = '' 9 | 10 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 11 | self.writer.add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step, walltime=walltime) 12 | 13 | def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): 14 | self.writer.add_scalars(main_tag=main_tag, tag_scalar_dict=tag_scalar_dict, global_step=global_step, walltime=walltime) 15 | 16 | def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'): 17 | self.writer.add_image(tag=tag, img_tensor=img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats) 18 | 19 | def add_image_grid(self, tag, img_tensor, nrow=8, global_step=None, padding=0, pad_value=0,walltime=None, dataformats='CHW'): 20 | grid = misc.make_image_grid(img_tensor, nrow, padding=padding, pad_value=pad_value) 21 | self.writer.add_image(tag=tag, img_tensor=grid, global_step=global_step, walltime=walltime, dataformats=dataformats) 22 | 23 | def add_graph(self, graph_profile, walltime=None): 24 | self.writer.add_graph(graph_profile, walltime=walltime) 25 | 26 | def add_histogram(self, tag, values, global_step=None): 27 | self.writer.add_histogram(tag, values, global_step) 28 | 29 | def add_figure(self, tag, figure, global_step=None, close=True, walltime=None): 30 | self.writer.add_figure(tag, figure, global_step=global_step, close=close, walltime=walltime) 31 | 32 | def export_json(self, out_file): 33 | self.writer.export_scalars_to_json(out_file) 34 | 35 | def close(self): 36 | self.writer.close() 37 | 38 | if __name__ == "__main__": 39 | print('None') 40 | -------------------------------------------------------------------------------- /super-resolution/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | global visualization, hooks 3 | 4 | visualization = {} 5 | hooks = [] 6 | 7 | def hook_act(m, i, o): 8 | min, _ = o.flatten(start_dim=2).min(dim=-1) 9 | max, _ = o.flatten(start_dim=2).max(dim=-1) 10 | visualization[m] = (o - min.view(o.size(0), -1, 1, 1)) / (max.view(o.size(0), -1, 1, 1) - min.view(o.size(0), -1, 1, 1) + 1e-8) 11 | 12 | def hook_forward_output(m, i, o): 13 | visualization[m] = o 14 | 15 | def hook_forward_norm(m, i, o): 16 | visualization[m] = o.flatten(start_dim=1).norm(p=2, dim=1).mean() 17 | 18 | def hook_backward_norm(m, i, o): 19 | visualization[m] = i[1].flatten(start_dim=1).norm(p=2, dim=1).mean() # i[1]: weights' grad 20 | 21 | def backward_norms_hook(model, instance=nn.Conv2d): 22 | for name, layer in model._modules.items(): 23 | if isinstance(layer, nn.Sequential): 24 | backward_norms_hook(layer, instance) 25 | elif isinstance(layer, instance): 26 | hooks.append(layer.register_backward_hook(hook_backward_norm)) 27 | 28 | def forward_output_hook(model, instance=(nn.ReLU, nn.LeakyReLU, nn.Linear)): 29 | for name, layer in model._modules.items(): 30 | if isinstance(layer, (nn.Sequential)): 31 | forward_output_hook(layer, instance) 32 | elif isinstance(layer, instance): 33 | layer.register_forward_hook(hook_forward_output) 34 | 35 | def forward_activations_hook(model, instance=(nn.ReLU, nn.LeakyReLU)): 36 | for name, layer in model._modules.items(): 37 | if isinstance(layer, (nn.Sequential)): 38 | forward_activations_hook(layer, instance) 39 | elif isinstance(layer, instance): 40 | layer.register_forward_hook(hook_act) 41 | 42 | def get_visualization(): 43 | return visualization 44 | 45 | def remove_hooks(): 46 | for i, hook in enumerate(hooks): 47 | hook.remove() 48 | # del hooks[i] 49 | --------------------------------------------------------------------------------