├── DATASET.md ├── LICENSE ├── README.md ├── data ├── fashionWOMENBlouses_Shirtsid0000635004_1front.png ├── fashionWOMENBlouses_Shirtsid0000635004_1front_iuv.png ├── fashionWOMENBlouses_Shirtsid0000635004_1front_sil.png ├── fashionWOMENDressesid0000262902_1front_iuv.png ├── fashionWOMENDressesid0000262902_3back.png ├── fashionWOMENDressesid0000262902_3back_iuv.png ├── fashionWOMENDressesid0000262902_3back_sil.png ├── fashionWOMENSkirtsid0000177102_1front.png ├── fashionWOMENSkirtsid0000177102_1front_iuv.png └── fashionWOMENSkirtsid0000177102_1front_sil.png ├── dataset.py ├── distributed.py ├── garment_transfer.py ├── inference.py ├── model.py ├── op ├── __init__.py ├── conv2d_gradfix.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── requirements.txt ├── sphereface.py ├── test.py ├── train.py └── util ├── __init__.py ├── complete_coor.py ├── coordinate_completion_model.py ├── dp2coor.py ├── generate_fashion_datasets.py └── pickle2png.py /DATASET.md: -------------------------------------------------------------------------------- 1 | # Pose with Style: human reposing with pose-guided StyleGAN2 2 | 3 | 4 | ## Dataset and Downloads 5 | 1. Download images: 6 | 1. Download `img_highres.zip` from [In-shop Clothes Retrieval Benchmark](https://drive.google.com/drive/folders/0B7EVK8r0v71pYkd5TzBiclMzR00?resourcekey=0-fsjVShvqXP2517KnwaZ0zw). You will need to follow the [download instructions](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html) to unzip the file. Unzip file in `DATASET/DeepFashion_highres/img_highres` 7 | 8 | 2. Download the [train/test data](https://drive.google.com/drive/folders/1BX3Bxh8KG01yKWViRY0WTyDWbJHju-SL): **train.lst**, **test.lst**, and **fashion-pairs-test.csv**. Put in `DATASET/DeepFashion_highres/tools`. Note: because not all training images had their densepose detected we used a slightly modified training pairs file [**fashion-pairs-test.csv**](https://drive.google.com/file/d/1Uxpz8yBJ53XPkJ3O2GFP3nnbZWJQYlbv/view?usp=sharing). 9 | 10 | 3. Split the train/test dataset using: 11 | ``` 12 | python util/generate_fashion_datasets.py --dataroot DATASET/DeepFashion_highres 13 | ``` 14 | This will save the train/test images in `DeepFashion_highres/train` and `DeepFashion_highres/test`. 15 | 16 | 2. Compute [densepose](https://github.com/facebookresearch/detectron2/tree/master/projects/DensePose): 17 | 1. Install detectron2 following their [installation instructions](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). 18 | 19 | 2. Use [apply net](https://github.com/facebookresearch/detectron2/blob/master/projects/DensePose/doc/TOOL_APPLY_NET.md) from [densepose](https://github.com/facebookresearch/detectron2/tree/master/projects/DensePose) and save the train/test results to a pickle file. 20 | Make sure you download [densepose_rcnn_R_101_FPN_DL_s1x.pkl](https://github.com/facebookresearch/detectron2/blob/master/projects/DensePose/doc/DENSEPOSE_IUV.md#ModelZoo). 21 | 22 | 3. Copy `util/pickle2png.py` into your detectron2 DensePose project directory. Using the DensePose environment, convert the pickle file to densepose png images and save the results in `DATASET/densepose` directory, using: 23 | ``` 24 | python pickle2png.py --pickle_file train_output.pkl --save_path DATASET/densepose/train 25 | ``` 26 | 27 | 3. Compute [human foreground mask](https://github.com/Engineering-Course/CIHP_PGN). Save results in `silhouette` directory. Or you can download our computed silhouettes for the [training set](https://drive.google.com/file/d/1xXJGi5zkkTC2iIAUloylq6DlgqzAUrcw/view?usp=sharing) and [testing set](https://drive.google.com/file/d/1QdGgnBossIxsOrY8fUkJJYh1NbExIpyX/view?usp=sharing). 28 | 29 | 4. Compute UV space coordinates: 30 | 1. Compute UV space partial coordinates in the resolution 512x512. 31 | 1. Download the [UV space - 2D look up map](https://drive.google.com/file/d/1JLQ5bGl7YU-BwmdSc-DySy5Ya6FQIJBy/view?usp=sharing) and save it in `util` folder. 32 | 2. Compute partial coordinates: 33 | ``` 34 | python util/dp2coor.py --image_file DATASET/DeepFashion_highres/tools/train.lst --dp_path DATASET/densepose/train --save_path DATASET/partial_coordinates/train 35 | ``` 36 | 37 | 2. Complete the UV space coordinates offline, for faster training. 38 | 1. Download the pretrained coordinate completion model from [here](https://drive.google.com/file/d/1Tck_NzJ4ifT76csEShOtlRK7HpfjFhHP/view?usp=sharing). 39 | 2. Complete the partial coordinates. 40 | ``` 41 | python util/complete_coor.py --dataroot DATASET/DeepFashion_highres --coordinates_path DATASET/partial_coordinates --image_file DATASET/DeepFashion_highres/tools/train.lst --phase train --save_path DATASET/complete_coordinates --pretrained_model /path/to/CCM_epoch50.pt 42 | ``` 43 | 44 | 5. Download the following in `DATASET/resources`, to apply Face Identity loss: 45 | 1. download the pre-computed [required transformation (T)](https://drive.google.com/file/d/1r5ODZr1ewZk95Mdsmv-mGdW6lit3MNRc/view?usp=sharing) to align and crop the face. 46 | 2. Download [sphereface net pretrained model](https://drive.google.com/file/d/1p_cBfPZwwhWsWDXdKJ3n9VTp0_qsXZF0/view?usp=sharing). 47 | 48 | Note: we provide the DeepFashion train/test split of [StylePoseGAN](https://people.mpi-inf.mpg.de/~ksarkar/styleposegan/) [Sarkar et al. 2021]: [train pairs](https://drive.google.com/file/d/1ZaoQmUS92zHtqCWvsyunAaORlosLDVMm/view?usp=sharing), and [test pairs](https://drive.google.com/file/d/125EK9Y2QFMYMf8WV2_BoXIwLKz0uNxqa/view?usp=sharing). 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Badour AlBahar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pose with Style: Detail-Preserving Pose-Guided Image Synthesis with Conditional StyleGAN 2 | ### [[Paper](https://pose-with-style.github.io/asset/paper.pdf)] [[Project Website](https://pose-with-style.github.io/)] [[Output resutls](https://pose-with-style.github.io/results.html)] 3 | 4 | Official Pytorch implementation for **Pose with Style: Detail-Preserving Pose-Guided Image Synthesis with Conditional StyleGAN**. Please contact Badour AlBahar (badour@vt.edu) if you have any questions. 5 | 6 |

7 | 8 |

