├── .gitignore ├── DSC00261.jpg ├── LICENSE ├── README.md ├── helpers.py ├── requirements.txt ├── vgg_loss.py ├── vgg_loss_demo.py ├── vgg_loss_demo_2.py └── vgg_loss_demo_3.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | __pycache__ 3 | .ipynb_checkpoints 4 | out.png 5 | target.png 6 | demo.png 7 | autoencoder.pth 8 | -------------------------------------------------------------------------------- /DSC00261.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/vgg_loss/7c4d4c3926ebec06d0cbbb50f9ddff4a74f0e593/DSC00261.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Katherine Crowson 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | vgg_loss 2 | ======== 3 | 4 | A VGG-based perceptual loss function for PyTorch. See Johnson, Alahi, and Fei-Fei, "[Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)". 5 | 6 | The module containing the code to import is `vgg_loss.py`. See the three demos for usage examples. 7 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | """Helper functions for the VGG perceptual loss.""" 2 | 3 | import torch 4 | 5 | 6 | def inspect_outputs(module): 7 | """Registers hooks on each submodule that print their outputs.""" 8 | 9 | def make_hook(name): 10 | return lambda m, i, o: print(f'({name}) {type(m).__name__}: {o}') 11 | 12 | for name, mod in module.named_children(): 13 | mod.register_forward_hook(make_hook(name)) 14 | 15 | 16 | def batchify_image(input): 17 | """Promotes the input tensor (an image or a batch of images) to a 4D tensor 18 | with three channels, if it is not already. Strips alpha channels if 19 | present. 20 | """ 21 | if input.ndim == 2: 22 | input = input[None] 23 | if input.ndim == 3: 24 | input = input[None] 25 | if input.ndim != 4: 26 | raise ValueError('input.ndim must be 2, 3, or 4') 27 | if input.shape[1] == 2 or input.shape[1] == 4: 28 | input = input[:, :-1] 29 | if input.shape[1] == 1: 30 | input = torch.cat([input] * 3, dim=1) 31 | if input.shape[1] != 3: 32 | raise ValueError('input must have 1-4 channels') 33 | return input 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | torchvision>=0.8.1 3 | -------------------------------------------------------------------------------- /vgg_loss.py: -------------------------------------------------------------------------------- 1 | """A VGG-based perceptual loss function for PyTorch.""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torchvision import models, transforms 7 | 8 | 9 | class Lambda(nn.Module): 10 | """Wraps a callable in an :class:`nn.Module` without registering it.""" 11 | 12 | def __init__(self, func): 13 | super().__init__() 14 | object.__setattr__(self, 'forward', func) 15 | 16 | def extra_repr(self): 17 | return getattr(self.forward, '__name__', type(self.forward).__name__) + '()' 18 | 19 | 20 | class WeightedLoss(nn.ModuleList): 21 | """A weighted combination of multiple loss functions.""" 22 | 23 | def __init__(self, losses, weights, verbose=False): 24 | super().__init__() 25 | for loss in losses: 26 | self.append(loss if isinstance(loss, nn.Module) else Lambda(loss)) 27 | self.weights = weights 28 | self.verbose = verbose 29 | 30 | def _print_losses(self, losses): 31 | for i, loss in enumerate(losses): 32 | print(f'({i}) {type(self[i]).__name__}: {loss.item()}') 33 | 34 | def forward(self, *args, **kwargs): 35 | losses = [] 36 | for loss, weight in zip(self, self.weights): 37 | losses.append(loss(*args, **kwargs) * weight) 38 | if self.verbose: 39 | self._print_losses(losses) 40 | return sum(losses) 41 | 42 | 43 | class TVLoss(nn.Module): 44 | """Total variation loss (Lp penalty on image gradient magnitude). 45 | 46 | The input must be 4D. If a target (second parameter) is passed in, it is 47 | ignored. 48 | 49 | ``p=1`` yields the vectorial total variation norm. It is a generalization 50 | of the originally proposed (isotropic) 2D total variation norm (see 51 | (see https://en.wikipedia.org/wiki/Total_variation_denoising) for color 52 | images. On images with a single channel it is equal to the 2D TV norm. 53 | 54 | ``p=2`` yields a variant that is often used for smoothing out noise in 55 | reconstructions of images from neural network feature maps (see Mahendran 56 | and Vevaldi, "Understanding Deep Image Representations by Inverting 57 | Them", https://arxiv.org/abs/1412.0035) 58 | 59 | :attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'`` 60 | similarly to the loss functions in :mod:`torch.nn`. The default is 61 | ``'mean'``. 62 | """ 63 | 64 | def __init__(self, p, reduction='mean', eps=1e-8): 65 | super().__init__() 66 | if p not in {1, 2}: 67 | raise ValueError('p must be 1 or 2') 68 | if reduction not in {'mean', 'sum', 'none'}: 69 | raise ValueError("reduction must be 'mean', 'sum', or 'none'") 70 | self.p = p 71 | self.reduction = reduction 72 | self.eps = eps 73 | 74 | def forward(self, input, target=None): 75 | input = F.pad(input, (0, 1, 0, 1), 'replicate') 76 | x_diff = input[..., :-1, :-1] - input[..., :-1, 1:] 77 | y_diff = input[..., :-1, :-1] - input[..., 1:, :-1] 78 | diff = x_diff**2 + y_diff**2 79 | if self.p == 1: 80 | diff = (diff + self.eps).mean(dim=1, keepdims=True).sqrt() 81 | if self.reduction == 'mean': 82 | return diff.mean() 83 | if self.reduction == 'sum': 84 | return diff.sum() 85 | return diff 86 | 87 | 88 | class VGGLoss(nn.Module): 89 | """Computes the VGG perceptual loss between two batches of images. 90 | 91 | The input and target must be 4D tensors with three channels 92 | ``(B, 3, H, W)`` and must have equivalent shapes. Pixel values should be 93 | normalized to the range 0–1. 94 | 95 | The VGG perceptual loss is the mean squared difference between the features 96 | computed for the input and target at layer :attr:`layer` (default 8, or 97 | ``relu2_2``) of the pretrained model specified by :attr:`model` (either 98 | ``'vgg16'`` (default) or ``'vgg19'``). 99 | 100 | If :attr:`shift` is nonzero, a random shift of at most :attr:`shift` 101 | pixels in both height and width will be applied to all images in the input 102 | and target. The shift will only be applied when the loss function is in 103 | training mode, and will not be applied if a precomputed feature map is 104 | supplied as the target. 105 | 106 | :attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'`` 107 | similarly to the loss functions in :mod:`torch.nn`. The default is 108 | ``'mean'``. 109 | 110 | :meth:`get_features()` may be used to precompute the features for the 111 | target, to speed up the case where inputs are compared against the same 112 | target over and over. To use the precomputed features, pass them in as 113 | :attr:`target` and set :attr:`target_is_features` to :code:`True`. 114 | 115 | Instances of :class:`VGGLoss` must be manually converted to the same 116 | device and dtype as their inputs. 117 | """ 118 | 119 | models = {'vgg16': models.vgg16, 'vgg19': models.vgg19} 120 | 121 | def __init__(self, model='vgg16', layer=8, shift=0, reduction='mean'): 122 | super().__init__() 123 | self.shift = shift 124 | self.reduction = reduction 125 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 126 | std=[0.229, 0.224, 0.225]) 127 | self.model = self.models[model](pretrained=True).features[:layer+1] 128 | self.model.eval() 129 | self.model.requires_grad_(False) 130 | 131 | def get_features(self, input): 132 | return self.model(self.normalize(input)) 133 | 134 | def train(self, mode=True): 135 | self.training = mode 136 | 137 | def forward(self, input, target, target_is_features=False): 138 | if target_is_features: 139 | input_feats = self.get_features(input) 140 | target_feats = target 141 | else: 142 | sep = input.shape[0] 143 | batch = torch.cat([input, target]) 144 | if self.shift and self.training: 145 | padded = F.pad(batch, [self.shift] * 4, mode='replicate') 146 | batch = transforms.RandomCrop(batch.shape[2:])(padded) 147 | feats = self.get_features(batch) 148 | input_feats, target_feats = feats[:sep], feats[sep:] 149 | return F.mse_loss(input_feats, target_feats, reduction=self.reduction) 150 | -------------------------------------------------------------------------------- /vgg_loss_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Reconstruction of a target image from the VGG perceptual loss.""" 4 | 5 | import torch 6 | from torch import optim 7 | from torchvision import io as tio 8 | from torchvision.transforms import functional as TF 9 | 10 | import vgg_loss 11 | 12 | 13 | def main(): 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | print('Using device:', device) 16 | 17 | crit_vgg = vgg_loss.VGGLoss().to(device) 18 | crit_tv = vgg_loss.TVLoss(p=2) 19 | 20 | target = tio.read_image('DSC00261.jpg')[None] / 255 21 | target = TF.resize(target, (256, 256), 3).to(device) 22 | target_act = crit_vgg.get_features(target) 23 | 24 | input = torch.rand_like(target) / 255 + 0.5 25 | input.requires_grad_(True) 26 | 27 | opt = optim.Adam([input], lr=0.025) 28 | 29 | try: 30 | for i in range(1000): 31 | opt.zero_grad() 32 | loss = crit_vgg(input, target_act, target_is_features=True) 33 | loss += crit_tv(input) * 20 34 | print(i, loss.item()) 35 | loss.backward() 36 | opt.step() 37 | except KeyboardInterrupt: 38 | pass 39 | 40 | TF.to_pil_image(input[0].clamp(0, 1)).save('out.png') 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /vgg_loss_demo_2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Total variation image denoising using the VGG perceptual loss.""" 4 | 5 | import torch 6 | from torch import nn, optim 7 | from torchvision import io as tio 8 | from torchvision.transforms import functional as TF 9 | 10 | import vgg_loss 11 | 12 | 13 | def main(): 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | print('Using device:', device) 16 | 17 | crit_vgg = vgg_loss.VGGLoss().to(device) 18 | crit_l2 = nn.MSELoss() 19 | crit_tv = vgg_loss.TVLoss(p=1) 20 | 21 | target = tio.read_image('DSC00261.jpg')[None] / 255 22 | target = TF.resize(target, (256, 256), 3) 23 | target += (torch.rand_like(target) - 0.5) / 4 24 | target = target.clamp(0, 1) 25 | TF.to_pil_image(target[0]).save('target.png') 26 | target = target.to(device) 27 | target_act = crit_vgg.get_features(target) 28 | 29 | input = target.clone() 30 | input.requires_grad_(True) 31 | 32 | opt = optim.Adam([input], lr=0.01) 33 | 34 | try: 35 | for i in range(250): 36 | opt.zero_grad() 37 | loss = crit_vgg(input, target_act, target_is_features=True) 38 | loss += crit_l2(input, target) * 1500 39 | loss += crit_tv(input) * 250 40 | print(i, loss.item()) 41 | loss.backward() 42 | opt.step() 43 | except KeyboardInterrupt: 44 | pass 45 | 46 | TF.to_pil_image(input[0].clamp(0, 1)).save('out.png') 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /vgg_loss_demo_3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """An autoencoder using the VGG perceptual loss.""" 4 | 5 | import torch 6 | from torch import optim, nn 7 | from torch.utils import data 8 | from torchvision import datasets, transforms 9 | 10 | import vgg_loss 11 | 12 | BATCH_SIZE = 100 13 | EPOCHS = 100 14 | 15 | 16 | def main(): 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | print('Using device:', device) 19 | 20 | tf = transforms.ToTensor() 21 | train_set = datasets.CIFAR10('data/cifar10', download=True, transform=tf) 22 | train_dl = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, 23 | num_workers=2, pin_memory=True) 24 | test_set = datasets.CIFAR10('data/cifar10', train=False, transform=tf) 25 | test_dl = data.DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=2, 26 | pin_memory=True) 27 | 28 | encoder = nn.Sequential( 29 | nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), 30 | nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(), 31 | nn.AvgPool2d(2, ceil_mode=True), 32 | nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), 33 | nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(), 34 | nn.AvgPool2d(2, ceil_mode=True), 35 | nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), 36 | nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(), 37 | nn.AvgPool2d(2, ceil_mode=True), 38 | nn.Flatten(), 39 | nn.Linear(1024, 128), nn.Tanh(), 40 | ).to(device) 41 | 42 | decoder = nn.Sequential( 43 | nn.Linear(128, 1024), nn.ReLU(), 44 | nn.Unflatten(-1, (64, 4, 4)), 45 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 46 | nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(), 47 | nn.Conv2d(64, 32, 3, padding=1), nn.ReLU(), 48 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 49 | nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(), 50 | nn.Conv2d(32, 16, 3, padding=1), nn.ReLU(), 51 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 52 | nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(), 53 | nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid(), 54 | ).to(device) 55 | 56 | model = nn.Sequential(encoder, decoder) 57 | print('Parameters:', sum(map(lambda x: x.numel(), model.parameters()))) 58 | 59 | # crit = nn.MSELoss() 60 | crit = vgg_loss.WeightedLoss([vgg_loss.VGGLoss(shift=2), 61 | nn.MSELoss(), 62 | vgg_loss.TVLoss(p=1)], 63 | [1, 40, 10]).to(device) 64 | # helpers.inspect_outputs(crit) 65 | opt = optim.Adam(model.parameters()) 66 | sched = optim.lr_scheduler.ReduceLROnPlateau(opt, patience=5, 67 | verbose=True) 68 | 69 | def train(): 70 | model.train() 71 | crit.train() 72 | i = 0 73 | for batch, _ in train_dl: 74 | i += 1 75 | batch = batch.to(device, non_blocking=True) 76 | opt.zero_grad() 77 | out = model(batch) 78 | loss = crit(out, batch) 79 | if i % 50 == 0: 80 | print(i, loss.item()) 81 | loss.backward() 82 | opt.step() 83 | 84 | @torch.no_grad() 85 | def test(): 86 | model.eval() 87 | crit.eval() 88 | losses = [] 89 | for batch, _ in test_dl: 90 | batch = batch.to(device, non_blocking=True) 91 | out = model(batch) 92 | losses.append(crit(out, batch)) 93 | loss = sum(losses) / len(losses) 94 | print('Validation loss:', loss.item()) 95 | sched.step(loss) 96 | 97 | @torch.no_grad() 98 | def demo(): 99 | model.eval() 100 | batch = torch.cat([test_set[i][0][None] for i in range(10)]) 101 | out = model(batch.to(device)).cpu() 102 | col_l = torch.cat(list(batch), dim=1) 103 | col_r = torch.cat(list(out), dim=1) 104 | grid = torch.cat([col_l, col_r], dim=2) 105 | transforms.functional.to_pil_image(grid).save('demo.png') 106 | print('Wrote example grid to demo.png.') 107 | 108 | try: 109 | for epoch in range(EPOCHS): 110 | print('Epoch', epoch + 1) 111 | train() 112 | test() 113 | demo() 114 | except KeyboardInterrupt: 115 | pass 116 | 117 | torch.save(model.state_dict(), 'autoencoder.pth') 118 | print('Wrote trained model to autoencoder.pth.') 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | --------------------------------------------------------------------------------