├── README.md ├── data_utils.py ├── fast_neural_style.py ├── images ├── content │ └── get_dataset.sh ├── input │ └── hoovertowernight.jpg └── style │ ├── candy.jpg │ ├── starry_night.jpg │ └── wave.jpg ├── loss.py ├── models.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" 2 | 3 | ## Requirement 4 | - [Pytorch](http://pytorch.org/) 5 | ``` 6 | $ conda install pytorch torchvision -c soumith 7 | ``` 8 | 9 | ## Train 10 | Need to train one image transformation network model per one style target. 11 | According to the paper, the models are trained on the [Microsoft COCO dataset](http://mscoco.org/dataset/#download). 12 | ``` 13 | python train.py --style_image "images/style/xyz.jpg" --dataset_path "images/content" --cuda 14 | ``` 15 | 16 | ## Generate 17 | ``` 18 | python fast_neural_style.py --input_image "images/input/xyz.jpg" -model "model_epoch_2.pth" --output_name "output.jpg" --cuda 19 | ``` 20 | 21 | ## Reference 22 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](http://arxiv.org/abs/1603.08155) 23 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from PIL import Image 6 | 7 | def vgg_mean_subtraction(batch): 8 | tensortype = type(batch.data) 9 | mean = tensortype(batch.data.size()) 10 | mean[:, 0, :, :] = 103.939 11 | mean[:, 1, :, :] = 116.779 12 | mean[:, 2, :, :] = 123.680 13 | batch -= Variable(mean) 14 | 15 | 16 | # batch : BxCxHxW 17 | def batch_rgb_to_bgr(batch): 18 | batch = batch.transpose(0, 1) 19 | (r, g, b) = torch.chunk(batch, 3) 20 | batch = torch.cat((b, g, r)) 21 | batch = batch.transpose(0, 1) 22 | return batch 23 | 24 | 25 | # load image in RGB CxHxW [0,255] 26 | def load_image(filename, size=None): 27 | img = Image.open(filename) 28 | if size is not None: 29 | img = img.resize((size, size), Image.ANTIALIAS) 30 | img = np.array(img).transpose(2, 0, 1) 31 | img = torch.from_numpy(img).float() 32 | return img 33 | -------------------------------------------------------------------------------- /fast_neural_style.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | from models import ImageTransformNet 7 | from PIL import Image 8 | 9 | parser = argparse.ArgumentParser(description='PyTorch Fast Style Transfer') 10 | parser.add_argument('--input_image', type=str, required=True, help='input image to use') 11 | parser.add_argument('--model', type=str, required=True, help='model file to use') 12 | parser.add_argument('--output_name', default='styleTransfer.jpg', type=str, help='location to save the output image') 13 | parser.add_argument('--cuda', action='store_true', default=True, help='use cuda') 14 | args = parser.parse_args() 15 | 16 | cuda = args.cuda 17 | if cuda and not torch.cuda.is_available(): 18 | raise Exception("No GPU found, please run without --cuda") 19 | 20 | model = models.ImageTransformNet() 21 | model.load_state_dict(torch.load(args.model)) 22 | if cuda: 23 | model.cuda() 24 | 25 | # load image 26 | img = Image.open(args.input) 27 | img = np.array(img) 28 | img = np.array(img[..., ::-1]) # RGB -> BGR 29 | img = img.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) 30 | img = img.reshape((1, ) + img.shape) # (C, H, W) -> (B, C, H, W) 31 | img = torch.from_numpy(img).float() 32 | img = Variable(img, volatile=True) 33 | if cuda: 34 | img = img.cuda() 35 | 36 | model.eval() 37 | output_img = model(img) 38 | 39 | # save output 40 | output_img = output_img.data.cpu().clamp(0, 255).byte().numpy() 41 | output_img = output_img[0].transpose((1, 2, 0)) 42 | output_img = output_img[..., ::-1] 43 | output_img = Image.fromarray(output_img) 44 | output_img.save(args.output_name) -------------------------------------------------------------------------------- /images/content/get_dataset.sh: -------------------------------------------------------------------------------- 1 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip 2 | unzip train2014.zip 3 | rm train2014.zip 4 | -------------------------------------------------------------------------------- /images/input/hoovertowernight.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishal1796/pytorch-fast-neural-style/78b5703093234071a57268ed528a52e68329239f/images/input/hoovertowernight.jpg -------------------------------------------------------------------------------- /images/style/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishal1796/pytorch-fast-neural-style/78b5703093234071a57268ed528a52e68329239f/images/style/candy.jpg -------------------------------------------------------------------------------- /images/style/starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishal1796/pytorch-fast-neural-style/78b5703093234071a57268ed528a52e68329239f/images/style/starry_night.jpg -------------------------------------------------------------------------------- /images/style/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishal1796/pytorch-fast-neural-style/78b5703093234071a57268ed528a52e68329239f/images/style/wave.jpg -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.nn import MSELoss 4 | from torch.utils.serialization import load_lua 5 | from torch.autograd import Variable 6 | import models 7 | from data_utils import vgg_mean_subtraction 8 | 9 | 10 | def vgg16_model(): 11 | if not os.path.exists('vgg16feature.pth'): 12 | if not os.path.exists('vgg16.t7'): 13 | os.system('wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7') 14 | vgglua = load_lua('vgg16.t7') 15 | vgg = models.VGGFeature() 16 | for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): 17 | dst[:] = src[:] 18 | torch.save(vgg.state_dict(), 'vgg16feature.pth') 19 | 20 | def gram_matrix(y): 21 | B, C, H, W = y.size() 22 | features = y.view(B, C, W*H) 23 | features_t = features.transpose(1,2) 24 | gram = features.bmm(features_t) / (C*H*W) 25 | return gram 26 | 27 | def loss_function(content_weight, style_weight, yc, ys, y_hat, cuda): 28 | vgg16_model() 29 | vgg = models.VGGFeature() 30 | vgg.load_state_dict(torch.load('vgg16feature.pth')) 31 | criterion = torch.nn.MSELoss() 32 | 33 | if cuda: 34 | vgg = vgg.cuda() 35 | criterion = criterion.cuda() 36 | 37 | vgg_mean_subtraction(yc) 38 | vgg_mean_subtraction(ys) 39 | vgg_mean_subtraction(y_hat) 40 | 41 | feature_c = vgg(yc) 42 | feature_hat = vgg(y_hat) 43 | feat_loss = content_weight * criterion(feature_hat[2], Variable(feature_c[2].data, requires_grad=False)) 44 | 45 | feature_s = vgg(ys) 46 | gram_s = [gram_matrix(y) for y in feature_s] 47 | gram_hat = [gram_matrix(y) for y in feature_hat] 48 | style_loss = 0 49 | for m in range(0, len(feature_hat)): 50 | style_loss += style_weight * criterion(gram_hat[m], Variable(gram_s[m].data, requires_grad=False)) 51 | 52 | return style_loss + feat_loss 53 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | import torch.nn.functional as F 4 | 5 | 6 | class VGGFeature(Module): 7 | def __init__(self): 8 | super(VGGFeature, self).__init__() 9 | self.conv1_1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 10 | self.conv1_2 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 11 | self.conv2_1 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 12 | self.conv2_2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 13 | self.conv3_1 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 14 | self.conv3_2 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 15 | self.conv3_3 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 16 | self.conv4_1 = torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 17 | self.conv4_2 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 18 | self.conv4_3 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 19 | self.conv5_1 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 20 | self.conv5_2 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 21 | self.conv5_3 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 22 | 23 | def forward(self, x): 24 | relu1_1 = F.relu(self.conv1_1(x)) 25 | relu1_2 = F.relu(self.conv1_2(relu1_1)) 26 | maxpool_1 = F.max_pool2d(relu1_2, kernel_size=2, stride=2) 27 | relu2_1 = F.relu(self.conv2_1(maxpool_1)) 28 | relu2_2 = F.relu(self.conv2_2(relu2_1)) 29 | maxpool_2 = F.max_pool2d(relu2_2, kernel_size=2, stride=2) 30 | relu3_1 = F.relu(self.conv3_1(maxpool_2)) 31 | relu3_2 = F.relu(self.conv3_2(relu3_1)) 32 | relu3_3 = F.relu(self.conv3_3(relu3_2)) 33 | maxpool_3 = F.max_pool2d(relu3_3, kernel_size=2, stride=2) 34 | relu4_1 = F.relu(self.conv4_1(maxpool_3)) 35 | relu4_2 = F.relu(self.conv4_2(relu4_1)) 36 | relu4_3 = F.relu(self.conv4_3(relu4_2)) 37 | 38 | return [relu1_2,relu2_2,relu3_3,relu4_3] 39 | 40 | class ResidualBlock(Module): 41 | def __init__(self,num): 42 | super(ResidualBlock, self).__init__() 43 | self.c1 = torch.nn.Conv2d(num, num, kernel_size=3, stride=1, padding=1) 44 | self.c2 = torch.nn.Conv2d(num, num, kernel_size=3, stride=1, padding=1) 45 | self.b1 = torch.nn.BatchNorm2d(num) 46 | self.b2 = torch.nn.BatchNorm2d(num) 47 | 48 | def forward(self, x): 49 | h = F.relu(self.b1(self.c1(x))) 50 | h = self.b2(self.c2(h)) 51 | return h + x 52 | 53 | class ImageTransformNet(Module): 54 | def __init__(self): 55 | super(ImageTransformNet, self).__init__() 56 | self.c1 = torch.nn.Conv2d(3, 32, kernel_size=9, stride=1, padding=4) 57 | self.c2 = torch.nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) 58 | self.c3 = torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) 59 | self.r1 = ResidualBlock(128) 60 | self.r2 = ResidualBlock(128) 61 | self.r3 = ResidualBlock(128) 62 | self.r4 = ResidualBlock(128) 63 | self.r5 = ResidualBlock(128) 64 | self.d1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) 65 | self.d2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1) 66 | self.d3 = torch.nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4) 67 | self.b1 = torch.nn.BatchNorm2d(32) 68 | self.b2 = torch.nn.BatchNorm2d(64) 69 | self.b3 = torch.nn.BatchNorm2d(128) 70 | self.b4 = torch.nn.BatchNorm2d(64) 71 | self.b5 = torch.nn.BatchNorm2d(32) 72 | 73 | def forward(self, x): 74 | h = F.relu(self.b1(self.c1(x))) 75 | h = F.relu(self.b2(self.c2(h))) 76 | h = F.relu(self.b3(self.c3(h))) 77 | h = self.r1(h) 78 | h = self.r2(h) 79 | h = self.r3(h) 80 | h = self.r4(h) 81 | h = self.r5(h) 82 | h = F.relu(self.b4(self.d1(h))) 83 | h = F.relu(self.b5(self.d2(h))) 84 | y = self.d3(h) 85 | return y 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets, transforms 8 | from data_utils import batch_rgb_to_bgr, load_image 9 | from loss import loss_function 10 | import models 11 | 12 | # Training settings 13 | parser = argparse.ArgumentParser(description='Fast Neural style transfer using PyTorch.') 14 | parser.add_argument('--style_image', metavar='ref', type=str, help='Path to the style reference image.') 15 | parser.add_argument("--dataset_path", type=str, help="Path to training images") 16 | parser.add_argument("--content_weight", type=float, default=1.0, help='Content weight') 17 | parser.add_argument("--style_weight", type=float, default=5.0, help='Style weight') 18 | parser.add_argument("--image_size", default=256, type=int, help='Output Image size') 19 | parser.add_argument("--epochs", default=2, type=int, help='Number of epochs') 20 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') 21 | parser.add_argument("--batchSize", default=4, type=int, help='Number of images per epoch') 22 | parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate of optimizer') 23 | parser.add_argument('--cuda', action='store_true', help='use cuda?') 24 | args = parser.parse_args() 25 | 26 | cuda = args.cuda 27 | if cuda and not torch.cuda.is_available(): 28 | raise Exception("No GPU found, please run without --cuda") 29 | 30 | 31 | print('===> Loading datasets') 32 | transform = transforms.Compose([transforms.Scale(args.image_size), 33 | transforms.CenterCrop(args.image_size), 34 | transforms.ToTensor(), 35 | transforms.Lambda(lambda x: x.mul(255))]) 36 | train_set = datasets.ImageFolder(args.dataset_path, transform) 37 | data_loader = DataLoader(dataset=train_set, num_workers=args.threads, batch_size=args.batchSize, shuffle=True) 38 | 39 | 40 | print('===> Building model') 41 | model = models.ImageTransformNet() 42 | if cuda: 43 | model.cuda() 44 | 45 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 46 | model.train() 47 | 48 | print('===> Loading Style Image') 49 | style_image = load_image(args.style_image, args.image_size) 50 | style_image_batch = style_image.repeat(args.batchSize, 1, 1, 1) 51 | style_image_batch = batch_rgb_to_bgr(style_image_batch) 52 | if cuda: 53 | style_image_batch = style_image_batch.cuda() 54 | xs = Variable(style_image_batch, volatile=True) 55 | 56 | print('===> Training model') 57 | def train(epoch): 58 | epoch_loss = 0 59 | for iteration, batch in enumerate(data_loader): 60 | x = Variable(batch[0]) 61 | x = batch_rgb_to_bgr(x) 62 | if cuda: 63 | x = x.cuda() 64 | 65 | y_hat = model(x) 66 | xc = Variable(x.data, volatile=True) 67 | optimizer.zero_grad() 68 | loss = loss_function(args.content_weight, args.style_weight, xc, xs, y_hat, cuda) 69 | loss.backward() 70 | optimizer.step() 71 | 72 | print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(data_loader), loss.data[0])) 73 | 74 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader))) 75 | 76 | 77 | def checkpoint(epoch): 78 | model_out_path = "model_epoch_{}.pth".format(epoch) 79 | torch.save(model.state_dict(), model_out_path) 80 | print("Checkpoint saved to {}".format(model_out_path)) 81 | 82 | for epoch in range(1, args.epochs + 1): 83 | train(epoch) 84 | checkpoint(epoch) --------------------------------------------------------------------------------