9 | 10 | 11 | ## Requirements 12 | ``` 13 | conda create -n pws python=3.8 14 | conda activate pws 15 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch 16 | pip install -r requirements.txt 17 | ``` 18 | Intall openCV using `conda install -c conda-forge opencv` or `pip install opencv-python`. 19 | If you would like to use [wandb](https://wandb.ai/site), install it using `pip install wandb`. 20 | 21 | ## Download pretrained models 22 | You can download the pretrained model [here](https://drive.google.com/file/d/1IcIJMTHA-_-_qcjRIrnSUnQf4n8GmtPd/view?usp=sharing), and the pretrained coordinate completion model [here](https://drive.google.com/file/d/1Tck_NzJ4ifT76csEShOtlRK7HpfjFhHP/view?usp=sharing). 23 | 24 | Note: we also provide the pretrained model trained on [StylePoseGAN](https://people.mpi-inf.mpg.de/~ksarkar/styleposegan/) [Sarkar et al. 2021] DeepFashion train/test split [here](https://drive.google.com/file/d/1DpOQ1Z7JEls3kCYaHMg1CpgjxbTJB6hg/view?usp=sharing). We also provide this split's pretrained coordinate completion model [here](https://drive.google.com/file/d/1oz-g0T5HLW-hVgeqYC55eI10DA-59Yp3/view?usp=sharing). 25 | 26 | ## Reposing 27 | Download the [UV space - 2D look up map](https://drive.google.com/file/d/1JLQ5bGl7YU-BwmdSc-DySy5Ya6FQIJBy/view?usp=sharing) and save it in `util` folder. 28 | 29 | We provide sample data in `data` directory. The output will be saved in `data/output` directory. 30 | ``` 31 | python inference.py --input_path ./data --CCM_pretrained_model path/to/CCM_epoch50.pt --pretrained_model path/to/posewithstyle.pt 32 | ``` 33 | 34 | To repose your own images you need to put the input image (input_name+'.png'), dense pose (input_name+'_iuv.png'), and silhouette (input_name+'_sil.png'), as well as the target dense pose (target_name+'_iuv.png') in `data` directory. 35 | ``` 36 | python inference.py --input_path ./data --input_name fashionWOMENDressesid0000262902_3back --target_name fashionWOMENDressesid0000262902_1front --CCM_pretrained_model path/to/CCM_epoch50.pt --pretrained_model path/to/posewithstyle.pt 37 | ``` 38 | 39 | ## Garment transfer 40 | Download the [UV space - 2D look up map](https://drive.google.com/file/d/1JLQ5bGl7YU-BwmdSc-DySy5Ya6FQIJBy/view?usp=sharing) and the [UV space body part segmentation](https://drive.google.com/file/d/179zoQVVrEgFbkEmHwLHU1AAqPktH-guU/view?usp=sharing). Save both in `util` folder. 41 | The UV space body part segmentation will provide a generic segmentation of the human body. Alternatively, you can specify your own mask of the region you want to transfer. 42 | 43 | We provide sample data in `data` directory. The output will be saved in `data/output` directory. 44 | ``` 45 | python garment_transfer.py --input_path ./data --CCM_pretrained_model path/to/CCM_epoch50.pt --pretrained_model path/to/posewithstyle.pt --part upper_body 46 | ``` 47 | 48 | To use your own images you need to put the input image (input_name+'.png'), dense pose (input_name+'_iuv.png'), and silhouette (input_name+'_sil.png'), as well as the garment source target image (target_name+'.png'), dense pose (target_name+'_iuv.png'), and silhouette (target_name+'_sil.png') in `data` directory. You can specify the part to be transferred using `--part` as `upper_body`, `lower_body`, `full_body` or `face`. The output as well as the part transferred (shown in red) will be saved in `data/output` directory. 49 | ``` 50 | python garment_transfer.py --input_path ./data --input_name fashionWOMENSkirtsid0000177102_1front --target_name fashionWOMENBlouses_Shirtsid0000635004_1front --CCM_pretrained_model path/to/CCM_epoch50.pt --pretrained_model path/to/posewithstyle.pt --part upper_body 51 | ``` 52 | 53 | ## DeepFashion Dataset 54 | To train or test, you must download and process the dataset. Please follow instructions in [Dataset and Downloads](https://github.com/BadourAlBahar/pose-with-style/blob/main/DATASET.md). 55 | 56 | You should have the following downloaded in your `DATASET` folder: 57 | ``` 58 | DATASET/DeepFashion_highres 59 | - train 60 | - test 61 | - tools 62 | - train.lst 63 | - test.lst 64 | - fashion-pairs-train.csv 65 | - fashion-pairs-test.csv 66 | 67 | DATASET/densepose 68 | - train 69 | - test 70 | 71 | DATASET/silhouette 72 | - train 73 | - test 74 | 75 | DATASET/partial_coordinates 76 | - train 77 | - test 78 | 79 | DATASET/complete_coordinates 80 | - train 81 | - test 82 | 83 | DATASET/resources 84 | - train_face_T.pickle 85 | - sphere20a_20171020.pth 86 | ``` 87 | 88 | 89 | ## Training 90 | Step 1: First, train the reposing model by focusing on generating the foreground. 91 | We set the batch size to 1 and train for 50 epochs. This training process takes around 7 days on 8 NVIDIA 2080 Ti GPUs. 92 | ``` 93 | python -m torch.distributed.launch --nproc_per_node=8 --master_port XXXX train.py --batch 1 /path/to/DATASET --name exp_name_step1 --size 512 --faceloss --epoch 50 94 | ``` 95 | The checkpoints will be saved in `checkpoint/exp_name`. 96 | 97 | Step 2: Then, finetune the model by training on the entire image (only masking the padded boundary). 98 | We set the batch size to 8 and train for 10 epochs. This training process takes less than 2 days on 2 A100 GPUs. 99 | ``` 100 | python -m torch.distributed.launch --nproc_per_node=2 --master_port XXXX train.py --batch 8 /path/to/DATASET --name exp_name_step2 --size 512 --faceloss --epoch 10 --ckpt /path/to/step1/pretrained/model --finetune 101 | ``` 102 | 103 | ## Testing 104 | To test the reposing model and generate the reposing results: 105 | ``` 106 | python test.py /path/to/DATASET --pretrained_model /path/to/step2/pretrained/model --size 512 --save_path /path/to/save/output 107 | ``` 108 | Output images will be saved in `--save_path`. 109 | 110 | You can find our reposing output images [here](https://pose-with-style.github.io/results.html). 111 | 112 | ## Evaluation 113 | We follow the same evaluation code as [Global-Flow-Local-Attention](https://github.com/RenYurui/Global-Flow-Local-Attention/blob/master/PERSON_IMAGE_GENERATION.md#evaluation). 114 | 115 | 116 | ## Bibtex 117 | Please consider citing our work if you find it useful for your research: 118 | 119 | @article{albahar2021pose, 120 | title = {Pose with {S}tyle: {D}etail-Preserving Pose-Guided Image Synthesis with Conditional StyleGAN}, 121 | author = {AlBahar, Badour and Lu, Jingwan and Yang, Jimei and Shu, Zhixin and Shechtman, Eli and Huang, Jia-Bin}, 122 | journal = {ACM Transactions on Graphics}, 123 | year = {2021} 124 | } 125 | 126 | 127 | ## Acknowledgments 128 | This code is heavily borrowed from [Rosinality: StyleGAN 2 in PyTorch](https://github.com/rosinality/stylegan2-pytorch). 129 | -------------------------------------------------------------------------------- /data/fashionWOMENBlouses_Shirtsid0000635004_1front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENBlouses_Shirtsid0000635004_1front.png -------------------------------------------------------------------------------- /data/fashionWOMENBlouses_Shirtsid0000635004_1front_iuv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENBlouses_Shirtsid0000635004_1front_iuv.png -------------------------------------------------------------------------------- /data/fashionWOMENBlouses_Shirtsid0000635004_1front_sil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENBlouses_Shirtsid0000635004_1front_sil.png -------------------------------------------------------------------------------- /data/fashionWOMENDressesid0000262902_1front_iuv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENDressesid0000262902_1front_iuv.png -------------------------------------------------------------------------------- /data/fashionWOMENDressesid0000262902_3back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENDressesid0000262902_3back.png -------------------------------------------------------------------------------- /data/fashionWOMENDressesid0000262902_3back_iuv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENDressesid0000262902_3back_iuv.png -------------------------------------------------------------------------------- /data/fashionWOMENDressesid0000262902_3back_sil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENDressesid0000262902_3back_sil.png -------------------------------------------------------------------------------- /data/fashionWOMENSkirtsid0000177102_1front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENSkirtsid0000177102_1front.png -------------------------------------------------------------------------------- /data/fashionWOMENSkirtsid0000177102_1front_iuv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENSkirtsid0000177102_1front_iuv.png -------------------------------------------------------------------------------- /data/fashionWOMENSkirtsid0000177102_1front_sil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/data/fashionWOMENSkirtsid0000177102_1front_sil.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | import torchvision.transforms as transforms 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | import torch 8 | import random 9 | import pickle 10 | 11 | 12 | class DeepFashionDataset(Dataset): 13 | def __init__(self, path, phase, size): 14 | self.phase = phase # train or test 15 | self.size = size # 256 or 512 FOR 174x256 or 348x512 16 | 17 | # set root directories 18 | self.image_root = os.path.join(path, 'DeepFashion_highres', phase) 19 | self.densepose_root = os.path.join(path, 'densepose', phase) 20 | self.parsing_root = os.path.join(path, 'silhouette', phase) 21 | # path to pairs of data 22 | pairs_csv_path = os.path.join(path, 'DeepFashion_highres', 'tools', 'fashion-pairs-%s.csv'%phase) 23 | 24 | # uv space 25 | self.uv_root = os.path.join(path, 'complete_coordinates', phase) 26 | 27 | # initialize the pairs of data 28 | self.init_pairs(pairs_csv_path) 29 | self.data_size = len(self.pairs) 30 | print('%s data pairs (#=%d)...'%(phase, self.data_size)) 31 | 32 | if phase == 'train': 33 | # get dictionary of image name and transfrom to detect and align the face 34 | with open(os.path.join(path, 'resources', 'train_face_T.pickle'), 'rb') as handle: 35 | self.faceTransform = pickle.load(handle) 36 | 37 | self.transform = transforms.Compose([transforms.ToTensor(), 38 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 39 | 40 | 41 | def init_pairs(self, pairs_csv_path): 42 | pairs_file = pd.read_csv(pairs_csv_path) 43 | self.pairs = [] 44 | self.sources = {} 45 | print('Loading data pairs ...') 46 | for i in range(len(pairs_file)): 47 | pair = [pairs_file.iloc[i]['from'], pairs_file.iloc[i]['to']] 48 | self.pairs.append(pair) 49 | print('Loading data pairs finished ...') 50 | 51 | 52 | def __len__(self): 53 | return self.data_size 54 | 55 | 56 | def resize_height_PIL(self, x, height=512): 57 | w, h = x.size 58 | width = int(height * w / h) 59 | return x.resize((width, height), Image.NEAREST) #Image.ANTIALIAS 60 | 61 | 62 | def resize_PIL(self, x, height=512, width=348, type=Image.NEAREST): 63 | return x.resize((width, height), type) 64 | 65 | 66 | def tensors2square(self, im, pose, sil): 67 | width = im.shape[2] 68 | diff = self.size - width 69 | if self.phase == 'train': 70 | left = random.randint(0, diff) 71 | right = diff - left 72 | else: # when testing put in the center 73 | left = int((self.size-width)/2) 74 | right = diff - left 75 | im = torch.nn.functional.pad(input=im, pad=(right, left, 0, 0), mode='constant', value=0) 76 | pose = torch.nn.functional.pad(input=pose, pad=(right, left, 0, 0), mode='constant', value=0) 77 | sil = torch.nn.functional.pad(input=sil, pad=(right, left, 0, 0), mode='constant', value=0) 78 | return im, pose, sil, left, right 79 | 80 | 81 | def __getitem__(self, index): 82 | # get current pair 83 | im1_name, im2_name = self.pairs[index] 84 | 85 | # get path to dataset 86 | input_image_path = os.path.join(self.image_root, im1_name) 87 | target_image_path = os.path.join(self.image_root, im2_name) 88 | # dense pose 89 | input_densepose_path = os.path.join(self.densepose_root, im1_name.split('.')[0]+'_iuv.png') 90 | target_densepose_path = os.path.join(self.densepose_root, im2_name.split('.')[0]+'_iuv.png') 91 | # silhouette 92 | input_sil_path = os.path.join(self.parsing_root, im1_name.split('.')[0]+'_sil.png') 93 | target_sil_path = os.path.join(self.parsing_root, im2_name.split('.')[0]+'_sil.png') 94 | # uv space 95 | complete_coor_path = os.path.join(self.uv_root, im1_name.split('.')[0]+'_uv_coor.npy') 96 | 97 | # read data 98 | # get original size of data -> for augmentation 99 | input_image_pil = Image.open(input_image_path).convert('RGB') 100 | orig_w, orig_h = input_image_pil.size 101 | if self.phase == 'test': 102 | # set target height and target width 103 | if self.size == 512: 104 | target_h = 512 105 | target_w = 348 106 | if self.size == 256: 107 | target_h = 256 108 | target_w = 174 109 | # images 110 | input_image = self.resize_PIL(input_image_pil, height=target_h, width=target_w, type=Image.ANTIALIAS) 111 | target_image = self.resize_PIL(Image.open(target_image_path).convert('RGB'), height=target_h, width=target_w, type=Image.ANTIALIAS) 112 | # dense pose 113 | input_densepose = np.array(self.resize_PIL(Image.open(input_densepose_path), height=target_h, width=target_w)) 114 | target_densepose = np.array(self.resize_PIL(Image.open(target_densepose_path), height=target_h, width=target_w)) 115 | # silhouette 116 | silhouette1 = np.array(self.resize_PIL(Image.open(input_sil_path), height=target_h, width=target_w))/255 117 | silhouette2 = np.array(self.resize_PIL(Image.open(target_sil_path), height=target_h, width=target_w))/255 118 | # union with densepose mask for a more accurate mask 119 | silhouette1 = 1-((1-silhouette1) * (input_densepose[:, :, 0] == 0).astype('float')) 120 | 121 | else: 122 | input_image = self.resize_height_PIL(input_image_pil, self.size) 123 | target_image = self.resize_height_PIL(Image.open(target_image_path).convert('RGB'), self.size) 124 | # dense pose 125 | input_densepose = np.array(self.resize_height_PIL(Image.open(input_densepose_path), self.size)) 126 | target_densepose = np.array(self.resize_height_PIL(Image.open(target_densepose_path), self.size)) 127 | # silhouette 128 | silhouette1 = np.array(self.resize_height_PIL(Image.open(input_sil_path), self.size))/255 129 | silhouette2 = np.array(self.resize_height_PIL(Image.open(target_sil_path), self.size))/255 130 | # union with densepose masks 131 | silhouette1 = 1-((1-silhouette1) * (input_densepose[:, :, 0] == 0).astype('float')) 132 | silhouette2 = 1-((1-silhouette2) * (target_densepose[:, :, 0] == 0).astype('float')) 133 | 134 | # read uv-space data 135 | complete_coor = np.load(complete_coor_path) 136 | 137 | # Transform 138 | input_image = self.transform(input_image) 139 | target_image = self.transform(target_image) 140 | # Dense Pose 141 | input_densepose = torch.from_numpy(input_densepose).permute(2, 0, 1) 142 | target_densepose = torch.from_numpy(target_densepose).permute(2, 0, 1) 143 | # silhouette 144 | silhouette1 = torch.from_numpy(silhouette1).float().unsqueeze(0) # from h,w to c,h,w 145 | silhouette2 = torch.from_numpy(silhouette2).float().unsqueeze(0) # from h,w to c,h,w 146 | 147 | # put into a square 148 | input_image, input_densepose, silhouette1, Sleft, Sright = self.tensors2square(input_image, input_densepose, silhouette1) 149 | target_image, target_densepose, silhouette2, Tleft, Tright = self.tensors2square(target_image, target_densepose, silhouette2) 150 | 151 | if self.phase == 'train': 152 | # remove loaded center shift and add augmentation shift 153 | loaded_shift = int((orig_h-orig_w)/2) 154 | complete_coor = ((complete_coor+1)/2)*(orig_h-1) # [-1, 1] to [0, orig_h] 155 | complete_coor[:,:,0] = complete_coor[:,:,0] - loaded_shift # remove center shift 156 | complete_coor = ((2*complete_coor/(orig_h-1))-1) # [0, orig_h] (no shift in w) to [-1, 1] 157 | complete_coor = ((complete_coor+1)/2) * (self.size-1) # [-1, 1] to [0, size] (no shift in w) 158 | complete_coor[:,:,0] = complete_coor[:,:,0] + Sright # add augmentation shift to w 159 | complete_coor = ((2*complete_coor/(self.size-1))-1) # [0, size] (with shift in w) to [-1,1] 160 | # to tensor 161 | complete_coor = torch.from_numpy(complete_coor).float().permute(2, 0, 1) 162 | else: 163 | # might have hxw inconsistencies since dp is of different sizes.. fixing this.. 164 | loaded_shift = int((orig_h-orig_w)/2) 165 | complete_coor = ((complete_coor+1)/2)*(orig_h-1) # [-1, 1] to [0, orig_h] 166 | complete_coor[:,:,0] = complete_coor[:,:,0] - loaded_shift # remove center shift 167 | # before: width complete_coor[:,:,0] 0-orig_w-1 168 | # and height complete_coor[:,:,1] 0-orig_h-1 169 | complete_coor[:,:,0] = (complete_coor[:,:,0]/(orig_w-1))*(target_w-1) 170 | complete_coor[:,:,1] = (complete_coor[:,:,1]/(orig_h-1))*(target_h-1) 171 | complete_coor[:,:,0] = complete_coor[:,:,0] + Sright # add center shift to w 172 | complete_coor = ((2*complete_coor/(self.size-1))-1) # [0, size] (with shift in w) to [-1,1] 173 | # to tensor 174 | complete_coor = torch.from_numpy(complete_coor).float().permute(2, 0, 1) 175 | 176 | # either source or target pass 1:5 177 | if self.phase == 'train': 178 | choice = random.randint(0, 6) 179 | if choice == 0: 180 | # source pass 181 | target_im = input_image 182 | target_p = input_densepose 183 | target_sil = silhouette1 184 | target_image_name = im1_name 185 | target_left_pad = Sleft 186 | target_right_pad = Sright 187 | else: 188 | # target pass 189 | target_im = target_image 190 | target_p = target_densepose 191 | target_sil = silhouette2 192 | target_image_name = im2_name 193 | target_left_pad = Tleft 194 | target_right_pad = Tright 195 | else: 196 | target_im = target_image 197 | target_p = target_densepose 198 | target_sil = silhouette2 199 | target_image_name = im2_name 200 | target_left_pad = Tleft 201 | target_right_pad = Tright 202 | 203 | # Get the face transfrom 204 | if self.phase == 'train': 205 | if target_image_name in self.faceTransform.keys(): 206 | FT = torch.from_numpy(self.faceTransform[target_image_name]).float() 207 | else: # no face detected 208 | FT = torch.zeros((3,3)) 209 | 210 | # return data 211 | if self.phase == 'train': 212 | return {'input_image':input_image, 'target_image':target_im, 213 | 'target_sil': target_sil, 214 | 'target_pose':target_p, 215 | 'TargetFaceTransform': FT, 'target_left_pad':torch.tensor(target_left_pad), 'target_right_pad':torch.tensor(target_right_pad), 216 | 'input_sil': silhouette1, 'complete_coor':complete_coor, 217 | } 218 | 219 | if self.phase == 'test': 220 | save_name = im1_name.split('.')[0] + '_2_' + im2_name.split('.')[0] + '_vis.png' 221 | return {'input_image':input_image, 'target_image':target_im, 222 | 'target_sil': target_sil, 223 | 'target_pose':target_p, 224 | 'target_left_pad':torch.tensor(target_left_pad), 'target_right_pad':torch.tensor(target_right_pad), 225 | 'input_sil': silhouette1, 'complete_coor':complete_coor, 226 | 'save_name':save_name, 227 | } 228 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /garment_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torchvision import utils 5 | from tqdm import tqdm 6 | from torch.utils import data 7 | import numpy as np 8 | import random 9 | from PIL import Image 10 | import torchvision.transforms as transforms 11 | from dataset import DeepFashionDataset 12 | from model import Generator 13 | from util.dp2coor import getSymXYcoordinates 14 | from util.coordinate_completion_model import define_G as define_CCM 15 | 16 | 17 | def tensors2square(im, pose, sil): 18 | width = im.shape[2] 19 | diff = args.size - width 20 | left = int((args.size-width)/2) 21 | right = diff - left 22 | im = torch.nn.functional.pad(input=im, pad=(right, left, 0, 0), mode='constant', value=0) 23 | pose = torch.nn.functional.pad(input=pose, pad=(right, left, 0, 0), mode='constant', value=0) 24 | sil = torch.nn.functional.pad(input=sil, pad=(right, left, 0, 0), mode='constant', value=0) 25 | return im, pose, sil 26 | 27 | def tensor2square(x): 28 | width = x.shape[2] 29 | diff = args.size - width 30 | left = int((args.size-width)/2) 31 | right = diff - left 32 | x = torch.nn.functional.pad(input=x, pad=(right, left, 0, 0), mode='constant', value=0) 33 | return x 34 | 35 | def generate(args, g_ema, device, mean_latent): 36 | with torch.no_grad(): 37 | g_ema.eval() 38 | 39 | path = args.input_path 40 | input_name = args.input_name 41 | target_name = args.target_name 42 | part = args.part 43 | 44 | # input 45 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 46 | input_image = Image.open(os.path.join(path, input_name+'.png')).convert('RGB') 47 | iw, ih = input_image.size 48 | input_image = transform(input_image).float().to(device) 49 | input_pose = np.array(Image.open(os.path.join(path, input_name+'_iuv.png'))) 50 | input_sil = np.array(Image.open(os.path.join(path, input_name+'_sil.png')))/255 51 | # get partial coordinates from dense pose 52 | dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy') 53 | input_uv_coor, input_uv_mask, input_uv_symm_mask = getSymXYcoordinates(input_pose, resolution = 512) 54 | # union sil with densepose masks 55 | input_sil = 1-((1-input_sil) * (input_pose[:, :, 0] == 0).astype('float')) 56 | input_sil = torch.from_numpy(input_sil).float().unsqueeze(0) 57 | input_pose = torch.from_numpy(input_pose).permute(2, 0, 1) 58 | 59 | # target 60 | target_image = Image.open(os.path.join(path, target_name+'.png')).convert('RGB') 61 | tw, th = target_image.size 62 | target_image = transform(target_image).float().to(device) 63 | target_pose = np.array(Image.open(os.path.join(path, target_name+'_iuv.png'))) 64 | target_sil = np.array(Image.open(os.path.join(path, target_name+'_sil.png')))/255 65 | # get partial coordinates from dense pose 66 | target_uv_coor, target_uv_mask, target_uv_symm_mask = getSymXYcoordinates(target_pose, resolution = 512) 67 | # union sil with densepose masks 68 | target_sil = 1-((1-target_sil) * (target_pose[:, :, 0] == 0).astype('float')) 69 | target_sil = torch.from_numpy(target_sil).float().unsqueeze(0) 70 | target_pose = torch.from_numpy(target_pose).permute(2, 0, 1) 71 | 72 | # convert to square by centering 73 | input_image, input_pose, input_sil = tensors2square(input_image, input_pose, input_sil) 74 | target_image, target_pose, target_sil = tensors2square(target_image, target_pose, target_sil) 75 | 76 | # add batch dimension 77 | input_image = input_image.unsqueeze(0).float().to(device) 78 | input_pose = input_pose.unsqueeze(0).float().to(device) 79 | input_sil = input_sil.unsqueeze(0).float().to(device) 80 | target_image = target_image.unsqueeze(0).float().to(device) 81 | target_pose = target_pose.unsqueeze(0).float().to(device) 82 | target_sil = target_sil.unsqueeze(0).float().to(device) 83 | 84 | # complete partial coordinates 85 | coor_completion_generator = define_CCM().cuda() 86 | CCM_checkpoint = torch.load(args.CCM_pretrained_model) 87 | coor_completion_generator.load_state_dict(CCM_checkpoint["g"]) 88 | coor_completion_generator.eval() 89 | for param in coor_completion_generator.parameters(): 90 | coor_completion_generator.requires_grad = False 91 | 92 | # uv coor preprocessing (put image in center) 93 | # input 94 | ishift = int((ih-iw)/2) # center shift 95 | input_uv_coor[:,:,0] = input_uv_coor[:,:,0] + ishift # put in center 96 | input_uv_coor = ((2*input_uv_coor/(ih-1))-1) 97 | input_uv_coor = input_uv_coor*np.expand_dims(input_uv_mask,2) + (-10*(1-np.expand_dims(input_uv_mask,2))) 98 | # target 99 | tshift = int((th-tw)/2) # center shift 100 | target_uv_coor[:,:,0] = target_uv_coor[:,:,0] + tshift # put in center 101 | target_uv_coor = ((2*target_uv_coor/(th-1))-1) 102 | target_uv_coor = target_uv_coor*np.expand_dims(target_uv_mask,2) + (-10*(1-np.expand_dims(target_uv_mask,2))) 103 | 104 | # coordinate completion 105 | # input 106 | uv_coor_pytorch = torch.from_numpy(input_uv_coor).float().permute(2, 0, 1).unsqueeze(0) # from h,w,c to 1,c,h,w 107 | uv_mask_pytorch = torch.from_numpy(input_uv_mask).unsqueeze(0).unsqueeze(0).float() #1xchw 108 | with torch.no_grad(): 109 | coor_completion_generator.eval() 110 | input_complete_coor = coor_completion_generator(uv_coor_pytorch.cuda(), uv_mask_pytorch.cuda()) 111 | # target 112 | uv_coor_pytorch = torch.from_numpy(target_uv_coor).float().permute(2, 0, 1).unsqueeze(0) # from h,w,c to 1,c,h,w 113 | uv_mask_pytorch = torch.from_numpy(target_uv_mask).unsqueeze(0).unsqueeze(0).float() #1xchw 114 | with torch.no_grad(): 115 | coor_completion_generator.eval() 116 | target_complete_coor = coor_completion_generator(uv_coor_pytorch.cuda(), uv_mask_pytorch.cuda()) 117 | 118 | 119 | # garment transfer 120 | appearance = torch.cat([input_image, input_sil, input_complete_coor, target_image, target_sil, target_complete_coor], 1) 121 | output, part_mask = g_ema(appearance=appearance, pose=input_pose) 122 | 123 | # visualize the transfered part 124 | zeros = torch.zeros(part_mask.shape).to(part_mask) 125 | ones255 = torch.ones(part_mask.shape).to(part_mask)*255 126 | part_red = torch.cat([part_mask*255, zeros, zeros], 1) 127 | part_img = part_red * ((input_pose[:, 0, :, :] != 0)) + torch.cat([ones255, ones255, ones255], 1)*(1-part_mask) 128 | 129 | utils.save_image( 130 | output[:, :, :, int(ishift):args.size-int(ishift)], 131 | os.path.join(args.save_path, input_name+'_and_'+target_name+'_'+args.part+'_vis.png'), 132 | nrow=1, 133 | normalize=True, 134 | range=(-1, 1), 135 | ) 136 | utils.save_image( 137 | part_img[:, :, :, int(ishift):args.size-int(ishift)], 138 | os.path.join(args.save_path, input_name+'_and_'+target_name+'_'+args.part+'.png'), 139 | nrow=1, 140 | normalize=True, 141 | range=(0, 255), 142 | ) 143 | 144 | 145 | 146 | if __name__ == "__main__": 147 | device = "cuda" 148 | 149 | parser = argparse.ArgumentParser(description="inference") 150 | 151 | parser.add_argument("--input_path", type=str, help="path to the input dataset") 152 | parser.add_argument("--input_name", type=str, default="fashionWOMENSkirtsid0000177102_1front", help="input file name") 153 | parser.add_argument("--target_name", type=str, default="fashionWOMENBlouses_Shirtsid0000635004_1front", help="target file name") 154 | parser.add_argument("--part", type=str, default="upper_body", help="body part to transfer upper_body, lower_body, full_body, and face") 155 | parser.add_argument("--size", type=int, default=512, help="output image size of the generator") 156 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 157 | parser.add_argument("--truncation_mean", type=int, default=4096, help="number of vectors to calculate mean for the truncation") 158 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of the generator. config-f = 2, else = 1") 159 | parser.add_argument("--pretrained_model", type=str, default="posewithstyle.pt", help="pose with style pretrained model") 160 | parser.add_argument("--CCM_pretrained_model", type=str, default="CCM_epoch50.pt", help="pretrained coordinate completion model") 161 | parser.add_argument("--save_path", type=str, default="./data/output", help="path to save output .data/output") 162 | 163 | args = parser.parse_args() 164 | 165 | args.latent = 2048 166 | args.n_mlp = 8 167 | 168 | if not os.path.exists(args.save_path): 169 | os.makedirs(args.save_path) 170 | 171 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, garment_transfer=True, part=args.part).to(device) 172 | checkpoint = torch.load(args.pretrained_model) 173 | g_ema.load_state_dict(checkpoint["g_ema"]) 174 | 175 | if args.truncation < 1: 176 | with torch.no_grad(): 177 | mean_latent = g_ema.mean_latent(args.truncation_mean) 178 | else: 179 | mean_latent = None 180 | 181 | generate(args, g_ema, device, mean_latent) 182 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torchvision import utils 5 | from tqdm import tqdm 6 | from torch.utils import data 7 | import numpy as np 8 | import random 9 | from PIL import Image 10 | import torchvision.transforms as transforms 11 | from dataset import DeepFashionDataset 12 | from model import Generator 13 | from util.dp2coor import getSymXYcoordinates 14 | from util.coordinate_completion_model import define_G as define_CCM 15 | 16 | def tensors2square(im, pose, sil): 17 | width = im.shape[2] 18 | diff = args.size - width 19 | left = int((args.size-width)/2) 20 | right = diff - left 21 | im = torch.nn.functional.pad(input=im, pad=(right, left, 0, 0), mode='constant', value=0) 22 | pose = torch.nn.functional.pad(input=pose, pad=(right, left, 0, 0), mode='constant', value=0) 23 | sil = torch.nn.functional.pad(input=sil, pad=(right, left, 0, 0), mode='constant', value=0) 24 | return im, pose, sil 25 | 26 | def tensor2square(x): 27 | width = x.shape[2] 28 | diff = args.size - width 29 | left = int((args.size-width)/2) 30 | right = diff - left 31 | x = torch.nn.functional.pad(input=x, pad=(right, left, 0, 0), mode='constant', value=0) 32 | return x 33 | 34 | def generate(args, g_ema, device, mean_latent): 35 | with torch.no_grad(): 36 | g_ema.eval() 37 | 38 | path = args.input_path 39 | input_name = args.input_name 40 | pose_name = args.target_name 41 | 42 | # input 43 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 44 | input_image = Image.open(os.path.join(path, input_name+'.png')).convert('RGB') 45 | w, h = input_image.size 46 | input_image = transform(input_image).float().to(device) 47 | 48 | input_pose = np.array(Image.open(os.path.join(path, input_name+'_iuv.png'))) 49 | input_sil = np.array(Image.open(os.path.join(path, input_name+'_sil.png')))/255 50 | 51 | # get partial coordinates from dense pose 52 | dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy') 53 | uv_coor, uv_mask, uv_symm_mask = getSymXYcoordinates(input_pose, resolution = 512) 54 | 55 | # union sil with densepose masks 56 | input_sil = 1-((1-input_sil) * (input_pose[:, :, 0] == 0).astype('float')) 57 | 58 | input_sil = torch.from_numpy(input_sil).float().unsqueeze(0) 59 | input_pose = torch.from_numpy(input_pose).permute(2, 0, 1) 60 | 61 | # target 62 | target_pose = np.array(Image.open(os.path.join(path, pose_name+'_iuv.png'))) 63 | target_pose = torch.from_numpy(target_pose).permute(2, 0, 1) 64 | 65 | # convert to square by centering 66 | input_image, input_pose, input_sil = tensors2square(input_image, input_pose, input_sil) 67 | target_pose = tensor2square(target_pose) 68 | 69 | # add batch dimension 70 | input_image = input_image.unsqueeze(0).float().to(device) 71 | input_pose = input_pose.unsqueeze(0).float().to(device) 72 | input_sil = input_sil.unsqueeze(0).float().to(device) 73 | target_pose = target_pose.unsqueeze(0).float().to(device) 74 | 75 | # complete partial coordinates 76 | coor_completion_generator = define_CCM().cuda() 77 | CCM_checkpoint = torch.load(args.CCM_pretrained_model) 78 | coor_completion_generator.load_state_dict(CCM_checkpoint["g"]) 79 | coor_completion_generator.eval() 80 | for param in coor_completion_generator.parameters(): 81 | coor_completion_generator.requires_grad = False 82 | 83 | # uv coor preprocessing (put image in center) 84 | shift = int((h-w)/2) # center shift 85 | uv_coor[:,:,0] = uv_coor[:,:,0] + shift # put in center 86 | uv_coor = ((2*uv_coor/(h-1))-1) 87 | uv_coor = uv_coor*np.expand_dims(uv_mask,2) + (-10*(1-np.expand_dims(uv_mask,2))) 88 | 89 | # coordinate completion 90 | uv_coor_pytorch = torch.from_numpy(uv_coor).float().permute(2, 0, 1).unsqueeze(0) # from h,w,c to 1,c,h,w 91 | uv_mask_pytorch = torch.from_numpy(uv_mask).unsqueeze(0).unsqueeze(0).float() #1xchw 92 | with torch.no_grad(): 93 | coor_completion_generator.eval() 94 | complete_coor = coor_completion_generator(uv_coor_pytorch.cuda(), uv_mask_pytorch.cuda()) 95 | 96 | # reposing 97 | appearance = torch.cat([input_image, input_sil, complete_coor], 1) 98 | output, _ = g_ema(appearance=appearance, pose=target_pose) 99 | 100 | utils.save_image( 101 | output[:, :, :, int(shift):args.size-int(shift)], 102 | os.path.join(args.save_path, input_name+'_2_'+pose_name+'_vis.png'), 103 | nrow=1, 104 | normalize=True, 105 | range=(-1, 1), 106 | ) 107 | 108 | 109 | 110 | if __name__ == "__main__": 111 | device = "cuda" 112 | 113 | parser = argparse.ArgumentParser(description="inference") 114 | 115 | parser.add_argument("--input_path", type=str, help="path to the input dataset") 116 | parser.add_argument("--input_name", type=str, default="fashionWOMENDressesid0000262902_3back", help="input file name") 117 | parser.add_argument("--target_name", type=str, default="fashionWOMENDressesid0000262902_1front", help="target file name") 118 | parser.add_argument("--size", type=int, default=512, help="output image size of the generator") 119 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 120 | parser.add_argument("--truncation_mean", type=int, default=4096, help="number of vectors to calculate mean for the truncation") 121 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of the generator. config-f = 2, else = 1") 122 | parser.add_argument("--pretrained_model", type=str, default="posewithstyle.pt", help="pose with style pretrained model") 123 | parser.add_argument("--CCM_pretrained_model", type=str, default="CCM_epoch50.pt", help="pretrained coordinate completion model") 124 | parser.add_argument("--save_path", type=str, default="./data/output", help="path to save output .data/output") 125 | 126 | args = parser.parse_args() 127 | 128 | args.latent = 2048 129 | args.n_mlp = 8 130 | 131 | if not os.path.exists(args.save_path): 132 | os.makedirs(args.save_path) 133 | 134 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) 135 | checkpoint = torch.load(args.pretrained_model) 136 | g_ema.load_state_dict(checkpoint["g_ema"]) 137 | 138 | if args.truncation < 1: 139 | with torch.no_grad(): 140 | mean_latent = g_ema.mean_latent(args.truncation_mean) 141 | else: 142 | mean_latent = None 143 | 144 | generate(args, g_ema, device, mean_latent) 145 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import functools 4 | import operator 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | 11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 12 | 13 | 14 | class PixelNorm(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, input): 19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 20 | 21 | 22 | def make_kernel(k): 23 | k = torch.tensor(k, dtype=torch.float32) 24 | 25 | if k.ndim == 1: 26 | k = k[None, :] * k[:, None] 27 | 28 | k /= k.sum() 29 | 30 | return k 31 | 32 | 33 | class Upsample(nn.Module): 34 | def __init__(self, kernel, factor=2): 35 | super().__init__() 36 | 37 | self.factor = factor 38 | kernel = make_kernel(kernel) * (factor ** 2) 39 | self.register_buffer("kernel", kernel) 40 | 41 | p = kernel.shape[0] - factor 42 | 43 | pad0 = (p + 1) // 2 + factor - 1 44 | pad1 = p // 2 45 | 46 | self.pad = (pad0, pad1) 47 | 48 | def forward(self, input): 49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 50 | 51 | return out 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, kernel, factor=2): 56 | super().__init__() 57 | 58 | self.factor = factor 59 | kernel = make_kernel(kernel) 60 | self.register_buffer("kernel", kernel) 61 | 62 | p = kernel.shape[0] - factor 63 | 64 | pad0 = (p + 1) // 2 65 | pad1 = p // 2 66 | 67 | self.pad = (pad0, pad1) 68 | 69 | def forward(self, input): 70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 71 | 72 | return out 73 | 74 | 75 | class Blur(nn.Module): 76 | def __init__(self, kernel, pad, upsample_factor=1): 77 | super().__init__() 78 | 79 | kernel = make_kernel(kernel) 80 | 81 | if upsample_factor > 1: 82 | kernel = kernel * (upsample_factor ** 2) 83 | 84 | self.register_buffer("kernel", kernel) 85 | 86 | self.pad = pad 87 | 88 | def forward(self, input): 89 | out = upfirdn2d(input, self.kernel, pad=self.pad) 90 | 91 | return out 92 | 93 | 94 | class EqualConv2d(nn.Module): 95 | def __init__( 96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 97 | ): 98 | super().__init__() 99 | 100 | self.weight = nn.Parameter( 101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 102 | ) 103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 104 | 105 | self.stride = stride 106 | self.padding = padding 107 | 108 | if bias: 109 | self.bias = nn.Parameter(torch.zeros(out_channel)) 110 | 111 | else: 112 | self.bias = None 113 | 114 | def forward(self, input): 115 | out = conv2d_gradfix.conv2d( 116 | input, 117 | self.weight * self.scale, 118 | bias=self.bias, 119 | stride=self.stride, 120 | padding=self.padding, 121 | ) 122 | 123 | return out 124 | 125 | def __repr__(self): 126 | return ( 127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 129 | ) 130 | 131 | 132 | class EqualLinear(nn.Module): 133 | def __init__( 134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 135 | ): 136 | super().__init__() 137 | 138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 139 | 140 | if bias: 141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 142 | 143 | else: 144 | self.bias = None 145 | 146 | self.activation = activation 147 | 148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 149 | self.lr_mul = lr_mul 150 | 151 | def forward(self, input): 152 | if self.activation: 153 | out = F.linear(input, self.weight * self.scale) 154 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 155 | 156 | else: 157 | out = F.linear( 158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 159 | ) 160 | 161 | return out 162 | 163 | def __repr__(self): 164 | return ( 165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 166 | ) 167 | 168 | 169 | class ModulatedConv2d(nn.Module): 170 | def __init__( 171 | self, 172 | in_channel, 173 | out_channel, 174 | kernel_size, 175 | style_dim, 176 | demodulate=True, 177 | upsample=False, 178 | downsample=False, 179 | blur_kernel=[1, 3, 3, 1], 180 | fused=True, 181 | ): 182 | super().__init__() 183 | 184 | self.eps = 1e-8 185 | self.kernel_size = kernel_size 186 | self.in_channel = in_channel 187 | self.out_channel = out_channel 188 | self.upsample = upsample 189 | self.downsample = downsample 190 | 191 | if upsample: 192 | factor = 2 193 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 194 | pad0 = (p + 1) // 2 + factor - 1 195 | pad1 = p // 2 + 1 196 | 197 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 198 | 199 | if downsample: 200 | factor = 2 201 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 202 | pad0 = (p + 1) // 2 203 | pad1 = p // 2 204 | 205 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 206 | 207 | fan_in = in_channel * kernel_size ** 2 208 | self.scale = 1 / math.sqrt(fan_in) 209 | self.padding = kernel_size // 2 210 | 211 | self.weight = nn.Parameter( 212 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 213 | ) 214 | 215 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 216 | 217 | self.demodulate = demodulate 218 | self.fused = fused 219 | 220 | def __repr__(self): 221 | return ( 222 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 223 | f"upsample={self.upsample}, downsample={self.downsample})" 224 | ) 225 | 226 | def forward(self, input, style): 227 | batch, in_channel, height, width = input.shape 228 | 229 | if not self.fused: 230 | weight = self.scale * self.weight.squeeze(0) 231 | style = self.modulation(style) 232 | 233 | if self.demodulate: 234 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) 235 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() 236 | 237 | input = input * style.reshape(batch, in_channel, 1, 1) 238 | 239 | if self.upsample: 240 | weight = weight.transpose(0, 1) 241 | out = conv2d_gradfix.conv_transpose2d( 242 | input, weight, padding=0, stride=2 243 | ) 244 | out = self.blur(out) 245 | 246 | elif self.downsample: 247 | input = self.blur(input) 248 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) 249 | 250 | else: 251 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) 252 | 253 | if self.demodulate: 254 | out = out * dcoefs.view(batch, -1, 1, 1) 255 | 256 | return out 257 | 258 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 259 | weight = self.scale * self.weight * style 260 | 261 | if self.demodulate: 262 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 263 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 264 | 265 | weight = weight.view( 266 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 267 | ) 268 | 269 | if self.upsample: 270 | input = input.view(1, batch * in_channel, height, width) 271 | weight = weight.view( 272 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 273 | ) 274 | weight = weight.transpose(1, 2).reshape( 275 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 276 | ) 277 | out = conv2d_gradfix.conv_transpose2d( 278 | input, weight, padding=0, stride=2, groups=batch 279 | ) 280 | _, _, height, width = out.shape 281 | out = out.view(batch, self.out_channel, height, width) 282 | out = self.blur(out) 283 | 284 | elif self.downsample: 285 | input = self.blur(input) 286 | _, _, height, width = input.shape 287 | input = input.view(1, batch * in_channel, height, width) 288 | out = conv2d_gradfix.conv2d( 289 | input, weight, padding=0, stride=2, groups=batch 290 | ) 291 | _, _, height, width = out.shape 292 | out = out.view(batch, self.out_channel, height, width) 293 | 294 | else: 295 | input = input.view(1, batch * in_channel, height, width) 296 | out = conv2d_gradfix.conv2d( 297 | input, weight, padding=self.padding, groups=batch 298 | ) 299 | _, _, height, width = out.shape 300 | out = out.view(batch, self.out_channel, height, width) 301 | 302 | return out 303 | 304 | 305 | class SpatiallyModulatedConv2d(nn.Module): 306 | def __init__( 307 | self, 308 | in_channel, 309 | out_channel, 310 | kernel_size, 311 | upsample=False, 312 | downsample=False, 313 | blur_kernel=[1, 3, 3, 1], 314 | ): 315 | super().__init__() 316 | 317 | self.eps = 1e-8 318 | self.kernel_size = kernel_size 319 | self.in_channel = in_channel 320 | self.out_channel = out_channel 321 | self.upsample = upsample 322 | self.downsample = downsample 323 | 324 | if upsample: 325 | factor = 2 326 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 327 | pad0 = (p + 1) // 2 + factor - 1 328 | pad1 = p // 2 + 1 329 | 330 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 331 | 332 | if downsample: 333 | factor = 2 334 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 335 | pad0 = (p + 1) // 2 336 | pad1 = p // 2 337 | 338 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 339 | 340 | fan_in = in_channel * kernel_size ** 2 341 | self.scale = 1 / math.sqrt(fan_in) 342 | self.padding = kernel_size // 2 343 | 344 | self.weight = nn.Parameter( 345 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 346 | ) 347 | 348 | self.gamma = nn.Sequential(*[EqualConv2d(in_channel, 128, kernel_size=1), nn.ReLU(True), EqualConv2d(128, in_channel, kernel_size=1)]) 349 | self.beta = nn.Sequential(*[EqualConv2d(in_channel, 128, kernel_size=1), nn.ReLU(True), EqualConv2d(128, in_channel, kernel_size=1)]) 350 | 351 | def calc_mean_std(self, feat, eps=1e-5): 352 | size = feat.size() 353 | assert (len(size) == 4) 354 | N, C = size[:2] 355 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 356 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 357 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 358 | return feat_mean, feat_std 359 | 360 | def modulate(self, x, gamma, beta): 361 | return gamma * x + beta 362 | 363 | def normalize(self, x): 364 | mean, std = self.calc_mean_std(x) 365 | mean = mean.expand_as(x) 366 | std = std.expand_as(x) 367 | return (x-mean)/std 368 | 369 | def forward(self, input, style): 370 | batch, in_channel, height, width = input.shape 371 | 372 | weight = self.scale * self.weight.squeeze(0) 373 | 374 | gamma = self.gamma(style) 375 | beta = self.beta(style) 376 | 377 | input = self.modulate(input, gamma, beta) 378 | 379 | if self.upsample: 380 | weight = weight.transpose(0, 1) 381 | out = conv2d_gradfix.conv_transpose2d( 382 | input, weight, padding=0, stride=2 383 | ) 384 | out = self.blur(out) 385 | elif self.downsample: 386 | input = self.blur(input) 387 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) 388 | else: 389 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) 390 | 391 | out = self.normalize(out) 392 | 393 | return out 394 | 395 | 396 | class NoiseInjection(nn.Module): 397 | def __init__(self): 398 | super().__init__() 399 | 400 | self.weight = nn.Parameter(torch.zeros(1)) 401 | 402 | def forward(self, image, noise=None): 403 | if noise is None: 404 | batch, _, height, width = image.shape 405 | noise = image.new_empty(batch, 1, height, width).normal_() 406 | 407 | return image + self.weight * noise 408 | 409 | 410 | class ConstantInput(nn.Module): 411 | def __init__(self, channel, size=4): 412 | super().__init__() 413 | 414 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 415 | 416 | def forward(self, input): 417 | batch = input.shape[0] 418 | out = self.input.repeat(batch, 1, 1, 1) 419 | 420 | return out 421 | 422 | 423 | class StyledConv(nn.Module): 424 | def __init__( 425 | self, 426 | in_channel, 427 | out_channel, 428 | kernel_size, 429 | style_dim, 430 | upsample=False, 431 | blur_kernel=[1, 3, 3, 1], 432 | demodulate=True, 433 | spatial=False, 434 | ): 435 | super().__init__() 436 | 437 | if spatial: 438 | self.conv = SpatiallyModulatedConv2d( 439 | in_channel, 440 | out_channel, 441 | kernel_size, 442 | upsample=upsample, 443 | blur_kernel=blur_kernel, 444 | ) 445 | else: 446 | self.conv = ModulatedConv2d( 447 | in_channel, 448 | out_channel, 449 | kernel_size, 450 | style_dim, 451 | upsample=upsample, 452 | blur_kernel=blur_kernel, 453 | demodulate=demodulate, 454 | ) 455 | 456 | self.noise = NoiseInjection() 457 | self.activate = FusedLeakyReLU(out_channel) 458 | 459 | def forward(self, input, style, noise=None): 460 | out = self.conv(input, style) 461 | out = self.noise(out, noise=noise) 462 | out = self.activate(out) 463 | 464 | return out 465 | 466 | 467 | class ToRGB(nn.Module): 468 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], spatial=False): 469 | super().__init__() 470 | 471 | if upsample: 472 | self.upsample = Upsample(blur_kernel) 473 | 474 | if spatial: 475 | self.conv = SpatiallyModulatedConv2d(in_channel, 3, 1) 476 | else: 477 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 478 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 479 | 480 | def forward(self, input, style, skip=None): 481 | out = self.conv(input, style) 482 | out = out + self.bias 483 | 484 | if skip is not None: 485 | skip = self.upsample(skip) 486 | 487 | out = out + skip 488 | 489 | return out 490 | 491 | 492 | class PoseEncoder(nn.Module): 493 | def __init__(self, ngf=64, blur_kernel=[1, 3, 3, 1], size=256): 494 | super().__init__() 495 | self.size = size 496 | convs = [ConvLayer(3, ngf, 1)] 497 | convs.append(ResBlock(ngf, ngf*2, blur_kernel)) 498 | convs.append(ResBlock(ngf*2, ngf*4, blur_kernel)) 499 | convs.append(ResBlock(ngf*4, ngf*8, blur_kernel)) 500 | convs.append(ResBlock(ngf*8, ngf*8, blur_kernel)) 501 | if self.size == 512: 502 | convs.append(ResBlock(ngf*8, ngf*8, blur_kernel)) 503 | if self.size == 1024: 504 | convs.append(ResBlock(ngf*8, ngf*8, blur_kernel)) 505 | convs.append(ResBlock(ngf*8, ngf*8, blur_kernel)) 506 | 507 | self.convs = nn.Sequential(*convs) 508 | 509 | def forward(self, input): 510 | out = self.convs(input) 511 | return out 512 | 513 | 514 | class SpatialAppearanceEncoder(nn.Module): 515 | def __init__(self, ngf=64, blur_kernel=[1, 3, 3, 1], size=256): 516 | super().__init__() 517 | self.size = size 518 | self.dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy') 519 | input_nc = 4 # source RGB and sil 520 | 521 | self.conv1 = ConvLayer(input_nc, ngf, 1) # ngf 256 256 522 | self.conv2 = ResBlock(ngf, ngf*2, blur_kernel) # 2ngf 128 128 523 | self.conv3 = ResBlock(ngf*2, ngf*4, blur_kernel) # 4ngf 64 64 524 | self.conv4 = ResBlock(ngf*4, ngf*8, blur_kernel) # 8ngf 32 32 525 | self.conv5 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 526 | if self.size == 512: 527 | self.conv6 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 - starting from ngf 512 512 528 | if self.size == 1024: 529 | self.conv6 = ResBlock(ngf*8, ngf*8, blur_kernel) 530 | self.conv7 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 - starting from ngf 1024 0124 531 | 532 | self.conv11 = EqualConv2d(ngf+1, ngf*8, 1) 533 | self.conv21 = EqualConv2d(ngf*2+1, ngf*8, 1) 534 | self.conv31 = EqualConv2d(ngf*4+1, ngf*8, 1) 535 | self.conv41 = EqualConv2d(ngf*8+1, ngf*8, 1) 536 | self.conv51 = EqualConv2d(ngf*8+1, ngf*8, 1) 537 | if self.size == 512: 538 | self.conv61 = EqualConv2d(ngf*8+1, ngf*8, 1) 539 | if self.size == 1024: 540 | self.conv61 = EqualConv2d(ngf*8+1, ngf*8, 1) 541 | self.conv71 = EqualConv2d(ngf*8+1, ngf*8, 1) 542 | 543 | if self.size == 1024: 544 | self.conv13 = EqualConv2d(ngf*8, int(ngf/2), 3, padding=1) 545 | self.conv23 = EqualConv2d(ngf*8, ngf*1, 3, padding=1) 546 | self.conv33 = EqualConv2d(ngf*8, ngf*2, 3, padding=1) 547 | self.conv43 = EqualConv2d(ngf*8, ngf*4, 3, padding=1) 548 | self.conv53 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 549 | self.conv63 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 550 | elif self.size == 512: 551 | self.conv13 = EqualConv2d(ngf*8, ngf*1, 3, padding=1) 552 | self.conv23 = EqualConv2d(ngf*8, ngf*2, 3, padding=1) 553 | self.conv33 = EqualConv2d(ngf*8, ngf*4, 3, padding=1) 554 | self.conv43 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 555 | self.conv53 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 556 | else: 557 | self.conv13 = EqualConv2d(ngf*8, ngf*2, 3, padding=1) 558 | self.conv23 = EqualConv2d(ngf*8, ngf*4, 3, padding=1) 559 | self.conv33 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 560 | self.conv43 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 561 | 562 | self.up = nn.Upsample(scale_factor=2) 563 | 564 | def uv2target(self, dp, tex, tex_mask=None): 565 | # uv2img 566 | b, _, h, w = dp.shape 567 | ch = tex.shape[1] 568 | UV_MAPs = [] 569 | for idx in range(b): 570 | iuv = dp[idx] 571 | point_pos = iuv[0, :, :] > 0 572 | iuv_raw_i = iuv[0][point_pos] 573 | iuv_raw_u = iuv[1][point_pos] 574 | iuv_raw_v = iuv[2][point_pos] 575 | iuv_raw = torch.stack([iuv_raw_i,iuv_raw_u,iuv_raw_v], 0) 576 | i = iuv_raw[0, :] - 1 ## dp_uv_lookup_256_np does not contain BG class 577 | u = iuv_raw[1, :] 578 | v = iuv_raw[2, :] 579 | i = i.cpu().numpy() 580 | u = u.cpu().numpy() 581 | v = v.cpu().numpy() 582 | uv_smpl = self.dp_uv_lookup_256_np[i, v, u] 583 | uv_map = torch.zeros((2,h,w)).to(tex).float() 584 | ## being normalize [0,1] to [-1,1] for the grid sample of Pytorch 585 | u_map = uv_smpl[:, 0] * 2 - 1 586 | v_map = (1 - uv_smpl[:, 1]) * 2 - 1 587 | uv_map[0][point_pos] = torch.from_numpy(u_map).to(tex).float() 588 | uv_map[1][point_pos] = torch.from_numpy(v_map).to(tex).float() 589 | UV_MAPs.append(uv_map) 590 | uv_map = torch.stack(UV_MAPs, 0) 591 | # warping 592 | # before warping validate sizes 593 | _, _, h_x, w_x = tex.shape 594 | _, _, h_t, w_t = uv_map.shape 595 | if h_t != h_x or w_t != w_x: 596 | #https://github.com/iPERDance/iPERCore/blob/4a010f781a4fb90dd29a516472e4aadf41ed1609/iPERCore/models/networks/generators/lwb_avg_resunet.py#L55 597 | uv_map = torch.nn.functional.interpolate(uv_map, size=(h_x, w_x), mode='bilinear', align_corners=True) 598 | uv_map = uv_map.permute(0, 2, 3, 1) 599 | warped_image = torch.nn.functional.grid_sample(tex.float(), uv_map.float()) 600 | if tex_mask is not None: 601 | warped_mask = torch.nn.functional.grid_sample(tex_mask.float(), uv_map.float()) 602 | final_warped = warped_image * warped_mask 603 | return final_warped, warped_mask 604 | else: 605 | return warped_image 606 | 607 | def forward(self, input, pose): 608 | coor = input[:,4:,:,:] 609 | input = input[:,:4,:,:] 610 | 611 | x1 = self.conv1(input) 612 | x2 = self.conv2(x1) 613 | x3 = self.conv3(x2) 614 | x4 = self.conv4(x3) 615 | x5 = self.conv5(x4) 616 | if self.size == 512: 617 | x6 = self.conv6(x5) 618 | if self.size == 1024: 619 | x6 = self.conv6(x5) 620 | x7 = self.conv7(x6) 621 | 622 | # warp- get flow 623 | pose_mask = 1-(pose[:,0, :, :] == 0).float().unsqueeze(1) 624 | flow = self.uv2target(pose.int(), coor) 625 | # warp- resize flow 626 | f1 = torch.nn.functional.interpolate(flow, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True) 627 | f2 = torch.nn.functional.interpolate(flow, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True) 628 | f3 = torch.nn.functional.interpolate(flow, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True) 629 | f4 = torch.nn.functional.interpolate(flow, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True) 630 | f5 = torch.nn.functional.interpolate(flow, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True) 631 | if self.size == 512: 632 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 633 | if self.size == 1024: 634 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 635 | f7 = torch.nn.functional.interpolate(flow, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True) 636 | # warp- now warp 637 | x1 = torch.nn.functional.grid_sample(x1, f1.permute(0,2,3,1)) 638 | x2 = torch.nn.functional.grid_sample(x2, f2.permute(0,2,3,1)) 639 | x3 = torch.nn.functional.grid_sample(x3, f3.permute(0,2,3,1)) 640 | x4 = torch.nn.functional.grid_sample(x4, f4.permute(0,2,3,1)) 641 | x5 = torch.nn.functional.grid_sample(x5, f5.permute(0,2,3,1)) 642 | if self.size == 512: 643 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1)) 644 | if self.size == 1024: 645 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1)) 646 | x7 = torch.nn.functional.grid_sample(x7, f7.permute(0,2,3,1)) 647 | 648 | # mask features 649 | p1 = torch.nn.functional.interpolate(pose_mask, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True) 650 | p2 = torch.nn.functional.interpolate(pose_mask, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True) 651 | p3 = torch.nn.functional.interpolate(pose_mask, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True) 652 | p4 = torch.nn.functional.interpolate(pose_mask, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True) 653 | p5 = torch.nn.functional.interpolate(pose_mask, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True) 654 | if self.size == 512: 655 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 656 | if self.size == 1024: 657 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 658 | p7 = torch.nn.functional.interpolate(pose_mask, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True) 659 | 660 | x1 = x1 * p1 661 | x2 = x2 * p2 662 | x3 = x3 * p3 663 | x4 = x4 * p4 664 | x5 = x5 * p5 665 | if self.size == 512: 666 | x6 = x6 * p6 667 | if self.size == 1024: 668 | x6 = x6 * p6 669 | x7 = x7 * p7 670 | 671 | # fpn 672 | if self.size == 1024: 673 | F7 = self.conv71(torch.cat([x7,p7], 1)) 674 | f6 = self.up(F7)+self.conv61(torch.cat([x6,p6], 1)) 675 | F6 = self.conv63(f6) 676 | f5 = self.up(F6)+self.conv51(torch.cat([x5,p5], 1)) 677 | F5 = self.conv53(f5) 678 | f4 = self.up(f5)+self.conv41(torch.cat([x4,p4], 1)) 679 | F4 = self.conv43(f4) 680 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1)) 681 | F3 = self.conv33(f3) 682 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1)) 683 | F2 = self.conv23(f2) 684 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1)) 685 | F1 = self.conv13(f1) 686 | elif self.size == 512: 687 | F6 = self.conv61(torch.cat([x6,p6], 1)) 688 | f5 = self.up(F6)+self.conv51(torch.cat([x5,p5], 1)) 689 | F5 = self.conv53(f5) 690 | f4 = self.up(f5)+self.conv41(torch.cat([x4,p4], 1)) 691 | F4 = self.conv43(f4) 692 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1)) 693 | F3 = self.conv33(f3) 694 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1)) 695 | F2 = self.conv23(f2) 696 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1)) 697 | F1 = self.conv13(f1) 698 | else: 699 | F5 = self.conv51(torch.cat([x5,p5], 1)) 700 | f4 = self.up(F5)+self.conv41(torch.cat([x4,p4], 1)) 701 | F4 = self.conv43(f4) 702 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1)) 703 | F3 = self.conv33(f3) 704 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1)) 705 | F2 = self.conv23(f2) 706 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1)) 707 | F1 = self.conv13(f1) 708 | 709 | if self.size == 1024: 710 | return [F7, F6, F5, F4, F3, F2, F1] 711 | elif self.size == 512: 712 | return [F6, F5, F4, F3, F2, F1] 713 | else: 714 | return [F5, F4, F3, F2, F1] 715 | 716 | 717 | class Generator(nn.Module): 718 | def __init__( 719 | self, 720 | size, 721 | style_dim, 722 | n_mlp, 723 | channel_multiplier=2, 724 | blur_kernel=[1, 3, 3, 1], 725 | lr_mlp=0.01, 726 | garment_transfer=False, 727 | part='upper_body', 728 | ): 729 | super().__init__() 730 | 731 | self.garment_transfer = garment_transfer 732 | self.size = size 733 | self.style_dim = style_dim 734 | 735 | if self.garment_transfer: 736 | self.appearance_encoder = GarmentTransferSpatialAppearanceEncoder(size=size, part=part) 737 | else: 738 | self.appearance_encoder = SpatialAppearanceEncoder(size=size) 739 | self.pose_encoder = PoseEncoder(size=size) 740 | 741 | # StyleGAN 742 | self.channels = { 743 | 16: 512, 744 | 32: 512, 745 | 64: 256 * channel_multiplier, 746 | 128: 128 * channel_multiplier, 747 | 256: 64 * channel_multiplier, 748 | 512: 32 * channel_multiplier, 749 | 1024: 16 * channel_multiplier, 750 | } 751 | 752 | self.conv1 = StyledConv( 753 | self.channels[16], self.channels[16], 3, style_dim, blur_kernel=blur_kernel, spatial=True 754 | ) 755 | self.to_rgb1 = ToRGB(self.channels[16], style_dim, upsample=False, spatial=True) 756 | 757 | self.log_size = int(math.log(size, 2)) 758 | self.num_layers = (self.log_size - 4) * 2 + 1 759 | 760 | self.convs = nn.ModuleList() 761 | self.upsamples = nn.ModuleList() 762 | self.to_rgbs = nn.ModuleList() 763 | self.noises = nn.Module() 764 | 765 | in_channel = self.channels[16] 766 | 767 | for i in range(5, self.log_size + 1): 768 | out_channel = self.channels[2 ** i] 769 | 770 | self.convs.append( 771 | StyledConv( 772 | in_channel, 773 | out_channel, 774 | 3, 775 | style_dim, 776 | upsample=True, 777 | blur_kernel=blur_kernel, 778 | spatial=True, 779 | ) 780 | ) 781 | 782 | self.convs.append( 783 | StyledConv( 784 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel, spatial=True, 785 | ) 786 | ) 787 | 788 | self.to_rgbs.append(ToRGB(out_channel, style_dim, spatial=True)) 789 | 790 | in_channel = out_channel 791 | 792 | self.n_latent = self.log_size * 2 - 2 793 | 794 | 795 | def make_noise(self): 796 | device = self.input.input.device 797 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 798 | for i in range(3, self.log_size + 1): 799 | for _ in range(2): 800 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 801 | return noises 802 | 803 | def mean_latent(self, n_latent): 804 | latent_in = torch.randn( 805 | n_latent, self.style_dim, device=self.input.input.device 806 | ) 807 | latent = self.style(latent_in).mean(0, keepdim=True) 808 | return latent 809 | 810 | def get_latent(self, input): 811 | return self.style(input) 812 | 813 | def forward( 814 | self, 815 | pose, 816 | appearance, 817 | styles=None, 818 | return_latents=False, 819 | inject_index=None, 820 | truncation=1, 821 | truncation_latent=None, 822 | input_is_latent=False, 823 | noise=None, 824 | randomize_noise=True, 825 | ): 826 | 827 | if self.garment_transfer: 828 | styles, part_mask = self.appearance_encoder(appearance, pose) 829 | else: 830 | styles = self.appearance_encoder(appearance, pose) 831 | 832 | if noise is None: 833 | if randomize_noise: 834 | noise = [None] * self.num_layers 835 | else: 836 | noise = [ 837 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 838 | ] 839 | 840 | latent = [styles[0], styles[0]] 841 | if self.size == 1024: 842 | length = 6 843 | elif self.size == 512: 844 | length = 5 845 | else: 846 | length = 4 847 | for i in range(length): 848 | latent += [styles[i+1],styles[i+1]] 849 | 850 | out = self.pose_encoder(pose) 851 | out = self.conv1(out, latent[0], noise=noise[0]) 852 | skip = self.to_rgb1(out, latent[1]) 853 | 854 | i = 1 855 | for conv1, conv2, noise1, noise2, to_rgb in zip( 856 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs#, self.to_silhouettes 857 | ): 858 | out = conv1(out, latent[i], noise=noise1) 859 | out = conv2(out, latent[i + 1], noise=noise2) 860 | skip = to_rgb(out, latent[i + 2], skip) 861 | i += 2 862 | 863 | image = skip 864 | 865 | if self.garment_transfer: 866 | return image, part_mask 867 | else: 868 | if return_latents: 869 | return image, latent 870 | else: 871 | return image, None 872 | 873 | 874 | class ConvLayer(nn.Sequential): 875 | def __init__( 876 | self, 877 | in_channel, 878 | out_channel, 879 | kernel_size, 880 | downsample=False, 881 | blur_kernel=[1, 3, 3, 1], 882 | bias=True, 883 | activate=True, 884 | ): 885 | layers = [] 886 | 887 | if downsample: 888 | factor = 2 889 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 890 | pad0 = (p + 1) // 2 891 | pad1 = p // 2 892 | 893 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 894 | 895 | stride = 2 896 | self.padding = 0 897 | 898 | else: 899 | stride = 1 900 | self.padding = kernel_size // 2 901 | 902 | layers.append( 903 | EqualConv2d( 904 | in_channel, 905 | out_channel, 906 | kernel_size, 907 | padding=self.padding, 908 | stride=stride, 909 | bias=bias and not activate, 910 | ) 911 | ) 912 | 913 | if activate: 914 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 915 | 916 | super().__init__(*layers) 917 | 918 | 919 | class ResBlock(nn.Module): 920 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 921 | super().__init__() 922 | 923 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 924 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 925 | 926 | self.skip = ConvLayer( 927 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 928 | ) 929 | 930 | def forward(self, input): 931 | out = self.conv1(input) 932 | out = self.conv2(out) 933 | 934 | skip = self.skip(input) 935 | out = (out + skip) / math.sqrt(2) 936 | 937 | return out 938 | 939 | 940 | class Discriminator(nn.Module): 941 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 942 | super().__init__() 943 | 944 | channels = { 945 | 4: 512, 946 | 8: 512, 947 | 16: 512, 948 | 32: 512, 949 | 64: 256 * channel_multiplier, 950 | 128: 128 * channel_multiplier, 951 | 256: 64 * channel_multiplier, 952 | 512: 32 * channel_multiplier, 953 | 1024: 16 * channel_multiplier, 954 | } 955 | 956 | convs = [ConvLayer(6, channels[size], 1)] 957 | 958 | log_size = int(math.log(size, 2)) 959 | 960 | in_channel = channels[size] 961 | 962 | for i in range(log_size, 2, -1): 963 | out_channel = channels[2 ** (i - 1)] 964 | 965 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 966 | 967 | in_channel = out_channel 968 | 969 | self.convs = nn.Sequential(*convs) 970 | 971 | self.stddev_group = 4 972 | self.stddev_feat = 1 973 | 974 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 975 | self.final_linear = nn.Sequential( 976 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 977 | EqualLinear(channels[4], 1), 978 | ) 979 | 980 | def forward(self, input, pose): 981 | input = torch.cat([input, pose], 1) 982 | out = self.convs(input) 983 | 984 | batch, channel, height, width = out.shape 985 | group = min(batch, self.stddev_group) 986 | stddev = out.view( 987 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 988 | ) 989 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 990 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 991 | stddev = stddev.repeat(group, 1, height, width) 992 | out = torch.cat([out, stddev], 1) 993 | 994 | out = self.final_conv(out) 995 | 996 | out = out.view(batch, -1) 997 | out = self.final_linear(out) 998 | 999 | return out 1000 | 1001 | 1002 | class VGGLoss(nn.Module): 1003 | def __init__(self, device): 1004 | super(VGGLoss, self).__init__() 1005 | self.vgg = Vgg19().to(device) 1006 | self.criterion = nn.L1Loss() 1007 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 1008 | 1009 | def forward(self, x, y): 1010 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 1011 | loss = 0 1012 | for i in range(len(x_vgg)): 1013 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 1014 | return loss 1015 | 1016 | from torchvision import models 1017 | class Vgg19(torch.nn.Module): 1018 | def __init__(self, requires_grad=False): 1019 | super(Vgg19, self).__init__() 1020 | vgg_pretrained_features = models.vgg19(pretrained=True).features 1021 | self.slice1 = torch.nn.Sequential() 1022 | self.slice2 = torch.nn.Sequential() 1023 | self.slice3 = torch.nn.Sequential() 1024 | self.slice4 = torch.nn.Sequential() 1025 | self.slice5 = torch.nn.Sequential() 1026 | for x in range(2): 1027 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 1028 | for x in range(2, 7): 1029 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 1030 | for x in range(7, 12): 1031 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 1032 | for x in range(12, 21): 1033 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 1034 | for x in range(21, 30): 1035 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 1036 | if not requires_grad: 1037 | for param in self.parameters(): 1038 | param.requires_grad = False 1039 | 1040 | def forward(self, X): 1041 | h_relu1 = self.slice1(X) 1042 | h_relu2 = self.slice2(h_relu1) 1043 | h_relu3 = self.slice3(h_relu2) 1044 | h_relu4 = self.slice4(h_relu3) 1045 | h_relu5 = self.slice5(h_relu4) 1046 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 1047 | return out 1048 | 1049 | 1050 | ## Garment transfer 1051 | class GarmentTransferSpatialAppearanceEncoder(nn.Module): 1052 | def __init__(self, ngf=64, blur_kernel=[1, 3, 3, 1], size=256, part='upper_body'): 1053 | super().__init__() 1054 | self.size = size 1055 | self.part = part 1056 | self.dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy') 1057 | self.uv_parts = torch.from_numpy(np.load('util/uv_space_parts.npy')).unsqueeze(0).unsqueeze(0) 1058 | 1059 | input_nc = 4 # source RGB and sil 1060 | 1061 | self.conv1 = ConvLayer(input_nc, ngf, 1) # ngf 256 256 1062 | self.conv2 = ResBlock(ngf, ngf*2, blur_kernel) # 2ngf 128 128 1063 | self.conv3 = ResBlock(ngf*2, ngf*4, blur_kernel) # 4ngf 64 64 1064 | self.conv4 = ResBlock(ngf*4, ngf*8, blur_kernel) # 8ngf 32 32 1065 | self.conv5 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 1066 | if self.size == 512: 1067 | self.conv6 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 - starting from ngf 512 512 1068 | if self.size == 1024: 1069 | self.conv6 = ResBlock(ngf*8, ngf*8, blur_kernel) 1070 | self.conv7 = ResBlock(ngf*8, ngf*8, blur_kernel) # 8ngf 16 16 - starting from ngf 1024 0124 1071 | 1072 | self.conv11 = EqualConv2d(ngf+1, ngf*8, 1) 1073 | self.conv21 = EqualConv2d(ngf*2+1, ngf*8, 1) 1074 | self.conv31 = EqualConv2d(ngf*4+1, ngf*8, 1) 1075 | self.conv41 = EqualConv2d(ngf*8+1, ngf*8, 1) 1076 | self.conv51 = EqualConv2d(ngf*8+1, ngf*8, 1) 1077 | if self.size == 512: 1078 | self.conv61 = EqualConv2d(ngf*8+1, ngf*8, 1) 1079 | if self.size == 1024: 1080 | self.conv61 = EqualConv2d(ngf*8+1, ngf*8, 1) 1081 | self.conv71 = EqualConv2d(ngf*8+1, ngf*8, 1) 1082 | 1083 | if self.size == 1024: 1084 | self.conv13 = EqualConv2d(ngf*8, int(ngf/2), 3, padding=1) 1085 | self.conv23 = EqualConv2d(ngf*8, ngf*1, 3, padding=1) 1086 | self.conv33 = EqualConv2d(ngf*8, ngf*2, 3, padding=1) 1087 | self.conv43 = EqualConv2d(ngf*8, ngf*4, 3, padding=1) 1088 | self.conv53 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 1089 | self.conv63 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 1090 | elif self.size == 512: 1091 | self.conv13 = EqualConv2d(ngf*8, ngf*1, 3, padding=1) 1092 | self.conv23 = EqualConv2d(ngf*8, ngf*2, 3, padding=1) 1093 | self.conv33 = EqualConv2d(ngf*8, ngf*4, 3, padding=1) 1094 | self.conv43 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 1095 | self.conv53 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 1096 | else: 1097 | self.conv13 = EqualConv2d(ngf*8, ngf*2, 3, padding=1) 1098 | self.conv23 = EqualConv2d(ngf*8, ngf*4, 3, padding=1) 1099 | self.conv33 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 1100 | self.conv43 = EqualConv2d(ngf*8, ngf*8, 3, padding=1) 1101 | 1102 | self.up = nn.Upsample(scale_factor=2) 1103 | 1104 | def uv2target(self, dp, tex, tex_mask=None): 1105 | # uv2img 1106 | b, _, h, w = dp.shape 1107 | ch = tex.shape[1] 1108 | UV_MAPs = [] 1109 | for idx in range(b): 1110 | iuv = dp[idx] 1111 | point_pos = iuv[0, :, :] > 0 1112 | iuv_raw_i = iuv[0][point_pos] 1113 | iuv_raw_u = iuv[1][point_pos] 1114 | iuv_raw_v = iuv[2][point_pos] 1115 | iuv_raw = torch.stack([iuv_raw_i,iuv_raw_u,iuv_raw_v], 0) 1116 | i = iuv_raw[0, :] - 1 ## dp_uv_lookup_256_np does not contain BG class 1117 | u = iuv_raw[1, :] 1118 | v = iuv_raw[2, :] 1119 | i = i.cpu().numpy() 1120 | u = u.cpu().numpy() 1121 | v = v.cpu().numpy() 1122 | uv_smpl = self.dp_uv_lookup_256_np[i, v, u] 1123 | uv_map = torch.zeros((2,h,w)).to(tex).float() 1124 | ## being normalize [0,1] to [-1,1] for the grid sample of Pytorch 1125 | u_map = uv_smpl[:, 0] * 2 - 1 1126 | v_map = (1 - uv_smpl[:, 1]) * 2 - 1 1127 | uv_map[0][point_pos] = torch.from_numpy(u_map).to(tex).float() 1128 | uv_map[1][point_pos] = torch.from_numpy(v_map).to(tex).float() 1129 | UV_MAPs.append(uv_map) 1130 | uv_map = torch.stack(UV_MAPs, 0) 1131 | # warping 1132 | # before warping validate sizes 1133 | _, _, h_x, w_x = tex.shape 1134 | _, _, h_t, w_t = uv_map.shape 1135 | if h_t != h_x or w_t != w_x: 1136 | #https://github.com/iPERDance/iPERCore/blob/4a010f781a4fb90dd29a516472e4aadf41ed1609/iPERCore/models/networks/generators/lwb_avg_resunet.py#L55 1137 | uv_map = torch.nn.functional.interpolate(uv_map, size=(h_x, w_x), mode='bilinear', align_corners=True) 1138 | uv_map = uv_map.permute(0, 2, 3, 1) 1139 | warped_image = torch.nn.functional.grid_sample(tex.float(), uv_map.float()) 1140 | if tex_mask is not None: 1141 | warped_mask = torch.nn.functional.grid_sample(tex_mask.float(), uv_map.float()) 1142 | final_warped = warped_image * warped_mask 1143 | return final_warped, warped_mask 1144 | else: 1145 | return warped_image 1146 | 1147 | def forward(self, input, pose): 1148 | coor = input[:,4:6,:,:] 1149 | target_coor = input[:,10:,:,:] 1150 | target_input = input[:,6:10,:,:] 1151 | input = input[:,:4,:,:] 1152 | 1153 | # input 1154 | x1 = self.conv1(input) 1155 | x2 = self.conv2(x1) 1156 | x3 = self.conv3(x2) 1157 | x4 = self.conv4(x3) 1158 | x5 = self.conv5(x4) 1159 | if self.size == 512: 1160 | x6 = self.conv6(x5) 1161 | if self.size == 1024: 1162 | x6 = self.conv6(x5) 1163 | x7 = self.conv7(x6) 1164 | # warp- get flow 1165 | pose_mask = 1-(pose[:,0, :, :] == 0).float().unsqueeze(1) 1166 | flow = self.uv2target(pose.int(), coor) 1167 | # warp- resize flow 1168 | f1 = torch.nn.functional.interpolate(flow, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True) 1169 | f2 = torch.nn.functional.interpolate(flow, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True) 1170 | f3 = torch.nn.functional.interpolate(flow, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True) 1171 | f4 = torch.nn.functional.interpolate(flow, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True) 1172 | f5 = torch.nn.functional.interpolate(flow, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True) 1173 | if self.size == 512: 1174 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1175 | if self.size == 1024: 1176 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1177 | f7 = torch.nn.functional.interpolate(flow, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True) 1178 | # warp- now warp 1179 | x1 = torch.nn.functional.grid_sample(x1, f1.permute(0,2,3,1)) 1180 | x2 = torch.nn.functional.grid_sample(x2, f2.permute(0,2,3,1)) 1181 | x3 = torch.nn.functional.grid_sample(x3, f3.permute(0,2,3,1)) 1182 | x4 = torch.nn.functional.grid_sample(x4, f4.permute(0,2,3,1)) 1183 | x5 = torch.nn.functional.grid_sample(x5, f5.permute(0,2,3,1)) 1184 | if self.size == 512: 1185 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1)) 1186 | if self.size == 1024: 1187 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1)) 1188 | x7 = torch.nn.functional.grid_sample(x7, f7.permute(0,2,3,1)) 1189 | # mask features 1190 | p1 = torch.nn.functional.interpolate(pose_mask, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True) 1191 | p2 = torch.nn.functional.interpolate(pose_mask, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True) 1192 | p3 = torch.nn.functional.interpolate(pose_mask, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True) 1193 | p4 = torch.nn.functional.interpolate(pose_mask, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True) 1194 | p5 = torch.nn.functional.interpolate(pose_mask, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True) 1195 | if self.size == 512: 1196 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1197 | if self.size == 1024: 1198 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1199 | p7 = torch.nn.functional.interpolate(pose_mask, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True) 1200 | x1 = x1 * p1 1201 | x2 = x2 * p2 1202 | x3 = x3 * p3 1203 | x4 = x4 * p4 1204 | x5 = x5 * p5 1205 | if self.size == 512: 1206 | x6 = x6 * p6 1207 | if self.size == 1024: 1208 | x6 = x6 * p6 1209 | x7 = x7 * p7 1210 | 1211 | input_x1 = x1 1212 | input_x2 = x2 1213 | input_x3 = x3 1214 | input_x4 = x4 1215 | input_x5 = x5 1216 | if self.size == 512: 1217 | input_x6 = x6 1218 | if self.size == 1024: 1219 | input_x6 = x6 1220 | input_x7 = x7 1221 | 1222 | # target 1223 | x1 = self.conv1(target_input) 1224 | x2 = self.conv2(x1) 1225 | x3 = self.conv3(x2) 1226 | x4 = self.conv4(x3) 1227 | x5 = self.conv5(x4) 1228 | if self.size == 512: 1229 | x6 = self.conv6(x5) 1230 | if self.size == 1024: 1231 | x6 = self.conv6(x5) 1232 | x7 = self.conv7(x6) 1233 | # warp- get flow 1234 | pose_mask = 1-(pose[:,0, :, :] == 0).float().unsqueeze(1) 1235 | flow = self.uv2target(pose.int(), target_coor) 1236 | # warp- resize flow 1237 | f1 = torch.nn.functional.interpolate(flow, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True) 1238 | f2 = torch.nn.functional.interpolate(flow, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True) 1239 | f3 = torch.nn.functional.interpolate(flow, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True) 1240 | f4 = torch.nn.functional.interpolate(flow, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True) 1241 | f5 = torch.nn.functional.interpolate(flow, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True) 1242 | if self.size == 512: 1243 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1244 | if self.size == 1024: 1245 | f6 = torch.nn.functional.interpolate(flow, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1246 | f7 = torch.nn.functional.interpolate(flow, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True) 1247 | # warp- now warp 1248 | x1 = torch.nn.functional.grid_sample(x1, f1.permute(0,2,3,1)) 1249 | x2 = torch.nn.functional.grid_sample(x2, f2.permute(0,2,3,1)) 1250 | x3 = torch.nn.functional.grid_sample(x3, f3.permute(0,2,3,1)) 1251 | x4 = torch.nn.functional.grid_sample(x4, f4.permute(0,2,3,1)) 1252 | x5 = torch.nn.functional.grid_sample(x5, f5.permute(0,2,3,1)) 1253 | if self.size == 512: 1254 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1)) 1255 | if self.size == 1024: 1256 | x6 = torch.nn.functional.grid_sample(x6, f6.permute(0,2,3,1)) 1257 | x7 = torch.nn.functional.grid_sample(x7, f7.permute(0,2,3,1)) 1258 | # mask features 1259 | p1 = torch.nn.functional.interpolate(pose_mask, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True) 1260 | p2 = torch.nn.functional.interpolate(pose_mask, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True) 1261 | p3 = torch.nn.functional.interpolate(pose_mask, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True) 1262 | p4 = torch.nn.functional.interpolate(pose_mask, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True) 1263 | p5 = torch.nn.functional.interpolate(pose_mask, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True) 1264 | if self.size == 512: 1265 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1266 | if self.size == 1024: 1267 | p6 = torch.nn.functional.interpolate(pose_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1268 | p7 = torch.nn.functional.interpolate(pose_mask, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True) 1269 | x1 = x1 * p1 1270 | x2 = x2 * p2 1271 | x3 = x3 * p3 1272 | x4 = x4 * p4 1273 | x5 = x5 * p5 1274 | if self.size == 512: 1275 | x6 = x6 * p6 1276 | if self.size == 1024: 1277 | x6 = x6 * p6 1278 | x7 = x7 * p7 1279 | 1280 | target_x1 = x1 1281 | target_x2 = x2 1282 | target_x3 = x3 1283 | target_x4 = x4 1284 | target_x5 = x5 1285 | if self.size == 512: 1286 | target_x6 = x6 1287 | if self.size == 1024: 1288 | target_x6 = x6 1289 | target_x7 = x7 1290 | 1291 | # body part to transfer 1292 | target_parts = torch.round(self.uv2target(pose.int(), self.uv_parts.to(pose))).int() 1293 | if self.part == 'face': 1294 | face_mask_1 = target_parts==23-1 1295 | face_mask_2 = target_parts==24-1 1296 | face_mask = (face_mask_1+face_mask_2).float() 1297 | part_mask = face_mask 1298 | elif self.part == 'lower_body': 1299 | # LOWER CLOTHES: 7,9=UpperLegRight 8,10=UpperLegLeft 11,13=LowerLegRight 12,14=LowerLegLeft 1300 | lower_mask_1 = target_parts==7-1 1301 | lower_mask_2 = target_parts==9-1 1302 | lower_mask_3 = target_parts==8-1 1303 | lower_mask_4 = target_parts==10-1 1304 | lower_mask_5 = target_parts==11-1 1305 | lower_mask_6 = target_parts==13-1 1306 | lower_mask_7 = target_parts==12-1 1307 | lower_mask_8 = target_parts==14-1 1308 | lower_mask = (lower_mask_1+lower_mask_2+lower_mask_3+lower_mask_4+lower_mask_5+lower_mask_6+lower_mask_7+lower_mask_8).float() 1309 | part_mask = lower_mask 1310 | elif self.part == 'upper_body': 1311 | # UPPER CLOTHES: 1,2=Torso 15,17=UpperArmLeft 16,18=UpperArmRight 19,21=LowerArmLeft 20,22=LowerArmRight 1312 | upper_mask_1 = target_parts==1-1 1313 | upper_mask_2 = target_parts==2-1 1314 | upper_mask_3 = target_parts==15-1 1315 | upper_mask_4 = target_parts==17-1 1316 | upper_mask_5 = target_parts==16-1 1317 | upper_mask_6 = target_parts==18-1 1318 | upper_mask_7 = target_parts==19-1 1319 | upper_mask_8 = target_parts==21-1 1320 | upper_mask_9 = target_parts==20-1 1321 | upper_mask_10 = target_parts==22-1 1322 | upper_mask = (upper_mask_1+upper_mask_2+upper_mask_3+upper_mask_4+upper_mask_5+upper_mask_6+upper_mask_7+upper_mask_8+upper_mask_9+upper_mask_10).float() 1323 | part_mask = upper_mask 1324 | else: # full body 1325 | upper_mask_1 = target_parts==1-1 1326 | upper_mask_2 = target_parts==2-1 1327 | upper_mask_3 = target_parts==15-1 1328 | upper_mask_4 = target_parts==17-1 1329 | upper_mask_5 = target_parts==16-1 1330 | upper_mask_6 = target_parts==18-1 1331 | upper_mask_7 = target_parts==19-1 1332 | upper_mask_8 = target_parts==21-1 1333 | upper_mask_9 = target_parts==20-1 1334 | upper_mask_10 = target_parts==22-1 1335 | upper_mask = (upper_mask_1+upper_mask_2+upper_mask_3+upper_mask_4+upper_mask_5+upper_mask_6+upper_mask_7+upper_mask_8+upper_mask_9+upper_mask_10).float() 1336 | lower_mask_1 = target_parts==7-1 1337 | lower_mask_2 = target_parts==9-1 1338 | lower_mask_3 = target_parts==8-1 1339 | lower_mask_4 = target_parts==10-1 1340 | lower_mask_5 = target_parts==11-1 1341 | lower_mask_6 = target_parts==13-1 1342 | lower_mask_7 = target_parts==12-1 1343 | lower_mask_8 = target_parts==14-1 1344 | lower_mask = (lower_mask_1+lower_mask_2+lower_mask_3+lower_mask_4+lower_mask_5+lower_mask_6+lower_mask_7+lower_mask_8).float() 1345 | part_mask = upper_mask + lower_mask 1346 | 1347 | part_mask1 = torch.nn.functional.interpolate(part_mask, size=(x1.shape[2], x1.shape[3]), mode='bilinear', align_corners=True) 1348 | part_mask2 = torch.nn.functional.interpolate(part_mask, size=(x2.shape[2], x2.shape[3]), mode='bilinear', align_corners=True) 1349 | part_mask3 = torch.nn.functional.interpolate(part_mask, size=(x3.shape[2], x3.shape[3]), mode='bilinear', align_corners=True) 1350 | part_mask4 = torch.nn.functional.interpolate(part_mask, size=(x4.shape[2], x4.shape[3]), mode='bilinear', align_corners=True) 1351 | part_mask5 = torch.nn.functional.interpolate(part_mask, size=(x5.shape[2], x5.shape[3]), mode='bilinear', align_corners=True) 1352 | if self.size == 512: 1353 | part_mask6 = torch.nn.functional.interpolate(part_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1354 | if self.size == 1024: 1355 | part_mask6 = torch.nn.functional.interpolate(part_mask, size=(x6.shape[2], x6.shape[3]), mode='bilinear', align_corners=True) 1356 | part_mask7 = torch.nn.functional.interpolate(part_mask, size=(x7.shape[2], x7.shape[3]), mode='bilinear', align_corners=True) 1357 | 1358 | x1 = target_x1*part_mask1 + input_x1*(1-part_mask1) 1359 | x2 = target_x2*part_mask2 + input_x2*(1-part_mask2) 1360 | x3 = target_x3*part_mask3 + input_x3*(1-part_mask3) 1361 | x4 = target_x4*part_mask4 + input_x4*(1-part_mask4) 1362 | x5 = target_x5*part_mask5 + input_x5*(1-part_mask5) 1363 | if self.size == 512: 1364 | x6 = target_x6*part_mask6 + input_x6*(1-part_mask6) 1365 | if self.size == 1024: 1366 | x6 = target_x6*part_mask6 + input_x6*(1-part_mask6) 1367 | x7 = target_x7*part_mask7 + input_x7*(1-part_mask7) 1368 | 1369 | # fpn 1370 | if self.size == 1024: 1371 | F7 = self.conv71(torch.cat([x7,p7], 1)) 1372 | f6 = self.up(F7)+self.conv61(torch.cat([x6,p6], 1)) 1373 | F6 = self.conv63(f6) 1374 | f5 = self.up(F6)+self.conv51(torch.cat([x5,p5], 1)) 1375 | F5 = self.conv53(f5) 1376 | f4 = self.up(f5)+self.conv41(torch.cat([x4,p4], 1)) 1377 | F4 = self.conv43(f4) 1378 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1)) 1379 | F3 = self.conv33(f3) 1380 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1)) 1381 | F2 = self.conv23(f2) 1382 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1)) 1383 | F1 = self.conv13(f1) 1384 | elif self.size == 512: 1385 | F6 = self.conv61(torch.cat([x6,p6], 1)) 1386 | f5 = self.up(F6)+self.conv51(torch.cat([x5,p5], 1)) 1387 | F5 = self.conv53(f5) 1388 | f4 = self.up(f5)+self.conv41(torch.cat([x4,p4], 1)) 1389 | F4 = self.conv43(f4) 1390 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1)) 1391 | F3 = self.conv33(f3) 1392 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1)) 1393 | F2 = self.conv23(f2) 1394 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1)) 1395 | F1 = self.conv13(f1) 1396 | else: 1397 | F5 = self.conv51(torch.cat([x5,p5], 1)) 1398 | f4 = self.up(F5)+self.conv41(torch.cat([x4,p4], 1)) 1399 | F4 = self.conv43(f4) 1400 | f3 = self.up(f4)+self.conv31(torch.cat([x3,p3], 1)) 1401 | F3 = self.conv33(f3) 1402 | f2 = self.up(f3)+self.conv21(torch.cat([x2,p2], 1)) 1403 | F2 = self.conv23(f2) 1404 | f1 = self.up(f2)+self.conv11(torch.cat([x1,p1], 1)) 1405 | F1 = self.conv13(f1) 1406 | 1407 | if self.size == 1024: 1408 | return [F7, F6, F5, F4, F3, F2, F1], part_mask 1409 | elif self.size == 512: 1410 | return [F6, F5, F4, F3, F2, F1], part_mask 1411 | else: 1412 | return [F5, F4, F3, F2, F1], part_mask 1413 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | pandas 3 | scipy 4 | opencv-python 5 | -------------------------------------------------------------------------------- /sphereface.py: -------------------------------------------------------------------------------- 1 | # source: https://raw.githubusercontent.com/clcarwin/sphereface_pytorch/master/net_sphere.py 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.nn import Parameter 7 | import math 8 | 9 | def myphi(x,m): 10 | x = x * m 11 | return 1-x**2/math.factorial(2)+x**4/math.factorial(4)-x**6/math.factorial(6) + \ 12 | x**8/math.factorial(8) - x**9/math.factorial(9) 13 | 14 | class AngleLinear(nn.Module): 15 | def __init__(self, in_features, out_features, m = 4, phiflag=True): 16 | super(AngleLinear, self).__init__() 17 | self.in_features = in_features 18 | self.out_features = out_features 19 | self.weight = Parameter(torch.Tensor(in_features,out_features)) 20 | self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5) 21 | self.phiflag = phiflag 22 | self.m = m 23 | self.mlambda = [ 24 | lambda x: x**0, 25 | lambda x: x**1, 26 | lambda x: 2*x**2-1, 27 | lambda x: 4*x**3-3*x, 28 | lambda x: 8*x**4-8*x**2+1, 29 | lambda x: 16*x**5-20*x**3+5*x 30 | ] 31 | 32 | def forward(self, input): 33 | x = input # size=(B,F) F is feature len 34 | w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features 35 | 36 | ww = w.renorm(2,1,1e-5).mul(1e5) 37 | xlen = x.pow(2).sum(1).pow(0.5) # size=B 38 | wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum 39 | 40 | cos_theta = x.mm(ww) # size=(B,Classnum) 41 | cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1) 42 | cos_theta = cos_theta.clamp(-1,1) 43 | 44 | if self.phiflag: 45 | cos_m_theta = self.mlambda[self.m](cos_theta) 46 | theta = Variable(cos_theta.data.acos()) 47 | k = (self.m*theta/3.14159265).floor() 48 | n_one = k*0.0 - 1 49 | phi_theta = (n_one**k) * cos_m_theta - 2*k 50 | else: 51 | theta = cos_theta.acos() 52 | phi_theta = myphi(theta,self.m) 53 | phi_theta = phi_theta.clamp(-1*self.m,1) 54 | 55 | cos_theta = cos_theta * xlen.view(-1,1) 56 | phi_theta = phi_theta * xlen.view(-1,1) 57 | output = (cos_theta,phi_theta) 58 | return output # size=(B,Classnum,2) 59 | 60 | 61 | class AngleLoss(nn.Module): 62 | def __init__(self, gamma=0): 63 | super(AngleLoss, self).__init__() 64 | self.gamma = gamma 65 | self.it = 0 66 | self.LambdaMin = 5.0 67 | self.LambdaMax = 1500.0 68 | self.lamb = 1500.0 69 | 70 | def forward(self, input, target): 71 | self.it += 1 72 | cos_theta,phi_theta = input 73 | target = target.view(-1,1) #size=(B,1) 74 | 75 | index = cos_theta.data * 0.0 #size=(B,Classnum) 76 | index.scatter_(1,target.data.view(-1,1),1) 77 | index = index.byte() 78 | index = Variable(index) 79 | 80 | self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it )) 81 | output = cos_theta * 1.0 #size=(B,Classnum) 82 | output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb) 83 | output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb) 84 | 85 | logpt = F.log_softmax(output) 86 | logpt = logpt.gather(1,target) 87 | logpt = logpt.view(-1) 88 | pt = Variable(logpt.data.exp()) 89 | 90 | loss = -1 * (1-pt)**self.gamma * logpt 91 | loss = loss.mean() 92 | 93 | return loss 94 | 95 | 96 | class sphere20a(nn.Module): 97 | def __init__(self,classnum=10574,feature=False): 98 | super(sphere20a, self).__init__() 99 | self.classnum = classnum 100 | self.feature = feature 101 | #input = B*3*112*96 102 | self.conv1_1 = nn.Conv2d(3,64,3,2,1) #=>B*64*56*48 103 | self.relu1_1 = nn.PReLU(64) 104 | self.conv1_2 = nn.Conv2d(64,64,3,1,1) 105 | self.relu1_2 = nn.PReLU(64) 106 | self.conv1_3 = nn.Conv2d(64,64,3,1,1) 107 | self.relu1_3 = nn.PReLU(64) 108 | 109 | self.conv2_1 = nn.Conv2d(64,128,3,2,1) #=>B*128*28*24 110 | self.relu2_1 = nn.PReLU(128) 111 | self.conv2_2 = nn.Conv2d(128,128,3,1,1) 112 | self.relu2_2 = nn.PReLU(128) 113 | self.conv2_3 = nn.Conv2d(128,128,3,1,1) 114 | self.relu2_3 = nn.PReLU(128) 115 | 116 | self.conv2_4 = nn.Conv2d(128,128,3,1,1) #=>B*128*28*24 117 | self.relu2_4 = nn.PReLU(128) 118 | self.conv2_5 = nn.Conv2d(128,128,3,1,1) 119 | self.relu2_5 = nn.PReLU(128) 120 | 121 | 122 | self.conv3_1 = nn.Conv2d(128,256,3,2,1) #=>B*256*14*12 123 | self.relu3_1 = nn.PReLU(256) 124 | self.conv3_2 = nn.Conv2d(256,256,3,1,1) 125 | self.relu3_2 = nn.PReLU(256) 126 | self.conv3_3 = nn.Conv2d(256,256,3,1,1) 127 | self.relu3_3 = nn.PReLU(256) 128 | 129 | self.conv3_4 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 130 | self.relu3_4 = nn.PReLU(256) 131 | self.conv3_5 = nn.Conv2d(256,256,3,1,1) 132 | self.relu3_5 = nn.PReLU(256) 133 | 134 | self.conv3_6 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 135 | self.relu3_6 = nn.PReLU(256) 136 | self.conv3_7 = nn.Conv2d(256,256,3,1,1) 137 | self.relu3_7 = nn.PReLU(256) 138 | 139 | self.conv3_8 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 140 | self.relu3_8 = nn.PReLU(256) 141 | self.conv3_9 = nn.Conv2d(256,256,3,1,1) 142 | self.relu3_9 = nn.PReLU(256) 143 | 144 | self.conv4_1 = nn.Conv2d(256,512,3,2,1) #=>B*512*7*6 145 | self.relu4_1 = nn.PReLU(512) 146 | self.conv4_2 = nn.Conv2d(512,512,3,1,1) 147 | self.relu4_2 = nn.PReLU(512) 148 | self.conv4_3 = nn.Conv2d(512,512,3,1,1) 149 | self.relu4_3 = nn.PReLU(512) 150 | 151 | self.fc5 = nn.Linear(512*7*6,512) 152 | self.fc6 = AngleLinear(512,self.classnum) 153 | 154 | 155 | def forward(self, x): 156 | x = self.relu1_1(self.conv1_1(x)) 157 | x = x + self.relu1_3(self.conv1_3(self.relu1_2(self.conv1_2(x)))) 158 | 159 | x = self.relu2_1(self.conv2_1(x)) 160 | x = x + self.relu2_3(self.conv2_3(self.relu2_2(self.conv2_2(x)))) 161 | x = x + self.relu2_5(self.conv2_5(self.relu2_4(self.conv2_4(x)))) 162 | 163 | x = self.relu3_1(self.conv3_1(x)) 164 | x = x + self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(x)))) 165 | x = x + self.relu3_5(self.conv3_5(self.relu3_4(self.conv3_4(x)))) 166 | x = x + self.relu3_7(self.conv3_7(self.relu3_6(self.conv3_6(x)))) 167 | x = x + self.relu3_9(self.conv3_9(self.relu3_8(self.conv3_8(x)))) 168 | 169 | x = self.relu4_1(self.conv4_1(x)) 170 | x = x + self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(x)))) 171 | 172 | x = x.view(x.size(0),-1) 173 | x = self.fc5(x) 174 | if self.feature: return x 175 | 176 | x = self.fc6(x) 177 | return x 178 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torchvision import utils 5 | from model import Generator 6 | from tqdm import tqdm 7 | from torch.utils import data 8 | import numpy as np 9 | from dataset import DeepFashionDataset 10 | 11 | 12 | def sample_data(loader): 13 | while True: 14 | for batch in loader: 15 | yield batch 16 | 17 | def generate(args, g_ema, device, mean_latent, loader): 18 | loader = sample_data(loader) 19 | with torch.no_grad(): 20 | g_ema.eval() 21 | for i in tqdm(range(args.pics)): 22 | data = next(loader) 23 | 24 | input_image = data['input_image'].float().to(device) 25 | real_img = data['target_image'].float().to(device) 26 | pose = data['target_pose'].float().to(device) 27 | sil = data['target_sil'].float().to(device) 28 | 29 | source_sil = data['input_sil'].float().to(device) 30 | complete_coor = data['complete_coor'].float().to(device) 31 | 32 | if args.size == 256: 33 | complete_coor = torch.nn.functional.interpolate(complete_coor, size=(256,256), mode='bilinear') 34 | 35 | appearance = torch.cat([input_image, source_sil, complete_coor], 1) 36 | 37 | sample, _ = g_ema(appearance=appearance, pose=pose) 38 | 39 | RP = data['target_right_pad'] 40 | LP = data['target_left_pad'] 41 | 42 | utils.save_image( 43 | sample[:, :, :, int(RP[0].item()):args.size-int(LP[0].item())], 44 | os.path.join(args.save_path, data['save_name'][0]), 45 | nrow=1, 46 | normalize=True, 47 | range=(-1, 1), 48 | ) 49 | 50 | 51 | if __name__ == "__main__": 52 | device = "cuda" 53 | 54 | parser = argparse.ArgumentParser(description="Generate reposing results") 55 | 56 | parser.add_argument("path", type=str, help="path to dataset") 57 | parser.add_argument("--size", type=int, default=512, help="output image size of the generator") 58 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 59 | parser.add_argument("--truncation_mean", type=int, default=4096, help="number of vectors to calculate mean for the truncation") 60 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of the generator. config-f = 2, else = 1") 61 | parser.add_argument("--pretrained_model", type=str, default="posewithstyle.pt", help="pose with style pretrained model") 62 | parser.add_argument("--save_path", type=str, default="output", help="path to save output .data/output") 63 | 64 | args = parser.parse_args() 65 | 66 | args.latent = 2048 67 | args.n_mlp = 8 68 | 69 | if not os.path.exists(args.save_path): 70 | os.makedirs(args.save_path) 71 | 72 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) 73 | checkpoint = torch.load(args.pretrained_model) 74 | g_ema.load_state_dict(checkpoint["g_ema"]) 75 | 76 | if args.truncation < 1: 77 | with torch.no_grad(): 78 | mean_latent = g_ema.mean_latent(args.truncation_mean) 79 | else: 80 | mean_latent = None 81 | 82 | dataset = DeepFashionDataset(args.path, 'test', args.size) 83 | loader = data.DataLoader( 84 | dataset, 85 | batch_size=1, 86 | sampler=data.SequentialSampler(dataset), 87 | drop_last=False, 88 | ) 89 | 90 | print ('Testing %d images...'%len(dataset)) 91 | args.pics = len(dataset) 92 | 93 | generate(args, g_ema, device, mean_latent, loader) 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, autograd, optim 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | import torch.distributed as dist 12 | from torchvision import transforms, utils 13 | from tqdm import tqdm 14 | import time 15 | 16 | from dataset import DeepFashionDataset 17 | from model import Generator, Discriminator, VGGLoss 18 | 19 | try: 20 | import wandb 21 | except ImportError: 22 | wandb = None 23 | 24 | 25 | from distributed import ( 26 | get_rank, 27 | synchronize, 28 | reduce_loss_dict, 29 | reduce_sum, 30 | get_world_size, 31 | ) 32 | from op import conv2d_gradfix 33 | 34 | 35 | class AverageMeter(object): 36 | """Computes and stores the average and current value""" 37 | def __init__(self): 38 | self.reset() 39 | 40 | def reset(self): 41 | self.val = 0 42 | self.avg = 0 43 | self.sum = 0 44 | self.count = 0 45 | 46 | def update(self, val, n=1): 47 | self.val = val 48 | self.sum += val * n 49 | self.count += n 50 | self.avg = self.sum / self.count 51 | 52 | 53 | def data_sampler(dataset, shuffle, distributed): 54 | if distributed: 55 | return data.distributed.DistributedSampler(dataset) 56 | if shuffle: 57 | return data.RandomSampler(dataset) 58 | else: 59 | return data.SequentialSampler(dataset) 60 | 61 | 62 | def requires_grad(model, flag=True): 63 | for p in model.parameters(): 64 | p.requires_grad = flag 65 | 66 | 67 | def accumulate(model1, model2, decay=0.999): 68 | par1 = dict(model1.named_parameters()) 69 | par2 = dict(model2.named_parameters()) 70 | for k in par1.keys(): 71 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 72 | 73 | 74 | def sample_data(loader): 75 | while True: 76 | for batch in loader: 77 | yield batch 78 | 79 | 80 | def d_logistic_loss(real_pred, fake_pred): 81 | real_loss = F.softplus(-real_pred) 82 | fake_loss = F.softplus(fake_pred) 83 | return real_loss.mean() + fake_loss.mean() 84 | 85 | 86 | def d_r1_loss(real_pred, real_img): 87 | with conv2d_gradfix.no_weight_gradients(): 88 | grad_real, = autograd.grad( 89 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 90 | ) 91 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 92 | return grad_penalty 93 | 94 | 95 | def g_nonsaturating_loss(fake_pred): 96 | loss = F.softplus(-fake_pred).mean() 97 | return loss 98 | 99 | 100 | def set_grad_none(model, targets): 101 | for n, p in model.named_parameters(): 102 | if n in targets: 103 | p.grad = None 104 | 105 | 106 | def getFace(images, FT, LP, RP): 107 | """ 108 | images: are images where we want to get the faces 109 | FT: transform to get the aligned face 110 | LP: left pad added to the imgae 111 | RP: right pad added to the image 112 | """ 113 | faces = [] 114 | b, h, w, c = images.shape 115 | for b in range(images.shape[0]): 116 | if not (abs(FT[b]).sum() == 0): # all 3x3 elements are zero 117 | # only apply the loss to image with detected faces 118 | # need to do this per image because images are of different shape 119 | current_im = images[b][:, :, int(RP[b].item()):w-int(LP[b].item())].unsqueeze(0) 120 | theta = FT[b].unsqueeze(0)[:, :2] #bx2x3 121 | grid = torch.nn.functional.affine_grid(theta, (1, 3, 112, 96)) 122 | current_face = torch.nn.functional.grid_sample(current_im, grid) 123 | faces.append(current_face) 124 | if len(faces) == 0: 125 | return None 126 | return torch.cat(faces, 0) 127 | 128 | 129 | def train(args, loader, sampler, generator, discriminator, g_optim, d_optim, g_ema, device): 130 | pbar = range(args.epoch) 131 | 132 | if get_rank() == 0: 133 | pbar = tqdm(pbar, initial=args.start_epoch, dynamic_ncols=True, smoothing=0.01) 134 | pbar.set_description('Epoch Counter') 135 | 136 | d_loss_val = 0 137 | r1_loss = torch.tensor(0.0, device=device) 138 | g_loss_val = 0 139 | g_L1_loss_val = 0 140 | g_vgg_loss_val = 0 141 | g_l1 = torch.tensor(0.0, device=device) 142 | g_vgg = torch.tensor(0.0, device=device) 143 | g_cos = torch.tensor(0.0, device=device) 144 | loss_dict = {} 145 | 146 | criterionL1 = torch.nn.L1Loss() 147 | criterionVGG = VGGLoss(device).to(device) 148 | if args.faceloss: 149 | criterionCOS = nn.CosineSimilarity() 150 | 151 | if args.distributed: 152 | g_module = generator.module 153 | d_module = discriminator.module 154 | else: 155 | g_module = generator 156 | d_module = discriminator 157 | 158 | accum = 0.5 ** (32 / (10 * 1000)) 159 | 160 | for idx in pbar: 161 | epoch = idx + args.start_epoch 162 | 163 | if epoch > args.epoch: 164 | print("Done!") 165 | break 166 | 167 | if args.distributed: 168 | sampler.set_epoch(epoch) 169 | 170 | batch_time = AverageMeter() 171 | 172 | ##################################### 173 | ############ START EPOCH ############ 174 | ##################################### 175 | for i, data in enumerate(loader): 176 | batch_start_time = time.time() 177 | 178 | input_image = data['input_image'].float().to(device) 179 | real_img = data['target_image'].float().to(device) 180 | pose = data['target_pose'].float().to(device) 181 | sil = data['target_sil'].float().to(device) 182 | 183 | LeftPad = data['target_left_pad'].float().to(device) 184 | RightPad = data['target_right_pad'].float().to(device) 185 | 186 | if args.faceloss: 187 | FT = data['TargetFaceTransform'].float().to(device) 188 | real_face = getFace(real_img, FT, LeftPad, RightPad) 189 | 190 | if args.finetune: 191 | # only mask padding 192 | sil = torch.zeros((sil.shape)).float().to(device) 193 | for b in range(sil.shape[0]): 194 | w = sil.shape[3] 195 | sil[b][:, :, int(RightPad[b].item()):w-int(LeftPad[b].item())] = 1 # mask out the padding 196 | # else only focus on the foreground - initial step of training 197 | 198 | real_img = real_img * sil 199 | 200 | # appearance = human foregound + fg mask (pass coor for warping) 201 | source_sil = data['input_sil'].float().to(device) 202 | complete_coor = data['complete_coor'].float().to(device) 203 | if args.size == 256: 204 | complete_coor = torch.nn.functional.interpolate(complete_coor, size=(256, 256), mode='bilinear') 205 | if args.finetune: 206 | appearance = torch.cat([input_image, source_sil, complete_coor], 1) 207 | else: 208 | appearance = torch.cat([input_image * source_sil, source_sil, complete_coor], 1) 209 | 210 | 211 | ############ Optimize Discriminator ############ 212 | requires_grad(generator, False) 213 | requires_grad(discriminator, True) 214 | 215 | fake_img, _ = generator(appearance=appearance, pose=pose) 216 | fake_img = fake_img * sil 217 | 218 | fake_pred = discriminator(fake_img, pose=pose) 219 | real_pred = discriminator(real_img, pose=pose) 220 | d_loss = d_logistic_loss(real_pred, fake_pred) 221 | 222 | loss_dict["d"] = d_loss 223 | loss_dict["real_score"] = real_pred.mean() 224 | loss_dict["fake_score"] = fake_pred.mean() 225 | 226 | discriminator.zero_grad() 227 | d_loss.backward() 228 | d_optim.step() 229 | 230 | 231 | d_regularize = i % args.d_reg_every == 0 232 | 233 | if d_regularize: 234 | real_img.requires_grad = True 235 | 236 | real_pred = discriminator(real_img, pose=pose) 237 | r1_loss = d_r1_loss(real_pred, real_img) 238 | 239 | discriminator.zero_grad() 240 | (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() 241 | 242 | d_optim.step() 243 | 244 | loss_dict["r1"] = r1_loss 245 | 246 | 247 | ############## Optimize Generator ############## 248 | requires_grad(generator, True) 249 | requires_grad(discriminator, False) 250 | 251 | fake_img, _ = generator(appearance=appearance, pose=pose) 252 | fake_img = fake_img * sil 253 | 254 | fake_pred = discriminator(fake_img, pose=pose) 255 | g_loss = g_nonsaturating_loss(fake_pred) 256 | 257 | loss_dict["g"] = g_loss 258 | 259 | ## reconstruction loss: L1 and VGG loss + face identity loss 260 | g_l1 = criterionL1(fake_img, real_img) 261 | g_loss += g_l1 262 | g_vgg = criterionVGG(fake_img, real_img) 263 | g_loss += g_vgg 264 | 265 | loss_dict["g_L1"] = g_l1 266 | loss_dict["g_vgg"] = g_vgg 267 | 268 | if args.faceloss and (real_face is not None): 269 | fake_face = getFace(fake_img, FT, LeftPad, RightPad) 270 | features_real_face = sphereface_net(real_face) 271 | features_fake_face = sphereface_net(fake_face) 272 | g_cos = 1. - criterionCOS(features_real_face, features_fake_face).mean() 273 | g_loss += g_cos 274 | 275 | loss_dict["g_cos"] = g_cos 276 | 277 | generator.zero_grad() 278 | g_loss.backward() 279 | g_optim.step() 280 | 281 | 282 | ############ Optimization Done ############ 283 | accumulate(g_ema, g_module, accum) 284 | 285 | loss_reduced = reduce_loss_dict(loss_dict) 286 | 287 | d_loss_val = loss_reduced["d"].mean().item() 288 | g_loss_val = loss_reduced["g"].mean().item() 289 | g_L1_loss_val = loss_reduced["g_L1"].mean().item() 290 | g_cos_loss_val = loss_reduced["g_cos"].mean().item() 291 | g_vgg_loss_val = loss_reduced["g_vgg"].mean().item() 292 | r1_val = loss_reduced["r1"].mean().item() 293 | real_score_val = loss_reduced["real_score"].mean().item() 294 | fake_score_val = loss_reduced["fake_score"].mean().item() 295 | 296 | batch_time.update(time.time() - batch_start_time) 297 | 298 | if i % 100 == 0: 299 | print('Epoch: [{0}/{1}] Iter: [{2}/{3}]\t' 300 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(epoch, args.epoch, i, len(loader), batch_time=batch_time) 301 | + 302 | f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; g_L1: {g_L1_loss_val:.4f}; g_vgg: {g_vgg_loss_val:.4f}; g_cos: {g_cos_loss_val:.4f}; r1: {r1_val:.4f}; " 303 | ) 304 | 305 | if get_rank() == 0: 306 | if wandb and args.wandb: 307 | wandb.log( 308 | { 309 | "Generator": g_loss_val, 310 | "Discriminator": d_loss_val, 311 | "R1": r1_val, 312 | "Real Score": real_score_val, 313 | "Fake Score": fake_score_val, 314 | "Generator_L1": g_L1_loss_val, 315 | "Generator_vgg": g_vgg_loss_val, 316 | "Generator_facecos": g_cos_loss_val, 317 | } 318 | ) 319 | 320 | if i % 5000 == 0: 321 | with torch.no_grad(): 322 | g_ema.eval() 323 | sample, _ = g_ema(appearance=appearance[:args.n_sample], pose=pose[:args.n_sample]) 324 | sample = sample * sil 325 | utils.save_image( 326 | sample, 327 | os.path.join('sample', args.name, f"epoch_{str(epoch)}_iter_{str(i)}.png"), 328 | nrow=int(args.n_sample ** 0.5), 329 | normalize=True, 330 | range=(-1, 1), 331 | ) 332 | 333 | if i % 5000 == 0: 334 | torch.save( 335 | { 336 | "g": g_module.state_dict(), 337 | "d": d_module.state_dict(), 338 | "g_ema": g_ema.state_dict(), 339 | "g_optim": g_optim.state_dict(), 340 | "d_optim": d_optim.state_dict(), 341 | "args": args, 342 | }, 343 | os.path.join('checkpoint', args.name, f"epoch_{str(epoch)}_iter_{str(i)}.pt"), 344 | ) 345 | 346 | ################################### 347 | ############ END EPOCH ############ 348 | ################################### 349 | if get_rank() == 0: 350 | torch.save( 351 | { 352 | "g": g_module.state_dict(), 353 | "d": d_module.state_dict(), 354 | "g_ema": g_ema.state_dict(), 355 | "g_optim": g_optim.state_dict(), 356 | "d_optim": d_optim.state_dict(), 357 | "args": args, 358 | }, 359 | os.path.join('checkpoint', args.name, f"epoch_{str(epoch)}.pt"), 360 | ) 361 | 362 | 363 | if __name__ == "__main__": 364 | device = "cuda" 365 | 366 | parser = argparse.ArgumentParser(description="Pose with Style trainer") 367 | 368 | parser.add_argument("path", type=str, help="path to the lmdb dataset") 369 | parser.add_argument("--name", type=str, help="name of experiment") 370 | parser.add_argument("--epoch", type=int, default=50, help="total training epochs") 371 | parser.add_argument("--batch", type=int, default=4, help="batch sizes for each gpus") 372 | parser.add_argument("--workers", type=int, default=4, help="batch sizes for each gpus") 373 | parser.add_argument("--n_sample", type=int, default=4, help="number of the samples generated during training") 374 | parser.add_argument("--size", type=int, default=512, help="image sizes for the model") 375 | parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization") 376 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier factor for the model. config-f = 2, else = 1") 377 | parser.add_argument( 378 | "--d_reg_every", 379 | type=int, 380 | default=16, 381 | help="interval of the applying r1 regularization", 382 | ) 383 | parser.add_argument( 384 | "--g_reg_every", 385 | type=int, 386 | default=4, 387 | help="interval of the applying path length regularization", 388 | ) 389 | parser.add_argument("--ckpt", type=str, default=None, help="path to the checkpoints to resume training") 390 | parser.add_argument("--lr", type=float, default=0.002, help="learning rate") 391 | parser.add_argument("--wandb", action="store_true", help="use weights and biases logging") 392 | parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training") 393 | parser.add_argument("--faceloss", action="store_true", help="add face loss when faces are detected") 394 | parser.add_argument("--finetune", action="store_true", help="finetune to handle background- second step of training.") 395 | 396 | 397 | args = parser.parse_args() 398 | 399 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 400 | args.distributed = n_gpu > 1 401 | 402 | if args.distributed: 403 | print ('Distributed Training Mode.') 404 | torch.cuda.set_device(args.local_rank) 405 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 406 | synchronize() 407 | 408 | if get_rank() == 0: 409 | if not os.path.exists(os.path.join('checkpoint', args.name)): 410 | os.makedirs(os.path.join('checkpoint', args.name)) 411 | if not os.path.exists(os.path.join('sample', args.name)): 412 | os.makedirs(os.path.join('sample', args.name)) 413 | 414 | args.latent = 2048 415 | args.n_mlp = 8 416 | 417 | args.start_epoch = 0 418 | 419 | if args.finetune and (args.ckpt is None): 420 | print ('to finetune the model, please specify --ckpt.') 421 | import sys 422 | sys.exit() 423 | 424 | # define models 425 | generator = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) 426 | discriminator = Discriminator(args.size, channel_multiplier=args.channel_multiplier).to(device) 427 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) 428 | g_ema.eval() 429 | accumulate(g_ema, generator, 0) 430 | 431 | if args.faceloss: 432 | import sphereface 433 | sphereface_net = getattr(sphereface, 'sphere20a')() 434 | sphereface_net.load_state_dict(torch.load(os.path.join(args.path, 'resources', 'sphere20a_20171020.pth'))) 435 | sphereface_net.to(device) 436 | sphereface_net.eval() 437 | sphereface_net.feature = True 438 | 439 | 440 | g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) 441 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 442 | 443 | g_optim = optim.Adam( 444 | generator.parameters(), 445 | lr=args.lr * g_reg_ratio, 446 | betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), 447 | ) 448 | d_optim = optim.Adam( 449 | discriminator.parameters(), 450 | lr=args.lr * d_reg_ratio, 451 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 452 | ) 453 | 454 | if args.ckpt is not None: 455 | print("load model:", args.ckpt) 456 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 457 | 458 | try: 459 | ckpt_name = os.path.basename(args.ckpt) 460 | args.start_epoch = int(os.path.splitext(ckpt_name)[0].split('_')[1])+1 # asuming saving as epoch_1_iter_1000.pt or epoch_1.pt 461 | except ValueError: 462 | pass 463 | 464 | generator.load_state_dict(ckpt["g"]) 465 | discriminator.load_state_dict(ckpt["d"]) 466 | g_ema.load_state_dict(ckpt["g_ema"]) 467 | 468 | g_optim.load_state_dict(ckpt["g_optim"]) 469 | d_optim.load_state_dict(ckpt["d_optim"]) 470 | 471 | if args.distributed: 472 | generator = nn.parallel.DistributedDataParallel( 473 | generator, 474 | device_ids=[args.local_rank], 475 | output_device=args.local_rank, 476 | broadcast_buffers=False, 477 | ) 478 | discriminator = nn.parallel.DistributedDataParallel( 479 | discriminator, 480 | device_ids=[args.local_rank], 481 | output_device=args.local_rank, 482 | broadcast_buffers=False, 483 | ) 484 | 485 | dataset = DeepFashionDataset(args.path, 'train', args.size) 486 | sampler = data_sampler(dataset, shuffle=True, distributed=args.distributed) 487 | loader = data.DataLoader( 488 | dataset, 489 | batch_size=args.batch, 490 | sampler=sampler, 491 | drop_last=True, 492 | pin_memory=True, 493 | num_workers=args.workers, 494 | shuffle=False, 495 | ) 496 | 497 | if get_rank() == 0 and (wandb is not None) and args.wandb: 498 | wandb.init(project=args.name) 499 | 500 | train(args, loader, sampler, generator, discriminator, g_optim, d_optim, g_ema, device) 501 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BadourAlBahar/pose-with-style/f344f3bcad621cb2faf8b2cb83a1c38975871072/util/__init__.py -------------------------------------------------------------------------------- /util/complete_coor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | from coordinate_completion_model import define_G as define_CCM 6 | 7 | def pad_PIL(pil_img, top, right, bottom, left, color=(0, 0, 0)): 8 | width, height = pil_img.size 9 | new_width = width + right + left 10 | new_height = height + top + bottom 11 | result = Image.new(pil_img.mode, (new_width, new_height), color) 12 | result.paste(pil_img, (left, top)) 13 | return result 14 | 15 | def save_image(image_numpy, image_path): 16 | image_pil = Image.fromarray(image_numpy.astype(np.uint8)) 17 | image_pil.save(image_path) 18 | 19 | import argparse 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataroot', type=str, help="path to DeepFashion dataset") 22 | parser.add_argument('--coordinates_path', type=str, help="path to partial coordinates dataset") 23 | parser.add_argument('--image_file', type=str, help="path to image file to process. ex: ./train.lst") 24 | parser.add_argument('--phase', type=str, help="train or test") 25 | parser.add_argument('--save_path', type=str, help="path to save results") 26 | parser.add_argument('--pretrained_model', type=str, help="path to save results") 27 | 28 | args = parser.parse_args() 29 | 30 | images_file = args.image_file 31 | phase = args.phase 32 | 33 | if not os.path.exists(os.path.join(args.save_path, phase)): 34 | os.makedirs(os.path.join(args.save_path, phase)) 35 | 36 | images = [] 37 | f = open(images_file, 'r') 38 | for lines in f: 39 | lines = lines.strip() 40 | images.append(lines) 41 | 42 | coor_completion_generator = define_CCM().cuda() 43 | CCM_checkpoint = torch.load(args.pretrained_model) 44 | coor_completion_generator.load_state_dict(CCM_checkpoint["g"]) 45 | coor_completion_generator.eval() 46 | for param in coor_completion_generator.parameters(): 47 | coor_completion_generator.requires_grad = False 48 | 49 | for i in range(len(images)): 50 | print ('%d/%d'%(i, len(images))) 51 | im_name = images[i] 52 | 53 | # get image 54 | path = os.path.join(args.dataroot, phase, im_name) 55 | im = Image.open(path) 56 | w, h = im.size 57 | 58 | # get uv coordinates 59 | uvcoor_root = os.path.join(args.coordinates_path, phase) 60 | uv_coor_path = os.path.join(uvcoor_root, im_name.split('.')[0]+'_uv_coor.npy') 61 | uv_mask_path = os.path.join(uvcoor_root, im_name.split('.')[0]+'_uv_mask.png') 62 | uv_symm_mask_path = os.path.join(uvcoor_root, im_name.split('.')[0]+'_uv_symm_mask.png') 63 | 64 | if (os.path.exists(uv_coor_path)): 65 | # read high-resolution coordinates 66 | uv_coor = np.load(uv_coor_path) 67 | uv_mask = np.array(Image.open(uv_mask_path))/255 68 | uv_symm_mask = np.array(Image.open(uv_symm_mask_path))/255 69 | 70 | # uv coor 71 | shift = int((h-w)/2) 72 | uv_coor[:,:,0] = uv_coor[:,:,0] + shift # put in center 73 | uv_coor = ((2*uv_coor/(h-1))-1) 74 | uv_coor = uv_coor*np.expand_dims(uv_mask,2) + (-10*(1-np.expand_dims(uv_mask,2))) 75 | 76 | x1 = shift 77 | x2 = h-(w+x1) 78 | im = pad_PIL(im, 0, x2, 0, x1, color=(0, 0, 0)) 79 | 80 | ## coordinate completion 81 | uv_coor_pytorch = torch.from_numpy(uv_coor).float().permute(2, 0, 1).unsqueeze(0) # from h,w,c to 1,c,h,w 82 | uv_mask_pytorch = torch.from_numpy(uv_mask).unsqueeze(0).unsqueeze(0).float() #1xchw 83 | with torch.no_grad(): 84 | coor_completion_generator.eval() 85 | complete_coor = coor_completion_generator(uv_coor_pytorch.cuda(), uv_mask_pytorch.cuda()) 86 | uv_coor = complete_coor[0].permute(1,2,0).data.cpu().numpy() 87 | uv_confidence = np.stack([uv_mask-uv_symm_mask, uv_symm_mask, 1-uv_mask], 2) 88 | 89 | im = torch.from_numpy(np.array(im)).permute(2, 0, 1).unsqueeze(0).float() 90 | rgb_uv = torch.nn.functional.grid_sample(im.cuda(), complete_coor.permute(0,2,3,1).cuda()) 91 | rgb_uv = rgb_uv[0].permute(1,2,0).data.cpu().numpy() 92 | 93 | # saving 94 | save_image(rgb_uv, os.path.join(args.save_path, phase, im_name.split('.jpg')[0]+'.png')) 95 | np.save(os.path.join(args.save_path, phase, '%s_uv_coor.npy'%(im_name.split('.')[0])), uv_coor) 96 | save_image(uv_confidence*255, os.path.join(args.save_path, phase, '%s_conf.png'%(im_name.split('.')[0]))) 97 | -------------------------------------------------------------------------------- /util/coordinate_completion_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | def init_weights(net, init_type='normal', init_gain=0.02): 6 | """Initialize network weights. 7 | 8 | Parameters: 9 | net (network) -- network to be initialized 10 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 11 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 12 | 13 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 14 | work better for some applications. Feel free to try yourself. 15 | """ 16 | def init_func(m): # define the initialization function 17 | classname = m.__class__.__name__ 18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 19 | if init_type == 'normal': 20 | init.normal_(m.weight.data, 0.0, init_gain) 21 | elif init_type == 'xavier': 22 | init.xavier_normal_(m.weight.data, gain=init_gain) 23 | elif init_type == 'kaiming': 24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 25 | elif init_type == 'orthogonal': 26 | init.orthogonal_(m.weight.data, gain=init_gain) 27 | else: 28 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 29 | if hasattr(m, 'bias') and m.bias is not None: 30 | init.constant_(m.bias.data, 0.0) 31 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 32 | init.normal_(m.weight.data, 1.0, init_gain) 33 | init.constant_(m.bias.data, 0.0) 34 | 35 | print('initialize network with %s' % init_type) 36 | net.apply(init_func) # apply the initialization function 37 | return net 38 | 39 | def define_G(init_type='normal', init_gain=0.02): 40 | net = CoordinateCompletion() 41 | return init_weights(net, init_type, init_gain) 42 | 43 | class CoordinateCompletion(nn.Module): 44 | def __init__(self): 45 | super(CoordinateCompletion, self).__init__() 46 | self.generator = CoorGenerator(input_nc=2+1, output_nc=2, tanh=True) 47 | 48 | def forward(self, coor_xy, UV_texture_mask): 49 | complete_coor = self.generator(torch.cat((coor_xy, UV_texture_mask), 1)) 50 | return complete_coor 51 | 52 | class CoorGenerator(nn.Module): 53 | def __init__(self, input_nc, output_nc, ngf=32, batch_norm=True, spectral_norm=False, tanh=True): 54 | super(CoorGenerator, self).__init__() 55 | 56 | block = GatedConv2dWithActivation 57 | activation = nn.ELU(inplace=True) 58 | 59 | model = [block(input_nc, ngf, kernel_size=5, stride=1, padding=2, 60 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation),\ 61 | block(ngf, ngf*2, kernel_size=3, stride=2, padding=1, 62 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation),\ 63 | block(ngf*2, ngf*2, kernel_size=3, stride=1, padding=1, 64 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation),\ 65 | block(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1, 66 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 67 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1, 68 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 69 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1, 70 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 71 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1+1, dilation=2, 72 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 73 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1+3, dilation=4, 74 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 75 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1+7, dilation=8, 76 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 77 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1+7, dilation=8, 78 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 79 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1, dilation=1, 80 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 81 | block(ngf*4, ngf*4, kernel_size=3, stride=1, padding=1, dilation=1, 82 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 83 | nn.Upsample(scale_factor=2, mode='bilinear'),\ 84 | block(ngf*4, ngf*2, kernel_size=3, stride=1, padding=1, dilation=1, 85 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 86 | nn.Upsample(scale_factor=2, mode='bilinear'),\ 87 | block(ngf*2, ngf, kernel_size=3, stride=1, padding=1, dilation=1, 88 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 89 | block(ngf, int(ngf/2), kernel_size=3, stride=1, padding=1, dilation=1, 90 | batch_norm=batch_norm, spectral_norm=spectral_norm, activation=activation), \ 91 | nn.Conv2d(int(ngf/2), output_nc, kernel_size=3, stride=1, padding=1, dilation=1)] 92 | if tanh: 93 | model += [ nn.Tanh()] 94 | 95 | self.model = nn.Sequential(*model) 96 | 97 | def forward(self, input): 98 | out = self.model(input) 99 | return out 100 | 101 | class GatedConv2dWithActivation(torch.nn.Module): 102 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,batch_norm=True, spectral_norm=False, activation=torch.nn.ELU(inplace=True)): 103 | super(GatedConv2dWithActivation, self).__init__() 104 | self.batch_norm = batch_norm 105 | 106 | self.pad = nn.ReflectionPad2d(padding) 107 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 108 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 109 | if spectral_norm: 110 | self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) 111 | self.mask_conv2d = torch.nn.utils.spectral_norm(self.mask_conv2d) 112 | self.activation = activation 113 | self.sigmoid = torch.nn.Sigmoid() 114 | 115 | if self.batch_norm: 116 | self.batch_norm2d = torch.nn.BatchNorm2d(out_channels) 117 | 118 | def forward(self, input): 119 | padded_in = self.pad(input) 120 | x = self.conv2d(padded_in) 121 | mask = self.mask_conv2d(padded_in) 122 | gated_mask = self.sigmoid(mask) 123 | if self.batch_norm: 124 | x = self.batch_norm2d(x) 125 | if self.activation is not None: 126 | x = self.activation(x) 127 | x = x * gated_mask 128 | return x 129 | -------------------------------------------------------------------------------- /util/dp2coor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | from scipy.interpolate import griddata 5 | import cv2 6 | import argparse 7 | 8 | def getSymXYcoordinates(iuv, resolution = 256): 9 | xy, xyMask = getXYcoor(iuv, resolution = resolution) 10 | f_xy, f_xyMask = getXYcoor(flip_iuv(np.copy(iuv)), resolution = resolution) 11 | f_xyMask = np.clip(f_xyMask-xyMask, a_min=0, a_max=1) 12 | # combine actual + symmetric 13 | combined_texture = xy*np.expand_dims(xyMask,2) + f_xy*np.expand_dims(f_xyMask,2) 14 | combined_mask = np.clip(xyMask+f_xyMask, a_min=0, a_max=1) 15 | return combined_texture, combined_mask, f_xyMask 16 | 17 | def flip_iuv(iuv): 18 | POINT_LABEL_SYMMETRIES = [ 0, 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23] 19 | i = iuv[:,:,0] 20 | u = iuv[:,:,1] 21 | v = iuv[:,:,2] 22 | i_old = np.copy(i) 23 | for part in range(24): 24 | if (part + 1) in i_old: 25 | annot_indices_i = i_old == (part + 1) 26 | if POINT_LABEL_SYMMETRIES[part + 1] != part + 1: 27 | i[annot_indices_i] = POINT_LABEL_SYMMETRIES[part + 1] 28 | if part == 22 or part == 23 or part == 2 or part == 3 : #head and hands 29 | u[annot_indices_i] = 255-u[annot_indices_i] 30 | if part == 0 or part == 1: # torso 31 | v[annot_indices_i] = 255-v[annot_indices_i] 32 | return np.stack([i,u,v],2) 33 | 34 | def getXYcoor(iuv, resolution = 256): 35 | x, y, u, v = mapper(iuv, resolution) 36 | # A meshgrid of pixel coordinates 37 | nx, ny = resolution, resolution 38 | X, Y = np.meshgrid(np.arange(0, nx, 1), np.arange(0, ny, 1)) 39 | ## get x,y coordinates 40 | uv_y = griddata((v, u), y, (Y, X), method='linear') 41 | uv_y_ = griddata((v, u), y, (Y, X), method='nearest') 42 | uv_y[np.isnan(uv_y)] = uv_y_[np.isnan(uv_y)] 43 | uv_x = griddata((v, u), x, (Y, X), method='linear') 44 | uv_x_ = griddata((v, u), x, (Y, X), method='nearest') 45 | uv_x[np.isnan(uv_x)] = uv_x_[np.isnan(uv_x)] 46 | # get mask 47 | uv_mask = np.zeros((ny,nx)) 48 | uv_mask[np.ceil(v).astype(int),np.ceil(u).astype(int)]=1 49 | uv_mask[np.floor(v).astype(int),np.floor(u).astype(int)]=1 50 | uv_mask[np.ceil(v).astype(int),np.floor(u).astype(int)]=1 51 | uv_mask[np.floor(v).astype(int),np.ceil(u).astype(int)]=1 52 | kernel = np.ones((3,3),np.uint8) 53 | uv_mask_d = cv2.dilate(uv_mask,kernel,iterations = 1) 54 | # update 55 | coor_x = uv_x * uv_mask_d 56 | coor_y = uv_y * uv_mask_d 57 | coor_xy = np.stack([coor_x, coor_y], 2) 58 | return coor_xy, uv_mask_d 59 | 60 | def mapper(iuv, resolution=256): 61 | dp_uv_lookup_256_np = np.load('util/dp_uv_lookup_256.npy') 62 | H, W, _ = iuv.shape 63 | iuv_raw = iuv[iuv[:, :, 0] > 0] 64 | x = np.linspace(0, W-1, W).astype(np.int) 65 | y = np.linspace(0, H-1, H).astype(np.int) 66 | xx, yy = np.meshgrid(x, y) 67 | xx_rgb = xx[iuv[:, :, 0] > 0] 68 | yy_rgb = yy[iuv[:, :, 0] > 0] 69 | # modify i to start from 0... 0-23 70 | i = iuv_raw[:, 0] - 1 71 | u = iuv_raw[:, 1] 72 | v = iuv_raw[:, 2] 73 | uv_smpl = dp_uv_lookup_256_np[ 74 | i.astype(np.int), 75 | v.astype(np.int), 76 | u.astype(np.int) 77 | ] 78 | u_f = uv_smpl[:, 0] * (resolution - 1) 79 | v_f = (1 - uv_smpl[:, 1]) * (resolution - 1) 80 | return xx_rgb, yy_rgb, u_f, v_f 81 | 82 | 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--image_file', type=str, help="path to image file to process. ex: ./train.lst") 86 | parser.add_argument("--save_path", type=str, help="path to save the uv data") 87 | parser.add_argument("--dp_path", type=str, help="path to densepose data") 88 | args = parser.parse_args() 89 | 90 | if not os.path.exists(args.save_path): 91 | os.makedirs(args.save_path) 92 | 93 | images = [] 94 | f = open(args.image_file, 'r') 95 | for lines in f: 96 | lines = lines.strip() 97 | images.append(lines) 98 | 99 | for i in range(len(images)): 100 | im_name = images[i] 101 | print ('%d/%d'%(i+1, len(images))) 102 | 103 | dp = os.path.join(args.dp_path, im_name.split('.')[0]+'_iuv.png') 104 | 105 | iuv = np.array(Image.open(dp)) 106 | h, w, _ = iuv.shape 107 | if np.sum(iuv[:,:,0]==0)==(h*w): 108 | print ('no human: invalid image %d: %s'%(i, im_name)) 109 | else: 110 | uv_coor, uv_mask, uv_symm_mask = getSymXYcoordinates(iuv, resolution = 512) 111 | np.save(os.path.join(args.save_path, '%s_uv_coor.npy'%(im_name.split('.')[0])), uv_coor) 112 | mask_im = Image.fromarray((uv_mask*255).astype(np.uint8)) 113 | mask_im.save(os.path.join(args.save_path, im_name.split('.')[0]+'_uv_mask.png')) 114 | mask_im = Image.fromarray((uv_symm_mask*255).astype(np.uint8)) 115 | mask_im.save(os.path.join(args.save_path, im_name.split('.')[0]+'_uv_symm_mask.png')) 116 | -------------------------------------------------------------------------------- /util/generate_fashion_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from PIL import Image 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--dataroot', type=str, help="path to dataset") 8 | args = parser.parse_args() 9 | 10 | IMG_EXTENSIONS = [ 11 | '.jpg', '.JPG', '.jpeg', '.JPEG', 12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 13 | ] 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def make_dataset(dir): 19 | images = [] 20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 21 | 22 | train_root = os.path.join(dir, 'train') 23 | if not os.path.exists(train_root): 24 | os.mkdir(train_root) 25 | 26 | test_root = os.path.join(dir, 'test') 27 | if not os.path.exists(test_root): 28 | os.mkdir(test_root) 29 | 30 | train_images = [] 31 | train_f = open(os.path.join(dir, 'tools', 'train.lst'), 'r') 32 | for lines in train_f: 33 | lines = lines.strip() 34 | if lines.endswith('.jpg'): 35 | train_images.append(lines) 36 | 37 | test_images = [] 38 | test_f = open(os.path.join(dir, 'tools', 'test.lst'), 'r') 39 | for lines in test_f: 40 | lines = lines.strip() 41 | if lines.endswith('.jpg'): 42 | test_images.append(lines) 43 | 44 | 45 | for root, _, fnames in sorted(os.walk(os.path.join(dir, 'img_highres'))): 46 | for fname in fnames: 47 | if is_image_file(fname): 48 | path = os.path.join(root, fname) 49 | path_names = path.split('/') 50 | print(path_names) 51 | 52 | path_names = path_names[4:] 53 | del path_names[1] 54 | path_names[0] = 'fashion' 55 | path_names[3] = path_names[3].replace('_', '') 56 | path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:]) 57 | path_names = "".join(path_names) 58 | 59 | if path_names in train_images: 60 | shutil.copy(path, os.path.join(train_root, path_names)) 61 | print('saving -- %s'%os.path.join(train_root, path_names)) 62 | elif path_names in test_images: 63 | shutil.copy(path, os.path.join(test_root, path_names)) 64 | print('saving -- %s'%os.path.join(train_root, path_names)) 65 | 66 | make_dataset(args.dataroot) 67 | -------------------------------------------------------------------------------- /util/pickle2png.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from PIL import Image 4 | import argparse 5 | import os 6 | 7 | 8 | def save_image(image_numpy, image_path): 9 | image_pil = Image.fromarray(image_numpy.astype(np.uint8)) 10 | image_pil.save(image_path) 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--pickle_file', type=str, help="path to pickle file") 14 | parser.add_argument("--save_path", type=str, help="path to save the png images") 15 | args = parser.parse_args() 16 | 17 | # READING 18 | f = open(args.pickle_file, 'rb') 19 | data = pickle.load(f) 20 | data_size = len(data) 21 | print ('Will process %d images'%data_size) 22 | 23 | # save path 24 | if not os.path.exists(args.save_path): 25 | os.makedirs(args.save_path) 26 | 27 | for img_id in range(data_size): 28 | # assuming we always have 1 person 29 | name = data[img_id]['file_name'] 30 | iuv_image_name = name.split('/')[-1].split('.')[0]+ '_iuv.png' 31 | iuv_name = os.path.join(args.save_path, iuv_image_name) 32 | size = np.array(Image.open(name).convert('RGB')).shape 33 | wrapped_iuv = np.zeros(size) 34 | 35 | print ('Processing %d/%d: %s'%(img_id+1, data_size, iuv_image_name)) 36 | num_instances = len(data[img_id]['scores']) 37 | if num_instances is 0: 38 | print ('%s has no person.'%iuv_image_name) 39 | file_object.write(iuv_image_name+'\n') 40 | else: 41 | # get results - process first detected human 42 | instance_id = 0 43 | # process highest score detected human 44 | # instance_id = data[img_id]['scores'].numpy().tolist().index(max(data[img_id]['scores'].numpy().tolist())) 45 | pred_densepose_result = data[img_id]['pred_densepose'][instance_id] 46 | bbox_xyxy = data[img_id]['pred_boxes_XYXY'][instance_id] 47 | i = pred_densepose_result.labels 48 | uv = pred_densepose_result.uv * 255 49 | iuv = np.concatenate((np.expand_dims(i, 0), uv), axis=0) 50 | # 3xhxw to hxwx3 51 | iuv_arr = np.transpose(iuv, (1, 2, 0)) 52 | # wrap iuv to size of image 53 | wrapped_iuv[int(bbox_xyxy[1]):iuv_arr.shape[0]+int(bbox_xyxy[1]), int(bbox_xyxy[0]):iuv_arr.shape[1]+int(bbox_xyxy[0]), :] = iuv_arr 54 | save_image(wrapped_iuv, iuv_name) 55 | --------------------------------------------------------------------------------