├── .DS_Store
├── LICENSE
├── README.md
├── TestArtistic.py
├── TestPhotoReal.py
├── TestVideo.py
├── Train.py
├── TrainSPN.py
├── data
├── .DS_Store
├── content
│ ├── .DS_Store
│ ├── 1.jpg
│ └── chicago.png
├── photo_real
│ ├── .DS_Store
│ ├── content
│ │ ├── .DS_Store
│ │ └── images
│ │ │ ├── in16.png
│ │ │ ├── in25.png
│ │ │ ├── in26.png
│ │ │ ├── in29.png
│ │ │ ├── in3.png
│ │ │ ├── in39.png
│ │ │ ├── in53.png
│ │ │ └── in7.png
│ ├── contentSeg
│ │ ├── in16.png
│ │ ├── in25.png
│ │ ├── in26.png
│ │ ├── in29.png
│ │ ├── in3.png
│ │ ├── in39.png
│ │ ├── in53.png
│ │ └── in7.png
│ ├── style
│ │ └── images
│ │ │ ├── in16.png
│ │ │ ├── in25.png
│ │ │ ├── in26.png
│ │ │ ├── in29.png
│ │ │ ├── in3.png
│ │ │ ├── in39.png
│ │ │ ├── in53.png
│ │ │ └── in7.png
│ └── styleSeg
│ │ ├── in16.png
│ │ ├── in25.png
│ │ ├── in26.png
│ │ ├── in29.png
│ │ ├── in3.png
│ │ ├── in39.png
│ │ ├── in53.png
│ │ └── in7.png
├── style
│ ├── .DS_Store
│ ├── 27.jpg
│ ├── 3314.jpg
│ ├── antimonocromatismo.jpg
│ ├── in2.jpg
│ ├── picasso_self_portrait.jpg
│ └── sketch.jpg
└── videos
│ ├── .DS_Store
│ └── content
│ ├── .DS_Store
│ └── mountain_2
│ ├── .DS_Store
│ ├── frame_0001.png
│ ├── frame_0002.png
│ ├── frame_0003.png
│ ├── frame_0004.png
│ ├── frame_0005.png
│ ├── frame_0006.png
│ ├── frame_0007.png
│ ├── frame_0008.png
│ ├── frame_0009.png
│ ├── frame_0010.png
│ ├── frame_0011.png
│ ├── frame_0012.png
│ ├── frame_0013.png
│ ├── frame_0014.png
│ ├── frame_0015.png
│ ├── frame_0016.png
│ ├── frame_0017.png
│ ├── frame_0018.png
│ ├── frame_0019.png
│ ├── frame_0020.png
│ ├── frame_0021.png
│ ├── frame_0022.png
│ ├── frame_0023.png
│ ├── frame_0024.png
│ ├── frame_0025.png
│ ├── frame_0026.png
│ ├── frame_0027.png
│ ├── frame_0028.png
│ ├── frame_0029.png
│ ├── frame_0030.png
│ ├── frame_0031.png
│ ├── frame_0032.png
│ ├── frame_0033.png
│ ├── frame_0034.png
│ ├── frame_0035.png
│ ├── frame_0036.png
│ ├── frame_0037.png
│ ├── frame_0038.png
│ ├── frame_0039.png
│ ├── frame_0040.png
│ ├── frame_0041.png
│ ├── frame_0042.png
│ ├── frame_0043.png
│ ├── frame_0044.png
│ ├── frame_0045.png
│ ├── frame_0046.png
│ ├── frame_0047.png
│ ├── frame_0048.png
│ ├── frame_0049.png
│ └── frame_0050.png
├── doc
├── .DS_Store
└── images
│ ├── .DS_Store
│ ├── chicago_27.png
│ ├── chicago_paste.png
│ ├── content.gif
│ ├── in5_result.png
│ ├── photo_content.png
│ └── test.gif
├── libs
├── .DS_Store
├── Criterion.py
├── Loader.py
├── LoaderPhotoReal.py
├── Matrix.py
├── MatrixTest.py
├── SPN.py
├── __init__.py
├── models.py
├── pytorch_spn
│ ├── README.md
│ ├── __init__.py
│ ├── _ext
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ └── gaterecurrent2dnoind
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ │ └── _gaterecurrent2dnoind.so
│ ├── build.py
│ ├── functions
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── gaterecurrent2dnoind.cpython-36.pyc
│ │ ├── gaterecurrent2dnoind.py
│ │ └── gaterecurrent2dnoind.pyc
│ ├── left_right_demo.py
│ ├── make.sh
│ ├── modules
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── gaterecurrent2dnoind.cpython-36.pyc
│ │ ├── gaterecurrent2dnoind.py
│ │ └── gaterecurrent2dnoind.pyc
│ └── src
│ │ ├── .DS_Store
│ │ ├── cuda
│ │ ├── gaterecurrent2dnoind_kernel.cu
│ │ ├── gaterecurrent2dnoind_kernel.cu.o
│ │ └── gaterecurrent2dnoind_kernel.h
│ │ ├── gaterecurrent2dnoind_cuda.c
│ │ └── gaterecurrent2dnoind_cuda.h
├── smooth_filter.py
└── utils.py
└── real-time-demo.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2018, SunshineAtNoon
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Learning Linear Transformations for Fast Image and Video Style Transfer
2 | **[[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Learning_Linear_Transformations_for_Fast_Image_and_Video_Style_Transfer_CVPR_2019_paper.pdf)** **[[Project Page]](https://sites.google.com/view/linear-style-transfer-cvpr19/)**
3 |
4 | 

5 | 

6 |
7 | ## Prerequisites
8 | - [Pytorch](http://pytorch.org/)
9 | - [torchvision](https://github.com/pytorch/vision)
10 | - [opencv](https://opencv.org/) for video generation
11 |
12 | **All code tested on Ubuntu 16.04, pytorch 0.4.1, and opencv 3.4.2**
13 |
14 | ## Style Transfer
15 | - Clone from github: `git clone https://github.com/sunshineatnoon/LinearStyleTransfer`
16 | - Download pre-trained models from [google drive](https://drive.google.com/file/d/1H9T5rfXGlGCUh04DGkpkMFbVnmscJAbs/view?usp=sharing).
17 | - Uncompress to root folder :
18 | ```
19 | cd LinearStyleTransfer
20 | unzip models.zip
21 | rm models.zip
22 | ```
23 |
24 | #### Artistic style transfer
25 | ```
26 | python TestArtistic.py
27 | ```
28 | or conduct style transfer on relu_31 features
29 | ```
30 | python TestArtistic.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31
31 | ```
32 |
33 | #### Photo-realistic style transfer
34 | For photo-realistic style transfer, we need first compile the [pytorch_spn](https://github.com/Liusifei/pytorch_spn) repository.
35 | ```
36 | cd libs/pytorch_spn
37 | sh make.sh
38 | cd ../..
39 | ```
40 | Then:
41 | ```
42 | python TestPhotoReal.py
43 | ```
44 | Note: images with `_filtered.png` as postfix are images filtered by the SPN after style transfer, images with `_smooth.png` as postfix are images post process by a [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py).
45 |
46 | #### Video style transfer
47 | ```
48 | python TestVideo.py
49 | ```
50 |
51 | #### Real-time video demo
52 | ```
53 | python real-time-demo.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31
54 | ```
55 |
56 | ## Model Training
57 | ### Data Preparation
58 | - MSCOCO
59 | ```
60 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip
61 | ```
62 | - WikiArt
63 | - Either manually download from [kaggle](https://www.kaggle.com/c/painter-by-numbers).
64 | - Or install [kaggle-cli](https://github.com/floydwch/kaggle-cli) and download by running:
65 | ```
66 | kg download -u -p -c painter-by-numbers -f train.zip
67 | ```
68 |
69 | ### Training
70 | #### Train a style transfer model
71 | To train a model that transfers relu4_1 features, run:
72 | ```
73 | python Train.py --vgg_dir models/vgg_r41.pth --decoder_dir models/dec_r41.pth --layer r41 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR
74 | ```
75 | or train a model that transfers relu3_1 features:
76 | ```
77 | python Train.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --layer r31 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR
78 | ```
79 | Key hyper-parameters:
80 | - style_layers: which features to compute style loss.
81 | - style_weight: larger style weight leads to heavier style in transferred images.
82 |
83 | Intermediate results and weight will be stored in `OUTPUT_DIR`
84 |
85 | #### Train a SPN model to cancel distortions for photo-realistic style transfer
86 | Run:
87 | ```
88 | python TrainSPN.py --contentPath PATH_TO_MSCOCO
89 | ```
90 |
91 | ### Acknowledgement
92 | - We use the [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py) by [LouieYang](https://github.com/LouieYang) in the photo-realistic style transfer.
93 |
94 | ### Citation
95 | ```
96 | @inproceedings{li2018learning,
97 | author = {Li, Xueting and Liu, Sifei and Kautz, Jan and Yang, Ming-Hsuan},
98 | title = {Learning Linear Transformations for Fast Arbitrary Style Transfer},
99 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
100 | year = {2019}
101 | }
102 | ```
103 |
--------------------------------------------------------------------------------
/TestArtistic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from libs.Loader import Dataset
5 | from libs.Matrix import MulLayer
6 | import torchvision.utils as vutils
7 | import torch.backends.cudnn as cudnn
8 | from libs.utils import print_options
9 | from libs.models import encoder3,encoder4, encoder5
10 | from libs.models import decoder3,decoder4, decoder5
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
14 | help='pre-trained encoder path')
15 | parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
16 | help='pre-trained decoder path')
17 | parser.add_argument("--matrixPath", default='models/r41.pth',
18 | help='pre-trained model path')
19 | parser.add_argument("--stylePath", default="data/style/",
20 | help='path to style image')
21 | parser.add_argument("--contentPath", default="data/content/",
22 | help='path to frames')
23 | parser.add_argument("--outf", default="Artistic/",
24 | help='path to transferred images')
25 | parser.add_argument("--batchSize", type=int,default=1,
26 | help='batch size')
27 | parser.add_argument('--loadSize', type=int, default=256,
28 | help='scale image size')
29 | parser.add_argument('--fineSize', type=int, default=256,
30 | help='crop image size')
31 | parser.add_argument("--layer", default="r41",
32 | help='which features to transfer, either r31 or r41')
33 |
34 | ################# PREPARATIONS #################
35 | opt = parser.parse_args()
36 | opt.cuda = torch.cuda.is_available()
37 | print_options(opt)
38 |
39 | os.makedirs(opt.outf,exist_ok=True)
40 | cudnn.benchmark = True
41 |
42 | ################# DATA #################
43 | content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize,test=True)
44 | content_loader = torch.utils.data.DataLoader(dataset=content_dataset,
45 | batch_size = opt.batchSize,
46 | shuffle = False,
47 | num_workers = 1)
48 | style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize,test=True)
49 | style_loader = torch.utils.data.DataLoader(dataset=style_dataset,
50 | batch_size = opt.batchSize,
51 | shuffle = False,
52 | num_workers = 1)
53 |
54 | ################# MODEL #################
55 | if(opt.layer == 'r31'):
56 | vgg = encoder3()
57 | dec = decoder3()
58 | elif(opt.layer == 'r41'):
59 | vgg = encoder4()
60 | dec = decoder4()
61 | matrix = MulLayer(opt.layer)
62 | vgg.load_state_dict(torch.load(opt.vgg_dir))
63 | dec.load_state_dict(torch.load(opt.decoder_dir))
64 | matrix.load_state_dict(torch.load(opt.matrixPath))
65 |
66 | ################# GLOBAL VARIABLE #################
67 | contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
68 | styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
69 |
70 | ################# GPU #################
71 | if(opt.cuda):
72 | vgg.cuda()
73 | dec.cuda()
74 | matrix.cuda()
75 | contentV = contentV.cuda()
76 | styleV = styleV.cuda()
77 |
78 | for ci,(content,contentName) in enumerate(content_loader):
79 | contentName = contentName[0]
80 | contentV.resize_(content.size()).copy_(content)
81 | for sj,(style,styleName) in enumerate(style_loader):
82 | styleName = styleName[0]
83 | styleV.resize_(style.size()).copy_(style)
84 |
85 | # forward
86 | with torch.no_grad():
87 | sF = vgg(styleV)
88 | cF = vgg(contentV)
89 |
90 | if(opt.layer == 'r41'):
91 | feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
92 | else:
93 | feature,transmatrix = matrix(cF,sF)
94 | transfer = dec(feature)
95 |
96 | transfer = transfer.clamp(0,1)
97 | vutils.save_image(transfer,'%s/%s_%s.png'%(opt.outf,contentName,styleName),normalize=True,scale_each=True,nrow=opt.batchSize)
98 | print('Transferred image saved at %s%s_%s.png'%(opt.outf,contentName,styleName))
99 |
--------------------------------------------------------------------------------
/TestPhotoReal.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import torch
5 | import argparse
6 | import numpy as np
7 | from PIL import Image
8 | from libs.SPN import SPN
9 | import torchvision.utils as vutils
10 | from libs.utils import print_options
11 | from libs.MatrixTest import MulLayer
12 | import torch.backends.cudnn as cudnn
13 | from libs.LoaderPhotoReal import Dataset
14 | from libs.models import encoder3,encoder4
15 | from libs.models import decoder3,decoder4
16 | import torchvision.transforms as transforms
17 | from libs.smooth_filter import smooth_filter
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
21 | help='pre-trained encoder path')
22 | parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
23 | help='pre-trained decoder path')
24 | parser.add_argument("--matrixPath", default='models/r41.pth',
25 | help='pre-trained model path')
26 | parser.add_argument("--stylePath", default="data/photo_real/style/images/",
27 | help='path to style image')
28 | parser.add_argument("--styleSegPath", default="data/photo_real/styleSeg/",
29 | help='path to style image masks')
30 | parser.add_argument("--contentPath", default="data/photo_real/content/images/",
31 | help='path to content image')
32 | parser.add_argument("--contentSegPath", default="data/photo_real/contentSeg/",
33 | help='path to content image masks')
34 | parser.add_argument("--outf", default="PhotoReal/",
35 | help='path to save output images')
36 | parser.add_argument("--batchSize", type=int,default=1,
37 | help='batch size')
38 | parser.add_argument('--fineSize', type=int, default=512,
39 | help='image size')
40 | parser.add_argument("--layer", default="r41",
41 | help='features of which layer to transform, either r31 or r41')
42 | parser.add_argument("--spn_dir", default='models/r41_spn.pth',
43 | help='path to pretrained SPN model')
44 |
45 | ################# PREPARATIONS #################
46 | opt = parser.parse_args()
47 | opt.cuda = torch.cuda.is_available()
48 | print_options(opt)
49 |
50 | os.makedirs(opt.outf, exist_ok=True)
51 |
52 | cudnn.benchmark = True
53 |
54 | ################# DATA #################
55 | dataset = Dataset(opt.contentPath,opt.stylePath,opt.contentSegPath,opt.styleSegPath,opt.fineSize)
56 | loader = torch.utils.data.DataLoader(dataset=dataset,
57 | batch_size=1,
58 | shuffle=False)
59 |
60 | ################# MODEL #################
61 | if(opt.layer == 'r31'):
62 | vgg = encoder3()
63 | dec = decoder3()
64 | elif(opt.layer == 'r41'):
65 | vgg = encoder4()
66 | dec = decoder4()
67 | matrix = MulLayer(opt.layer)
68 | vgg.load_state_dict(torch.load(opt.vgg_dir))
69 | dec.load_state_dict(torch.load(opt.decoder_dir))
70 | matrix.load_state_dict(torch.load(opt.matrixPath))
71 | spn = SPN()
72 | spn.load_state_dict(torch.load(opt.spn_dir))
73 |
74 | ################# GLOBAL VARIABLE #################
75 | contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
76 | styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
77 | whitenV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
78 |
79 | ################# GPU #################
80 | if(opt.cuda):
81 | vgg.cuda()
82 | dec.cuda()
83 | spn.cuda()
84 | matrix.cuda()
85 | contentV = contentV.cuda()
86 | styleV = styleV.cuda()
87 | whitenV = whitenV.cuda()
88 |
89 | for i,(contentImg,styleImg,whitenImg,cmasks,smasks,imname) in enumerate(loader):
90 | imname = imname[0]
91 | contentV.resize_(contentImg.size()).copy_(contentImg)
92 | styleV.resize_(styleImg.size()).copy_(styleImg)
93 | whitenV.resize_(whitenImg.size()).copy_(whitenImg)
94 |
95 | # forward
96 | sF = vgg(styleV)
97 | cF = vgg(contentV)
98 |
99 | with torch.no_grad():
100 | if(opt.layer == 'r41'):
101 | feature = matrix(cF[opt.layer],sF[opt.layer],cmasks,smasks)
102 | else:
103 | feature = matrix(cF,sF,cmasks,smasks)
104 | transfer = dec(feature)
105 | filtered = spn(transfer,whitenV)
106 | vutils.save_image(transfer,os.path.join(opt.outf,'%s_transfer.png'%(imname.split('.')[0])))
107 |
108 | filtered = filtered.clamp(0,1)
109 | filtered = filtered.cpu()
110 | vutils.save_image(filtered,'%s/%s_filtered.png'%(opt.outf,imname.split('.')[0]))
111 | out_img = filtered.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy()
112 | content = contentImg.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy()
113 | content = content.copy()
114 | out_img = out_img.copy()
115 | smoothed = smooth_filter(out_img, content, f_radius=15, f_edge=1e-1)
116 | smoothed.save('%s/%s_smooth.png'%(opt.outf,imname.split('.')[0]))
117 | print('Transferred image saved at %s%s, filtered image saved at %s%s_filtered.png' \
118 | %(opt.outf,imname,opt.outf,imname.split('.')[0]))
119 |
--------------------------------------------------------------------------------
/TestVideo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from PIL import Image
5 | from libs.Loader import Dataset
6 | from libs.Matrix import MulLayer
7 | import torch.backends.cudnn as cudnn
8 | from libs.models import encoder3,encoder4
9 | from libs.models import decoder3,decoder4
10 | import torchvision.transforms as transforms
11 | from libs.utils import makeVideo, print_options
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--vgg_dir", default='models/vgg_r31.pth',
15 | help='pre-trained encoder path')
16 | parser.add_argument("--decoder_dir", default='models/dec_r31.pth',
17 | help='pre-trained decoder path')
18 | parser.add_argument("--matrix_dir", default="models/r31.pth",
19 | help='path to pre-trained model')
20 | parser.add_argument("--style", default="data/style/in2.jpg",
21 | help='path to style image')
22 | parser.add_argument("--content_dir", default="data/videos/content/mountain_2/",
23 | help='path to video frames')
24 | parser.add_argument('--loadSize', type=int, default=512,
25 | help='scale image size')
26 | parser.add_argument('--fineSize', type=int, default=512,
27 | help='crop image size')
28 | parser.add_argument("--name",default="transferred_video",
29 | help="name of generated video")
30 | parser.add_argument("--layer",default="r31",
31 | help="features of which layer to transform")
32 | parser.add_argument("--outf",default="videos",
33 | help="output folder")
34 |
35 | ################# PREPARATIONS #################
36 | opt = parser.parse_args()
37 | opt.cuda = torch.cuda.is_available()
38 | print_options(opt)
39 |
40 | os.makedirs(opt.outf,exist_ok=True)
41 | cudnn.benchmark = True
42 |
43 | ################# DATA #################
44 | def loadImg(imgPath):
45 | img = Image.open(imgPath).convert('RGB')
46 | transform = transforms.Compose([
47 | transforms.Scale(opt.fineSize),
48 | transforms.ToTensor()])
49 | return transform(img)
50 | styleV = loadImg(opt.style).unsqueeze(0)
51 |
52 | content_dataset = Dataset(opt.content_dir,
53 | loadSize = opt.loadSize,
54 | fineSize = opt.fineSize,
55 | test = True,
56 | video = True)
57 | content_loader = torch.utils.data.DataLoader(dataset = content_dataset,
58 | batch_size = 1,
59 | shuffle = False)
60 |
61 | ################# MODEL #################
62 | if(opt.layer == 'r31'):
63 | vgg = encoder3()
64 | dec = decoder3()
65 | elif(opt.layer == 'r41'):
66 | vgg = encoder4()
67 | dec = decoder4()
68 | matrix = MulLayer(layer=opt.layer)
69 | vgg.load_state_dict(torch.load(opt.vgg_dir))
70 | dec.load_state_dict(torch.load(opt.decoder_dir))
71 | matrix.load_state_dict(torch.load(opt.matrix_dir))
72 |
73 | ################# GLOBAL VARIABLE #################
74 | contentV = torch.Tensor(1,3,opt.fineSize,opt.fineSize)
75 |
76 | ################# GPU #################
77 | if(opt.cuda):
78 | vgg.cuda()
79 | dec.cuda()
80 | matrix.cuda()
81 |
82 | styleV = styleV.cuda()
83 | contentV = contentV.cuda()
84 |
85 | result_frames = []
86 | contents = []
87 | style = styleV.squeeze(0).cpu().numpy()
88 | sF = vgg(styleV)
89 |
90 | for i,(content,contentName) in enumerate(content_loader):
91 | print('Transfer frame %d...'%i)
92 | contentName = contentName[0]
93 | contentV.resize_(content.size()).copy_(content)
94 | contents.append(content.squeeze(0).float().numpy())
95 | # forward
96 | with torch.no_grad():
97 | cF = vgg(contentV)
98 |
99 | if(opt.layer == 'r41'):
100 | feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
101 | else:
102 | feature,transmatrix = matrix(cF,sF)
103 | transfer = dec(feature)
104 |
105 | transfer = transfer.clamp(0,1)
106 | result_frames.append(transfer.squeeze(0).cpu().numpy())
107 |
108 | makeVideo(contents,style,result_frames,opt.outf)
109 |
--------------------------------------------------------------------------------
/Train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from libs.Loader import Dataset
7 | from libs.Matrix import MulLayer
8 | import torchvision.utils as vutils
9 | import torch.backends.cudnn as cudnn
10 | from libs.utils import print_options
11 | from libs.Criterion import LossCriterion
12 | from libs.models import encoder3,encoder4
13 | from libs.models import decoder3,decoder4
14 | from libs.models import encoder5 as loss_network
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
18 | help='pre-trained encoder path')
19 | parser.add_argument("--loss_network_dir", default='models/vgg_r51.pth',
20 | help='used for loss network')
21 | parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
22 | help='pre-trained decoder path')
23 | parser.add_argument("--stylePath", default="/home/xtli/DATA/wikiArt/train/images/",
24 | help='path to wikiArt dataset')
25 | parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/",
26 | help='path to MSCOCO dataset')
27 | parser.add_argument("--outf", default="trainingOutput/",
28 | help='folder to output images and model checkpoints')
29 | parser.add_argument("--content_layers", default="r41",
30 | help='layers for content')
31 | parser.add_argument("--style_layers", default="r11,r21,r31,r41",
32 | help='layers for style')
33 | parser.add_argument("--batchSize", type=int,default=8,
34 | help='batch size')
35 | parser.add_argument("--niter", type=int,default=100000,
36 | help='iterations to train the model')
37 | parser.add_argument('--loadSize', type=int, default=300,
38 | help='scale image size')
39 | parser.add_argument('--fineSize', type=int, default=256,
40 | help='crop image size')
41 | parser.add_argument("--lr", type=float, default=1e-4,
42 | help='learning rate')
43 | parser.add_argument("--content_weight", type=float, default=1.0,
44 | help='content loss weight')
45 | parser.add_argument("--style_weight", type=float, default=0.02,
46 | help='style loss weight')
47 | parser.add_argument("--log_interval", type=int, default=500,
48 | help='log interval')
49 | parser.add_argument("--gpu_id", type=int, default=0,
50 | help='which gpu to use')
51 | parser.add_argument("--save_interval", type=int, default=5000,
52 | help='checkpoint save interval')
53 | parser.add_argument("--layer", default="r41",
54 | help='which features to transfer, either r31 or r41')
55 |
56 | ################# PREPARATIONS #################
57 | opt = parser.parse_args()
58 | opt.content_layers = opt.content_layers.split(',')
59 | opt.style_layers = opt.style_layers.split(',')
60 | opt.cuda = torch.cuda.is_available()
61 | if(opt.cuda):
62 | torch.cuda.set_device(opt.gpu_id)
63 |
64 | os.makedirs(opt.outf,exist_ok=True)
65 | cudnn.benchmark = True
66 | print_options(opt)
67 |
68 | ################# DATA #################
69 | content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize)
70 | content_loader_ = torch.utils.data.DataLoader(dataset = content_dataset,
71 | batch_size = opt.batchSize,
72 | shuffle = True,
73 | num_workers = 1,
74 | drop_last = True)
75 | content_loader = iter(content_loader_)
76 | style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize)
77 | style_loader_ = torch.utils.data.DataLoader(dataset = style_dataset,
78 | batch_size = opt.batchSize,
79 | shuffle = True,
80 | num_workers = 1,
81 | drop_last = True)
82 | style_loader = iter(style_loader_)
83 |
84 | ################# MODEL #################
85 | vgg5 = loss_network()
86 | if(opt.layer == 'r31'):
87 | matrix = MulLayer('r31')
88 | vgg = encoder3()
89 | dec = decoder3()
90 | elif(opt.layer == 'r41'):
91 | matrix = MulLayer('r41')
92 | vgg = encoder4()
93 | dec = decoder4()
94 | vgg.load_state_dict(torch.load(opt.vgg_dir))
95 | dec.load_state_dict(torch.load(opt.decoder_dir))
96 | vgg5.load_state_dict(torch.load(opt.loss_network_dir))
97 |
98 | for param in vgg.parameters():
99 | param.requires_grad = False
100 | for param in vgg5.parameters():
101 | param.requires_grad = False
102 | for param in dec.parameters():
103 | param.requires_grad = False
104 |
105 | ################# LOSS & OPTIMIZER #################
106 | criterion = LossCriterion(opt.style_layers,
107 | opt.content_layers,
108 | opt.style_weight,
109 | opt.content_weight)
110 | optimizer = optim.Adam(matrix.parameters(), opt.lr)
111 |
112 | ################# GLOBAL VARIABLE #################
113 | contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
114 | styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
115 |
116 | ################# GPU #################
117 | if(opt.cuda):
118 | vgg.cuda()
119 | dec.cuda()
120 | vgg5.cuda()
121 | matrix.cuda()
122 | contentV = contentV.cuda()
123 | styleV = styleV.cuda()
124 |
125 | ################# TRAINING #################
126 | def adjust_learning_rate(optimizer, iteration):
127 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
128 | for param_group in optimizer.param_groups:
129 | param_group['lr'] = opt.lr / (1+iteration*1e-5)
130 |
131 | for iteration in range(1,opt.niter+1):
132 | optimizer.zero_grad()
133 | try:
134 | content,_ = content_loader.next()
135 | except IOError:
136 | content,_ = content_loader.next()
137 | except StopIteration:
138 | content_loader = iter(content_loader_)
139 | content,_ = content_loader.next()
140 | except:
141 | continue
142 |
143 | try:
144 | style,_ = style_loader.next()
145 | except IOError:
146 | style,_ = style_loader.next()
147 | except StopIteration:
148 | style_loader = iter(style_loader_)
149 | style,_ = style_loader.next()
150 | except:
151 | continue
152 |
153 | contentV.resize_(content.size()).copy_(content)
154 | styleV.resize_(style.size()).copy_(style)
155 |
156 | # forward
157 | sF = vgg(styleV)
158 | cF = vgg(contentV)
159 |
160 | if(opt.layer == 'r41'):
161 | feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
162 | else:
163 | feature,transmatrix = matrix(cF,sF)
164 | transfer = dec(feature)
165 |
166 | sF_loss = vgg5(styleV)
167 | cF_loss = vgg5(contentV)
168 | tF = vgg5(transfer)
169 | loss,styleLoss,contentLoss = criterion(tF,sF_loss,cF_loss)
170 |
171 | # backward & optimization
172 | loss.backward()
173 | optimizer.step()
174 | print('Iteration: [%d/%d] Loss: %.4f contentLoss: %.4f styleLoss: %.4f Learng Rate is %.6f'%
175 | (opt.niter,iteration,loss,contentLoss,styleLoss,optimizer.param_groups[0]['lr']))
176 |
177 | adjust_learning_rate(optimizer,iteration)
178 |
179 | if((iteration) % opt.log_interval == 0):
180 | transfer = transfer.clamp(0,1)
181 | concat = torch.cat((content,style,transfer.cpu()),dim=0)
182 | vutils.save_image(concat,'%s/%d.png'%(opt.outf,iteration),normalize=True,scale_each=True,nrow=opt.batchSize)
183 |
184 | if(iteration > 0 and (iteration) % opt.save_interval == 0):
185 | torch.save(matrix.state_dict(), '%s/%s.pth' % (opt.outf,opt.layer))
186 |
--------------------------------------------------------------------------------
/TrainSPN.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import argparse
4 |
5 | from libs.SPN import SPN
6 | from libs.Loader import Dataset
7 | from libs.models import encoder4
8 | from libs.models import decoder4
9 | from libs.utils import print_options
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.optim as optim
14 | import torch.nn.functional as F
15 | import torchvision.utils as vutils
16 | import torch.backends.cudnn as cudnn
17 | import torchvision.transforms as transforms
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
21 | help='pre-trained encoder path')
22 | parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
23 | help='pre-trained decoder path')
24 | parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/",
25 | help='path to MSCOCO dataset')
26 | parser.add_argument("--outf", default="trainingSPNOutput/",
27 | help='folder to output images and model checkpoints')
28 | parser.add_argument("--layer", default="r41",
29 | help='layers for content')
30 | parser.add_argument("--batchSize", type=int,default=8,
31 | help='batch size')
32 | parser.add_argument("--niter", type=int,default=100000,
33 | help='iterations to train the model')
34 | parser.add_argument('--loadSize', type=int, default=512,
35 | help='scale image size')
36 | parser.add_argument('--fineSize', type=int, default=256,
37 | help='crop image size')
38 | parser.add_argument("--lr", type=float, default=1e-3,
39 | help='learning rate')
40 | parser.add_argument("--log_interval", type=int, default=500,
41 | help='log interval')
42 | parser.add_argument("--save_interval", type=int, default=5000,
43 | help='checkpoint save interval')
44 | parser.add_argument("--spn_num", type=int, default=1,
45 | help='number of spn filters')
46 |
47 | ################# PREPARATIONS #################
48 | opt = parser.parse_args()
49 | opt.cuda = torch.cuda.is_available()
50 | print_options(opt)
51 |
52 |
53 | os.makedirs(opt.outf, exist_ok = True)
54 |
55 | cudnn.benchmark = True
56 |
57 | ################# DATA #################
58 | content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize)
59 | content_loader_ = torch.utils.data.DataLoader(dataset=content_dataset,
60 | batch_size = opt.batchSize,
61 | shuffle = True,
62 | num_workers = 4,
63 | drop_last = True)
64 | content_loader = iter(content_loader_)
65 |
66 | ################# MODEL #################
67 | spn = SPN(spn=opt.spn_num)
68 | if(opt.layer == 'r31'):
69 | vgg = encoder3()
70 | dec = decoder3()
71 | elif(opt.layer == 'r41'):
72 | vgg = encoder4()
73 | dec = decoder4()
74 | vgg.load_state_dict(torch.load(opt.vgg_dir))
75 | dec.load_state_dict(torch.load(opt.decoder_dir))
76 |
77 | for param in vgg.parameters():
78 | param.requires_grad = False
79 | for param in dec.parameters():
80 | param.requires_grad = False
81 |
82 | ################# LOSS & OPTIMIZER #################
83 | criterion = nn.MSELoss(size_average=False)
84 | #optimizer_spn = optim.SGD(spn.parameters(), opt.lr)
85 | optimizer_spn = optim.Adam(spn.parameters(), opt.lr)
86 |
87 | ################# GLOBAL VARIABLE #################
88 | contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
89 |
90 | ################# GPU #################
91 | if(opt.cuda):
92 | vgg.cuda()
93 | dec.cuda()
94 | spn.cuda()
95 | contentV = contentV.cuda()
96 |
97 | ################# TRAINING #################
98 | def adjust_learning_rate(optimizer, iteration):
99 | for param_group in optimizer.param_groups:
100 | param_group['lr'] = opt.lr / (1+iteration*1e-5)
101 |
102 | spn.train()
103 | for iteration in range(1,opt.niter+1):
104 | optimizer_spn.zero_grad()
105 | try:
106 | content,_ = content_loader.next()
107 | except IOError:
108 | content,_ = content_loader.next()
109 | except StopIteration:
110 | content_loader = iter(content_loader_)
111 | content,_ = content_loader.next()
112 | except:
113 | continue
114 |
115 | contentV.resize_(content.size()).copy_(content)
116 |
117 | # forward
118 | cF = vgg(contentV)
119 | transfer = dec(cF['r41'])
120 |
121 |
122 | propagated = spn(transfer,contentV)
123 | loss = criterion(propagated,contentV)
124 |
125 | # backward & optimization
126 | loss.backward()
127 | #nn.utils.clip_grad_norm(spn.parameters(), 1000)
128 | optimizer_spn.step()
129 | print('Iteration: [%d/%d] Loss: %.4f Learng Rate is %.6f'
130 | %(opt.niter,iteration,loss,optimizer_spn.param_groups[0]['lr']))
131 |
132 | adjust_learning_rate(optimizer_spn,iteration)
133 |
134 | if((iteration) % opt.log_interval == 0):
135 | transfer = transfer.clamp(0,1)
136 | propagated = propagated.clamp(0,1)
137 | vutils.save_image(transfer,'%s/%d_transfer.png'%(opt.outf,iteration))
138 | vutils.save_image(propagated,'%s/%d_propagated.png'%(opt.outf,iteration))
139 |
140 | if(iteration > 0 and (iteration) % opt.save_interval == 0):
141 | torch.save(spn.state_dict(), '%s/%s_spn.pth' % (opt.outf,opt.layer))
142 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/.DS_Store
--------------------------------------------------------------------------------
/data/content/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/content/.DS_Store
--------------------------------------------------------------------------------
/data/content/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/content/1.jpg
--------------------------------------------------------------------------------
/data/content/chicago.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/content/chicago.png
--------------------------------------------------------------------------------
/data/photo_real/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/.DS_Store
--------------------------------------------------------------------------------
/data/photo_real/content/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/.DS_Store
--------------------------------------------------------------------------------
/data/photo_real/content/images/in16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in16.png
--------------------------------------------------------------------------------
/data/photo_real/content/images/in25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in25.png
--------------------------------------------------------------------------------
/data/photo_real/content/images/in26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in26.png
--------------------------------------------------------------------------------
/data/photo_real/content/images/in29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in29.png
--------------------------------------------------------------------------------
/data/photo_real/content/images/in3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in3.png
--------------------------------------------------------------------------------
/data/photo_real/content/images/in39.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in39.png
--------------------------------------------------------------------------------
/data/photo_real/content/images/in53.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in53.png
--------------------------------------------------------------------------------
/data/photo_real/content/images/in7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/content/images/in7.png
--------------------------------------------------------------------------------
/data/photo_real/contentSeg/in16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in16.png
--------------------------------------------------------------------------------
/data/photo_real/contentSeg/in25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in25.png
--------------------------------------------------------------------------------
/data/photo_real/contentSeg/in26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in26.png
--------------------------------------------------------------------------------
/data/photo_real/contentSeg/in29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in29.png
--------------------------------------------------------------------------------
/data/photo_real/contentSeg/in3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in3.png
--------------------------------------------------------------------------------
/data/photo_real/contentSeg/in39.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in39.png
--------------------------------------------------------------------------------
/data/photo_real/contentSeg/in53.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in53.png
--------------------------------------------------------------------------------
/data/photo_real/contentSeg/in7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/contentSeg/in7.png
--------------------------------------------------------------------------------
/data/photo_real/style/images/in16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in16.png
--------------------------------------------------------------------------------
/data/photo_real/style/images/in25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in25.png
--------------------------------------------------------------------------------
/data/photo_real/style/images/in26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in26.png
--------------------------------------------------------------------------------
/data/photo_real/style/images/in29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in29.png
--------------------------------------------------------------------------------
/data/photo_real/style/images/in3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in3.png
--------------------------------------------------------------------------------
/data/photo_real/style/images/in39.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in39.png
--------------------------------------------------------------------------------
/data/photo_real/style/images/in53.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in53.png
--------------------------------------------------------------------------------
/data/photo_real/style/images/in7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/style/images/in7.png
--------------------------------------------------------------------------------
/data/photo_real/styleSeg/in16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in16.png
--------------------------------------------------------------------------------
/data/photo_real/styleSeg/in25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in25.png
--------------------------------------------------------------------------------
/data/photo_real/styleSeg/in26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in26.png
--------------------------------------------------------------------------------
/data/photo_real/styleSeg/in29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in29.png
--------------------------------------------------------------------------------
/data/photo_real/styleSeg/in3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in3.png
--------------------------------------------------------------------------------
/data/photo_real/styleSeg/in39.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in39.png
--------------------------------------------------------------------------------
/data/photo_real/styleSeg/in53.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in53.png
--------------------------------------------------------------------------------
/data/photo_real/styleSeg/in7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/photo_real/styleSeg/in7.png
--------------------------------------------------------------------------------
/data/style/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/.DS_Store
--------------------------------------------------------------------------------
/data/style/27.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/27.jpg
--------------------------------------------------------------------------------
/data/style/3314.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/3314.jpg
--------------------------------------------------------------------------------
/data/style/antimonocromatismo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/antimonocromatismo.jpg
--------------------------------------------------------------------------------
/data/style/in2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/in2.jpg
--------------------------------------------------------------------------------
/data/style/picasso_self_portrait.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/picasso_self_portrait.jpg
--------------------------------------------------------------------------------
/data/style/sketch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/style/sketch.jpg
--------------------------------------------------------------------------------
/data/videos/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/.DS_Store
--------------------------------------------------------------------------------
/data/videos/content/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/.DS_Store
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/.DS_Store
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0001.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0002.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0003.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0004.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0005.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0006.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0007.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0008.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0009.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0010.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0011.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0011.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0012.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0012.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0013.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0013.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0014.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0014.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0015.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0015.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0016.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0017.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0018.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0019.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0020.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0021.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0021.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0022.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0022.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0023.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0024.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0025.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0026.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0026.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0027.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0027.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0028.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0028.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0029.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0029.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0030.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0030.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0031.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0031.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0032.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0032.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0033.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0033.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0034.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0034.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0035.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0035.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0036.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0036.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0037.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0037.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0038.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0038.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0039.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0039.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0040.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0040.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0041.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0041.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0042.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0042.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0043.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0043.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0044.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0044.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0045.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0045.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0046.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0046.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0047.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0047.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0048.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0048.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0049.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0049.png
--------------------------------------------------------------------------------
/data/videos/content/mountain_2/frame_0050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/data/videos/content/mountain_2/frame_0050.png
--------------------------------------------------------------------------------
/doc/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/.DS_Store
--------------------------------------------------------------------------------
/doc/images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/.DS_Store
--------------------------------------------------------------------------------
/doc/images/chicago_27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/chicago_27.png
--------------------------------------------------------------------------------
/doc/images/chicago_paste.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/chicago_paste.png
--------------------------------------------------------------------------------
/doc/images/content.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/content.gif
--------------------------------------------------------------------------------
/doc/images/in5_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/in5_result.png
--------------------------------------------------------------------------------
/doc/images/photo_content.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/photo_content.png
--------------------------------------------------------------------------------
/doc/images/test.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/doc/images/test.gif
--------------------------------------------------------------------------------
/libs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/.DS_Store
--------------------------------------------------------------------------------
/libs/Criterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class styleLoss(nn.Module):
5 | def forward(self,input,target):
6 | ib,ic,ih,iw = input.size()
7 | iF = input.view(ib,ic,-1)
8 | iMean = torch.mean(iF,dim=2)
9 | iCov = GramMatrix()(input)
10 |
11 | tb,tc,th,tw = target.size()
12 | tF = target.view(tb,tc,-1)
13 | tMean = torch.mean(tF,dim=2)
14 | tCov = GramMatrix()(target)
15 |
16 | loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
17 | return loss/tb
18 |
19 | class GramMatrix(nn.Module):
20 | def forward(self,input):
21 | b, c, h, w = input.size()
22 | f = input.view(b,c,h*w) # bxcx(hxw)
23 | # torch.bmm(batch1, batch2, out=None) #
24 | # batch1: bxmxp, batch2: bxpxn -> bxmxn #
25 | G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
26 | return G.div_(c*h*w)
27 |
28 | class LossCriterion(nn.Module):
29 | def __init__(self,style_layers,content_layers,style_weight,content_weight):
30 | super(LossCriterion,self).__init__()
31 |
32 | self.style_layers = style_layers
33 | self.content_layers = content_layers
34 | self.style_weight = style_weight
35 | self.content_weight = content_weight
36 |
37 | self.styleLosses = [styleLoss()] * len(style_layers)
38 | self.contentLosses = [nn.MSELoss()] * len(content_layers)
39 |
40 | def forward(self,tF,sF,cF):
41 | # content loss
42 | totalContentLoss = 0
43 | for i,layer in enumerate(self.content_layers):
44 | cf_i = cF[layer]
45 | cf_i = cf_i.detach()
46 | tf_i = tF[layer]
47 | loss_i = self.contentLosses[i]
48 | totalContentLoss += loss_i(tf_i,cf_i)
49 | totalContentLoss = totalContentLoss * self.content_weight
50 |
51 | # style loss
52 | totalStyleLoss = 0
53 | for i,layer in enumerate(self.style_layers):
54 | sf_i = sF[layer]
55 | sf_i = sf_i.detach()
56 | tf_i = tF[layer]
57 | loss_i = self.styleLosses[i]
58 | totalStyleLoss += loss_i(tf_i,sf_i)
59 | totalStyleLoss = totalStyleLoss * self.style_weight
60 | loss = totalStyleLoss + totalContentLoss
61 |
62 | return loss,totalStyleLoss,totalContentLoss
63 |
--------------------------------------------------------------------------------
/libs/Loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 |
6 | def is_image_file(filename):
7 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
8 |
9 | def default_loader(path):
10 | return Image.open(path).convert('RGB')
11 |
12 | class Dataset(data.Dataset):
13 | def __init__(self,dataPath,loadSize,fineSize,test=False,video=False):
14 | super(Dataset,self).__init__()
15 | self.dataPath = dataPath
16 | self.image_list = [x for x in os.listdir(dataPath) if is_image_file(x)]
17 | self.image_list = sorted(self.image_list)
18 | if(video):
19 | self.image_list = sorted(self.image_list)
20 | if not test:
21 | self.transform = transforms.Compose([
22 | transforms.Resize(fineSize),
23 | transforms.RandomCrop(fineSize),
24 | transforms.RandomHorizontalFlip(),
25 | transforms.ToTensor()])
26 | else:
27 | self.transform = transforms.Compose([
28 | transforms.Resize(fineSize),
29 | transforms.ToTensor()])
30 |
31 | self.test = test
32 |
33 | def __getitem__(self,index):
34 | dataPath = os.path.join(self.dataPath,self.image_list[index])
35 |
36 | Img = default_loader(dataPath)
37 | ImgA = self.transform(Img)
38 |
39 | imgName = self.image_list[index]
40 | imgName = imgName.split('.')[0]
41 | return ImgA,imgName
42 |
43 | def __len__(self):
44 | return len(self.image_list)
45 |
--------------------------------------------------------------------------------
/libs/LoaderPhotoReal.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torchvision.transforms as transforms
3 | import torchvision.utils as vutils
4 | import torch.utils.data as data
5 | from os import listdir
6 | from os.path import join
7 | import numpy as np
8 | import torch
9 | import os
10 | import torch.nn as nn
11 | from torch.autograd import Variable
12 | import numpy as np
13 | from libs.utils import whiten
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
17 |
18 | def default_loader(path,fineSize):
19 | img = Image.open(path).convert('RGB')
20 | w,h = img.size
21 | if(w < h):
22 | neww = fineSize
23 | newh = h * neww / w
24 | newh = int(newh / 8) * 8
25 | else:
26 | newh = fineSize
27 | neww = w * newh / h
28 | neww = int(neww / 8) * 8
29 | img = img.resize((neww,newh))
30 | return img
31 |
32 | def MaskHelper(seg,color):
33 | # green
34 | mask = torch.Tensor()
35 | if(color == 'green'):
36 | mask = torch.lt(seg[0],0.1)
37 | mask = torch.mul(mask,torch.gt(seg[1],1-0.1))
38 | mask = torch.mul(mask,torch.lt(seg[2],0.1))
39 | elif(color == 'black'):
40 | mask = torch.lt(seg[0], 0.1)
41 | mask = torch.mul(mask,torch.lt(seg[1], 0.1))
42 | mask = torch.mul(mask,torch.lt(seg[2], 0.1))
43 | elif(color == 'white'):
44 | mask = torch.gt(seg[0], 1-0.1)
45 | mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
46 | mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
47 | elif(color == 'red'):
48 | mask = torch.gt(seg[0], 1-0.1)
49 | mask = torch.mul(mask,torch.lt(seg[1], 0.1))
50 | mask = torch.mul(mask,torch.lt(seg[2], 0.1))
51 | elif(color == 'blue'):
52 | mask = torch.lt(seg[0], 0.1)
53 | mask = torch.mul(mask,torch.lt(seg[1], 0.1))
54 | mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
55 | elif(color == 'yellow'):
56 | mask = torch.gt(seg[0], 1-0.1)
57 | mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
58 | mask = torch.mul(mask,torch.lt(seg[2], 0.1))
59 | elif(color == 'grey'):
60 | mask = torch.lt(seg[0], 0.1)
61 | mask = torch.mul(mask,torch.lt(seg[1], 0.1))
62 | mask = torch.mul(mask,torch.lt(seg[2], 0.1))
63 | elif(color == 'lightblue'):
64 | mask = torch.lt(seg[0], 0.1)
65 | mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
66 | mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
67 | elif(color == 'purple'):
68 | mask = torch.gt(seg[0], 1-0.1)
69 | mask = torch.mul(mask,torch.lt(seg[1], 0.1))
70 | mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
71 | else:
72 | print('MaskHelper(): color not recognized, color = ' + color)
73 | return mask.float()
74 |
75 | def ExtractMask(Seg):
76 | # Given segmentation for content and style, we get a list of segmentation for each color
77 | '''
78 | Test Code:
79 | content_masks,style_masks = ExtractMask(contentSegImg,styleSegImg)
80 | for i,mask in enumerate(content_masks):
81 | vutils.save_image(mask,'samples/content_%d.png' % (i),normalize=True)
82 | for i,mask in enumerate(style_masks):
83 | vutils.save_image(mask,'samples/style_%d.png' % (i),normalize=True)
84 | '''
85 | color_codes = ['blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple']
86 | masks = []
87 | for color in color_codes:
88 | mask = MaskHelper(Seg,color)
89 | masks.append(mask)
90 | return masks
91 |
92 | def calculate_size(h,w,fineSize):
93 | if(h > w):
94 | newh = fineSize
95 | neww = int(w * 1.0 * newh / h)
96 | else:
97 | neww = fineSize
98 | newh = int(h * 1.0 * neww / w)
99 | newh = (newh // 8) * 8
100 | neww = (neww // 8) * 8
101 | return neww, newh
102 |
103 | class Dataset(data.Dataset):
104 | def __init__(self,contentPath,stylePath,contentSegPath,styleSegPath,fineSize):
105 | super(Dataset,self).__init__()
106 | self.contentPath = contentPath
107 | self.image_list = [x for x in listdir(contentPath) if is_image_file(x)]
108 | self.stylePath = stylePath
109 | self.contentSegPath = contentSegPath
110 | self.styleSegPath = styleSegPath
111 | self.fineSize = fineSize
112 |
113 | def __getitem__(self,index):
114 | contentImgPath = os.path.join(self.contentPath,self.image_list[index])
115 | styleImgPath = os.path.join(self.stylePath,self.image_list[index])
116 | contentImg = default_loader(contentImgPath,self.fineSize)
117 | styleImg = default_loader(styleImgPath,self.fineSize)
118 |
119 | try:
120 | contentSegImgPath = os.path.join(self.contentSegPath,self.image_list[index])
121 | contentSegImg = default_loader(contentSegImgPath,self.fineSize)
122 | except :
123 | print('no mask provided, fake a whole black one')
124 | contentSegImg = Image.new('RGB', (contentImg.size))
125 |
126 | try:
127 | styleSegImgPath = os.path.join(self.styleSegPath,self.image_list[index])
128 | styleSegImg = default_loader(styleSegImgPath,self.fineSize)
129 | except :
130 | print('no mask provided, fake a whole black one')
131 | styleSegImg = Image.new('RGB', (styleImg.size))
132 |
133 |
134 | hs, ws = styleImg.size
135 | newhs, newws = calculate_size(hs,ws,self.fineSize)
136 |
137 | transform = transforms.Compose([
138 | transforms.Resize((newhs, newws)),
139 | transforms.ToTensor()])
140 | # Turning segmentation images into masks
141 | styleSegImg = transform(styleSegImg)
142 | styleImgArbi = transform(styleImg)
143 |
144 | hc, wc = contentImg.size
145 | newhc, newwc = calculate_size(hc,wc,self.fineSize)
146 |
147 | transform = transforms.Compose([
148 | transforms.Resize((newhc, newwc)),
149 | transforms.ToTensor()])
150 | contentSegImg = transform(contentSegImg)
151 | contentImgArbi = transform(contentImg)
152 |
153 | content_masks = ExtractMask(contentSegImg)
154 | style_masks = ExtractMask(styleSegImg)
155 |
156 | ImgW = whiten(contentImgArbi.view(3,-1).double())
157 | ImgW = ImgW.view(contentImgArbi.size()).float()
158 |
159 | return contentImgArbi.squeeze(0),styleImgArbi.squeeze(0),ImgW,content_masks,style_masks,self.image_list[index]
160 |
161 | def __len__(self):
162 | return len(self.image_list)
163 |
--------------------------------------------------------------------------------
/libs/Matrix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class CNN(nn.Module):
5 | def __init__(self,layer,matrixSize=32):
6 | super(CNN,self).__init__()
7 | if(layer == 'r31'):
8 | # 256x64x64
9 | self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
10 | nn.ReLU(inplace=True),
11 | nn.Conv2d(128,64,3,1,1),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(64,matrixSize,3,1,1))
14 | elif(layer == 'r41'):
15 | # 512x32x32
16 | self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(256,128,3,1,1),
19 | nn.ReLU(inplace=True),
20 | nn.Conv2d(128,matrixSize,3,1,1))
21 |
22 | # 32x8x8
23 | self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize)
24 | #self.fc = nn.Linear(32*64,256*256)
25 |
26 | def forward(self,x):
27 | out = self.convs(x)
28 | # 32x8x8
29 | b,c,h,w = out.size()
30 | out = out.view(b,c,-1)
31 | # 32x64
32 | out = torch.bmm(out,out.transpose(1,2)).div(h*w)
33 | # 32x32
34 | out = out.view(out.size(0),-1)
35 | return self.fc(out)
36 |
37 | class MulLayer(nn.Module):
38 | def __init__(self,layer,matrixSize=32):
39 | super(MulLayer,self).__init__()
40 | self.snet = CNN(layer,matrixSize)
41 | self.cnet = CNN(layer,matrixSize)
42 | self.matrixSize = matrixSize
43 |
44 | if(layer == 'r41'):
45 | self.compress = nn.Conv2d(512,matrixSize,1,1,0)
46 | self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
47 | elif(layer == 'r31'):
48 | self.compress = nn.Conv2d(256,matrixSize,1,1,0)
49 | self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
50 | self.transmatrix = None
51 |
52 | def forward(self,cF,sF,trans=True):
53 | cFBK = cF.clone()
54 | cb,cc,ch,cw = cF.size()
55 | cFF = cF.view(cb,cc,-1)
56 | cMean = torch.mean(cFF,dim=2,keepdim=True)
57 | cMean = cMean.unsqueeze(3)
58 | cMean = cMean.expand_as(cF)
59 | cF = cF - cMean
60 |
61 | sb,sc,sh,sw = sF.size()
62 | sFF = sF.view(sb,sc,-1)
63 | sMean = torch.mean(sFF,dim=2,keepdim=True)
64 | sMean = sMean.unsqueeze(3)
65 | sMeanC = sMean.expand_as(cF)
66 | sMeanS = sMean.expand_as(sF)
67 | sF = sF - sMeanS
68 |
69 |
70 | compress_content = self.compress(cF)
71 | b,c,h,w = compress_content.size()
72 | compress_content = compress_content.view(b,c,-1)
73 |
74 | if(trans):
75 | cMatrix = self.cnet(cF)
76 | sMatrix = self.snet(sF)
77 |
78 | sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize)
79 | cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize)
80 | transmatrix = torch.bmm(sMatrix,cMatrix)
81 | transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w)
82 | out = self.unzip(transfeature.view(b,c,h,w))
83 | out = out + sMeanC
84 | return out, transmatrix
85 | else:
86 | out = self.unzip(compress_content.view(b,c,h,w))
87 | out = out + cMean
88 | return out
89 |
--------------------------------------------------------------------------------
/libs/MatrixTest.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import cv2
6 | from torch.autograd import Variable
7 | import torchvision.utils as vutils
8 |
9 |
10 | class CNN(nn.Module):
11 | def __init__(self,layer,matrixSize=32):
12 | super(CNN,self).__init__()
13 | # 256x64x64
14 | if(layer == 'r31'):
15 | self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(128,64,3,1,1),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(64,matrixSize,3,1,1))
20 | elif(layer == 'r41'):
21 | # 512x32x32
22 | self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
23 | nn.ReLU(inplace=True),
24 | nn.Conv2d(256,128,3,1,1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(128,matrixSize,3,1,1))
27 | self.fc = nn.Linear(32*32,32*32)
28 |
29 | def forward(self,x,masks,style=False):
30 | color_code_number = 9
31 | xb,xc,xh,xw = x.size()
32 | x = x.view(xc,-1)
33 | feature_sub_mean = x.clone()
34 | for i in range(color_code_number):
35 | mask = masks[i].clone().squeeze(0)
36 | mask = cv2.resize(mask.numpy(),(xw,xh),interpolation=cv2.INTER_NEAREST)
37 | mask = torch.FloatTensor(mask)
38 | mask = mask.long()
39 | if(torch.sum(mask) >= 10):
40 | mask = mask.view(-1)
41 |
42 | # dilation here
43 | """
44 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
45 | mask = mask.cpu().numpy()
46 | mask = cv2.dilate(mask.astype(np.float32), kernel)
47 | mask = torch.from_numpy(mask)
48 | mask = mask.squeeze()
49 | """
50 |
51 | fgmask = (mask>0).nonzero().squeeze(1)
52 | fgmask = fgmask.cuda()
53 | selectFeature = torch.index_select(x,1,fgmask) # 32x96
54 | # subtract mean
55 | f_mean = torch.mean(selectFeature,1)
56 | f_mean = f_mean.unsqueeze(1).expand_as(selectFeature)
57 | selectFeature = selectFeature - f_mean
58 | feature_sub_mean.index_copy_(1,fgmask,selectFeature)
59 |
60 | feature = self.convs(feature_sub_mean.view(xb,xc,xh,xw))
61 | # 32x16x16
62 | b,c,h,w = feature.size()
63 | transMatrices = {}
64 | feature = feature.view(c,-1)
65 |
66 | for i in range(color_code_number):
67 | mask = masks[i].clone().squeeze(0)
68 | mask = cv2.resize(mask.numpy(),(w,h),interpolation=cv2.INTER_NEAREST)
69 | mask = torch.FloatTensor(mask)
70 | mask = mask.long()
71 | if(torch.sum(mask) >= 10):
72 | mask = mask.view(-1)
73 | fgmask = Variable((mask==1).nonzero().squeeze(1))
74 | fgmask = fgmask.cuda()
75 | selectFeature = torch.index_select(feature,1,fgmask) # 32x96
76 | tc,tN = selectFeature.size()
77 |
78 | covMatrix = torch.mm(selectFeature,selectFeature.transpose(0,1)).div(tN)
79 | transmatrix = self.fc(covMatrix.view(-1))
80 | transMatrices[i] = transmatrix
81 | return transMatrices,feature_sub_mean
82 |
83 | class MulLayer(nn.Module):
84 | def __init__(self,layer,matrixSize=32):
85 | super(MulLayer,self).__init__()
86 | self.snet = CNN(layer)
87 | self.cnet = CNN(layer)
88 | self.matrixSize = matrixSize
89 |
90 | if(layer == 'r41'):
91 | self.compress = nn.Conv2d(512,matrixSize,1,1,0)
92 | self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
93 | elif(layer == 'r31'):
94 | self.compress = nn.Conv2d(256,matrixSize,1,1,0)
95 | self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
96 |
97 | def forward(self,cF,sF,cmasks,smasks):
98 |
99 | sb,sc,sh,sw = sF.size()
100 |
101 | sMatrices,sF_sub_mean = self.snet(sF,smasks,style=True)
102 | cMatrices,cF_sub_mean = self.cnet(cF,cmasks,style=False)
103 |
104 | compress_content = self.compress(cF_sub_mean.view(cF.size()))
105 | cb,cc,ch,cw = compress_content.size()
106 | compress_content = compress_content.view(cc,-1)
107 | transfeature = compress_content.clone()
108 | color_code_number = 9
109 | finalSMean = Variable(torch.zeros(cF.size()).cuda(0))
110 | finalSMean = finalSMean.view(sc,-1)
111 | for i in range(color_code_number):
112 | cmask = cmasks[i].clone().squeeze(0)
113 | smask = smasks[i].clone().squeeze(0)
114 |
115 | cmask = cv2.resize(cmask.numpy(),(cw,ch),interpolation=cv2.INTER_NEAREST)
116 | cmask = torch.FloatTensor(cmask)
117 | cmask = cmask.long()
118 | smask = cv2.resize(smask.numpy(),(sw,sh),interpolation=cv2.INTER_NEAREST)
119 | smask = torch.FloatTensor(smask)
120 | smask = smask.long()
121 | if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10
122 | and (i in sMatrices) and (i in cMatrices)):
123 | cmask = cmask.view(-1)
124 | fgcmask = Variable((cmask==1).nonzero().squeeze(1))
125 | fgcmask = fgcmask.cuda()
126 |
127 | smask = smask.view(-1)
128 | fgsmask = Variable((smask==1).nonzero().squeeze(1))
129 | fgsmask = fgsmask.cuda()
130 |
131 | sFF = sF.view(sc,-1)
132 | sFF_select = torch.index_select(sFF,1,fgsmask)
133 | sMean = torch.mean(sFF_select,dim=1,keepdim=True)
134 | sMean = sMean.view(1,sc,1,1)
135 | sMean = sMean.expand_as(cF)
136 |
137 | sMatrix = sMatrices[i]
138 | cMatrix = cMatrices[i]
139 |
140 | sMatrix = sMatrix.view(self.matrixSize,self.matrixSize)
141 | cMatrix = cMatrix.view(self.matrixSize,self.matrixSize)
142 |
143 | transmatrix = torch.mm(sMatrix,cMatrix) # (C*C)
144 |
145 | compress_content_select = torch.index_select(compress_content,1,fgcmask)
146 |
147 | transfeatureFG = torch.mm(transmatrix,compress_content_select)
148 | transfeature.index_copy_(1,fgcmask,transfeatureFG)
149 |
150 | sMean = sMean.contiguous()
151 | sMean_select = torch.index_select(sMean.view(sc,-1),1,fgcmask)
152 | finalSMean.index_copy_(1,fgcmask,sMean_select)
153 | out = self.unzip(transfeature.view(cb,cc,ch,cw))
154 | return out + finalSMean.view(out.size())
155 |
--------------------------------------------------------------------------------
/libs/SPN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models import vgg16
4 | from torch.autograd import Variable
5 | from collections import OrderedDict
6 | import torch.nn.functional as F
7 | import sys
8 | sys.path.append('../')
9 | from libs.pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind
10 |
11 | class spn_block(nn.Module):
12 | def __init__(self, horizontal, reverse):
13 | super(spn_block, self).__init__()
14 | self.propagator = GateRecurrent2dnoind(horizontal,reverse)
15 |
16 | def forward(self,x,G1,G2,G3):
17 | sum_abs = G1.abs() + G2.abs() + G3.abs()
18 | sum_abs.data[sum_abs.data == 0] = 1e-6
19 | mask_need_norm = sum_abs.ge(1)
20 | mask_need_norm = mask_need_norm.float()
21 | G1_norm = torch.div(G1, sum_abs)
22 | G2_norm = torch.div(G2, sum_abs)
23 | G3_norm = torch.div(G3, sum_abs)
24 |
25 | G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
26 | G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
27 | G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
28 |
29 | return self.propagator(x,G1,G2,G3)
30 |
31 | class VGG(nn.Module):
32 | def __init__(self,nf):
33 | super(VGG,self).__init__()
34 | self.conv1 = nn.Conv2d(3,nf,3,padding = 1)
35 | # 256 x 256
36 | self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
37 | self.conv2 = nn.Conv2d(nf,nf*2,3,padding = 1)
38 | # 128 x 128
39 | self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
40 | self.conv3 = nn.Conv2d(nf*2,nf*4,3,padding = 1)
41 | # 64 x 64
42 | self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
43 | # 32 x 32
44 | self.conv4 = nn.Conv2d(nf*4,nf*8,3,padding = 1)
45 |
46 | def forward(self,x):
47 | output = {}
48 | output['conv1'] = self.conv1(x)
49 | x = F.relu(output['conv1'])
50 | x = self.pool1(x)
51 | output['conv2'] = self.conv2(x)
52 | # 128 x 128
53 | x = F.relu(output['conv2'])
54 | x = self.pool2(x)
55 | output['conv3'] = self.conv3(x)
56 | # 64 x 64
57 | x = F.relu(output['conv3'])
58 | output['pool3'] = self.pool3(x)
59 | # 32 x 32
60 | output['conv4'] = self.conv4(output['pool3'])
61 | return output
62 |
63 | class Decoder(nn.Module):
64 | def __init__(self,nf=32,spn=1):
65 | super(Decoder,self).__init__()
66 | # 32 x 32
67 | self.layer0 = nn.Conv2d(nf*8,nf*4,1,1,0) # edge_conv5
68 | self.layer1 = nn.Upsample(scale_factor=2,mode='bilinear')
69 | self.layer2 = nn.Sequential(nn.Conv2d(nf*4,nf*4,3,1,1), # edge_conv8
70 | nn.ELU(inplace=True))
71 | # 64 x 64
72 | self.layer3 = nn.Upsample(scale_factor=2,mode='bilinear')
73 | self.layer4 = nn.Sequential(nn.Conv2d(nf*4,nf*2,3,1,1), # edge_conv8
74 | nn.ELU(inplace=True))
75 | # 128 x 128
76 | self.layer5 = nn.Upsample(scale_factor=2,mode='bilinear')
77 | self.layer6 = nn.Sequential(nn.Conv2d(nf*2,nf,3,1,1), # edge_conv8
78 | nn.ELU(inplace=True))
79 | if(spn == 1):
80 | self.layer7 = nn.Conv2d(nf,nf*12,3,1,1)
81 | else:
82 | self.layer7 = nn.Conv2d(nf,nf*24,3,1,1)
83 | self.spn = spn
84 | # 256 x 256
85 |
86 | def forward(self,encode_feature):
87 | output = {}
88 | output['0'] = self.layer0(encode_feature['conv4'])
89 | output['1'] = self.layer1(output['0'])
90 |
91 | output['2'] = self.layer2(output['1'])
92 | output['2res'] = output['2'] + encode_feature['conv3']
93 | # 64 x 64
94 |
95 | output['3'] = self.layer3(output['2res'])
96 | output['4'] = self.layer4(output['3'])
97 | output['4res'] = output['4'] + encode_feature['conv2']
98 | # 128 x 128
99 |
100 | output['5'] = self.layer5(output['4res'])
101 | output['6'] = self.layer6(output['5'])
102 | output['6res'] = output['6'] + encode_feature['conv1']
103 |
104 | output['7'] = self.layer7(output['6res'])
105 |
106 | return output['7']
107 |
108 |
109 | class SPN(nn.Module):
110 | def __init__(self,nf=32,spn=1):
111 | super(SPN,self).__init__()
112 | # conv for mask
113 | self.mask_conv = nn.Conv2d(3,nf,3,1,1)
114 |
115 | # guidance network
116 | self.encoder = VGG(nf)
117 | self.decoder = Decoder(nf,spn)
118 |
119 | # spn blocks
120 | self.left_right = spn_block(True,False)
121 | self.right_left = spn_block(True,True)
122 | self.top_down = spn_block(False, False)
123 | self.down_top = spn_block(False,True)
124 |
125 | # post upsample
126 | self.post = nn.Conv2d(nf,3,3,1,1)
127 | self.nf = nf
128 |
129 | def forward(self,x,rgb):
130 | # feature for mask
131 | X = self.mask_conv(x)
132 |
133 | # guidance
134 | features = self.encoder(rgb)
135 | guide = self.decoder(features)
136 |
137 | G = torch.split(guide,self.nf,1)
138 | out1 = self.left_right(X,G[0],G[1],G[2])
139 | out2 = self.right_left(X,G[3],G[4],G[5])
140 | out3 = self.top_down(X,G[6],G[7],G[8])
141 | out4 = self.down_top(X,G[9],G[10],G[11])
142 |
143 | out = torch.max(out1,out2)
144 | out = torch.max(out,out3)
145 | out = torch.max(out,out4)
146 |
147 | return self.post(out)
148 |
149 | if __name__ == '__main__':
150 | spn = SPN()
151 | spn = spn.cuda()
152 | for i in range(100):
153 | x = Variable(torch.Tensor(1,3,256,256)).cuda()
154 | rgb = Variable(torch.Tensor(1,3,256,256)).cuda()
155 | output = spn(x,rgb)
156 | print(output.size())
157 |
--------------------------------------------------------------------------------
/libs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/__init__.py
--------------------------------------------------------------------------------
/libs/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class encoder3(nn.Module):
5 | def __init__(self):
6 | super(encoder3,self).__init__()
7 | # vgg
8 | # 224 x 224
9 | self.conv1 = nn.Conv2d(3,3,1,1,0)
10 | self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
11 | # 226 x 226
12 |
13 | self.conv2 = nn.Conv2d(3,64,3,1,0)
14 | self.relu2 = nn.ReLU(inplace=True)
15 | # 224 x 224
16 |
17 | self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
18 | self.conv3 = nn.Conv2d(64,64,3,1,0)
19 | self.relu3 = nn.ReLU(inplace=True)
20 | # 224 x 224
21 |
22 | self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
23 | # 112 x 112
24 |
25 | self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
26 | self.conv4 = nn.Conv2d(64,128,3,1,0)
27 | self.relu4 = nn.ReLU(inplace=True)
28 | # 112 x 112
29 |
30 | self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
31 | self.conv5 = nn.Conv2d(128,128,3,1,0)
32 | self.relu5 = nn.ReLU(inplace=True)
33 | # 112 x 112
34 |
35 | self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
36 | # 56 x 56
37 |
38 | self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
39 | self.conv6 = nn.Conv2d(128,256,3,1,0)
40 | self.relu6 = nn.ReLU(inplace=True)
41 | # 56 x 56
42 | def forward(self,x):
43 | out = self.conv1(x)
44 | out = self.reflecPad1(out)
45 | out = self.conv2(out)
46 | out = self.relu2(out)
47 | out = self.reflecPad3(out)
48 | out = self.conv3(out)
49 | pool1 = self.relu3(out)
50 | out,pool_idx = self.maxPool(pool1)
51 | out = self.reflecPad4(out)
52 | out = self.conv4(out)
53 | out = self.relu4(out)
54 | out = self.reflecPad5(out)
55 | out = self.conv5(out)
56 | pool2 = self.relu5(out)
57 | out,pool_idx2 = self.maxPool2(pool2)
58 | out = self.reflecPad6(out)
59 | out = self.conv6(out)
60 | out = self.relu6(out)
61 | return out
62 |
63 | class decoder3(nn.Module):
64 | def __init__(self):
65 | super(decoder3,self).__init__()
66 | # decoder
67 | self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
68 | self.conv7 = nn.Conv2d(256,128,3,1,0)
69 | self.relu7 = nn.ReLU(inplace=True)
70 | # 56 x 56
71 |
72 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
73 | # 112 x 112
74 |
75 | self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
76 | self.conv8 = nn.Conv2d(128,128,3,1,0)
77 | self.relu8 = nn.ReLU(inplace=True)
78 | # 112 x 112
79 |
80 | self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
81 | self.conv9 = nn.Conv2d(128,64,3,1,0)
82 | self.relu9 = nn.ReLU(inplace=True)
83 |
84 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
85 | # 224 x 224
86 |
87 | self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
88 | self.conv10 = nn.Conv2d(64,64,3,1,0)
89 | self.relu10 = nn.ReLU(inplace=True)
90 |
91 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
92 | self.conv11 = nn.Conv2d(64,3,3,1,0)
93 |
94 | def forward(self,x):
95 | output = {}
96 | out = self.reflecPad7(x)
97 | out = self.conv7(out)
98 | out = self.relu7(out)
99 | out = self.unpool(out)
100 | out = self.reflecPad8(out)
101 | out = self.conv8(out)
102 | out = self.relu8(out)
103 | out = self.reflecPad9(out)
104 | out = self.conv9(out)
105 | out_relu9 = self.relu9(out)
106 | out = self.unpool2(out_relu9)
107 | out = self.reflecPad10(out)
108 | out = self.conv10(out)
109 | out = self.relu10(out)
110 | out = self.reflecPad11(out)
111 | out = self.conv11(out)
112 | return out
113 |
114 | class encoder4(nn.Module):
115 | def __init__(self):
116 | super(encoder4,self).__init__()
117 | # vgg
118 | # 224 x 224
119 | self.conv1 = nn.Conv2d(3,3,1,1,0)
120 | self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
121 | # 226 x 226
122 |
123 | self.conv2 = nn.Conv2d(3,64,3,1,0)
124 | self.relu2 = nn.ReLU(inplace=True)
125 | # 224 x 224
126 |
127 | self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
128 | self.conv3 = nn.Conv2d(64,64,3,1,0)
129 | self.relu3 = nn.ReLU(inplace=True)
130 | # 224 x 224
131 |
132 | self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
133 | # 112 x 112
134 |
135 | self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
136 | self.conv4 = nn.Conv2d(64,128,3,1,0)
137 | self.relu4 = nn.ReLU(inplace=True)
138 | # 112 x 112
139 |
140 | self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
141 | self.conv5 = nn.Conv2d(128,128,3,1,0)
142 | self.relu5 = nn.ReLU(inplace=True)
143 | # 112 x 112
144 |
145 | self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
146 | # 56 x 56
147 |
148 | self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
149 | self.conv6 = nn.Conv2d(128,256,3,1,0)
150 | self.relu6 = nn.ReLU(inplace=True)
151 | # 56 x 56
152 |
153 | self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
154 | self.conv7 = nn.Conv2d(256,256,3,1,0)
155 | self.relu7 = nn.ReLU(inplace=True)
156 | # 56 x 56
157 |
158 | self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
159 | self.conv8 = nn.Conv2d(256,256,3,1,0)
160 | self.relu8 = nn.ReLU(inplace=True)
161 | # 56 x 56
162 |
163 | self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
164 | self.conv9 = nn.Conv2d(256,256,3,1,0)
165 | self.relu9 = nn.ReLU(inplace=True)
166 | # 56 x 56
167 |
168 | self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
169 | # 28 x 28
170 |
171 | self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
172 | self.conv10 = nn.Conv2d(256,512,3,1,0)
173 | self.relu10 = nn.ReLU(inplace=True)
174 | # 28 x 28
175 | def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None):
176 | output = {}
177 | out = self.conv1(x)
178 | out = self.reflecPad1(out)
179 | out = self.conv2(out)
180 | output['r11'] = self.relu2(out)
181 | out = self.reflecPad7(output['r11'])
182 |
183 | out = self.conv3(out)
184 | output['r12'] = self.relu3(out)
185 |
186 | output['p1'] = self.maxPool(output['r12'])
187 | out = self.reflecPad4(output['p1'])
188 | out = self.conv4(out)
189 | output['r21'] = self.relu4(out)
190 | out = self.reflecPad7(output['r21'])
191 |
192 | out = self.conv5(out)
193 | output['r22'] = self.relu5(out)
194 |
195 | output['p2'] = self.maxPool2(output['r22'])
196 | out = self.reflecPad6(output['p2'])
197 | out = self.conv6(out)
198 | output['r31'] = self.relu6(out)
199 | if(matrix31 is not None):
200 | feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
201 | out = self.reflecPad7(feature3)
202 | else:
203 | out = self.reflecPad7(output['r31'])
204 | out = self.conv7(out)
205 | output['r32'] = self.relu7(out)
206 |
207 | out = self.reflecPad8(output['r32'])
208 | out = self.conv8(out)
209 | output['r33'] = self.relu8(out)
210 |
211 | out = self.reflecPad9(output['r33'])
212 | out = self.conv9(out)
213 | output['r34'] = self.relu9(out)
214 |
215 | output['p3'] = self.maxPool3(output['r34'])
216 | out = self.reflecPad10(output['p3'])
217 | out = self.conv10(out)
218 | output['r41'] = self.relu10(out)
219 |
220 | return output
221 |
222 | class decoder4(nn.Module):
223 | def __init__(self):
224 | super(decoder4,self).__init__()
225 | # decoder
226 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
227 | self.conv11 = nn.Conv2d(512,256,3,1,0)
228 | self.relu11 = nn.ReLU(inplace=True)
229 | # 28 x 28
230 |
231 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
232 | # 56 x 56
233 |
234 | self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
235 | self.conv12 = nn.Conv2d(256,256,3,1,0)
236 | self.relu12 = nn.ReLU(inplace=True)
237 | # 56 x 56
238 |
239 | self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
240 | self.conv13 = nn.Conv2d(256,256,3,1,0)
241 | self.relu13 = nn.ReLU(inplace=True)
242 | # 56 x 56
243 |
244 | self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
245 | self.conv14 = nn.Conv2d(256,256,3,1,0)
246 | self.relu14 = nn.ReLU(inplace=True)
247 | # 56 x 56
248 |
249 | self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
250 | self.conv15 = nn.Conv2d(256,128,3,1,0)
251 | self.relu15 = nn.ReLU(inplace=True)
252 | # 56 x 56
253 |
254 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
255 | # 112 x 112
256 |
257 | self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
258 | self.conv16 = nn.Conv2d(128,128,3,1,0)
259 | self.relu16 = nn.ReLU(inplace=True)
260 | # 112 x 112
261 |
262 | self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
263 | self.conv17 = nn.Conv2d(128,64,3,1,0)
264 | self.relu17 = nn.ReLU(inplace=True)
265 | # 112 x 112
266 |
267 | self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
268 | # 224 x 224
269 |
270 | self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
271 | self.conv18 = nn.Conv2d(64,64,3,1,0)
272 | self.relu18 = nn.ReLU(inplace=True)
273 | # 224 x 224
274 |
275 | self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
276 | self.conv19 = nn.Conv2d(64,3,3,1,0)
277 |
278 | def forward(self,x):
279 | # decoder
280 | out = self.reflecPad11(x)
281 | out = self.conv11(out)
282 | out = self.relu11(out)
283 | out = self.unpool(out)
284 | out = self.reflecPad12(out)
285 | out = self.conv12(out)
286 |
287 | out = self.relu12(out)
288 | out = self.reflecPad13(out)
289 | out = self.conv13(out)
290 | out = self.relu13(out)
291 | out = self.reflecPad14(out)
292 | out = self.conv14(out)
293 | out = self.relu14(out)
294 | out = self.reflecPad15(out)
295 | out = self.conv15(out)
296 | out = self.relu15(out)
297 | out = self.unpool2(out)
298 | out = self.reflecPad16(out)
299 | out = self.conv16(out)
300 | out = self.relu16(out)
301 | out = self.reflecPad17(out)
302 | out = self.conv17(out)
303 | out = self.relu17(out)
304 | out = self.unpool3(out)
305 | out = self.reflecPad18(out)
306 | out = self.conv18(out)
307 | out = self.relu18(out)
308 | out = self.reflecPad19(out)
309 | out = self.conv19(out)
310 | return out
311 |
312 | class decoder4(nn.Module):
313 | def __init__(self):
314 | super(decoder4,self).__init__()
315 | # decoder
316 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
317 | self.conv11 = nn.Conv2d(512,256,3,1,0)
318 | self.relu11 = nn.ReLU(inplace=True)
319 | # 28 x 28
320 |
321 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
322 | # 56 x 56
323 |
324 | self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
325 | self.conv12 = nn.Conv2d(256,256,3,1,0)
326 | self.relu12 = nn.ReLU(inplace=True)
327 | # 56 x 56
328 |
329 | self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
330 | self.conv13 = nn.Conv2d(256,256,3,1,0)
331 | self.relu13 = nn.ReLU(inplace=True)
332 | # 56 x 56
333 |
334 | self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
335 | self.conv14 = nn.Conv2d(256,256,3,1,0)
336 | self.relu14 = nn.ReLU(inplace=True)
337 | # 56 x 56
338 |
339 | self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
340 | self.conv15 = nn.Conv2d(256,128,3,1,0)
341 | self.relu15 = nn.ReLU(inplace=True)
342 | # 56 x 56
343 |
344 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
345 | # 112 x 112
346 |
347 | self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
348 | self.conv16 = nn.Conv2d(128,128,3,1,0)
349 | self.relu16 = nn.ReLU(inplace=True)
350 | # 112 x 112
351 |
352 | self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
353 | self.conv17 = nn.Conv2d(128,64,3,1,0)
354 | self.relu17 = nn.ReLU(inplace=True)
355 | # 112 x 112
356 |
357 | self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
358 | # 224 x 224
359 |
360 | self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
361 | self.conv18 = nn.Conv2d(64,64,3,1,0)
362 | self.relu18 = nn.ReLU(inplace=True)
363 | # 224 x 224
364 |
365 | self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
366 | self.conv19 = nn.Conv2d(64,3,3,1,0)
367 |
368 | def forward(self,x):
369 | # decoder
370 | out = self.reflecPad11(x)
371 | out = self.conv11(out)
372 | out = self.relu11(out)
373 | out = self.unpool(out)
374 | out = self.reflecPad12(out)
375 | out = self.conv12(out)
376 |
377 | out = self.relu12(out)
378 | out = self.reflecPad13(out)
379 | out = self.conv13(out)
380 | out = self.relu13(out)
381 | out = self.reflecPad14(out)
382 | out = self.conv14(out)
383 | out = self.relu14(out)
384 | out = self.reflecPad15(out)
385 | out = self.conv15(out)
386 | out = self.relu15(out)
387 | out = self.unpool2(out)
388 | out = self.reflecPad16(out)
389 | out = self.conv16(out)
390 | out = self.relu16(out)
391 | out = self.reflecPad17(out)
392 | out = self.conv17(out)
393 | out = self.relu17(out)
394 | out = self.unpool3(out)
395 | out = self.reflecPad18(out)
396 | out = self.conv18(out)
397 | out = self.relu18(out)
398 | out = self.reflecPad19(out)
399 | out = self.conv19(out)
400 | return out
401 |
402 | class encoder5(nn.Module):
403 | def __init__(self):
404 | super(encoder5,self).__init__()
405 | # vgg
406 | # 224 x 224
407 | self.conv1 = nn.Conv2d(3,3,1,1,0)
408 | self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
409 | # 226 x 226
410 |
411 | self.conv2 = nn.Conv2d(3,64,3,1,0)
412 | self.relu2 = nn.ReLU(inplace=True)
413 | # 224 x 224
414 |
415 | self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
416 | self.conv3 = nn.Conv2d(64,64,3,1,0)
417 | self.relu3 = nn.ReLU(inplace=True)
418 | # 224 x 224
419 |
420 | self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
421 | # 112 x 112
422 |
423 | self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
424 | self.conv4 = nn.Conv2d(64,128,3,1,0)
425 | self.relu4 = nn.ReLU(inplace=True)
426 | # 112 x 112
427 |
428 | self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
429 | self.conv5 = nn.Conv2d(128,128,3,1,0)
430 | self.relu5 = nn.ReLU(inplace=True)
431 | # 112 x 112
432 |
433 | self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
434 | # 56 x 56
435 |
436 | self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
437 | self.conv6 = nn.Conv2d(128,256,3,1,0)
438 | self.relu6 = nn.ReLU(inplace=True)
439 | # 56 x 56
440 |
441 | self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
442 | self.conv7 = nn.Conv2d(256,256,3,1,0)
443 | self.relu7 = nn.ReLU(inplace=True)
444 | # 56 x 56
445 |
446 | self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
447 | self.conv8 = nn.Conv2d(256,256,3,1,0)
448 | self.relu8 = nn.ReLU(inplace=True)
449 | # 56 x 56
450 |
451 | self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
452 | self.conv9 = nn.Conv2d(256,256,3,1,0)
453 | self.relu9 = nn.ReLU(inplace=True)
454 | # 56 x 56
455 |
456 | self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
457 | # 28 x 28
458 |
459 | self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
460 | self.conv10 = nn.Conv2d(256,512,3,1,0)
461 | self.relu10 = nn.ReLU(inplace=True)
462 |
463 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
464 | self.conv11 = nn.Conv2d(512,512,3,1,0)
465 | self.relu11 = nn.ReLU(inplace=True)
466 |
467 | self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
468 | self.conv12 = nn.Conv2d(512,512,3,1,0)
469 | self.relu12 = nn.ReLU(inplace=True)
470 |
471 | self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
472 | self.conv13 = nn.Conv2d(512,512,3,1,0)
473 | self.relu13 = nn.ReLU(inplace=True)
474 |
475 | self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
476 | self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
477 | self.conv14 = nn.Conv2d(512,512,3,1,0)
478 | self.relu14 = nn.ReLU(inplace=True)
479 |
480 | def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None):
481 | output = {}
482 | out = self.conv1(x)
483 | out = self.reflecPad1(out)
484 | out = self.conv2(out)
485 | output['r11'] = self.relu2(out)
486 | out = self.reflecPad7(output['r11'])
487 |
488 | #out = self.reflecPad3(output['r11'])
489 | out = self.conv3(out)
490 | output['r12'] = self.relu3(out)
491 |
492 | output['p1'] = self.maxPool(output['r12'])
493 | out = self.reflecPad4(output['p1'])
494 | out = self.conv4(out)
495 | output['r21'] = self.relu4(out)
496 | out = self.reflecPad7(output['r21'])
497 |
498 | #out = self.reflecPad5(output['r21'])
499 | out = self.conv5(out)
500 | output['r22'] = self.relu5(out)
501 |
502 | output['p2'] = self.maxPool2(output['r22'])
503 | out = self.reflecPad6(output['p2'])
504 | out = self.conv6(out)
505 | output['r31'] = self.relu6(out)
506 | if(styleV256 is not None):
507 | feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256)
508 | out = self.reflecPad7(feature)
509 | else:
510 | out = self.reflecPad7(output['r31'])
511 | out = self.conv7(out)
512 | output['r32'] = self.relu7(out)
513 |
514 | out = self.reflecPad8(output['r32'])
515 | out = self.conv8(out)
516 | output['r33'] = self.relu8(out)
517 |
518 | out = self.reflecPad9(output['r33'])
519 | out = self.conv9(out)
520 | output['r34'] = self.relu9(out)
521 |
522 | output['p3'] = self.maxPool3(output['r34'])
523 | out = self.reflecPad10(output['p3'])
524 | out = self.conv10(out)
525 | output['r41'] = self.relu10(out)
526 |
527 | out = self.reflecPad11(output['r41'])
528 | out = self.conv11(out)
529 | output['r42'] = self.relu11(out)
530 |
531 | out = self.reflecPad12(output['r42'])
532 | out = self.conv12(out)
533 | output['r43'] = self.relu12(out)
534 |
535 | out = self.reflecPad13(output['r43'])
536 | out = self.conv13(out)
537 | output['r44'] = self.relu13(out)
538 |
539 | output['p4'] = self.maxPool4(output['r44'])
540 |
541 | out = self.reflecPad14(output['p4'])
542 | out = self.conv14(out)
543 | output['r51'] = self.relu14(out)
544 | return output
545 |
546 | class decoder5(nn.Module):
547 | def __init__(self):
548 | super(decoder5,self).__init__()
549 |
550 | # decoder
551 | self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
552 | self.conv15 = nn.Conv2d(512,512,3,1,0)
553 | self.relu15 = nn.ReLU(inplace=True)
554 |
555 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
556 | # 28 x 28
557 |
558 | self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
559 | self.conv16 = nn.Conv2d(512,512,3,1,0)
560 | self.relu16 = nn.ReLU(inplace=True)
561 | # 28 x 28
562 |
563 | self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
564 | self.conv17 = nn.Conv2d(512,512,3,1,0)
565 | self.relu17 = nn.ReLU(inplace=True)
566 | # 28 x 28
567 |
568 | self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
569 | self.conv18 = nn.Conv2d(512,512,3,1,0)
570 | self.relu18 = nn.ReLU(inplace=True)
571 | # 28 x 28
572 |
573 | self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
574 | self.conv19 = nn.Conv2d(512,256,3,1,0)
575 | self.relu19 = nn.ReLU(inplace=True)
576 | # 28 x 28
577 |
578 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
579 | # 56 x 56
580 |
581 | self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1))
582 | self.conv20 = nn.Conv2d(256,256,3,1,0)
583 | self.relu20 = nn.ReLU(inplace=True)
584 | # 56 x 56
585 |
586 | self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1))
587 | self.conv21 = nn.Conv2d(256,256,3,1,0)
588 | self.relu21 = nn.ReLU(inplace=True)
589 |
590 | self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1))
591 | self.conv22 = nn.Conv2d(256,256,3,1,0)
592 | self.relu22 = nn.ReLU(inplace=True)
593 |
594 | self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1))
595 | self.conv23 = nn.Conv2d(256,128,3,1,0)
596 | self.relu23 = nn.ReLU(inplace=True)
597 |
598 | self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
599 | # 112 X 112
600 |
601 | self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1))
602 | self.conv24 = nn.Conv2d(128,128,3,1,0)
603 | self.relu24 = nn.ReLU(inplace=True)
604 |
605 | self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1))
606 | self.conv25 = nn.Conv2d(128,64,3,1,0)
607 | self.relu25 = nn.ReLU(inplace=True)
608 |
609 | self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2)
610 |
611 | self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1))
612 | self.conv26 = nn.Conv2d(64,64,3,1,0)
613 | self.relu26 = nn.ReLU(inplace=True)
614 |
615 | self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1))
616 | self.conv27 = nn.Conv2d(64,3,3,1,0)
617 |
618 | def forward(self,x):
619 | # decoder
620 | out = self.reflecPad15(x)
621 | out = self.conv15(out)
622 | out = self.relu15(out)
623 | out = self.unpool(out)
624 | out = self.reflecPad16(out)
625 | out = self.conv16(out)
626 | out = self.relu16(out)
627 | out = self.reflecPad17(out)
628 | out = self.conv17(out)
629 | out = self.relu17(out)
630 | out = self.reflecPad18(out)
631 | out = self.conv18(out)
632 | out = self.relu18(out)
633 | out = self.reflecPad19(out)
634 | out = self.conv19(out)
635 | out = self.relu19(out)
636 | out = self.unpool2(out)
637 | out = self.reflecPad20(out)
638 | out = self.conv20(out)
639 | out = self.relu20(out)
640 | out = self.reflecPad21(out)
641 | out = self.conv21(out)
642 | out = self.relu21(out)
643 | out = self.reflecPad22(out)
644 | out = self.conv22(out)
645 | out = self.relu22(out)
646 | out = self.reflecPad23(out)
647 | out = self.conv23(out)
648 | out = self.relu23(out)
649 | out = self.unpool3(out)
650 | out = self.reflecPad24(out)
651 | out = self.conv24(out)
652 | out = self.relu24(out)
653 | out = self.reflecPad25(out)
654 | out = self.conv25(out)
655 | out = self.relu25(out)
656 | out = self.unpool4(out)
657 | out = self.reflecPad26(out)
658 | out = self.conv26(out)
659 | out = self.relu26(out)
660 | out = self.reflecPad27(out)
661 | out = self.conv27(out)
662 | return out
663 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/README.md:
--------------------------------------------------------------------------------
1 | # pytorch_spn
2 | To build, install [pytorch](https://github.com/pytorch) and run:
3 |
4 | $ sh make.sh
5 |
6 | See left_right_demo.py for usage:
7 |
8 | $ mv left_right_demo.py ../
9 |
10 | $ python left_right_demo.py
11 |
12 | The original codes (caffe) and models will be relesed [HERE](https://github.com/Liusifei/caffe-spn.git).
13 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/__init__.py
--------------------------------------------------------------------------------
/libs/pytorch_spn/_ext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/_ext/__init__.py
--------------------------------------------------------------------------------
/libs/pytorch_spn/_ext/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/_ext/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.ffi import _wrap_function
3 | from ._gaterecurrent2dnoind import lib as _lib, ffi as _ffi
4 |
5 | __all__ = []
6 | def _import_symbols(locals):
7 | for symbol in dir(_lib):
8 | fn = getattr(_lib, symbol)
9 | if callable(fn):
10 | locals[symbol] = _wrap_function(fn, _ffi)
11 | else:
12 | locals[symbol] = fn
13 | __all__.append(symbol)
14 |
15 | _import_symbols(locals())
16 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/_ext/gaterecurrent2dnoind/_gaterecurrent2dnoind.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/_ext/gaterecurrent2dnoind/_gaterecurrent2dnoind.so
--------------------------------------------------------------------------------
/libs/pytorch_spn/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.ffi import create_extension
4 |
5 | this_file = os.path.dirname(__file__)
6 |
7 | sources = []
8 | headers = []
9 | defines = []
10 | with_cuda = False
11 |
12 | if torch.cuda.is_available():
13 | print('Including CUDA code.')
14 | sources += ['src/gaterecurrent2dnoind_cuda.c']
15 | headers += ['src/gaterecurrent2dnoind_cuda.h']
16 | defines += [('WITH_CUDA', None)]
17 | with_cuda = True
18 |
19 | this_file = os.path.dirname(os.path.realpath(__file__))
20 | extra_objects = ['src/cuda/gaterecurrent2dnoind_kernel.cu.o']
21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
22 |
23 | ffi = create_extension(
24 | '_ext.gaterecurrent2dnoind',
25 | headers=headers,
26 | sources=sources,
27 | define_macros=defines,
28 | relative_to=__file__,
29 | with_cuda=with_cuda,
30 | extra_objects=extra_objects
31 | )
32 |
33 | if __name__ == '__main__':
34 | ffi.build()
35 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/__init__.py
--------------------------------------------------------------------------------
/libs/pytorch_spn/functions/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/__init__.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/functions/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/functions/__pycache__/gaterecurrent2dnoind.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/__pycache__/gaterecurrent2dnoind.cpython-36.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/functions/gaterecurrent2dnoind.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from .._ext import gaterecurrent2dnoind as gaterecurrent2d
4 |
5 | class GateRecurrent2dnoindFunction(Function):
6 | def __init__(self, horizontal_, reverse_):
7 | self.horizontal = horizontal_
8 | self.reverse = reverse_
9 |
10 | def forward(self, X, G1, G2, G3):
11 | num, channels, height, width = X.size()
12 | output = torch.zeros(num, channels, height, width)
13 |
14 | if not X.is_cuda:
15 | print("cpu version is not ready at this time")
16 | return 0
17 | else:
18 | output = output.cuda()
19 | gaterecurrent2d.gaterecurrent2dnoind_forward_cuda(self.horizontal,self.reverse, X, G1, G2, G3, output)
20 |
21 | self.X = X
22 | self.G1 = G1
23 | self.G2 = G2
24 | self.G3 = G3
25 | self.output = output
26 | self.hiddensize = X.size()
27 | return output
28 |
29 | def backward(self, grad_output):
30 | assert(self.hiddensize is not None and grad_output.is_cuda)
31 | num, channels, height, width = self.hiddensize
32 |
33 | grad_X = torch.zeros(num, channels, height, width).cuda()
34 | grad_G1 = torch.zeros(num, channels, height, width).cuda()
35 | grad_G2 = torch.zeros(num, channels, height, width).cuda()
36 | grad_G3 = torch.zeros(num, channels, height, width).cuda()
37 |
38 | gaterecurrent2d.gaterecurrent2dnoind_backward_cuda(self.horizontal, self.reverse, self.output, grad_output, self.X, self.G1, self.G2, self.G3, grad_X, grad_G1, grad_G2, grad_G3)
39 |
40 | del self.hiddensize
41 | del self.G1
42 | del self.G2
43 | del self.G3
44 | del self.output
45 | del self.X
46 |
47 | return grad_X, grad_G1, grad_G2, grad_G3
48 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/functions/gaterecurrent2dnoind.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/functions/gaterecurrent2dnoind.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/left_right_demo.py:
--------------------------------------------------------------------------------
1 | """
2 | An example of left->right propagation
3 |
4 | Other direction settings:
5 | left->right: Propagator = GateRecurrent2dnoind(True,False)
6 | right->left: Propagator = GateRecurrent2dnoind(True,True)
7 | top->bottom: Propagator = GateRecurrent2dnoind(False,False)
8 | bottom->top: Propagator = GateRecurrent2dnoind(False,True)
9 |
10 | X: any signal/feature map to be filtered
11 | G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom)
12 |
13 | Note:
14 | 1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images)
15 | 2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper)
16 | """
17 | import torch
18 | from torch.autograd import Variable
19 | from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind
20 |
21 | Propagator = GateRecurrent2dnoind(True,False)
22 |
23 | X = Variable(torch.randn(1,3,10,10))
24 | G1 = Variable(torch.randn(1,3,10,10))
25 | G2 = Variable(torch.randn(1,3,10,10))
26 | G3 = Variable(torch.randn(1,3,10,10))
27 |
28 | sum_abs = G1.abs() + G2.abs() + G3.abs()
29 | mask_need_norm = sum_abs.ge(1)
30 | mask_need_norm = mask_need_norm.float()
31 | G1_norm = torch.div(G1, sum_abs)
32 | G2_norm = torch.div(G2, sum_abs)
33 | G3_norm = torch.div(G3, sum_abs)
34 |
35 | G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
36 | G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
37 | G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
38 |
39 | X = X.cuda()
40 | G1 = G1.cuda()
41 | G2 = G2.cuda()
42 | G3 = G3.cuda()
43 |
44 | output = Propagator.forward(X,G1,G2,G3)
45 | print(X)
46 | print(output)
47 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_PATH=/usr/local/cuda/
4 |
5 | cd src/cuda/
6 | echo "Compiling gaterecurrent2dnoind layer kernels by nvcc..."
7 | nvcc -c -o gaterecurrent2dnoind_kernel.cu.o gaterecurrent2dnoind_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
8 | cd ../../
9 | python build.py
10 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/modules/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/modules/__init__.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/modules/__pycache__/gaterecurrent2dnoind.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/modules/__pycache__/gaterecurrent2dnoind.cpython-36.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/modules/gaterecurrent2dnoind.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from ..functions.gaterecurrent2dnoind import GateRecurrent2dnoindFunction
3 |
4 | class GateRecurrent2dnoind(nn.Module):
5 | """docstring for ."""
6 | def __init__(self, horizontal_, reverse_):
7 | super(GateRecurrent2dnoind, self).__init__()
8 | self.horizontal = horizontal_
9 | self.reverse = reverse_
10 |
11 | def forward(self, X, G1, G2, G3):
12 | return GateRecurrent2dnoindFunction(self.horizontal, self.reverse)(X, G1, G2, G3)
13 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/modules/gaterecurrent2dnoind.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/modules/gaterecurrent2dnoind.pyc
--------------------------------------------------------------------------------
/libs/pytorch_spn/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/src/.DS_Store
--------------------------------------------------------------------------------
/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu:
--------------------------------------------------------------------------------
1 | #ifdef __cplusplus
2 | extern "C" {
3 | #endif
4 |
5 | #include
6 | #include
7 | #include
8 | #include "gaterecurrent2dnoind_kernel.h"
9 |
10 | #define CUDA_1D_KERNEL_LOOP(i, n) \
11 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
12 | i += blockDim.x * gridDim.x)
13 |
14 | __device__ void get_gate_idx_sf(int h1, int w1, int h2, int w2, int * out, int horizontal, int reverse)
15 | {
16 | if(horizontal && ! reverse) // left -> right
17 | {
18 | if(w1>w2)
19 | {
20 | out[0]=h1;
21 | out[1]=w1;
22 | }
23 | else
24 | {
25 | out[0]=h2;
26 | out[1]=w2;
27 | }
28 | }
29 | if(horizontal && reverse) // right -> left
30 | {
31 | if(w1 bottom
43 | {
44 | if(h1>h2)
45 | {
46 | out[0]=h1;
47 | out[1]=w1;
48 | }
49 | else
50 | {
51 | out[0]=h2;
52 | out[1]=w2;
53 | }
54 | }
55 | if(!horizontal && reverse) // bottom -> top
56 | {
57 | if(h1=height)
74 | return 0;
75 | if(w<0 || w >= width)
76 | return 0;
77 |
78 | return data[n*channels*height*width + c * height*width + h * width + w];
79 | }
80 |
81 | __device__ void set_data_sf(float * data, int num, int channels,int height, int width,int n,int c,int h,int w, float v)
82 | {
83 | if(h<0 || h >=height)
84 | return ;
85 | if(w<0 || w >= width)
86 | return ;
87 |
88 | data[n*channels*height*width + c * height*width + h * width + w]=v;
89 | }
90 |
91 | __device__ float get_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse)
92 | {
93 | if(h1<0 || h1 >=height)
94 | return 0;
95 | if(w1<0 || w1 >= width)
96 | return 0;
97 | if(h2<0 || h2 >=height)
98 | return 0;
99 | if(w2<0 || w2 >= width)
100 | return 0;
101 | int idx[2];
102 |
103 | get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse);
104 |
105 | int h = idx[0];
106 | int w = idx[1];
107 |
108 | return data[n*channels*height*width + c * height*width + h * width + w];
109 | }
110 |
111 | __device__ void set_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse, float v)
112 | {
113 | if(h1<0 || h1 >=height)
114 | return ;
115 | if(w1<0 || w1 >= width)
116 | return ;
117 | if(h2<0 || h2 >=height)
118 | return ;
119 | if(w2<0 || w2 >= width)
120 | return ;
121 | int idx[2];
122 |
123 | get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse);
124 |
125 | int h = idx[0];
126 | int w = idx[1];
127 |
128 | data[n*channels*height*width + c * height*width + h * width + w]=v;
129 | }
130 |
131 | // we do not use set_gate_add_sf(...) in the caffe implimentation
132 | // avoid using atomicAdd
133 |
134 | __global__ void forward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, int horizontal, int reverse) {
135 | CUDA_1D_KERNEL_LOOP(index, count) {
136 |
137 | int hc_count = height * channels;
138 |
139 | int n,c,h,w;
140 | int temp=index;
141 | w = T;
142 | n = temp / hc_count;
143 | temp = temp % hc_count;
144 | c = temp / height;
145 | temp = temp % height;
146 | h = temp;
147 |
148 |
149 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
150 |
151 | float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
152 | float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
153 | float h1_minus1 = g_data_1 * h_minus1_data_1;
154 |
155 | float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
156 | float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w-1);
157 | float h2_minus1 = g_data_2 * h_minus1_data_2;
158 |
159 | float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
160 | float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
161 | float h3_minus1 = g_data_3 * h_minus1_data_3;
162 |
163 | float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
164 | float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
165 |
166 | float h_data = x_hype + h_hype;
167 |
168 | set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
169 |
170 | }
171 | }
172 |
173 | __global__ void forward_one_col_right_left( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
174 | CUDA_1D_KERNEL_LOOP(index, count) {
175 |
176 | int hc_count = height * channels;
177 | int n,c,h,w;
178 | int temp=index;
179 | w = T;
180 | n = temp / hc_count;
181 | temp = temp % hc_count;
182 | c = temp / height;
183 | temp = temp % height;
184 | h = temp;
185 |
186 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
187 |
188 | float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
189 | float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
190 | float h1_minus1 = g_data_1 * h_minus1_data_1;
191 |
192 | float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse);
193 | float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w+1);
194 | float h2_minus1 = g_data_2 * h_minus1_data_2;
195 |
196 | float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
197 | float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
198 | float h3_minus1 = g_data_3 * h_minus1_data_3;
199 |
200 | float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
201 | float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
202 |
203 | float h_data = x_hype + h_hype;
204 |
205 | set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
206 |
207 | }
208 | }
209 |
210 | __global__ void forward_one_row_top_bottom( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
211 | CUDA_1D_KERNEL_LOOP(index, count) {
212 |
213 | int wc_count = width * channels;
214 |
215 | int n,c,h,w;
216 | int temp=index;
217 | h = T;
218 | n = temp / wc_count;
219 | temp = temp % wc_count;
220 | c = temp / width;
221 | temp = temp % width;
222 | w = temp;
223 |
224 |
225 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
226 |
227 |
228 | float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
229 | float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
230 | float h1_minus1 = g_data_1 * h_minus1_data_1;
231 |
232 | float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse);
233 | float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h-1,w);
234 | float h2_minus1 = g_data_2 * h_minus1_data_2;
235 |
236 | float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
237 | float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
238 | float h3_minus1 = g_data_3 * h_minus1_data_3;
239 |
240 | float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
241 | float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
242 |
243 | float h_data = x_hype + h_hype;
244 |
245 | set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
246 |
247 | }
248 | }
249 |
250 | __global__ void forward_one_row_bottom_top( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
251 | CUDA_1D_KERNEL_LOOP(index, count) {
252 |
253 | int wc_count = width * channels;
254 |
255 | int n,c,h,w;
256 | int temp=index;
257 | h = T;
258 | n = temp / wc_count;
259 | temp = temp % wc_count;
260 | c = temp / width;
261 | temp = temp % width;
262 | w = temp;
263 |
264 |
265 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
266 |
267 |
268 | float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
269 | float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
270 | float h1_minus1 = g_data_1 * h_minus1_data_1;
271 |
272 |
273 | float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse);
274 | float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h+1,w);
275 | float h2_minus1 = g_data_2 * h_minus1_data_2;
276 |
277 | float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
278 | float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
279 | float h3_minus1 = g_data_3 * h_minus1_data_3;
280 |
281 | float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
282 | float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
283 |
284 | float h_data = x_hype + h_hype;
285 |
286 | set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
287 |
288 | }
289 | }
290 |
291 |
292 | __global__ void backward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
293 | CUDA_1D_KERNEL_LOOP(index, count) {
294 |
295 | int hc_count = height * channels;
296 |
297 | int n,c,h,w;
298 | int temp=index;
299 |
300 | w = T;
301 | n = temp / hc_count;
302 | temp = temp % hc_count;
303 | c = temp / height;
304 | temp = temp % height;
305 | h = temp;
306 |
307 | float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
308 |
309 | //h(t)_diff = top(t)_diff
310 | float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
311 |
312 | //h(t)_diff += h(t+1)_diff * g(t+1) if t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
547 |
548 | err = cudaGetLastError();
549 | if(cudaSuccess != err)
550 | {
551 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
552 | exit( -1 );
553 | }
554 | }
555 | return 1;
556 | }
557 |
558 | int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
559 | {
560 | int count = height_ * channels_ * num_;
561 | int kThreadsPerBlock = 1024;
562 | cudaError_t err;
563 |
564 | for(int t = width_ - 1; t >= 0; t--) {
565 | forward_one_col_right_left<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
566 |
567 | err = cudaGetLastError();
568 | if(cudaSuccess != err)
569 | {
570 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
571 | exit( -1 );
572 | }
573 | }
574 | return 1;
575 | }
576 |
577 | int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
578 | {
579 | int count = width_ * channels_ * num_;
580 | int kThreadsPerBlock = 1024;
581 | cudaError_t err;
582 |
583 | for(int t=0; t< height_; t++) {
584 | forward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
585 |
586 | err = cudaGetLastError();
587 | if(cudaSuccess != err)
588 | {
589 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
590 | exit( -1 );
591 | }
592 | }
593 | return 1;
594 | }
595 |
596 | int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
597 | {
598 | int count = width_ * channels_ * num_;
599 | int kThreadsPerBlock = 1024;
600 | cudaError_t err;
601 |
602 | for(int t = height_-1; t >= 0; t--) {
603 | forward_one_row_bottom_top<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
604 |
605 | err = cudaGetLastError();
606 | if(cudaSuccess != err)
607 | {
608 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
609 | exit( -1 );
610 | }
611 | }
612 | return 1;
613 | }
614 |
615 | int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
616 | {
617 | int count = height_ * channels_ * num_;
618 | int kThreadsPerBlock = 1024;
619 | cudaError_t err;
620 |
621 | for(int t = width_ -1; t>=0; t--)
622 | {
623 | backward_one_col_left_right<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
624 |
625 | err = cudaGetLastError();
626 | if(cudaSuccess != err)
627 | {
628 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
629 | exit( -1 );
630 | }
631 | }
632 | return 1;
633 | }
634 |
635 | int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
636 | {
637 | int count = height_ * channels_ * num_;
638 | int kThreadsPerBlock = 1024;
639 | cudaError_t err;
640 |
641 | for(int t = 0; t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
644 |
645 | err = cudaGetLastError();
646 | if(cudaSuccess != err)
647 | {
648 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
649 | exit( -1 );
650 | }
651 | }
652 | return 1;
653 | }
654 |
655 | int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
656 | {
657 | int count = width_ * channels_ * num_;
658 | int kThreadsPerBlock = 1024;
659 | cudaError_t err;
660 |
661 | for(int t = height_-1; t>=0; t--)
662 | {
663 | backward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
664 |
665 | err = cudaGetLastError();
666 | if(cudaSuccess != err)
667 | {
668 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
669 | exit( -1 );
670 | }
671 | }
672 | return 1;
673 | }
674 |
675 | int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
676 | {
677 | int count = width_ * channels_ * num_;
678 | int kThreadsPerBlock = 1024;
679 | cudaError_t err;
680 |
681 | for(int t = 0; t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
684 |
685 | err = cudaGetLastError();
686 | if(cudaSuccess != err)
687 | {
688 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
689 | exit( -1 );
690 | }
691 | }
692 | return 1;
693 | }
694 |
695 | #ifdef __cplusplus
696 | }
697 | #endif
698 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunshineatnoon/LinearStyleTransfer/188e8e5b2ad9dbb00bed6317eb1280ecde48c657/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o
--------------------------------------------------------------------------------
/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _GATERECURRENT2DNOIND_KERNEL
2 | #define _GATERECURRENT2DNOIND_KERNEL
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | int Forward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
9 |
10 | int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
11 |
12 | int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
13 |
14 | int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
15 |
16 | int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
17 |
18 | int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
19 |
20 | int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
21 |
22 | int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
23 |
24 | #ifdef __cplusplus
25 | }
26 | #endif
27 |
28 | #endif
29 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c:
--------------------------------------------------------------------------------
1 | // gaterecurrent2dnoind_cuda.c
2 | #include
3 | #include
4 | #include "gaterecurrent2dnoind_cuda.h"
5 | #include "cuda/gaterecurrent2dnoind_kernel.h"
6 |
7 | // typedef bool boolean;
8 |
9 | // this symbol will be resolved automatically from PyTorch libs
10 | extern THCState *state;
11 |
12 | int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output)
13 | {
14 | // Grab the input tensor to flat
15 | float * X_data = THCudaTensor_data(state, X);
16 | float * G1_data = THCudaTensor_data(state, G1);
17 | float * G2_data = THCudaTensor_data(state, G2);
18 | float * G3_data = THCudaTensor_data(state, G3);
19 | float * H_data = THCudaTensor_data(state, output);
20 |
21 | // dimensions
22 | int num_ = THCudaTensor_size(state, X, 0);
23 | int channels_ = THCudaTensor_size(state, X, 1);
24 | int height_ = THCudaTensor_size(state, X, 2);
25 | int width_ = THCudaTensor_size(state, X, 3);
26 |
27 | cudaStream_t stream = THCState_getCurrentStream(state);
28 |
29 | if(horizontal_ && !reverse_) // left to right
30 | {
31 | //const int count = height_ * channels_ * num_;
32 | Forward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
33 | }
34 | else if(horizontal_ && reverse_) // right to left
35 | {
36 | Forward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
37 | }
38 | else if(!horizontal_ && !reverse_) // top to bottom
39 | {
40 | Forward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
41 | }
42 | else
43 | {
44 | Forward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
45 | }
46 |
47 | return 1;
48 | }
49 |
50 | int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_grad, THCudaTensor * G1_grad, THCudaTensor * G2_grad, THCudaTensor * G3_grad)
51 | {
52 | //Grab the input tensor to flat
53 | float * X_data = THCudaTensor_data(state, X);
54 | float * G1_data = THCudaTensor_data(state, G1);
55 | float * G2_data = THCudaTensor_data(state, G2);
56 | float * G3_data = THCudaTensor_data(state, G3);
57 | float * H_data = THCudaTensor_data(state, top);
58 |
59 | float * H_diff = THCudaTensor_data(state, top_grad);
60 |
61 | float * X_diff = THCudaTensor_data(state, X_grad);
62 | float * G1_diff = THCudaTensor_data(state, G1_grad);
63 | float * G2_diff = THCudaTensor_data(state, G2_grad);
64 | float * G3_diff = THCudaTensor_data(state, G3_grad);
65 |
66 | // dimensions
67 | int num_ = THCudaTensor_size(state, X, 0);
68 | int channels_ = THCudaTensor_size(state, X, 1);
69 | int height_ = THCudaTensor_size(state, X, 2);
70 | int width_ = THCudaTensor_size(state, X, 3);
71 |
72 | cudaStream_t stream = THCState_getCurrentStream(state);
73 |
74 | if(horizontal_ && ! reverse_) //left to right
75 | {
76 | Backward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
77 | }
78 | else if(horizontal_ && reverse_) //right to left
79 | {
80 | Backward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
81 | }
82 | else if(!horizontal_ && !reverse_) //top to bottom
83 | {
84 | Backward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
85 | }
86 | else {
87 | Backward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
88 | }
89 |
90 | return 1;
91 | }
92 |
--------------------------------------------------------------------------------
/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h:
--------------------------------------------------------------------------------
1 |
2 | // #include
3 | // gaterecurrent2dnoind_cuda.h
4 | int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output);
5 |
6 | int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_diff, THCudaTensor * G1_diff, THCudaTensor * G2_diff, THCudaTensor * G3_diff);
7 |
--------------------------------------------------------------------------------
/libs/smooth_filter.py:
--------------------------------------------------------------------------------
1 | """
2 | Code cc from https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py
3 | """
4 | src = '''
5 | #include "/usr/local/cuda/include/math_functions.h"
6 | #define TB 256
7 | #define EPS 1e-7
8 |
9 | __device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) {
10 | double m[16], inv[16];
11 | for (int i = 0; i < 4; i++) {
12 | for (int j = 0; j < 4; j++) {
13 | m[i * 4 + j] = m_in[i][j];
14 | }
15 | }
16 |
17 | inv[0] = m[5] * m[10] * m[15] -
18 | m[5] * m[11] * m[14] -
19 | m[9] * m[6] * m[15] +
20 | m[9] * m[7] * m[14] +
21 | m[13] * m[6] * m[11] -
22 | m[13] * m[7] * m[10];
23 |
24 | inv[4] = -m[4] * m[10] * m[15] +
25 | m[4] * m[11] * m[14] +
26 | m[8] * m[6] * m[15] -
27 | m[8] * m[7] * m[14] -
28 | m[12] * m[6] * m[11] +
29 | m[12] * m[7] * m[10];
30 |
31 | inv[8] = m[4] * m[9] * m[15] -
32 | m[4] * m[11] * m[13] -
33 | m[8] * m[5] * m[15] +
34 | m[8] * m[7] * m[13] +
35 | m[12] * m[5] * m[11] -
36 | m[12] * m[7] * m[9];
37 |
38 | inv[12] = -m[4] * m[9] * m[14] +
39 | m[4] * m[10] * m[13] +
40 | m[8] * m[5] * m[14] -
41 | m[8] * m[6] * m[13] -
42 | m[12] * m[5] * m[10] +
43 | m[12] * m[6] * m[9];
44 |
45 | inv[1] = -m[1] * m[10] * m[15] +
46 | m[1] * m[11] * m[14] +
47 | m[9] * m[2] * m[15] -
48 | m[9] * m[3] * m[14] -
49 | m[13] * m[2] * m[11] +
50 | m[13] * m[3] * m[10];
51 |
52 | inv[5] = m[0] * m[10] * m[15] -
53 | m[0] * m[11] * m[14] -
54 | m[8] * m[2] * m[15] +
55 | m[8] * m[3] * m[14] +
56 | m[12] * m[2] * m[11] -
57 | m[12] * m[3] * m[10];
58 |
59 | inv[9] = -m[0] * m[9] * m[15] +
60 | m[0] * m[11] * m[13] +
61 | m[8] * m[1] * m[15] -
62 | m[8] * m[3] * m[13] -
63 | m[12] * m[1] * m[11] +
64 | m[12] * m[3] * m[9];
65 |
66 | inv[13] = m[0] * m[9] * m[14] -
67 | m[0] * m[10] * m[13] -
68 | m[8] * m[1] * m[14] +
69 | m[8] * m[2] * m[13] +
70 | m[12] * m[1] * m[10] -
71 | m[12] * m[2] * m[9];
72 |
73 | inv[2] = m[1] * m[6] * m[15] -
74 | m[1] * m[7] * m[14] -
75 | m[5] * m[2] * m[15] +
76 | m[5] * m[3] * m[14] +
77 | m[13] * m[2] * m[7] -
78 | m[13] * m[3] * m[6];
79 |
80 | inv[6] = -m[0] * m[6] * m[15] +
81 | m[0] * m[7] * m[14] +
82 | m[4] * m[2] * m[15] -
83 | m[4] * m[3] * m[14] -
84 | m[12] * m[2] * m[7] +
85 | m[12] * m[3] * m[6];
86 |
87 | inv[10] = m[0] * m[5] * m[15] -
88 | m[0] * m[7] * m[13] -
89 | m[4] * m[1] * m[15] +
90 | m[4] * m[3] * m[13] +
91 | m[12] * m[1] * m[7] -
92 | m[12] * m[3] * m[5];
93 |
94 | inv[14] = -m[0] * m[5] * m[14] +
95 | m[0] * m[6] * m[13] +
96 | m[4] * m[1] * m[14] -
97 | m[4] * m[2] * m[13] -
98 | m[12] * m[1] * m[6] +
99 | m[12] * m[2] * m[5];
100 |
101 | inv[3] = -m[1] * m[6] * m[11] +
102 | m[1] * m[7] * m[10] +
103 | m[5] * m[2] * m[11] -
104 | m[5] * m[3] * m[10] -
105 | m[9] * m[2] * m[7] +
106 | m[9] * m[3] * m[6];
107 |
108 | inv[7] = m[0] * m[6] * m[11] -
109 | m[0] * m[7] * m[10] -
110 | m[4] * m[2] * m[11] +
111 | m[4] * m[3] * m[10] +
112 | m[8] * m[2] * m[7] -
113 | m[8] * m[3] * m[6];
114 |
115 | inv[11] = -m[0] * m[5] * m[11] +
116 | m[0] * m[7] * m[9] +
117 | m[4] * m[1] * m[11] -
118 | m[4] * m[3] * m[9] -
119 | m[8] * m[1] * m[7] +
120 | m[8] * m[3] * m[5];
121 |
122 | inv[15] = m[0] * m[5] * m[10] -
123 | m[0] * m[6] * m[9] -
124 | m[4] * m[1] * m[10] +
125 | m[4] * m[2] * m[9] +
126 | m[8] * m[1] * m[6] -
127 | m[8] * m[2] * m[5];
128 |
129 | double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
130 |
131 | if (abs(det) < 1e-9) {
132 | return false;
133 | }
134 |
135 |
136 | det = 1.0 / det;
137 |
138 | for (int i = 0; i < 4; i++) {
139 | for (int j = 0; j < 4; j++) {
140 | inv_out[i][j] = inv[i * 4 + j] * det;
141 | }
142 | }
143 |
144 | return true;
145 | }
146 |
147 | extern "C"
148 | __global__ void best_local_affine_kernel(
149 | float *output, float *input, float *affine_model,
150 | int h, int w, float epsilon, int kernel_radius
151 | )
152 | {
153 | int size = h * w;
154 | int id = blockIdx.x * blockDim.x + threadIdx.x;
155 |
156 | if (id < size) {
157 | int x = id % w, y = id / w;
158 |
159 | double Mt_M[4][4] = {}; // 4x4
160 | double invMt_M[4][4] = {};
161 | double Mt_S[3][4] = {}; // RGB -> 1x4
162 | double A[3][4] = {};
163 | for (int i = 0; i < 4; i++)
164 | for (int j = 0; j < 4; j++) {
165 | Mt_M[i][j] = 0, invMt_M[i][j] = 0;
166 | if (i != 3) {
167 | Mt_S[i][j] = 0, A[i][j] = 0;
168 | if (i == j)
169 | Mt_M[i][j] = 1e-3;
170 | }
171 | }
172 |
173 | for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
174 | for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
175 |
176 | int xx = x + dx, yy = y + dy;
177 | int id2 = yy * w + xx;
178 |
179 | if (0 <= xx && xx < w && 0 <= yy && yy < h) {
180 |
181 | Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size];
182 | Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size];
183 | Mt_M[0][2] += input[id2 + 2*size] * input[id2];
184 | Mt_M[0][3] += input[id2 + 2*size];
185 |
186 | Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size];
187 | Mt_M[1][1] += input[id2 + size] * input[id2 + size];
188 | Mt_M[1][2] += input[id2 + size] * input[id2];
189 | Mt_M[1][3] += input[id2 + size];
190 |
191 | Mt_M[2][0] += input[id2] * input[id2 + 2*size];
192 | Mt_M[2][1] += input[id2] * input[id2 + size];
193 | Mt_M[2][2] += input[id2] * input[id2];
194 | Mt_M[2][3] += input[id2];
195 |
196 | Mt_M[3][0] += input[id2 + 2*size];
197 | Mt_M[3][1] += input[id2 + size];
198 | Mt_M[3][2] += input[id2];
199 | Mt_M[3][3] += 1;
200 |
201 | Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size];
202 | Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size];
203 | Mt_S[0][2] += input[id2] * output[id2 + 2*size];
204 | Mt_S[0][3] += output[id2 + 2*size];
205 |
206 | Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size];
207 | Mt_S[1][1] += input[id2 + size] * output[id2 + size];
208 | Mt_S[1][2] += input[id2] * output[id2 + size];
209 | Mt_S[1][3] += output[id2 + size];
210 |
211 | Mt_S[2][0] += input[id2 + 2*size] * output[id2];
212 | Mt_S[2][1] += input[id2 + size] * output[id2];
213 | Mt_S[2][2] += input[id2] * output[id2];
214 | Mt_S[2][3] += output[id2];
215 | }
216 | }
217 | }
218 |
219 | bool success = InverseMat4x4(Mt_M, invMt_M);
220 |
221 | for (int i = 0; i < 3; i++) {
222 | for (int j = 0; j < 4; j++) {
223 | for (int k = 0; k < 4; k++) {
224 | A[i][j] += invMt_M[j][k] * Mt_S[i][k];
225 | }
226 | }
227 | }
228 |
229 | for (int i = 0; i < 3; i++) {
230 | for (int j = 0; j < 4; j++) {
231 | int affine_id = i * 4 + j;
232 | affine_model[12 * id + affine_id] = A[i][j];
233 | }
234 | }
235 | }
236 | return ;
237 | }
238 |
239 | extern "C"
240 | __global__ void bilateral_smooth_kernel(
241 | float *affine_model, float *filtered_affine_model, float *guide,
242 | int h, int w, int kernel_radius, float sigma1, float sigma2
243 | )
244 | {
245 | int id = blockIdx.x * blockDim.x + threadIdx.x;
246 | int size = h * w;
247 | if (id < size) {
248 | int x = id % w;
249 | int y = id / w;
250 |
251 | double sum_affine[12] = {};
252 | double sum_weight = 0;
253 | for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
254 | for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
255 | int yy = y + dy, xx = x + dx;
256 | int id2 = yy * w + xx;
257 | if (0 <= xx && xx < w && 0 <= yy && yy < h) {
258 | float color_diff1 = guide[yy*w + xx] - guide[y*w + x];
259 | float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size];
260 | float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size];
261 | float color_diff_sqr =
262 | (color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3;
263 |
264 | float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1));
265 | float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2));
266 | float weight = v1 * v2;
267 |
268 | for (int i = 0; i < 3; i++) {
269 | for (int j = 0; j < 4; j++) {
270 | int affine_id = i * 4 + j;
271 | sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id];
272 | }
273 | }
274 | sum_weight += weight;
275 | }
276 | }
277 | }
278 |
279 | for (int i = 0; i < 3; i++) {
280 | for (int j = 0; j < 4; j++) {
281 | int affine_id = i * 4 + j;
282 | filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight;
283 | }
284 | }
285 | }
286 | return ;
287 | }
288 |
289 |
290 | extern "C"
291 | __global__ void reconstruction_best_kernel(
292 | float *input, float *filtered_affine_model, float *filtered_best_output,
293 | int h, int w
294 | )
295 | {
296 | int id = blockIdx.x * blockDim.x + threadIdx.x;
297 | int size = h * w;
298 | if (id < size) {
299 | double out1 =
300 | input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] +
301 | input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] +
302 | input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] +
303 | filtered_affine_model[id*12 + 3]; //A[0][3];
304 | double out2 =
305 | input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] +
306 | input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] +
307 | input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] +
308 | filtered_affine_model[id*12 + 7]; //A[1][3];
309 | double out3 =
310 | input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] +
311 | input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] +
312 | input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] +
313 | filtered_affine_model[id*12 + 11]; // A[2][3];
314 |
315 | filtered_best_output[id] = out1;
316 | filtered_best_output[id + size] = out2;
317 | filtered_best_output[id + 2*size] = out3;
318 | }
319 | return ;
320 | }
321 | '''
322 |
323 | import cv2
324 | import torch
325 | import numpy as np
326 | from PIL import Image
327 | from cupy.cuda import function
328 | from pynvrtc.compiler import Program
329 | from collections import namedtuple
330 |
331 |
332 | def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e):
333 | # program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8'))
334 | # ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')])
335 | program = Program(src, 'best_local_affine_kernel.cu')
336 | ptx = program.compile(['-I/usr/local/cuda/include'])
337 | m = function.Module()
338 | m.load(bytes(ptx.encode()))
339 |
340 | _reconstruction_best_kernel = m.get_function('reconstruction_best_kernel')
341 | _bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel')
342 | _best_local_affine_kernel = m.get_function('best_local_affine_kernel')
343 | Stream = namedtuple('Stream', ['ptr'])
344 | s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
345 |
346 | filter_radius = f_r
347 | sigma1 = filter_radius / 3
348 | sigma2 = f_e
349 | radius = (patch - 1) / 2
350 |
351 | filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda()
352 | affine_model = torch.zeros((h * w, 12)).cuda()
353 | filtered_affine_model =torch.zeros((h * w, 12)).cuda()
354 |
355 | input_ = torch.from_numpy(input_cpu).cuda()
356 | output_ = torch.from_numpy(output_cpu).cuda()
357 | _best_local_affine_kernel(
358 | grid=(int((h * w) / 256 + 1), 1),
359 | block=(256, 1, 1),
360 | args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(),
361 | np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s
362 | )
363 |
364 | _bilateral_smooth_kernel(
365 | grid=(int((h * w) / 256 + 1), 1),
366 | block=(256, 1, 1),
367 | args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s
368 | )
369 |
370 | _reconstruction_best_kernel(
371 | grid=(int((h * w) / 256 + 1), 1),
372 | block=(256, 1, 1),
373 | args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(),
374 | np.int32(h), np.int32(w)], stream=s
375 | )
376 | numpy_filtered_best_output = filtered_best_output.cpu().numpy()
377 | return numpy_filtered_best_output
378 |
379 |
380 | def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1):
381 | '''
382 | :param initImg: intermediate output. Either image path or PIL Image
383 | :param contentImg: content image output. Either path or PIL Image
384 | :return: stylized output image. PIL Image
385 | '''
386 | if type(initImg) == str:
387 | initImg = Image.open(initImg).convert("RGB")
388 | best_image_bgr = np.array(initImg, dtype=np.float32)
389 | bW, bH, bC = best_image_bgr.shape
390 | best_image_bgr = best_image_bgr[:, :, ::-1]
391 | best_image_bgr = best_image_bgr.transpose((2, 0, 1))
392 |
393 | if type(contentImg) == str:
394 | contentImg = Image.open(contentImg).convert("RGB")
395 | content_input = contentImg.resize((bH,bW))
396 | else:
397 | content_input = cv2.resize(contentImg,(bH,bW))
398 | content_input = np.array(content_input, dtype=np.float32)
399 | content_input = content_input[:, :, ::-1]
400 | content_input = content_input.transpose((2, 0, 1))
401 | input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255.
402 | _, H, W = np.shape(input_)
403 | output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255.
404 | best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge)
405 | best_ = best_.transpose(1, 2, 0)
406 | result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.)))
407 | return result
408 |
--------------------------------------------------------------------------------
/libs/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import cv2
4 | import time
5 | import torch
6 | import scipy.misc
7 | import numpy as np
8 | import scipy.sparse
9 | from PIL import Image
10 | import scipy.sparse.linalg
11 | from cv2.ximgproc import jointBilateralFilter
12 | from torch.utils.serialization import load_lua
13 | from numpy.lib.stride_tricks import as_strided
14 |
15 | def whiten(cF):
16 | cFSize = cF.size()
17 | c_mean = torch.mean(cF,1) # c x (h x w)
18 | c_mean = c_mean.unsqueeze(1).expand_as(cF)
19 | cF = cF - c_mean
20 |
21 | contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + torch.eye(cFSize[0]).double()
22 | c_u,c_e,c_v = torch.svd(contentConv,some=False)
23 |
24 | k_c = cFSize[0]
25 | for i in range(cFSize[0]):
26 | if c_e[i] < 0.00001:
27 | k_c = i
28 | break
29 |
30 | c_d = (c_e[0:k_c]).pow(-0.5)
31 | step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d))
32 | step2 = torch.mm(step1,(c_v[:,0:k_c].t()))
33 | whiten_cF = torch.mm(step2,cF)
34 | return whiten_cF
35 |
36 | def numpy2cv2(cont,style,prop,width,height):
37 | cont = cont.transpose((1,2,0))
38 | cont = cont[...,::-1]
39 | cont = cont * 255
40 | cont = cv2.resize(cont,(width,height))
41 | #cv2.resize(iimg,(width,height))
42 | style = style.transpose((1,2,0))
43 | style = style[...,::-1]
44 | style = style * 255
45 | style = cv2.resize(style,(width,height))
46 |
47 | prop = prop.transpose((1,2,0))
48 | prop = prop[...,::-1]
49 | prop = prop * 255
50 | prop = cv2.resize(prop,(width,height))
51 |
52 | #return np.concatenate((cont,np.concatenate((style,prop),axis=1)),axis=1)
53 | return prop,cont
54 |
55 | def makeVideo(content,style,props,outf):
56 | print('Stack transferred frames back to video...')
57 | layers,height,width = content[0].shape
58 | fourcc = cv2.VideoWriter_fourcc(*'MJPG')
59 | video = cv2.VideoWriter(os.path.join(outf,'transfer.avi'),fourcc,10.0,(width,height))
60 | ori_video = cv2.VideoWriter(os.path.join(outf,'content.avi'),fourcc,10.0,(width,height))
61 | for j in range(len(content)):
62 | prop,cont = numpy2cv2(content[j],style,props[j],width,height)
63 | cv2.imwrite('prop.png',prop)
64 | cv2.imwrite('content.png',cont)
65 | # TODO: this is ugly, fix this
66 | imgj = cv2.imread('prop.png')
67 | imgc = cv2.imread('content.png')
68 |
69 | video.write(imgj)
70 | ori_video.write(imgc)
71 | # RGB or BRG, yuks
72 | video.release()
73 | ori_video.release()
74 | os.remove('prop.png')
75 | os.remove('content.png')
76 | print('Transferred video saved at %s.'%outf)
77 |
78 | def print_options(opt):
79 | message = ''
80 | message += '----------------- Options ---------------\n'
81 | for k, v in sorted(vars(opt).items()):
82 | comment = ''
83 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
84 | message += '----------------- End -------------------'
85 | print(message)
86 |
87 | # save to the disk
88 | expr_dir = os.path.join(opt.outf)
89 | os.makedirs(expr_dir,exist_ok=True)
90 | file_name = os.path.join(expr_dir, 'opt.txt')
91 | with open(file_name, 'wt') as opt_file:
92 | opt_file.write(message)
93 | opt_file.write('\n')
94 |
--------------------------------------------------------------------------------
/real-time-demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import argparse
5 | import numpy as np
6 | from PIL import Image
7 | from libs.Loader import Dataset
8 | from libs.Matrix import MulLayer
9 | from libs.utils import makeVideo
10 | import torch.backends.cudnn as cudnn
11 | from libs.models import encoder3,encoder4
12 | from libs.models import decoder3,decoder4
13 | import torchvision.transforms as transforms
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--vgg_dir", default='models/vgg_r31.pth',
17 | help='pre-trained encoder path')
18 | parser.add_argument("--decoder_dir", default='models/dec_r31.pth',
19 | help='pre-trained decoder path')
20 | parser.add_argument("--style", default="data/style/in2.jpg",
21 | help='path to style image')
22 | parser.add_argument("--matrixPath", default="models/r31.pth",
23 | help='path to pre-trained model')
24 | parser.add_argument('--fineSize', type=int, default=256,
25 | help='crop image size')
26 | parser.add_argument("--name",default="transferred_video",
27 | help="name of generated video")
28 | parser.add_argument("--layer",default="r31",
29 | help="features of which layer to transfer")
30 | parser.add_argument("--outf",default="real_time_demo_output",
31 | help="output folder")
32 |
33 | ################# PREPARATIONS #################
34 | opt = parser.parse_args()
35 | opt.cuda = torch.cuda.is_available()
36 | print(opt)
37 | os.makedirs(opt.outf,exist_ok=True)
38 | cudnn.benchmark = True
39 |
40 | ################# DATA #################
41 | def loadImg(imgPath):
42 | img = Image.open(imgPath).convert('RGB')
43 | transform = transforms.Compose([
44 | transforms.Scale(opt.fineSize),
45 | transforms.ToTensor()])
46 | return transform(img)
47 | style = loadImg(opt.style).unsqueeze(0)
48 |
49 | ################# MODEL #################
50 | if(opt.layer == 'r31'):
51 | matrix = MulLayer(layer='r31')
52 | vgg = encoder3()
53 | dec = decoder3()
54 | elif(opt.layer == 'r41'):
55 | matrix = MulLayer(layer='r41')
56 | vgg = encoder4()
57 | dec = decoder4()
58 | vgg.load_state_dict(torch.load(opt.vgg_dir))
59 | dec.load_state_dict(torch.load(opt.dec_dir))
60 | matrix.load_state_dict(torch.load(opt.matrixPath))
61 | for param in vgg.parameters():
62 | param.requires_grad = False
63 | for param in dec.parameters():
64 | param.requires_grad = False
65 | for param in matrix.parameters():
66 | param.requires_grad = False
67 |
68 | ################# GLOBAL VARIABLE #################
69 | content = torch.Tensor(1,3,opt.fineSize,opt.fineSize)
70 |
71 | ################# GPU #################
72 | if(opt.cuda):
73 | vgg.cuda()
74 | dec.cuda()
75 | matrix.cuda()
76 |
77 | style = style.cuda()
78 | content = content.cuda()
79 |
80 | totalTime = 0
81 | imageCounter = 0
82 | result_frames = []
83 | contents = []
84 | styles = []
85 | cap = cv2.VideoCapture(0)
86 | cap.set(3,256)
87 | cap.set(4,512)
88 | fourcc = cv2.VideoWriter_fourcc(*'MJPG')
89 | out = cv2.VideoWriter(os.path.join(opt.outf,opt.name+'.avi'),fourcc,20.0,(512,256))
90 |
91 | with torch.no_grad():
92 | sF = vgg(style)
93 |
94 | while(True):
95 | ret,frame = cap.read()
96 | frame = cv2.resize(frame,(512,256),interpolation=cv2.INTER_CUBIC)
97 | frame = frame.transpose((2,0,1))
98 | frame = frame[::-1,:,:]
99 | frame = frame/255.0
100 | frame = torch.from_numpy(frame.copy()).unsqueeze(0)
101 | content.data.resize_(frame.size()).copy_(frame)
102 | with torch.no_grad():
103 | cF = vgg(content)
104 | if(opt.layer == 'r41'):
105 | feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
106 | else:
107 | feature,transmatrix = matrix(cF,sF)
108 | transfer = dec(feature)
109 | transfer = transfer.clamp(0,1).squeeze(0).data.cpu().numpy()
110 | transfer = transfer.transpose((1,2,0))
111 | transfer = transfer[...,::-1]
112 | out.write(np.uint8(transfer*255))
113 | cv2.imshow('frame',transfer)
114 | if cv2.waitKey(1) & 0xFF == ord('q'):
115 | break
116 |
117 | # When everything done, release the capture
118 | out.release()
119 | cap.release()
120 | cv2.destroyAllWindows()
121 |
--------------------------------------------------------------------------------