├── .DS_Store ├── LICENSE ├── README.md ├── TestArtistic.py ├── TestPhotoReal.py ├── TestVideo.py ├── Train.py ├── TrainSPN.py ├── data ├── .DS_Store ├── content │ ├── .DS_Store │ ├── 1.jpg │ └── chicago.png ├── photo_real │ ├── .DS_Store │ ├── content │ │ ├── .DS_Store │ │ └── images │ │ │ ├── in16.png │ │ │ ├── in25.png │ │ │ ├── in26.png │ │ │ ├── in29.png │ │ │ ├── in3.png │ │ │ ├── in39.png │ │ │ ├── in53.png │ │ │ └── in7.png │ ├── contentSeg │ │ ├── in16.png │ │ ├── in25.png │ │ ├── in26.png │ │ ├── in29.png │ │ ├── in3.png │ │ ├── in39.png │ │ ├── in53.png │ │ └── in7.png │ ├── style │ │ └── images │ │ │ ├── in16.png │ │ │ ├── in25.png │ │ │ ├── in26.png │ │ │ ├── in29.png │ │ │ ├── in3.png │ │ │ ├── in39.png │ │ │ ├── in53.png │ │ │ └── in7.png │ └── styleSeg │ │ ├── in16.png │ │ ├── in25.png │ │ ├── in26.png │ │ ├── in29.png │ │ ├── in3.png │ │ ├── in39.png │ │ ├── in53.png │ │ └── in7.png ├── style │ ├── .DS_Store │ ├── 27.jpg │ ├── 3314.jpg │ ├── antimonocromatismo.jpg │ ├── in2.jpg │ ├── picasso_self_portrait.jpg │ └── sketch.jpg └── videos │ ├── .DS_Store │ └── content │ ├── .DS_Store │ └── mountain_2 │ ├── .DS_Store │ ├── frame_0001.png │ ├── frame_0002.png │ ├── frame_0003.png │ ├── frame_0004.png │ ├── frame_0005.png │ ├── frame_0006.png │ ├── frame_0007.png │ ├── frame_0008.png │ ├── frame_0009.png │ ├── frame_0010.png │ ├── frame_0011.png │ ├── frame_0012.png │ ├── frame_0013.png │ ├── frame_0014.png │ ├── frame_0015.png │ ├── frame_0016.png │ ├── frame_0017.png │ ├── frame_0018.png │ ├── frame_0019.png │ ├── frame_0020.png │ ├── frame_0021.png │ ├── frame_0022.png │ ├── frame_0023.png │ ├── frame_0024.png │ ├── frame_0025.png │ ├── frame_0026.png │ ├── frame_0027.png │ ├── frame_0028.png │ ├── frame_0029.png │ ├── frame_0030.png │ ├── frame_0031.png │ ├── frame_0032.png │ ├── frame_0033.png │ ├── frame_0034.png │ ├── frame_0035.png │ ├── frame_0036.png │ ├── frame_0037.png │ ├── frame_0038.png │ ├── frame_0039.png │ ├── frame_0040.png │ ├── frame_0041.png │ ├── frame_0042.png │ ├── frame_0043.png │ ├── frame_0044.png │ ├── frame_0045.png │ ├── frame_0046.png │ ├── frame_0047.png │ ├── frame_0048.png │ ├── frame_0049.png │ └── frame_0050.png ├── doc ├── .DS_Store └── images │ ├── .DS_Store │ ├── chicago_27.png │ ├── chicago_paste.png │ ├── content.gif │ ├── in5_result.png │ ├── photo_content.png │ └── test.gif ├── libs ├── .DS_Store ├── Criterion.py ├── Loader.py ├── LoaderPhotoReal.py ├── Matrix.py ├── MatrixTest.py ├── SPN.py ├── __init__.py ├── models.py ├── pytorch_spn │ ├── README.md │ ├── __init__.py │ ├── _ext │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ └── gaterecurrent2dnoind │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ │ └── _gaterecurrent2dnoind.so │ ├── build.py │ ├── functions │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── gaterecurrent2dnoind.cpython-36.pyc │ │ ├── gaterecurrent2dnoind.py │ │ └── gaterecurrent2dnoind.pyc │ ├── left_right_demo.py │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── gaterecurrent2dnoind.cpython-36.pyc │ │ ├── gaterecurrent2dnoind.py │ │ └── gaterecurrent2dnoind.pyc │ └── src │ │ ├── .DS_Store │ │ ├── cuda │ │ ├── gaterecurrent2dnoind_kernel.cu │ │ ├── gaterecurrent2dnoind_kernel.cu.o │ │ └── gaterecurrent2dnoind_kernel.h │ │ ├── gaterecurrent2dnoind_cuda.c │ │ └── gaterecurrent2dnoind_cuda.h ├── smooth_filter.py └── utils.py └── real-time-demo.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2018, SunshineAtNoon 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Learning Linear Transformations for Fast Image and Video Style Transfer 2 | **[[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Learning_Linear_Transformations_for_Fast_Image_and_Video_Style_Transfer_CVPR_2019_paper.pdf)** **[[Project Page]](https://sites.google.com/view/linear-style-transfer-cvpr19/)** 3 | 4 | 5 | 6 | 7 | ## Prerequisites 8 | - [Pytorch](http://pytorch.org/) 9 | - [torchvision](https://github.com/pytorch/vision) 10 | - [opencv](https://opencv.org/) for video generation 11 | 12 | **All code tested on Ubuntu 16.04, pytorch 0.4.1, and opencv 3.4.2** 13 | 14 | ## Style Transfer 15 | - Clone from github: `git clone https://github.com/sunshineatnoon/LinearStyleTransfer` 16 | - Download pre-trained models from [google drive](https://drive.google.com/file/d/1H9T5rfXGlGCUh04DGkpkMFbVnmscJAbs/view?usp=sharing). 17 | - Uncompress to root folder : 18 | ``` 19 | cd LinearStyleTransfer 20 | unzip models.zip 21 | rm models.zip 22 | ``` 23 | 24 | #### Artistic style transfer 25 | ``` 26 | python TestArtistic.py 27 | ``` 28 | or conduct style transfer on relu_31 features 29 | ``` 30 | python TestArtistic.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31 31 | ``` 32 | 33 | #### Photo-realistic style transfer 34 | For photo-realistic style transfer, we need first compile the [pytorch_spn](https://github.com/Liusifei/pytorch_spn) repository. 35 | ``` 36 | cd libs/pytorch_spn 37 | sh make.sh 38 | cd ../.. 39 | ``` 40 | Then: 41 | ``` 42 | python TestPhotoReal.py 43 | ``` 44 | Note: images with `_filtered.png` as postfix are images filtered by the SPN after style transfer, images with `_smooth.png` as postfix are images post process by a [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py). 45 | 46 | #### Video style transfer 47 | ``` 48 | python TestVideo.py 49 | ``` 50 | 51 | #### Real-time video demo 52 | ``` 53 | python real-time-demo.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31 54 | ``` 55 | 56 | ## Model Training 57 | ### Data Preparation 58 | - MSCOCO 59 | ``` 60 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip 61 | ``` 62 | - WikiArt 63 | - Either manually download from [kaggle](https://www.kaggle.com/c/painter-by-numbers). 64 | - Or install [kaggle-cli](https://github.com/floydwch/kaggle-cli) and download by running: 65 | ``` 66 | kg download -u -p -c painter-by-numbers -f train.zip 67 | ``` 68 | 69 | ### Training 70 | #### Train a style transfer model 71 | To train a model that transfers relu4_1 features, run: 72 | ``` 73 | python Train.py --vgg_dir models/vgg_r41.pth --decoder_dir models/dec_r41.pth --layer r41 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR 74 | ``` 75 | or train a model that transfers relu3_1 features: 76 | ``` 77 | python Train.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --layer r31 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR 78 | ``` 79 | Key hyper-parameters: 80 | - style_layers: which features to compute style loss. 81 | - style_weight: larger style weight leads to heavier style in transferred images. 82 | 83 | Intermediate results and weight will be stored in `OUTPUT_DIR` 84 | 85 | #### Train a SPN model to cancel distortions for photo-realistic style transfer 86 | Run: 87 | ``` 88 | python TrainSPN.py --contentPath PATH_TO_MSCOCO 89 | ``` 90 | 91 | ### Acknowledgement 92 | - We use the [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py) by [LouieYang](https://github.com/LouieYang) in the photo-realistic style transfer. 93 | 94 | ### Citation 95 | ``` 96 | @inproceedings{li2018learning, 97 | author = {Li, Xueting and Liu, Sifei and Kautz, Jan and Yang, Ming-Hsuan}, 98 | title = {Learning Linear Transformations for Fast Arbitrary Style Transfer}, 99 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, 100 | year = {2019} 101 | } 102 | ``` 103 | -------------------------------------------------------------------------------- /TestArtistic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from libs.Loader import Dataset 5 | from libs.Matrix import MulLayer 6 | import torchvision.utils as vutils 7 | import torch.backends.cudnn as cudnn 8 | from libs.utils import print_options 9 | from libs.models import encoder3,encoder4, encoder5 10 | from libs.models import decoder3,decoder4, decoder5 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--vgg_dir", default='models/vgg_r41.pth', 14 | help='pre-trained encoder path') 15 | parser.add_argument("--decoder_dir", default='models/dec_r41.pth', 16 | help='pre-trained decoder path') 17 | parser.add_argument("--matrixPath", default='models/r41.pth', 18 | help='pre-trained model path') 19 | parser.add_argument("--stylePath", default="data/style/", 20 | help='path to style image') 21 | parser.add_argument("--contentPath", default="data/content/", 22 | help='path to frames') 23 | parser.add_argument("--outf", default="Artistic/", 24 | help='path to transferred images') 25 | parser.add_argument("--batchSize", type=int,default=1, 26 | help='batch size') 27 | parser.add_argument('--loadSize', type=int, default=256, 28 | help='scale image size') 29 | parser.add_argument('--fineSize', type=int, default=256, 30 | help='crop image size') 31 | parser.add_argument("--layer", default="r41", 32 | help='which features to transfer, either r31 or r41') 33 | 34 | ################# PREPARATIONS ################# 35 | opt = parser.parse_args() 36 | opt.cuda = torch.cuda.is_available() 37 | print_options(opt) 38 | 39 | os.makedirs(opt.outf,exist_ok=True) 40 | cudnn.benchmark = True 41 | 42 | ################# DATA ################# 43 | content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize,test=True) 44 | content_loader = torch.utils.data.DataLoader(dataset=content_dataset, 45 | batch_size = opt.batchSize, 46 | shuffle = False, 47 | num_workers = 1) 48 | style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize,test=True) 49 | style_loader = torch.utils.data.DataLoader(dataset=style_dataset, 50 | batch_size = opt.batchSize, 51 | shuffle = False, 52 | num_workers = 1) 53 | 54 | ################# MODEL ################# 55 | if(opt.layer == 'r31'): 56 | vgg = encoder3() 57 | dec = decoder3() 58 | elif(opt.layer == 'r41'): 59 | vgg = encoder4() 60 | dec = decoder4() 61 | matrix = MulLayer(opt.layer) 62 | vgg.load_state_dict(torch.load(opt.vgg_dir)) 63 | dec.load_state_dict(torch.load(opt.decoder_dir)) 64 | matrix.load_state_dict(torch.load(opt.matrixPath)) 65 | 66 | ################# GLOBAL VARIABLE ################# 67 | contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 68 | styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 69 | 70 | ################# GPU ################# 71 | if(opt.cuda): 72 | vgg.cuda() 73 | dec.cuda() 74 | matrix.cuda() 75 | contentV = contentV.cuda() 76 | styleV = styleV.cuda() 77 | 78 | for ci,(content,contentName) in enumerate(content_loader): 79 | contentName = contentName[0] 80 | contentV.resize_(content.size()).copy_(content) 81 | for sj,(style,styleName) in enumerate(style_loader): 82 | styleName = styleName[0] 83 | styleV.resize_(style.size()).copy_(style) 84 | 85 | # forward 86 | with torch.no_grad(): 87 | sF = vgg(styleV) 88 | cF = vgg(contentV) 89 | 90 | if(opt.layer == 'r41'): 91 | feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer]) 92 | else: 93 | feature,transmatrix = matrix(cF,sF) 94 | transfer = dec(feature) 95 | 96 | transfer = transfer.clamp(0,1) 97 | vutils.save_image(transfer,'%s/%s_%s.png'%(opt.outf,contentName,styleName),normalize=True,scale_each=True,nrow=opt.batchSize) 98 | print('Transferred image saved at %s%s_%s.png'%(opt.outf,contentName,styleName)) 99 | -------------------------------------------------------------------------------- /TestPhotoReal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import torch 5 | import argparse 6 | import numpy as np 7 | from PIL import Image 8 | from libs.SPN import SPN 9 | import torchvision.utils as vutils 10 | from libs.utils import print_options 11 | from libs.MatrixTest import MulLayer 12 | import torch.backends.cudnn as cudnn 13 | from libs.LoaderPhotoReal import Dataset 14 | from libs.models import encoder3,encoder4 15 | from libs.models import decoder3,decoder4 16 | import torchvision.transforms as transforms 17 | from libs.smooth_filter import smooth_filter 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--vgg_dir", default='models/vgg_r41.pth', 21 | help='pre-trained encoder path') 22 | parser.add_argument("--decoder_dir", default='models/dec_r41.pth', 23 | help='pre-trained decoder path') 24 | parser.add_argument("--matrixPath", default='models/r41.pth', 25 | help='pre-trained model path') 26 | parser.add_argument("--stylePath", default="data/photo_real/style/images/", 27 | help='path to style image') 28 | parser.add_argument("--styleSegPath", default="data/photo_real/styleSeg/", 29 | help='path to style image masks') 30 | parser.add_argument("--contentPath", default="data/photo_real/content/images/", 31 | help='path to content image') 32 | parser.add_argument("--contentSegPath", default="data/photo_real/contentSeg/", 33 | help='path to content image masks') 34 | parser.add_argument("--outf", default="PhotoReal/", 35 | help='path to save output images') 36 | parser.add_argument("--batchSize", type=int,default=1, 37 | help='batch size') 38 | parser.add_argument('--fineSize', type=int, default=512, 39 | help='image size') 40 | parser.add_argument("--layer", default="r41", 41 | help='features of which layer to transform, either r31 or r41') 42 | parser.add_argument("--spn_dir", default='models/r41_spn.pth', 43 | help='path to pretrained SPN model') 44 | 45 | ################# PREPARATIONS ################# 46 | opt = parser.parse_args() 47 | opt.cuda = torch.cuda.is_available() 48 | print_options(opt) 49 | 50 | os.makedirs(opt.outf, exist_ok=True) 51 | 52 | cudnn.benchmark = True 53 | 54 | ################# DATA ################# 55 | dataset = Dataset(opt.contentPath,opt.stylePath,opt.contentSegPath,opt.styleSegPath,opt.fineSize) 56 | loader = torch.utils.data.DataLoader(dataset=dataset, 57 | batch_size=1, 58 | shuffle=False) 59 | 60 | ################# MODEL ################# 61 | if(opt.layer == 'r31'): 62 | vgg = encoder3() 63 | dec = decoder3() 64 | elif(opt.layer == 'r41'): 65 | vgg = encoder4() 66 | dec = decoder4() 67 | matrix = MulLayer(opt.layer) 68 | vgg.load_state_dict(torch.load(opt.vgg_dir)) 69 | dec.load_state_dict(torch.load(opt.decoder_dir)) 70 | matrix.load_state_dict(torch.load(opt.matrixPath)) 71 | spn = SPN() 72 | spn.load_state_dict(torch.load(opt.spn_dir)) 73 | 74 | ################# GLOBAL VARIABLE ################# 75 | contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 76 | styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 77 | whitenV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 78 | 79 | ################# GPU ################# 80 | if(opt.cuda): 81 | vgg.cuda() 82 | dec.cuda() 83 | spn.cuda() 84 | matrix.cuda() 85 | contentV = contentV.cuda() 86 | styleV = styleV.cuda() 87 | whitenV = whitenV.cuda() 88 | 89 | for i,(contentImg,styleImg,whitenImg,cmasks,smasks,imname) in enumerate(loader): 90 | imname = imname[0] 91 | contentV.resize_(contentImg.size()).copy_(contentImg) 92 | styleV.resize_(styleImg.size()).copy_(styleImg) 93 | whitenV.resize_(whitenImg.size()).copy_(whitenImg) 94 | 95 | # forward 96 | sF = vgg(styleV) 97 | cF = vgg(contentV) 98 | 99 | with torch.no_grad(): 100 | if(opt.layer == 'r41'): 101 | feature = matrix(cF[opt.layer],sF[opt.layer],cmasks,smasks) 102 | else: 103 | feature = matrix(cF,sF,cmasks,smasks) 104 | transfer = dec(feature) 105 | filtered = spn(transfer,whitenV) 106 | vutils.save_image(transfer,os.path.join(opt.outf,'%s_transfer.png'%(imname.split('.')[0]))) 107 | 108 | filtered = filtered.clamp(0,1) 109 | filtered = filtered.cpu() 110 | vutils.save_image(filtered,'%s/%s_filtered.png'%(opt.outf,imname.split('.')[0])) 111 | out_img = filtered.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy() 112 | content = contentImg.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy() 113 | content = content.copy() 114 | out_img = out_img.copy() 115 | smoothed = smooth_filter(out_img, content, f_radius=15, f_edge=1e-1) 116 | smoothed.save('%s/%s_smooth.png'%(opt.outf,imname.split('.')[0])) 117 | print('Transferred image saved at %s%s, filtered image saved at %s%s_filtered.png' \ 118 | %(opt.outf,imname,opt.outf,imname.split('.')[0])) 119 | -------------------------------------------------------------------------------- /TestVideo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from PIL import Image 5 | from libs.Loader import Dataset 6 | from libs.Matrix import MulLayer 7 | import torch.backends.cudnn as cudnn 8 | from libs.models import encoder3,encoder4 9 | from libs.models import decoder3,decoder4 10 | import torchvision.transforms as transforms 11 | from libs.utils import makeVideo, print_options 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--vgg_dir", default='models/vgg_r31.pth', 15 | help='pre-trained encoder path') 16 | parser.add_argument("--decoder_dir", default='models/dec_r31.pth', 17 | help='pre-trained decoder path') 18 | parser.add_argument("--matrix_dir", default="models/r31.pth", 19 | help='path to pre-trained model') 20 | parser.add_argument("--style", default="data/style/in2.jpg", 21 | help='path to style image') 22 | parser.add_argument("--content_dir", default="data/videos/content/mountain_2/", 23 | help='path to video frames') 24 | parser.add_argument('--loadSize', type=int, default=512, 25 | help='scale image size') 26 | parser.add_argument('--fineSize', type=int, default=512, 27 | help='crop image size') 28 | parser.add_argument("--name",default="transferred_video", 29 | help="name of generated video") 30 | parser.add_argument("--layer",default="r31", 31 | help="features of which layer to transform") 32 | parser.add_argument("--outf",default="videos", 33 | help="output folder") 34 | 35 | ################# PREPARATIONS ################# 36 | opt = parser.parse_args() 37 | opt.cuda = torch.cuda.is_available() 38 | print_options(opt) 39 | 40 | os.makedirs(opt.outf,exist_ok=True) 41 | cudnn.benchmark = True 42 | 43 | ################# DATA ################# 44 | def loadImg(imgPath): 45 | img = Image.open(imgPath).convert('RGB') 46 | transform = transforms.Compose([ 47 | transforms.Scale(opt.fineSize), 48 | transforms.ToTensor()]) 49 | return transform(img) 50 | styleV = loadImg(opt.style).unsqueeze(0) 51 | 52 | content_dataset = Dataset(opt.content_dir, 53 | loadSize = opt.loadSize, 54 | fineSize = opt.fineSize, 55 | test = True, 56 | video = True) 57 | content_loader = torch.utils.data.DataLoader(dataset = content_dataset, 58 | batch_size = 1, 59 | shuffle = False) 60 | 61 | ################# MODEL ################# 62 | if(opt.layer == 'r31'): 63 | vgg = encoder3() 64 | dec = decoder3() 65 | elif(opt.layer == 'r41'): 66 | vgg = encoder4() 67 | dec = decoder4() 68 | matrix = MulLayer(layer=opt.layer) 69 | vgg.load_state_dict(torch.load(opt.vgg_dir)) 70 | dec.load_state_dict(torch.load(opt.decoder_dir)) 71 | matrix.load_state_dict(torch.load(opt.matrix_dir)) 72 | 73 | ################# GLOBAL VARIABLE ################# 74 | contentV = torch.Tensor(1,3,opt.fineSize,opt.fineSize) 75 | 76 | ################# GPU ################# 77 | if(opt.cuda): 78 | vgg.cuda() 79 | dec.cuda() 80 | matrix.cuda() 81 | 82 | styleV = styleV.cuda() 83 | contentV = contentV.cuda() 84 | 85 | result_frames = [] 86 | contents = [] 87 | style = styleV.squeeze(0).cpu().numpy() 88 | sF = vgg(styleV) 89 | 90 | for i,(content,contentName) in enumerate(content_loader): 91 | print('Transfer frame %d...'%i) 92 | contentName = contentName[0] 93 | contentV.resize_(content.size()).copy_(content) 94 | contents.append(content.squeeze(0).float().numpy()) 95 | # forward 96 | with torch.no_grad(): 97 | cF = vgg(contentV) 98 | 99 | if(opt.layer == 'r41'): 100 | feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer]) 101 | else: 102 | feature,transmatrix = matrix(cF,sF) 103 | transfer = dec(feature) 104 | 105 | transfer = transfer.clamp(0,1) 106 | result_frames.append(transfer.squeeze(0).cpu().numpy()) 107 | 108 | makeVideo(contents,style,result_frames,opt.outf) 109 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from libs.Loader import Dataset 7 | from libs.Matrix import MulLayer 8 | import torchvision.utils as vutils 9 | import torch.backends.cudnn as cudnn 10 | from libs.utils import print_options 11 | from libs.Criterion import LossCriterion 12 | from libs.models import encoder3,encoder4 13 | from libs.models import decoder3,decoder4 14 | from libs.models import encoder5 as loss_network 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--vgg_dir", default='models/vgg_r41.pth', 18 | help='pre-trained encoder path') 19 | parser.add_argument("--loss_network_dir", default='models/vgg_r51.pth', 20 | help='used for loss network') 21 | parser.add_argument("--decoder_dir", default='models/dec_r41.pth', 22 | help='pre-trained decoder path') 23 | parser.add_argument("--stylePath", default="/home/xtli/DATA/wikiArt/train/images/", 24 | help='path to wikiArt dataset') 25 | parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/", 26 | help='path to MSCOCO dataset') 27 | parser.add_argument("--outf", default="trainingOutput/", 28 | help='folder to output images and model checkpoints') 29 | parser.add_argument("--content_layers", default="r41", 30 | help='layers for content') 31 | parser.add_argument("--style_layers", default="r11,r21,r31,r41", 32 | help='layers for style') 33 | parser.add_argument("--batchSize", type=int,default=8, 34 | help='batch size') 35 | parser.add_argument("--niter", type=int,default=100000, 36 | help='iterations to train the model') 37 | parser.add_argument('--loadSize', type=int, default=300, 38 | help='scale image size') 39 | parser.add_argument('--fineSize', type=int, default=256, 40 | help='crop image size') 41 | parser.add_argument("--lr", type=float, default=1e-4, 42 | help='learning rate') 43 | parser.add_argument("--content_weight", type=float, default=1.0, 44 | help='content loss weight') 45 | parser.add_argument("--style_weight", type=float, default=0.02, 46 | help='style loss weight') 47 | parser.add_argument("--log_interval", type=int, default=500, 48 | help='log interval') 49 | parser.add_argument("--gpu_id", type=int, default=0, 50 | help='which gpu to use') 51 | parser.add_argument("--save_interval", type=int, default=5000, 52 | help='checkpoint save interval') 53 | parser.add_argument("--layer", default="r41", 54 | help='which features to transfer, either r31 or r41') 55 | 56 | ################# PREPARATIONS ################# 57 | opt = parser.parse_args() 58 | opt.content_layers = opt.content_layers.split(',') 59 | opt.style_layers = opt.style_layers.split(',') 60 | opt.cuda = torch.cuda.is_available() 61 | if(opt.cuda): 62 | torch.cuda.set_device(opt.gpu_id) 63 | 64 | os.makedirs(opt.outf,exist_ok=True) 65 | cudnn.benchmark = True 66 | print_options(opt) 67 | 68 | ################# DATA ################# 69 | content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize) 70 | content_loader_ = torch.utils.data.DataLoader(dataset = content_dataset, 71 | batch_size = opt.batchSize, 72 | shuffle = True, 73 | num_workers = 1, 74 | drop_last = True) 75 | content_loader = iter(content_loader_) 76 | style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize) 77 | style_loader_ = torch.utils.data.DataLoader(dataset = style_dataset, 78 | batch_size = opt.batchSize, 79 | shuffle = True, 80 | num_workers = 1, 81 | drop_last = True) 82 | style_loader = iter(style_loader_) 83 | 84 | ################# MODEL ################# 85 | vgg5 = loss_network() 86 | if(opt.layer == 'r31'): 87 | matrix = MulLayer('r31') 88 | vgg = encoder3() 89 | dec = decoder3() 90 | elif(opt.layer == 'r41'): 91 | matrix = MulLayer('r41') 92 | vgg = encoder4() 93 | dec = decoder4() 94 | vgg.load_state_dict(torch.load(opt.vgg_dir)) 95 | dec.load_state_dict(torch.load(opt.decoder_dir)) 96 | vgg5.load_state_dict(torch.load(opt.loss_network_dir)) 97 | 98 | for param in vgg.parameters(): 99 | param.requires_grad = False 100 | for param in vgg5.parameters(): 101 | param.requires_grad = False 102 | for param in dec.parameters(): 103 | param.requires_grad = False 104 | 105 | ################# LOSS & OPTIMIZER ################# 106 | criterion = LossCriterion(opt.style_layers, 107 | opt.content_layers, 108 | opt.style_weight, 109 | opt.content_weight) 110 | optimizer = optim.Adam(matrix.parameters(), opt.lr) 111 | 112 | ################# GLOBAL VARIABLE ################# 113 | contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 114 | styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 115 | 116 | ################# GPU ################# 117 | if(opt.cuda): 118 | vgg.cuda() 119 | dec.cuda() 120 | vgg5.cuda() 121 | matrix.cuda() 122 | contentV = contentV.cuda() 123 | styleV = styleV.cuda() 124 | 125 | ################# TRAINING ################# 126 | def adjust_learning_rate(optimizer, iteration): 127 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 128 | for param_group in optimizer.param_groups: 129 | param_group['lr'] = opt.lr / (1+iteration*1e-5) 130 | 131 | for iteration in range(1,opt.niter+1): 132 | optimizer.zero_grad() 133 | try: 134 | content,_ = content_loader.next() 135 | except IOError: 136 | content,_ = content_loader.next() 137 | except StopIteration: 138 | content_loader = iter(content_loader_) 139 | content,_ = content_loader.next() 140 | except: 141 | continue 142 | 143 | try: 144 | style,_ = style_loader.next() 145 | except IOError: 146 | style,_ = style_loader.next() 147 | except StopIteration: 148 | style_loader = iter(style_loader_) 149 | style,_ = style_loader.next() 150 | except: 151 | continue 152 | 153 | contentV.resize_(content.size()).copy_(content) 154 | styleV.resize_(style.size()).copy_(style) 155 | 156 | # forward 157 | sF = vgg(styleV) 158 | cF = vgg(contentV) 159 | 160 | if(opt.layer == 'r41'): 161 | feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer]) 162 | else: 163 | feature,transmatrix = matrix(cF,sF) 164 | transfer = dec(feature) 165 | 166 | sF_loss = vgg5(styleV) 167 | cF_loss = vgg5(contentV) 168 | tF = vgg5(transfer) 169 | loss,styleLoss,contentLoss = criterion(tF,sF_loss,cF_loss) 170 | 171 | # backward & optimization 172 | loss.backward() 173 | optimizer.step() 174 | print('Iteration: [%d/%d] Loss: %.4f contentLoss: %.4f styleLoss: %.4f Learng Rate is %.6f'% 175 | (opt.niter,iteration,loss,contentLoss,styleLoss,optimizer.param_groups[0]['lr'])) 176 | 177 | adjust_learning_rate(optimizer,iteration) 178 | 179 | if((iteration) % opt.log_interval == 0): 180 | transfer = transfer.clamp(0,1) 181 | concat = torch.cat((content,style,transfer.cpu()),dim=0) 182 | vutils.save_image(concat,'%s/%d.png'%(opt.outf,iteration),normalize=True,scale_each=True,nrow=opt.batchSize) 183 | 184 | if(iteration > 0 and (iteration) % opt.save_interval == 0): 185 | torch.save(matrix.state_dict(), '%s/%s.pth' % (opt.outf,opt.layer)) 186 | -------------------------------------------------------------------------------- /TrainSPN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | 5 | from libs.SPN import SPN 6 | from libs.Loader import Dataset 7 | from libs.models import encoder4 8 | from libs.models import decoder4 9 | from libs.utils import print_options 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torch.nn.functional as F 15 | import torchvision.utils as vutils 16 | import torch.backends.cudnn as cudnn 17 | import torchvision.transforms as transforms 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--vgg_dir", default='models/vgg_r41.pth', 21 | help='pre-trained encoder path') 22 | parser.add_argument("--decoder_dir", default='models/dec_r41.pth', 23 | help='pre-trained decoder path') 24 | parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/", 25 | help='path to MSCOCO dataset') 26 | parser.add_argument("--outf", default="trainingSPNOutput/", 27 | help='folder to output images and model checkpoints') 28 | parser.add_argument("--layer", default="r41", 29 | help='layers for content') 30 | parser.add_argument("--batchSize", type=int,default=8, 31 | help='batch size') 32 | parser.add_argument("--niter", type=int,default=100000, 33 | help='iterations to train the model') 34 | parser.add_argument('--loadSize', type=int, default=512, 35 | help='scale image size') 36 | parser.add_argument('--fineSize', type=int, default=256, 37 | help='crop image size') 38 | parser.add_argument("--lr", type=float, default=1e-3, 39 | help='learning rate') 40 | parser.add_argument("--log_interval", type=int, default=500, 41 | help='log interval') 42 | parser.add_argument("--save_interval", type=int, default=5000, 43 | help='checkpoint save interval') 44 | parser.add_argument("--spn_num", type=int, default=1, 45 | help='number of spn filters') 46 | 47 | ################# PREPARATIONS ################# 48 | opt = parser.parse_args() 49 | opt.cuda = torch.cuda.is_available() 50 | print_options(opt) 51 | 52 | 53 | os.makedirs(opt.outf, exist_ok = True) 54 | 55 | cudnn.benchmark = True 56 | 57 | ################# DATA ################# 58 | content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize) 59 | content_loader_ = torch.utils.data.DataLoader(dataset=content_dataset, 60 | batch_size = opt.batchSize, 61 | shuffle = True, 62 | num_workers = 4, 63 | drop_last = True) 64 | content_loader = iter(content_loader_) 65 | 66 | ################# MODEL ################# 67 | spn = SPN(spn=opt.spn_num) 68 | if(opt.layer == 'r31'): 69 | vgg = encoder3() 70 | dec = decoder3() 71 | elif(opt.layer == 'r41'): 72 | vgg = encoder4() 73 | dec = decoder4() 74 | vgg.load_state_dict(torch.load(opt.vgg_dir)) 75 | dec.load_state_dict(torch.load(opt.decoder_dir)) 76 | 77 | for param in vgg.parameters(): 78 | param.requires_grad = False 79 | for param in dec.parameters(): 80 | param.requires_grad = False 81 | 82 | ################# LOSS & OPTIMIZER ################# 83 | criterion = nn.MSELoss(size_average=False) 84 | #optimizer_spn = optim.SGD(spn.parameters(), opt.lr) 85 | optimizer_spn = optim.Adam(spn.parameters(), opt.lr) 86 | 87 | ################# GLOBAL VARIABLE ################# 88 | contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 89 | 90 | ################# GPU ################# 91 | if(opt.cuda): 92 | vgg.cuda() 93 | dec.cuda() 94 | spn.cuda() 95 | contentV = contentV.cuda() 96 | 97 | ################# TRAINING ################# 98 | def adjust_learning_rate(optimizer, iteration): 99 | for param_group in optimizer.param_groups: 100 | param_group['lr'] = opt.lr / (1+iteration*1e-5) 101 | 102 | spn.train() 103 | for iteration in range(1,opt.niter+1): 104 | optimizer_spn.zero_grad() 105 | try: 106 | content,_ = content_loader.next() 107 | except IOError: 108 | content,_ = content_loader.next() 109 | except StopIteration: 110 | content_loader = iter(content_loader_) 111 | content,_ = content_loader.next() 112 | except: 113 | continue 114 | 115 | contentV.resize_(content.size()).copy_(content) 116 | 117 | # forward 118 | cF = vgg(contentV) 119 | transfer = dec(cF['r41']) 120 | 121 | 122 | propagated = spn(transfer,contentV) 123 | loss = criterion(propagated,contentV) 124 | 125 | # backward & optimization 126 | loss.backward() 127 | #nn.utils.clip_grad_norm(spn.parameters(), 1000) 128 | optimizer_spn.step() 129 | print('Iteration: [%d/%d] Loss: %.4f Learng Rate is %.6f' 130 | %(opt.niter,iteration,loss,optimizer_spn.param_groups[0]['lr'])) 131 | 132 | adjust_learning_rate(optimizer_spn,iteration) 133 | 134 | if((iteration) % opt.log_interval == 0): 135 | transfer = transfer.clamp(0,1) 136 | propagated = propagated.clamp(0,1) 137 | vutils.save_image(transfer,'%s/%d_transfer.png'%(opt.outf,iteration)) 138 | vutils.save_image(propagated,'%s/%d_propagated.png'%(opt.outf,iteration)) 139 | 140 | if(iteration > 0 and (iteration) % opt.save_interval == 0): 141 | torch.save(spn.state_dict(), '%s/%s_spn.pth' % (opt.outf,opt.layer)) 142 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/.DS_Store -------------------------------------------------------------------------------- /data/content/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/content/.DS_Store -------------------------------------------------------------------------------- /data/content/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/content/1.jpg -------------------------------------------------------------------------------- /data/content/chicago.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/content/chicago.png -------------------------------------------------------------------------------- /data/photo_real/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/.DS_Store -------------------------------------------------------------------------------- /data/photo_real/content/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/.DS_Store -------------------------------------------------------------------------------- /data/photo_real/content/images/in16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in16.png -------------------------------------------------------------------------------- /data/photo_real/content/images/in25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in25.png -------------------------------------------------------------------------------- /data/photo_real/content/images/in26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in26.png -------------------------------------------------------------------------------- /data/photo_real/content/images/in29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in29.png -------------------------------------------------------------------------------- /data/photo_real/content/images/in3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in3.png -------------------------------------------------------------------------------- /data/photo_real/content/images/in39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in39.png -------------------------------------------------------------------------------- /data/photo_real/content/images/in53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in53.png -------------------------------------------------------------------------------- /data/photo_real/content/images/in7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in7.png -------------------------------------------------------------------------------- /data/photo_real/contentSeg/in16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in16.png -------------------------------------------------------------------------------- /data/photo_real/contentSeg/in25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in25.png -------------------------------------------------------------------------------- /data/photo_real/contentSeg/in26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in26.png -------------------------------------------------------------------------------- /data/photo_real/contentSeg/in29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in29.png -------------------------------------------------------------------------------- /data/photo_real/contentSeg/in3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in3.png -------------------------------------------------------------------------------- /data/photo_real/contentSeg/in39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in39.png -------------------------------------------------------------------------------- /data/photo_real/contentSeg/in53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in53.png -------------------------------------------------------------------------------- /data/photo_real/contentSeg/in7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in7.png -------------------------------------------------------------------------------- /data/photo_real/style/images/in16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in16.png -------------------------------------------------------------------------------- /data/photo_real/style/images/in25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in25.png -------------------------------------------------------------------------------- /data/photo_real/style/images/in26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in26.png -------------------------------------------------------------------------------- /data/photo_real/style/images/in29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in29.png -------------------------------------------------------------------------------- /data/photo_real/style/images/in3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in3.png -------------------------------------------------------------------------------- /data/photo_real/style/images/in39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in39.png -------------------------------------------------------------------------------- /data/photo_real/style/images/in53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in53.png -------------------------------------------------------------------------------- /data/photo_real/style/images/in7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in7.png -------------------------------------------------------------------------------- /data/photo_real/styleSeg/in16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in16.png -------------------------------------------------------------------------------- /data/photo_real/styleSeg/in25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in25.png -------------------------------------------------------------------------------- /data/photo_real/styleSeg/in26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in26.png -------------------------------------------------------------------------------- /data/photo_real/styleSeg/in29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in29.png -------------------------------------------------------------------------------- /data/photo_real/styleSeg/in3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in3.png -------------------------------------------------------------------------------- /data/photo_real/styleSeg/in39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in39.png -------------------------------------------------------------------------------- /data/photo_real/styleSeg/in53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in53.png -------------------------------------------------------------------------------- /data/photo_real/styleSeg/in7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in7.png -------------------------------------------------------------------------------- /data/style/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/.DS_Store -------------------------------------------------------------------------------- /data/style/27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/27.jpg -------------------------------------------------------------------------------- /data/style/3314.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/3314.jpg -------------------------------------------------------------------------------- /data/style/antimonocromatismo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/antimonocromatismo.jpg -------------------------------------------------------------------------------- /data/style/in2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/in2.jpg -------------------------------------------------------------------------------- /data/style/picasso_self_portrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/picasso_self_portrait.jpg -------------------------------------------------------------------------------- /data/style/sketch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/sketch.jpg -------------------------------------------------------------------------------- /data/videos/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/.DS_Store -------------------------------------------------------------------------------- /data/videos/content/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/.DS_Store -------------------------------------------------------------------------------- /data/videos/content/mountain_2/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/.DS_Store -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0001.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0002.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0003.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0004.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0005.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0006.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0007.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0008.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0009.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0010.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0011.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0012.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0013.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0014.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0015.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0016.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0017.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0018.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0019.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0020.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0021.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0022.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0023.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0024.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0025.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0026.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0026.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0027.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0027.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0028.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0028.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0029.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0030.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0030.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0031.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0032.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0033.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0033.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0034.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0034.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0035.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0036.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0037.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0037.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0038.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0038.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0039.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0039.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0040.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0040.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0041.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0042.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0042.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0043.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0044.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0045.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0045.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0046.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0047.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0048.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0049.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0049.png -------------------------------------------------------------------------------- /data/videos/content/mountain_2/frame_0050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0050.png -------------------------------------------------------------------------------- /doc/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/.DS_Store -------------------------------------------------------------------------------- /doc/images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/.DS_Store -------------------------------------------------------------------------------- /doc/images/chicago_27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/chicago_27.png -------------------------------------------------------------------------------- /doc/images/chicago_paste.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/chicago_paste.png -------------------------------------------------------------------------------- /doc/images/content.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/content.gif -------------------------------------------------------------------------------- /doc/images/in5_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/in5_result.png -------------------------------------------------------------------------------- /doc/images/photo_content.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/photo_content.png -------------------------------------------------------------------------------- /doc/images/test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/test.gif -------------------------------------------------------------------------------- /libs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/.DS_Store -------------------------------------------------------------------------------- /libs/Criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class styleLoss(nn.Module): 5 | def forward(self,input,target): 6 | ib,ic,ih,iw = input.size() 7 | iF = input.view(ib,ic,-1) 8 | iMean = torch.mean(iF,dim=2) 9 | iCov = GramMatrix()(input) 10 | 11 | tb,tc,th,tw = target.size() 12 | tF = target.view(tb,tc,-1) 13 | tMean = torch.mean(tF,dim=2) 14 | tCov = GramMatrix()(target) 15 | 16 | loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov) 17 | return loss/tb 18 | 19 | class GramMatrix(nn.Module): 20 | def forward(self,input): 21 | b, c, h, w = input.size() 22 | f = input.view(b,c,h*w) # bxcx(hxw) 23 | # torch.bmm(batch1, batch2, out=None) # 24 | # batch1: bxmxp, batch2: bxpxn -> bxmxn # 25 | G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc 26 | return G.div_(c*h*w) 27 | 28 | class LossCriterion(nn.Module): 29 | def __init__(self,style_layers,content_layers,style_weight,content_weight): 30 | super(LossCriterion,self).__init__() 31 | 32 | self.style_layers = style_layers 33 | self.content_layers = content_layers 34 | self.style_weight = style_weight 35 | self.content_weight = content_weight 36 | 37 | self.styleLosses = [styleLoss()] * len(style_layers) 38 | self.contentLosses = [nn.MSELoss()] * len(content_layers) 39 | 40 | def forward(self,tF,sF,cF): 41 | # content loss 42 | totalContentLoss = 0 43 | for i,layer in enumerate(self.content_layers): 44 | cf_i = cF[layer] 45 | cf_i = cf_i.detach() 46 | tf_i = tF[layer] 47 | loss_i = self.contentLosses[i] 48 | totalContentLoss += loss_i(tf_i,cf_i) 49 | totalContentLoss = totalContentLoss * self.content_weight 50 | 51 | # style loss 52 | totalStyleLoss = 0 53 | for i,layer in enumerate(self.style_layers): 54 | sf_i = sF[layer] 55 | sf_i = sf_i.detach() 56 | tf_i = tF[layer] 57 | loss_i = self.styleLosses[i] 58 | totalStyleLoss += loss_i(tf_i,sf_i) 59 | totalStyleLoss = totalStyleLoss * self.style_weight 60 | loss = totalStyleLoss + totalContentLoss 61 | 62 | return loss,totalStyleLoss,totalContentLoss 63 | -------------------------------------------------------------------------------- /libs/Loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | 6 | def is_image_file(filename): 7 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 8 | 9 | def default_loader(path): 10 | return Image.open(path).convert('RGB') 11 | 12 | class Dataset(data.Dataset): 13 | def __init__(self,dataPath,loadSize,fineSize,test=False,video=False): 14 | super(Dataset,self).__init__() 15 | self.dataPath = dataPath 16 | self.image_list = [x for x in os.listdir(dataPath) if is_image_file(x)] 17 | self.image_list = sorted(self.image_list) 18 | if(video): 19 | self.image_list = sorted(self.image_list) 20 | if not test: 21 | self.transform = transforms.Compose([ 22 | transforms.Resize(fineSize), 23 | transforms.RandomCrop(fineSize), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor()]) 26 | else: 27 | self.transform = transforms.Compose([ 28 | transforms.Resize(fineSize), 29 | transforms.ToTensor()]) 30 | 31 | self.test = test 32 | 33 | def __getitem__(self,index): 34 | dataPath = os.path.join(self.dataPath,self.image_list[index]) 35 | 36 | Img = default_loader(dataPath) 37 | ImgA = self.transform(Img) 38 | 39 | imgName = self.image_list[index] 40 | imgName = imgName.split('.')[0] 41 | return ImgA,imgName 42 | 43 | def __len__(self): 44 | return len(self.image_list) 45 | -------------------------------------------------------------------------------- /libs/LoaderPhotoReal.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torchvision.transforms as transforms 3 | import torchvision.utils as vutils 4 | import torch.utils.data as data 5 | from os import listdir 6 | from os.path import join 7 | import numpy as np 8 | import torch 9 | import os 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | import numpy as np 13 | from libs.utils import whiten 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 17 | 18 | def default_loader(path,fineSize): 19 | img = Image.open(path).convert('RGB') 20 | w,h = img.size 21 | if(w < h): 22 | neww = fineSize 23 | newh = h * neww / w 24 | newh = int(newh / 8) * 8 25 | else: 26 | newh = fineSize 27 | neww = w * newh / h 28 | neww = int(neww / 8) * 8 29 | img = img.resize((neww,newh)) 30 | return img 31 | 32 | def MaskHelper(seg,color): 33 | # green 34 | mask = torch.Tensor() 35 | if(color == 'green'): 36 | mask = torch.lt(seg[0],0.1) 37 | mask = torch.mul(mask,torch.gt(seg[1],1-0.1)) 38 | mask = torch.mul(mask,torch.lt(seg[2],0.1)) 39 | elif(color == 'black'): 40 | mask = torch.lt(seg[0], 0.1) 41 | mask = torch.mul(mask,torch.lt(seg[1], 0.1)) 42 | mask = torch.mul(mask,torch.lt(seg[2], 0.1)) 43 | elif(color == 'white'): 44 | mask = torch.gt(seg[0], 1-0.1) 45 | mask = torch.mul(mask,torch.gt(seg[1], 1-0.1)) 46 | mask = torch.mul(mask,torch.gt(seg[2], 1-0.1)) 47 | elif(color == 'red'): 48 | mask = torch.gt(seg[0], 1-0.1) 49 | mask = torch.mul(mask,torch.lt(seg[1], 0.1)) 50 | mask = torch.mul(mask,torch.lt(seg[2], 0.1)) 51 | elif(color == 'blue'): 52 | mask = torch.lt(seg[0], 0.1) 53 | mask = torch.mul(mask,torch.lt(seg[1], 0.1)) 54 | mask = torch.mul(mask,torch.gt(seg[2], 1-0.1)) 55 | elif(color == 'yellow'): 56 | mask = torch.gt(seg[0], 1-0.1) 57 | mask = torch.mul(mask,torch.gt(seg[1], 1-0.1)) 58 | mask = torch.mul(mask,torch.lt(seg[2], 0.1)) 59 | elif(color == 'grey'): 60 | mask = torch.lt(seg[0], 0.1) 61 | mask = torch.mul(mask,torch.lt(seg[1], 0.1)) 62 | mask = torch.mul(mask,torch.lt(seg[2], 0.1)) 63 | elif(color == 'lightblue'): 64 | mask = torch.lt(seg[0], 0.1) 65 | mask = torch.mul(mask,torch.gt(seg[1], 1-0.1)) 66 | mask = torch.mul(mask,torch.gt(seg[2], 1-0.1)) 67 | elif(color == 'purple'): 68 | mask = torch.gt(seg[0], 1-0.1) 69 | mask = torch.mul(mask,torch.lt(seg[1], 0.1)) 70 | mask = torch.mul(mask,torch.gt(seg[2], 1-0.1)) 71 | else: 72 | print('MaskHelper(): color not recognized, color = ' + color) 73 | return mask.float() 74 | 75 | def ExtractMask(Seg): 76 | # Given segmentation for content and style, we get a list of segmentation for each color 77 | ''' 78 | Test Code: 79 | content_masks,style_masks = ExtractMask(contentSegImg,styleSegImg) 80 | for i,mask in enumerate(content_masks): 81 | vutils.save_image(mask,'samples/content_%d.png' % (i),normalize=True) 82 | for i,mask in enumerate(style_masks): 83 | vutils.save_image(mask,'samples/style_%d.png' % (i),normalize=True) 84 | ''' 85 | color_codes = ['blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple'] 86 | masks = [] 87 | for color in color_codes: 88 | mask = MaskHelper(Seg,color) 89 | masks.append(mask) 90 | return masks 91 | 92 | def calculate_size(h,w,fineSize): 93 | if(h > w): 94 | newh = fineSize 95 | neww = int(w * 1.0 * newh / h) 96 | else: 97 | neww = fineSize 98 | newh = int(h * 1.0 * neww / w) 99 | newh = (newh // 8) * 8 100 | neww = (neww // 8) * 8 101 | return neww, newh 102 | 103 | class Dataset(data.Dataset): 104 | def __init__(self,contentPath,stylePath,contentSegPath,styleSegPath,fineSize): 105 | super(Dataset,self).__init__() 106 | self.contentPath = contentPath 107 | self.image_list = [x for x in listdir(contentPath) if is_image_file(x)] 108 | self.stylePath = stylePath 109 | self.contentSegPath = contentSegPath 110 | self.styleSegPath = styleSegPath 111 | self.fineSize = fineSize 112 | 113 | def __getitem__(self,index): 114 | contentImgPath = os.path.join(self.contentPath,self.image_list[index]) 115 | styleImgPath = os.path.join(self.stylePath,self.image_list[index]) 116 | contentImg = default_loader(contentImgPath,self.fineSize) 117 | styleImg = default_loader(styleImgPath,self.fineSize) 118 | 119 | try: 120 | contentSegImgPath = os.path.join(self.contentSegPath,self.image_list[index]) 121 | contentSegImg = default_loader(contentSegImgPath,self.fineSize) 122 | except : 123 | print('no mask provided, fake a whole black one') 124 | contentSegImg = Image.new('RGB', (contentImg.size)) 125 | 126 | try: 127 | styleSegImgPath = os.path.join(self.styleSegPath,self.image_list[index]) 128 | styleSegImg = default_loader(styleSegImgPath,self.fineSize) 129 | except : 130 | print('no mask provided, fake a whole black one') 131 | styleSegImg = Image.new('RGB', (styleImg.size)) 132 | 133 | 134 | hs, ws = styleImg.size 135 | newhs, newws = calculate_size(hs,ws,self.fineSize) 136 | 137 | transform = transforms.Compose([ 138 | transforms.Resize((newhs, newws)), 139 | transforms.ToTensor()]) 140 | # Turning segmentation images into masks 141 | styleSegImg = transform(styleSegImg) 142 | styleImgArbi = transform(styleImg) 143 | 144 | hc, wc = contentImg.size 145 | newhc, newwc = calculate_size(hc,wc,self.fineSize) 146 | 147 | transform = transforms.Compose([ 148 | transforms.Resize((newhc, newwc)), 149 | transforms.ToTensor()]) 150 | contentSegImg = transform(contentSegImg) 151 | contentImgArbi = transform(contentImg) 152 | 153 | content_masks = ExtractMask(contentSegImg) 154 | style_masks = ExtractMask(styleSegImg) 155 | 156 | ImgW = whiten(contentImgArbi.view(3,-1).double()) 157 | ImgW = ImgW.view(contentImgArbi.size()).float() 158 | 159 | return contentImgArbi.squeeze(0),styleImgArbi.squeeze(0),ImgW,content_masks,style_masks,self.image_list[index] 160 | 161 | def __len__(self): 162 | return len(self.image_list) 163 | -------------------------------------------------------------------------------- /libs/Matrix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CNN(nn.Module): 5 | def __init__(self,layer,matrixSize=32): 6 | super(CNN,self).__init__() 7 | if(layer == 'r31'): 8 | # 256x64x64 9 | self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1), 10 | nn.ReLU(inplace=True), 11 | nn.Conv2d(128,64,3,1,1), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(64,matrixSize,3,1,1)) 14 | elif(layer == 'r41'): 15 | # 512x32x32 16 | self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(256,128,3,1,1), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(128,matrixSize,3,1,1)) 21 | 22 | # 32x8x8 23 | self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize) 24 | #self.fc = nn.Linear(32*64,256*256) 25 | 26 | def forward(self,x): 27 | out = self.convs(x) 28 | # 32x8x8 29 | b,c,h,w = out.size() 30 | out = out.view(b,c,-1) 31 | # 32x64 32 | out = torch.bmm(out,out.transpose(1,2)).div(h*w) 33 | # 32x32 34 | out = out.view(out.size(0),-1) 35 | return self.fc(out) 36 | 37 | class MulLayer(nn.Module): 38 | def __init__(self,layer,matrixSize=32): 39 | super(MulLayer,self).__init__() 40 | self.snet = CNN(layer,matrixSize) 41 | self.cnet = CNN(layer,matrixSize) 42 | self.matrixSize = matrixSize 43 | 44 | if(layer == 'r41'): 45 | self.compress = nn.Conv2d(512,matrixSize,1,1,0) 46 | self.unzip = nn.Conv2d(matrixSize,512,1,1,0) 47 | elif(layer == 'r31'): 48 | self.compress = nn.Conv2d(256,matrixSize,1,1,0) 49 | self.unzip = nn.Conv2d(matrixSize,256,1,1,0) 50 | self.transmatrix = None 51 | 52 | def forward(self,cF,sF,trans=True): 53 | cFBK = cF.clone() 54 | cb,cc,ch,cw = cF.size() 55 | cFF = cF.view(cb,cc,-1) 56 | cMean = torch.mean(cFF,dim=2,keepdim=True) 57 | cMean = cMean.unsqueeze(3) 58 | cMean = cMean.expand_as(cF) 59 | cF = cF - cMean 60 | 61 | sb,sc,sh,sw = sF.size() 62 | sFF = sF.view(sb,sc,-1) 63 | sMean = torch.mean(sFF,dim=2,keepdim=True) 64 | sMean = sMean.unsqueeze(3) 65 | sMeanC = sMean.expand_as(cF) 66 | sMeanS = sMean.expand_as(sF) 67 | sF = sF - sMeanS 68 | 69 | 70 | compress_content = self.compress(cF) 71 | b,c,h,w = compress_content.size() 72 | compress_content = compress_content.view(b,c,-1) 73 | 74 | if(trans): 75 | cMatrix = self.cnet(cF) 76 | sMatrix = self.snet(sF) 77 | 78 | sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize) 79 | cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize) 80 | transmatrix = torch.bmm(sMatrix,cMatrix) 81 | transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w) 82 | out = self.unzip(transfeature.view(b,c,h,w)) 83 | out = out + sMeanC 84 | return out, transmatrix 85 | else: 86 | out = self.unzip(compress_content.view(b,c,h,w)) 87 | out = out + cMean 88 | return out 89 | -------------------------------------------------------------------------------- /libs/MatrixTest.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import cv2 6 | from torch.autograd import Variable 7 | import torchvision.utils as vutils 8 | 9 | 10 | class CNN(nn.Module): 11 | def __init__(self,layer,matrixSize=32): 12 | super(CNN,self).__init__() 13 | # 256x64x64 14 | if(layer == 'r31'): 15 | self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(128,64,3,1,1), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(64,matrixSize,3,1,1)) 20 | elif(layer == 'r41'): 21 | # 512x32x32 22 | self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(256,128,3,1,1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(128,matrixSize,3,1,1)) 27 | self.fc = nn.Linear(32*32,32*32) 28 | 29 | def forward(self,x,masks,style=False): 30 | color_code_number = 9 31 | xb,xc,xh,xw = x.size() 32 | x = x.view(xc,-1) 33 | feature_sub_mean = x.clone() 34 | for i in range(color_code_number): 35 | mask = masks[i].clone().squeeze(0) 36 | mask = cv2.resize(mask.numpy(),(xw,xh),interpolation=cv2.INTER_NEAREST) 37 | mask = torch.FloatTensor(mask) 38 | mask = mask.long() 39 | if(torch.sum(mask) >= 10): 40 | mask = mask.view(-1) 41 | 42 | # dilation here 43 | """ 44 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5)) 45 | mask = mask.cpu().numpy() 46 | mask = cv2.dilate(mask.astype(np.float32), kernel) 47 | mask = torch.from_numpy(mask) 48 | mask = mask.squeeze() 49 | """ 50 | 51 | fgmask = (mask>0).nonzero().squeeze(1) 52 | fgmask = fgmask.cuda() 53 | selectFeature = torch.index_select(x,1,fgmask) # 32x96 54 | # subtract mean 55 | f_mean = torch.mean(selectFeature,1) 56 | f_mean = f_mean.unsqueeze(1).expand_as(selectFeature) 57 | selectFeature = selectFeature - f_mean 58 | feature_sub_mean.index_copy_(1,fgmask,selectFeature) 59 | 60 | feature = self.convs(feature_sub_mean.view(xb,xc,xh,xw)) 61 | # 32x16x16 62 | b,c,h,w = feature.size() 63 | transMatrices = {} 64 | feature = feature.view(c,-1) 65 | 66 | for i in range(color_code_number): 67 | mask = masks[i].clone().squeeze(0) 68 | mask = cv2.resize(mask.numpy(),(w,h),interpolation=cv2.INTER_NEAREST) 69 | mask = torch.FloatTensor(mask) 70 | mask = mask.long() 71 | if(torch.sum(mask) >= 10): 72 | mask = mask.view(-1) 73 | fgmask = Variable((mask==1).nonzero().squeeze(1)) 74 | fgmask = fgmask.cuda() 75 | selectFeature = torch.index_select(feature,1,fgmask) # 32x96 76 | tc,tN = selectFeature.size() 77 | 78 | covMatrix = torch.mm(selectFeature,selectFeature.transpose(0,1)).div(tN) 79 | transmatrix = self.fc(covMatrix.view(-1)) 80 | transMatrices[i] = transmatrix 81 | return transMatrices,feature_sub_mean 82 | 83 | class MulLayer(nn.Module): 84 | def __init__(self,layer,matrixSize=32): 85 | super(MulLayer,self).__init__() 86 | self.snet = CNN(layer) 87 | self.cnet = CNN(layer) 88 | self.matrixSize = matrixSize 89 | 90 | if(layer == 'r41'): 91 | self.compress = nn.Conv2d(512,matrixSize,1,1,0) 92 | self.unzip = nn.Conv2d(matrixSize,512,1,1,0) 93 | elif(layer == 'r31'): 94 | self.compress = nn.Conv2d(256,matrixSize,1,1,0) 95 | self.unzip = nn.Conv2d(matrixSize,256,1,1,0) 96 | 97 | def forward(self,cF,sF,cmasks,smasks): 98 | 99 | sb,sc,sh,sw = sF.size() 100 | 101 | sMatrices,sF_sub_mean = self.snet(sF,smasks,style=True) 102 | cMatrices,cF_sub_mean = self.cnet(cF,cmasks,style=False) 103 | 104 | compress_content = self.compress(cF_sub_mean.view(cF.size())) 105 | cb,cc,ch,cw = compress_content.size() 106 | compress_content = compress_content.view(cc,-1) 107 | transfeature = compress_content.clone() 108 | color_code_number = 9 109 | finalSMean = Variable(torch.zeros(cF.size()).cuda(0)) 110 | finalSMean = finalSMean.view(sc,-1) 111 | for i in range(color_code_number): 112 | cmask = cmasks[i].clone().squeeze(0) 113 | smask = smasks[i].clone().squeeze(0) 114 | 115 | cmask = cv2.resize(cmask.numpy(),(cw,ch),interpolation=cv2.INTER_NEAREST) 116 | cmask = torch.FloatTensor(cmask) 117 | cmask = cmask.long() 118 | smask = cv2.resize(smask.numpy(),(sw,sh),interpolation=cv2.INTER_NEAREST) 119 | smask = torch.FloatTensor(smask) 120 | smask = smask.long() 121 | if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10 122 | and (i in sMatrices) and (i in cMatrices)): 123 | cmask = cmask.view(-1) 124 | fgcmask = Variable((cmask==1).nonzero().squeeze(1)) 125 | fgcmask = fgcmask.cuda() 126 | 127 | smask = smask.view(-1) 128 | fgsmask = Variable((smask==1).nonzero().squeeze(1)) 129 | fgsmask = fgsmask.cuda() 130 | 131 | sFF = sF.view(sc,-1) 132 | sFF_select = torch.index_select(sFF,1,fgsmask) 133 | sMean = torch.mean(sFF_select,dim=1,keepdim=True) 134 | sMean = sMean.view(1,sc,1,1) 135 | sMean = sMean.expand_as(cF) 136 | 137 | sMatrix = sMatrices[i] 138 | cMatrix = cMatrices[i] 139 | 140 | sMatrix = sMatrix.view(self.matrixSize,self.matrixSize) 141 | cMatrix = cMatrix.view(self.matrixSize,self.matrixSize) 142 | 143 | transmatrix = torch.mm(sMatrix,cMatrix) # (C*C) 144 | 145 | compress_content_select = torch.index_select(compress_content,1,fgcmask) 146 | 147 | transfeatureFG = torch.mm(transmatrix,compress_content_select) 148 | transfeature.index_copy_(1,fgcmask,transfeatureFG) 149 | 150 | sMean = sMean.contiguous() 151 | sMean_select = torch.index_select(sMean.view(sc,-1),1,fgcmask) 152 | finalSMean.index_copy_(1,fgcmask,sMean_select) 153 | out = self.unzip(transfeature.view(cb,cc,ch,cw)) 154 | return out + finalSMean.view(out.size()) 155 | -------------------------------------------------------------------------------- /libs/SPN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import vgg16 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | import torch.nn.functional as F 7 | import sys 8 | sys.path.append('../') 9 | from libs.pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind 10 | 11 | class spn_block(nn.Module): 12 | def __init__(self, horizontal, reverse): 13 | super(spn_block, self).__init__() 14 | self.propagator = GateRecurrent2dnoind(horizontal,reverse) 15 | 16 | def forward(self,x,G1,G2,G3): 17 | sum_abs = G1.abs() + G2.abs() + G3.abs() 18 | sum_abs.data[sum_abs.data == 0] = 1e-6 19 | mask_need_norm = sum_abs.ge(1) 20 | mask_need_norm = mask_need_norm.float() 21 | G1_norm = torch.div(G1, sum_abs) 22 | G2_norm = torch.div(G2, sum_abs) 23 | G3_norm = torch.div(G3, sum_abs) 24 | 25 | G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm 26 | G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm 27 | G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm 28 | 29 | return self.propagator(x,G1,G2,G3) 30 | 31 | class VGG(nn.Module): 32 | def __init__(self,nf): 33 | super(VGG,self).__init__() 34 | self.conv1 = nn.Conv2d(3,nf,3,padding = 1) 35 | # 256 x 256 36 | self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1) 37 | self.conv2 = nn.Conv2d(nf,nf*2,3,padding = 1) 38 | # 128 x 128 39 | self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1) 40 | self.conv3 = nn.Conv2d(nf*2,nf*4,3,padding = 1) 41 | # 64 x 64 42 | self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1) 43 | # 32 x 32 44 | self.conv4 = nn.Conv2d(nf*4,nf*8,3,padding = 1) 45 | 46 | def forward(self,x): 47 | output = {} 48 | output['conv1'] = self.conv1(x) 49 | x = F.relu(output['conv1']) 50 | x = self.pool1(x) 51 | output['conv2'] = self.conv2(x) 52 | # 128 x 128 53 | x = F.relu(output['conv2']) 54 | x = self.pool2(x) 55 | output['conv3'] = self.conv3(x) 56 | # 64 x 64 57 | x = F.relu(output['conv3']) 58 | output['pool3'] = self.pool3(x) 59 | # 32 x 32 60 | output['conv4'] = self.conv4(output['pool3']) 61 | return output 62 | 63 | class Decoder(nn.Module): 64 | def __init__(self,nf=32,spn=1): 65 | super(Decoder,self).__init__() 66 | # 32 x 32 67 | self.layer0 = nn.Conv2d(nf*8,nf*4,1,1,0) # edge_conv5 68 | self.layer1 = nn.Upsample(scale_factor=2,mode='bilinear') 69 | self.layer2 = nn.Sequential(nn.Conv2d(nf*4,nf*4,3,1,1), # edge_conv8 70 | nn.ELU(inplace=True)) 71 | # 64 x 64 72 | self.layer3 = nn.Upsample(scale_factor=2,mode='bilinear') 73 | self.layer4 = nn.Sequential(nn.Conv2d(nf*4,nf*2,3,1,1), # edge_conv8 74 | nn.ELU(inplace=True)) 75 | # 128 x 128 76 | self.layer5 = nn.Upsample(scale_factor=2,mode='bilinear') 77 | self.layer6 = nn.Sequential(nn.Conv2d(nf*2,nf,3,1,1), # edge_conv8 78 | nn.ELU(inplace=True)) 79 | if(spn == 1): 80 | self.layer7 = nn.Conv2d(nf,nf*12,3,1,1) 81 | else: 82 | self.layer7 = nn.Conv2d(nf,nf*24,3,1,1) 83 | self.spn = spn 84 | # 256 x 256 85 | 86 | def forward(self,encode_feature): 87 | output = {} 88 | output['0'] = self.layer0(encode_feature['conv4']) 89 | output['1'] = self.layer1(output['0']) 90 | 91 | output['2'] = self.layer2(output['1']) 92 | output['2res'] = output['2'] + encode_feature['conv3'] 93 | # 64 x 64 94 | 95 | output['3'] = self.layer3(output['2res']) 96 | output['4'] = self.layer4(output['3']) 97 | output['4res'] = output['4'] + encode_feature['conv2'] 98 | # 128 x 128 99 | 100 | output['5'] = self.layer5(output['4res']) 101 | output['6'] = self.layer6(output['5']) 102 | output['6res'] = output['6'] + encode_feature['conv1'] 103 | 104 | output['7'] = self.layer7(output['6res']) 105 | 106 | return output['7'] 107 | 108 | 109 | class SPN(nn.Module): 110 | def __init__(self,nf=32,spn=1): 111 | super(SPN,self).__init__() 112 | # conv for mask 113 | self.mask_conv = nn.Conv2d(3,nf,3,1,1) 114 | 115 | # guidance network 116 | self.encoder = VGG(nf) 117 | self.decoder = Decoder(nf,spn) 118 | 119 | # spn blocks 120 | self.left_right = spn_block(True,False) 121 | self.right_left = spn_block(True,True) 122 | self.top_down = spn_block(False, False) 123 | self.down_top = spn_block(False,True) 124 | 125 | # post upsample 126 | self.post = nn.Conv2d(nf,3,3,1,1) 127 | self.nf = nf 128 | 129 | def forward(self,x,rgb): 130 | # feature for mask 131 | X = self.mask_conv(x) 132 | 133 | # guidance 134 | features = self.encoder(rgb) 135 | guide = self.decoder(features) 136 | 137 | G = torch.split(guide,self.nf,1) 138 | out1 = self.left_right(X,G[0],G[1],G[2]) 139 | out2 = self.right_left(X,G[3],G[4],G[5]) 140 | out3 = self.top_down(X,G[6],G[7],G[8]) 141 | out4 = self.down_top(X,G[9],G[10],G[11]) 142 | 143 | out = torch.max(out1,out2) 144 | out = torch.max(out,out3) 145 | out = torch.max(out,out4) 146 | 147 | return self.post(out) 148 | 149 | if __name__ == '__main__': 150 | spn = SPN() 151 | spn = spn.cuda() 152 | for i in range(100): 153 | x = Variable(torch.Tensor(1,3,256,256)).cuda() 154 | rgb = Variable(torch.Tensor(1,3,256,256)).cuda() 155 | output = spn(x,rgb) 156 | print(output.size()) 157 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/__init__.py -------------------------------------------------------------------------------- /libs/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class encoder3(nn.Module): 5 | def __init__(self): 6 | super(encoder3,self).__init__() 7 | # vgg 8 | # 224 x 224 9 | self.conv1 = nn.Conv2d(3,3,1,1,0) 10 | self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) 11 | # 226 x 226 12 | 13 | self.conv2 = nn.Conv2d(3,64,3,1,0) 14 | self.relu2 = nn.ReLU(inplace=True) 15 | # 224 x 224 16 | 17 | self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) 18 | self.conv3 = nn.Conv2d(64,64,3,1,0) 19 | self.relu3 = nn.ReLU(inplace=True) 20 | # 224 x 224 21 | 22 | self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) 23 | # 112 x 112 24 | 25 | self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) 26 | self.conv4 = nn.Conv2d(64,128,3,1,0) 27 | self.relu4 = nn.ReLU(inplace=True) 28 | # 112 x 112 29 | 30 | self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) 31 | self.conv5 = nn.Conv2d(128,128,3,1,0) 32 | self.relu5 = nn.ReLU(inplace=True) 33 | # 112 x 112 34 | 35 | self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) 36 | # 56 x 56 37 | 38 | self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) 39 | self.conv6 = nn.Conv2d(128,256,3,1,0) 40 | self.relu6 = nn.ReLU(inplace=True) 41 | # 56 x 56 42 | def forward(self,x): 43 | out = self.conv1(x) 44 | out = self.reflecPad1(out) 45 | out = self.conv2(out) 46 | out = self.relu2(out) 47 | out = self.reflecPad3(out) 48 | out = self.conv3(out) 49 | pool1 = self.relu3(out) 50 | out,pool_idx = self.maxPool(pool1) 51 | out = self.reflecPad4(out) 52 | out = self.conv4(out) 53 | out = self.relu4(out) 54 | out = self.reflecPad5(out) 55 | out = self.conv5(out) 56 | pool2 = self.relu5(out) 57 | out,pool_idx2 = self.maxPool2(pool2) 58 | out = self.reflecPad6(out) 59 | out = self.conv6(out) 60 | out = self.relu6(out) 61 | return out 62 | 63 | class decoder3(nn.Module): 64 | def __init__(self): 65 | super(decoder3,self).__init__() 66 | # decoder 67 | self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) 68 | self.conv7 = nn.Conv2d(256,128,3,1,0) 69 | self.relu7 = nn.ReLU(inplace=True) 70 | # 56 x 56 71 | 72 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2) 73 | # 112 x 112 74 | 75 | self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) 76 | self.conv8 = nn.Conv2d(128,128,3,1,0) 77 | self.relu8 = nn.ReLU(inplace=True) 78 | # 112 x 112 79 | 80 | self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) 81 | self.conv9 = nn.Conv2d(128,64,3,1,0) 82 | self.relu9 = nn.ReLU(inplace=True) 83 | 84 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) 85 | # 224 x 224 86 | 87 | self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) 88 | self.conv10 = nn.Conv2d(64,64,3,1,0) 89 | self.relu10 = nn.ReLU(inplace=True) 90 | 91 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) 92 | self.conv11 = nn.Conv2d(64,3,3,1,0) 93 | 94 | def forward(self,x): 95 | output = {} 96 | out = self.reflecPad7(x) 97 | out = self.conv7(out) 98 | out = self.relu7(out) 99 | out = self.unpool(out) 100 | out = self.reflecPad8(out) 101 | out = self.conv8(out) 102 | out = self.relu8(out) 103 | out = self.reflecPad9(out) 104 | out = self.conv9(out) 105 | out_relu9 = self.relu9(out) 106 | out = self.unpool2(out_relu9) 107 | out = self.reflecPad10(out) 108 | out = self.conv10(out) 109 | out = self.relu10(out) 110 | out = self.reflecPad11(out) 111 | out = self.conv11(out) 112 | return out 113 | 114 | class encoder4(nn.Module): 115 | def __init__(self): 116 | super(encoder4,self).__init__() 117 | # vgg 118 | # 224 x 224 119 | self.conv1 = nn.Conv2d(3,3,1,1,0) 120 | self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) 121 | # 226 x 226 122 | 123 | self.conv2 = nn.Conv2d(3,64,3,1,0) 124 | self.relu2 = nn.ReLU(inplace=True) 125 | # 224 x 224 126 | 127 | self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) 128 | self.conv3 = nn.Conv2d(64,64,3,1,0) 129 | self.relu3 = nn.ReLU(inplace=True) 130 | # 224 x 224 131 | 132 | self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) 133 | # 112 x 112 134 | 135 | self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) 136 | self.conv4 = nn.Conv2d(64,128,3,1,0) 137 | self.relu4 = nn.ReLU(inplace=True) 138 | # 112 x 112 139 | 140 | self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) 141 | self.conv5 = nn.Conv2d(128,128,3,1,0) 142 | self.relu5 = nn.ReLU(inplace=True) 143 | # 112 x 112 144 | 145 | self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) 146 | # 56 x 56 147 | 148 | self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) 149 | self.conv6 = nn.Conv2d(128,256,3,1,0) 150 | self.relu6 = nn.ReLU(inplace=True) 151 | # 56 x 56 152 | 153 | self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) 154 | self.conv7 = nn.Conv2d(256,256,3,1,0) 155 | self.relu7 = nn.ReLU(inplace=True) 156 | # 56 x 56 157 | 158 | self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) 159 | self.conv8 = nn.Conv2d(256,256,3,1,0) 160 | self.relu8 = nn.ReLU(inplace=True) 161 | # 56 x 56 162 | 163 | self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) 164 | self.conv9 = nn.Conv2d(256,256,3,1,0) 165 | self.relu9 = nn.ReLU(inplace=True) 166 | # 56 x 56 167 | 168 | self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) 169 | # 28 x 28 170 | 171 | self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) 172 | self.conv10 = nn.Conv2d(256,512,3,1,0) 173 | self.relu10 = nn.ReLU(inplace=True) 174 | # 28 x 28 175 | def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None): 176 | output = {} 177 | out = self.conv1(x) 178 | out = self.reflecPad1(out) 179 | out = self.conv2(out) 180 | output['r11'] = self.relu2(out) 181 | out = self.reflecPad7(output['r11']) 182 | 183 | out = self.conv3(out) 184 | output['r12'] = self.relu3(out) 185 | 186 | output['p1'] = self.maxPool(output['r12']) 187 | out = self.reflecPad4(output['p1']) 188 | out = self.conv4(out) 189 | output['r21'] = self.relu4(out) 190 | out = self.reflecPad7(output['r21']) 191 | 192 | out = self.conv5(out) 193 | output['r22'] = self.relu5(out) 194 | 195 | output['p2'] = self.maxPool2(output['r22']) 196 | out = self.reflecPad6(output['p2']) 197 | out = self.conv6(out) 198 | output['r31'] = self.relu6(out) 199 | if(matrix31 is not None): 200 | feature3,transmatrix3 = matrix31(output['r31'],sF['r31']) 201 | out = self.reflecPad7(feature3) 202 | else: 203 | out = self.reflecPad7(output['r31']) 204 | out = self.conv7(out) 205 | output['r32'] = self.relu7(out) 206 | 207 | out = self.reflecPad8(output['r32']) 208 | out = self.conv8(out) 209 | output['r33'] = self.relu8(out) 210 | 211 | out = self.reflecPad9(output['r33']) 212 | out = self.conv9(out) 213 | output['r34'] = self.relu9(out) 214 | 215 | output['p3'] = self.maxPool3(output['r34']) 216 | out = self.reflecPad10(output['p3']) 217 | out = self.conv10(out) 218 | output['r41'] = self.relu10(out) 219 | 220 | return output 221 | 222 | class decoder4(nn.Module): 223 | def __init__(self): 224 | super(decoder4,self).__init__() 225 | # decoder 226 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) 227 | self.conv11 = nn.Conv2d(512,256,3,1,0) 228 | self.relu11 = nn.ReLU(inplace=True) 229 | # 28 x 28 230 | 231 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2) 232 | # 56 x 56 233 | 234 | self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) 235 | self.conv12 = nn.Conv2d(256,256,3,1,0) 236 | self.relu12 = nn.ReLU(inplace=True) 237 | # 56 x 56 238 | 239 | self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) 240 | self.conv13 = nn.Conv2d(256,256,3,1,0) 241 | self.relu13 = nn.ReLU(inplace=True) 242 | # 56 x 56 243 | 244 | self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) 245 | self.conv14 = nn.Conv2d(256,256,3,1,0) 246 | self.relu14 = nn.ReLU(inplace=True) 247 | # 56 x 56 248 | 249 | self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) 250 | self.conv15 = nn.Conv2d(256,128,3,1,0) 251 | self.relu15 = nn.ReLU(inplace=True) 252 | # 56 x 56 253 | 254 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) 255 | # 112 x 112 256 | 257 | self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) 258 | self.conv16 = nn.Conv2d(128,128,3,1,0) 259 | self.relu16 = nn.ReLU(inplace=True) 260 | # 112 x 112 261 | 262 | self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) 263 | self.conv17 = nn.Conv2d(128,64,3,1,0) 264 | self.relu17 = nn.ReLU(inplace=True) 265 | # 112 x 112 266 | 267 | self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) 268 | # 224 x 224 269 | 270 | self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) 271 | self.conv18 = nn.Conv2d(64,64,3,1,0) 272 | self.relu18 = nn.ReLU(inplace=True) 273 | # 224 x 224 274 | 275 | self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) 276 | self.conv19 = nn.Conv2d(64,3,3,1,0) 277 | 278 | def forward(self,x): 279 | # decoder 280 | out = self.reflecPad11(x) 281 | out = self.conv11(out) 282 | out = self.relu11(out) 283 | out = self.unpool(out) 284 | out = self.reflecPad12(out) 285 | out = self.conv12(out) 286 | 287 | out = self.relu12(out) 288 | out = self.reflecPad13(out) 289 | out = self.conv13(out) 290 | out = self.relu13(out) 291 | out = self.reflecPad14(out) 292 | out = self.conv14(out) 293 | out = self.relu14(out) 294 | out = self.reflecPad15(out) 295 | out = self.conv15(out) 296 | out = self.relu15(out) 297 | out = self.unpool2(out) 298 | out = self.reflecPad16(out) 299 | out = self.conv16(out) 300 | out = self.relu16(out) 301 | out = self.reflecPad17(out) 302 | out = self.conv17(out) 303 | out = self.relu17(out) 304 | out = self.unpool3(out) 305 | out = self.reflecPad18(out) 306 | out = self.conv18(out) 307 | out = self.relu18(out) 308 | out = self.reflecPad19(out) 309 | out = self.conv19(out) 310 | return out 311 | 312 | class decoder4(nn.Module): 313 | def __init__(self): 314 | super(decoder4,self).__init__() 315 | # decoder 316 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) 317 | self.conv11 = nn.Conv2d(512,256,3,1,0) 318 | self.relu11 = nn.ReLU(inplace=True) 319 | # 28 x 28 320 | 321 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2) 322 | # 56 x 56 323 | 324 | self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) 325 | self.conv12 = nn.Conv2d(256,256,3,1,0) 326 | self.relu12 = nn.ReLU(inplace=True) 327 | # 56 x 56 328 | 329 | self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) 330 | self.conv13 = nn.Conv2d(256,256,3,1,0) 331 | self.relu13 = nn.ReLU(inplace=True) 332 | # 56 x 56 333 | 334 | self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) 335 | self.conv14 = nn.Conv2d(256,256,3,1,0) 336 | self.relu14 = nn.ReLU(inplace=True) 337 | # 56 x 56 338 | 339 | self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) 340 | self.conv15 = nn.Conv2d(256,128,3,1,0) 341 | self.relu15 = nn.ReLU(inplace=True) 342 | # 56 x 56 343 | 344 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) 345 | # 112 x 112 346 | 347 | self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) 348 | self.conv16 = nn.Conv2d(128,128,3,1,0) 349 | self.relu16 = nn.ReLU(inplace=True) 350 | # 112 x 112 351 | 352 | self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) 353 | self.conv17 = nn.Conv2d(128,64,3,1,0) 354 | self.relu17 = nn.ReLU(inplace=True) 355 | # 112 x 112 356 | 357 | self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) 358 | # 224 x 224 359 | 360 | self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) 361 | self.conv18 = nn.Conv2d(64,64,3,1,0) 362 | self.relu18 = nn.ReLU(inplace=True) 363 | # 224 x 224 364 | 365 | self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) 366 | self.conv19 = nn.Conv2d(64,3,3,1,0) 367 | 368 | def forward(self,x): 369 | # decoder 370 | out = self.reflecPad11(x) 371 | out = self.conv11(out) 372 | out = self.relu11(out) 373 | out = self.unpool(out) 374 | out = self.reflecPad12(out) 375 | out = self.conv12(out) 376 | 377 | out = self.relu12(out) 378 | out = self.reflecPad13(out) 379 | out = self.conv13(out) 380 | out = self.relu13(out) 381 | out = self.reflecPad14(out) 382 | out = self.conv14(out) 383 | out = self.relu14(out) 384 | out = self.reflecPad15(out) 385 | out = self.conv15(out) 386 | out = self.relu15(out) 387 | out = self.unpool2(out) 388 | out = self.reflecPad16(out) 389 | out = self.conv16(out) 390 | out = self.relu16(out) 391 | out = self.reflecPad17(out) 392 | out = self.conv17(out) 393 | out = self.relu17(out) 394 | out = self.unpool3(out) 395 | out = self.reflecPad18(out) 396 | out = self.conv18(out) 397 | out = self.relu18(out) 398 | out = self.reflecPad19(out) 399 | out = self.conv19(out) 400 | return out 401 | 402 | class encoder5(nn.Module): 403 | def __init__(self): 404 | super(encoder5,self).__init__() 405 | # vgg 406 | # 224 x 224 407 | self.conv1 = nn.Conv2d(3,3,1,1,0) 408 | self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) 409 | # 226 x 226 410 | 411 | self.conv2 = nn.Conv2d(3,64,3,1,0) 412 | self.relu2 = nn.ReLU(inplace=True) 413 | # 224 x 224 414 | 415 | self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) 416 | self.conv3 = nn.Conv2d(64,64,3,1,0) 417 | self.relu3 = nn.ReLU(inplace=True) 418 | # 224 x 224 419 | 420 | self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) 421 | # 112 x 112 422 | 423 | self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) 424 | self.conv4 = nn.Conv2d(64,128,3,1,0) 425 | self.relu4 = nn.ReLU(inplace=True) 426 | # 112 x 112 427 | 428 | self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) 429 | self.conv5 = nn.Conv2d(128,128,3,1,0) 430 | self.relu5 = nn.ReLU(inplace=True) 431 | # 112 x 112 432 | 433 | self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) 434 | # 56 x 56 435 | 436 | self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) 437 | self.conv6 = nn.Conv2d(128,256,3,1,0) 438 | self.relu6 = nn.ReLU(inplace=True) 439 | # 56 x 56 440 | 441 | self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) 442 | self.conv7 = nn.Conv2d(256,256,3,1,0) 443 | self.relu7 = nn.ReLU(inplace=True) 444 | # 56 x 56 445 | 446 | self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) 447 | self.conv8 = nn.Conv2d(256,256,3,1,0) 448 | self.relu8 = nn.ReLU(inplace=True) 449 | # 56 x 56 450 | 451 | self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) 452 | self.conv9 = nn.Conv2d(256,256,3,1,0) 453 | self.relu9 = nn.ReLU(inplace=True) 454 | # 56 x 56 455 | 456 | self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) 457 | # 28 x 28 458 | 459 | self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) 460 | self.conv10 = nn.Conv2d(256,512,3,1,0) 461 | self.relu10 = nn.ReLU(inplace=True) 462 | 463 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) 464 | self.conv11 = nn.Conv2d(512,512,3,1,0) 465 | self.relu11 = nn.ReLU(inplace=True) 466 | 467 | self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) 468 | self.conv12 = nn.Conv2d(512,512,3,1,0) 469 | self.relu12 = nn.ReLU(inplace=True) 470 | 471 | self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) 472 | self.conv13 = nn.Conv2d(512,512,3,1,0) 473 | self.relu13 = nn.ReLU(inplace=True) 474 | 475 | self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2) 476 | self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) 477 | self.conv14 = nn.Conv2d(512,512,3,1,0) 478 | self.relu14 = nn.ReLU(inplace=True) 479 | 480 | def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None): 481 | output = {} 482 | out = self.conv1(x) 483 | out = self.reflecPad1(out) 484 | out = self.conv2(out) 485 | output['r11'] = self.relu2(out) 486 | out = self.reflecPad7(output['r11']) 487 | 488 | #out = self.reflecPad3(output['r11']) 489 | out = self.conv3(out) 490 | output['r12'] = self.relu3(out) 491 | 492 | output['p1'] = self.maxPool(output['r12']) 493 | out = self.reflecPad4(output['p1']) 494 | out = self.conv4(out) 495 | output['r21'] = self.relu4(out) 496 | out = self.reflecPad7(output['r21']) 497 | 498 | #out = self.reflecPad5(output['r21']) 499 | out = self.conv5(out) 500 | output['r22'] = self.relu5(out) 501 | 502 | output['p2'] = self.maxPool2(output['r22']) 503 | out = self.reflecPad6(output['p2']) 504 | out = self.conv6(out) 505 | output['r31'] = self.relu6(out) 506 | if(styleV256 is not None): 507 | feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256) 508 | out = self.reflecPad7(feature) 509 | else: 510 | out = self.reflecPad7(output['r31']) 511 | out = self.conv7(out) 512 | output['r32'] = self.relu7(out) 513 | 514 | out = self.reflecPad8(output['r32']) 515 | out = self.conv8(out) 516 | output['r33'] = self.relu8(out) 517 | 518 | out = self.reflecPad9(output['r33']) 519 | out = self.conv9(out) 520 | output['r34'] = self.relu9(out) 521 | 522 | output['p3'] = self.maxPool3(output['r34']) 523 | out = self.reflecPad10(output['p3']) 524 | out = self.conv10(out) 525 | output['r41'] = self.relu10(out) 526 | 527 | out = self.reflecPad11(output['r41']) 528 | out = self.conv11(out) 529 | output['r42'] = self.relu11(out) 530 | 531 | out = self.reflecPad12(output['r42']) 532 | out = self.conv12(out) 533 | output['r43'] = self.relu12(out) 534 | 535 | out = self.reflecPad13(output['r43']) 536 | out = self.conv13(out) 537 | output['r44'] = self.relu13(out) 538 | 539 | output['p4'] = self.maxPool4(output['r44']) 540 | 541 | out = self.reflecPad14(output['p4']) 542 | out = self.conv14(out) 543 | output['r51'] = self.relu14(out) 544 | return output 545 | 546 | class decoder5(nn.Module): 547 | def __init__(self): 548 | super(decoder5,self).__init__() 549 | 550 | # decoder 551 | self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) 552 | self.conv15 = nn.Conv2d(512,512,3,1,0) 553 | self.relu15 = nn.ReLU(inplace=True) 554 | 555 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2) 556 | # 28 x 28 557 | 558 | self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) 559 | self.conv16 = nn.Conv2d(512,512,3,1,0) 560 | self.relu16 = nn.ReLU(inplace=True) 561 | # 28 x 28 562 | 563 | self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) 564 | self.conv17 = nn.Conv2d(512,512,3,1,0) 565 | self.relu17 = nn.ReLU(inplace=True) 566 | # 28 x 28 567 | 568 | self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) 569 | self.conv18 = nn.Conv2d(512,512,3,1,0) 570 | self.relu18 = nn.ReLU(inplace=True) 571 | # 28 x 28 572 | 573 | self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) 574 | self.conv19 = nn.Conv2d(512,256,3,1,0) 575 | self.relu19 = nn.ReLU(inplace=True) 576 | # 28 x 28 577 | 578 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) 579 | # 56 x 56 580 | 581 | self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1)) 582 | self.conv20 = nn.Conv2d(256,256,3,1,0) 583 | self.relu20 = nn.ReLU(inplace=True) 584 | # 56 x 56 585 | 586 | self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1)) 587 | self.conv21 = nn.Conv2d(256,256,3,1,0) 588 | self.relu21 = nn.ReLU(inplace=True) 589 | 590 | self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1)) 591 | self.conv22 = nn.Conv2d(256,256,3,1,0) 592 | self.relu22 = nn.ReLU(inplace=True) 593 | 594 | self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1)) 595 | self.conv23 = nn.Conv2d(256,128,3,1,0) 596 | self.relu23 = nn.ReLU(inplace=True) 597 | 598 | self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) 599 | # 112 X 112 600 | 601 | self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1)) 602 | self.conv24 = nn.Conv2d(128,128,3,1,0) 603 | self.relu24 = nn.ReLU(inplace=True) 604 | 605 | self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1)) 606 | self.conv25 = nn.Conv2d(128,64,3,1,0) 607 | self.relu25 = nn.ReLU(inplace=True) 608 | 609 | self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2) 610 | 611 | self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1)) 612 | self.conv26 = nn.Conv2d(64,64,3,1,0) 613 | self.relu26 = nn.ReLU(inplace=True) 614 | 615 | self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1)) 616 | self.conv27 = nn.Conv2d(64,3,3,1,0) 617 | 618 | def forward(self,x): 619 | # decoder 620 | out = self.reflecPad15(x) 621 | out = self.conv15(out) 622 | out = self.relu15(out) 623 | out = self.unpool(out) 624 | out = self.reflecPad16(out) 625 | out = self.conv16(out) 626 | out = self.relu16(out) 627 | out = self.reflecPad17(out) 628 | out = self.conv17(out) 629 | out = self.relu17(out) 630 | out = self.reflecPad18(out) 631 | out = self.conv18(out) 632 | out = self.relu18(out) 633 | out = self.reflecPad19(out) 634 | out = self.conv19(out) 635 | out = self.relu19(out) 636 | out = self.unpool2(out) 637 | out = self.reflecPad20(out) 638 | out = self.conv20(out) 639 | out = self.relu20(out) 640 | out = self.reflecPad21(out) 641 | out = self.conv21(out) 642 | out = self.relu21(out) 643 | out = self.reflecPad22(out) 644 | out = self.conv22(out) 645 | out = self.relu22(out) 646 | out = self.reflecPad23(out) 647 | out = self.conv23(out) 648 | out = self.relu23(out) 649 | out = self.unpool3(out) 650 | out = self.reflecPad24(out) 651 | out = self.conv24(out) 652 | out = self.relu24(out) 653 | out = self.reflecPad25(out) 654 | out = self.conv25(out) 655 | out = self.relu25(out) 656 | out = self.unpool4(out) 657 | out = self.reflecPad26(out) 658 | out = self.conv26(out) 659 | out = self.relu26(out) 660 | out = self.reflecPad27(out) 661 | out = self.conv27(out) 662 | return out 663 | -------------------------------------------------------------------------------- /libs/pytorch_spn/README.md: -------------------------------------------------------------------------------- 1 | # pytorch_spn 2 | To build, install [pytorch](https://github.com/pytorch) and run: 3 | 4 | $ sh make.sh 5 | 6 | See left_right_demo.py for usage: 7 | 8 | $ mv left_right_demo.py ../ 9 | 10 | $ python left_right_demo.py 11 | 12 | The original codes (caffe) and models will be relesed [HERE](https://github.com/Liusifei/caffe-spn.git). 13 | -------------------------------------------------------------------------------- /libs/pytorch_spn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/__init__.py -------------------------------------------------------------------------------- /libs/pytorch_spn/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/_ext/__init__.py -------------------------------------------------------------------------------- /libs/pytorch_spn/_ext/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/_ext/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._gaterecurrent2dnoind import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /libs/pytorch_spn/_ext/gaterecurrent2dnoind/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/_ext/gaterecurrent2dnoind/_gaterecurrent2dnoind.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/_ext/gaterecurrent2dnoind/_gaterecurrent2dnoind.so -------------------------------------------------------------------------------- /libs/pytorch_spn/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | this_file = os.path.dirname(__file__) 6 | 7 | sources = [] 8 | headers = [] 9 | defines = [] 10 | with_cuda = False 11 | 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['src/gaterecurrent2dnoind_cuda.c'] 15 | headers += ['src/gaterecurrent2dnoind_cuda.h'] 16 | defines += [('WITH_CUDA', None)] 17 | with_cuda = True 18 | 19 | this_file = os.path.dirname(os.path.realpath(__file__)) 20 | extra_objects = ['src/cuda/gaterecurrent2dnoind_kernel.cu.o'] 21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 22 | 23 | ffi = create_extension( 24 | '_ext.gaterecurrent2dnoind', 25 | headers=headers, 26 | sources=sources, 27 | define_macros=defines, 28 | relative_to=__file__, 29 | with_cuda=with_cuda, 30 | extra_objects=extra_objects 31 | ) 32 | 33 | if __name__ == '__main__': 34 | ffi.build() 35 | -------------------------------------------------------------------------------- /libs/pytorch_spn/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/__init__.py -------------------------------------------------------------------------------- /libs/pytorch_spn/functions/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/__init__.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/functions/__pycache__/gaterecurrent2dnoind.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/__pycache__/gaterecurrent2dnoind.cpython-36.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/functions/gaterecurrent2dnoind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from .._ext import gaterecurrent2dnoind as gaterecurrent2d 4 | 5 | class GateRecurrent2dnoindFunction(Function): 6 | def __init__(self, horizontal_, reverse_): 7 | self.horizontal = horizontal_ 8 | self.reverse = reverse_ 9 | 10 | def forward(self, X, G1, G2, G3): 11 | num, channels, height, width = X.size() 12 | output = torch.zeros(num, channels, height, width) 13 | 14 | if not X.is_cuda: 15 | print("cpu version is not ready at this time") 16 | return 0 17 | else: 18 | output = output.cuda() 19 | gaterecurrent2d.gaterecurrent2dnoind_forward_cuda(self.horizontal,self.reverse, X, G1, G2, G3, output) 20 | 21 | self.X = X 22 | self.G1 = G1 23 | self.G2 = G2 24 | self.G3 = G3 25 | self.output = output 26 | self.hiddensize = X.size() 27 | return output 28 | 29 | def backward(self, grad_output): 30 | assert(self.hiddensize is not None and grad_output.is_cuda) 31 | num, channels, height, width = self.hiddensize 32 | 33 | grad_X = torch.zeros(num, channels, height, width).cuda() 34 | grad_G1 = torch.zeros(num, channels, height, width).cuda() 35 | grad_G2 = torch.zeros(num, channels, height, width).cuda() 36 | grad_G3 = torch.zeros(num, channels, height, width).cuda() 37 | 38 | gaterecurrent2d.gaterecurrent2dnoind_backward_cuda(self.horizontal, self.reverse, self.output, grad_output, self.X, self.G1, self.G2, self.G3, grad_X, grad_G1, grad_G2, grad_G3) 39 | 40 | del self.hiddensize 41 | del self.G1 42 | del self.G2 43 | del self.G3 44 | del self.output 45 | del self.X 46 | 47 | return grad_X, grad_G1, grad_G2, grad_G3 48 | -------------------------------------------------------------------------------- /libs/pytorch_spn/functions/gaterecurrent2dnoind.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/gaterecurrent2dnoind.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/left_right_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example of left->right propagation 3 | 4 | Other direction settings: 5 | left->right: Propagator = GateRecurrent2dnoind(True,False) 6 | right->left: Propagator = GateRecurrent2dnoind(True,True) 7 | top->bottom: Propagator = GateRecurrent2dnoind(False,False) 8 | bottom->top: Propagator = GateRecurrent2dnoind(False,True) 9 | 10 | X: any signal/feature map to be filtered 11 | G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom) 12 | 13 | Note: 14 | 1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images) 15 | 2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper) 16 | """ 17 | import torch 18 | from torch.autograd import Variable 19 | from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind 20 | 21 | Propagator = GateRecurrent2dnoind(True,False) 22 | 23 | X = Variable(torch.randn(1,3,10,10)) 24 | G1 = Variable(torch.randn(1,3,10,10)) 25 | G2 = Variable(torch.randn(1,3,10,10)) 26 | G3 = Variable(torch.randn(1,3,10,10)) 27 | 28 | sum_abs = G1.abs() + G2.abs() + G3.abs() 29 | mask_need_norm = sum_abs.ge(1) 30 | mask_need_norm = mask_need_norm.float() 31 | G1_norm = torch.div(G1, sum_abs) 32 | G2_norm = torch.div(G2, sum_abs) 33 | G3_norm = torch.div(G3, sum_abs) 34 | 35 | G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm 36 | G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm 37 | G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm 38 | 39 | X = X.cuda() 40 | G1 = G1.cuda() 41 | G2 = G2.cuda() 42 | G3 = G3.cuda() 43 | 44 | output = Propagator.forward(X,G1,G2,G3) 45 | print(X) 46 | print(output) 47 | -------------------------------------------------------------------------------- /libs/pytorch_spn/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_PATH=/usr/local/cuda/ 4 | 5 | cd src/cuda/ 6 | echo "Compiling gaterecurrent2dnoind layer kernels by nvcc..." 7 | nvcc -c -o gaterecurrent2dnoind_kernel.cu.o gaterecurrent2dnoind_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 8 | cd ../../ 9 | python build.py 10 | -------------------------------------------------------------------------------- /libs/pytorch_spn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /libs/pytorch_spn/modules/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/modules/__init__.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/modules/__pycache__/gaterecurrent2dnoind.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/modules/__pycache__/gaterecurrent2dnoind.cpython-36.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/modules/gaterecurrent2dnoind.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..functions.gaterecurrent2dnoind import GateRecurrent2dnoindFunction 3 | 4 | class GateRecurrent2dnoind(nn.Module): 5 | """docstring for .""" 6 | def __init__(self, horizontal_, reverse_): 7 | super(GateRecurrent2dnoind, self).__init__() 8 | self.horizontal = horizontal_ 9 | self.reverse = reverse_ 10 | 11 | def forward(self, X, G1, G2, G3): 12 | return GateRecurrent2dnoindFunction(self.horizontal, self.reverse)(X, G1, G2, G3) 13 | -------------------------------------------------------------------------------- /libs/pytorch_spn/modules/gaterecurrent2dnoind.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/modules/gaterecurrent2dnoind.pyc -------------------------------------------------------------------------------- /libs/pytorch_spn/src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/src/.DS_Store -------------------------------------------------------------------------------- /libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include 6 | #include 7 | #include 8 | #include "gaterecurrent2dnoind_kernel.h" 9 | 10 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 11 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 12 | i += blockDim.x * gridDim.x) 13 | 14 | __device__ void get_gate_idx_sf(int h1, int w1, int h2, int w2, int * out, int horizontal, int reverse) 15 | { 16 | if(horizontal && ! reverse) // left -> right 17 | { 18 | if(w1>w2) 19 | { 20 | out[0]=h1; 21 | out[1]=w1; 22 | } 23 | else 24 | { 25 | out[0]=h2; 26 | out[1]=w2; 27 | } 28 | } 29 | if(horizontal && reverse) // right -> left 30 | { 31 | if(w1 bottom 43 | { 44 | if(h1>h2) 45 | { 46 | out[0]=h1; 47 | out[1]=w1; 48 | } 49 | else 50 | { 51 | out[0]=h2; 52 | out[1]=w2; 53 | } 54 | } 55 | if(!horizontal && reverse) // bottom -> top 56 | { 57 | if(h1=height) 74 | return 0; 75 | if(w<0 || w >= width) 76 | return 0; 77 | 78 | return data[n*channels*height*width + c * height*width + h * width + w]; 79 | } 80 | 81 | __device__ void set_data_sf(float * data, int num, int channels,int height, int width,int n,int c,int h,int w, float v) 82 | { 83 | if(h<0 || h >=height) 84 | return ; 85 | if(w<0 || w >= width) 86 | return ; 87 | 88 | data[n*channels*height*width + c * height*width + h * width + w]=v; 89 | } 90 | 91 | __device__ float get_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse) 92 | { 93 | if(h1<0 || h1 >=height) 94 | return 0; 95 | if(w1<0 || w1 >= width) 96 | return 0; 97 | if(h2<0 || h2 >=height) 98 | return 0; 99 | if(w2<0 || w2 >= width) 100 | return 0; 101 | int idx[2]; 102 | 103 | get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse); 104 | 105 | int h = idx[0]; 106 | int w = idx[1]; 107 | 108 | return data[n*channels*height*width + c * height*width + h * width + w]; 109 | } 110 | 111 | __device__ void set_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse, float v) 112 | { 113 | if(h1<0 || h1 >=height) 114 | return ; 115 | if(w1<0 || w1 >= width) 116 | return ; 117 | if(h2<0 || h2 >=height) 118 | return ; 119 | if(w2<0 || w2 >= width) 120 | return ; 121 | int idx[2]; 122 | 123 | get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse); 124 | 125 | int h = idx[0]; 126 | int w = idx[1]; 127 | 128 | data[n*channels*height*width + c * height*width + h * width + w]=v; 129 | } 130 | 131 | // we do not use set_gate_add_sf(...) in the caffe implimentation 132 | // avoid using atomicAdd 133 | 134 | __global__ void forward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, int horizontal, int reverse) { 135 | CUDA_1D_KERNEL_LOOP(index, count) { 136 | 137 | int hc_count = height * channels; 138 | 139 | int n,c,h,w; 140 | int temp=index; 141 | w = T; 142 | n = temp / hc_count; 143 | temp = temp % hc_count; 144 | c = temp / height; 145 | temp = temp % height; 146 | h = temp; 147 | 148 | 149 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); 150 | 151 | float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse); 152 | float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1); 153 | float h1_minus1 = g_data_1 * h_minus1_data_1; 154 | 155 | float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse); 156 | float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w-1); 157 | float h2_minus1 = g_data_2 * h_minus1_data_2; 158 | 159 | float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse); 160 | float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1); 161 | float h3_minus1 = g_data_3 * h_minus1_data_3; 162 | 163 | float h_hype = h1_minus1 + h2_minus1 + h3_minus1; 164 | float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data; 165 | 166 | float h_data = x_hype + h_hype; 167 | 168 | set_data_sf(H,num,channels,height,width,n,c,h,w,h_data); 169 | 170 | } 171 | } 172 | 173 | __global__ void forward_one_col_right_left( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) { 174 | CUDA_1D_KERNEL_LOOP(index, count) { 175 | 176 | int hc_count = height * channels; 177 | int n,c,h,w; 178 | int temp=index; 179 | w = T; 180 | n = temp / hc_count; 181 | temp = temp % hc_count; 182 | c = temp / height; 183 | temp = temp % height; 184 | h = temp; 185 | 186 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); 187 | 188 | float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse); 189 | float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1); 190 | float h1_minus1 = g_data_1 * h_minus1_data_1; 191 | 192 | float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse); 193 | float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w+1); 194 | float h2_minus1 = g_data_2 * h_minus1_data_2; 195 | 196 | float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse); 197 | float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1); 198 | float h3_minus1 = g_data_3 * h_minus1_data_3; 199 | 200 | float h_hype = h1_minus1 + h2_minus1 + h3_minus1; 201 | float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data; 202 | 203 | float h_data = x_hype + h_hype; 204 | 205 | set_data_sf(H,num,channels,height,width,n,c,h,w,h_data); 206 | 207 | } 208 | } 209 | 210 | __global__ void forward_one_row_top_bottom( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) { 211 | CUDA_1D_KERNEL_LOOP(index, count) { 212 | 213 | int wc_count = width * channels; 214 | 215 | int n,c,h,w; 216 | int temp=index; 217 | h = T; 218 | n = temp / wc_count; 219 | temp = temp % wc_count; 220 | c = temp / width; 221 | temp = temp % width; 222 | w = temp; 223 | 224 | 225 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); 226 | 227 | 228 | float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse); 229 | float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1); 230 | float h1_minus1 = g_data_1 * h_minus1_data_1; 231 | 232 | float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse); 233 | float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h-1,w); 234 | float h2_minus1 = g_data_2 * h_minus1_data_2; 235 | 236 | float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse); 237 | float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1); 238 | float h3_minus1 = g_data_3 * h_minus1_data_3; 239 | 240 | float h_hype = h1_minus1 + h2_minus1 + h3_minus1; 241 | float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data; 242 | 243 | float h_data = x_hype + h_hype; 244 | 245 | set_data_sf(H,num,channels,height,width,n,c,h,w,h_data); 246 | 247 | } 248 | } 249 | 250 | __global__ void forward_one_row_bottom_top( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) { 251 | CUDA_1D_KERNEL_LOOP(index, count) { 252 | 253 | int wc_count = width * channels; 254 | 255 | int n,c,h,w; 256 | int temp=index; 257 | h = T; 258 | n = temp / wc_count; 259 | temp = temp % wc_count; 260 | c = temp / width; 261 | temp = temp % width; 262 | w = temp; 263 | 264 | 265 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); 266 | 267 | 268 | float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse); 269 | float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1); 270 | float h1_minus1 = g_data_1 * h_minus1_data_1; 271 | 272 | 273 | float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse); 274 | float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h+1,w); 275 | float h2_minus1 = g_data_2 * h_minus1_data_2; 276 | 277 | float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse); 278 | float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1); 279 | float h3_minus1 = g_data_3 * h_minus1_data_3; 280 | 281 | float h_hype = h1_minus1 + h2_minus1 + h3_minus1; 282 | float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data; 283 | 284 | float h_data = x_hype + h_hype; 285 | 286 | set_data_sf(H,num,channels,height,width,n,c,h,w,h_data); 287 | 288 | } 289 | } 290 | 291 | 292 | __global__ void backward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) { 293 | CUDA_1D_KERNEL_LOOP(index, count) { 294 | 295 | int hc_count = height * channels; 296 | 297 | int n,c,h,w; 298 | int temp=index; 299 | 300 | w = T; 301 | n = temp / hc_count; 302 | temp = temp % hc_count; 303 | c = temp / height; 304 | temp = temp % height; 305 | h = temp; 306 | 307 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); 308 | 309 | //h(t)_diff = top(t)_diff 310 | float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w); 311 | 312 | //h(t)_diff += h(t+1)_diff * g(t+1) if t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_); 547 | 548 | err = cudaGetLastError(); 549 | if(cudaSuccess != err) 550 | { 551 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 552 | exit( -1 ); 553 | } 554 | } 555 | return 1; 556 | } 557 | 558 | int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream) 559 | { 560 | int count = height_ * channels_ * num_; 561 | int kThreadsPerBlock = 1024; 562 | cudaError_t err; 563 | 564 | for(int t = width_ - 1; t >= 0; t--) { 565 | forward_one_col_right_left<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_); 566 | 567 | err = cudaGetLastError(); 568 | if(cudaSuccess != err) 569 | { 570 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 571 | exit( -1 ); 572 | } 573 | } 574 | return 1; 575 | } 576 | 577 | int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream) 578 | { 579 | int count = width_ * channels_ * num_; 580 | int kThreadsPerBlock = 1024; 581 | cudaError_t err; 582 | 583 | for(int t=0; t< height_; t++) { 584 | forward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_); 585 | 586 | err = cudaGetLastError(); 587 | if(cudaSuccess != err) 588 | { 589 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 590 | exit( -1 ); 591 | } 592 | } 593 | return 1; 594 | } 595 | 596 | int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream) 597 | { 598 | int count = width_ * channels_ * num_; 599 | int kThreadsPerBlock = 1024; 600 | cudaError_t err; 601 | 602 | for(int t = height_-1; t >= 0; t--) { 603 | forward_one_row_bottom_top<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_); 604 | 605 | err = cudaGetLastError(); 606 | if(cudaSuccess != err) 607 | { 608 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 609 | exit( -1 ); 610 | } 611 | } 612 | return 1; 613 | } 614 | 615 | int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream) 616 | { 617 | int count = height_ * channels_ * num_; 618 | int kThreadsPerBlock = 1024; 619 | cudaError_t err; 620 | 621 | for(int t = width_ -1; t>=0; t--) 622 | { 623 | backward_one_col_left_right<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_); 624 | 625 | err = cudaGetLastError(); 626 | if(cudaSuccess != err) 627 | { 628 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 629 | exit( -1 ); 630 | } 631 | } 632 | return 1; 633 | } 634 | 635 | int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream) 636 | { 637 | int count = height_ * channels_ * num_; 638 | int kThreadsPerBlock = 1024; 639 | cudaError_t err; 640 | 641 | for(int t = 0; t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_); 644 | 645 | err = cudaGetLastError(); 646 | if(cudaSuccess != err) 647 | { 648 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 649 | exit( -1 ); 650 | } 651 | } 652 | return 1; 653 | } 654 | 655 | int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream) 656 | { 657 | int count = width_ * channels_ * num_; 658 | int kThreadsPerBlock = 1024; 659 | cudaError_t err; 660 | 661 | for(int t = height_-1; t>=0; t--) 662 | { 663 | backward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_); 664 | 665 | err = cudaGetLastError(); 666 | if(cudaSuccess != err) 667 | { 668 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 669 | exit( -1 ); 670 | } 671 | } 672 | return 1; 673 | } 674 | 675 | int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream) 676 | { 677 | int count = width_ * channels_ * num_; 678 | int kThreadsPerBlock = 1024; 679 | cudaError_t err; 680 | 681 | for(int t = 0; t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_); 684 | 685 | err = cudaGetLastError(); 686 | if(cudaSuccess != err) 687 | { 688 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 689 | exit( -1 ); 690 | } 691 | } 692 | return 1; 693 | } 694 | 695 | #ifdef __cplusplus 696 | } 697 | #endif 698 | -------------------------------------------------------------------------------- /libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o -------------------------------------------------------------------------------- /libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _GATERECURRENT2DNOIND_KERNEL 2 | #define _GATERECURRENT2DNOIND_KERNEL 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | int Forward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream); 9 | 10 | int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream); 11 | 12 | int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream); 13 | 14 | int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream); 15 | 16 | int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream); 17 | 18 | int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream); 19 | 20 | int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream); 21 | 22 | int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream); 23 | 24 | #ifdef __cplusplus 25 | } 26 | #endif 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c: -------------------------------------------------------------------------------- 1 | // gaterecurrent2dnoind_cuda.c 2 | #include 3 | #include 4 | #include "gaterecurrent2dnoind_cuda.h" 5 | #include "cuda/gaterecurrent2dnoind_kernel.h" 6 | 7 | // typedef bool boolean; 8 | 9 | // this symbol will be resolved automatically from PyTorch libs 10 | extern THCState *state; 11 | 12 | int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output) 13 | { 14 | // Grab the input tensor to flat 15 | float * X_data = THCudaTensor_data(state, X); 16 | float * G1_data = THCudaTensor_data(state, G1); 17 | float * G2_data = THCudaTensor_data(state, G2); 18 | float * G3_data = THCudaTensor_data(state, G3); 19 | float * H_data = THCudaTensor_data(state, output); 20 | 21 | // dimensions 22 | int num_ = THCudaTensor_size(state, X, 0); 23 | int channels_ = THCudaTensor_size(state, X, 1); 24 | int height_ = THCudaTensor_size(state, X, 2); 25 | int width_ = THCudaTensor_size(state, X, 3); 26 | 27 | cudaStream_t stream = THCState_getCurrentStream(state); 28 | 29 | if(horizontal_ && !reverse_) // left to right 30 | { 31 | //const int count = height_ * channels_ * num_; 32 | Forward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream); 33 | } 34 | else if(horizontal_ && reverse_) // right to left 35 | { 36 | Forward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream); 37 | } 38 | else if(!horizontal_ && !reverse_) // top to bottom 39 | { 40 | Forward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream); 41 | } 42 | else 43 | { 44 | Forward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream); 45 | } 46 | 47 | return 1; 48 | } 49 | 50 | int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_grad, THCudaTensor * G1_grad, THCudaTensor * G2_grad, THCudaTensor * G3_grad) 51 | { 52 | //Grab the input tensor to flat 53 | float * X_data = THCudaTensor_data(state, X); 54 | float * G1_data = THCudaTensor_data(state, G1); 55 | float * G2_data = THCudaTensor_data(state, G2); 56 | float * G3_data = THCudaTensor_data(state, G3); 57 | float * H_data = THCudaTensor_data(state, top); 58 | 59 | float * H_diff = THCudaTensor_data(state, top_grad); 60 | 61 | float * X_diff = THCudaTensor_data(state, X_grad); 62 | float * G1_diff = THCudaTensor_data(state, G1_grad); 63 | float * G2_diff = THCudaTensor_data(state, G2_grad); 64 | float * G3_diff = THCudaTensor_data(state, G3_grad); 65 | 66 | // dimensions 67 | int num_ = THCudaTensor_size(state, X, 0); 68 | int channels_ = THCudaTensor_size(state, X, 1); 69 | int height_ = THCudaTensor_size(state, X, 2); 70 | int width_ = THCudaTensor_size(state, X, 3); 71 | 72 | cudaStream_t stream = THCState_getCurrentStream(state); 73 | 74 | if(horizontal_ && ! reverse_) //left to right 75 | { 76 | Backward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream); 77 | } 78 | else if(horizontal_ && reverse_) //right to left 79 | { 80 | Backward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream); 81 | } 82 | else if(!horizontal_ && !reverse_) //top to bottom 83 | { 84 | Backward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream); 85 | } 86 | else { 87 | Backward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream); 88 | } 89 | 90 | return 1; 91 | } 92 | -------------------------------------------------------------------------------- /libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h: -------------------------------------------------------------------------------- 1 | 2 | // #include 3 | // gaterecurrent2dnoind_cuda.h 4 | int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output); 5 | 6 | int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_diff, THCudaTensor * G1_diff, THCudaTensor * G2_diff, THCudaTensor * G3_diff); 7 | -------------------------------------------------------------------------------- /libs/smooth_filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code cc from https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py 3 | """ 4 | src = ''' 5 | #include "/usr/local/cuda/include/math_functions.h" 6 | #define TB 256 7 | #define EPS 1e-7 8 | 9 | __device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) { 10 | double m[16], inv[16]; 11 | for (int i = 0; i < 4; i++) { 12 | for (int j = 0; j < 4; j++) { 13 | m[i * 4 + j] = m_in[i][j]; 14 | } 15 | } 16 | 17 | inv[0] = m[5] * m[10] * m[15] - 18 | m[5] * m[11] * m[14] - 19 | m[9] * m[6] * m[15] + 20 | m[9] * m[7] * m[14] + 21 | m[13] * m[6] * m[11] - 22 | m[13] * m[7] * m[10]; 23 | 24 | inv[4] = -m[4] * m[10] * m[15] + 25 | m[4] * m[11] * m[14] + 26 | m[8] * m[6] * m[15] - 27 | m[8] * m[7] * m[14] - 28 | m[12] * m[6] * m[11] + 29 | m[12] * m[7] * m[10]; 30 | 31 | inv[8] = m[4] * m[9] * m[15] - 32 | m[4] * m[11] * m[13] - 33 | m[8] * m[5] * m[15] + 34 | m[8] * m[7] * m[13] + 35 | m[12] * m[5] * m[11] - 36 | m[12] * m[7] * m[9]; 37 | 38 | inv[12] = -m[4] * m[9] * m[14] + 39 | m[4] * m[10] * m[13] + 40 | m[8] * m[5] * m[14] - 41 | m[8] * m[6] * m[13] - 42 | m[12] * m[5] * m[10] + 43 | m[12] * m[6] * m[9]; 44 | 45 | inv[1] = -m[1] * m[10] * m[15] + 46 | m[1] * m[11] * m[14] + 47 | m[9] * m[2] * m[15] - 48 | m[9] * m[3] * m[14] - 49 | m[13] * m[2] * m[11] + 50 | m[13] * m[3] * m[10]; 51 | 52 | inv[5] = m[0] * m[10] * m[15] - 53 | m[0] * m[11] * m[14] - 54 | m[8] * m[2] * m[15] + 55 | m[8] * m[3] * m[14] + 56 | m[12] * m[2] * m[11] - 57 | m[12] * m[3] * m[10]; 58 | 59 | inv[9] = -m[0] * m[9] * m[15] + 60 | m[0] * m[11] * m[13] + 61 | m[8] * m[1] * m[15] - 62 | m[8] * m[3] * m[13] - 63 | m[12] * m[1] * m[11] + 64 | m[12] * m[3] * m[9]; 65 | 66 | inv[13] = m[0] * m[9] * m[14] - 67 | m[0] * m[10] * m[13] - 68 | m[8] * m[1] * m[14] + 69 | m[8] * m[2] * m[13] + 70 | m[12] * m[1] * m[10] - 71 | m[12] * m[2] * m[9]; 72 | 73 | inv[2] = m[1] * m[6] * m[15] - 74 | m[1] * m[7] * m[14] - 75 | m[5] * m[2] * m[15] + 76 | m[5] * m[3] * m[14] + 77 | m[13] * m[2] * m[7] - 78 | m[13] * m[3] * m[6]; 79 | 80 | inv[6] = -m[0] * m[6] * m[15] + 81 | m[0] * m[7] * m[14] + 82 | m[4] * m[2] * m[15] - 83 | m[4] * m[3] * m[14] - 84 | m[12] * m[2] * m[7] + 85 | m[12] * m[3] * m[6]; 86 | 87 | inv[10] = m[0] * m[5] * m[15] - 88 | m[0] * m[7] * m[13] - 89 | m[4] * m[1] * m[15] + 90 | m[4] * m[3] * m[13] + 91 | m[12] * m[1] * m[7] - 92 | m[12] * m[3] * m[5]; 93 | 94 | inv[14] = -m[0] * m[5] * m[14] + 95 | m[0] * m[6] * m[13] + 96 | m[4] * m[1] * m[14] - 97 | m[4] * m[2] * m[13] - 98 | m[12] * m[1] * m[6] + 99 | m[12] * m[2] * m[5]; 100 | 101 | inv[3] = -m[1] * m[6] * m[11] + 102 | m[1] * m[7] * m[10] + 103 | m[5] * m[2] * m[11] - 104 | m[5] * m[3] * m[10] - 105 | m[9] * m[2] * m[7] + 106 | m[9] * m[3] * m[6]; 107 | 108 | inv[7] = m[0] * m[6] * m[11] - 109 | m[0] * m[7] * m[10] - 110 | m[4] * m[2] * m[11] + 111 | m[4] * m[3] * m[10] + 112 | m[8] * m[2] * m[7] - 113 | m[8] * m[3] * m[6]; 114 | 115 | inv[11] = -m[0] * m[5] * m[11] + 116 | m[0] * m[7] * m[9] + 117 | m[4] * m[1] * m[11] - 118 | m[4] * m[3] * m[9] - 119 | m[8] * m[1] * m[7] + 120 | m[8] * m[3] * m[5]; 121 | 122 | inv[15] = m[0] * m[5] * m[10] - 123 | m[0] * m[6] * m[9] - 124 | m[4] * m[1] * m[10] + 125 | m[4] * m[2] * m[9] + 126 | m[8] * m[1] * m[6] - 127 | m[8] * m[2] * m[5]; 128 | 129 | double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12]; 130 | 131 | if (abs(det) < 1e-9) { 132 | return false; 133 | } 134 | 135 | 136 | det = 1.0 / det; 137 | 138 | for (int i = 0; i < 4; i++) { 139 | for (int j = 0; j < 4; j++) { 140 | inv_out[i][j] = inv[i * 4 + j] * det; 141 | } 142 | } 143 | 144 | return true; 145 | } 146 | 147 | extern "C" 148 | __global__ void best_local_affine_kernel( 149 | float *output, float *input, float *affine_model, 150 | int h, int w, float epsilon, int kernel_radius 151 | ) 152 | { 153 | int size = h * w; 154 | int id = blockIdx.x * blockDim.x + threadIdx.x; 155 | 156 | if (id < size) { 157 | int x = id % w, y = id / w; 158 | 159 | double Mt_M[4][4] = {}; // 4x4 160 | double invMt_M[4][4] = {}; 161 | double Mt_S[3][4] = {}; // RGB -> 1x4 162 | double A[3][4] = {}; 163 | for (int i = 0; i < 4; i++) 164 | for (int j = 0; j < 4; j++) { 165 | Mt_M[i][j] = 0, invMt_M[i][j] = 0; 166 | if (i != 3) { 167 | Mt_S[i][j] = 0, A[i][j] = 0; 168 | if (i == j) 169 | Mt_M[i][j] = 1e-3; 170 | } 171 | } 172 | 173 | for (int dy = -kernel_radius; dy <= kernel_radius; dy++) { 174 | for (int dx = -kernel_radius; dx <= kernel_radius; dx++) { 175 | 176 | int xx = x + dx, yy = y + dy; 177 | int id2 = yy * w + xx; 178 | 179 | if (0 <= xx && xx < w && 0 <= yy && yy < h) { 180 | 181 | Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size]; 182 | Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size]; 183 | Mt_M[0][2] += input[id2 + 2*size] * input[id2]; 184 | Mt_M[0][3] += input[id2 + 2*size]; 185 | 186 | Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size]; 187 | Mt_M[1][1] += input[id2 + size] * input[id2 + size]; 188 | Mt_M[1][2] += input[id2 + size] * input[id2]; 189 | Mt_M[1][3] += input[id2 + size]; 190 | 191 | Mt_M[2][0] += input[id2] * input[id2 + 2*size]; 192 | Mt_M[2][1] += input[id2] * input[id2 + size]; 193 | Mt_M[2][2] += input[id2] * input[id2]; 194 | Mt_M[2][3] += input[id2]; 195 | 196 | Mt_M[3][0] += input[id2 + 2*size]; 197 | Mt_M[3][1] += input[id2 + size]; 198 | Mt_M[3][2] += input[id2]; 199 | Mt_M[3][3] += 1; 200 | 201 | Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size]; 202 | Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size]; 203 | Mt_S[0][2] += input[id2] * output[id2 + 2*size]; 204 | Mt_S[0][3] += output[id2 + 2*size]; 205 | 206 | Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size]; 207 | Mt_S[1][1] += input[id2 + size] * output[id2 + size]; 208 | Mt_S[1][2] += input[id2] * output[id2 + size]; 209 | Mt_S[1][3] += output[id2 + size]; 210 | 211 | Mt_S[2][0] += input[id2 + 2*size] * output[id2]; 212 | Mt_S[2][1] += input[id2 + size] * output[id2]; 213 | Mt_S[2][2] += input[id2] * output[id2]; 214 | Mt_S[2][3] += output[id2]; 215 | } 216 | } 217 | } 218 | 219 | bool success = InverseMat4x4(Mt_M, invMt_M); 220 | 221 | for (int i = 0; i < 3; i++) { 222 | for (int j = 0; j < 4; j++) { 223 | for (int k = 0; k < 4; k++) { 224 | A[i][j] += invMt_M[j][k] * Mt_S[i][k]; 225 | } 226 | } 227 | } 228 | 229 | for (int i = 0; i < 3; i++) { 230 | for (int j = 0; j < 4; j++) { 231 | int affine_id = i * 4 + j; 232 | affine_model[12 * id + affine_id] = A[i][j]; 233 | } 234 | } 235 | } 236 | return ; 237 | } 238 | 239 | extern "C" 240 | __global__ void bilateral_smooth_kernel( 241 | float *affine_model, float *filtered_affine_model, float *guide, 242 | int h, int w, int kernel_radius, float sigma1, float sigma2 243 | ) 244 | { 245 | int id = blockIdx.x * blockDim.x + threadIdx.x; 246 | int size = h * w; 247 | if (id < size) { 248 | int x = id % w; 249 | int y = id / w; 250 | 251 | double sum_affine[12] = {}; 252 | double sum_weight = 0; 253 | for (int dx = -kernel_radius; dx <= kernel_radius; dx++) { 254 | for (int dy = -kernel_radius; dy <= kernel_radius; dy++) { 255 | int yy = y + dy, xx = x + dx; 256 | int id2 = yy * w + xx; 257 | if (0 <= xx && xx < w && 0 <= yy && yy < h) { 258 | float color_diff1 = guide[yy*w + xx] - guide[y*w + x]; 259 | float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size]; 260 | float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size]; 261 | float color_diff_sqr = 262 | (color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3; 263 | 264 | float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1)); 265 | float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2)); 266 | float weight = v1 * v2; 267 | 268 | for (int i = 0; i < 3; i++) { 269 | for (int j = 0; j < 4; j++) { 270 | int affine_id = i * 4 + j; 271 | sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id]; 272 | } 273 | } 274 | sum_weight += weight; 275 | } 276 | } 277 | } 278 | 279 | for (int i = 0; i < 3; i++) { 280 | for (int j = 0; j < 4; j++) { 281 | int affine_id = i * 4 + j; 282 | filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight; 283 | } 284 | } 285 | } 286 | return ; 287 | } 288 | 289 | 290 | extern "C" 291 | __global__ void reconstruction_best_kernel( 292 | float *input, float *filtered_affine_model, float *filtered_best_output, 293 | int h, int w 294 | ) 295 | { 296 | int id = blockIdx.x * blockDim.x + threadIdx.x; 297 | int size = h * w; 298 | if (id < size) { 299 | double out1 = 300 | input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] + 301 | input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] + 302 | input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] + 303 | filtered_affine_model[id*12 + 3]; //A[0][3]; 304 | double out2 = 305 | input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] + 306 | input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] + 307 | input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] + 308 | filtered_affine_model[id*12 + 7]; //A[1][3]; 309 | double out3 = 310 | input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] + 311 | input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] + 312 | input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] + 313 | filtered_affine_model[id*12 + 11]; // A[2][3]; 314 | 315 | filtered_best_output[id] = out1; 316 | filtered_best_output[id + size] = out2; 317 | filtered_best_output[id + 2*size] = out3; 318 | } 319 | return ; 320 | } 321 | ''' 322 | 323 | import cv2 324 | import torch 325 | import numpy as np 326 | from PIL import Image 327 | from cupy.cuda import function 328 | from pynvrtc.compiler import Program 329 | from collections import namedtuple 330 | 331 | 332 | def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e): 333 | # program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8')) 334 | # ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')]) 335 | program = Program(src, 'best_local_affine_kernel.cu') 336 | ptx = program.compile(['-I/usr/local/cuda/include']) 337 | m = function.Module() 338 | m.load(bytes(ptx.encode())) 339 | 340 | _reconstruction_best_kernel = m.get_function('reconstruction_best_kernel') 341 | _bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel') 342 | _best_local_affine_kernel = m.get_function('best_local_affine_kernel') 343 | Stream = namedtuple('Stream', ['ptr']) 344 | s = Stream(ptr=torch.cuda.current_stream().cuda_stream) 345 | 346 | filter_radius = f_r 347 | sigma1 = filter_radius / 3 348 | sigma2 = f_e 349 | radius = (patch - 1) / 2 350 | 351 | filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda() 352 | affine_model = torch.zeros((h * w, 12)).cuda() 353 | filtered_affine_model =torch.zeros((h * w, 12)).cuda() 354 | 355 | input_ = torch.from_numpy(input_cpu).cuda() 356 | output_ = torch.from_numpy(output_cpu).cuda() 357 | _best_local_affine_kernel( 358 | grid=(int((h * w) / 256 + 1), 1), 359 | block=(256, 1, 1), 360 | args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(), 361 | np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s 362 | ) 363 | 364 | _bilateral_smooth_kernel( 365 | grid=(int((h * w) / 256 + 1), 1), 366 | block=(256, 1, 1), 367 | args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s 368 | ) 369 | 370 | _reconstruction_best_kernel( 371 | grid=(int((h * w) / 256 + 1), 1), 372 | block=(256, 1, 1), 373 | args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(), 374 | np.int32(h), np.int32(w)], stream=s 375 | ) 376 | numpy_filtered_best_output = filtered_best_output.cpu().numpy() 377 | return numpy_filtered_best_output 378 | 379 | 380 | def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1): 381 | ''' 382 | :param initImg: intermediate output. Either image path or PIL Image 383 | :param contentImg: content image output. Either path or PIL Image 384 | :return: stylized output image. PIL Image 385 | ''' 386 | if type(initImg) == str: 387 | initImg = Image.open(initImg).convert("RGB") 388 | best_image_bgr = np.array(initImg, dtype=np.float32) 389 | bW, bH, bC = best_image_bgr.shape 390 | best_image_bgr = best_image_bgr[:, :, ::-1] 391 | best_image_bgr = best_image_bgr.transpose((2, 0, 1)) 392 | 393 | if type(contentImg) == str: 394 | contentImg = Image.open(contentImg).convert("RGB") 395 | content_input = contentImg.resize((bH,bW)) 396 | else: 397 | content_input = cv2.resize(contentImg,(bH,bW)) 398 | content_input = np.array(content_input, dtype=np.float32) 399 | content_input = content_input[:, :, ::-1] 400 | content_input = content_input.transpose((2, 0, 1)) 401 | input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255. 402 | _, H, W = np.shape(input_) 403 | output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255. 404 | best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge) 405 | best_ = best_.transpose(1, 2, 0) 406 | result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.))) 407 | return result 408 | -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import cv2 4 | import time 5 | import torch 6 | import scipy.misc 7 | import numpy as np 8 | import scipy.sparse 9 | from PIL import Image 10 | import scipy.sparse.linalg 11 | from cv2.ximgproc import jointBilateralFilter 12 | from torch.utils.serialization import load_lua 13 | from numpy.lib.stride_tricks import as_strided 14 | 15 | def whiten(cF): 16 | cFSize = cF.size() 17 | c_mean = torch.mean(cF,1) # c x (h x w) 18 | c_mean = c_mean.unsqueeze(1).expand_as(cF) 19 | cF = cF - c_mean 20 | 21 | contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + torch.eye(cFSize[0]).double() 22 | c_u,c_e,c_v = torch.svd(contentConv,some=False) 23 | 24 | k_c = cFSize[0] 25 | for i in range(cFSize[0]): 26 | if c_e[i] < 0.00001: 27 | k_c = i 28 | break 29 | 30 | c_d = (c_e[0:k_c]).pow(-0.5) 31 | step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d)) 32 | step2 = torch.mm(step1,(c_v[:,0:k_c].t())) 33 | whiten_cF = torch.mm(step2,cF) 34 | return whiten_cF 35 | 36 | def numpy2cv2(cont,style,prop,width,height): 37 | cont = cont.transpose((1,2,0)) 38 | cont = cont[...,::-1] 39 | cont = cont * 255 40 | cont = cv2.resize(cont,(width,height)) 41 | #cv2.resize(iimg,(width,height)) 42 | style = style.transpose((1,2,0)) 43 | style = style[...,::-1] 44 | style = style * 255 45 | style = cv2.resize(style,(width,height)) 46 | 47 | prop = prop.transpose((1,2,0)) 48 | prop = prop[...,::-1] 49 | prop = prop * 255 50 | prop = cv2.resize(prop,(width,height)) 51 | 52 | #return np.concatenate((cont,np.concatenate((style,prop),axis=1)),axis=1) 53 | return prop,cont 54 | 55 | def makeVideo(content,style,props,outf): 56 | print('Stack transferred frames back to video...') 57 | layers,height,width = content[0].shape 58 | fourcc = cv2.VideoWriter_fourcc(*'MJPG') 59 | video = cv2.VideoWriter(os.path.join(outf,'transfer.avi'),fourcc,10.0,(width,height)) 60 | ori_video = cv2.VideoWriter(os.path.join(outf,'content.avi'),fourcc,10.0,(width,height)) 61 | for j in range(len(content)): 62 | prop,cont = numpy2cv2(content[j],style,props[j],width,height) 63 | cv2.imwrite('prop.png',prop) 64 | cv2.imwrite('content.png',cont) 65 | # TODO: this is ugly, fix this 66 | imgj = cv2.imread('prop.png') 67 | imgc = cv2.imread('content.png') 68 | 69 | video.write(imgj) 70 | ori_video.write(imgc) 71 | # RGB or BRG, yuks 72 | video.release() 73 | ori_video.release() 74 | os.remove('prop.png') 75 | os.remove('content.png') 76 | print('Transferred video saved at %s.'%outf) 77 | 78 | def print_options(opt): 79 | message = '' 80 | message += '----------------- Options ---------------\n' 81 | for k, v in sorted(vars(opt).items()): 82 | comment = '' 83 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 84 | message += '----------------- End -------------------' 85 | print(message) 86 | 87 | # save to the disk 88 | expr_dir = os.path.join(opt.outf) 89 | os.makedirs(expr_dir,exist_ok=True) 90 | file_name = os.path.join(expr_dir, 'opt.txt') 91 | with open(file_name, 'wt') as opt_file: 92 | opt_file.write(message) 93 | opt_file.write('\n') 94 | -------------------------------------------------------------------------------- /real-time-demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from PIL import Image 7 | from libs.Loader import Dataset 8 | from libs.Matrix import MulLayer 9 | from libs.utils import makeVideo 10 | import torch.backends.cudnn as cudnn 11 | from libs.models import encoder3,encoder4 12 | from libs.models import decoder3,decoder4 13 | import torchvision.transforms as transforms 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--vgg_dir", default='models/vgg_r31.pth', 17 | help='pre-trained encoder path') 18 | parser.add_argument("--decoder_dir", default='models/dec_r31.pth', 19 | help='pre-trained decoder path') 20 | parser.add_argument("--style", default="data/style/in2.jpg", 21 | help='path to style image') 22 | parser.add_argument("--matrixPath", default="models/r31.pth", 23 | help='path to pre-trained model') 24 | parser.add_argument('--fineSize', type=int, default=256, 25 | help='crop image size') 26 | parser.add_argument("--name",default="transferred_video", 27 | help="name of generated video") 28 | parser.add_argument("--layer",default="r31", 29 | help="features of which layer to transfer") 30 | parser.add_argument("--outf",default="real_time_demo_output", 31 | help="output folder") 32 | 33 | ################# PREPARATIONS ################# 34 | opt = parser.parse_args() 35 | opt.cuda = torch.cuda.is_available() 36 | print(opt) 37 | os.makedirs(opt.outf,exist_ok=True) 38 | cudnn.benchmark = True 39 | 40 | ################# DATA ################# 41 | def loadImg(imgPath): 42 | img = Image.open(imgPath).convert('RGB') 43 | transform = transforms.Compose([ 44 | transforms.Scale(opt.fineSize), 45 | transforms.ToTensor()]) 46 | return transform(img) 47 | style = loadImg(opt.style).unsqueeze(0) 48 | 49 | ################# MODEL ################# 50 | if(opt.layer == 'r31'): 51 | matrix = MulLayer(layer='r31') 52 | vgg = encoder3() 53 | dec = decoder3() 54 | elif(opt.layer == 'r41'): 55 | matrix = MulLayer(layer='r41') 56 | vgg = encoder4() 57 | dec = decoder4() 58 | vgg.load_state_dict(torch.load(opt.vgg_dir)) 59 | dec.load_state_dict(torch.load(opt.dec_dir)) 60 | matrix.load_state_dict(torch.load(opt.matrixPath)) 61 | for param in vgg.parameters(): 62 | param.requires_grad = False 63 | for param in dec.parameters(): 64 | param.requires_grad = False 65 | for param in matrix.parameters(): 66 | param.requires_grad = False 67 | 68 | ################# GLOBAL VARIABLE ################# 69 | content = torch.Tensor(1,3,opt.fineSize,opt.fineSize) 70 | 71 | ################# GPU ################# 72 | if(opt.cuda): 73 | vgg.cuda() 74 | dec.cuda() 75 | matrix.cuda() 76 | 77 | style = style.cuda() 78 | content = content.cuda() 79 | 80 | totalTime = 0 81 | imageCounter = 0 82 | result_frames = [] 83 | contents = [] 84 | styles = [] 85 | cap = cv2.VideoCapture(0) 86 | cap.set(3,256) 87 | cap.set(4,512) 88 | fourcc = cv2.VideoWriter_fourcc(*'MJPG') 89 | out = cv2.VideoWriter(os.path.join(opt.outf,opt.name+'.avi'),fourcc,20.0,(512,256)) 90 | 91 | with torch.no_grad(): 92 | sF = vgg(style) 93 | 94 | while(True): 95 | ret,frame = cap.read() 96 | frame = cv2.resize(frame,(512,256),interpolation=cv2.INTER_CUBIC) 97 | frame = frame.transpose((2,0,1)) 98 | frame = frame[::-1,:,:] 99 | frame = frame/255.0 100 | frame = torch.from_numpy(frame.copy()).unsqueeze(0) 101 | content.data.resize_(frame.size()).copy_(frame) 102 | with torch.no_grad(): 103 | cF = vgg(content) 104 | if(opt.layer == 'r41'): 105 | feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer]) 106 | else: 107 | feature,transmatrix = matrix(cF,sF) 108 | transfer = dec(feature) 109 | transfer = transfer.clamp(0,1).squeeze(0).data.cpu().numpy() 110 | transfer = transfer.transpose((1,2,0)) 111 | transfer = transfer[...,::-1] 112 | out.write(np.uint8(transfer*255)) 113 | cv2.imshow('frame',transfer) 114 | if cv2.waitKey(1) & 0xFF == ord('q'): 115 | break 116 | 117 | # When everything done, release the capture 118 | out.release() 119 | cap.release() 120 | cv2.destroyAllWindows() 121 | --------------------------------------------------------------------------------