├── style ├── 1.jpg ├── wave.jpg ├── Starry.jpg ├── candy.jpg ├── udnie.jpg ├── wreck.jpg ├── la_muse.jpg ├── style11.jpg ├── seated_nude.jpg ├── rain_princess.jpg └── Composition-VII.jpg ├── model ├── __pycache__ │ ├── Net.cpython-36.pyc │ ├── VGG.cpython-36.pyc │ ├── SANet.cpython-36.pyc │ ├── Decoder.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── Transform.cpython-36.pyc ├── Transform.py ├── Decoder.py ├── VGG.py ├── SANet.py └── Net.py ├── dataset ├── __pycache__ │ ├── MPI_seq.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── dataset.cpython-36.pyc ├── video_dataset.py └── dataset.py ├── weight └── readme.md ├── README.md ├── image_transfer.py ├── viedo_transfer.py ├── image_train.py └── video_train.py /style/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/1.jpg -------------------------------------------------------------------------------- /style/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/wave.jpg -------------------------------------------------------------------------------- /style/Starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/Starry.jpg -------------------------------------------------------------------------------- /style/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/candy.jpg -------------------------------------------------------------------------------- /style/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/udnie.jpg -------------------------------------------------------------------------------- /style/wreck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/wreck.jpg -------------------------------------------------------------------------------- /style/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/la_muse.jpg -------------------------------------------------------------------------------- /style/style11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/style11.jpg -------------------------------------------------------------------------------- /style/seated_nude.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/seated_nude.jpg -------------------------------------------------------------------------------- /style/rain_princess.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/rain_princess.jpg -------------------------------------------------------------------------------- /style/Composition-VII.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/style/Composition-VII.jpg -------------------------------------------------------------------------------- /model/__pycache__/Net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/model/__pycache__/Net.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/VGG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/model/__pycache__/VGG.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/SANet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/model/__pycache__/SANet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/Decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/model/__pycache__/Decoder.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/MPI_seq.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/dataset/__pycache__/MPI_seq.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/dataset/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/Transform.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnchanterXiao/video-style-transfer/HEAD/model/__pycache__/Transform.cpython-36.pyc -------------------------------------------------------------------------------- /weight/readme.md: -------------------------------------------------------------------------------- 1 | ## orginal weight: 2 | -decoder: https://pan.baidu.com/s/1q_tixQ84E4q9hvKfjOSJfA 3 | 4 | -transformer: https://pan.baidu.com/s/1Q7XDJvAdedCwkOslNbYzpw 5 | 6 | -vgg_normalised: https://pan.baidu.com/s/1jvS4oLSsV92jYKGoaLaKzQ 7 | 8 | ## fine-tune weight(for video): 9 | -decoder: https://pan.baidu.com/s/1VhnHG0tkaoGO3vJAF7yXpA 10 | 11 | -transformer: https://pan.baidu.com/s/1XfDq10PUXHR0bDk0031j6g 12 | -------------------------------------------------------------------------------- /model/Transform.py: -------------------------------------------------------------------------------- 1 | from model.SANet import * 2 | 3 | class Transform(nn.Module): 4 | def __init__(self, in_planes): 5 | super(Transform, self).__init__() 6 | self.sanet4_1 = SANet(in_planes = in_planes) 7 | self.sanet5_1 = SANet(in_planes = in_planes) 8 | self.upsample5_1 = nn.Upsample(scale_factor=2, mode='nearest') 9 | self.merge_conv_pad = nn.ReflectionPad2d((1, 1, 1, 1)) 10 | self.merge_conv = nn.Conv2d(in_planes, in_planes, (3, 3)) 11 | def forward(self, content4_1, style4_1, content5_1, style5_1): 12 | return self.merge_conv(self.merge_conv_pad(self.sanet4_1(content4_1, style4_1) + self.upsample5_1(self.sanet5_1(content5_1, style5_1)))) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # video-style-transfer 2 | This is video style transfer PyTorch implementation based on "Arbitrary Style Transfer with Style-Attentional Networks". 3 | 4 | Official paper: https://arxiv.org/abs/1812.02342v5. 5 | 6 | Source code: https://github.com/GlebBrykin/SANET 7 | 8 | ## Dataset: 9 | COCO 10 | WikiArt 11 | Video sequence(60 videos, from https://www.videvo.net/) 12 | 13 | ## Modify: 14 | Add temporal loss and Spatial smoothing loss to fine-tune. 15 | Use image pair from video to fine-tune. 16 | 17 | ## Train: 18 | Image_train: COCO+WikiArt 19 | video_train: Video sequence+WikiArt 20 | 21 | ## Test: 22 | Image_transfer: single image transfer 23 | video_transfer: video transfer 24 | 25 | ## Result: 26 | demo1:https://pan.baidu.com/s/1o40EPY7_6FnMKsnjGaD24Q 27 | demo2:https://pan.baidu.com/s/1ZMPegXQCBB35NimzmEg2fQ 28 | 29 | -------------------------------------------------------------------------------- /model/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | cfg = { 5 | 'Decoder': [256, 'U', 256, 256, 256, 128, 'U', 128, 64, 'U', 64, 3] 6 | } 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, vgg_name): 10 | super(Decoder, self).__init__() 11 | self.features = self._make_layers(cfg[vgg_name]) 12 | 13 | def forward(self, x): 14 | out = self.features(x) 15 | return out 16 | 17 | def _make_layers(self, cfg): 18 | layers = [] 19 | in_channels = 512 20 | for x in cfg: 21 | if x == 'U': 22 | layers += [nn.Upsample(scale_factor=2, mode='nearest')] 23 | elif x==3: 24 | layers += [nn.ReflectionPad2d((1, 1, 1, 1)), 25 | nn.Conv2d(in_channels, x, kernel_size=3)] 26 | else: 27 | layers += [nn.ReflectionPad2d((1, 1, 1, 1)), 28 | nn.Conv2d(in_channels, x, kernel_size=3), 29 | nn.ReLU(inplace=True)] 30 | in_channels = x 31 | return nn.Sequential(*layers) 32 | 33 | if __name__ == "__main__": 34 | decoder = Decoder('Decoder') 35 | print(decoder) -------------------------------------------------------------------------------- /model/VGG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | cfg = { 5 | 'VGG11': [64, 'M', 128, 'M', 256,'M', 512, 'M', 512,'M'], 6 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 7 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 8 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 9 | } 10 | 11 | class VGG(nn.Module): 12 | def __init__(self, vgg_name): 13 | super(VGG, self).__init__() 14 | self.features = self._make_layers(cfg[vgg_name]) 15 | 16 | def forward(self, x): 17 | out = self.features(x) 18 | return out 19 | 20 | def _make_layers(self, cfg): 21 | layers = [] 22 | in_channels = 3 23 | layers += [nn.Conv2d(3, 3, (1, 1))] 24 | for x in cfg: 25 | if x == 'M': 26 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 27 | else: 28 | layers += [nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(in_channels, x, kernel_size=3), 30 | nn.ReLU(inplace=True)] 31 | in_channels = x 32 | return nn.Sequential(*layers) 33 | 34 | 35 | if __name__ == "__main__": 36 | net = VGG('VGG19') 37 | print(net.features) -------------------------------------------------------------------------------- /dataset/video_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy as np 3 | from torch.utils import data 4 | from PIL import Image 5 | from PIL import ImageFile 6 | import torch.backends.cudnn as cudnn 7 | from torchvision import transforms 8 | import os 9 | 10 | 11 | 12 | class Video_dataset(data.Dataset): 13 | def __init__(self, root, transform): 14 | super(Video_dataset, self).__init__() 15 | self.root = root 16 | self.paths = os.listdir(self.root) 17 | self.transform = transform 18 | 19 | def __getitem__(self, index): 20 | path = self.paths[index] 21 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 22 | img = self.transform(img) 23 | if index<1: 24 | path2 = self.paths[index] 25 | else: 26 | path2 = self.paths[index-1] 27 | img2 = Image.open(os.path.join(self.root, path2)).convert('RGB') 28 | img2 = self.transform(img2) 29 | return img, img2 30 | 31 | def __len__(self): 32 | return len(self.paths) 33 | 34 | def name(self): 35 | return 'MPISeq' 36 | 37 | 38 | if __name__ == '__main__': 39 | seqs_str = '''alley_1 40 | alley_2 41 | ambush_2 42 | ambush_4 43 | ambush_5''' 44 | seqs = [seq.strip() for seq in seqs_str.split()] 45 | 46 | 47 | -------------------------------------------------------------------------------- /model/SANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def calc_mean_std(feat, eps=1e-5): 5 | # eps is a small value added to the variance to avoid divide-by-zero. 6 | size = feat.size() 7 | assert (len(size) == 4) 8 | N, C = size[:2] 9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 10 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 12 | return feat_mean, feat_std 13 | 14 | def mean_variance_norm(feat): 15 | size = feat.size() 16 | mean, std = calc_mean_std(feat) 17 | normalized_feat = (feat - mean.expand(size)) / std.expand(size) 18 | return normalized_feat 19 | 20 | class SANet(nn.Module): 21 | 22 | def __init__(self, in_planes): 23 | super(SANet, self).__init__() 24 | self.f = nn.Conv2d(in_planes, in_planes, (1, 1)) 25 | self.g = nn.Conv2d(in_planes, in_planes, (1, 1)) 26 | self.h = nn.Conv2d(in_planes, in_planes, (1, 1)) 27 | self.sm = nn.Softmax(dim=-1) 28 | self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1)) 29 | 30 | def forward(self, content, style): 31 | F = self.f(mean_variance_norm(content)) 32 | G = self.g(mean_variance_norm(style)) 33 | H = self.h(style) 34 | b, c, h, w = F.size() 35 | F = F.view(b, -1, w * h).permute(0, 2, 1) 36 | b, c, h, w = G.size() 37 | G = G.view(b, -1, w * h) 38 | S = torch.bmm(F, G) 39 | S = self.sm(S) 40 | b, c, h, w = H.size() 41 | H = H.view(b, -1, w * h) 42 | O = torch.bmm(H, S.permute(0, 2, 1)) 43 | b, c, h, w = content.size() 44 | O = O.view(b, c, h, w) 45 | O = self.out_conv(O) 46 | O += content 47 | return O 48 | 49 | 50 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils import data 3 | from PIL import Image 4 | from PIL import ImageFile 5 | import torch.backends.cudnn as cudnn 6 | from torchvision import transforms 7 | import os 8 | 9 | def InfiniteSampler(n): 10 | # i = 0 11 | i = n - 1 12 | order = np.random.permutation(n) 13 | while True: 14 | yield order[i] 15 | i += 1 16 | if i >= n: 17 | np.random.seed() 18 | order = np.random.permutation(n) 19 | i = 0 20 | 21 | 22 | class InfiniteSamplerWrapper(data.sampler.Sampler): 23 | def __init__(self, data_source): 24 | self.num_samples = len(data_source) 25 | 26 | def __iter__(self): 27 | return iter(InfiniteSampler(self.num_samples)) 28 | 29 | def __len__(self): 30 | return 2 ** 31 31 | 32 | cudnn.benchmark = True 33 | Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError 34 | ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated 35 | 36 | 37 | def train_transform(): 38 | transform_list = [ 39 | transforms.Resize(size=(512, 512)), 40 | transforms.RandomCrop(256), 41 | transforms.ToTensor() 42 | ] 43 | return transforms.Compose(transform_list) 44 | 45 | def train_transform2(): 46 | transform_list = [ 47 | transforms.Resize(size=(256, 256)), 48 | transforms.ToTensor() 49 | ] 50 | return transforms.Compose(transform_list) 51 | 52 | class FlatFolderDataset(data.Dataset): 53 | def __init__(self, root, transform): 54 | super(FlatFolderDataset, self).__init__() 55 | self.root = root 56 | self.paths = os.listdir(self.root) 57 | self.transform = transform 58 | 59 | def __getitem__(self, index): 60 | path = self.paths[index] 61 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 62 | img = self.transform(img) 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.paths) 67 | 68 | def name(self): 69 | return 'FlatFolderDataset' -------------------------------------------------------------------------------- /image_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from model.VGG import * 3 | from model.Decoder import * 4 | from model.Transform import * 5 | import os 6 | from PIL import Image 7 | from os.path import basename 8 | from os.path import splitext 9 | from torchvision import transforms 10 | from torchvision.utils import save_image 11 | 12 | def test_transform(): 13 | transform_list = [] 14 | transform_list.append(transforms.ToTensor()) 15 | transform = transforms.Compose(transform_list) 16 | return transform 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | # Basic options 21 | parser.add_argument('--content', type=str, default='input/chicago.jpg', 22 | help='File path to the content image') 23 | parser.add_argument('--style', type=str, default='style/1.jpg', 24 | help='File path to the style image, or multiple style \ 25 | images separated by commas if you want to do style \ 26 | interpolation or spatial control') 27 | parser.add_argument('--steps', type=str, default=1) 28 | parser.add_argument('--vgg', type=str, default='weight/vgg_normalised.pth') 29 | parser.add_argument('--decoder', type=str, default='decoder_iter_500000.pth') 30 | parser.add_argument('--transform', type=str, default='weight/transformer_iter_500000.pth') 31 | 32 | # Additional options 33 | parser.add_argument('--save_ext', default='.jpg', 34 | help='The extension name of the output image') 35 | parser.add_argument('--output', type=str, default='output', 36 | help='Directory to save the output image(s)') 37 | 38 | # Advanced options 39 | 40 | args = parser.parse_args('') 41 | 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | 44 | if not os.path.exists(args.output): 45 | os.mkdir(args.output) 46 | 47 | decoder = Decoder('Decoder') 48 | transform = Transform(in_planes=512) 49 | vgg = VGG('VGG19') 50 | 51 | decoder.eval() 52 | transform.eval() 53 | vgg.eval() 54 | 55 | decoder.load_state_dict(torch.load(args.decoder)) 56 | transform.load_state_dict(torch.load(args.transform)) 57 | vgg.features.load_state_dict(torch.load(args.vgg)) 58 | 59 | # norm = nn.Sequential(*list(vgg.features())[:1]) 60 | enc_1 = nn.Sequential(*list(vgg.features.children())[:4]) # input -> relu1_1 61 | enc_2 = nn.Sequential(*list(vgg.features.children())[4:11]) # relu1_1 -> relu2_1 62 | enc_3 = nn.Sequential(*list(vgg.features.children())[11:18]) # relu2_1 -> relu3_1 63 | enc_4 = nn.Sequential(*list(vgg.features.children())[18:31]) # relu3_1 -> relu4_1 64 | enc_5 = nn.Sequential(*list(vgg.features.children())[31:44]) # relu4_1 -> relu5_1 65 | 66 | 67 | # norm.to(device) 68 | enc_1.to(device) 69 | enc_2.to(device) 70 | enc_3.to(device) 71 | enc_4.to(device) 72 | enc_5.to(device) 73 | transform.to(device) 74 | decoder.to(device) 75 | 76 | content_tf = test_transform() 77 | style_tf = test_transform() 78 | 79 | content = content_tf(Image.open(args.content)) 80 | style = style_tf(Image.open(args.style)) 81 | 82 | style = style.to(device).unsqueeze(0) 83 | content = content.to(device).unsqueeze(0) 84 | with torch.no_grad(): 85 | for x in range(args.steps): 86 | print('iteration ' + str(x)) 87 | 88 | Content4_1 = enc_4(enc_3(enc_2(enc_1(content)))) 89 | Content5_1 = enc_5(Content4_1) 90 | 91 | Style4_1 = enc_4(enc_3(enc_2(enc_1(style)))) 92 | Style5_1 = enc_5(Style4_1) 93 | 94 | content = decoder(transform(Content4_1, Style4_1, Content5_1, Style5_1)) 95 | 96 | content.clamp(0, 255) 97 | 98 | content = content.cpu() 99 | 100 | output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format( 101 | args.output, splitext(basename(args.content))[0], 102 | splitext(basename(args.style))[0], args.save_ext 103 | ) 104 | save_image(content, output_name) -------------------------------------------------------------------------------- /viedo_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from model.VGG import * 3 | from model.Decoder import * 4 | from model.Transform import * 5 | import os 6 | from PIL import Image 7 | from os.path import basename 8 | from os.path import splitext 9 | from torchvision import transforms 10 | from torchvision.utils import save_image 11 | import cv2 as cv 12 | import numpy as np 13 | import time 14 | 15 | def test_transform(): 16 | transform_list = [] 17 | transform_list.append(transforms.ToTensor()) 18 | transform = transforms.Compose(transform_list) 19 | return transform 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | # Basic options 24 | parser.add_argument('--content', type=str, default = 'input/test3.avi', 25 | help='File path to the content image') 26 | parser.add_argument('--style', type=str, default = 'style/style11.jpg', 27 | help='File path to the style image, or multiple style \ 28 | images separated by commas if you want to do style \ 29 | interpolation or spatial control') 30 | parser.add_argument('--steps', type=str, default = 1) 31 | parser.add_argument('--vgg', type=str, default = 'weight/vgg_normalised.pth') 32 | parser.add_argument('--decoder', type=str, default = 'experiments4/decoder_iter_600000.pth') 33 | parser.add_argument('--transform', type=str, default = 'experiments4/transformer_iter_600000.pth') 34 | # Additional options 35 | parser.add_argument('--save_ext', default = 'output+', 36 | help='The extension name of the output viedo') 37 | parser.add_argument('--output', type=str, default = 'output', 38 | help='Directory to save the output image(s)') 39 | 40 | # Advanced options 41 | args = parser.parse_args('') 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | if not os.path.exists(args.output): 44 | os.mkdir(args.output) 45 | decoder = Decoder('Decoder') 46 | transform = Transform(in_planes=512) 47 | vgg = VGG('VGG19') 48 | 49 | decoder.eval() 50 | transform.eval() 51 | vgg.eval() 52 | 53 | # decoder.features.load_state_dict(torch.load(args.decoder)) 54 | decoder.load_state_dict(torch.load(args.decoder)) 55 | transform.load_state_dict(torch.load(args.transform)) 56 | vgg.features.load_state_dict(torch.load(args.vgg)) 57 | 58 | enc_1 = nn.Sequential(*list(vgg.features.children())[:4]) # input -> relu1_1 59 | enc_2 = nn.Sequential(*list(vgg.features.children())[4:11]) # relu1_1 -> relu2_1 60 | enc_3 = nn.Sequential(*list(vgg.features.children())[11:18]) # relu2_1 -> relu3_1 61 | enc_4 = nn.Sequential(*list(vgg.features.children())[18:31]) # relu3_1 -> relu4_1 62 | enc_5 = nn.Sequential(*list(vgg.features.children())[31:44]) # relu4_1 -> relu5_1 63 | 64 | 65 | enc_1.to(device) 66 | enc_2.to(device) 67 | enc_3.to(device) 68 | enc_4.to(device) 69 | enc_5.to(device) 70 | transform.to(device) 71 | decoder.to(device) 72 | 73 | content_tf = test_transform() 74 | style_tf = test_transform() 75 | 76 | cap = cv.VideoCapture(args.content) 77 | 78 | style = style_tf(Image.open(args.style)) 79 | 80 | style = style.to(device).unsqueeze(0) 81 | fourcc = cv.VideoWriter_fourcc(*'XVID') 82 | fps = cap.get(cv.CAP_PROP_FPS) 83 | 84 | out = cv.VideoWriter(args.save_ext, fourcc, fps, (512, 512)) 85 | 86 | fps = 0 87 | while(1): 88 | fps += 1 89 | ret, frame = cap.read() 90 | if ret!=True: 91 | break 92 | frame = cv.resize(frame, (512, 512)) 93 | frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB) 94 | frame = frame.astype(np.float) / 255 95 | content = torch.from_numpy(frame) 96 | content = content.transpose(0, 2) 97 | content = content.transpose(1, 2) 98 | content = content.type(torch.FloatTensor) 99 | content = content.to(device).unsqueeze(0) 100 | 101 | with torch.no_grad(): 102 | start = time.time() 103 | 104 | Style4_1 = enc_4(enc_3(enc_2(enc_1(style)))) 105 | Style5_1 = enc_5(Style4_1) 106 | Content4_1 = enc_4(enc_3(enc_2(enc_1(content)))) 107 | Content5_1 = enc_5(Content4_1) 108 | 109 | content = decoder(transform(Content4_1, Style4_1, Content5_1, Style5_1)) 110 | 111 | end = time.time() 112 | content.clamp(0, 255) 113 | content = content.cpu() 114 | content = content[0] 115 | content = content.transpose(1, 2) 116 | content = content.transpose(0, 2) 117 | content = content.numpy()*255 118 | 119 | output_value = np.clip(content, 0, 255).astype(np.uint8) 120 | output_value = cv.cvtColor(output_value, cv.COLOR_RGB2BGR) 121 | 122 | out.write(output_value) 123 | print('frame(%s) transfer! fps:%.2f' % (fps,1/(end-start))) -------------------------------------------------------------------------------- /image_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import os 4 | import torch 5 | from PIL import Image 6 | from PIL import ImageFile 7 | from tensorboardX import SummaryWriter 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | from model.Decoder import * 11 | from model.VGG import * 12 | from model.SANet import * 13 | from model.Net import * 14 | from dataset.dataset import * 15 | 16 | import numpy as np 17 | from torch.utils import data 18 | 19 | def adjust_learning_rate(optimizer, iteration_count): 20 | """Imitating the original implementation""" 21 | lr = args.lr / (1.0 + args.lr_decay * iteration_count) 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | # Basic options 28 | parser.add_argument('--content_dir', type=str, default='/media/wwh/XIaoxin/Datasets/coco_2014/data/coco_2014/images/train2014/', 29 | help='Directory path to a batch of content images') 30 | parser.add_argument('--style_dir', type=str, default='/home/wwh/Desktop/wikiArt/', 31 | help='Directory path to a batch of style images') 32 | parser.add_argument('--vgg', type=str, default='weight/vgg_normalised.pth') 33 | 34 | # training options 35 | parser.add_argument('--save_dir', default='./experiments', 36 | help='Directory to save the model') 37 | parser.add_argument('--log_dir', default='./logs', 38 | help='Directory to save the log') 39 | parser.add_argument('--lr', type=float, default=1e-4) 40 | parser.add_argument('--lr_decay', type=float, default=5e-5) 41 | parser.add_argument('--max_iter', type=int, default=1600000) 42 | parser.add_argument('--batch_size', type=int, default=1) 43 | parser.add_argument('--style_weight', type=float, default=3.0) 44 | parser.add_argument('--content_weight', type=float, default=1.0) 45 | parser.add_argument('--n_threads', type=int, default=16) 46 | parser.add_argument('--save_model_interval', type=int, default=10000) 47 | parser.add_argument('--start_iter', type=float, default=500000) 48 | args = parser.parse_args('') 49 | 50 | device = torch.device('cuda') 51 | 52 | decoder = Decoder('Decoder') 53 | vgg = VGG('VGG19') 54 | 55 | vgg.features.load_state_dict(torch.load(args.vgg)) 56 | vgg = nn.Sequential(*list(vgg.features.children())[:44]) 57 | network = Net(vgg, decoder, args.start_iter) 58 | 59 | network.train() 60 | network.to(device) 61 | 62 | content_tf = train_transform() 63 | style_tf = train_transform() 64 | 65 | content_dataset = FlatFolderDataset(args.content_dir, content_tf) 66 | style_dataset = FlatFolderDataset(args.style_dir, style_tf) 67 | 68 | content_iter = iter(data.DataLoader( 69 | content_dataset, batch_size=args.batch_size, 70 | sampler=InfiniteSamplerWrapper(content_dataset), 71 | num_workers=args.n_threads)) 72 | 73 | style_iter = iter(data.DataLoader( 74 | style_dataset, batch_size=args.batch_size, 75 | sampler=InfiniteSamplerWrapper(style_dataset), 76 | num_workers=args.n_threads)) 77 | 78 | optimizer = torch.optim.Adam([ 79 | {'params': network.decoder.parameters()}, 80 | {'params': network.transform.parameters()}], lr=args.lr) 81 | 82 | # if(args.start_iter > 0): 83 | # optimizer.load_state_dict(torch.load('optimizer_iter_' + str(args.start_iter) + '.pth')) 84 | 85 | writer = SummaryWriter('runs/loss4') 86 | 87 | for i in tqdm(range(args.start_iter, args.max_iter)): 88 | adjust_learning_rate(optimizer, iteration_count=i) 89 | content_images = next(content_iter).to(device) 90 | style_images = next(style_iter).to(device) 91 | print(content_images.shape) 92 | print(style_images.shape) 93 | loss_c, loss_s, l_identity1, l_identity2, T_loss = network(content_images, style_images, content_images, True) 94 | loss_c = args.content_weight * loss_c 95 | loss_s = args.style_weight * loss_s 96 | loss = loss_c + loss_s + l_identity1 * 50 + l_identity2 * 1 97 | writer.add_scalar('total loss', loss, global_step=i) 98 | 99 | optimizer.zero_grad() 100 | loss.backward() 101 | optimizer.step() 102 | 103 | if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter: 104 | state_dict = decoder.state_dict() 105 | for key in state_dict.keys(): 106 | state_dict[key] = state_dict[key].to(torch.device('cpu')) 107 | torch.save(state_dict, 108 | '{:s}/decoder_iter_{:d}.pth'.format(args.save_dir, 109 | i + 1)) 110 | state_dict = network.transform.state_dict() 111 | for key in state_dict.keys(): 112 | state_dict[key] = state_dict[key].to(torch.device('cpu')) 113 | torch.save(state_dict, 114 | '{:s}/transformer_iter_{:d}.pth'.format(args.save_dir, 115 | i + 1)) 116 | state_dict = optimizer.state_dict() 117 | torch.save(state_dict, 118 | '{:s}/optimizer_iter_{:d}.pth'.format(args.save_dir, 119 | i + 1)) 120 | 121 | writer.close() -------------------------------------------------------------------------------- /video_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import os 4 | import torch 5 | from torchvision import transforms 6 | from tqdm import tqdm 7 | from model.Decoder import * 8 | from model.VGG import * 9 | from model.SANet import * 10 | from model.Net import * 11 | import numpy as np 12 | from dataset.video_dataset import * 13 | from dataset.dataset import * 14 | from tensorboardX import SummaryWriter 15 | 16 | def adjust_learning_rate(optimizer, iteration_count): 17 | """Imitating the original implementation""" 18 | lr = args.lr / (1.0 + args.lr_decay * iteration_count) 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = lr 21 | 22 | '''model_load''' 23 | parser = argparse.ArgumentParser() 24 | # Basic options 25 | parser.add_argument('--content_dir', type=str, default='/media/wwh/XIaoxin/Datasets/video_frame/', 26 | help='Directory path to a batch of content images') 27 | parser.add_argument('--style_dir', type=str, default='/media/wwh/XIaoxin/Datasets/wikiArt/', 28 | help='Directory path to a batch of style images') 29 | parser.add_argument('--vgg', type=str, default='weight/vgg_normalised.pth') 30 | parser.add_argument('--decoder', type=str, default = 'weight/decoder_iter_500000.pth') 31 | parser.add_argument('--transform', type=str, default = 'weight/transformer_iter_500000.pth') 32 | # training options 33 | parser.add_argument('--save_dir', default='./experiments4', 34 | help='Directory to save the model') 35 | parser.add_argument('--log_dir', default='./logs', 36 | help='Directory to save the log') 37 | parser.add_argument('--lr', type=float, default=1e-4) 38 | parser.add_argument('--lr_decay', type=float, default=5e-5) 39 | parser.add_argument('--max_iter', type=int, default=600000) 40 | parser.add_argument('--batch_size', type=int, default=5) 41 | parser.add_argument('--style_weight', type=float, default=3.0) 42 | parser.add_argument('--content_weight', type=float, default=1.0) 43 | parser.add_argument('--temporal_weight', type=float, default=2.0) 44 | parser.add_argument('--v_weight', type=float, default=20.0) 45 | parser.add_argument('--n_threads', type=int, default=16) 46 | parser.add_argument('--save_model_interval', type=int, default=10000) 47 | parser.add_argument('--start_iter', type=float, default=500000) 48 | args = parser.parse_args('') 49 | 50 | 51 | device = torch.device('cuda') 52 | decoder = Decoder('Decoder') 53 | vgg = VGG('VGG19') 54 | 55 | vgg.features.load_state_dict(torch.load(args.vgg)) 56 | vgg = nn.Sequential(*list(vgg.features.children())[:44]) 57 | network = Net(vgg, decoder, args.start_iter) 58 | network.train() 59 | network.to(device) 60 | 61 | 62 | optimizer = torch.optim.Adam([ 63 | {'params': network.decoder.parameters()}, 64 | {'params': network.transform.parameters()}], lr=args.lr) 65 | 66 | 67 | style_tf = train_transform() 68 | content_tf = train_transform2() 69 | style_dataset = FlatFolderDataset(args.style_dir, style_tf) 70 | style_iter = iter(data.DataLoader( 71 | style_dataset, batch_size=args.batch_size, 72 | sampler=InfiniteSamplerWrapper(style_dataset), 73 | num_workers=args.n_threads)) 74 | content_dataset = Video_dataset(args.content_dir, content_tf) 75 | content_iter = iter(data.DataLoader( 76 | content_dataset, batch_size=args.batch_size, 77 | sampler=InfiniteSamplerWrapper(content_dataset), 78 | num_workers=args.n_threads)) 79 | 80 | writer = SummaryWriter('runs/loss4') 81 | 82 | for i in tqdm(range(args.start_iter, args.max_iter)): 83 | adjust_learning_rate(optimizer, iteration_count=i) 84 | style_images = next(style_iter).to(device) 85 | content_image1, content_image2 = next(content_iter) 86 | content_image1 = content_image1.to(device) 87 | content_image2 = content_image2.to(device) 88 | # print(content_image1.shape) 89 | loss_c, loss_s, l_identity1, l_identity2, loss_t, loss_v = network(content_image1, style_images, content_image2, True) 90 | loss_c = args.content_weight * loss_c 91 | loss_s = args.style_weight * loss_s 92 | loss_t = args.temporal_weight*loss_t 93 | loss_v = args.v_weight*loss_v 94 | loss = loss_c + loss_s + l_identity1 * 50 + l_identity2 * 1 + loss_t + loss_v 95 | # print(loss_t) 96 | # print(loss_v) 97 | writer.add_scalar('total loss', loss, global_step=i) 98 | writer.add_scalar('tempal_loss', loss_t, global_step=i) 99 | writer.add_scalar('variation_loss', loss_v, global_step=i) 100 | 101 | optimizer.zero_grad() 102 | loss.backward() 103 | optimizer.step() 104 | 105 | if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter: 106 | state_dict = decoder.state_dict() 107 | for key in state_dict.keys(): 108 | state_dict[key] = state_dict[key].to(torch.device('cpu')) 109 | torch.save(state_dict, 110 | '{:s}/decoder_iter_{:d}.pth'.format(args.save_dir, 111 | i + 1)) 112 | state_dict = network.transform.state_dict() 113 | for key in state_dict.keys(): 114 | state_dict[key] = state_dict[key].to(torch.device('cpu')) 115 | torch.save(state_dict, 116 | '{:s}/transformer_iter_{:d}.pth'.format(args.save_dir, 117 | i + 1)) 118 | state_dict = optimizer.state_dict() 119 | torch.save(state_dict, 120 | '{:s}/optimizer_iter_{:d}.pth'.format(args.save_dir, 121 | i + 1)) 122 | writer.close() -------------------------------------------------------------------------------- /model/Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | from model.Transform import * 5 | from model.SANet import * 6 | from model.Decoder import * 7 | from model.VGG import * 8 | import numpy as np 9 | 10 | def calc_mean_std(feat, eps=1e-5): 11 | # eps is a small value added to the variance to avoid divide-by-zero. 12 | size = feat.size() 13 | assert (len(size) == 4) 14 | N, C = size[:2] 15 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 16 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 17 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 18 | return feat_mean, feat_std 19 | 20 | def mean_variance_norm(feat): 21 | size = feat.size() 22 | mean, std = calc_mean_std(feat) 23 | normalized_feat = (feat - mean.expand(size)) / std.expand(size) 24 | return normalized_feat 25 | 26 | def _calc_feat_flatten_mean_std(feat): 27 | # takes 3D feat (C, H, W), return mean and std of array within channels 28 | assert (feat.size()[0] == 3) 29 | assert (isinstance(feat, torch.FloatTensor)) 30 | feat_flatten = feat.view(3, -1) 31 | mean = feat_flatten.mean(dim=-1, keepdim=True) 32 | std = feat_flatten.std(dim=-1, keepdim=True) 33 | return feat_flatten, mean, std 34 | 35 | 36 | class Net(nn.Module): 37 | def __init__(self, encoder, decoder, start_iter): 38 | super(Net, self).__init__() 39 | vgg = encoder 40 | # self.enc_0 = nn.Sequential(*list(vgg.children())[:1]) 41 | # enc_layers = list(encoder.children()) 42 | self.enc_1 = nn.Sequential(*list(vgg.children())[:4]) # input -> relu1_1 43 | self.enc_2 = nn.Sequential(*list(vgg.children())[4:11]) # relu1_1 -> relu2_1 44 | self.enc_3 = nn.Sequential(*list(vgg.children())[11:18]) # relu2_1 -> relu3_1 45 | self.enc_4 = nn.Sequential(*list(vgg.children())[18:31]) # relu3_1 -> relu4_1 46 | self.enc_5 = nn.Sequential(*list(vgg.children())[31:44]) # relu4_1 -> relu5_1 47 | # transform 48 | self.transform = Transform(in_planes=512) 49 | self.decoder = decoder 50 | if (start_iter > 0): 51 | self.transform.load_state_dict(torch.load('weight/transformer_iter_' + str(start_iter) + '.pth')) 52 | self.decoder.load_state_dict(torch.load('weight/decoder_iter_' + str(start_iter) + '.pth')) 53 | self.mse_loss = nn.MSELoss() 54 | self.variation_loss = nn.L1Loss() 55 | # fix the encoder 56 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']: 57 | for param in getattr(self, name).parameters(): 58 | param.requires_grad = False 59 | 60 | self.dx_bias = np.zeros([256, 256]) 61 | self.dy_bias = np.zeros([256, 256]) 62 | for i in range(256): 63 | self.dx_bias[:, i] = i 64 | self.dx_bias[i, :] = i 65 | 66 | # extract relu1_1, relu2_1, relu3_1, relu4_1, relu5_1 from input image 67 | def encode_with_intermediate(self, input): 68 | results = [input] 69 | for i in range(5): 70 | func = getattr(self, 'enc_{:d}'.format(i+1)) 71 | results.append(func(results[-1])) 72 | return results[1:] 73 | 74 | def calc_content_loss(self, input, target, norm=False): 75 | if (norm == False): 76 | return self.mse_loss(input, target) 77 | else: 78 | return self.mse_loss(mean_variance_norm(input), mean_variance_norm(target)) 79 | 80 | def calc_style_loss(self, input, target): 81 | input_mean, input_std = calc_mean_std(input) 82 | target_mean, target_std = calc_mean_std(target) 83 | return self.mse_loss(input_mean, target_mean) + \ 84 | self.mse_loss(input_std, target_std) 85 | 86 | def calc_temporal_loss(self, x1, x2): 87 | h = x1.shape[2] 88 | w = x1.shape[3] 89 | D = h*w 90 | return self.mse_loss(x1, x2) 91 | 92 | def compute_total_variation_loss_l1(self, inputs): 93 | h = inputs.shape[2] 94 | w = inputs.shape[3] 95 | h1 = inputs[:, :, 0:h-1, :] 96 | h2 = inputs[:, :, 1:h, :] 97 | w1 = inputs[:, :, :, 0:w-1] 98 | w2 = inputs[:, :, :, 1:w] 99 | return self.variation_loss(h1, h2)+self.variation_loss(w1, w2) 100 | 101 | def forward(self, content, style, content2=None, video=False): 102 | style_feats = self.encode_with_intermediate(style) 103 | content_feats = self.encode_with_intermediate(content) 104 | stylized = self.transform(content_feats[3], style_feats[3], content_feats[4], style_feats[4]) 105 | g_t = self.decoder(stylized) 106 | loss_v = self.compute_total_variation_loss_l1(g_t) 107 | g_t_feats = self.encode_with_intermediate(g_t) 108 | loss_c = self.calc_content_loss(g_t_feats[3], content_feats[3], norm=True) + self.calc_content_loss( 109 | g_t_feats[4], content_feats[4], norm=True) 110 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) 111 | for i in range(1, 5): 112 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) 113 | """IDENTITY LOSSES""" 114 | Icc = self.decoder(self.transform(content_feats[3], content_feats[3], content_feats[4], content_feats[4])) 115 | Iss = self.decoder(self.transform(style_feats[3], style_feats[3], style_feats[4], style_feats[4])) 116 | l_identity1 = self.calc_content_loss(Icc, content) + self.calc_content_loss(Iss, style) 117 | Fcc = self.encode_with_intermediate(Icc) 118 | Fss = self.encode_with_intermediate(Iss) 119 | l_identity2 = self.calc_content_loss(Fcc[0], content_feats[0]) + self.calc_content_loss(Fss[0], style_feats[0]) 120 | for i in range(1, 5): 121 | l_identity2 += self.calc_content_loss(Fcc[i], content_feats[i]) + self.calc_content_loss(Fss[i], 122 | style_feats[i]) 123 | if video==False: 124 | return loss_c, loss_s, l_identity1, l_identity2, loss_v 125 | else: 126 | content_feats2 = self.encode_with_intermediate(content2) 127 | stylized2 = self.transform(content_feats2[3], style_feats[3], content_feats2[4], style_feats[4]) 128 | g_t2 = self.decoder(stylized2) 129 | g_t2_feats = self.encode_with_intermediate(g_t2) 130 | 131 | temporal_loss = self.calc_temporal_loss(g_t_feats[0], g_t2_feats[0]) 132 | 133 | 134 | return loss_c, loss_s, l_identity1, l_identity2, temporal_loss, loss_v 135 | 136 | 137 | 138 | --------------------------------------------------------------------------------