├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── dataset_logo.png ├── dataset_overview.png ├── results.png └── ui.png ├── configs ├── index_pred_net.yml ├── parsing_gen.yml ├── parsing_token.yml ├── sample_from_parsing.yml ├── sample_from_pose.yml ├── sampler.yml ├── vqvae_bottom.yml └── vqvae_top.yml ├── data ├── __init__.py ├── mask_dataset.py ├── parsing_generation_segm_attr_dataset.py ├── pose_attr_dataset.py └── segm_attr_dataset.py ├── environment └── text2human_env.yaml ├── models ├── __init__.py ├── archs │ ├── __init__.py │ ├── fcn_arch.py │ ├── shape_attr_embedding_arch.py │ ├── transformer_arch.py │ ├── unet_arch.py │ └── vqgan_arch.py ├── hierarchy_inference_model.py ├── hierarchy_vqgan_model.py ├── losses │ ├── __init__.py │ ├── accuracy.py │ ├── cross_entropy_loss.py │ ├── segmentation_loss.py │ └── vqgan_loss.py ├── parsing_gen_model.py ├── sample_model.py ├── transformer_model.py └── vqgan_model.py ├── sample_from_parsing.py ├── sample_from_pose.py ├── train_index_prediction.py ├── train_parsing_gen.py ├── train_parsing_token.py ├── train_sampler.py ├── train_vqvae.py ├── ui ├── __init__.py ├── color_blocks │ ├── class_bag.png │ ├── class_belt.png │ ├── class_bg.png │ ├── class_dress.png │ ├── class_earstuds.png │ ├── class_eyeglass.png │ ├── class_face.png │ ├── class_footwear.png │ ├── class_glove.png │ ├── class_hair.png │ ├── class_headwear.png │ ├── class_leggings.png │ ├── class_necklace.png │ ├── class_neckwear.png │ ├── class_outer.png │ ├── class_pants.png │ ├── class_ring.png │ ├── class_rompers.png │ ├── class_skin.png │ ├── class_skirt.png │ ├── class_socks.png │ ├── class_tie.png │ ├── class_top.png │ └── class_wrist.png ├── icons │ ├── icon_palette.png │ └── icon_title.png ├── mouse_event.py └── ui.py ├── ui_demo.py ├── ui_util ├── __init__.py └── config.py └── utils ├── __init__.py ├── language_utils.py ├── logger.py ├── options.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .cache/ 3 | datasets/* 4 | experiments/* 5 | tb_logger/* 6 | results/* 7 | *.png 8 | *.txt 9 | *.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2022 S-Lab 4 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text2Human - Official PyTorch Implementation 2 | 3 | 4 | 5 | This repository provides the official PyTorch implementation for the following paper: 6 | 7 | **Text2Human: Text-Driven Controllable Human Image Generation**
8 | [Yuming Jiang](https://yumingj.github.io/), [Shuai Yang](https://williamyang1991.github.io/), [Haonan Qiu](http://haonanqiu.com/), [Wayne Wu](https://wywu.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) and [Ziwei Liu](https://liuziwei7.github.io/)
9 | In ACM Transactions on Graphics (Proceedings of SIGGRAPH), 2022. 10 | 11 | From [MMLab@NTU](https://www.mmlab-ntu.com/index.html) affliated with S-Lab, Nanyang Technological University and SenseTime Research. 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
The lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt.The man wears a long and floral shirt, and long pants with the pure color pattern.A lady is wearing a sleeveless pure-color shirt and long jeansThe man wears a short-sleeve T-shirt with the pure color pattern and a short pants with the pure color pattern.
27 | 28 | [**[Project Page]**](https://yumingj.github.io/projects/Text2Human.html) | [**[Paper]**](https://arxiv.org/pdf/2205.15996.pdf) | [**[Dataset]**](https://github.com/yumingj/DeepFashion-MultiModal) | [**[Demo Video]**](https://youtu.be/yKh4VORA_E0) | [**[Gradio Web Demo]**](https://huggingface.co/spaces/CVPR/Text2Human) 29 | 30 | 31 | ## Updates 32 | 33 | - [09/2022] :fire::fire::fire:**We have released a high-quality 3D human generative model [EVA3D](https://hongfz16.github.io/projects/EVA3D.html)!**:fire::fire::fire: 34 | - [07/2022] Release the model trained on [SHHQ dataset](https://stylegan-human.github.io/)! 35 | - [07/2022] Try out the web demo of [drawings-to-human](https://huggingface.co/spaces/CVPR/drawings-to-human)! [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/CVPR/drawings-to-human). 36 | - [06/2022] Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/CVPR/Text2Human) 37 | - [05/2022] Paper and demo video are released. 38 | - [05/2022] Code is released. 39 | - [05/2022] This website is created. 40 | 41 | 42 | ## Installation 43 | **Clone this repo:** 44 | ```bash 45 | git clone https://github.com/yumingj/Text2Human.git 46 | cd Text2Human 47 | ``` 48 | **Dependencies:** 49 | 50 | All dependencies for defining the environment are provided in `environment/text2human_env.yaml`. 51 | We recommend using [Anaconda](https://docs.anaconda.com/anaconda/install/) to manage the python environment: 52 | ```bash 53 | conda env create -f ./environment/text2human_env.yaml 54 | conda activate text2human 55 | pip install mmcv-full==1.2.1 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html 56 | pip install mmsegmentation==0.9.0 57 | conda install -c huggingface tokenizers=0.9.4 58 | conda install -c huggingface transformers=4.0.0 59 | conda install -c conda-forge sentence-transformers=2.0.0 60 | ``` 61 | 62 | If it doesn't work, you may need to install the following packages on your own: 63 | - Python 3.6 64 | - PyTorch 1.7.1 65 | - CUDA 10.1 66 | - [sentence-transformers](https://huggingface.co/sentence-transformers) 2.0.0 67 | - [tokenizers](https://pypi.org/project/tokenizers/) 0.9.4 68 | - [transformers](https://huggingface.co/docs/transformers/installation) 4.0.0 69 | 70 | ## (1) Dataset Preparation 71 | 72 | In this work, we contribute a large-scale high-quality dataset with rich multi-modal annotations named [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset. 73 | Here we pre-processed the raw annotations of the original dataset for the task of text-driven controllable human image generation. The pre-processing pipeline consists of: 74 | - align the human body in the center of the images according to the human pose 75 | - fuse the clothing color and clothing fabric annotations into one texture annotation 76 | - do some annotation cleaning and image filtering 77 | - split the whole dataset into the training set and testing set 78 | 79 | You can download our processed dataset from this [Google Drive](https://drive.google.com/file/d/1KIoFfRZNQVn6RV_wTxG2wZmY8f2T_84B/view?usp=sharing). If you want to access the raw annotations, please refer to the [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset. 80 | 81 | After downloading the dataset, unzip the file and put them under the dataset folder with the following structure: 82 | ``` 83 | ./datasets 84 | ├── train_images 85 | ├── xxx.png 86 | ... 87 | ├── xxx.png 88 | └── xxx.png 89 | ├── test_images 90 | % the same structure as in train_images 91 | ├── densepose 92 | % the same structure as in train_images 93 | ├── segm 94 | % the same structure as in train_images 95 | ├── shape_ann 96 | ├── test_ann_file.txt 97 | ├── train_ann_file.txt 98 | └── val_ann_file.txt 99 | └── texture_ann 100 | ├── test 101 | ├── lower_fused.txt 102 | ├── outer_fused.txt 103 | └── upper_fused.txt 104 | ├── train 105 | % the same files as in test 106 | └── val 107 | % the same files as in test 108 | ``` 109 | 110 | ## (2) Sampling 111 | 112 | ### HuggingFace Demo 113 | [Full Web Demo](https://huggingface.co/spaces/CVPR/Text2Human)[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/CVPR/Text2Human) 114 | 115 | [Drawing-to-human](https://huggingface.co/spaces/CVPR/drawings-to-human)[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/CVPR/drawings-to-human) 116 | 117 | 118 | ### Colab 119 | [Unofficial Demo](https://colab.research.google.com/drive/1AVwbqLwMp_Gz3KTCgBTtnGVtXIlCZDPk#scrollTo=wMeHXDu11ebH) implemented by [@neverix](https://github.com/neverix). 120 | 121 | 122 | ### Pretrained Models 123 | 124 | Pretrained models can be downloaded from the model zoo. Unzip the file and put them under the pretrained_models folder with the following structure: 125 | ``` 126 | pretrained_models 127 | ├── index_pred_net.pth 128 | ├── parsing_gen.pth 129 | ├── parsing_token.pth 130 | ├── sampler.pth 131 | ├── vqvae_bottom.pth 132 | └── vqvae_top.pth 133 | ``` 134 | 135 | #### Model Zoo 136 | | Model | Dataset | Annotations | 137 | | :--- | :---- | :---- | 138 | | [Standard Model](https://drive.google.com/file/d/1VyI8_AbPwAUaZJPaPba8zxsFIWumlDen/view?usp=sharing) | [DeepFashion-Multimodal](https://github.com/yumingj/DeepFashion-MultiModal) | Follow the dataset preparation in Step(1) | 139 | | [Extended Model](https://drive.google.com/file/d/1hK1Yu2PA03UuDhewu_sC-WGiEO7h-j6G/view?usp=sharing) | [SHHQ](https://stylegan-human.github.io/) | Replace the annotations with the following ones: [densepose](https://drive.google.com/file/d/1nWRAdjoBqAjrFGaxtNClsvh6JHB8NOnn/view?usp=sharing), [segm](https://drive.google.com/file/d/1tz5i4zjPfYn1fWc5nGCflajQBrdruGj2/view?usp=sharing), [shape](https://drive.google.com/file/d/1Cqo62ffCKuiCiAPyIQ_HT7JqpLQQvCGu/view?usp=sharing), [texture](https://drive.google.com/file/d/1xyFyHGvlp-Qly7t8TT_IlSX7YJlikyy3/view?usp=sharing) | 140 | 141 | **Remark**: For fair research comparisons, it is suggested to use the standard model. 142 | 143 | ### Generation from Paring Maps 144 | You can generate images from given parsing maps and pre-defined texture annotations: 145 | ```python 146 | python sample_from_parsing.py -opt ./configs/sample_from_parsing.yml 147 | ``` 148 | The results are saved in the folder `./results/sampling_from_parsing`. 149 | 150 | ### Generation from Poses 151 | You can generate images from given human poses and pre-defined clothing shape and texture annotations: 152 | ```python 153 | python sample_from_pose.py -opt ./configs/sample_from_pose.yml 154 | ``` 155 | 156 | **Remarks**: The above two scripts generate images without language interactions. If you want to generate images using texts, you can use the notebook or our user interface. 157 | 158 | ### User Interface 159 | 160 | ```python 161 | python ui_demo.py 162 | ``` 163 | 164 | 165 | The descriptions for shapes should follow the following format: 166 | ``` 167 | , , , , , ... 168 | 169 | Note: The outer clothing type and accessories can be omitted. 170 | 171 | Examples: 172 | man, sleeveless T-shirt, long pants 173 | woman, short-sleeve T-shirt, short jeans 174 | ``` 175 | 176 | The descriptions for textures should follow the following format: 177 | ``` 178 | , , 179 | 180 | Note: Currently, we only support 5 types of textures, i.e., pure color, stripe/spline, plaid/lattice, 181 | floral, denim. Your inputs should be restricted to these textures. 182 | ``` 183 | 184 | ## (3) Training Text2Human 185 | 186 | ### Stage I: Pose to Parsing 187 | Train the parsing generation network. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1MNyFLGqIQcOMg_HhgwCmKqdwfQSjeg_6/view?usp=sharing). 188 | ```python 189 | python train_parsing_gen.py -opt ./configs/parsing_gen.yml 190 | ``` 191 | 192 | ### Stage II: Parsing to Human 193 | 194 | **Step 1: Train the top level of the hierarchical VQVAE.** 195 | We provide our pretrained model [here](https://drive.google.com/file/d/1TwypUg85gPFJtMwBLUjVS66FKR3oaTz8/view?usp=sharing). This model is trained by: 196 | ```python 197 | python train_vqvae.py -opt ./configs/vqvae_top.yml 198 | ``` 199 | 200 | **Step 2: Train the bottom level of the hierarchical VQVAE.** 201 | We provide our pretrained model [here](https://drive.google.com/file/d/15hzbY-RG-ILgzUqqGC0qMzlS4OayPdRH/view?usp=sharing). This model is trained by: 202 | ```python 203 | python train_vqvae.py -opt ./configs/vqvae_bottom.yml 204 | ``` 205 | 206 | **Stage 3 & 4: Train the sampler with mixture-of-experts.** To train the sampler, we first need to train a model to tokenize the parsing maps. You can access our pretrained parsing maps [here](https://drive.google.com/file/d/1GLHoOeCP6sMao1-R63ahJMJF7-J00uir/view?usp=sharing). 207 | ```python 208 | python train_parsing_token.py -opt ./configs/parsing_token.yml 209 | ``` 210 | 211 | With the parsing tokenization model, the sampler is trained by: 212 | ```python 213 | python train_sampler.py -opt ./configs/sampler.yml 214 | ``` 215 | Our pretrained sampler is provided [here](https://drive.google.com/file/d/1OQO_kG2fK7eKiG1VJH1OL782X71UQAmS/view?usp=sharing). 216 | 217 | **Stage 5: Train the index prediction network.** 218 | We provide our pretrained index prediction network [here](https://drive.google.com/file/d/1rqhkQD-JGd7YBeIfDvMV-vjfbNHpIhYm/view?usp=sharing). It is trained by: 219 | ```python 220 | python train_index_prediction.py -opt ./configs/index_pred_net.yml 221 | ``` 222 | 223 | 224 | **Remarks**: In the config files, we use the path to our models as the required pretrained models. If you want to train the models from scratch, please replace the path to your own one. We set the numbers of the training epochs as large numbers and you can choose the best epoch for each model. For your reference, our pretrained parsing generation network is trained for 50 epochs, top-level VQVAE is trained for 135 epochs, bottom-level VQVAE is trained for 70 epochs, parsing tokenization network is trained for 20 epochs, sampler is trained for 95 epochs, and the index prediction network is trained for 70 epochs. 225 | 226 | ## (4) Results 227 | 228 | Please visit our [Project Page](https://yumingj.github.io/projects/Text2Human.html#results) to view more results.
229 | You can select the attribtues to customize the desired human images. 230 | [ 231 | ](https://yumingj.github.io/projects/Text2Human.html#results) 232 | 233 | ## DeepFashion-MultiModal Dataset 234 | 235 | 236 | 237 | In this work, we also propose **DeepFashion-MultiModal**, a large-scale high-quality human dataset with rich multi-modal annotations. It has the following properties: 238 | 1. It contains 44,096 high-resolution human images, including 12,701 full body human images. 239 | 2. For each full body images, we **manually annotate** the human parsing labels of 24 classes. 240 | 3. For each full body images, we **manually annotate** the keypoints. 241 | 4. We extract DensePose for each human image. 242 | 5. Each image is **manually annotated** with attributes for both clothes shapes and textures. 243 | 6. We provide a textual description for each image. 244 | 245 | 246 | 247 | Please refer to [this repo](https://github.com/yumingj/DeepFashion-MultiModal) for more details about our proposed dataset. 248 | 249 | ## Citation 250 | 251 | If you find this work useful for your research, please consider citing our paper: 252 | 253 | ```bibtex 254 | @article{jiang2022text2human, 255 | title={Text2Human: Text-Driven Controllable Human Image Generation}, 256 | author={Jiang, Yuming and Yang, Shuai and Qiu, Haonan and Wu, Wayne and Loy, Chen Change and Liu, Ziwei}, 257 | journal={ACM Transactions on Graphics (TOG)}, 258 | volume={41}, 259 | number={4}, 260 | articleno={162}, 261 | pages={1--11}, 262 | year={2022}, 263 | publisher={ACM New York, NY, USA}, 264 | doi={10.1145/3528223.3530104}, 265 | } 266 | ``` 267 | 268 | ## Acknowledgments 269 | 270 | Part of the code is borrowed from [unleashing-transformers](https://github.com/samb-t/unleashing-transformers), [taming-transformers](https://github.com/CompVis/taming-transformers) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation). 271 | -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/1.png -------------------------------------------------------------------------------- /assets/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/2.png -------------------------------------------------------------------------------- /assets/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/3.png -------------------------------------------------------------------------------- /assets/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/4.png -------------------------------------------------------------------------------- /assets/dataset_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/dataset_logo.png -------------------------------------------------------------------------------- /assets/dataset_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/dataset_overview.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/results.png -------------------------------------------------------------------------------- /assets/ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/ui.png -------------------------------------------------------------------------------- /configs/index_pred_net.yml: -------------------------------------------------------------------------------- 1 | name: index_prediction_network 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 4 9 | train_img_dir: ./datasets/train_images 10 | test_img_dir: ./datasets/test_images 11 | segm_dir: ./datasets/segm 12 | pose_dir: ./datasets/densepose 13 | train_ann_file: ./datasets/texture_ann/train 14 | val_ann_file: ./datasets/texture_ann/val 15 | test_ann_file: ./datasets/texture_ann/test 16 | downsample_factor: 2 17 | 18 | model_type: VQGANTextureAwareSpatialHierarchyInferenceModel 19 | # network configs 20 | embed_dim: 256 21 | n_embed: 1024 22 | codebook_spatial_size: 2 23 | 24 | # bottom level vqvae 25 | bot_n_embed: 512 26 | bot_double_z: false 27 | bot_z_channels: 256 28 | bot_resolution: 512 29 | bot_in_channels: 3 30 | bot_out_ch: 3 31 | bot_ch: 128 32 | bot_ch_mult: [1, 1, 2, 4] 33 | bot_num_res_blocks: 2 34 | bot_attn_resolutions: [64] 35 | bot_dropout: 0.0 36 | bot_vae_path: ./pretrained_models/vqvae_bottom.pth 37 | 38 | # top level vqgan 39 | top_double_z: false 40 | top_z_channels: 256 41 | top_resolution: 512 42 | top_in_channels: 3 43 | top_out_ch: 3 44 | top_ch: 128 45 | top_ch_mult: [1, 1, 2, 2, 4] 46 | top_num_res_blocks: 2 47 | top_attn_resolutions: [32] 48 | top_dropout: 0.0 49 | top_vae_path: ./pretrained_models/vqvae_top.pth 50 | 51 | # unet configs 52 | encoder_in_channels: 256 53 | fc_in_channels: 64 54 | fc_in_index: 4 55 | fc_channels: 64 56 | fc_num_convs: 1 57 | fc_concat_input: False 58 | fc_dropout_ratio: 0.1 59 | fc_num_classes: 512 60 | fc_align_corners: False 61 | 62 | disc_layers: 3 63 | disc_weight_max: 1 64 | disc_start_step: 30001 65 | n_channels: 3 66 | ndf: 64 67 | nf: 128 68 | perceptual_weight: 1.0 69 | 70 | num_segm_classes: 24 71 | 72 | # training configs 73 | val_freq: 5 74 | print_freq: 100 75 | weight_decay: 0 76 | manual_seed: 2021 77 | num_epochs: 100 78 | lr: !!float 1.0e-04 79 | lr_decay: step 80 | gamma: 1.0 81 | step: 50 82 | optimizer: Adam 83 | loss_function: cross_entropy 84 | 85 | -------------------------------------------------------------------------------- /configs/parsing_gen.yml: -------------------------------------------------------------------------------- 1 | name: parsing_generation 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 8 8 | num_workers: 4 9 | segm_dir: ./datasets/segm 10 | pose_dir: ./datasets/densepose 11 | train_ann_file: ./datasets/shape_ann/train_ann_file.txt 12 | val_ann_file: ./datasets/shape_ann/val_ann_file.txt 13 | test_ann_file: ./datasets/shape_ann/test_ann_file.txt 14 | downsample_factor: 2 15 | 16 | model_type: ParsingGenModel 17 | # network configs 18 | embedder_dim: 8 19 | embedder_out_dim: 128 20 | attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2] 21 | encoder_in_channels: 1 22 | fc_in_channels: 64 23 | fc_in_index: 4 24 | fc_channels: 64 25 | fc_num_convs: 1 26 | fc_concat_input: False 27 | fc_dropout_ratio: 0.1 28 | fc_num_classes: 24 29 | fc_align_corners: False 30 | 31 | # training configs 32 | val_freq: 5 33 | print_freq: 100 34 | weight_decay: 0 35 | manual_seed: 2021 36 | num_epochs: 100 37 | lr: !!float 1e-4 38 | lr_decay: step 39 | gamma: 0.1 40 | step: 50 41 | -------------------------------------------------------------------------------- /configs/parsing_token.yml: -------------------------------------------------------------------------------- 1 | name: parsing_tokenization 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 4 9 | train_img_dir: ./datasets/train_images 10 | test_img_dir: ./datasets/test_images 11 | segm_dir: ./datasets/segm 12 | pose_dir: ./datasets/densepose 13 | train_ann_file: ./datasets/texture_ann/train 14 | val_ann_file: ./datasets/texture_ann/val 15 | test_ann_file: ./datasets/texture_ann/test 16 | downsample_factor: 2 17 | 18 | model_type: VQSegmentationModel 19 | # network configs 20 | embed_dim: 32 21 | n_embed: 1024 22 | image_key: "segmentation" 23 | n_labels: 24 24 | double_z: false 25 | z_channels: 32 26 | resolution: 512 27 | in_channels: 24 28 | out_ch: 24 29 | ch: 64 30 | ch_mult: [1, 1, 2, 2, 4] 31 | num_res_blocks: 1 32 | attn_resolutions: [16] 33 | dropout: 0.0 34 | 35 | num_segm_classes: 24 36 | 37 | 38 | # training configs 39 | val_freq: 5 40 | print_freq: 100 41 | weight_decay: 0 42 | manual_seed: 2021 43 | num_epochs: 100 44 | lr: !!float 4.5e-05 45 | lr_decay: step 46 | gamma: 0.1 47 | step: 50 48 | -------------------------------------------------------------------------------- /configs/sample_from_parsing.yml: -------------------------------------------------------------------------------- 1 | name: sample_from_parsing 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 4 9 | test_img_dir: ./datasets/test_images 10 | segm_dir: ./datasets/segm 11 | pose_dir: ./datasets/densepose 12 | test_ann_file: ./datasets/texture_ann/test 13 | downsample_factor: 2 14 | 15 | model_type: SampleFromParsingModel 16 | # network configs 17 | embed_dim: 256 18 | n_embed: 1024 19 | codebook_spatial_size: 2 20 | 21 | # bottom level vqvae 22 | bot_n_embed: 512 23 | bot_codebook_spatial_size: 2 24 | bot_double_z: false 25 | bot_z_channels: 256 26 | bot_resolution: 512 27 | bot_in_channels: 3 28 | bot_out_ch: 3 29 | bot_ch: 128 30 | bot_ch_mult: [1, 1, 2, 4] 31 | bot_num_res_blocks: 2 32 | bot_attn_resolutions: [64] 33 | bot_dropout: 0.0 34 | bot_vae_path: ./pretrained_models/vqvae_bottom.pth 35 | 36 | # top level vqgan 37 | top_double_z: false 38 | top_z_channels: 256 39 | top_resolution: 512 40 | top_in_channels: 3 41 | top_out_ch: 3 42 | top_ch: 128 43 | top_ch_mult: [1, 1, 2, 2, 4] 44 | top_num_res_blocks: 2 45 | top_attn_resolutions: [32] 46 | top_dropout: 0.0 47 | top_vae_path: ./pretrained_models/vqvae_top.pth 48 | 49 | # unet configs 50 | index_pred_encoder_in_channels: 256 51 | index_pred_fc_in_channels: 64 52 | index_pred_fc_in_index: 4 53 | index_pred_fc_channels: 64 54 | index_pred_fc_num_convs: 1 55 | index_pred_fc_concat_input: False 56 | index_pred_fc_dropout_ratio: 0.1 57 | index_pred_fc_num_classes: 512 58 | index_pred_fc_align_corners: False 59 | pretrained_index_network: ./pretrained_models/index_pred_net.pth 60 | 61 | # segmentation tokenization 62 | segm_double_z: false 63 | segm_z_channels: 32 64 | segm_resolution: 512 65 | segm_in_channels: 24 66 | segm_out_ch: 24 67 | segm_ch: 64 68 | segm_ch_mult: [1, 1, 2, 2, 4] 69 | segm_num_res_blocks: 1 70 | segm_attn_resolutions: [16] 71 | segm_dropout: 0.0 72 | segm_num_segm_classes: 24 73 | segm_n_embed: 1024 74 | segm_embed_dim: 32 75 | segm_token_path: ./pretrained_models/parsing_token.pth 76 | 77 | # sampler configs 78 | codebook_size: 18432 79 | segm_codebook_size: 1024 80 | texture_codebook_size: 18 81 | bert_n_emb: 512 82 | bert_n_layers: 24 83 | bert_n_head: 8 84 | block_size: 512 # 32 x 16 85 | latent_shape: [32, 16] 86 | embd_pdrop: 0.0 87 | resid_pdrop: 0.0 88 | attn_pdrop: 0.0 89 | num_head: 18 90 | pretrained_sampler: ./pretrained_models/sampler.pth 91 | 92 | manual_seed: 2021 93 | sample_steps: 256 94 | -------------------------------------------------------------------------------- /configs/sample_from_pose.yml: -------------------------------------------------------------------------------- 1 | name: sample_from_pose 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 4 9 | pose_dir: ./datasets/densepose 10 | texture_ann_file: ./datasets/texture_ann/test 11 | shape_ann_path: ./datasets/shape_ann/test_ann_file.txt 12 | downsample_factor: 2 13 | 14 | model_type: SampleFromPoseModel 15 | # network configs 16 | embed_dim: 256 17 | n_embed: 1024 18 | codebook_spatial_size: 2 19 | 20 | # bottom level vqgan 21 | bot_n_embed: 512 22 | bot_codebook_spatial_size: 2 23 | bot_double_z: false 24 | bot_z_channels: 256 25 | bot_resolution: 512 26 | bot_in_channels: 3 27 | bot_out_ch: 3 28 | bot_ch: 128 29 | bot_ch_mult: [1, 1, 2, 4] 30 | bot_num_res_blocks: 2 31 | bot_attn_resolutions: [64] 32 | bot_dropout: 0.0 33 | bot_vae_path: ./pretrained_models/vqvae_bottom.pth 34 | 35 | # top level vqgan 36 | top_double_z: false 37 | top_z_channels: 256 38 | top_resolution: 512 39 | top_in_channels: 3 40 | top_out_ch: 3 41 | top_ch: 128 42 | top_ch_mult: [1, 1, 2, 2, 4] 43 | top_num_res_blocks: 2 44 | top_attn_resolutions: [32] 45 | top_dropout: 0.0 46 | top_vae_path: ./pretrained_models/vqvae_top.pth 47 | 48 | # unet configs 49 | index_pred_encoder_in_channels: 256 50 | index_pred_fc_in_channels: 64 51 | index_pred_fc_in_index: 4 52 | index_pred_fc_channels: 64 53 | index_pred_fc_num_convs: 1 54 | index_pred_fc_concat_input: False 55 | index_pred_fc_dropout_ratio: 0.1 56 | index_pred_fc_num_classes: 512 57 | index_pred_fc_align_corners: False 58 | pretrained_index_network: ./pretrained_models/index_pred_net.pth 59 | 60 | # segmentation tokenization 61 | segm_double_z: false 62 | segm_z_channels: 32 63 | segm_resolution: 512 64 | segm_in_channels: 24 65 | segm_out_ch: 24 66 | segm_ch: 64 67 | segm_ch_mult: [1, 1, 2, 2, 4] 68 | segm_num_res_blocks: 1 69 | segm_attn_resolutions: [16] 70 | segm_dropout: 0.0 71 | segm_num_segm_classes: 24 72 | segm_n_embed: 1024 73 | segm_embed_dim: 32 74 | segm_token_path: ./pretrained_models/parsing_token.pth 75 | 76 | # sampler configs 77 | codebook_size: 18432 78 | segm_codebook_size: 1024 79 | texture_codebook_size: 18 80 | bert_n_emb: 512 81 | bert_n_layers: 24 82 | bert_n_head: 8 83 | block_size: 512 # 32 x 16 84 | latent_shape: [32, 16] 85 | embd_pdrop: 0.0 86 | resid_pdrop: 0.0 87 | attn_pdrop: 0.0 88 | num_head: 18 89 | pretrained_sampler: ./pretrained_models/sampler.pth 90 | 91 | # shape network configs 92 | shape_embedder_dim: 8 93 | shape_embedder_out_dim: 128 94 | shape_attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2] 95 | shape_encoder_in_channels: 1 96 | shape_fc_in_channels: 64 97 | shape_fc_in_index: 4 98 | shape_fc_channels: 64 99 | shape_fc_num_convs: 1 100 | shape_fc_concat_input: False 101 | shape_fc_dropout_ratio: 0.1 102 | shape_fc_num_classes: 24 103 | shape_fc_align_corners: False 104 | pretrained_parsing_gen: ./pretrained_models/parsing_gen.pth 105 | 106 | manual_seed: 2021 107 | sample_steps: 256 108 | -------------------------------------------------------------------------------- /configs/sampler.yml: -------------------------------------------------------------------------------- 1 | name: sampler 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 1 9 | train_img_dir: ./datasets/train_images 10 | test_img_dir: ./datasets/test_images 11 | segm_dir: ./datasets/segm 12 | pose_dir: ./datasets/densepose 13 | train_ann_file: ./datasets/texture_ann/train 14 | val_ann_file: ./datasets/texture_ann/val 15 | test_ann_file: ./datasets/texture_ann/test 16 | downsample_factor: 2 17 | 18 | # pretrained models 19 | img_ae_path: ./pretrained_models/vqvae_top.pth 20 | segm_ae_path: ./pretrained_models/parsing_token.pth 21 | 22 | model_type: TransformerTextureAwareModel 23 | # network configs 24 | 25 | # image autoencoder 26 | img_embed_dim: 256 27 | img_n_embed: 1024 28 | img_double_z: false 29 | img_z_channels: 256 30 | img_resolution: 512 31 | img_in_channels: 3 32 | img_out_ch: 3 33 | img_ch: 128 34 | img_ch_mult: [1, 1, 2, 2, 4] 35 | img_num_res_blocks: 2 36 | img_attn_resolutions: [32] 37 | img_dropout: 0.0 38 | 39 | # segmentation tokenization 40 | segm_double_z: false 41 | segm_z_channels: 32 42 | segm_resolution: 512 43 | segm_in_channels: 24 44 | segm_out_ch: 24 45 | segm_ch: 64 46 | segm_ch_mult: [1, 1, 2, 2, 4] 47 | segm_num_res_blocks: 1 48 | segm_attn_resolutions: [16] 49 | segm_dropout: 0.0 50 | segm_num_segm_classes: 24 51 | segm_n_embed: 1024 52 | segm_embed_dim: 32 53 | 54 | # sampler configs 55 | codebook_size: 18432 56 | segm_codebook_size: 1024 57 | texture_codebook_size: 18 58 | bert_n_emb: 512 59 | bert_n_layers: 24 60 | bert_n_head: 8 61 | block_size: 512 # 32 x 16 62 | latent_shape: [32, 16] 63 | embd_pdrop: 0.0 64 | resid_pdrop: 0.0 65 | attn_pdrop: 0.0 66 | num_head: 18 67 | 68 | # loss configs 69 | loss_type: reweighted_elbo 70 | mask_schedule: random 71 | 72 | sample_steps: 256 73 | 74 | # training configs 75 | val_freq: 5 76 | print_freq: 100 77 | weight_decay: 0 78 | manual_seed: 2021 79 | num_epochs: 100 80 | lr: !!float 1e-4 81 | lr_decay: step 82 | gamma: 1.0 83 | step: 50 84 | -------------------------------------------------------------------------------- /configs/vqvae_bottom.yml: -------------------------------------------------------------------------------- 1 | name: vqvae_bottom 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 4 9 | train_img_dir: ./datasets/train_images 10 | test_img_dir: ./datasets/test_images 11 | segm_dir: ./datasets/segm 12 | pose_dir: ./datasets/densepose 13 | train_ann_file: ./datasets/texture_ann/train 14 | val_ann_file: ./datasets/texture_ann/val 15 | test_ann_file: ./datasets/texture_ann/test 16 | downsample_factor: 2 17 | 18 | model_type: HierarchyVQSpatialTextureAwareModel 19 | # network configs 20 | embed_dim: 256 21 | n_embed: 1024 22 | codebook_spatial_size: 2 23 | 24 | # bottom level vqvae 25 | bot_n_embed: 512 26 | bot_double_z: false 27 | bot_z_channels: 256 28 | bot_resolution: 512 29 | bot_in_channels: 3 30 | bot_out_ch: 3 31 | bot_ch: 128 32 | bot_ch_mult: [1, 1, 2, 4] 33 | bot_num_res_blocks: 2 34 | bot_attn_resolutions: [64] 35 | bot_dropout: 0.0 36 | 37 | # top level vqgan 38 | top_double_z: false 39 | top_z_channels: 256 40 | top_resolution: 512 41 | top_in_channels: 3 42 | top_out_ch: 3 43 | top_ch: 128 44 | top_ch_mult: [1, 1, 2, 2, 4] 45 | top_num_res_blocks: 2 46 | top_attn_resolutions: [32] 47 | top_dropout: 0.0 48 | top_vae_path: ./pretrained_models/vqvae_top.pth 49 | 50 | fix_decoder: false 51 | 52 | disc_layers: 3 53 | disc_weight_max: 1 54 | disc_start_step: 1 55 | n_channels: 3 56 | ndf: 64 57 | nf: 128 58 | perceptual_weight: 1.0 59 | 60 | num_segm_classes: 24 61 | 62 | # training configs 63 | val_freq: 5 64 | print_freq: 100 65 | weight_decay: 0 66 | manual_seed: 2021 67 | num_epochs: 1000 68 | lr: !!float 1.0e-04 69 | lr_decay: step 70 | gamma: 1.0 71 | step: 50 72 | 73 | -------------------------------------------------------------------------------- /configs/vqvae_top.yml: -------------------------------------------------------------------------------- 1 | name: vqvae_top 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 4 9 | train_img_dir: ./datasets/train_images 10 | test_img_dir: ./datasets/test_images 11 | segm_dir: ./datasets/segm 12 | pose_dir: ./datasets/densepose 13 | train_ann_file: ./datasets/texture_ann/train 14 | val_ann_file: ./datasets/texture_ann/val 15 | test_ann_file: ./datasets/texture_ann/test 16 | downsample_factor: 2 17 | 18 | model_type: VQImageSegmTextureModel 19 | # network configs 20 | embed_dim: 256 21 | n_embed: 1024 22 | double_z: false 23 | z_channels: 256 24 | resolution: 512 25 | in_channels: 3 26 | out_ch: 3 27 | ch: 128 28 | ch_mult: [1, 1, 2, 2, 4] 29 | num_res_blocks: 2 30 | attn_resolutions: [32] 31 | dropout: 0.0 32 | 33 | disc_layers: 3 34 | disc_weight_max: 1 35 | disc_start_step: 30001 36 | n_channels: 3 37 | ndf: 64 38 | nf: 128 39 | perceptual_weight: 1.0 40 | 41 | num_segm_classes: 24 42 | 43 | 44 | # training configs 45 | val_freq: 5 46 | print_freq: 100 47 | weight_decay: 0 48 | manual_seed: 2021 49 | num_epochs: 1000 50 | lr: !!float 1.0e-04 51 | lr_decay: step 52 | gamma: 1.0 53 | step: 50 54 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/data/__init__.py -------------------------------------------------------------------------------- /data/mask_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | from PIL import Image 9 | 10 | 11 | class MaskDataset(data.Dataset): 12 | 13 | def __init__(self, segm_dir, ann_dir, downsample_factor=2, xflip=False): 14 | 15 | self._segm_path = segm_dir 16 | self._image_fnames = [] 17 | 18 | self.downsample_factor = downsample_factor 19 | self.xflip = xflip 20 | 21 | # load attributes 22 | assert os.path.exists(f'{ann_dir}/upper_fused.txt') 23 | for idx, row in enumerate( 24 | open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')): 25 | annotations = row.split() 26 | self._image_fnames.append(annotations[0]) 27 | 28 | def _open_file(self, path_prefix, fname): 29 | return open(os.path.join(path_prefix, fname), 'rb') 30 | 31 | def _load_segm(self, raw_idx): 32 | fname = self._image_fnames[raw_idx] 33 | fname = f'{fname[:-4]}_segm.png' 34 | with self._open_file(self._segm_path, fname) as f: 35 | segm = Image.open(f) 36 | if self.downsample_factor != 1: 37 | width, height = segm.size 38 | width = width // self.downsample_factor 39 | height = height // self.downsample_factor 40 | segm = segm.resize( 41 | size=(width, height), resample=Image.NEAREST) 42 | segm = np.array(segm) 43 | # segm = segm[:, :, np.newaxis].transpose(2, 0, 1) 44 | return segm.astype(np.float32) 45 | 46 | def __getitem__(self, index): 47 | segm = self._load_segm(index) 48 | 49 | if self.xflip and random.random() > 0.5: 50 | segm = segm[:, ::-1].copy() 51 | 52 | segm = torch.from_numpy(segm).long() 53 | 54 | return_dict = {'segm': segm, 'img_name': self._image_fnames[index]} 55 | 56 | return return_dict 57 | 58 | def __len__(self): 59 | return len(self._image_fnames) 60 | -------------------------------------------------------------------------------- /data/parsing_generation_segm_attr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | 9 | 10 | class ParsingGenerationDeepFashionAttrSegmDataset(data.Dataset): 11 | 12 | def __init__(self, segm_dir, pose_dir, ann_file, downsample_factor=2): 13 | self._densepose_path = pose_dir 14 | self._segm_path = segm_dir 15 | self._image_fnames = [] 16 | self.attrs = [] 17 | 18 | self.downsample_factor = downsample_factor 19 | 20 | # training, ground-truth available 21 | assert os.path.exists(ann_file) 22 | for row in open(os.path.join(ann_file), 'r'): 23 | annotations = row.split() 24 | self._image_fnames.append(annotations[0]) 25 | self.attrs.append([int(i) for i in annotations[1:]]) 26 | 27 | def _open_file(self, path_prefix, fname): 28 | return open(os.path.join(path_prefix, fname), 'rb') 29 | 30 | def _load_densepose(self, raw_idx): 31 | fname = self._image_fnames[raw_idx] 32 | fname = f'{fname[:-4]}_densepose.png' 33 | with self._open_file(self._densepose_path, fname) as f: 34 | densepose = Image.open(f) 35 | if self.downsample_factor != 1: 36 | width, height = densepose.size 37 | width = width // self.downsample_factor 38 | height = height // self.downsample_factor 39 | densepose = densepose.resize( 40 | size=(width, height), resample=Image.NEAREST) 41 | # channel-wise IUV order, [3, H, W] 42 | densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1) 43 | return densepose.astype(np.float32) 44 | 45 | def _load_segm(self, raw_idx): 46 | fname = self._image_fnames[raw_idx] 47 | fname = f'{fname[:-4]}_segm.png' 48 | with self._open_file(self._segm_path, fname) as f: 49 | segm = Image.open(f) 50 | if self.downsample_factor != 1: 51 | width, height = segm.size 52 | width = width // self.downsample_factor 53 | height = height // self.downsample_factor 54 | segm = segm.resize( 55 | size=(width, height), resample=Image.NEAREST) 56 | segm = np.array(segm) 57 | return segm.astype(np.float32) 58 | 59 | def __getitem__(self, index): 60 | pose = self._load_densepose(index) 61 | segm = self._load_segm(index) 62 | attr = self.attrs[index] 63 | 64 | pose = torch.from_numpy(pose) 65 | segm = torch.LongTensor(segm) 66 | attr = torch.LongTensor(attr) 67 | 68 | pose = pose / 12. - 1 69 | 70 | return_dict = { 71 | 'densepose': pose, 72 | 'segm': segm, 73 | 'attr': attr, 74 | 'img_name': self._image_fnames[index] 75 | } 76 | 77 | return return_dict 78 | 79 | def __len__(self): 80 | return len(self._image_fnames) 81 | -------------------------------------------------------------------------------- /data/pose_attr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | from PIL import Image 9 | 10 | 11 | class DeepFashionAttrPoseDataset(data.Dataset): 12 | 13 | def __init__(self, 14 | pose_dir, 15 | texture_ann_dir, 16 | shape_ann_path, 17 | downsample_factor=2, 18 | xflip=False): 19 | self._densepose_path = pose_dir 20 | self._image_fnames_target = [] 21 | self._image_fnames = [] 22 | self.upper_fused_attrs = [] 23 | self.lower_fused_attrs = [] 24 | self.outer_fused_attrs = [] 25 | self.shape_attrs = [] 26 | 27 | self.downsample_factor = downsample_factor 28 | self.xflip = xflip 29 | 30 | # load attributes 31 | assert os.path.exists(f'{texture_ann_dir}/upper_fused.txt') 32 | for idx, row in enumerate( 33 | open(os.path.join(f'{texture_ann_dir}/upper_fused.txt'), 'r')): 34 | annotations = row.split() 35 | self._image_fnames_target.append(annotations[0]) 36 | self._image_fnames.append(f'{annotations[0].split(".")[0]}.png') 37 | self.upper_fused_attrs.append(int(annotations[1])) 38 | 39 | assert len(self._image_fnames_target) == len(self.upper_fused_attrs) 40 | 41 | assert os.path.exists(f'{texture_ann_dir}/lower_fused.txt') 42 | for idx, row in enumerate( 43 | open(os.path.join(f'{texture_ann_dir}/lower_fused.txt'), 'r')): 44 | annotations = row.split() 45 | assert self._image_fnames_target[idx] == annotations[0] 46 | self.lower_fused_attrs.append(int(annotations[1])) 47 | 48 | assert len(self._image_fnames_target) == len(self.lower_fused_attrs) 49 | 50 | assert os.path.exists(f'{texture_ann_dir}/outer_fused.txt') 51 | for idx, row in enumerate( 52 | open(os.path.join(f'{texture_ann_dir}/outer_fused.txt'), 'r')): 53 | annotations = row.split() 54 | assert self._image_fnames_target[idx] == annotations[0] 55 | self.outer_fused_attrs.append(int(annotations[1])) 56 | 57 | assert len(self._image_fnames_target) == len(self.outer_fused_attrs) 58 | 59 | assert os.path.exists(shape_ann_path) 60 | for idx, row in enumerate(open(os.path.join(shape_ann_path), 'r')): 61 | annotations = row.split() 62 | assert self._image_fnames_target[idx] == annotations[0] 63 | self.shape_attrs.append([int(i) for i in annotations[1:]]) 64 | 65 | def _open_file(self, path_prefix, fname): 66 | return open(os.path.join(path_prefix, fname), 'rb') 67 | 68 | def _load_densepose(self, raw_idx): 69 | fname = self._image_fnames[raw_idx] 70 | fname = f'{fname[:-4]}_densepose.png' 71 | with self._open_file(self._densepose_path, fname) as f: 72 | densepose = Image.open(f) 73 | if self.downsample_factor != 1: 74 | width, height = densepose.size 75 | width = width // self.downsample_factor 76 | height = height // self.downsample_factor 77 | densepose = densepose.resize( 78 | size=(width, height), resample=Image.NEAREST) 79 | # channel-wise IUV order, [3, H, W] 80 | densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1) 81 | return densepose.astype(np.float32) 82 | 83 | def __getitem__(self, index): 84 | pose = self._load_densepose(index) 85 | shape_attr = self.shape_attrs[index] 86 | shape_attr = torch.LongTensor(shape_attr) 87 | 88 | if self.xflip and random.random() > 0.5: 89 | pose = pose[:, :, ::-1].copy() 90 | 91 | upper_fused_attr = self.upper_fused_attrs[index] 92 | lower_fused_attr = self.lower_fused_attrs[index] 93 | outer_fused_attr = self.outer_fused_attrs[index] 94 | 95 | pose = pose / 12. - 1 96 | 97 | return_dict = { 98 | 'densepose': pose, 99 | 'img_name': self._image_fnames_target[index], 100 | 'shape_attr': shape_attr, 101 | 'upper_fused_attr': upper_fused_attr, 102 | 'lower_fused_attr': lower_fused_attr, 103 | 'outer_fused_attr': outer_fused_attr, 104 | } 105 | 106 | return return_dict 107 | 108 | def __len__(self): 109 | return len(self._image_fnames) 110 | -------------------------------------------------------------------------------- /data/segm_attr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | from PIL import Image 9 | 10 | 11 | class DeepFashionAttrSegmDataset(data.Dataset): 12 | 13 | def __init__(self, 14 | img_dir, 15 | segm_dir, 16 | pose_dir, 17 | ann_dir, 18 | downsample_factor=2, 19 | xflip=False): 20 | self._img_path = img_dir 21 | self._densepose_path = pose_dir 22 | self._segm_path = segm_dir 23 | self._image_fnames = [] 24 | self.upper_fused_attrs = [] 25 | self.lower_fused_attrs = [] 26 | self.outer_fused_attrs = [] 27 | 28 | self.downsample_factor = downsample_factor 29 | self.xflip = xflip 30 | 31 | # load attributes 32 | assert os.path.exists(f'{ann_dir}/upper_fused.txt') 33 | for idx, row in enumerate( 34 | open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')): 35 | annotations = row.split() 36 | self._image_fnames.append(annotations[0]) 37 | # assert self._image_fnames[idx] == annotations[0] 38 | self.upper_fused_attrs.append(int(annotations[1])) 39 | 40 | assert len(self._image_fnames) == len(self.upper_fused_attrs) 41 | 42 | assert os.path.exists(f'{ann_dir}/lower_fused.txt') 43 | for idx, row in enumerate( 44 | open(os.path.join(f'{ann_dir}/lower_fused.txt'), 'r')): 45 | annotations = row.split() 46 | assert self._image_fnames[idx] == annotations[0] 47 | self.lower_fused_attrs.append(int(annotations[1])) 48 | 49 | assert len(self._image_fnames) == len(self.lower_fused_attrs) 50 | 51 | assert os.path.exists(f'{ann_dir}/outer_fused.txt') 52 | for idx, row in enumerate( 53 | open(os.path.join(f'{ann_dir}/outer_fused.txt'), 'r')): 54 | annotations = row.split() 55 | assert self._image_fnames[idx] == annotations[0] 56 | self.outer_fused_attrs.append(int(annotations[1])) 57 | 58 | assert len(self._image_fnames) == len(self.outer_fused_attrs) 59 | 60 | # remove the overlapping item between upper cls and lower cls 61 | # cls 21 can appear with upper clothes 62 | # cls 4 can appear with lower clothes 63 | self.upper_cls = [1., 4.] 64 | self.lower_cls = [3., 5., 21.] 65 | self.outer_cls = [2.] 66 | self.other_cls = [ 67 | 11., 18., 7., 8., 9., 10., 12., 16., 17., 19., 20., 22., 23., 15., 68 | 14., 13., 0., 6. 69 | ] 70 | 71 | def _open_file(self, path_prefix, fname): 72 | return open(os.path.join(path_prefix, fname), 'rb') 73 | 74 | def _load_raw_image(self, raw_idx): 75 | fname = self._image_fnames[raw_idx] 76 | with self._open_file(self._img_path, fname) as f: 77 | image = Image.open(f) 78 | if self.downsample_factor != 1: 79 | width, height = image.size 80 | width = width // self.downsample_factor 81 | height = height // self.downsample_factor 82 | image = image.resize( 83 | size=(width, height), resample=Image.LANCZOS) 84 | image = np.array(image) 85 | if image.ndim == 2: 86 | image = image[:, :, np.newaxis] # HW => HWC 87 | image = image.transpose(2, 0, 1) # HWC => CHW 88 | return image 89 | 90 | def _load_densepose(self, raw_idx): 91 | fname = self._image_fnames[raw_idx] 92 | fname = f'{fname[:-4]}_densepose.png' 93 | with self._open_file(self._densepose_path, fname) as f: 94 | densepose = Image.open(f) 95 | if self.downsample_factor != 1: 96 | width, height = densepose.size 97 | width = width // self.downsample_factor 98 | height = height // self.downsample_factor 99 | densepose = densepose.resize( 100 | size=(width, height), resample=Image.NEAREST) 101 | # channel-wise IUV order, [3, H, W] 102 | densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1) 103 | return densepose.astype(np.float32) 104 | 105 | def _load_segm(self, raw_idx): 106 | fname = self._image_fnames[raw_idx] 107 | fname = f'{fname[:-4]}_segm.png' 108 | with self._open_file(self._segm_path, fname) as f: 109 | segm = Image.open(f) 110 | if self.downsample_factor != 1: 111 | width, height = segm.size 112 | width = width // self.downsample_factor 113 | height = height // self.downsample_factor 114 | segm = segm.resize( 115 | size=(width, height), resample=Image.NEAREST) 116 | segm = np.array(segm) 117 | segm = segm[:, :, np.newaxis].transpose(2, 0, 1) 118 | return segm.astype(np.float32) 119 | 120 | def __getitem__(self, index): 121 | image = self._load_raw_image(index) 122 | pose = self._load_densepose(index) 123 | segm = self._load_segm(index) 124 | 125 | if self.xflip and random.random() > 0.5: 126 | assert image.ndim == 3 # CHW 127 | image = image[:, :, ::-1].copy() 128 | pose = pose[:, :, ::-1].copy() 129 | segm = segm[:, :, ::-1].copy() 130 | 131 | image = torch.from_numpy(image) 132 | segm = torch.from_numpy(segm) 133 | 134 | upper_fused_attr = self.upper_fused_attrs[index] 135 | lower_fused_attr = self.lower_fused_attrs[index] 136 | outer_fused_attr = self.outer_fused_attrs[index] 137 | 138 | # mask 0: denotes the common codebook, 139 | # mask (attr + 1): denotes the texture-specific codebook 140 | mask = torch.zeros_like(segm) 141 | if upper_fused_attr != 17: 142 | for cls in self.upper_cls: 143 | mask[segm == cls] = upper_fused_attr + 1 144 | 145 | if lower_fused_attr != 17: 146 | for cls in self.lower_cls: 147 | mask[segm == cls] = lower_fused_attr + 1 148 | 149 | if outer_fused_attr != 17: 150 | for cls in self.outer_cls: 151 | mask[segm == cls] = outer_fused_attr + 1 152 | 153 | pose = pose / 12. - 1 154 | image = image / 127.5 - 1 155 | 156 | return_dict = { 157 | 'image': image, 158 | 'densepose': pose, 159 | 'segm': segm, 160 | 'texture_mask': mask, 161 | 'img_name': self._image_fnames[index] 162 | } 163 | 164 | return return_dict 165 | 166 | def __len__(self): 167 | return len(self._image_fnames) 168 | -------------------------------------------------------------------------------- /environment/text2human_env.yaml: -------------------------------------------------------------------------------- 1 | name: text2human 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - astroid=2.5=py36h06a4308_1 9 | - blas=1.0=mkl 10 | - brotlipy=0.7.0=py36h7b6447c_1000 11 | - ca-certificates=2021.10.26=h06a4308_2 12 | - certifi=2021.5.30=py36h06a4308_0 13 | - cffi=1.14.3=py36he30daa8_0 14 | - chardet=3.0.4=py36_1003 15 | - click=8.0.3=pyhd3eb1b0_0 16 | - cryptography=3.1.1=py36h1ba5d50_0 17 | - cudatoolkit=10.1.243=h6bb024c_0 18 | - dataclasses=0.8=pyh4f3eec9_6 19 | - dbus=1.13.18=hb2f20db_0 20 | - expat=2.2.10=he6710b0_2 21 | - filelock=3.4.0=pyhd3eb1b0_0 22 | - fontconfig=2.13.0=h9420a91_0 23 | - freetype=2.10.4=h5ab3b9f_0 24 | - glib=2.56.2=hd408876_0 25 | - gst-plugins-base=1.14.0=hbbd80ab_1 26 | - gstreamer=1.14.0=hb453b48_1 27 | - icu=58.2=he6710b0_3 28 | - idna=2.10=py_0 29 | - importlib-metadata=4.8.1=py36h06a4308_0 30 | - importlib_metadata=4.8.1=hd3eb1b0_0 31 | - intel-openmp=2020.2=254 32 | - isort=5.7.0=pyhd3eb1b0_0 33 | - joblib=1.0.1=pyhd3eb1b0_0 34 | - jpeg=9b=habf39ab_1 35 | - lazy-object-proxy=1.5.2=py36h27cfd23_0 36 | - lcms2=2.11=h396b838_0 37 | - ld_impl_linux-64=2.33.1=h53a641e_7 38 | - libffi=3.3=he6710b0_2 39 | - libgcc-ng=9.1.0=hdf63c60_0 40 | - libpng=1.6.37=hbc83047_0 41 | - libprotobuf=3.17.2=h4ff587b_1 42 | - libstdcxx-ng=9.1.0=hdf63c60_0 43 | - libtiff=4.2.0=h3942068_0 44 | - libuuid=1.0.3=h1bed415_2 45 | - libuv=1.40.0=h7b6447c_0 46 | - libwebp-base=1.2.0=h27cfd23_0 47 | - libxcb=1.14=h7b6447c_0 48 | - libxml2=2.9.10=hb55368b_3 49 | - lz4-c=1.9.3=h2531618_0 50 | - mccabe=0.6.1=py36_1 51 | - mkl=2020.2=256 52 | - mkl-service=2.3.0=py36he8ac12f_0 53 | - mkl_fft=1.3.0=py36h54f3939_0 54 | - mkl_random=1.1.1=py36h0573a6f_0 55 | - ncurses=6.2=he6710b0_1 56 | - ninja=1.10.2=h5e70eb0_2 57 | - numpy=1.19.2=py36h54aff64_0 58 | - numpy-base=1.19.2=py36hfa32c7d_0 59 | - olefile=0.46=py36_0 60 | - openssl=1.1.1m=h7f8727e_0 61 | - packaging=21.3=pyhd3eb1b0_0 62 | - pcre=8.44=he6710b0_0 63 | - pillow=8.1.2=py36he98fc37_0 64 | - pip=21.0.1=py36h06a4308_0 65 | - protobuf=3.17.2=py36h295c915_0 66 | - pycparser=2.20=py_2 67 | - pylint=2.7.2=py36h06a4308_1 68 | - pyopenssl=19.1.0=py_1 69 | - pyqt=5.9.2=py36h05f1152_2 70 | - pysocks=1.7.1=py36_0 71 | - python=3.6.13=hdb3f193_0 72 | - pytorch=1.7.1=py3.6_cuda10.1.243_cudnn7.6.3_0 73 | - qt=5.9.7=h5867ecd_1 74 | - readline=8.1=h27cfd23_0 75 | - regex=2021.8.3=py36h7f8727e_0 76 | - requests=2.24.0=py_0 77 | - setuptools=52.0.0=py36h06a4308_0 78 | - sip=4.19.8=py36hf484d3e_0 79 | - six=1.15.0=py36h06a4308_0 80 | - sqlite=3.35.2=hdfb4753_0 81 | - tk=8.6.10=hbc83047_0 82 | - toml=0.10.2=pyhd3eb1b0_0 83 | - torchvision=0.8.2=py36_cu101 84 | - tqdm=4.62.3=pyhd3eb1b0_1 85 | - typed-ast=1.4.2=py36h27cfd23_1 86 | - typing-extensions=3.10.0.2=hd3eb1b0_0 87 | - typing_extensions=3.10.0.2=pyh06a4308_0 88 | - urllib3=1.25.11=py_0 89 | - wheel=0.36.2=pyhd3eb1b0_0 90 | - wrapt=1.12.1=py36h7b6447c_1 91 | - xz=5.2.5=h7b6447c_0 92 | - yaml=0.2.5=h7b6447c_0 93 | - zipp=3.6.0=pyhd3eb1b0_0 94 | - zlib=1.2.11=h7b6447c_3 95 | - zstd=1.4.5=h9ceee32_0 96 | - pip: 97 | - addict==2.4.0 98 | - cycler==0.11.0 99 | - einops==0.4.0 100 | - kiwisolver==1.3.1 101 | - matplotlib==3.3.4 102 | - huggingface-hub==0.4.0 103 | - nltk==3.6.7 104 | - opencv-python==4.5.5.62 105 | - pyparsing==3.0.7 106 | - python-dateutil==2.8.2 107 | - pyyaml==6.0 108 | - scikit-learn==0.24.2 109 | - scipy==1.5.4 110 | - sentencepiece==0.1.96 111 | - terminaltables==3.1.10 112 | - threadpoolctl==3.0.0 113 | - yapf==0.32.0 114 | - lpips==0.1.4 115 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | import logging 4 | import os.path as osp 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] 12 | for v in glob.glob(f'{model_folder}/*_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamically instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = logging.getLogger('base') 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /models/archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/models/archs/__init__.py -------------------------------------------------------------------------------- /models/archs/shape_attr_embedding_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class ShapeAttrEmbedding(nn.Module): 7 | 8 | def __init__(self, dim, out_dim, cls_num_list): 9 | super(ShapeAttrEmbedding, self).__init__() 10 | 11 | for idx, cls_num in enumerate(cls_num_list): 12 | setattr( 13 | self, f'attr_{idx}', 14 | nn.Sequential( 15 | nn.Linear(cls_num, dim), nn.LeakyReLU(), 16 | nn.Linear(dim, dim))) 17 | self.cls_num_list = cls_num_list 18 | self.attr_num = len(cls_num_list) 19 | self.fusion = nn.Sequential( 20 | nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(), 21 | nn.Linear(out_dim, out_dim)) 22 | 23 | def forward(self, attr): 24 | attr_embedding_list = [] 25 | for idx in range(self.attr_num): 26 | attr_embed_fc = getattr(self, f'attr_{idx}') 27 | attr_embedding_list.append( 28 | attr_embed_fc( 29 | F.one_hot( 30 | attr[:, idx], 31 | num_classes=self.cls_num_list[idx]).to(torch.float32))) 32 | attr_embedding = torch.cat(attr_embedding_list, dim=1) 33 | attr_embedding = self.fusion(attr_embedding) 34 | 35 | return attr_embedding 36 | -------------------------------------------------------------------------------- /models/archs/transformer_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class CausalSelfAttention(nn.Module): 10 | """ 11 | A vanilla multi-head masked self-attention layer with a projection at the end. 12 | It is possible to use torch.nn.MultiheadAttention here but I am including an 13 | explicit implementation here to show that there is nothing too scary here. 14 | """ 15 | 16 | def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop, 17 | latent_shape, sampler): 18 | super().__init__() 19 | assert bert_n_emb % bert_n_head == 0 20 | # key, query, value projections for all heads 21 | self.key = nn.Linear(bert_n_emb, bert_n_emb) 22 | self.query = nn.Linear(bert_n_emb, bert_n_emb) 23 | self.value = nn.Linear(bert_n_emb, bert_n_emb) 24 | # regularization 25 | self.attn_drop = nn.Dropout(attn_pdrop) 26 | self.resid_drop = nn.Dropout(resid_pdrop) 27 | # output projection 28 | self.proj = nn.Linear(bert_n_emb, bert_n_emb) 29 | self.n_head = bert_n_head 30 | self.causal = True if sampler == 'autoregressive' else False 31 | if self.causal: 32 | block_size = np.prod(latent_shape) 33 | mask = torch.tril(torch.ones(block_size, block_size)) 34 | self.register_buffer("mask", mask.view(1, 1, block_size, 35 | block_size)) 36 | 37 | def forward(self, x, layer_past=None): 38 | B, T, C = x.size() 39 | 40 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 41 | k = self.key(x).view(B, T, self.n_head, 42 | C // self.n_head).transpose(1, 43 | 2) # (B, nh, T, hs) 44 | q = self.query(x).view(B, T, self.n_head, 45 | C // self.n_head).transpose(1, 46 | 2) # (B, nh, T, hs) 47 | v = self.value(x).view(B, T, self.n_head, 48 | C // self.n_head).transpose(1, 49 | 2) # (B, nh, T, hs) 50 | 51 | present = torch.stack((k, v)) 52 | if self.causal and layer_past is not None: 53 | past_key, past_value = layer_past 54 | k = torch.cat((past_key, k), dim=-2) 55 | v = torch.cat((past_value, v), dim=-2) 56 | 57 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 58 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 59 | 60 | if self.causal and layer_past is None: 61 | att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) 62 | 63 | att = F.softmax(att, dim=-1) 64 | att = self.attn_drop(att) 65 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 66 | # re-assemble all head outputs side by side 67 | y = y.transpose(1, 2).contiguous().view(B, T, C) 68 | 69 | # output projection 70 | y = self.resid_drop(self.proj(y)) 71 | return y, present 72 | 73 | 74 | class Block(nn.Module): 75 | """ an unassuming Transformer block """ 76 | 77 | def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop, 78 | latent_shape, sampler): 79 | super().__init__() 80 | self.ln1 = nn.LayerNorm(bert_n_emb) 81 | self.ln2 = nn.LayerNorm(bert_n_emb) 82 | self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop, 83 | resid_pdrop, latent_shape, sampler) 84 | self.mlp = nn.Sequential( 85 | nn.Linear(bert_n_emb, 4 * bert_n_emb), 86 | nn.GELU(), # nice 87 | nn.Linear(4 * bert_n_emb, bert_n_emb), 88 | nn.Dropout(resid_pdrop), 89 | ) 90 | 91 | def forward(self, x, layer_past=None, return_present=False): 92 | 93 | attn, present = self.attn(self.ln1(x), layer_past) 94 | x = x + attn 95 | x = x + self.mlp(self.ln2(x)) 96 | 97 | if layer_past is not None or return_present: 98 | return x, present 99 | return x 100 | 101 | 102 | class Transformer(nn.Module): 103 | """ the full GPT language model, with a context size of block_size """ 104 | 105 | def __init__(self, 106 | codebook_size, 107 | segm_codebook_size, 108 | bert_n_emb, 109 | bert_n_layers, 110 | bert_n_head, 111 | block_size, 112 | latent_shape, 113 | embd_pdrop, 114 | resid_pdrop, 115 | attn_pdrop, 116 | sampler='absorbing'): 117 | super().__init__() 118 | 119 | self.vocab_size = codebook_size + 1 120 | self.n_embd = bert_n_emb 121 | self.block_size = block_size 122 | self.n_layers = bert_n_layers 123 | self.codebook_size = codebook_size 124 | self.segm_codebook_size = segm_codebook_size 125 | self.causal = sampler == 'autoregressive' 126 | if self.causal: 127 | self.vocab_size = codebook_size 128 | 129 | self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd) 130 | self.pos_emb = nn.Parameter( 131 | torch.zeros(1, self.block_size, self.n_embd)) 132 | self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd) 133 | self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd)) 134 | self.drop = nn.Dropout(embd_pdrop) 135 | 136 | # transformer 137 | self.blocks = nn.Sequential(*[ 138 | Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop, 139 | latent_shape, sampler) for _ in range(self.n_layers) 140 | ]) 141 | # decoder head 142 | self.ln_f = nn.LayerNorm(self.n_embd) 143 | self.head = nn.Linear(self.n_embd, self.codebook_size, bias=False) 144 | 145 | def get_block_size(self): 146 | return self.block_size 147 | 148 | def _init_weights(self, module): 149 | if isinstance(module, (nn.Linear, nn.Embedding)): 150 | module.weight.data.normal_(mean=0.0, std=0.02) 151 | if isinstance(module, nn.Linear) and module.bias is not None: 152 | module.bias.data.zero_() 153 | elif isinstance(module, nn.LayerNorm): 154 | module.bias.data.zero_() 155 | module.weight.data.fill_(1.0) 156 | 157 | def forward(self, idx, segm_tokens, t=None): 158 | # each index maps to a (learnable) vector 159 | token_embeddings = self.tok_emb(idx) 160 | 161 | segm_embeddings = self.segm_emb(segm_tokens) 162 | 163 | if self.causal: 164 | token_embeddings = torch.cat((self.start_tok.repeat( 165 | token_embeddings.size(0), 1, 1), token_embeddings), 166 | dim=1) 167 | 168 | t = token_embeddings.shape[1] 169 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 170 | # each position maps to a (learnable) vector 171 | 172 | position_embeddings = self.pos_emb[:, :t, :] 173 | 174 | x = token_embeddings + position_embeddings + segm_embeddings 175 | x = self.drop(x) 176 | for block in self.blocks: 177 | x = block(x) 178 | x = self.ln_f(x) 179 | logits = self.head(x) 180 | 181 | return logits 182 | 183 | 184 | class TransformerMultiHead(nn.Module): 185 | """ the full GPT language model, with a context size of block_size """ 186 | 187 | def __init__(self, 188 | codebook_size, 189 | segm_codebook_size, 190 | texture_codebook_size, 191 | bert_n_emb, 192 | bert_n_layers, 193 | bert_n_head, 194 | block_size, 195 | latent_shape, 196 | embd_pdrop, 197 | resid_pdrop, 198 | attn_pdrop, 199 | num_head, 200 | sampler='absorbing'): 201 | super().__init__() 202 | 203 | self.vocab_size = codebook_size + 1 204 | self.n_embd = bert_n_emb 205 | self.block_size = block_size 206 | self.n_layers = bert_n_layers 207 | self.codebook_size = codebook_size 208 | self.segm_codebook_size = segm_codebook_size 209 | self.texture_codebook_size = texture_codebook_size 210 | self.causal = sampler == 'autoregressive' 211 | if self.causal: 212 | self.vocab_size = codebook_size 213 | 214 | self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd) 215 | self.pos_emb = nn.Parameter( 216 | torch.zeros(1, self.block_size, self.n_embd)) 217 | self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd) 218 | self.texture_emb = nn.Embedding(self.texture_codebook_size, 219 | self.n_embd) 220 | self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd)) 221 | self.drop = nn.Dropout(embd_pdrop) 222 | 223 | # transformer 224 | self.blocks = nn.Sequential(*[ 225 | Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop, 226 | latent_shape, sampler) for _ in range(self.n_layers) 227 | ]) 228 | # decoder head 229 | self.num_head = num_head 230 | self.head_class_num = codebook_size // self.num_head 231 | self.ln_f = nn.LayerNorm(self.n_embd) 232 | self.head_list = nn.ModuleList([ 233 | nn.Linear(self.n_embd, self.head_class_num, bias=False) 234 | for _ in range(self.num_head) 235 | ]) 236 | 237 | def get_block_size(self): 238 | return self.block_size 239 | 240 | def _init_weights(self, module): 241 | if isinstance(module, (nn.Linear, nn.Embedding)): 242 | module.weight.data.normal_(mean=0.0, std=0.02) 243 | if isinstance(module, nn.Linear) and module.bias is not None: 244 | module.bias.data.zero_() 245 | elif isinstance(module, nn.LayerNorm): 246 | module.bias.data.zero_() 247 | module.weight.data.fill_(1.0) 248 | 249 | def forward(self, idx, segm_tokens, texture_tokens, t=None): 250 | # each index maps to a (learnable) vector 251 | token_embeddings = self.tok_emb(idx) 252 | segm_embeddings = self.segm_emb(segm_tokens) 253 | texture_embeddings = self.texture_emb(texture_tokens) 254 | 255 | if self.causal: 256 | token_embeddings = torch.cat((self.start_tok.repeat( 257 | token_embeddings.size(0), 1, 1), token_embeddings), 258 | dim=1) 259 | 260 | t = token_embeddings.shape[1] 261 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 262 | # each position maps to a (learnable) vector 263 | 264 | position_embeddings = self.pos_emb[:, :t, :] 265 | 266 | x = token_embeddings + position_embeddings + segm_embeddings + texture_embeddings 267 | x = self.drop(x) 268 | for block in self.blocks: 269 | x = block(x) 270 | x = self.ln_f(x) 271 | logits_list = [self.head_list[i](x) for i in range(self.num_head)] 272 | 273 | return logits_list 274 | -------------------------------------------------------------------------------- /models/hierarchy_inference_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torchvision.utils import save_image 8 | 9 | from models.archs.fcn_arch import MultiHeadFCNHead 10 | from models.archs.unet_arch import UNet 11 | from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder, 12 | VectorQuantizerSpatialTextureAware, 13 | VectorQuantizerTexture) 14 | from models.losses.accuracy import accuracy 15 | from models.losses.cross_entropy_loss import CrossEntropyLoss 16 | 17 | logger = logging.getLogger('base') 18 | 19 | 20 | class VQGANTextureAwareSpatialHierarchyInferenceModel(): 21 | 22 | def __init__(self, opt): 23 | self.opt = opt 24 | self.device = torch.device('cuda') 25 | self.is_train = opt['is_train'] 26 | 27 | self.top_encoder = Encoder( 28 | ch=opt['top_ch'], 29 | num_res_blocks=opt['top_num_res_blocks'], 30 | attn_resolutions=opt['top_attn_resolutions'], 31 | ch_mult=opt['top_ch_mult'], 32 | in_channels=opt['top_in_channels'], 33 | resolution=opt['top_resolution'], 34 | z_channels=opt['top_z_channels'], 35 | double_z=opt['top_double_z'], 36 | dropout=opt['top_dropout']).to(self.device) 37 | self.decoder = Decoder( 38 | in_channels=opt['top_in_channels'], 39 | resolution=opt['top_resolution'], 40 | z_channels=opt['top_z_channels'], 41 | ch=opt['top_ch'], 42 | out_ch=opt['top_out_ch'], 43 | num_res_blocks=opt['top_num_res_blocks'], 44 | attn_resolutions=opt['top_attn_resolutions'], 45 | ch_mult=opt['top_ch_mult'], 46 | dropout=opt['top_dropout'], 47 | resamp_with_conv=True, 48 | give_pre_end=False).to(self.device) 49 | self.top_quantize = VectorQuantizerTexture( 50 | 1024, opt['embed_dim'], beta=0.25).to(self.device) 51 | self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"], 52 | opt['embed_dim'], 53 | 1).to(self.device) 54 | self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], 55 | opt["top_z_channels"], 56 | 1).to(self.device) 57 | self.load_top_pretrain_models() 58 | 59 | self.bot_encoder = Encoder( 60 | ch=opt['bot_ch'], 61 | num_res_blocks=opt['bot_num_res_blocks'], 62 | attn_resolutions=opt['bot_attn_resolutions'], 63 | ch_mult=opt['bot_ch_mult'], 64 | in_channels=opt['bot_in_channels'], 65 | resolution=opt['bot_resolution'], 66 | z_channels=opt['bot_z_channels'], 67 | double_z=opt['bot_double_z'], 68 | dropout=opt['bot_dropout']).to(self.device) 69 | self.bot_decoder_res = DecoderRes( 70 | in_channels=opt['bot_in_channels'], 71 | resolution=opt['bot_resolution'], 72 | z_channels=opt['bot_z_channels'], 73 | ch=opt['bot_ch'], 74 | num_res_blocks=opt['bot_num_res_blocks'], 75 | ch_mult=opt['bot_ch_mult'], 76 | dropout=opt['bot_dropout'], 77 | give_pre_end=False).to(self.device) 78 | self.bot_quantize = VectorQuantizerSpatialTextureAware( 79 | opt['bot_n_embed'], 80 | opt['embed_dim'], 81 | beta=0.25, 82 | spatial_size=opt['codebook_spatial_size']).to(self.device) 83 | self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"], 84 | opt['embed_dim'], 85 | 1).to(self.device) 86 | self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], 87 | opt["bot_z_channels"], 88 | 1).to(self.device) 89 | 90 | self.load_bot_pretrain_network() 91 | 92 | self.guidance_encoder = UNet( 93 | in_channels=opt['encoder_in_channels']).to(self.device) 94 | self.index_decoder = MultiHeadFCNHead( 95 | in_channels=opt['fc_in_channels'], 96 | in_index=opt['fc_in_index'], 97 | channels=opt['fc_channels'], 98 | num_convs=opt['fc_num_convs'], 99 | concat_input=opt['fc_concat_input'], 100 | dropout_ratio=opt['fc_dropout_ratio'], 101 | num_classes=opt['fc_num_classes'], 102 | align_corners=opt['fc_align_corners'], 103 | num_head=18).to(self.device) 104 | 105 | self.init_training_settings() 106 | 107 | def init_training_settings(self): 108 | optim_params = [] 109 | for v in self.guidance_encoder.parameters(): 110 | if v.requires_grad: 111 | optim_params.append(v) 112 | for v in self.index_decoder.parameters(): 113 | if v.requires_grad: 114 | optim_params.append(v) 115 | # set up optimizers 116 | if self.opt['optimizer'] == 'Adam': 117 | self.optimizer = torch.optim.Adam( 118 | optim_params, 119 | self.opt['lr'], 120 | weight_decay=self.opt['weight_decay']) 121 | elif self.opt['optimizer'] == 'SGD': 122 | self.optimizer = torch.optim.SGD( 123 | optim_params, 124 | self.opt['lr'], 125 | momentum=self.opt['momentum'], 126 | weight_decay=self.opt['weight_decay']) 127 | self.log_dict = OrderedDict() 128 | if self.opt['loss_function'] == 'cross_entropy': 129 | self.loss_func = CrossEntropyLoss().to(self.device) 130 | 131 | def load_top_pretrain_models(self): 132 | # load pretrained vqgan for segmentation mask 133 | top_vae_checkpoint = torch.load(self.opt['top_vae_path']) 134 | self.top_encoder.load_state_dict( 135 | top_vae_checkpoint['encoder'], strict=True) 136 | self.decoder.load_state_dict( 137 | top_vae_checkpoint['decoder'], strict=True) 138 | self.top_quantize.load_state_dict( 139 | top_vae_checkpoint['quantize'], strict=True) 140 | self.top_quant_conv.load_state_dict( 141 | top_vae_checkpoint['quant_conv'], strict=True) 142 | self.top_post_quant_conv.load_state_dict( 143 | top_vae_checkpoint['post_quant_conv'], strict=True) 144 | self.top_encoder.eval() 145 | self.top_quantize.eval() 146 | self.top_quant_conv.eval() 147 | self.top_post_quant_conv.eval() 148 | 149 | def load_bot_pretrain_network(self): 150 | checkpoint = torch.load(self.opt['bot_vae_path']) 151 | self.bot_encoder.load_state_dict( 152 | checkpoint['bot_encoder'], strict=True) 153 | self.bot_decoder_res.load_state_dict( 154 | checkpoint['bot_decoder_res'], strict=True) 155 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True) 156 | self.bot_quantize.load_state_dict( 157 | checkpoint['bot_quantize'], strict=True) 158 | self.bot_quant_conv.load_state_dict( 159 | checkpoint['bot_quant_conv'], strict=True) 160 | self.bot_post_quant_conv.load_state_dict( 161 | checkpoint['bot_post_quant_conv'], strict=True) 162 | 163 | self.bot_encoder.eval() 164 | self.bot_decoder_res.eval() 165 | self.decoder.eval() 166 | self.bot_quantize.eval() 167 | self.bot_quant_conv.eval() 168 | self.bot_post_quant_conv.eval() 169 | 170 | def top_encode(self, x, mask): 171 | h = self.top_encoder(x) 172 | h = self.top_quant_conv(h) 173 | quant, _, _ = self.top_quantize(h, mask) 174 | quant = self.top_post_quant_conv(quant) 175 | 176 | return quant, quant 177 | 178 | def feed_data(self, data): 179 | self.image = data['image'].to(self.device) 180 | self.texture_mask = data['texture_mask'].float().to(self.device) 181 | self.get_gt_indices() 182 | 183 | self.texture_tokens = F.interpolate( 184 | self.texture_mask, size=(32, 16), 185 | mode='nearest').view(self.image.size(0), -1).long() 186 | 187 | def bot_encode(self, x, mask): 188 | h = self.bot_encoder(x) 189 | h = self.bot_quant_conv(h) 190 | _, _, (_, _, indices_list) = self.bot_quantize(h, mask) 191 | 192 | return indices_list 193 | 194 | def get_gt_indices(self): 195 | self.quant_t, self.feature_t = self.top_encode(self.image, 196 | self.texture_mask) 197 | self.gt_indices_list = self.bot_encode(self.image, self.texture_mask) 198 | 199 | def index_to_image(self, index_bottom_list, texture_mask): 200 | quant_b = self.bot_quantize.get_codebook_entry( 201 | index_bottom_list, texture_mask, 202 | (index_bottom_list[0].size(0), index_bottom_list[0].size(1), 203 | index_bottom_list[0].size(2), 204 | self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2) 205 | quant_b = self.bot_post_quant_conv(quant_b) 206 | bot_dec_res = self.bot_decoder_res(quant_b) 207 | 208 | dec = self.decoder(self.quant_t, bot_h=bot_dec_res) 209 | 210 | return dec 211 | 212 | def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path): 213 | rec_img = self.index_to_image(rec_img_index, texture_mask) 214 | pred_img = self.index_to_image(pred_img_index, texture_mask) 215 | 216 | base_img = self.decoder(self.quant_t) 217 | img_cat = torch.cat([ 218 | self.image, 219 | rec_img, 220 | base_img, 221 | pred_img, 222 | ], dim=3).detach() 223 | img_cat = ((img_cat + 1) / 2) 224 | img_cat = img_cat.clamp_(0, 1) 225 | save_image(img_cat, save_path, nrow=1, padding=4) 226 | 227 | def optimize_parameters(self): 228 | self.guidance_encoder.train() 229 | self.index_decoder.train() 230 | 231 | self.feature_enc = self.guidance_encoder(self.feature_t) 232 | self.memory_logits_list = self.index_decoder(self.feature_enc) 233 | 234 | loss = 0 235 | for i in range(18): 236 | loss += self.loss_func( 237 | self.memory_logits_list[i], 238 | self.gt_indices_list[i], 239 | ignore_index=-1) 240 | 241 | self.optimizer.zero_grad() 242 | loss.backward() 243 | self.optimizer.step() 244 | 245 | self.log_dict['loss_total'] = loss 246 | 247 | def inference(self, data_loader, save_dir): 248 | self.guidance_encoder.eval() 249 | self.index_decoder.eval() 250 | 251 | acc = 0 252 | num = 0 253 | 254 | for _, data in enumerate(data_loader): 255 | self.feed_data(data) 256 | img_name = data['img_name'] 257 | 258 | num += self.image.size(0) 259 | 260 | texture_mask_flatten = self.texture_tokens.view(-1) 261 | min_encodings_indices_list = [ 262 | torch.full( 263 | texture_mask_flatten.size(), 264 | fill_value=-1, 265 | dtype=torch.long, 266 | device=texture_mask_flatten.device) for _ in range(18) 267 | ] 268 | with torch.no_grad(): 269 | self.feature_enc = self.guidance_encoder(self.feature_t) 270 | memory_logits_list = self.index_decoder(self.feature_enc) 271 | # memory_indices_pred = memory_logits.argmax(dim=1) 272 | batch_acc = 0 273 | for codebook_idx, memory_logits in enumerate(memory_logits_list): 274 | region_of_interest = texture_mask_flatten == codebook_idx 275 | if torch.sum(region_of_interest) > 0: 276 | memory_indices_pred = memory_logits.argmax(dim=1).view(-1) 277 | batch_acc += torch.sum( 278 | memory_indices_pred[region_of_interest] == 279 | self.gt_indices_list[codebook_idx].view( 280 | -1)[region_of_interest]) 281 | memory_indices_pred = memory_indices_pred 282 | min_encodings_indices_list[codebook_idx][ 283 | region_of_interest] = memory_indices_pred[ 284 | region_of_interest] 285 | min_encodings_indices_return_list = [ 286 | min_encodings_indices.view(self.gt_indices_list[0].size()) 287 | for min_encodings_indices in min_encodings_indices_list 288 | ] 289 | batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel( 290 | ) * self.image.size(0) 291 | acc += batch_acc 292 | self.get_vis(min_encodings_indices_return_list, 293 | self.gt_indices_list, self.texture_mask, 294 | f'{save_dir}/{img_name[0]}') 295 | 296 | self.guidance_encoder.train() 297 | self.index_decoder.train() 298 | return (acc / num).item() 299 | 300 | def load_network(self): 301 | checkpoint = torch.load(self.opt['pretrained_models']) 302 | self.guidance_encoder.load_state_dict( 303 | checkpoint['guidance_encoder'], strict=True) 304 | self.guidance_encoder.eval() 305 | 306 | self.index_decoder.load_state_dict( 307 | checkpoint['index_decoder'], strict=True) 308 | self.index_decoder.eval() 309 | 310 | def save_network(self, save_path): 311 | """Save networks. 312 | 313 | Args: 314 | net (nn.Module): Network to be saved. 315 | net_label (str): Network label. 316 | current_iter (int): Current iter number. 317 | """ 318 | 319 | save_dict = {} 320 | save_dict['guidance_encoder'] = self.guidance_encoder.state_dict() 321 | save_dict['index_decoder'] = self.index_decoder.state_dict() 322 | 323 | torch.save(save_dict, save_path) 324 | 325 | def update_learning_rate(self, epoch): 326 | """Update learning rate. 327 | 328 | Args: 329 | current_iter (int): Current iteration. 330 | warmup_iter (int): Warmup iter numbers. -1 for no warmup. 331 | Default: -1. 332 | """ 333 | lr = self.optimizer.param_groups[0]['lr'] 334 | 335 | if self.opt['lr_decay'] == 'step': 336 | lr = self.opt['lr'] * ( 337 | self.opt['gamma']**(epoch // self.opt['step'])) 338 | elif self.opt['lr_decay'] == 'cos': 339 | lr = self.opt['lr'] * ( 340 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 341 | elif self.opt['lr_decay'] == 'linear': 342 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) 343 | elif self.opt['lr_decay'] == 'linear2exp': 344 | if epoch < self.opt['turning_point'] + 1: 345 | # learning rate decay as 95% 346 | # at the turning point (1 / 95% = 1.0526) 347 | lr = self.opt['lr'] * ( 348 | 1 - epoch / int(self.opt['turning_point'] * 1.0526)) 349 | else: 350 | lr *= self.opt['gamma'] 351 | elif self.opt['lr_decay'] == 'schedule': 352 | if epoch in self.opt['schedule']: 353 | lr *= self.opt['gamma'] 354 | else: 355 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) 356 | # set learning rate 357 | for param_group in self.optimizer.param_groups: 358 | param_group['lr'] = lr 359 | 360 | return lr 361 | 362 | def get_current_log(self): 363 | return self.log_dict 364 | -------------------------------------------------------------------------------- /models/hierarchy_vqgan_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from collections import OrderedDict 4 | 5 | sys.path.append('..') 6 | import lpips 7 | import torch 8 | import torch.nn.functional as F 9 | from torchvision.utils import save_image 10 | 11 | from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator, 12 | Encoder, 13 | VectorQuantizerSpatialTextureAware, 14 | VectorQuantizerTexture) 15 | from models.losses.vqgan_loss import (DiffAugment, adopt_weight, 16 | calculate_adaptive_weight, hinge_d_loss) 17 | 18 | 19 | class HierarchyVQSpatialTextureAwareModel(): 20 | 21 | def __init__(self, opt): 22 | self.opt = opt 23 | self.device = torch.device('cuda') 24 | self.top_encoder = Encoder( 25 | ch=opt['top_ch'], 26 | num_res_blocks=opt['top_num_res_blocks'], 27 | attn_resolutions=opt['top_attn_resolutions'], 28 | ch_mult=opt['top_ch_mult'], 29 | in_channels=opt['top_in_channels'], 30 | resolution=opt['top_resolution'], 31 | z_channels=opt['top_z_channels'], 32 | double_z=opt['top_double_z'], 33 | dropout=opt['top_dropout']).to(self.device) 34 | self.decoder = Decoder( 35 | in_channels=opt['top_in_channels'], 36 | resolution=opt['top_resolution'], 37 | z_channels=opt['top_z_channels'], 38 | ch=opt['top_ch'], 39 | out_ch=opt['top_out_ch'], 40 | num_res_blocks=opt['top_num_res_blocks'], 41 | attn_resolutions=opt['top_attn_resolutions'], 42 | ch_mult=opt['top_ch_mult'], 43 | dropout=opt['top_dropout'], 44 | resamp_with_conv=True, 45 | give_pre_end=False).to(self.device) 46 | self.top_quantize = VectorQuantizerTexture( 47 | 1024, opt['embed_dim'], beta=0.25).to(self.device) 48 | self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"], 49 | opt['embed_dim'], 50 | 1).to(self.device) 51 | self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], 52 | opt["top_z_channels"], 53 | 1).to(self.device) 54 | self.load_top_pretrain_models() 55 | 56 | self.bot_encoder = Encoder( 57 | ch=opt['bot_ch'], 58 | num_res_blocks=opt['bot_num_res_blocks'], 59 | attn_resolutions=opt['bot_attn_resolutions'], 60 | ch_mult=opt['bot_ch_mult'], 61 | in_channels=opt['bot_in_channels'], 62 | resolution=opt['bot_resolution'], 63 | z_channels=opt['bot_z_channels'], 64 | double_z=opt['bot_double_z'], 65 | dropout=opt['bot_dropout']).to(self.device) 66 | self.bot_decoder_res = DecoderRes( 67 | in_channels=opt['bot_in_channels'], 68 | resolution=opt['bot_resolution'], 69 | z_channels=opt['bot_z_channels'], 70 | ch=opt['bot_ch'], 71 | num_res_blocks=opt['bot_num_res_blocks'], 72 | ch_mult=opt['bot_ch_mult'], 73 | dropout=opt['bot_dropout'], 74 | give_pre_end=False).to(self.device) 75 | self.bot_quantize = VectorQuantizerSpatialTextureAware( 76 | opt['bot_n_embed'], 77 | opt['embed_dim'], 78 | beta=0.25, 79 | spatial_size=opt['codebook_spatial_size']).to(self.device) 80 | self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"], 81 | opt['embed_dim'], 82 | 1).to(self.device) 83 | self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], 84 | opt["bot_z_channels"], 85 | 1).to(self.device) 86 | 87 | self.disc = Discriminator( 88 | opt['n_channels'], opt['ndf'], 89 | n_layers=opt['disc_layers']).to(self.device) 90 | self.perceptual = lpips.LPIPS(net="vgg").to(self.device) 91 | self.perceptual_weight = opt['perceptual_weight'] 92 | self.disc_start_step = opt['disc_start_step'] 93 | self.disc_weight_max = opt['disc_weight_max'] 94 | self.diff_aug = opt['diff_aug'] 95 | self.policy = "color,translation" 96 | 97 | self.load_discriminator_models() 98 | 99 | self.disc.train() 100 | 101 | self.fix_decoder = opt['fix_decoder'] 102 | 103 | self.init_training_settings() 104 | 105 | def load_top_pretrain_models(self): 106 | # load pretrained vqgan for segmentation mask 107 | top_vae_checkpoint = torch.load(self.opt['top_vae_path']) 108 | self.top_encoder.load_state_dict( 109 | top_vae_checkpoint['encoder'], strict=True) 110 | self.decoder.load_state_dict( 111 | top_vae_checkpoint['decoder'], strict=True) 112 | self.top_quantize.load_state_dict( 113 | top_vae_checkpoint['quantize'], strict=True) 114 | self.top_quant_conv.load_state_dict( 115 | top_vae_checkpoint['quant_conv'], strict=True) 116 | self.top_post_quant_conv.load_state_dict( 117 | top_vae_checkpoint['post_quant_conv'], strict=True) 118 | self.top_encoder.eval() 119 | self.top_quantize.eval() 120 | self.top_quant_conv.eval() 121 | self.top_post_quant_conv.eval() 122 | 123 | def init_training_settings(self): 124 | self.log_dict = OrderedDict() 125 | self.configure_optimizers() 126 | 127 | def configure_optimizers(self): 128 | optim_params = [] 129 | for v in self.bot_encoder.parameters(): 130 | if v.requires_grad: 131 | optim_params.append(v) 132 | for v in self.bot_decoder_res.parameters(): 133 | if v.requires_grad: 134 | optim_params.append(v) 135 | for v in self.bot_quantize.parameters(): 136 | if v.requires_grad: 137 | optim_params.append(v) 138 | for v in self.bot_quant_conv.parameters(): 139 | if v.requires_grad: 140 | optim_params.append(v) 141 | for v in self.bot_post_quant_conv.parameters(): 142 | if v.requires_grad: 143 | optim_params.append(v) 144 | if not self.fix_decoder: 145 | for name, v in self.decoder.named_parameters(): 146 | if v.requires_grad: 147 | if 'up.0' in name: 148 | optim_params.append(v) 149 | if 'up.1' in name: 150 | optim_params.append(v) 151 | if 'up.2' in name: 152 | optim_params.append(v) 153 | if 'up.3' in name: 154 | optim_params.append(v) 155 | 156 | self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr']) 157 | 158 | self.disc_optimizer = torch.optim.Adam( 159 | self.disc.parameters(), lr=self.opt['lr']) 160 | 161 | def load_discriminator_models(self): 162 | # load pretrained vqgan for segmentation mask 163 | top_vae_checkpoint = torch.load(self.opt['top_vae_path']) 164 | self.disc.load_state_dict( 165 | top_vae_checkpoint['discriminator'], strict=True) 166 | 167 | def save_network(self, save_path): 168 | """Save networks. 169 | """ 170 | 171 | save_dict = {} 172 | save_dict['bot_encoder'] = self.bot_encoder.state_dict() 173 | save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict() 174 | save_dict['decoder'] = self.decoder.state_dict() 175 | save_dict['bot_quantize'] = self.bot_quantize.state_dict() 176 | save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict() 177 | save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict( 178 | ) 179 | save_dict['discriminator'] = self.disc.state_dict() 180 | torch.save(save_dict, save_path) 181 | 182 | def load_network(self): 183 | checkpoint = torch.load(self.opt['pretrained_models']) 184 | self.bot_encoder.load_state_dict( 185 | checkpoint['bot_encoder'], strict=True) 186 | self.bot_decoder_res.load_state_dict( 187 | checkpoint['bot_decoder_res'], strict=True) 188 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True) 189 | self.bot_quantize.load_state_dict( 190 | checkpoint['bot_quantize'], strict=True) 191 | self.bot_quant_conv.load_state_dict( 192 | checkpoint['bot_quant_conv'], strict=True) 193 | self.bot_post_quant_conv.load_state_dict( 194 | checkpoint['bot_post_quant_conv'], strict=True) 195 | 196 | def optimize_parameters(self, data, step): 197 | self.bot_encoder.train() 198 | self.bot_decoder_res.train() 199 | if not self.fix_decoder: 200 | self.decoder.train() 201 | self.bot_quantize.train() 202 | self.bot_quant_conv.train() 203 | self.bot_post_quant_conv.train() 204 | 205 | loss, d_loss = self.training_step(data, step) 206 | self.optimizer.zero_grad() 207 | loss.backward() 208 | self.optimizer.step() 209 | 210 | if step > self.disc_start_step: 211 | self.disc_optimizer.zero_grad() 212 | d_loss.backward() 213 | self.disc_optimizer.step() 214 | 215 | def top_encode(self, x, mask): 216 | h = self.top_encoder(x) 217 | h = self.top_quant_conv(h) 218 | quant, _, _ = self.top_quantize(h, mask) 219 | quant = self.top_post_quant_conv(quant) 220 | return quant 221 | 222 | def bot_encode(self, x, mask): 223 | h = self.bot_encoder(x) 224 | h = self.bot_quant_conv(h) 225 | quant, emb_loss, info = self.bot_quantize(h, mask) 226 | quant = self.bot_post_quant_conv(quant) 227 | bot_dec_res = self.bot_decoder_res(quant) 228 | return bot_dec_res, emb_loss, info 229 | 230 | def decode(self, quant_top, bot_dec_res): 231 | dec = self.decoder(quant_top, bot_h=bot_dec_res) 232 | return dec 233 | 234 | def forward_step(self, input, mask): 235 | with torch.no_grad(): 236 | quant_top = self.top_encode(input, mask) 237 | bot_dec_res, diff, _ = self.bot_encode(input, mask) 238 | dec = self.decode(quant_top, bot_dec_res) 239 | return dec, diff 240 | 241 | def feed_data(self, data): 242 | x = data['image'].float().to(self.device) 243 | mask = data['texture_mask'].float().to(self.device) 244 | 245 | return x, mask 246 | 247 | def training_step(self, data, step): 248 | x, mask = self.feed_data(data) 249 | xrec, codebook_loss = self.forward_step(x, mask) 250 | 251 | # get recon/perceptual loss 252 | recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) 253 | p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) 254 | nll_loss = recon_loss + self.perceptual_weight * p_loss 255 | nll_loss = torch.mean(nll_loss) 256 | 257 | # augment for input to discriminator 258 | if self.diff_aug: 259 | xrec = DiffAugment(xrec, policy=self.policy) 260 | 261 | # update generator 262 | logits_fake = self.disc(xrec) 263 | g_loss = -torch.mean(logits_fake) 264 | last_layer = self.decoder.conv_out.weight 265 | d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer, 266 | self.disc_weight_max) 267 | d_weight *= adopt_weight(1, step, self.disc_start_step) 268 | loss = nll_loss + d_weight * g_loss + codebook_loss 269 | 270 | self.log_dict["loss"] = loss 271 | self.log_dict["l1"] = recon_loss.mean().item() 272 | self.log_dict["perceptual"] = p_loss.mean().item() 273 | self.log_dict["nll_loss"] = nll_loss.item() 274 | self.log_dict["g_loss"] = g_loss.item() 275 | self.log_dict["d_weight"] = d_weight 276 | self.log_dict["codebook_loss"] = codebook_loss.item() 277 | 278 | if step > self.disc_start_step: 279 | if self.diff_aug: 280 | logits_real = self.disc( 281 | DiffAugment(x.contiguous().detach(), policy=self.policy)) 282 | else: 283 | logits_real = self.disc(x.contiguous().detach()) 284 | logits_fake = self.disc(xrec.contiguous().detach( 285 | )) # detach so that generator isn"t also updated 286 | d_loss = hinge_d_loss(logits_real, logits_fake) 287 | self.log_dict["d_loss"] = d_loss 288 | else: 289 | d_loss = None 290 | 291 | return loss, d_loss 292 | 293 | @torch.no_grad() 294 | def inference(self, data_loader, save_dir): 295 | self.bot_encoder.eval() 296 | self.bot_decoder_res.eval() 297 | self.decoder.eval() 298 | self.bot_quantize.eval() 299 | self.bot_quant_conv.eval() 300 | self.bot_post_quant_conv.eval() 301 | 302 | loss_total = 0 303 | num = 0 304 | 305 | for _, data in enumerate(data_loader): 306 | img_name = data['img_name'][0] 307 | x, mask = self.feed_data(data) 308 | xrec, _ = self.forward_step(x, mask) 309 | 310 | recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) 311 | p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) 312 | nll_loss = recon_loss + self.perceptual_weight * p_loss 313 | nll_loss = torch.mean(nll_loss) 314 | loss_total += nll_loss 315 | 316 | num += x.size(0) 317 | 318 | if x.shape[1] > 3: 319 | # colorize with random projection 320 | assert xrec.shape[1] > 3 321 | # convert logits to indices 322 | xrec = torch.argmax(xrec, dim=1, keepdim=True) 323 | xrec = F.one_hot(xrec, num_classes=x.shape[1]) 324 | xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() 325 | x = self.to_rgb(x) 326 | xrec = self.to_rgb(xrec) 327 | 328 | img_cat = torch.cat([x, xrec], dim=3).detach() 329 | img_cat = ((img_cat + 1) / 2) 330 | img_cat = img_cat.clamp_(0, 1) 331 | save_image( 332 | img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) 333 | 334 | return (loss_total / num).item() 335 | 336 | def get_current_log(self): 337 | return self.log_dict 338 | 339 | def update_learning_rate(self, epoch): 340 | """Update learning rate. 341 | 342 | Args: 343 | current_iter (int): Current iteration. 344 | warmup_iter (int): Warmup iter numbers. -1 for no warmup. 345 | Default: -1. 346 | """ 347 | lr = self.optimizer.param_groups[0]['lr'] 348 | 349 | if self.opt['lr_decay'] == 'step': 350 | lr = self.opt['lr'] * ( 351 | self.opt['gamma']**(epoch // self.opt['step'])) 352 | elif self.opt['lr_decay'] == 'cos': 353 | lr = self.opt['lr'] * ( 354 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 355 | elif self.opt['lr_decay'] == 'linear': 356 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) 357 | elif self.opt['lr_decay'] == 'linear2exp': 358 | if epoch < self.opt['turning_point'] + 1: 359 | # learning rate decay as 95% 360 | # at the turning point (1 / 95% = 1.0526) 361 | lr = self.opt['lr'] * ( 362 | 1 - epoch / int(self.opt['turning_point'] * 1.0526)) 363 | else: 364 | lr *= self.opt['gamma'] 365 | elif self.opt['lr_decay'] == 'schedule': 366 | if epoch in self.opt['schedule']: 367 | lr *= self.opt['gamma'] 368 | else: 369 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) 370 | # set learning rate 371 | for param_group in self.optimizer.param_groups: 372 | param_group['lr'] = lr 373 | 374 | return lr 375 | -------------------------------------------------------------------------------- /models/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/models/losses/__init__.py -------------------------------------------------------------------------------- /models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | def accuracy(pred, target, topk=1, thresh=None): 2 | """Calculate accuracy according to the prediction and target. 3 | 4 | Args: 5 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 6 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 7 | topk (int | tuple[int], optional): If the predictions in ``topk`` 8 | matches the target, the predictions will be regarded as 9 | correct ones. Defaults to 1. 10 | thresh (float, optional): If not None, predictions with scores under 11 | this threshold are considered incorrect. Default to None. 12 | 13 | Returns: 14 | float | tuple[float]: If the input ``topk`` is a single integer, 15 | the function will return a single float as accuracy. If 16 | ``topk`` is a tuple containing multiple integers, the 17 | function will return a tuple containing accuracies of 18 | each ``topk`` number. 19 | """ 20 | assert isinstance(topk, (int, tuple)) 21 | if isinstance(topk, int): 22 | topk = (topk, ) 23 | return_single = True 24 | else: 25 | return_single = False 26 | 27 | maxk = max(topk) 28 | if pred.size(0) == 0: 29 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 30 | return accu[0] if return_single else accu 31 | assert pred.ndim == target.ndim + 1 32 | assert pred.size(0) == target.size(0) 33 | assert maxk <= pred.size(1), \ 34 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 35 | pred_value, pred_label = pred.topk(maxk, dim=1) 36 | # transpose to shape (maxk, N, ...) 37 | pred_label = pred_label.transpose(0, 1) 38 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 39 | if thresh is not None: 40 | # Only prediction values larger than thresh are counted as correct 41 | correct = correct & (pred_value > thresh).t() 42 | res = [] 43 | for k in topk: 44 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 45 | res.append(correct_k.mul_(100.0 / target.numel())) 46 | return res[0] if return_single else res 47 | -------------------------------------------------------------------------------- /models/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are "none", "mean" and "sum". 12 | 13 | Return: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | elif reduction_enum == 2: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. 32 | reduction (str): Same as built-in losses of PyTorch. 33 | avg_factor (float): Avarage factor when computing the mean of losses. 34 | 35 | Returns: 36 | Tensor: Processed loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | if weight.dim() > 1: 42 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 43 | loss = loss * weight 44 | 45 | # if avg_factor is not specified, just reduce the loss 46 | if avg_factor is None: 47 | loss = reduce_loss(loss, reduction) 48 | else: 49 | # if reduction is mean, then average the loss by avg_factor 50 | if reduction == 'mean': 51 | loss = loss.sum() / avg_factor 52 | # if reduction is 'none', then do nothing, otherwise raise an error 53 | elif reduction != 'none': 54 | raise ValueError('avg_factor can not be used with reduction="sum"') 55 | return loss 56 | 57 | 58 | def cross_entropy(pred, 59 | label, 60 | weight=None, 61 | class_weight=None, 62 | reduction='mean', 63 | avg_factor=None, 64 | ignore_index=-100): 65 | """The wrapper function for :func:`F.cross_entropy`""" 66 | # class_weight is a manual rescaling weight given to each class. 67 | # If given, has to be a Tensor of size C element-wise losses 68 | loss = F.cross_entropy( 69 | pred, 70 | label, 71 | weight=class_weight, 72 | reduction='none', 73 | ignore_index=ignore_index) 74 | 75 | # apply weights and do the reduction 76 | if weight is not None: 77 | weight = weight.float() 78 | loss = weight_reduce_loss( 79 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor) 80 | 81 | return loss 82 | 83 | 84 | def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): 85 | """Expand onehot labels to match the size of prediction.""" 86 | bin_labels = labels.new_zeros(target_shape) 87 | valid_mask = (labels >= 0) & (labels != ignore_index) 88 | inds = torch.nonzero(valid_mask, as_tuple=True) 89 | 90 | if inds[0].numel() > 0: 91 | if labels.dim() == 3: 92 | bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 93 | else: 94 | bin_labels[inds[0], labels[valid_mask]] = 1 95 | 96 | valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() 97 | if label_weights is None: 98 | bin_label_weights = valid_mask 99 | else: 100 | bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) 101 | bin_label_weights *= valid_mask 102 | 103 | return bin_labels, bin_label_weights 104 | 105 | 106 | def binary_cross_entropy(pred, 107 | label, 108 | weight=None, 109 | reduction='mean', 110 | avg_factor=None, 111 | class_weight=None, 112 | ignore_index=255): 113 | """Calculate the binary CrossEntropy loss. 114 | 115 | Args: 116 | pred (torch.Tensor): The prediction with shape (N, 1). 117 | label (torch.Tensor): The learning label of the prediction. 118 | weight (torch.Tensor, optional): Sample-wise loss weight. 119 | reduction (str, optional): The method used to reduce the loss. 120 | Options are "none", "mean" and "sum". 121 | avg_factor (int, optional): Average factor that is used to average 122 | the loss. Defaults to None. 123 | class_weight (list[float], optional): The weight for each class. 124 | ignore_index (int | None): The label index to be ignored. Default: 255 125 | 126 | Returns: 127 | torch.Tensor: The calculated loss 128 | """ 129 | if pred.dim() != label.dim(): 130 | assert (pred.dim() == 2 and label.dim() == 1) or ( 131 | pred.dim() == 4 and label.dim() == 3), \ 132 | 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ 133 | 'H, W], label shape [N, H, W] are supported' 134 | label, weight = _expand_onehot_labels(label, weight, pred.shape, 135 | ignore_index) 136 | 137 | # weighted element-wise losses 138 | if weight is not None: 139 | weight = weight.float() 140 | loss = F.binary_cross_entropy_with_logits( 141 | pred, label.float(), pos_weight=class_weight, reduction='none') 142 | # do the reduction for the weighted loss 143 | loss = weight_reduce_loss( 144 | loss, weight, reduction=reduction, avg_factor=avg_factor) 145 | 146 | return loss 147 | 148 | 149 | def mask_cross_entropy(pred, 150 | target, 151 | label, 152 | reduction='mean', 153 | avg_factor=None, 154 | class_weight=None, 155 | ignore_index=None): 156 | """Calculate the CrossEntropy loss for masks. 157 | 158 | Args: 159 | pred (torch.Tensor): The prediction with shape (N, C), C is the number 160 | of classes. 161 | target (torch.Tensor): The learning label of the prediction. 162 | label (torch.Tensor): ``label`` indicates the class label of the mask' 163 | corresponding object. This will be used to select the mask in the 164 | of the class which the object belongs to when the mask prediction 165 | if not class-agnostic. 166 | reduction (str, optional): The method used to reduce the loss. 167 | Options are "none", "mean" and "sum". 168 | avg_factor (int, optional): Average factor that is used to average 169 | the loss. Defaults to None. 170 | class_weight (list[float], optional): The weight for each class. 171 | ignore_index (None): Placeholder, to be consistent with other loss. 172 | Default: None. 173 | 174 | Returns: 175 | torch.Tensor: The calculated loss 176 | """ 177 | assert ignore_index is None, 'BCE loss does not support ignore_index' 178 | # TODO: handle these two reserved arguments 179 | assert reduction == 'mean' and avg_factor is None 180 | num_rois = pred.size()[0] 181 | inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) 182 | pred_slice = pred[inds, label].squeeze(1) 183 | return F.binary_cross_entropy_with_logits( 184 | pred_slice, target, weight=class_weight, reduction='mean')[None] 185 | 186 | 187 | class CrossEntropyLoss(nn.Module): 188 | """CrossEntropyLoss. 189 | 190 | Args: 191 | use_sigmoid (bool, optional): Whether the prediction uses sigmoid 192 | of softmax. Defaults to False. 193 | use_mask (bool, optional): Whether to use mask cross entropy loss. 194 | Defaults to False. 195 | reduction (str, optional): . Defaults to 'mean'. 196 | Options are "none", "mean" and "sum". 197 | class_weight (list[float], optional): Weight of each class. 198 | Defaults to None. 199 | loss_weight (float, optional): Weight of the loss. Defaults to 1.0. 200 | """ 201 | 202 | def __init__(self, 203 | use_sigmoid=False, 204 | use_mask=False, 205 | reduction='mean', 206 | class_weight=None, 207 | loss_weight=1.0): 208 | super(CrossEntropyLoss, self).__init__() 209 | assert (use_sigmoid is False) or (use_mask is False) 210 | self.use_sigmoid = use_sigmoid 211 | self.use_mask = use_mask 212 | self.reduction = reduction 213 | self.loss_weight = loss_weight 214 | self.class_weight = class_weight 215 | 216 | if self.use_sigmoid: 217 | self.cls_criterion = binary_cross_entropy 218 | elif self.use_mask: 219 | self.cls_criterion = mask_cross_entropy 220 | else: 221 | self.cls_criterion = cross_entropy 222 | 223 | def forward(self, 224 | cls_score, 225 | label, 226 | weight=None, 227 | avg_factor=None, 228 | reduction_override=None, 229 | **kwargs): 230 | """Forward function.""" 231 | assert reduction_override in (None, 'none', 'mean', 'sum') 232 | reduction = ( 233 | reduction_override if reduction_override else self.reduction) 234 | if self.class_weight is not None: 235 | class_weight = cls_score.new_tensor(self.class_weight) 236 | else: 237 | class_weight = None 238 | loss_cls = self.loss_weight * self.cls_criterion( 239 | cls_score, 240 | label, 241 | weight, 242 | class_weight=class_weight, 243 | reduction=reduction, 244 | avg_factor=avg_factor, 245 | **kwargs) 246 | return loss_cls 247 | -------------------------------------------------------------------------------- /models/losses/segmentation_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | 7 | def forward(self, prediction, target): 8 | loss = F.binary_cross_entropy_with_logits(prediction, target) 9 | return loss, {} 10 | 11 | 12 | class BCELossWithQuant(nn.Module): 13 | 14 | def __init__(self, codebook_weight=1.): 15 | super().__init__() 16 | self.codebook_weight = codebook_weight 17 | 18 | def forward(self, qloss, target, prediction, split): 19 | bce_loss = F.binary_cross_entropy_with_logits(prediction, target) 20 | loss = bce_loss + self.codebook_weight * qloss 21 | return loss, { 22 | "{}/total_loss".format(split): loss.clone().detach().mean(), 23 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 24 | "{}/quant_loss".format(split): qloss.detach().mean() 25 | } 26 | -------------------------------------------------------------------------------- /models/losses/vqgan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max): 6 | recon_grads = torch.autograd.grad( 7 | recon_loss, last_layer, retain_graph=True)[0] 8 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 9 | 10 | d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) 11 | d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() 12 | return d_weight 13 | 14 | 15 | def adopt_weight(weight, global_step, threshold=0, value=0.): 16 | if global_step < threshold: 17 | weight = value 18 | return weight 19 | 20 | 21 | @torch.jit.script 22 | def hinge_d_loss(logits_real, logits_fake): 23 | loss_real = torch.mean(F.relu(1. - logits_real)) 24 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 25 | d_loss = 0.5 * (loss_real + loss_fake) 26 | return d_loss 27 | 28 | 29 | def DiffAugment(x, policy='', channels_first=True): 30 | if policy: 31 | if not channels_first: 32 | x = x.permute(0, 3, 1, 2) 33 | for p in policy.split(','): 34 | for f in AUGMENT_FNS[p]: 35 | x = f(x) 36 | if not channels_first: 37 | x = x.permute(0, 2, 3, 1) 38 | x = x.contiguous() 39 | return x 40 | 41 | 42 | def rand_brightness(x): 43 | x = x + ( 44 | torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 45 | return x 46 | 47 | 48 | def rand_saturation(x): 49 | x_mean = x.mean(dim=1, keepdim=True) 50 | x = (x - x_mean) * (torch.rand( 51 | x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 52 | return x 53 | 54 | 55 | def rand_contrast(x): 56 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 57 | x = (x - x_mean) * (torch.rand( 58 | x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 59 | return x 60 | 61 | 62 | def rand_translation(x, ratio=0.125): 63 | shift_x, shift_y = int(x.size(2) * ratio + 64 | 0.5), int(x.size(3) * ratio + 0.5) 65 | translation_x = torch.randint( 66 | -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 67 | translation_y = torch.randint( 68 | -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 69 | grid_batch, grid_x, grid_y = torch.meshgrid( 70 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 71 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 72 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 73 | ) 74 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 75 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 76 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 77 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, 78 | grid_y].permute(0, 3, 1, 2) 79 | return x 80 | 81 | 82 | def rand_cutout(x, ratio=0.5): 83 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 84 | offset_x = torch.randint( 85 | 0, 86 | x.size(2) + (1 - cutout_size[0] % 2), 87 | size=[x.size(0), 1, 1], 88 | device=x.device) 89 | offset_y = torch.randint( 90 | 0, 91 | x.size(3) + (1 - cutout_size[1] % 2), 92 | size=[x.size(0), 1, 1], 93 | device=x.device) 94 | grid_batch, grid_x, grid_y = torch.meshgrid( 95 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 96 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 97 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 98 | ) 99 | grid_x = torch.clamp( 100 | grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 101 | grid_y = torch.clamp( 102 | grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 103 | mask = torch.ones( 104 | x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 105 | mask[grid_batch, grid_x, grid_y] = 0 106 | x = x * mask.unsqueeze(1) 107 | return x 108 | 109 | 110 | AUGMENT_FNS = { 111 | 'color': [rand_brightness, rand_saturation, rand_contrast], 112 | 'translation': [rand_translation], 113 | 'cutout': [rand_cutout], 114 | } 115 | -------------------------------------------------------------------------------- /models/parsing_gen_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from collections import OrderedDict 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch 8 | from torchvision.utils import save_image 9 | 10 | from models.archs.fcn_arch import FCNHead 11 | from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding 12 | from models.archs.unet_arch import ShapeUNet 13 | from models.losses.accuracy import accuracy 14 | from models.losses.cross_entropy_loss import CrossEntropyLoss 15 | 16 | logger = logging.getLogger('base') 17 | 18 | 19 | class ParsingGenModel(): 20 | """Paring Generation model. 21 | """ 22 | 23 | def __init__(self, opt): 24 | self.opt = opt 25 | self.device = torch.device('cuda') 26 | self.is_train = opt['is_train'] 27 | 28 | self.attr_embedder = ShapeAttrEmbedding( 29 | dim=opt['embedder_dim'], 30 | out_dim=opt['embedder_out_dim'], 31 | cls_num_list=opt['attr_class_num']).to(self.device) 32 | self.parsing_encoder = ShapeUNet( 33 | in_channels=opt['encoder_in_channels']).to(self.device) 34 | self.parsing_decoder = FCNHead( 35 | in_channels=opt['fc_in_channels'], 36 | in_index=opt['fc_in_index'], 37 | channels=opt['fc_channels'], 38 | num_convs=opt['fc_num_convs'], 39 | concat_input=opt['fc_concat_input'], 40 | dropout_ratio=opt['fc_dropout_ratio'], 41 | num_classes=opt['fc_num_classes'], 42 | align_corners=opt['fc_align_corners'], 43 | ).to(self.device) 44 | 45 | self.init_training_settings() 46 | 47 | self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220], 48 | [250, 235, 215], [255, 250, 205], [211, 211, 211], 49 | [70, 130, 180], [127, 255, 212], [0, 100, 0], 50 | [50, 205, 50], [255, 255, 0], [245, 222, 179], 51 | [255, 140, 0], [255, 0, 0], [16, 78, 139], 52 | [144, 238, 144], [50, 205, 174], [50, 155, 250], 53 | [160, 140, 88], [213, 140, 88], [90, 140, 90], 54 | [185, 210, 205], [130, 165, 180], [225, 141, 151]] 55 | 56 | def init_training_settings(self): 57 | optim_params = [] 58 | for v in self.attr_embedder.parameters(): 59 | if v.requires_grad: 60 | optim_params.append(v) 61 | for v in self.parsing_encoder.parameters(): 62 | if v.requires_grad: 63 | optim_params.append(v) 64 | for v in self.parsing_decoder.parameters(): 65 | if v.requires_grad: 66 | optim_params.append(v) 67 | # set up optimizers 68 | self.optimizer = torch.optim.Adam( 69 | optim_params, 70 | self.opt['lr'], 71 | weight_decay=self.opt['weight_decay']) 72 | self.log_dict = OrderedDict() 73 | self.entropy_loss = CrossEntropyLoss().to(self.device) 74 | 75 | def feed_data(self, data): 76 | self.pose = data['densepose'].to(self.device) 77 | self.attr = data['attr'].to(self.device) 78 | self.segm = data['segm'].to(self.device) 79 | 80 | def optimize_parameters(self): 81 | self.attr_embedder.train() 82 | self.parsing_encoder.train() 83 | self.parsing_decoder.train() 84 | 85 | self.attr_embedding = self.attr_embedder(self.attr) 86 | self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding) 87 | self.seg_logits = self.parsing_decoder(self.pose_enc) 88 | 89 | loss = self.entropy_loss(self.seg_logits, self.segm) 90 | 91 | self.optimizer.zero_grad() 92 | loss.backward() 93 | self.optimizer.step() 94 | 95 | self.log_dict['loss_total'] = loss 96 | 97 | def get_vis(self, save_path): 98 | img_cat = torch.cat([ 99 | self.pose, 100 | self.segm, 101 | ], dim=3).detach() 102 | img_cat = ((img_cat + 1) / 2) 103 | 104 | img_cat = img_cat.clamp_(0, 1) 105 | 106 | save_image(img_cat, save_path, nrow=1, padding=4) 107 | 108 | def inference(self, data_loader, save_dir): 109 | self.attr_embedder.eval() 110 | self.parsing_encoder.eval() 111 | self.parsing_decoder.eval() 112 | 113 | acc = 0 114 | num = 0 115 | 116 | for _, data in enumerate(data_loader): 117 | pose = data['densepose'].to(self.device) 118 | attr = data['attr'].to(self.device) 119 | segm = data['segm'].to(self.device) 120 | img_name = data['img_name'] 121 | 122 | num += pose.size(0) 123 | with torch.no_grad(): 124 | attr_embedding = self.attr_embedder(attr) 125 | pose_enc = self.parsing_encoder(pose, attr_embedding) 126 | seg_logits = self.parsing_decoder(pose_enc) 127 | seg_pred = seg_logits.argmax(dim=1) 128 | acc += accuracy(seg_logits, segm) 129 | palette_label = self.palette_result(segm.cpu().numpy()) 130 | palette_pred = self.palette_result(seg_pred.cpu().numpy()) 131 | pose_numpy = ((pose[0] + 1) / 2. * 255.).expand( 132 | 3, 133 | pose[0].size(1), 134 | pose[0].size(2), 135 | ).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0) 136 | concat_result = np.concatenate( 137 | (pose_numpy, palette_pred, palette_label), axis=1) 138 | mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}') 139 | 140 | self.attr_embedder.train() 141 | self.parsing_encoder.train() 142 | self.parsing_decoder.train() 143 | return (acc / num).item() 144 | 145 | def get_current_log(self): 146 | return self.log_dict 147 | 148 | def update_learning_rate(self, epoch): 149 | """Update learning rate. 150 | 151 | Args: 152 | current_iter (int): Current iteration. 153 | warmup_iter (int): Warmup iter numbers. -1 for no warmup. 154 | Default: -1. 155 | """ 156 | lr = self.optimizer.param_groups[0]['lr'] 157 | 158 | if self.opt['lr_decay'] == 'step': 159 | lr = self.opt['lr'] * ( 160 | self.opt['gamma']**(epoch // self.opt['step'])) 161 | elif self.opt['lr_decay'] == 'cos': 162 | lr = self.opt['lr'] * ( 163 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 164 | elif self.opt['lr_decay'] == 'linear': 165 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) 166 | elif self.opt['lr_decay'] == 'linear2exp': 167 | if epoch < self.opt['turning_point'] + 1: 168 | # learning rate decay as 95% 169 | # at the turning point (1 / 95% = 1.0526) 170 | lr = self.opt['lr'] * ( 171 | 1 - epoch / int(self.opt['turning_point'] * 1.0526)) 172 | else: 173 | lr *= self.opt['gamma'] 174 | elif self.opt['lr_decay'] == 'schedule': 175 | if epoch in self.opt['schedule']: 176 | lr *= self.opt['gamma'] 177 | else: 178 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) 179 | # set learning rate 180 | for param_group in self.optimizer.param_groups: 181 | param_group['lr'] = lr 182 | 183 | return lr 184 | 185 | def save_network(self, save_path): 186 | """Save networks. 187 | """ 188 | 189 | save_dict = {} 190 | save_dict['embedder'] = self.attr_embedder.state_dict() 191 | save_dict['encoder'] = self.parsing_encoder.state_dict() 192 | save_dict['decoder'] = self.parsing_decoder.state_dict() 193 | 194 | torch.save(save_dict, save_path) 195 | 196 | def load_network(self): 197 | checkpoint = torch.load(self.opt['pretrained_parsing_gen']) 198 | 199 | self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True) 200 | self.attr_embedder.eval() 201 | 202 | self.parsing_encoder.load_state_dict( 203 | checkpoint['encoder'], strict=True) 204 | self.parsing_encoder.eval() 205 | 206 | self.parsing_decoder.load_state_dict( 207 | checkpoint['decoder'], strict=True) 208 | self.parsing_decoder.eval() 209 | 210 | def palette_result(self, result): 211 | seg = result[0] 212 | palette = np.array(self.palette) 213 | assert palette.shape[1] == 3 214 | assert len(palette.shape) == 2 215 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 216 | for label, color in enumerate(palette): 217 | color_seg[seg == label, :] = color 218 | # convert to BGR 219 | color_seg = color_seg[..., ::-1] 220 | return color_seg 221 | -------------------------------------------------------------------------------- /sample_from_parsing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os.path as osp 4 | import random 5 | 6 | import torch 7 | 8 | from data.segm_attr_dataset import DeepFashionAttrSegmDataset 9 | from models import create_model 10 | from utils.logger import get_root_logger 11 | from utils.options import dict2str, dict_to_nonedict, parse 12 | from utils.util import make_exp_dirs, set_random_seed 13 | 14 | 15 | def main(): 16 | # options 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 19 | args = parser.parse_args() 20 | opt = parse(args.opt, is_train=False) 21 | 22 | # mkdir and loggers 23 | make_exp_dirs(opt) 24 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log") 25 | logger = get_root_logger( 26 | logger_name='base', log_level=logging.INFO, log_file=log_file) 27 | logger.info(dict2str(opt)) 28 | 29 | # convert to NoneDict, which returns None for missing keys 30 | opt = dict_to_nonedict(opt) 31 | 32 | # random seed 33 | seed = opt['manual_seed'] 34 | if seed is None: 35 | seed = random.randint(1, 10000) 36 | logger.info(f'Random seed: {seed}') 37 | set_random_seed(seed) 38 | 39 | test_dataset = DeepFashionAttrSegmDataset( 40 | img_dir=opt['test_img_dir'], 41 | segm_dir=opt['segm_dir'], 42 | pose_dir=opt['pose_dir'], 43 | ann_dir=opt['test_ann_file']) 44 | test_loader = torch.utils.data.DataLoader( 45 | dataset=test_dataset, batch_size=4, shuffle=False) 46 | logger.info(f'Number of test set: {len(test_dataset)}.') 47 | 48 | model = create_model(opt) 49 | _ = model.inference(test_loader, opt['path']['results_root']) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /sample_from_pose.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os.path as osp 4 | import random 5 | 6 | import torch 7 | 8 | from data.pose_attr_dataset import DeepFashionAttrPoseDataset 9 | from models import create_model 10 | from utils.logger import get_root_logger 11 | from utils.options import dict2str, dict_to_nonedict, parse 12 | from utils.util import make_exp_dirs, set_random_seed 13 | 14 | 15 | def main(): 16 | # options 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 19 | args = parser.parse_args() 20 | opt = parse(args.opt, is_train=False) 21 | 22 | # mkdir and loggers 23 | make_exp_dirs(opt) 24 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log") 25 | logger = get_root_logger( 26 | logger_name='base', log_level=logging.INFO, log_file=log_file) 27 | logger.info(dict2str(opt)) 28 | 29 | # convert to NoneDict, which returns None for missing keys 30 | opt = dict_to_nonedict(opt) 31 | 32 | # random seed 33 | seed = opt['manual_seed'] 34 | if seed is None: 35 | seed = random.randint(1, 10000) 36 | logger.info(f'Random seed: {seed}') 37 | set_random_seed(seed) 38 | 39 | test_dataset = DeepFashionAttrPoseDataset( 40 | pose_dir=opt['pose_dir'], 41 | texture_ann_dir=opt['texture_ann_file'], 42 | shape_ann_path=opt['shape_ann_path']) 43 | test_loader = torch.utils.data.DataLoader( 44 | dataset=test_dataset, batch_size=4, shuffle=False) 45 | logger.info(f'Number of test set: {len(test_dataset)}.') 46 | 47 | model = create_model(opt) 48 | _ = model.inference(test_loader, opt['path']['results_root']) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /train_index_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import random 6 | import time 7 | 8 | import torch 9 | 10 | from data.segm_attr_dataset import DeepFashionAttrSegmDataset 11 | from models import create_model 12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger 13 | from utils.options import dict2str, dict_to_nonedict, parse 14 | from utils.util import make_exp_dirs 15 | 16 | 17 | def main(): 18 | # options 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 21 | args = parser.parse_args() 22 | opt = parse(args.opt, is_train=True) 23 | 24 | # mkdir and loggers 25 | make_exp_dirs(opt) 26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 27 | logger = get_root_logger( 28 | logger_name='base', log_level=logging.INFO, log_file=log_file) 29 | logger.info(dict2str(opt)) 30 | # initialize tensorboard logger 31 | tb_logger = None 32 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name']) 34 | 35 | # convert to NoneDict, which returns None for missing keys 36 | opt = dict_to_nonedict(opt) 37 | 38 | # set up data loader 39 | train_dataset = DeepFashionAttrSegmDataset( 40 | img_dir=opt['train_img_dir'], 41 | segm_dir=opt['segm_dir'], 42 | pose_dir=opt['pose_dir'], 43 | ann_dir=opt['train_ann_file'], 44 | xflip=True) 45 | train_loader = torch.utils.data.DataLoader( 46 | dataset=train_dataset, 47 | batch_size=opt['batch_size'], 48 | shuffle=True, 49 | num_workers=opt['num_workers'], 50 | drop_last=True) 51 | logger.info(f'Number of train set: {len(train_dataset)}.') 52 | opt['max_iters'] = opt['num_epochs'] * len( 53 | train_dataset) // opt['batch_size'] 54 | 55 | val_dataset = DeepFashionAttrSegmDataset( 56 | img_dir=opt['train_img_dir'], 57 | segm_dir=opt['segm_dir'], 58 | pose_dir=opt['pose_dir'], 59 | ann_dir=opt['val_ann_file']) 60 | val_loader = torch.utils.data.DataLoader( 61 | dataset=val_dataset, batch_size=1, shuffle=False) 62 | logger.info(f'Number of val set: {len(val_dataset)}.') 63 | 64 | test_dataset = DeepFashionAttrSegmDataset( 65 | img_dir=opt['test_img_dir'], 66 | segm_dir=opt['segm_dir'], 67 | pose_dir=opt['pose_dir'], 68 | ann_dir=opt['test_ann_file']) 69 | test_loader = torch.utils.data.DataLoader( 70 | dataset=test_dataset, batch_size=1, shuffle=False) 71 | logger.info(f'Number of test set: {len(test_dataset)}.') 72 | 73 | current_iter = 0 74 | best_epoch = None 75 | best_acc = 0 76 | 77 | model = create_model(opt) 78 | 79 | data_time, iter_time = 0, 0 80 | current_iter = 0 81 | 82 | # create message logger (formatted outputs) 83 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 84 | 85 | for epoch in range(opt['num_epochs']): 86 | lr = model.update_learning_rate(epoch) 87 | 88 | for _, batch_data in enumerate(train_loader): 89 | data_time = time.time() - data_time 90 | 91 | current_iter += 1 92 | 93 | model.feed_data(batch_data) 94 | model.optimize_parameters() 95 | 96 | iter_time = time.time() - iter_time 97 | if current_iter % opt['print_freq'] == 0: 98 | log_vars = {'epoch': epoch, 'iter': current_iter} 99 | log_vars.update({'lrs': [lr]}) 100 | log_vars.update({'time': iter_time, 'data_time': data_time}) 101 | log_vars.update(model.get_current_log()) 102 | msg_logger(log_vars) 103 | 104 | data_time = time.time() 105 | iter_time = time.time() 106 | 107 | if epoch % opt['val_freq'] == 0: 108 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa 109 | os.makedirs(save_dir, exist_ok=opt['debug']) 110 | val_acc = model.inference(val_loader, save_dir) 111 | 112 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa 113 | os.makedirs(save_dir, exist_ok=opt['debug']) 114 | test_acc = model.inference(test_loader, save_dir) 115 | 116 | logger.info( 117 | f'Epoch: {epoch}, val_acc: {val_acc: .4f}, test_acc: {test_acc: .4f}.' 118 | ) 119 | 120 | if test_acc > best_acc: 121 | best_epoch = epoch 122 | best_acc = test_acc 123 | 124 | logger.info(f'Best epoch: {best_epoch}, ' 125 | f'Best test acc: {best_acc: .4f}.') 126 | 127 | # save model 128 | model.save_network( 129 | f'{opt["path"]["models"]}/models_epoch{epoch}.pth') 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /train_parsing_gen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import random 6 | import time 7 | 8 | import torch 9 | 10 | from data.parsing_generation_segm_attr_dataset import \ 11 | ParsingGenerationDeepFashionAttrSegmDataset 12 | from models import create_model 13 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger 14 | from utils.options import dict2str, dict_to_nonedict, parse 15 | from utils.util import make_exp_dirs 16 | 17 | 18 | def main(): 19 | # options 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 22 | args = parser.parse_args() 23 | opt = parse(args.opt, is_train=True) 24 | 25 | # mkdir and loggers 26 | make_exp_dirs(opt) 27 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 28 | logger = get_root_logger( 29 | logger_name='base', log_level=logging.INFO, log_file=log_file) 30 | logger.info(dict2str(opt)) 31 | # initialize tensorboard logger 32 | tb_logger = None 33 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 34 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name']) 35 | 36 | # convert to NoneDict, which returns None for missing keys 37 | opt = dict_to_nonedict(opt) 38 | 39 | # set up data loader 40 | train_dataset = ParsingGenerationDeepFashionAttrSegmDataset( 41 | segm_dir=opt['segm_dir'], 42 | pose_dir=opt['pose_dir'], 43 | ann_file=opt['train_ann_file']) 44 | train_loader = torch.utils.data.DataLoader( 45 | dataset=train_dataset, 46 | batch_size=opt['batch_size'], 47 | shuffle=True, 48 | num_workers=opt['num_workers'], 49 | drop_last=True) 50 | logger.info(f'Number of train set: {len(train_dataset)}.') 51 | opt['max_iters'] = opt['num_epochs'] * len( 52 | train_dataset) // opt['batch_size'] 53 | 54 | val_dataset = ParsingGenerationDeepFashionAttrSegmDataset( 55 | segm_dir=opt['segm_dir'], 56 | pose_dir=opt['pose_dir'], 57 | ann_file=opt['val_ann_file']) 58 | val_loader = torch.utils.data.DataLoader( 59 | dataset=val_dataset, 60 | batch_size=1, 61 | shuffle=False, 62 | num_workers=opt['num_workers']) 63 | logger.info(f'Number of val set: {len(val_dataset)}.') 64 | 65 | test_dataset = ParsingGenerationDeepFashionAttrSegmDataset( 66 | segm_dir=opt['segm_dir'], 67 | pose_dir=opt['pose_dir'], 68 | ann_file=opt['test_ann_file']) 69 | test_loader = torch.utils.data.DataLoader( 70 | dataset=test_dataset, 71 | batch_size=1, 72 | shuffle=False, 73 | num_workers=opt['num_workers']) 74 | logger.info(f'Number of test set: {len(test_dataset)}.') 75 | 76 | current_iter = 0 77 | best_epoch = None 78 | best_acc = 0 79 | 80 | model = create_model(opt) 81 | 82 | data_time, iter_time = 0, 0 83 | current_iter = 0 84 | 85 | # create message logger (formatted outputs) 86 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 87 | 88 | for epoch in range(opt['num_epochs']): 89 | lr = model.update_learning_rate(epoch) 90 | 91 | for _, batch_data in enumerate(train_loader): 92 | data_time = time.time() - data_time 93 | 94 | current_iter += 1 95 | 96 | model.feed_data(batch_data) 97 | model.optimize_parameters() 98 | 99 | iter_time = time.time() - iter_time 100 | if current_iter % opt['print_freq'] == 0: 101 | log_vars = {'epoch': epoch, 'iter': current_iter} 102 | log_vars.update({'lrs': [lr]}) 103 | log_vars.update({'time': iter_time, 'data_time': data_time}) 104 | log_vars.update(model.get_current_log()) 105 | msg_logger(log_vars) 106 | 107 | data_time = time.time() 108 | iter_time = time.time() 109 | 110 | if epoch % opt['val_freq'] == 0: 111 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' 112 | os.makedirs(save_dir, exist_ok=opt['debug']) 113 | val_acc = model.inference(val_loader, save_dir) 114 | 115 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' 116 | os.makedirs(save_dir, exist_ok=opt['debug']) 117 | test_acc = model.inference(test_loader, save_dir) 118 | 119 | logger.info(f'Epoch: {epoch}, ' 120 | f'val_acc: {val_acc: .4f}, ' 121 | f'test_acc: {test_acc: .4f}.') 122 | 123 | if test_acc > best_acc: 124 | best_epoch = epoch 125 | best_acc = test_acc 126 | 127 | logger.info(f'Best epoch: {best_epoch}, ' 128 | f'Best test acc: {best_acc: .4f}.') 129 | 130 | # save model 131 | model.save_network( 132 | f'{opt["path"]["models"]}/parsing_generation_epoch{epoch}.pth') 133 | 134 | 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /train_parsing_token.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import random 6 | import time 7 | 8 | import torch 9 | 10 | from data.mask_dataset import MaskDataset 11 | from models import create_model 12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger 13 | from utils.options import dict2str, dict_to_nonedict, parse 14 | from utils.util import make_exp_dirs 15 | 16 | 17 | def main(): 18 | # options 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 21 | args = parser.parse_args() 22 | opt = parse(args.opt, is_train=True) 23 | 24 | # mkdir and loggers 25 | make_exp_dirs(opt) 26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 27 | logger = get_root_logger( 28 | logger_name='base', log_level=logging.INFO, log_file=log_file) 29 | logger.info(dict2str(opt)) 30 | # initialize tensorboard logger 31 | tb_logger = None 32 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name']) 34 | 35 | # convert to NoneDict, which returns None for missing keys 36 | opt = dict_to_nonedict(opt) 37 | 38 | # set up data loader 39 | train_dataset = MaskDataset( 40 | segm_dir=opt['segm_dir'], ann_dir=opt['train_ann_file'], xflip=True) 41 | train_loader = torch.utils.data.DataLoader( 42 | dataset=train_dataset, 43 | batch_size=opt['batch_size'], 44 | shuffle=True, 45 | num_workers=opt['num_workers'], 46 | persistent_workers=True, 47 | drop_last=True) 48 | logger.info(f'Number of train set: {len(train_dataset)}.') 49 | opt['max_iters'] = opt['num_epochs'] * len( 50 | train_dataset) // opt['batch_size'] 51 | 52 | val_dataset = MaskDataset( 53 | segm_dir=opt['segm_dir'], ann_dir=opt['val_ann_file']) 54 | val_loader = torch.utils.data.DataLoader( 55 | dataset=val_dataset, batch_size=1, shuffle=False) 56 | logger.info(f'Number of val set: {len(val_dataset)}.') 57 | 58 | test_dataset = MaskDataset( 59 | segm_dir=opt['segm_dir'], ann_dir=opt['test_ann_file']) 60 | test_loader = torch.utils.data.DataLoader( 61 | dataset=test_dataset, batch_size=1, shuffle=False) 62 | logger.info(f'Number of test set: {len(test_dataset)}.') 63 | 64 | current_iter = 0 65 | best_epoch = None 66 | best_loss = 100000 67 | 68 | model = create_model(opt) 69 | 70 | data_time, iter_time = 0, 0 71 | current_iter = 0 72 | 73 | # create message logger (formatted outputs) 74 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 75 | 76 | for epoch in range(opt['num_epochs']): 77 | lr = model.update_learning_rate(epoch) 78 | 79 | for _, batch_data in enumerate(train_loader): 80 | data_time = time.time() - data_time 81 | 82 | current_iter += 1 83 | 84 | model.optimize_parameters(batch_data, current_iter) 85 | 86 | iter_time = time.time() - iter_time 87 | if current_iter % opt['print_freq'] == 0: 88 | log_vars = {'epoch': epoch, 'iter': current_iter} 89 | log_vars.update({'lrs': [lr]}) 90 | log_vars.update({'time': iter_time, 'data_time': data_time}) 91 | log_vars.update(model.get_current_log()) 92 | msg_logger(log_vars) 93 | 94 | data_time = time.time() 95 | iter_time = time.time() 96 | 97 | if epoch % opt['val_freq'] == 0: 98 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa 99 | os.makedirs(save_dir, exist_ok=opt['debug']) 100 | val_loss_total, _, _ = model.inference(val_loader, save_dir) 101 | 102 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa 103 | os.makedirs(save_dir, exist_ok=opt['debug']) 104 | test_loss_total, _, _ = model.inference(test_loader, save_dir) 105 | 106 | logger.info(f'Epoch: {epoch}, ' 107 | f'val_loss_total: {val_loss_total}, ' 108 | f'test_loss_total: {test_loss_total}.') 109 | 110 | if test_loss_total < best_loss: 111 | best_epoch = epoch 112 | best_loss = test_loss_total 113 | 114 | logger.info(f'Best epoch: {best_epoch}, ' 115 | f'Best test loss: {best_loss: .4f}.') 116 | 117 | # save model 118 | model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth') 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /train_sampler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import random 6 | import time 7 | 8 | import torch 9 | 10 | from data.segm_attr_dataset import DeepFashionAttrSegmDataset 11 | from models import create_model 12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger 13 | from utils.options import dict2str, dict_to_nonedict, parse 14 | from utils.util import make_exp_dirs 15 | 16 | 17 | def main(): 18 | # options 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 21 | args = parser.parse_args() 22 | opt = parse(args.opt, is_train=True) 23 | 24 | # mkdir and loggers 25 | make_exp_dirs(opt) 26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 27 | logger = get_root_logger( 28 | logger_name='base', log_level=logging.INFO, log_file=log_file) 29 | logger.info(dict2str(opt)) 30 | # initialize tensorboard logger 31 | tb_logger = None 32 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name']) 34 | 35 | # convert to NoneDict, which returns None for missing keys 36 | opt = dict_to_nonedict(opt) 37 | 38 | # set up data loader 39 | train_dataset = DeepFashionAttrSegmDataset( 40 | img_dir=opt['train_img_dir'], 41 | segm_dir=opt['segm_dir'], 42 | pose_dir=opt['pose_dir'], 43 | ann_dir=opt['train_ann_file'], 44 | xflip=True) 45 | train_loader = torch.utils.data.DataLoader( 46 | dataset=train_dataset, 47 | batch_size=opt['batch_size'], 48 | shuffle=True, 49 | num_workers=opt['num_workers'], 50 | persistent_workers=True, 51 | drop_last=True) 52 | logger.info(f'Number of train set: {len(train_dataset)}.') 53 | opt['max_iters'] = opt['num_epochs'] * len( 54 | train_dataset) // opt['batch_size'] 55 | 56 | val_dataset = DeepFashionAttrSegmDataset( 57 | img_dir=opt['train_img_dir'], 58 | segm_dir=opt['segm_dir'], 59 | pose_dir=opt['pose_dir'], 60 | ann_dir=opt['val_ann_file']) 61 | val_loader = torch.utils.data.DataLoader( 62 | dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False) 63 | logger.info(f'Number of val set: {len(val_dataset)}.') 64 | 65 | test_dataset = DeepFashionAttrSegmDataset( 66 | img_dir=opt['test_img_dir'], 67 | segm_dir=opt['segm_dir'], 68 | pose_dir=opt['pose_dir'], 69 | ann_dir=opt['test_ann_file']) 70 | test_loader = torch.utils.data.DataLoader( 71 | dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False) 72 | logger.info(f'Number of test set: {len(test_dataset)}.') 73 | 74 | current_iter = 0 75 | 76 | model = create_model(opt) 77 | 78 | data_time, iter_time = 0, 0 79 | current_iter = 0 80 | 81 | # create message logger (formatted outputs) 82 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 83 | 84 | for epoch in range(opt['num_epochs']): 85 | lr = model.update_learning_rate(epoch, current_iter) 86 | 87 | for _, batch_data in enumerate(train_loader): 88 | data_time = time.time() - data_time 89 | 90 | current_iter += 1 91 | 92 | model.feed_data(batch_data) 93 | model.optimize_parameters() 94 | 95 | iter_time = time.time() - iter_time 96 | if current_iter % opt['print_freq'] == 0: 97 | log_vars = {'epoch': epoch, 'iter': current_iter} 98 | log_vars.update({'lrs': [lr]}) 99 | log_vars.update({'time': iter_time, 'data_time': data_time}) 100 | log_vars.update(model.get_current_log()) 101 | msg_logger(log_vars) 102 | 103 | data_time = time.time() 104 | iter_time = time.time() 105 | 106 | if epoch % opt['val_freq'] == 0 and epoch != 0: 107 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa 108 | os.makedirs(save_dir, exist_ok=opt['debug']) 109 | model.inference(val_loader, save_dir) 110 | 111 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa 112 | os.makedirs(save_dir, exist_ok=opt['debug']) 113 | model.inference(test_loader, save_dir) 114 | 115 | # save model 116 | model.save_network( 117 | model._denoise_fn, 118 | f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth') 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /train_vqvae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import random 6 | import time 7 | 8 | import torch 9 | 10 | from data.segm_attr_dataset import DeepFashionAttrSegmDataset 11 | from models import create_model 12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger 13 | from utils.options import dict2str, dict_to_nonedict, parse 14 | from utils.util import make_exp_dirs 15 | 16 | 17 | def main(): 18 | # options 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 21 | args = parser.parse_args() 22 | opt = parse(args.opt, is_train=True) 23 | 24 | # mkdir and loggers 25 | make_exp_dirs(opt) 26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 27 | logger = get_root_logger( 28 | logger_name='base', log_level=logging.INFO, log_file=log_file) 29 | logger.info(dict2str(opt)) 30 | # initialize tensorboard logger 31 | tb_logger = None 32 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name']) 34 | 35 | # convert to NoneDict, which returns None for missing keys 36 | opt = dict_to_nonedict(opt) 37 | 38 | # set up data loader 39 | train_dataset = DeepFashionAttrSegmDataset( 40 | img_dir=opt['train_img_dir'], 41 | segm_dir=opt['segm_dir'], 42 | pose_dir=opt['pose_dir'], 43 | ann_dir=opt['train_ann_file'], 44 | xflip=True) 45 | train_loader = torch.utils.data.DataLoader( 46 | dataset=train_dataset, 47 | batch_size=opt['batch_size'], 48 | shuffle=True, 49 | num_workers=opt['num_workers'], 50 | persistent_workers=True, 51 | drop_last=True) 52 | logger.info(f'Number of train set: {len(train_dataset)}.') 53 | opt['max_iters'] = opt['num_epochs'] * len( 54 | train_dataset) // opt['batch_size'] 55 | 56 | val_dataset = DeepFashionAttrSegmDataset( 57 | img_dir=opt['train_img_dir'], 58 | segm_dir=opt['segm_dir'], 59 | pose_dir=opt['pose_dir'], 60 | ann_dir=opt['val_ann_file']) 61 | val_loader = torch.utils.data.DataLoader( 62 | dataset=val_dataset, batch_size=1, shuffle=False) 63 | logger.info(f'Number of val set: {len(val_dataset)}.') 64 | 65 | test_dataset = DeepFashionAttrSegmDataset( 66 | img_dir=opt['test_img_dir'], 67 | segm_dir=opt['segm_dir'], 68 | pose_dir=opt['pose_dir'], 69 | ann_dir=opt['test_ann_file']) 70 | test_loader = torch.utils.data.DataLoader( 71 | dataset=test_dataset, batch_size=1, shuffle=False) 72 | logger.info(f'Number of test set: {len(test_dataset)}.') 73 | 74 | current_iter = 0 75 | best_epoch = None 76 | best_loss = 100000 77 | 78 | model = create_model(opt) 79 | 80 | data_time, iter_time = 0, 0 81 | current_iter = 0 82 | 83 | # create message logger (formatted outputs) 84 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 85 | 86 | for epoch in range(opt['num_epochs']): 87 | lr = model.update_learning_rate(epoch) 88 | 89 | for _, batch_data in enumerate(train_loader): 90 | data_time = time.time() - data_time 91 | 92 | current_iter += 1 93 | 94 | model.optimize_parameters(batch_data, current_iter) 95 | 96 | iter_time = time.time() - iter_time 97 | if current_iter % opt['print_freq'] == 0: 98 | log_vars = {'epoch': epoch, 'iter': current_iter} 99 | log_vars.update({'lrs': [lr]}) 100 | log_vars.update({'time': iter_time, 'data_time': data_time}) 101 | log_vars.update(model.get_current_log()) 102 | msg_logger(log_vars) 103 | 104 | data_time = time.time() 105 | iter_time = time.time() 106 | 107 | if epoch % opt['val_freq'] == 0: 108 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa 109 | os.makedirs(save_dir, exist_ok=opt['debug']) 110 | val_loss_total = model.inference(val_loader, save_dir) 111 | 112 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa 113 | os.makedirs(save_dir, exist_ok=opt['debug']) 114 | test_loss_total = model.inference(test_loader, save_dir) 115 | 116 | logger.info(f'Epoch: {epoch}, ' 117 | f'val_loss_total: {val_loss_total}, ' 118 | f'test_loss_total: {test_loss_total}.') 119 | 120 | if test_loss_total < best_loss: 121 | best_epoch = epoch 122 | best_loss = test_loss_total 123 | 124 | logger.info(f'Best epoch: {best_epoch}, ' 125 | f'Best test loss: {best_loss: .4f}.') 126 | 127 | # save model 128 | model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth') 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/__init__.py -------------------------------------------------------------------------------- /ui/color_blocks/class_bag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_bag.png -------------------------------------------------------------------------------- /ui/color_blocks/class_belt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_belt.png -------------------------------------------------------------------------------- /ui/color_blocks/class_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_bg.png -------------------------------------------------------------------------------- /ui/color_blocks/class_dress.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_dress.png -------------------------------------------------------------------------------- /ui/color_blocks/class_earstuds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_earstuds.png -------------------------------------------------------------------------------- /ui/color_blocks/class_eyeglass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_eyeglass.png -------------------------------------------------------------------------------- /ui/color_blocks/class_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_face.png -------------------------------------------------------------------------------- /ui/color_blocks/class_footwear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_footwear.png -------------------------------------------------------------------------------- /ui/color_blocks/class_glove.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_glove.png -------------------------------------------------------------------------------- /ui/color_blocks/class_hair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_hair.png -------------------------------------------------------------------------------- /ui/color_blocks/class_headwear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_headwear.png -------------------------------------------------------------------------------- /ui/color_blocks/class_leggings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_leggings.png -------------------------------------------------------------------------------- /ui/color_blocks/class_necklace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_necklace.png -------------------------------------------------------------------------------- /ui/color_blocks/class_neckwear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_neckwear.png -------------------------------------------------------------------------------- /ui/color_blocks/class_outer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_outer.png -------------------------------------------------------------------------------- /ui/color_blocks/class_pants.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_pants.png -------------------------------------------------------------------------------- /ui/color_blocks/class_ring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_ring.png -------------------------------------------------------------------------------- /ui/color_blocks/class_rompers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_rompers.png -------------------------------------------------------------------------------- /ui/color_blocks/class_skin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_skin.png -------------------------------------------------------------------------------- /ui/color_blocks/class_skirt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_skirt.png -------------------------------------------------------------------------------- /ui/color_blocks/class_socks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_socks.png -------------------------------------------------------------------------------- /ui/color_blocks/class_tie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_tie.png -------------------------------------------------------------------------------- /ui/color_blocks/class_top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_top.png -------------------------------------------------------------------------------- /ui/color_blocks/class_wrist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_wrist.png -------------------------------------------------------------------------------- /ui/icons/icon_palette.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/icons/icon_palette.png -------------------------------------------------------------------------------- /ui/icons/icon_title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/icons/icon_title.png -------------------------------------------------------------------------------- /ui/mouse_event.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from PyQt5.QtCore import * 5 | from PyQt5.QtGui import * 6 | from PyQt5.QtWidgets import * 7 | 8 | color_list = [ 9 | QColor(0, 0, 0), 10 | QColor(255, 250, 250), 11 | QColor(220, 220, 220), 12 | QColor(250, 235, 215), 13 | QColor(255, 250, 205), 14 | QColor(211, 211, 211), 15 | QColor(70, 130, 180), 16 | QColor(127, 255, 212), 17 | QColor(0, 100, 0), 18 | QColor(50, 205, 50), 19 | QColor(255, 255, 0), 20 | QColor(245, 222, 179), 21 | QColor(255, 140, 0), 22 | QColor(255, 0, 0), 23 | QColor(16, 78, 139), 24 | QColor(144, 238, 144), 25 | QColor(50, 205, 174), 26 | QColor(50, 155, 250), 27 | QColor(160, 140, 88), 28 | QColor(213, 140, 88), 29 | QColor(90, 140, 90), 30 | QColor(185, 210, 205), 31 | QColor(130, 165, 180), 32 | QColor(225, 141, 151) 33 | ] 34 | 35 | 36 | class GraphicsScene(QGraphicsScene): 37 | 38 | def __init__(self, mode, size, parent=None): 39 | QGraphicsScene.__init__(self, parent) 40 | self.mode = mode 41 | self.size = size 42 | self.mouse_clicked = False 43 | self.prev_pt = None 44 | 45 | # self.masked_image = None 46 | 47 | # save the points 48 | self.mask_points = [] 49 | for i in range(len(color_list)): 50 | self.mask_points.append([]) 51 | 52 | # save the size of points 53 | self.size_points = [] 54 | for i in range(len(color_list)): 55 | self.size_points.append([]) 56 | 57 | # save the history of edit 58 | self.history = [] 59 | 60 | def reset(self): 61 | # save the points 62 | self.mask_points = [] 63 | for i in range(len(color_list)): 64 | self.mask_points.append([]) 65 | # save the size of points 66 | self.size_points = [] 67 | for i in range(len(color_list)): 68 | self.size_points.append([]) 69 | # save the history of edit 70 | self.history = [] 71 | 72 | self.mode = 0 73 | self.prev_pt = None 74 | 75 | def mousePressEvent(self, event): 76 | self.mouse_clicked = True 77 | 78 | def mouseReleaseEvent(self, event): 79 | self.prev_pt = None 80 | self.mouse_clicked = False 81 | 82 | def mouseMoveEvent(self, event): # drawing 83 | if self.mouse_clicked: 84 | if self.prev_pt: 85 | self.drawMask(self.prev_pt, event.scenePos(), 86 | color_list[self.mode], self.size) 87 | pts = {} 88 | pts['prev'] = (int(self.prev_pt.x()), int(self.prev_pt.y())) 89 | pts['curr'] = (int(event.scenePos().x()), 90 | int(event.scenePos().y())) 91 | 92 | self.size_points[self.mode].append(self.size) 93 | self.mask_points[self.mode].append(pts) 94 | self.history.append(self.mode) 95 | self.prev_pt = event.scenePos() 96 | else: 97 | self.prev_pt = event.scenePos() 98 | 99 | def drawMask(self, prev_pt, curr_pt, color, size): 100 | lineItem = QGraphicsLineItem(QLineF(prev_pt, curr_pt)) 101 | lineItem.setPen(QPen(color, size, Qt.SolidLine)) # rect 102 | self.addItem(lineItem) 103 | 104 | def erase_prev_pt(self): 105 | self.prev_pt = None 106 | 107 | def reset_items(self): 108 | for i in range(len(self.items())): 109 | item = self.items()[0] 110 | self.removeItem(item) 111 | 112 | def undo(self): 113 | if len(self.items()) > 1: 114 | if len(self.items()) >= 9: 115 | for i in range(8): 116 | item = self.items()[0] 117 | self.removeItem(item) 118 | if self.history[-1] == self.mode: 119 | self.mask_points[self.mode].pop() 120 | self.size_points[self.mode].pop() 121 | self.history.pop() 122 | else: 123 | for i in range(len(self.items()) - 1): 124 | item = self.items()[0] 125 | self.removeItem(item) 126 | if self.history[-1] == self.mode: 127 | self.mask_points[self.mode].pop() 128 | self.size_points[self.mode].pop() 129 | self.history.pop() 130 | -------------------------------------------------------------------------------- /ui/ui.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtGui, QtWidgets 2 | from PyQt5.QtCore import * 3 | from PyQt5.QtGui import * 4 | from PyQt5.QtWidgets import * 5 | 6 | 7 | class Ui_Form(object): 8 | 9 | def setupUi(self, Form): 10 | Form.setObjectName("Form") 11 | Form.resize(1250, 670) 12 | 13 | self.pushButton_2 = QtWidgets.QPushButton(Form) 14 | self.pushButton_2.setGeometry(QtCore.QRect(20, 60, 97, 27)) 15 | self.pushButton_2.setObjectName("pushButton_2") 16 | 17 | self.pushButton_6 = QtWidgets.QPushButton(Form) 18 | self.pushButton_6.setGeometry(QtCore.QRect(20, 100, 97, 27)) 19 | self.pushButton_6.setObjectName("pushButton_6") 20 | 21 | # Generate Parsing 22 | self.pushButton_0 = QtWidgets.QPushButton(Form) 23 | self.pushButton_0.setGeometry(QtCore.QRect(126, 60, 150, 27)) 24 | self.pushButton_0.setObjectName("pushButton_0") 25 | 26 | # Generate Human 27 | self.pushButton_1 = QtWidgets.QPushButton(Form) 28 | self.pushButton_1.setGeometry(QtCore.QRect(126, 100, 150, 27)) 29 | self.pushButton_1.setObjectName("pushButton_1") 30 | 31 | # shape text box 32 | self.label_heading_1 = QtWidgets.QLabel(Form) 33 | self.label_heading_1.setText('Describe the shape.') 34 | self.label_heading_1.setObjectName("label_heading_1") 35 | self.label_heading_1.setGeometry(QtCore.QRect(320, 20, 200, 20)) 36 | 37 | self.message_box_1 = QtWidgets.QLineEdit(Form) 38 | self.message_box_1.setGeometry(QtCore.QRect(320, 50, 256, 80)) 39 | self.message_box_1.setObjectName("message_box_1") 40 | self.message_box_1.setAlignment(Qt.AlignTop) 41 | 42 | # texture text box 43 | self.label_heading_2 = QtWidgets.QLabel(Form) 44 | self.label_heading_2.setText('Describe the textures.') 45 | self.label_heading_2.setObjectName("label_heading_2") 46 | self.label_heading_2.setGeometry(QtCore.QRect(620, 20, 200, 20)) 47 | 48 | self.message_box_2 = QtWidgets.QLineEdit(Form) 49 | self.message_box_2.setGeometry(QtCore.QRect(620, 50, 256, 80)) 50 | self.message_box_2.setObjectName("message_box_2") 51 | self.message_box_2.setAlignment(Qt.AlignTop) 52 | 53 | # title icon 54 | self.title_icon = QtWidgets.QLabel(Form) 55 | self.title_icon.setGeometry(QtCore.QRect(30, 10, 200, 50)) 56 | self.title_icon.setPixmap( 57 | QtGui.QPixmap('./ui/icons/icon_title.png').scaledToWidth(200)) 58 | 59 | # palette icon 60 | self.palette_icon = QtWidgets.QLabel(Form) 61 | self.palette_icon.setGeometry(QtCore.QRect(950, 10, 256, 128)) 62 | self.palette_icon.setPixmap( 63 | QtGui.QPixmap('./ui/icons/icon_palette.png').scaledToWidth(256)) 64 | 65 | # top 66 | self.pushButton_8 = QtWidgets.QPushButton(' top', Form) 67 | self.pushButton_8.setGeometry(QtCore.QRect(940, 120, 120, 27)) 68 | self.pushButton_8.setObjectName("pushButton_8") 69 | self.pushButton_8.setStyleSheet( 70 | "text-align: left; padding-left: 10px;") 71 | self.pushButton_8.setIcon(QIcon('./ui/color_blocks/class_top.png')) 72 | # skin 73 | self.pushButton_9 = QtWidgets.QPushButton(' skin', Form) 74 | self.pushButton_9.setGeometry(QtCore.QRect(940, 165, 120, 27)) 75 | self.pushButton_9.setObjectName("pushButton_9") 76 | self.pushButton_9.setStyleSheet( 77 | "text-align: left; padding-left: 10px;") 78 | self.pushButton_9.setIcon(QIcon('./ui/color_blocks/class_skin.png')) 79 | # outer 80 | self.pushButton_10 = QtWidgets.QPushButton(' outer', Form) 81 | self.pushButton_10.setGeometry(QtCore.QRect(940, 210, 120, 27)) 82 | self.pushButton_10.setObjectName("pushButton_10") 83 | self.pushButton_10.setStyleSheet( 84 | "text-align: left; padding-left: 10px;") 85 | self.pushButton_10.setIcon(QIcon('./ui/color_blocks/class_outer.png')) 86 | # face 87 | self.pushButton_11 = QtWidgets.QPushButton(' face', Form) 88 | self.pushButton_11.setGeometry(QtCore.QRect(940, 255, 120, 27)) 89 | self.pushButton_11.setObjectName("pushButton_11") 90 | self.pushButton_11.setStyleSheet( 91 | "text-align: left; padding-left: 10px;") 92 | self.pushButton_11.setIcon(QIcon('./ui/color_blocks/class_face.png')) 93 | # skirt 94 | self.pushButton_12 = QtWidgets.QPushButton(' skirt', Form) 95 | self.pushButton_12.setGeometry(QtCore.QRect(940, 300, 120, 27)) 96 | self.pushButton_12.setObjectName("pushButton_12") 97 | self.pushButton_12.setStyleSheet( 98 | "text-align: left; padding-left: 10px;") 99 | self.pushButton_12.setIcon(QIcon('./ui/color_blocks/class_skirt.png')) 100 | # hair 101 | self.pushButton_13 = QtWidgets.QPushButton(' hair', Form) 102 | self.pushButton_13.setGeometry(QtCore.QRect(940, 345, 120, 27)) 103 | self.pushButton_13.setObjectName("pushButton_13") 104 | self.pushButton_13.setStyleSheet( 105 | "text-align: left; padding-left: 10px;") 106 | self.pushButton_13.setIcon(QIcon('./ui/color_blocks/class_hair.png')) 107 | # dress 108 | self.pushButton_14 = QtWidgets.QPushButton(' dress', Form) 109 | self.pushButton_14.setGeometry(QtCore.QRect(940, 390, 120, 27)) 110 | self.pushButton_14.setObjectName("pushButton_14") 111 | self.pushButton_14.setStyleSheet( 112 | "text-align: left; padding-left: 10px;") 113 | self.pushButton_14.setIcon(QIcon('./ui/color_blocks/class_dress.png')) 114 | # headwear 115 | self.pushButton_15 = QtWidgets.QPushButton(' headwear', Form) 116 | self.pushButton_15.setGeometry(QtCore.QRect(940, 435, 120, 27)) 117 | self.pushButton_15.setObjectName("pushButton_15") 118 | self.pushButton_15.setStyleSheet( 119 | "text-align: left; padding-left: 10px;") 120 | self.pushButton_15.setIcon( 121 | QIcon('./ui/color_blocks/class_headwear.png')) 122 | # pants 123 | self.pushButton_16 = QtWidgets.QPushButton(' pants', Form) 124 | self.pushButton_16.setGeometry(QtCore.QRect(940, 480, 120, 27)) 125 | self.pushButton_16.setObjectName("pushButton_16") 126 | self.pushButton_16.setStyleSheet( 127 | "text-align: left; padding-left: 10px;") 128 | self.pushButton_16.setIcon(QIcon('./ui/color_blocks/class_pants.png')) 129 | # eyeglasses 130 | self.pushButton_17 = QtWidgets.QPushButton(' eyeglass', Form) 131 | self.pushButton_17.setGeometry(QtCore.QRect(940, 525, 120, 27)) 132 | self.pushButton_17.setObjectName("pushButton_17") 133 | self.pushButton_17.setStyleSheet( 134 | "text-align: left; padding-left: 10px;") 135 | self.pushButton_17.setIcon( 136 | QIcon('./ui/color_blocks/class_eyeglass.png')) 137 | # rompers 138 | self.pushButton_18 = QtWidgets.QPushButton(' rompers', Form) 139 | self.pushButton_18.setGeometry(QtCore.QRect(940, 570, 120, 27)) 140 | self.pushButton_18.setObjectName("pushButton_18") 141 | self.pushButton_18.setStyleSheet( 142 | "text-align: left; padding-left: 10px;") 143 | self.pushButton_18.setIcon( 144 | QIcon('./ui/color_blocks/class_rompers.png')) 145 | # footwear 146 | self.pushButton_19 = QtWidgets.QPushButton(' footwear', Form) 147 | self.pushButton_19.setGeometry(QtCore.QRect(940, 615, 120, 27)) 148 | self.pushButton_19.setObjectName("pushButton_19") 149 | self.pushButton_19.setStyleSheet( 150 | "text-align: left; padding-left: 10px;") 151 | self.pushButton_19.setIcon( 152 | QIcon('./ui/color_blocks/class_footwear.png')) 153 | 154 | # leggings 155 | self.pushButton_20 = QtWidgets.QPushButton(' leggings', Form) 156 | self.pushButton_20.setGeometry(QtCore.QRect(1100, 120, 120, 27)) 157 | self.pushButton_20.setObjectName("pushButton_10") 158 | self.pushButton_20.setStyleSheet( 159 | "text-align: left; padding-left: 10px;") 160 | self.pushButton_20.setIcon( 161 | QIcon('./ui/color_blocks/class_leggings.png')) 162 | 163 | # ring 164 | self.pushButton_21 = QtWidgets.QPushButton(' ring', Form) 165 | self.pushButton_21.setGeometry(QtCore.QRect(1100, 165, 120, 27)) 166 | self.pushButton_21.setObjectName("pushButton_2`0`") 167 | self.pushButton_21.setStyleSheet( 168 | "text-align: left; padding-left: 10px;") 169 | self.pushButton_21.setIcon(QIcon('./ui/color_blocks/class_ring.png')) 170 | 171 | # belt 172 | self.pushButton_22 = QtWidgets.QPushButton(' belt', Form) 173 | self.pushButton_22.setGeometry(QtCore.QRect(1100, 210, 120, 27)) 174 | self.pushButton_22.setObjectName("pushButton_2`0`") 175 | self.pushButton_22.setStyleSheet( 176 | "text-align: left; padding-left: 10px;") 177 | self.pushButton_22.setIcon(QIcon('./ui/color_blocks/class_belt.png')) 178 | 179 | # neckwear 180 | self.pushButton_23 = QtWidgets.QPushButton(' neckwear', Form) 181 | self.pushButton_23.setGeometry(QtCore.QRect(1100, 255, 120, 27)) 182 | self.pushButton_23.setObjectName("pushButton_2`0`") 183 | self.pushButton_23.setStyleSheet( 184 | "text-align: left; padding-left: 10px;") 185 | self.pushButton_23.setIcon( 186 | QIcon('./ui/color_blocks/class_neckwear.png')) 187 | 188 | # wrist 189 | self.pushButton_24 = QtWidgets.QPushButton(' wrist', Form) 190 | self.pushButton_24.setGeometry(QtCore.QRect(1100, 300, 120, 27)) 191 | self.pushButton_24.setObjectName("pushButton_2`0`") 192 | self.pushButton_24.setStyleSheet( 193 | "text-align: left; padding-left: 10px;") 194 | self.pushButton_24.setIcon(QIcon('./ui/color_blocks/class_wrist.png')) 195 | 196 | # socks 197 | self.pushButton_25 = QtWidgets.QPushButton(' socks', Form) 198 | self.pushButton_25.setGeometry(QtCore.QRect(1100, 345, 120, 27)) 199 | self.pushButton_25.setObjectName("pushButton_2`0`") 200 | self.pushButton_25.setStyleSheet( 201 | "text-align: left; padding-left: 10px;") 202 | self.pushButton_25.setIcon(QIcon('./ui/color_blocks/class_socks.png')) 203 | 204 | # tie 205 | self.pushButton_26 = QtWidgets.QPushButton(' tie', Form) 206 | self.pushButton_26.setGeometry(QtCore.QRect(1100, 390, 120, 27)) 207 | self.pushButton_26.setObjectName("pushButton_2`0`") 208 | self.pushButton_26.setStyleSheet( 209 | "text-align: left; padding-left: 10px;") 210 | self.pushButton_26.setIcon(QIcon('./ui/color_blocks/class_tie.png')) 211 | 212 | # earstuds 213 | self.pushButton_27 = QtWidgets.QPushButton(' necklace', Form) 214 | self.pushButton_27.setGeometry(QtCore.QRect(1100, 435, 120, 27)) 215 | self.pushButton_27.setObjectName("pushButton_2`0`") 216 | self.pushButton_27.setStyleSheet( 217 | "text-align: left; padding-left: 10px;") 218 | self.pushButton_27.setIcon( 219 | QIcon('./ui/color_blocks/class_necklace.png')) 220 | 221 | # necklace 222 | self.pushButton_28 = QtWidgets.QPushButton(' earstuds', Form) 223 | self.pushButton_28.setGeometry(QtCore.QRect(1100, 480, 120, 27)) 224 | self.pushButton_28.setObjectName("pushButton_2`0`") 225 | self.pushButton_28.setStyleSheet( 226 | "text-align: left; padding-left: 10px;") 227 | self.pushButton_28.setIcon( 228 | QIcon('./ui/color_blocks/class_earstuds.png')) 229 | 230 | # bag 231 | self.pushButton_29 = QtWidgets.QPushButton(' bag', Form) 232 | self.pushButton_29.setGeometry(QtCore.QRect(1100, 525, 120, 27)) 233 | self.pushButton_29.setObjectName("pushButton_2`0`") 234 | self.pushButton_29.setStyleSheet( 235 | "text-align: left; padding-left: 10px;") 236 | self.pushButton_29.setIcon(QIcon('./ui/color_blocks/class_bag.png')) 237 | 238 | # glove 239 | self.pushButton_30 = QtWidgets.QPushButton(' glove', Form) 240 | self.pushButton_30.setGeometry(QtCore.QRect(1100, 570, 120, 27)) 241 | self.pushButton_30.setObjectName("pushButton_2`0`") 242 | self.pushButton_30.setStyleSheet( 243 | "text-align: left; padding-left: 10px;") 244 | self.pushButton_30.setIcon(QIcon('./ui/color_blocks/class_glove.png')) 245 | 246 | # background 247 | self.pushButton_31 = QtWidgets.QPushButton(' background', Form) 248 | self.pushButton_31.setGeometry(QtCore.QRect(1100, 615, 120, 27)) 249 | self.pushButton_31.setObjectName("pushButton_2`0`") 250 | self.pushButton_31.setStyleSheet( 251 | "text-align: left; padding-left: 10px;") 252 | self.pushButton_31.setIcon(QIcon('./ui/color_blocks/class_bg.png')) 253 | 254 | self.graphicsView = QtWidgets.QGraphicsView(Form) 255 | self.graphicsView.setGeometry(QtCore.QRect(20, 140, 256, 512)) 256 | self.graphicsView.setObjectName("graphicsView") 257 | self.graphicsView_2 = QtWidgets.QGraphicsView(Form) 258 | self.graphicsView_2.setGeometry(QtCore.QRect(320, 140, 256, 512)) 259 | self.graphicsView_2.setObjectName("graphicsView_2") 260 | self.graphicsView_3 = QtWidgets.QGraphicsView(Form) 261 | self.graphicsView_3.setGeometry(QtCore.QRect(620, 140, 256, 512)) 262 | self.graphicsView_3.setObjectName("graphicsView_3") 263 | 264 | self.retranslateUi(Form) 265 | self.pushButton_2.clicked.connect(Form.open_densepose) 266 | self.pushButton_6.clicked.connect(Form.save_img) 267 | self.pushButton_8.clicked.connect(Form.top_mode) 268 | self.pushButton_9.clicked.connect(Form.skin_mode) 269 | self.pushButton_10.clicked.connect(Form.outer_mode) 270 | self.pushButton_11.clicked.connect(Form.face_mode) 271 | self.pushButton_12.clicked.connect(Form.skirt_mode) 272 | self.pushButton_13.clicked.connect(Form.hair_mode) 273 | self.pushButton_14.clicked.connect(Form.dress_mode) 274 | self.pushButton_15.clicked.connect(Form.headwear_mode) 275 | self.pushButton_16.clicked.connect(Form.pants_mode) 276 | self.pushButton_17.clicked.connect(Form.eyeglass_mode) 277 | self.pushButton_18.clicked.connect(Form.rompers_mode) 278 | self.pushButton_19.clicked.connect(Form.footwear_mode) 279 | self.pushButton_20.clicked.connect(Form.leggings_mode) 280 | self.pushButton_21.clicked.connect(Form.ring_mode) 281 | self.pushButton_22.clicked.connect(Form.belt_mode) 282 | self.pushButton_23.clicked.connect(Form.neckwear_mode) 283 | self.pushButton_24.clicked.connect(Form.wrist_mode) 284 | self.pushButton_25.clicked.connect(Form.socks_mode) 285 | self.pushButton_26.clicked.connect(Form.tie_mode) 286 | self.pushButton_27.clicked.connect(Form.earstuds_mode) 287 | self.pushButton_28.clicked.connect(Form.necklace_mode) 288 | self.pushButton_29.clicked.connect(Form.bag_mode) 289 | self.pushButton_30.clicked.connect(Form.glove_mode) 290 | self.pushButton_31.clicked.connect(Form.background_mode) 291 | self.pushButton_0.clicked.connect(Form.generate_parsing) 292 | self.pushButton_1.clicked.connect(Form.generate_human) 293 | 294 | QtCore.QMetaObject.connectSlotsByName(Form) 295 | 296 | def retranslateUi(self, Form): 297 | _translate = QtCore.QCoreApplication.translate 298 | Form.setWindowTitle(_translate("Form", "Text2Human")) 299 | self.pushButton_2.setText(_translate("Form", "Load Pose")) 300 | self.pushButton_6.setText(_translate("Form", "Save Image")) 301 | 302 | self.pushButton_0.setText(_translate("Form", "Generate Parsing")) 303 | self.pushButton_1.setText(_translate("Form", "Generate Human")) 304 | 305 | 306 | if __name__ == "__main__": 307 | import sys 308 | app = QtWidgets.QApplication(sys.argv) 309 | Form = QtWidgets.QWidget() 310 | ui = Ui_Form() 311 | ui.setupUi(Form) 312 | Form.show() 313 | sys.exit(app.exec_()) 314 | -------------------------------------------------------------------------------- /ui_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from PyQt5.QtCore import * 8 | from PyQt5.QtGui import * 9 | from PyQt5.QtWidgets import * 10 | 11 | from models.sample_model import SampleFromPoseModel 12 | from ui.mouse_event import GraphicsScene 13 | from ui.ui import Ui_Form 14 | from utils.language_utils import (generate_shape_attributes, 15 | generate_texture_attributes) 16 | from utils.options import dict_to_nonedict, parse 17 | 18 | color_list = [(0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215), 19 | (255, 250, 205), (211, 211, 211), (70, 130, 180), 20 | (127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0), 21 | (245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139), 22 | (144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88), 23 | (213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180), 24 | (225, 141, 151)] 25 | 26 | 27 | class Ex(QWidget, Ui_Form): 28 | 29 | def __init__(self, opt): 30 | super(Ex, self).__init__() 31 | self.setupUi(self) 32 | self.show() 33 | 34 | self.output_img = None 35 | 36 | self.mat_img = None 37 | 38 | self.mode = 0 39 | self.size = 6 40 | self.mask = None 41 | self.mask_m = None 42 | self.img = None 43 | 44 | # about UI 45 | self.mouse_clicked = False 46 | self.scene = QGraphicsScene() 47 | self.graphicsView.setScene(self.scene) 48 | self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft) 49 | self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 50 | self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 51 | 52 | self.ref_scene = GraphicsScene(self.mode, self.size) 53 | self.graphicsView_2.setScene(self.ref_scene) 54 | self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft) 55 | self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 56 | self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 57 | 58 | self.result_scene = QGraphicsScene() 59 | self.graphicsView_3.setScene(self.result_scene) 60 | self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft) 61 | self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 62 | self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 63 | 64 | self.dlg = QColorDialog(self.graphicsView) 65 | self.color = None 66 | 67 | self.sample_model = SampleFromPoseModel(opt) 68 | 69 | def open_densepose(self): 70 | fileName, _ = QFileDialog.getOpenFileName(self, "Open File", 71 | QDir.currentPath()) 72 | if fileName: 73 | image = QPixmap(fileName) 74 | mat_img = Image.open(fileName) 75 | self.pose_img = mat_img.copy() 76 | if image.isNull(): 77 | QMessageBox.information(self, "Image Viewer", 78 | "Cannot load %s." % fileName) 79 | return 80 | image = image.scaled(self.graphicsView.size(), 81 | Qt.IgnoreAspectRatio) 82 | 83 | if len(self.scene.items()) > 0: 84 | self.scene.removeItem(self.scene.items()[-1]) 85 | self.scene.addPixmap(image) 86 | 87 | self.ref_scene.clear() 88 | self.result_scene.clear() 89 | 90 | # load pose to model 91 | self.pose_img = np.array( 92 | self.pose_img.resize( 93 | size=(256, 512), 94 | resample=Image.LANCZOS))[:, :, 2:].transpose( 95 | 2, 0, 1).astype(np.float32) 96 | self.pose_img = self.pose_img / 12. - 1 97 | 98 | self.pose_img = torch.from_numpy(self.pose_img).unsqueeze(1) 99 | 100 | self.sample_model.feed_pose_data(self.pose_img) 101 | 102 | def generate_parsing(self): 103 | self.ref_scene.reset_items() 104 | self.ref_scene.reset() 105 | 106 | shape_texts = self.message_box_1.text() 107 | 108 | shape_attributes = generate_shape_attributes(shape_texts) 109 | shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0) 110 | self.sample_model.feed_shape_attributes(shape_attributes) 111 | 112 | self.sample_model.generate_parsing_map() 113 | self.sample_model.generate_quantized_segm() 114 | 115 | self.colored_segm = self.sample_model.palette_result( 116 | self.sample_model.segm[0].cpu()) 117 | 118 | self.mask_m = cv2.cvtColor( 119 | cv2.cvtColor(self.colored_segm, cv2.COLOR_RGB2BGR), 120 | cv2.COLOR_BGR2RGB) 121 | 122 | qim = QImage(self.colored_segm.data.tobytes(), 123 | self.colored_segm.shape[1], self.colored_segm.shape[0], 124 | QImage.Format_RGB888) 125 | 126 | image = QPixmap.fromImage(qim) 127 | 128 | image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) 129 | 130 | if len(self.ref_scene.items()) > 0: 131 | self.ref_scene.removeItem(self.ref_scene.items()[-1]) 132 | self.ref_scene.addPixmap(image) 133 | 134 | self.result_scene.clear() 135 | 136 | def generate_human(self): 137 | for i in range(24): 138 | self.mask_m = self.make_mask(self.mask_m, 139 | self.ref_scene.mask_points[i], 140 | self.ref_scene.size_points[i], 141 | color_list[i]) 142 | 143 | seg_map = np.full(self.mask_m.shape[:-1], -1) 144 | 145 | # convert rgb to num 146 | for index, color in enumerate(color_list): 147 | seg_map[np.sum(self.mask_m == color, axis=2) == 3] = index 148 | assert (seg_map != -1).all() 149 | 150 | self.sample_model.segm = torch.from_numpy(seg_map).unsqueeze( 151 | 0).unsqueeze(0).to(self.sample_model.device) 152 | self.sample_model.generate_quantized_segm() 153 | 154 | texture_texts = self.message_box_2.text() 155 | texture_attributes = generate_texture_attributes(texture_texts) 156 | 157 | texture_attributes = torch.LongTensor(texture_attributes) 158 | 159 | self.sample_model.feed_texture_attributes(texture_attributes) 160 | 161 | self.sample_model.generate_texture_map() 162 | result = self.sample_model.sample_and_refine() 163 | result = result.permute(0, 2, 3, 1) 164 | result = result.detach().cpu().numpy() 165 | result = result * 255 166 | 167 | result = np.asarray(result[0, :, :, :], dtype=np.uint8) 168 | 169 | self.output_img = result 170 | 171 | qim = QImage(result.data.tobytes(), result.shape[1], result.shape[0], 172 | QImage.Format_RGB888) 173 | image = QPixmap.fromImage(qim) 174 | 175 | image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) 176 | 177 | if len(self.result_scene.items()) > 0: 178 | self.result_scene.removeItem(self.result_scene.items()[-1]) 179 | self.result_scene.addPixmap(image) 180 | 181 | def top_mode(self): 182 | self.ref_scene.mode = 1 183 | 184 | def skin_mode(self): 185 | self.ref_scene.mode = 15 186 | 187 | def outer_mode(self): 188 | self.ref_scene.mode = 2 189 | 190 | def face_mode(self): 191 | self.ref_scene.mode = 14 192 | 193 | def skirt_mode(self): 194 | self.ref_scene.mode = 3 195 | 196 | def hair_mode(self): 197 | self.ref_scene.mode = 13 198 | 199 | def dress_mode(self): 200 | self.ref_scene.mode = 4 201 | 202 | def headwear_mode(self): 203 | self.ref_scene.mode = 7 204 | 205 | def pants_mode(self): 206 | self.ref_scene.mode = 5 207 | 208 | def eyeglass_mode(self): 209 | self.ref_scene.mode = 8 210 | 211 | def rompers_mode(self): 212 | self.ref_scene.mode = 21 213 | 214 | def footwear_mode(self): 215 | self.ref_scene.mode = 11 216 | 217 | def leggings_mode(self): 218 | self.ref_scene.mode = 6 219 | 220 | def ring_mode(self): 221 | self.ref_scene.mode = 16 222 | 223 | def belt_mode(self): 224 | self.ref_scene.mode = 10 225 | 226 | def neckwear_mode(self): 227 | self.ref_scene.mode = 9 228 | 229 | def wrist_mode(self): 230 | self.ref_scene.mode = 17 231 | 232 | def socks_mode(self): 233 | self.ref_scene.mode = 18 234 | 235 | def tie_mode(self): 236 | self.ref_scene.mode = 23 237 | 238 | def earstuds_mode(self): 239 | self.ref_scene.mode = 22 240 | 241 | def necklace_mode(self): 242 | self.ref_scene.mode = 20 243 | 244 | def bag_mode(self): 245 | self.ref_scene.mode = 12 246 | 247 | def glove_mode(self): 248 | self.ref_scene.mode = 19 249 | 250 | def background_mode(self): 251 | self.ref_scene.mode = 0 252 | 253 | def make_mask(self, mask, pts, sizes, color): 254 | if len(pts) > 0: 255 | for idx, pt in enumerate(pts): 256 | cv2.line(mask, pt['prev'], pt['curr'], color, sizes[idx]) 257 | return mask 258 | 259 | def save_img(self): 260 | if type(self.output_img): 261 | fileName, _ = QFileDialog.getSaveFileName(self, "Save File", 262 | QDir.currentPath()) 263 | cv2.imwrite(fileName + '.png', self.output_img[:, :, ::-1]) 264 | 265 | def undo(self): 266 | self.scene.undo() 267 | 268 | def clear(self): 269 | 270 | self.ref_scene.reset_items() 271 | self.ref_scene.reset() 272 | 273 | self.ref_scene.clear() 274 | 275 | self.result_scene.clear() 276 | 277 | 278 | if __name__ == '__main__': 279 | 280 | app = QApplication(sys.argv) 281 | opt = './configs/sample_from_pose.yml' 282 | opt = parse(opt, is_train=False) 283 | opt = dict_to_nonedict(opt) 284 | ex = Ex(opt) 285 | sys.exit(app.exec_()) 286 | -------------------------------------------------------------------------------- /ui_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui_util/__init__.py -------------------------------------------------------------------------------- /ui_util/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import yaml 6 | 7 | logger = logging.getLogger() 8 | 9 | class Config(object): 10 | def __init__(self, filename=None): 11 | assert os.path.exists(filename), "ERROR: Config File doesn't exist." 12 | try: 13 | with open(filename, 'r') as f: 14 | self._cfg_dict = yaml.load(f) 15 | # parent of IOError, OSError *and* WindowsError where available 16 | except EnvironmentError: 17 | logger.error('Please check the file with name of "%s"', filename) 18 | logger.info(' APP CONFIG '.center(80, '-')) 19 | logger.info(''.center(80, '-')) 20 | 21 | def __getattr__(self, name): 22 | value = self._cfg_dict[name] 23 | if isinstance(value, dict): 24 | value = DictAsMember(value) 25 | return value 26 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/utils/__init__.py -------------------------------------------------------------------------------- /utils/language_utils.py: -------------------------------------------------------------------------------- 1 | from curses import A_ATTRIBUTES 2 | 3 | import numpy 4 | import torch 5 | from pip import main 6 | from sentence_transformers import SentenceTransformer, util 7 | 8 | # predefined shape text 9 | upper_length_text = [ 10 | 'sleeveless', 'without sleeves', 'sleeves have been cut off', 'tank top', 11 | 'tank shirt', 'muscle shirt', 'short-sleeve', 'short sleeves', 12 | 'with short sleeves', 'medium-sleeve', 'medium sleeves', 13 | 'with medium sleeves', 'sleeves reach elbow', 'long-sleeve', 14 | 'long sleeves', 'with long sleeves' 15 | ] 16 | upper_length_attr = { 17 | 'sleeveless': 0, 18 | 'without sleeves': 0, 19 | 'sleeves have been cut off': 0, 20 | 'tank top': 0, 21 | 'tank shirt': 0, 22 | 'muscle shirt': 0, 23 | 'short-sleeve': 1, 24 | 'with short sleeves': 1, 25 | 'short sleeves': 1, 26 | 'medium-sleeve': 2, 27 | 'with medium sleeves': 2, 28 | 'medium sleeves': 2, 29 | 'sleeves reach elbow': 2, 30 | 'long-sleeve': 3, 31 | 'long sleeves': 3, 32 | 'with long sleeves': 3 33 | } 34 | lower_length_text = [ 35 | 'three-point', 'medium', 'short', 'covering knee', 'cropped', 36 | 'three-quarter', 'long', 'slack', 'of long length' 37 | ] 38 | lower_length_attr = { 39 | 'three-point': 0, 40 | 'medium': 1, 41 | 'covering knee': 1, 42 | 'short': 1, 43 | 'cropped': 2, 44 | 'three-quarter': 2, 45 | 'long': 3, 46 | 'slack': 3, 47 | 'of long length': 3 48 | } 49 | socks_length_text = [ 50 | 'socks', 'stocking', 'pantyhose', 'leggings', 'sheer hosiery' 51 | ] 52 | socks_length_attr = { 53 | 'socks': 0, 54 | 'stocking': 1, 55 | 'pantyhose': 1, 56 | 'leggings': 1, 57 | 'sheer hosiery': 1 58 | } 59 | hat_text = ['hat', 'cap', 'chapeau'] 60 | eyeglasses_text = ['sunglasses'] 61 | belt_text = ['belt', 'with a dress tied around the waist'] 62 | outer_shape_text = [ 63 | 'with outer clothing open', 'with outer clothing unzipped', 64 | 'covering inner clothes', 'with outer clothing zipped' 65 | ] 66 | outer_shape_attr = { 67 | 'with outer clothing open': 0, 68 | 'with outer clothing unzipped': 0, 69 | 'covering inner clothes': 1, 70 | 'with outer clothing zipped': 1 71 | } 72 | 73 | upper_types = [ 74 | 'T-shirt', 'shirt', 'sweater', 'hoodie', 'tops', 'blouse', 'Basic Tee' 75 | ] 76 | outer_types = [ 77 | 'jacket', 'outer clothing', 'coat', 'overcoat', 'blazer', 'outerwear', 78 | 'duffle', 'cardigan' 79 | ] 80 | skirt_types = ['skirt'] 81 | dress_types = ['dress'] 82 | pant_types = ['jeans', 'pants', 'trousers'] 83 | rompers_types = ['rompers', 'bodysuit', 'jumpsuit'] 84 | 85 | attr_names_list = [ 86 | 'gender', 'hair length', '0 upper clothing length', 87 | '1 lower clothing length', '2 socks', '3 hat', '4 eyeglasses', '5 belt', 88 | '6 opening of outer clothing', '7 upper clothes', '8 outer clothing', 89 | '9 skirt', '10 dress', '11 pants', '12 rompers' 90 | ] 91 | 92 | 93 | def generate_shape_attributes(user_shape_texts): 94 | model = SentenceTransformer('all-MiniLM-L6-v2') 95 | parsed_texts = user_shape_texts.split(',') 96 | 97 | text_num = len(parsed_texts) 98 | 99 | human_attr = [0, 0] 100 | attr = [1, 3, 0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0] 101 | 102 | changed = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 103 | for text_id, text in enumerate(parsed_texts): 104 | user_embeddings = model.encode(text) 105 | if ('man' in text) and (text_id == 0): 106 | human_attr[0] = 0 107 | human_attr[1] = 0 108 | 109 | if ('woman' in text or 'lady' in text) and (text_id == 0): 110 | human_attr[0] = 1 111 | human_attr[1] = 2 112 | 113 | if (not changed[0]) and (text_id == 1): 114 | # upper length 115 | predefined_embeddings = model.encode(upper_length_text) 116 | similarities = util.dot_score(user_embeddings, 117 | predefined_embeddings) 118 | arg_idx = torch.argmax(similarities).item() 119 | attr[0] = upper_length_attr[upper_length_text[arg_idx]] 120 | changed[0] = 1 121 | 122 | if (not changed[1]) and ((text_num == 2 and text_id == 1) or 123 | (text_num > 2 and text_id == 2)): 124 | # lower length 125 | predefined_embeddings = model.encode(lower_length_text) 126 | similarities = util.dot_score(user_embeddings, 127 | predefined_embeddings) 128 | arg_idx = torch.argmax(similarities).item() 129 | attr[1] = lower_length_attr[lower_length_text[arg_idx]] 130 | changed[1] = 1 131 | 132 | if (not changed[2]) and (text_id > 2): 133 | # socks length 134 | predefined_embeddings = model.encode(socks_length_text) 135 | similarities = util.dot_score(user_embeddings, 136 | predefined_embeddings) 137 | arg_idx = torch.argmax(similarities).item() 138 | if similarities[0][arg_idx] > 0.7: 139 | attr[2] = arg_idx + 1 140 | changed[2] = 1 141 | 142 | if (not changed[3]) and (text_id > 2): 143 | # hat 144 | predefined_embeddings = model.encode(hat_text) 145 | similarities = util.dot_score(user_embeddings, 146 | predefined_embeddings) 147 | if similarities[0][0] > 0.7: 148 | attr[3] = 1 149 | changed[3] = 1 150 | 151 | if (not changed[4]) and (text_id > 2): 152 | # glasses 153 | predefined_embeddings = model.encode(eyeglasses_text) 154 | similarities = util.dot_score(user_embeddings, 155 | predefined_embeddings) 156 | arg_idx = torch.argmax(similarities).item() 157 | if similarities[0][arg_idx] > 0.7: 158 | attr[4] = arg_idx + 1 159 | changed[4] = 1 160 | 161 | if (not changed[5]) and (text_id > 2): 162 | # belt 163 | predefined_embeddings = model.encode(belt_text) 164 | similarities = util.dot_score(user_embeddings, 165 | predefined_embeddings) 166 | arg_idx = torch.argmax(similarities).item() 167 | if similarities[0][arg_idx] > 0.7: 168 | attr[5] = arg_idx + 1 169 | changed[5] = 1 170 | 171 | if (not changed[6]) and (text_id == 3): 172 | # outer coverage 173 | predefined_embeddings = model.encode(outer_shape_text) 174 | similarities = util.dot_score(user_embeddings, 175 | predefined_embeddings) 176 | arg_idx = torch.argmax(similarities).item() 177 | if similarities[0][arg_idx] > 0.7: 178 | attr[6] = arg_idx 179 | changed[6] = 1 180 | 181 | if (not changed[10]) and (text_num == 2 and text_id == 1): 182 | # dress_types 183 | predefined_embeddings = model.encode(dress_types) 184 | similarities = util.dot_score(user_embeddings, 185 | predefined_embeddings) 186 | similarity_skirt = util.dot_score(user_embeddings, 187 | model.encode(skirt_types)) 188 | if similarities[0][0] > 0.5 and similarities[0][ 189 | 0] > similarity_skirt[0][0]: 190 | attr[10] = 1 191 | attr[7] = 0 192 | attr[8] = 0 193 | attr[9] = 0 194 | attr[11] = 0 195 | attr[12] = 0 196 | 197 | changed[0] = 1 198 | changed[10] = 1 199 | changed[7] = 1 200 | changed[8] = 1 201 | changed[9] = 1 202 | changed[11] = 1 203 | changed[12] = 1 204 | 205 | if (not changed[12]) and (text_num == 2 and text_id == 1): 206 | # rompers_types 207 | predefined_embeddings = model.encode(rompers_types) 208 | similarities = util.dot_score(user_embeddings, 209 | predefined_embeddings) 210 | max_similarity = torch.max(similarities).item() 211 | if max_similarity > 0.6: 212 | attr[12] = 1 213 | attr[7] = 0 214 | attr[8] = 0 215 | attr[9] = 0 216 | attr[10] = 0 217 | attr[11] = 0 218 | 219 | changed[12] = 1 220 | changed[7] = 1 221 | changed[8] = 1 222 | changed[9] = 1 223 | changed[10] = 1 224 | changed[11] = 1 225 | 226 | if (not changed[7]) and (text_num > 2 and text_id == 1): 227 | # upper_types 228 | predefined_embeddings = model.encode(upper_types) 229 | similarities = util.dot_score(user_embeddings, 230 | predefined_embeddings) 231 | max_similarity = torch.max(similarities).item() 232 | if max_similarity > 0.6: 233 | attr[7] = 1 234 | changed[7] = 1 235 | 236 | if (not changed[8]) and (text_id == 3): 237 | # outer_types 238 | predefined_embeddings = model.encode(outer_types) 239 | similarities = util.dot_score(user_embeddings, 240 | predefined_embeddings) 241 | arg_idx = torch.argmax(similarities).item() 242 | if similarities[0][arg_idx] > 0.7: 243 | attr[6] = outer_shape_attr[outer_shape_text[arg_idx]] 244 | attr[8] = 1 245 | changed[8] = 1 246 | 247 | if (not changed[9]) and (text_num > 2 and text_id == 2): 248 | # skirt_types 249 | predefined_embeddings = model.encode(skirt_types) 250 | similarity_skirt = util.dot_score(user_embeddings, 251 | predefined_embeddings) 252 | similarity_dress = util.dot_score(user_embeddings, 253 | model.encode(dress_types)) 254 | if similarity_skirt[0][0] > 0.7 and similarity_skirt[0][ 255 | 0] > similarity_dress[0][0]: 256 | attr[9] = 1 257 | attr[10] = 0 258 | changed[9] = 1 259 | changed[10] = 1 260 | 261 | if (not changed[11]) and (text_num > 2 and text_id == 2): 262 | # pant_types 263 | predefined_embeddings = model.encode(pant_types) 264 | similarities = util.dot_score(user_embeddings, 265 | predefined_embeddings) 266 | max_similarity = torch.max(similarities).item() 267 | if max_similarity > 0.6: 268 | attr[11] = 1 269 | attr[9] = 0 270 | attr[10] = 0 271 | attr[12] = 0 272 | changed[11] = 1 273 | changed[9] = 1 274 | changed[10] = 1 275 | changed[12] = 1 276 | 277 | return human_attr + attr 278 | 279 | 280 | def generate_texture_attributes(user_text): 281 | parsed_texts = user_text.split(',') 282 | 283 | attr = [] 284 | for text in parsed_texts: 285 | if ('pure color' in text) or ('solid color' in text): 286 | attr.append(4) 287 | elif ('spline' in text) or ('stripe' in text): 288 | attr.append(3) 289 | elif ('plaid' in text) or ('lattice' in text): 290 | attr.append(5) 291 | elif 'floral' in text: 292 | attr.append(1) 293 | elif 'denim' in text: 294 | attr.append(0) 295 | else: 296 | attr.append(17) 297 | 298 | if len(attr) == 1: 299 | attr.append(attr[0]) 300 | attr.append(17) 301 | 302 | if len(attr) == 2: 303 | attr.append(17) 304 | 305 | return attr 306 | 307 | 308 | if __name__ == "__main__": 309 | user_request = input('Enter your request: ') 310 | while user_request != '\\q': 311 | attr = generate_shape_attributes(user_request) 312 | print(attr) 313 | for attr_name, attr_value in zip(attr_names_list, attr): 314 | print(attr_name, attr_value) 315 | user_request = input('Enter your request: ') 316 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | 6 | class MessageLogger(): 7 | """Message logger for printing. 8 | 9 | Args: 10 | opt (dict): Config. It contains the following keys: 11 | name (str): Exp name. 12 | logger (dict): Contains 'print_freq' (str) for logger interval. 13 | train (dict): Contains 'niter' (int) for total iters. 14 | use_tb_logger (bool): Use tensorboard logger. 15 | start_iter (int): Start iter. Default: 1. 16 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 17 | """ 18 | 19 | def __init__(self, opt, start_iter=1, tb_logger=None): 20 | self.exp_name = opt['name'] 21 | self.interval = opt['print_freq'] 22 | self.start_iter = start_iter 23 | self.max_iters = opt['max_iters'] 24 | self.use_tb_logger = opt['use_tb_logger'] 25 | self.tb_logger = tb_logger 26 | self.start_time = time.time() 27 | self.logger = get_root_logger() 28 | 29 | def __call__(self, log_vars): 30 | """Format logging message. 31 | 32 | Args: 33 | log_vars (dict): It contains the following keys: 34 | epoch (int): Epoch number. 35 | iter (int): Current iter. 36 | lrs (list): List for learning rates. 37 | 38 | time (float): Iter time. 39 | data_time (float): Data time for each iter. 40 | """ 41 | # epoch, iter, learning rates 42 | epoch = log_vars.pop('epoch') 43 | current_iter = log_vars.pop('iter') 44 | lrs = log_vars.pop('lrs') 45 | 46 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' 47 | f'iter:{current_iter:8,d}, lr:(') 48 | for v in lrs: 49 | message += f'{v:.3e},' 50 | message += ')] ' 51 | 52 | # time and estimated time 53 | if 'time' in log_vars.keys(): 54 | iter_time = log_vars.pop('time') 55 | data_time = log_vars.pop('data_time') 56 | 57 | total_time = time.time() - self.start_time 58 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 59 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 60 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 61 | message += f'[eta: {eta_str}, ' 62 | message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] ' 63 | 64 | # other items, especially losses 65 | for k, v in log_vars.items(): 66 | message += f'{k}: {v:.4e} ' 67 | # tensorboard logger 68 | if self.use_tb_logger and 'debug' not in self.exp_name: 69 | self.tb_logger.add_scalar(k, v, current_iter) 70 | 71 | self.logger.info(message) 72 | 73 | 74 | def init_tb_logger(log_dir): 75 | from torch.utils.tensorboard import SummaryWriter 76 | tb_logger = SummaryWriter(log_dir=log_dir) 77 | return tb_logger 78 | 79 | 80 | def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None): 81 | """Get the root logger. 82 | 83 | The logger will be initialized if it has not been initialized. By default a 84 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 85 | also be added. 86 | 87 | Args: 88 | logger_name (str): root logger name. Default: base. 89 | log_file (str | None): The log filename. If specified, a FileHandler 90 | will be added to the root logger. 91 | log_level (int): The root logger level. Note that only the process of 92 | rank 0 is affected, while other processes will set the level to 93 | "Error" and be silent most of the time. 94 | 95 | Returns: 96 | logging.Logger: The root logger. 97 | """ 98 | logger = logging.getLogger(logger_name) 99 | # if the logger has been initialized, just return it 100 | if logger.hasHandlers(): 101 | return logger 102 | 103 | format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s' 104 | logging.basicConfig(format=format_str, level=log_level) 105 | 106 | if log_file is not None: 107 | file_handler = logging.FileHandler(log_file, 'w') 108 | file_handler.setFormatter(logging.Formatter(format_str)) 109 | file_handler.setLevel(log_level) 110 | logger.addHandler(file_handler) 111 | 112 | return logger 113 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from collections import OrderedDict 4 | 5 | import yaml 6 | 7 | 8 | def ordered_yaml(): 9 | """Support OrderedDict for yaml. 10 | 11 | Returns: 12 | yaml Loader and Dumper. 13 | """ 14 | try: 15 | from yaml import CDumper as Dumper 16 | from yaml import CLoader as Loader 17 | except ImportError: 18 | from yaml import Dumper, Loader 19 | 20 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 21 | 22 | def dict_representer(dumper, data): 23 | return dumper.represent_dict(data.items()) 24 | 25 | def dict_constructor(loader, node): 26 | return OrderedDict(loader.construct_pairs(node)) 27 | 28 | Dumper.add_representer(OrderedDict, dict_representer) 29 | Loader.add_constructor(_mapping_tag, dict_constructor) 30 | return Loader, Dumper 31 | 32 | 33 | def parse(opt_path, is_train=True): 34 | """Parse option file. 35 | 36 | Args: 37 | opt_path (str): Option file path. 38 | is_train (str): Indicate whether in training or not. Default: True. 39 | 40 | Returns: 41 | (dict): Options. 42 | """ 43 | with open(opt_path, mode='r') as f: 44 | Loader, _ = ordered_yaml() 45 | opt = yaml.load(f, Loader=Loader) 46 | 47 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 48 | if opt.get('set_CUDA_VISIBLE_DEVICES', None): 49 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 50 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True) 51 | else: 52 | print('gpu_list: ', gpu_list, flush=True) 53 | 54 | opt['is_train'] = is_train 55 | 56 | # paths 57 | opt['path'] = {} 58 | opt['path']['root'] = osp.abspath( 59 | osp.join(__file__, osp.pardir, osp.pardir)) 60 | if is_train: 61 | experiments_root = osp.join(opt['path']['root'], 'experiments', 62 | opt['name']) 63 | opt['path']['experiments_root'] = experiments_root 64 | opt['path']['models'] = osp.join(experiments_root, 'models') 65 | opt['path']['log'] = experiments_root 66 | opt['path']['visualization'] = osp.join(experiments_root, 67 | 'visualization') 68 | 69 | # change some options for debug mode 70 | if 'debug' in opt['name']: 71 | opt['debug'] = True 72 | opt['val_freq'] = 1 73 | opt['print_freq'] = 1 74 | opt['save_checkpoint_freq'] = 1 75 | else: # test 76 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 77 | opt['path']['results_root'] = results_root 78 | opt['path']['log'] = results_root 79 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 80 | 81 | return opt 82 | 83 | 84 | def dict2str(opt, indent_level=1): 85 | """dict to string for printing options. 86 | 87 | Args: 88 | opt (dict): Option dict. 89 | indent_level (int): Indent level. Default: 1. 90 | 91 | Return: 92 | (str): Option string for printing. 93 | """ 94 | msg = '' 95 | for k, v in opt.items(): 96 | if isinstance(v, dict): 97 | msg += ' ' * (indent_level * 2) + k + ':[\n' 98 | msg += dict2str(v, indent_level + 1) 99 | msg += ' ' * (indent_level * 2) + ']\n' 100 | else: 101 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 102 | return msg 103 | 104 | 105 | class NoneDict(dict): 106 | """None dict. It will return none if key is not in the dict.""" 107 | 108 | def __missing__(self, key): 109 | return None 110 | 111 | 112 | def dict_to_nonedict(opt): 113 | """Convert to NoneDict, which returns None for missing keys. 114 | 115 | Args: 116 | opt (dict): Option dict. 117 | 118 | Returns: 119 | (dict): NoneDict for options. 120 | """ 121 | if isinstance(opt, dict): 122 | new_opt = dict() 123 | for key, sub_opt in opt.items(): 124 | new_opt[key] = dict_to_nonedict(sub_opt) 125 | return NoneDict(**new_opt) 126 | elif isinstance(opt, list): 127 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 128 | else: 129 | return opt 130 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import sys 5 | import time 6 | from shutil import get_terminal_size 7 | 8 | import numpy as np 9 | import torch 10 | 11 | logger = logging.getLogger('base') 12 | 13 | 14 | def make_exp_dirs(opt): 15 | """Make dirs for experiments.""" 16 | path_opt = opt['path'].copy() 17 | if opt['is_train']: 18 | overwrite = True if 'debug' in opt['name'] else False 19 | os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite) 20 | os.makedirs(path_opt.pop('models'), exist_ok=overwrite) 21 | else: 22 | os.makedirs(path_opt.pop('results_root')) 23 | 24 | 25 | def set_random_seed(seed): 26 | """Set random seeds.""" 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | 33 | 34 | class ProgressBar(object): 35 | """A progress bar which can print the progress. 36 | 37 | Modified from: 38 | https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 39 | """ 40 | 41 | def __init__(self, task_num=0, bar_width=50, start=True): 42 | self.task_num = task_num 43 | max_bar_width = self._get_max_bar_width() 44 | self.bar_width = ( 45 | bar_width if bar_width <= max_bar_width else max_bar_width) 46 | self.completed = 0 47 | if start: 48 | self.start() 49 | 50 | def _get_max_bar_width(self): 51 | terminal_width, _ = get_terminal_size() 52 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 53 | if max_bar_width < 10: 54 | print(f'terminal width is too small ({terminal_width}), ' 55 | 'please consider widen the terminal for better ' 56 | 'progressbar visualization') 57 | max_bar_width = 10 58 | return max_bar_width 59 | 60 | def start(self): 61 | if self.task_num > 0: 62 | sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, " 63 | f'elapsed: 0s, ETA:\nStart...\n') 64 | else: 65 | sys.stdout.write('completed: 0, elapsed: 0s') 66 | sys.stdout.flush() 67 | self.start_time = time.time() 68 | 69 | def update(self, msg='In progress...'): 70 | self.completed += 1 71 | elapsed = time.time() - self.start_time 72 | fps = self.completed / elapsed 73 | if self.task_num > 0: 74 | percentage = self.completed / float(self.task_num) 75 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 76 | mark_width = int(self.bar_width * percentage) 77 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 78 | sys.stdout.write('\033[2F') # cursor up 2 lines 79 | sys.stdout.write( 80 | '\033[J' 81 | ) # clean the output (remove extra chars since last display) 82 | sys.stdout.write( 83 | f'[{bar_chars}] {self.completed}/{self.task_num}, ' 84 | f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' 85 | f'ETA: {eta:5}s\n{msg}\n') 86 | else: 87 | sys.stdout.write( 88 | f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, ' 89 | f'{fps:.1f} tasks/s') 90 | sys.stdout.flush() 91 | 92 | 93 | class AverageMeter(object): 94 | """ 95 | Computes and stores the average and current value 96 | Imported from 97 | https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 98 | """ 99 | 100 | def __init__(self): 101 | self.reset() 102 | 103 | def reset(self): 104 | self.val = 0 105 | self.avg = 0 # running average = running sum / running count 106 | self.sum = 0 # running sum 107 | self.count = 0 # running count 108 | 109 | def update(self, val, n=1): 110 | # n = batch_size 111 | 112 | # val = batch accuracy for an attribute 113 | # self.val = val 114 | 115 | # sum = 100 * accumulative correct predictions for this attribute 116 | self.sum += val * n 117 | 118 | # count = total samples so far 119 | self.count += n 120 | 121 | # avg = 100 * avg accuracy for this attribute 122 | # for all the batches so far 123 | self.avg = self.sum / self.count 124 | --------------------------------------------------------------------------------