├── DATASET.md
├── LICENSE
├── README.md
├── data
├── fashionWOMENBlouses_Shirtsid0000635004_1front.png
├── fashionWOMENBlouses_Shirtsid0000635004_1front_iuv.png
├── fashionWOMENBlouses_Shirtsid0000635004_1front_sil.png
├── fashionWOMENDressesid0000262902_1front_iuv.png
├── fashionWOMENDressesid0000262902_3back.png
├── fashionWOMENDressesid0000262902_3back_iuv.png
├── fashionWOMENDressesid0000262902_3back_sil.png
├── fashionWOMENSkirtsid0000177102_1front.png
├── fashionWOMENSkirtsid0000177102_1front_iuv.png
└── fashionWOMENSkirtsid0000177102_1front_sil.png
├── dataset.py
├── distributed.py
├── garment_transfer.py
├── inference.py
├── model.py
├── op
├── __init__.py
├── conv2d_gradfix.py
├── fused_act.py
├── fused_bias_act.cpp
├── fused_bias_act_kernel.cu
├── upfirdn2d.cpp
├── upfirdn2d.py
└── upfirdn2d_kernel.cu
├── requirements.txt
├── sphereface.py
├── test.py
├── train.py
└── util
├── __init__.py
├── complete_coor.py
├── coordinate_completion_model.py
├── dp2coor.py
├── generate_fashion_datasets.py
└── pickle2png.py
/DATASET.md:
--------------------------------------------------------------------------------
1 | # Pose with Style: human reposing with pose-guided StyleGAN2
2 |
3 |
4 | ## Dataset and Downloads
5 | 1. Download images:
6 | 1. Download `img_highres.zip` from [In-shop Clothes Retrieval Benchmark](https://drive.google.com/drive/folders/0B7EVK8r0v71pYkd5TzBiclMzR00?resourcekey=0-fsjVShvqXP2517KnwaZ0zw). You will need to follow the [download instructions](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html) to unzip the file. Unzip file in `DATASET/DeepFashion_highres/img_highres`
7 |
8 | 2. Download the [train/test data](https://drive.google.com/drive/folders/1BX3Bxh8KG01yKWViRY0WTyDWbJHju-SL): **train.lst**, **test.lst**, and **fashion-pairs-test.csv**. Put in `DATASET/DeepFashion_highres/tools`. Note: because not all training images had their densepose detected we used a slightly modified training pairs file [**fashion-pairs-test.csv**](https://drive.google.com/file/d/1Uxpz8yBJ53XPkJ3O2GFP3nnbZWJQYlbv/view?usp=sharing).
9 |
10 | 3. Split the train/test dataset using:
11 | ```
12 | python util/generate_fashion_datasets.py --dataroot DATASET/DeepFashion_highres
13 | ```
14 | This will save the train/test images in `DeepFashion_highres/train` and `DeepFashion_highres/test`.
15 |
16 | 2. Compute [densepose](https://github.com/facebookresearch/detectron2/tree/master/projects/DensePose):
17 | 1. Install detectron2 following their [installation instructions](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md).
18 |
19 | 2. Use [apply net](https://github.com/facebookresearch/detectron2/blob/master/projects/DensePose/doc/TOOL_APPLY_NET.md) from [densepose](https://github.com/facebookresearch/detectron2/tree/master/projects/DensePose) and save the train/test results to a pickle file.
20 | Make sure you download [densepose_rcnn_R_101_FPN_DL_s1x.pkl](https://github.com/facebookresearch/detectron2/blob/master/projects/DensePose/doc/DENSEPOSE_IUV.md#ModelZoo).
21 |
22 | 3. Copy `util/pickle2png.py` into your detectron2 DensePose project directory. Using the DensePose environment, convert the pickle file to densepose png images and save the results in `DATASET/densepose` directory, using:
23 | ```
24 | python pickle2png.py --pickle_file train_output.pkl --save_path DATASET/densepose/train
25 | ```
26 |
27 | 3. Compute [human foreground mask](https://github.com/Engineering-Course/CIHP_PGN). Save results in `silhouette` directory. Or you can download our computed silhouettes for the [training set](https://drive.google.com/file/d/1xXJGi5zkkTC2iIAUloylq6DlgqzAUrcw/view?usp=sharing) and [testing set](https://drive.google.com/file/d/1QdGgnBossIxsOrY8fUkJJYh1NbExIpyX/view?usp=sharing).
28 |
29 | 4. Compute UV space coordinates:
30 | 1. Compute UV space partial coordinates in the resolution 512x512.
31 | 1. Download the [UV space - 2D look up map](https://drive.google.com/file/d/1JLQ5bGl7YU-BwmdSc-DySy5Ya6FQIJBy/view?usp=sharing) and save it in `util` folder.
32 | 2. Compute partial coordinates:
33 | ```
34 | python util/dp2coor.py --image_file DATASET/DeepFashion_highres/tools/train.lst --dp_path DATASET/densepose/train --save_path DATASET/partial_coordinates/train
35 | ```
36 |
37 | 2. Complete the UV space coordinates offline, for faster training.
38 | 1. Download the pretrained coordinate completion model from [here](https://drive.google.com/file/d/1Tck_NzJ4ifT76csEShOtlRK7HpfjFhHP/view?usp=sharing).
39 | 2. Complete the partial coordinates.
40 | ```
41 | python util/complete_coor.py --dataroot DATASET/DeepFashion_highres --coordinates_path DATASET/partial_coordinates --image_file DATASET/DeepFashion_highres/tools/train.lst --phase train --save_path DATASET/complete_coordinates --pretrained_model /path/to/CCM_epoch50.pt
42 | ```
43 |
44 | 5. Download the following in `DATASET/resources`, to apply Face Identity loss:
45 | 1. download the pre-computed [required transformation (T)](https://drive.google.com/file/d/1r5ODZr1ewZk95Mdsmv-mGdW6lit3MNRc/view?usp=sharing) to align and crop the face.
46 | 2. Download [sphereface net pretrained model](https://drive.google.com/file/d/1p_cBfPZwwhWsWDXdKJ3n9VTp0_qsXZF0/view?usp=sharing).
47 |
48 | Note: we provide the DeepFashion train/test split of [StylePoseGAN](https://people.mpi-inf.mpg.de/~ksarkar/styleposegan/) [Sarkar et al. 2021]: [train pairs](https://drive.google.com/file/d/1ZaoQmUS92zHtqCWvsyunAaORlosLDVMm/view?usp=sharing), and [test pairs](https://drive.google.com/file/d/125EK9Y2QFMYMf8WV2_BoXIwLKz0uNxqa/view?usp=sharing).
49 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Badour AlBahar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pose with Style: Detail-Preserving Pose-Guided Image Synthesis with Conditional StyleGAN
2 | ### [[Paper](https://pose-with-style.github.io/asset/paper.pdf)] [[Project Website](https://pose-with-style.github.io/)] [[Output resutls](https://pose-with-style.github.io/results.html)]
3 |
4 | Official Pytorch implementation for **Pose with Style: Detail-Preserving Pose-Guided Image Synthesis with Conditional StyleGAN**. Please contact Badour AlBahar (badour@vt.edu) if you have any questions.
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## Requirements
12 | ```
13 | conda create -n pws python=3.8
14 | conda activate pws
15 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
16 | pip install -r requirements.txt
17 | ```
18 | Intall openCV using `conda install -c conda-forge opencv` or `pip install opencv-python`.
19 | If you would like to use [wandb](https://wandb.ai/site), install it using `pip install wandb`.
20 |
21 | ## Download pretrained models
22 | You can download the pretrained model [here](https://drive.google.com/file/d/1IcIJMTHA-_-_qcjRIrnSUnQf4n8GmtPd/view?usp=sharing), and the pretrained coordinate completion model [here](https://drive.google.com/file/d/1Tck_NzJ4ifT76csEShOtlRK7HpfjFhHP/view?usp=sharing).
23 |
24 | Note: we also provide the pretrained model trained on [StylePoseGAN](https://people.mpi-inf.mpg.de/~ksarkar/styleposegan/) [Sarkar et al. 2021] DeepFashion train/test split [here](https://drive.google.com/file/d/1DpOQ1Z7JEls3kCYaHMg1CpgjxbTJB6hg/view?usp=sharing). We also provide this split's pretrained coordinate completion model [here](https://drive.google.com/file/d/1oz-g0T5HLW-hVgeqYC55eI10DA-59Yp3/view?usp=sharing).
25 |
26 | ## Reposing
27 | Download the [UV space - 2D look up map](https://drive.google.com/file/d/1JLQ5bGl7YU-BwmdSc-DySy5Ya6FQIJBy/view?usp=sharing) and save it in `util` folder.
28 |
29 | We provide sample data in `data` directory. The output will be saved in `data/output` directory.
30 | ```
31 | python inference.py --input_path ./data --CCM_pretrained_model path/to/CCM_epoch50.pt --pretrained_model path/to/posewithstyle.pt
32 | ```
33 |
34 | To repose your own images you need to put the input image (input_name+'.png'), dense pose (input_name+'_iuv.png'), and silhouette (input_name+'_sil.png'), as well as the target dense pose (target_name+'_iuv.png') in `data` directory.
35 | ```
36 | python inference.py --input_path ./data --input_name fashionWOMENDressesid0000262902_3back --target_name fashionWOMENDressesid0000262902_1front --CCM_pretrained_model path/to/CCM_epoch50.pt --pretrained_model path/to/posewithstyle.pt
37 | ```
38 |
39 | ## Garment transfer
40 | Download the [UV space - 2D look up map](https://drive.google.com/file/d/1JLQ5bGl7YU-BwmdSc-DySy5Ya6FQIJBy/view?usp=sharing) and the [UV space body part segmentation](https://drive.google.com/file/d/179zoQVVrEgFbkEmHwLHU1AAqPktH-guU/view?usp=sharing). Save both in `util` folder.
41 | The UV space body part segmentation will provide a generic segmentation of the human body. Alternatively, you can specify your own mask of the region you want to transfer.
42 |
43 | We provide sample data in `data` directory. The output will be saved in `data/output` directory.
44 | ```
45 | python garment_transfer.py --input_path ./data --CCM_pretrained_model path/to/CCM_epoch50.pt --pretrained_model path/to/posewithstyle.pt --part upper_body
46 | ```
47 |
48 | To use your own images you need to put the input image (input_name+'.png'), dense pose (input_name+'_iuv.png'), and silhouette (input_name+'_sil.png'), as well as the garment source target image (target_name+'.png'), dense pose (target_name+'_iuv.png'), and silhouette (target_name+'_sil.png') in `data` directory. You can specify the part to be transferred using `--part` as `upper_body`, `lower_body`, `full_body` or `face`. The output as well as the part transferred (shown in red) will be saved in `data/output` directory.
49 | ```
50 | python garment_transfer.py --input_path ./data --input_name fashionWOMENSkirtsid0000177102_1front --target_name fashionWOMENBlouses_Shirtsid0000635004_1front --CCM_pretrained_model path/to/CCM_epoch50.pt --pretrained_model path/to/posewithstyle.pt --part upper_body
51 | ```
52 |
53 | ## DeepFashion Dataset
54 | To train or test, you must download and process the dataset. Please follow instructions in [Dataset and Downloads](https://github.com/BadourAlBahar/pose-with-style/blob/main/DATASET.md).
55 |
56 | You should have the following downloaded in your `DATASET` folder:
57 | ```
58 | DATASET/DeepFashion_highres
59 | - train
60 | - test
61 | - tools
62 | - train.lst
63 | - test.lst
64 | - fashion-pairs-train.csv
65 | - fashion-pairs-test.csv
66 |
67 | DATASET/densepose
68 | - train
69 | - test
70 |
71 | DATASET/silhouette
72 | - train
73 | - test
74 |
75 | DATASET/partial_coordinates
76 | - train
77 | - test
78 |
79 | DATASET/complete_coordinates
80 | - train
81 | - test
82 |
83 | DATASET/resources
84 | - train_face_T.pickle
85 | - sphere20a_20171020.pth
86 | ```
87 |
88 |
89 | ## Training
90 | Step 1: First, train the reposing model by focusing on generating the foreground.
91 | We set the batch size to 1 and train for 50 epochs. This training process takes around 7 days on 8 NVIDIA 2080 Ti GPUs.
92 | ```
93 | python -m torch.distributed.launch --nproc_per_node=8 --master_port XXXX train.py --batch 1 /path/to/DATASET --name exp_name_step1 --size 512 --faceloss --epoch 50
94 | ```
95 | The checkpoints will be saved in `checkpoint/exp_name`.
96 |
97 | Step 2: Then, finetune the model by training on the entire image (only masking the padded boundary).
98 | We set the batch size to 8 and train for 10 epochs. This training process takes less than 2 days on 2 A100 GPUs.
99 | ```
100 | python -m torch.distributed.launch --nproc_per_node=2 --master_port XXXX train.py --batch 8 /path/to/DATASET --name exp_name_step2 --size 512 --faceloss --epoch 10 --ckpt /path/to/step1/pretrained/model --finetune
101 | ```
102 |
103 | ## Testing
104 | To test the reposing model and generate the reposing results:
105 | ```
106 | python test.py /path/to/DATASET --pretrained_model /path/to/step2/pretrained/model --size 512 --save_path /path/to/save/output
107 | ```
108 | Output images will be saved in `--save_path`.
109 |
110 | You can find our reposing output images [here](https://pose-with-style.github.io/results.html).
111 |
112 | ## Evaluation
113 | We follow the same evaluation code as [Global-Flow-Local-Attention](https://github.com/RenYurui/Global-Flow-Local-Attention/blob/master/PERSON_IMAGE_GENERATION.md#evaluation).
114 |
115 |
116 | ## Bibtex
117 | Please consider citing our work if you find it useful for your research:
118 |
119 | @article{albahar2021pose,
120 | title = {Pose with {S}tyle: {D}etail-Preserving Pose-Guided Image Synthesis with Conditional StyleGAN},
121 | author = {AlBahar, Badour and Lu, Jingwan and Yang, Jimei and Shu, Zhixin and Shechtman, Eli and Huang, Jia-Bin},
122 | journal = {ACM Transactions on Graphics},
123 | year = {2021}
124 | }
125 |
126 |
127 | ## Acknowledgments
128 | This code is heavily borrowed from [Rosinality: StyleGAN 2 in PyTorch](https://github.com/rosinality/stylegan2-pytorch).
129 |
--------------------------------------------------------------------------------
/data/fashionWOMENBlouses_Shirtsid0000635004_1front.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENBlouses_Shirtsid0000635004_1front.png
--------------------------------------------------------------------------------
/data/fashionWOMENBlouses_Shirtsid0000635004_1front_iuv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENBlouses_Shirtsid0000635004_1front_iuv.png
--------------------------------------------------------------------------------
/data/fashionWOMENBlouses_Shirtsid0000635004_1front_sil.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENBlouses_Shirtsid0000635004_1front_sil.png
--------------------------------------------------------------------------------
/data/fashionWOMENDressesid0000262902_1front_iuv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENDressesid0000262902_1front_iuv.png
--------------------------------------------------------------------------------
/data/fashionWOMENDressesid0000262902_3back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENDressesid0000262902_3back.png
--------------------------------------------------------------------------------
/data/fashionWOMENDressesid0000262902_3back_iuv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENDressesid0000262902_3back_iuv.png
--------------------------------------------------------------------------------
/data/fashionWOMENDressesid0000262902_3back_sil.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENDressesid0000262902_3back_sil.png
--------------------------------------------------------------------------------
/data/fashionWOMENSkirtsid0000177102_1front.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENSkirtsid0000177102_1front.png
--------------------------------------------------------------------------------
/data/fashionWOMENSkirtsid0000177102_1front_iuv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENSkirtsid0000177102_1front_iuv.png
--------------------------------------------------------------------------------
/data/fashionWOMENSkirtsid0000177102_1front_sil.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENSkirtsid0000177102_1front_sil.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from torch.utils.data import Dataset
3 | import torchvision.transforms as transforms
4 | import os
5 | import pandas as pd
6 | import numpy as np
7 | import torch
8 | import random
9 | import pickle
10 |
11 |
12 | class DeepFashionDataset(Dataset):
13 | def __init__(self, path, phase, size):
14 | self.phase = phase # train or test
15 | self.size = size # 256 or 512 FOR 174x256 or 348x512
16 |
17 | # set root directories
18 | self.image_root = os.path.join(path, 'DeepFashion_highres', phase)
19 | self.densepose_root = os.path.join(path, 'densepose', phase)
20 | self.parsing_root = os.path.join(path, 'silhouette', phase)
21 | # path to pairs of data
22 | pairs_csv_path = os.path.join(path, 'DeepFashion_highres', 'tools', 'fashion-pairs-%s.csv'%phase)
23 |
24 | # uv space
25 | self.uv_root = os.path.join(path, 'complete_coordinates', phase)
26 |
27 | # initialize the pairs of data
28 | self.init_pairs(pairs_csv_path)
29 | self.data_size = len(self.pairs)
30 | print('%s data pairs (#=%d)...'%(phase, self.data_size))
31 |
32 | if phase == 'train':
33 | # get dictionary of image name and transfrom to detect and align the face
34 | with open(os.path.join(path, 'resources', 'train_face_T.pickle'), 'rb') as handle:
35 | self.faceTransform = pickle.load(handle)
36 |
37 | self.transform = transforms.Compose([transforms.ToTensor(),
38 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
39 |
40 |
41 | def init_pairs(self, pairs_csv_path):
42 | pairs_file = pd.read_csv(pairs_csv_path)
43 | self.pairs = []
44 | self.sources = {}
45 | print('Loading data pairs ...')
46 | for i in range(len(pairs_file)):
47 | pair = [pairs_file.iloc[i]['from'], pairs_file.iloc[i]['to']]
48 | self.pairs.append(pair)
49 | print('Loading data pairs finished ...')
50 |
51 |
52 | def __len__(self):
53 | return self.data_size
54 |
55 |
56 | def resize_height_PIL(self, x, height=512):
57 | w, h = x.size
58 | width = int(height * w / h)
59 | return x.resize((width, height), Image.NEAREST) #Image.ANTIALIAS
60 |
61 |
62 | def resize_PIL(self, x, height=512, width=348, type=Image.NEAREST):
63 | return x.resize((width, height), type)
64 |
65 |
66 | def tensors2square(self, im, pose, sil):
67 | width = im.shape[2]
68 | diff = self.size - width
69 | if self.phase == 'train':
70 | left = random.randint(0, diff)
71 | right = diff - left
72 | else: # when testing put in the center
73 | left = int((self.size-width)/2)
74 | right = diff - left
75 | im = torch.nn.functional.pad(input=im, pad=(right, left, 0, 0), mode='constant', value=0)
76 | pose = torch.nn.functional.pad(input=pose, pad=(right, left, 0, 0), mode='constant', value=0)
77 | sil = torch.nn.functional.pad(input=sil, pad=(right, left, 0, 0), mode='constant', value=0)
78 | return im, pose, sil, left, right
79 |
80 |
81 | def __getitem__(self, index):
82 | # get current pair
83 | im1_name, im2_name = self.pairs[index]
84 |
85 | # get path to dataset
86 | input_image_path = os.path.join(self.image_root, im1_name)
87 | target_image_path = os.path.join(self.image_root, im2_name)
88 | # dense pose
89 | input_densepose_path = os.path.join(self.densepose_root, im1_name.split('.')[0]+'_iuv.png')
90 | target_densepose_path = os.path.join(self.densepose_root, im2_name.split('.')[0]+'_iuv.png')
91 | # silhouette
92 | input_sil_path = os.path.join(self.parsing_root, im1_name.split('.')[0]+'_sil.png')
93 | target_sil_path = os.path.join(self.parsing_root, im2_name.split('.')[0]+'_sil.png')
94 | # uv space
95 | complete_coor_path = os.path.join(self.uv_root, im1_name.split('.')[0]+'_uv_coor.npy')
96 |
97 | # read data
98 | # get original size of data -> for augmentation
99 | input_image_pil = Image.open(input_image_path).convert('RGB')
100 | orig_w, orig_h = input_image_pil.size
101 | if self.phase == 'test':
102 | # set target height and target width
103 | if self.size == 512:
104 | target_h = 512
105 | target_w = 348
106 | if self.size == 256:
107 | target_h = 256
108 | target_w = 174
109 | # images
110 | input_image = self.resize_PIL(input_image_pil, height=target_h, width=target_w, type=Image.ANTIALIAS)
111 | target_image = self.resize_PIL(Image.open(target_image_path).convert('RGB'), height=target_h, width=target_w, type=Image.ANTIALIAS)
112 | # dense pose
113 | input_densepose = np.array(self.resize_PIL(Image.open(input_densepose_path), height=target_h, width=target_w))
114 | target_densepose = np.array(self.resize_PIL(Image.open(target_densepose_path), height=target_h, width=target_w))
115 | # silhouette
116 | silhouette1 = np.array(self.resize_PIL(Image.open(input_sil_path), height=target_h, width=target_w))/255
117 | silhouette2 = np.array(self.resize_PIL(Image.open(target_sil_path), height=target_h, width=target_w))/255
118 | # union with densepose mask for a more accurate mask
119 | silhouette1 = 1-((1-silhouette1) * (input_densepose[:, :, 0] == 0).astype('float'))
120 |
121 | else:
122 | input_image = self.resize_height_PIL(input_image_pil, self.size)
123 | target_image = self.resize_height_PIL(Image.open(target_image_path).convert('RGB'), self.size)
124 | # dense pose
125 | input_densepose = np.array(self.resize_height_PIL(Image.open(input_densepose_path), self.size))
126 | target_densepose = np.array(self.resize_height_PIL(Image.open(target_densepose_path), self.size))
127 | # silhouette
128 | silhouette1 = np.array(self.resize_height_PIL(Image.open(input_sil_path), self.size))/255
129 | silhouette2 = np.array(self.resize_height_PIL(Image.open(target_sil_path), self.size))/255
130 | # union with densepose masks
131 | silhouette1 = 1-((1-silhouette1) * (input_densepose[:, :, 0] == 0).astype('float'))
132 | silhouette2 = 1-((1-silhouette2) * (target_densepose[:, :, 0] == 0).astype('float'))
133 |
134 | # read uv-space data
135 | complete_coor = np.load(complete_coor_path)
136 |
137 | # Transform
138 | input_image = self.transform(input_image)
139 | target_image = self.transform(target_image)
140 | # Dense Pose
141 | input_densepose = torch.from_numpy(input_densepose).permute(2, 0, 1)
142 | target_densepose = torch.from_numpy(target_densepose).permute(2, 0, 1)
143 | # silhouette
144 | silhouette1 = torch.from_numpy(silhouette1).float().unsqueeze(0) # from h,w to c,h,w
145 | silhouette2 = torch.from_numpy(silhouette2).float().unsqueeze(0) # from h,w to c,h,w
146 |
147 | # put into a square
148 | input_image, input_densepose, silhouette1, Sleft, Sright = self.tensors2square(input_image, input_densepose, silhouette1)
149 | target_image, target_densepose, silhouette2, Tleft, Tright = self.tensors2square(target_image, target_densepose, silhouette2)
150 |
151 | if self.phase == 'train':
152 | # remove loaded center shift and add augmentation shift
153 | loaded_shift = int((orig_h-orig_w)/2)
154 | complete_coor = ((complete_coor+1)/2)*(orig_h-1) # [-1, 1] to [0, orig_h]
155 | complete_coor[:,:,0] = complete_coor[:,:,0] - loaded_shift # remove center shift
156 | complete_coor = ((2*complete_coor/(orig_h-1))-1) # [0, orig_h] (no shift in w) to [-1, 1]
157 | complete_coor = ((complete_coor+1)/2) * (self.size-1) # [-1, 1] to [0, size] (no shift in w)
158 | complete_coor[:,:,0] = complete_coor[:,:,0] + Sright # add augmentation shift to w
159 | complete_coor = ((2*complete_coor/(self.size-1))-1) # [0, size] (with shift in w) to [-1,1]
160 | # to tensor
161 | complete_coor = torch.from_numpy(complete_coor).float().permute(2, 0, 1)
162 | else:
163 | # might have hxw inconsistencies since dp is of different sizes.. fixing this..
164 | loaded_shift = int((orig_h-orig_w)/2)
165 | complete_coor = ((complete_coor+1)/2)*(orig_h-1) # [-1, 1] to [0, orig_h]
166 | complete_coor[:,:,0] = complete_coor[:,:,0] - loaded_shift # remove center shift
167 | # before: width complete_coor[:,:,0] 0-orig_w-1
168 | # and height complete_coor[:,:,1] 0-orig_h-1
169 | complete_coor[:,:,0] = (complete_coor[:,:,0]/(orig_w-1))*(target_w-1)
170 | complete_coor[:,:,1] = (complete_coor[:,:,1]/(orig_h-1))*(target_h-1)
171 | complete_coor[:,:,0] = complete_coor[:,:,0] + Sright # add center shift to w
172 | complete_coor = ((2*complete_coor/(self.size-1))-1) # [0, size] (with shift in w) to [-1,1]
173 | # to tensor
174 | complete_coor = torch.from_numpy(complete_coor).float().permute(2, 0, 1)
175 |
176 | # either source or target pass 1:5
177 | if self.phase == 'train':
178 | choice = random.randint(0, 6)
179 | if choice == 0:
180 | # source pass
181 | target_im = input_image
182 | target_p = input_densepose
183 | target_sil = silhouette1
184 | target_image_name = im1_name
185 | target_left_pad = Sleft
186 | target_right_pad = Sright
187 | else:
188 | # target pass
189 | target_im = target_image
190 | target_p = target_densepose
191 | target_sil = silhouette2
192 | target_image_name = im2_name
193 | target_left_pad = Tleft
194 | target_right_pad = Tright
195 | else:
196 | target_im = target_image
197 | target_p = target_densepose
198 | target_sil = silhouette2
199 | target_image_name = im2_name
200 | target_left_pad = Tleft
201 | target_right_pad = Tright
202 |
203 | # Get the face transfrom
204 | if self.phase == 'train':
205 | if target_image_name in self.faceTransform.keys():
206 | FT = torch.from_numpy(self.faceTransform[target_image_name]).float()
207 | else: # no face detected
208 | FT = torch.zeros((3,3))
209 |
210 | # return data
211 | if self.phase == 'train':
212 | return {'input_image':input_image, 'target_image':target_im,
213 | 'target_sil': target_sil,
214 | 'target_pose':target_p,
215 | 'TargetFaceTransform': FT, 'target_left_pad':torch.tensor(target_left_pad), 'target_right_pad':torch.tensor(target_right_pad),
216 | 'input_sil': silhouette1, 'complete_coor':complete_coor,
217 | }
218 |
219 | if self.phase == 'test':
220 | save_name = im1_name.split('.')[0] + '_2_' + im2_name.split('.')[0] + '_vis.png'
221 | return {'input_image':input_image, 'target_image':target_im,
222 | 'target_sil': target_sil,
223 | 'target_pose':target_p,
224 | 'target_left_pad':torch.tensor(target_left_pad), 'target_right_pad':torch.tensor(target_right_pad),
225 | 'input_sil': silhouette1, 'complete_coor':complete_coor,
226 | 'save_name':save_name,
227 | }
228 |
--------------------------------------------------------------------------------
/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 |
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils.data.sampler import Sampler
7 |
8 |
9 | def get_rank():
10 | if not dist.is_available():
11 | return 0
12 |
13 | if not dist.is_initialized():
14 | return 0
15 |
16 | return dist.get_rank()
17 |
18 |
19 | def synchronize():
20 | if not dist.is_available():
21 | return
22 |
23 | if not dist.is_initialized():
24 | return
25 |
26 | world_size = dist.get_world_size()
27 |
28 | if world_size == 1:
29 | return
30 |
31 | dist.barrier()
32 |
33 |
34 | def get_world_size():
35 | if not dist.is_available():
36 | return 1
37 |
38 | if not dist.is_initialized():
39 | return 1
40 |
41 | return dist.get_world_size()
42 |
43 |
44 | def reduce_sum(tensor):
45 | if not dist.is_available():
46 | return tensor
47 |
48 | if not dist.is_initialized():
49 | return tensor
50 |
51 | tensor = tensor.clone()
52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53 |
54 | return tensor
55 |
56 |
57 | def gather_grad(params):
58 | world_size = get_world_size()
59 |
60 | if world_size == 1:
61 | return
62 |
63 | for param in params:
64 | if param.grad is not None:
65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66 | param.grad.data.div_(world_size)
67 |
68 |
69 | def all_gather(data):
70 | world_size = get_world_size()
71 |
72 | if world_size == 1:
73 | return [data]
74 |
75 | buffer = pickle.dumps(data)
76 | storage = torch.ByteStorage.from_buffer(buffer)
77 | tensor = torch.ByteTensor(storage).to('cuda')
78 |
79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81 | dist.all_gather(size_list, local_size)
82 | size_list = [int(size.item()) for size in size_list]
83 | max_size = max(size_list)
84 |
85 | tensor_list = []
86 | for _ in size_list:
87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88 |
89 | if local_size != max_size:
90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91 | tensor = torch.cat((tensor, padding), 0)
92 |
93 | dist.all_gather(tensor_list, tensor)
94 |
95 | data_list = []
96 |
97 | for size, tensor in zip(size_list, tensor_list):
98 | buffer = tensor.cpu().numpy().tobytes()[:size]
99 | data_list.append(pickle.loads(buffer))
100 |
101 | return data_list
102 |
103 |
104 | def reduce_loss_dict(loss_dict):
105 | world_size = get_world_size()
106 |
107 | if world_size < 2:
108 | return loss_dict
109 |
110 | with torch.no_grad():
111 | keys = []
112 | losses = []
113 |
114 | for k in sorted(loss_dict.keys()):
115 | keys.append(k)
116 | losses.append(loss_dict[k])
117 |
118 | losses = torch.stack(losses, 0)
119 | dist.reduce(losses, dst=0)
120 |
121 | if dist.get_rank() == 0:
122 | losses /= world_size
123 |
124 | reduced_losses = {k: v for k, v in zip(keys, losses)}
125 |
126 | return reduced_losses
127 |
--------------------------------------------------------------------------------
/garment_transfer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torchvision import utils
5 | from tqdm import tqdm
6 | from torch.utils import data
7 | import numpy as np
8 | import random
9 | from PIL import Image
10 | import torchvision.transforms as transforms
11 | from dataset import DeepFashionDataset
12 | from model import Generator
13 | from util.dp2coor import getSymXYcoordinates
14 | from util.coordinate_completion_model import define_G as define_CCM
15 |
16 |
17 | def tensors2square(im, pose, sil):
18 | width = im.shape[2]
19 | diff = args.size - width
20 | left = int((args.size-width)/2)
21 | right = diff - left
22 | im = torch.nn.functional.pad(input=im, pad=(right, left, 0, 0), mode='constant', value=0)
23 | pose = torch.nn.functional.pad(input=pose, pad=(right, left, 0, 0), mode='constant', value=0)
24 | sil = torch.nn.functional.pad(input=sil, pad=(right, left, 0, 0), mode='constant', value=0)
25 | return im, pose, sil
26 |
27 | def tensor2square(x):
28 | width = x.shape[2]
29 | diff = args.size - width
30 | left = int((args.size-width)/2)
31 | right = diff - left
32 | x = torch.nn.functional.pad(input=x, pad=(right, left, 0, 0), mode='constant', value=0)
33 | return x
34 |
35 | def generate(args, g_ema, device, mean_latent):
36 | with torch.no_grad():
37 | g_ema.eval()
38 |
39 | path = args.input_path
40 | input_name = args.input_name
41 | target_name = args.target_name
42 | part = args.part
43 |
44 | # input
45 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
46 | input_image = Image.open(os.path.join(path, input_name+'.png')).convert('RGB')
47 | iw, ih = input_image.size
48 | input_image = transform(input_image).float().to(device)
49 | input_pose = np.array(Image.open(os.path.join(path, input_name+'_iuv.png')))
50 | input_sil = np.array(Image.open(os.path.join(path, input_name+'_sil.png')))/255
51 | # get partial coordinates from dense pose
52 | dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy')
53 | input_uv_coor, input_uv_mask, input_uv_symm_mask = getSymXYcoordinates(input_pose, resolution = 512)
54 | # union sil with densepose masks
55 | input_sil = 1-((1-input_sil) * (input_pose[:, :, 0] == 0).astype('float'))
56 | input_sil = torch.from_numpy(input_sil).float().unsqueeze(0)
57 | input_pose = torch.from_numpy(input_pose).permute(2, 0, 1)
58 |
59 | # target
60 | target_image = Image.open(os.path.join(path, target_name+'.png')).convert('RGB')
61 | tw, th = target_image.size
62 | target_image = transform(target_image).float().to(device)
63 | target_pose = np.array(Image.open(os.path.join(path, target_name+'_iuv.png')))
64 | target_sil = np.array(Image.open(os.path.join(path, target_name+'_sil.png')))/255
65 | # get partial coordinates from dense pose
66 | target_uv_coor, target_uv_mask, target_uv_symm_mask = getSymXYcoordinates(target_pose, resolution = 512)
67 | # union sil with densepose masks
68 | target_sil = 1-((1-target_sil) * (target_pose[:, :, 0] == 0).astype('float'))
69 | target_sil = torch.from_numpy(target_sil).float().unsqueeze(0)
70 | target_pose = torch.from_numpy(target_pose).permute(2, 0, 1)
71 |
72 | # convert to square by centering
73 | input_image, input_pose, input_sil = tensors2square(input_image, input_pose, input_sil)
74 | target_image, target_pose, target_sil = tensors2square(target_image, target_pose, target_sil)
75 |
76 | # add batch dimension
77 | input_image = input_image.unsqueeze(0).float().to(device)
78 | input_pose = input_pose.unsqueeze(0).float().to(device)
79 | input_sil = input_sil.unsqueeze(0).float().to(device)
80 | target_image = target_image.unsqueeze(0).float().to(device)
81 | target_pose = target_pose.unsqueeze(0).float().to(device)
82 | target_sil = target_sil.unsqueeze(0).float().to(device)
83 |
84 | # complete partial coordinates
85 | coor_completion_generator = define_CCM().cuda()
86 | CCM_checkpoint = torch.load(args.CCM_pretrained_model)
87 | coor_completion_generator.load_state_dict(CCM_checkpoint["g"])
88 | coor_completion_generator.eval()
89 | for param in coor_completion_generator.parameters():
90 | coor_completion_generator.requires_grad = False
91 |
92 | # uv coor preprocessing (put image in center)
93 | # input
94 | ishift = int((ih-iw)/2) # center shift
95 | input_uv_coor[:,:,0] = input_uv_coor[:,:,0] + ishift # put in center
96 | input_uv_coor = ((2*input_uv_coor/(ih-1))-1)
97 | input_uv_coor = input_uv_coor*np.expand_dims(input_uv_mask,2) + (-10*(1-np.expand_dims(input_uv_mask,2)))
98 | # target
99 | tshift = int((th-tw)/2) # center shift
100 | target_uv_coor[:,:,0] = target_uv_coor[:,:,0] + tshift # put in center
101 | target_uv_coor = ((2*target_uv_coor/(th-1))-1)
102 | target_uv_coor = target_uv_coor*np.expand_dims(target_uv_mask,2) + (-10*(1-np.expand_dims(target_uv_mask,2)))
103 |
104 | # coordinate completion
105 | # input
106 | uv_coor_pytorch = torch.from_numpy(input_uv_coor).float().permute(2, 0, 1).unsqueeze(0) # from h,w,c to 1,c,h,w
107 | uv_mask_pytorch = torch.from_numpy(input_uv_mask).unsqueeze(0).unsqueeze(0).float() #1xchw
108 | with torch.no_grad():
109 | coor_completion_generator.eval()
110 | input_complete_coor = coor_completion_generator(uv_coor_pytorch.cuda(), uv_mask_pytorch.cuda())
111 | # target
112 | uv_coor_pytorch = torch.from_numpy(target_uv_coor).float().permute(2, 0, 1).unsqueeze(0) # from h,w,c to 1,c,h,w
113 | uv_mask_pytorch = torch.from_numpy(target_uv_mask).unsqueeze(0).unsqueeze(0).float() #1xchw
114 | with torch.no_grad():
115 | coor_completion_generator.eval()
116 | target_complete_coor = coor_completion_generator(uv_coor_pytorch.cuda(), uv_mask_pytorch.cuda())
117 |
118 |
119 | # garment transfer
120 | appearance = torch.cat([input_image, input_sil, input_complete_coor, target_image, target_sil, target_complete_coor], 1)
121 | output, part_mask = g_ema(appearance=appearance, pose=input_pose)
122 |
123 | # visualize the transfered part
124 | zeros = torch.zeros(part_mask.shape).to(part_mask)
125 | ones255 = torch.ones(part_mask.shape).to(part_mask)*255
126 | part_red = torch.cat([part_mask*255, zeros, zeros], 1)
127 | part_img = part_red * ((input_pose[:, 0, :, :] != 0)) + torch.cat([ones255, ones255, ones255], 1)*(1-part_mask)
128 |
129 | utils.save_image(
130 | output[:, :, :, int(ishift):args.size-int(ishift)],
131 | os.path.join(args.save_path, input_name+'_and_'+target_name+'_'+args.part+'_vis.png'),
132 | nrow=1,
133 | normalize=True,
134 | range=(-1, 1),
135 | )
136 | utils.save_image(
137 | part_img[:, :, :, int(ishift):args.size-int(ishift)],
138 | os.path.join(args.save_path, input_name+'_and_'+target_name+'_'+args.part+'.png'),
139 | nrow=1,
140 | normalize=True,
141 | range=(0, 255),
142 | )
143 |
144 |
145 |
146 | if __name__ == "__main__":
147 | device = "cuda"
148 |
149 | parser = argparse.ArgumentParser(description="inference")
150 |
151 | parser.add_argument("--input_path", type=str, help="path to the input dataset")
152 | parser.add_argument("--input_name", type=str, default="fashionWOMENSkirtsid0000177102_1front", help="input file name")
153 | parser.add_argument("--target_name", type=str, default="fashionWOMENBlouses_Shirtsid0000635004_1front", help="target file name")
154 | parser.add_argument("--part", type=str, default="upper_body", help="body part to transfer upper_body, lower_body, full_body, and face")
155 | parser.add_argument("--size", type=int, default=512, help="output image size of the generator")
156 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
157 | parser.add_argument("--truncation_mean", type=int, default=4096, help="number of vectors to calculate mean for the truncation")
158 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of the generator. config-f = 2, else = 1")
159 | parser.add_argument("--pretrained_model", type=str, default="posewithstyle.pt", help="pose with style pretrained model")
160 | parser.add_argument("--CCM_pretrained_model", type=str, default="CCM_epoch50.pt", help="pretrained coordinate completion model")
161 | parser.add_argument("--save_path", type=str, default="./data/output", help="path to save output .data/output")
162 |
163 | args = parser.parse_args()
164 |
165 | args.latent = 2048
166 | args.n_mlp = 8
167 |
168 | if not os.path.exists(args.save_path):
169 | os.makedirs(args.save_path)
170 |
171 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, garment_transfer=True, part=args.part).to(device)
172 | checkpoint = torch.load(args.pretrained_model)
173 | g_ema.load_state_dict(checkpoint["g_ema"])
174 |
175 | if args.truncation < 1:
176 | with torch.no_grad():
177 | mean_latent = g_ema.mean_latent(args.truncation_mean)
178 | else:
179 | mean_latent = None
180 |
181 | generate(args, g_ema, device, mean_latent)
182 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torchvision import utils
5 | from tqdm import tqdm
6 | from torch.utils import data
7 | import numpy as np
8 | import random
9 | from PIL import Image
10 | import torchvision.transforms as transforms
11 | from dataset import DeepFashionDataset
12 | from model import Generator
13 | from util.dp2coor import getSymXYcoordinates
14 | from util.coordinate_completion_model import define_G as define_CCM
15 |
16 | def tensors2square(im, pose, sil):
17 | width = im.shape[2]
18 | diff = args.size - width
19 | left = int((args.size-width)/2)
20 | right = diff - left
21 | im = torch.nn.functional.pad(input=im, pad=(right, left, 0, 0), mode='constant', value=0)
22 | pose = torch.nn.functional.pad(input=pose, pad=(right, left, 0, 0), mode='constant', value=0)
23 | sil = torch.nn.functional.pad(input=sil, pad=(right, left, 0, 0), mode='constant', value=0)
24 | return im, pose, sil
25 |
26 | def tensor2square(x):
27 | width = x.shape[2]
28 | diff = args.size - width
29 | left = int((args.size-width)/2)
30 | right = diff - left
31 | x = torch.nn.functional.pad(input=x, pad=(right, left, 0, 0), mode='constant', value=0)
32 | return x
33 |
34 | def generate(args, g_ema, device, mean_latent):
35 | with torch.no_grad():
36 | g_ema.eval()
37 |
38 | path = args.input_path
39 | input_name = args.input_name
40 | pose_name = args.target_name
41 |
42 | # input
43 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
44 | input_image = Image.open(os.path.join(path, input_name+'.png')).convert('RGB')
45 | w, h = input_image.size
46 | input_image = transform(input_image).float().to(device)
47 |
48 | input_pose = np.array(Image.open(os.path.join(path, input_name+'_iuv.png')))
49 | input_sil = np.array(Image.open(os.path.join(path, input_name+'_sil.png')))/255
50 |
51 | # get partial coordinates from dense pose
52 | dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy')
53 | uv_coor, uv_mask, uv_symm_mask = getSymXYcoordinates(input_pose, resolution = 512)
54 |
55 | # union sil with densepose masks
56 | input_sil = 1-((1-input_sil) * (input_pose[:, :, 0] == 0).astype('float'))
57 |
58 | input_sil = torch.from_numpy(input_sil).float().unsqueeze(0)
59 | input_pose = torch.from_numpy(input_pose).permute(2, 0, 1)
60 |
61 | # target
62 | target_pose = np.array(Image.open(os.path.join(path, pose_name+'_iuv.png')))
63 | target_pose = torch.from_numpy(target_pose).permute(2, 0, 1)
64 |
65 | # convert to square by centering
66 | input_image, input_pose, input_sil = tensors2square(input_image, input_pose, input_sil)
67 | target_pose = tensor2square(target_pose)
68 |
69 | # add batch dimension
70 | input_image = input_image.unsqueeze(0).float().to(device)
71 | input_pose = input_pose.unsqueeze(0).float().to(device)
72 | input_sil = input_sil.unsqueeze(0).float().to(device)
73 | target_pose = target_pose.unsqueeze(0).float().to(device)
74 |
75 | # complete partial coordinates
76 | coor_completion_generator = define_CCM().cuda()
77 | CCM_checkpoint = torch.load(args.CCM_pretrained_model)
78 | coor_completion_generator.load_state_dict(CCM_checkpoint["g"])
79 | coor_completion_generator.eval()
80 | for param in coor_completion_generator.parameters():
81 | coor_completion_generator.requires_grad = False
82 |
83 | # uv coor preprocessing (put image in center)
84 | shift = int((h-w)/2) # center shift
85 | uv_coor[:,:,0] = uv_coor[:,:,0] + shift # put in center
86 | uv_coor = ((2*uv_coor/(h-1))-1)
87 | uv_coor = uv_coor*np.expand_dims(uv_mask,2) + (-10*(1-np.expand_dims(uv_mask,2)))
88 |
89 | # coordinate completion
90 | uv_coor_pytorch = torch.from_numpy(uv_coor).float().permute(2, 0, 1).unsqueeze(0) # from h,w,c to 1,c,h,w
91 | uv_mask_pytorch = torch.from_numpy(uv_mask).unsqueeze(0).unsqueeze(0).float() #1xchw
92 | with torch.no_grad():
93 | coor_completion_generator.eval()
94 | complete_coor = coor_completion_generator(uv_coor_pytorch.cuda(), uv_mask_pytorch.cuda())
95 |
96 | # reposing
97 | appearance = torch.cat([input_image, input_sil, complete_coor], 1)
98 | output, _ = g_ema(appearance=appearance, pose=target_pose)
99 |
100 | utils.save_image(
101 | output[:, :, :, int(shift):args.size-int(shift)],
102 | os.path.join(args.save_path, input_name+'_2_'+pose_name+'_vis.png'),
103 | nrow=1,
104 | normalize=True,
105 | range=(-1, 1),
106 | )
107 |
108 |
109 |
110 | if __name__ == "__main__":
111 | device = "cuda"
112 |
113 | parser = argparse.ArgumentParser(description="inference")
114 |
115 | parser.add_argument("--input_path", type=str, help="path to the input dataset")
116 | parser.add_argument("--input_name", type=str, default="fashionWOMENDressesid0000262902_3back", help="input file name")
117 | parser.add_argument("--target_name", type=str, default="fashionWOMENDressesid0000262902_1front", help="target file name")
118 | parser.add_argument("--size", type=int, default=512, help="output image size of the generator")
119 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
120 | parser.add_argument("--truncation_mean", type=int, default=4096, help="number of vectors to calculate mean for the truncation")
121 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of the generator. config-f = 2, else = 1")
122 | parser.add_argument("--pretrained_model", type=str, default="posewithstyle.pt", help="pose with style pretrained model")
123 | parser.add_argument("--CCM_pretrained_model", type=str, default="CCM_epoch50.pt", help="pretrained coordinate completion model")
124 | parser.add_argument("--save_path", type=str, default="./data/output", help="path to save output .data/output")
125 |
126 | args = parser.parse_args()
127 |
128 | args.latent = 2048
129 | args.n_mlp = 8
130 |
131 | if not os.path.exists(args.save_path):
132 | os.makedirs(args.save_path)
133 |
134 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device)
135 | checkpoint = torch.load(args.pretrained_model)
136 | g_ema.load_state_dict(checkpoint["g_ema"])
137 |
138 | if args.truncation < 1:
139 | with torch.no_grad():
140 | mean_latent = g_ema.mean_latent(args.truncation_mean)
141 | else:
142 | mean_latent = None
143 |
144 | generate(args, g_ema, device, mean_latent)
145 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import functools
4 | import operator
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from torch.autograd import Function
10 |
11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12 |
13 |
14 | class PixelNorm(nn.Module):
15 | def __init__(self):
16 | super().__init__()
17 |
18 | def forward(self, input):
19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20 |
21 |
22 | def make_kernel(k):
23 | k = torch.tensor(k, dtype=torch.float32)
24 |
25 | if k.ndim == 1:
26 | k = k[None, :] * k[:, None]
27 |
28 | k /= k.sum()
29 |
30 | return k
31 |
32 |
33 | class Upsample(nn.Module):
34 | def __init__(self, kernel, factor=2):
35 | super().__init__()
36 |
37 | self.factor = factor
38 | kernel = make_kernel(kernel) * (factor ** 2)
39 | self.register_buffer("kernel", kernel)
40 |
41 | p = kernel.shape[0] - factor
42 |
43 | pad0 = (p + 1) // 2 + factor - 1
44 | pad1 = p // 2
45 |
46 | self.pad = (pad0, pad1)
47 |
48 | def forward(self, input):
49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50 |
51 | return out
52 |
53 |
54 | class Downsample(nn.Module):
55 | def __init__(self, kernel, factor=2):
56 | super().__init__()
57 |
58 | self.factor = factor
59 | kernel = make_kernel(kernel)
60 | self.register_buffer("kernel", kernel)
61 |
62 | p = kernel.shape[0] - factor
63 |
64 | pad0 = (p + 1) // 2
65 | pad1 = p // 2
66 |
67 | self.pad = (pad0, pad1)
68 |
69 | def forward(self, input):
70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71 |
72 | return out
73 |
74 |
75 | class Blur(nn.Module):
76 | def __init__(self, kernel, pad, upsample_factor=1):
77 | super().__init__()
78 |
79 | kernel = make_kernel(kernel)
80 |
81 | if upsample_factor > 1:
82 | kernel = kernel * (upsample_factor ** 2)
83 |
84 | self.register_buffer("kernel", kernel)
85 |
86 | self.pad = pad
87 |
88 | def forward(self, input):
89 | out = upfirdn2d(input, self.kernel, pad=self.pad)
90 |
91 | return out
92 |
93 |
94 | class EqualConv2d(nn.Module):
95 | def __init__(
96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97 | ):
98 | super().__init__()
99 |
100 | self.weight = nn.Parameter(
101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102 | )
103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104 |
105 | self.stride = stride
106 | self.padding = padding
107 |
108 | if bias:
109 | self.bias = nn.Parameter(torch.zeros(out_channel))
110 |
111 | else:
112 | self.bias = None
113 |
114 | def forward(self, input):
115 | out = conv2d_gradfix.conv2d(
116 | input,
117 | self.weight * self.scale,
118 | bias=self.bias,
119 | stride=self.stride,
120 | padding=self.padding,
121 | )
122 |
123 | return out
124 |
125 | def __repr__(self):
126 | return (
127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129 | )
130 |
131 |
132 | class EqualLinear(nn.Module):
133 | def __init__(
134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135 | ):
136 | super().__init__()
137 |
138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139 |
140 | if bias:
141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142 |
143 | else:
144 | self.bias = None
145 |
146 | self.activation = activation
147 |
148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149 | self.lr_mul = lr_mul
150 |
151 | def forward(self, input):
152 | if self.activation:
153 | out = F.linear(input, self.weight * self.scale)
154 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
155 |
156 | else:
157 | out = F.linear(
158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
159 | )
160 |
161 | return out
162 |
163 | def __repr__(self):
164 | return (
165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166 | )
167 |
168 |
169 | class ModulatedConv2d(nn.Module):
170 | def __init__(
171 | self,
172 | in_channel,
173 | out_channel,
174 | kernel_size,
175 | style_dim,
176 | demodulate=True,
177 | upsample=False,
178 | downsample=False,
179 | blur_kernel=[1, 3, 3, 1],
180 | fused=True,
181 | ):
182 | super().__init__()
183 |
184 | self.eps = 1e-8
185 | self.kernel_size = kernel_size
186 | self.in_channel = in_channel
187 | self.out_channel = out_channel
188 | self.upsample = upsample
189 | self.downsample = downsample
190 |
191 | if upsample:
192 | factor = 2
193 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
194 | pad0 = (p + 1) // 2 + factor - 1
195 | pad1 = p // 2 + 1
196 |
197 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
198 |
199 | if downsample:
200 | factor = 2
201 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
202 | pad0 = (p + 1) // 2
203 | pad1 = p // 2
204 |
205 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
206 |
207 | fan_in = in_channel * kernel_size ** 2
208 | self.scale = 1 / math.sqrt(fan_in)
209 | self.padding = kernel_size // 2
210 |
211 | self.weight = nn.Parameter(
212 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
213 | )
214 |
215 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216 |
217 | self.demodulate = demodulate
218 | self.fused = fused
219 |
220 | def __repr__(self):
221 | return (
222 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
223 | f"upsample={self.upsample}, downsample={self.downsample})"
224 | )
225 |
226 | def forward(self, input, style):
227 | batch, in_channel, height, width = input.shape
228 |
229 | if not self.fused:
230 | weight = self.scale * self.weight.squeeze(0)
231 | style = self.modulation(style)
232 |
233 | if self.demodulate:
234 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
235 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
236 |
237 | input = input * style.reshape(batch, in_channel, 1, 1)
238 |
239 | if self.upsample:
240 | weight = weight.transpose(0, 1)
241 | out = conv2d_gradfix.conv_transpose2d(
242 | input, weight, padding=0, stride=2
243 | )
244 | out = self.blur(out)
245 |
246 | elif self.downsample:
247 | input = self.blur(input)
248 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
249 |
250 | else:
251 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
252 |
253 | if self.demodulate:
254 | out = out * dcoefs.view(batch, -1, 1, 1)
255 |
256 | return out
257 |
258 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
259 | weight = self.scale * self.weight * style
260 |
261 | if self.demodulate:
262 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
263 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
264 |
265 | weight = weight.view(
266 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
267 | )
268 |
269 | if self.upsample:
270 | input = input.view(1, batch * in_channel, height, width)
271 | weight = weight.view(
272 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
273 | )
274 | weight = weight.transpose(1, 2).reshape(
275 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
276 | )
277 | out = conv2d_gradfix.conv_transpose2d(
278 | input, weight, padding=0, stride=2, groups=batch
279 | )
280 | _, _, height, width = out.shape
281 | out = out.view(batch, self.out_channel, height, width)
282 | out = self.blur(out)
283 |
284 | elif self.downsample:
285 | input = self.blur(input)
286 | _, _, height, width = input.shape
287 | input = input.view(1, batch * in_channel, height, width)
288 | out = conv2d_gradfix.conv2d(
289 | input, weight, padding=0, stride=2, groups=batch
290 | )
291 | _, _, height, width = out.shape
292 | out = out.view(batch, self.out_channel, height, width)
293 |
294 | else:
295 | input = input.view(1, batch * in_channel, height, width)
296 | out = conv2d_gradfix.conv2d(
297 | input, weight, padding=self.padding, groups=batch
298 | )
299 | _, _, height, width = out.shape
300 | out = out.view(batch, self.out_channel, height, width)
301 |
302 | return out
303 |
304 |
305 | class SpatiallyModulatedConv2d(nn.Module):
306 | def __init__(
307 | self,
308 | in_channel,
309 | out_channel,
310 | kernel_size,
311 | upsample=False,
312 | downsample=False,
313 | blur_kernel=[1, 3, 3, 1],
314 | ):
315 | super().__init__()
316 |
317 | self.eps = 1e-8
318 | self.kernel_size = kernel_size
319 | self.in_channel = in_channel
320 | self.out_channel = out_channel
321 | self.upsample = upsample
322 | self.downsample = downsample
323 |
324 | if upsample:
325 | factor = 2
326 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
327 | pad0 = (p + 1) // 2 + factor - 1
328 | pad1 = p // 2 + 1
329 |
330 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
331 |
332 | if downsample:
333 | factor = 2
334 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
335 | pad0 = (p + 1) // 2
336 | pad1 = p // 2
337 |
338 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
339 |
340 | fan_in = in_channel * kernel_size ** 2
341 | self.scale = 1 / math.sqrt(fan_in)
342 | self.padding = kernel_size // 2
343 |
344 | self.weight = nn.Parameter(
345 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
346 | )
347 |
348 | self.gamma = nn.Sequential(*[EqualConv2d(in_channel, 128, kernel_size=1), nn.ReLU(True), EqualConv2d(128, in_channel, kernel_size=1)])
349 | self.beta = nn.Sequential(*[EqualConv2d(in_channel, 128, kernel_size=1), nn.ReLU(True), EqualConv2d(128, in_channel, kernel_size=1)])
350 |
351 | def calc_mean_std(self, feat, eps=1e-5):
352 | size = feat.size()
353 | assert (len(size) == 4)
354 | N, C = size[:2]
355 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
356 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
357 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
358 | return feat_mean, feat_std
359 |
360 | def modulate(self, x, gamma, beta):
361 | return gamma * x + beta
362 |
363 | def normalize(self, x):
364 | mean, std = self.calc_mean_std(x)
365 | mean = mean.expand_as(x)
366 | std = std.expand_as(x)
367 | return (x-mean)/std
368 |
369 | def forward(self, input, style):
370 | batch, in_channel, height, width = input.shape
371 |
372 | weight = self.scale * self.weight.squeeze(0)
373 |
374 | gamma = self.gamma(style)
375 | beta = self.beta(style)
376 |
377 | input = self.modulate(input, gamma, beta)
378 |
379 | if self.upsample:
380 | weight = weight.transpose(0, 1)
381 | out = conv2d_gradfix.conv_transpose2d(
382 | input, weight, padding=0, stride=2
383 | )
384 | out = self.blur(out)
385 | elif self.downsample:
386 | input = self.blur(input)
387 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
388 | else:
389 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
390 |
391 | out = self.normalize(out)
392 |
393 | return out
394 |
395 |
396 | class NoiseInjection(nn.Module):
397 | def __init__(self):
398 | super().__init__()
399 |
400 | self.weight = nn.Parameter(torch.zeros(1))
401 |
402 | def forward(self, image, noise=None):
403 | if noise is None:
404 | batch, _, height, width = image.shape
405 | noise = image.new_empty(batch, 1, height, width).normal_()
406 |
407 | return image + self.weight * noise
408 |
409 |
410 | class ConstantInput(nn.Module):
411 | def __init__(self, channel, size=4):
412 | super().__init__()
413 |
414 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
415 |
416 | def forward(self, input):
417 | batch = input.shape[0]
418 | out = self.input.repeat(batch, 1, 1, 1)
419 |
420 | return out
421 |
422 |
423 | class StyledConv(nn.Module):
424 | def __init__(
425 | self,
426 | in_channel,
427 | out_channel,
428 | kernel_size,
429 | style_dim,
430 | upsample=False,
431 | blur_kernel=[1, 3, 3, 1],
432 | demodulate=True,
433 | spatial=False,
434 | ):
435 | super().__init__()
436 |
437 | if spatial:
438 | self.conv = SpatiallyModulatedConv2d(
439 | in_channel,
440 | out_channel,
441 | kernel_size,
442 | upsample=upsample,
443 | blur_kernel=blur_kernel,
444 | )
445 | else:
446 | self.conv = ModulatedConv2d(
447 | in_channel,
448 | out_channel,
449 | kernel_size,
450 | style_dim,
451 | upsample=upsample,
452 | blur_kernel=blur_kernel,
453 | demodulate=demodulate,
454 | )
455 |
456 | self.noise = NoiseInjection()
457 | self.activate = FusedLeakyReLU(out_channel)
458 |
459 | def forward(self, input, style, noise=None):
460 | out = self.conv(input, style)
461 | out = self.noise(out, noise=noise)
462 | out = self.activate(out)
463 |
464 | return out
465 |
466 |
467 | class ToRGB(nn.Module):
468 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], spatial=False):
469 | super().__init__()
470 |
471 | if upsample:
472 | self.upsample = Upsample(blur_kernel)
473 |
474 | if spatial:
475 | self.conv = SpatiallyModulatedConv2d(in_channel, 3, 1)
476 | else:
477 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
478 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
479 |
480 | def forward(self, input, style, skip=None):
481 | out = self.conv(input, style)
482 | out = out + self.bias
483 |
484 | if skip is not None:
485 | skip = self.upsample(skip)
486 |
487 | out = out + skip
488 |
489 | return out
490 |
491 |
492 | class PoseEncoder(nn.Module):
493 | def __init__(self, ngf=64, blur_kernel=[1, 3, 3, 1], size=256):
494 | super().__init__()
495 | self.size = size
496 | convs = [ConvLayer(3, ngf, 1)]
497 | convs.append(ResBlock(ngf, ngf*2, blur_kernel))
498 | convs.append(ResBlock(ngf*2, ngf*4, blur_kernel))
499 | convs.append(ResBlock(ngf*4, ngf*8, blur_kernel))
500 | convs.append(ResBlock(ngf*8, ngf*8, blur_kernel))
501 | if self.size == 512:
502 | convs.append(ResBlock(ngf*8, ngf*8, blur_kernel))
503 | if self.size == 1024:
504 | convs.append(ResBlock(ngf*8, ngf*8, blur_kernel))
505 | convs.append(ResBlock(ngf*8, ngf*8, blur_kernel))
506 |
507 | self.convs = nn.Sequential(*convs)
508 |
509 | def forward(self, input):
510 | out = self.convs(input)
511 | return out
512 |
513 |
514 | class SpatialAppearanceEncoder(nn.Module):
515 | def __init__(self, ngf=64, blur_kernel=[1, 3, 3, 1], size=256):
516 | super().__init__()
517 | self.size = size
518 | self.dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy')
519 | input_nc = 4 # source RGB and sil
520 |
521 | self.conv1 = ConvLayer(input_nc, ngf, 1) # ngf 256 256
522 | self.conv2 = ResBlock(ngf, ngf*2, blur_kernel) # 2ngf 128 128
523 | self.conv3 = ResBlock(ngf*2, ngf*4, blur_kernel) # 4ngf 64 64
524 | self.conv4 = ResBlock(ngf*4, ngf*8, blur_kernel) # 8ngf 32 32
525 | self.conv5 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16
526 | if self.size == 512:
527 | self.conv6 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 - starting from ngf 512 512
528 | if self.size == 1024:
529 | self.conv6 = ResBlock(ngf*8, ngf*8, blur_kernel)
530 | self.conv7 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 - starting from ngf 1024 0124
531 |
532 | self.conv11 = EqualConv2d(ngf+1, ngf*8, 1)
533 | self.conv21 = EqualConv2d(ngf*2+1, ngf*8, 1)
534 | self.conv31 = EqualConv2d(ngf*4+1, ngf*8, 1)
535 | self.conv41 = EqualConv2d(ngf*8+1, ngf*8, 1)
536 | self.conv51 = EqualConv2d(ngf*8+1, ngf*8, 1)
537 | if self.size == 512:
538 | self.conv61 = EqualConv2d(ngf*8+1, ngf*8, 1)
539 | if self.size == 1024:
540 | self.conv61 = EqualConv2d(ngf*8+1, ngf*8, 1)
541 | self.conv71 = EqualConv2d(ngf*8+1, ngf*8, 1)
542 |
543 | if self.size == 1024:
544 | self.conv13 = EqualConv2d(ngf*8, int(ngf/2), 3, padding=1)
545 | self.conv23 = EqualConv2d(ngf*8, ngf*1, 3, padding=1)
546 | self.conv33 = EqualConv2d(ngf*8, ngf*2, 3, padding=1)
547 | self.conv43 = EqualConv2d(ngf*8, ngf*4, 3, padding=1)
548 | self.conv53 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
549 | self.conv63 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
550 | elif self.size == 512:
551 | self.conv13 = EqualConv2d(ngf*8, ngf*1, 3, padding=1)
552 | self.conv23 = EqualConv2d(ngf*8, ngf*2, 3, padding=1)
553 | self.conv33 = EqualConv2d(ngf*8, ngf*4, 3, padding=1)
554 | self.conv43 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
555 | self.conv53 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
556 | else:
557 | self.conv13 = EqualConv2d(ngf*8, ngf*2, 3, padding=1)
558 | self.conv23 = EqualConv2d(ngf*8, ngf*4, 3, padding=1)
559 | self.conv33 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
560 | self.conv43 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
561 |
562 | self.up = nn.Upsample(scale_factor=2)
563 |
564 | def uv2target(self, dp, tex, tex_mask=None):
565 | # uv2img
566 | b, _, h, w = dp.shape
567 | ch = tex.shape[1]
568 | UV_MAPs = []
569 | for idx in range(b):
570 | iuv = dp[idx]
571 | point_pos = iuv[0, :, :] > 0
572 | iuv_raw_i = iuv[0][point_pos]
573 | iuv_raw_u = iuv[1][point_pos]
574 | iuv_raw_v = iuv[2][point_pos]
575 | iuv_raw = torch.stack([iuv_raw_i,iuv_raw_u,iuv_raw_v], 0)
576 | i = iuv_raw[0, :] - 1 ## dp_uv_lookup_256_np does not contain BG class
577 | u = iuv_raw[1, :]
578 | v = iuv_raw[2, :]
579 | i = i.cpu().numpy()
580 | u = u.cpu().numpy()
581 | v = v.cpu().numpy()
582 | uv_smpl = self.dp_uv_lookup_256_np[i, v, u]
583 | uv_map = torch.zeros((2,h,w)).to(tex).float()
584 | ## being normalize [0,1] to [-1,1] for the grid sample of Pytorch
585 | u_map = uv_smpl[:, 0] * 2 - 1
586 | v_map = (1 - uv_smpl[:, 1]) * 2 - 1
587 | uv_map[0][point_pos] = torch.from_numpy(u_map).to(tex).float()
588 | uv_map[1][point_pos] = torch.from_numpy(v_map).to(tex).float()
589 | UV_MAPs.append(uv_map)
590 | uv_map = torch.stack(UV_MAPs, 0)
591 | # warping
592 | # before warping validate sizes
593 | _, _, h_x, w_x = tex.shape
594 | _, _, h_t, w_t = uv_map.shape
595 | if h_t != h_x or w_t != w_x:
596 | #https://github.com/iPERDance/iPERCore/blob/4a010f781a4fb90dd29a516472e4aadf41ed1609/iPERCore/models/networks/generators/lwb_avg_resunet.py#L55
597 | uv_map = torch.nn.functional.interpolate(uv_map, size=(h_x, w_x), mode='bilinear', align_corners=True)
598 | uv_map = uv_map.permute(0, 2, 3, 1)
599 | warped_image = torch.nn.functional.grid_sample(tex.float(), uv_map.float())
600 | if tex_mask is not None:
601 | warped_mask = torch.nn.functional.grid_sample(tex_mask.float(), uv_map.float())
602 | final_warped = warped_image * warped_mask
603 | return final_warped, warped_mask
604 | else:
605 | return warped_image
606 |
607 | def forward(self, input, pose):
608 | coor = input[:,4:,:,:]
609 | input = input[:,:4,:,:]
610 |
611 | x1 = self.conv1(input)
612 | x2 = self.conv2(x1)
613 | x3 = self.conv3(x2)
614 | x4 = self.conv4(x3)
615 | x5 = self.conv5(x4)
616 | if self.size == 512:
617 | x6 = self.conv6(x5)
618 | if self.size == 1024:
619 | x6 = self.conv6(x5)
620 | x7 = self.conv7(x6)
621 |
622 | # warp- get flow
623 | pose_mask = 1-(pose[:,0, :, :] == 0).float().unsqueeze(1)
624 | flow = self.uv2target(pose.int(), coor)
625 | # warp- resize flow
626 | f1 = torch.nn.functional.interpolate(flow, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True)
627 | f2 = torch.nn.functional.interpolate(flow, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True)
628 | f3 = torch.nn.functional.interpolate(flow, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True)
629 | f4 = torch.nn.functional.interpolate(flow, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True)
630 | f5 = torch.nn.functional.interpolate(flow, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True)
631 | if self.size == 512:
632 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
633 | if self.size == 1024:
634 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
635 | f7 = torch.nn.functional.interpolate(flow, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True)
636 | # warp- now warp
637 | x1 = torch.nn.functional.grid_sample(x1, f1.permute(0,2,3,1))
638 | x2 = torch.nn.functional.grid_sample(x2, f2.permute(0,2,3,1))
639 | x3 = torch.nn.functional.grid_sample(x3, f3.permute(0,2,3,1))
640 | x4 = torch.nn.functional.grid_sample(x4, f4.permute(0,2,3,1))
641 | x5 = torch.nn.functional.grid_sample(x5, f5.permute(0,2,3,1))
642 | if self.size == 512:
643 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1))
644 | if self.size == 1024:
645 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1))
646 | x7 = torch.nn.functional.grid_sample(x7, f7.permute(0,2,3,1))
647 |
648 | # mask features
649 | p1 = torch.nn.functional.interpolate(pose_mask, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True)
650 | p2 = torch.nn.functional.interpolate(pose_mask, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True)
651 | p3 = torch.nn.functional.interpolate(pose_mask, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True)
652 | p4 = torch.nn.functional.interpolate(pose_mask, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True)
653 | p5 = torch.nn.functional.interpolate(pose_mask, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True)
654 | if self.size == 512:
655 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
656 | if self.size == 1024:
657 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
658 | p7 = torch.nn.functional.interpolate(pose_mask, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True)
659 |
660 | x1 = x1 * p1
661 | x2 = x2 * p2
662 | x3 = x3 * p3
663 | x4 = x4 * p4
664 | x5 = x5 * p5
665 | if self.size == 512:
666 | x6 = x6 * p6
667 | if self.size == 1024:
668 | x6 = x6 * p6
669 | x7 = x7 * p7
670 |
671 | # fpn
672 | if self.size == 1024:
673 | F7 = self.conv71(torch.cat([x7,p7], 1))
674 | f6 = self.up(F7)+self.conv61(torch.cat([x6,p6], 1))
675 | F6 = self.conv63(f6)
676 | f5 = self.up(F6)+self.conv51(torch.cat([x5,p5], 1))
677 | F5 = self.conv53(f5)
678 | f4 = self.up(f5)+self.conv41(torch.cat([x4,p4], 1))
679 | F4 = self.conv43(f4)
680 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1))
681 | F3 = self.conv33(f3)
682 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1))
683 | F2 = self.conv23(f2)
684 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1))
685 | F1 = self.conv13(f1)
686 | elif self.size == 512:
687 | F6 = self.conv61(torch.cat([x6,p6], 1))
688 | f5 = self.up(F6)+self.conv51(torch.cat([x5,p5], 1))
689 | F5 = self.conv53(f5)
690 | f4 = self.up(f5)+self.conv41(torch.cat([x4,p4], 1))
691 | F4 = self.conv43(f4)
692 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1))
693 | F3 = self.conv33(f3)
694 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1))
695 | F2 = self.conv23(f2)
696 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1))
697 | F1 = self.conv13(f1)
698 | else:
699 | F5 = self.conv51(torch.cat([x5,p5], 1))
700 | f4 = self.up(F5)+self.conv41(torch.cat([x4,p4], 1))
701 | F4 = self.conv43(f4)
702 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1))
703 | F3 = self.conv33(f3)
704 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1))
705 | F2 = self.conv23(f2)
706 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1))
707 | F1 = self.conv13(f1)
708 |
709 | if self.size == 1024:
710 | return [F7, F6, F5, F4, F3, F2, F1]
711 | elif self.size == 512:
712 | return [F6, F5, F4, F3, F2, F1]
713 | else:
714 | return [F5, F4, F3, F2, F1]
715 |
716 |
717 | class Generator(nn.Module):
718 | def __init__(
719 | self,
720 | size,
721 | style_dim,
722 | n_mlp,
723 | channel_multiplier=2,
724 | blur_kernel=[1, 3, 3, 1],
725 | lr_mlp=0.01,
726 | garment_transfer=False,
727 | part='upper_body',
728 | ):
729 | super().__init__()
730 |
731 | self.garment_transfer = garment_transfer
732 | self.size = size
733 | self.style_dim = style_dim
734 |
735 | if self.garment_transfer:
736 | self.appearance_encoder = GarmentTransferSpatialAppearanceEncoder(size=size, part=part)
737 | else:
738 | self.appearance_encoder = SpatialAppearanceEncoder(size=size)
739 | self.pose_encoder = PoseEncoder(size=size)
740 |
741 | # StyleGAN
742 | self.channels = {
743 | 16: 512,
744 | 32: 512,
745 | 64: 256 * channel_multiplier,
746 | 128: 128 * channel_multiplier,
747 | 256: 64 * channel_multiplier,
748 | 512: 32 * channel_multiplier,
749 | 1024: 16 * channel_multiplier,
750 | }
751 |
752 | self.conv1 = StyledConv(
753 | self.channels[16], self.channels[16], 3, style_dim, blur_kernel=blur_kernel, spatial=True
754 | )
755 | self.to_rgb1 = ToRGB(self.channels[16], style_dim, upsample=False, spatial=True)
756 |
757 | self.log_size = int(math.log(size, 2))
758 | self.num_layers = (self.log_size - 4) * 2 + 1
759 |
760 | self.convs = nn.ModuleList()
761 | self.upsamples = nn.ModuleList()
762 | self.to_rgbs = nn.ModuleList()
763 | self.noises = nn.Module()
764 |
765 | in_channel = self.channels[16]
766 |
767 | for i in range(5, self.log_size + 1):
768 | out_channel = self.channels[2 ** i]
769 |
770 | self.convs.append(
771 | StyledConv(
772 | in_channel,
773 | out_channel,
774 | 3,
775 | style_dim,
776 | upsample=True,
777 | blur_kernel=blur_kernel,
778 | spatial=True,
779 | )
780 | )
781 |
782 | self.convs.append(
783 | StyledConv(
784 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel, spatial=True,
785 | )
786 | )
787 |
788 | self.to_rgbs.append(ToRGB(out_channel, style_dim, spatial=True))
789 |
790 | in_channel = out_channel
791 |
792 | self.n_latent = self.log_size * 2 - 2
793 |
794 |
795 | def make_noise(self):
796 | device = self.input.input.device
797 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
798 | for i in range(3, self.log_size + 1):
799 | for _ in range(2):
800 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
801 | return noises
802 |
803 | def mean_latent(self, n_latent):
804 | latent_in = torch.randn(
805 | n_latent, self.style_dim, device=self.input.input.device
806 | )
807 | latent = self.style(latent_in).mean(0, keepdim=True)
808 | return latent
809 |
810 | def get_latent(self, input):
811 | return self.style(input)
812 |
813 | def forward(
814 | self,
815 | pose,
816 | appearance,
817 | styles=None,
818 | return_latents=False,
819 | inject_index=None,
820 | truncation=1,
821 | truncation_latent=None,
822 | input_is_latent=False,
823 | noise=None,
824 | randomize_noise=True,
825 | ):
826 |
827 | if self.garment_transfer:
828 | styles, part_mask = self.appearance_encoder(appearance, pose)
829 | else:
830 | styles = self.appearance_encoder(appearance, pose)
831 |
832 | if noise is None:
833 | if randomize_noise:
834 | noise = [None] * self.num_layers
835 | else:
836 | noise = [
837 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
838 | ]
839 |
840 | latent = [styles[0], styles[0]]
841 | if self.size == 1024:
842 | length = 6
843 | elif self.size == 512:
844 | length = 5
845 | else:
846 | length = 4
847 | for i in range(length):
848 | latent += [styles[i+1],styles[i+1]]
849 |
850 | out = self.pose_encoder(pose)
851 | out = self.conv1(out, latent[0], noise=noise[0])
852 | skip = self.to_rgb1(out, latent[1])
853 |
854 | i = 1
855 | for conv1, conv2, noise1, noise2, to_rgb in zip(
856 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs#, self.to_silhouettes
857 | ):
858 | out = conv1(out, latent[i], noise=noise1)
859 | out = conv2(out, latent[i + 1], noise=noise2)
860 | skip = to_rgb(out, latent[i + 2], skip)
861 | i += 2
862 |
863 | image = skip
864 |
865 | if self.garment_transfer:
866 | return image, part_mask
867 | else:
868 | if return_latents:
869 | return image, latent
870 | else:
871 | return image, None
872 |
873 |
874 | class ConvLayer(nn.Sequential):
875 | def __init__(
876 | self,
877 | in_channel,
878 | out_channel,
879 | kernel_size,
880 | downsample=False,
881 | blur_kernel=[1, 3, 3, 1],
882 | bias=True,
883 | activate=True,
884 | ):
885 | layers = []
886 |
887 | if downsample:
888 | factor = 2
889 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
890 | pad0 = (p + 1) // 2
891 | pad1 = p // 2
892 |
893 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
894 |
895 | stride = 2
896 | self.padding = 0
897 |
898 | else:
899 | stride = 1
900 | self.padding = kernel_size // 2
901 |
902 | layers.append(
903 | EqualConv2d(
904 | in_channel,
905 | out_channel,
906 | kernel_size,
907 | padding=self.padding,
908 | stride=stride,
909 | bias=bias and not activate,
910 | )
911 | )
912 |
913 | if activate:
914 | layers.append(FusedLeakyReLU(out_channel, bias=bias))
915 |
916 | super().__init__(*layers)
917 |
918 |
919 | class ResBlock(nn.Module):
920 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
921 | super().__init__()
922 |
923 | self.conv1 = ConvLayer(in_channel, in_channel, 3)
924 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
925 |
926 | self.skip = ConvLayer(
927 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False
928 | )
929 |
930 | def forward(self, input):
931 | out = self.conv1(input)
932 | out = self.conv2(out)
933 |
934 | skip = self.skip(input)
935 | out = (out + skip) / math.sqrt(2)
936 |
937 | return out
938 |
939 |
940 | class Discriminator(nn.Module):
941 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
942 | super().__init__()
943 |
944 | channels = {
945 | 4: 512,
946 | 8: 512,
947 | 16: 512,
948 | 32: 512,
949 | 64: 256 * channel_multiplier,
950 | 128: 128 * channel_multiplier,
951 | 256: 64 * channel_multiplier,
952 | 512: 32 * channel_multiplier,
953 | 1024: 16 * channel_multiplier,
954 | }
955 |
956 | convs = [ConvLayer(6, channels[size], 1)]
957 |
958 | log_size = int(math.log(size, 2))
959 |
960 | in_channel = channels[size]
961 |
962 | for i in range(log_size, 2, -1):
963 | out_channel = channels[2 ** (i - 1)]
964 |
965 | convs.append(ResBlock(in_channel, out_channel, blur_kernel))
966 |
967 | in_channel = out_channel
968 |
969 | self.convs = nn.Sequential(*convs)
970 |
971 | self.stddev_group = 4
972 | self.stddev_feat = 1
973 |
974 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
975 | self.final_linear = nn.Sequential(
976 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
977 | EqualLinear(channels[4], 1),
978 | )
979 |
980 | def forward(self, input, pose):
981 | input = torch.cat([input, pose], 1)
982 | out = self.convs(input)
983 |
984 | batch, channel, height, width = out.shape
985 | group = min(batch, self.stddev_group)
986 | stddev = out.view(
987 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
988 | )
989 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
990 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
991 | stddev = stddev.repeat(group, 1, height, width)
992 | out = torch.cat([out, stddev], 1)
993 |
994 | out = self.final_conv(out)
995 |
996 | out = out.view(batch, -1)
997 | out = self.final_linear(out)
998 |
999 | return out
1000 |
1001 |
1002 | class VGGLoss(nn.Module):
1003 | def __init__(self, device):
1004 | super(VGGLoss, self).__init__()
1005 | self.vgg = Vgg19().to(device)
1006 | self.criterion = nn.L1Loss()
1007 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
1008 |
1009 | def forward(self, x, y):
1010 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
1011 | loss = 0
1012 | for i in range(len(x_vgg)):
1013 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
1014 | return loss
1015 |
1016 | from torchvision import models
1017 | class Vgg19(torch.nn.Module):
1018 | def __init__(self, requires_grad=False):
1019 | super(Vgg19, self).__init__()
1020 | vgg_pretrained_features = models.vgg19(pretrained=True).features
1021 | self.slice1 = torch.nn.Sequential()
1022 | self.slice2 = torch.nn.Sequential()
1023 | self.slice3 = torch.nn.Sequential()
1024 | self.slice4 = torch.nn.Sequential()
1025 | self.slice5 = torch.nn.Sequential()
1026 | for x in range(2):
1027 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
1028 | for x in range(2, 7):
1029 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
1030 | for x in range(7, 12):
1031 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
1032 | for x in range(12, 21):
1033 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
1034 | for x in range(21, 30):
1035 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
1036 | if not requires_grad:
1037 | for param in self.parameters():
1038 | param.requires_grad = False
1039 |
1040 | def forward(self, X):
1041 | h_relu1 = self.slice1(X)
1042 | h_relu2 = self.slice2(h_relu1)
1043 | h_relu3 = self.slice3(h_relu2)
1044 | h_relu4 = self.slice4(h_relu3)
1045 | h_relu5 = self.slice5(h_relu4)
1046 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
1047 | return out
1048 |
1049 |
1050 | ## Garment transfer
1051 | class GarmentTransferSpatialAppearanceEncoder(nn.Module):
1052 | def __init__(self, ngf=64, blur_kernel=[1, 3, 3, 1], size=256, part='upper_body'):
1053 | super().__init__()
1054 | self.size = size
1055 | self.part = part
1056 | self.dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy')
1057 | self.uv_parts = torch.from_numpy(np.load('util/uv_space_parts.npy')).unsqueeze(0).unsqueeze(0)
1058 |
1059 | input_nc = 4 # source RGB and sil
1060 |
1061 | self.conv1 = ConvLayer(input_nc, ngf, 1) # ngf 256 256
1062 | self.conv2 = ResBlock(ngf, ngf*2, blur_kernel) # 2ngf 128 128
1063 | self.conv3 = ResBlock(ngf*2, ngf*4, blur_kernel) # 4ngf 64 64
1064 | self.conv4 = ResBlock(ngf*4, ngf*8, blur_kernel) # 8ngf 32 32
1065 | self.conv5 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16
1066 | if self.size == 512:
1067 | self.conv6 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 - starting from ngf 512 512
1068 | if self.size == 1024:
1069 | self.conv6 = ResBlock(ngf*8, ngf*8, blur_kernel)
1070 | self.conv7 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 - starting from ngf 1024 0124
1071 |
1072 | self.conv11 = EqualConv2d(ngf+1, ngf*8, 1)
1073 | self.conv21 = EqualConv2d(ngf*2+1, ngf*8, 1)
1074 | self.conv31 = EqualConv2d(ngf*4+1, ngf*8, 1)
1075 | self.conv41 = EqualConv2d(ngf*8+1, ngf*8, 1)
1076 | self.conv51 = EqualConv2d(ngf*8+1, ngf*8, 1)
1077 | if self.size == 512:
1078 | self.conv61 = EqualConv2d(ngf*8+1, ngf*8, 1)
1079 | if self.size == 1024:
1080 | self.conv61 = EqualConv2d(ngf*8+1, ngf*8, 1)
1081 | self.conv71 = EqualConv2d(ngf*8+1, ngf*8, 1)
1082 |
1083 | if self.size == 1024:
1084 | self.conv13 = EqualConv2d(ngf*8, int(ngf/2), 3, padding=1)
1085 | self.conv23 = EqualConv2d(ngf*8, ngf*1, 3, padding=1)
1086 | self.conv33 = EqualConv2d(ngf*8, ngf*2, 3, padding=1)
1087 | self.conv43 = EqualConv2d(ngf*8, ngf*4, 3, padding=1)
1088 | self.conv53 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
1089 | self.conv63 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
1090 | elif self.size == 512:
1091 | self.conv13 = EqualConv2d(ngf*8, ngf*1, 3, padding=1)
1092 | self.conv23 = EqualConv2d(ngf*8, ngf*2, 3, padding=1)
1093 | self.conv33 = EqualConv2d(ngf*8, ngf*4, 3, padding=1)
1094 | self.conv43 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
1095 | self.conv53 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
1096 | else:
1097 | self.conv13 = EqualConv2d(ngf*8, ngf*2, 3, padding=1)
1098 | self.conv23 = EqualConv2d(ngf*8, ngf*4, 3, padding=1)
1099 | self.conv33 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
1100 | self.conv43 = EqualConv2d(ngf*8, ngf*8, 3, padding=1)
1101 |
1102 | self.up = nn.Upsample(scale_factor=2)
1103 |
1104 | def uv2target(self, dp, tex, tex_mask=None):
1105 | # uv2img
1106 | b, _, h, w = dp.shape
1107 | ch = tex.shape[1]
1108 | UV_MAPs = []
1109 | for idx in range(b):
1110 | iuv = dp[idx]
1111 | point_pos = iuv[0, :, :] > 0
1112 | iuv_raw_i = iuv[0][point_pos]
1113 | iuv_raw_u = iuv[1][point_pos]
1114 | iuv_raw_v = iuv[2][point_pos]
1115 | iuv_raw = torch.stack([iuv_raw_i,iuv_raw_u,iuv_raw_v], 0)
1116 | i = iuv_raw[0, :] - 1 ## dp_uv_lookup_256_np does not contain BG class
1117 | u = iuv_raw[1, :]
1118 | v = iuv_raw[2, :]
1119 | i = i.cpu().numpy()
1120 | u = u.cpu().numpy()
1121 | v = v.cpu().numpy()
1122 | uv_smpl = self.dp_uv_lookup_256_np[i, v, u]
1123 | uv_map = torch.zeros((2,h,w)).to(tex).float()
1124 | ## being normalize [0,1] to [-1,1] for the grid sample of Pytorch
1125 | u_map = uv_smpl[:, 0] * 2 - 1
1126 | v_map = (1 - uv_smpl[:, 1]) * 2 - 1
1127 | uv_map[0][point_pos] = torch.from_numpy(u_map).to(tex).float()
1128 | uv_map[1][point_pos] = torch.from_numpy(v_map).to(tex).float()
1129 | UV_MAPs.append(uv_map)
1130 | uv_map = torch.stack(UV_MAPs, 0)
1131 | # warping
1132 | # before warping validate sizes
1133 | _, _, h_x, w_x = tex.shape
1134 | _, _, h_t, w_t = uv_map.shape
1135 | if h_t != h_x or w_t != w_x:
1136 | #https://github.com/iPERDance/iPERCore/blob/4a010f781a4fb90dd29a516472e4aadf41ed1609/iPERCore/models/networks/generators/lwb_avg_resunet.py#L55
1137 | uv_map = torch.nn.functional.interpolate(uv_map, size=(h_x, w_x), mode='bilinear', align_corners=True)
1138 | uv_map = uv_map.permute(0, 2, 3, 1)
1139 | warped_image = torch.nn.functional.grid_sample(tex.float(), uv_map.float())
1140 | if tex_mask is not None:
1141 | warped_mask = torch.nn.functional.grid_sample(tex_mask.float(), uv_map.float())
1142 | final_warped = warped_image * warped_mask
1143 | return final_warped, warped_mask
1144 | else:
1145 | return warped_image
1146 |
1147 | def forward(self, input, pose):
1148 | coor = input[:,4:6,:,:]
1149 | target_coor = input[:,10:,:,:]
1150 | target_input = input[:,6:10,:,:]
1151 | input = input[:,:4,:,:]
1152 |
1153 | # input
1154 | x1 = self.conv1(input)
1155 | x2 = self.conv2(x1)
1156 | x3 = self.conv3(x2)
1157 | x4 = self.conv4(x3)
1158 | x5 = self.conv5(x4)
1159 | if self.size == 512:
1160 | x6 = self.conv6(x5)
1161 | if self.size == 1024:
1162 | x6 = self.conv6(x5)
1163 | x7 = self.conv7(x6)
1164 | # warp- get flow
1165 | pose_mask = 1-(pose[:,0, :, :] == 0).float().unsqueeze(1)
1166 | flow = self.uv2target(pose.int(), coor)
1167 | # warp- resize flow
1168 | f1 = torch.nn.functional.interpolate(flow, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True)
1169 | f2 = torch.nn.functional.interpolate(flow, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True)
1170 | f3 = torch.nn.functional.interpolate(flow, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True)
1171 | f4 = torch.nn.functional.interpolate(flow, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True)
1172 | f5 = torch.nn.functional.interpolate(flow, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True)
1173 | if self.size == 512:
1174 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1175 | if self.size == 1024:
1176 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1177 | f7 = torch.nn.functional.interpolate(flow, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True)
1178 | # warp- now warp
1179 | x1 = torch.nn.functional.grid_sample(x1, f1.permute(0,2,3,1))
1180 | x2 = torch.nn.functional.grid_sample(x2, f2.permute(0,2,3,1))
1181 | x3 = torch.nn.functional.grid_sample(x3, f3.permute(0,2,3,1))
1182 | x4 = torch.nn.functional.grid_sample(x4, f4.permute(0,2,3,1))
1183 | x5 = torch.nn.functional.grid_sample(x5, f5.permute(0,2,3,1))
1184 | if self.size == 512:
1185 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1))
1186 | if self.size == 1024:
1187 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1))
1188 | x7 = torch.nn.functional.grid_sample(x7, f7.permute(0,2,3,1))
1189 | # mask features
1190 | p1 = torch.nn.functional.interpolate(pose_mask, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True)
1191 | p2 = torch.nn.functional.interpolate(pose_mask, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True)
1192 | p3 = torch.nn.functional.interpolate(pose_mask, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True)
1193 | p4 = torch.nn.functional.interpolate(pose_mask, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True)
1194 | p5 = torch.nn.functional.interpolate(pose_mask, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True)
1195 | if self.size == 512:
1196 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1197 | if self.size == 1024:
1198 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1199 | p7 = torch.nn.functional.interpolate(pose_mask, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True)
1200 | x1 = x1 * p1
1201 | x2 = x2 * p2
1202 | x3 = x3 * p3
1203 | x4 = x4 * p4
1204 | x5 = x5 * p5
1205 | if self.size == 512:
1206 | x6 = x6 * p6
1207 | if self.size == 1024:
1208 | x6 = x6 * p6
1209 | x7 = x7 * p7
1210 |
1211 | input_x1 = x1
1212 | input_x2 = x2
1213 | input_x3 = x3
1214 | input_x4 = x4
1215 | input_x5 = x5
1216 | if self.size == 512:
1217 | input_x6 = x6
1218 | if self.size == 1024:
1219 | input_x6 = x6
1220 | input_x7 = x7
1221 |
1222 | # target
1223 | x1 = self.conv1(target_input)
1224 | x2 = self.conv2(x1)
1225 | x3 = self.conv3(x2)
1226 | x4 = self.conv4(x3)
1227 | x5 = self.conv5(x4)
1228 | if self.size == 512:
1229 | x6 = self.conv6(x5)
1230 | if self.size == 1024:
1231 | x6 = self.conv6(x5)
1232 | x7 = self.conv7(x6)
1233 | # warp- get flow
1234 | pose_mask = 1-(pose[:,0, :, :] == 0).float().unsqueeze(1)
1235 | flow = self.uv2target(pose.int(), target_coor)
1236 | # warp- resize flow
1237 | f1 = torch.nn.functional.interpolate(flow, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True)
1238 | f2 = torch.nn.functional.interpolate(flow, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True)
1239 | f3 = torch.nn.functional.interpolate(flow, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True)
1240 | f4 = torch.nn.functional.interpolate(flow, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True)
1241 | f5 = torch.nn.functional.interpolate(flow, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True)
1242 | if self.size == 512:
1243 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1244 | if self.size == 1024:
1245 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1246 | f7 = torch.nn.functional.interpolate(flow, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True)
1247 | # warp- now warp
1248 | x1 = torch.nn.functional.grid_sample(x1, f1.permute(0,2,3,1))
1249 | x2 = torch.nn.functional.grid_sample(x2, f2.permute(0,2,3,1))
1250 | x3 = torch.nn.functional.grid_sample(x3, f3.permute(0,2,3,1))
1251 | x4 = torch.nn.functional.grid_sample(x4, f4.permute(0,2,3,1))
1252 | x5 = torch.nn.functional.grid_sample(x5, f5.permute(0,2,3,1))
1253 | if self.size == 512:
1254 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1))
1255 | if self.size == 1024:
1256 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1))
1257 | x7 = torch.nn.functional.grid_sample(x7, f7.permute(0,2,3,1))
1258 | # mask features
1259 | p1 = torch.nn.functional.interpolate(pose_mask, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True)
1260 | p2 = torch.nn.functional.interpolate(pose_mask, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True)
1261 | p3 = torch.nn.functional.interpolate(pose_mask, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True)
1262 | p4 = torch.nn.functional.interpolate(pose_mask, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True)
1263 | p5 = torch.nn.functional.interpolate(pose_mask, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True)
1264 | if self.size == 512:
1265 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1266 | if self.size == 1024:
1267 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1268 | p7 = torch.nn.functional.interpolate(pose_mask, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True)
1269 | x1 = x1 * p1
1270 | x2 = x2 * p2
1271 | x3 = x3 * p3
1272 | x4 = x4 * p4
1273 | x5 = x5 * p5
1274 | if self.size == 512:
1275 | x6 = x6 * p6
1276 | if self.size == 1024:
1277 | x6 = x6 * p6
1278 | x7 = x7 * p7
1279 |
1280 | target_x1 = x1
1281 | target_x2 = x2
1282 | target_x3 = x3
1283 | target_x4 = x4
1284 | target_x5 = x5
1285 | if self.size == 512:
1286 | target_x6 = x6
1287 | if self.size == 1024:
1288 | target_x6 = x6
1289 | target_x7 = x7
1290 |
1291 | # body part to transfer
1292 | target_parts = torch.round(self.uv2target(pose.int(), self.uv_parts.to(pose))).int()
1293 | if self.part == 'face':
1294 | face_mask_1 = target_parts==23-1
1295 | face_mask_2 = target_parts==24-1
1296 | face_mask = (face_mask_1+face_mask_2).float()
1297 | part_mask = face_mask
1298 | elif self.part == 'lower_body':
1299 | # LOWER CLOTHES: 7,9=UpperLegRight 8,10=UpperLegLeft 11,13=LowerLegRight 12,14=LowerLegLeft
1300 | lower_mask_1 = target_parts==7-1
1301 | lower_mask_2 = target_parts==9-1
1302 | lower_mask_3 = target_parts==8-1
1303 | lower_mask_4 = target_parts==10-1
1304 | lower_mask_5 = target_parts==11-1
1305 | lower_mask_6 = target_parts==13-1
1306 | lower_mask_7 = target_parts==12-1
1307 | lower_mask_8 = target_parts==14-1
1308 | lower_mask = (lower_mask_1+lower_mask_2+lower_mask_3+lower_mask_4+lower_mask_5+lower_mask_6+lower_mask_7+lower_mask_8).float()
1309 | part_mask = lower_mask
1310 | elif self.part == 'upper_body':
1311 | # UPPER CLOTHES: 1,2=Torso 15,17=UpperArmLeft 16,18=UpperArmRight 19,21=LowerArmLeft 20,22=LowerArmRight
1312 | upper_mask_1 = target_parts==1-1
1313 | upper_mask_2 = target_parts==2-1
1314 | upper_mask_3 = target_parts==15-1
1315 | upper_mask_4 = target_parts==17-1
1316 | upper_mask_5 = target_parts==16-1
1317 | upper_mask_6 = target_parts==18-1
1318 | upper_mask_7 = target_parts==19-1
1319 | upper_mask_8 = target_parts==21-1
1320 | upper_mask_9 = target_parts==20-1
1321 | upper_mask_10 = target_parts==22-1
1322 | upper_mask = (upper_mask_1+upper_mask_2+upper_mask_3+upper_mask_4+upper_mask_5+upper_mask_6+upper_mask_7+upper_mask_8+upper_mask_9+upper_mask_10).float()
1323 | part_mask = upper_mask
1324 | else: # full body
1325 | upper_mask_1 = target_parts==1-1
1326 | upper_mask_2 = target_parts==2-1
1327 | upper_mask_3 = target_parts==15-1
1328 | upper_mask_4 = target_parts==17-1
1329 | upper_mask_5 = target_parts==16-1
1330 | upper_mask_6 = target_parts==18-1
1331 | upper_mask_7 = target_parts==19-1
1332 | upper_mask_8 = target_parts==21-1
1333 | upper_mask_9 = target_parts==20-1
1334 | upper_mask_10 = target_parts==22-1
1335 | upper_mask = (upper_mask_1+upper_mask_2+upper_mask_3+upper_mask_4+upper_mask_5+upper_mask_6+upper_mask_7+upper_mask_8+upper_mask_9+upper_mask_10).float()
1336 | lower_mask_1 = target_parts==7-1
1337 | lower_mask_2 = target_parts==9-1
1338 | lower_mask_3 = target_parts==8-1
1339 | lower_mask_4 = target_parts==10-1
1340 | lower_mask_5 = target_parts==11-1
1341 | lower_mask_6 = target_parts==13-1
1342 | lower_mask_7 = target_parts==12-1
1343 | lower_mask_8 = target_parts==14-1
1344 | lower_mask = (lower_mask_1+lower_mask_2+lower_mask_3+lower_mask_4+lower_mask_5+lower_mask_6+lower_mask_7+lower_mask_8).float()
1345 | part_mask = upper_mask + lower_mask
1346 |
1347 | part_mask1 = torch.nn.functional.interpolate(part_mask, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True)
1348 | part_mask2 = torch.nn.functional.interpolate(part_mask, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True)
1349 | part_mask3 = torch.nn.functional.interpolate(part_mask, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True)
1350 | part_mask4 = torch.nn.functional.interpolate(part_mask, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True)
1351 | part_mask5 = torch.nn.functional.interpolate(part_mask, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True)
1352 | if self.size == 512:
1353 | part_mask6 = torch.nn.functional.interpolate(part_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1354 | if self.size == 1024:
1355 | part_mask6 = torch.nn.functional.interpolate(part_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True)
1356 | part_mask7 = torch.nn.functional.interpolate(part_mask, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True)
1357 |
1358 | x1 = target_x1*part_mask1 + input_x1*(1-part_mask1)
1359 | x2 = target_x2*part_mask2 + input_x2*(1-part_mask2)
1360 | x3 = target_x3*part_mask3 + input_x3*(1-part_mask3)
1361 | x4 = target_x4*part_mask4 + input_x4*(1-part_mask4)
1362 | x5 = target_x5*part_mask5 + input_x5*(1-part_mask5)
1363 | if self.size == 512:
1364 | x6 = target_x6*part_mask6 + input_x6*(1-part_mask6)
1365 | if self.size == 1024:
1366 | x6 = target_x6*part_mask6 + input_x6*(1-part_mask6)
1367 | x7 = target_x7*part_mask7 + input_x7*(1-part_mask7)
1368 |
1369 | # fpn
1370 | if self.size == 1024:
1371 | F7 = self.conv71(torch.cat([x7,p7], 1))
1372 | f6 = self.up(F7)+self.conv61(torch.cat([x6,p6], 1))
1373 | F6 = self.conv63(f6)
1374 | f5 = self.up(F6)+self.conv51(torch.cat([x5,p5], 1))
1375 | F5 = self.conv53(f5)
1376 | f4 = self.up(f5)+self.conv41(torch.cat([x4,p4], 1))
1377 | F4 = self.conv43(f4)
1378 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1))
1379 | F3 = self.conv33(f3)
1380 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1))
1381 | F2 = self.conv23(f2)
1382 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1))
1383 | F1 = self.conv13(f1)
1384 | elif self.size == 512:
1385 | F6 = self.conv61(torch.cat([x6,p6], 1))
1386 | f5 = self.up(F6)+self.conv51(torch.cat([x5,p5], 1))
1387 | F5 = self.conv53(f5)
1388 | f4 = self.up(f5)+self.conv41(torch.cat([x4,p4], 1))
1389 | F4 = self.conv43(f4)
1390 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1))
1391 | F3 = self.conv33(f3)
1392 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1))
1393 | F2 = self.conv23(f2)
1394 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1))
1395 | F1 = self.conv13(f1)
1396 | else:
1397 | F5 = self.conv51(torch.cat([x5,p5], 1))
1398 | f4 = self.up(F5)+self.conv41(torch.cat([x4,p4], 1))
1399 | F4 = self.conv43(f4)
1400 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1))
1401 | F3 = self.conv33(f3)
1402 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1))
1403 | F2 = self.conv23(f2)
1404 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1))
1405 | F1 = self.conv13(f1)
1406 |
1407 | if self.size == 1024:
1408 | return [F7, F6, F5, F4, F3, F2, F1], part_mask
1409 | elif self.size == 512:
1410 | return [F6, F5, F4, F3, F2, F1], part_mask
1411 | else:
1412 | return [F5, F4, F3, F2, F1], part_mask
1413 |
--------------------------------------------------------------------------------
/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/op/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import warnings
3 |
4 | import torch
5 | from torch import autograd
6 | from torch.nn import functional as F
7 |
8 | enabled = True
9 | weight_gradients_disabled = False
10 |
11 |
12 | @contextlib.contextmanager
13 | def no_weight_gradients():
14 | global weight_gradients_disabled
15 |
16 | old = weight_gradients_disabled
17 | weight_gradients_disabled = True
18 | yield
19 | weight_gradients_disabled = old
20 |
21 |
22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23 | if could_use_op(input):
24 | return conv2d_gradfix(
25 | transpose=False,
26 | weight_shape=weight.shape,
27 | stride=stride,
28 | padding=padding,
29 | output_padding=0,
30 | dilation=dilation,
31 | groups=groups,
32 | ).apply(input, weight, bias)
33 |
34 | return F.conv2d(
35 | input=input,
36 | weight=weight,
37 | bias=bias,
38 | stride=stride,
39 | padding=padding,
40 | dilation=dilation,
41 | groups=groups,
42 | )
43 |
44 |
45 | def conv_transpose2d(
46 | input,
47 | weight,
48 | bias=None,
49 | stride=1,
50 | padding=0,
51 | output_padding=0,
52 | groups=1,
53 | dilation=1,
54 | ):
55 | if could_use_op(input):
56 | return conv2d_gradfix(
57 | transpose=True,
58 | weight_shape=weight.shape,
59 | stride=stride,
60 | padding=padding,
61 | output_padding=output_padding,
62 | groups=groups,
63 | dilation=dilation,
64 | ).apply(input, weight, bias)
65 |
66 | return F.conv_transpose2d(
67 | input=input,
68 | weight=weight,
69 | bias=bias,
70 | stride=stride,
71 | padding=padding,
72 | output_padding=output_padding,
73 | dilation=dilation,
74 | groups=groups,
75 | )
76 |
77 |
78 | def could_use_op(input):
79 | if (not enabled) or (not torch.backends.cudnn.enabled):
80 | return False
81 |
82 | if input.device.type != "cuda":
83 | return False
84 |
85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86 | return True
87 |
88 | warnings.warn(
89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90 | )
91 |
92 | return False
93 |
94 |
95 | def ensure_tuple(xs, ndim):
96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97 |
98 | return xs
99 |
100 |
101 | conv2d_gradfix_cache = dict()
102 |
103 |
104 | def conv2d_gradfix(
105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups
106 | ):
107 | ndim = 2
108 | weight_shape = tuple(weight_shape)
109 | stride = ensure_tuple(stride, ndim)
110 | padding = ensure_tuple(padding, ndim)
111 | output_padding = ensure_tuple(output_padding, ndim)
112 | dilation = ensure_tuple(dilation, ndim)
113 |
114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115 | if key in conv2d_gradfix_cache:
116 | return conv2d_gradfix_cache[key]
117 |
118 | common_kwargs = dict(
119 | stride=stride, padding=padding, dilation=dilation, groups=groups
120 | )
121 |
122 | def calc_output_padding(input_shape, output_shape):
123 | if transpose:
124 | return [0, 0]
125 |
126 | return [
127 | input_shape[i + 2]
128 | - (output_shape[i + 2] - 1) * stride[i]
129 | - (1 - 2 * padding[i])
130 | - dilation[i] * (weight_shape[i + 2] - 1)
131 | for i in range(ndim)
132 | ]
133 |
134 | class Conv2d(autograd.Function):
135 | @staticmethod
136 | def forward(ctx, input, weight, bias):
137 | if not transpose:
138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139 |
140 | else:
141 | out = F.conv_transpose2d(
142 | input=input,
143 | weight=weight,
144 | bias=bias,
145 | output_padding=output_padding,
146 | **common_kwargs,
147 | )
148 |
149 | ctx.save_for_backward(input, weight)
150 |
151 | return out
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | input, weight = ctx.saved_tensors
156 | grad_input, grad_weight, grad_bias = None, None, None
157 |
158 | if ctx.needs_input_grad[0]:
159 | p = calc_output_padding(
160 | input_shape=input.shape, output_shape=grad_output.shape
161 | )
162 | grad_input = conv2d_gradfix(
163 | transpose=(not transpose),
164 | weight_shape=weight_shape,
165 | output_padding=p,
166 | **common_kwargs,
167 | ).apply(grad_output, weight, None)
168 |
169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
171 |
172 | if ctx.needs_input_grad[2]:
173 | grad_bias = grad_output.sum((0, 2, 3))
174 |
175 | return grad_input, grad_weight, grad_bias
176 |
177 | class Conv2dGradWeight(autograd.Function):
178 | @staticmethod
179 | def forward(ctx, grad_output, input):
180 | op = torch._C._jit_get_operation(
181 | "aten::cudnn_convolution_backward_weight"
182 | if not transpose
183 | else "aten::cudnn_convolution_transpose_backward_weight"
184 | )
185 | flags = [
186 | torch.backends.cudnn.benchmark,
187 | torch.backends.cudnn.deterministic,
188 | torch.backends.cudnn.allow_tf32,
189 | ]
190 | grad_weight = op(
191 | weight_shape,
192 | grad_output,
193 | input,
194 | padding,
195 | stride,
196 | dilation,
197 | groups,
198 | *flags,
199 | )
200 | ctx.save_for_backward(grad_output, input)
201 |
202 | return grad_weight
203 |
204 | @staticmethod
205 | def backward(ctx, grad_grad_weight):
206 | grad_output, input = ctx.saved_tensors
207 | grad_grad_output, grad_grad_input = None, None
208 |
209 | if ctx.needs_input_grad[0]:
210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211 |
212 | if ctx.needs_input_grad[1]:
213 | p = calc_output_padding(
214 | input_shape=input.shape, output_shape=grad_output.shape
215 | )
216 | grad_grad_input = conv2d_gradfix(
217 | transpose=(not transpose),
218 | weight_shape=weight_shape,
219 | output_padding=p,
220 | **common_kwargs,
221 | ).apply(grad_output, grad_grad_weight, None)
222 |
223 | return grad_grad_output, grad_grad_input
224 |
225 | conv2d_gradfix_cache[key] = Conv2d
226 |
227 | return Conv2d
228 |
--------------------------------------------------------------------------------
/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output, empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | if bias:
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | else:
42 | grad_bias = empty
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | out, = ctx.saved_tensors
49 | gradgrad_out = fused.fused_bias_act(
50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51 | )
52 |
53 | return gradgrad_out, None, None, None, None
54 |
55 |
56 | class FusedLeakyReLUFunction(Function):
57 | @staticmethod
58 | def forward(ctx, input, bias, negative_slope, scale):
59 | empty = input.new_empty(0)
60 |
61 | ctx.bias = bias is not None
62 |
63 | if bias is None:
64 | bias = empty
65 |
66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
67 | ctx.save_for_backward(out)
68 | ctx.negative_slope = negative_slope
69 | ctx.scale = scale
70 |
71 | return out
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | out, = ctx.saved_tensors
76 |
77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
79 | )
80 |
81 | if not ctx.bias:
82 | grad_bias = None
83 |
84 | return grad_input, grad_bias, None, None
85 |
86 |
87 | class FusedLeakyReLU(nn.Module):
88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
89 | super().__init__()
90 |
91 | if bias:
92 | self.bias = nn.Parameter(torch.zeros(channel))
93 |
94 | else:
95 | self.bias = None
96 |
97 | self.negative_slope = negative_slope
98 | self.scale = scale
99 |
100 | def forward(self, input):
101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
102 |
103 |
104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
105 | if input.device.type == "cpu":
106 | if bias is not None:
107 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
108 | return (
109 | F.leaky_relu(
110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
111 | )
112 | * scale
113 | )
114 |
115 | else:
116 | return F.leaky_relu(input, negative_slope=0.2) * scale
117 |
118 | else:
119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
120 |
--------------------------------------------------------------------------------
/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | from collections import abc
2 | import os
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | upfirdn2d_op = load(
12 | "upfirdn2d",
13 | sources=[
14 | os.path.join(module_path, "upfirdn2d.cpp"),
15 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class UpFirDn2dBackward(Function):
21 | @staticmethod
22 | def forward(
23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24 | ):
25 |
26 | up_x, up_y = up
27 | down_x, down_y = down
28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29 |
30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31 |
32 | grad_input = upfirdn2d_op.upfirdn2d(
33 | grad_output,
34 | grad_kernel,
35 | down_x,
36 | down_y,
37 | up_x,
38 | up_y,
39 | g_pad_x0,
40 | g_pad_x1,
41 | g_pad_y0,
42 | g_pad_y1,
43 | )
44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45 |
46 | ctx.save_for_backward(kernel)
47 |
48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
49 |
50 | ctx.up_x = up_x
51 | ctx.up_y = up_y
52 | ctx.down_x = down_x
53 | ctx.down_y = down_y
54 | ctx.pad_x0 = pad_x0
55 | ctx.pad_x1 = pad_x1
56 | ctx.pad_y0 = pad_y0
57 | ctx.pad_y1 = pad_y1
58 | ctx.in_size = in_size
59 | ctx.out_size = out_size
60 |
61 | return grad_input
62 |
63 | @staticmethod
64 | def backward(ctx, gradgrad_input):
65 | kernel, = ctx.saved_tensors
66 |
67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68 |
69 | gradgrad_out = upfirdn2d_op.upfirdn2d(
70 | gradgrad_input,
71 | kernel,
72 | ctx.up_x,
73 | ctx.up_y,
74 | ctx.down_x,
75 | ctx.down_y,
76 | ctx.pad_x0,
77 | ctx.pad_x1,
78 | ctx.pad_y0,
79 | ctx.pad_y1,
80 | )
81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82 | gradgrad_out = gradgrad_out.view(
83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84 | )
85 |
86 | return gradgrad_out, None, None, None, None, None, None, None, None
87 |
88 |
89 | class UpFirDn2d(Function):
90 | @staticmethod
91 | def forward(ctx, input, kernel, up, down, pad):
92 | up_x, up_y = up
93 | down_x, down_y = down
94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
95 |
96 | kernel_h, kernel_w = kernel.shape
97 | batch, channel, in_h, in_w = input.shape
98 | ctx.in_size = input.shape
99 |
100 | input = input.reshape(-1, in_h, in_w, 1)
101 |
102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103 |
104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106 | ctx.out_size = (out_h, out_w)
107 |
108 | ctx.up = (up_x, up_y)
109 | ctx.down = (down_x, down_y)
110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111 |
112 | g_pad_x0 = kernel_w - pad_x0 - 1
113 | g_pad_y0 = kernel_h - pad_y0 - 1
114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116 |
117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118 |
119 | out = upfirdn2d_op.upfirdn2d(
120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121 | )
122 | # out = out.view(major, out_h, out_w, minor)
123 | out = out.view(-1, channel, out_h, out_w)
124 |
125 | return out
126 |
127 | @staticmethod
128 | def backward(ctx, grad_output):
129 | kernel, grad_kernel = ctx.saved_tensors
130 |
131 | grad_input = None
132 |
133 | if ctx.needs_input_grad[0]:
134 | grad_input = UpFirDn2dBackward.apply(
135 | grad_output,
136 | kernel,
137 | grad_kernel,
138 | ctx.up,
139 | ctx.down,
140 | ctx.pad,
141 | ctx.g_pad,
142 | ctx.in_size,
143 | ctx.out_size,
144 | )
145 |
146 | return grad_input, None, None, None, None
147 |
148 |
149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150 | if not isinstance(up, abc.Iterable):
151 | up = (up, up)
152 |
153 | if not isinstance(down, abc.Iterable):
154 | down = (down, down)
155 |
156 | if len(pad) == 2:
157 | pad = (pad[0], pad[1], pad[0], pad[1])
158 |
159 | if input.device.type == "cpu":
160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161 |
162 | else:
163 | out = UpFirDn2d.apply(input, kernel, up, down, pad)
164 |
165 | return out
166 |
167 |
168 | def upfirdn2d_native(
169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170 | ):
171 | _, channel, in_h, in_w = input.shape
172 | input = input.reshape(-1, in_h, in_w, 1)
173 |
174 | _, in_h, in_w, minor = input.shape
175 | kernel_h, kernel_w = kernel.shape
176 |
177 | out = input.view(-1, in_h, 1, in_w, 1, minor)
178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180 |
181 | out = F.pad(
182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183 | )
184 | out = out[
185 | :,
186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188 | :,
189 | ]
190 |
191 | out = out.permute(0, 3, 1, 2)
192 | out = out.reshape(
193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194 | )
195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196 | out = F.conv2d(out, w)
197 | out = out.reshape(
198 | -1,
199 | minor,
200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202 | )
203 | out = out.permute(0, 2, 3, 1)
204 | out = out[:, ::down_y, ::down_x, :]
205 |
206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208 |
209 | return out.view(-1, channel, out_h, out_w)
210 |
--------------------------------------------------------------------------------
/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | pandas
3 | scipy
4 | opencv-python
5 |
--------------------------------------------------------------------------------
/sphereface.py:
--------------------------------------------------------------------------------
1 | # source: https://raw.githubusercontent.com/clcarwin/sphereface_pytorch/master/net_sphere.py
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | from torch.nn import Parameter
7 | import math
8 |
9 | def myphi(x,m):
10 | x = x * m
11 | return 1-x**2/math.factorial(2)+x**4/math.factorial(4)-x**6/math.factorial(6) + \
12 | x**8/math.factorial(8) - x**9/math.factorial(9)
13 |
14 | class AngleLinear(nn.Module):
15 | def __init__(self, in_features, out_features, m = 4, phiflag=True):
16 | super(AngleLinear, self).__init__()
17 | self.in_features = in_features
18 | self.out_features = out_features
19 | self.weight = Parameter(torch.Tensor(in_features,out_features))
20 | self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
21 | self.phiflag = phiflag
22 | self.m = m
23 | self.mlambda = [
24 | lambda x: x**0,
25 | lambda x: x**1,
26 | lambda x: 2*x**2-1,
27 | lambda x: 4*x**3-3*x,
28 | lambda x: 8*x**4-8*x**2+1,
29 | lambda x: 16*x**5-20*x**3+5*x
30 | ]
31 |
32 | def forward(self, input):
33 | x = input # size=(B,F) F is feature len
34 | w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features
35 |
36 | ww = w.renorm(2,1,1e-5).mul(1e5)
37 | xlen = x.pow(2).sum(1).pow(0.5) # size=B
38 | wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum
39 |
40 | cos_theta = x.mm(ww) # size=(B,Classnum)
41 | cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1)
42 | cos_theta = cos_theta.clamp(-1,1)
43 |
44 | if self.phiflag:
45 | cos_m_theta = self.mlambda[self.m](cos_theta)
46 | theta = Variable(cos_theta.data.acos())
47 | k = (self.m*theta/3.14159265).floor()
48 | n_one = k*0.0 - 1
49 | phi_theta = (n_one**k) * cos_m_theta - 2*k
50 | else:
51 | theta = cos_theta.acos()
52 | phi_theta = myphi(theta,self.m)
53 | phi_theta = phi_theta.clamp(-1*self.m,1)
54 |
55 | cos_theta = cos_theta * xlen.view(-1,1)
56 | phi_theta = phi_theta * xlen.view(-1,1)
57 | output = (cos_theta,phi_theta)
58 | return output # size=(B,Classnum,2)
59 |
60 |
61 | class AngleLoss(nn.Module):
62 | def __init__(self, gamma=0):
63 | super(AngleLoss, self).__init__()
64 | self.gamma = gamma
65 | self.it = 0
66 | self.LambdaMin = 5.0
67 | self.LambdaMax = 1500.0
68 | self.lamb = 1500.0
69 |
70 | def forward(self, input, target):
71 | self.it += 1
72 | cos_theta,phi_theta = input
73 | target = target.view(-1,1) #size=(B,1)
74 |
75 | index = cos_theta.data * 0.0 #size=(B,Classnum)
76 | index.scatter_(1,target.data.view(-1,1),1)
77 | index = index.byte()
78 | index = Variable(index)
79 |
80 | self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it ))
81 | output = cos_theta * 1.0 #size=(B,Classnum)
82 | output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb)
83 | output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb)
84 |
85 | logpt = F.log_softmax(output)
86 | logpt = logpt.gather(1,target)
87 | logpt = logpt.view(-1)
88 | pt = Variable(logpt.data.exp())
89 |
90 | loss = -1 * (1-pt)**self.gamma * logpt
91 | loss = loss.mean()
92 |
93 | return loss
94 |
95 |
96 | class sphere20a(nn.Module):
97 | def __init__(self,classnum=10574,feature=False):
98 | super(sphere20a, self).__init__()
99 | self.classnum = classnum
100 | self.feature = feature
101 | #input = B*3*112*96
102 | self.conv1_1 = nn.Conv2d(3,64,3,2,1) #=>B*64*56*48
103 | self.relu1_1 = nn.PReLU(64)
104 | self.conv1_2 = nn.Conv2d(64,64,3,1,1)
105 | self.relu1_2 = nn.PReLU(64)
106 | self.conv1_3 = nn.Conv2d(64,64,3,1,1)
107 | self.relu1_3 = nn.PReLU(64)
108 |
109 | self.conv2_1 = nn.Conv2d(64,128,3,2,1) #=>B*128*28*24
110 | self.relu2_1 = nn.PReLU(128)
111 | self.conv2_2 = nn.Conv2d(128,128,3,1,1)
112 | self.relu2_2 = nn.PReLU(128)
113 | self.conv2_3 = nn.Conv2d(128,128,3,1,1)
114 | self.relu2_3 = nn.PReLU(128)
115 |
116 | self.conv2_4 = nn.Conv2d(128,128,3,1,1) #=>B*128*28*24
117 | self.relu2_4 = nn.PReLU(128)
118 | self.conv2_5 = nn.Conv2d(128,128,3,1,1)
119 | self.relu2_5 = nn.PReLU(128)
120 |
121 |
122 | self.conv3_1 = nn.Conv2d(128,256,3,2,1) #=>B*256*14*12
123 | self.relu3_1 = nn.PReLU(256)
124 | self.conv3_2 = nn.Conv2d(256,256,3,1,1)
125 | self.relu3_2 = nn.PReLU(256)
126 | self.conv3_3 = nn.Conv2d(256,256,3,1,1)
127 | self.relu3_3 = nn.PReLU(256)
128 |
129 | self.conv3_4 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12
130 | self.relu3_4 = nn.PReLU(256)
131 | self.conv3_5 = nn.Conv2d(256,256,3,1,1)
132 | self.relu3_5 = nn.PReLU(256)
133 |
134 | self.conv3_6 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12
135 | self.relu3_6 = nn.PReLU(256)
136 | self.conv3_7 = nn.Conv2d(256,256,3,1,1)
137 | self.relu3_7 = nn.PReLU(256)
138 |
139 | self.conv3_8 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12
140 | self.relu3_8 = nn.PReLU(256)
141 | self.conv3_9 = nn.Conv2d(256,256,3,1,1)
142 | self.relu3_9 = nn.PReLU(256)
143 |
144 | self.conv4_1 = nn.Conv2d(256,512,3,2,1) #=>B*512*7*6
145 | self.relu4_1 = nn.PReLU(512)
146 | self.conv4_2 = nn.Conv2d(512,512,3,1,1)
147 | self.relu4_2 = nn.PReLU(512)
148 | self.conv4_3 = nn.Conv2d(512,512,3,1,1)
149 | self.relu4_3 = nn.PReLU(512)
150 |
151 | self.fc5 = nn.Linear(512*7*6,512)
152 | self.fc6 = AngleLinear(512,self.classnum)
153 |
154 |
155 | def forward(self, x):
156 | x = self.relu1_1(self.conv1_1(x))
157 | x = x + self.relu1_3(self.conv1_3(self.relu1_2(self.conv1_2(x))))
158 |
159 | x = self.relu2_1(self.conv2_1(x))
160 | x = x + self.relu2_3(self.conv2_3(self.relu2_2(self.conv2_2(x))))
161 | x = x + self.relu2_5(self.conv2_5(self.relu2_4(self.conv2_4(x))))
162 |
163 | x = self.relu3_1(self.conv3_1(x))
164 | x = x + self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(x))))
165 | x = x + self.relu3_5(self.conv3_5(self.relu3_4(self.conv3_4(x))))
166 | x = x + self.relu3_7(self.conv3_7(self.relu3_6(self.conv3_6(x))))
167 | x = x + self.relu3_9(self.conv3_9(self.relu3_8(self.conv3_8(x))))
168 |
169 | x = self.relu4_1(self.conv4_1(x))
170 | x = x + self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(x))))
171 |
172 | x = x.view(x.size(0),-1)
173 | x = self.fc5(x)
174 | if self.feature: return x
175 |
176 | x = self.fc6(x)
177 | return x
178 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torchvision import utils
5 | from model import Generator
6 | from tqdm import tqdm
7 | from torch.utils import data
8 | import numpy as np
9 | from dataset import DeepFashionDataset
10 |
11 |
12 | def sample_data(loader):
13 | while True:
14 | for batch in loader:
15 | yield batch
16 |
17 | def generate(args, g_ema, device, mean_latent, loader):
18 | loader = sample_data(loader)
19 | with torch.no_grad():
20 | g_ema.eval()
21 | for i in tqdm(range(args.pics)):
22 | data = next(loader)
23 |
24 | input_image = data['input_image'].float().to(device)
25 | real_img = data['target_image'].float().to(device)
26 | pose = data['target_pose'].float().to(device)
27 | sil = data['target_sil'].float().to(device)
28 |
29 | source_sil = data['input_sil'].float().to(device)
30 | complete_coor = data['complete_coor'].float().to(device)
31 |
32 | if args.size == 256:
33 | complete_coor = torch.nn.functional.interpolate(complete_coor, size=(256,256), mode='bilinear')
34 |
35 | appearance = torch.cat([input_image, source_sil, complete_coor], 1)
36 |
37 | sample, _ = g_ema(appearance=appearance, pose=pose)
38 |
39 | RP = data['target_right_pad']
40 | LP = data['target_left_pad']
41 |
42 | utils.save_image(
43 | sample[:, :, :, int(RP[0].item()):args.size-int(LP[0].item())],
44 | os.path.join(args.save_path, data['save_name'][0]),
45 | nrow=1,
46 | normalize=True,
47 | range=(-1, 1),
48 | )
49 |
50 |
51 | if __name__ == "__main__":
52 | device = "cuda"
53 |
54 | parser = argparse.ArgumentParser(description="Generate reposing results")
55 |
56 | parser.add_argument("path", type=str, help="path to dataset")
57 | parser.add_argument("--size", type=int, default=512, help="output image size of the generator")
58 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
59 | parser.add_argument("--truncation_mean", type=int, default=4096, help="number of vectors to calculate mean for the truncation")
60 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of the generator. config-f = 2, else = 1")
61 | parser.add_argument("--pretrained_model", type=str, default="posewithstyle.pt", help="pose with style pretrained model")
62 | parser.add_argument("--save_path", type=str, default="output", help="path to save output .data/output")
63 |
64 | args = parser.parse_args()
65 |
66 | args.latent = 2048
67 | args.n_mlp = 8
68 |
69 | if not os.path.exists(args.save_path):
70 | os.makedirs(args.save_path)
71 |
72 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device)
73 | checkpoint = torch.load(args.pretrained_model)
74 | g_ema.load_state_dict(checkpoint["g_ema"])
75 |
76 | if args.truncation < 1:
77 | with torch.no_grad():
78 | mean_latent = g_ema.mean_latent(args.truncation_mean)
79 | else:
80 | mean_latent = None
81 |
82 | dataset = DeepFashionDataset(args.path, 'test', args.size)
83 | loader = data.DataLoader(
84 | dataset,
85 | batch_size=1,
86 | sampler=data.SequentialSampler(dataset),
87 | drop_last=False,
88 | )
89 |
90 | print ('Testing %d images...'%len(dataset))
91 | args.pics = len(dataset)
92 |
93 | generate(args, g_ema, device, mean_latent, loader)
94 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import random
4 | import os
5 |
6 | import numpy as np
7 | import torch
8 | from torch import nn, autograd, optim
9 | from torch.nn import functional as F
10 | from torch.utils import data
11 | import torch.distributed as dist
12 | from torchvision import transforms, utils
13 | from tqdm import tqdm
14 | import time
15 |
16 | from dataset import DeepFashionDataset
17 | from model import Generator, Discriminator, VGGLoss
18 |
19 | try:
20 | import wandb
21 | except ImportError:
22 | wandb = None
23 |
24 |
25 | from distributed import (
26 | get_rank,
27 | synchronize,
28 | reduce_loss_dict,
29 | reduce_sum,
30 | get_world_size,
31 | )
32 | from op import conv2d_gradfix
33 |
34 |
35 | class AverageMeter(object):
36 | """Computes and stores the average and current value"""
37 | def __init__(self):
38 | self.reset()
39 |
40 | def reset(self):
41 | self.val = 0
42 | self.avg = 0
43 | self.sum = 0
44 | self.count = 0
45 |
46 | def update(self, val, n=1):
47 | self.val = val
48 | self.sum += val * n
49 | self.count += n
50 | self.avg = self.sum / self.count
51 |
52 |
53 | def data_sampler(dataset, shuffle, distributed):
54 | if distributed:
55 | return data.distributed.DistributedSampler(dataset)
56 | if shuffle:
57 | return data.RandomSampler(dataset)
58 | else:
59 | return data.SequentialSampler(dataset)
60 |
61 |
62 | def requires_grad(model, flag=True):
63 | for p in model.parameters():
64 | p.requires_grad = flag
65 |
66 |
67 | def accumulate(model1, model2, decay=0.999):
68 | par1 = dict(model1.named_parameters())
69 | par2 = dict(model2.named_parameters())
70 | for k in par1.keys():
71 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
72 |
73 |
74 | def sample_data(loader):
75 | while True:
76 | for batch in loader:
77 | yield batch
78 |
79 |
80 | def d_logistic_loss(real_pred, fake_pred):
81 | real_loss = F.softplus(-real_pred)
82 | fake_loss = F.softplus(fake_pred)
83 | return real_loss.mean() + fake_loss.mean()
84 |
85 |
86 | def d_r1_loss(real_pred, real_img):
87 | with conv2d_gradfix.no_weight_gradients():
88 | grad_real, = autograd.grad(
89 | outputs=real_pred.sum(), inputs=real_img, create_graph=True
90 | )
91 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
92 | return grad_penalty
93 |
94 |
95 | def g_nonsaturating_loss(fake_pred):
96 | loss = F.softplus(-fake_pred).mean()
97 | return loss
98 |
99 |
100 | def set_grad_none(model, targets):
101 | for n, p in model.named_parameters():
102 | if n in targets:
103 | p.grad = None
104 |
105 |
106 | def getFace(images, FT, LP, RP):
107 | """
108 | images: are images where we want to get the faces
109 | FT: transform to get the aligned face
110 | LP: left pad added to the imgae
111 | RP: right pad added to the image
112 | """
113 | faces = []
114 | b, h, w, c = images.shape
115 | for b in range(images.shape[0]):
116 | if not (abs(FT[b]).sum() == 0): # all 3x3 elements are zero
117 | # only apply the loss to image with detected faces
118 | # need to do this per image because images are of different shape
119 | current_im = images[b][:, :, int(RP[b].item()):w-int(LP[b].item())].unsqueeze(0)
120 | theta = FT[b].unsqueeze(0)[:, :2] #bx2x3
121 | grid = torch.nn.functional.affine_grid(theta, (1, 3, 112, 96))
122 | current_face = torch.nn.functional.grid_sample(current_im, grid)
123 | faces.append(current_face)
124 | if len(faces) == 0:
125 | return None
126 | return torch.cat(faces, 0)
127 |
128 |
129 | def train(args, loader, sampler, generator, discriminator, g_optim, d_optim, g_ema, device):
130 | pbar = range(args.epoch)
131 |
132 | if get_rank() == 0:
133 | pbar = tqdm(pbar, initial=args.start_epoch, dynamic_ncols=True, smoothing=0.01)
134 | pbar.set_description('Epoch Counter')
135 |
136 | d_loss_val = 0
137 | r1_loss = torch.tensor(0.0, device=device)
138 | g_loss_val = 0
139 | g_L1_loss_val = 0
140 | g_vgg_loss_val = 0
141 | g_l1 = torch.tensor(0.0, device=device)
142 | g_vgg = torch.tensor(0.0, device=device)
143 | g_cos = torch.tensor(0.0, device=device)
144 | loss_dict = {}
145 |
146 | criterionL1 = torch.nn.L1Loss()
147 | criterionVGG = VGGLoss(device).to(device)
148 | if args.faceloss:
149 | criterionCOS = nn.CosineSimilarity()
150 |
151 | if args.distributed:
152 | g_module = generator.module
153 | d_module = discriminator.module
154 | else:
155 | g_module = generator
156 | d_module = discriminator
157 |
158 | accum = 0.5 ** (32 / (10 * 1000))
159 |
160 | for idx in pbar:
161 | epoch = idx + args.start_epoch
162 |
163 | if epoch > args.epoch:
164 | print("Done!")
165 | break
166 |
167 | if args.distributed:
168 | sampler.set_epoch(epoch)
169 |
170 | batch_time = AverageMeter()
171 |
172 | #####################################
173 | ############ START EPOCH ############
174 | #####################################
175 | for i, data in enumerate(loader):
176 | batch_start_time = time.time()
177 |
178 | input_image = data['input_image'].float().to(device)
179 | real_img = data['target_image'].float().to(device)
180 | pose = data['target_pose'].float().to(device)
181 | sil = data['target_sil'].float().to(device)
182 |
183 | LeftPad = data['target_left_pad'].float().to(device)
184 | RightPad = data['target_right_pad'].float().to(device)
185 |
186 | if args.faceloss:
187 | FT = data['TargetFaceTransform'].float().to(device)
188 | real_face = getFace(real_img, FT, LeftPad, RightPad)
189 |
190 | if args.finetune:
191 | # only mask padding
192 | sil = torch.zeros((sil.shape)).float().to(device)
193 | for b in range(sil.shape[0]):
194 | w = sil.shape[3]
195 | sil[b][:, :, int(RightPad[b].item()):w-int(LeftPad[b].item())] = 1 # mask out the padding
196 | # else only focus on the foreground - initial step of training
197 |
198 | real_img = real_img * sil
199 |
200 | # appearance = human foregound + fg mask (pass coor for warping)
201 | source_sil = data['input_sil'].float().to(device)
202 | complete_coor = data['complete_coor'].float().to(device)
203 | if args.size == 256:
204 | complete_coor = torch.nn.functional.interpolate(complete_coor, size=(256, 256), mode='bilinear')
205 | if args.finetune:
206 | appearance = torch.cat([input_image, source_sil, complete_coor], 1)
207 | else:
208 | appearance = torch.cat([input_image * source_sil, source_sil, complete_coor], 1)
209 |
210 |
211 | ############ Optimize Discriminator ############
212 | requires_grad(generator, False)
213 | requires_grad(discriminator, True)
214 |
215 | fake_img, _ = generator(appearance=appearance, pose=pose)
216 | fake_img = fake_img * sil
217 |
218 | fake_pred = discriminator(fake_img, pose=pose)
219 | real_pred = discriminator(real_img, pose=pose)
220 | d_loss = d_logistic_loss(real_pred, fake_pred)
221 |
222 | loss_dict["d"] = d_loss
223 | loss_dict["real_score"] = real_pred.mean()
224 | loss_dict["fake_score"] = fake_pred.mean()
225 |
226 | discriminator.zero_grad()
227 | d_loss.backward()
228 | d_optim.step()
229 |
230 |
231 | d_regularize = i % args.d_reg_every == 0
232 |
233 | if d_regularize:
234 | real_img.requires_grad = True
235 |
236 | real_pred = discriminator(real_img, pose=pose)
237 | r1_loss = d_r1_loss(real_pred, real_img)
238 |
239 | discriminator.zero_grad()
240 | (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
241 |
242 | d_optim.step()
243 |
244 | loss_dict["r1"] = r1_loss
245 |
246 |
247 | ############## Optimize Generator ##############
248 | requires_grad(generator, True)
249 | requires_grad(discriminator, False)
250 |
251 | fake_img, _ = generator(appearance=appearance, pose=pose)
252 | fake_img = fake_img * sil
253 |
254 | fake_pred = discriminator(fake_img, pose=pose)
255 | g_loss = g_nonsaturating_loss(fake_pred)
256 |
257 | loss_dict["g"] = g_loss
258 |
259 | ## reconstruction loss: L1 and VGG loss + face identity loss
260 | g_l1 = criterionL1(fake_img, real_img)
261 | g_loss += g_l1
262 | g_vgg = criterionVGG(fake_img, real_img)
263 | g_loss += g_vgg
264 |
265 | loss_dict["g_L1"] = g_l1
266 | loss_dict["g_vgg"] = g_vgg
267 |
268 | if args.faceloss and (real_face is not None):
269 | fake_face = getFace(fake_img, FT, LeftPad, RightPad)
270 | features_real_face = sphereface_net(real_face)
271 | features_fake_face = sphereface_net(fake_face)
272 | g_cos = 1. - criterionCOS(features_real_face, features_fake_face).mean()
273 | g_loss += g_cos
274 |
275 | loss_dict["g_cos"] = g_cos
276 |
277 | generator.zero_grad()
278 | g_loss.backward()
279 | g_optim.step()
280 |
281 |
282 | ############ Optimization Done ############
283 | accumulate(g_ema, g_module, accum)
284 |
285 | loss_reduced = reduce_loss_dict(loss_dict)
286 |
287 | d_loss_val = loss_reduced["d"].mean().item()
288 | g_loss_val = loss_reduced["g"].mean().item()
289 | g_L1_loss_val = loss_reduced["g_L1"].mean().item()
290 | g_cos_loss_val = loss_reduced["g_cos"].mean().item()
291 | g_vgg_loss_val = loss_reduced["g_vgg"].mean().item()
292 | r1_val = loss_reduced["r1"].mean().item()
293 | real_score_val = loss_reduced["real_score"].mean().item()
294 | fake_score_val = loss_reduced["fake_score"].mean().item()
295 |
296 | batch_time.update(time.time() - batch_start_time)
297 |
298 | if i % 100 == 0:
299 | print('Epoch: [{0}/{1}] Iter: [{2}/{3}]\t'
300 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(epoch, args.epoch, i, len(loader), batch_time=batch_time)
301 | +
302 | f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; g_L1: {g_L1_loss_val:.4f}; g_vgg: {g_vgg_loss_val:.4f}; g_cos: {g_cos_loss_val:.4f}; r1: {r1_val:.4f}; "
303 | )
304 |
305 | if get_rank() == 0:
306 | if wandb and args.wandb:
307 | wandb.log(
308 | {
309 | "Generator": g_loss_val,
310 | "Discriminator": d_loss_val,
311 | "R1": r1_val,
312 | "Real Score": real_score_val,
313 | "Fake Score": fake_score_val,
314 | "Generator_L1": g_L1_loss_val,
315 | "Generator_vgg": g_vgg_loss_val,
316 | "Generator_facecos": g_cos_loss_val,
317 | }
318 | )
319 |
320 | if i % 5000 == 0:
321 | with torch.no_grad():
322 | g_ema.eval()
323 | sample, _ = g_ema(appearance=appearance[:args.n_sample], pose=pose[:args.n_sample])
324 | sample = sample * sil
325 | utils.save_image(
326 | sample,
327 | os.path.join('sample', args.name, f"epoch_{str(epoch)}_iter_{str(i)}.png"),
328 | nrow=int(args.n_sample ** 0.5),
329 | normalize=True,
330 | range=(-1, 1),
331 | )
332 |
333 | if i % 5000 == 0:
334 | torch.save(
335 | {
336 | "g": g_module.state_dict(),
337 | "d": d_module.state_dict(),
338 | "g_ema": g_ema.state_dict(),
339 | "g_optim": g_optim.state_dict(),
340 | "d_optim": d_optim.state_dict(),
341 | "args": args,
342 | },
343 | os.path.join('checkpoint', args.name, f"epoch_{str(epoch)}_iter_{str(i)}.pt"),
344 | )
345 |
346 | ###################################
347 | ############ END EPOCH ############
348 | ###################################
349 | if get_rank() == 0:
350 | torch.save(
351 | {
352 | "g": g_module.state_dict(),
353 | "d": d_module.state_dict(),
354 | "g_ema": g_ema.state_dict(),
355 | "g_optim": g_optim.state_dict(),
356 | "d_optim": d_optim.state_dict(),
357 | "args": args,
358 | },
359 | os.path.join('checkpoint', args.name, f"epoch_{str(epoch)}.pt"),
360 | )
361 |
362 |
363 | if __name__ == "__main__":
364 | device = "cuda"
365 |
366 | parser = argparse.ArgumentParser(description="Pose with Style trainer")
367 |
368 | parser.add_argument("path", type=str, help="path to the lmdb dataset")
369 | parser.add_argument("--name", type=str, help="name of experiment")
370 | parser.add_argument("--epoch", type=int, default=50, help="total training epochs")
371 | parser.add_argument("--batch", type=int, default=4, help="batch sizes for each gpus")
372 | parser.add_argument("--workers", type=int, default=4, help="batch sizes for each gpus")
373 | parser.add_argument("--n_sample", type=int, default=4, help="number of the samples generated during training")
374 | parser.add_argument("--size", type=int, default=512, help="image sizes for the model")
375 | parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
376 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier factor for the model. config-f = 2, else = 1")
377 | parser.add_argument(
378 | "--d_reg_every",
379 | type=int,
380 | default=16,
381 | help="interval of the applying r1 regularization",
382 | )
383 | parser.add_argument(
384 | "--g_reg_every",
385 | type=int,
386 | default=4,
387 | help="interval of the applying path length regularization",
388 | )
389 | parser.add_argument("--ckpt", type=str, default=None, help="path to the checkpoints to resume training")
390 | parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
391 | parser.add_argument("--wandb", action="store_true", help="use weights and biases logging")
392 | parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training")
393 | parser.add_argument("--faceloss", action="store_true", help="add face loss when faces are detected")
394 | parser.add_argument("--finetune", action="store_true", help="finetune to handle background- second step of training.")
395 |
396 |
397 | args = parser.parse_args()
398 |
399 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
400 | args.distributed = n_gpu > 1
401 |
402 | if args.distributed:
403 | print ('Distributed Training Mode.')
404 | torch.cuda.set_device(args.local_rank)
405 | torch.distributed.init_process_group(backend="nccl", init_method="env://")
406 | synchronize()
407 |
408 | if get_rank() == 0:
409 | if not os.path.exists(os.path.join('checkpoint', args.name)):
410 | os.makedirs(os.path.join('checkpoint', args.name))
411 | if not os.path.exists(os.path.join('sample', args.name)):
412 | os.makedirs(os.path.join('sample', args.name))
413 |
414 | args.latent = 2048
415 | args.n_mlp = 8
416 |
417 | args.start_epoch = 0
418 |
419 | if args.finetune and (args.ckpt is None):
420 | print ('to finetune the model, please specify --ckpt.')
421 | import sys
422 | sys.exit()
423 |
424 | # define models
425 | generator = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device)
426 | discriminator = Discriminator(args.size, channel_multiplier=args.channel_multiplier).to(device)
427 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device)
428 | g_ema.eval()
429 | accumulate(g_ema, generator, 0)
430 |
431 | if args.faceloss:
432 | import sphereface
433 | sphereface_net = getattr(sphereface, 'sphere20a')()
434 | sphereface_net.load_state_dict(torch.load(os.path.join(args.path, 'resources', 'sphere20a_20171020.pth')))
435 | sphereface_net.to(device)
436 | sphereface_net.eval()
437 | sphereface_net.feature = True
438 |
439 |
440 | g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
441 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
442 |
443 | g_optim = optim.Adam(
444 | generator.parameters(),
445 | lr=args.lr * g_reg_ratio,
446 | betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
447 | )
448 | d_optim = optim.Adam(
449 | discriminator.parameters(),
450 | lr=args.lr * d_reg_ratio,
451 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
452 | )
453 |
454 | if args.ckpt is not None:
455 | print("load model:", args.ckpt)
456 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
457 |
458 | try:
459 | ckpt_name = os.path.basename(args.ckpt)
460 | args.start_epoch = int(os.path.splitext(ckpt_name)[0].split('_')[1])+1 # asuming saving as epoch_1_iter_1000.pt or epoch_1.pt
461 | except ValueError:
462 | pass
463 |
464 | generator.load_state_dict(ckpt["g"])
465 | discriminator.load_state_dict(ckpt["d"])
466 | g_ema.load_state_dict(ckpt["g_ema"])
467 |
468 | g_optim.load_state_dict(ckpt["g_optim"])
469 | d_optim.load_state_dict(ckpt["d_optim"])
470 |
471 | if args.distributed:
472 | generator = nn.parallel.DistributedDataParallel(
473 | generator,
474 | device_ids=[args.local_rank],
475 | output_device=args.local_rank,
476 | broadcast_buffers=False,
477 | )
478 | discriminator = nn.parallel.DistributedDataParallel(
479 | discriminator,
480 | device_ids=[args.local_rank],
481 | output_device=args.local_rank,
482 | broadcast_buffers=False,
483 | )
484 |
485 | dataset = DeepFashionDataset(args.path, 'train', args.size)
486 | sampler = data_sampler(dataset, shuffle=True, distributed=args.distributed)
487 | loader = data.DataLoader(
488 | dataset,
489 | batch_size=args.batch,
490 | sampler=sampler,
491 | drop_last=True,
492 | pin_memory=True,
493 | num_workers=args.workers,
494 | shuffle=False,
495 | )
496 |
497 | if get_rank() == 0 and (wandb is not None) and args.wandb:
498 | wandb.init(project=args.name)
499 |
500 | train(args, loader, sampler, generator, discriminator, g_optim, d_optim, g_ema, device)
501 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/util/__init__.py
--------------------------------------------------------------------------------
/util/complete_coor.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | import torch
5 | from coordinate_completion_model import define_G as define_CCM
6 |
7 | def pad_PIL(pil_img, top, right, bottom, left, color=(0, 0, 0)):
8 | width, height = pil_img.size
9 | new_width = width + right + left
10 | new_height = height + top + bottom
11 | result = Image.new(pil_img.mode, (new_width, new_height), color)
12 | result.paste(pil_img, (left, top))
13 | return result
14 |
15 | def save_image(image_numpy, image_path):
16 | image_pil = Image.fromarray(image_numpy.astype(np.uint8))
17 | image_pil.save(image_path)
18 |
19 | import argparse
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--dataroot', type=str, help="path to DeepFashion dataset")
22 | parser.add_argument('--coordinates_path', type=str, help="path to partial coordinates dataset")
23 | parser.add_argument('--image_file', type=str, help="path to image file to process. ex: ./train.lst")
24 | parser.add_argument('--phase', type=str, help="train or test")
25 | parser.add_argument('--save_path', type=str, help="path to save results")
26 | parser.add_argument('--pretrained_model', type=str, help="path to save results")
27 |
28 | args = parser.parse_args()
29 |
30 | images_file = args.image_file
31 | phase = args.phase
32 |
33 | if not os.path.exists(os.path.join(args.save_path, phase)):
34 | os.makedirs(os.path.join(args.save_path, phase))
35 |
36 | images = []
37 | f = open(images_file, 'r')
38 | for lines in f:
39 | lines = lines.strip()
40 | images.append(lines)
41 |
42 | coor_completion_generator = define_CCM().cuda()
43 | CCM_checkpoint = torch.load(args.pretrained_model)
44 | coor_completion_generator.load_state_dict(CCM_checkpoint["g"])
45 | coor_completion_generator.eval()
46 | for param in coor_completion_generator.parameters():
47 | coor_completion_generator.requires_grad = False
48 |
49 | for i in range(len(images)):
50 | print ('%d/%d'%(i, len(images)))
51 | im_name = images[i]
52 |
53 | # get image
54 | path = os.path.join(args.dataroot, phase, im_name)
55 | im = Image.open(path)
56 | w, h = im.size
57 |
58 | # get uv coordinates
59 | uvcoor_root = os.path.join(args.coordinates_path, phase)
60 | uv_coor_path = os.path.join(uvcoor_root, im_name.split('.')[0]+'_uv_coor.npy')
61 | uv_mask_path = os.path.join(uvcoor_root, im_name.split('.')[0]+'_uv_mask.png')
62 | uv_symm_mask_path = os.path.join(uvcoor_root, im_name.split('.')[0]+'_uv_symm_mask.png')
63 |
64 | if (os.path.exists(uv_coor_path)):
65 | # read high-resolution coordinates
66 | uv_coor = np.load(uv_coor_path)
67 | uv_mask = np.array(Image.open(uv_mask_path))/255
68 | uv_symm_mask = np.array(Image.open(uv_symm_mask_path))/255
69 |
70 | # uv coor
71 | shift = int((h-w)/2)
72 | uv_coor[:,:,0] = uv_coor[:,:,0] + shift # put in center
73 | uv_coor = ((2*uv_coor/(h-1))-1)
74 | uv_coor = uv_coor*np.expand_dims(uv_mask,2) + (-10*(1-np.expand_dims(uv_mask,2)))
75 |
76 | x1 = shift
77 | x2 = h-(w+x1)
78 | im = pad_PIL(im, 0, x2, 0, x1, color=(0, 0, 0))
79 |
80 | ## coordinate completion
81 | uv_coor_pytorch = torch.from_numpy(uv_coor).float().permute(2, 0, 1).unsqueeze(0) # from h,w,c to 1,c,h,w
82 | uv_mask_pytorch = torch.from_numpy(uv_mask).unsqueeze(0).unsqueeze(0).float() #1xchw
83 | with torch.no_grad():
84 | coor_completion_generator.eval()
85 | complete_coor = coor_completion_generator(uv_coor_pytorch.cuda(), uv_mask_pytorch.cuda())
86 | uv_coor = complete_coor[0].permute(1,2,0).data.cpu().numpy()
87 | uv_confidence = np.stack([uv_mask-uv_symm_mask, uv_symm_mask, 1-uv_mask], 2)
88 |
89 | im = torch.from_numpy(np.array(im)).permute(2, 0, 1).unsqueeze(0).float()
90 | rgb_uv = torch.nn.functional.grid_sample(im.cuda(), complete_coor.permute(0,2,3,1).cuda())
91 | rgb_uv = rgb_uv[0].permute(1,2,0).data.cpu().numpy()
92 |
93 | # saving
94 | save_image(rgb_uv, os.path.join(args.save_path, phase, im_name.split('.jpg')[0]+'.png'))
95 | np.save(os.path.join(args.save_path, phase, '%s_uv_coor.npy'%(im_name.split('.')[0])), uv_coor)
96 | save_image(uv_confidence*255, os.path.join(args.save_path, phase, '%s_conf.png'%(im_name.split('.')[0])))
97 |
--------------------------------------------------------------------------------
/util/coordinate_completion_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 |
5 | def init_weights(net, init_type='normal', init_gain=0.02):
6 | """Initialize network weights.
7 |
8 | Parameters:
9 | net (network) -- network to be initialized
10 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
11 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
12 |
13 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
14 | work better for some applications. Feel free to try yourself.
15 | """
16 | def init_func(m): # define the initialization function
17 | classname = m.__class__.__name__
18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19 | if init_type == 'normal':
20 | init.normal_(m.weight.data, 0.0, init_gain)
21 | elif init_type == 'xavier':
22 | init.xavier_normal_(m.weight.data, gain=init_gain)
23 | elif init_type == 'kaiming':
24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25 | elif init_type == 'orthogonal':
26 | init.orthogonal_(m.weight.data, gain=init_gain)
27 | else:
28 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
29 | if hasattr(m, 'bias') and m.bias is not None:
30 | init.constant_(m.bias.data, 0.0)
31 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
32 | init.normal_(m.weight.data, 1.0, init_gain)
33 | init.constant_(m.bias.data, 0.0)
34 |
35 | print('initialize network with %s' % init_type)
36 | net.apply(init_func) # apply the initialization function
37 | return net
38 |
39 | def define_G(init_type='normal', init_gain=0.02):
40 | net = CoordinateCompletion()
41 | return init_weights(net, init_type, init_gain)
42 |
43 | class CoordinateCompletion(nn.Module):
44 | def __init__(self):
45 | super(CoordinateCompletion, self).__init__()
46 | self.generator = CoorGenerator(input_nc=2+1, output_nc=2, tanh=True)
47 |
48 | def forward(self, coor_xy, UV_texture_mask):
49 | complete_coor = self.generator(torch.cat((coor_xy, UV_texture_mask), 1))
50 | return complete_coor
51 |
52 | class CoorGenerator(nn.Module):
53 | def __init__(self, input_nc, output_nc, ngf=32, batch_norm=True, spectral_norm=False, tanh=True):
54 | super(CoorGenerator, self).__init__()
55 |
56 | block = GatedConv2dWithActivation
57 | activation = nn.ELU(inplace=True)
58 |
59 | model = [block(input_nc, ngf, kernel_size=5, stride=1, padding=2,
60 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation),\
61 | block(ngf, ngf*2, kernel_size=3, stride=2, padding=1,
62 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation),\
63 | block(ngf*2, ngf*2, kernel_size=3, stride=1, padding=1,
64 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation),\
65 | block(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1,
66 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
67 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1,
68 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
69 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1,
70 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
71 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1+1, dilation=2,
72 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
73 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1+3, dilation=4,
74 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
75 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1+7, dilation=8,
76 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
77 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1+7, dilation=8,
78 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
79 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1, dilation=1,
80 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
81 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1, dilation=1,
82 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
83 | nn.Upsample(scale_factor=2, mode='bilinear'),\
84 | block(ngf*4, ngf*2, kernel_size=3, stride=1, padding=1, dilation=1,
85 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
86 | nn.Upsample(scale_factor=2, mode='bilinear'),\
87 | block(ngf*2, ngf, kernel_size=3, stride=1, padding=1, dilation=1,
88 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
89 | block(ngf, int(ngf/2), kernel_size=3, stride=1, padding=1, dilation=1,
90 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \
91 | nn.Conv2d(int(ngf/2), output_nc, kernel_size=3, stride=1, padding=1, dilation=1)]
92 | if tanh:
93 | model += [ nn.Tanh()]
94 |
95 | self.model = nn.Sequential(*model)
96 |
97 | def forward(self, input):
98 | out = self.model(input)
99 | return out
100 |
101 | class GatedConv2dWithActivation(torch.nn.Module):
102 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,batch_norm=True, spectral_norm=False, activation=torch.nn.ELU(inplace=True)):
103 | super(GatedConv2dWithActivation, self).__init__()
104 | self.batch_norm = batch_norm
105 |
106 | self.pad = nn.ReflectionPad2d(padding)
107 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
108 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
109 | if spectral_norm:
110 | self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
111 | self.mask_conv2d = torch.nn.utils.spectral_norm(self.mask_conv2d)
112 | self.activation = activation
113 | self.sigmoid = torch.nn.Sigmoid()
114 |
115 | if self.batch_norm:
116 | self.batch_norm2d = torch.nn.BatchNorm2d(out_channels)
117 |
118 | def forward(self, input):
119 | padded_in = self.pad(input)
120 | x = self.conv2d(padded_in)
121 | mask = self.mask_conv2d(padded_in)
122 | gated_mask = self.sigmoid(mask)
123 | if self.batch_norm:
124 | x = self.batch_norm2d(x)
125 | if self.activation is not None:
126 | x = self.activation(x)
127 | x = x * gated_mask
128 | return x
129 |
--------------------------------------------------------------------------------
/util/dp2coor.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | from scipy.interpolate import griddata
5 | import cv2
6 | import argparse
7 |
8 | def getSymXYcoordinates(iuv, resolution = 256):
9 | xy, xyMask = getXYcoor(iuv, resolution = resolution)
10 | f_xy, f_xyMask = getXYcoor(flip_iuv(np.copy(iuv)), resolution = resolution)
11 | f_xyMask = np.clip(f_xyMask-xyMask, a_min=0, a_max=1)
12 | # combine actual + symmetric
13 | combined_texture = xy*np.expand_dims(xyMask,2) + f_xy*np.expand_dims(f_xyMask,2)
14 | combined_mask = np.clip(xyMask+f_xyMask, a_min=0, a_max=1)
15 | return combined_texture, combined_mask, f_xyMask
16 |
17 | def flip_iuv(iuv):
18 | POINT_LABEL_SYMMETRIES = [ 0, 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23]
19 | i = iuv[:,:,0]
20 | u = iuv[:,:,1]
21 | v = iuv[:,:,2]
22 | i_old = np.copy(i)
23 | for part in range(24):
24 | if (part + 1) in i_old:
25 | annot_indices_i = i_old == (part + 1)
26 | if POINT_LABEL_SYMMETRIES[part + 1] != part + 1:
27 | i[annot_indices_i] = POINT_LABEL_SYMMETRIES[part + 1]
28 | if part == 22 or part == 23 or part == 2 or part == 3 : #head and hands
29 | u[annot_indices_i] = 255-u[annot_indices_i]
30 | if part == 0 or part == 1: # torso
31 | v[annot_indices_i] = 255-v[annot_indices_i]
32 | return np.stack([i,u,v],2)
33 |
34 | def getXYcoor(iuv, resolution = 256):
35 | x, y, u, v = mapper(iuv, resolution)
36 | # A meshgrid of pixel coordinates
37 | nx, ny = resolution, resolution
38 | X, Y = np.meshgrid(np.arange(0, nx, 1), np.arange(0, ny, 1))
39 | ## get x,y coordinates
40 | uv_y = griddata((v, u), y, (Y, X), method='linear')
41 | uv_y_ = griddata((v, u), y, (Y, X), method='nearest')
42 | uv_y[np.isnan(uv_y)] = uv_y_[np.isnan(uv_y)]
43 | uv_x = griddata((v, u), x, (Y, X), method='linear')
44 | uv_x_ = griddata((v, u), x, (Y, X), method='nearest')
45 | uv_x[np.isnan(uv_x)] = uv_x_[np.isnan(uv_x)]
46 | # get mask
47 | uv_mask = np.zeros((ny,nx))
48 | uv_mask[np.ceil(v).astype(int),np.ceil(u).astype(int)]=1
49 | uv_mask[np.floor(v).astype(int),np.floor(u).astype(int)]=1
50 | uv_mask[np.ceil(v).astype(int),np.floor(u).astype(int)]=1
51 | uv_mask[np.floor(v).astype(int),np.ceil(u).astype(int)]=1
52 | kernel = np.ones((3,3),np.uint8)
53 | uv_mask_d = cv2.dilate(uv_mask,kernel,iterations = 1)
54 | # update
55 | coor_x = uv_x * uv_mask_d
56 | coor_y = uv_y * uv_mask_d
57 | coor_xy = np.stack([coor_x, coor_y], 2)
58 | return coor_xy, uv_mask_d
59 |
60 | def mapper(iuv, resolution=256):
61 | dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy')
62 | H, W, _ = iuv.shape
63 | iuv_raw = iuv[iuv[:, :, 0] > 0]
64 | x = np.linspace(0, W-1, W).astype(np.int)
65 | y = np.linspace(0, H-1, H).astype(np.int)
66 | xx, yy = np.meshgrid(x, y)
67 | xx_rgb = xx[iuv[:, :, 0] > 0]
68 | yy_rgb = yy[iuv[:, :, 0] > 0]
69 | # modify i to start from 0... 0-23
70 | i = iuv_raw[:, 0] - 1
71 | u = iuv_raw[:, 1]
72 | v = iuv_raw[:, 2]
73 | uv_smpl = dp_uv_lookup_256_np[
74 | i.astype(np.int),
75 | v.astype(np.int),
76 | u.astype(np.int)
77 | ]
78 | u_f = uv_smpl[:, 0] * (resolution - 1)
79 | v_f = (1 - uv_smpl[:, 1]) * (resolution - 1)
80 | return xx_rgb, yy_rgb, u_f, v_f
81 |
82 |
83 | if __name__ == "__main__":
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--image_file', type=str, help="path to image file to process. ex: ./train.lst")
86 | parser.add_argument("--save_path", type=str, help="path to save the uv data")
87 | parser.add_argument("--dp_path", type=str, help="path to densepose data")
88 | args = parser.parse_args()
89 |
90 | if not os.path.exists(args.save_path):
91 | os.makedirs(args.save_path)
92 |
93 | images = []
94 | f = open(args.image_file, 'r')
95 | for lines in f:
96 | lines = lines.strip()
97 | images.append(lines)
98 |
99 | for i in range(len(images)):
100 | im_name = images[i]
101 | print ('%d/%d'%(i+1, len(images)))
102 |
103 | dp = os.path.join(args.dp_path, im_name.split('.')[0]+'_iuv.png')
104 |
105 | iuv = np.array(Image.open(dp))
106 | h, w, _ = iuv.shape
107 | if np.sum(iuv[:,:,0]==0)==(h*w):
108 | print ('no human: invalid image %d: %s'%(i, im_name))
109 | else:
110 | uv_coor, uv_mask, uv_symm_mask = getSymXYcoordinates(iuv, resolution = 512)
111 | np.save(os.path.join(args.save_path, '%s_uv_coor.npy'%(im_name.split('.')[0])), uv_coor)
112 | mask_im = Image.fromarray((uv_mask*255).astype(np.uint8))
113 | mask_im.save(os.path.join(args.save_path, im_name.split('.')[0]+'_uv_mask.png'))
114 | mask_im = Image.fromarray((uv_symm_mask*255).astype(np.uint8))
115 | mask_im.save(os.path.join(args.save_path, im_name.split('.')[0]+'_uv_symm_mask.png'))
116 |
--------------------------------------------------------------------------------
/util/generate_fashion_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from PIL import Image
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--dataroot', type=str, help="path to dataset")
8 | args = parser.parse_args()
9 |
10 | IMG_EXTENSIONS = [
11 | '.jpg', '.JPG', '.jpeg', '.JPEG',
12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
13 | ]
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def make_dataset(dir):
19 | images = []
20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
21 |
22 | train_root = os.path.join(dir, 'train')
23 | if not os.path.exists(train_root):
24 | os.mkdir(train_root)
25 |
26 | test_root = os.path.join(dir, 'test')
27 | if not os.path.exists(test_root):
28 | os.mkdir(test_root)
29 |
30 | train_images = []
31 | train_f = open(os.path.join(dir, 'tools', 'train.lst'), 'r')
32 | for lines in train_f:
33 | lines = lines.strip()
34 | if lines.endswith('.jpg'):
35 | train_images.append(lines)
36 |
37 | test_images = []
38 | test_f = open(os.path.join(dir, 'tools', 'test.lst'), 'r')
39 | for lines in test_f:
40 | lines = lines.strip()
41 | if lines.endswith('.jpg'):
42 | test_images.append(lines)
43 |
44 |
45 | for root, _, fnames in sorted(os.walk(os.path.join(dir, 'img_highres'))):
46 | for fname in fnames:
47 | if is_image_file(fname):
48 | path = os.path.join(root, fname)
49 | path_names = path.split('/')
50 | print(path_names)
51 |
52 | path_names = path_names[4:]
53 | del path_names[1]
54 | path_names[0] = 'fashion'
55 | path_names[3] = path_names[3].replace('_', '')
56 | path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:])
57 | path_names = "".join(path_names)
58 |
59 | if path_names in train_images:
60 | shutil.copy(path, os.path.join(train_root, path_names))
61 | print('saving -- %s'%os.path.join(train_root, path_names))
62 | elif path_names in test_images:
63 | shutil.copy(path, os.path.join(test_root, path_names))
64 | print('saving -- %s'%os.path.join(train_root, path_names))
65 |
66 | make_dataset(args.dataroot)
67 |
--------------------------------------------------------------------------------
/util/pickle2png.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from PIL import Image
4 | import argparse
5 | import os
6 |
7 |
8 | def save_image(image_numpy, image_path):
9 | image_pil = Image.fromarray(image_numpy.astype(np.uint8))
10 | image_pil.save(image_path)
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--pickle_file', type=str, help="path to pickle file")
14 | parser.add_argument("--save_path", type=str, help="path to save the png images")
15 | args = parser.parse_args()
16 |
17 | # READING
18 | f = open(args.pickle_file, 'rb')
19 | data = pickle.load(f)
20 | data_size = len(data)
21 | print ('Will process %d images'%data_size)
22 |
23 | # save path
24 | if not os.path.exists(args.save_path):
25 | os.makedirs(args.save_path)
26 |
27 | for img_id in range(data_size):
28 | # assuming we always have 1 person
29 | name = data[img_id]['file_name']
30 | iuv_image_name = name.split('/')[-1].split('.')[0]+ '_iuv.png'
31 | iuv_name = os.path.join(args.save_path, iuv_image_name)
32 | size = np.array(Image.open(name).convert('RGB')).shape
33 | wrapped_iuv = np.zeros(size)
34 |
35 | print ('Processing %d/%d: %s'%(img_id+1, data_size, iuv_image_name))
36 | num_instances = len(data[img_id]['scores'])
37 | if num_instances is 0:
38 | print ('%s has no person.'%iuv_image_name)
39 | file_object.write(iuv_image_name+'\n')
40 | else:
41 | # get results - process first detected human
42 | instance_id = 0
43 | # process highest score detected human
44 | # instance_id = data[img_id]['scores'].numpy().tolist().index(max(data[img_id]['scores'].numpy().tolist()))
45 | pred_densepose_result = data[img_id]['pred_densepose'][instance_id]
46 | bbox_xyxy = data[img_id]['pred_boxes_XYXY'][instance_id]
47 | i = pred_densepose_result.labels
48 | uv = pred_densepose_result.uv * 255
49 | iuv = np.concatenate((np.expand_dims(i, 0), uv), axis=0)
50 | # 3xhxw to hxwx3
51 | iuv_arr = np.transpose(iuv, (1, 2, 0))
52 | # wrap iuv to size of image
53 | wrapped_iuv[int(bbox_xyxy[1]):iuv_arr.shape[0]+int(bbox_xyxy[1]), int(bbox_xyxy[0]):iuv_arr.shape[1]+int(bbox_xyxy[0]), :] = iuv_arr
54 | save_image(wrapped_iuv, iuv_name)
55 |
--------------------------------------------------------------------------------