├── LICENSE
├── Config.py
├── DataSplit.py
├── readme.md
├── style_blending.py
├── train.py
├── model.py
├── test_video.py
├── test.py
├── vgg19.py
├── networks.py
├── blocks.py
└── Video_NST.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Sooyoung Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/Config.py:
--------------------------------------------------------------------------------
1 | class Config:
2 | phase = 'train' # You must change the phase into train/test/style_blending
3 | train_continue = 'off' # on / off
4 |
5 | data_num = 60000 # Maximum # of training data
6 |
7 | content_dir = './COCO'
8 | style_dir = './WikiArt'
9 |
10 | file_n = 'main'
11 | log_dir = './log/' + file_n
12 | ckpt_dir = './ckpt/' + file_n
13 | img_dir = './Generated_images/' + file_n
14 |
15 | if phase == 'test':
16 | multi_to_multi = True
17 | test_content_size = 256
18 | test_style_size = 256
19 | content_dir = './testDataset'
20 | style_dir = './testDataset'
21 | img_dir = './output/'+file_n+'/'+str(test_content_size)
22 |
23 | elif phase == 'style_blending':
24 | blend_load_size = 256
25 | blend_dir = './blendingDataset/'
26 | content_img = blend_dir + str(blend_load_size) + '/A.jpg'
27 | style_high_img = blend_dir + str(blend_load_size) + '/B.jpg'
28 | style_low_img = blend_dir + str(blend_load_size) + '/C.jpg'
29 | img_dir = './output/'+file_n+'_blending_' + str(blend_load_size)
30 |
31 | # VGG pre-trained model
32 | vgg_model = './vgg_normalised.pth'
33 |
34 | ## basic parameters
35 | n_iter = 160000
36 | batch_size = 8
37 | lr = 0.0001
38 | lr_policy = 'step'
39 | lr_decay_iters = 50
40 | beta1 = 0.0
41 |
42 | # preprocess parameters
43 | load_size = 512
44 | crop_size = 256
45 |
46 | # model parameters
47 | input_nc = 3 # of input image channel
48 | nf = 64 # of feature map channel after Encoder first layer
49 | output_nc = 3 # of output image channel
50 | style_kernel = 3 # size of style kernel
51 |
52 | # Octave Convolution parameters
53 | alpha_in = 0.5 # input ratio of low-frequency channel
54 | alpha_out = 0.5 # output ratio of low-frequency channel
55 | freq_ratio = [1, 1] # [high, low] ratio at the last layer
56 |
57 | # Loss ratio
58 | lambda_percept = 1.0
59 | lambda_perc_cont = 1.0
60 | lambda_perc_style = 10.0
61 | lambda_const_style = 5.0
62 |
63 | # Else
64 | norm = 'instance'
65 | init_type = 'normal'
66 | init_gain = 0.02
67 | no_dropout = 'store_true'
68 | num_workers = 4
69 |
--------------------------------------------------------------------------------
/DataSplit.py:
--------------------------------------------------------------------------------
1 | from path import Path
2 | import glob
3 | # import torch
4 | import torch.nn as nn
5 | # import pandas as pd
6 | # import numpy as np
7 | from PIL import Image
8 | from torchvision.transforms import ToTensor, Compose, Resize, Normalize, RandomCrop
9 | import random
10 |
11 | Image.MAX_IMAGE_PIXELS = 1000000000
12 |
13 | class DataSplit(nn.Module):
14 | def __init__(self, config, phase='train'):
15 | super(DataSplit, self).__init__()
16 |
17 | self.transform = Compose([Resize(size=[config.load_size, config.load_size]),
18 | RandomCrop(size=(config.crop_size, config.crop_size)),
19 | ToTensor(),
20 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
21 |
22 | if phase == 'train':
23 | # Content image data
24 | img_dir = Path(config.content_dir+'/train')
25 | self.images = self.get_data(img_dir)
26 | if config.data_num < len(self.images):
27 | self.images = random.sample(self.images, config.data_num)
28 |
29 | # Style image data
30 | sty_dir = Path(config.style_dir+'/train')
31 | self.style_images = self.get_data(sty_dir)
32 | if len(self.images) < len(self.style_images):
33 | self.style_images = random.sample(self.style_images, len(self.images))
34 | elif len(self.images) > len(self.style_images):
35 | ratio = len(self.images) // len(self.style_images)
36 | bias = len(self.images) - ratio * len(self.style_images)
37 | self.style_images = self.style_images * ratio
38 | self.style_images += random.sample(self.style_images, bias)
39 | assert len(self.images) == len(self.style_images)
40 |
41 | elif phase == 'test':
42 | img_dir = Path(config.content_dir)
43 | self.images = self.get_data(img_dir)[:config.data_num]
44 |
45 | sty_dir = Path(config.style_dir)
46 | self.style_images = self.get_data(sty_dir)[:config.data_num]
47 |
48 | print('content dir:', img_dir)
49 | print('style dir:', sty_dir)
50 |
51 | def __len__(self):
52 | return len(self.images)
53 |
54 | def get_data(self, img_dir):
55 | file_type = ['*.jpg', '*.png', '*.jpeg', '*.tif']
56 | imgs = []
57 | for ft in file_type:
58 | imgs += sorted(img_dir.glob(ft))
59 | images = sorted(imgs)
60 | return images
61 |
62 | def __getitem__(self, index):
63 | cont_img = self.images[index]
64 | cont_img = Image.open(cont_img).convert('RGB')
65 | cont_img = self.transform(cont_img)
66 |
67 | sty_img = self.style_images[index]
68 | sty_img = Image.open(sty_img).convert('RGB')
69 | sty_img = self.transform(sty_img)
70 |
71 | return {'content_img': cont_img, 'style_img': sty_img}
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # AesFA: An Aesthetic Feature-Aware Arbitrary Neural Style Transfer (AAAI 2024)
2 | Official Pytorch code for "AesFA: An Aesthetic Feature-Aware Arbitrary Neural Style Transfer"
3 |
4 | - Project page: [The Official Website for AesFA.](https://aesfa-nst.github.io/AesFA/)
5 | - arXiv preprint:
6 |
7 | First co-authors
8 | - Joonwoo Kwon (joonkwon96@gmail.com, **pioneers@snu.ac.kr**)
9 | - Sooyoung Kim (sooyyoungg513@gmail.com, **rlatndud0513@snu.ac.kr**)
10 | If one of us doesn't reply, please contact the other :)
11 |
12 | ## Introduction
13 | 
14 | 
15 | Neural style transfer (NST) has evolved significantly in recent years. Yet, despite its rapid progress and advancement, exist- ing NST methods either struggle to transfer aesthetic information from a style effectively or suffer from high computa- tional costs and inefficiencies in feature disentanglement due to using pre-trained models. This work proposes a lightweight but effective model, AesFA—Aesthetic Feature-Aware NST. The primary idea is to decompose the image via its frequencies to better disentangle aesthetic styles from the reference image while training the entire model in an end-to-end manner to exclude pre-trained models at inference completely. To improve the network’s ability to extract more distinct representations and further enhance the stylization quality, this work introduces a new aesthetic feature: contrastive loss. Ex- tensive experiments and ablations show the approach not only outperforms recent NST methods in terms of stylization quality, but it also achieves faster inference.
16 |
17 |
18 | ## Environment:
19 | - python 3.7
20 | - pytorch 1.13.1
21 |
22 | ## Getting Started:
23 | **Clone this repo:**
24 | ```
25 | git clone https://github.com/Sooyyoungg/AesFA
26 | cd AesFA
27 | ```
28 |
29 | **Train:**
30 | - Download dataset [MS-COCO](https://cocodataset.org/#download) for content images and [WikiArt](https://www.kaggle.com/c/painter-by-numbers) for style images.
31 | - Download the pre-trained [vgg_normalised.pth](https://github.com/naoto0804/pytorch-AdaIN/releases/tag/v0.0.0).
32 | - Change the training options in Config.py file.
33 | - The 'phase' must be 'train'.
34 | - The 'train_continue' should be 'on' if you train continuously with the previous model file.
35 | ```python train.py```
36 |
37 | **Test:**
38 | - Download pre-trained AesFA model [main.pth](https://drive.google.com/file/d/1Y3OutPAsmPmJcnZs07ZVbDFf6nn3RzxR/view?usp=drive_link)
39 | - Change options about testing in the Config.py file.
40 | - Change phase into 'test' and other options (ex) data info (num, dir), image load and crop size.
41 | - If you want to use content and style images with different sizes, you can set test_content_size and test_style_size differently.
42 | - Also, you can choose whether you want to translate using multi_to_multi or only translate content images using each style image.
43 | ```python test.py```
44 |
45 |
--------------------------------------------------------------------------------
/style_blending.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop, Normalize, RandomCrop
6 |
7 | from Config import Config
8 | from model import AesFA_test
9 | from blocks import test_model_load
10 |
11 | def im_convert(tensor):
12 | image = tensor.to("cpu").clone().detach().numpy()
13 | image = image.transpose(0, 2, 3, 1)
14 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
15 | image = image.clip(0, 1)
16 | return image
17 |
18 | def do_transform(img, osize):
19 | transform = Compose([Resize(size=osize),
20 | CenterCrop(size=osize),
21 | ToTensor(),
22 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
23 | return transform(img).unsqueeze(0)
24 |
25 | def save_img(config, cont_name, sty_h_name, sty_l_name, content, style_h, style_l, stylized):
26 | real_A = im_convert(content)
27 | real_B_1 = im_convert(style_h)
28 | real_B_2 = im_convert(style_l)
29 | trs_AtoB = im_convert(stylized)
30 |
31 | A_image = Image.fromarray((real_A[0] * 255.0).astype(np.uint8))
32 | B1_image = Image.fromarray((real_B_1[0] * 255.0).astype(np.uint8))
33 | B2_image = Image.fromarray((real_B_2[0] * 255.0).astype(np.uint8))
34 | trs_image = Image.fromarray((trs_AtoB[0] * 255.0).astype(np.uint8))
35 |
36 | cont_name = cont_name.split('/')[-1].split('.')[0]
37 | sty_h_name = sty_h_name.split('/')[-1].split('.')[0]
38 | sty_l_name = sty_l_name.split('/')[-1].split('.')[0]
39 |
40 | A_image.save('{}/{:s}_content.jpg'.format(config.img_dir, cont_name))
41 | B1_image.save('{}/{:s}_high_style.jpg'.format(config.img_dir, sty_h_name))
42 | B2_image.save('{}/{:s}_low_style.jpg'.format(config.img_dir, sty_l_name))
43 | trs_image.save('{}/stylized_{:s}_{:s}_{:s}.jpg'.format(config.img_dir, cont_name, sty_h_name, sty_l_name))
44 |
45 | def main():
46 | config = Config()
47 | if not os.path.exists(config.img_dir):
48 | os.makedirs(config.img_dir)
49 |
50 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
51 | print('Version:', config.file_n)
52 | print(device)
53 |
54 | with torch.no_grad():
55 | ## Model load
56 | model = AesFA_test(config)
57 |
58 | ## Load saved model
59 | ckpt = config.ckpt_dir + '/main.pth'
60 | print("checkpoint: ", ckpt)
61 | model = test_model_load(checkpoint=ckpt, model=model)
62 | model.to(device)
63 |
64 | ## Style Blending
65 | real_A = Image.open(config.content_img).convert('RGB')
66 | style_high = Image.open(config.style_high_img).convert('RGB')
67 | style_low = Image.open(config.style_low_img).convert('RGB')
68 |
69 | real_A = do_transform(real_A, config.blend_load_size).to(device)
70 | style_high = do_transform(style_high, config.blend_load_size).to(device)
71 | style_low = do_transform(style_low, config.blend_load_size).to(device)
72 |
73 | stylized, during = model.style_blending(real_A, style_high, style_low)
74 | save_img(config, config.content_img, config.style_high_img, config.style_low_img, real_A, style_high, style_low, stylized)
75 | print("Time:", during)
76 |
77 | if __name__ == '__main__':
78 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import tensorboardX
5 |
6 | from Config import Config
7 | from DataSplit import DataSplit
8 | from model import AesFA
9 | from blocks import model_save, model_load, update_learning_rate
10 |
11 | from torch.utils.data import RandomSampler
12 |
13 | def mkoutput_dir(config):
14 | if not os.path.exists(config.log_dir):
15 | os.makedirs(config.log_dir)
16 | if not os.path.exists(config.ckpt_dir):
17 | os.makedirs(config.ckpt_dir)
18 |
19 | def get_n_params(model):
20 | total_params=0
21 | net_params = {'netE':0, 'netS':0, 'netG':0, 'vgg_loss':0}
22 |
23 | for name, param in model.named_parameters():
24 | net = name.split('.')[0]
25 | nn=1
26 | for s in list(param.size()):
27 | nn = nn*s
28 | net_params[net] += nn
29 | total_params += nn
30 | return total_params, net_params
31 |
32 | def im_convert(tensor):
33 | image = tensor.to("cpu").clone().detach().numpy()
34 | image = image.transpose(0, 2, 3, 1)
35 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
36 | image = image.clip(0, 1)
37 | return image
38 |
39 | def main():
40 | config = Config()
41 | mkoutput_dir(config)
42 |
43 | config.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44 | print('cuda:', config.device)
45 | print('Version:', config.file_n)
46 |
47 | ########## Data Loader ##########
48 | train_data = DataSplit(config=config, phase='train')
49 | train_sampler = RandomSampler(train_data)
50 | data_loader_train = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=False, sampler=train_sampler)
51 | print("Train: ", train_data.__len__(), "images: ", len(data_loader_train), "x", config.batch_size,"(batch size) =", train_data.__len__())
52 |
53 | ########## load model ##########
54 | model = AesFA(config)
55 | model.to(config.device)
56 |
57 | # # of parameter
58 | param_num, net_params = get_n_params(model)
59 | print("# of parameter:", param_num)
60 | print("parameters of networks:", net_params)
61 |
62 | ########## load saved model - to continue previous learning ##########
63 | if config.train_continue == 'on':
64 | model, model.optimizer_E, model.optimizer_S, model.optimizer_G, epoch_start, tot_itr = model_load(checkpoint=None, ckpt_dir=config.ckpt_dir, model=model,
65 | optim_E=model.optimizer_E,
66 | optim_S=model.optimizer_S,
67 | optim_G=model.optimizer_G)
68 | print(epoch_start, "th epoch ", tot_itr, "th iteration model load")
69 | else:
70 | epoch_start = 0
71 | tot_itr = 0
72 |
73 | train_writer = tensorboardX.SummaryWriter(config.log_dir)
74 |
75 | ########## Training ##########
76 | # to save ckpt file starting with epoch and iteration 1
77 | epoch = epoch_start - 1
78 | tot_itr = tot_itr - 1
79 | while tot_itr < config.n_iter:
80 | epoch += 1
81 |
82 | for i, data in enumerate(data_loader_train):
83 | tot_itr += 1
84 | train_dict = model.train_step(data)
85 |
86 | real_A = im_convert(data['content_img'])
87 | real_B = im_convert(train_dict['style_img'])
88 | fake_B = im_convert(train_dict['fake_AtoB'])
89 | trs_high = im_convert(train_dict['fake_AtoB_high'])
90 | trs_low = im_convert(train_dict['fake_AtoB_low'])
91 |
92 | ## Tensorboard ##
93 | # tensorboard - loss
94 | train_writer.add_scalar('Loss_G', train_dict['G_loss'], tot_itr)
95 | train_writer.add_scalar('Loss_G_Percept', train_dict['G_Percept'], tot_itr)
96 | train_writer.add_scalar('Loss_G_Contrast', train_dict['G_Contrast'], tot_itr)
97 |
98 | # tensorboard - images
99 | train_writer.add_image('Content_Image_A', real_A, tot_itr, dataformats='NHWC')
100 | train_writer.add_image('Style_Image_B', real_B, tot_itr, dataformats='NHWC')
101 | train_writer.add_image('Generated_Image_AtoB', fake_B, tot_itr, dataformats='NHWC')
102 | train_writer.add_image('Translation_AtoB_high', trs_high, tot_itr, dataformats='NHWC')
103 | train_writer.add_image('Translation_AtoB_low', trs_low, tot_itr, dataformats='NHWC')
104 |
105 | print("Tot_itrs: %d/%d | Epoch: %d | itr: %d/%d | Loss_G: %.5f"%(tot_itr+1, config.n_iter, epoch+1, (i+1), len(data_loader_train), train_dict['G_loss']))
106 |
107 | if (tot_itr + 1) % 10000 == 0:
108 | model_save(ckpt_dir=config.ckpt_dir, model=model, optim_E=model.optimizer_E, optim_S=model.optimizer_S, optim_G=model.optimizer_G, epoch=epoch, itr=tot_itr)
109 | print(tot_itr+1, "th iteration model save")
110 |
111 | update_learning_rate(model.E_scheduler, model.optimizer_E)
112 | update_learning_rate(model.S_scheduler, model.optimizer_S)
113 | update_learning_rate(model.G_scheduler, model.optimizer_G)
114 |
115 | if __name__ == '__main__':
116 | main()
117 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import networks
4 | import blocks
5 | import time
6 |
7 | from vgg19 import vgg, VGG_loss
8 | from networks import EFDM_loss
9 |
10 | class AesFA(nn.Module):
11 | def __init__(self, config):
12 | super(AesFA, self).__init__()
13 |
14 | self.config = config
15 | self.device = self.config.device
16 |
17 | self.lr = config.lr
18 | self.lambda_percept = config.lambda_percept
19 | self.lambda_const_style = config.lambda_const_style
20 |
21 | self.netE = networks.define_network(net_type='Encoder', config = config) # Content Encoder
22 | self.netS = networks.define_network(net_type='Encoder', config = config) # Style Encoder
23 | self.netG = networks.define_network(net_type='Generator', config = config)
24 |
25 | self.vgg_loss = VGG_loss(config, vgg)
26 | self.efdm_loss = EFDM_loss()
27 |
28 | self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=self.config.lr, betas=(self.config.beta1, 0.99))
29 | self.optimizer_S = torch.optim.Adam(self.netS.parameters(), lr=self.config.lr, betas=(self.config.beta1, 0.99))
30 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.config.lr, betas=(self.config.beta1, 0.99))
31 |
32 | self.E_scheduler = blocks.get_scheduler(self.optimizer_E, config)
33 | self.S_scheduler = blocks.get_scheduler(self.optimizer_S, config)
34 | self.G_scheduler = blocks.get_scheduler(self.optimizer_G, config)
35 |
36 |
37 | def forward(self, data):
38 | self.real_A = data['content_img'].to(self.device)
39 | self.real_B = data['style_img'].to(self.device)
40 |
41 | self.content_A, _, _ = self.netE(self.real_A)
42 | _, self.style_B, self.content_B_feat = self.netS(self.real_B)
43 | self.style_B_feat = self.content_B_feat.copy()
44 | self.style_B_feat.append(self.style_B)
45 |
46 | self.trs_AtoB, self.trs_AtoB_high, self.trs_AtoB_low = self.netG(self.content_A, self.style_B)
47 |
48 | self.trs_AtoB_content, _, self.content_trs_AtoB_feat = self.netE(self.trs_AtoB)
49 | _, self.trs_AtoB_style, self.style_trs_AtoB_feat = self.netS(self.trs_AtoB)
50 | self.style_trs_AtoB_feat.append(self.trs_AtoB_style)
51 |
52 |
53 | def calc_G_loss(self):
54 | self.G_percept, self.neg_idx = self.vgg_loss.perceptual_loss(self.real_A, self.real_B, self.trs_AtoB)
55 | self.G_percept *= self.lambda_percept
56 |
57 | self.G_contrast = self.efdm_loss(self.content_B_feat, self.style_B_feat, self.content_trs_AtoB_feat, self.style_trs_AtoB_feat, self.neg_idx) * self.lambda_const_style
58 |
59 | self.G_loss = self.G_percept + self.G_contrast
60 |
61 |
62 | def train_step(self, data):
63 | self.set_requires_grad([self.netE, self.netS, self.netG], True)
64 |
65 | self.forward(data)
66 | self.calc_G_loss()
67 |
68 | self.optimizer_E.zero_grad()
69 | self.optimizer_S.zero_grad()
70 | self.optimizer_G.zero_grad()
71 | self.G_loss.backward()
72 | self.optimizer_E.step()
73 | self.optimizer_S.step()
74 | self.optimizer_G.step()
75 |
76 | train_dict = {}
77 | train_dict['G_loss'] = self.G_loss
78 | train_dict['G_Percept'] = self.G_percept
79 | train_dict['G_Contrast'] = self.G_contrast
80 |
81 | train_dict['style_img'] = self.real_B
82 | train_dict['fake_AtoB'] = self.trs_AtoB
83 | train_dict['fake_AtoB_high'] = self.trs_AtoB_high
84 | train_dict['fake_AtoB_low'] = self.trs_AtoB_low
85 |
86 | return train_dict
87 |
88 | def set_requires_grad(self, nets, requires_grad=False):
89 | if not isinstance(nets, list):
90 | nets = [nets]
91 | for net in nets:
92 | if net is not None:
93 | for param in net.parameters():
94 | param.requires_grad = requires_grad
95 |
96 |
97 | class AesFA_test(nn.Module):
98 | def __init__(self, config):
99 | super(AesFA_test, self).__init__()
100 |
101 | self.netE = networks.define_network(net_type='Encoder', config=config)
102 | self.netS = networks.define_network(net_type='Encoder', config=config)
103 | self.netG = networks.define_network(net_type='Generator', config=config)
104 |
105 | def forward(self, real_A, real_B, freq):
106 | with torch.no_grad():
107 | start = time.time()
108 | content_A = self.netE.forward_test(real_A, 'content')
109 | style_B = self.netS.forward_test(real_B, 'style')
110 | if freq:
111 | trs_AtoB, trs_AtoB_high, trs_AtoB_low = self.netG(content_A, style_B)
112 | end = time.time()
113 | during = end - start
114 | return trs_AtoB, trs_AtoB_high, trs_AtoB_low, during
115 | else:
116 | trs_AtoB = self.netG.forward_test(content_A, style_B)
117 | end = time.time()
118 | during = end - start
119 | return trs_AtoB, during
120 |
121 | def style_blending(self, real_A, real_B_1, real_B_2):
122 | with torch.no_grad():
123 | start = time.time()
124 | content_A = self.netE.forward_test(real_A, 'content')
125 | style_B1_h = self.netS.forward_test(real_B_1, 'style')[0]
126 | style_B2_l = self.netS.forward_test(real_B_2, 'style')[1]
127 | style_B = style_B1_h, style_B2_l
128 |
129 | trs_AtoB = self.netG.forward_test(content_A, style_B)
130 | end = time.time()
131 | during = end - start
132 |
133 | return trs_AtoB, during
--------------------------------------------------------------------------------
/test_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import glob
6 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop, Normalize, RandomCrop
7 |
8 | from Config import Config
9 | from DataSplit import DataSplit
10 | from model import AesFA_test
11 | from blocks import test_model_load
12 |
13 |
14 | def load_img(img_name, img_size, device):
15 | img = Image.open(img_name).convert('RGB')
16 | img = do_transform(img, img_size).to(device)
17 | if len(img.shape) == 3:
18 | img = img.unsqueeze(0)
19 | return img
20 |
21 | def im_convert(tensor):
22 | image = tensor.to("cpu").clone().detach().numpy()
23 | image = image.transpose(0, 2, 3, 1)
24 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
25 | image = image.clip(0, 1)
26 |
27 | return image
28 |
29 | def do_transform(img, osize):
30 | # if config.phase == 'test':
31 | # osize = config.test_load_size
32 | # elif config.phase == 'style_blending':
33 | # osize = config.blend_load_size
34 | transform = Compose([Resize(size=osize),
35 | CenterCrop(size=osize),
36 | ToTensor(),
37 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
38 | return transform(img)
39 |
40 | def save_img(config, cont_name, sty_name, content, style, stylized, freq=False, high=None, low=None):
41 | real_A = im_convert(content)
42 | real_B = im_convert(style)
43 | trs_AtoB = im_convert(stylized)
44 |
45 | A_image = Image.fromarray((real_A[0] * 255.0).astype(np.uint8))
46 | B_image = Image.fromarray((real_B[0] * 255.0).astype(np.uint8))
47 | trs_image = Image.fromarray((trs_AtoB[0] * 255.0).astype(np.uint8))
48 |
49 | if config.phase == 'test':
50 | A_image.save('{}/content/{:s}_content.jpg'.format(config.img_dir, cont_name.stem))
51 | B_image.save('{}/style/{:s}_style.jpg'.format(config.img_dir, sty_name.stem))
52 | trs_image.save('{}/stylized/{:s}_stylized_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem))
53 | else:
54 | A_image.save('{}/content/{:s}_content.jpg'.format(config.img_dir, cont_name))
55 | B_image.save('{}/style/{:s}_style.jpg'.format(config.img_dir, sty_name))
56 | trs_image.save('{}/stylized/{:s}_stylized_{:s}.jpg'.format(config.img_dir, cont_name, sty_name))
57 |
58 | if freq:
59 | trs_AtoB_high = im_convert(high)
60 | trs_AtoB_low = im_convert(low)
61 |
62 | trsh_image = Image.fromarray((trs_AtoB_high[0] * 255.0).astype(np.uint8))
63 | trsl_image = Image.fromarray((trs_AtoB_low[0] * 255.0).astype(np.uint8))
64 |
65 | trsh_image.save('{}/{:s}_stylizing_high_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem))
66 | trsl_image.save('{}/{:s}_stylizing_low_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem))
67 |
68 |
69 | def main():
70 | config = Config()
71 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
72 | print('Version:', config.file_n)
73 | print(device)
74 |
75 | with torch.no_grad():
76 | ## Model load
77 | ckpt = config.ckpt_dir + '/main.pth'
78 | print("checkpoint: ", ckpt)
79 | model = AesFA_test(config)
80 | model = test_model_load(checkpoint=ckpt, model=model)
81 | model.to(device)
82 |
83 | if not os.path.exists(config.img_dir):
84 | os.makedirs(config.img_dir)
85 | os.makedirs(config.img_dir+'/content')
86 | os.makedirs(config.img_dir+'/style')
87 | os.makedirs(config.img_dir+'/stylized')
88 |
89 | ## Start Testing
90 | count = 0
91 | t_during = 0
92 | if config.phase == 'test':
93 | ## Data Loader
94 | test_data = DataSplit(config=config, phase='test')
95 | contents = test_data.images
96 | styles = test_data.style_images
97 | print("# of contents:", len(contents))
98 | print("# of styles:", len(styles))
99 |
100 | for idx in range(len(contents)):
101 | cont_name = contents[idx]
102 | content = load_img(cont_name, config.test_content_size, device)
103 |
104 | for i in range(len(styles)):
105 | sty_name = styles[i]
106 | style = load_img(sty_name, config.test_style_size, device)
107 |
108 | freq = False
109 | if freq:
110 | stylized, stylized_high, stylized_low, during = model(content, style, freq)
111 | save_img(config, cont_name, sty_name, content, style, stylized, freq, stylized_high, stylized_low)
112 | else:
113 | stylized, during = model(content, style, freq)
114 | save_img(config, cont_name, sty_name, content, style, stylized)
115 |
116 | count += 1
117 | print(count, idx+1, i+1, during)
118 | t_during += during
119 |
120 | elif config.phase == 'style_blending':
121 | contents = sorted(glob.glob(config.blend_dir+'/content/*.jpg'))
122 | for content in contents:
123 | cont_name = content.split('/')[-1].split('.')[0]
124 | content = load_img(cont_name, config.blend_load_size, device)
125 |
126 | style_h = config.style_high_img
127 | style_l = config.style_low_img
128 |
129 | sty_name = style_h.split('/')[-1].split('.')[0]
130 | style_h = Image.open(style_h).convert('RGB')
131 | style_h = do_transform(config, style_h).to(device)
132 | if len(style_h.shape) == 3:
133 | style_h = style_h.unsqueeze(0)
134 |
135 | style_l = Image.open(style_l).convert('RGB')
136 | style_l = do_transform(config, style_l).to(device)
137 | if len(style_l.shape) == 3:
138 | style_l = style_l.unsqueeze(0)
139 |
140 | stylized, during = model.style_blending(content, style_h, style_l)
141 | save_img(config, cont_name, sty_name, content, style_h, stylized)
142 |
143 | t_during = float(t_during / (len(contents) * len(styles)))
144 | print("[AesFA] Total images:", len(contents) * len(styles), "Avg Testing time:", t_during)
145 |
146 |
147 | if __name__ == '__main__':
148 | main()
149 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import thop
5 | from PIL import Image
6 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop, Normalize, RandomCrop
7 |
8 | from Config import Config
9 | from DataSplit import DataSplit
10 | from model import AesFA_test
11 | from blocks import test_model_load
12 |
13 |
14 | def load_img(img_name, img_size, device):
15 | img = Image.open(img_name).convert('RGB')
16 | img = do_transform(img, img_size).to(device)
17 | if len(img.shape) == 3:
18 | img = img.unsqueeze(0) # make batch dimension
19 | return img
20 |
21 | def im_convert(tensor):
22 | image = tensor.to("cpu").clone().detach().numpy()
23 | image = image.transpose(0, 2, 3, 1)
24 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
25 | image = image.clip(0, 1)
26 | return image
27 |
28 | def do_transform(img, osize):
29 | transform = Compose([Resize(size=osize), # Resize to keep aspect ratio
30 | CenterCrop(size=osize),
31 | ToTensor(),
32 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
33 | return transform(img)
34 |
35 | def save_img(config, cont_name, sty_name, content, style, stylized, freq=False, high=None, low=None):
36 | real_A = im_convert(content)
37 | real_B = im_convert(style)
38 | trs_AtoB = im_convert(stylized)
39 |
40 | A_image = Image.fromarray((real_A[0] * 255.0).astype(np.uint8))
41 | B_image = Image.fromarray((real_B[0] * 255.0).astype(np.uint8))
42 | trs_image = Image.fromarray((trs_AtoB[0] * 255.0).astype(np.uint8))
43 |
44 | A_image.save('{}/{:s}_content_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem))
45 | B_image.save('{}/{:s}_style_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem))
46 | trs_image.save('{}/{:s}_stylized_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem))
47 |
48 | if freq:
49 | trs_AtoB_high = im_convert(high)
50 | trs_AtoB_low = im_convert(low)
51 |
52 | trsh_image = Image.fromarray((trs_AtoB_high[0] * 255.0).astype(np.uint8))
53 | trsl_image = Image.fromarray((trs_AtoB_low[0] * 255.0).astype(np.uint8))
54 |
55 | trsh_image.save('{}/{:s}_stylizing_high_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem))
56 | trsl_image.save('{}/{:s}_stylizing_low_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem))
57 |
58 |
59 | def main():
60 | config = Config()
61 |
62 | config.gpu = 0
63 | device = torch.device('cuda:'+str(config.gpu) if torch.cuda.is_available() else 'cpu')
64 | print('Version:', config.file_n)
65 | print(device)
66 |
67 | with torch.no_grad():
68 | ## Data Loader
69 | test_bs = 1
70 | test_data = DataSplit(config=config, phase='test')
71 | data_loader_test = torch.utils.data.DataLoader(test_data, batch_size=test_bs, shuffle=False, num_workers=16, pin_memory=False)
72 | print("Test: ", test_data.__len__(), "images: ", len(data_loader_test), "x", test_bs, "(batch size) =", test_data.__len__())
73 |
74 | ## Model load
75 | ckpt = config.ckpt_dir + '/main.pth'
76 | print("checkpoint: ", ckpt)
77 | model = AesFA_test(config)
78 | model = test_model_load(checkpoint=ckpt, model=model)
79 | model.to(device)
80 |
81 | if not os.path.exists(config.img_dir):
82 | os.makedirs(config.img_dir)
83 |
84 | ## Start Testing
85 | freq = False # whether save high, low frequency images or not
86 | count = 0
87 | t_during = 0
88 |
89 | contents = test_data.images
90 | styles = test_data.style_images
91 | if config.multi_to_multi: # one content image, N style image
92 | tot_imgs = len(contents) * len(styles)
93 | for idx in range(len(contents)):
94 | cont_name = contents[idx] # path of content image
95 | content = load_img(cont_name, config.test_content_size, device)
96 |
97 | for i in range(len(styles)):
98 | sty_name = styles[i] # path of style image
99 | style = load_img(sty_name, config.test_style_size, device)
100 |
101 | if freq:
102 | stylized, stylized_high, stylized_low, during = model(content, style, freq)
103 | save_img(config, cont_name, sty_name, content, style, stylized, freq, stylized_high, stylized_low)
104 | else:
105 | stylized, during = model(content, style, freq)
106 | save_img(config, cont_name, sty_name, content, style, stylized)
107 |
108 | count += 1
109 | print(count, idx+1, i+1, during)
110 | t_during += during
111 | flops, params = thop.profile(model, inputs=(content, style, freq))
112 | print("GFLOPS: %.4f, Params: %.4f"% (flops/1e9, params/1e6))
113 | print("Max GPU memory allocated: %.4f GB" % (torch.cuda.max_memory_allocated(device=config.gpu) / 1024. / 1024. / 1024.))
114 |
115 | else:
116 | tot_imgs = len(contents)
117 | for idx in range(len(contents)):
118 | cont_name = contents[idx]
119 | content = load_img(cont_name, config.test_content_size, device)
120 |
121 | sty_name = styles[idx]
122 | style = load_img(sty_name, config.test_style_size, device)
123 |
124 | if freq:
125 | stylized, stylized_high, stylized_low, during = model(content, style, freq)
126 | save_img(config, cont_name, sty_name, content, style, stylized, freq, stylized_high, stylized_low)
127 | else:
128 | stylized, during = model(content, style, freq)
129 | save_img(config, cont_name, sty_name, content, style, stylized)
130 |
131 | t_during += during
132 | flops, params = thop.profile(model, inputs=(content, style, freq))
133 | print("GFLOPS: %.4f, Params: %.4f" % (flops / 1e9, params / 1e6))
134 | print("Max GPU memory allocated: %.4f GB" % (torch.cuda.max_memory_allocated(device=config.gpu) / 1024. / 1024. / 1024.))
135 |
136 |
137 | t_during = float(t_during / (len(contents) * len(styles)))
138 | print("[AesFA] Content size:", config.test_content_size, "Style size:", config.test_style_size,
139 | " Total images:", tot_imgs, "Avg Testing time:", t_during)
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
--------------------------------------------------------------------------------
/vgg19.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | vgg = nn.Sequential(
6 | nn.Conv2d(3, 3, (1, 1)),
7 | nn.ReflectionPad2d((1, 1, 1, 1)),
8 | nn.Conv2d(3, 64, (3, 3)),
9 |
10 | nn.ReLU(), # relu1-1
11 |
12 | nn.ReflectionPad2d((1, 1, 1, 1)),
13 | nn.Conv2d(64, 64, (3, 3)),
14 | nn.ReLU(), # relu1-2
15 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
16 | nn.ReflectionPad2d((1, 1, 1, 1)),
17 | nn.Conv2d(64, 128, (3, 3)),
18 |
19 | nn.ReLU(), # relu2-1
20 |
21 | nn.ReflectionPad2d((1, 1, 1, 1)),
22 | nn.Conv2d(128, 128, (3, 3)),
23 | nn.ReLU(), # relu2-2
24 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
25 | nn.ReflectionPad2d((1, 1, 1, 1)),
26 | nn.Conv2d(128, 256, (3, 3)),
27 |
28 | nn.ReLU(), # relu3-1
29 |
30 | nn.ReflectionPad2d((1, 1, 1, 1)),
31 | nn.Conv2d(256, 256, (3, 3)),
32 | nn.ReLU(), # relu3-2
33 | nn.ReflectionPad2d((1, 1, 1, 1)),
34 | nn.Conv2d(256, 256, (3, 3)),
35 | nn.ReLU(), # relu3-3
36 | nn.ReflectionPad2d((1, 1, 1, 1)),
37 | nn.Conv2d(256, 256, (3, 3)),
38 | nn.ReLU(), # relu3-4
39 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
40 | nn.ReflectionPad2d((1, 1, 1, 1)),
41 | nn.Conv2d(256, 512, (3, 3)),
42 |
43 | nn.ReLU(), # relu4-1, this is the last layer used
44 |
45 | nn.ReflectionPad2d((1, 1, 1, 1)),
46 | nn.Conv2d(512, 512, (3, 3)),
47 | nn.ReLU(), # relu4-2
48 | nn.ReflectionPad2d((1, 1, 1, 1)),
49 | nn.Conv2d(512, 512, (3, 3)),
50 | nn.ReLU(), # relu4-3
51 | nn.ReflectionPad2d((1, 1, 1, 1)),
52 | nn.Conv2d(512, 512, (3, 3)),
53 | nn.ReLU(), # relu4-4
54 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
55 | nn.ReflectionPad2d((1, 1, 1, 1)),
56 | nn.Conv2d(512, 512, (3, 3)),
57 |
58 | nn.ReLU(), # relu5-1
59 |
60 | nn.ReflectionPad2d((1, 1, 1, 1)),
61 | nn.Conv2d(512, 512, (3, 3)),
62 | nn.ReLU(), # relu5-2
63 | nn.ReflectionPad2d((1, 1, 1, 1)),
64 | nn.Conv2d(512, 512, (3, 3)),
65 | nn.ReLU(), # relu5-3
66 | nn.ReflectionPad2d((1, 1, 1, 1)),
67 | nn.Conv2d(512, 512, (3, 3)),
68 | nn.ReLU() # relu5-4
69 | )
70 |
71 |
72 | class VGG_loss(nn.Module):
73 | def __init__(self, config, vgg):
74 | super(VGG_loss, self).__init__()
75 |
76 | self.config = config
77 |
78 | vgg_pretrained = config.vgg_model
79 | vgg.load_state_dict(torch.load(vgg_pretrained))
80 | vgg = nn.Sequential(*list(vgg.children())[:43]) # depends on what layers you want to load
81 | vgg_enc_layers = list(vgg.children())
82 |
83 | self.n_layers = 4
84 | self.vgg_enc_1 = nn.Sequential(*vgg_enc_layers[:3]) # ~ conv1_1
85 | self.vgg_enc_2 = nn.Sequential(*vgg_enc_layers[3:10]) # conv1_1 ~ conv2_1
86 | self.vgg_enc_3 = nn.Sequential(*vgg_enc_layers[10:17]) # conv2_1 ~ conv3_1
87 | self.vgg_enc_4 = nn.Sequential(*vgg_enc_layers[17:30]) # conv3_1 ~ conv4_1
88 |
89 | self.mse_loss = nn.MSELoss()
90 |
91 | for name in ['vgg_enc_1', 'vgg_enc_2', 'vgg_enc_3', 'vgg_enc_4']:
92 | for param in getattr(self, name).parameters():
93 | param.requires_grad = False
94 |
95 | # extract relu1_1, relu2_1, relu3_1, relu4_1
96 | def encode_with_vgg_intermediate(self, input):
97 | results = [input]
98 | for i in range(self.n_layers):
99 | func = getattr(self, 'vgg_enc_{:d}'.format(i + 1))
100 | results.append(func(results[-1]))
101 | return results[1:]
102 |
103 | # extract relu3_1
104 | def encode_vgg_content(self, input):
105 | for i in range(3):
106 | input = getattr(self, 'vgg_enc_{:d}'.format(i + 1))(input)
107 | return input
108 |
109 | def calc_content_loss(self, input, target):
110 | assert (input.size() == target.size())
111 | return self.mse_loss(input, target)
112 |
113 | def efdm_single(self, style, trans):
114 | B, C, W, H = style.size(0), style.size(1), style.size(2), style.size(3)
115 |
116 | value_style, index_style = torch.sort(style.view(B, C, -1))
117 | value_trans, index_trans = torch.sort(trans.view(B, C, -1))
118 | inverse_index = index_trans.argsort(-1)
119 |
120 | return self.mse_loss(trans.view(B, C,-1), value_style.gather(-1, inverse_index))
121 |
122 | def perceptual_loss(self, content, style, trs_img):
123 | # normalization for putting images as inputs to VGG
124 | content = content.permute(0, 2, 3, 1)
125 | style = style.permute(0, 2, 3, 1)
126 | trs_img = trs_img.permute(0, 2, 3, 1)
127 |
128 | content = content * torch.from_numpy(np.array((0.229, 0.224, 0.225))).to(content.device) + torch.from_numpy(np.array((0.485, 0.456, 0.406))).to(content.device)
129 | style = style * torch.from_numpy(np.array((0.229, 0.224, 0.225))).to(style.device) + torch.from_numpy(np.array((0.485, 0.456, 0.406))).to(style.device)
130 | trs_img = trs_img * torch.from_numpy(np.array((0.229, 0.224, 0.225))).to(trs_img.device) + torch.from_numpy(np.array((0.485, 0.456, 0.406))).to(trs_img.device)
131 |
132 | content = content.permute(0, 3, 1, 2).float()
133 | style = style.permute(0, 3, 1, 2).float()
134 | trs_img = trs_img.permute(0, 3, 1, 2).float()
135 |
136 | # loss
137 | content_feats_vgg = self.encode_vgg_content(content)
138 | style_feats_vgg = self.encode_with_vgg_intermediate(style)
139 | trs_feats_vgg = self.encode_with_vgg_intermediate(trs_img)
140 |
141 | loss_c = self.calc_content_loss(trs_feats_vgg[-2], content_feats_vgg)
142 | loss_s = self.efdm_single(trs_feats_vgg[0], style_feats_vgg[0])
143 | for i in range(1, self.n_layers):
144 | loss_s = loss_s + self.efdm_single(trs_feats_vgg[i], style_feats_vgg[i])
145 |
146 | loss = loss_c * self.config.lambda_perc_cont + loss_s * self.config.lambda_perc_style
147 |
148 | # EFDM negative pair
149 | neg_idx = []
150 | batch = content.shape[0]
151 | for a in range(batch):
152 | neg_lst = {}
153 | for b in range(batch): # for each image pair
154 | if a != b:
155 | loss_s_single = 0
156 | for i in range(0, self.n_layers): # for each vgg layer
157 | loss_s_single += self.efdm_single(trs_feats_vgg[i][a].unsqueeze(0), style_feats_vgg[i][b].unsqueeze(0))
158 | neg_lst[b] = loss_s_single
159 | neg_lst = sorted(neg_lst, key=neg_lst.get)
160 | # neg_idx.append(neg_lst[:3])
161 | neg_idx.append([neg_lst[0]])
162 |
163 | return loss, neg_idx
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from blocks import *
6 |
7 | def define_network(net_type, config = None):
8 | net = None
9 | alpha_in = config.alpha_in
10 | alpha_out = config.alpha_out
11 | sk = config.style_kernel
12 |
13 | if net_type == 'Encoder':
14 | net = Encoder(in_dim=config.input_nc, nf=config.nf, style_kernel=[sk, sk], alpha_in=alpha_in, alpha_out=alpha_out)
15 | elif net_type == 'Generator':
16 | net = Decoder(nf=config.nf, out_dim=config.output_nc, style_channel=256, style_kernel=[sk, sk, 3], alpha_in=alpha_in, freq_ratio=config.freq_ratio, alpha_out=alpha_out)
17 | return net
18 |
19 | class Encoder(nn.Module):
20 | def __init__(self, in_dim, nf=64, style_kernel=[3, 3], alpha_in=0.5, alpha_out=0.5):
21 | super(Encoder, self).__init__()
22 |
23 | self.conv = nn.Conv2d(in_channels=in_dim, out_channels=nf, kernel_size=7, stride=1, padding=3)
24 |
25 | self.OctConv1_1 = OctConv(in_channels=nf, out_channels=nf, kernel_size=3, stride=2, padding=1, groups=64, alpha_in=alpha_in, alpha_out=alpha_out, type="first")
26 | self.OctConv1_2 = OctConv(in_channels=nf, out_channels=2*nf, kernel_size=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
27 | self.OctConv1_3 = OctConv(in_channels=2*nf, out_channels=2*nf, kernel_size=3, stride=1, padding=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
28 |
29 | self.OctConv2_1 = OctConv(in_channels=2*nf, out_channels=2*nf, kernel_size=3, stride=2, padding=1, groups=128, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
30 | self.OctConv2_2 = OctConv(in_channels=2*nf, out_channels=4*nf, kernel_size=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
31 | self.OctConv2_3 = OctConv(in_channels=4*nf, out_channels=4*nf, kernel_size=3, stride=1, padding=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
32 |
33 | self.pool_h = nn.AdaptiveAvgPool2d((style_kernel[0], style_kernel[0]))
34 | self.pool_l = nn.AdaptiveAvgPool2d((style_kernel[1], style_kernel[1]))
35 |
36 | self.relu = Oct_conv_lreLU()
37 |
38 | def forward(self, x):
39 | enc_feat = []
40 | out = self.conv(x)
41 |
42 | out = self.OctConv1_1(out)
43 | out = self.relu(out)
44 | out = self.OctConv1_2(out)
45 | out = self.relu(out)
46 | out = self.OctConv1_3(out)
47 | out = self.relu(out)
48 | enc_feat.append(out)
49 |
50 | out = self.OctConv2_1(out)
51 | out = self.relu(out)
52 | out = self.OctConv2_2(out)
53 | out = self.relu(out)
54 | out = self.OctConv2_3(out)
55 | out = self.relu(out)
56 | enc_feat.append(out)
57 |
58 | out_high, out_low = out
59 | out_sty_h = self.pool_h(out_high)
60 | out_sty_l = self.pool_l(out_low)
61 | out_sty = out_sty_h, out_sty_l
62 |
63 | return out, out_sty, enc_feat
64 |
65 | def forward_test(self, x, cond):
66 | out = self.conv(x)
67 |
68 | out = self.OctConv1_1(out)
69 | out = self.relu(out)
70 | out = self.OctConv1_2(out)
71 | out = self.relu(out)
72 | out = self.OctConv1_3(out)
73 | out = self.relu(out)
74 |
75 | out = self.OctConv2_1(out)
76 | out = self.relu(out)
77 | out = self.OctConv2_2(out)
78 | out = self.relu(out)
79 | out = self.OctConv2_3(out)
80 | out = self.relu(out)
81 |
82 | if cond == 'style':
83 | out_high, out_low = out
84 | out_sty_h = self.pool_h(out_high)
85 | out_sty_l = self.pool_l(out_low)
86 | return out_sty_h, out_sty_l
87 | else:
88 | return out
89 |
90 | class Decoder(nn.Module):
91 | def __init__(self, nf=64, out_dim=3, style_channel=512, style_kernel=[3, 3, 3], alpha_in=0.5, alpha_out=0.5, freq_ratio=[1,1], pad_type='reflect'):
92 | super(Decoder, self).__init__()
93 |
94 | group_div = [1, 2, 4, 8]
95 | self.up_oct = Oct_conv_up(scale_factor=2)
96 |
97 | self.AdaOctConv1_1 = AdaOctConv(in_channels=4*nf, out_channels=4*nf, group_div=group_div[0], style_channels=style_channel, kernel_size=style_kernel, stride=1, padding=1, oct_groups=4*nf, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
98 | self.OctConv1_2 = OctConv(in_channels=4*nf, out_channels=2*nf, kernel_size=1, stride=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
99 | self.oct_conv_aftup_1 = Oct_Conv_aftup(in_channels=2*nf, out_channels=2*nf, kernel_size=3, stride=1, padding=1, pad_type=pad_type, alpha_in=alpha_in, alpha_out=alpha_out)
100 |
101 | self.AdaOctConv2_1 = AdaOctConv(in_channels=2*nf, out_channels=2*nf, group_div=group_div[1], style_channels=style_channel, kernel_size=style_kernel, stride=1, padding=1, oct_groups=2*nf, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
102 | self.OctConv2_2 = OctConv(in_channels=2*nf, out_channels=nf, kernel_size=1, stride=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
103 | self.oct_conv_aftup_2 = Oct_Conv_aftup(nf, nf, 3, 1, 1, pad_type, alpha_in, alpha_out)
104 |
105 | self.AdaOctConv3_1 = AdaOctConv(in_channels=nf, out_channels=nf, group_div=group_div[2], style_channels=style_channel, kernel_size=style_kernel, stride=1, padding=1, oct_groups=nf, alpha_in=alpha_in, alpha_out=alpha_out, type="normal")
106 | self.OctConv3_2 = OctConv(in_channels=nf, out_channels=nf//2, kernel_size=1, stride=1, alpha_in=alpha_in, alpha_out=alpha_out, type="last", freq_ratio=freq_ratio)
107 |
108 | self.conv4 = nn.Conv2d(in_channels=nf//2, out_channels=out_dim, kernel_size=1)
109 |
110 | def forward(self, content, style):
111 | out = self.AdaOctConv1_1(content, style)
112 | out = self.OctConv1_2(out)
113 | out = self.up_oct(out)
114 | out = self.oct_conv_aftup_1(out)
115 |
116 | out = self.AdaOctConv2_1(out, style)
117 | out = self.OctConv2_2(out)
118 | out = self.up_oct(out)
119 | out = self.oct_conv_aftup_2(out)
120 |
121 | out = self.AdaOctConv3_1(out, style)
122 | out = self.OctConv3_2(out)
123 | out, out_high, out_low = out
124 |
125 | out = self.conv4(out)
126 | out_high = self.conv4(out_high)
127 | out_low = self.conv4(out_low)
128 |
129 | return out, out_high, out_low
130 |
131 | def forward_test(self, content, style):
132 | out = self.AdaOctConv1_1(content, style, 'test')
133 | out = self.OctConv1_2(out)
134 | out = self.up_oct(out)
135 | out = self.oct_conv_aftup_1(out)
136 |
137 | out = self.AdaOctConv2_1(out, style, 'test')
138 | out = self.OctConv2_2(out)
139 | out = self.up_oct(out)
140 | out = self.oct_conv_aftup_2(out)
141 |
142 | out = self.AdaOctConv3_1(out, style, 'test')
143 | out = self.OctConv3_2(out)
144 |
145 | out = self.conv4(out[0])
146 | return out
147 |
148 |
149 | ############## Contrastive Loss function ##############
150 | def calc_mean_std(feat, eps=1e-5):
151 | # eps is a small value added to the variance to avoid divide-by-zero.
152 | size = feat.size()
153 | assert (len(size) == 4)
154 | N, C = size[:2]
155 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
156 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
157 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
158 | return feat_mean, feat_std
159 |
160 | def calc_content_loss(input, target):
161 | assert (input.size() == target.size())
162 | mse_loss = nn.MSELoss()
163 | return mse_loss(input, target)
164 |
165 | def calc_style_loss(input, target):
166 | assert (input.size() == target.size())
167 | mse_loss = nn.MSELoss()
168 | input_mean, input_std = calc_mean_std(input)
169 | target_mean, target_std = calc_mean_std(target)
170 |
171 | loss = mse_loss(input_mean, target_mean) + \
172 | mse_loss(input_std, target_std)
173 | return loss
174 |
175 | class EFDM_loss(nn.Module):
176 | def __init__(self):
177 | super(EFDM_loss, self).__init__()
178 | self.mse_loss = nn.MSELoss()
179 |
180 | def efdm_single(self, style, trans):
181 | B, C, W, H = style.size(0), style.size(1), style.size(2), style.size(3)
182 |
183 | value_style, index_style = torch.sort(style.view(B, C, -1))
184 | value_trans, index_trans = torch.sort(trans.view(B, C, -1))
185 | inverse_index = index_trans.argsort(-1)
186 |
187 | return self.mse_loss(trans.view(B, C,-1), value_style.gather(-1, inverse_index))
188 |
189 | def forward(self, style_E, style_S, translate_E, translate_S, neg_idx):
190 | loss = 0.
191 | batch = style_E[0][0].shape[0]
192 | for b in range(batch):
193 | poss_loss = 0.
194 | neg_loss = 0.
195 |
196 | # Positive loss
197 | for i in range(len(style_E)):
198 | poss_loss += self.efdm_single(style_E[i][0][b].unsqueeze(0), translate_E[i][0][b].unsqueeze(0)) + \
199 | self.efdm_single(style_E[i][1][b].unsqueeze(0), translate_E[i][1][b].unsqueeze(0))
200 | for i in range(len(style_S)):
201 | poss_loss += self.efdm_single(style_S[i][0][b].unsqueeze(0), translate_S[i][0][b].unsqueeze(0)) + \
202 | self.efdm_single(style_S[i][1][b].unsqueeze(0), translate_S[i][1][b].unsqueeze(0))
203 |
204 | # Negative loss
205 | for nb in neg_idx[b]:
206 | for i in range(len(style_E)):
207 | neg_loss += self.efdm_single(style_E[i][0][nb].unsqueeze(0), translate_E[i][0][b].unsqueeze(0)) + \
208 | self.efdm_single(style_E[i][1][nb].unsqueeze(0), translate_E[i][1][b].unsqueeze(0))
209 | for i in range(len(style_S)):
210 | neg_loss += self.efdm_single(style_S[i][0][nb].unsqueeze(0), translate_S[i][0][b].unsqueeze(0)) + \
211 | self.efdm_single(style_S[i][1][nb].unsqueeze(0), translate_S[i][1][b].unsqueeze(0))
212 |
213 | loss += poss_loss / neg_loss
214 |
215 | return loss
--------------------------------------------------------------------------------
/blocks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from path import Path
4 | import math
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torch.optim import lr_scheduler
9 | from blocks import *
10 |
11 | def model_save(ckpt_dir, model, optim_E, optim_S, optim_G, epoch, itr=None):
12 | if not os.path.exists(ckpt_dir):
13 | os.makedirs(ckpt_dir)
14 |
15 | torch.save({'netE': model.netE.state_dict(),
16 | 'netS': model.netS.state_dict(),
17 | 'netG': model.netG.state_dict(),
18 | 'optim_E': optim_E.state_dict(),
19 | 'optim_S': optim_S.state_dict(),
20 | 'optim_G': optim_G.state_dict()},
21 | '%s/model_iter_%d_epoch_%d.pth' % (ckpt_dir, itr+1, epoch+1))
22 |
23 | def model_load(checkpoint, ckpt_dir, model, optim_E, optim_S, optim_G):
24 | if not os.path.exists(ckpt_dir):
25 | epoch = -1
26 | return model, optim_E, optim_S, optim_G, epoch
27 |
28 | ckpt_path = Path(ckpt_dir)
29 | if checkpoint:
30 | model_ckpt = ckpt_path + '/' + checkpoint
31 | else:
32 | ckpt_lst = ckpt_path.glob('model_iter_*')
33 | ckpt_lst.sort(key=lambda x: int(x.split('iter_')[1].split('_epoch')[0]))
34 | model_ckpt = ckpt_lst[-1]
35 | itr = int(model_ckpt.split('iter_')[1].split('_epoch_')[0])
36 | epoch = int(model_ckpt.split('iter_')[1].split('_epoch_')[1].split('.')[0])
37 | print(model_ckpt)
38 |
39 | dict_model = torch.load(model_ckpt)
40 |
41 | model.netE.load_state_dict(dict_model['netE'])
42 | model.netS.load_state_dict(dict_model['netS'])
43 | model.netG.load_state_dict(dict_model['netG'])
44 | optim_E.load_state_dict(dict_model['optim_E'])
45 | optim_S.load_state_dict(dict_model['optim_S'])
46 | optim_G.load_state_dict(dict_model['optim_G'])
47 |
48 | return model, optim_E, optim_S, optim_G, epoch, itr
49 |
50 | def test_model_load(checkpoint, model):
51 | dict_model = torch.load(checkpoint)
52 | model.netE.load_state_dict(dict_model['netE'])
53 | model.netS.load_state_dict(dict_model['netS'])
54 | model.netG.load_state_dict(dict_model['netG'])
55 | return model
56 |
57 | def get_scheduler(optimizer, config):
58 | if config.lr_policy == 'lambda':
59 | def lambda_rule(epoch):
60 | lr_l = 1.0 - max(0, epoch + config.n_epoch - config.n_iter) / float(config.n_iter_decay + 1)
61 | return lr_l
62 |
63 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
64 | elif config.lr_policy == 'step':
65 | scheduler = lr_scheduler.StepLR(optimizer, step_size=config.lr_decay_iters, gamma=0.1)
66 | elif config.lr_policy == 'plateau':
67 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
68 | elif config.lr_policy == 'cosine':
69 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.n_iter, eta_min=0)
70 | else:
71 | return NotImplementedError('learning rate policy [%s] is not implemented', config.lr_policy)
72 | return scheduler
73 |
74 | def update_learning_rate(scheduler, optimizer):
75 | scheduler.step()
76 | lr = optimizer.param_groups[0]['lr']
77 | print('learning rate = %.7f' % lr)
78 |
79 | class Oct_Conv_aftup(nn.Module):
80 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, pad_type, alpha_in, alpha_out):
81 | super(Oct_Conv_aftup, self).__init__()
82 | lf_in = int(in_channels*alpha_in)
83 | lf_out = int(out_channels*alpha_out)
84 | hf_in = in_channels - lf_in
85 | hf_out = out_channels - lf_out
86 |
87 | self.conv_h = nn.Conv2d(in_channels=hf_in, out_channels=hf_out, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type)
88 | self.conv_l = nn.Conv2d(in_channels=lf_in, out_channels=lf_out, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type)
89 |
90 | def forward(self, x):
91 | hf, lf = x
92 | hf = self.conv_h(hf)
93 | lf = self.conv_l(lf)
94 | return hf, lf
95 |
96 | class Oct_conv_reLU(nn.ReLU):
97 | def forward(self, x):
98 | hf, lf = x
99 | hf = super(Oct_conv_reLU, self).forward(hf)
100 | lf = super(Oct_conv_reLU, self).forward(lf)
101 | return hf, lf
102 |
103 | class Oct_conv_lreLU(nn.LeakyReLU):
104 | def forward(self, x):
105 | hf, lf = x
106 | hf = super(Oct_conv_lreLU, self).forward(hf)
107 | lf = super(Oct_conv_lreLU, self).forward(lf)
108 | return hf, lf
109 |
110 | class Oct_conv_up(nn.Upsample):
111 | def forward(self, x):
112 | hf, lf = x
113 | hf = super(Oct_conv_up, self).forward(hf)
114 | lf = super(Oct_conv_up, self).forward(lf)
115 | return hf, lf
116 |
117 |
118 | ############## Encoder ##############
119 | class OctConv(nn.Module):
120 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
121 | padding=0, groups=1, pad_type='reflect', alpha_in=0.5, alpha_out=0.5, type='normal', freq_ratio = [1, 1]):
122 | super(OctConv, self).__init__()
123 | self.kernel_size = kernel_size
124 | self.stride = stride
125 | self.type = type
126 | self.alpha_in = alpha_in
127 | self.alpha_out = alpha_out
128 | self.freq_ratio = freq_ratio
129 |
130 | hf_ch_in = int(in_channels * (1 - self.alpha_in))
131 | hf_ch_out = int(out_channels * (1 -self. alpha_out))
132 | lf_ch_in = in_channels - hf_ch_in
133 | lf_ch_out = out_channels - hf_ch_out
134 |
135 | self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
136 | self.upsample = nn.Upsample(scale_factor=2)
137 |
138 | self.is_dw = groups == in_channels
139 |
140 | if type == 'first':
141 | self.convh = nn.Conv2d(in_channels, hf_ch_out, kernel_size=kernel_size,
142 | stride=stride, padding=padding, padding_mode=pad_type, bias = False)
143 | self.convl = nn.Conv2d(in_channels, lf_ch_out,
144 | kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False)
145 | elif type == 'last':
146 | self.convh = nn.Conv2d(hf_ch_in, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False)
147 | self.convl = nn.Conv2d(lf_ch_in, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False)
148 | else:
149 | self.L2L = nn.Conv2d(
150 | lf_ch_in, lf_ch_out,
151 | kernel_size=kernel_size, stride=stride, padding=padding, groups=math.ceil(alpha_in * groups), padding_mode=pad_type, bias=False
152 | )
153 | if self.is_dw:
154 | self.L2H = None
155 | self.H2L = None
156 | else:
157 | self.L2H = nn.Conv2d(
158 | lf_ch_in, hf_ch_out,
159 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, padding_mode=pad_type, bias=False
160 | )
161 | self.H2L = nn.Conv2d(
162 | hf_ch_in, lf_ch_out,
163 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, padding_mode=pad_type, bias=False
164 | )
165 | self.H2H = nn.Conv2d(
166 | hf_ch_in, hf_ch_out,
167 | kernel_size=kernel_size, stride=stride, padding=padding, groups=math.ceil(groups - alpha_in * groups), padding_mode=pad_type, bias=False
168 | )
169 |
170 | def forward(self, x):
171 | if self.type == 'first':
172 | hf = self.convh(x)
173 | lf = self.avg_pool(x)
174 | lf = self.convl(lf)
175 | return hf, lf
176 | elif self.type == 'last':
177 | hf, lf = x
178 | out_h = self.convh(hf)
179 | out_l = self.convl(self.upsample(lf))
180 | output = out_h * self.freq_ratio[0] + out_l * self.freq_ratio[1]
181 | return output, out_h, out_l
182 | else:
183 | hf, lf = x
184 | if self.is_dw:
185 | hf, lf = self.H2H(hf), self.L2L(lf)
186 | else:
187 | hf, lf = self.H2H(hf) + self.L2H(self.upsample(lf)), self.L2L(lf) + self.H2L(self.avg_pool(hf))
188 | return hf, lf
189 |
190 |
191 | ############## Decoder ##############
192 | class AdaOctConv(nn.Module):
193 | def __init__(self, in_channels, out_channels, group_div, style_channels, kernel_size,
194 | stride, padding, oct_groups, alpha_in, alpha_out, type='normal'):
195 | super(AdaOctConv, self).__init__()
196 | self.in_channels = in_channels
197 | self.alpha_in = alpha_in
198 | self.alpha_out = alpha_out
199 | self.type = type
200 |
201 | h_in = int(in_channels * (1 - self.alpha_in))
202 | l_in = in_channels - h_in
203 |
204 | n_groups_h = h_in // group_div
205 | n_groups_l = l_in // group_div
206 |
207 | style_channels_h = int(style_channels * (1 - self.alpha_in))
208 | style_channels_l = int(style_channels - style_channels_h)
209 |
210 | kernel_size_h = kernel_size[0]
211 | kernel_size_l = kernel_size[1]
212 | kernel_size_A = kernel_size[2]
213 |
214 | self.kernelPredictor_h = KernelPredictor(in_channels=h_in,
215 | out_channels=h_in,
216 | n_groups=n_groups_h,
217 | style_channels=style_channels_h,
218 | kernel_size=kernel_size_h)
219 | self.kernelPredictor_l = KernelPredictor(in_channels=l_in,
220 | out_channels=l_in,
221 | n_groups=n_groups_l,
222 | style_channels=style_channels_l,
223 | kernel_size=kernel_size_l)
224 |
225 | self.AdaConv_h = AdaConv2d(in_channels=h_in, out_channels=h_in, n_groups=n_groups_h)
226 | self.AdaConv_l = AdaConv2d(in_channels=l_in, out_channels=l_in, n_groups=n_groups_l)
227 |
228 | self.OctConv = OctConv(in_channels=in_channels,
229 | out_channels=out_channels,
230 | kernel_size=kernel_size_A, stride=stride, padding=padding, groups=oct_groups,
231 | alpha_in=alpha_in, alpha_out=alpha_out, type=type)
232 |
233 | self.relu = Oct_conv_lreLU()
234 |
235 | def forward(self, content, style, cond='train'):
236 | c_hf, c_lf = content
237 | s_hf, s_lf = style
238 | h_w_spatial, h_w_pointwise, h_bias = self.kernelPredictor_h(s_hf)
239 | l_w_spatial, l_w_pointwise, l_bias = self.kernelPredictor_l(s_lf)
240 |
241 | if cond == 'train':
242 | output_h = self.AdaConv_h(c_hf, h_w_spatial, h_w_pointwise, h_bias)
243 | output_l = self.AdaConv_l(c_lf, l_w_spatial, l_w_pointwise, l_bias)
244 | output = output_h, output_l
245 |
246 | output = self.relu(output)
247 |
248 | output = self.OctConv(output)
249 | if self.type != 'last':
250 | output = self.relu(output)
251 | return output
252 |
253 | if cond == 'test':
254 | output_h = self.AdaConv_h(c_hf, h_w_spatial, h_w_pointwise, h_bias)
255 | output_l = self.AdaConv_l(c_lf, l_w_spatial, l_w_pointwise, l_bias)
256 | output = output_h, output_l
257 | output = self.relu(output)
258 | output = self.OctConv(output)
259 | if self.type != 'last':
260 | output = self.relu(output)
261 | return output
262 |
263 | class KernelPredictor(nn.Module):
264 | def __init__(self, in_channels, out_channels, n_groups, style_channels, kernel_size):
265 | super(KernelPredictor, self).__init__()
266 |
267 | self.in_channels = in_channels
268 | self.out_channels = out_channels
269 | self.n_groups = n_groups
270 | self.w_channels = style_channels
271 | self.kernel_size = kernel_size
272 |
273 | padding = (kernel_size - 1) / 2
274 | self.spatial = nn.Conv2d(style_channels,
275 | in_channels * out_channels // n_groups,
276 | kernel_size=kernel_size,
277 | padding=(math.ceil(padding), math.ceil(padding)),
278 | padding_mode='reflect')
279 | self.pointwise = nn.Sequential(
280 | nn.AdaptiveAvgPool2d((1, 1)),
281 | nn.Conv2d(style_channels,
282 | out_channels * out_channels // n_groups,
283 | kernel_size=1)
284 | )
285 | self.bias = nn.Sequential(
286 | nn.AdaptiveAvgPool2d((1, 1)),
287 | nn.Conv2d(style_channels,
288 | out_channels,
289 | kernel_size=1)
290 | )
291 |
292 | def forward(self, w):
293 | w_spatial = self.spatial(w)
294 | w_spatial = w_spatial.reshape(len(w),
295 | self.out_channels,
296 | self.in_channels // self.n_groups,
297 | self.kernel_size, self.kernel_size)
298 |
299 | w_pointwise = self.pointwise(w)
300 | w_pointwise = w_pointwise.reshape(len(w),
301 | self.out_channels,
302 | self.out_channels // self.n_groups,
303 | 1, 1)
304 | bias = self.bias(w)
305 | bias = bias.reshape(len(w), self.out_channels)
306 | return w_spatial, w_pointwise, bias
307 |
308 | class AdaConv2d(nn.Module):
309 | def __init__(self, in_channels, out_channels, kernel_size=3, n_groups=None):
310 | super(AdaConv2d, self).__init__()
311 | self.n_groups = in_channels if n_groups is None else n_groups
312 | self.in_channels = in_channels
313 | self.out_channels = out_channels
314 |
315 | padding = (kernel_size - 1) / 2
316 | self.conv = nn.Conv2d(in_channels=in_channels,
317 | out_channels=out_channels,
318 | kernel_size=(kernel_size, kernel_size),
319 | padding=(math.ceil(padding), math.floor(padding)),
320 | padding_mode='reflect')
321 |
322 | def forward(self, x, w_spatial, w_pointwise, bias):
323 | assert len(x) == len(w_spatial) == len(w_pointwise) == len(bias)
324 | x = F.instance_norm(x)
325 |
326 | ys = []
327 | for i in range(len(x)):
328 | y = self.forward_single(x[i:i+1], w_spatial[i], w_pointwise[i], bias[i])
329 | ys.append(y)
330 | ys = torch.cat(ys, dim=0)
331 |
332 | ys = self.conv(ys)
333 | return ys
334 |
335 | def forward_single(self, x, w_spatial, w_pointwise, bias):
336 | assert w_spatial.size(-1) == w_spatial.size(-2)
337 | padding = (w_spatial.size(-1) - 1) / 2
338 | pad = (math.ceil(padding), math.floor(padding), math.ceil(padding), math.floor(padding))
339 |
340 | x = F.pad(x, pad=pad, mode='reflect')
341 | x = F.conv2d(x, w_spatial, groups=self.n_groups)
342 | x = F.conv2d(x, w_pointwise, groups=self.n_groups, bias=bias)
343 | return x
--------------------------------------------------------------------------------
/Video_NST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "9a6eab3c-20b0-414e-91b5-4bab2b625f02",
6 | "metadata": {},
7 | "source": [
8 | "# Video to Image\n",
9 | "before style transfer using AesFA"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "id": "ad2a36df-e065-4739-abb8-3f8bc9755ddb",
16 | "metadata": {},
17 | "outputs": [
18 | {
19 | "name": "stdout",
20 | "output_type": "stream",
21 | "text": [
22 | "pexels-kelly-7165765-3840x2160-24fps\n",
23 | "319 24.0\n"
24 | ]
25 | }
26 | ],
27 | "source": [
28 | "import os\n",
29 | "import cv2\n",
30 | "import numpy as np\n",
31 | "import matplotlib.pyplot as plt\n",
32 | "from PIL import Image\n",
33 | "import glob\n",
34 | "\n",
35 | "resolution = '2048'\n",
36 | "content_dir = './video/{}/'.format(resolution)\n",
37 | "contents_in = sorted(glob.glob(content_dir+'*.mp4'))\n",
38 | "content_img = './video_img/{}/content/'.format(resolution)\n",
39 | "if not os.path.isdir(content_img):\n",
40 | " os.makedirs(content_img)\n",
41 | "\n",
42 | "contents = []\n",
43 | "for content in contents_in:\n",
44 | " content_name = content.split('/')[-1].split('.')[0]\n",
45 | " contents.append(content_name)\n",
46 | " print(content_name)\n",
47 | " \n",
48 | " # Video Capture Sanity Check\n",
49 | " cap = cv2.VideoCapture(content)\n",
50 | " length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
51 | " fps = cap.get(cv2.CAP_PROP_FPS)\n",
52 | " print(length, fps)\n",
53 | "\n",
54 | " i = 0\n",
55 | " while(cap.isOpened()):\n",
56 | " ret, frame = cap.read()\n",
57 | " try:\n",
58 | " frame = frame[:,:,::-1]\n",
59 | " iframe = Image.fromarray(frame)\n",
60 | "\n",
61 | " iframe.save(content_img+'{:s}_{:d}.jpg'.format(content_name, i+1))\n",
62 | " i += 1\n",
63 | " except:\n",
64 | " break\n",
65 | "\n",
66 | " cap.release()"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "id": "53bb10ab-08f2-4a2e-b8e6-4caa5ab1e065",
72 | "metadata": {},
73 | "source": [
74 | "# Transfer the style of each frame using AesFA model"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 2,
80 | "id": "ab7255c1-1256-4c2b-a29d-a9438ebd4c74",
81 | "metadata": {
82 | "scrolled": true,
83 | "tags": []
84 | },
85 | "outputs": [
86 | {
87 | "name": "stdout",
88 | "output_type": "stream",
89 | "text": [
90 | "Version: main\n",
91 | "cuda:0\n",
92 | "checkpoint: ./ckpt/main/main.pth\n",
93 | "content dir: ./video_img/2048/content/\n",
94 | "style dir: /global/cfs/cdirs/m3898/GANBERT/st_to_diff/Code/LabServer/ResolutionDataset/2048/style\n",
95 | "# of contents: 319\n",
96 | "# of styles: 2\n",
97 | "1 1 1 1.4959888458251953\n",
98 | "2 1 2 0.013611078262329102\n",
99 | "3 2 1 0.013701677322387695\n",
100 | "4 2 2 0.013676643371582031\n",
101 | "5 3 1 0.013656854629516602\n",
102 | "6 3 2 0.013533353805541992\n",
103 | "7 4 1 0.01367640495300293\n",
104 | "8 4 2 0.013755083084106445\n",
105 | "9 5 1 0.013662338256835938\n",
106 | "10 5 2 0.01371622085571289\n",
107 | "11 6 1 0.01367497444152832\n",
108 | "12 6 2 0.013765335083007812\n",
109 | "13 7 1 0.011701345443725586\n",
110 | "14 7 2 0.011647224426269531\n",
111 | "15 8 1 0.01172637939453125\n",
112 | "16 8 2 0.01176309585571289\n",
113 | "17 9 1 0.011693477630615234\n",
114 | "18 9 2 0.011683225631713867\n",
115 | "19 10 1 0.012634038925170898\n",
116 | "20 10 2 0.011624336242675781\n",
117 | "21 11 1 0.011622428894042969\n",
118 | "22 11 2 0.01162266731262207\n",
119 | "23 12 1 0.01264333724975586\n",
120 | "24 12 2 0.012636184692382812\n",
121 | "25 13 1 0.011599302291870117\n",
122 | "26 13 2 0.011915445327758789\n",
123 | "27 14 1 0.01216435432434082\n",
124 | "28 14 2 0.0116424560546875\n",
125 | "29 15 1 0.011826276779174805\n",
126 | "30 15 2 0.011800765991210938\n",
127 | "31 16 1 0.011651277542114258\n",
128 | "32 16 2 0.011873722076416016\n",
129 | "33 17 1 0.013638734817504883\n",
130 | "34 17 2 0.011655330657958984\n",
131 | "35 18 1 0.011660337448120117\n",
132 | "36 18 2 0.014058828353881836\n",
133 | "37 19 1 0.012962102890014648\n",
134 | "38 19 2 0.01383519172668457\n",
135 | "39 20 1 0.012880802154541016\n",
136 | "40 20 2 0.013935565948486328\n",
137 | "41 21 1 0.013086557388305664\n",
138 | "42 21 2 0.013714790344238281\n",
139 | "43 22 1 0.012937784194946289\n",
140 | "44 22 2 0.013326644897460938\n",
141 | "45 23 1 0.01256418228149414\n",
142 | "46 23 2 0.012811660766601562\n",
143 | "47 24 1 0.012848854064941406\n",
144 | "48 24 2 0.013879776000976562\n",
145 | "49 25 1 0.012457847595214844\n",
146 | "50 25 2 0.012771844863891602\n",
147 | "51 26 1 0.01279306411743164\n",
148 | "52 26 2 0.01271200180053711\n",
149 | "53 27 1 0.012722969055175781\n",
150 | "54 27 2 0.012816190719604492\n",
151 | "55 28 1 0.013971090316772461\n",
152 | "56 28 2 0.013326644897460938\n",
153 | "57 29 1 0.014523029327392578\n",
154 | "58 29 2 0.012859106063842773\n",
155 | "59 30 1 0.0128631591796875\n",
156 | "60 30 2 0.012863874435424805\n",
157 | "61 31 1 0.012889623641967773\n",
158 | "62 31 2 0.01306605339050293\n",
159 | "63 32 1 0.013204574584960938\n",
160 | "64 32 2 0.012650012969970703\n",
161 | "65 33 1 0.014211177825927734\n",
162 | "66 33 2 0.012797355651855469\n",
163 | "67 34 1 0.012365102767944336\n",
164 | "68 34 2 0.012729883193969727\n",
165 | "69 35 1 0.014531135559082031\n",
166 | "70 35 2 0.012779951095581055\n",
167 | "71 36 1 0.012744903564453125\n",
168 | "72 36 2 0.013289928436279297\n",
169 | "73 37 1 0.013209104537963867\n",
170 | "74 37 2 0.013160467147827148\n",
171 | "75 38 1 0.012948989868164062\n",
172 | "76 38 2 0.013440370559692383\n",
173 | "77 39 1 0.013657569885253906\n",
174 | "78 39 2 0.012427091598510742\n",
175 | "79 40 1 0.012775897979736328\n",
176 | "80 40 2 0.012731313705444336\n",
177 | "81 41 1 0.01245737075805664\n",
178 | "82 41 2 0.01445913314819336\n",
179 | "83 42 1 0.013427972793579102\n",
180 | "84 42 2 0.014007568359375\n",
181 | "85 43 1 0.013984203338623047\n",
182 | "86 43 2 0.012977361679077148\n",
183 | "87 44 1 0.0128021240234375\n",
184 | "88 44 2 0.01381373405456543\n",
185 | "89 45 1 0.012957572937011719\n",
186 | "90 45 2 0.013097286224365234\n",
187 | "91 46 1 0.014750003814697266\n",
188 | "92 46 2 0.014478206634521484\n",
189 | "93 47 1 0.014056205749511719\n",
190 | "94 47 2 0.012886524200439453\n",
191 | "95 48 1 0.013549327850341797\n",
192 | "96 48 2 0.012481451034545898\n",
193 | "97 49 1 0.014037847518920898\n",
194 | "98 49 2 0.013414621353149414\n",
195 | "99 50 1 0.012603282928466797\n",
196 | "100 50 2 0.012481689453125\n",
197 | "101 51 1 0.013272285461425781\n",
198 | "102 51 2 0.013521909713745117\n",
199 | "103 52 1 0.0137481689453125\n",
200 | "104 52 2 0.013808965682983398\n",
201 | "105 53 1 0.013106107711791992\n",
202 | "106 53 2 0.013852834701538086\n",
203 | "107 54 1 0.01243138313293457\n",
204 | "108 54 2 0.01333761215209961\n",
205 | "109 55 1 0.013401269912719727\n",
206 | "110 55 2 0.012766599655151367\n",
207 | "111 56 1 0.01387476921081543\n",
208 | "112 56 2 0.013069391250610352\n",
209 | "113 57 1 0.014247655868530273\n",
210 | "114 57 2 0.012739419937133789\n",
211 | "115 58 1 0.013568401336669922\n",
212 | "116 58 2 0.01391911506652832\n",
213 | "117 59 1 0.012873172760009766\n",
214 | "118 59 2 0.013228178024291992\n",
215 | "119 60 1 0.012822389602661133\n",
216 | "120 60 2 0.01289820671081543\n",
217 | "121 61 1 0.012432336807250977\n",
218 | "122 61 2 0.012836456298828125\n",
219 | "123 62 1 0.014320850372314453\n",
220 | "124 62 2 0.014084339141845703\n",
221 | "125 63 1 0.01263284683227539\n",
222 | "126 63 2 0.012930154800415039\n",
223 | "127 64 1 0.013004064559936523\n",
224 | "128 64 2 0.012860536575317383\n",
225 | "129 65 1 0.01366734504699707\n",
226 | "130 65 2 0.01270437240600586\n",
227 | "131 66 1 0.012862682342529297\n",
228 | "132 66 2 0.013658761978149414\n",
229 | "133 67 1 0.012963533401489258\n",
230 | "134 67 2 0.013154983520507812\n",
231 | "135 68 1 0.013437271118164062\n",
232 | "136 68 2 0.01375889778137207\n",
233 | "137 69 1 0.012767553329467773\n",
234 | "138 69 2 0.01370859146118164\n",
235 | "139 70 1 0.01272273063659668\n",
236 | "140 70 2 0.012549161911010742\n",
237 | "141 71 1 0.013580799102783203\n",
238 | "142 71 2 0.013698577880859375\n",
239 | "143 72 1 0.013996362686157227\n",
240 | "144 72 2 0.012699127197265625\n",
241 | "145 73 1 0.01391291618347168\n",
242 | "146 73 2 0.012805461883544922\n",
243 | "147 74 1 0.01383829116821289\n",
244 | "148 74 2 0.013287544250488281\n",
245 | "149 75 1 0.013069868087768555\n",
246 | "150 75 2 0.013500213623046875\n",
247 | "151 76 1 0.014073848724365234\n",
248 | "152 76 2 0.01279592514038086\n",
249 | "153 77 1 0.013386726379394531\n",
250 | "154 77 2 0.012765884399414062\n",
251 | "155 78 1 0.013403177261352539\n",
252 | "156 78 2 0.01393437385559082\n",
253 | "157 79 1 0.013809919357299805\n",
254 | "158 79 2 0.013591766357421875\n",
255 | "159 80 1 0.013782501220703125\n",
256 | "160 80 2 0.012714147567749023\n",
257 | "161 81 1 0.01299595832824707\n",
258 | "162 81 2 0.01310110092163086\n",
259 | "163 82 1 0.013424873352050781\n",
260 | "164 82 2 0.012754678726196289\n",
261 | "165 83 1 0.012630701065063477\n",
262 | "166 83 2 0.012861967086791992\n",
263 | "167 84 1 0.01343226432800293\n",
264 | "168 84 2 0.01407480239868164\n",
265 | "169 85 1 0.013016462326049805\n",
266 | "170 85 2 0.012706518173217773\n",
267 | "171 86 1 0.013083696365356445\n",
268 | "172 86 2 0.013735294342041016\n",
269 | "173 87 1 0.012756824493408203\n",
270 | "174 87 2 0.012722969055175781\n",
271 | "175 88 1 0.01279902458190918\n",
272 | "176 88 2 0.013587474822998047\n",
273 | "177 89 1 0.012699604034423828\n",
274 | "178 89 2 0.01271200180053711\n",
275 | "179 90 1 0.012832880020141602\n",
276 | "180 90 2 0.014787435531616211\n",
277 | "181 91 1 0.012573957443237305\n",
278 | "182 91 2 0.013875246047973633\n",
279 | "183 92 1 0.012664318084716797\n",
280 | "184 92 2 0.012704133987426758\n",
281 | "185 93 1 0.013759374618530273\n",
282 | "186 93 2 0.01273798942565918\n",
283 | "187 94 1 0.012641429901123047\n",
284 | "188 94 2 0.0140838623046875\n",
285 | "189 95 1 0.012718439102172852\n",
286 | "190 95 2 0.012366533279418945\n",
287 | "191 96 1 0.013098716735839844\n",
288 | "192 96 2 0.013202428817749023\n",
289 | "193 97 1 0.012679576873779297\n",
290 | "194 97 2 0.012736082077026367\n",
291 | "195 98 1 0.013821840286254883\n",
292 | "196 98 2 0.012673616409301758\n",
293 | "197 99 1 0.012502670288085938\n",
294 | "198 99 2 0.012480020523071289\n",
295 | "199 100 1 0.012747049331665039\n",
296 | "200 100 2 0.012721538543701172\n",
297 | "201 101 1 0.013068914413452148\n",
298 | "202 101 2 0.013926029205322266\n",
299 | "203 102 1 0.014308691024780273\n",
300 | "204 102 2 0.013094663619995117\n",
301 | "205 103 1 0.012408018112182617\n",
302 | "206 103 2 0.012737035751342773\n",
303 | "207 104 1 0.014155387878417969\n",
304 | "208 104 2 0.01417684555053711\n",
305 | "209 105 1 0.012717962265014648\n",
306 | "210 105 2 0.013577461242675781\n",
307 | "211 106 1 0.01384592056274414\n",
308 | "212 106 2 0.012499094009399414\n",
309 | "213 107 1 0.013324737548828125\n",
310 | "214 107 2 0.01272726058959961\n",
311 | "215 108 1 0.012765169143676758\n",
312 | "216 108 2 0.012412786483764648\n",
313 | "217 109 1 0.013402938842773438\n",
314 | "218 109 2 0.012748479843139648\n",
315 | "219 110 1 0.012696266174316406\n",
316 | "220 110 2 0.014709711074829102\n",
317 | "221 111 1 0.012678861618041992\n",
318 | "222 111 2 0.012707948684692383\n",
319 | "223 112 1 0.012372016906738281\n",
320 | "224 112 2 0.01387166976928711\n",
321 | "225 113 1 0.014742136001586914\n",
322 | "226 113 2 0.013881921768188477\n",
323 | "227 114 1 0.013998746871948242\n",
324 | "228 114 2 0.012818336486816406\n",
325 | "229 115 1 0.01278829574584961\n",
326 | "230 115 2 0.01302647590637207\n",
327 | "231 116 1 0.013860464096069336\n",
328 | "232 116 2 0.01271963119506836\n",
329 | "233 117 1 0.012736797332763672\n",
330 | "234 117 2 0.01385498046875\n",
331 | "235 118 1 0.013054370880126953\n",
332 | "236 118 2 0.01247715950012207\n",
333 | "237 119 1 0.012817621231079102\n",
334 | "238 119 2 0.012902498245239258\n",
335 | "239 120 1 0.01287531852722168\n",
336 | "240 120 2 0.012738943099975586\n",
337 | "241 121 1 0.013455867767333984\n",
338 | "242 121 2 0.012953519821166992\n",
339 | "243 122 1 0.013499021530151367\n",
340 | "244 122 2 0.014836549758911133\n",
341 | "245 123 1 0.013926029205322266\n",
342 | "246 123 2 0.014399528503417969\n",
343 | "247 124 1 0.014725923538208008\n",
344 | "248 124 2 0.013016462326049805\n",
345 | "249 125 1 0.014930009841918945\n",
346 | "250 125 2 0.012398481369018555\n",
347 | "251 126 1 0.013311147689819336\n",
348 | "252 126 2 0.01300668716430664\n",
349 | "253 127 1 0.014466047286987305\n",
350 | "254 127 2 0.013182640075683594\n",
351 | "255 128 1 0.01296687126159668\n",
352 | "256 128 2 0.012848615646362305\n",
353 | "257 129 1 0.012396097183227539\n",
354 | "258 129 2 0.0128021240234375\n",
355 | "259 130 1 0.012939929962158203\n",
356 | "260 130 2 0.013190269470214844\n",
357 | "261 131 1 0.013485431671142578\n",
358 | "262 131 2 0.01271677017211914\n",
359 | "263 132 1 0.013487100601196289\n",
360 | "264 132 2 0.012562990188598633\n",
361 | "265 133 1 0.013659238815307617\n",
362 | "266 133 2 0.014175176620483398\n",
363 | "267 134 1 0.013893365859985352\n",
364 | "268 134 2 0.013002395629882812\n",
365 | "269 135 1 0.014646768569946289\n",
366 | "270 135 2 0.013591289520263672\n",
367 | "271 136 1 0.014624595642089844\n",
368 | "272 136 2 0.014317989349365234\n",
369 | "273 137 1 0.013198375701904297\n",
370 | "274 137 2 0.01444697380065918\n",
371 | "275 138 1 0.013878345489501953\n",
372 | "276 138 2 0.013666152954101562\n",
373 | "277 139 1 0.01478433609008789\n",
374 | "278 139 2 0.012974262237548828\n",
375 | "279 140 1 0.013012170791625977\n",
376 | "280 140 2 0.01274728775024414\n",
377 | "281 141 1 0.013678789138793945\n",
378 | "282 141 2 0.012810468673706055\n",
379 | "283 142 1 0.013793706893920898\n",
380 | "284 142 2 0.013660907745361328\n",
381 | "285 143 1 0.013870716094970703\n",
382 | "286 143 2 0.01301121711730957\n",
383 | "287 144 1 0.012352943420410156\n",
384 | "288 144 2 0.012561798095703125\n",
385 | "289 145 1 0.01263737678527832\n",
386 | "290 145 2 0.013193130493164062\n",
387 | "291 146 1 0.011633634567260742\n",
388 | "292 146 2 0.011667728424072266\n",
389 | "293 147 1 0.011678695678710938\n",
390 | "294 147 2 0.012796878814697266\n",
391 | "295 148 1 0.011621952056884766\n",
392 | "296 148 2 0.011636018753051758\n",
393 | "297 149 1 0.011415243148803711\n",
394 | "298 149 2 0.011688232421875\n",
395 | "299 150 1 0.011734485626220703\n",
396 | "300 150 2 0.012699365615844727\n",
397 | "301 151 1 0.011561155319213867\n",
398 | "302 151 2 0.011710882186889648\n",
399 | "303 152 1 0.012788534164428711\n",
400 | "304 152 2 0.012428522109985352\n",
401 | "305 153 1 0.011623144149780273\n",
402 | "306 153 2 0.012603521347045898\n",
403 | "307 154 1 0.011607885360717773\n",
404 | "308 154 2 0.011417388916015625\n",
405 | "309 155 1 0.011678218841552734\n",
406 | "310 155 2 0.01161646842956543\n",
407 | "311 156 1 0.012554168701171875\n",
408 | "312 156 2 0.011397123336791992\n",
409 | "313 157 1 0.012323379516601562\n",
410 | "314 157 2 0.011717081069946289\n",
411 | "315 158 1 0.01142120361328125\n",
412 | "316 158 2 0.011664152145385742\n",
413 | "317 159 1 0.012273788452148438\n",
414 | "318 159 2 0.01161336898803711\n",
415 | "319 160 1 0.011413097381591797\n",
416 | "320 160 2 0.011538505554199219\n",
417 | "321 161 1 0.011590003967285156\n",
418 | "322 161 2 0.011747360229492188\n",
419 | "323 162 1 0.011584281921386719\n",
420 | "324 162 2 0.011697769165039062\n",
421 | "325 163 1 0.011633157730102539\n",
422 | "326 163 2 0.011493921279907227\n",
423 | "327 164 1 0.012678146362304688\n",
424 | "328 164 2 0.011621475219726562\n",
425 | "329 165 1 0.0115966796875\n",
426 | "330 165 2 0.011447429656982422\n",
427 | "331 166 1 0.011586427688598633\n",
428 | "332 166 2 0.011658906936645508\n",
429 | "333 167 1 0.011457681655883789\n",
430 | "334 167 2 0.011510610580444336\n",
431 | "335 168 1 0.01169729232788086\n",
432 | "336 168 2 0.0126495361328125\n",
433 | "337 169 1 0.011425018310546875\n",
434 | "338 169 2 0.01159524917602539\n",
435 | "339 170 1 0.011730194091796875\n",
436 | "340 170 2 0.011640071868896484\n",
437 | "341 171 1 0.014062881469726562\n",
438 | "342 171 2 0.011739492416381836\n",
439 | "343 172 1 0.011718034744262695\n",
440 | "344 172 2 0.01143193244934082\n",
441 | "345 173 1 0.01160883903503418\n",
442 | "346 173 2 0.011601448059082031\n",
443 | "347 174 1 0.011661767959594727\n",
444 | "348 174 2 0.011547088623046875\n",
445 | "349 175 1 0.012003183364868164\n",
446 | "350 175 2 0.012264490127563477\n",
447 | "351 176 1 0.011504173278808594\n",
448 | "352 176 2 0.011589527130126953\n",
449 | "353 177 1 0.011622428894042969\n",
450 | "354 177 2 0.011635065078735352\n",
451 | "355 178 1 0.01145315170288086\n",
452 | "356 178 2 0.01169729232788086\n",
453 | "357 179 1 0.013379812240600586\n",
454 | "358 179 2 0.011698722839355469\n",
455 | "359 180 1 0.011440277099609375\n",
456 | "360 180 2 0.01191568374633789\n",
457 | "361 181 1 0.011662483215332031\n",
458 | "362 181 2 0.013228654861450195\n",
459 | "363 182 1 0.01308441162109375\n",
460 | "364 182 2 0.012153148651123047\n",
461 | "365 183 1 0.011551856994628906\n",
462 | "366 183 2 0.011779546737670898\n",
463 | "367 184 1 0.011803627014160156\n",
464 | "368 184 2 0.011717796325683594\n",
465 | "369 185 1 0.0117950439453125\n",
466 | "370 185 2 0.011774063110351562\n",
467 | "371 186 1 0.013192892074584961\n",
468 | "372 186 2 0.011783599853515625\n",
469 | "373 187 1 0.011571168899536133\n",
470 | "374 187 2 0.011755228042602539\n",
471 | "375 188 1 0.01375722885131836\n",
472 | "376 188 2 0.012132883071899414\n",
473 | "377 189 1 0.01233530044555664\n",
474 | "378 189 2 0.012766599655151367\n",
475 | "379 190 1 0.012897968292236328\n",
476 | "380 190 2 0.012633562088012695\n",
477 | "381 191 1 0.013985157012939453\n",
478 | "382 191 2 0.013800621032714844\n",
479 | "383 192 1 0.012768268585205078\n",
480 | "384 192 2 0.013072967529296875\n",
481 | "385 193 1 0.01387929916381836\n",
482 | "386 193 2 0.013388633728027344\n",
483 | "387 194 1 0.012714624404907227\n",
484 | "388 194 2 0.013583660125732422\n",
485 | "389 195 1 0.01267385482788086\n",
486 | "390 195 2 0.014597177505493164\n",
487 | "391 196 1 0.012612581253051758\n",
488 | "392 196 2 0.01373600959777832\n",
489 | "393 197 1 0.013222694396972656\n",
490 | "394 197 2 0.01295018196105957\n",
491 | "395 198 1 0.012334823608398438\n",
492 | "396 198 2 0.01252889633178711\n",
493 | "397 199 1 0.012554407119750977\n",
494 | "398 199 2 0.013105630874633789\n",
495 | "399 200 1 0.014058351516723633\n",
496 | "400 200 2 0.013690948486328125\n",
497 | "401 201 1 0.013262271881103516\n",
498 | "402 201 2 0.01384878158569336\n",
499 | "403 202 1 0.01252293586730957\n",
500 | "404 202 2 0.012486696243286133\n",
501 | "405 203 1 0.01377558708190918\n",
502 | "406 203 2 0.01277303695678711\n",
503 | "407 204 1 0.012504339218139648\n",
504 | "408 204 2 0.013715028762817383\n",
505 | "409 205 1 0.01278066635131836\n",
506 | "410 205 2 0.012386083602905273\n",
507 | "411 206 1 0.013184547424316406\n",
508 | "412 206 2 0.013869047164916992\n",
509 | "413 207 1 0.012421369552612305\n",
510 | "414 207 2 0.013211965560913086\n",
511 | "415 208 1 0.012650251388549805\n",
512 | "416 208 2 0.012481451034545898\n",
513 | "417 209 1 0.013153314590454102\n",
514 | "418 209 2 0.01270294189453125\n",
515 | "419 210 1 0.013456583023071289\n",
516 | "420 210 2 0.012352228164672852\n",
517 | "421 211 1 0.013653278350830078\n",
518 | "422 211 2 0.013286590576171875\n",
519 | "423 212 1 0.012918472290039062\n",
520 | "424 212 2 0.01227712631225586\n",
521 | "425 213 1 0.012325286865234375\n",
522 | "426 213 2 0.01367640495300293\n",
523 | "427 214 1 0.012329339981079102\n",
524 | "428 214 2 0.013483762741088867\n",
525 | "429 215 1 0.012603044509887695\n",
526 | "430 215 2 0.013625860214233398\n",
527 | "431 216 1 0.012730121612548828\n",
528 | "432 216 2 0.013671398162841797\n",
529 | "433 217 1 0.012735366821289062\n",
530 | "434 217 2 0.012464523315429688\n",
531 | "435 218 1 0.012424230575561523\n",
532 | "436 218 2 0.012888431549072266\n",
533 | "437 219 1 0.01347208023071289\n",
534 | "438 219 2 0.012525320053100586\n",
535 | "439 220 1 0.01267242431640625\n",
536 | "440 220 2 0.012814760208129883\n",
537 | "441 221 1 0.012642383575439453\n",
538 | "442 221 2 0.012613534927368164\n",
539 | "443 222 1 0.012290239334106445\n",
540 | "444 222 2 0.014106512069702148\n",
541 | "445 223 1 0.01321864128112793\n",
542 | "446 223 2 0.014102697372436523\n",
543 | "447 224 1 0.01352834701538086\n",
544 | "448 224 2 0.012572050094604492\n",
545 | "449 225 1 0.012751102447509766\n",
546 | "450 225 2 0.012769460678100586\n",
547 | "451 226 1 0.014033317565917969\n",
548 | "452 226 2 0.012693643569946289\n",
549 | "453 227 1 0.013306140899658203\n",
550 | "454 227 2 0.01329493522644043\n",
551 | "455 228 1 0.013216018676757812\n",
552 | "456 228 2 0.012681722640991211\n",
553 | "457 229 1 0.013617753982543945\n",
554 | "458 229 2 0.012444257736206055\n",
555 | "459 230 1 0.01358795166015625\n",
556 | "460 230 2 0.01271367073059082\n",
557 | "461 231 1 0.012509346008300781\n",
558 | "462 231 2 0.012388467788696289\n",
559 | "463 232 1 0.01296854019165039\n",
560 | "464 232 2 0.012692689895629883\n",
561 | "465 233 1 0.012335062026977539\n",
562 | "466 233 2 0.012636899948120117\n",
563 | "467 234 1 0.01267552375793457\n",
564 | "468 234 2 0.013209342956542969\n",
565 | "469 235 1 0.012700080871582031\n",
566 | "470 235 2 0.012656927108764648\n",
567 | "471 236 1 0.013702154159545898\n",
568 | "472 236 2 0.012688636779785156\n",
569 | "473 237 1 0.013486623764038086\n",
570 | "474 237 2 0.013573169708251953\n",
571 | "475 238 1 0.012482643127441406\n",
572 | "476 238 2 0.014265775680541992\n",
573 | "477 239 1 0.013628005981445312\n",
574 | "478 239 2 0.012498617172241211\n",
575 | "479 240 1 0.012382984161376953\n",
576 | "480 240 2 0.01279759407043457\n",
577 | "481 241 1 0.013809919357299805\n",
578 | "482 241 2 0.012789011001586914\n",
579 | "483 242 1 0.013376474380493164\n",
580 | "484 242 2 0.01320338249206543\n",
581 | "485 243 1 0.014653444290161133\n",
582 | "486 243 2 0.01357579231262207\n",
583 | "487 244 1 0.012531757354736328\n",
584 | "488 244 2 0.012547969818115234\n",
585 | "489 245 1 0.012688398361206055\n",
586 | "490 245 2 0.012520313262939453\n",
587 | "491 246 1 0.012963056564331055\n",
588 | "492 246 2 0.012600421905517578\n",
589 | "493 247 1 0.012735605239868164\n",
590 | "494 247 2 0.012371301651000977\n",
591 | "495 248 1 0.014453887939453125\n",
592 | "496 248 2 0.013877391815185547\n",
593 | "497 249 1 0.014386415481567383\n",
594 | "498 249 2 0.014112710952758789\n",
595 | "499 250 1 0.013568401336669922\n",
596 | "500 250 2 0.013689279556274414\n",
597 | "501 251 1 0.014236688613891602\n",
598 | "502 251 2 0.015862703323364258\n",
599 | "503 252 1 0.014400720596313477\n",
600 | "504 252 2 0.013496637344360352\n",
601 | "505 253 1 0.013967514038085938\n",
602 | "506 253 2 0.014106273651123047\n",
603 | "507 254 1 0.013255119323730469\n",
604 | "508 254 2 0.014400720596313477\n",
605 | "509 255 1 0.014082193374633789\n",
606 | "510 255 2 0.013763904571533203\n",
607 | "511 256 1 0.013281583786010742\n",
608 | "512 256 2 0.013522624969482422\n",
609 | "513 257 1 0.01388859748840332\n",
610 | "514 257 2 0.01385354995727539\n",
611 | "515 258 1 0.013643264770507812\n",
612 | "516 258 2 0.0128936767578125\n",
613 | "517 259 1 0.013870000839233398\n",
614 | "518 259 2 0.01346588134765625\n",
615 | "519 260 1 0.014710426330566406\n",
616 | "520 260 2 0.014383077621459961\n",
617 | "521 261 1 0.013254642486572266\n",
618 | "522 261 2 0.01389002799987793\n",
619 | "523 262 1 0.01373910903930664\n",
620 | "524 262 2 0.013704299926757812\n",
621 | "525 263 1 0.014248371124267578\n",
622 | "526 263 2 0.013776302337646484\n",
623 | "527 264 1 0.013773679733276367\n",
624 | "528 264 2 0.013849496841430664\n",
625 | "529 265 1 0.014349937438964844\n",
626 | "530 265 2 0.013779878616333008\n",
627 | "531 266 1 0.01332998275756836\n",
628 | "532 266 2 0.013013362884521484\n",
629 | "533 267 1 0.013605356216430664\n",
630 | "534 267 2 0.013622045516967773\n",
631 | "535 268 1 0.013462305068969727\n",
632 | "536 268 2 0.013475656509399414\n",
633 | "537 269 1 0.012336969375610352\n",
634 | "538 269 2 0.014376401901245117\n",
635 | "539 270 1 0.013368844985961914\n",
636 | "540 270 2 0.013869285583496094\n",
637 | "541 271 1 0.013324499130249023\n",
638 | "542 271 2 0.012337923049926758\n",
639 | "543 272 1 0.014369487762451172\n",
640 | "544 272 2 0.013405799865722656\n",
641 | "545 273 1 0.013122081756591797\n",
642 | "546 273 2 0.012341499328613281\n",
643 | "547 274 1 0.01433420181274414\n",
644 | "548 274 2 0.012491226196289062\n",
645 | "549 275 1 0.012465476989746094\n",
646 | "550 275 2 0.012725591659545898\n",
647 | "551 276 1 0.013387441635131836\n",
648 | "552 276 2 0.012306451797485352\n",
649 | "553 277 1 0.014266490936279297\n",
650 | "554 277 2 0.01361083984375\n",
651 | "555 278 1 0.012292623519897461\n",
652 | "556 278 2 0.013612747192382812\n",
653 | "557 279 1 0.012541055679321289\n",
654 | "558 279 2 0.013263702392578125\n",
655 | "559 280 1 0.013273954391479492\n",
656 | "560 280 2 0.012352466583251953\n",
657 | "561 281 1 0.013643264770507812\n",
658 | "562 281 2 0.013596057891845703\n",
659 | "563 282 1 0.01385188102722168\n",
660 | "564 282 2 0.012816429138183594\n",
661 | "565 283 1 0.012507200241088867\n",
662 | "566 283 2 0.014425516128540039\n",
663 | "567 284 1 0.012351512908935547\n",
664 | "568 284 2 0.013317108154296875\n",
665 | "569 285 1 0.012427806854248047\n",
666 | "570 285 2 0.013004541397094727\n",
667 | "571 286 1 0.01266026496887207\n",
668 | "572 286 2 0.012664794921875\n",
669 | "573 287 1 0.013560295104980469\n",
670 | "574 287 2 0.012282848358154297\n",
671 | "575 288 1 0.013694524765014648\n",
672 | "576 288 2 0.012948274612426758\n",
673 | "577 289 1 0.012575626373291016\n",
674 | "578 289 2 0.012510538101196289\n",
675 | "579 290 1 0.013562202453613281\n",
676 | "580 290 2 0.013659238815307617\n",
677 | "581 291 1 0.012853860855102539\n",
678 | "582 291 2 0.013521671295166016\n",
679 | "583 292 1 0.012647151947021484\n",
680 | "584 292 2 0.012308835983276367\n",
681 | "585 293 1 0.013398885726928711\n",
682 | "586 293 2 0.01263737678527832\n",
683 | "587 294 1 0.012757062911987305\n",
684 | "588 294 2 0.013249874114990234\n",
685 | "589 295 1 0.012322425842285156\n",
686 | "590 295 2 0.012548446655273438\n",
687 | "591 296 1 0.012465715408325195\n",
688 | "592 296 2 0.012528657913208008\n",
689 | "593 297 1 0.012717485427856445\n",
690 | "594 297 2 0.01473093032836914\n",
691 | "595 298 1 0.013164758682250977\n",
692 | "596 298 2 0.012622356414794922\n",
693 | "597 299 1 0.013996124267578125\n",
694 | "598 299 2 0.012646198272705078\n",
695 | "599 300 1 0.012280702590942383\n",
696 | "600 300 2 0.013857841491699219\n",
697 | "601 301 1 0.0142364501953125\n",
698 | "602 301 2 0.012766838073730469\n",
699 | "603 302 1 0.013649463653564453\n",
700 | "604 302 2 0.014155864715576172\n",
701 | "605 303 1 0.012651443481445312\n",
702 | "606 303 2 0.013751506805419922\n",
703 | "607 304 1 0.01323556900024414\n",
704 | "608 304 2 0.012627363204956055\n",
705 | "609 305 1 0.013291358947753906\n",
706 | "610 305 2 0.01236104965209961\n",
707 | "611 306 1 0.012605905532836914\n",
708 | "612 306 2 0.012676000595092773\n",
709 | "613 307 1 0.01229405403137207\n",
710 | "614 307 2 0.012366294860839844\n",
711 | "615 308 1 0.013654947280883789\n",
712 | "616 308 2 0.012632369995117188\n",
713 | "617 309 1 0.012373685836791992\n",
714 | "618 309 2 0.012389183044433594\n",
715 | "619 310 1 0.012661457061767578\n",
716 | "620 310 2 0.012536287307739258\n",
717 | "621 311 1 0.013274908065795898\n",
718 | "622 311 2 0.014044523239135742\n",
719 | "623 312 1 0.012743473052978516\n",
720 | "624 312 2 0.014609813690185547\n",
721 | "625 313 1 0.013411521911621094\n",
722 | "626 313 2 0.012664794921875\n",
723 | "627 314 1 0.012385129928588867\n",
724 | "628 314 2 0.01350259780883789\n",
725 | "629 315 1 0.012861251831054688\n",
726 | "630 315 2 0.013253927230834961\n",
727 | "631 316 1 0.01340794563293457\n",
728 | "632 316 2 0.013985157012939453\n",
729 | "633 317 1 0.013353824615478516\n",
730 | "634 317 2 0.012727022171020508\n",
731 | "635 318 1 0.01324009895324707\n",
732 | "636 318 2 0.012718915939331055\n",
733 | "637 319 1 0.0127105712890625\n",
734 | "638 319 2 0.012707710266113281\n",
735 | "[Ours] Total images: 638 Avg Testing time: 0.015323991312128624\n"
736 | ]
737 | }
738 | ],
739 | "source": [
740 | "# Before testing, you have to edit Config.py file as you want.\n",
741 | "# content_dir would be like './video_img/2048/content/'\n",
742 | "# img_dir would be like './output/video_output/2048'\n",
743 | "import test_video\n",
744 | "test_video.main()"
745 | ]
746 | },
747 | {
748 | "cell_type": "markdown",
749 | "id": "9fa62836-3cc3-49f0-ae9f-d8c007c1483a",
750 | "metadata": {},
751 | "source": [
752 | "# Image to Video - Content Video\n",
753 | "After style transfer"
754 | ]
755 | },
756 | {
757 | "cell_type": "code",
758 | "execution_count": 3,
759 | "id": "147d187f-1f05-4f74-808b-9a8d3dd2e284",
760 | "metadata": {},
761 | "outputs": [
762 | {
763 | "name": "stdout",
764 | "output_type": "stream",
765 | "text": [
766 | "319 24.0\n"
767 | ]
768 | },
769 | {
770 | "name": "stderr",
771 | "output_type": "stream",
772 | "text": [
773 | "OpenCV: FFMPEG: tag 0x5634504d/'MP4V' is not supported with codec id 12 and format 'mp4 / MP4 (MPEG-4 Part 14)'\n",
774 | "OpenCV: FFMPEG: fallback to use tag 0x7634706d/'mp4v'\n"
775 | ]
776 | }
777 | ],
778 | "source": [
779 | "import cv2\n",
780 | "import numpy as np\n",
781 | "import glob\n",
782 | "from Config import Config\n",
783 | "\n",
784 | "config = Config()\n",
785 | "out_dir = config.img_dir\n",
786 | "\n",
787 | "for content_n in contents:\n",
788 | " length = len(glob.glob(out_dir+'/content/'+content_n+'*_content*'))\n",
789 | " cap = cv2.VideoCapture(content_dir+content_n+'.mp4')\n",
790 | " fps = cap.get(cv2.CAP_PROP_FPS)\n",
791 | " print(length, fps)\n",
792 | " \n",
793 | " img_array = []\n",
794 | " for idx in range(length):\n",
795 | " filename = out_dir+'/content/'+content_n+'_'+str(idx+1)+'_content.jpg'\n",
796 | " img = cv2.imread(filename)\n",
797 | " img_array.append(img)\n",
798 | " height, width, layers = img.shape\n",
799 | " size = (width,height)\n",
800 | "\n",
801 | " out = cv2.VideoWriter(out_dir+'/CONTENT_'+content_n+'.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)\n",
802 | "\n",
803 | " for i in range(len(img_array)):\n",
804 | " out.write(img_array[i])\n",
805 | " out.release()"
806 | ]
807 | },
808 | {
809 | "cell_type": "markdown",
810 | "id": "0f777fe8-b10a-4b2a-94cc-ef6f8e8dabb1",
811 | "metadata": {},
812 | "source": [
813 | "# Image to Video - Style Transfered Video"
814 | ]
815 | },
816 | {
817 | "cell_type": "code",
818 | "execution_count": 4,
819 | "id": "2fe49999-ffee-45b7-b0a4-4a008739046f",
820 | "metadata": {},
821 | "outputs": [
822 | {
823 | "name": "stdout",
824 | "output_type": "stream",
825 | "text": [
826 | "pexels-kelly-7165765-3840x2160-24fps 319 24.0\n"
827 | ]
828 | }
829 | ],
830 | "source": [
831 | "style_imgs = sorted(glob.glob('./video_img/'+resolution+'/style/*.jpg'))\n",
832 | "\n",
833 | "for content in contents:\n",
834 | " length = len(glob.glob(out_dir+'/content/'+content+'*_content*'))\n",
835 | " cap = cv2.VideoCapture(content_dir+content+'.mp4')\n",
836 | " fps = cap.get(cv2.CAP_PROP_FPS)\n",
837 | " print(content, length, fps)\n",
838 | " \n",
839 | " for style_dir in style_imgs:\n",
840 | " style = style_dir.split('/')[-1].split('.')[0]\n",
841 | " img_array = []\n",
842 | "\n",
843 | " for idx in range(length):\n",
844 | " filename = out_dir+'/stylized/'+content+'_'+str(idx+1)+'_stylized_'+style+'.jpg'\n",
845 | " img = cv2.imread(filename)\n",
846 | " img_array.append(img)\n",
847 | " height, width, layers = img.shape\n",
848 | " size = (width,height)\n",
849 | " print(\"-\", style, len(img_array))\n",
850 | "\n",
851 | " out = cv2.VideoWriter(out_dir+'/STYLE_'+content+'_'+style+'.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)\n",
852 | "\n",
853 | " for i in range(len(img_array)):\n",
854 | " out.write(img_array[i])\n",
855 | " out.release()"
856 | ]
857 | },
858 | {
859 | "cell_type": "code",
860 | "execution_count": null,
861 | "id": "db973afc-7937-4216-b80a-a2df475a15f1",
862 | "metadata": {},
863 | "outputs": [],
864 | "source": []
865 | }
866 | ],
867 | "metadata": {
868 | "kernelspec": {
869 | "display_name": "pytorch-1.13.1",
870 | "language": "python",
871 | "name": "pytorch-1.13.1"
872 | },
873 | "language_info": {
874 | "codemirror_mode": {
875 | "name": "ipython",
876 | "version": 3
877 | },
878 | "file_extension": ".py",
879 | "mimetype": "text/x-python",
880 | "name": "python",
881 | "nbconvert_exporter": "python",
882 | "pygments_lexer": "ipython3",
883 | "version": "3.9.15"
884 | }
885 | },
886 | "nbformat": 4,
887 | "nbformat_minor": 5
888 | }
889 |
--------------------------------------------------------------------------------