├── LICENSE ├── Config.py ├── DataSplit.py ├── readme.md ├── style_blending.py ├── train.py ├── model.py ├── test_video.py ├── test.py ├── vgg19.py ├── networks.py ├── blocks.py └── Video_NST.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sooyoung Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /Config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | phase = 'train' # You must change the phase into train/test/style_blending 3 | train_continue = 'off' # on / off 4 | 5 | data_num = 60000 # Maximum # of training data 6 | 7 | content_dir = './COCO' 8 | style_dir = './WikiArt' 9 | 10 | file_n = 'main' 11 | log_dir = './log/' + file_n 12 | ckpt_dir = './ckpt/' + file_n 13 | img_dir = './Generated_images/' + file_n 14 | 15 | if phase == 'test': 16 | multi_to_multi = True 17 | test_content_size = 256 18 | test_style_size = 256 19 | content_dir = './testDataset' 20 | style_dir = './testDataset' 21 | img_dir = './output/'+file_n+'/'+str(test_content_size) 22 | 23 | elif phase == 'style_blending': 24 | blend_load_size = 256 25 | blend_dir = './blendingDataset/' 26 | content_img = blend_dir + str(blend_load_size) + '/A.jpg' 27 | style_high_img = blend_dir + str(blend_load_size) + '/B.jpg' 28 | style_low_img = blend_dir + str(blend_load_size) + '/C.jpg' 29 | img_dir = './output/'+file_n+'_blending_' + str(blend_load_size) 30 | 31 | # VGG pre-trained model 32 | vgg_model = './vgg_normalised.pth' 33 | 34 | ## basic parameters 35 | n_iter = 160000 36 | batch_size = 8 37 | lr = 0.0001 38 | lr_policy = 'step' 39 | lr_decay_iters = 50 40 | beta1 = 0.0 41 | 42 | # preprocess parameters 43 | load_size = 512 44 | crop_size = 256 45 | 46 | # model parameters 47 | input_nc = 3 # of input image channel 48 | nf = 64 # of feature map channel after Encoder first layer 49 | output_nc = 3 # of output image channel 50 | style_kernel = 3 # size of style kernel 51 | 52 | # Octave Convolution parameters 53 | alpha_in = 0.5 # input ratio of low-frequency channel 54 | alpha_out = 0.5 # output ratio of low-frequency channel 55 | freq_ratio = [1, 1] # [high, low] ratio at the last layer 56 | 57 | # Loss ratio 58 | lambda_percept = 1.0 59 | lambda_perc_cont = 1.0 60 | lambda_perc_style = 10.0 61 | lambda_const_style = 5.0 62 | 63 | # Else 64 | norm = 'instance' 65 | init_type = 'normal' 66 | init_gain = 0.02 67 | no_dropout = 'store_true' 68 | num_workers = 4 69 | -------------------------------------------------------------------------------- /DataSplit.py: -------------------------------------------------------------------------------- 1 | from path import Path 2 | import glob 3 | # import torch 4 | import torch.nn as nn 5 | # import pandas as pd 6 | # import numpy as np 7 | from PIL import Image 8 | from torchvision.transforms import ToTensor, Compose, Resize, Normalize, RandomCrop 9 | import random 10 | 11 | Image.MAX_IMAGE_PIXELS = 1000000000 12 | 13 | class DataSplit(nn.Module): 14 | def __init__(self, config, phase='train'): 15 | super(DataSplit, self).__init__() 16 | 17 | self.transform = Compose([Resize(size=[config.load_size, config.load_size]), 18 | RandomCrop(size=(config.crop_size, config.crop_size)), 19 | ToTensor(), 20 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 21 | 22 | if phase == 'train': 23 | # Content image data 24 | img_dir = Path(config.content_dir+'/train') 25 | self.images = self.get_data(img_dir) 26 | if config.data_num < len(self.images): 27 | self.images = random.sample(self.images, config.data_num) 28 | 29 | # Style image data 30 | sty_dir = Path(config.style_dir+'/train') 31 | self.style_images = self.get_data(sty_dir) 32 | if len(self.images) < len(self.style_images): 33 | self.style_images = random.sample(self.style_images, len(self.images)) 34 | elif len(self.images) > len(self.style_images): 35 | ratio = len(self.images) // len(self.style_images) 36 | bias = len(self.images) - ratio * len(self.style_images) 37 | self.style_images = self.style_images * ratio 38 | self.style_images += random.sample(self.style_images, bias) 39 | assert len(self.images) == len(self.style_images) 40 | 41 | elif phase == 'test': 42 | img_dir = Path(config.content_dir) 43 | self.images = self.get_data(img_dir)[:config.data_num] 44 | 45 | sty_dir = Path(config.style_dir) 46 | self.style_images = self.get_data(sty_dir)[:config.data_num] 47 | 48 | print('content dir:', img_dir) 49 | print('style dir:', sty_dir) 50 | 51 | def __len__(self): 52 | return len(self.images) 53 | 54 | def get_data(self, img_dir): 55 | file_type = ['*.jpg', '*.png', '*.jpeg', '*.tif'] 56 | imgs = [] 57 | for ft in file_type: 58 | imgs += sorted(img_dir.glob(ft)) 59 | images = sorted(imgs) 60 | return images 61 | 62 | def __getitem__(self, index): 63 | cont_img = self.images[index] 64 | cont_img = Image.open(cont_img).convert('RGB') 65 | cont_img = self.transform(cont_img) 66 | 67 | sty_img = self.style_images[index] 68 | sty_img = Image.open(sty_img).convert('RGB') 69 | sty_img = self.transform(sty_img) 70 | 71 | return {'content_img': cont_img, 'style_img': sty_img} -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # AesFA: An Aesthetic Feature-Aware Arbitrary Neural Style Transfer (AAAI 2024) 2 | Official Pytorch code for "AesFA: An Aesthetic Feature-Aware Arbitrary Neural Style Transfer"
3 | 4 | - Project page: [The Official Website for AesFA.](https://aesfa-nst.github.io/AesFA/) 5 | - arXiv preprint: 6 | 7 | First co-authors 8 | - Joonwoo Kwon (joonkwon96@gmail.com, **pioneers@snu.ac.kr**)
9 | - Sooyoung Kim (sooyyoungg513@gmail.com, **rlatndud0513@snu.ac.kr**)
10 | If one of us doesn't reply, please contact the other :) 11 | 12 | ## Introduction 13 | ![Figure1](https://github.com/Sooyyoungg/AesFA/assets/43199011/e9eca171-3bc6-49fc-9677-75020c2d596d) 14 | ![fig_eiffel](https://github.com/Sooyyoungg/AesFA/assets/43199011/d50e5142-1af3-4f3b-aeb7-2430c2aa7446) 15 | Neural style transfer (NST) has evolved significantly in recent years. Yet, despite its rapid progress and advancement, exist- ing NST methods either struggle to transfer aesthetic information from a style effectively or suffer from high computa- tional costs and inefficiencies in feature disentanglement due to using pre-trained models. This work proposes a lightweight but effective model, AesFA—Aesthetic Feature-Aware NST. The primary idea is to decompose the image via its frequencies to better disentangle aesthetic styles from the reference image while training the entire model in an end-to-end manner to exclude pre-trained models at inference completely. To improve the network’s ability to extract more distinct representations and further enhance the stylization quality, this work introduces a new aesthetic feature: contrastive loss. Ex- tensive experiments and ablations show the approach not only outperforms recent NST methods in terms of stylization quality, but it also achieves faster inference. 16 | 17 | 18 | ## Environment: 19 | - python 3.7 20 | - pytorch 1.13.1 21 | 22 | ## Getting Started: 23 | **Clone this repo:** 24 | ``` 25 | git clone https://github.com/Sooyyoungg/AesFA 26 | cd AesFA 27 | ``` 28 | 29 | **Train:** 30 | - Download dataset [MS-COCO](https://cocodataset.org/#download) for content images and [WikiArt](https://www.kaggle.com/c/painter-by-numbers) for style images. 31 | - Download the pre-trained [vgg_normalised.pth](https://github.com/naoto0804/pytorch-AdaIN/releases/tag/v0.0.0). 32 | - Change the training options in Config.py file. 33 | - The 'phase' must be 'train'. 34 | - The 'train_continue' should be 'on' if you train continuously with the previous model file. 35 | ```python train.py``` 36 | 37 | **Test:** 38 | - Download pre-trained AesFA model [main.pth](https://drive.google.com/file/d/1Y3OutPAsmPmJcnZs07ZVbDFf6nn3RzxR/view?usp=drive_link) 39 | - Change options about testing in the Config.py file. 40 | - Change phase into 'test' and other options (ex) data info (num, dir), image load and crop size. 41 | - If you want to use content and style images with different sizes, you can set test_content_size and test_style_size differently. 42 | - Also, you can choose whether you want to translate using multi_to_multi or only translate content images using each style image. 43 | ```python test.py``` 44 | 45 | -------------------------------------------------------------------------------- /style_blending.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop, Normalize, RandomCrop 6 | 7 | from Config import Config 8 | from model import AesFA_test 9 | from blocks import test_model_load 10 | 11 | def im_convert(tensor): 12 | image = tensor.to("cpu").clone().detach().numpy() 13 | image = image.transpose(0, 2, 3, 1) 14 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) 15 | image = image.clip(0, 1) 16 | return image 17 | 18 | def do_transform(img, osize): 19 | transform = Compose([Resize(size=osize), 20 | CenterCrop(size=osize), 21 | ToTensor(), 22 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 23 | return transform(img).unsqueeze(0) 24 | 25 | def save_img(config, cont_name, sty_h_name, sty_l_name, content, style_h, style_l, stylized): 26 | real_A = im_convert(content) 27 | real_B_1 = im_convert(style_h) 28 | real_B_2 = im_convert(style_l) 29 | trs_AtoB = im_convert(stylized) 30 | 31 | A_image = Image.fromarray((real_A[0] * 255.0).astype(np.uint8)) 32 | B1_image = Image.fromarray((real_B_1[0] * 255.0).astype(np.uint8)) 33 | B2_image = Image.fromarray((real_B_2[0] * 255.0).astype(np.uint8)) 34 | trs_image = Image.fromarray((trs_AtoB[0] * 255.0).astype(np.uint8)) 35 | 36 | cont_name = cont_name.split('/')[-1].split('.')[0] 37 | sty_h_name = sty_h_name.split('/')[-1].split('.')[0] 38 | sty_l_name = sty_l_name.split('/')[-1].split('.')[0] 39 | 40 | A_image.save('{}/{:s}_content.jpg'.format(config.img_dir, cont_name)) 41 | B1_image.save('{}/{:s}_high_style.jpg'.format(config.img_dir, sty_h_name)) 42 | B2_image.save('{}/{:s}_low_style.jpg'.format(config.img_dir, sty_l_name)) 43 | trs_image.save('{}/stylized_{:s}_{:s}_{:s}.jpg'.format(config.img_dir, cont_name, sty_h_name, sty_l_name)) 44 | 45 | def main(): 46 | config = Config() 47 | if not os.path.exists(config.img_dir): 48 | os.makedirs(config.img_dir) 49 | 50 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 51 | print('Version:', config.file_n) 52 | print(device) 53 | 54 | with torch.no_grad(): 55 | ## Model load 56 | model = AesFA_test(config) 57 | 58 | ## Load saved model 59 | ckpt = config.ckpt_dir + '/main.pth' 60 | print("checkpoint: ", ckpt) 61 | model = test_model_load(checkpoint=ckpt, model=model) 62 | model.to(device) 63 | 64 | ## Style Blending 65 | real_A = Image.open(config.content_img).convert('RGB') 66 | style_high = Image.open(config.style_high_img).convert('RGB') 67 | style_low = Image.open(config.style_low_img).convert('RGB') 68 | 69 | real_A = do_transform(real_A, config.blend_load_size).to(device) 70 | style_high = do_transform(style_high, config.blend_load_size).to(device) 71 | style_low = do_transform(style_low, config.blend_load_size).to(device) 72 | 73 | stylized, during = model.style_blending(real_A, style_high, style_low) 74 | save_img(config, config.content_img, config.style_high_img, config.style_low_img, real_A, style_high, style_low, stylized) 75 | print("Time:", during) 76 | 77 | if __name__ == '__main__': 78 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import tensorboardX 5 | 6 | from Config import Config 7 | from DataSplit import DataSplit 8 | from model import AesFA 9 | from blocks import model_save, model_load, update_learning_rate 10 | 11 | from torch.utils.data import RandomSampler 12 | 13 | def mkoutput_dir(config): 14 | if not os.path.exists(config.log_dir): 15 | os.makedirs(config.log_dir) 16 | if not os.path.exists(config.ckpt_dir): 17 | os.makedirs(config.ckpt_dir) 18 | 19 | def get_n_params(model): 20 | total_params=0 21 | net_params = {'netE':0, 'netS':0, 'netG':0, 'vgg_loss':0} 22 | 23 | for name, param in model.named_parameters(): 24 | net = name.split('.')[0] 25 | nn=1 26 | for s in list(param.size()): 27 | nn = nn*s 28 | net_params[net] += nn 29 | total_params += nn 30 | return total_params, net_params 31 | 32 | def im_convert(tensor): 33 | image = tensor.to("cpu").clone().detach().numpy() 34 | image = image.transpose(0, 2, 3, 1) 35 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) 36 | image = image.clip(0, 1) 37 | return image 38 | 39 | def main(): 40 | config = Config() 41 | mkoutput_dir(config) 42 | 43 | config.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 44 | print('cuda:', config.device) 45 | print('Version:', config.file_n) 46 | 47 | ########## Data Loader ########## 48 | train_data = DataSplit(config=config, phase='train') 49 | train_sampler = RandomSampler(train_data) 50 | data_loader_train = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=False, sampler=train_sampler) 51 | print("Train: ", train_data.__len__(), "images: ", len(data_loader_train), "x", config.batch_size,"(batch size) =", train_data.__len__()) 52 | 53 | ########## load model ########## 54 | model = AesFA(config) 55 | model.to(config.device) 56 | 57 | # # of parameter 58 | param_num, net_params = get_n_params(model) 59 | print("# of parameter:", param_num) 60 | print("parameters of networks:", net_params) 61 | 62 | ########## load saved model - to continue previous learning ########## 63 | if config.train_continue == 'on': 64 | model, model.optimizer_E, model.optimizer_S, model.optimizer_G, epoch_start, tot_itr = model_load(checkpoint=None, ckpt_dir=config.ckpt_dir, model=model, 65 | optim_E=model.optimizer_E, 66 | optim_S=model.optimizer_S, 67 | optim_G=model.optimizer_G) 68 | print(epoch_start, "th epoch ", tot_itr, "th iteration model load") 69 | else: 70 | epoch_start = 0 71 | tot_itr = 0 72 | 73 | train_writer = tensorboardX.SummaryWriter(config.log_dir) 74 | 75 | ########## Training ########## 76 | # to save ckpt file starting with epoch and iteration 1 77 | epoch = epoch_start - 1 78 | tot_itr = tot_itr - 1 79 | while tot_itr < config.n_iter: 80 | epoch += 1 81 | 82 | for i, data in enumerate(data_loader_train): 83 | tot_itr += 1 84 | train_dict = model.train_step(data) 85 | 86 | real_A = im_convert(data['content_img']) 87 | real_B = im_convert(train_dict['style_img']) 88 | fake_B = im_convert(train_dict['fake_AtoB']) 89 | trs_high = im_convert(train_dict['fake_AtoB_high']) 90 | trs_low = im_convert(train_dict['fake_AtoB_low']) 91 | 92 | ## Tensorboard ## 93 | # tensorboard - loss 94 | train_writer.add_scalar('Loss_G', train_dict['G_loss'], tot_itr) 95 | train_writer.add_scalar('Loss_G_Percept', train_dict['G_Percept'], tot_itr) 96 | train_writer.add_scalar('Loss_G_Contrast', train_dict['G_Contrast'], tot_itr) 97 | 98 | # tensorboard - images 99 | train_writer.add_image('Content_Image_A', real_A, tot_itr, dataformats='NHWC') 100 | train_writer.add_image('Style_Image_B', real_B, tot_itr, dataformats='NHWC') 101 | train_writer.add_image('Generated_Image_AtoB', fake_B, tot_itr, dataformats='NHWC') 102 | train_writer.add_image('Translation_AtoB_high', trs_high, tot_itr, dataformats='NHWC') 103 | train_writer.add_image('Translation_AtoB_low', trs_low, tot_itr, dataformats='NHWC') 104 | 105 | print("Tot_itrs: %d/%d | Epoch: %d | itr: %d/%d | Loss_G: %.5f"%(tot_itr+1, config.n_iter, epoch+1, (i+1), len(data_loader_train), train_dict['G_loss'])) 106 | 107 | if (tot_itr + 1) % 10000 == 0: 108 | model_save(ckpt_dir=config.ckpt_dir, model=model, optim_E=model.optimizer_E, optim_S=model.optimizer_S, optim_G=model.optimizer_G, epoch=epoch, itr=tot_itr) 109 | print(tot_itr+1, "th iteration model save") 110 | 111 | update_learning_rate(model.E_scheduler, model.optimizer_E) 112 | update_learning_rate(model.S_scheduler, model.optimizer_S) 113 | update_learning_rate(model.G_scheduler, model.optimizer_G) 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import networks 4 | import blocks 5 | import time 6 | 7 | from vgg19 import vgg, VGG_loss 8 | from networks import EFDM_loss 9 | 10 | class AesFA(nn.Module): 11 | def __init__(self, config): 12 | super(AesFA, self).__init__() 13 | 14 | self.config = config 15 | self.device = self.config.device 16 | 17 | self.lr = config.lr 18 | self.lambda_percept = config.lambda_percept 19 | self.lambda_const_style = config.lambda_const_style 20 | 21 | self.netE = networks.define_network(net_type='Encoder', config = config) # Content Encoder 22 | self.netS = networks.define_network(net_type='Encoder', config = config) # Style Encoder 23 | self.netG = networks.define_network(net_type='Generator', config = config) 24 | 25 | self.vgg_loss = VGG_loss(config, vgg) 26 | self.efdm_loss = EFDM_loss() 27 | 28 | self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=self.config.lr, betas=(self.config.beta1, 0.99)) 29 | self.optimizer_S = torch.optim.Adam(self.netS.parameters(), lr=self.config.lr, betas=(self.config.beta1, 0.99)) 30 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.config.lr, betas=(self.config.beta1, 0.99)) 31 | 32 | self.E_scheduler = blocks.get_scheduler(self.optimizer_E, config) 33 | self.S_scheduler = blocks.get_scheduler(self.optimizer_S, config) 34 | self.G_scheduler = blocks.get_scheduler(self.optimizer_G, config) 35 | 36 | 37 | def forward(self, data): 38 | self.real_A = data['content_img'].to(self.device) 39 | self.real_B = data['style_img'].to(self.device) 40 | 41 | self.content_A, _, _ = self.netE(self.real_A) 42 | _, self.style_B, self.content_B_feat = self.netS(self.real_B) 43 | self.style_B_feat = self.content_B_feat.copy() 44 | self.style_B_feat.append(self.style_B) 45 | 46 | self.trs_AtoB, self.trs_AtoB_high, self.trs_AtoB_low = self.netG(self.content_A, self.style_B) 47 | 48 | self.trs_AtoB_content, _, self.content_trs_AtoB_feat = self.netE(self.trs_AtoB) 49 | _, self.trs_AtoB_style, self.style_trs_AtoB_feat = self.netS(self.trs_AtoB) 50 | self.style_trs_AtoB_feat.append(self.trs_AtoB_style) 51 | 52 | 53 | def calc_G_loss(self): 54 | self.G_percept, self.neg_idx = self.vgg_loss.perceptual_loss(self.real_A, self.real_B, self.trs_AtoB) 55 | self.G_percept *= self.lambda_percept 56 | 57 | self.G_contrast = self.efdm_loss(self.content_B_feat, self.style_B_feat, self.content_trs_AtoB_feat, self.style_trs_AtoB_feat, self.neg_idx) * self.lambda_const_style 58 | 59 | self.G_loss = self.G_percept + self.G_contrast 60 | 61 | 62 | def train_step(self, data): 63 | self.set_requires_grad([self.netE, self.netS, self.netG], True) 64 | 65 | self.forward(data) 66 | self.calc_G_loss() 67 | 68 | self.optimizer_E.zero_grad() 69 | self.optimizer_S.zero_grad() 70 | self.optimizer_G.zero_grad() 71 | self.G_loss.backward() 72 | self.optimizer_E.step() 73 | self.optimizer_S.step() 74 | self.optimizer_G.step() 75 | 76 | train_dict = {} 77 | train_dict['G_loss'] = self.G_loss 78 | train_dict['G_Percept'] = self.G_percept 79 | train_dict['G_Contrast'] = self.G_contrast 80 | 81 | train_dict['style_img'] = self.real_B 82 | train_dict['fake_AtoB'] = self.trs_AtoB 83 | train_dict['fake_AtoB_high'] = self.trs_AtoB_high 84 | train_dict['fake_AtoB_low'] = self.trs_AtoB_low 85 | 86 | return train_dict 87 | 88 | def set_requires_grad(self, nets, requires_grad=False): 89 | if not isinstance(nets, list): 90 | nets = [nets] 91 | for net in nets: 92 | if net is not None: 93 | for param in net.parameters(): 94 | param.requires_grad = requires_grad 95 | 96 | 97 | class AesFA_test(nn.Module): 98 | def __init__(self, config): 99 | super(AesFA_test, self).__init__() 100 | 101 | self.netE = networks.define_network(net_type='Encoder', config=config) 102 | self.netS = networks.define_network(net_type='Encoder', config=config) 103 | self.netG = networks.define_network(net_type='Generator', config=config) 104 | 105 | def forward(self, real_A, real_B, freq): 106 | with torch.no_grad(): 107 | start = time.time() 108 | content_A = self.netE.forward_test(real_A, 'content') 109 | style_B = self.netS.forward_test(real_B, 'style') 110 | if freq: 111 | trs_AtoB, trs_AtoB_high, trs_AtoB_low = self.netG(content_A, style_B) 112 | end = time.time() 113 | during = end - start 114 | return trs_AtoB, trs_AtoB_high, trs_AtoB_low, during 115 | else: 116 | trs_AtoB = self.netG.forward_test(content_A, style_B) 117 | end = time.time() 118 | during = end - start 119 | return trs_AtoB, during 120 | 121 | def style_blending(self, real_A, real_B_1, real_B_2): 122 | with torch.no_grad(): 123 | start = time.time() 124 | content_A = self.netE.forward_test(real_A, 'content') 125 | style_B1_h = self.netS.forward_test(real_B_1, 'style')[0] 126 | style_B2_l = self.netS.forward_test(real_B_2, 'style')[1] 127 | style_B = style_B1_h, style_B2_l 128 | 129 | trs_AtoB = self.netG.forward_test(content_A, style_B) 130 | end = time.time() 131 | during = end - start 132 | 133 | return trs_AtoB, during -------------------------------------------------------------------------------- /test_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import glob 6 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop, Normalize, RandomCrop 7 | 8 | from Config import Config 9 | from DataSplit import DataSplit 10 | from model import AesFA_test 11 | from blocks import test_model_load 12 | 13 | 14 | def load_img(img_name, img_size, device): 15 | img = Image.open(img_name).convert('RGB') 16 | img = do_transform(img, img_size).to(device) 17 | if len(img.shape) == 3: 18 | img = img.unsqueeze(0) 19 | return img 20 | 21 | def im_convert(tensor): 22 | image = tensor.to("cpu").clone().detach().numpy() 23 | image = image.transpose(0, 2, 3, 1) 24 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) 25 | image = image.clip(0, 1) 26 | 27 | return image 28 | 29 | def do_transform(img, osize): 30 | # if config.phase == 'test': 31 | # osize = config.test_load_size 32 | # elif config.phase == 'style_blending': 33 | # osize = config.blend_load_size 34 | transform = Compose([Resize(size=osize), 35 | CenterCrop(size=osize), 36 | ToTensor(), 37 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 38 | return transform(img) 39 | 40 | def save_img(config, cont_name, sty_name, content, style, stylized, freq=False, high=None, low=None): 41 | real_A = im_convert(content) 42 | real_B = im_convert(style) 43 | trs_AtoB = im_convert(stylized) 44 | 45 | A_image = Image.fromarray((real_A[0] * 255.0).astype(np.uint8)) 46 | B_image = Image.fromarray((real_B[0] * 255.0).astype(np.uint8)) 47 | trs_image = Image.fromarray((trs_AtoB[0] * 255.0).astype(np.uint8)) 48 | 49 | if config.phase == 'test': 50 | A_image.save('{}/content/{:s}_content.jpg'.format(config.img_dir, cont_name.stem)) 51 | B_image.save('{}/style/{:s}_style.jpg'.format(config.img_dir, sty_name.stem)) 52 | trs_image.save('{}/stylized/{:s}_stylized_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem)) 53 | else: 54 | A_image.save('{}/content/{:s}_content.jpg'.format(config.img_dir, cont_name)) 55 | B_image.save('{}/style/{:s}_style.jpg'.format(config.img_dir, sty_name)) 56 | trs_image.save('{}/stylized/{:s}_stylized_{:s}.jpg'.format(config.img_dir, cont_name, sty_name)) 57 | 58 | if freq: 59 | trs_AtoB_high = im_convert(high) 60 | trs_AtoB_low = im_convert(low) 61 | 62 | trsh_image = Image.fromarray((trs_AtoB_high[0] * 255.0).astype(np.uint8)) 63 | trsl_image = Image.fromarray((trs_AtoB_low[0] * 255.0).astype(np.uint8)) 64 | 65 | trsh_image.save('{}/{:s}_stylizing_high_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem)) 66 | trsl_image.save('{}/{:s}_stylizing_low_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem)) 67 | 68 | 69 | def main(): 70 | config = Config() 71 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 72 | print('Version:', config.file_n) 73 | print(device) 74 | 75 | with torch.no_grad(): 76 | ## Model load 77 | ckpt = config.ckpt_dir + '/main.pth' 78 | print("checkpoint: ", ckpt) 79 | model = AesFA_test(config) 80 | model = test_model_load(checkpoint=ckpt, model=model) 81 | model.to(device) 82 | 83 | if not os.path.exists(config.img_dir): 84 | os.makedirs(config.img_dir) 85 | os.makedirs(config.img_dir+'/content') 86 | os.makedirs(config.img_dir+'/style') 87 | os.makedirs(config.img_dir+'/stylized') 88 | 89 | ## Start Testing 90 | count = 0 91 | t_during = 0 92 | if config.phase == 'test': 93 | ## Data Loader 94 | test_data = DataSplit(config=config, phase='test') 95 | contents = test_data.images 96 | styles = test_data.style_images 97 | print("# of contents:", len(contents)) 98 | print("# of styles:", len(styles)) 99 | 100 | for idx in range(len(contents)): 101 | cont_name = contents[idx] 102 | content = load_img(cont_name, config.test_content_size, device) 103 | 104 | for i in range(len(styles)): 105 | sty_name = styles[i] 106 | style = load_img(sty_name, config.test_style_size, device) 107 | 108 | freq = False 109 | if freq: 110 | stylized, stylized_high, stylized_low, during = model(content, style, freq) 111 | save_img(config, cont_name, sty_name, content, style, stylized, freq, stylized_high, stylized_low) 112 | else: 113 | stylized, during = model(content, style, freq) 114 | save_img(config, cont_name, sty_name, content, style, stylized) 115 | 116 | count += 1 117 | print(count, idx+1, i+1, during) 118 | t_during += during 119 | 120 | elif config.phase == 'style_blending': 121 | contents = sorted(glob.glob(config.blend_dir+'/content/*.jpg')) 122 | for content in contents: 123 | cont_name = content.split('/')[-1].split('.')[0] 124 | content = load_img(cont_name, config.blend_load_size, device) 125 | 126 | style_h = config.style_high_img 127 | style_l = config.style_low_img 128 | 129 | sty_name = style_h.split('/')[-1].split('.')[0] 130 | style_h = Image.open(style_h).convert('RGB') 131 | style_h = do_transform(config, style_h).to(device) 132 | if len(style_h.shape) == 3: 133 | style_h = style_h.unsqueeze(0) 134 | 135 | style_l = Image.open(style_l).convert('RGB') 136 | style_l = do_transform(config, style_l).to(device) 137 | if len(style_l.shape) == 3: 138 | style_l = style_l.unsqueeze(0) 139 | 140 | stylized, during = model.style_blending(content, style_h, style_l) 141 | save_img(config, cont_name, sty_name, content, style_h, stylized) 142 | 143 | t_during = float(t_during / (len(contents) * len(styles))) 144 | print("[AesFA] Total images:", len(contents) * len(styles), "Avg Testing time:", t_during) 145 | 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import thop 5 | from PIL import Image 6 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop, Normalize, RandomCrop 7 | 8 | from Config import Config 9 | from DataSplit import DataSplit 10 | from model import AesFA_test 11 | from blocks import test_model_load 12 | 13 | 14 | def load_img(img_name, img_size, device): 15 | img = Image.open(img_name).convert('RGB') 16 | img = do_transform(img, img_size).to(device) 17 | if len(img.shape) == 3: 18 | img = img.unsqueeze(0) # make batch dimension 19 | return img 20 | 21 | def im_convert(tensor): 22 | image = tensor.to("cpu").clone().detach().numpy() 23 | image = image.transpose(0, 2, 3, 1) 24 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) 25 | image = image.clip(0, 1) 26 | return image 27 | 28 | def do_transform(img, osize): 29 | transform = Compose([Resize(size=osize), # Resize to keep aspect ratio 30 | CenterCrop(size=osize), 31 | ToTensor(), 32 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 33 | return transform(img) 34 | 35 | def save_img(config, cont_name, sty_name, content, style, stylized, freq=False, high=None, low=None): 36 | real_A = im_convert(content) 37 | real_B = im_convert(style) 38 | trs_AtoB = im_convert(stylized) 39 | 40 | A_image = Image.fromarray((real_A[0] * 255.0).astype(np.uint8)) 41 | B_image = Image.fromarray((real_B[0] * 255.0).astype(np.uint8)) 42 | trs_image = Image.fromarray((trs_AtoB[0] * 255.0).astype(np.uint8)) 43 | 44 | A_image.save('{}/{:s}_content_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem)) 45 | B_image.save('{}/{:s}_style_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem)) 46 | trs_image.save('{}/{:s}_stylized_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem)) 47 | 48 | if freq: 49 | trs_AtoB_high = im_convert(high) 50 | trs_AtoB_low = im_convert(low) 51 | 52 | trsh_image = Image.fromarray((trs_AtoB_high[0] * 255.0).astype(np.uint8)) 53 | trsl_image = Image.fromarray((trs_AtoB_low[0] * 255.0).astype(np.uint8)) 54 | 55 | trsh_image.save('{}/{:s}_stylizing_high_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem)) 56 | trsl_image.save('{}/{:s}_stylizing_low_{:s}.jpg'.format(config.img_dir, cont_name.stem, sty_name.stem)) 57 | 58 | 59 | def main(): 60 | config = Config() 61 | 62 | config.gpu = 0 63 | device = torch.device('cuda:'+str(config.gpu) if torch.cuda.is_available() else 'cpu') 64 | print('Version:', config.file_n) 65 | print(device) 66 | 67 | with torch.no_grad(): 68 | ## Data Loader 69 | test_bs = 1 70 | test_data = DataSplit(config=config, phase='test') 71 | data_loader_test = torch.utils.data.DataLoader(test_data, batch_size=test_bs, shuffle=False, num_workers=16, pin_memory=False) 72 | print("Test: ", test_data.__len__(), "images: ", len(data_loader_test), "x", test_bs, "(batch size) =", test_data.__len__()) 73 | 74 | ## Model load 75 | ckpt = config.ckpt_dir + '/main.pth' 76 | print("checkpoint: ", ckpt) 77 | model = AesFA_test(config) 78 | model = test_model_load(checkpoint=ckpt, model=model) 79 | model.to(device) 80 | 81 | if not os.path.exists(config.img_dir): 82 | os.makedirs(config.img_dir) 83 | 84 | ## Start Testing 85 | freq = False # whether save high, low frequency images or not 86 | count = 0 87 | t_during = 0 88 | 89 | contents = test_data.images 90 | styles = test_data.style_images 91 | if config.multi_to_multi: # one content image, N style image 92 | tot_imgs = len(contents) * len(styles) 93 | for idx in range(len(contents)): 94 | cont_name = contents[idx] # path of content image 95 | content = load_img(cont_name, config.test_content_size, device) 96 | 97 | for i in range(len(styles)): 98 | sty_name = styles[i] # path of style image 99 | style = load_img(sty_name, config.test_style_size, device) 100 | 101 | if freq: 102 | stylized, stylized_high, stylized_low, during = model(content, style, freq) 103 | save_img(config, cont_name, sty_name, content, style, stylized, freq, stylized_high, stylized_low) 104 | else: 105 | stylized, during = model(content, style, freq) 106 | save_img(config, cont_name, sty_name, content, style, stylized) 107 | 108 | count += 1 109 | print(count, idx+1, i+1, during) 110 | t_during += during 111 | flops, params = thop.profile(model, inputs=(content, style, freq)) 112 | print("GFLOPS: %.4f, Params: %.4f"% (flops/1e9, params/1e6)) 113 | print("Max GPU memory allocated: %.4f GB" % (torch.cuda.max_memory_allocated(device=config.gpu) / 1024. / 1024. / 1024.)) 114 | 115 | else: 116 | tot_imgs = len(contents) 117 | for idx in range(len(contents)): 118 | cont_name = contents[idx] 119 | content = load_img(cont_name, config.test_content_size, device) 120 | 121 | sty_name = styles[idx] 122 | style = load_img(sty_name, config.test_style_size, device) 123 | 124 | if freq: 125 | stylized, stylized_high, stylized_low, during = model(content, style, freq) 126 | save_img(config, cont_name, sty_name, content, style, stylized, freq, stylized_high, stylized_low) 127 | else: 128 | stylized, during = model(content, style, freq) 129 | save_img(config, cont_name, sty_name, content, style, stylized) 130 | 131 | t_during += during 132 | flops, params = thop.profile(model, inputs=(content, style, freq)) 133 | print("GFLOPS: %.4f, Params: %.4f" % (flops / 1e9, params / 1e6)) 134 | print("Max GPU memory allocated: %.4f GB" % (torch.cuda.max_memory_allocated(device=config.gpu) / 1024. / 1024. / 1024.)) 135 | 136 | 137 | t_during = float(t_during / (len(contents) * len(styles))) 138 | print("[AesFA] Content size:", config.test_content_size, "Style size:", config.test_style_size, 139 | " Total images:", tot_imgs, "Avg Testing time:", t_during) 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /vgg19.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | vgg = nn.Sequential( 6 | nn.Conv2d(3, 3, (1, 1)), 7 | nn.ReflectionPad2d((1, 1, 1, 1)), 8 | nn.Conv2d(3, 64, (3, 3)), 9 | 10 | nn.ReLU(), # relu1-1 11 | 12 | nn.ReflectionPad2d((1, 1, 1, 1)), 13 | nn.Conv2d(64, 64, (3, 3)), 14 | nn.ReLU(), # relu1-2 15 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 16 | nn.ReflectionPad2d((1, 1, 1, 1)), 17 | nn.Conv2d(64, 128, (3, 3)), 18 | 19 | nn.ReLU(), # relu2-1 20 | 21 | nn.ReflectionPad2d((1, 1, 1, 1)), 22 | nn.Conv2d(128, 128, (3, 3)), 23 | nn.ReLU(), # relu2-2 24 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(128, 256, (3, 3)), 27 | 28 | nn.ReLU(), # relu3-1 29 | 30 | nn.ReflectionPad2d((1, 1, 1, 1)), 31 | nn.Conv2d(256, 256, (3, 3)), 32 | nn.ReLU(), # relu3-2 33 | nn.ReflectionPad2d((1, 1, 1, 1)), 34 | nn.Conv2d(256, 256, (3, 3)), 35 | nn.ReLU(), # relu3-3 36 | nn.ReflectionPad2d((1, 1, 1, 1)), 37 | nn.Conv2d(256, 256, (3, 3)), 38 | nn.ReLU(), # relu3-4 39 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 40 | nn.ReflectionPad2d((1, 1, 1, 1)), 41 | nn.Conv2d(256, 512, (3, 3)), 42 | 43 | nn.ReLU(), # relu4-1, this is the last layer used 44 | 45 | nn.ReflectionPad2d((1, 1, 1, 1)), 46 | nn.Conv2d(512, 512, (3, 3)), 47 | nn.ReLU(), # relu4-2 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(512, 512, (3, 3)), 50 | nn.ReLU(), # relu4-3 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(512, 512, (3, 3)), 53 | nn.ReLU(), # relu4-4 54 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 55 | nn.ReflectionPad2d((1, 1, 1, 1)), 56 | nn.Conv2d(512, 512, (3, 3)), 57 | 58 | nn.ReLU(), # relu5-1 59 | 60 | nn.ReflectionPad2d((1, 1, 1, 1)), 61 | nn.Conv2d(512, 512, (3, 3)), 62 | nn.ReLU(), # relu5-2 63 | nn.ReflectionPad2d((1, 1, 1, 1)), 64 | nn.Conv2d(512, 512, (3, 3)), 65 | nn.ReLU(), # relu5-3 66 | nn.ReflectionPad2d((1, 1, 1, 1)), 67 | nn.Conv2d(512, 512, (3, 3)), 68 | nn.ReLU() # relu5-4 69 | ) 70 | 71 | 72 | class VGG_loss(nn.Module): 73 | def __init__(self, config, vgg): 74 | super(VGG_loss, self).__init__() 75 | 76 | self.config = config 77 | 78 | vgg_pretrained = config.vgg_model 79 | vgg.load_state_dict(torch.load(vgg_pretrained)) 80 | vgg = nn.Sequential(*list(vgg.children())[:43]) # depends on what layers you want to load 81 | vgg_enc_layers = list(vgg.children()) 82 | 83 | self.n_layers = 4 84 | self.vgg_enc_1 = nn.Sequential(*vgg_enc_layers[:3]) # ~ conv1_1 85 | self.vgg_enc_2 = nn.Sequential(*vgg_enc_layers[3:10]) # conv1_1 ~ conv2_1 86 | self.vgg_enc_3 = nn.Sequential(*vgg_enc_layers[10:17]) # conv2_1 ~ conv3_1 87 | self.vgg_enc_4 = nn.Sequential(*vgg_enc_layers[17:30]) # conv3_1 ~ conv4_1 88 | 89 | self.mse_loss = nn.MSELoss() 90 | 91 | for name in ['vgg_enc_1', 'vgg_enc_2', 'vgg_enc_3', 'vgg_enc_4']: 92 | for param in getattr(self, name).parameters(): 93 | param.requires_grad = False 94 | 95 | # extract relu1_1, relu2_1, relu3_1, relu4_1 96 | def encode_with_vgg_intermediate(self, input): 97 | results = [input] 98 | for i in range(self.n_layers): 99 | func = getattr(self, 'vgg_enc_{:d}'.format(i + 1)) 100 | results.append(func(results[-1])) 101 | return results[1:] 102 | 103 | # extract relu3_1 104 | def encode_vgg_content(self, input): 105 | for i in range(3): 106 | input = getattr(self, 'vgg_enc_{:d}'.format(i + 1))(input) 107 | return input 108 | 109 | def calc_content_loss(self, input, target): 110 | assert (input.size() == target.size()) 111 | return self.mse_loss(input, target) 112 | 113 | def efdm_single(self, style, trans): 114 | B, C, W, H = style.size(0), style.size(1), style.size(2), style.size(3) 115 | 116 | value_style, index_style = torch.sort(style.view(B, C, -1)) 117 | value_trans, index_trans = torch.sort(trans.view(B, C, -1)) 118 | inverse_index = index_trans.argsort(-1) 119 | 120 | return self.mse_loss(trans.view(B, C,-1), value_style.gather(-1, inverse_index)) 121 | 122 | def perceptual_loss(self, content, style, trs_img): 123 | # normalization for putting images as inputs to VGG 124 | content = content.permute(0, 2, 3, 1) 125 | style = style.permute(0, 2, 3, 1) 126 | trs_img = trs_img.permute(0, 2, 3, 1) 127 | 128 | content = content * torch.from_numpy(np.array((0.229, 0.224, 0.225))).to(content.device) + torch.from_numpy(np.array((0.485, 0.456, 0.406))).to(content.device) 129 | style = style * torch.from_numpy(np.array((0.229, 0.224, 0.225))).to(style.device) + torch.from_numpy(np.array((0.485, 0.456, 0.406))).to(style.device) 130 | trs_img = trs_img * torch.from_numpy(np.array((0.229, 0.224, 0.225))).to(trs_img.device) + torch.from_numpy(np.array((0.485, 0.456, 0.406))).to(trs_img.device) 131 | 132 | content = content.permute(0, 3, 1, 2).float() 133 | style = style.permute(0, 3, 1, 2).float() 134 | trs_img = trs_img.permute(0, 3, 1, 2).float() 135 | 136 | # loss 137 | content_feats_vgg = self.encode_vgg_content(content) 138 | style_feats_vgg = self.encode_with_vgg_intermediate(style) 139 | trs_feats_vgg = self.encode_with_vgg_intermediate(trs_img) 140 | 141 | loss_c = self.calc_content_loss(trs_feats_vgg[-2], content_feats_vgg) 142 | loss_s = self.efdm_single(trs_feats_vgg[0], style_feats_vgg[0]) 143 | for i in range(1, self.n_layers): 144 | loss_s = loss_s + self.efdm_single(trs_feats_vgg[i], style_feats_vgg[i]) 145 | 146 | loss = loss_c * self.config.lambda_perc_cont + loss_s * self.config.lambda_perc_style 147 | 148 | # EFDM negative pair 149 | neg_idx = [] 150 | batch = content.shape[0] 151 | for a in range(batch): 152 | neg_lst = {} 153 | for b in range(batch): # for each image pair 154 | if a != b: 155 | loss_s_single = 0 156 | for i in range(0, self.n_layers): # for each vgg layer 157 | loss_s_single += self.efdm_single(trs_feats_vgg[i][a].unsqueeze(0), style_feats_vgg[i][b].unsqueeze(0)) 158 | neg_lst[b] = loss_s_single 159 | neg_lst = sorted(neg_lst, key=neg_lst.get) 160 | # neg_idx.append(neg_lst[:3]) 161 | neg_idx.append([neg_lst[0]]) 162 | 163 | return loss, neg_idx -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from blocks import * 6 | 7 | def define_network(net_type, config = None): 8 | net = None 9 | alpha_in = config.alpha_in 10 | alpha_out = config.alpha_out 11 | sk = config.style_kernel 12 | 13 | if net_type == 'Encoder': 14 | net = Encoder(in_dim=config.input_nc, nf=config.nf, style_kernel=[sk, sk], alpha_in=alpha_in, alpha_out=alpha_out) 15 | elif net_type == 'Generator': 16 | net = Decoder(nf=config.nf, out_dim=config.output_nc, style_channel=256, style_kernel=[sk, sk, 3], alpha_in=alpha_in, freq_ratio=config.freq_ratio, alpha_out=alpha_out) 17 | return net 18 | 19 | class Encoder(nn.Module): 20 | def __init__(self, in_dim, nf=64, style_kernel=[3, 3], alpha_in=0.5, alpha_out=0.5): 21 | super(Encoder, self).__init__() 22 | 23 | self.conv = nn.Conv2d(in_channels=in_dim, out_channels=nf, kernel_size=7, stride=1, padding=3) 24 | 25 | self.OctConv1_1 = OctConv(in_channels=nf, out_channels=nf, kernel_size=3, stride=2, padding=1, groups=64, alpha_in=alpha_in, alpha_out=alpha_out, type="first") 26 | self.OctConv1_2 = OctConv(in_channels=nf, out_channels=2*nf, kernel_size=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 27 | self.OctConv1_3 = OctConv(in_channels=2*nf, out_channels=2*nf, kernel_size=3, stride=1, padding=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 28 | 29 | self.OctConv2_1 = OctConv(in_channels=2*nf, out_channels=2*nf, kernel_size=3, stride=2, padding=1, groups=128, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 30 | self.OctConv2_2 = OctConv(in_channels=2*nf, out_channels=4*nf, kernel_size=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 31 | self.OctConv2_3 = OctConv(in_channels=4*nf, out_channels=4*nf, kernel_size=3, stride=1, padding=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 32 | 33 | self.pool_h = nn.AdaptiveAvgPool2d((style_kernel[0], style_kernel[0])) 34 | self.pool_l = nn.AdaptiveAvgPool2d((style_kernel[1], style_kernel[1])) 35 | 36 | self.relu = Oct_conv_lreLU() 37 | 38 | def forward(self, x): 39 | enc_feat = [] 40 | out = self.conv(x) 41 | 42 | out = self.OctConv1_1(out) 43 | out = self.relu(out) 44 | out = self.OctConv1_2(out) 45 | out = self.relu(out) 46 | out = self.OctConv1_3(out) 47 | out = self.relu(out) 48 | enc_feat.append(out) 49 | 50 | out = self.OctConv2_1(out) 51 | out = self.relu(out) 52 | out = self.OctConv2_2(out) 53 | out = self.relu(out) 54 | out = self.OctConv2_3(out) 55 | out = self.relu(out) 56 | enc_feat.append(out) 57 | 58 | out_high, out_low = out 59 | out_sty_h = self.pool_h(out_high) 60 | out_sty_l = self.pool_l(out_low) 61 | out_sty = out_sty_h, out_sty_l 62 | 63 | return out, out_sty, enc_feat 64 | 65 | def forward_test(self, x, cond): 66 | out = self.conv(x) 67 | 68 | out = self.OctConv1_1(out) 69 | out = self.relu(out) 70 | out = self.OctConv1_2(out) 71 | out = self.relu(out) 72 | out = self.OctConv1_3(out) 73 | out = self.relu(out) 74 | 75 | out = self.OctConv2_1(out) 76 | out = self.relu(out) 77 | out = self.OctConv2_2(out) 78 | out = self.relu(out) 79 | out = self.OctConv2_3(out) 80 | out = self.relu(out) 81 | 82 | if cond == 'style': 83 | out_high, out_low = out 84 | out_sty_h = self.pool_h(out_high) 85 | out_sty_l = self.pool_l(out_low) 86 | return out_sty_h, out_sty_l 87 | else: 88 | return out 89 | 90 | class Decoder(nn.Module): 91 | def __init__(self, nf=64, out_dim=3, style_channel=512, style_kernel=[3, 3, 3], alpha_in=0.5, alpha_out=0.5, freq_ratio=[1,1], pad_type='reflect'): 92 | super(Decoder, self).__init__() 93 | 94 | group_div = [1, 2, 4, 8] 95 | self.up_oct = Oct_conv_up(scale_factor=2) 96 | 97 | self.AdaOctConv1_1 = AdaOctConv(in_channels=4*nf, out_channels=4*nf, group_div=group_div[0], style_channels=style_channel, kernel_size=style_kernel, stride=1, padding=1, oct_groups=4*nf, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 98 | self.OctConv1_2 = OctConv(in_channels=4*nf, out_channels=2*nf, kernel_size=1, stride=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 99 | self.oct_conv_aftup_1 = Oct_Conv_aftup(in_channels=2*nf, out_channels=2*nf, kernel_size=3, stride=1, padding=1, pad_type=pad_type, alpha_in=alpha_in, alpha_out=alpha_out) 100 | 101 | self.AdaOctConv2_1 = AdaOctConv(in_channels=2*nf, out_channels=2*nf, group_div=group_div[1], style_channels=style_channel, kernel_size=style_kernel, stride=1, padding=1, oct_groups=2*nf, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 102 | self.OctConv2_2 = OctConv(in_channels=2*nf, out_channels=nf, kernel_size=1, stride=1, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 103 | self.oct_conv_aftup_2 = Oct_Conv_aftup(nf, nf, 3, 1, 1, pad_type, alpha_in, alpha_out) 104 | 105 | self.AdaOctConv3_1 = AdaOctConv(in_channels=nf, out_channels=nf, group_div=group_div[2], style_channels=style_channel, kernel_size=style_kernel, stride=1, padding=1, oct_groups=nf, alpha_in=alpha_in, alpha_out=alpha_out, type="normal") 106 | self.OctConv3_2 = OctConv(in_channels=nf, out_channels=nf//2, kernel_size=1, stride=1, alpha_in=alpha_in, alpha_out=alpha_out, type="last", freq_ratio=freq_ratio) 107 | 108 | self.conv4 = nn.Conv2d(in_channels=nf//2, out_channels=out_dim, kernel_size=1) 109 | 110 | def forward(self, content, style): 111 | out = self.AdaOctConv1_1(content, style) 112 | out = self.OctConv1_2(out) 113 | out = self.up_oct(out) 114 | out = self.oct_conv_aftup_1(out) 115 | 116 | out = self.AdaOctConv2_1(out, style) 117 | out = self.OctConv2_2(out) 118 | out = self.up_oct(out) 119 | out = self.oct_conv_aftup_2(out) 120 | 121 | out = self.AdaOctConv3_1(out, style) 122 | out = self.OctConv3_2(out) 123 | out, out_high, out_low = out 124 | 125 | out = self.conv4(out) 126 | out_high = self.conv4(out_high) 127 | out_low = self.conv4(out_low) 128 | 129 | return out, out_high, out_low 130 | 131 | def forward_test(self, content, style): 132 | out = self.AdaOctConv1_1(content, style, 'test') 133 | out = self.OctConv1_2(out) 134 | out = self.up_oct(out) 135 | out = self.oct_conv_aftup_1(out) 136 | 137 | out = self.AdaOctConv2_1(out, style, 'test') 138 | out = self.OctConv2_2(out) 139 | out = self.up_oct(out) 140 | out = self.oct_conv_aftup_2(out) 141 | 142 | out = self.AdaOctConv3_1(out, style, 'test') 143 | out = self.OctConv3_2(out) 144 | 145 | out = self.conv4(out[0]) 146 | return out 147 | 148 | 149 | ############## Contrastive Loss function ############## 150 | def calc_mean_std(feat, eps=1e-5): 151 | # eps is a small value added to the variance to avoid divide-by-zero. 152 | size = feat.size() 153 | assert (len(size) == 4) 154 | N, C = size[:2] 155 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 156 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 157 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 158 | return feat_mean, feat_std 159 | 160 | def calc_content_loss(input, target): 161 | assert (input.size() == target.size()) 162 | mse_loss = nn.MSELoss() 163 | return mse_loss(input, target) 164 | 165 | def calc_style_loss(input, target): 166 | assert (input.size() == target.size()) 167 | mse_loss = nn.MSELoss() 168 | input_mean, input_std = calc_mean_std(input) 169 | target_mean, target_std = calc_mean_std(target) 170 | 171 | loss = mse_loss(input_mean, target_mean) + \ 172 | mse_loss(input_std, target_std) 173 | return loss 174 | 175 | class EFDM_loss(nn.Module): 176 | def __init__(self): 177 | super(EFDM_loss, self).__init__() 178 | self.mse_loss = nn.MSELoss() 179 | 180 | def efdm_single(self, style, trans): 181 | B, C, W, H = style.size(0), style.size(1), style.size(2), style.size(3) 182 | 183 | value_style, index_style = torch.sort(style.view(B, C, -1)) 184 | value_trans, index_trans = torch.sort(trans.view(B, C, -1)) 185 | inverse_index = index_trans.argsort(-1) 186 | 187 | return self.mse_loss(trans.view(B, C,-1), value_style.gather(-1, inverse_index)) 188 | 189 | def forward(self, style_E, style_S, translate_E, translate_S, neg_idx): 190 | loss = 0. 191 | batch = style_E[0][0].shape[0] 192 | for b in range(batch): 193 | poss_loss = 0. 194 | neg_loss = 0. 195 | 196 | # Positive loss 197 | for i in range(len(style_E)): 198 | poss_loss += self.efdm_single(style_E[i][0][b].unsqueeze(0), translate_E[i][0][b].unsqueeze(0)) + \ 199 | self.efdm_single(style_E[i][1][b].unsqueeze(0), translate_E[i][1][b].unsqueeze(0)) 200 | for i in range(len(style_S)): 201 | poss_loss += self.efdm_single(style_S[i][0][b].unsqueeze(0), translate_S[i][0][b].unsqueeze(0)) + \ 202 | self.efdm_single(style_S[i][1][b].unsqueeze(0), translate_S[i][1][b].unsqueeze(0)) 203 | 204 | # Negative loss 205 | for nb in neg_idx[b]: 206 | for i in range(len(style_E)): 207 | neg_loss += self.efdm_single(style_E[i][0][nb].unsqueeze(0), translate_E[i][0][b].unsqueeze(0)) + \ 208 | self.efdm_single(style_E[i][1][nb].unsqueeze(0), translate_E[i][1][b].unsqueeze(0)) 209 | for i in range(len(style_S)): 210 | neg_loss += self.efdm_single(style_S[i][0][nb].unsqueeze(0), translate_S[i][0][b].unsqueeze(0)) + \ 211 | self.efdm_single(style_S[i][1][nb].unsqueeze(0), translate_S[i][1][b].unsqueeze(0)) 212 | 213 | loss += poss_loss / neg_loss 214 | 215 | return loss -------------------------------------------------------------------------------- /blocks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from path import Path 4 | import math 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.optim import lr_scheduler 9 | from blocks import * 10 | 11 | def model_save(ckpt_dir, model, optim_E, optim_S, optim_G, epoch, itr=None): 12 | if not os.path.exists(ckpt_dir): 13 | os.makedirs(ckpt_dir) 14 | 15 | torch.save({'netE': model.netE.state_dict(), 16 | 'netS': model.netS.state_dict(), 17 | 'netG': model.netG.state_dict(), 18 | 'optim_E': optim_E.state_dict(), 19 | 'optim_S': optim_S.state_dict(), 20 | 'optim_G': optim_G.state_dict()}, 21 | '%s/model_iter_%d_epoch_%d.pth' % (ckpt_dir, itr+1, epoch+1)) 22 | 23 | def model_load(checkpoint, ckpt_dir, model, optim_E, optim_S, optim_G): 24 | if not os.path.exists(ckpt_dir): 25 | epoch = -1 26 | return model, optim_E, optim_S, optim_G, epoch 27 | 28 | ckpt_path = Path(ckpt_dir) 29 | if checkpoint: 30 | model_ckpt = ckpt_path + '/' + checkpoint 31 | else: 32 | ckpt_lst = ckpt_path.glob('model_iter_*') 33 | ckpt_lst.sort(key=lambda x: int(x.split('iter_')[1].split('_epoch')[0])) 34 | model_ckpt = ckpt_lst[-1] 35 | itr = int(model_ckpt.split('iter_')[1].split('_epoch_')[0]) 36 | epoch = int(model_ckpt.split('iter_')[1].split('_epoch_')[1].split('.')[0]) 37 | print(model_ckpt) 38 | 39 | dict_model = torch.load(model_ckpt) 40 | 41 | model.netE.load_state_dict(dict_model['netE']) 42 | model.netS.load_state_dict(dict_model['netS']) 43 | model.netG.load_state_dict(dict_model['netG']) 44 | optim_E.load_state_dict(dict_model['optim_E']) 45 | optim_S.load_state_dict(dict_model['optim_S']) 46 | optim_G.load_state_dict(dict_model['optim_G']) 47 | 48 | return model, optim_E, optim_S, optim_G, epoch, itr 49 | 50 | def test_model_load(checkpoint, model): 51 | dict_model = torch.load(checkpoint) 52 | model.netE.load_state_dict(dict_model['netE']) 53 | model.netS.load_state_dict(dict_model['netS']) 54 | model.netG.load_state_dict(dict_model['netG']) 55 | return model 56 | 57 | def get_scheduler(optimizer, config): 58 | if config.lr_policy == 'lambda': 59 | def lambda_rule(epoch): 60 | lr_l = 1.0 - max(0, epoch + config.n_epoch - config.n_iter) / float(config.n_iter_decay + 1) 61 | return lr_l 62 | 63 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 64 | elif config.lr_policy == 'step': 65 | scheduler = lr_scheduler.StepLR(optimizer, step_size=config.lr_decay_iters, gamma=0.1) 66 | elif config.lr_policy == 'plateau': 67 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 68 | elif config.lr_policy == 'cosine': 69 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.n_iter, eta_min=0) 70 | else: 71 | return NotImplementedError('learning rate policy [%s] is not implemented', config.lr_policy) 72 | return scheduler 73 | 74 | def update_learning_rate(scheduler, optimizer): 75 | scheduler.step() 76 | lr = optimizer.param_groups[0]['lr'] 77 | print('learning rate = %.7f' % lr) 78 | 79 | class Oct_Conv_aftup(nn.Module): 80 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, pad_type, alpha_in, alpha_out): 81 | super(Oct_Conv_aftup, self).__init__() 82 | lf_in = int(in_channels*alpha_in) 83 | lf_out = int(out_channels*alpha_out) 84 | hf_in = in_channels - lf_in 85 | hf_out = out_channels - lf_out 86 | 87 | self.conv_h = nn.Conv2d(in_channels=hf_in, out_channels=hf_out, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type) 88 | self.conv_l = nn.Conv2d(in_channels=lf_in, out_channels=lf_out, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type) 89 | 90 | def forward(self, x): 91 | hf, lf = x 92 | hf = self.conv_h(hf) 93 | lf = self.conv_l(lf) 94 | return hf, lf 95 | 96 | class Oct_conv_reLU(nn.ReLU): 97 | def forward(self, x): 98 | hf, lf = x 99 | hf = super(Oct_conv_reLU, self).forward(hf) 100 | lf = super(Oct_conv_reLU, self).forward(lf) 101 | return hf, lf 102 | 103 | class Oct_conv_lreLU(nn.LeakyReLU): 104 | def forward(self, x): 105 | hf, lf = x 106 | hf = super(Oct_conv_lreLU, self).forward(hf) 107 | lf = super(Oct_conv_lreLU, self).forward(lf) 108 | return hf, lf 109 | 110 | class Oct_conv_up(nn.Upsample): 111 | def forward(self, x): 112 | hf, lf = x 113 | hf = super(Oct_conv_up, self).forward(hf) 114 | lf = super(Oct_conv_up, self).forward(lf) 115 | return hf, lf 116 | 117 | 118 | ############## Encoder ############## 119 | class OctConv(nn.Module): 120 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 121 | padding=0, groups=1, pad_type='reflect', alpha_in=0.5, alpha_out=0.5, type='normal', freq_ratio = [1, 1]): 122 | super(OctConv, self).__init__() 123 | self.kernel_size = kernel_size 124 | self.stride = stride 125 | self.type = type 126 | self.alpha_in = alpha_in 127 | self.alpha_out = alpha_out 128 | self.freq_ratio = freq_ratio 129 | 130 | hf_ch_in = int(in_channels * (1 - self.alpha_in)) 131 | hf_ch_out = int(out_channels * (1 -self. alpha_out)) 132 | lf_ch_in = in_channels - hf_ch_in 133 | lf_ch_out = out_channels - hf_ch_out 134 | 135 | self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) 136 | self.upsample = nn.Upsample(scale_factor=2) 137 | 138 | self.is_dw = groups == in_channels 139 | 140 | if type == 'first': 141 | self.convh = nn.Conv2d(in_channels, hf_ch_out, kernel_size=kernel_size, 142 | stride=stride, padding=padding, padding_mode=pad_type, bias = False) 143 | self.convl = nn.Conv2d(in_channels, lf_ch_out, 144 | kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False) 145 | elif type == 'last': 146 | self.convh = nn.Conv2d(hf_ch_in, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False) 147 | self.convl = nn.Conv2d(lf_ch_in, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False) 148 | else: 149 | self.L2L = nn.Conv2d( 150 | lf_ch_in, lf_ch_out, 151 | kernel_size=kernel_size, stride=stride, padding=padding, groups=math.ceil(alpha_in * groups), padding_mode=pad_type, bias=False 152 | ) 153 | if self.is_dw: 154 | self.L2H = None 155 | self.H2L = None 156 | else: 157 | self.L2H = nn.Conv2d( 158 | lf_ch_in, hf_ch_out, 159 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, padding_mode=pad_type, bias=False 160 | ) 161 | self.H2L = nn.Conv2d( 162 | hf_ch_in, lf_ch_out, 163 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, padding_mode=pad_type, bias=False 164 | ) 165 | self.H2H = nn.Conv2d( 166 | hf_ch_in, hf_ch_out, 167 | kernel_size=kernel_size, stride=stride, padding=padding, groups=math.ceil(groups - alpha_in * groups), padding_mode=pad_type, bias=False 168 | ) 169 | 170 | def forward(self, x): 171 | if self.type == 'first': 172 | hf = self.convh(x) 173 | lf = self.avg_pool(x) 174 | lf = self.convl(lf) 175 | return hf, lf 176 | elif self.type == 'last': 177 | hf, lf = x 178 | out_h = self.convh(hf) 179 | out_l = self.convl(self.upsample(lf)) 180 | output = out_h * self.freq_ratio[0] + out_l * self.freq_ratio[1] 181 | return output, out_h, out_l 182 | else: 183 | hf, lf = x 184 | if self.is_dw: 185 | hf, lf = self.H2H(hf), self.L2L(lf) 186 | else: 187 | hf, lf = self.H2H(hf) + self.L2H(self.upsample(lf)), self.L2L(lf) + self.H2L(self.avg_pool(hf)) 188 | return hf, lf 189 | 190 | 191 | ############## Decoder ############## 192 | class AdaOctConv(nn.Module): 193 | def __init__(self, in_channels, out_channels, group_div, style_channels, kernel_size, 194 | stride, padding, oct_groups, alpha_in, alpha_out, type='normal'): 195 | super(AdaOctConv, self).__init__() 196 | self.in_channels = in_channels 197 | self.alpha_in = alpha_in 198 | self.alpha_out = alpha_out 199 | self.type = type 200 | 201 | h_in = int(in_channels * (1 - self.alpha_in)) 202 | l_in = in_channels - h_in 203 | 204 | n_groups_h = h_in // group_div 205 | n_groups_l = l_in // group_div 206 | 207 | style_channels_h = int(style_channels * (1 - self.alpha_in)) 208 | style_channels_l = int(style_channels - style_channels_h) 209 | 210 | kernel_size_h = kernel_size[0] 211 | kernel_size_l = kernel_size[1] 212 | kernel_size_A = kernel_size[2] 213 | 214 | self.kernelPredictor_h = KernelPredictor(in_channels=h_in, 215 | out_channels=h_in, 216 | n_groups=n_groups_h, 217 | style_channels=style_channels_h, 218 | kernel_size=kernel_size_h) 219 | self.kernelPredictor_l = KernelPredictor(in_channels=l_in, 220 | out_channels=l_in, 221 | n_groups=n_groups_l, 222 | style_channels=style_channels_l, 223 | kernel_size=kernel_size_l) 224 | 225 | self.AdaConv_h = AdaConv2d(in_channels=h_in, out_channels=h_in, n_groups=n_groups_h) 226 | self.AdaConv_l = AdaConv2d(in_channels=l_in, out_channels=l_in, n_groups=n_groups_l) 227 | 228 | self.OctConv = OctConv(in_channels=in_channels, 229 | out_channels=out_channels, 230 | kernel_size=kernel_size_A, stride=stride, padding=padding, groups=oct_groups, 231 | alpha_in=alpha_in, alpha_out=alpha_out, type=type) 232 | 233 | self.relu = Oct_conv_lreLU() 234 | 235 | def forward(self, content, style, cond='train'): 236 | c_hf, c_lf = content 237 | s_hf, s_lf = style 238 | h_w_spatial, h_w_pointwise, h_bias = self.kernelPredictor_h(s_hf) 239 | l_w_spatial, l_w_pointwise, l_bias = self.kernelPredictor_l(s_lf) 240 | 241 | if cond == 'train': 242 | output_h = self.AdaConv_h(c_hf, h_w_spatial, h_w_pointwise, h_bias) 243 | output_l = self.AdaConv_l(c_lf, l_w_spatial, l_w_pointwise, l_bias) 244 | output = output_h, output_l 245 | 246 | output = self.relu(output) 247 | 248 | output = self.OctConv(output) 249 | if self.type != 'last': 250 | output = self.relu(output) 251 | return output 252 | 253 | if cond == 'test': 254 | output_h = self.AdaConv_h(c_hf, h_w_spatial, h_w_pointwise, h_bias) 255 | output_l = self.AdaConv_l(c_lf, l_w_spatial, l_w_pointwise, l_bias) 256 | output = output_h, output_l 257 | output = self.relu(output) 258 | output = self.OctConv(output) 259 | if self.type != 'last': 260 | output = self.relu(output) 261 | return output 262 | 263 | class KernelPredictor(nn.Module): 264 | def __init__(self, in_channels, out_channels, n_groups, style_channels, kernel_size): 265 | super(KernelPredictor, self).__init__() 266 | 267 | self.in_channels = in_channels 268 | self.out_channels = out_channels 269 | self.n_groups = n_groups 270 | self.w_channels = style_channels 271 | self.kernel_size = kernel_size 272 | 273 | padding = (kernel_size - 1) / 2 274 | self.spatial = nn.Conv2d(style_channels, 275 | in_channels * out_channels // n_groups, 276 | kernel_size=kernel_size, 277 | padding=(math.ceil(padding), math.ceil(padding)), 278 | padding_mode='reflect') 279 | self.pointwise = nn.Sequential( 280 | nn.AdaptiveAvgPool2d((1, 1)), 281 | nn.Conv2d(style_channels, 282 | out_channels * out_channels // n_groups, 283 | kernel_size=1) 284 | ) 285 | self.bias = nn.Sequential( 286 | nn.AdaptiveAvgPool2d((1, 1)), 287 | nn.Conv2d(style_channels, 288 | out_channels, 289 | kernel_size=1) 290 | ) 291 | 292 | def forward(self, w): 293 | w_spatial = self.spatial(w) 294 | w_spatial = w_spatial.reshape(len(w), 295 | self.out_channels, 296 | self.in_channels // self.n_groups, 297 | self.kernel_size, self.kernel_size) 298 | 299 | w_pointwise = self.pointwise(w) 300 | w_pointwise = w_pointwise.reshape(len(w), 301 | self.out_channels, 302 | self.out_channels // self.n_groups, 303 | 1, 1) 304 | bias = self.bias(w) 305 | bias = bias.reshape(len(w), self.out_channels) 306 | return w_spatial, w_pointwise, bias 307 | 308 | class AdaConv2d(nn.Module): 309 | def __init__(self, in_channels, out_channels, kernel_size=3, n_groups=None): 310 | super(AdaConv2d, self).__init__() 311 | self.n_groups = in_channels if n_groups is None else n_groups 312 | self.in_channels = in_channels 313 | self.out_channels = out_channels 314 | 315 | padding = (kernel_size - 1) / 2 316 | self.conv = nn.Conv2d(in_channels=in_channels, 317 | out_channels=out_channels, 318 | kernel_size=(kernel_size, kernel_size), 319 | padding=(math.ceil(padding), math.floor(padding)), 320 | padding_mode='reflect') 321 | 322 | def forward(self, x, w_spatial, w_pointwise, bias): 323 | assert len(x) == len(w_spatial) == len(w_pointwise) == len(bias) 324 | x = F.instance_norm(x) 325 | 326 | ys = [] 327 | for i in range(len(x)): 328 | y = self.forward_single(x[i:i+1], w_spatial[i], w_pointwise[i], bias[i]) 329 | ys.append(y) 330 | ys = torch.cat(ys, dim=0) 331 | 332 | ys = self.conv(ys) 333 | return ys 334 | 335 | def forward_single(self, x, w_spatial, w_pointwise, bias): 336 | assert w_spatial.size(-1) == w_spatial.size(-2) 337 | padding = (w_spatial.size(-1) - 1) / 2 338 | pad = (math.ceil(padding), math.floor(padding), math.ceil(padding), math.floor(padding)) 339 | 340 | x = F.pad(x, pad=pad, mode='reflect') 341 | x = F.conv2d(x, w_spatial, groups=self.n_groups) 342 | x = F.conv2d(x, w_pointwise, groups=self.n_groups, bias=bias) 343 | return x -------------------------------------------------------------------------------- /Video_NST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9a6eab3c-20b0-414e-91b5-4bab2b625f02", 6 | "metadata": {}, 7 | "source": [ 8 | "# Video to Image\n", 9 | "before style transfer using AesFA" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "id": "ad2a36df-e065-4739-abb8-3f8bc9755ddb", 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "pexels-kelly-7165765-3840x2160-24fps\n", 23 | "319 24.0\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import os\n", 29 | "import cv2\n", 30 | "import numpy as np\n", 31 | "import matplotlib.pyplot as plt\n", 32 | "from PIL import Image\n", 33 | "import glob\n", 34 | "\n", 35 | "resolution = '2048'\n", 36 | "content_dir = './video/{}/'.format(resolution)\n", 37 | "contents_in = sorted(glob.glob(content_dir+'*.mp4'))\n", 38 | "content_img = './video_img/{}/content/'.format(resolution)\n", 39 | "if not os.path.isdir(content_img):\n", 40 | " os.makedirs(content_img)\n", 41 | "\n", 42 | "contents = []\n", 43 | "for content in contents_in:\n", 44 | " content_name = content.split('/')[-1].split('.')[0]\n", 45 | " contents.append(content_name)\n", 46 | " print(content_name)\n", 47 | " \n", 48 | " # Video Capture Sanity Check\n", 49 | " cap = cv2.VideoCapture(content)\n", 50 | " length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", 51 | " fps = cap.get(cv2.CAP_PROP_FPS)\n", 52 | " print(length, fps)\n", 53 | "\n", 54 | " i = 0\n", 55 | " while(cap.isOpened()):\n", 56 | " ret, frame = cap.read()\n", 57 | " try:\n", 58 | " frame = frame[:,:,::-1]\n", 59 | " iframe = Image.fromarray(frame)\n", 60 | "\n", 61 | " iframe.save(content_img+'{:s}_{:d}.jpg'.format(content_name, i+1))\n", 62 | " i += 1\n", 63 | " except:\n", 64 | " break\n", 65 | "\n", 66 | " cap.release()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "53bb10ab-08f2-4a2e-b8e6-4caa5ab1e065", 72 | "metadata": {}, 73 | "source": [ 74 | "# Transfer the style of each frame using AesFA model" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 2, 80 | "id": "ab7255c1-1256-4c2b-a29d-a9438ebd4c74", 81 | "metadata": { 82 | "scrolled": true, 83 | "tags": [] 84 | }, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "Version: main\n", 91 | "cuda:0\n", 92 | "checkpoint: ./ckpt/main/main.pth\n", 93 | "content dir: ./video_img/2048/content/\n", 94 | "style dir: /global/cfs/cdirs/m3898/GANBERT/st_to_diff/Code/LabServer/ResolutionDataset/2048/style\n", 95 | "# of contents: 319\n", 96 | "# of styles: 2\n", 97 | "1 1 1 1.4959888458251953\n", 98 | "2 1 2 0.013611078262329102\n", 99 | "3 2 1 0.013701677322387695\n", 100 | "4 2 2 0.013676643371582031\n", 101 | "5 3 1 0.013656854629516602\n", 102 | "6 3 2 0.013533353805541992\n", 103 | "7 4 1 0.01367640495300293\n", 104 | "8 4 2 0.013755083084106445\n", 105 | "9 5 1 0.013662338256835938\n", 106 | "10 5 2 0.01371622085571289\n", 107 | "11 6 1 0.01367497444152832\n", 108 | "12 6 2 0.013765335083007812\n", 109 | "13 7 1 0.011701345443725586\n", 110 | "14 7 2 0.011647224426269531\n", 111 | "15 8 1 0.01172637939453125\n", 112 | "16 8 2 0.01176309585571289\n", 113 | "17 9 1 0.011693477630615234\n", 114 | "18 9 2 0.011683225631713867\n", 115 | "19 10 1 0.012634038925170898\n", 116 | "20 10 2 0.011624336242675781\n", 117 | "21 11 1 0.011622428894042969\n", 118 | "22 11 2 0.01162266731262207\n", 119 | "23 12 1 0.01264333724975586\n", 120 | "24 12 2 0.012636184692382812\n", 121 | "25 13 1 0.011599302291870117\n", 122 | "26 13 2 0.011915445327758789\n", 123 | "27 14 1 0.01216435432434082\n", 124 | "28 14 2 0.0116424560546875\n", 125 | "29 15 1 0.011826276779174805\n", 126 | "30 15 2 0.011800765991210938\n", 127 | "31 16 1 0.011651277542114258\n", 128 | "32 16 2 0.011873722076416016\n", 129 | "33 17 1 0.013638734817504883\n", 130 | "34 17 2 0.011655330657958984\n", 131 | "35 18 1 0.011660337448120117\n", 132 | "36 18 2 0.014058828353881836\n", 133 | "37 19 1 0.012962102890014648\n", 134 | "38 19 2 0.01383519172668457\n", 135 | "39 20 1 0.012880802154541016\n", 136 | "40 20 2 0.013935565948486328\n", 137 | "41 21 1 0.013086557388305664\n", 138 | "42 21 2 0.013714790344238281\n", 139 | "43 22 1 0.012937784194946289\n", 140 | "44 22 2 0.013326644897460938\n", 141 | "45 23 1 0.01256418228149414\n", 142 | "46 23 2 0.012811660766601562\n", 143 | "47 24 1 0.012848854064941406\n", 144 | "48 24 2 0.013879776000976562\n", 145 | "49 25 1 0.012457847595214844\n", 146 | "50 25 2 0.012771844863891602\n", 147 | "51 26 1 0.01279306411743164\n", 148 | "52 26 2 0.01271200180053711\n", 149 | "53 27 1 0.012722969055175781\n", 150 | "54 27 2 0.012816190719604492\n", 151 | "55 28 1 0.013971090316772461\n", 152 | "56 28 2 0.013326644897460938\n", 153 | "57 29 1 0.014523029327392578\n", 154 | "58 29 2 0.012859106063842773\n", 155 | "59 30 1 0.0128631591796875\n", 156 | "60 30 2 0.012863874435424805\n", 157 | "61 31 1 0.012889623641967773\n", 158 | "62 31 2 0.01306605339050293\n", 159 | "63 32 1 0.013204574584960938\n", 160 | "64 32 2 0.012650012969970703\n", 161 | "65 33 1 0.014211177825927734\n", 162 | "66 33 2 0.012797355651855469\n", 163 | "67 34 1 0.012365102767944336\n", 164 | "68 34 2 0.012729883193969727\n", 165 | "69 35 1 0.014531135559082031\n", 166 | "70 35 2 0.012779951095581055\n", 167 | "71 36 1 0.012744903564453125\n", 168 | "72 36 2 0.013289928436279297\n", 169 | "73 37 1 0.013209104537963867\n", 170 | "74 37 2 0.013160467147827148\n", 171 | "75 38 1 0.012948989868164062\n", 172 | "76 38 2 0.013440370559692383\n", 173 | "77 39 1 0.013657569885253906\n", 174 | "78 39 2 0.012427091598510742\n", 175 | "79 40 1 0.012775897979736328\n", 176 | "80 40 2 0.012731313705444336\n", 177 | "81 41 1 0.01245737075805664\n", 178 | "82 41 2 0.01445913314819336\n", 179 | "83 42 1 0.013427972793579102\n", 180 | "84 42 2 0.014007568359375\n", 181 | "85 43 1 0.013984203338623047\n", 182 | "86 43 2 0.012977361679077148\n", 183 | "87 44 1 0.0128021240234375\n", 184 | "88 44 2 0.01381373405456543\n", 185 | "89 45 1 0.012957572937011719\n", 186 | "90 45 2 0.013097286224365234\n", 187 | "91 46 1 0.014750003814697266\n", 188 | "92 46 2 0.014478206634521484\n", 189 | "93 47 1 0.014056205749511719\n", 190 | "94 47 2 0.012886524200439453\n", 191 | "95 48 1 0.013549327850341797\n", 192 | "96 48 2 0.012481451034545898\n", 193 | "97 49 1 0.014037847518920898\n", 194 | "98 49 2 0.013414621353149414\n", 195 | "99 50 1 0.012603282928466797\n", 196 | "100 50 2 0.012481689453125\n", 197 | "101 51 1 0.013272285461425781\n", 198 | "102 51 2 0.013521909713745117\n", 199 | "103 52 1 0.0137481689453125\n", 200 | "104 52 2 0.013808965682983398\n", 201 | "105 53 1 0.013106107711791992\n", 202 | "106 53 2 0.013852834701538086\n", 203 | "107 54 1 0.01243138313293457\n", 204 | "108 54 2 0.01333761215209961\n", 205 | "109 55 1 0.013401269912719727\n", 206 | "110 55 2 0.012766599655151367\n", 207 | "111 56 1 0.01387476921081543\n", 208 | "112 56 2 0.013069391250610352\n", 209 | "113 57 1 0.014247655868530273\n", 210 | "114 57 2 0.012739419937133789\n", 211 | "115 58 1 0.013568401336669922\n", 212 | "116 58 2 0.01391911506652832\n", 213 | "117 59 1 0.012873172760009766\n", 214 | "118 59 2 0.013228178024291992\n", 215 | "119 60 1 0.012822389602661133\n", 216 | "120 60 2 0.01289820671081543\n", 217 | "121 61 1 0.012432336807250977\n", 218 | "122 61 2 0.012836456298828125\n", 219 | "123 62 1 0.014320850372314453\n", 220 | "124 62 2 0.014084339141845703\n", 221 | "125 63 1 0.01263284683227539\n", 222 | "126 63 2 0.012930154800415039\n", 223 | "127 64 1 0.013004064559936523\n", 224 | "128 64 2 0.012860536575317383\n", 225 | "129 65 1 0.01366734504699707\n", 226 | "130 65 2 0.01270437240600586\n", 227 | "131 66 1 0.012862682342529297\n", 228 | "132 66 2 0.013658761978149414\n", 229 | "133 67 1 0.012963533401489258\n", 230 | "134 67 2 0.013154983520507812\n", 231 | "135 68 1 0.013437271118164062\n", 232 | "136 68 2 0.01375889778137207\n", 233 | "137 69 1 0.012767553329467773\n", 234 | "138 69 2 0.01370859146118164\n", 235 | "139 70 1 0.01272273063659668\n", 236 | "140 70 2 0.012549161911010742\n", 237 | "141 71 1 0.013580799102783203\n", 238 | "142 71 2 0.013698577880859375\n", 239 | "143 72 1 0.013996362686157227\n", 240 | "144 72 2 0.012699127197265625\n", 241 | "145 73 1 0.01391291618347168\n", 242 | "146 73 2 0.012805461883544922\n", 243 | "147 74 1 0.01383829116821289\n", 244 | "148 74 2 0.013287544250488281\n", 245 | "149 75 1 0.013069868087768555\n", 246 | "150 75 2 0.013500213623046875\n", 247 | "151 76 1 0.014073848724365234\n", 248 | "152 76 2 0.01279592514038086\n", 249 | "153 77 1 0.013386726379394531\n", 250 | "154 77 2 0.012765884399414062\n", 251 | "155 78 1 0.013403177261352539\n", 252 | "156 78 2 0.01393437385559082\n", 253 | "157 79 1 0.013809919357299805\n", 254 | "158 79 2 0.013591766357421875\n", 255 | "159 80 1 0.013782501220703125\n", 256 | "160 80 2 0.012714147567749023\n", 257 | "161 81 1 0.01299595832824707\n", 258 | "162 81 2 0.01310110092163086\n", 259 | "163 82 1 0.013424873352050781\n", 260 | "164 82 2 0.012754678726196289\n", 261 | "165 83 1 0.012630701065063477\n", 262 | "166 83 2 0.012861967086791992\n", 263 | "167 84 1 0.01343226432800293\n", 264 | "168 84 2 0.01407480239868164\n", 265 | "169 85 1 0.013016462326049805\n", 266 | "170 85 2 0.012706518173217773\n", 267 | "171 86 1 0.013083696365356445\n", 268 | "172 86 2 0.013735294342041016\n", 269 | "173 87 1 0.012756824493408203\n", 270 | "174 87 2 0.012722969055175781\n", 271 | "175 88 1 0.01279902458190918\n", 272 | "176 88 2 0.013587474822998047\n", 273 | "177 89 1 0.012699604034423828\n", 274 | "178 89 2 0.01271200180053711\n", 275 | "179 90 1 0.012832880020141602\n", 276 | "180 90 2 0.014787435531616211\n", 277 | "181 91 1 0.012573957443237305\n", 278 | "182 91 2 0.013875246047973633\n", 279 | "183 92 1 0.012664318084716797\n", 280 | "184 92 2 0.012704133987426758\n", 281 | "185 93 1 0.013759374618530273\n", 282 | "186 93 2 0.01273798942565918\n", 283 | "187 94 1 0.012641429901123047\n", 284 | "188 94 2 0.0140838623046875\n", 285 | "189 95 1 0.012718439102172852\n", 286 | "190 95 2 0.012366533279418945\n", 287 | "191 96 1 0.013098716735839844\n", 288 | "192 96 2 0.013202428817749023\n", 289 | "193 97 1 0.012679576873779297\n", 290 | "194 97 2 0.012736082077026367\n", 291 | "195 98 1 0.013821840286254883\n", 292 | "196 98 2 0.012673616409301758\n", 293 | "197 99 1 0.012502670288085938\n", 294 | "198 99 2 0.012480020523071289\n", 295 | "199 100 1 0.012747049331665039\n", 296 | "200 100 2 0.012721538543701172\n", 297 | "201 101 1 0.013068914413452148\n", 298 | "202 101 2 0.013926029205322266\n", 299 | "203 102 1 0.014308691024780273\n", 300 | "204 102 2 0.013094663619995117\n", 301 | "205 103 1 0.012408018112182617\n", 302 | "206 103 2 0.012737035751342773\n", 303 | "207 104 1 0.014155387878417969\n", 304 | "208 104 2 0.01417684555053711\n", 305 | "209 105 1 0.012717962265014648\n", 306 | "210 105 2 0.013577461242675781\n", 307 | "211 106 1 0.01384592056274414\n", 308 | "212 106 2 0.012499094009399414\n", 309 | "213 107 1 0.013324737548828125\n", 310 | "214 107 2 0.01272726058959961\n", 311 | "215 108 1 0.012765169143676758\n", 312 | "216 108 2 0.012412786483764648\n", 313 | "217 109 1 0.013402938842773438\n", 314 | "218 109 2 0.012748479843139648\n", 315 | "219 110 1 0.012696266174316406\n", 316 | "220 110 2 0.014709711074829102\n", 317 | "221 111 1 0.012678861618041992\n", 318 | "222 111 2 0.012707948684692383\n", 319 | "223 112 1 0.012372016906738281\n", 320 | "224 112 2 0.01387166976928711\n", 321 | "225 113 1 0.014742136001586914\n", 322 | "226 113 2 0.013881921768188477\n", 323 | "227 114 1 0.013998746871948242\n", 324 | "228 114 2 0.012818336486816406\n", 325 | "229 115 1 0.01278829574584961\n", 326 | "230 115 2 0.01302647590637207\n", 327 | "231 116 1 0.013860464096069336\n", 328 | "232 116 2 0.01271963119506836\n", 329 | "233 117 1 0.012736797332763672\n", 330 | "234 117 2 0.01385498046875\n", 331 | "235 118 1 0.013054370880126953\n", 332 | "236 118 2 0.01247715950012207\n", 333 | "237 119 1 0.012817621231079102\n", 334 | "238 119 2 0.012902498245239258\n", 335 | "239 120 1 0.01287531852722168\n", 336 | "240 120 2 0.012738943099975586\n", 337 | "241 121 1 0.013455867767333984\n", 338 | "242 121 2 0.012953519821166992\n", 339 | "243 122 1 0.013499021530151367\n", 340 | "244 122 2 0.014836549758911133\n", 341 | "245 123 1 0.013926029205322266\n", 342 | "246 123 2 0.014399528503417969\n", 343 | "247 124 1 0.014725923538208008\n", 344 | "248 124 2 0.013016462326049805\n", 345 | "249 125 1 0.014930009841918945\n", 346 | "250 125 2 0.012398481369018555\n", 347 | "251 126 1 0.013311147689819336\n", 348 | "252 126 2 0.01300668716430664\n", 349 | "253 127 1 0.014466047286987305\n", 350 | "254 127 2 0.013182640075683594\n", 351 | "255 128 1 0.01296687126159668\n", 352 | "256 128 2 0.012848615646362305\n", 353 | "257 129 1 0.012396097183227539\n", 354 | "258 129 2 0.0128021240234375\n", 355 | "259 130 1 0.012939929962158203\n", 356 | "260 130 2 0.013190269470214844\n", 357 | "261 131 1 0.013485431671142578\n", 358 | "262 131 2 0.01271677017211914\n", 359 | "263 132 1 0.013487100601196289\n", 360 | "264 132 2 0.012562990188598633\n", 361 | "265 133 1 0.013659238815307617\n", 362 | "266 133 2 0.014175176620483398\n", 363 | "267 134 1 0.013893365859985352\n", 364 | "268 134 2 0.013002395629882812\n", 365 | "269 135 1 0.014646768569946289\n", 366 | "270 135 2 0.013591289520263672\n", 367 | "271 136 1 0.014624595642089844\n", 368 | "272 136 2 0.014317989349365234\n", 369 | "273 137 1 0.013198375701904297\n", 370 | "274 137 2 0.01444697380065918\n", 371 | "275 138 1 0.013878345489501953\n", 372 | "276 138 2 0.013666152954101562\n", 373 | "277 139 1 0.01478433609008789\n", 374 | "278 139 2 0.012974262237548828\n", 375 | "279 140 1 0.013012170791625977\n", 376 | "280 140 2 0.01274728775024414\n", 377 | "281 141 1 0.013678789138793945\n", 378 | "282 141 2 0.012810468673706055\n", 379 | "283 142 1 0.013793706893920898\n", 380 | "284 142 2 0.013660907745361328\n", 381 | "285 143 1 0.013870716094970703\n", 382 | "286 143 2 0.01301121711730957\n", 383 | "287 144 1 0.012352943420410156\n", 384 | "288 144 2 0.012561798095703125\n", 385 | "289 145 1 0.01263737678527832\n", 386 | "290 145 2 0.013193130493164062\n", 387 | "291 146 1 0.011633634567260742\n", 388 | "292 146 2 0.011667728424072266\n", 389 | "293 147 1 0.011678695678710938\n", 390 | "294 147 2 0.012796878814697266\n", 391 | "295 148 1 0.011621952056884766\n", 392 | "296 148 2 0.011636018753051758\n", 393 | "297 149 1 0.011415243148803711\n", 394 | "298 149 2 0.011688232421875\n", 395 | "299 150 1 0.011734485626220703\n", 396 | "300 150 2 0.012699365615844727\n", 397 | "301 151 1 0.011561155319213867\n", 398 | "302 151 2 0.011710882186889648\n", 399 | "303 152 1 0.012788534164428711\n", 400 | "304 152 2 0.012428522109985352\n", 401 | "305 153 1 0.011623144149780273\n", 402 | "306 153 2 0.012603521347045898\n", 403 | "307 154 1 0.011607885360717773\n", 404 | "308 154 2 0.011417388916015625\n", 405 | "309 155 1 0.011678218841552734\n", 406 | "310 155 2 0.01161646842956543\n", 407 | "311 156 1 0.012554168701171875\n", 408 | "312 156 2 0.011397123336791992\n", 409 | "313 157 1 0.012323379516601562\n", 410 | "314 157 2 0.011717081069946289\n", 411 | "315 158 1 0.01142120361328125\n", 412 | "316 158 2 0.011664152145385742\n", 413 | "317 159 1 0.012273788452148438\n", 414 | "318 159 2 0.01161336898803711\n", 415 | "319 160 1 0.011413097381591797\n", 416 | "320 160 2 0.011538505554199219\n", 417 | "321 161 1 0.011590003967285156\n", 418 | "322 161 2 0.011747360229492188\n", 419 | "323 162 1 0.011584281921386719\n", 420 | "324 162 2 0.011697769165039062\n", 421 | "325 163 1 0.011633157730102539\n", 422 | "326 163 2 0.011493921279907227\n", 423 | "327 164 1 0.012678146362304688\n", 424 | "328 164 2 0.011621475219726562\n", 425 | "329 165 1 0.0115966796875\n", 426 | "330 165 2 0.011447429656982422\n", 427 | "331 166 1 0.011586427688598633\n", 428 | "332 166 2 0.011658906936645508\n", 429 | "333 167 1 0.011457681655883789\n", 430 | "334 167 2 0.011510610580444336\n", 431 | "335 168 1 0.01169729232788086\n", 432 | "336 168 2 0.0126495361328125\n", 433 | "337 169 1 0.011425018310546875\n", 434 | "338 169 2 0.01159524917602539\n", 435 | "339 170 1 0.011730194091796875\n", 436 | "340 170 2 0.011640071868896484\n", 437 | "341 171 1 0.014062881469726562\n", 438 | "342 171 2 0.011739492416381836\n", 439 | "343 172 1 0.011718034744262695\n", 440 | "344 172 2 0.01143193244934082\n", 441 | "345 173 1 0.01160883903503418\n", 442 | "346 173 2 0.011601448059082031\n", 443 | "347 174 1 0.011661767959594727\n", 444 | "348 174 2 0.011547088623046875\n", 445 | "349 175 1 0.012003183364868164\n", 446 | "350 175 2 0.012264490127563477\n", 447 | "351 176 1 0.011504173278808594\n", 448 | "352 176 2 0.011589527130126953\n", 449 | "353 177 1 0.011622428894042969\n", 450 | "354 177 2 0.011635065078735352\n", 451 | "355 178 1 0.01145315170288086\n", 452 | "356 178 2 0.01169729232788086\n", 453 | "357 179 1 0.013379812240600586\n", 454 | "358 179 2 0.011698722839355469\n", 455 | "359 180 1 0.011440277099609375\n", 456 | "360 180 2 0.01191568374633789\n", 457 | "361 181 1 0.011662483215332031\n", 458 | "362 181 2 0.013228654861450195\n", 459 | "363 182 1 0.01308441162109375\n", 460 | "364 182 2 0.012153148651123047\n", 461 | "365 183 1 0.011551856994628906\n", 462 | "366 183 2 0.011779546737670898\n", 463 | "367 184 1 0.011803627014160156\n", 464 | "368 184 2 0.011717796325683594\n", 465 | "369 185 1 0.0117950439453125\n", 466 | "370 185 2 0.011774063110351562\n", 467 | "371 186 1 0.013192892074584961\n", 468 | "372 186 2 0.011783599853515625\n", 469 | "373 187 1 0.011571168899536133\n", 470 | "374 187 2 0.011755228042602539\n", 471 | "375 188 1 0.01375722885131836\n", 472 | "376 188 2 0.012132883071899414\n", 473 | "377 189 1 0.01233530044555664\n", 474 | "378 189 2 0.012766599655151367\n", 475 | "379 190 1 0.012897968292236328\n", 476 | "380 190 2 0.012633562088012695\n", 477 | "381 191 1 0.013985157012939453\n", 478 | "382 191 2 0.013800621032714844\n", 479 | "383 192 1 0.012768268585205078\n", 480 | "384 192 2 0.013072967529296875\n", 481 | "385 193 1 0.01387929916381836\n", 482 | "386 193 2 0.013388633728027344\n", 483 | "387 194 1 0.012714624404907227\n", 484 | "388 194 2 0.013583660125732422\n", 485 | "389 195 1 0.01267385482788086\n", 486 | "390 195 2 0.014597177505493164\n", 487 | "391 196 1 0.012612581253051758\n", 488 | "392 196 2 0.01373600959777832\n", 489 | "393 197 1 0.013222694396972656\n", 490 | "394 197 2 0.01295018196105957\n", 491 | "395 198 1 0.012334823608398438\n", 492 | "396 198 2 0.01252889633178711\n", 493 | "397 199 1 0.012554407119750977\n", 494 | "398 199 2 0.013105630874633789\n", 495 | "399 200 1 0.014058351516723633\n", 496 | "400 200 2 0.013690948486328125\n", 497 | "401 201 1 0.013262271881103516\n", 498 | "402 201 2 0.01384878158569336\n", 499 | "403 202 1 0.01252293586730957\n", 500 | "404 202 2 0.012486696243286133\n", 501 | "405 203 1 0.01377558708190918\n", 502 | "406 203 2 0.01277303695678711\n", 503 | "407 204 1 0.012504339218139648\n", 504 | "408 204 2 0.013715028762817383\n", 505 | "409 205 1 0.01278066635131836\n", 506 | "410 205 2 0.012386083602905273\n", 507 | "411 206 1 0.013184547424316406\n", 508 | "412 206 2 0.013869047164916992\n", 509 | "413 207 1 0.012421369552612305\n", 510 | "414 207 2 0.013211965560913086\n", 511 | "415 208 1 0.012650251388549805\n", 512 | "416 208 2 0.012481451034545898\n", 513 | "417 209 1 0.013153314590454102\n", 514 | "418 209 2 0.01270294189453125\n", 515 | "419 210 1 0.013456583023071289\n", 516 | "420 210 2 0.012352228164672852\n", 517 | "421 211 1 0.013653278350830078\n", 518 | "422 211 2 0.013286590576171875\n", 519 | "423 212 1 0.012918472290039062\n", 520 | "424 212 2 0.01227712631225586\n", 521 | "425 213 1 0.012325286865234375\n", 522 | "426 213 2 0.01367640495300293\n", 523 | "427 214 1 0.012329339981079102\n", 524 | "428 214 2 0.013483762741088867\n", 525 | "429 215 1 0.012603044509887695\n", 526 | "430 215 2 0.013625860214233398\n", 527 | "431 216 1 0.012730121612548828\n", 528 | "432 216 2 0.013671398162841797\n", 529 | "433 217 1 0.012735366821289062\n", 530 | "434 217 2 0.012464523315429688\n", 531 | "435 218 1 0.012424230575561523\n", 532 | "436 218 2 0.012888431549072266\n", 533 | "437 219 1 0.01347208023071289\n", 534 | "438 219 2 0.012525320053100586\n", 535 | "439 220 1 0.01267242431640625\n", 536 | "440 220 2 0.012814760208129883\n", 537 | "441 221 1 0.012642383575439453\n", 538 | "442 221 2 0.012613534927368164\n", 539 | "443 222 1 0.012290239334106445\n", 540 | "444 222 2 0.014106512069702148\n", 541 | "445 223 1 0.01321864128112793\n", 542 | "446 223 2 0.014102697372436523\n", 543 | "447 224 1 0.01352834701538086\n", 544 | "448 224 2 0.012572050094604492\n", 545 | "449 225 1 0.012751102447509766\n", 546 | "450 225 2 0.012769460678100586\n", 547 | "451 226 1 0.014033317565917969\n", 548 | "452 226 2 0.012693643569946289\n", 549 | "453 227 1 0.013306140899658203\n", 550 | "454 227 2 0.01329493522644043\n", 551 | "455 228 1 0.013216018676757812\n", 552 | "456 228 2 0.012681722640991211\n", 553 | "457 229 1 0.013617753982543945\n", 554 | "458 229 2 0.012444257736206055\n", 555 | "459 230 1 0.01358795166015625\n", 556 | "460 230 2 0.01271367073059082\n", 557 | "461 231 1 0.012509346008300781\n", 558 | "462 231 2 0.012388467788696289\n", 559 | "463 232 1 0.01296854019165039\n", 560 | "464 232 2 0.012692689895629883\n", 561 | "465 233 1 0.012335062026977539\n", 562 | "466 233 2 0.012636899948120117\n", 563 | "467 234 1 0.01267552375793457\n", 564 | "468 234 2 0.013209342956542969\n", 565 | "469 235 1 0.012700080871582031\n", 566 | "470 235 2 0.012656927108764648\n", 567 | "471 236 1 0.013702154159545898\n", 568 | "472 236 2 0.012688636779785156\n", 569 | "473 237 1 0.013486623764038086\n", 570 | "474 237 2 0.013573169708251953\n", 571 | "475 238 1 0.012482643127441406\n", 572 | "476 238 2 0.014265775680541992\n", 573 | "477 239 1 0.013628005981445312\n", 574 | "478 239 2 0.012498617172241211\n", 575 | "479 240 1 0.012382984161376953\n", 576 | "480 240 2 0.01279759407043457\n", 577 | "481 241 1 0.013809919357299805\n", 578 | "482 241 2 0.012789011001586914\n", 579 | "483 242 1 0.013376474380493164\n", 580 | "484 242 2 0.01320338249206543\n", 581 | "485 243 1 0.014653444290161133\n", 582 | "486 243 2 0.01357579231262207\n", 583 | "487 244 1 0.012531757354736328\n", 584 | "488 244 2 0.012547969818115234\n", 585 | "489 245 1 0.012688398361206055\n", 586 | "490 245 2 0.012520313262939453\n", 587 | "491 246 1 0.012963056564331055\n", 588 | "492 246 2 0.012600421905517578\n", 589 | "493 247 1 0.012735605239868164\n", 590 | "494 247 2 0.012371301651000977\n", 591 | "495 248 1 0.014453887939453125\n", 592 | "496 248 2 0.013877391815185547\n", 593 | "497 249 1 0.014386415481567383\n", 594 | "498 249 2 0.014112710952758789\n", 595 | "499 250 1 0.013568401336669922\n", 596 | "500 250 2 0.013689279556274414\n", 597 | "501 251 1 0.014236688613891602\n", 598 | "502 251 2 0.015862703323364258\n", 599 | "503 252 1 0.014400720596313477\n", 600 | "504 252 2 0.013496637344360352\n", 601 | "505 253 1 0.013967514038085938\n", 602 | "506 253 2 0.014106273651123047\n", 603 | "507 254 1 0.013255119323730469\n", 604 | "508 254 2 0.014400720596313477\n", 605 | "509 255 1 0.014082193374633789\n", 606 | "510 255 2 0.013763904571533203\n", 607 | "511 256 1 0.013281583786010742\n", 608 | "512 256 2 0.013522624969482422\n", 609 | "513 257 1 0.01388859748840332\n", 610 | "514 257 2 0.01385354995727539\n", 611 | "515 258 1 0.013643264770507812\n", 612 | "516 258 2 0.0128936767578125\n", 613 | "517 259 1 0.013870000839233398\n", 614 | "518 259 2 0.01346588134765625\n", 615 | "519 260 1 0.014710426330566406\n", 616 | "520 260 2 0.014383077621459961\n", 617 | "521 261 1 0.013254642486572266\n", 618 | "522 261 2 0.01389002799987793\n", 619 | "523 262 1 0.01373910903930664\n", 620 | "524 262 2 0.013704299926757812\n", 621 | "525 263 1 0.014248371124267578\n", 622 | "526 263 2 0.013776302337646484\n", 623 | "527 264 1 0.013773679733276367\n", 624 | "528 264 2 0.013849496841430664\n", 625 | "529 265 1 0.014349937438964844\n", 626 | "530 265 2 0.013779878616333008\n", 627 | "531 266 1 0.01332998275756836\n", 628 | "532 266 2 0.013013362884521484\n", 629 | "533 267 1 0.013605356216430664\n", 630 | "534 267 2 0.013622045516967773\n", 631 | "535 268 1 0.013462305068969727\n", 632 | "536 268 2 0.013475656509399414\n", 633 | "537 269 1 0.012336969375610352\n", 634 | "538 269 2 0.014376401901245117\n", 635 | "539 270 1 0.013368844985961914\n", 636 | "540 270 2 0.013869285583496094\n", 637 | "541 271 1 0.013324499130249023\n", 638 | "542 271 2 0.012337923049926758\n", 639 | "543 272 1 0.014369487762451172\n", 640 | "544 272 2 0.013405799865722656\n", 641 | "545 273 1 0.013122081756591797\n", 642 | "546 273 2 0.012341499328613281\n", 643 | "547 274 1 0.01433420181274414\n", 644 | "548 274 2 0.012491226196289062\n", 645 | "549 275 1 0.012465476989746094\n", 646 | "550 275 2 0.012725591659545898\n", 647 | "551 276 1 0.013387441635131836\n", 648 | "552 276 2 0.012306451797485352\n", 649 | "553 277 1 0.014266490936279297\n", 650 | "554 277 2 0.01361083984375\n", 651 | "555 278 1 0.012292623519897461\n", 652 | "556 278 2 0.013612747192382812\n", 653 | "557 279 1 0.012541055679321289\n", 654 | "558 279 2 0.013263702392578125\n", 655 | "559 280 1 0.013273954391479492\n", 656 | "560 280 2 0.012352466583251953\n", 657 | "561 281 1 0.013643264770507812\n", 658 | "562 281 2 0.013596057891845703\n", 659 | "563 282 1 0.01385188102722168\n", 660 | "564 282 2 0.012816429138183594\n", 661 | "565 283 1 0.012507200241088867\n", 662 | "566 283 2 0.014425516128540039\n", 663 | "567 284 1 0.012351512908935547\n", 664 | "568 284 2 0.013317108154296875\n", 665 | "569 285 1 0.012427806854248047\n", 666 | "570 285 2 0.013004541397094727\n", 667 | "571 286 1 0.01266026496887207\n", 668 | "572 286 2 0.012664794921875\n", 669 | "573 287 1 0.013560295104980469\n", 670 | "574 287 2 0.012282848358154297\n", 671 | "575 288 1 0.013694524765014648\n", 672 | "576 288 2 0.012948274612426758\n", 673 | "577 289 1 0.012575626373291016\n", 674 | "578 289 2 0.012510538101196289\n", 675 | "579 290 1 0.013562202453613281\n", 676 | "580 290 2 0.013659238815307617\n", 677 | "581 291 1 0.012853860855102539\n", 678 | "582 291 2 0.013521671295166016\n", 679 | "583 292 1 0.012647151947021484\n", 680 | "584 292 2 0.012308835983276367\n", 681 | "585 293 1 0.013398885726928711\n", 682 | "586 293 2 0.01263737678527832\n", 683 | "587 294 1 0.012757062911987305\n", 684 | "588 294 2 0.013249874114990234\n", 685 | "589 295 1 0.012322425842285156\n", 686 | "590 295 2 0.012548446655273438\n", 687 | "591 296 1 0.012465715408325195\n", 688 | "592 296 2 0.012528657913208008\n", 689 | "593 297 1 0.012717485427856445\n", 690 | "594 297 2 0.01473093032836914\n", 691 | "595 298 1 0.013164758682250977\n", 692 | "596 298 2 0.012622356414794922\n", 693 | "597 299 1 0.013996124267578125\n", 694 | "598 299 2 0.012646198272705078\n", 695 | "599 300 1 0.012280702590942383\n", 696 | "600 300 2 0.013857841491699219\n", 697 | "601 301 1 0.0142364501953125\n", 698 | "602 301 2 0.012766838073730469\n", 699 | "603 302 1 0.013649463653564453\n", 700 | "604 302 2 0.014155864715576172\n", 701 | "605 303 1 0.012651443481445312\n", 702 | "606 303 2 0.013751506805419922\n", 703 | "607 304 1 0.01323556900024414\n", 704 | "608 304 2 0.012627363204956055\n", 705 | "609 305 1 0.013291358947753906\n", 706 | "610 305 2 0.01236104965209961\n", 707 | "611 306 1 0.012605905532836914\n", 708 | "612 306 2 0.012676000595092773\n", 709 | "613 307 1 0.01229405403137207\n", 710 | "614 307 2 0.012366294860839844\n", 711 | "615 308 1 0.013654947280883789\n", 712 | "616 308 2 0.012632369995117188\n", 713 | "617 309 1 0.012373685836791992\n", 714 | "618 309 2 0.012389183044433594\n", 715 | "619 310 1 0.012661457061767578\n", 716 | "620 310 2 0.012536287307739258\n", 717 | "621 311 1 0.013274908065795898\n", 718 | "622 311 2 0.014044523239135742\n", 719 | "623 312 1 0.012743473052978516\n", 720 | "624 312 2 0.014609813690185547\n", 721 | "625 313 1 0.013411521911621094\n", 722 | "626 313 2 0.012664794921875\n", 723 | "627 314 1 0.012385129928588867\n", 724 | "628 314 2 0.01350259780883789\n", 725 | "629 315 1 0.012861251831054688\n", 726 | "630 315 2 0.013253927230834961\n", 727 | "631 316 1 0.01340794563293457\n", 728 | "632 316 2 0.013985157012939453\n", 729 | "633 317 1 0.013353824615478516\n", 730 | "634 317 2 0.012727022171020508\n", 731 | "635 318 1 0.01324009895324707\n", 732 | "636 318 2 0.012718915939331055\n", 733 | "637 319 1 0.0127105712890625\n", 734 | "638 319 2 0.012707710266113281\n", 735 | "[Ours] Total images: 638 Avg Testing time: 0.015323991312128624\n" 736 | ] 737 | } 738 | ], 739 | "source": [ 740 | "# Before testing, you have to edit Config.py file as you want.\n", 741 | "# content_dir would be like './video_img/2048/content/'\n", 742 | "# img_dir would be like './output/video_output/2048'\n", 743 | "import test_video\n", 744 | "test_video.main()" 745 | ] 746 | }, 747 | { 748 | "cell_type": "markdown", 749 | "id": "9fa62836-3cc3-49f0-ae9f-d8c007c1483a", 750 | "metadata": {}, 751 | "source": [ 752 | "# Image to Video - Content Video\n", 753 | "After style transfer" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": 3, 759 | "id": "147d187f-1f05-4f74-808b-9a8d3dd2e284", 760 | "metadata": {}, 761 | "outputs": [ 762 | { 763 | "name": "stdout", 764 | "output_type": "stream", 765 | "text": [ 766 | "319 24.0\n" 767 | ] 768 | }, 769 | { 770 | "name": "stderr", 771 | "output_type": "stream", 772 | "text": [ 773 | "OpenCV: FFMPEG: tag 0x5634504d/'MP4V' is not supported with codec id 12 and format 'mp4 / MP4 (MPEG-4 Part 14)'\n", 774 | "OpenCV: FFMPEG: fallback to use tag 0x7634706d/'mp4v'\n" 775 | ] 776 | } 777 | ], 778 | "source": [ 779 | "import cv2\n", 780 | "import numpy as np\n", 781 | "import glob\n", 782 | "from Config import Config\n", 783 | "\n", 784 | "config = Config()\n", 785 | "out_dir = config.img_dir\n", 786 | "\n", 787 | "for content_n in contents:\n", 788 | " length = len(glob.glob(out_dir+'/content/'+content_n+'*_content*'))\n", 789 | " cap = cv2.VideoCapture(content_dir+content_n+'.mp4')\n", 790 | " fps = cap.get(cv2.CAP_PROP_FPS)\n", 791 | " print(length, fps)\n", 792 | " \n", 793 | " img_array = []\n", 794 | " for idx in range(length):\n", 795 | " filename = out_dir+'/content/'+content_n+'_'+str(idx+1)+'_content.jpg'\n", 796 | " img = cv2.imread(filename)\n", 797 | " img_array.append(img)\n", 798 | " height, width, layers = img.shape\n", 799 | " size = (width,height)\n", 800 | "\n", 801 | " out = cv2.VideoWriter(out_dir+'/CONTENT_'+content_n+'.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)\n", 802 | "\n", 803 | " for i in range(len(img_array)):\n", 804 | " out.write(img_array[i])\n", 805 | " out.release()" 806 | ] 807 | }, 808 | { 809 | "cell_type": "markdown", 810 | "id": "0f777fe8-b10a-4b2a-94cc-ef6f8e8dabb1", 811 | "metadata": {}, 812 | "source": [ 813 | "# Image to Video - Style Transfered Video" 814 | ] 815 | }, 816 | { 817 | "cell_type": "code", 818 | "execution_count": 4, 819 | "id": "2fe49999-ffee-45b7-b0a4-4a008739046f", 820 | "metadata": {}, 821 | "outputs": [ 822 | { 823 | "name": "stdout", 824 | "output_type": "stream", 825 | "text": [ 826 | "pexels-kelly-7165765-3840x2160-24fps 319 24.0\n" 827 | ] 828 | } 829 | ], 830 | "source": [ 831 | "style_imgs = sorted(glob.glob('./video_img/'+resolution+'/style/*.jpg'))\n", 832 | "\n", 833 | "for content in contents:\n", 834 | " length = len(glob.glob(out_dir+'/content/'+content+'*_content*'))\n", 835 | " cap = cv2.VideoCapture(content_dir+content+'.mp4')\n", 836 | " fps = cap.get(cv2.CAP_PROP_FPS)\n", 837 | " print(content, length, fps)\n", 838 | " \n", 839 | " for style_dir in style_imgs:\n", 840 | " style = style_dir.split('/')[-1].split('.')[0]\n", 841 | " img_array = []\n", 842 | "\n", 843 | " for idx in range(length):\n", 844 | " filename = out_dir+'/stylized/'+content+'_'+str(idx+1)+'_stylized_'+style+'.jpg'\n", 845 | " img = cv2.imread(filename)\n", 846 | " img_array.append(img)\n", 847 | " height, width, layers = img.shape\n", 848 | " size = (width,height)\n", 849 | " print(\"-\", style, len(img_array))\n", 850 | "\n", 851 | " out = cv2.VideoWriter(out_dir+'/STYLE_'+content+'_'+style+'.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)\n", 852 | "\n", 853 | " for i in range(len(img_array)):\n", 854 | " out.write(img_array[i])\n", 855 | " out.release()" 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": null, 861 | "id": "db973afc-7937-4216-b80a-a2df475a15f1", 862 | "metadata": {}, 863 | "outputs": [], 864 | "source": [] 865 | } 866 | ], 867 | "metadata": { 868 | "kernelspec": { 869 | "display_name": "pytorch-1.13.1", 870 | "language": "python", 871 | "name": "pytorch-1.13.1" 872 | }, 873 | "language_info": { 874 | "codemirror_mode": { 875 | "name": "ipython", 876 | "version": 3 877 | }, 878 | "file_extension": ".py", 879 | "mimetype": "text/x-python", 880 | "name": "python", 881 | "nbconvert_exporter": "python", 882 | "pygments_lexer": "ipython3", 883 | "version": "3.9.15" 884 | } 885 | }, 886 | "nbformat": 4, 887 | "nbformat_minor": 5 888 | } 889 | --------------------------------------------------------------------------------